From fd0c9a130530eeab28c212cec601a3ef5ba1f843 Mon Sep 17 00:00:00 2001
From: erio
Date: Fri, 17 Apr 2026 17:00:29 +0800
Subject: [PATCH 001/189] fix(payment): store provider config as plaintext JSON
with legacy ciphertext fallback
Without TOTP_ENCRYPTION_KEY, saved payment configs were lost on restart because
the AES round-trip failed silently. Write new records as plaintext JSON; read
path tries JSON first, falls back to legacy AES decrypt when a key is present,
and treats unreadable values as empty so admins can re-enter them via the UI.
---
backend/internal/payment/crypto.go | 11 ++-
backend/internal/payment/load_balancer.go | 30 ++++--
.../internal/payment/load_balancer_test.go | 97 +++++++++++++++++++
.../service/payment_config_providers.go | 38 +++++---
4 files changed, 151 insertions(+), 25 deletions(-)
diff --git a/backend/internal/payment/crypto.go b/backend/internal/payment/crypto.go
index e39e957f..5467e50b 100644
--- a/backend/internal/payment/crypto.go
+++ b/backend/internal/payment/crypto.go
@@ -10,12 +10,15 @@ import (
"strings"
)
+// AES256KeySize is the required key length (in bytes) for AES-256-GCM.
+const AES256KeySize = 32
+
// Encrypt encrypts plaintext using AES-256-GCM with the given 32-byte key.
// The output format is "iv:authTag:ciphertext" where each component is base64-encoded,
// matching the Node.js crypto.ts format for cross-compatibility.
func Encrypt(plaintext string, key []byte) (string, error) {
- if len(key) != 32 {
- return "", fmt.Errorf("encryption key must be 32 bytes, got %d", len(key))
+ if len(key) != AES256KeySize {
+ return "", fmt.Errorf("encryption key must be %d bytes, got %d", AES256KeySize, len(key))
}
block, err := aes.NewCipher(key)
@@ -52,8 +55,8 @@ func Encrypt(plaintext string, key []byte) (string, error) {
// Decrypt decrypts a ciphertext string produced by Encrypt.
// The input format is "iv:authTag:ciphertext" where each component is base64-encoded.
func Decrypt(ciphertext string, key []byte) (string, error) {
- if len(key) != 32 {
- return "", fmt.Errorf("encryption key must be 32 bytes, got %d", len(key))
+ if len(key) != AES256KeySize {
+ return "", fmt.Errorf("encryption key must be %d bytes, got %d", AES256KeySize, len(key))
}
parts := strings.SplitN(ciphertext, ":", 3)
diff --git a/backend/internal/payment/load_balancer.go b/backend/internal/payment/load_balancer.go
index f0353173..52a1b011 100644
--- a/backend/internal/payment/load_balancer.go
+++ b/backend/internal/payment/load_balancer.go
@@ -261,6 +261,9 @@ func (lb *DefaultLoadBalancer) buildSelection(selected *dbent.PaymentProviderIns
if err != nil {
return nil, fmt.Errorf("decrypt instance %d config: %w", selected.ID, err)
}
+ if config == nil {
+ config = map[string]string{}
+ }
if selected.PaymentMode != "" {
config["paymentMode"] = selected.PaymentMode
@@ -275,16 +278,29 @@ func (lb *DefaultLoadBalancer) buildSelection(selected *dbent.PaymentProviderIns
}, nil
}
-func (lb *DefaultLoadBalancer) decryptConfig(encrypted string) (map[string]string, error) {
- plaintext, err := Decrypt(encrypted, lb.encryptionKey)
- if err != nil {
- return nil, err
+// decryptConfig parses a stored provider config.
+// New records are plaintext JSON; legacy records are AES-256-GCM ciphertext.
+// Unreadable values (legacy ciphertext without a valid key, or malformed data)
+// are treated as empty so the service keeps running while the admin re-enters
+// the config via the UI.
+func (lb *DefaultLoadBalancer) decryptConfig(stored string) (map[string]string, error) {
+ if stored == "" {
+ return nil, nil
}
var config map[string]string
- if err := json.Unmarshal([]byte(plaintext), &config); err != nil {
- return nil, fmt.Errorf("unmarshal config: %w", err)
+ if err := json.Unmarshal([]byte(stored), &config); err == nil {
+ return config, nil
}
- return config, nil
+ if len(lb.encryptionKey) == AES256KeySize {
+ if plaintext, err := Decrypt(stored, lb.encryptionKey); err == nil {
+ if err := json.Unmarshal([]byte(plaintext), &config); err == nil {
+ return config, nil
+ }
+ }
+ }
+ slog.Warn("payment provider config unreadable, treating as empty for re-entry",
+ "stored_len", len(stored))
+ return nil, nil
}
// GetInstanceDailyAmount returns the total completed order amount for an instance today.
diff --git a/backend/internal/payment/load_balancer_test.go b/backend/internal/payment/load_balancer_test.go
index 04b3c25b..2bf4f6ac 100644
--- a/backend/internal/payment/load_balancer_test.go
+++ b/backend/internal/payment/load_balancer_test.go
@@ -452,6 +452,103 @@ func TestStartOfDay(t *testing.T) {
}
}
+func TestDecryptConfig_PlaintextAndLegacyCompat(t *testing.T) {
+ t.Parallel()
+
+ key := make([]byte, AES256KeySize)
+ for i := range key {
+ key[i] = byte(i + 1)
+ }
+ wrongKey := make([]byte, AES256KeySize)
+ for i := range wrongKey {
+ wrongKey[i] = byte(0xFF - i)
+ }
+
+ plaintextJSON := `{"appId":"app-123","secret":"sec-xyz"}`
+
+ legacyEncrypted, err := Encrypt(plaintextJSON, key)
+ if err != nil {
+ t.Fatalf("seed Encrypt: %v", err)
+ }
+
+ tests := []struct {
+ name string
+ stored string
+ key []byte
+ want map[string]string
+ }{
+ {
+ name: "empty stored returns nil map",
+ stored: "",
+ key: key,
+ want: nil,
+ },
+ {
+ name: "plaintext JSON parses directly",
+ stored: plaintextJSON,
+ key: nil,
+ want: map[string]string{"appId": "app-123", "secret": "sec-xyz"},
+ },
+ {
+ name: "plaintext JSON works even with key present",
+ stored: plaintextJSON,
+ key: key,
+ want: map[string]string{"appId": "app-123", "secret": "sec-xyz"},
+ },
+ {
+ name: "legacy ciphertext with correct key decrypts",
+ stored: legacyEncrypted,
+ key: key,
+ want: map[string]string{"appId": "app-123", "secret": "sec-xyz"},
+ },
+ {
+ name: "legacy ciphertext with no key treated as empty",
+ stored: legacyEncrypted,
+ key: nil,
+ want: nil,
+ },
+ {
+ name: "legacy ciphertext with wrong key treated as empty",
+ stored: legacyEncrypted,
+ key: wrongKey,
+ want: nil,
+ },
+ {
+ name: "garbage data treated as empty",
+ stored: "not-json-and-not-ciphertext",
+ key: key,
+ want: nil,
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ t.Parallel()
+ lb := NewDefaultLoadBalancer(nil, tt.key)
+ got, err := lb.decryptConfig(tt.stored)
+ if err != nil {
+ t.Fatalf("decryptConfig unexpected error: %v", err)
+ }
+ if !stringMapEqual(got, tt.want) {
+ t.Fatalf("decryptConfig = %v, want %v", got, tt.want)
+ }
+ })
+ }
+}
+
+// stringMapEqual compares two map[string]string values; nil and empty are equal.
+func stringMapEqual(a, b map[string]string) bool {
+ if len(a) != len(b) {
+ return false
+ }
+ for k, v := range a {
+ if bv, ok := b[k]; !ok || bv != v {
+ return false
+ }
+ }
+ return true
+}
+
// ---------------------------------------------------------------------------
// Helpers
// ---------------------------------------------------------------------------
diff --git a/backend/internal/service/payment_config_providers.go b/backend/internal/service/payment_config_providers.go
index 3c406b45..59337ad6 100644
--- a/backend/internal/service/payment_config_providers.go
+++ b/backend/internal/service/payment_config_providers.go
@@ -4,6 +4,7 @@ import (
"context"
"encoding/json"
"fmt"
+ "log/slog"
"strconv"
"strings"
@@ -290,19 +291,29 @@ func (s *PaymentConfigService) mergeConfig(ctx context.Context, id int64, newCon
return existing, nil
}
-func (s *PaymentConfigService) decryptConfig(encrypted string) (map[string]string, error) {
- if encrypted == "" {
+// decryptConfig parses a stored provider config.
+// New records are plaintext JSON; legacy records are AES-256-GCM ciphertext
+// ("iv:authTag:ciphertext"). Values that cannot be parsed as either — including
+// legacy ciphertext with no/invalid TOTP_ENCRYPTION_KEY — are treated as empty,
+// letting the admin re-enter the config via the UI to complete the migration.
+func (s *PaymentConfigService) decryptConfig(stored string) (map[string]string, error) {
+ if stored == "" {
return nil, nil
}
- decrypted, err := payment.Decrypt(encrypted, s.encryptionKey)
- if err != nil {
- return nil, fmt.Errorf("decrypt config: %w", err)
+ var cfg map[string]string
+ if err := json.Unmarshal([]byte(stored), &cfg); err == nil {
+ return cfg, nil
}
- var raw map[string]string
- if err := json.Unmarshal([]byte(decrypted), &raw); err != nil {
- return nil, fmt.Errorf("unmarshal decrypted config: %w", err)
+ if len(s.encryptionKey) == payment.AES256KeySize {
+ if plaintext, err := payment.Decrypt(stored, s.encryptionKey); err == nil {
+ if err := json.Unmarshal([]byte(plaintext), &cfg); err == nil {
+ return cfg, nil
+ }
+ }
}
- return raw, nil
+ slog.Warn("payment provider config unreadable, treating as empty for re-entry",
+ "stored_len", len(stored))
+ return nil, nil
}
func (s *PaymentConfigService) DeleteProviderInstance(ctx context.Context, id int64) error {
@@ -317,14 +328,13 @@ func (s *PaymentConfigService) DeleteProviderInstance(ctx context.Context, id in
return s.entClient.PaymentProviderInstance.DeleteOneID(id).Exec(ctx)
}
+// encryptConfig serialises a provider config for storage.
+// New records are written as plaintext JSON; the historical AES-GCM wrapping
+// has been dropped but decryptConfig still accepts old ciphertext during migration.
func (s *PaymentConfigService) encryptConfig(cfg map[string]string) (string, error) {
data, err := json.Marshal(cfg)
if err != nil {
return "", fmt.Errorf("marshal config: %w", err)
}
- enc, err := payment.Encrypt(string(data), s.encryptionKey)
- if err != nil {
- return "", fmt.Errorf("encrypt config: %w", err)
- }
- return enc, nil
+ return string(data), nil
}
From 44cdef7934168f167c5d433e5947aea3ac5a279d Mon Sep 17 00:00:00 2001
From: erio
Date: Fri, 17 Apr 2026 17:00:45 +0800
Subject: [PATCH 002/189] fix(usage): subscription billing honours group rate
multiplier
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
Subscription-mode billing was consuming quota at TotalCost (raw) instead of
ActualCost (TotalCost * RateMultiplier), so per-group rate multipliers —
including free subscriptions (multiplier = 0) — were silently ignored.
Switch the three subscription cost writes in buildUsageBillingCommand,
finalizePostUsageBilling, and the legacy postUsageBilling fallback to
ActualCost, and add a table-driven test covering 2x / 0.5x / free multipliers
plus a balance-mode regression check.
---
backend/internal/service/gateway_service.go | 16 ++--
...teway_service_subscription_billing_test.go | 85 +++++++++++++++++++
2 files changed, 96 insertions(+), 5 deletions(-)
create mode 100644 backend/internal/service/gateway_service_subscription_billing_test.go
diff --git a/backend/internal/service/gateway_service.go b/backend/internal/service/gateway_service.go
index 4b4fc0bf..07a9e41c 100644
--- a/backend/internal/service/gateway_service.go
+++ b/backend/internal/service/gateway_service.go
@@ -7317,8 +7317,10 @@ func postUsageBilling(ctx context.Context, p *postUsageBillingParams, deps *bill
cost := p.Cost
if p.IsSubscriptionBill {
- if cost.TotalCost > 0 {
- if err := deps.userSubRepo.IncrementUsage(billingCtx, p.Subscription.ID, cost.TotalCost); err != nil {
+ // Subscription usage tracked by ActualCost so group rate multiplier
+ // consumes the quota at the expected speed.
+ if cost.ActualCost > 0 {
+ if err := deps.userSubRepo.IncrementUsage(billingCtx, p.Subscription.ID, cost.ActualCost); err != nil {
slog.Error("increment subscription usage failed", "subscription_id", p.Subscription.ID, "error", err)
}
}
@@ -7417,9 +7419,13 @@ func buildUsageBillingCommand(requestID string, usageLog *UsageLog, p *postUsage
}
}
+ // Record subscription / balance cost using ActualCost so the group (and any
+ // user-specific) rate multiplier consumes subscription quota at the expected
+ // speed. TotalCost remains the raw (pre-multiplier) value; downstream guards
+ // on "> 0" still correctly skip free subscriptions (RateMultiplier == 0).
if p.IsSubscriptionBill && p.Subscription != nil && p.Cost.TotalCost > 0 {
cmd.SubscriptionID = &p.Subscription.ID
- cmd.SubscriptionCost = p.Cost.TotalCost
+ cmd.SubscriptionCost = p.Cost.ActualCost
} else if p.Cost.ActualCost > 0 {
cmd.BalanceCost = p.Cost.ActualCost
}
@@ -7478,8 +7484,8 @@ func finalizePostUsageBilling(p *postUsageBillingParams, deps *billingDeps, resu
}
if p.IsSubscriptionBill {
- if p.Cost.TotalCost > 0 && p.User != nil && p.APIKey != nil && p.APIKey.GroupID != nil {
- deps.billingCacheService.QueueUpdateSubscriptionUsage(p.User.ID, *p.APIKey.GroupID, p.Cost.TotalCost)
+ if p.Cost.ActualCost > 0 && p.User != nil && p.APIKey != nil && p.APIKey.GroupID != nil {
+ deps.billingCacheService.QueueUpdateSubscriptionUsage(p.User.ID, *p.APIKey.GroupID, p.Cost.ActualCost)
}
} else if p.Cost.ActualCost > 0 && p.User != nil {
deps.billingCacheService.QueueDeductBalance(p.User.ID, p.Cost.ActualCost)
diff --git a/backend/internal/service/gateway_service_subscription_billing_test.go b/backend/internal/service/gateway_service_subscription_billing_test.go
new file mode 100644
index 00000000..42a81035
--- /dev/null
+++ b/backend/internal/service/gateway_service_subscription_billing_test.go
@@ -0,0 +1,85 @@
+//go:build unit
+
+package service
+
+import (
+ "testing"
+)
+
+// TestBuildUsageBillingCommand_SubscriptionAppliesRateMultiplier locks in the fix
+// that subscription-mode billing honours the group (and any user-specific) rate
+// multiplier — i.e. cmd.SubscriptionCost tracks ActualCost (= TotalCost *
+// RateMultiplier), not raw TotalCost.
+func TestBuildUsageBillingCommand_SubscriptionAppliesRateMultiplier(t *testing.T) {
+ t.Parallel()
+
+ groupID := int64(7)
+ subID := int64(42)
+
+ tests := []struct {
+ name string
+ totalCost float64
+ actualCost float64
+ isSubscription bool
+ wantSub float64
+ wantBalance float64
+ }{
+ {
+ name: "subscription with 2x multiplier consumes 2x quota",
+ totalCost: 1.0,
+ actualCost: 2.0,
+ isSubscription: true,
+ wantSub: 2.0,
+ wantBalance: 0,
+ },
+ {
+ name: "subscription with 0.5x multiplier consumes 0.5x quota",
+ totalCost: 1.0,
+ actualCost: 0.5,
+ isSubscription: true,
+ wantSub: 0.5,
+ wantBalance: 0,
+ },
+ {
+ name: "free subscription (multiplier 0) consumes no quota",
+ totalCost: 1.0,
+ actualCost: 0,
+ isSubscription: true,
+ wantSub: 0,
+ wantBalance: 0,
+ },
+ {
+ name: "balance billing keeps using ActualCost (regression)",
+ totalCost: 1.0,
+ actualCost: 2.0,
+ isSubscription: false,
+ wantSub: 0,
+ wantBalance: 2.0,
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ t.Parallel()
+ p := &postUsageBillingParams{
+ Cost: &CostBreakdown{TotalCost: tt.totalCost, ActualCost: tt.actualCost},
+ User: &User{ID: 1},
+ APIKey: &APIKey{ID: 2, GroupID: &groupID},
+ Account: &Account{ID: 3},
+ Subscription: &UserSubscription{ID: subID},
+ IsSubscriptionBill: tt.isSubscription,
+ }
+
+ cmd := buildUsageBillingCommand("req-1", nil, p)
+ if cmd == nil {
+ t.Fatal("buildUsageBillingCommand returned nil")
+ }
+ if cmd.SubscriptionCost != tt.wantSub {
+ t.Errorf("SubscriptionCost = %v, want %v", cmd.SubscriptionCost, tt.wantSub)
+ }
+ if cmd.BalanceCost != tt.wantBalance {
+ t.Errorf("BalanceCost = %v, want %v", cmd.BalanceCost, tt.wantBalance)
+ }
+ })
+ }
+}
From 948d8e6d024412cc2efda34ec41540fddb4bfb0b Mon Sep 17 00:00:00 2001
From: erio
Date: Fri, 17 Apr 2026 17:01:01 +0800
Subject: [PATCH 003/189] fix(admin): prevent browser password manager from
autofilling account API key
Chrome's password manager matched the apikey-type account's Base URL + API Key
inputs as a login form and autofilled the last saved password by domain, so
editing a Gemini account could overwrite its apikey with a Claude key that
shared the same Base URL. Add autocomplete="new-password" plus data-*-ignore
attributes for 1Password / LastPass / Bitwarden to opt the field out of every
major password manager's autofill.
---
frontend/src/components/account/EditAccountModal.vue | 4 ++++
1 file changed, 4 insertions(+)
diff --git a/frontend/src/components/account/EditAccountModal.vue b/frontend/src/components/account/EditAccountModal.vue
index 1da32e2c..59ca0b9c 100644
--- a/frontend/src/components/account/EditAccountModal.vue
+++ b/frontend/src/components/account/EditAccountModal.vue
@@ -52,6 +52,10 @@
v-model="editApiKey"
type="password"
class="input font-mono"
+ autocomplete="new-password"
+ data-1p-ignore
+ data-lpignore="true"
+ data-bwignore="true"
:placeholder="
account.platform === 'openai'
? 'sk-proj-...'
From df57d2776b1a74e37470970f1bcf8942ad810e1d Mon Sep 17 00:00:00 2001
From: erio
Date: Fri, 17 Apr 2026 18:32:12 +0800
Subject: [PATCH 004/189] fix(billing): reject rate_multiplier <= 0 on save;
clamp negatives to 0 in compute
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
分组倍率和用户专属倍率在保存时没有校验,0 会触发计费层的 `<=0 → 1.0`
防御条款,结果订阅/余额分组按标准价扣费;完全是沉默地绕过了业务规则。
- 保存校验(admin_service):CreateGroup / UpdateGroup / BatchSetGroupRateMultipliers /
UpdateUser.SyncUserGroupRates 全部要求 > 0
- 计算层(billing_service):三处 `<=0 → 1.0` 改为 `<0 → 0`;负数按 0 结算,
避免配置异常被静默按 1x 收费
- 前端:分组倍率 / 用户专属倍率输入 min 统一到 0.001
- 删除未使用的 IsFreeSubscription 方法
测试:新增 billing_service_rate_multiplier_test.go 端到端验证;更新原有锁定
旧 `<=0 → 1.0` 行为的测试。
---
backend/internal/service/admin_service.go | 21 +++++++
.../service/admin_service_group_test.go | 6 ++
backend/internal/service/billing_service.go | 16 ++---
.../service/billing_service_image_test.go | 5 +-
.../billing_service_rate_multiplier_test.go | 63 +++++++++++++++++++
.../internal/service/billing_service_test.go | 28 ---------
.../service/billing_service_unified_test.go | 40 ++++--------
backend/internal/service/group.go | 4 --
.../openai_gateway_record_usage_test.go | 2 +-
.../admin/group/GroupRateMultipliersModal.vue | 2 +-
.../admin/user/UserAllowedGroupsModal.vue | 4 +-
11 files changed, 119 insertions(+), 72 deletions(-)
create mode 100644 backend/internal/service/billing_service_rate_multiplier_test.go
diff --git a/backend/internal/service/admin_service.go b/backend/internal/service/admin_service.go
index 7c26a47c..701f3659 100644
--- a/backend/internal/service/admin_service.go
+++ b/backend/internal/service/admin_service.go
@@ -586,6 +586,15 @@ func (s *adminServiceImpl) assignDefaultSubscriptions(ctx context.Context, userI
}
func (s *adminServiceImpl) UpdateUser(ctx context.Context, id int64, input *UpdateUserInput) (*User, error) {
+ // 校验用户专属分组倍率:必须 > 0(nil 合法,表示清除专属倍率)
+ if input.GroupRates != nil {
+ for groupID, rate := range input.GroupRates {
+ if rate != nil && *rate <= 0 {
+ return nil, fmt.Errorf("rate_multiplier must be > 0 (group_id=%d)", groupID)
+ }
+ }
+ }
+
user, err := s.userRepo.GetByID(ctx, id)
if err != nil {
return nil, err
@@ -811,6 +820,10 @@ func (s *adminServiceImpl) GetGroup(ctx context.Context, id int64) (*Group, erro
}
func (s *adminServiceImpl) CreateGroup(ctx context.Context, input *CreateGroupInput) (*Group, error) {
+ if input.RateMultiplier <= 0 {
+ return nil, errors.New("rate_multiplier must be > 0")
+ }
+
platform := input.Platform
if platform == "" {
platform = PlatformAnthropic
@@ -1050,6 +1063,9 @@ func (s *adminServiceImpl) UpdateGroup(ctx context.Context, id int64, input *Upd
group.Platform = input.Platform
}
if input.RateMultiplier != nil {
+ if *input.RateMultiplier <= 0 {
+ return nil, errors.New("rate_multiplier must be > 0")
+ }
group.RateMultiplier = *input.RateMultiplier
}
if input.IsExclusive != nil {
@@ -1286,6 +1302,11 @@ func (s *adminServiceImpl) BatchSetGroupRateMultipliers(ctx context.Context, gro
if s.userGroupRateRepo == nil {
return nil
}
+ for _, e := range entries {
+ if e.RateMultiplier <= 0 {
+ return fmt.Errorf("rate_multiplier must be > 0 (user_id=%d)", e.UserID)
+ }
+ }
return s.userGroupRateRepo.SyncGroupRateMultipliers(ctx, groupID, entries)
}
diff --git a/backend/internal/service/admin_service_group_test.go b/backend/internal/service/admin_service_group_test.go
index a4c6d0ca..41d2c26a 100644
--- a/backend/internal/service/admin_service_group_test.go
+++ b/backend/internal/service/admin_service_group_test.go
@@ -621,6 +621,7 @@ func TestAdminService_CreateGroup_InvalidRequestFallbackRejectsUnsupportedPlatfo
_, err := svc.CreateGroup(context.Background(), &CreateGroupInput{
Name: "g1",
Platform: PlatformOpenAI,
+ RateMultiplier: 1.0,
SubscriptionType: SubscriptionTypeStandard,
FallbackGroupIDOnInvalidRequest: &fallbackID,
})
@@ -641,6 +642,7 @@ func TestAdminService_CreateGroup_InvalidRequestFallbackRejectsSubscription(t *t
_, err := svc.CreateGroup(context.Background(), &CreateGroupInput{
Name: "g1",
Platform: PlatformAnthropic,
+ RateMultiplier: 1.0,
SubscriptionType: SubscriptionTypeSubscription,
FallbackGroupIDOnInvalidRequest: &fallbackID,
})
@@ -695,6 +697,7 @@ func TestAdminService_CreateGroup_InvalidRequestFallbackRejectsFallbackGroup(t *
_, err := svc.CreateGroup(context.Background(), &CreateGroupInput{
Name: "g1",
Platform: PlatformAnthropic,
+ RateMultiplier: 1.0,
SubscriptionType: SubscriptionTypeStandard,
FallbackGroupIDOnInvalidRequest: &fallbackID,
})
@@ -713,6 +716,7 @@ func TestAdminService_CreateGroup_InvalidRequestFallbackNotFound(t *testing.T) {
_, err := svc.CreateGroup(context.Background(), &CreateGroupInput{
Name: "g1",
Platform: PlatformAnthropic,
+ RateMultiplier: 1.0,
SubscriptionType: SubscriptionTypeStandard,
FallbackGroupIDOnInvalidRequest: &fallbackID,
})
@@ -733,6 +737,7 @@ func TestAdminService_CreateGroup_InvalidRequestFallbackAllowsAntigravity(t *tes
group, err := svc.CreateGroup(context.Background(), &CreateGroupInput{
Name: "g1",
Platform: PlatformAntigravity,
+ RateMultiplier: 1.0,
SubscriptionType: SubscriptionTypeStandard,
FallbackGroupIDOnInvalidRequest: &fallbackID,
})
@@ -750,6 +755,7 @@ func TestAdminService_CreateGroup_InvalidRequestFallbackClearsOnZero(t *testing.
group, err := svc.CreateGroup(context.Background(), &CreateGroupInput{
Name: "g1",
Platform: PlatformAnthropic,
+ RateMultiplier: 1.0,
SubscriptionType: SubscriptionTypeStandard,
FallbackGroupIDOnInvalidRequest: &zero,
})
diff --git a/backend/internal/service/billing_service.go b/backend/internal/service/billing_service.go
index 32a54cbe..c9f32b3b 100644
--- a/backend/internal/service/billing_service.go
+++ b/backend/internal/service/billing_service.go
@@ -448,8 +448,9 @@ func (s *BillingService) CalculateCostUnified(input CostInput) (*CostBreakdown,
})
}
- if input.RateMultiplier <= 0 {
- input.RateMultiplier = 1.0
+ // 保存时强制 > 0;若仍有负数泄漏(缓存/迁移残留),按 0 处理避免按 1x 误扣。
+ if input.RateMultiplier < 0 {
+ input.RateMultiplier = 0
}
var breakdown *CostBreakdown
@@ -493,8 +494,9 @@ func (s *BillingService) computeTokenBreakdown(
rateMultiplier float64, serviceTier string,
applyLongCtx bool,
) *CostBreakdown {
- if rateMultiplier <= 0 {
- rateMultiplier = 1.0
+ // 保存时强制 > 0;若仍有负数泄漏,按 0 处理避免按 1x 误扣。
+ if rateMultiplier < 0 {
+ rateMultiplier = 0
}
inputPrice := pricing.InputPricePerToken
@@ -831,9 +833,9 @@ func (s *BillingService) CalculateImageCost(model string, imageSize string, imag
// 计算总费用
totalCost := unitPrice * float64(imageCount)
- // 应用倍率
- if rateMultiplier <= 0 {
- rateMultiplier = 1.0
+ // 应用倍率(保存时强制 > 0;负数按 0 处理避免按 1x 误扣)
+ if rateMultiplier < 0 {
+ rateMultiplier = 0
}
actualCost := totalCost * rateMultiplier
diff --git a/backend/internal/service/billing_service_image_test.go b/backend/internal/service/billing_service_image_test.go
index fa90f6bb..8d3ca987 100644
--- a/backend/internal/service/billing_service_image_test.go
+++ b/backend/internal/service/billing_service_image_test.go
@@ -90,13 +90,14 @@ func TestCalculateImageCost_NegativeCount(t *testing.T) {
require.Equal(t, 0.0, cost.ActualCost)
}
-// TestCalculateImageCost_ZeroRateMultiplier 测试费率倍数为 0 时默认使用 1.0
+// TestCalculateImageCost_ZeroRateMultiplier 锁定新行为:倍率 0 直接按 0 计费
+// (保存时已强制 > 0;若仍有 0 泄漏到计费层,零消耗比历史的 1.0 更安全)。
func TestCalculateImageCost_ZeroRateMultiplier(t *testing.T) {
svc := &BillingService{}
cost := svc.CalculateImageCost("gemini-3-pro-image", "2K", 1, nil, 0)
require.InDelta(t, 0.201, cost.TotalCost, 0.0001)
- require.InDelta(t, 0.201, cost.ActualCost, 0.0001) // 0 倍率当作 1.0 处理
+ require.InDelta(t, 0.0, cost.ActualCost, 1e-10)
}
// TestGetImageUnitPrice_GroupPriorityOverDefault 测试分组价格优先于默认价格
diff --git a/backend/internal/service/billing_service_rate_multiplier_test.go b/backend/internal/service/billing_service_rate_multiplier_test.go
new file mode 100644
index 00000000..83788196
--- /dev/null
+++ b/backend/internal/service/billing_service_rate_multiplier_test.go
@@ -0,0 +1,63 @@
+//go:build unit
+
+package service
+
+import (
+ "testing"
+
+ "github.com/stretchr/testify/require"
+)
+
+// TestCalculateCost_RateMultiplier_NegativeClampedToZero 锁定负数倍率被
+// 钳制为 0(而非历史上的 1.0),避免配置异常导致静默按标准价扣费。
+func TestCalculateCost_RateMultiplier_NegativeClampedToZero(t *testing.T) {
+ svc := newTestBillingService()
+ tokens := UsageTokens{InputTokens: 1000, OutputTokens: 500}
+
+ tests := []struct {
+ name string
+ multiplier float64
+ wantRatio float64 // ActualCost / TotalCost
+ }{
+ {"negative clamped to 0", -1.5, 0},
+ {"zero passes through as 0 (defense in depth)", 0, 0},
+ {"positive 2x applied", 2.0, 2.0},
+ {"positive 0.5x applied", 0.5, 0.5},
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ cost, err := svc.CalculateCost("claude-sonnet-4", tokens, tt.multiplier)
+ require.NoError(t, err)
+ require.Greater(t, cost.TotalCost, 0.0, "TotalCost should be non-zero")
+ require.InDelta(t, tt.wantRatio*cost.TotalCost, cost.ActualCost, 1e-9)
+ })
+ }
+}
+
+// TestCalculateImageCost_RateMultiplier_NegativeClampedToZero 图片按次计费路径
+// 同样遵循"负数 → 0"语义。
+func TestCalculateImageCost_RateMultiplier_NegativeClampedToZero(t *testing.T) {
+ svc := newTestBillingService()
+ price := 0.04
+ cfg := &ImagePriceConfig{Price1K: &price}
+
+ tests := []struct {
+ name string
+ multiplier float64
+ wantRatio float64
+ }{
+ {"negative clamped to 0", -0.5, 0},
+ {"zero passes through", 0, 0},
+ {"positive 3x applied", 3.0, 3.0},
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ cost := svc.CalculateImageCost("imagen-3", "1K", 2, cfg, tt.multiplier)
+ require.NotNil(t, cost)
+ require.Greater(t, cost.TotalCost, 0.0)
+ require.InDelta(t, tt.wantRatio*cost.TotalCost, cost.ActualCost, 1e-9)
+ })
+ }
+}
diff --git a/backend/internal/service/billing_service_test.go b/backend/internal/service/billing_service_test.go
index 2cf134e2..fc8361c7 100644
--- a/backend/internal/service/billing_service_test.go
+++ b/backend/internal/service/billing_service_test.go
@@ -71,34 +71,6 @@ func TestCalculateCost_RateMultiplier(t *testing.T) {
require.InDelta(t, cost1x.ActualCost*2, cost2x.ActualCost, 1e-10)
}
-func TestCalculateCost_ZeroMultiplierDefaultsToOne(t *testing.T) {
- svc := newTestBillingService()
-
- tokens := UsageTokens{InputTokens: 1000}
-
- costZero, err := svc.CalculateCost("claude-sonnet-4", tokens, 0)
- require.NoError(t, err)
-
- costOne, err := svc.CalculateCost("claude-sonnet-4", tokens, 1.0)
- require.NoError(t, err)
-
- require.InDelta(t, costOne.ActualCost, costZero.ActualCost, 1e-10)
-}
-
-func TestCalculateCost_NegativeMultiplierDefaultsToOne(t *testing.T) {
- svc := newTestBillingService()
-
- tokens := UsageTokens{InputTokens: 1000}
-
- costNeg, err := svc.CalculateCost("claude-sonnet-4", tokens, -1.0)
- require.NoError(t, err)
-
- costOne, err := svc.CalculateCost("claude-sonnet-4", tokens, 1.0)
- require.NoError(t, err)
-
- require.InDelta(t, costOne.ActualCost, costNeg.ActualCost, 1e-10)
-}
-
func TestGetModelPricing_FallbackMatchesByFamily(t *testing.T) {
svc := newTestBillingService()
diff --git a/backend/internal/service/billing_service_unified_test.go b/backend/internal/service/billing_service_unified_test.go
index 694c3384..e6a92d1a 100644
--- a/backend/internal/service/billing_service_unified_test.go
+++ b/backend/internal/service/billing_service_unified_test.go
@@ -147,40 +147,35 @@ func TestCalculateCostUnified_ImageMode(t *testing.T) {
require.Equal(t, string(BillingModeImage), cost.BillingMode)
}
-func TestCalculateCostUnified_RateMultiplierZeroDefaultsToOne(t *testing.T) {
+// TestCalculateCostUnified_RateMultiplierZeroProducesZero 锁定新行为:
+// 保存时强制 > 0;若 0 仍泄漏到计费层,按 0 计费(而非历史上的 1.0)。
+func TestCalculateCostUnified_RateMultiplierZeroProducesZero(t *testing.T) {
bs := newTestBillingService()
resolver := NewModelPricingResolver(nil, bs)
tokens := UsageTokens{InputTokens: 1000, OutputTokens: 500}
- costZero, err := bs.CalculateCostUnified(CostInput{
+ cost, err := bs.CalculateCostUnified(CostInput{
Ctx: context.Background(),
Model: "claude-sonnet-4",
Tokens: tokens,
- RateMultiplier: 0, // should default to 1.0
+ RateMultiplier: 0,
Resolver: resolver,
})
require.NoError(t, err)
-
- costOne, err := bs.CalculateCostUnified(CostInput{
- Ctx: context.Background(),
- Model: "claude-sonnet-4",
- Tokens: tokens,
- RateMultiplier: 1.0,
- Resolver: resolver,
- })
- require.NoError(t, err)
-
- require.InDelta(t, costOne.ActualCost, costZero.ActualCost, 1e-10)
+ require.Greater(t, cost.TotalCost, 0.0)
+ require.InDelta(t, 0.0, cost.ActualCost, 1e-10)
}
-func TestCalculateCostUnified_NegativeRateMultiplierDefaultsToOne(t *testing.T) {
+// TestCalculateCostUnified_NegativeRateMultiplierClampedToZero 锁定新行为:
+// 负数倍率按 0 计费,避免历史的 <=0 → 1.0 把配置异常静默按标准价扣费。
+func TestCalculateCostUnified_NegativeRateMultiplierClampedToZero(t *testing.T) {
bs := newTestBillingService()
resolver := NewModelPricingResolver(nil, bs)
tokens := UsageTokens{InputTokens: 1000}
- costNeg, err := bs.CalculateCostUnified(CostInput{
+ cost, err := bs.CalculateCostUnified(CostInput{
Ctx: context.Background(),
Model: "claude-sonnet-4",
Tokens: tokens,
@@ -188,17 +183,8 @@ func TestCalculateCostUnified_NegativeRateMultiplierDefaultsToOne(t *testing.T)
Resolver: resolver,
})
require.NoError(t, err)
-
- costOne, err := bs.CalculateCostUnified(CostInput{
- Ctx: context.Background(),
- Model: "claude-sonnet-4",
- Tokens: tokens,
- RateMultiplier: 1.0,
- Resolver: resolver,
- })
- require.NoError(t, err)
-
- require.InDelta(t, costOne.ActualCost, costNeg.ActualCost, 1e-10)
+ require.Greater(t, cost.TotalCost, 0.0)
+ require.InDelta(t, 0.0, cost.ActualCost, 1e-10)
}
func TestCalculateCostUnified_BillingModeFieldFilled(t *testing.T) {
diff --git a/backend/internal/service/group.go b/backend/internal/service/group.go
index 12262613..64434ae1 100644
--- a/backend/internal/service/group.go
+++ b/backend/internal/service/group.go
@@ -76,10 +76,6 @@ func (g *Group) IsSubscriptionType() bool {
return g.SubscriptionType == SubscriptionTypeSubscription
}
-func (g *Group) IsFreeSubscription() bool {
- return g.IsSubscriptionType() && g.RateMultiplier == 0
-}
-
func (g *Group) HasDailyLimit() bool {
return g.DailyLimitUSD != nil && *g.DailyLimitUSD > 0
}
diff --git a/backend/internal/service/openai_gateway_record_usage_test.go b/backend/internal/service/openai_gateway_record_usage_test.go
index e6fa94aa..6fa8a5bd 100644
--- a/backend/internal/service/openai_gateway_record_usage_test.go
+++ b/backend/internal/service/openai_gateway_record_usage_test.go
@@ -1031,7 +1031,7 @@ func TestOpenAIGatewayServiceRecordUsage_SubscriptionBillingSetsSubscriptionFiel
Model: "gpt-5.1",
Duration: time.Second,
},
- APIKey: &APIKey{ID: 100, GroupID: i64p(88), Group: &Group{ID: 88, SubscriptionType: SubscriptionTypeSubscription}},
+ APIKey: &APIKey{ID: 100, GroupID: i64p(88), Group: &Group{ID: 88, SubscriptionType: SubscriptionTypeSubscription, RateMultiplier: 1.0}},
User: &User{ID: 200},
Account: &Account{ID: 300},
Subscription: subscription,
diff --git a/frontend/src/components/admin/group/GroupRateMultipliersModal.vue b/frontend/src/components/admin/group/GroupRateMultipliersModal.vue
index bf79bea2..41b2e63c 100644
--- a/frontend/src/components/admin/group/GroupRateMultipliersModal.vue
+++ b/frontend/src/components/admin/group/GroupRateMultipliersModal.vue
@@ -166,7 +166,7 @@
Date: Fri, 17 Apr 2026 22:07:15 +0800
Subject: [PATCH 005/189] feat(gateway): raise upstream response read limit 8MB
-> 128MB (configurable)
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
图片生成 API 返回的 base64 内联图响应经常超过 8MB 单次读取上限,被
ReadUpstreamResponseBody 拦截成 502 upstream_error。
单张 4K PNG base64 最坏约 67MB,多张候选图或 imageSize=4K 的 image_generation
一次请求能轻松到 30MB+。把默认上限提到 128MB 能覆盖 2-3 张 4K 图,相对
请求体上限 256MB 仍有缓冲;同时抽出 config.DefaultUpstreamResponseReadMaxBytes
共享常量,viper 默认值和 service 层兜底共用,消除两处同步魔法数字。
仍可通过 gateway.upstream_response_read_max_bytes 配置项覆盖。
---
backend/internal/config/config.go | 7 ++++++-
backend/internal/service/upstream_response_limit.go | 4 +++-
2 files changed, 9 insertions(+), 2 deletions(-)
diff --git a/backend/internal/config/config.go b/backend/internal/config/config.go
index dd9a4e58..15592905 100644
--- a/backend/internal/config/config.go
+++ b/backend/internal/config/config.go
@@ -52,6 +52,11 @@ const (
ConnectionPoolIsolationAccountProxy = "account_proxy"
)
+// DefaultUpstreamResponseReadMaxBytes 上游非流式响应体的默认读取上限。
+// 128 MB 可容纳 2-3 张 4K PNG(base64 膨胀 33%,单张 4K PNG 最坏约 67MB base64)。
+// 可通过 gateway.upstream_response_read_max_bytes 配置项覆盖。
+const DefaultUpstreamResponseReadMaxBytes int64 = 128 * 1024 * 1024
+
type Config struct {
Server ServerConfig `mapstructure:"server"`
Log LogConfig `mapstructure:"log"`
@@ -1407,7 +1412,7 @@ func setDefaults() {
viper.SetDefault("gateway.antigravity_fallback_cooldown_minutes", 1)
viper.SetDefault("gateway.antigravity_extra_retries", 10)
viper.SetDefault("gateway.max_body_size", int64(256*1024*1024))
- viper.SetDefault("gateway.upstream_response_read_max_bytes", int64(8*1024*1024))
+ viper.SetDefault("gateway.upstream_response_read_max_bytes", DefaultUpstreamResponseReadMaxBytes)
viper.SetDefault("gateway.proxy_probe_response_read_max_bytes", int64(1024*1024))
viper.SetDefault("gateway.gemini_debug_response_headers", false)
viper.SetDefault("gateway.connection_pool_isolation", ConnectionPoolIsolationAccountProxy)
diff --git a/backend/internal/service/upstream_response_limit.go b/backend/internal/service/upstream_response_limit.go
index a0444d52..ddf0e818 100644
--- a/backend/internal/service/upstream_response_limit.go
+++ b/backend/internal/service/upstream_response_limit.go
@@ -12,7 +12,9 @@ import (
var ErrUpstreamResponseBodyTooLarge = errors.New("upstream response body too large")
-const defaultUpstreamResponseReadMaxBytes int64 = 8 * 1024 * 1024
+// defaultUpstreamResponseReadMaxBytes 源自 config.DefaultUpstreamResponseReadMaxBytes,
+// 仅在 cfg 为 nil 时作为兜底(测试或极端场景)。
+const defaultUpstreamResponseReadMaxBytes = config.DefaultUpstreamResponseReadMaxBytes
func resolveUpstreamResponseReadLimit(cfg *config.Config) int64 {
if cfg != nil && cfg.Gateway.UpstreamResponseReadMaxBytes > 0 {
From 61a008f7e4e0e019f00e63409389aa5aa1058b0a Mon Sep 17 00:00:00 2001
From: erio
Date: Fri, 17 Apr 2026 23:05:58 +0800
Subject: [PATCH 006/189] chore(payment): mark legacy AES ciphertext fallback
as deprecated
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
明文 JSON 已经是新写入的默认格式;保留 AES 密文读取仅为兼容迁移期间的旧
记录,一旦所有部署通过管理后台重存过一次即可删除。标记为 deprecated 并加
TODO,几个版本后统一清理掉:payment.Encrypt / payment.Decrypt、两处
decryptConfig 的 AES 分支、PaymentConfigService.encryptionKey 和
DefaultLoadBalancer.encryptionKey 字段。
---
backend/internal/payment/crypto.go | 10 ++++++++++
backend/internal/payment/load_balancer.go | 7 +++++++
backend/internal/service/payment_config_providers.go | 6 ++++++
3 files changed, 23 insertions(+)
diff --git a/backend/internal/payment/crypto.go b/backend/internal/payment/crypto.go
index 5467e50b..0581469d 100644
--- a/backend/internal/payment/crypto.go
+++ b/backend/internal/payment/crypto.go
@@ -16,6 +16,11 @@ const AES256KeySize = 32
// Encrypt encrypts plaintext using AES-256-GCM with the given 32-byte key.
// The output format is "iv:authTag:ciphertext" where each component is base64-encoded,
// matching the Node.js crypto.ts format for cross-compatibility.
+//
+// Deprecated: payment provider configs are now stored as plaintext JSON.
+// This function is kept only for seeding legacy ciphertext in tests and for
+// the transitional Decrypt fallback. Scheduled for removal after all live
+// deployments complete migration by re-saving their configs.
func Encrypt(plaintext string, key []byte) (string, error) {
if len(key) != AES256KeySize {
return "", fmt.Errorf("encryption key must be %d bytes, got %d", AES256KeySize, len(key))
@@ -54,6 +59,11 @@ func Encrypt(plaintext string, key []byte) (string, error) {
// Decrypt decrypts a ciphertext string produced by Encrypt.
// The input format is "iv:authTag:ciphertext" where each component is base64-encoded.
+//
+// Deprecated: payment provider configs are now stored as plaintext JSON.
+// This function remains only as a read-path fallback for pre-migration
+// ciphertext records. Scheduled for removal once all deployments re-save
+// their provider configs through the admin UI.
func Decrypt(ciphertext string, key []byte) (string, error) {
if len(key) != AES256KeySize {
return "", fmt.Errorf("encryption key must be %d bytes, got %d", AES256KeySize, len(key))
diff --git a/backend/internal/payment/load_balancer.go b/backend/internal/payment/load_balancer.go
index 52a1b011..ec244cd6 100644
--- a/backend/internal/payment/load_balancer.go
+++ b/backend/internal/payment/load_balancer.go
@@ -283,6 +283,11 @@ func (lb *DefaultLoadBalancer) buildSelection(selected *dbent.PaymentProviderIns
// Unreadable values (legacy ciphertext without a valid key, or malformed data)
// are treated as empty so the service keeps running while the admin re-enters
// the config via the UI.
+//
+// TODO(deprecated-legacy-ciphertext): The AES fallback branch below is a
+// transitional compatibility shim for pre-plaintext records. Remove it (and
+// the encryptionKey field + the Decrypt import) after a few releases once all
+// live deployments have re-saved their provider configs through the UI.
func (lb *DefaultLoadBalancer) decryptConfig(stored string) (map[string]string, error) {
if stored == "" {
return nil, nil
@@ -291,7 +296,9 @@ func (lb *DefaultLoadBalancer) decryptConfig(stored string) (map[string]string,
if err := json.Unmarshal([]byte(stored), &config); err == nil {
return config, nil
}
+ // Deprecated: legacy AES-256-GCM ciphertext fallback — scheduled for removal.
if len(lb.encryptionKey) == AES256KeySize {
+ //nolint:staticcheck // SA1019: intentional legacy fallback, scheduled for removal
if plaintext, err := Decrypt(stored, lb.encryptionKey); err == nil {
if err := json.Unmarshal([]byte(plaintext), &config); err == nil {
return config, nil
diff --git a/backend/internal/service/payment_config_providers.go b/backend/internal/service/payment_config_providers.go
index 59337ad6..8e470525 100644
--- a/backend/internal/service/payment_config_providers.go
+++ b/backend/internal/service/payment_config_providers.go
@@ -296,6 +296,10 @@ func (s *PaymentConfigService) mergeConfig(ctx context.Context, id int64, newCon
// ("iv:authTag:ciphertext"). Values that cannot be parsed as either — including
// legacy ciphertext with no/invalid TOTP_ENCRYPTION_KEY — are treated as empty,
// letting the admin re-enter the config via the UI to complete the migration.
+//
+// TODO(deprecated-legacy-ciphertext): The AES fallback branch is a transitional
+// shim for pre-plaintext records. Remove it (and the encryptionKey field) after
+// a few releases once all live deployments have re-saved their provider configs.
func (s *PaymentConfigService) decryptConfig(stored string) (map[string]string, error) {
if stored == "" {
return nil, nil
@@ -304,7 +308,9 @@ func (s *PaymentConfigService) decryptConfig(stored string) (map[string]string,
if err := json.Unmarshal([]byte(stored), &cfg); err == nil {
return cfg, nil
}
+ // Deprecated: legacy AES-256-GCM ciphertext fallback — scheduled for removal.
if len(s.encryptionKey) == payment.AES256KeySize {
+ //nolint:staticcheck // SA1019: intentional legacy fallback, scheduled for removal
if plaintext, err := payment.Decrypt(stored, s.encryptionKey); err == nil {
if err := json.Unmarshal([]byte(plaintext), &cfg); err == nil {
return cfg, nil
From 37123cef8fc14ca3d4b336a709167075ff8ea0df Mon Sep 17 00:00:00 2001
From: erio
Date: Sat, 18 Apr 2026 14:42:55 +0800
Subject: [PATCH 007/189] docs(payment): add Kyren Topup as international
EasyPay provider option
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
Restructure the EasyPay recommendation block to present two options side
by side so users can pick by funding channel and settlement currency:
- Domestic / CNY — ZPay: official Alipay/WeChat API with 1.6% fee and
T+1 automatic settlement (existing recommendation, expanded with fee
and settlement details).
- International / USDT or USD — Kyren Topup (https://kyren.top): global
payment stack supporting WeChat Pay and Alipay with local-currency
checkout, USD settlement, and USDT/USD withdrawal. Fees: WeChat 2%,
Alipay 2.5%, withdrawal 0.1% ($40 min / $150 max). Fills the gap for
users who cannot use domestic Chinese channels or tolerate Stripe's
6%+ fees.
Both recommendations share a single disclaimer at the end. The Chinese
heading uses "易支付" while the English one keeps "EasyPay".
---
docs/PAYMENT.md | 7 ++++++-
docs/PAYMENT_CN.md | 7 ++++++-
2 files changed, 12 insertions(+), 2 deletions(-)
diff --git a/docs/PAYMENT.md b/docs/PAYMENT.md
index b66a791c..2735aea3 100644
--- a/docs/PAYMENT.md
+++ b/docs/PAYMENT.md
@@ -28,7 +28,12 @@ Sub2API has a built-in payment system that enables user self-service top-up with
> Alipay/WeChat Pay direct and EasyPay can coexist. Direct channels connect to payment APIs directly with lower fees; EasyPay aggregates through third-party platforms with easier setup.
-> **EasyPay Recommendation**: [ZPay](https://z-pay.cn/?uid=23808) (`https://z-pay.cn/?uid=23808`) is recommended as an EasyPay provider (link contains the referral code of [Sub2ApiPay](https://github.com/touwaeriol/sub2apipay) original author [@touwaeriol](https://github.com/touwaeriol) — feel free to remove it). ZPay supports **individual users** (no business license required) with up to 10,000 CNY daily transactions; business-licensed accounts have no limit. Please evaluate the security, reliability, and compliance of any third-party payment provider on your own — this project does not endorse or guarantee any of them.
+> **EasyPay Provider Recommendations**: Both options below are third-party aggregators compatible with the EasyPay protocol. Pick based on the funding channel and settlement currency you need:
+>
+> - **Domestic channel / CNY settlement** — [ZPay](https://z-pay.cn/?uid=23808) (`https://z-pay.cn/?uid=23808`): direct integration with official Alipay / WeChat Pay APIs, fee **1.6%**; funds go straight to the merchant account with **T+1 automatic settlement**. Supports **individual users** (no business license required) with up to 10,000 CNY daily transactions; business-licensed accounts have no limit. Link contains the referral code of [Sub2ApiPay](https://github.com/touwaeriol/sub2apipay) original author [@touwaeriol](https://github.com/touwaeriol) — feel free to remove it.
+> - **International channel / USDT or USD settlement** — [Kyren Topup](https://kyren.top/?code=SUB2API) (`https://kyren.top/?code=SUB2API`): a ready-to-launch global payment stack for AI startups with WeChat Pay and Alipay support, local-currency checkout, and USD settlement. Fees: WeChat 2%, Alipay 2.5%; withdrawal 0.1% (min $40, max $150), settled in **USDT or USD**. No qualification review required — sign up and use immediately, making it the lowest barrier to entry. Withdrawal threshold is relatively high, recommended for users **who do not use domestic Chinese payment channels, cannot tolerate Stripe's 6%+ fees, have high transaction volume, and have USD or USDT channels to receive withdrawn funds**. Link contains Sub2Api author [@Wei-Shaw](https://github.com/Wei-Shaw)'s referral code — feel free to remove it.
+>
+> Please evaluate the security, reliability, and compliance of any third-party payment provider on your own — this project does not endorse or guarantee any of them.
---
diff --git a/docs/PAYMENT_CN.md b/docs/PAYMENT_CN.md
index 9d96557f..474700e5 100644
--- a/docs/PAYMENT_CN.md
+++ b/docs/PAYMENT_CN.md
@@ -28,7 +28,12 @@ Sub2API 内置支付系统,支持用户自助充值,无需部署独立的支
> 支付宝官方 / 微信官方与 EasyPay 可以共存。官方渠道直接对接 API,资金直达商户账户,手续费更低;EasyPay 通过第三方平台聚合,接入门槛更低。
-> **EasyPay 推荐**:个人推荐 [ZPay](https://z-pay.cn/?uid=23808)(`https://z-pay.cn/?uid=23808`)作为 EasyPay 服务商(链接含 [Sub2ApiPay](https://github.com/touwaeriol/sub2apipay) 原作者 [@touwaeriol](https://github.com/touwaeriol) 的邀请码,介意可去掉)。ZPay 支持**个人用户**(无营业执照)每日 1 万元以内交易;拥有营业执照则无限额。支付渠道的安全性、稳定性及合规性请自行鉴别,本项目不对任何第三方支付服务商做担保或背书。
+> **易支付服务商推荐**:以下两家均为兼容易支付协议的第三方聚合支付,按资金通道与结算方式选择:
+>
+> - **国内渠道 / 人民币结算** — [ZPay](https://z-pay.cn/?uid=23808)(`https://z-pay.cn/?uid=23808`):支付宝 / 微信官方 API 直连,手续费 **1.6%**;资金直达商家账户,**T+1 自动到账**。支持**个人用户**(无营业执照)每日 1 万元以内交易;拥有营业执照则无限额。链接含 [Sub2ApiPay](https://github.com/touwaeriol/sub2apipay) 原作者 [@touwaeriol](https://github.com/touwaeriol) 的邀请码,介意可去掉。
+> - **国际渠道 / USDT 或美元结算** — [启润支付](https://kyren.top/?code=SUB2API)(`https://kyren.top/?code=SUB2API`):为 AI 项目提供低门槛国际收款通道,支持国际版微信支付与支付宝,本地货币支付、美元结算。手续费:微信 2%、支付宝 2.5%;提现 0.1%(最低 40 美元、最高 150 美元),以 **USDT 或美元**到账。无资质审核、注册即用,使用门槛最低;提现门槛略高,适合**不使用国内支付渠道、无法接受 Stripe 高达 6%+ 手续费、流水较大,且拥有美元或 USDT 渠道可接收提现资金**的用户。链接含 Sub2Api 作者 [@Wei-Shaw](https://github.com/Wei-Shaw) 邀请码,介意可去掉。
+>
+> 支付渠道的安全性、稳定性及合规性请自行鉴别,本项目不对任何第三方支付服务商做担保或背书。
---
From 6ae1cc8f3f893ea1c6896ee045c8e26ff4b2f16d Mon Sep 17 00:00:00 2001
From: erio
Date: Sat, 18 Apr 2026 14:45:25 +0800
Subject: [PATCH 008/189] =?UTF-8?q?docs:=20use=20=E6=98=93=E6=94=AF?=
=?UTF-8?q?=E4=BB=98=20in=20Chinese=20coexistence=20note?=
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
---
docs/PAYMENT_CN.md | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/docs/PAYMENT_CN.md b/docs/PAYMENT_CN.md
index 474700e5..11325a80 100644
--- a/docs/PAYMENT_CN.md
+++ b/docs/PAYMENT_CN.md
@@ -26,7 +26,7 @@ Sub2API 内置支付系统,支持用户自助充值,无需部署独立的支
| **微信官方** | Native 扫码支付、H5 支付 | 直接对接微信支付 APIv3,移动端优先 H5 |
| **Stripe** | 银行卡、支付宝、微信支付、Link 等 | 国际支付,支持多币种 |
-> 支付宝官方 / 微信官方与 EasyPay 可以共存。官方渠道直接对接 API,资金直达商户账户,手续费更低;EasyPay 通过第三方平台聚合,接入门槛更低。
+> 支付宝官方 / 微信官方与易支付可以共存。官方渠道直接对接 API,资金直达商户账户,手续费更低;易支付通过第三方平台聚合,接入门槛更低。
> **易支付服务商推荐**:以下两家均为兼容易支付协议的第三方聚合支付,按资金通道与结算方式选择:
>
From 0c538a584fbf48414fe3c11b9a3c4ddb2dad96d5 Mon Sep 17 00:00:00 2001
From: erio
Date: Sat, 18 Apr 2026 14:48:42 +0800
Subject: [PATCH 009/189] docs: note Kyren Topup $200 account fee waived via
referral link
---
docs/PAYMENT.md | 2 +-
docs/PAYMENT_CN.md | 2 +-
2 files changed, 2 insertions(+), 2 deletions(-)
diff --git a/docs/PAYMENT.md b/docs/PAYMENT.md
index 2735aea3..755b313a 100644
--- a/docs/PAYMENT.md
+++ b/docs/PAYMENT.md
@@ -31,7 +31,7 @@ Sub2API has a built-in payment system that enables user self-service top-up with
> **EasyPay Provider Recommendations**: Both options below are third-party aggregators compatible with the EasyPay protocol. Pick based on the funding channel and settlement currency you need:
>
> - **Domestic channel / CNY settlement** — [ZPay](https://z-pay.cn/?uid=23808) (`https://z-pay.cn/?uid=23808`): direct integration with official Alipay / WeChat Pay APIs, fee **1.6%**; funds go straight to the merchant account with **T+1 automatic settlement**. Supports **individual users** (no business license required) with up to 10,000 CNY daily transactions; business-licensed accounts have no limit. Link contains the referral code of [Sub2ApiPay](https://github.com/touwaeriol/sub2apipay) original author [@touwaeriol](https://github.com/touwaeriol) — feel free to remove it.
-> - **International channel / USDT or USD settlement** — [Kyren Topup](https://kyren.top/?code=SUB2API) (`https://kyren.top/?code=SUB2API`): a ready-to-launch global payment stack for AI startups with WeChat Pay and Alipay support, local-currency checkout, and USD settlement. Fees: WeChat 2%, Alipay 2.5%; withdrawal 0.1% (min $40, max $150), settled in **USDT or USD**. No qualification review required — sign up and use immediately, making it the lowest barrier to entry. Withdrawal threshold is relatively high, recommended for users **who do not use domestic Chinese payment channels, cannot tolerate Stripe's 6%+ fees, have high transaction volume, and have USD or USDT channels to receive withdrawn funds**. Link contains Sub2Api author [@Wei-Shaw](https://github.com/Wei-Shaw)'s referral code — feel free to remove it.
+> - **International channel / USDT or USD settlement** — [Kyren Topup](https://kyren.top/?code=SUB2API) (`https://kyren.top/?code=SUB2API`): a ready-to-launch global payment stack for AI startups with WeChat Pay and Alipay support, local-currency checkout, and USD settlement. Fees: WeChat 2%, Alipay 2.5%; withdrawal 0.1% (min $40, max $150), settled in **USDT or USD**. No qualification review required — sign up and use immediately, making it the lowest barrier to entry. Withdrawal threshold is relatively high, recommended for users **who do not use domestic Chinese payment channels, cannot tolerate Stripe's 6%+ fees, have high transaction volume, and have USD or USDT channels to receive withdrawn funds**. Kyren Topup charges a $200 account opening fee; signing up via this link (which contains Sub2Api author [@Wei-Shaw](https://github.com/Wei-Shaw)'s referral code) **waives the opening fee**. Feel free to remove it if you prefer.
>
> Please evaluate the security, reliability, and compliance of any third-party payment provider on your own — this project does not endorse or guarantee any of them.
diff --git a/docs/PAYMENT_CN.md b/docs/PAYMENT_CN.md
index 11325a80..aca3c866 100644
--- a/docs/PAYMENT_CN.md
+++ b/docs/PAYMENT_CN.md
@@ -31,7 +31,7 @@ Sub2API 内置支付系统,支持用户自助充值,无需部署独立的支
> **易支付服务商推荐**:以下两家均为兼容易支付协议的第三方聚合支付,按资金通道与结算方式选择:
>
> - **国内渠道 / 人民币结算** — [ZPay](https://z-pay.cn/?uid=23808)(`https://z-pay.cn/?uid=23808`):支付宝 / 微信官方 API 直连,手续费 **1.6%**;资金直达商家账户,**T+1 自动到账**。支持**个人用户**(无营业执照)每日 1 万元以内交易;拥有营业执照则无限额。链接含 [Sub2ApiPay](https://github.com/touwaeriol/sub2apipay) 原作者 [@touwaeriol](https://github.com/touwaeriol) 的邀请码,介意可去掉。
-> - **国际渠道 / USDT 或美元结算** — [启润支付](https://kyren.top/?code=SUB2API)(`https://kyren.top/?code=SUB2API`):为 AI 项目提供低门槛国际收款通道,支持国际版微信支付与支付宝,本地货币支付、美元结算。手续费:微信 2%、支付宝 2.5%;提现 0.1%(最低 40 美元、最高 150 美元),以 **USDT 或美元**到账。无资质审核、注册即用,使用门槛最低;提现门槛略高,适合**不使用国内支付渠道、无法接受 Stripe 高达 6%+ 手续费、流水较大,且拥有美元或 USDT 渠道可接收提现资金**的用户。链接含 Sub2Api 作者 [@Wei-Shaw](https://github.com/Wei-Shaw) 邀请码,介意可去掉。
+> - **国际渠道 / USDT 或美元结算** — [启润支付](https://kyren.top/?code=SUB2API)(`https://kyren.top/?code=SUB2API`):为 AI 项目提供低门槛国际收款通道,支持国际版微信支付与支付宝,本地货币支付、美元结算。手续费:微信 2%、支付宝 2.5%;提现 0.1%(最低 40 美元、最高 150 美元),以 **USDT 或美元**到账。无资质审核、注册即用,使用门槛最低;提现门槛略高,适合**不使用国内支付渠道、无法接受 Stripe 高达 6%+ 手续费、流水较大,且拥有美元或 USDT 渠道可接收提现资金**的用户。启润支付开户费 200 美元,通过本链接注册(含 Sub2Api 作者 [@Wei-Shaw](https://github.com/Wei-Shaw) 邀请码)可**免开户费**,介意可去掉。
>
> 支付渠道的安全性、稳定性及合规性请自行鉴别,本项目不对任何第三方支付服务商做担保或背书。
From c3cb0280ef8dff949ad7ad8f84793872b9304686 Mon Sep 17 00:00:00 2001
From: erio
Date: Sun, 19 Apr 2026 01:40:25 +0800
Subject: [PATCH 010/189] fix(payment): alipay redirect-only flow, H5 detection
and popup sizing
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
The native Alipay provider previously tried to embed the payment page
URL into a QR code on the client — the URL is not a scannable payload
so the QR never worked. Merchants also hit a H5 detection mismatch
whenever the backend UA sniffer missed iPadOS 13+ or embedded browsers,
and the popup window was too small for Alipay's standard checkout
layout (QR + account-login panel on the right), forcing the user to
scroll horizontally and vertically.
Changes:
Backend
- alipay.go: drop QR-on-URL path. Use redirect-only flow —
alipay.trade.page.pay for PC (returns a gateway URL the browser
opens in a new window) and alipay.trade.wap.pay for H5 (returns a
URL the browser jumps to). Both flows produce pages on
openapi.alipaydev.com / excashier.alipay.com; the client never
renders a QR itself.
- payment_handler.go: add optional is_mobile bool to
CreateOrderRequest so the frontend can declare the device
explicitly. Server still falls back to UA sniffing when absent.
Frontend
- types/payment.ts, PaymentView.vue: declare is_mobile in
CreateOrderRequest and pass the computed isMobileDevice() value.
- providerConfig.ts: replace the two fixed POPUP_WINDOW_FEATURES
constants with getPaymentPopupFeatures(), which prefers 1250×900
(Alipay's checkout footprint), clamps to window.screen.avail* and
centers the popup so it never overflows on smaller laptops.
- PaymentQRDialog.vue, PaymentStatusPanel.vue, StripePaymentInline.vue,
PaymentView.vue: use the new helper at all popup call sites.
---
backend/internal/handler/payment_handler.go | 10 +++-
backend/internal/payment/provider/alipay.go | 48 ++++++++++---------
.../components/payment/PaymentQRDialog.vue | 4 +-
.../components/payment/PaymentStatusPanel.vue | 4 +-
.../payment/StripePaymentInline.vue | 4 +-
.../src/components/payment/providerConfig.ts | 21 ++++++--
frontend/src/types/payment.ts | 1 +
frontend/src/views/user/PaymentView.vue | 5 +-
8 files changed, 62 insertions(+), 35 deletions(-)
diff --git a/backend/internal/handler/payment_handler.go b/backend/internal/handler/payment_handler.go
index 1ddb8ae2..854dca54 100644
--- a/backend/internal/handler/payment_handler.go
+++ b/backend/internal/handler/payment_handler.go
@@ -206,6 +206,10 @@ type CreateOrderRequest struct {
PaymentType string `json:"payment_type" binding:"required"`
OrderType string `json:"order_type"`
PlanID int64 `json:"plan_id"`
+ // IsMobile lets the frontend declare its mobile status directly. When
+ // nil we fall back to User-Agent heuristics (which miss iPadOS / some
+ // embedded browsers that strip the "Mobile" keyword).
+ IsMobile *bool `json:"is_mobile,omitempty"`
}
// CreateOrder creates a new payment order.
@@ -222,12 +226,16 @@ func (h *PaymentHandler) CreateOrder(c *gin.Context) {
return
}
+ mobile := isMobile(c)
+ if req.IsMobile != nil {
+ mobile = *req.IsMobile
+ }
result, err := h.paymentService.CreateOrder(c.Request.Context(), service.CreateOrderRequest{
UserID: subject.UserID,
Amount: req.Amount,
PaymentType: req.PaymentType,
ClientIP: c.ClientIP(),
- IsMobile: isMobile(c),
+ IsMobile: mobile,
SrcHost: c.Request.Host,
SrcURL: c.Request.Referer(),
OrderType: req.OrderType,
diff --git a/backend/internal/payment/provider/alipay.go b/backend/internal/payment/provider/alipay.go
index af8a90c6..fe8ea89c 100644
--- a/backend/internal/payment/provider/alipay.go
+++ b/backend/internal/payment/provider/alipay.go
@@ -15,8 +15,8 @@ import (
// Alipay product codes.
const (
- alipayProductCodePagePay = "FAST_INSTANT_TRADE_PAY"
alipayProductCodeWapPay = "QUICK_WAP_WAY"
+ alipayProductCodePagePay = "FAST_INSTANT_TRADE_PAY"
)
// Alipay response constants.
@@ -79,7 +79,12 @@ func (a *Alipay) SupportedTypes() []payment.PaymentType {
return []payment.PaymentType{payment.TypeAlipay}
}
-// CreatePayment creates an Alipay payment page URL.
+// CreatePayment creates an Alipay payment using redirect-only flow:
+// - Mobile (H5): alipay.trade.wap.pay — returns a URL the browser jumps to.
+// - PC: alipay.trade.page.pay — returns a gateway URL the browser opens in a
+// new window; Alipay's own page then shows login/QR. We intentionally do
+// NOT encode the URL into a QR on the client (it isn't a scannable payload
+// and would produce an invalid scan result).
func (a *Alipay) CreatePayment(_ context.Context, req payment.CreatePaymentRequest) (*payment.CreatePaymentResponse, error) {
client, err := a.getClient()
if err != nil {
@@ -96,31 +101,31 @@ func (a *Alipay) CreatePayment(_ context.Context, req payment.CreatePaymentReque
}
if req.IsMobile {
- return a.createTrade(client, req, notifyURL, returnURL, true)
+ return a.createWapTrade(client, req, notifyURL, returnURL)
}
- return a.createTrade(client, req, notifyURL, returnURL, false)
+ return a.createPagePayTrade(client, req, notifyURL, returnURL)
}
-func (a *Alipay) createTrade(client *alipay.Client, req payment.CreatePaymentRequest, notifyURL, returnURL string, isMobile bool) (*payment.CreatePaymentResponse, error) {
- if isMobile {
- param := alipay.TradeWapPay{}
- param.OutTradeNo = req.OrderID
- param.TotalAmount = req.Amount
- param.Subject = req.Subject
- param.ProductCode = alipayProductCodeWapPay
- param.NotifyURL = notifyURL
- param.ReturnURL = returnURL
+func (a *Alipay) createWapTrade(client *alipay.Client, req payment.CreatePaymentRequest, notifyURL, returnURL string) (*payment.CreatePaymentResponse, error) {
+ param := alipay.TradeWapPay{}
+ param.OutTradeNo = req.OrderID
+ param.TotalAmount = req.Amount
+ param.Subject = req.Subject
+ param.ProductCode = alipayProductCodeWapPay
+ param.NotifyURL = notifyURL
+ param.ReturnURL = returnURL
- payURL, err := client.TradeWapPay(param)
- if err != nil {
- return nil, fmt.Errorf("alipay TradeWapPay: %w", err)
- }
- return &payment.CreatePaymentResponse{
- TradeNo: req.OrderID,
- PayURL: payURL.String(),
- }, nil
+ payURL, err := client.TradeWapPay(param)
+ if err != nil {
+ return nil, fmt.Errorf("alipay TradeWapPay: %w", err)
}
+ return &payment.CreatePaymentResponse{
+ TradeNo: req.OrderID,
+ PayURL: payURL.String(),
+ }, nil
+}
+func (a *Alipay) createPagePayTrade(client *alipay.Client, req payment.CreatePaymentRequest, notifyURL, returnURL string) (*payment.CreatePaymentResponse, error) {
param := alipay.TradePagePay{}
param.OutTradeNo = req.OrderID
param.TotalAmount = req.Amount
@@ -136,7 +141,6 @@ func (a *Alipay) createTrade(client *alipay.Client, req payment.CreatePaymentReq
return &payment.CreatePaymentResponse{
TradeNo: req.OrderID,
PayURL: payURL.String(),
- QRCode: payURL.String(),
}, nil
}
diff --git a/frontend/src/components/payment/PaymentQRDialog.vue b/frontend/src/components/payment/PaymentQRDialog.vue
index b9026e78..db90c3b6 100644
--- a/frontend/src/components/payment/PaymentQRDialog.vue
+++ b/frontend/src/components/payment/PaymentQRDialog.vue
@@ -79,7 +79,7 @@ import { usePaymentStore } from '@/stores/payment'
import { useAppStore } from '@/stores'
import { paymentAPI } from '@/api/payment'
import { extractApiErrorMessage } from '@/utils/apiError'
-import { POPUP_WINDOW_FEATURES } from '@/components/payment/providerConfig'
+import { getPaymentPopupFeatures } from '@/components/payment/providerConfig'
import type { PaymentOrder } from '@/types/payment'
import QRCode from 'qrcode'
import alipayIcon from '@/assets/icons/alipay.svg'
@@ -147,7 +147,7 @@ function getLogoForType(): string | null {
function reopenPopup() {
if (props.payUrl) {
- window.open(props.payUrl, 'paymentPopup', POPUP_WINDOW_FEATURES)
+ window.open(props.payUrl, 'paymentPopup', getPaymentPopupFeatures())
}
}
diff --git a/frontend/src/components/payment/PaymentStatusPanel.vue b/frontend/src/components/payment/PaymentStatusPanel.vue
index 974dee66..17541e59 100644
--- a/frontend/src/components/payment/PaymentStatusPanel.vue
+++ b/frontend/src/components/payment/PaymentStatusPanel.vue
@@ -125,7 +125,7 @@ import { usePaymentStore } from '@/stores/payment'
import { useAppStore } from '@/stores'
import { paymentAPI } from '@/api/payment'
import { extractApiErrorMessage } from '@/utils/apiError'
-import { POPUP_WINDOW_FEATURES } from '@/components/payment/providerConfig'
+import { getPaymentPopupFeatures } from '@/components/payment/providerConfig'
import type { PaymentOrder } from '@/types/payment'
import Icon from '@/components/icons/Icon.vue'
import QRCode from 'qrcode'
@@ -194,7 +194,7 @@ const countdownDisplay = computed(() => {
function reopenPopup() {
if (props.payUrl) {
- window.open(props.payUrl, 'paymentPopup', POPUP_WINDOW_FEATURES)
+ window.open(props.payUrl, 'paymentPopup', getPaymentPopupFeatures())
}
}
diff --git a/frontend/src/components/payment/StripePaymentInline.vue b/frontend/src/components/payment/StripePaymentInline.vue
index b8fd55ef..3ddff8c8 100644
--- a/frontend/src/components/payment/StripePaymentInline.vue
+++ b/frontend/src/components/payment/StripePaymentInline.vue
@@ -70,7 +70,7 @@ import { useRouter } from 'vue-router'
import { extractApiErrorMessage } from '@/utils/apiError'
import { paymentAPI } from '@/api/payment'
import { useAppStore } from '@/stores'
-import { STRIPE_POPUP_WINDOW_FEATURES } from '@/components/payment/providerConfig'
+import { getPaymentPopupFeatures } from '@/components/payment/providerConfig'
import type { Stripe, StripeElements } from '@stripe/stripe-js'
import Icon from '@/components/icons/Icon.vue'
@@ -151,7 +151,7 @@ async function handlePay() {
amount: String(props.payAmount),
},
}).href
- const popup = window.open(popupUrl, 'paymentPopup', STRIPE_POPUP_WINDOW_FEATURES)
+ const popup = window.open(popupUrl, 'paymentPopup', getPaymentPopupFeatures())
const onReady = (event: MessageEvent) => {
if (event.source !== popup || event.data?.type !== 'STRIPE_POPUP_READY') return
diff --git a/frontend/src/components/payment/providerConfig.ts b/frontend/src/components/payment/providerConfig.ts
index a83787fd..bf2d4177 100644
--- a/frontend/src/components/payment/providerConfig.ts
+++ b/frontend/src/components/payment/providerConfig.ts
@@ -43,11 +43,24 @@ export const METHOD_ORDER = ['alipay', 'alipay_direct', 'wxpay', 'wxpay_direct',
export const PAYMENT_MODE_QRCODE = 'qrcode'
export const PAYMENT_MODE_POPUP = 'popup'
-/** Window features for payment popup windows */
-export const POPUP_WINDOW_FEATURES = 'width=1000,height=750,left=100,top=80,scrollbars=yes,resizable=yes'
+/** Preferred popup size for payment gateways. Alipay's standard checkout
+ * (QR + account login panel) needs ~1200×900 to render without any scrolling. */
+const PAYMENT_POPUP_PREFERRED_WIDTH = 1250
+const PAYMENT_POPUP_PREFERRED_HEIGHT = 900
-/** Wider popup for Stripe redirect methods (Alipay checkout page needs ~1200px) */
-export const STRIPE_POPUP_WINDOW_FEATURES = 'width=1250,height=780,left=80,top=60,scrollbars=yes,resizable=yes'
+/** Build a window.open features string sized to fit within the current screen
+ * while preferring the above dimensions. Centers the popup on the available
+ * work area so nothing is clipped on smaller laptop displays. */
+export function getPaymentPopupFeatures(): string {
+ const screen = typeof window !== 'undefined' ? window.screen : null
+ const availW = screen?.availWidth ?? PAYMENT_POPUP_PREFERRED_WIDTH
+ const availH = screen?.availHeight ?? PAYMENT_POPUP_PREFERRED_HEIGHT
+ const width = Math.min(PAYMENT_POPUP_PREFERRED_WIDTH, availW - 40)
+ const height = Math.min(PAYMENT_POPUP_PREFERRED_HEIGHT, availH - 40)
+ const left = Math.max(0, Math.floor((availW - width) / 2))
+ const top = Math.max(0, Math.floor((availH - height) / 2))
+ return `width=${width},height=${height},left=${left},top=${top},scrollbars=yes,resizable=yes`
+}
/** Webhook paths for each provider (relative to origin). */
export const WEBHOOK_PATHS: Record = {
diff --git a/frontend/src/types/payment.ts b/frontend/src/types/payment.ts
index 7ecbb9a9..6f2eec51 100644
--- a/frontend/src/types/payment.ts
+++ b/frontend/src/types/payment.ts
@@ -154,6 +154,7 @@ export interface CreateOrderRequest {
payment_type: string
order_type: string
plan_id?: number
+ is_mobile?: boolean
}
export interface CreateOrderResult {
diff --git a/frontend/src/views/user/PaymentView.vue b/frontend/src/views/user/PaymentView.vue
index e91df5da..3f1401b3 100644
--- a/frontend/src/views/user/PaymentView.vue
+++ b/frontend/src/views/user/PaymentView.vue
@@ -277,7 +277,7 @@ import type { SubscriptionPlan, CheckoutInfoResponse, OrderType } from '@/types/
import AppLayout from '@/components/layout/AppLayout.vue'
import AmountInput from '@/components/payment/AmountInput.vue'
import PaymentMethodSelector from '@/components/payment/PaymentMethodSelector.vue'
-import { METHOD_ORDER, POPUP_WINDOW_FEATURES } from '@/components/payment/providerConfig'
+import { METHOD_ORDER, getPaymentPopupFeatures } from '@/components/payment/providerConfig'
import { platformAccentBarClass, platformBadgeLightClass, platformBadgeClass, platformTextClass, platformLabel } from '@/utils/platformColors'
import SubscriptionPlanCard from '@/components/payment/SubscriptionPlanCard.vue'
import PaymentStatusPanel from '@/components/payment/PaymentStatusPanel.vue'
@@ -551,9 +551,10 @@ async function createOrder(orderAmount: number, orderType: OrderType, planId?: n
payment_type: selectedMethod.value,
order_type: orderType,
plan_id: planId,
+ is_mobile: isMobileDevice(),
})
const openWindow = (url: string) => {
- const win = window.open(url, 'paymentPopup', POPUP_WINDOW_FEATURES)
+ const win = window.open(url, 'paymentPopup', getPaymentPopupFeatures())
if (!win || win.closed) {
window.location.href = url
}
From 235f710853b3e44b5b2fa236cfefd0b9d0c2a044 Mon Sep 17 00:00:00 2001
From: erio
Date: Sun, 19 Apr 2026 01:46:50 +0800
Subject: [PATCH 011/189] feat(payment): redact provider secrets in admin
config API
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
Admin GET /api/v1/admin/payment/providers previously returned every
config value — including privateKey / apiV3Key / secretKey etc. —
verbatim. Any future XSS on the admin UI would hand attackers the
full set of production payment credentials, and the plaintext values
sat unnecessarily in browser memory for every operator.
Treat those fields as write-only from the admin surface:
- decryptAndMaskConfig() strips sensitive keys from the GET response.
The authoritative list is an explicit per-provider registry that
mirrors the frontend's PROVIDER_CONFIG_FIELDS sensitive flag:
alipay → privateKey, publicKey, alipayPublicKey
wxpay → privateKey, apiV3Key, publicKey
stripe → secretKey, webhookSecret (publishableKey stays plain)
easypay → pkey
Payment runtime still reads the full config via decryptConfig, so
nothing at the gateway changes.
- mergeConfig() treats an empty value for a sensitive key as "leave
unchanged" — the admin UI omits unchanged secrets so operators can
tweak non-sensitive settings without re-entering credentials.
- Admin dialog (PaymentProviderDialog.vue):
* secret inputs get autocomplete="new-password", data-1p-ignore,
data-lpignore and data-bwignore so password managers do not
offer to save provider credentials
* in edit mode the required-field check skips sensitive fields
(empty is the "keep existing" signal) and the placeholder shows
"leave empty to keep" instead of the default example value
* create mode still requires every non-optional field, including
secrets, since there is nothing to preserve
- Unit test renamed to TestIsSensitiveProviderConfigField, covers
the per-provider registry and specifically asserts that Stripe's
publishableKey is NOT treated as a secret.
---
.../service/payment_config_providers.go | 78 +++++++++++++++----
.../service/payment_config_providers_test.go | 61 +++++++++------
.../payment/PaymentProviderDialog.vue | 23 ++++--
3 files changed, 118 insertions(+), 44 deletions(-)
diff --git a/backend/internal/service/payment_config_providers.go b/backend/internal/service/payment_config_providers.go
index 8e470525..3d1e4dc4 100644
--- a/backend/internal/service/payment_config_providers.go
+++ b/backend/internal/service/payment_config_providers.go
@@ -52,7 +52,7 @@ func (s *PaymentConfigService) ListProviderInstancesWithConfig(ctx context.Conte
AllowUserRefund: inst.AllowUserRefund,
SortOrder: inst.SortOrder, PaymentMode: inst.PaymentMode,
}
- resp.Config, err = s.decryptAndMaskConfig(inst.Config)
+ resp.Config, err = s.decryptAndMaskConfig(inst.ProviderKey, inst.Config)
if err != nil {
return nil, fmt.Errorf("decrypt config for instance %d: %w", inst.ID, err)
}
@@ -61,8 +61,26 @@ func (s *PaymentConfigService) ListProviderInstancesWithConfig(ctx context.Conte
return result, nil
}
-func (s *PaymentConfigService) decryptAndMaskConfig(encrypted string) (map[string]string, error) {
- return s.decryptConfig(encrypted)
+// decryptAndMaskConfig returns the stored config with sensitive fields omitted.
+// Admin UIs display masked placeholders for these; the raw values never leave
+// the server. Callers that need the full config (e.g. payment runtime) must
+// use decryptConfig directly.
+func (s *PaymentConfigService) decryptAndMaskConfig(providerKey, encrypted string) (map[string]string, error) {
+ cfg, err := s.decryptConfig(encrypted)
+ if err != nil {
+ return nil, err
+ }
+ if cfg == nil {
+ return nil, nil
+ }
+ masked := make(map[string]string, len(cfg))
+ for k, v := range cfg {
+ if isSensitiveProviderConfigField(providerKey, k) {
+ continue
+ }
+ masked[k] = v
+ }
+ return masked, nil
}
// pendingOrderStatuses are order statuses considered "in progress".
@@ -72,16 +90,27 @@ var pendingOrderStatuses = []string{
payment.OrderStatusRecharging,
}
-var sensitiveConfigPatterns = []string{"key", "pkey", "secret", "private", "password"}
+// providerSensitiveConfigFields is the authoritative list of config keys that
+// are treated as secrets per provider. Must stay in sync with the frontend
+// definition at frontend/src/components/payment/providerConfig.ts
+// (PROVIDER_CONFIG_FIELDS, fields with sensitive: true).
+//
+// Key matching is case-insensitive. Non-listed keys (e.g. appId, notifyUrl,
+// stripe publishableKey) are returned in plaintext by the admin GET API.
+var providerSensitiveConfigFields = map[string]map[string]struct{}{
+ payment.TypeEasyPay: {"pkey": {}},
+ payment.TypeAlipay: {"privatekey": {}, "publickey": {}, "alipaypublickey": {}},
+ payment.TypeWxpay: {"privatekey": {}, "apiv3key": {}, "publickey": {}},
+ payment.TypeStripe: {"secretkey": {}, "webhooksecret": {}},
+}
-func isSensitiveConfigField(fieldName string) bool {
- lower := strings.ToLower(fieldName)
- for _, p := range sensitiveConfigPatterns {
- if strings.Contains(lower, p) {
- return true
- }
+func isSensitiveProviderConfigField(providerKey, fieldName string) bool {
+ fields, ok := providerSensitiveConfigFields[providerKey]
+ if !ok {
+ return false
}
- return false
+ _, found := fields[strings.ToLower(fieldName)]
+ return found
}
func (s *PaymentConfigService) countPendingOrders(ctx context.Context, providerInstanceID int64) (int, error) {
@@ -137,10 +166,26 @@ func validateProviderRequest(providerKey, name, supportedTypes string) error {
// NOTE: This function exceeds 30 lines due to per-field nil-check patch update
// boilerplate and pending-order safety checks.
func (s *PaymentConfigService) UpdateProviderInstance(ctx context.Context, id int64, req UpdateProviderInstanceRequest) (*dbent.PaymentProviderInstance, error) {
+ var cachedInst *dbent.PaymentProviderInstance
+ loadInst := func() (*dbent.PaymentProviderInstance, error) {
+ if cachedInst != nil {
+ return cachedInst, nil
+ }
+ inst, err := s.entClient.PaymentProviderInstance.Get(ctx, id)
+ if err != nil {
+ return nil, fmt.Errorf("load provider instance: %w", err)
+ }
+ cachedInst = inst
+ return inst, nil
+ }
if req.Config != nil {
+ inst, err := loadInst()
+ if err != nil {
+ return nil, err
+ }
hasSensitive := false
- for k := range req.Config {
- if isSensitiveConfigField(k) && req.Config[k] != "" {
+ for k, v := range req.Config {
+ if v != "" && isSensitiveProviderConfigField(inst.ProviderKey, k) {
hasSensitive = true
break
}
@@ -283,9 +328,14 @@ func (s *PaymentConfigService) mergeConfig(ctx context.Context, id int64, newCon
return nil, fmt.Errorf("decrypt existing config for instance %d: %w", id, err)
}
if existing == nil {
- return newConfig, nil
+ existing = map[string]string{}
}
for k, v := range newConfig {
+ // Preserve existing secrets when the client submits an empty value
+ // (admin UI omits the value to indicate "leave unchanged").
+ if v == "" && isSensitiveProviderConfigField(inst.ProviderKey, k) {
+ continue
+ }
existing[k] = v
}
return existing, nil
diff --git a/backend/internal/service/payment_config_providers_test.go b/backend/internal/service/payment_config_providers_test.go
index 2aaa874f..bc2a9b18 100644
--- a/backend/internal/service/payment_config_providers_test.go
+++ b/backend/internal/service/payment_config_providers_test.go
@@ -97,41 +97,52 @@ func TestValidateProviderRequest(t *testing.T) {
}
}
-func TestIsSensitiveConfigField(t *testing.T) {
+func TestIsSensitiveProviderConfigField(t *testing.T) {
t.Parallel()
tests := []struct {
- field string
- wantSen bool
+ providerKey string
+ field string
+ wantSen bool
}{
- // Sensitive fields (contain key/secret/private/password/pkey patterns)
- {"secretKey", true},
- {"apiSecret", true},
- {"pkey", true},
- {"privateKey", true},
- {"apiPassword", true},
- {"appKey", true},
- {"SECRET_TOKEN", true},
- {"PrivateData", true},
- {"PASSWORD", true},
- {"mySecretValue", true},
+ // Stripe: publishableKey is public, only secretKey/webhookSecret are secrets
+ {"stripe", "secretKey", true},
+ {"stripe", "webhookSecret", true},
+ {"stripe", "SecretKey", true}, // case-insensitive
+ {"stripe", "publishableKey", false},
+ {"stripe", "appId", false},
- // Non-sensitive fields
- {"appId", false},
- {"mchId", false},
- {"apiBase", false},
- {"endpoint", false},
- {"merchantNo", false},
- {"paymentMode", false},
- {"notifyUrl", false},
+ // Alipay
+ {"alipay", "privateKey", true},
+ {"alipay", "publicKey", true},
+ {"alipay", "alipayPublicKey", true},
+ {"alipay", "appId", false},
+ {"alipay", "notifyUrl", false},
+
+ // Wxpay
+ {"wxpay", "privateKey", true},
+ {"wxpay", "apiV3Key", true},
+ {"wxpay", "publicKey", true},
+ {"wxpay", "publicKeyId", false},
+ {"wxpay", "certSerial", false},
+ {"wxpay", "mchId", false},
+
+ // EasyPay
+ {"easypay", "pkey", true},
+ {"easypay", "pid", false},
+ {"easypay", "apiBase", false},
+
+ // Unknown provider: never sensitive
+ {"unknown", "secretKey", false},
}
for _, tc := range tests {
- t.Run(tc.field, func(t *testing.T) {
+ tc := tc
+ t.Run(tc.providerKey+"/"+tc.field, func(t *testing.T) {
t.Parallel()
- got := isSensitiveConfigField(tc.field)
- assert.Equal(t, tc.wantSen, got, "isSensitiveConfigField(%q)", tc.field)
+ got := isSensitiveProviderConfigField(tc.providerKey, tc.field)
+ assert.Equal(t, tc.wantSen, got, "isSensitiveProviderConfigField(%q, %q)", tc.providerKey, tc.field)
})
}
}
diff --git a/frontend/src/components/payment/PaymentProviderDialog.vue b/frontend/src/components/payment/PaymentProviderDialog.vue
index 10c1bfea..624ddcdd 100644
--- a/frontend/src/components/payment/PaymentProviderDialog.vue
+++ b/frontend/src/components/payment/PaymentProviderDialog.vue
@@ -88,13 +88,24 @@
v-model="config[field.key]"
rows="3"
class="input font-mono text-xs"
+ autocomplete="new-password"
+ data-1p-ignore
+ data-lpignore="true"
+ data-bwignore="true"
+ spellcheck="false"
+ :placeholder="editing ? t('admin.accounts.leaveEmptyToKeep') : ''"
/>
= {}
for (const [k, v] of Object.entries(config)) {
if (!v || !v.trim()) continue
- // Skip masked values — backend keeps existing credentials
- if (v === '••••••••') continue
filteredConfig[k] = v
}
@@ -470,7 +482,8 @@ function loadProvider(provider: ProviderInstance) {
form.refund_enabled = provider.refund_enabled
form.allow_user_refund = provider.allow_user_refund
clearConfig()
- // Pre-fill config from API response (non-sensitive in cleartext, sensitive masked as ••••••••)
+ // Pre-fill config from API response. Backend omits sensitive fields entirely,
+ // so those inputs stay blank — submitting blank preserves the stored secret.
if (provider.config) {
for (const [k, v] of Object.entries(provider.config)) {
// Skip notifyUrl/returnUrl — they are derived from callbackBaseUrl
From 6530776a62dd355ccd3f9bca1ca3ed41f715bdd7 Mon Sep 17 00:00:00 2001
From: erio
Date: Sun, 19 Apr 2026 18:05:25 +0800
Subject: [PATCH 012/189] fix: support xhigh reasoning effort in usage records
for Claude Messages API
Closes #1732
---
backend/internal/service/gateway_request.go | 2 +-
backend/internal/service/gateway_request_test.go | 8 +++++++-
frontend/src/utils/format.ts | 4 +++-
3 files changed, 11 insertions(+), 3 deletions(-)
diff --git a/backend/internal/service/gateway_request.go b/backend/internal/service/gateway_request.go
index 55cb2c84..498336a4 100644
--- a/backend/internal/service/gateway_request.go
+++ b/backend/internal/service/gateway_request.go
@@ -962,7 +962,7 @@ func NormalizeClaudeOutputEffort(raw string) *string {
return nil
}
switch value {
- case "low", "medium", "high", "max":
+ case "low", "medium", "high", "xhigh", "max":
return &value
default:
return nil
diff --git a/backend/internal/service/gateway_request_test.go b/backend/internal/service/gateway_request_test.go
index d262456d..40bd1186 100644
--- a/backend/internal/service/gateway_request_test.go
+++ b/backend/internal/service/gateway_request_test.go
@@ -1149,6 +1149,11 @@ func TestParseGatewayRequest_OutputEffort(t *testing.T) {
body: `{"model":"claude-opus-4-6","output_config":{"effort":"max"},"messages":[]}`,
wantEffort: "max",
},
+ {
+ name: "output_config.effort xhigh",
+ body: `{"model":"claude-opus-4-7","output_config":{"effort":"xhigh"},"messages":[]}`,
+ wantEffort: "xhigh",
+ },
{
name: "output_config without effort",
body: `{"model":"claude-opus-4-6","output_config":{},"messages":[]}`,
@@ -1186,9 +1191,10 @@ func TestNormalizeClaudeOutputEffort(t *testing.T) {
{"LOW", strPtr("low")},
{"Max", strPtr("max")},
{" medium ", strPtr("medium")},
+ {"xhigh", strPtr("xhigh")},
+ {"XHIGH", strPtr("xhigh")},
{"", nil},
{"unknown", nil},
- {"xhigh", nil},
}
for _, tt := range tests {
t.Run(tt.input, func(t *testing.T) {
diff --git a/frontend/src/utils/format.ts b/frontend/src/utils/format.ts
index ea5d597b..6f13a065 100644
--- a/frontend/src/utils/format.ts
+++ b/frontend/src/utils/format.ts
@@ -193,7 +193,9 @@ export function formatReasoningEffort(effort: string | null | undefined): string
return 'High'
case 'xhigh':
case 'extrahigh':
- return 'Xhigh'
+ return 'XHigh'
+ case 'max':
+ return 'Max'
case 'none':
case 'minimal':
return '-'
From 258fd145ff8a2355befc03166f34c3d498da7aa8 Mon Sep 17 00:00:00 2001
From: erio
Date: Sun, 19 Apr 2026 18:45:04 +0800
Subject: [PATCH 013/189] fix(account): prevent quota-exceeded API key/Bedrock
accounts from being scheduled
Add quota exceeded check to IsSchedulable() and refactor
shouldClearStickySession to delegate to IsSchedulable(), eliminating
duplicated logic and fixing missed overload/rate-limit/expired checks.
Frontend displays quota exceeded status independently via quota fields.
---
backend/internal/service/account.go | 3 +
.../service/account_quota_schedulable_test.go | 123 ++++++++++++++++++
backend/internal/service/gateway_service.go | 15 +--
.../internal/service/sticky_session_test.go | 66 ++++++++--
.../account/AccountStatusIndicator.vue | 33 +++--
frontend/src/i18n/locales/en.ts | 1 +
frontend/src/i18n/locales/zh.ts | 1 +
7 files changed, 207 insertions(+), 35 deletions(-)
create mode 100644 backend/internal/service/account_quota_schedulable_test.go
diff --git a/backend/internal/service/account.go b/backend/internal/service/account.go
index 52db3073..af686ae7 100644
--- a/backend/internal/service/account.go
+++ b/backend/internal/service/account.go
@@ -121,6 +121,9 @@ func (a *Account) IsSchedulable() bool {
if a.TempUnschedulableUntil != nil && now.Before(*a.TempUnschedulableUntil) {
return false
}
+ if a.IsAPIKeyOrBedrock() && a.IsQuotaExceeded() {
+ return false
+ }
return true
}
diff --git a/backend/internal/service/account_quota_schedulable_test.go b/backend/internal/service/account_quota_schedulable_test.go
new file mode 100644
index 00000000..2895b34c
--- /dev/null
+++ b/backend/internal/service/account_quota_schedulable_test.go
@@ -0,0 +1,123 @@
+//go:build unit
+
+package service
+
+import (
+ "testing"
+ "time"
+
+ "github.com/stretchr/testify/require"
+)
+
+func TestAccountIsSchedulable_QuotaExceeded(t *testing.T) {
+ now := time.Now()
+
+ tests := []struct {
+ name string
+ account *Account
+ want bool
+ }{
+ {
+ name: "apikey daily quota exceeded",
+ account: &Account{
+ Status: StatusActive,
+ Schedulable: true,
+ Type: AccountTypeAPIKey,
+ Extra: map[string]any{
+ "quota_daily_limit": 10.0,
+ "quota_daily_used": 10.0,
+ "quota_daily_start": now.Add(-1 * time.Hour).Format(time.RFC3339),
+ },
+ },
+ want: false,
+ },
+ {
+ name: "apikey weekly quota exceeded",
+ account: &Account{
+ Status: StatusActive,
+ Schedulable: true,
+ Type: AccountTypeAPIKey,
+ Extra: map[string]any{
+ "quota_weekly_limit": 50.0,
+ "quota_weekly_used": 50.0,
+ "quota_weekly_start": now.Add(-2 * 24 * time.Hour).Format(time.RFC3339),
+ },
+ },
+ want: false,
+ },
+ {
+ name: "apikey total quota exceeded",
+ account: &Account{
+ Status: StatusActive,
+ Schedulable: true,
+ Type: AccountTypeAPIKey,
+ Extra: map[string]any{
+ "quota_limit": 100.0,
+ "quota_used": 100.0,
+ },
+ },
+ want: false,
+ },
+ {
+ name: "apikey quota not exceeded",
+ account: &Account{
+ Status: StatusActive,
+ Schedulable: true,
+ Type: AccountTypeAPIKey,
+ Extra: map[string]any{
+ "quota_daily_limit": 10.0,
+ "quota_daily_used": 5.0,
+ "quota_daily_start": now.Add(-1 * time.Hour).Format(time.RFC3339),
+ },
+ },
+ want: true,
+ },
+ {
+ name: "apikey expired daily period restores schedulable",
+ account: &Account{
+ Status: StatusActive,
+ Schedulable: true,
+ Type: AccountTypeAPIKey,
+ Extra: map[string]any{
+ "quota_daily_limit": 10.0,
+ "quota_daily_used": 10.0,
+ "quota_daily_start": now.Add(-25 * time.Hour).Format(time.RFC3339),
+ },
+ },
+ want: true,
+ },
+ {
+ name: "oauth ignores quota exceeded",
+ account: &Account{
+ Status: StatusActive,
+ Schedulable: true,
+ Type: AccountTypeOAuth,
+ Extra: map[string]any{
+ "quota_daily_limit": 10.0,
+ "quota_daily_used": 10.0,
+ "quota_daily_start": now.Add(-1 * time.Hour).Format(time.RFC3339),
+ },
+ },
+ want: true,
+ },
+ {
+ name: "bedrock quota exceeded",
+ account: &Account{
+ Status: StatusActive,
+ Schedulable: true,
+ Type: AccountTypeBedrock,
+ Extra: map[string]any{
+ "quota_limit": 200.0,
+ "quota_used": 200.0,
+ },
+ },
+ want: false,
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ require.Equal(t, tt.want, tt.account.IsSchedulable())
+ })
+ }
+}
diff --git a/backend/internal/service/gateway_service.go b/backend/internal/service/gateway_service.go
index 07a9e41c..5a91d0de 100644
--- a/backend/internal/service/gateway_service.go
+++ b/backend/internal/service/gateway_service.go
@@ -435,26 +435,19 @@ func prefetchedStickyAccountIDFromContext(ctx context.Context, groupID *int64) i
}
// shouldClearStickySession 检查账号是否处于不可调度状态,需要清理粘性会话绑定。
-// 当账号状态为错误、禁用、不可调度、处于临时不可调度期间,
-// 或请求的模型处于限流状态时,返回 true。
-// 这确保后续请求不会继续使用不可用的账号。
+// 委托 IsSchedulable() 判断账号级可调度性(状态、配额、过载、限流等),
+// 额外检查模型级限流。
//
// shouldClearStickySession checks if an account is in an unschedulable state
// and the sticky session binding should be cleared.
-// Returns true when account status is error/disabled, schedulable is false,
-// within temporary unschedulable period, or the requested model is rate-limited.
-// This ensures subsequent requests won't continue using unavailable accounts.
+// Delegates to IsSchedulable() for account-level checks, plus model-level rate limiting.
func shouldClearStickySession(account *Account, requestedModel string) bool {
if account == nil {
return false
}
- if account.Status == StatusError || account.Status == StatusDisabled || !account.Schedulable {
+ if !account.IsSchedulable() {
return true
}
- if account.TempUnschedulableUntil != nil && time.Now().Before(*account.TempUnschedulableUntil) {
- return true
- }
- // 检查模型限流和 scope 限流,有限流即清除粘性会话
if remaining := account.GetRateLimitRemainingTimeWithContext(context.Background(), requestedModel); remaining > 0 {
return true
}
diff --git a/backend/internal/service/sticky_session_test.go b/backend/internal/service/sticky_session_test.go
index e7ef8982..11ace7bd 100644
--- a/backend/internal/service/sticky_session_test.go
+++ b/backend/internal/service/sticky_session_test.go
@@ -15,20 +15,8 @@ import (
"github.com/stretchr/testify/require"
)
-// TestShouldClearStickySession 测试粘性会话清理判断逻辑。
-// 验证在以下情况下是否正确判断需要清理粘性会话:
-// - nil 账号:不清理(返回 false)
-// - 状态为错误或禁用:清理
-// - 不可调度:清理
-// - 临时不可调度且未过期:清理
-// - 临时不可调度已过期:不清理
-// - 正常可调度状态:不清理
-// - 模型限流(任意时长):清理
-//
-// TestShouldClearStickySession tests the sticky session clearing logic.
-// Verifies correct behavior for various account states including:
-// nil account, error/disabled status, unschedulable, temporary unschedulable,
-// and model rate limiting scenarios.
+// TestShouldClearStickySession tests sticky session clearing via IsSchedulable() delegation
+// plus model-level rate limiting.
func TestShouldClearStickySession(t *testing.T) {
now := time.Now()
future := now.Add(1 * time.Hour)
@@ -101,6 +89,56 @@ func TestShouldClearStickySession(t *testing.T) {
requestedModel: "claude-opus-4", // 请求不同模型
want: false, // 不同模型不受影响
},
+ {
+ name: "apikey quota exceeded",
+ account: &Account{
+ Status: StatusActive,
+ Schedulable: true,
+ Type: AccountTypeAPIKey,
+ Extra: map[string]any{
+ "quota_daily_limit": 10.0,
+ "quota_daily_used": 10.0,
+ "quota_daily_start": now.Add(-1 * time.Hour).Format(time.RFC3339),
+ },
+ },
+ requestedModel: "",
+ want: true,
+ },
+ {
+ name: "oauth quota exceeded not cleared",
+ account: &Account{
+ Status: StatusActive,
+ Schedulable: true,
+ Type: AccountTypeOAuth,
+ Extra: map[string]any{
+ "quota_daily_limit": 10.0,
+ "quota_daily_used": 10.0,
+ "quota_daily_start": now.Add(-1 * time.Hour).Format(time.RFC3339),
+ },
+ },
+ requestedModel: "",
+ want: false,
+ },
+ {
+ name: "overloaded account",
+ account: &Account{
+ Status: StatusActive,
+ Schedulable: true,
+ OverloadUntil: &future,
+ },
+ requestedModel: "",
+ want: true,
+ },
+ {
+ name: "account-level rate limited",
+ account: &Account{
+ Status: StatusActive,
+ Schedulable: true,
+ RateLimitResetAt: &future,
+ },
+ requestedModel: "",
+ want: true,
+ },
}
for _, tt := range tests {
diff --git a/frontend/src/components/account/AccountStatusIndicator.vue b/frontend/src/components/account/AccountStatusIndicator.vue
index fc2f7d0c..dd38a49f 100644
--- a/frontend/src/components/account/AccountStatusIndicator.vue
+++ b/frontend/src/components/account/AccountStatusIndicator.vue
@@ -284,6 +284,16 @@ const hasError = computed(() => {
return props.account.status === 'error'
})
+const isQuotaExceeded = computed(() => {
+ const exceeded = (used?: number | null, limit?: number | null) =>
+ typeof limit === 'number' && limit > 0 && typeof used === 'number' && used >= limit
+ return (
+ exceeded(props.account.quota_used, props.account.quota_limit) ||
+ exceeded(props.account.quota_daily_used, props.account.quota_daily_limit) ||
+ exceeded(props.account.quota_weekly_used, props.account.quota_weekly_limit)
+ )
+})
+
// Computed: countdown text for rate limit (429)
const rateLimitCountdown = computed(() => {
return formatCountdown(props.account.rate_limit_reset_at)
@@ -307,19 +317,16 @@ const statusClass = computed(() => {
if (isTempUnschedulable.value) {
return 'badge-warning'
}
+ if (props.account.status !== 'active') {
+ return props.account.status === 'error' ? 'badge-danger' : 'badge-gray'
+ }
+ if (isQuotaExceeded.value) {
+ return 'badge-warning'
+ }
if (!props.account.schedulable) {
return 'badge-gray'
}
- switch (props.account.status) {
- case 'active':
- return 'badge-success'
- case 'inactive':
- return 'badge-gray'
- case 'error':
- return 'badge-danger'
- default:
- return 'badge-gray'
- }
+ return 'badge-success'
})
// Computed: status text
@@ -330,6 +337,12 @@ const statusText = computed(() => {
if (isTempUnschedulable.value) {
return t('admin.accounts.status.tempUnschedulable')
}
+ if (props.account.status !== 'active') {
+ return t(`admin.accounts.status.${props.account.status}`)
+ }
+ if (isQuotaExceeded.value) {
+ return t('admin.accounts.status.quotaExceeded')
+ }
if (!props.account.schedulable) {
return t('admin.accounts.status.paused')
}
diff --git a/frontend/src/i18n/locales/en.ts b/frontend/src/i18n/locales/en.ts
index d1def45c..c0a17d96 100644
--- a/frontend/src/i18n/locales/en.ts
+++ b/frontend/src/i18n/locales/en.ts
@@ -2126,6 +2126,7 @@ export default {
rateLimited: 'Rate Limited',
overloaded: 'Overloaded',
tempUnschedulable: 'Temp Unschedulable',
+ quotaExceeded: 'Quota Exceeded',
unschedulable: 'Unschedulable',
rateLimitedUntil: 'Rate limited and removed from scheduling. Auto resumes at {time}',
rateLimitedAutoResume: 'Auto resumes in {time}',
diff --git a/frontend/src/i18n/locales/zh.ts b/frontend/src/i18n/locales/zh.ts
index 6f57ab3e..ba9edd7f 100644
--- a/frontend/src/i18n/locales/zh.ts
+++ b/frontend/src/i18n/locales/zh.ts
@@ -2315,6 +2315,7 @@ export default {
rateLimited: '限流中',
overloaded: '过载中',
tempUnschedulable: '临时不可调度',
+ quotaExceeded: '配额超限',
unschedulable: '不可调度',
rateLimitedUntil: '限流中,当前不参与调度,预计 {time} 自动恢复',
rateLimitedAutoResume: '{time} 自动恢复',
From 6579f28b649486cd3bfc9de3f2ed442f87707daf Mon Sep 17 00:00:00 2001
From: erio
Date: Sun, 19 Apr 2026 20:38:57 +0800
Subject: [PATCH 014/189] fix: delete scheduled test plans when account is
deleted
Accounts use soft-delete (setting deleted_at), so PostgreSQL's
ON DELETE CASCADE on scheduled_test_plans.account_id never fires.
Add plan deletion to the existing account deletion transaction
to ensure atomicity.
Closes Wei-Shaw/sub2api#1728
---
backend/internal/repository/account_repo.go | 3 +++
1 file changed, 3 insertions(+)
diff --git a/backend/internal/repository/account_repo.go b/backend/internal/repository/account_repo.go
index 24115c33..78f739ac 100644
--- a/backend/internal/repository/account_repo.go
+++ b/backend/internal/repository/account_repo.go
@@ -438,6 +438,9 @@ func (r *accountRepository) Delete(ctx context.Context, id int64) error {
if _, err := txClient.AccountGroup.Delete().Where(dbaccountgroup.AccountIDEQ(id)).Exec(ctx); err != nil {
return err
}
+ if _, err := txClient.ExecContext(ctx, "DELETE FROM scheduled_test_plans WHERE account_id = $1", id); err != nil {
+ return err
+ }
if _, err := txClient.Account.Delete().Where(dbaccount.IDEQ(id)).Exec(ctx); err != nil {
return err
}
From 23def40bc5415c04ca3a05bb6d67a6ff1e4a3566 Mon Sep 17 00:00:00 2001
From: shaw
Date: Sun, 19 Apr 2026 22:06:04 +0800
Subject: [PATCH 015/189] chore: change license from MIT to LGPL v3.0
---
LICENSE | 178 ++++++++++++++++++++++++++++++++++++++++++++++-----
README.md | 4 +-
README_CN.md | 4 +-
README_JA.md | 4 +-
4 files changed, 170 insertions(+), 20 deletions(-)
diff --git a/LICENSE b/LICENSE
index 7a94ca9d..153d416d 100644
--- a/LICENSE
+++ b/LICENSE
@@ -1,21 +1,165 @@
-MIT License
+ GNU LESSER GENERAL PUBLIC LICENSE
+ Version 3, 29 June 2007
-Copyright (c) 2025 Wesley Liddick
+ Copyright (C) 2007 Free Software Foundation, Inc.
+ Everyone is permitted to copy and distribute verbatim copies
+ of this license document, but changing it is not allowed.
-Permission is hereby granted, free of charge, to any person obtaining a copy
-of this software and associated documentation files (the "Software"), to deal
-in the Software without restriction, including without limitation the rights
-to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
-copies of the Software, and to permit persons to whom the Software is
-furnished to do so, subject to the following conditions:
-The above copyright notice and this permission notice shall be included in all
-copies or substantial portions of the Software.
+ This version of the GNU Lesser General Public License incorporates
+the terms and conditions of version 3 of the GNU General Public
+License, supplemented by the additional permissions listed below.
-THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
-IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
-FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
-AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
-LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
-OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
-SOFTWARE.
+ 0. Additional Definitions.
+
+ As used herein, "this License" refers to version 3 of the GNU Lesser
+General Public License, and the "GNU GPL" refers to version 3 of the GNU
+General Public License.
+
+ "The Library" refers to a covered work governed by this License,
+other than an Application or a Combined Work as defined below.
+
+ An "Application" is any work that makes use of an interface provided
+by the Library, but which is not otherwise based on the Library.
+Defining a subclass of a class defined by the Library is deemed a mode
+of using an interface provided by the Library.
+
+ A "Combined Work" is a work produced by combining or linking an
+Application with the Library. The particular version of the Library
+with which the Combined Work was made is also called the "Linked
+Version".
+
+ The "Minimal Corresponding Source" for a Combined Work means the
+Corresponding Source for the Combined Work, excluding any source code
+for portions of the Combined Work that, considered in isolation, are
+based on the Application, and not on the Linked Version.
+
+ The "Corresponding Application Code" for a Combined Work means the
+object code and/or source code for the Application, including any data
+and utility programs needed for reproducing the Combined Work from the
+Application, but excluding the System Libraries of the Combined Work.
+
+ 1. Exception to Section 3 of the GNU GPL.
+
+ You may convey a covered work under sections 3 and 4 of this License
+without being bound by section 3 of the GNU GPL.
+
+ 2. Conveying Modified Versions.
+
+ If you modify a copy of the Library, and, in your modifications, a
+facility refers to a function or data to be supplied by an Application
+that uses the facility (other than as an argument passed when the
+facility is invoked), then you may convey a copy of the modified
+version:
+
+ a) under this License, provided that you make a good faith effort to
+ ensure that, in the event an Application does not supply the
+ function or data, the facility still operates, and performs
+ whatever part of its purpose remains meaningful, or
+
+ b) under the GNU GPL, with none of the additional permissions of
+ this License applicable to that copy.
+
+ 3. Object Code Incorporating Material from Library Header Files.
+
+ The object code form of an Application may incorporate material from
+a header file that is part of the Library. You may convey such object
+code under terms of your choice, provided that, if the incorporated
+material is not limited to numerical parameters, data structure
+layouts and accessors, or small macros, inline functions and templates
+(ten or fewer lines in length), you do both of the following:
+
+ a) Give prominent notice with each copy of the object code that the
+ Library is used in it and that the Library and its use are
+ covered by this License.
+
+ b) Accompany the object code with a copy of the GNU GPL and this license
+ document.
+
+ 4. Combined Works.
+
+ You may convey a Combined Work under terms of your choice that,
+taken together, effectively do not restrict modification of the
+portions of the Library contained in the Combined Work and reverse
+engineering for debugging such modifications, if you also do each of
+the following:
+
+ a) Give prominent notice with each copy of the Combined Work that
+ the Library is used in it and that the Library and its use are
+ covered by this License.
+
+ b) Accompany the Combined Work with a copy of the GNU GPL and this license
+ document.
+
+ c) For a Combined Work that displays copyright notices during
+ execution, include the copyright notice for the Library among
+ these notices, as well as a reference directing the user to the
+ copies of the GNU GPL and this license document.
+
+ d) Do one of the following:
+
+ 0) Convey the Minimal Corresponding Source under the terms of this
+ License, and the Corresponding Application Code in a form
+ suitable for, and under terms that permit, the user to
+ recombine or relink the Application with a modified version of
+ the Linked Version to produce a modified Combined Work, in the
+ manner specified by section 6 of the GNU GPL for conveying
+ Corresponding Source.
+
+ 1) Use a suitable shared library mechanism for linking with the
+ Library. A suitable mechanism is one that (a) uses at run time
+ a copy of the Library already present on the user's computer
+ system, and (b) will operate properly with a modified version
+ of the Library that is interface-compatible with the Linked
+ Version.
+
+ e) Provide Installation Information, but only if you would otherwise
+ be required to provide such information under section 6 of the
+ GNU GPL, and only to the extent that such information is
+ necessary to install and execute a modified version of the
+ Combined Work produced by recombining or relinking the
+ Application with a modified version of the Linked Version. (If
+ you use option 4d0, the Installation Information must accompany
+ the Minimal Corresponding Source and Corresponding Application
+ Code. If you use option 4d1, you must provide the Installation
+ Information in the manner specified by section 6 of the GNU GPL
+ for conveying Corresponding Source.)
+
+ 5. Combined Libraries.
+
+ You may place library facilities that are a work based on the
+Library side by side in a single library together with other library
+facilities that are not Applications and are not covered by this
+License, and convey such a combined library under terms of your
+choice, if you do both of the following:
+
+ a) Accompany the combined library with a copy of the same work based
+ on the Library, uncombined with any other library facilities,
+ conveyed under the terms of this License.
+
+ b) Give prominent notice with the combined library that part of it
+ is a work based on the Library, and explaining where to find the
+ accompanying uncombined form of the same work.
+
+ 6. Revised Versions of the GNU Lesser General Public License.
+
+ The Free Software Foundation may publish revised and/or new versions
+of the GNU Lesser General Public License from time to time. Such new
+versions will be similar in spirit to the present version, but may
+differ in detail to address new problems or concerns.
+
+ Each version is given a distinguishing version number. If the
+Library as you received it specifies that a certain numbered version
+of the GNU Lesser General Public License "or any later version"
+applies to it, you have the option of following the terms and
+conditions either of that published version or of any later version
+published by the Free Software Foundation. If the Library as you
+received it does not specify a version number of the GNU Lesser
+General Public License, you may choose any version of the GNU Lesser
+General Public License ever published by the Free Software Foundation.
+
+ If the Library as you received it specifies that a proxy can decide
+whether future versions of the GNU Lesser General Public License shall
+apply, that proxy's public statement of acceptance of any version is
+permanent authorization for you to choose that version for the
+Library.
\ No newline at end of file
diff --git a/README.md b/README.md
index 74ab9af2..bee2e8c3 100644
--- a/README.md
+++ b/README.md
@@ -618,7 +618,9 @@ sub2api/
## License
-MIT License
+This project is licensed under the [GNU Lesser General Public License v3.0](LICENSE) (or later).
+
+Copyright (c) 2026 Wesley Liddick
---
diff --git a/README_CN.md b/README_CN.md
index c701372c..892eee61 100644
--- a/README_CN.md
+++ b/README_CN.md
@@ -679,7 +679,9 @@ sub2api/
## 许可证
-MIT License
+本项目基于 [GNU 宽通用公共许可证 v3.0](LICENSE)(或更高版本)授权。
+
+Copyright (c) 2026 Wesley Liddick
---
diff --git a/README_JA.md b/README_JA.md
index 0d4db616..6f0fc900 100644
--- a/README_JA.md
+++ b/README_JA.md
@@ -617,7 +617,9 @@ sub2api/
## ライセンス
-MIT License
+本プロジェクトは [GNU Lesser General Public License v3.0](LICENSE)(またはそれ以降のバージョン)の下でライセンスされています。
+
+Copyright (c) 2026 Wesley Liddick
---
From e01c1eacebc4b4d7cce25de83fb2cb94a76c10fb Mon Sep 17 00:00:00 2001
From: IanShaw027
Date: Mon, 20 Apr 2026 13:18:30 +0800
Subject: [PATCH 016/189] docs: add auth identity payment foundation design
spec
---
...auth-identity-payment-foundation-design.md | 540 ++++++++++++++++++
1 file changed, 540 insertions(+)
create mode 100644 docs/superpowers/specs/2026-04-20-auth-identity-payment-foundation-design.md
diff --git a/docs/superpowers/specs/2026-04-20-auth-identity-payment-foundation-design.md b/docs/superpowers/specs/2026-04-20-auth-identity-payment-foundation-design.md
new file mode 100644
index 00000000..a60cfdca
--- /dev/null
+++ b/docs/superpowers/specs/2026-04-20-auth-identity-payment-foundation-design.md
@@ -0,0 +1,540 @@
+# Auth Identity And Payment Foundation Design
+
+**Date:** 2026-04-20
+
+**Status:** Draft approved in conversation, written for implementation planning
+
+**Goal**
+
+Rebuild the `feat/auth-identity-foundation` intent on a clean branch from `main`, covering unified user identity, third-party login and binding, profile adoption, source-based signup defaults, unified payment routing and UX, admin configuration, compatibility with existing `main` data, and an opt-in OpenAI advanced scheduling switch.
+
+## Scope
+
+This design includes:
+
+- Email login and registration
+- Third-party login and binding for `LinuxDo`, `OIDC`, and `WeChat`
+- Unified identity storage for email and third-party identities
+- Pending auth sessions for callback-to-login/register/bind continuation
+- User-controlled nickname/avatar adoption during first relevant third-party flow
+- Profile binding management and avatar upload/delete
+- Source-based initial grants for balance, concurrency, and subscriptions
+- User management support for `last_login_at` and `last_active_at` sorting
+- Unified payment display methods (`alipay`, `wechat`) mapped to a single active backend source each
+- Alipay and WeChat UX routing rules across PC, mobile, H5, and WeChat environments
+- Admin settings for auth providers, source defaults, payment sources, and OpenAI advanced scheduling
+- Incremental migration and compatibility for existing email users and historical LinuxDo synthetic-email users
+
+This design does not treat unrelated upstream merges, docs churn, or license changes from the old branch as required scope.
+
+## Product Rules
+
+### Auth and identity
+
+- Existing email users remain valid and continue to log in with no manual action.
+- Third-party first login behavior:
+ - Existing bound identity: direct login
+ - Missing identity: start first-login flow
+- If `force_email_on_third_party_signup` is disabled, a first-login user may create an account without binding an email.
+- If `force_email_on_third_party_signup` is enabled, the user must provide an email.
+- If the provided and verified email already exists:
+ - show that the email already exists
+ - allow "verify and bind existing account"
+ - allow "change email and continue registration"
+ - do not allow bypassing the email requirement
+- Upstream provider email verification is not trusted as a local bound email.
+- WeChat login chooses channel by environment:
+ - in WeChat environment: `mp`
+ - outside WeChat: `open`
+- WeChat primary identity key is `unionid`.
+- If a WeChat login/bind flow cannot produce `unionid`, the flow fails and no fallback `openid` identity is created.
+
+### Profile adoption
+
+- During the first relevant third-party flow, the user can independently decide:
+ - replace current nickname or not
+ - replace current avatar or not
+- This applies to first third-party registration and first third-party binding.
+- The decision is explicit user choice, not automatic replacement.
+
+### Source-based initial grants
+
+- Source-specific defaults exist for `email`, `linuxdo`, `oidc`, and `wechat`.
+- Each source defines:
+ - default balance
+ - default concurrency
+ - default subscriptions
+ - grant on signup
+ - grant on first bind
+- Default behavior:
+ - grant on signup: enabled
+ - grant on first bind: disabled
+- First-bind grants are optional and controlled per source.
+- Grants must be idempotent.
+
+### Avatar management
+
+- Avatar supports:
+ - external URL
+ - image `data:` URL
+- `data:` URL images are compressed to at most `100KB` before persistence.
+- Avatar storage is database-backed.
+- Avatar delete is supported.
+
+### Payment UX and routing
+
+- Frontend shows only two display methods:
+ - `alipay`
+ - `wechat`
+- Users never choose between official providers and EasyPay explicitly.
+- Backend allows only one active source per display method at a time.
+- Alipay UX:
+ - PC: show QR code in page
+ - mobile: jump to Alipay app/payment flow
+- WeChat UX:
+ - PC: show QR code in page
+ - non-WeChat H5: prefer H5 pay; if unavailable, tell the user to open in WeChat
+ - WeChat environment: prefer MP/JSAPI pay; if unavailable, fall back to H5 pay
+- Payment success is confirmed by backend order state, webhook, and/or query, not only frontend return.
+
+### OpenAI advanced scheduling
+
+- OpenAI advanced scheduling is supported.
+- It is disabled by default.
+- Admin can enable it explicitly.
+
+## Architecture
+
+Keep `users` as the account owner table and move login identities, channel mappings, pending auth state, and first-bind grant idempotency into dedicated tables and services. Keep email login working while progressively introducing unified identity reads and writes.
+
+Payment uses a similar split between user-visible display methods and backend provider sources. Frontend works only with stable display methods while backend resolves to the currently active source and capability matrix.
+
+Compatibility is a first-class concern: migrations are additive, reads are compatibility-aware, and rollout must tolerate existing `main` data and short-lived frontend/backend version skew.
+
+## Data Model
+
+### `users`
+
+Preserve existing account ownership and local-login fields. Extend or use:
+
+- `email`
+- `password_hash`
+- `totp_enabled`
+- `signup_source`
+- `last_login_at`
+- `last_active_at`
+
+The `users` table remains the primary business subject for balance, concurrency, subscriptions, permissions, and profile.
+
+### `auth_identities`
+
+Represents all canonical login or bindable identities.
+
+Fields:
+
+- `user_id`
+- `provider_type`: `email`, `linuxdo`, `oidc`, `wechat`
+- `provider_key`
+- `provider_subject`
+- `verified_at`
+- `issuer`
+- `metadata`
+- timestamps
+
+Uniqueness:
+
+- `provider_type + provider_key + provider_subject` must be unique
+
+Rules:
+
+- email identity uses canonicalized local email
+- LinuxDo uses stable provider subject
+- OIDC uses stable issuer + subject
+- WeChat uses `unionid` as canonical subject
+
+### `auth_identity_channels`
+
+Stores channel-specific subject mappings for an identity.
+
+Primary use:
+
+- WeChat `open` / `mp` / payment channel mapping
+
+Fields:
+
+- `identity_id`
+- `provider_type`
+- `provider_key`
+- `channel`
+- `channel_app_id`
+- `channel_subject`
+- `metadata`
+- timestamps
+
+Rules:
+
+- canonical WeChat identity still keys on `unionid`
+- `openid` values live here as channel mappings
+
+### `pending_auth_sessions`
+
+Stores callback state between third-party callback and final account action.
+
+Fields:
+
+- `intent`
+- `provider_type`
+- `provider_key`
+- `provider_subject`
+- `target_user_id`
+- `redirect_to`
+- `resolved_email`
+- `pending_password_hash`
+- `upstream_identity_payload`
+- `metadata`
+- `email_verified_at`
+- `password_verified_at`
+- `totp_verified_at`
+- `expires_at`
+- `consumed_at`
+- timestamps
+
+Responsibilities:
+
+- continue provider callback into register/login/bind flows
+- persist nickname/avatar suggestions
+- persist explicit adoption decisions
+- survive navigation between auth pages
+
+### `identity_adoption_decisions`
+
+Persists user adoption preference for a specific identity.
+
+Fields:
+
+- `identity_id`
+- `adopt_display_name`
+- `adopt_avatar`
+- `decided_at`
+- timestamps
+
+### `user_avatars`
+
+Stores the currently effective custom avatar.
+
+Fields:
+
+- `user_id`
+- `storage_provider`
+- `storage_key`
+- `url`
+- `content_type`
+- `byte_size`
+- `sha256`
+- timestamps
+
+Rules:
+
+- supports URL-backed and inline data-backed representations
+- hard maximum payload size is `100KB`
+
+### `user_provider_default_grants`
+
+Stores idempotency state for source grants.
+
+Fields:
+
+- `user_id`
+- `provider_type`
+- `granted_at`
+- timestamps
+
+Responsibilities:
+
+- prevent duplicate first-bind grants
+- allow signup grants and first-bind grants to be reasoned about independently
+
+## Identity Keys And Canonicalization
+
+- Email canonical key: `lower(trim(email))`
+- LinuxDo canonical key: provider subject from LinuxDo
+- OIDC canonical key: `issuer + sub`
+- WeChat canonical key: `unionid`
+
+WeChat-specific rule:
+
+- `openid` never becomes the primary stored identity key
+- if only `openid` is available, login/bind fails with a configuration/identity error
+
+## Core Flows
+
+### Email register/login
+
+- Existing email auth flow remains
+- On email registration, create canonical `email` identity
+- Apply `email` source signup defaults
+
+### Third-party login with existing identity
+
+- Resolve canonical identity
+- Login mapped `user`
+- Update `last_login_at`
+- Do not issue signup or first-bind grants again
+
+### Third-party first login with no identity
+
+- Create `pending_auth_session`
+- Frontend callback flow decides next action
+
+Branches:
+
+- no forced email binding:
+ - user can create account directly
+- forced email binding:
+ - user must supply local email
+
+If supplied local email already exists:
+
+- tell the user the email already exists
+- allow verify-and-bind-existing-account
+- allow changing email to continue registration
+
+On new account creation:
+
+- create `users` row
+- create canonical third-party identity
+- apply source signup grants
+- apply adoption choices if selected
+
+### Bind third-party identity to current logged-in user
+
+- current user starts bind flow
+- callback resolves to `bind_current_user`
+- bind canonical identity to current user
+- if configured and first bind for that provider, apply first-bind grants
+- present nickname/avatar replacement choice
+
+### Bind existing account during first-login flow
+
+- verify password for existing account
+- if account requires TOTP, verify TOTP
+- bind canonical identity to target account
+- optionally apply first-bind grants
+- present nickname/avatar replacement choice
+
+### WeChat login and channel mapping
+
+- environment chooses `mp` or `open`
+- callback must resolve to `unionid`
+- channel `openid` is optionally recorded in `auth_identity_channels`
+- failure to obtain `unionid` aborts flow
+
+### Avatar upload and delete
+
+- URL avatar: validate and persist reference
+- data URL avatar:
+ - decode
+ - validate image type
+ - compress to `<=100KB`
+ - persist database-backed inline representation
+- delete removes current custom avatar entry
+
+## Payment Routing Model
+
+### User-visible methods
+
+- `alipay`
+- `wechat`
+
+### Backend source abstraction
+
+Each display method maps to exactly one active configured backend source:
+
+- `official_alipay`
+- `easypay_alipay`
+- `official_wechat`
+- `easypay_wechat`
+
+Frontend submits display method only. Backend resolves display method to active source and capability set.
+
+### Alipay routing
+
+- PC: create QR-oriented result and show QR in page
+- mobile: create jump/redirect-oriented result
+
+### WeChat routing
+
+- PC: QR result
+- non-WeChat H5:
+ - prefer H5 pay
+ - if unavailable, show "open in WeChat" requirement
+- WeChat environment:
+ - prefer MP/JSAPI
+ - if unavailable, fall back to H5 pay
+
+### Payment completion
+
+- frontend return restores context and UI state
+- backend order state remains source of truth
+- webhook and/or order query remain authoritative for fulfillment
+
+## Admin Configuration Model
+
+### Auth provider settings
+
+- email registration and verification settings
+- force email on third-party signup
+- LinuxDo client settings
+- OIDC issuer/client settings and provider display name
+- WeChat `open` and `mp` settings with config-valid and health indicators
+
+### Source default settings
+
+Per source (`email`, `linuxdo`, `oidc`, `wechat`):
+
+- default balance
+- default concurrency
+- default subscriptions
+- grant on signup
+- grant on first bind
+
+### Payment settings
+
+- active source for `alipay`
+- active source for `wechat`
+- source-specific credentials and enablement
+- WeChat capability matrix:
+ - QR available
+ - H5 available
+ - MP/JSAPI available
+
+### Scheduling settings
+
+- OpenAI advanced scheduling enabled/disabled
+- default disabled
+
+## Compatibility And Rollout
+
+Compatibility is mandatory, especially for:
+
+- existing email users
+- existing LinuxDo users
+- historical LinuxDo synthetic-email accounts
+
+### Additive migrations
+
+- preserve existing `users` data and behavior
+- add identity and pending-session tables
+- avoid destructive schema swaps
+
+### Migration backfill
+
+- backfill canonical `email` identities for valid existing email users
+- backfill canonical `linuxdo` identities during migration for historical synthetic-email LinuxDo users
+- backfill must be idempotent and repeatable
+
+### Compatibility reads
+
+During rollout:
+
+- read new identity model first
+- where necessary, retain compatibility logic for existing email and historical LinuxDo synthetic-email recognition
+
+### Grant idempotency
+
+- migration backfill must not trigger signup or first-bind grants
+- first-bind grants must use explicit idempotency tracking
+
+### API compatibility
+
+Retain transitional support for legacy/new request and response shapes where needed, including:
+
+- `pending_auth_token`
+- `pending_oauth_token`
+- old callback parsing expectations
+- historical profile field mappings
+
+### Settings and payment compatibility
+
+- preserve existing payment configs and order semantics from `main`
+- add new settings incrementally
+- avoid rewriting the entire settings schema in one cutover
+
+### Rolling upgrade tolerance
+
+- do not assume simultaneous frontend/backend deployment
+- new backend must tolerate short-lived older frontend request shapes
+
+## Testing Strategy
+
+### Repository tests
+
+- identity upsert and lookup
+- WeChat channel mapping
+- pending auth session persistence
+- source grant idempotency
+- avatar persistence and delete
+- migration backfill behavior
+
+### Service tests
+
+- direct login by existing identity
+- first third-party signup
+- forced email flow
+- existing-email bind-existing-account flow
+- first-bind grant on/off
+- nickname/avatar adoption choices
+- WeChat `unionid` required behavior
+- payment routing resolution
+
+### Handler and route tests
+
+- LinuxDo/OIDC/WeChat callback handling
+- bind-existing
+- bind-current-user
+- create-account
+- TOTP continuation
+- payment create and recovery
+
+### Frontend tests
+
+- third-party callback flow state machine
+- register/login continuation
+- profile bindings card
+- avatar interactions
+- payment page routing behavior
+- admin settings UI
+
+### Compatibility tests
+
+- existing email users
+- historical LinuxDo synthetic-email users
+- historical payment config
+- legacy auth payload field names
+- historical payment result handling
+
+## Implementation Phases
+
+1. Add schema, migrations, compatibility backfill, and repository support
+2. Implement unified identity services and pending auth session flows
+3. Integrate profile binding, avatar, and adoption decision flows
+4. Add per-source default grants and admin config surfaces
+5. Rebuild payment routing abstraction and frontend payment UX
+6. Add user-management sorting and OpenAI advanced scheduling switch
+7. Run compatibility, rollout, and regression hardening
+
+## External Constraints And Best Practices
+
+Implementation must follow current primary-source guidance:
+
+- OAuth 2.0 Security BCP (RFC 9700): strict redirect handling, state protection, mix-up resistant design
+- PKCE (RFC 7636): use on authorization code flows where applicable
+- OpenID Connect Core: stable issuer/subject handling for OIDC identities
+- Account linking best practice: require explicit user confirmation or re-authentication before linking to existing accounts
+
+References:
+
+- RFC 9700:
+- RFC 7636:
+- OpenID Connect Core 1.0:
+- Auth0 account linking guidance:
From 721d7ab3ab7b9c0ad6c96ec3e0ed23d9c7062951 Mon Sep 17 00:00:00 2001
From: IanShaw027
Date: Mon, 20 Apr 2026 13:40:31 +0800
Subject: [PATCH 017/189] docs: add audit synthesis to auth identity spec
---
...auth-identity-payment-foundation-design.md | 124 ++++++++++++++++++
1 file changed, 124 insertions(+)
diff --git a/docs/superpowers/specs/2026-04-20-auth-identity-payment-foundation-design.md b/docs/superpowers/specs/2026-04-20-auth-identity-payment-foundation-design.md
index a60cfdca..bfa250d9 100644
--- a/docs/superpowers/specs/2026-04-20-auth-identity-payment-foundation-design.md
+++ b/docs/superpowers/specs/2026-04-20-auth-identity-payment-foundation-design.md
@@ -538,3 +538,127 @@ References:
- RFC 7636:
- OpenID Connect Core 1.0:
- Auth0 account linking guidance:
+
+## Audit Synthesis
+
+The clean rebuild direction is not to copy either existing branch directly.
+
+- `feat/auth-identity-foundation` has the better long-term model:
+ - unified auth identities
+ - pending auth sessions
+ - identity adoption decisions
+ - provider-scoped default grants
+ - payment display-method abstraction
+ - OpenAI advanced scheduler layering
+- `personal-dev-branch` has the better real-world closure:
+ - LinuxDo and WeChat callback flows are more operationally complete
+ - profile binding and avatar UX is more complete
+ - historical synthetic-email users are recognized and recovered in live flows
+ - WeChat payment OAuth and recovery behavior is more complete
+- Primary-source guidance supplies hard constraints for OAuth/OIDC, account linking, WeChat identity handling, and payment completion semantics.
+
+The final rebuild must therefore:
+
+- keep the `feat/auth-identity-foundation` data model direction
+- absorb the strongest business-flow behavior from `personal-dev-branch`
+- reject transitional or half-finished behavior from both branches
+- treat compatibility and rollout as first-class implementation scope
+
+## Keep / Adapt / Drop
+
+### Keep
+
+Keep these architectural choices essentially intact:
+
+- `auth_identities`, `auth_identity_channels`, `pending_auth_sessions`, `identity_adoption_decisions`
+- per-provider default grants with one-time grant tracking
+- WeChat canonical identity plus channel mapping model
+- pending-auth verification gates before final bind
+- payment visible-method abstraction (`alipay`, `wechat`) decoupled from backend provider source
+- OpenAI advanced scheduler layering and test-backed behavior
+
+Keep these operational flow ideas from `personal-dev-branch`:
+
+- LinuxDo pending identity callback flow
+- WeChat pending identity callback flow
+- profile bindings UX and “cannot disconnect last usable login method” rule
+- separate WeChat login OAuth and WeChat payment OAuth entry points
+- historical synthetic-email recognition logic as a migration bridge
+
+### Adapt
+
+These areas must be reimplemented with the same intent but stricter boundaries:
+
+- third-party account creation from pending-auth state must be transactional and must not register a plain local user before identity finalization succeeds
+- email identity lifecycle must become real dual-write state, not just one migration-time backfill
+- `signup_source` must be backfilled more accurately for known historical third-party users
+- WeChat payment recovery state must move from frontend-only storage to server-backed continuation state
+- avatar adoption fetches must be security-hardened and failure-visible
+- pending-auth payload modeling must clearly separate immutable upstream payload from mutable local metadata
+- profile binding/avatar DTOs must be simplified to one authoritative backend contract instead of sprawling frontend fallback parsing
+- admin settings should preserve capability while reducing duplicated or transitional config branches
+
+### Drop
+
+Drop these as long-term design choices:
+
+- `user_external_identities` as the primary long-term identity model
+- synthetic email as a long-term canonical identity representation
+- OIDC as a side-path that does not participate in the same identity foundation as LinuxDo and WeChat
+- frontend multi-endpoint probing and broad compatibility parsing once the clean branch becomes the sole supported contract
+- unrelated branch noise such as generated-file churn, locale-only churn, or upstream merge residue as design inputs
+
+## Audit-Driven Hard Constraints
+
+The audit and source review establish these hard constraints:
+
+### Auth
+
+- all authorization-code providers use PKCE where applicable
+- callback handling uses strict `redirect_uri` discipline and state validation
+- OIDC identity key is `issuer + sub`
+- existing-account linking after email conflict must require explicit user action plus local-account verification
+- WeChat canonical identity key is `unionid`; `openid` is channel-scoped only
+
+### Compatibility
+
+- existing email users must continue to work with no manual intervention
+- existing LinuxDo users must not split into duplicate accounts
+- historical LinuxDo synthetic-email users must be backfilled into canonical LinuxDo identities during migration, not only lazily on next login
+- migration backfills must not trigger signup or first-bind grants
+- legacy `pending_auth_token` and `pending_oauth_token` contracts must remain accepted during rollout
+- legacy auth/public setting aliases needed by older frontend builds must remain available during rollout
+- existing payment configs and historical order semantics must remain valid
+
+### Payment
+
+- frontend return pages do not determine final payment success
+- backend order state, webhook processing, and/or provider status query remain authoritative
+- each visible method (`alipay`, `wechat`) may have only one active backend source at a time
+
+## Known Risks To Eliminate In Implementation
+
+These are specifically observed problems in the existing branches that the clean rebuild must eliminate:
+
+- third-party forced-email account creation currently bypasses the provider-aware account creation path and can leave orphan local accounts if bind finalization fails
+- post-migration email accounts are not fully dual-written into `auth_identities`
+- avatar adoption currently risks silent failure and insecure outbound fetch behavior
+- pending-auth payload responsibilities are internally inconsistent
+- OIDC parity is incomplete in `personal-dev-branch`; it must become a first-class provider in the unified identity model
+- WeChat union/open/channel identity handling is conceptually correct in the feature branch but still partially transitional across the codebase
+- WeChat payment recovery in `personal-dev-branch` is frontend-local and not robust across tabs or concurrent attempts
+- the existing pending-auth migration update is too destructive to reuse unchanged in a safer rollout
+- historical provider provenance should not be permanently flattened to `signup_source = email`
+
+## Rollout Gates
+
+The rebuild is not ready for rollout until all of these are satisfied:
+
+1. Identity schema and migration chain are linearized and production-safe.
+2. Email identity backfill is complete and idempotent.
+3. Historical LinuxDo synthetic-email backfill to canonical LinuxDo identity is complete and idempotent.
+4. `signup_source` backfill is accurate for known historical provider-created users.
+5. Dual token acceptance and required legacy field aliases are present.
+6. Existing payment configs are normalized and verified against current frontend-visible capabilities.
+7. New frontend flows are verified against mixed-version backend compatibility windows.
+8. Duplicate-account creation, first-bind grants, and payment route selection have regression coverage.
From b6751f7ebce43ef115ca583374b2e9c8d3d2dc28 Mon Sep 17 00:00:00 2001
From: IanShaw027
Date: Mon, 20 Apr 2026 13:47:00 +0800
Subject: [PATCH 018/189] docs: add auth identity implementation plan
---
...-04-20-auth-identity-payment-foundation.md | 476 ++++++++++++++++++
1 file changed, 476 insertions(+)
create mode 100644 docs/superpowers/plans/2026-04-20-auth-identity-payment-foundation.md
diff --git a/docs/superpowers/plans/2026-04-20-auth-identity-payment-foundation.md b/docs/superpowers/plans/2026-04-20-auth-identity-payment-foundation.md
new file mode 100644
index 00000000..e8fde9c0
--- /dev/null
+++ b/docs/superpowers/plans/2026-04-20-auth-identity-payment-foundation.md
@@ -0,0 +1,476 @@
+# Auth Identity Payment Foundation Implementation Plan
+
+> **For agentic workers:** REQUIRED SUB-SKILL: Use `superpowers:subagent-driven-development` (recommended) or `superpowers:executing-plans` to implement this plan task-by-task. Steps use checkbox (`- [ ]`) syntax for tracking.
+
+**Goal:** Rebuild the auth identity, profile binding, payment routing, and OpenAI advanced scheduler foundation on top of a clean `origin/main` branch while preserving historical compatibility for existing email users and historical LinuxDo users.
+**Architecture:** A unified identity foundation centered on durable provider subjects (`email`, `linuxdo`, `oidc`, `wechat`) and transactional pending-auth sessions; backend-owned payment source routing behind stable frontend methods (`alipay`, `wxpay`); compatibility-first migration/backfill before feature enablement.
+**Tech Stack:** Go, Gin, Ent, PostgreSQL, Redis, Vue 3, Pinia, TypeScript, Vitest, pnpm.
+
+---
+
+## Non-Negotiable Product Rules
+
+- [ ] Preserve login continuity for existing email users and historical LinuxDo users.
+- [ ] During migration, backfill historical LinuxDo synthetic-email users into explicit LinuxDo identities before first post-upgrade login.
+- [ ] Keep existing email login and add third-party login/bind for `linuxdo`, `oidc`, and `wechat`.
+- [ ] On first third-party login:
+ - identity exists: direct login.
+ - identity does not exist: start pending-auth flow.
+ - local email binding is required only when system config says so.
+ - upstream provider email verification never counts as local email verification.
+- [ ] When user-entered and locally verified email already exists:
+ - offer bind-existing-account after local re-authentication.
+ - offer change-email-and-create-new-account.
+ - when email binding is mandatory, do not allow bypass without changing to another email.
+- [ ] On first third-party login or first third-party bind, provider nickname/avatar must be presented as independent replace options for the current nickname and avatar. They are not auto-applied.
+- [ ] Source-specific initial grants must support per-source defaults for balance, concurrency, and subscriptions.
+- [ ] Default grant timing: on successful new-account creation.
+- [ ] Optional grant timing: on first successful bind for the configured source.
+- [ ] Migration/backfill must never trigger first-bind or first-signup grants retroactively.
+- [ ] Avatar profile supports:
+ - direct URL storage.
+ - image data URL upload compressed to `<=100KB` before storing in DB.
+ - explicit delete.
+- [ ] Admin user management must expose and sort by `last_login_at` and `last_active_at`.
+- [ ] WeChat login rules:
+ - WeChat environment uses MP login.
+ - non-WeChat browser uses Open/QR login.
+ - canonical identity uses `unionid`.
+ - when `unionid` is unavailable, fail the login/bind flow under the approved option-1 policy.
+- [ ] Payment UI rules:
+ - user-facing methods stay `支付宝` and `微信支付`.
+ - backend decides whether each method routes to official provider instance or EasyPay.
+ - at runtime, each visible method may only have one active source.
+- [ ] Alipay rules:
+ - PC: in-page QR.
+ - mobile browser: jump to Alipay payment.
+- [ ] WeChat Pay rules:
+ - PC: in-page QR.
+ - WeChat H5: MP/JSAPI first, fallback to H5 pay.
+ - non-WeChat H5: H5 pay, or prompt to open in WeChat when unavailable.
+- [ ] Payment success pages are informational only; actual fulfillment depends on webhook or server-side reconciliation.
+- [ ] OpenAI advanced scheduler is available but default-disabled.
+
+## Hard Technical Constraints From Audit
+
+- [ ] Browser-based third-party auth must use Authorization Code + PKCE `S256`.
+- [ ] OIDC identity primary key must be `(issuer, subject)`, not email.
+- [ ] Email equality must never auto-link accounts.
+- [ ] Bind-existing-account must require explicit local re-authentication and TOTP verification when enabled.
+- [ ] OAuth redirect URI must be fixed server config, exact-match, and never derived from user input.
+- [ ] User-supplied redirect may only choose a normalized same-origin internal route after completion.
+- [ ] WeChat canonical identity must be `unionid`; `openid` remains channel/app-scoped support data only.
+- [ ] Every payment order must snapshot the selected provider instance and reuse that exact instance for callback verification, reconciliation, refund, and audit.
+- [ ] Frontend must not receive first-party bearer tokens through callback URL fragments in the rebuilt flow.
+- [ ] Public payment result polling must not expose order data by raw `out_trade_no` alone; use authenticated lookup or signed opaque result token.
+
+## Baseline Notes
+
+- [ ] Current clean branch head when this plan was written: `721d7ab3`.
+- [ ] Baseline backend verification on clean `origin/main`: `cd backend && go test ./...` passes.
+- [ ] Baseline frontend verification on clean `origin/main`: `cd frontend && pnpm test:run` currently fails in unrelated existing suites. New work must add targeted tests and avoid claiming full frontend green until those baseline failures are addressed separately.
+- [ ] Existing migration directory currently ends at `107_*`; this rebuild reserves `108` through `111`.
+
+## Target File Map
+
+### New backend migrations
+
+- [ ] `backend/migrations/108_auth_identity_foundation_core.sql`
+- [ ] `backend/migrations/109_auth_identity_compat_backfill.sql`
+- [ ] `backend/migrations/110_pending_auth_and_provider_default_grants.sql`
+- [ ] `backend/migrations/111_payment_routing_and_scheduler_flags.sql`
+
+### New or rebuilt Ent schema
+
+- [ ] `backend/ent/schema/auth_identity.go`
+- [ ] `backend/ent/schema/auth_identity_channel.go`
+- [ ] `backend/ent/schema/pending_auth_session.go`
+- [ ] `backend/ent/schema/identity_adoption_decision.go`
+
+### New or rebuilt backend repositories/services/handlers
+
+- [ ] `backend/internal/repository/user_profile_identity_repo.go`
+- [ ] `backend/internal/repository/user_profile_identity_repo_contract_test.go`
+- [ ] `backend/internal/repository/auth_identity_migration_report.go`
+- [ ] `backend/internal/service/auth_identity_flow.go`
+- [ ] `backend/internal/service/auth_identity_flow_test.go`
+- [ ] `backend/internal/service/auth_pending_identity_service.go`
+- [ ] `backend/internal/service/auth_pending_identity_service_test.go`
+- [ ] `backend/internal/service/payment_config_service.go`
+- [ ] `backend/internal/service/payment_order.go`
+- [ ] `backend/internal/service/payment_order_lifecycle.go`
+- [ ] `backend/internal/service/payment_fulfillment.go`
+- [ ] `backend/internal/service/openai_account_scheduler.go`
+- [ ] `backend/internal/handler/auth_pending_identity_flow.go`
+- [ ] `backend/internal/handler/auth_linuxdo_oauth.go`
+- [ ] `backend/internal/handler/auth_oidc_oauth.go`
+- [ ] `backend/internal/handler/auth_wechat_oauth.go`
+- [ ] `backend/internal/handler/auth_handler.go`
+- [ ] `backend/internal/handler/user_handler.go`
+- [ ] `backend/internal/handler/payment_handler.go`
+- [ ] `backend/internal/handler/payment_webhook_handler.go`
+- [ ] `backend/internal/handler/admin/user_handler.go`
+- [ ] `backend/internal/handler/admin/setting_handler.go`
+
+### New or rebuilt frontend API/store/views/components
+
+- [ ] `frontend/src/api/auth.ts`
+- [ ] `frontend/src/api/user.ts`
+- [ ] `frontend/src/api/payment.ts`
+- [ ] `frontend/src/api/admin/settings.ts`
+- [ ] `frontend/src/api/admin/users.ts`
+- [ ] `frontend/src/stores/auth.ts`
+- [ ] `frontend/src/stores/payment.ts`
+- [ ] `frontend/src/components/auth/ThirdPartyAuthCallbackFlow.vue`
+- [ ] `frontend/src/components/auth/LinuxDoOAuthSection.vue`
+- [ ] `frontend/src/components/auth/OidcOAuthSection.vue`
+- [ ] `frontend/src/components/auth/WechatOAuthSection.vue`
+- [ ] `frontend/src/components/user/profile/ProfileAccountBindingsCard.vue`
+- [ ] `frontend/src/components/user/profile/ProfileInfoCard.vue`
+- [ ] `frontend/src/views/auth/LinuxDoCallbackView.vue`
+- [ ] `frontend/src/views/auth/OidcCallbackView.vue`
+- [ ] `frontend/src/views/auth/WechatCallbackView.vue`
+- [ ] `frontend/src/views/user/ProfileView.vue`
+- [ ] `frontend/src/views/user/PaymentView.vue`
+- [ ] `frontend/src/views/user/PaymentQRCodeView.vue`
+- [ ] `frontend/src/views/user/PaymentResultView.vue`
+
+## Phase 1: Migration And Compatibility Foundation
+
+### Task 1. Create core identity schema migration
+
+- [ ] Implement `backend/migrations/108_auth_identity_foundation_core.sql` with:
+ - `auth_identities`
+ - `auth_identity_channels`
+ - `pending_auth_sessions`
+ - `identity_adoption_decisions`
+ - `users.last_login_at`
+ - `users.last_active_at`
+ - grant-tracking columns/tables required to prevent double-award
+- [ ] Add uniqueness/index rules:
+ - one canonical identity per `(provider, provider_subject)`
+ - one channel record per `(provider, provider_channel, provider_app_id, provider_channel_subject)`
+ - one adoption decision per pending session
+- [ ] Preserve null-safe compatibility defaults so historical rows remain readable before backfill finishes.
+- [ ] Add explicit rollback blocks only where safe; never repeat the destructive pattern observed in old `112_update_pending_auth_sessions.sql`.
+
+### Task 2. Materialize historical identities before runtime
+
+- [ ] Implement `backend/migrations/109_auth_identity_compat_backfill.sql` to backfill:
+ - existing email users into `auth_identities(provider=email, provider_subject=normalized_email)`
+ - historical LinuxDo users into `auth_identities(provider=linuxdo, provider_subject=linuxdo_subject)`
+ - historical synthetic-email LinuxDo users into explicit LinuxDo identity rows by parsing legacy email mode and legacy provider metadata
+ - profile/channel rows from historical `user_external_identities`-style data when present in upgraded databases
+- [ ] Write migration report output in `backend/internal/repository/auth_identity_migration_report.go` so production can inspect unmatched rows instead of silently skipping them.
+- [ ] Set `signup_source` and provider provenance when recoverable from historical data. Do not flatten everything to `email`.
+
+### Task 3. Provider default grant and scheduler config migration
+
+- [ ] Implement `backend/migrations/110_pending_auth_and_provider_default_grants.sql` for:
+ - provider-specific initial balance/concurrency/subscription defaults
+ - grant timing flags: `on_signup`, optional `on_first_bind`
+ - email-required-on-third-party-signup flags
+ - profile avatar storage columns/settings
+- [ ] Implement `backend/migrations/111_payment_routing_and_scheduler_flags.sql` for:
+ - stable payment method to provider-instance routing
+ - admin exclusivity flags for `alipay` and `wxpay`
+ - advanced scheduler enable flag defaulting to disabled
+
+### Task 4. Generate Ent and compile migration-safe model layer
+
+- [ ] Add the schema definitions in:
+ - `backend/ent/schema/auth_identity.go`
+ - `backend/ent/schema/auth_identity_channel.go`
+ - `backend/ent/schema/pending_auth_session.go`
+ - `backend/ent/schema/identity_adoption_decision.go`
+- [ ] Run:
+ ```bash
+ cd backend
+ go generate ./ent
+ ```
+- [ ] Compile after generation:
+ ```bash
+ cd backend
+ go test ./... -run '^$'
+ ```
+- [ ] Commit checkpoint:
+ ```bash
+ git add backend/migrations backend/ent/schema backend/ent
+ git commit -m "feat: add auth identity foundation schema"
+ ```
+
+## Phase 2: Backend Identity Flow Rebuild
+
+### Task 5. Build a single repository contract for identity lookups and grants
+
+- [ ] Implement `backend/internal/repository/user_profile_identity_repo.go` with transactional helpers for:
+ - get user by canonical identity
+ - get user by channel identity
+ - create canonical + channel identity together
+ - bind identity to existing user after verified re-auth
+ - record one-time provider grant award
+ - record adoption preference decisions
+ - update `last_login_at` and `last_active_at`
+- [ ] Add repository contract coverage in `backend/internal/repository/user_profile_identity_repo_contract_test.go`.
+- [ ] Enforce dual-write for email registration/login so `users.email` and `auth_identities(provider=email, ...)` stay consistent from this phase onward.
+
+### Task 6. Rebuild transactional pending-auth service
+
+- [ ] Implement `backend/internal/service/auth_pending_identity_service.go` and tests to own these flows:
+ - create pending session from third-party callback
+ - verify local email code
+ - create new account from pending session with correct `signup_source`
+ - bind pending identity to existing account after password/TOTP re-auth
+ - apply configured provider defaults on the correct trigger only once
+ - store provider nickname/avatar candidates and user opt-in replacement decisions independently
+- [ ] Keep pending session payload normalized:
+ - provider identity fields live in typed columns/JSON structure
+ - avoid the old branch’s mixed `metadata` and `upstream_identity_payload` ambiguity
+- [ ] Do not call plain email registration helpers from this flow. The old feature branch bug where pending third-party signup fell back to `RegisterWithVerification` must not reappear.
+
+### Task 7. Rebuild provider callback adapters
+
+- [ ] Refactor these handlers to thin adapters over the shared pending-auth service:
+ - `backend/internal/handler/auth_linuxdo_oauth.go`
+ - `backend/internal/handler/auth_oidc_oauth.go`
+ - `backend/internal/handler/auth_wechat_oauth.go`
+- [ ] For OIDC:
+ - require PKCE `S256`, `state`, and `nonce`
+ - validate `iss`, `aud`, optional `azp`, `exp`, `nonce`
+ - persist canonical identity as `(issuer, sub)`
+- [ ] For WeChat:
+ - MP flow in WeChat UA
+ - Open/QR flow outside WeChat UA
+ - persist channel identity by `(channel, appid, openid)`
+ - persist canonical identity by `unionid`
+ - hard-fail when `unionid` is absent under the approved product policy
+- [ ] Replace callback URL fragment token delivery with backend session completion or one-time exchange code consumed by `frontend/src/stores/auth.ts`.
+
+### Task 8. Rebuild auth endpoints and profile binding endpoints
+
+- [ ] Implement `backend/internal/handler/auth_pending_identity_flow.go` for:
+ - fetch pending session summary
+ - submit verified email
+ - choose create-new-account or bind-existing-account
+ - submit nickname/avatar replacement choices
+- [ ] Update `backend/internal/handler/auth_handler.go` and `backend/internal/handler/user_handler.go` to expose:
+ - current bindings summary
+ - start-bind endpoints for LinuxDo/OIDC/WeChat
+ - disconnect endpoints with safety checks
+ - avatar upload/delete endpoints
+- [ ] Avatar handling requirements:
+ - allow external URL
+ - allow data URL upload
+ - compress image payload to `<=100KB`
+ - store compressed value in DB
+ - deleting custom avatar must not implicitly resurrect stale provider avatar unless the user explicitly chooses provider avatar again
+
+### Task 9. Add admin visibility and sorting
+
+- [ ] Update `backend/internal/handler/admin/user_handler.go` and supporting query/service code so admin list supports:
+ - `last_login_at`
+ - `last_active_at`
+ - sorting by both
+ - binding/provider summary columns
+- [ ] Update `backend/internal/handler/admin/setting_handler.go` and setting service code for:
+ - provider initial grant config
+ - mandatory-email-on-third-party-signup config
+ - payment source exclusivity config
+ - advanced scheduler toggle
+
+### Task 10. Backend verification checkpoint
+
+- [ ] Run targeted backend tests:
+ ```bash
+ cd backend
+ go test ./internal/repository -run 'TestUserProfileIdentity|TestAuthIdentityMigration'
+ go test ./internal/service -run 'TestAuthIdentityFlow|TestPendingAuthIdentity|TestOpenAIAccountScheduler'
+ go test ./internal/handler -run 'TestLinuxDo|TestOidc|TestWechat|TestPaymentWebhook'
+ go test ./...
+ ```
+- [ ] Commit checkpoint:
+ ```bash
+ git add backend
+ git commit -m "feat: rebuild auth identity backend flows"
+ ```
+
+## Phase 3: Frontend Third-Party Flow And Profile UX
+
+### Task 11. Rebuild callback flow UI around pending session decisions
+
+- [ ] Rebuild `frontend/src/components/auth/ThirdPartyAuthCallbackFlow.vue` so it:
+ - loads pending-session summary from backend
+ - shows provider nickname/avatar candidates
+ - lets user independently choose nickname replacement and avatar replacement
+ - handles create-new-account vs bind-existing-account
+ - enforces verified local email before completion when required
+ - handles “email already exists” by branching to bind-existing-account or change-email-and-create-new-account
+- [ ] Update:
+ - `frontend/src/views/auth/LinuxDoCallbackView.vue`
+ - `frontend/src/views/auth/OidcCallbackView.vue`
+ - `frontend/src/views/auth/WechatCallbackView.vue`
+ - `frontend/src/api/auth.ts`
+ - `frontend/src/stores/auth.ts`
+- [ ] Replace any token-fragment bootstrap with backend session completion or one-time exchange code flow.
+
+### Task 12. Rebuild profile account binding and avatar UX
+
+- [ ] Rebuild `frontend/src/components/user/profile/ProfileAccountBindingsCard.vue` to:
+ - show linked LinuxDo/OIDC/WeChat providers
+ - start bind/unbind flows
+ - show provider avatars and nicknames as reference only
+ - prevent unsafe disconnect when it would strand the account
+- [ ] Rebuild `frontend/src/components/user/profile/ProfileInfoCard.vue` and `frontend/src/views/user/ProfileView.vue` to:
+ - support avatar URL entry
+ - support data URL upload/compression preview
+ - support avatar delete
+ - clearly separate current profile nickname/avatar from provider-sourced suggested nickname/avatar
+
+### Task 13. Add frontend tests for rebuilt auth/profile flows
+
+- [ ] Add or update:
+ - `frontend/src/components/auth/__tests__/ThirdPartyAuthCallbackFlow.spec.ts`
+ - `frontend/src/components/auth/__tests__/LinuxDoCallbackView.spec.ts`
+ - `frontend/src/components/auth/__tests__/WechatCallbackView.spec.ts`
+ - `frontend/src/components/user/profile/__tests__/ProfileAccountBindingsCard.spec.ts`
+ - `frontend/src/components/user/profile/__tests__/ProfileInfoCard.spec.ts`
+- [ ] Cover:
+ - email-required branch
+ - email-conflict branch
+ - bind-existing-account with re-auth prompt
+ - nickname replacement only
+ - avatar replacement only
+ - neither replacement
+ - avatar delete after prior provider adoption
+
+## Phase 4: Payment Routing Rebuild
+
+### Task 14. Normalize payment routing backend
+
+- [ ] Rebuild `backend/internal/service/payment_config_service.go` to expose a stable method-routing contract:
+ - frontend visible methods remain `alipay` and `wxpay`
+ - admin chooses which provider instance serves each method
+ - runtime validation guarantees only one active source per visible method
+- [ ] Rebuild `backend/internal/service/payment_order.go` and `backend/internal/service/payment_order_lifecycle.go` so order creation snapshots:
+ - visible method
+ - selected provider instance id
+ - provider type
+ - provider capability mode
+- [ ] Rebuild `backend/internal/handler/payment_handler.go` for UX rules:
+ - Alipay PC: QR page
+ - Alipay mobile: direct jump
+ - WeChat PC: QR page
+ - WeChat H5 in WeChat: MP/JSAPI first, fallback to H5
+ - WeChat H5 outside WeChat: H5 or “open in WeChat” prompt when unavailable
+- [ ] Never derive canonical return URL from `Referer`; use configured or signed internal callback targets only.
+
+### Task 15. Make fulfillment and reconciliation provider-instance-safe
+
+- [ ] Rebuild `backend/internal/handler/payment_webhook_handler.go` and `backend/internal/service/payment_fulfillment.go` so:
+ - verification uses the order’s original provider instance
+ - webhook processing is idempotent by provider event id and internal order id
+ - missed webhook recovery uses server-side provider query, not frontend success return
+- [ ] Harden `frontend/src/views/user/PaymentResultView.vue` and `frontend/src/api/payment.ts` so result polling uses an authenticated order lookup or signed opaque token, not a raw public `out_trade_no` query.
+
+### Task 16. Rebuild payment frontend views
+
+- [ ] Rebuild `frontend/src/views/user/PaymentView.vue`, `frontend/src/views/user/PaymentQRCodeView.vue`, and `frontend/src/stores/payment.ts` so:
+ - only two buttons are shown to user: `支付宝` and `微信支付`
+ - frontend does not leak official-vs-EasyPay distinction
+ - route-specific copy handles QR, jump, MP, H5 fallback correctly
+- [ ] Add or update:
+ - `frontend/src/views/user/__tests__/PaymentView.spec.ts`
+ - `frontend/src/views/user/__tests__/PaymentResultView.spec.ts`
+ - backend webhook/payment routing tests
+
+### Task 17. Payment verification checkpoint
+
+- [ ] Run:
+ ```bash
+ cd backend
+ go test ./internal/service -run 'TestPayment'
+ go test ./internal/handler -run 'TestPayment'
+ cd ../frontend
+ pnpm test:run src/views/user/__tests__/PaymentView.spec.ts src/views/user/__tests__/PaymentResultView.spec.ts
+ ```
+- [ ] Commit checkpoint:
+ ```bash
+ git add backend frontend
+ git commit -m "feat: rebuild payment routing foundation"
+ ```
+
+## Phase 5: Scheduler, Rollout, And Final Compatibility Pass
+
+### Task 18. Gate advanced scheduler behind explicit config
+
+- [ ] Update `backend/internal/service/openai_account_scheduler.go` and related admin setting surfaces so:
+ - advanced scheduler remains compiled and testable
+ - default runtime state is disabled
+ - enablement is explicit through admin settings
+ - legacy scheduling behavior remains default on upgrade
+- [ ] Add targeted coverage in `backend/internal/service/openai_account_scheduler_test.go`.
+
+### Task 19. Complete compatibility and rollout safety checks
+
+- [ ] Add migration/repository tests covering:
+ - historical email-only user login after upgrade
+ - historical LinuxDo user login after upgrade
+ - historical synthetic-email LinuxDo user login after upgrade
+ - no retroactive grant replay during migration
+ - first-bind grant fires once only when enabled
+ - email identity dual-write stays consistent
+ - bind-existing-account requires password and TOTP where configured
+- [ ] Add deploy sequencing note to release docs or internal runbook:
+ 1. deploy schema and backfill release.
+ 2. inspect migration report for unmatched rows.
+ 3. deploy backend identity/payment compatibility code.
+ 4. deploy frontend callback/profile/payment UI.
+ 5. enable strict email-required signup or provider bind grants only after metrics are healthy.
+
+### Task 20. Final verification and handoff
+
+- [ ] Run final backend verification:
+ ```bash
+ cd backend
+ go test ./...
+ ```
+- [ ] Run targeted frontend verification:
+ ```bash
+ cd frontend
+ pnpm test:run \
+ src/components/auth/__tests__/ThirdPartyAuthCallbackFlow.spec.ts \
+ src/components/auth/__tests__/LinuxDoCallbackView.spec.ts \
+ src/components/auth/__tests__/WechatCallbackView.spec.ts \
+ src/components/user/profile/__tests__/ProfileAccountBindingsCard.spec.ts \
+ src/components/user/profile/__tests__/ProfileInfoCard.spec.ts \
+ src/views/user/__tests__/PaymentView.spec.ts \
+ src/views/user/__tests__/PaymentResultView.spec.ts
+ ```
+- [ ] Run focused manual smoke checks:
+ - email login with existing account
+ - LinuxDo existing-account login after migration
+ - third-party first login create-new-account path
+ - third-party first login bind-existing-account path
+ - first third-party bind with optional nickname/avatar replacement
+ - PC Alipay QR
+ - mobile Alipay jump
+ - PC WeChat QR
+ - WeChat H5 MP/JSAPI path
+ - non-WeChat H5 fallback path
+- [ ] Commit final checkpoint:
+ ```bash
+ git add docs backend frontend
+ git commit -m "feat: rebuild auth identity and payment foundation"
+ ```
+
+## Review Checklist
+
+- [ ] No flow still relies on provider email equality for account linking.
+- [ ] No flow still creates third-party users through plain email registration helpers.
+- [ ] No callback still returns first-party bearer tokens in URL fragments.
+- [ ] No payment result view trusts provider return page as authoritative fulfillment.
+- [ ] No webhook verification path selects provider credentials from “currently active config” instead of the order snapshot.
+- [ ] Existing email users and historical LinuxDo users are covered by migration tests.
+- [ ] Avatar adoption and deletion semantics are explicit and reversible.
+- [ ] Grant timing is source-aware and one-time only.
+
From 584ded2182e2b1225a67e55d8a5485bbbdc35658 Mon Sep 17 00:00:00 2001
From: IanShaw027
Date: Mon, 20 Apr 2026 14:41:12 +0800
Subject: [PATCH 019/189] docs: harden auth identity payment design
---
...-04-20-auth-identity-payment-foundation.md | 87 ++++++++--
...auth-identity-payment-foundation-design.md | 149 +++++++++++++++---
2 files changed, 199 insertions(+), 37 deletions(-)
diff --git a/docs/superpowers/plans/2026-04-20-auth-identity-payment-foundation.md b/docs/superpowers/plans/2026-04-20-auth-identity-payment-foundation.md
index e8fde9c0..2d44e058 100644
--- a/docs/superpowers/plans/2026-04-20-auth-identity-payment-foundation.md
+++ b/docs/superpowers/plans/2026-04-20-auth-identity-payment-foundation.md
@@ -2,7 +2,7 @@
> **For agentic workers:** REQUIRED SUB-SKILL: Use `superpowers:subagent-driven-development` (recommended) or `superpowers:executing-plans` to implement this plan task-by-task. Steps use checkbox (`- [ ]`) syntax for tracking.
-**Goal:** Rebuild the auth identity, profile binding, payment routing, and OpenAI advanced scheduler foundation on top of a clean `origin/main` branch while preserving historical compatibility for existing email users and historical LinuxDo users.
+**Goal:** Rebuild the auth identity, profile binding, payment routing, and OpenAI advanced scheduler foundation on top of a clean `origin/main` branch while preserving historical compatibility for existing email users, existing LinuxDo users, historical LinuxDo/WeChat/OIDC synthetic-email users, and historical WeChat `openid`-only records.
**Architecture:** A unified identity foundation centered on durable provider subjects (`email`, `linuxdo`, `oidc`, `wechat`) and transactional pending-auth sessions; backend-owned payment source routing behind stable frontend methods (`alipay`, `wxpay`); compatibility-first migration/backfill before feature enablement.
**Tech Stack:** Go, Gin, Ent, PostgreSQL, Redis, Vue 3, Pinia, TypeScript, Vitest, pnpm.
@@ -10,8 +10,9 @@
## Non-Negotiable Product Rules
-- [ ] Preserve login continuity for existing email users and historical LinuxDo users.
-- [ ] During migration, backfill historical LinuxDo synthetic-email users into explicit LinuxDo identities before first post-upgrade login.
+- [ ] Preserve login continuity for existing email users, existing LinuxDo users, and historically migrated third-party users.
+- [ ] During migration, backfill historical LinuxDo/WeChat/OIDC synthetic-email users into explicit third-party identities before first post-upgrade login whenever deterministic recovery is possible.
+- [ ] During migration, surface historical WeChat `openid`-only records through explicit migration reports and remediation rules; do not silently reinterpret them as valid canonical identities.
- [ ] Keep existing email login and add third-party login/bind for `linuxdo`, `oidc`, and `wechat`.
- [ ] On first third-party login:
- identity exists: direct login.
@@ -37,6 +38,11 @@
- non-WeChat browser uses Open/QR login.
- canonical identity uses `unionid`.
- when `unionid` is unavailable, fail the login/bind flow under the approved option-1 policy.
+- [ ] OIDC rules:
+ - browser authorization-code flow always uses PKCE `S256`.
+ - discovery issuer and ID token `iss` must match exactly.
+ - `userinfo.sub` must match ID token `sub` when UserInfo is used.
+ - upstream `email_verified` does not satisfy local email verification.
- [ ] Payment UI rules:
- user-facing methods stay `支付宝` and `微信支付`.
- backend decides whether each method routes to official provider instance or EasyPay.
@@ -49,20 +55,26 @@
- WeChat H5: MP/JSAPI first, fallback to H5 pay.
- non-WeChat H5: H5 pay, or prompt to open in WeChat when unavailable.
- [ ] Payment success pages are informational only; actual fulfillment depends on webhook or server-side reconciliation.
+- [ ] WeChat in-app payment requiring `openid` must use a dedicated server-backed payment OAuth resume flow rather than frontend-only recovery state.
- [ ] OpenAI advanced scheduler is available but default-disabled.
## Hard Technical Constraints From Audit
- [ ] Browser-based third-party auth must use Authorization Code + PKCE `S256`.
+- [ ] PKCE must not be admin-configurable off for browser authorization-code providers.
- [ ] OIDC identity primary key must be `(issuer, subject)`, not email.
- [ ] Email equality must never auto-link accounts.
- [ ] Bind-existing-account must require explicit local re-authentication and TOTP verification when enabled.
+- [ ] Bind-current-user must originate from an already-authenticated local user and preserve explicit bind intent across callback completion.
- [ ] OAuth redirect URI must be fixed server config, exact-match, and never derived from user input.
- [ ] User-supplied redirect may only choose a normalized same-origin internal route after completion.
- [ ] WeChat canonical identity must be `unionid`; `openid` remains channel/app-scoped support data only.
-- [ ] Every payment order must snapshot the selected provider instance and reuse that exact instance for callback verification, reconciliation, refund, and audit.
+- [ ] Every canonical identity uniqueness rule must include provider namespace (`provider_key`) consistently.
+- [ ] Callback completion must use backend session completion or a one-time opaque exchange code that is short-lived, one-time, browser-session-bound, `POST`-redeemed, and unusable as a bearer token.
+- [ ] Every payment order must snapshot the selected provider instance plus the order-time verification inputs required for callback verification, reconciliation, refund, and audit.
- [ ] Frontend must not receive first-party bearer tokens through callback URL fragments in the rebuilt flow.
- [ ] Public payment result polling must not expose order data by raw `out_trade_no` alone; use authenticated lookup or signed opaque result token.
+- [ ] WeChat Pay webhook handling must verify signature, decrypt payload, and compare `appid`, `mchid`, `out_trade_no`, `amount`, `currency`, and provider trade state against the order snapshot before fulfillment.
## Baseline Notes
@@ -100,6 +112,8 @@
- [ ] `backend/internal/service/payment_order.go`
- [ ] `backend/internal/service/payment_order_lifecycle.go`
- [ ] `backend/internal/service/payment_fulfillment.go`
+- [ ] `backend/internal/service/payment_resume_service.go`
+- [ ] `backend/internal/service/payment_resume_service_test.go`
- [ ] `backend/internal/service/openai_account_scheduler.go`
- [ ] `backend/internal/handler/auth_pending_identity_flow.go`
- [ ] `backend/internal/handler/auth_linuxdo_oauth.go`
@@ -148,9 +162,10 @@
- `users.last_active_at`
- grant-tracking columns/tables required to prevent double-award
- [ ] Add uniqueness/index rules:
- - one canonical identity per `(provider, provider_subject)`
+ - one canonical identity per `(provider, provider_key, provider_subject)`
- one channel record per `(provider, provider_channel, provider_app_id, provider_channel_subject)`
- one adoption decision per pending session
+- [ ] Model `pending_auth_sessions` so immutable upstream claims and mutable local flow state are stored separately; do not reintroduce a mixed `metadata` catch-all.
- [ ] Preserve null-safe compatibility defaults so historical rows remain readable before backfill finishes.
- [ ] Add explicit rollback blocks only where safe; never repeat the destructive pattern observed in old `112_update_pending_auth_sessions.sql`.
@@ -160,8 +175,10 @@
- existing email users into `auth_identities(provider=email, provider_subject=normalized_email)`
- historical LinuxDo users into `auth_identities(provider=linuxdo, provider_subject=linuxdo_subject)`
- historical synthetic-email LinuxDo users into explicit LinuxDo identity rows by parsing legacy email mode and legacy provider metadata
+ - historical synthetic-email WeChat users into explicit WeChat identities where `unionid` or equivalent deterministic provider identity is recoverable
+ - historical synthetic-email OIDC users into explicit OIDC identities where deterministic provider identity is recoverable
- profile/channel rows from historical `user_external_identities`-style data when present in upgraded databases
-- [ ] Write migration report output in `backend/internal/repository/auth_identity_migration_report.go` so production can inspect unmatched rows instead of silently skipping them.
+- [ ] Write migration report output in `backend/internal/repository/auth_identity_migration_report.go` so production can inspect unmatched rows, `openid`-only WeChat rows, and non-deterministic synthetic-email rows instead of silently skipping them.
- [ ] Set `signup_source` and provider provenance when recoverable from historical data. Do not flatten everything to `email`.
### Task 3. Provider default grant and scheduler config migration
@@ -173,6 +190,7 @@
- profile avatar storage columns/settings
- [ ] Implement `backend/migrations/111_payment_routing_and_scheduler_flags.sql` for:
- stable payment method to provider-instance routing
+ - visible-method normalization from historical `supported_types`, `payment_mode`, and legacy aliases such as `wxpay_direct`
- admin exclusivity flags for `alipay` and `wxpay`
- advanced scheduler enable flag defaulting to disabled
@@ -213,6 +231,7 @@
- update `last_login_at` and `last_active_at`
- [ ] Add repository contract coverage in `backend/internal/repository/user_profile_identity_repo_contract_test.go`.
- [ ] Enforce dual-write for email registration/login so `users.email` and `auth_identities(provider=email, ...)` stay consistent from this phase onward.
+- [ ] Add repository coverage proving `last_login_at` and `last_active_at` use the required field names and are not silently replaced by derived `last_used_at` logic.
### Task 6. Rebuild transactional pending-auth service
@@ -223,8 +242,15 @@
- bind pending identity to existing account after password/TOTP re-auth
- apply configured provider defaults on the correct trigger only once
- store provider nickname/avatar candidates and user opt-in replacement decisions independently
+- [ ] Implement callback completion so pending auth can finish through backend session completion or a one-time exchange code:
+ - short TTL
+ - one-time use
+ - browser-session binding
+ - `POST` redemption only
+ - safe mixed-version bridge to legacy pending-token aliases during rollout
- [ ] Keep pending session payload normalized:
- provider identity fields live in typed columns/JSON structure
+ - mutable local progression lives separately from immutable upstream claims
- avoid the old branch’s mixed `metadata` and `upstream_identity_payload` ambiguity
- [ ] Do not call plain email registration helpers from this flow. The old feature branch bug where pending third-party signup fell back to `RegisterWithVerification` must not reappear.
@@ -236,11 +262,13 @@
- `backend/internal/handler/auth_wechat_oauth.go`
- [ ] For OIDC:
- require PKCE `S256`, `state`, and `nonce`
- - validate `iss`, `aud`, optional `azp`, `exp`, `nonce`
+ - validate discovery issuer, `iss`, `aud`, optional `azp`, `exp`, and `nonce`
+ - verify `userinfo.sub == id_token.sub` when UserInfo is used
- persist canonical identity as `(issuer, sub)`
- [ ] For WeChat:
- MP flow in WeChat UA
- Open/QR flow outside WeChat UA
+ - website login uses authorization-code flow and persists channel/app binding
- persist channel identity by `(channel, appid, openid)`
- persist canonical identity by `unionid`
- hard-fail when `unionid` is absent under the approved product policy
@@ -253,6 +281,10 @@
- submit verified email
- choose create-new-account or bind-existing-account
- submit nickname/avatar replacement choices
+- [ ] Make bind-existing-account and bind-current-user flows explicit:
+ - no automatic linking on matching email
+ - fresh password/TOTP proof is scoped to the intended target account only
+ - no automatic metadata merge beyond explicitly selected nickname/avatar adoption
- [ ] Update `backend/internal/handler/auth_handler.go` and `backend/internal/handler/user_handler.go` to expose:
- current bindings summary
- start-bind endpoints for LinuxDo/OIDC/WeChat
@@ -312,6 +344,7 @@
- `frontend/src/api/auth.ts`
- `frontend/src/stores/auth.ts`
- [ ] Replace any token-fragment bootstrap with backend session completion or one-time exchange code flow.
+- [ ] During rollout, keep temporary compatibility readers for legacy pending-token aliases behind a bounded bridge contract and explicit removal step.
### Task 12. Rebuild profile account binding and avatar UX
@@ -351,11 +384,17 @@
- frontend visible methods remain `alipay` and `wxpay`
- admin chooses which provider instance serves each method
- runtime validation guarantees only one active source per visible method
+- [ ] Add migration logic and tests to normalize historical provider-instance config:
+ - `supported_types`
+ - `payment_mode`
+ - legacy aliases such as `wxpay_direct`
+ - historical limit config
- [ ] Rebuild `backend/internal/service/payment_order.go` and `backend/internal/service/payment_order_lifecycle.go` so order creation snapshots:
- visible method
- selected provider instance id
- provider type
- provider capability mode
+ - verification-critical provider fields needed for later callback/query/refund validation
- [ ] Rebuild `backend/internal/handler/payment_handler.go` for UX rules:
- Alipay PC: QR page
- Alipay mobile: direct jump
@@ -363,6 +402,11 @@
- WeChat H5 in WeChat: MP/JSAPI first, fallback to H5
- WeChat H5 outside WeChat: H5 or “open in WeChat” prompt when unavailable
- [ ] Never derive canonical return URL from `Referer`; use configured or signed internal callback targets only.
+- [ ] Implement `backend/internal/service/payment_resume_service.go` so WeChat in-app payment OAuth resume is server-backed rather than localStorage-backed:
+ - create `oauth_required` resume context
+ - persist amount/order_type/plan_id/visible method/redirect/state
+ - redeem callback into same-origin internal resume target
+ - expire and consume resume context safely
### Task 15. Make fulfillment and reconciliation provider-instance-safe
@@ -370,6 +414,12 @@
- verification uses the order’s original provider instance
- webhook processing is idempotent by provider event id and internal order id
- missed webhook recovery uses server-side provider query, not frontend success return
+- [ ] For WeChat Pay specifically, enforce:
+ - fixed HTTPS `notify_url` with no query params
+ - no dependency on user login state
+ - signature verification before decrypt
+ - APIv3 decrypt before business parsing
+ - comparison of `appid`, `mchid`, `out_trade_no`, `amount`, `currency`, and trade state against the order snapshot
- [ ] Harden `frontend/src/views/user/PaymentResultView.vue` and `frontend/src/api/payment.ts` so result polling uses an authenticated order lookup or signed opaque token, not a raw public `out_trade_no` query.
### Task 16. Rebuild payment frontend views
@@ -378,6 +428,10 @@
- only two buttons are shown to user: `支付宝` and `微信支付`
- frontend does not leak official-vs-EasyPay distinction
- route-specific copy handles QR, jump, MP, H5 fallback correctly
+- [ ] Rebuild WeChat in-app payment resume UX around the server-backed resume context:
+ - handle `oauth_required`
+ - continue from same-origin resume target
+ - avoid long-lived localStorage as the source of truth
- [ ] Add or update:
- `frontend/src/views/user/__tests__/PaymentView.spec.ts`
- `frontend/src/views/user/__tests__/PaymentResultView.spec.ts`
@@ -416,16 +470,22 @@
- historical email-only user login after upgrade
- historical LinuxDo user login after upgrade
- historical synthetic-email LinuxDo user login after upgrade
+ - historical synthetic-email WeChat user login after upgrade
+ - historical synthetic-email OIDC user login after upgrade
+ - historical WeChat `openid`-only rows are reported or explicitly remediated
- no retroactive grant replay during migration
- first-bind grant fires once only when enabled
- email identity dual-write stays consistent
- bind-existing-account requires password and TOTP where configured
+ - mixed-version callback token bridge works during rollout and is removable afterward
+ - historical payment config is normalized into visible-method routing without refund/query regression
- [ ] Add deploy sequencing note to release docs or internal runbook:
1. deploy schema and backfill release.
2. inspect migration report for unmatched rows.
- 3. deploy backend identity/payment compatibility code.
- 4. deploy frontend callback/profile/payment UI.
- 5. enable strict email-required signup or provider bind grants only after metrics are healthy.
+ 3. deploy backend identity/payment compatibility code with exchange bridge and legacy token aliases still enabled.
+ 4. deploy frontend callback/profile/payment UI using session completion, exchange code, and server-backed WeChat payment resume.
+ 5. remove legacy callback/token parsing after mixed-version window closes.
+ 6. enable strict email-required signup or provider bind grants only after metrics are healthy.
### Task 20. Final verification and handoff
@@ -449,6 +509,8 @@
- [ ] Run focused manual smoke checks:
- email login with existing account
- LinuxDo existing-account login after migration
+ - WeChat synthetic-email account login after migration
+ - OIDC synthetic-email account login after migration
- third-party first login create-new-account path
- third-party first login bind-existing-account path
- first third-party bind with optional nickname/avatar replacement
@@ -456,6 +518,7 @@
- mobile Alipay jump
- PC WeChat QR
- WeChat H5 MP/JSAPI path
+ - WeChat in-app OAuth resume path
- non-WeChat H5 fallback path
- [ ] Commit final checkpoint:
```bash
@@ -468,9 +531,9 @@
- [ ] No flow still relies on provider email equality for account linking.
- [ ] No flow still creates third-party users through plain email registration helpers.
- [ ] No callback still returns first-party bearer tokens in URL fragments.
+- [ ] No callback completion path can be replayed as a bearer token substitute.
- [ ] No payment result view trusts provider return page as authoritative fulfillment.
- [ ] No webhook verification path selects provider credentials from “currently active config” instead of the order snapshot.
-- [ ] Existing email users and historical LinuxDo users are covered by migration tests.
+- [ ] Existing email users, historical LinuxDo/WeChat/OIDC users, and `openid`-only WeChat remediation cases are covered by migration tests.
- [ ] Avatar adoption and deletion semantics are explicit and reversible.
- [ ] Grant timing is source-aware and one-time only.
-
diff --git a/docs/superpowers/specs/2026-04-20-auth-identity-payment-foundation-design.md b/docs/superpowers/specs/2026-04-20-auth-identity-payment-foundation-design.md
index bfa250d9..790861b7 100644
--- a/docs/superpowers/specs/2026-04-20-auth-identity-payment-foundation-design.md
+++ b/docs/superpowers/specs/2026-04-20-auth-identity-payment-foundation-design.md
@@ -20,10 +20,10 @@ This design includes:
- Profile binding management and avatar upload/delete
- Source-based initial grants for balance, concurrency, and subscriptions
- User management support for `last_login_at` and `last_active_at` sorting
-- Unified payment display methods (`alipay`, `wechat`) mapped to a single active backend source each
+- Unified payment display methods (`alipay`, `wxpay`) mapped to a single active backend source each
- Alipay and WeChat UX routing rules across PC, mobile, H5, and WeChat environments
- Admin settings for auth providers, source defaults, payment sources, and OpenAI advanced scheduling
-- Incremental migration and compatibility for existing email users and historical LinuxDo synthetic-email users
+- Incremental migration and compatibility for existing email users, existing LinuxDo users, historical LinuxDo/WeChat/OIDC synthetic-email users, and historical WeChat `openid`-only identity records
This design does not treat unrelated upstream merges, docs churn, or license changes from the old branch as required scope.
@@ -32,9 +32,11 @@ This design does not treat unrelated upstream merges, docs churn, or license cha
### Auth and identity
- Existing email users remain valid and continue to log in with no manual action.
+- Existing LinuxDo, OIDC, and WeChat users represented by historical third-party or synthetic-email data must remain recoverable during migration.
- Third-party first login behavior:
- Existing bound identity: direct login
- Missing identity: start first-login flow
+- Browser-based third-party authorization-code login always uses PKCE `S256`; this is not an admin-toggleable feature.
- If `force_email_on_third_party_signup` is disabled, a first-login user may create an account without binding an email.
- If `force_email_on_third_party_signup` is enabled, the user must provide an email.
- If the provided and verified email already exists:
@@ -43,11 +45,25 @@ This design does not treat unrelated upstream merges, docs churn, or license cha
- allow "change email and continue registration"
- do not allow bypassing the email requirement
- Upstream provider email verification is not trusted as a local bound email.
+- Matching upstream email must never auto-link to an existing local account.
+- Linking to an existing local account is allowed only when:
+ - the user explicitly chooses that target account
+ - the target account passes fresh local re-authentication
+ - required TOTP verification succeeds
+- New third-party bind initiated from profile must start from an already logged-in local account and preserve explicit bind intent end-to-end.
+- `redirect_to` may only represent a normalized same-origin internal route. It must never contain a third-party URL and must never be derived from `Referer`.
+- OIDC validation rules:
+ - canonical identity key is `issuer + sub`
+ - discovery issuer and ID token `iss` must match exactly
+ - `userinfo.sub` must match ID token `sub` when UserInfo is used
+ - upstream `email_verified` may improve UX copy but does not satisfy local email-binding requirements
- WeChat login chooses channel by environment:
- in WeChat environment: `mp`
- outside WeChat: `open`
- WeChat primary identity key is `unionid`.
- If a WeChat login/bind flow cannot produce `unionid`, the flow fails and no fallback `openid` identity is created.
+- Historical WeChat records that only contain `openid` are treated as migration-remediation cases, not as a valid long-term canonical identity model.
+- WeChat website login uses authorization code flow, random `state`, and the provider channel/app binding must be persisted alongside the resolved identity.
### Profile adoption
@@ -85,7 +101,7 @@ This design does not treat unrelated upstream merges, docs churn, or license cha
- Frontend shows only two display methods:
- `alipay`
- - `wechat`
+ - `wxpay`
- Users never choose between official providers and EasyPay explicitly.
- Backend allows only one active source per display method at a time.
- Alipay UX:
@@ -96,6 +112,10 @@ This design does not treat unrelated upstream merges, docs churn, or license cha
- non-WeChat H5: prefer H5 pay; if unavailable, tell the user to open in WeChat
- WeChat environment: prefer MP/JSAPI pay; if unavailable, fall back to H5 pay
- Payment success is confirmed by backend order state, webhook, and/or query, not only frontend return.
+- Frontend-visible labels remain `支付宝` and `微信支付`, while internal visible-method identifiers remain `alipay` and `wxpay`.
+- Public result pages must not verify order state by exposing raw `out_trade_no`; they use authenticated lookup or a signed opaque result token instead.
+- Payment callback or return URLs must be fixed same-origin internal targets. They must not be inferred from `Referer`.
+- WeChat payment webhook handling must use a fixed HTTPS `notify_url` with no query parameters and must not depend on user login state.
### OpenAI advanced scheduling
@@ -105,9 +125,9 @@ This design does not treat unrelated upstream merges, docs churn, or license cha
## Architecture
-Keep `users` as the account owner table and move login identities, channel mappings, pending auth state, and first-bind grant idempotency into dedicated tables and services. Keep email login working while progressively introducing unified identity reads and writes.
+Keep `users` as the account owner table and move login identities, channel mappings, pending auth state, callback completion state, and first-bind grant idempotency into dedicated tables and services. Keep email login working while progressively introducing unified identity reads and writes.
-Payment uses a similar split between user-visible display methods and backend provider sources. Frontend works only with stable display methods while backend resolves to the currently active source and capability matrix.
+Payment uses a similar split between user-visible display methods and backend provider sources. Frontend works only with stable display methods while backend resolves to the currently active source and capability matrix, and stores enough order-time snapshot data to survive later provider-config changes.
Compatibility is a first-class concern: migrations are additive, reads are compatibility-aware, and rollout must tolerate existing `main` data and short-lived frontend/backend version skew.
@@ -148,9 +168,9 @@ Uniqueness:
Rules:
- email identity uses canonicalized local email
-- LinuxDo uses stable provider subject
-- OIDC uses stable issuer + subject
-- WeChat uses `unionid` as canonical subject
+- LinuxDo uses stable provider subject under the configured provider namespace
+- OIDC uses stable issuer + subject, with issuer namespace represented consistently through `provider_key` and `issuer`
+- WeChat uses `unionid` as canonical subject under the configured Open Platform namespace
### `auth_identity_channels`
@@ -189,9 +209,12 @@ Fields:
- `target_user_id`
- `redirect_to`
- `resolved_email`
-- `pending_password_hash`
-- `upstream_identity_payload`
-- `metadata`
+- `registration_password_hash`
+- `upstream_identity_claims`
+- `local_flow_state`
+- `browser_session_key`
+- `completion_code_hash`
+- `completion_code_expires_at`
- `email_verified_at`
- `password_verified_at`
- `totp_verified_at`
@@ -205,19 +228,33 @@ Responsibilities:
- persist nickname/avatar suggestions
- persist explicit adoption decisions
- survive navigation between auth pages
+- support mixed-version rollout through short-lived legacy token aliases when required
+
+Security rules:
+
+- callback completion uses backend session completion or a one-time exchange code
+- exchange codes are short-lived, one-time, bound to browser session and pending session, and redeemed via `POST`
+- exchange codes must not behave as bearer tokens and must not be logged, stored in URL fragments, or reused after redemption
+- `local_flow_state` stores mutable local progression only; immutable upstream claims remain in `upstream_identity_claims`
### `identity_adoption_decisions`
-Persists user adoption preference for a specific identity.
+Persists user adoption preference collected during a pending-auth flow and resolved onto the bound identity.
Fields:
+- `pending_auth_session_id`
- `identity_id`
- `adopt_display_name`
- `adopt_avatar`
- `decided_at`
- timestamps
+Rules:
+
+- one adoption-decision row exists per pending session
+- `identity_id` is filled once final account creation or bind succeeds
+
### `user_avatars`
Stores the currently effective custom avatar.
@@ -265,6 +302,7 @@ WeChat-specific rule:
- `openid` never becomes the primary stored identity key
- if only `openid` is available, login/bind fails with a configuration/identity error
+- historical `openid`-only records must be reported and either remediated during migration or explicitly blocked from silent auto-upgrade
## Core Flows
@@ -285,6 +323,7 @@ WeChat-specific rule:
- Create `pending_auth_session`
- Frontend callback flow decides next action
+- Pending session creation stores immutable upstream claims separately from mutable local progress fields
Branches:
@@ -303,6 +342,7 @@ On new account creation:
- create `users` row
- create canonical third-party identity
+- create or update canonical email identity when local email binding succeeds
- apply source signup grants
- apply adoption choices if selected
@@ -310,21 +350,34 @@ On new account creation:
- current user starts bind flow
- callback resolves to `bind_current_user`
+- bind intent is tied to the initiating local user session and cannot be re-targeted by email match
- bind canonical identity to current user
- if configured and first bind for that provider, apply first-bind grants
- present nickname/avatar replacement choice
### Bind existing account during first-login flow
+- user explicitly selects bind-existing-account
- verify password for existing account
- if account requires TOTP, verify TOTP
- bind canonical identity to target account
- optionally apply first-bind grants
- present nickname/avatar replacement choice
+- no automatic profile or metadata merge occurs beyond explicitly selected nickname/avatar replacement
+
+### Callback completion and exchange flow
+
+- third-party callback never returns first-party bearer tokens in URL fragments
+- callback completion uses either:
+ - backend session completion tied to the initiating browser session
+ - one-time opaque exchange code redeemed by `POST`
+- mixed-version rollout may temporarily emit legacy pending token aliases in addition to the new completion path
+- legacy alias support is transitional and bounded to rollout windows only
### WeChat login and channel mapping
- environment chooses `mp` or `open`
+- website login uses authorization-code flow with provider-configured app/channel binding
- callback must resolve to `unionid`
- channel `openid` is optionally recorded in `auth_identity_channels`
- failure to obtain `unionid` aborts flow
@@ -344,7 +397,7 @@ On new account creation:
### User-visible methods
- `alipay`
-- `wechat`
+- `wxpay`
### Backend source abstraction
@@ -357,6 +410,12 @@ Each display method maps to exactly one active configured backend source:
Frontend submits display method only. Backend resolves display method to active source and capability set.
+### Legacy payment-config normalization
+
+- existing provider-instance `supported_types`, legacy aliases such as `wxpay_direct`, and per-type limit structures are migrated into the visible-method model
+- migration preserves historical payment capability and refund semantics
+- the system keeps one normalized visible-method mapping per provider instance for rollout and audit
+
### Alipay routing
- PC: create QR-oriented result and show QR in page
@@ -372,11 +431,25 @@ Frontend submits display method only. Backend resolves display method to active
- prefer MP/JSAPI
- if unavailable, fall back to H5 pay
+### WeChat payment OAuth recovery
+
+- if WeChat in-app payment requires `openid` and the current request does not already hold it, backend returns an `oauth_required` response instead of guessing
+- backend creates a server-backed payment-resume context containing:
+ - target visible method
+ - amount/order type/plan context
+ - redirect target
+ - anti-replay state
+- backend redirects through a dedicated WeChat payment OAuth start endpoint
+- callback exchanges the provider code server-side, stores `openid` in the payment-resume context, and returns a same-origin internal resume target
+- frontend resumes the original order flow through the resume context instead of trusting raw callback query state or long-lived local storage
+
### Payment completion
- frontend return restores context and UI state
- backend order state remains source of truth
- webhook and/or order query remain authoritative for fulfillment
+- order fulfillment validates webhook or query payload against order-time snapshot data including provider instance, merchant identifiers, amount, currency, and provider order references
+- result pages use authenticated lookup or signed opaque result tokens, never raw public `out_trade_no`
## Admin Configuration Model
@@ -420,6 +493,9 @@ Compatibility is mandatory, especially for:
- existing email users
- existing LinuxDo users
- historical LinuxDo synthetic-email accounts
+- historical WeChat synthetic-email accounts
+- historical OIDC synthetic-email accounts
+- historical WeChat `openid`-only records created by older branches
### Additive migrations
@@ -431,6 +507,8 @@ Compatibility is mandatory, especially for:
- backfill canonical `email` identities for valid existing email users
- backfill canonical `linuxdo` identities during migration for historical synthetic-email LinuxDo users
+- backfill canonical `wechat` and `oidc` identities when historical synthetic-email or `user_external_identities` data allows deterministic reconstruction
+- emit migration reports for historical WeChat `openid`-only records that cannot be safely promoted to canonical `unionid`
- backfill must be idempotent and repeatable
### Compatibility reads
@@ -438,7 +516,7 @@ Compatibility is mandatory, especially for:
During rollout:
- read new identity model first
-- where necessary, retain compatibility logic for existing email and historical LinuxDo synthetic-email recognition
+- where necessary, retain compatibility logic for existing email and historical LinuxDo/WeChat/OIDC synthetic-email recognition
### Grant idempotency
@@ -453,17 +531,20 @@ Retain transitional support for legacy/new request and response shapes where nee
- `pending_oauth_token`
- old callback parsing expectations
- historical profile field mappings
+- legacy callback fragment readers during the bounded rollout window
### Settings and payment compatibility
- preserve existing payment configs and order semantics from `main`
- add new settings incrementally
- avoid rewriting the entire settings schema in one cutover
+- preserve legacy provider-instance capabilities by explicitly mapping historical `supported_types`, `payment_mode`, and limit config into normalized visible-method routing
### Rolling upgrade tolerance
- do not assume simultaneous frontend/backend deployment
- new backend must tolerate short-lived older frontend request shapes
+- rollout must define the deployment order and removal point for legacy callback token parsing and legacy payment resume parsing
## Testing Strategy
@@ -509,9 +590,13 @@ Retain transitional support for legacy/new request and response shapes where nee
- existing email users
- historical LinuxDo synthetic-email users
+- historical WeChat synthetic-email users
+- historical OIDC synthetic-email users
+- historical WeChat `openid`-only records reported or remediated correctly
- historical payment config
- legacy auth payload field names
- historical payment result handling
+- mixed-version callback token bridge behavior
## Implementation Phases
@@ -528,9 +613,12 @@ Retain transitional support for legacy/new request and response shapes where nee
Implementation must follow current primary-source guidance:
- OAuth 2.0 Security BCP (RFC 9700): strict redirect handling, state protection, mix-up resistant design
-- PKCE (RFC 7636): use on authorization code flows where applicable
+- PKCE (RFC 7636): require `S256` on browser authorization-code flows
- OpenID Connect Core: stable issuer/subject handling for OIDC identities
- Account linking best practice: require explicit user confirmation or re-authentication before linking to existing accounts
+- WeChat UnionID and website-login guidance: treat `unionid` as canonical cross-channel subject and persist channel/app binding with website login responses
+- WeChat Pay webhook guidance: verify signatures, decrypt payloads, and confirm merchant/order/amount fields against order-time state before fulfillment
+- Payment success-page guidance: custom success pages are informational and must not be the only fulfillment trigger
References:
@@ -538,6 +626,10 @@ References:
- RFC 7636:
- OpenID Connect Core 1.0:
- Auth0 account linking guidance:
+- WeChat UnionID guidance:
+- WeChat website login guidance:
+- WeChat Pay callback/signature guidance:
+- Stripe Checkout fulfillment guidance:
## Audit Synthesis
@@ -553,7 +645,7 @@ The clean rebuild direction is not to copy either existing branch directly.
- `personal-dev-branch` has the better real-world closure:
- LinuxDo and WeChat callback flows are more operationally complete
- profile binding and avatar UX is more complete
- - historical synthetic-email users are recognized and recovered in live flows
+ - historical synthetic-email users across multiple providers are recognized and recovered in live flows
- WeChat payment OAuth and recovery behavior is more complete
- Primary-source guidance supplies hard constraints for OAuth/OIDC, account linking, WeChat identity handling, and payment completion semantics.
@@ -584,6 +676,7 @@ Keep these operational flow ideas from `personal-dev-branch`:
- profile bindings UX and “cannot disconnect last usable login method” rule
- separate WeChat login OAuth and WeChat payment OAuth entry points
- historical synthetic-email recognition logic as a migration bridge
+- explicit WeChat payment OAuth recovery protocol as a product requirement, but reimplemented with server-backed resume state
### Adapt
@@ -595,6 +688,7 @@ These areas must be reimplemented with the same intent but stricter boundaries:
- WeChat payment recovery state must move from frontend-only storage to server-backed continuation state
- avatar adoption fetches must be security-hardened and failure-visible
- pending-auth payload modeling must clearly separate immutable upstream payload from mutable local metadata
+- callback completion must use a real exchange/session model instead of fragment-delivered bearer tokens
- profile binding/avatar DTOs must be simplified to one authoritative backend contract instead of sprawling frontend fallback parsing
- admin settings should preserve capability while reducing duplicated or transitional config branches
@@ -614,7 +708,7 @@ The audit and source review establish these hard constraints:
### Auth
-- all authorization-code providers use PKCE where applicable
+- all browser authorization-code providers use PKCE `S256` and do not expose an admin-off switch
- callback handling uses strict `redirect_uri` discipline and state validation
- OIDC identity key is `issuer + sub`
- existing-account linking after email conflict must require explicit user action plus local-account verification
@@ -624,7 +718,8 @@ The audit and source review establish these hard constraints:
- existing email users must continue to work with no manual intervention
- existing LinuxDo users must not split into duplicate accounts
-- historical LinuxDo synthetic-email users must be backfilled into canonical LinuxDo identities during migration, not only lazily on next login
+- historical LinuxDo/WeChat/OIDC synthetic-email users must be backfilled into canonical identities during migration when deterministic recovery is possible
+- historical WeChat `openid`-only records must be surfaced through migration reporting and explicit remediation rules
- migration backfills must not trigger signup or first-bind grants
- legacy `pending_auth_token` and `pending_oauth_token` contracts must remain accepted during rollout
- legacy auth/public setting aliases needed by older frontend builds must remain available during rollout
@@ -634,7 +729,9 @@ The audit and source review establish these hard constraints:
- frontend return pages do not determine final payment success
- backend order state, webhook processing, and/or provider status query remain authoritative
-- each visible method (`alipay`, `wechat`) may have only one active backend source at a time
+- each visible method (`alipay`, `wxpay`) may have only one active backend source at a time
+- public result pages must not expose raw `out_trade_no` lookup
+- WeChat Pay callback handling must verify signature, decrypt payload, and compare order fields against order-time snapshot data
## Known Risks To Eliminate In Implementation
@@ -649,6 +746,7 @@ These are specifically observed problems in the existing branches that the clean
- WeChat payment recovery in `personal-dev-branch` is frontend-local and not robust across tabs or concurrent attempts
- the existing pending-auth migration update is too destructive to reuse unchanged in a safer rollout
- historical provider provenance should not be permanently flattened to `signup_source = email`
+- design/plan drift can reintroduce ambiguous identity uniqueness or ambiguous adoption-decision ownership if not aligned before implementation
## Rollout Gates
@@ -656,9 +754,10 @@ The rebuild is not ready for rollout until all of these are satisfied:
1. Identity schema and migration chain are linearized and production-safe.
2. Email identity backfill is complete and idempotent.
-3. Historical LinuxDo synthetic-email backfill to canonical LinuxDo identity is complete and idempotent.
-4. `signup_source` backfill is accurate for known historical provider-created users.
-5. Dual token acceptance and required legacy field aliases are present.
-6. Existing payment configs are normalized and verified against current frontend-visible capabilities.
-7. New frontend flows are verified against mixed-version backend compatibility windows.
-8. Duplicate-account creation, first-bind grants, and payment route selection have regression coverage.
+3. Historical LinuxDo/WeChat/OIDC synthetic-email backfill to canonical identity is complete where deterministic, and non-recoverable rows are reported.
+4. Historical WeChat `openid`-only rows are either remediated or explicitly blocked with operator-visible reporting.
+5. `signup_source` backfill is accurate for known historical provider-created users.
+6. Dual token acceptance, exchange bridge behavior, and required legacy field aliases are present for the bounded rollout window.
+7. Existing payment configs are normalized and verified against current frontend-visible capabilities.
+8. New frontend flows are verified against mixed-version backend compatibility windows.
+9. Duplicate-account creation, first-bind grants, and payment route selection have regression coverage.
From d3d42677311518b10b738e4180aefd484125ecac Mon Sep 17 00:00:00 2001
From: IanShaw027
Date: Mon, 20 Apr 2026 16:23:42 +0800
Subject: [PATCH 020/189] fix: harden oidc callback security
---
backend/internal/config/config_test.go | 2 +-
backend/internal/handler/auth_oidc_oauth.go | 203 ++++++++++++------
.../internal/handler/auth_oidc_oauth_test.go | 15 --
frontend/src/views/admin/SettingsView.vue | 10 +-
4 files changed, 143 insertions(+), 87 deletions(-)
diff --git a/backend/internal/config/config_test.go b/backend/internal/config/config_test.go
index cf58316c..964dbb88 100644
--- a/backend/internal/config/config_test.go
+++ b/backend/internal/config/config_test.go
@@ -334,7 +334,7 @@ func TestValidateLinuxDoFrontendRedirectURL(t *testing.T) {
cfg.LinuxDo.ClientSecret = "test-secret"
cfg.LinuxDo.RedirectURL = "https://example.com/api/v1/auth/oauth/linuxdo/callback"
cfg.LinuxDo.TokenAuthMethod = "client_secret_post"
- cfg.LinuxDo.UsePKCE = false
+ cfg.LinuxDo.UsePKCE = true
cfg.LinuxDo.FrontendRedirectURL = "javascript:alert(1)"
err = cfg.Validate()
diff --git a/backend/internal/handler/auth_oidc_oauth.go b/backend/internal/handler/auth_oidc_oauth.go
index 9d24df88..37ef6833 100644
--- a/backend/internal/handler/auth_oidc_oauth.go
+++ b/backend/internal/handler/auth_oidc_oauth.go
@@ -127,30 +127,34 @@ func (h *AuthHandler) OIDCOAuthStart(c *gin.Context) {
redirectTo = oidcOAuthDefaultRedirectTo
}
+ browserSessionKey, err := generateOAuthPendingBrowserSession()
+ if err != nil {
+ response.ErrorFrom(c, infraerrors.InternalServer("OAUTH_BROWSER_SESSION_GEN_FAILED", "failed to generate oauth browser session").WithCause(err))
+ return
+ }
+
secureCookie := isRequestHTTPS(c)
oidcSetCookie(c, oidcOAuthStateCookieName, encodeCookieValue(state), oidcOAuthCookieMaxAgeSec, secureCookie)
oidcSetCookie(c, oidcOAuthRedirectCookie, encodeCookieValue(redirectTo), oidcOAuthCookieMaxAgeSec, secureCookie)
+ setOAuthPendingBrowserCookie(c, browserSessionKey, secureCookie)
+ clearOAuthPendingSessionCookie(c, secureCookie)
codeChallenge := ""
- if cfg.UsePKCE {
- verifier, genErr := oauth.GenerateCodeVerifier()
- if genErr != nil {
- response.ErrorFrom(c, infraerrors.InternalServer("OAUTH_PKCE_GEN_FAILED", "failed to generate pkce verifier").WithCause(genErr))
- return
- }
- codeChallenge = oauth.GenerateCodeChallenge(verifier)
- oidcSetCookie(c, oidcOAuthVerifierCookie, encodeCookieValue(verifier), oidcOAuthCookieMaxAgeSec, secureCookie)
+ verifier, genErr := oauth.GenerateCodeVerifier()
+ if genErr != nil {
+ response.ErrorFrom(c, infraerrors.InternalServer("OAUTH_PKCE_GEN_FAILED", "failed to generate pkce verifier").WithCause(genErr))
+ return
}
+ codeChallenge = oauth.GenerateCodeChallenge(verifier)
+ oidcSetCookie(c, oidcOAuthVerifierCookie, encodeCookieValue(verifier), oidcOAuthCookieMaxAgeSec, secureCookie)
nonce := ""
- if cfg.ValidateIDToken {
- nonce, err = oauth.GenerateState()
- if err != nil {
- response.ErrorFrom(c, infraerrors.InternalServer("OAUTH_NONCE_GEN_FAILED", "failed to generate oauth nonce").WithCause(err))
- return
- }
- oidcSetCookie(c, oidcOAuthNonceCookie, encodeCookieValue(nonce), oidcOAuthCookieMaxAgeSec, secureCookie)
+ nonce, err = oauth.GenerateState()
+ if err != nil {
+ response.ErrorFrom(c, infraerrors.InternalServer("OAUTH_NONCE_GEN_FAILED", "failed to generate oauth nonce").WithCause(err))
+ return
}
+ oidcSetCookie(c, oidcOAuthNonceCookie, encodeCookieValue(nonce), oidcOAuthCookieMaxAgeSec, secureCookie)
redirectURI := strings.TrimSpace(cfg.RedirectURL)
if redirectURI == "" {
@@ -212,23 +216,24 @@ func (h *AuthHandler) OIDCOAuthCallback(c *gin.Context) {
if redirectTo == "" {
redirectTo = oidcOAuthDefaultRedirectTo
}
+ browserSessionKey, _ := readOAuthPendingBrowserCookie(c)
+ if strings.TrimSpace(browserSessionKey) == "" {
+ redirectOAuthError(c, frontendCallback, "missing_browser_session", "missing oauth browser session", "")
+ return
+ }
codeVerifier := ""
- if cfg.UsePKCE {
- codeVerifier, _ = readCookieDecoded(c, oidcOAuthVerifierCookie)
- if codeVerifier == "" {
- redirectOAuthError(c, frontendCallback, "missing_verifier", "missing pkce verifier", "")
- return
- }
+ codeVerifier, _ = readCookieDecoded(c, oidcOAuthVerifierCookie)
+ if codeVerifier == "" {
+ redirectOAuthError(c, frontendCallback, "missing_verifier", "missing pkce verifier", "")
+ return
}
expectedNonce := ""
- if cfg.ValidateIDToken {
- expectedNonce, _ = readCookieDecoded(c, oidcOAuthNonceCookie)
- if expectedNonce == "" {
- redirectOAuthError(c, frontendCallback, "missing_nonce", "missing oauth nonce", "")
- return
- }
+ expectedNonce, _ = readCookieDecoded(c, oidcOAuthNonceCookie)
+ if expectedNonce == "" {
+ redirectOAuthError(c, frontendCallback, "missing_nonce", "missing oauth nonce", "")
+ return
}
redirectURI := strings.TrimSpace(cfg.RedirectURL)
@@ -258,7 +263,7 @@ func (h *AuthHandler) OIDCOAuthCallback(c *gin.Context) {
return
}
- if cfg.ValidateIDToken && strings.TrimSpace(tokenResp.IDToken) == "" {
+ if strings.TrimSpace(tokenResp.IDToken) == "" {
redirectOAuthError(c, frontendCallback, "missing_id_token", "missing id_token", "")
return
}
@@ -304,9 +309,13 @@ func (h *AuthHandler) OIDCOAuthCallback(c *gin.Context) {
return
}
}
+ if userInfoClaims.Subject != "" && idClaims.Subject != "" && strings.TrimSpace(userInfoClaims.Subject) != strings.TrimSpace(idClaims.Subject) {
+ redirectOAuthError(c, frontendCallback, "subject_mismatch", "userinfo subject does not match id_token", "")
+ return
+ }
identityKey := oidcIdentityKey(issuer, subject)
- email := oidcSelectLoginEmail(userInfoClaims.Email, idClaims.Email, identityKey)
+ email := oidcSyntheticEmailFromIdentityKey(identityKey)
username := firstNonEmpty(
userInfoClaims.Username,
idClaims.PreferredUsername,
@@ -318,34 +327,73 @@ func (h *AuthHandler) OIDCOAuthCallback(c *gin.Context) {
tokenPair, _, err := h.authService.LoginOrRegisterOAuthWithTokenPair(c.Request.Context(), email, username, "")
if err != nil {
if errors.Is(err, service.ErrOAuthInvitationRequired) {
- pendingToken, tokenErr := h.authService.CreatePendingOAuthToken(email, username)
- if tokenErr != nil {
- redirectOAuthError(c, frontendCallback, "login_failed", "service_error", "")
+ if err := h.createOAuthPendingSession(c, oauthPendingSessionPayload{
+ Intent: "login",
+ Identity: service.PendingAuthIdentityKey{
+ ProviderType: "oidc",
+ ProviderKey: issuer,
+ ProviderSubject: subject,
+ },
+ ResolvedEmail: email,
+ RedirectTo: redirectTo,
+ BrowserSessionKey: browserSessionKey,
+ UpstreamIdentityClaims: map[string]any{
+ "email": email,
+ "username": username,
+ "subject": subject,
+ "issuer": issuer,
+ "email_verified": emailVerified != nil && *emailVerified,
+ "provider_fallback": strings.TrimSpace(cfg.ProviderName),
+ },
+ CompletionResponse: map[string]any{
+ "error": "invitation_required",
+ "redirect": redirectTo,
+ },
+ }); err != nil {
+ redirectOAuthError(c, frontendCallback, "session_error", "failed to continue oauth login", "")
return
}
- fragment := url.Values{}
- fragment.Set("error", "invitation_required")
- fragment.Set("pending_oauth_token", pendingToken)
- fragment.Set("redirect", redirectTo)
- redirectWithFragment(c, frontendCallback, fragment)
+ redirectToFrontendCallback(c, frontendCallback)
return
}
redirectOAuthError(c, frontendCallback, "login_failed", infraerrors.Reason(err), infraerrors.Message(err))
return
}
- fragment := url.Values{}
- fragment.Set("access_token", tokenPair.AccessToken)
- fragment.Set("refresh_token", tokenPair.RefreshToken)
- fragment.Set("expires_in", fmt.Sprintf("%d", tokenPair.ExpiresIn))
- fragment.Set("token_type", "Bearer")
- fragment.Set("redirect", redirectTo)
- redirectWithFragment(c, frontendCallback, fragment)
+ if err := h.createOAuthPendingSession(c, oauthPendingSessionPayload{
+ Intent: "login",
+ Identity: service.PendingAuthIdentityKey{
+ ProviderType: "oidc",
+ ProviderKey: issuer,
+ ProviderSubject: subject,
+ },
+ ResolvedEmail: email,
+ RedirectTo: redirectTo,
+ BrowserSessionKey: browserSessionKey,
+ UpstreamIdentityClaims: map[string]any{
+ "email": email,
+ "username": username,
+ "subject": subject,
+ "issuer": issuer,
+ "email_verified": emailVerified != nil && *emailVerified,
+ "provider_fallback": strings.TrimSpace(cfg.ProviderName),
+ },
+ CompletionResponse: map[string]any{
+ "access_token": tokenPair.AccessToken,
+ "refresh_token": tokenPair.RefreshToken,
+ "expires_in": tokenPair.ExpiresIn,
+ "token_type": "Bearer",
+ "redirect": redirectTo,
+ },
+ }); err != nil {
+ redirectOAuthError(c, frontendCallback, "session_error", "failed to continue oauth login", "")
+ return
+ }
+ redirectToFrontendCallback(c, frontendCallback)
}
type completeOIDCOAuthRequest struct {
- PendingOAuthToken string `json:"pending_oauth_token" binding:"required"`
- InvitationCode string `json:"invitation_code" binding:"required"`
+ InvitationCode string `json:"invitation_code" binding:"required"`
}
// CompleteOIDCOAuthRegistration completes a pending OAuth registration by validating
@@ -358,9 +406,38 @@ func (h *AuthHandler) CompleteOIDCOAuthRegistration(c *gin.Context) {
return
}
- email, username, err := h.authService.VerifyPendingOAuthToken(req.PendingOAuthToken)
+ secureCookie := isRequestHTTPS(c)
+ sessionToken, err := readOAuthPendingSessionCookie(c)
if err != nil {
- c.JSON(http.StatusUnauthorized, gin.H{"error": "INVALID_TOKEN", "message": "invalid or expired registration token"})
+ clearOAuthPendingSessionCookie(c, secureCookie)
+ clearOAuthPendingBrowserCookie(c, secureCookie)
+ response.ErrorFrom(c, service.ErrPendingAuthSessionNotFound)
+ return
+ }
+ browserSessionKey, err := readOAuthPendingBrowserCookie(c)
+ if err != nil {
+ clearOAuthPendingSessionCookie(c, secureCookie)
+ clearOAuthPendingBrowserCookie(c, secureCookie)
+ response.ErrorFrom(c, service.ErrPendingAuthBrowserMismatch)
+ return
+ }
+ pendingSvc, err := h.pendingIdentityService()
+ if err != nil {
+ response.ErrorFrom(c, err)
+ return
+ }
+ session, err := pendingSvc.GetBrowserSession(c.Request.Context(), sessionToken, browserSessionKey)
+ if err != nil {
+ clearOAuthPendingSessionCookie(c, secureCookie)
+ clearOAuthPendingBrowserCookie(c, secureCookie)
+ response.ErrorFrom(c, err)
+ return
+ }
+
+ email := strings.TrimSpace(session.ResolvedEmail)
+ username := pendingSessionStringValue(session.UpstreamIdentityClaims, "username")
+ if email == "" || username == "" {
+ response.ErrorFrom(c, infraerrors.BadRequest("PENDING_AUTH_SESSION_INVALID", "pending auth registration context is invalid"))
return
}
@@ -369,6 +446,14 @@ func (h *AuthHandler) CompleteOIDCOAuthRegistration(c *gin.Context) {
response.ErrorFrom(c, err)
return
}
+ if _, err := pendingSvc.ConsumeBrowserSession(c.Request.Context(), sessionToken, browserSessionKey); err != nil {
+ clearOAuthPendingSessionCookie(c, secureCookie)
+ clearOAuthPendingBrowserCookie(c, secureCookie)
+ response.ErrorFrom(c, err)
+ return
+ }
+ clearOAuthPendingSessionCookie(c, secureCookie)
+ clearOAuthPendingBrowserCookie(c, secureCookie)
c.JSON(http.StatusOK, gin.H{
"access_token": tokenPair.AccessToken,
@@ -405,9 +490,7 @@ func oidcExchangeCode(
form.Set("client_id", cfg.ClientID)
form.Set("code", code)
form.Set("redirect_uri", redirectURI)
- if cfg.UsePKCE {
- form.Set("code_verifier", codeVerifier)
- }
+ form.Set("code_verifier", codeVerifier)
r := client.R().
SetContext(ctx).
@@ -592,13 +675,9 @@ func buildOIDCAuthorizeURL(cfg config.OIDCConnectConfig, state, nonce, codeChall
q.Set("scope", cfg.Scopes)
}
q.Set("state", state)
- if strings.TrimSpace(nonce) != "" {
- q.Set("nonce", nonce)
- }
- if cfg.UsePKCE {
- q.Set("code_challenge", codeChallenge)
- q.Set("code_challenge_method", "S256")
- }
+ q.Set("nonce", nonce)
+ q.Set("code_challenge", codeChallenge)
+ q.Set("code_challenge_method", "S256")
u.RawQuery = q.Encode()
return u.String(), nil
@@ -831,14 +910,6 @@ func oidcSyntheticEmailFromIdentityKey(identityKey string) string {
return "oidc-" + hex.EncodeToString(sum[:16]) + service.OIDCConnectSyntheticEmailDomain
}
-func oidcSelectLoginEmail(userInfoEmail, idTokenEmail, identityKey string) string {
- email := strings.TrimSpace(firstNonEmpty(userInfoEmail, idTokenEmail))
- if email != "" {
- return email
- }
- return oidcSyntheticEmailFromIdentityKey(identityKey)
-}
-
func oidcFallbackUsername(subject string) string {
subject = strings.TrimSpace(subject)
if subject == "" {
diff --git a/backend/internal/handler/auth_oidc_oauth_test.go b/backend/internal/handler/auth_oidc_oauth_test.go
index a161aa77..a4cf776a 100644
--- a/backend/internal/handler/auth_oidc_oauth_test.go
+++ b/backend/internal/handler/auth_oidc_oauth_test.go
@@ -30,26 +30,11 @@ func TestOIDCSyntheticEmailStableAndDistinct(t *testing.T) {
require.Contains(t, e1, "@oidc-connect.invalid")
}
-func TestOIDCSelectLoginEmailPrefersRealEmail(t *testing.T) {
- identityKey := oidcIdentityKey("https://issuer.example.com", "subject-a")
-
- email := oidcSelectLoginEmail("user@example.com", "idtoken@example.com", identityKey)
- require.Equal(t, "user@example.com", email)
-
- email = oidcSelectLoginEmail("", "idtoken@example.com", identityKey)
- require.Equal(t, "idtoken@example.com", email)
-
- email = oidcSelectLoginEmail("", "", identityKey)
- require.Contains(t, email, "@oidc-connect.invalid")
- require.Equal(t, oidcSyntheticEmailFromIdentityKey(identityKey), email)
-}
-
func TestBuildOIDCAuthorizeURLIncludesNonceAndPKCE(t *testing.T) {
cfg := config.OIDCConnectConfig{
AuthorizeURL: "https://issuer.example.com/auth",
ClientID: "cid",
Scopes: "openid email profile",
- UsePKCE: true,
}
u, err := buildOIDCAuthorizeURL(cfg, "state123", "nonce123", "challenge123", "https://app.example.com/callback")
diff --git a/frontend/src/views/admin/SettingsView.vue b/frontend/src/views/admin/SettingsView.vue
index ee6a4c6d..8bfa0f2b 100644
--- a/frontend/src/views/admin/SettingsView.vue
+++ b/frontend/src/views/admin/SettingsView.vue
@@ -1382,7 +1382,7 @@
{{ t('admin.settings.oidc.usePkce') }}
-
+
@@ -1391,7 +1391,7 @@
{{ t('admin.settings.oidc.validateIdToken') }}
-
+
@@ -3024,7 +3024,7 @@ const form = reactive
({
oidc_connect_redirect_url: '',
oidc_connect_frontend_redirect_url: '/auth/oidc/callback',
oidc_connect_token_auth_method: 'client_secret_post',
- oidc_connect_use_pkce: false,
+ oidc_connect_use_pkce: true,
oidc_connect_validate_id_token: true,
oidc_connect_allowed_signing_algs: 'RS256,ES256,PS256',
oidc_connect_clock_skew_seconds: 120,
@@ -3613,8 +3613,8 @@ async function saveSettings() {
oidc_connect_redirect_url: form.oidc_connect_redirect_url,
oidc_connect_frontend_redirect_url: form.oidc_connect_frontend_redirect_url,
oidc_connect_token_auth_method: form.oidc_connect_token_auth_method,
- oidc_connect_use_pkce: form.oidc_connect_use_pkce,
- oidc_connect_validate_id_token: form.oidc_connect_validate_id_token,
+ oidc_connect_use_pkce: true,
+ oidc_connect_validate_id_token: true,
oidc_connect_allowed_signing_algs: form.oidc_connect_allowed_signing_algs,
oidc_connect_clock_skew_seconds: form.oidc_connect_clock_skew_seconds,
oidc_connect_require_email_verified: form.oidc_connect_require_email_verified,
From fbd0a2e3c488720025a3408e9db234407d8aef9b Mon Sep 17 00:00:00 2001
From: IanShaw027
Date: Mon, 20 Apr 2026 16:27:23 +0800
Subject: [PATCH 021/189] feat: carry suggested third-party profile through
pending oauth
---
.../internal/handler/auth_linuxdo_oauth.go | 201 +++++++++----
.../handler/auth_linuxdo_oauth_test.go | 12 +-
.../handler/auth_oauth_pending_flow.go | 263 ++++++++++++++++++
.../handler/auth_oauth_pending_flow_test.go | 40 +++
backend/internal/handler/auth_oidc_oauth.go | 47 +++-
.../internal/handler/auth_oidc_oauth_test.go | 20 ++
frontend/src/api/auth.ts | 24 +-
7 files changed, 534 insertions(+), 73 deletions(-)
create mode 100644 backend/internal/handler/auth_oauth_pending_flow.go
create mode 100644 backend/internal/handler/auth_oauth_pending_flow_test.go
diff --git a/backend/internal/handler/auth_linuxdo_oauth.go b/backend/internal/handler/auth_linuxdo_oauth.go
index 0c7c2da7..2f182642 100644
--- a/backend/internal/handler/auth_linuxdo_oauth.go
+++ b/backend/internal/handler/auth_linuxdo_oauth.go
@@ -87,20 +87,25 @@ func (h *AuthHandler) LinuxDoOAuthStart(c *gin.Context) {
redirectTo = linuxDoOAuthDefaultRedirectTo
}
+ browserSessionKey, err := generateOAuthPendingBrowserSession()
+ if err != nil {
+ response.ErrorFrom(c, infraerrors.InternalServer("OAUTH_BROWSER_SESSION_GEN_FAILED", "failed to generate oauth browser session").WithCause(err))
+ return
+ }
+
secureCookie := isRequestHTTPS(c)
setCookie(c, linuxDoOAuthStateCookieName, encodeCookieValue(state), linuxDoOAuthCookieMaxAgeSec, secureCookie)
setCookie(c, linuxDoOAuthRedirectCookie, encodeCookieValue(redirectTo), linuxDoOAuthCookieMaxAgeSec, secureCookie)
+ setOAuthPendingBrowserCookie(c, browserSessionKey, secureCookie)
+ clearOAuthPendingSessionCookie(c, secureCookie)
- codeChallenge := ""
- if cfg.UsePKCE {
- verifier, err := oauth.GenerateCodeVerifier()
- if err != nil {
- response.ErrorFrom(c, infraerrors.InternalServer("OAUTH_PKCE_GEN_FAILED", "failed to generate pkce verifier").WithCause(err))
- return
- }
- codeChallenge = oauth.GenerateCodeChallenge(verifier)
- setCookie(c, linuxDoOAuthVerifierCookie, encodeCookieValue(verifier), linuxDoOAuthCookieMaxAgeSec, secureCookie)
+ verifier, err := oauth.GenerateCodeVerifier()
+ if err != nil {
+ response.ErrorFrom(c, infraerrors.InternalServer("OAUTH_PKCE_GEN_FAILED", "failed to generate pkce verifier").WithCause(err))
+ return
}
+ codeChallenge := oauth.GenerateCodeChallenge(verifier)
+ setCookie(c, linuxDoOAuthVerifierCookie, encodeCookieValue(verifier), linuxDoOAuthCookieMaxAgeSec, secureCookie)
redirectURI := strings.TrimSpace(cfg.RedirectURL)
if redirectURI == "" {
@@ -161,14 +166,16 @@ func (h *AuthHandler) LinuxDoOAuthCallback(c *gin.Context) {
if redirectTo == "" {
redirectTo = linuxDoOAuthDefaultRedirectTo
}
+ browserSessionKey, _ := readOAuthPendingBrowserCookie(c)
+ if strings.TrimSpace(browserSessionKey) == "" {
+ redirectOAuthError(c, frontendCallback, "missing_browser_session", "missing oauth browser session", "")
+ return
+ }
- codeVerifier := ""
- if cfg.UsePKCE {
- codeVerifier, _ = readCookieDecoded(c, linuxDoOAuthVerifierCookie)
- if codeVerifier == "" {
- redirectOAuthError(c, frontendCallback, "missing_verifier", "missing pkce verifier", "")
- return
- }
+ codeVerifier, _ := readCookieDecoded(c, linuxDoOAuthVerifierCookie)
+ if codeVerifier == "" {
+ redirectOAuthError(c, frontendCallback, "missing_verifier", "missing pkce verifier", "")
+ return
}
redirectURI := strings.TrimSpace(cfg.RedirectURL)
@@ -198,7 +205,7 @@ func (h *AuthHandler) LinuxDoOAuthCallback(c *gin.Context) {
return
}
- email, username, subject, err := linuxDoFetchUserInfo(c.Request.Context(), cfg, tokenResp)
+ email, username, subject, displayName, avatarURL, err := linuxDoFetchUserInfo(c.Request.Context(), cfg, tokenResp)
if err != nil {
log.Printf("[LinuxDo OAuth] userinfo fetch failed: %v", err)
redirectOAuthError(c, frontendCallback, "userinfo_failed", "failed to fetch user info", "")
@@ -215,16 +222,32 @@ func (h *AuthHandler) LinuxDoOAuthCallback(c *gin.Context) {
tokenPair, _, err := h.authService.LoginOrRegisterOAuthWithTokenPair(c.Request.Context(), email, username, "")
if err != nil {
if errors.Is(err, service.ErrOAuthInvitationRequired) {
- pendingToken, tokenErr := h.authService.CreatePendingOAuthToken(email, username)
- if tokenErr != nil {
- redirectOAuthError(c, frontendCallback, "login_failed", "service_error", "")
+ if err := h.createOAuthPendingSession(c, oauthPendingSessionPayload{
+ Intent: "login",
+ Identity: service.PendingAuthIdentityKey{
+ ProviderType: "linuxdo",
+ ProviderKey: "linuxdo",
+ ProviderSubject: subject,
+ },
+ ResolvedEmail: email,
+ RedirectTo: redirectTo,
+ BrowserSessionKey: browserSessionKey,
+ UpstreamIdentityClaims: map[string]any{
+ "email": email,
+ "username": username,
+ "subject": subject,
+ "suggested_display_name": displayName,
+ "suggested_avatar_url": avatarURL,
+ },
+ CompletionResponse: map[string]any{
+ "error": "invitation_required",
+ "redirect": redirectTo,
+ },
+ }); err != nil {
+ redirectOAuthError(c, frontendCallback, "session_error", "failed to continue oauth login", "")
return
}
- fragment := url.Values{}
- fragment.Set("error", "invitation_required")
- fragment.Set("pending_oauth_token", pendingToken)
- fragment.Set("redirect", redirectTo)
- redirectWithFragment(c, frontendCallback, fragment)
+ redirectToFrontendCallback(c, frontendCallback)
return
}
// 避免把内部细节泄露给客户端;给前端保留结构化原因与提示信息即可。
@@ -232,18 +255,39 @@ func (h *AuthHandler) LinuxDoOAuthCallback(c *gin.Context) {
return
}
- fragment := url.Values{}
- fragment.Set("access_token", tokenPair.AccessToken)
- fragment.Set("refresh_token", tokenPair.RefreshToken)
- fragment.Set("expires_in", fmt.Sprintf("%d", tokenPair.ExpiresIn))
- fragment.Set("token_type", "Bearer")
- fragment.Set("redirect", redirectTo)
- redirectWithFragment(c, frontendCallback, fragment)
+ if err := h.createOAuthPendingSession(c, oauthPendingSessionPayload{
+ Intent: "login",
+ Identity: service.PendingAuthIdentityKey{
+ ProviderType: "linuxdo",
+ ProviderKey: "linuxdo",
+ ProviderSubject: subject,
+ },
+ ResolvedEmail: email,
+ RedirectTo: redirectTo,
+ BrowserSessionKey: browserSessionKey,
+ UpstreamIdentityClaims: map[string]any{
+ "email": email,
+ "username": username,
+ "subject": subject,
+ "suggested_display_name": displayName,
+ "suggested_avatar_url": avatarURL,
+ },
+ CompletionResponse: map[string]any{
+ "access_token": tokenPair.AccessToken,
+ "refresh_token": tokenPair.RefreshToken,
+ "expires_in": tokenPair.ExpiresIn,
+ "token_type": "Bearer",
+ "redirect": redirectTo,
+ },
+ }); err != nil {
+ redirectOAuthError(c, frontendCallback, "session_error", "failed to continue oauth login", "")
+ return
+ }
+ redirectToFrontendCallback(c, frontendCallback)
}
type completeLinuxDoOAuthRequest struct {
- PendingOAuthToken string `json:"pending_oauth_token" binding:"required"`
- InvitationCode string `json:"invitation_code" binding:"required"`
+ InvitationCode string `json:"invitation_code" binding:"required"`
}
// CompleteLinuxDoOAuthRegistration completes a pending OAuth registration by validating
@@ -256,9 +300,38 @@ func (h *AuthHandler) CompleteLinuxDoOAuthRegistration(c *gin.Context) {
return
}
- email, username, err := h.authService.VerifyPendingOAuthToken(req.PendingOAuthToken)
+ secureCookie := isRequestHTTPS(c)
+ sessionToken, err := readOAuthPendingSessionCookie(c)
if err != nil {
- c.JSON(http.StatusUnauthorized, gin.H{"error": "INVALID_TOKEN", "message": "invalid or expired registration token"})
+ clearOAuthPendingSessionCookie(c, secureCookie)
+ clearOAuthPendingBrowserCookie(c, secureCookie)
+ response.ErrorFrom(c, service.ErrPendingAuthSessionNotFound)
+ return
+ }
+ browserSessionKey, err := readOAuthPendingBrowserCookie(c)
+ if err != nil {
+ clearOAuthPendingSessionCookie(c, secureCookie)
+ clearOAuthPendingBrowserCookie(c, secureCookie)
+ response.ErrorFrom(c, service.ErrPendingAuthBrowserMismatch)
+ return
+ }
+ pendingSvc, err := h.pendingIdentityService()
+ if err != nil {
+ response.ErrorFrom(c, err)
+ return
+ }
+ session, err := pendingSvc.GetBrowserSession(c.Request.Context(), sessionToken, browserSessionKey)
+ if err != nil {
+ clearOAuthPendingSessionCookie(c, secureCookie)
+ clearOAuthPendingBrowserCookie(c, secureCookie)
+ response.ErrorFrom(c, err)
+ return
+ }
+
+ email := strings.TrimSpace(session.ResolvedEmail)
+ username := pendingSessionStringValue(session.UpstreamIdentityClaims, "username")
+ if email == "" || username == "" {
+ response.ErrorFrom(c, infraerrors.BadRequest("PENDING_AUTH_SESSION_INVALID", "pending auth registration context is invalid"))
return
}
@@ -267,6 +340,14 @@ func (h *AuthHandler) CompleteLinuxDoOAuthRegistration(c *gin.Context) {
response.ErrorFrom(c, err)
return
}
+ if _, err := pendingSvc.ConsumeBrowserSession(c.Request.Context(), sessionToken, browserSessionKey); err != nil {
+ clearOAuthPendingSessionCookie(c, secureCookie)
+ clearOAuthPendingBrowserCookie(c, secureCookie)
+ response.ErrorFrom(c, err)
+ return
+ }
+ clearOAuthPendingSessionCookie(c, secureCookie)
+ clearOAuthPendingBrowserCookie(c, secureCookie)
c.JSON(http.StatusOK, gin.H{
"access_token": tokenPair.AccessToken,
@@ -303,9 +384,7 @@ func linuxDoExchangeCode(
form.Set("client_id", cfg.ClientID)
form.Set("code", code)
form.Set("redirect_uri", redirectURI)
- if cfg.UsePKCE {
- form.Set("code_verifier", codeVerifier)
- }
+ form.Set("code_verifier", codeVerifier)
r := client.R().
SetContext(ctx).
@@ -353,11 +432,11 @@ func linuxDoFetchUserInfo(
ctx context.Context,
cfg config.LinuxDoConnectConfig,
token *linuxDoTokenResponse,
-) (email string, username string, subject string, err error) {
+) (email string, username string, subject string, displayName string, avatarURL string, err error) {
client := req.C().SetTimeout(30 * time.Second)
authorization, err := buildBearerAuthorization(token.TokenType, token.AccessToken)
if err != nil {
- return "", "", "", fmt.Errorf("invalid token for userinfo request: %w", err)
+ return "", "", "", "", "", fmt.Errorf("invalid token for userinfo request: %w", err)
}
resp, err := client.R().
@@ -366,16 +445,16 @@ func linuxDoFetchUserInfo(
SetHeader("Authorization", authorization).
Get(cfg.UserInfoURL)
if err != nil {
- return "", "", "", fmt.Errorf("request userinfo: %w", err)
+ return "", "", "", "", "", fmt.Errorf("request userinfo: %w", err)
}
if !resp.IsSuccessState() {
- return "", "", "", fmt.Errorf("userinfo status=%d", resp.StatusCode)
+ return "", "", "", "", "", fmt.Errorf("userinfo status=%d", resp.StatusCode)
}
return linuxDoParseUserInfo(resp.String(), cfg)
}
-func linuxDoParseUserInfo(body string, cfg config.LinuxDoConnectConfig) (email string, username string, subject string, err error) {
+func linuxDoParseUserInfo(body string, cfg config.LinuxDoConnectConfig) (email string, username string, subject string, displayName string, avatarURL string, err error) {
email = firstNonEmpty(
getGJSON(body, cfg.UserInfoEmailPath),
getGJSON(body, "email"),
@@ -400,12 +479,29 @@ func linuxDoParseUserInfo(body string, cfg config.LinuxDoConnectConfig) (email s
getGJSON(body, "user.id"),
)
+ displayName = firstNonEmpty(
+ getGJSON(body, "name"),
+ getGJSON(body, "nickname"),
+ getGJSON(body, "display_name"),
+ getGJSON(body, "user.name"),
+ getGJSON(body, "user.username"),
+ username,
+ )
+ avatarURL = firstNonEmpty(
+ getGJSON(body, "avatar_url"),
+ getGJSON(body, "avatar"),
+ getGJSON(body, "picture"),
+ getGJSON(body, "profile_image_url"),
+ getGJSON(body, "user.avatar"),
+ getGJSON(body, "user.avatar_url"),
+ )
+
subject = strings.TrimSpace(subject)
if subject == "" {
- return "", "", "", errors.New("userinfo missing id field")
+ return "", "", "", "", "", errors.New("userinfo missing id field")
}
if !isSafeLinuxDoSubject(subject) {
- return "", "", "", errors.New("userinfo returned invalid id field")
+ return "", "", "", "", "", errors.New("userinfo returned invalid id field")
}
email = strings.TrimSpace(email)
@@ -418,8 +514,13 @@ func linuxDoParseUserInfo(body string, cfg config.LinuxDoConnectConfig) (email s
if username == "" {
username = "linuxdo_" + subject
}
+ displayName = strings.TrimSpace(displayName)
+ if displayName == "" {
+ displayName = username
+ }
+ avatarURL = strings.TrimSpace(avatarURL)
- return email, username, subject, nil
+ return email, username, subject, displayName, avatarURL, nil
}
func buildLinuxDoAuthorizeURL(cfg config.LinuxDoConnectConfig, state string, codeChallenge string, redirectURI string) (string, error) {
@@ -436,10 +537,8 @@ func buildLinuxDoAuthorizeURL(cfg config.LinuxDoConnectConfig, state string, cod
q.Set("scope", cfg.Scopes)
}
q.Set("state", state)
- if cfg.UsePKCE {
- q.Set("code_challenge", codeChallenge)
- q.Set("code_challenge_method", "S256")
- }
+ q.Set("code_challenge", codeChallenge)
+ q.Set("code_challenge_method", "S256")
u.RawQuery = q.Encode()
return u.String(), nil
diff --git a/backend/internal/handler/auth_linuxdo_oauth_test.go b/backend/internal/handler/auth_linuxdo_oauth_test.go
index ff169c52..90bc10d1 100644
--- a/backend/internal/handler/auth_linuxdo_oauth_test.go
+++ b/backend/internal/handler/auth_linuxdo_oauth_test.go
@@ -41,11 +41,13 @@ func TestLinuxDoParseUserInfoParsesIDAndUsername(t *testing.T) {
UserInfoURL: "https://connect.linux.do/api/user",
}
- email, username, subject, err := linuxDoParseUserInfo(`{"id":123,"username":"alice"}`, cfg)
+ email, username, subject, displayName, avatarURL, err := linuxDoParseUserInfo(`{"id":123,"username":"alice","name":"Alice","avatar_url":"https://cdn.example/avatar.png"}`, cfg)
require.NoError(t, err)
require.Equal(t, "123", subject)
require.Equal(t, "alice", username)
require.Equal(t, "linuxdo-123@linuxdo-connect.invalid", email)
+ require.Equal(t, "Alice", displayName)
+ require.Equal(t, "https://cdn.example/avatar.png", avatarURL)
}
func TestLinuxDoParseUserInfoDefaultsUsername(t *testing.T) {
@@ -53,11 +55,13 @@ func TestLinuxDoParseUserInfoDefaultsUsername(t *testing.T) {
UserInfoURL: "https://connect.linux.do/api/user",
}
- email, username, subject, err := linuxDoParseUserInfo(`{"id":"123"}`, cfg)
+ email, username, subject, displayName, avatarURL, err := linuxDoParseUserInfo(`{"id":"123"}`, cfg)
require.NoError(t, err)
require.Equal(t, "123", subject)
require.Equal(t, "linuxdo_123", username)
require.Equal(t, "linuxdo-123@linuxdo-connect.invalid", email)
+ require.Equal(t, "linuxdo_123", displayName)
+ require.Equal(t, "", avatarURL)
}
func TestLinuxDoParseUserInfoRejectsUnsafeSubject(t *testing.T) {
@@ -65,11 +69,11 @@ func TestLinuxDoParseUserInfoRejectsUnsafeSubject(t *testing.T) {
UserInfoURL: "https://connect.linux.do/api/user",
}
- _, _, _, err := linuxDoParseUserInfo(`{"id":"123@456"}`, cfg)
+ _, _, _, _, _, err := linuxDoParseUserInfo(`{"id":"123@456"}`, cfg)
require.Error(t, err)
tooLong := strings.Repeat("a", linuxDoOAuthMaxSubjectLen+1)
- _, _, _, err = linuxDoParseUserInfo(`{"id":"`+tooLong+`"}`, cfg)
+ _, _, _, _, _, err = linuxDoParseUserInfo(`{"id":"`+tooLong+`"}`, cfg)
require.Error(t, err)
}
diff --git a/backend/internal/handler/auth_oauth_pending_flow.go b/backend/internal/handler/auth_oauth_pending_flow.go
new file mode 100644
index 00000000..a758c0b9
--- /dev/null
+++ b/backend/internal/handler/auth_oauth_pending_flow.go
@@ -0,0 +1,263 @@
+package handler
+
+import (
+ "net/http"
+ "net/url"
+ "strings"
+
+ infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
+ "github.com/Wei-Shaw/sub2api/internal/pkg/oauth"
+ "github.com/Wei-Shaw/sub2api/internal/pkg/response"
+ "github.com/Wei-Shaw/sub2api/internal/service"
+
+ "github.com/gin-gonic/gin"
+)
+
+const (
+ oauthPendingBrowserCookiePath = "/api/v1/auth/oauth"
+ oauthPendingBrowserCookieName = "oauth_pending_browser_session"
+ oauthPendingSessionCookiePath = "/api/v1/auth/oauth/pending"
+ oauthPendingSessionCookieName = "oauth_pending_session"
+ oauthPendingCookieMaxAgeSec = 10 * 60
+
+ oauthCompletionResponseKey = "completion_response"
+)
+
+type oauthPendingSessionPayload struct {
+ Intent string
+ Identity service.PendingAuthIdentityKey
+ ResolvedEmail string
+ RedirectTo string
+ BrowserSessionKey string
+ UpstreamIdentityClaims map[string]any
+ CompletionResponse map[string]any
+}
+
+func (h *AuthHandler) pendingIdentityService() (*service.AuthPendingIdentityService, error) {
+ if h == nil || h.authService == nil || h.authService.EntClient() == nil {
+ return nil, infraerrors.ServiceUnavailable("PENDING_AUTH_NOT_READY", "pending auth service is not ready")
+ }
+ return service.NewAuthPendingIdentityService(h.authService.EntClient()), nil
+}
+
+func generateOAuthPendingBrowserSession() (string, error) {
+ return oauth.GenerateState()
+}
+
+func setOAuthPendingBrowserCookie(c *gin.Context, sessionKey string, secure bool) {
+ http.SetCookie(c.Writer, &http.Cookie{
+ Name: oauthPendingBrowserCookieName,
+ Value: encodeCookieValue(sessionKey),
+ Path: oauthPendingBrowserCookiePath,
+ MaxAge: oauthPendingCookieMaxAgeSec,
+ HttpOnly: true,
+ Secure: secure,
+ SameSite: http.SameSiteLaxMode,
+ })
+}
+
+func clearOAuthPendingBrowserCookie(c *gin.Context, secure bool) {
+ http.SetCookie(c.Writer, &http.Cookie{
+ Name: oauthPendingBrowserCookieName,
+ Value: "",
+ Path: oauthPendingBrowserCookiePath,
+ MaxAge: -1,
+ HttpOnly: true,
+ Secure: secure,
+ SameSite: http.SameSiteLaxMode,
+ })
+}
+
+func readOAuthPendingBrowserCookie(c *gin.Context) (string, error) {
+ return readCookieDecoded(c, oauthPendingBrowserCookieName)
+}
+
+func setOAuthPendingSessionCookie(c *gin.Context, sessionToken string, secure bool) {
+ http.SetCookie(c.Writer, &http.Cookie{
+ Name: oauthPendingSessionCookieName,
+ Value: encodeCookieValue(sessionToken),
+ Path: oauthPendingSessionCookiePath,
+ MaxAge: oauthPendingCookieMaxAgeSec,
+ HttpOnly: true,
+ Secure: secure,
+ SameSite: http.SameSiteLaxMode,
+ })
+}
+
+func clearOAuthPendingSessionCookie(c *gin.Context, secure bool) {
+ http.SetCookie(c.Writer, &http.Cookie{
+ Name: oauthPendingSessionCookieName,
+ Value: "",
+ Path: oauthPendingSessionCookiePath,
+ MaxAge: -1,
+ HttpOnly: true,
+ Secure: secure,
+ SameSite: http.SameSiteLaxMode,
+ })
+}
+
+func readOAuthPendingSessionCookie(c *gin.Context) (string, error) {
+ return readCookieDecoded(c, oauthPendingSessionCookieName)
+}
+
+func redirectToFrontendCallback(c *gin.Context, frontendCallback string) {
+ u, err := url.Parse(frontendCallback)
+ if err != nil {
+ c.Redirect(http.StatusFound, linuxDoOAuthDefaultRedirectTo)
+ return
+ }
+ if u.Scheme != "" && !strings.EqualFold(u.Scheme, "http") && !strings.EqualFold(u.Scheme, "https") {
+ c.Redirect(http.StatusFound, linuxDoOAuthDefaultRedirectTo)
+ return
+ }
+ u.Fragment = ""
+ c.Header("Cache-Control", "no-store")
+ c.Header("Pragma", "no-cache")
+ c.Redirect(http.StatusFound, u.String())
+}
+
+func (h *AuthHandler) createOAuthPendingSession(c *gin.Context, payload oauthPendingSessionPayload) error {
+ svc, err := h.pendingIdentityService()
+ if err != nil {
+ return err
+ }
+
+ session, err := svc.CreatePendingSession(c.Request.Context(), service.CreatePendingAuthSessionInput{
+ Intent: strings.TrimSpace(payload.Intent),
+ Identity: payload.Identity,
+ ResolvedEmail: strings.TrimSpace(payload.ResolvedEmail),
+ RedirectTo: strings.TrimSpace(payload.RedirectTo),
+ BrowserSessionKey: strings.TrimSpace(payload.BrowserSessionKey),
+ UpstreamIdentityClaims: payload.UpstreamIdentityClaims,
+ LocalFlowState: map[string]any{
+ oauthCompletionResponseKey: payload.CompletionResponse,
+ },
+ })
+ if err != nil {
+ return infraerrors.InternalServer("PENDING_AUTH_SESSION_CREATE_FAILED", "failed to create pending auth session").WithCause(err)
+ }
+
+ setOAuthPendingSessionCookie(c, session.SessionToken, isRequestHTTPS(c))
+ return nil
+}
+
+func readCompletionResponse(session map[string]any) (map[string]any, bool) {
+ if len(session) == 0 {
+ return nil, false
+ }
+ value, ok := session[oauthCompletionResponseKey]
+ if !ok {
+ return nil, false
+ }
+ result, ok := value.(map[string]any)
+ if !ok {
+ return nil, false
+ }
+ return result, true
+}
+
+func pendingSessionStringValue(values map[string]any, key string) string {
+ if len(values) == 0 {
+ return ""
+ }
+ raw, ok := values[key]
+ if !ok {
+ return ""
+ }
+ value, ok := raw.(string)
+ if !ok {
+ return ""
+ }
+ return strings.TrimSpace(value)
+}
+
+func pendingSessionWantsInvitation(payload map[string]any) bool {
+ return strings.EqualFold(strings.TrimSpace(pendingSessionStringValue(payload, "error")), "invitation_required")
+}
+
+func applySuggestedProfileToCompletionResponse(payload map[string]any, upstream map[string]any) {
+ if len(payload) == 0 || len(upstream) == 0 {
+ return
+ }
+
+ displayName := pendingSessionStringValue(upstream, "suggested_display_name")
+ avatarURL := pendingSessionStringValue(upstream, "suggested_avatar_url")
+
+ if displayName != "" {
+ if _, exists := payload["suggested_display_name"]; !exists {
+ payload["suggested_display_name"] = displayName
+ }
+ }
+ if avatarURL != "" {
+ if _, exists := payload["suggested_avatar_url"]; !exists {
+ payload["suggested_avatar_url"] = avatarURL
+ }
+ }
+ if displayName != "" || avatarURL != "" {
+ payload["adoption_required"] = true
+ }
+}
+
+// ExchangePendingOAuthCompletion redeems a pending OAuth browser session into a frontend-safe payload.
+// POST /api/v1/auth/oauth/pending/exchange
+func (h *AuthHandler) ExchangePendingOAuthCompletion(c *gin.Context) {
+ secureCookie := isRequestHTTPS(c)
+ clearCookies := func() {
+ clearOAuthPendingSessionCookie(c, secureCookie)
+ clearOAuthPendingBrowserCookie(c, secureCookie)
+ }
+
+ sessionToken, err := readOAuthPendingSessionCookie(c)
+ if err != nil || strings.TrimSpace(sessionToken) == "" {
+ clearCookies()
+ response.ErrorFrom(c, service.ErrPendingAuthSessionNotFound)
+ return
+ }
+ browserSessionKey, err := readOAuthPendingBrowserCookie(c)
+ if err != nil || strings.TrimSpace(browserSessionKey) == "" {
+ clearCookies()
+ response.ErrorFrom(c, service.ErrPendingAuthBrowserMismatch)
+ return
+ }
+
+ svc, err := h.pendingIdentityService()
+ if err != nil {
+ clearCookies()
+ response.ErrorFrom(c, err)
+ return
+ }
+
+ session, err := svc.GetBrowserSession(c.Request.Context(), sessionToken, browserSessionKey)
+ if err != nil {
+ clearCookies()
+ response.ErrorFrom(c, err)
+ return
+ }
+
+ payload, ok := readCompletionResponse(session.LocalFlowState)
+ if !ok {
+ clearCookies()
+ response.ErrorFrom(c, infraerrors.InternalServer("PENDING_AUTH_COMPLETION_INVALID", "pending auth completion payload is invalid"))
+ return
+ }
+ if strings.TrimSpace(session.RedirectTo) != "" {
+ if _, exists := payload["redirect"]; !exists {
+ payload["redirect"] = session.RedirectTo
+ }
+ }
+ applySuggestedProfileToCompletionResponse(payload, session.UpstreamIdentityClaims)
+
+ if pendingSessionWantsInvitation(payload) {
+ response.Success(c, payload)
+ return
+ }
+
+ if _, err := svc.ConsumeBrowserSession(c.Request.Context(), sessionToken, browserSessionKey); err != nil {
+ clearCookies()
+ response.ErrorFrom(c, err)
+ return
+ }
+
+ clearCookies()
+ response.Success(c, payload)
+}
diff --git a/backend/internal/handler/auth_oauth_pending_flow_test.go b/backend/internal/handler/auth_oauth_pending_flow_test.go
new file mode 100644
index 00000000..5517bae2
--- /dev/null
+++ b/backend/internal/handler/auth_oauth_pending_flow_test.go
@@ -0,0 +1,40 @@
+package handler
+
+import (
+ "testing"
+
+ "github.com/stretchr/testify/require"
+)
+
+func TestApplySuggestedProfileToCompletionResponse(t *testing.T) {
+ payload := map[string]any{
+ "access_token": "token",
+ }
+ upstream := map[string]any{
+ "suggested_display_name": "Alice",
+ "suggested_avatar_url": "https://cdn.example/avatar.png",
+ }
+
+ applySuggestedProfileToCompletionResponse(payload, upstream)
+
+ require.Equal(t, "Alice", payload["suggested_display_name"])
+ require.Equal(t, "https://cdn.example/avatar.png", payload["suggested_avatar_url"])
+ require.Equal(t, true, payload["adoption_required"])
+}
+
+func TestApplySuggestedProfileToCompletionResponseKeepsExistingPayloadValues(t *testing.T) {
+ payload := map[string]any{
+ "suggested_display_name": "Existing",
+ "adoption_required": false,
+ }
+ upstream := map[string]any{
+ "suggested_display_name": "Alice",
+ "suggested_avatar_url": "https://cdn.example/avatar.png",
+ }
+
+ applySuggestedProfileToCompletionResponse(payload, upstream)
+
+ require.Equal(t, "Existing", payload["suggested_display_name"])
+ require.Equal(t, "https://cdn.example/avatar.png", payload["suggested_avatar_url"])
+ require.Equal(t, true, payload["adoption_required"])
+}
diff --git a/backend/internal/handler/auth_oidc_oauth.go b/backend/internal/handler/auth_oidc_oauth.go
index 37ef6833..e3694c8f 100644
--- a/backend/internal/handler/auth_oidc_oauth.go
+++ b/backend/internal/handler/auth_oidc_oauth.go
@@ -87,6 +87,8 @@ type oidcUserInfoClaims struct {
Username string
Subject string
EmailVerified *bool
+ DisplayName string
+ AvatarURL string
}
type oidcJWKSet struct {
@@ -338,12 +340,14 @@ func (h *AuthHandler) OIDCOAuthCallback(c *gin.Context) {
RedirectTo: redirectTo,
BrowserSessionKey: browserSessionKey,
UpstreamIdentityClaims: map[string]any{
- "email": email,
- "username": username,
- "subject": subject,
- "issuer": issuer,
- "email_verified": emailVerified != nil && *emailVerified,
- "provider_fallback": strings.TrimSpace(cfg.ProviderName),
+ "email": email,
+ "username": username,
+ "subject": subject,
+ "issuer": issuer,
+ "email_verified": emailVerified != nil && *emailVerified,
+ "provider_fallback": strings.TrimSpace(cfg.ProviderName),
+ "suggested_display_name": firstNonEmpty(userInfoClaims.DisplayName, idClaims.Name, username),
+ "suggested_avatar_url": userInfoClaims.AvatarURL,
},
CompletionResponse: map[string]any{
"error": "invitation_required",
@@ -371,12 +375,14 @@ func (h *AuthHandler) OIDCOAuthCallback(c *gin.Context) {
RedirectTo: redirectTo,
BrowserSessionKey: browserSessionKey,
UpstreamIdentityClaims: map[string]any{
- "email": email,
- "username": username,
- "subject": subject,
- "issuer": issuer,
- "email_verified": emailVerified != nil && *emailVerified,
- "provider_fallback": strings.TrimSpace(cfg.ProviderName),
+ "email": email,
+ "username": username,
+ "subject": subject,
+ "issuer": issuer,
+ "email_verified": emailVerified != nil && *emailVerified,
+ "provider_fallback": strings.TrimSpace(cfg.ProviderName),
+ "suggested_display_name": firstNonEmpty(userInfoClaims.DisplayName, idClaims.Name, username),
+ "suggested_avatar_url": userInfoClaims.AvatarURL,
},
CompletionResponse: map[string]any{
"access_token": tokenPair.AccessToken,
@@ -643,9 +649,26 @@ func oidcParseUserInfo(body string, cfg config.OIDCConnectConfig) *oidcUserInfoC
if verified, ok := getGJSONBool(body, "email_verified"); ok {
claims.EmailVerified = &verified
}
+ claims.DisplayName = firstNonEmpty(
+ getGJSON(body, "name"),
+ getGJSON(body, "nickname"),
+ getGJSON(body, "display_name"),
+ getGJSON(body, "preferred_username"),
+ getGJSON(body, "username"),
+ )
+ claims.AvatarURL = firstNonEmpty(
+ getGJSON(body, "picture"),
+ getGJSON(body, "avatar_url"),
+ getGJSON(body, "avatar"),
+ getGJSON(body, "profile_image_url"),
+ getGJSON(body, "user.avatar"),
+ getGJSON(body, "user.avatar_url"),
+ )
claims.Email = strings.TrimSpace(claims.Email)
claims.Username = strings.TrimSpace(claims.Username)
claims.Subject = strings.TrimSpace(claims.Subject)
+ claims.DisplayName = strings.TrimSpace(claims.DisplayName)
+ claims.AvatarURL = strings.TrimSpace(claims.AvatarURL)
return claims
}
diff --git a/backend/internal/handler/auth_oidc_oauth_test.go b/backend/internal/handler/auth_oidc_oauth_test.go
index a4cf776a..c389db51 100644
--- a/backend/internal/handler/auth_oidc_oauth_test.go
+++ b/backend/internal/handler/auth_oidc_oauth_test.go
@@ -91,6 +91,26 @@ func TestOIDCParseAndValidateIDToken(t *testing.T) {
require.Error(t, err)
}
+func TestOIDCParseUserInfoIncludesSuggestedProfile(t *testing.T) {
+ cfg := config.OIDCConnectConfig{}
+
+ claims := oidcParseUserInfo(`{
+ "sub":"subject-1",
+ "preferred_username":"alice",
+ "name":"Alice Example",
+ "picture":"https://cdn.example/avatar.png",
+ "email":"alice@example.com",
+ "email_verified":true
+ }`, cfg)
+
+ require.Equal(t, "subject-1", claims.Subject)
+ require.Equal(t, "alice", claims.Username)
+ require.Equal(t, "Alice Example", claims.DisplayName)
+ require.Equal(t, "https://cdn.example/avatar.png", claims.AvatarURL)
+ require.NotNil(t, claims.EmailVerified)
+ require.True(t, *claims.EmailVerified)
+}
+
func buildRSAJWK(kid string, pub *rsa.PublicKey) oidcJWK {
n := base64.RawURLEncoding.EncodeToString(pub.N.Bytes())
e := base64.RawURLEncoding.EncodeToString(big.NewInt(int64(pub.E)).Bytes())
diff --git a/frontend/src/api/auth.ts b/frontend/src/api/auth.ts
index 837c4f4c..d7abcd6a 100644
--- a/frontend/src/api/auth.ts
+++ b/frontend/src/api/auth.ts
@@ -186,6 +186,18 @@ export interface RefreshTokenResponse {
token_type: string
}
+export interface PendingOAuthExchangeResponse {
+ access_token?: string
+ refresh_token?: string
+ expires_in?: number
+ token_type?: string
+ redirect?: string
+ error?: string
+ adoption_required?: boolean
+ suggested_display_name?: string
+ suggested_avatar_url?: string
+}
+
/**
* Refresh the access token using the refresh token
* @returns New token pair
@@ -337,12 +349,10 @@ export async function resetPassword(request: ResetPasswordRequest): Promise {
const { data } = await apiClient.post<{
@@ -351,7 +361,6 @@ export async function completeLinuxDoOAuthRegistration(
expires_in: number
token_type: string
}>('/auth/oauth/linuxdo/complete-registration', {
- pending_oauth_token: pendingOAuthToken,
invitation_code: invitationCode
})
return data
@@ -359,12 +368,10 @@ export async function completeLinuxDoOAuthRegistration(
/**
* Complete OIDC OAuth registration by supplying an invitation code
- * @param pendingOAuthToken - Short-lived JWT from the OAuth callback
* @param invitationCode - Invitation code entered by the user
* @returns Token pair on success
*/
export async function completeOIDCOAuthRegistration(
- pendingOAuthToken: string,
invitationCode: string
): Promise<{ access_token: string; refresh_token: string; expires_in: number; token_type: string }> {
const { data } = await apiClient.post<{
@@ -373,12 +380,16 @@ export async function completeOIDCOAuthRegistration(
expires_in: number
token_type: string
}>('/auth/oauth/oidc/complete-registration', {
- pending_oauth_token: pendingOAuthToken,
invitation_code: invitationCode
})
return data
}
+export async function exchangePendingOAuthCompletion(): Promise {
+ const { data } = await apiClient.post('/auth/oauth/pending/exchange', {})
+ return data
+}
+
export const authAPI = {
login,
login2FA,
@@ -402,6 +413,7 @@ export const authAPI = {
resetPassword,
refreshToken,
revokeAllSessions,
+ exchangePendingOAuthCompletion,
completeLinuxDoOAuthRegistration,
completeOIDCOAuthRegistration
}
From e9de839d8791151314873bad11ac9583dabd2738 Mon Sep 17 00:00:00 2001
From: IanShaw027
Date: Mon, 20 Apr 2026 17:39:57 +0800
Subject: [PATCH 022/189] feat: rebuild auth identity foundation flow
---
backend/ent/authidentity.go | 266 +
backend/ent/authidentity/authidentity.go | 209 +
backend/ent/authidentity/where.go | 600 +++
backend/ent/authidentity_create.go | 1036 ++++
backend/ent/authidentity_delete.go | 88 +
backend/ent/authidentity_query.go | 797 +++
backend/ent/authidentity_update.go | 923 ++++
backend/ent/authidentitychannel.go | 228 +
.../authidentitychannel.go | 153 +
backend/ent/authidentitychannel/where.go | 559 ++
backend/ent/authidentitychannel_create.go | 932 ++++
backend/ent/authidentitychannel_delete.go | 88 +
backend/ent/authidentitychannel_query.go | 643 +++
backend/ent/authidentitychannel_update.go | 581 ++
backend/ent/client.go | 876 ++-
backend/ent/ent.go | 60 +-
backend/ent/hook/hook.go | 48 +
backend/ent/identityadoptiondecision.go | 223 +
.../identityadoptiondecision.go | 159 +
backend/ent/identityadoptiondecision/where.go | 342 ++
.../ent/identityadoptiondecision_create.go | 843 +++
.../ent/identityadoptiondecision_delete.go | 88 +
backend/ent/identityadoptiondecision_query.go | 721 +++
.../ent/identityadoptiondecision_update.go | 532 ++
backend/ent/intercept/intercept.go | 120 +
backend/ent/migrate/schema.go | 216 +
backend/ent/mutation.go | 4687 ++++++++++++++++-
backend/ent/pendingauthsession.go | 399 ++
.../pendingauthsession/pendingauthsession.go | 279 +
backend/ent/pendingauthsession/where.go | 1262 +++++
backend/ent/pendingauthsession_create.go | 1815 +++++++
backend/ent/pendingauthsession_delete.go | 88 +
backend/ent/pendingauthsession_query.go | 717 +++
backend/ent/pendingauthsession_update.go | 1178 +++++
backend/ent/predicate/predicate.go | 12 +
backend/ent/runtime/runtime.go | 266 +-
backend/ent/schema/auth_identity.go | 93 +
backend/ent/schema/auth_identity_channel.go | 72 +
.../ent/schema/auth_identity_schema_test.go | 124 +
.../ent/schema/identity_adoption_decision.go | 70 +
backend/ent/schema/pending_auth_session.go | 134 +
backend/ent/schema/user.go | 13 +
backend/ent/tx.go | 12 +
backend/ent/user.go | 79 +-
backend/ent/user/user.go | 88 +
backend/ent/user/where.go | 226 +
backend/ent/user_create.go | 290 +
backend/ent/user_query.go | 155 +-
backend/ent/user_update.go | 474 ++
backend/internal/config/config.go | 15 +-
.../internal/handler/admin/setting_handler.go | 189 +-
...tting_handler_auth_source_defaults_test.go | 149 +
.../internal/handler/auth_linuxdo_oauth.go | 21 +-
.../handler/auth_linuxdo_oauth_test.go | 87 +
.../handler/auth_oauth_pending_flow.go | 325 ++
.../handler/auth_oauth_pending_flow_test.go | 457 ++
backend/internal/handler/auth_oidc_oauth.go | 21 +-
.../internal/handler/auth_oidc_oauth_test.go | 84 +
backend/internal/handler/auth_wechat_oauth.go | 618 +++
.../handler/auth_wechat_oauth_test.go | 411 ++
backend/internal/handler/dto/settings.go | 1 +
.../handler/payment_webhook_handler.go | 2 +-
.../handler/payment_webhook_handler_test.go | 34 +
backend/internal/handler/setting_handler.go | 1 +
backend/internal/handler/user_handler.go | 30 +-
backend/internal/handler/user_handler_test.go | 136 +
backend/internal/repository/api_key_repo.go | 6 +
.../auth_identity_migration_report.go | 148 +
.../repository/user_profile_identity_repo.go | 544 ++
...ser_profile_identity_repo_contract_test.go | 428 ++
backend/internal/repository/user_repo.go | 44 +
.../user_repo_sort_integration_test.go | 83 +
backend/internal/server/api_contract_test.go | 4 +-
backend/internal/server/routes/auth.go | 14 +
.../service/admin_service_apikey_test.go | 9 +
.../service/admin_service_delete_test.go | 12 +
.../service/auth_pending_identity_service.go | 326 ++
.../auth_pending_identity_service_test.go | 224 +
backend/internal/service/auth_service.go | 164 +-
.../auth_service_identity_sync_test.go | 153 +
.../auth_service_pending_oauth_test.go | 146 -
backend/internal/service/domain_constants.go | 26 +
.../service/openai_account_scheduler.go | 77 +-
.../service/openai_account_scheduler_test.go | 430 +-
...enai_account_scheduler_ws_snapshot_test.go | 7 +-
.../service/payment_config_service.go | 96 +-
.../service/payment_config_service_test.go | 186 +
.../service/payment_resume_service.go | 248 +
.../service/payment_resume_service_test.go | 240 +
backend/internal/service/payment_service.go | 40 +-
backend/internal/service/setting_service.go | 256 +-
...tting_service_auth_source_defaults_test.go | 136 +
backend/internal/service/settings_view.go | 1 +
backend/internal/service/user.go | 34 +-
backend/internal/service/user_service.go | 153 +-
backend/internal/service/user_service_test.go | 180 +-
.../108_auth_identity_foundation_core.sql | 141 +
.../109_auth_identity_compat_backfill.sql | 125 +
...nding_auth_and_provider_default_grants.sql | 60 +
...11_payment_routing_and_scheduler_flags.sql | 8 +
.../api/__tests__/auth-oauth-adoption.spec.ts | 60 +
.../settings.authSourceDefaults.spec.ts | 118 +
frontend/src/api/admin/settings.ts | 127 +
frontend/src/api/auth.ts | 41 +-
.../components/auth/WechatOAuthSection.vue | 53 +
.../auth/__tests__/WechatOAuthSection.spec.ts | 74 +
.../components/payment/PaymentStatusPanel.vue | 29 +-
.../payment/__tests__/paymentFlow.spec.ts | 163 +
.../src/components/payment/paymentFlow.ts | 197 +
.../src/router/__tests__/wechat-route.spec.ts | 55 +
frontend/src/router/index.ts | 9 +
frontend/src/stores/app.ts | 1 +
frontend/src/types/index.ts | 1 +
frontend/src/views/admin/SettingsView.vue | 497 +-
.../src/views/auth/LinuxDoCallbackView.vue | 263 +-
frontend/src/views/auth/LoginView.vue | 10 +-
frontend/src/views/auth/OidcCallbackView.vue | 267 +-
frontend/src/views/auth/RegisterView.vue | 10 +-
.../src/views/auth/WechatCallbackView.vue | 361 ++
.../__tests__/LinuxDoCallbackView.spec.ts | 180 +
.../auth/__tests__/OidcCallbackView.spec.ts | 191 +
.../auth/__tests__/WechatCallbackView.spec.ts | 241 +
frontend/src/views/user/PaymentView.vue | 229 +-
123 files changed, 33599 insertions(+), 772 deletions(-)
create mode 100644 backend/ent/authidentity.go
create mode 100644 backend/ent/authidentity/authidentity.go
create mode 100644 backend/ent/authidentity/where.go
create mode 100644 backend/ent/authidentity_create.go
create mode 100644 backend/ent/authidentity_delete.go
create mode 100644 backend/ent/authidentity_query.go
create mode 100644 backend/ent/authidentity_update.go
create mode 100644 backend/ent/authidentitychannel.go
create mode 100644 backend/ent/authidentitychannel/authidentitychannel.go
create mode 100644 backend/ent/authidentitychannel/where.go
create mode 100644 backend/ent/authidentitychannel_create.go
create mode 100644 backend/ent/authidentitychannel_delete.go
create mode 100644 backend/ent/authidentitychannel_query.go
create mode 100644 backend/ent/authidentitychannel_update.go
create mode 100644 backend/ent/identityadoptiondecision.go
create mode 100644 backend/ent/identityadoptiondecision/identityadoptiondecision.go
create mode 100644 backend/ent/identityadoptiondecision/where.go
create mode 100644 backend/ent/identityadoptiondecision_create.go
create mode 100644 backend/ent/identityadoptiondecision_delete.go
create mode 100644 backend/ent/identityadoptiondecision_query.go
create mode 100644 backend/ent/identityadoptiondecision_update.go
create mode 100644 backend/ent/pendingauthsession.go
create mode 100644 backend/ent/pendingauthsession/pendingauthsession.go
create mode 100644 backend/ent/pendingauthsession/where.go
create mode 100644 backend/ent/pendingauthsession_create.go
create mode 100644 backend/ent/pendingauthsession_delete.go
create mode 100644 backend/ent/pendingauthsession_query.go
create mode 100644 backend/ent/pendingauthsession_update.go
create mode 100644 backend/ent/schema/auth_identity.go
create mode 100644 backend/ent/schema/auth_identity_channel.go
create mode 100644 backend/ent/schema/auth_identity_schema_test.go
create mode 100644 backend/ent/schema/identity_adoption_decision.go
create mode 100644 backend/ent/schema/pending_auth_session.go
create mode 100644 backend/internal/handler/admin/setting_handler_auth_source_defaults_test.go
create mode 100644 backend/internal/handler/auth_wechat_oauth.go
create mode 100644 backend/internal/handler/auth_wechat_oauth_test.go
create mode 100644 backend/internal/handler/user_handler_test.go
create mode 100644 backend/internal/repository/auth_identity_migration_report.go
create mode 100644 backend/internal/repository/user_profile_identity_repo.go
create mode 100644 backend/internal/repository/user_profile_identity_repo_contract_test.go
create mode 100644 backend/internal/service/auth_pending_identity_service.go
create mode 100644 backend/internal/service/auth_pending_identity_service_test.go
create mode 100644 backend/internal/service/auth_service_identity_sync_test.go
delete mode 100644 backend/internal/service/auth_service_pending_oauth_test.go
create mode 100644 backend/internal/service/payment_resume_service.go
create mode 100644 backend/internal/service/payment_resume_service_test.go
create mode 100644 backend/internal/service/setting_service_auth_source_defaults_test.go
create mode 100644 backend/migrations/108_auth_identity_foundation_core.sql
create mode 100644 backend/migrations/109_auth_identity_compat_backfill.sql
create mode 100644 backend/migrations/110_pending_auth_and_provider_default_grants.sql
create mode 100644 backend/migrations/111_payment_routing_and_scheduler_flags.sql
create mode 100644 frontend/src/api/__tests__/auth-oauth-adoption.spec.ts
create mode 100644 frontend/src/api/__tests__/settings.authSourceDefaults.spec.ts
create mode 100644 frontend/src/components/auth/WechatOAuthSection.vue
create mode 100644 frontend/src/components/auth/__tests__/WechatOAuthSection.spec.ts
create mode 100644 frontend/src/components/payment/__tests__/paymentFlow.spec.ts
create mode 100644 frontend/src/components/payment/paymentFlow.ts
create mode 100644 frontend/src/router/__tests__/wechat-route.spec.ts
create mode 100644 frontend/src/views/auth/WechatCallbackView.vue
create mode 100644 frontend/src/views/auth/__tests__/LinuxDoCallbackView.spec.ts
create mode 100644 frontend/src/views/auth/__tests__/OidcCallbackView.spec.ts
create mode 100644 frontend/src/views/auth/__tests__/WechatCallbackView.spec.ts
diff --git a/backend/ent/authidentity.go b/backend/ent/authidentity.go
new file mode 100644
index 00000000..5ccfcf19
--- /dev/null
+++ b/backend/ent/authidentity.go
@@ -0,0 +1,266 @@
+// Code generated by ent, DO NOT EDIT.
+
+package ent
+
+import (
+ "encoding/json"
+ "fmt"
+ "strings"
+ "time"
+
+ "entgo.io/ent"
+ "entgo.io/ent/dialect/sql"
+ "github.com/Wei-Shaw/sub2api/ent/authidentity"
+ "github.com/Wei-Shaw/sub2api/ent/user"
+)
+
+// AuthIdentity is the model entity for the AuthIdentity schema.
+type AuthIdentity struct {
+ config `json:"-"`
+ // ID of the ent.
+ ID int64 `json:"id,omitempty"`
+ // CreatedAt holds the value of the "created_at" field.
+ CreatedAt time.Time `json:"created_at,omitempty"`
+ // UpdatedAt holds the value of the "updated_at" field.
+ UpdatedAt time.Time `json:"updated_at,omitempty"`
+ // UserID holds the value of the "user_id" field.
+ UserID int64 `json:"user_id,omitempty"`
+ // ProviderType holds the value of the "provider_type" field.
+ ProviderType string `json:"provider_type,omitempty"`
+ // ProviderKey holds the value of the "provider_key" field.
+ ProviderKey string `json:"provider_key,omitempty"`
+ // ProviderSubject holds the value of the "provider_subject" field.
+ ProviderSubject string `json:"provider_subject,omitempty"`
+ // VerifiedAt holds the value of the "verified_at" field.
+ VerifiedAt *time.Time `json:"verified_at,omitempty"`
+ // Issuer holds the value of the "issuer" field.
+ Issuer *string `json:"issuer,omitempty"`
+ // Metadata holds the value of the "metadata" field.
+ Metadata map[string]interface{} `json:"metadata,omitempty"`
+ // Edges holds the relations/edges for other nodes in the graph.
+ // The values are being populated by the AuthIdentityQuery when eager-loading is set.
+ Edges AuthIdentityEdges `json:"edges"`
+ selectValues sql.SelectValues
+}
+
+// AuthIdentityEdges holds the relations/edges for other nodes in the graph.
+type AuthIdentityEdges struct {
+ // User holds the value of the user edge.
+ User *User `json:"user,omitempty"`
+ // Channels holds the value of the channels edge.
+ Channels []*AuthIdentityChannel `json:"channels,omitempty"`
+ // AdoptionDecisions holds the value of the adoption_decisions edge.
+ AdoptionDecisions []*IdentityAdoptionDecision `json:"adoption_decisions,omitempty"`
+ // loadedTypes holds the information for reporting if a
+ // type was loaded (or requested) in eager-loading or not.
+ loadedTypes [3]bool
+}
+
+// UserOrErr returns the User value or an error if the edge
+// was not loaded in eager-loading, or loaded but was not found.
+func (e AuthIdentityEdges) UserOrErr() (*User, error) {
+ if e.User != nil {
+ return e.User, nil
+ } else if e.loadedTypes[0] {
+ return nil, &NotFoundError{label: user.Label}
+ }
+ return nil, &NotLoadedError{edge: "user"}
+}
+
+// ChannelsOrErr returns the Channels value or an error if the edge
+// was not loaded in eager-loading.
+func (e AuthIdentityEdges) ChannelsOrErr() ([]*AuthIdentityChannel, error) {
+ if e.loadedTypes[1] {
+ return e.Channels, nil
+ }
+ return nil, &NotLoadedError{edge: "channels"}
+}
+
+// AdoptionDecisionsOrErr returns the AdoptionDecisions value or an error if the edge
+// was not loaded in eager-loading.
+func (e AuthIdentityEdges) AdoptionDecisionsOrErr() ([]*IdentityAdoptionDecision, error) {
+ if e.loadedTypes[2] {
+ return e.AdoptionDecisions, nil
+ }
+ return nil, &NotLoadedError{edge: "adoption_decisions"}
+}
+
+// scanValues returns the types for scanning values from sql.Rows.
+func (*AuthIdentity) scanValues(columns []string) ([]any, error) {
+ values := make([]any, len(columns))
+ for i := range columns {
+ switch columns[i] {
+ case authidentity.FieldMetadata:
+ values[i] = new([]byte)
+ case authidentity.FieldID, authidentity.FieldUserID:
+ values[i] = new(sql.NullInt64)
+ case authidentity.FieldProviderType, authidentity.FieldProviderKey, authidentity.FieldProviderSubject, authidentity.FieldIssuer:
+ values[i] = new(sql.NullString)
+ case authidentity.FieldCreatedAt, authidentity.FieldUpdatedAt, authidentity.FieldVerifiedAt:
+ values[i] = new(sql.NullTime)
+ default:
+ values[i] = new(sql.UnknownType)
+ }
+ }
+ return values, nil
+}
+
+// assignValues assigns the values that were returned from sql.Rows (after scanning)
+// to the AuthIdentity fields.
+func (_m *AuthIdentity) assignValues(columns []string, values []any) error {
+ if m, n := len(values), len(columns); m < n {
+ return fmt.Errorf("mismatch number of scan values: %d != %d", m, n)
+ }
+ for i := range columns {
+ switch columns[i] {
+ case authidentity.FieldID:
+ value, ok := values[i].(*sql.NullInt64)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field id", value)
+ }
+ _m.ID = int64(value.Int64)
+ case authidentity.FieldCreatedAt:
+ if value, ok := values[i].(*sql.NullTime); !ok {
+ return fmt.Errorf("unexpected type %T for field created_at", values[i])
+ } else if value.Valid {
+ _m.CreatedAt = value.Time
+ }
+ case authidentity.FieldUpdatedAt:
+ if value, ok := values[i].(*sql.NullTime); !ok {
+ return fmt.Errorf("unexpected type %T for field updated_at", values[i])
+ } else if value.Valid {
+ _m.UpdatedAt = value.Time
+ }
+ case authidentity.FieldUserID:
+ if value, ok := values[i].(*sql.NullInt64); !ok {
+ return fmt.Errorf("unexpected type %T for field user_id", values[i])
+ } else if value.Valid {
+ _m.UserID = value.Int64
+ }
+ case authidentity.FieldProviderType:
+ if value, ok := values[i].(*sql.NullString); !ok {
+ return fmt.Errorf("unexpected type %T for field provider_type", values[i])
+ } else if value.Valid {
+ _m.ProviderType = value.String
+ }
+ case authidentity.FieldProviderKey:
+ if value, ok := values[i].(*sql.NullString); !ok {
+ return fmt.Errorf("unexpected type %T for field provider_key", values[i])
+ } else if value.Valid {
+ _m.ProviderKey = value.String
+ }
+ case authidentity.FieldProviderSubject:
+ if value, ok := values[i].(*sql.NullString); !ok {
+ return fmt.Errorf("unexpected type %T for field provider_subject", values[i])
+ } else if value.Valid {
+ _m.ProviderSubject = value.String
+ }
+ case authidentity.FieldVerifiedAt:
+ if value, ok := values[i].(*sql.NullTime); !ok {
+ return fmt.Errorf("unexpected type %T for field verified_at", values[i])
+ } else if value.Valid {
+ _m.VerifiedAt = new(time.Time)
+ *_m.VerifiedAt = value.Time
+ }
+ case authidentity.FieldIssuer:
+ if value, ok := values[i].(*sql.NullString); !ok {
+ return fmt.Errorf("unexpected type %T for field issuer", values[i])
+ } else if value.Valid {
+ _m.Issuer = new(string)
+ *_m.Issuer = value.String
+ }
+ case authidentity.FieldMetadata:
+ if value, ok := values[i].(*[]byte); !ok {
+ return fmt.Errorf("unexpected type %T for field metadata", values[i])
+ } else if value != nil && len(*value) > 0 {
+ if err := json.Unmarshal(*value, &_m.Metadata); err != nil {
+ return fmt.Errorf("unmarshal field metadata: %w", err)
+ }
+ }
+ default:
+ _m.selectValues.Set(columns[i], values[i])
+ }
+ }
+ return nil
+}
+
+// Value returns the ent.Value that was dynamically selected and assigned to the AuthIdentity.
+// This includes values selected through modifiers, order, etc.
+func (_m *AuthIdentity) Value(name string) (ent.Value, error) {
+ return _m.selectValues.Get(name)
+}
+
+// QueryUser queries the "user" edge of the AuthIdentity entity.
+func (_m *AuthIdentity) QueryUser() *UserQuery {
+ return NewAuthIdentityClient(_m.config).QueryUser(_m)
+}
+
+// QueryChannels queries the "channels" edge of the AuthIdentity entity.
+func (_m *AuthIdentity) QueryChannels() *AuthIdentityChannelQuery {
+ return NewAuthIdentityClient(_m.config).QueryChannels(_m)
+}
+
+// QueryAdoptionDecisions queries the "adoption_decisions" edge of the AuthIdentity entity.
+func (_m *AuthIdentity) QueryAdoptionDecisions() *IdentityAdoptionDecisionQuery {
+ return NewAuthIdentityClient(_m.config).QueryAdoptionDecisions(_m)
+}
+
+// Update returns a builder for updating this AuthIdentity.
+// Note that you need to call AuthIdentity.Unwrap() before calling this method if this AuthIdentity
+// was returned from a transaction, and the transaction was committed or rolled back.
+func (_m *AuthIdentity) Update() *AuthIdentityUpdateOne {
+ return NewAuthIdentityClient(_m.config).UpdateOne(_m)
+}
+
+// Unwrap unwraps the AuthIdentity entity that was returned from a transaction after it was closed,
+// so that all future queries will be executed through the driver which created the transaction.
+func (_m *AuthIdentity) Unwrap() *AuthIdentity {
+ _tx, ok := _m.config.driver.(*txDriver)
+ if !ok {
+ panic("ent: AuthIdentity is not a transactional entity")
+ }
+ _m.config.driver = _tx.drv
+ return _m
+}
+
+// String implements the fmt.Stringer.
+func (_m *AuthIdentity) String() string {
+ var builder strings.Builder
+ builder.WriteString("AuthIdentity(")
+ builder.WriteString(fmt.Sprintf("id=%v, ", _m.ID))
+ builder.WriteString("created_at=")
+ builder.WriteString(_m.CreatedAt.Format(time.ANSIC))
+ builder.WriteString(", ")
+ builder.WriteString("updated_at=")
+ builder.WriteString(_m.UpdatedAt.Format(time.ANSIC))
+ builder.WriteString(", ")
+ builder.WriteString("user_id=")
+ builder.WriteString(fmt.Sprintf("%v", _m.UserID))
+ builder.WriteString(", ")
+ builder.WriteString("provider_type=")
+ builder.WriteString(_m.ProviderType)
+ builder.WriteString(", ")
+ builder.WriteString("provider_key=")
+ builder.WriteString(_m.ProviderKey)
+ builder.WriteString(", ")
+ builder.WriteString("provider_subject=")
+ builder.WriteString(_m.ProviderSubject)
+ builder.WriteString(", ")
+ if v := _m.VerifiedAt; v != nil {
+ builder.WriteString("verified_at=")
+ builder.WriteString(v.Format(time.ANSIC))
+ }
+ builder.WriteString(", ")
+ if v := _m.Issuer; v != nil {
+ builder.WriteString("issuer=")
+ builder.WriteString(*v)
+ }
+ builder.WriteString(", ")
+ builder.WriteString("metadata=")
+ builder.WriteString(fmt.Sprintf("%v", _m.Metadata))
+ builder.WriteByte(')')
+ return builder.String()
+}
+
+// AuthIdentities is a parsable slice of AuthIdentity.
+type AuthIdentities []*AuthIdentity
diff --git a/backend/ent/authidentity/authidentity.go b/backend/ent/authidentity/authidentity.go
new file mode 100644
index 00000000..c90be759
--- /dev/null
+++ b/backend/ent/authidentity/authidentity.go
@@ -0,0 +1,209 @@
+// Code generated by ent, DO NOT EDIT.
+
+package authidentity
+
+import (
+ "time"
+
+ "entgo.io/ent/dialect/sql"
+ "entgo.io/ent/dialect/sql/sqlgraph"
+)
+
+const (
+ // Label holds the string label denoting the authidentity type in the database.
+ Label = "auth_identity"
+ // FieldID holds the string denoting the id field in the database.
+ FieldID = "id"
+ // FieldCreatedAt holds the string denoting the created_at field in the database.
+ FieldCreatedAt = "created_at"
+ // FieldUpdatedAt holds the string denoting the updated_at field in the database.
+ FieldUpdatedAt = "updated_at"
+ // FieldUserID holds the string denoting the user_id field in the database.
+ FieldUserID = "user_id"
+ // FieldProviderType holds the string denoting the provider_type field in the database.
+ FieldProviderType = "provider_type"
+ // FieldProviderKey holds the string denoting the provider_key field in the database.
+ FieldProviderKey = "provider_key"
+ // FieldProviderSubject holds the string denoting the provider_subject field in the database.
+ FieldProviderSubject = "provider_subject"
+ // FieldVerifiedAt holds the string denoting the verified_at field in the database.
+ FieldVerifiedAt = "verified_at"
+ // FieldIssuer holds the string denoting the issuer field in the database.
+ FieldIssuer = "issuer"
+ // FieldMetadata holds the string denoting the metadata field in the database.
+ FieldMetadata = "metadata"
+ // EdgeUser holds the string denoting the user edge name in mutations.
+ EdgeUser = "user"
+ // EdgeChannels holds the string denoting the channels edge name in mutations.
+ EdgeChannels = "channels"
+ // EdgeAdoptionDecisions holds the string denoting the adoption_decisions edge name in mutations.
+ EdgeAdoptionDecisions = "adoption_decisions"
+ // Table holds the table name of the authidentity in the database.
+ Table = "auth_identities"
+ // UserTable is the table that holds the user relation/edge.
+ UserTable = "auth_identities"
+ // UserInverseTable is the table name for the User entity.
+ // It exists in this package in order to avoid circular dependency with the "user" package.
+ UserInverseTable = "users"
+ // UserColumn is the table column denoting the user relation/edge.
+ UserColumn = "user_id"
+ // ChannelsTable is the table that holds the channels relation/edge.
+ ChannelsTable = "auth_identity_channels"
+ // ChannelsInverseTable is the table name for the AuthIdentityChannel entity.
+ // It exists in this package in order to avoid circular dependency with the "authidentitychannel" package.
+ ChannelsInverseTable = "auth_identity_channels"
+ // ChannelsColumn is the table column denoting the channels relation/edge.
+ ChannelsColumn = "identity_id"
+ // AdoptionDecisionsTable is the table that holds the adoption_decisions relation/edge.
+ AdoptionDecisionsTable = "identity_adoption_decisions"
+ // AdoptionDecisionsInverseTable is the table name for the IdentityAdoptionDecision entity.
+ // It exists in this package in order to avoid circular dependency with the "identityadoptiondecision" package.
+ AdoptionDecisionsInverseTable = "identity_adoption_decisions"
+ // AdoptionDecisionsColumn is the table column denoting the adoption_decisions relation/edge.
+ AdoptionDecisionsColumn = "identity_id"
+)
+
+// Columns holds all SQL columns for authidentity fields.
+var Columns = []string{
+ FieldID,
+ FieldCreatedAt,
+ FieldUpdatedAt,
+ FieldUserID,
+ FieldProviderType,
+ FieldProviderKey,
+ FieldProviderSubject,
+ FieldVerifiedAt,
+ FieldIssuer,
+ FieldMetadata,
+}
+
+// ValidColumn reports if the column name is valid (part of the table columns).
+func ValidColumn(column string) bool {
+ for i := range Columns {
+ if column == Columns[i] {
+ return true
+ }
+ }
+ return false
+}
+
+var (
+ // DefaultCreatedAt holds the default value on creation for the "created_at" field.
+ DefaultCreatedAt func() time.Time
+ // DefaultUpdatedAt holds the default value on creation for the "updated_at" field.
+ DefaultUpdatedAt func() time.Time
+ // UpdateDefaultUpdatedAt holds the default value on update for the "updated_at" field.
+ UpdateDefaultUpdatedAt func() time.Time
+ // ProviderTypeValidator is a validator for the "provider_type" field. It is called by the builders before save.
+ ProviderTypeValidator func(string) error
+ // ProviderKeyValidator is a validator for the "provider_key" field. It is called by the builders before save.
+ ProviderKeyValidator func(string) error
+ // ProviderSubjectValidator is a validator for the "provider_subject" field. It is called by the builders before save.
+ ProviderSubjectValidator func(string) error
+ // DefaultMetadata holds the default value on creation for the "metadata" field.
+ DefaultMetadata func() map[string]interface{}
+)
+
+// OrderOption defines the ordering options for the AuthIdentity queries.
+type OrderOption func(*sql.Selector)
+
+// ByID orders the results by the id field.
+func ByID(opts ...sql.OrderTermOption) OrderOption {
+ return sql.OrderByField(FieldID, opts...).ToFunc()
+}
+
+// ByCreatedAt orders the results by the created_at field.
+func ByCreatedAt(opts ...sql.OrderTermOption) OrderOption {
+ return sql.OrderByField(FieldCreatedAt, opts...).ToFunc()
+}
+
+// ByUpdatedAt orders the results by the updated_at field.
+func ByUpdatedAt(opts ...sql.OrderTermOption) OrderOption {
+ return sql.OrderByField(FieldUpdatedAt, opts...).ToFunc()
+}
+
+// ByUserID orders the results by the user_id field.
+func ByUserID(opts ...sql.OrderTermOption) OrderOption {
+ return sql.OrderByField(FieldUserID, opts...).ToFunc()
+}
+
+// ByProviderType orders the results by the provider_type field.
+func ByProviderType(opts ...sql.OrderTermOption) OrderOption {
+ return sql.OrderByField(FieldProviderType, opts...).ToFunc()
+}
+
+// ByProviderKey orders the results by the provider_key field.
+func ByProviderKey(opts ...sql.OrderTermOption) OrderOption {
+ return sql.OrderByField(FieldProviderKey, opts...).ToFunc()
+}
+
+// ByProviderSubject orders the results by the provider_subject field.
+func ByProviderSubject(opts ...sql.OrderTermOption) OrderOption {
+ return sql.OrderByField(FieldProviderSubject, opts...).ToFunc()
+}
+
+// ByVerifiedAt orders the results by the verified_at field.
+func ByVerifiedAt(opts ...sql.OrderTermOption) OrderOption {
+ return sql.OrderByField(FieldVerifiedAt, opts...).ToFunc()
+}
+
+// ByIssuer orders the results by the issuer field.
+func ByIssuer(opts ...sql.OrderTermOption) OrderOption {
+ return sql.OrderByField(FieldIssuer, opts...).ToFunc()
+}
+
+// ByUserField orders the results by user field.
+func ByUserField(field string, opts ...sql.OrderTermOption) OrderOption {
+ return func(s *sql.Selector) {
+ sqlgraph.OrderByNeighborTerms(s, newUserStep(), sql.OrderByField(field, opts...))
+ }
+}
+
+// ByChannelsCount orders the results by channels count.
+func ByChannelsCount(opts ...sql.OrderTermOption) OrderOption {
+ return func(s *sql.Selector) {
+ sqlgraph.OrderByNeighborsCount(s, newChannelsStep(), opts...)
+ }
+}
+
+// ByChannels orders the results by channels terms.
+func ByChannels(term sql.OrderTerm, terms ...sql.OrderTerm) OrderOption {
+ return func(s *sql.Selector) {
+ sqlgraph.OrderByNeighborTerms(s, newChannelsStep(), append([]sql.OrderTerm{term}, terms...)...)
+ }
+}
+
+// ByAdoptionDecisionsCount orders the results by adoption_decisions count.
+func ByAdoptionDecisionsCount(opts ...sql.OrderTermOption) OrderOption {
+ return func(s *sql.Selector) {
+ sqlgraph.OrderByNeighborsCount(s, newAdoptionDecisionsStep(), opts...)
+ }
+}
+
+// ByAdoptionDecisions orders the results by adoption_decisions terms.
+func ByAdoptionDecisions(term sql.OrderTerm, terms ...sql.OrderTerm) OrderOption {
+ return func(s *sql.Selector) {
+ sqlgraph.OrderByNeighborTerms(s, newAdoptionDecisionsStep(), append([]sql.OrderTerm{term}, terms...)...)
+ }
+}
+func newUserStep() *sqlgraph.Step {
+ return sqlgraph.NewStep(
+ sqlgraph.From(Table, FieldID),
+ sqlgraph.To(UserInverseTable, FieldID),
+ sqlgraph.Edge(sqlgraph.M2O, true, UserTable, UserColumn),
+ )
+}
+func newChannelsStep() *sqlgraph.Step {
+ return sqlgraph.NewStep(
+ sqlgraph.From(Table, FieldID),
+ sqlgraph.To(ChannelsInverseTable, FieldID),
+ sqlgraph.Edge(sqlgraph.O2M, false, ChannelsTable, ChannelsColumn),
+ )
+}
+func newAdoptionDecisionsStep() *sqlgraph.Step {
+ return sqlgraph.NewStep(
+ sqlgraph.From(Table, FieldID),
+ sqlgraph.To(AdoptionDecisionsInverseTable, FieldID),
+ sqlgraph.Edge(sqlgraph.O2M, false, AdoptionDecisionsTable, AdoptionDecisionsColumn),
+ )
+}
diff --git a/backend/ent/authidentity/where.go b/backend/ent/authidentity/where.go
new file mode 100644
index 00000000..3dbf3178
--- /dev/null
+++ b/backend/ent/authidentity/where.go
@@ -0,0 +1,600 @@
+// Code generated by ent, DO NOT EDIT.
+
+package authidentity
+
+import (
+ "time"
+
+ "entgo.io/ent/dialect/sql"
+ "entgo.io/ent/dialect/sql/sqlgraph"
+ "github.com/Wei-Shaw/sub2api/ent/predicate"
+)
+
+// ID filters vertices based on their ID field.
+func ID(id int64) predicate.AuthIdentity {
+ return predicate.AuthIdentity(sql.FieldEQ(FieldID, id))
+}
+
+// IDEQ applies the EQ predicate on the ID field.
+func IDEQ(id int64) predicate.AuthIdentity {
+ return predicate.AuthIdentity(sql.FieldEQ(FieldID, id))
+}
+
+// IDNEQ applies the NEQ predicate on the ID field.
+func IDNEQ(id int64) predicate.AuthIdentity {
+ return predicate.AuthIdentity(sql.FieldNEQ(FieldID, id))
+}
+
+// IDIn applies the In predicate on the ID field.
+func IDIn(ids ...int64) predicate.AuthIdentity {
+ return predicate.AuthIdentity(sql.FieldIn(FieldID, ids...))
+}
+
+// IDNotIn applies the NotIn predicate on the ID field.
+func IDNotIn(ids ...int64) predicate.AuthIdentity {
+ return predicate.AuthIdentity(sql.FieldNotIn(FieldID, ids...))
+}
+
+// IDGT applies the GT predicate on the ID field.
+func IDGT(id int64) predicate.AuthIdentity {
+ return predicate.AuthIdentity(sql.FieldGT(FieldID, id))
+}
+
+// IDGTE applies the GTE predicate on the ID field.
+func IDGTE(id int64) predicate.AuthIdentity {
+ return predicate.AuthIdentity(sql.FieldGTE(FieldID, id))
+}
+
+// IDLT applies the LT predicate on the ID field.
+func IDLT(id int64) predicate.AuthIdentity {
+ return predicate.AuthIdentity(sql.FieldLT(FieldID, id))
+}
+
+// IDLTE applies the LTE predicate on the ID field.
+func IDLTE(id int64) predicate.AuthIdentity {
+ return predicate.AuthIdentity(sql.FieldLTE(FieldID, id))
+}
+
+// CreatedAt applies equality check predicate on the "created_at" field. It's identical to CreatedAtEQ.
+func CreatedAt(v time.Time) predicate.AuthIdentity {
+ return predicate.AuthIdentity(sql.FieldEQ(FieldCreatedAt, v))
+}
+
+// UpdatedAt applies equality check predicate on the "updated_at" field. It's identical to UpdatedAtEQ.
+func UpdatedAt(v time.Time) predicate.AuthIdentity {
+ return predicate.AuthIdentity(sql.FieldEQ(FieldUpdatedAt, v))
+}
+
+// UserID applies equality check predicate on the "user_id" field. It's identical to UserIDEQ.
+func UserID(v int64) predicate.AuthIdentity {
+ return predicate.AuthIdentity(sql.FieldEQ(FieldUserID, v))
+}
+
+// ProviderType applies equality check predicate on the "provider_type" field. It's identical to ProviderTypeEQ.
+func ProviderType(v string) predicate.AuthIdentity {
+ return predicate.AuthIdentity(sql.FieldEQ(FieldProviderType, v))
+}
+
+// ProviderKey applies equality check predicate on the "provider_key" field. It's identical to ProviderKeyEQ.
+func ProviderKey(v string) predicate.AuthIdentity {
+ return predicate.AuthIdentity(sql.FieldEQ(FieldProviderKey, v))
+}
+
+// ProviderSubject applies equality check predicate on the "provider_subject" field. It's identical to ProviderSubjectEQ.
+func ProviderSubject(v string) predicate.AuthIdentity {
+ return predicate.AuthIdentity(sql.FieldEQ(FieldProviderSubject, v))
+}
+
+// VerifiedAt applies equality check predicate on the "verified_at" field. It's identical to VerifiedAtEQ.
+func VerifiedAt(v time.Time) predicate.AuthIdentity {
+ return predicate.AuthIdentity(sql.FieldEQ(FieldVerifiedAt, v))
+}
+
+// Issuer applies equality check predicate on the "issuer" field. It's identical to IssuerEQ.
+func Issuer(v string) predicate.AuthIdentity {
+ return predicate.AuthIdentity(sql.FieldEQ(FieldIssuer, v))
+}
+
+// CreatedAtEQ applies the EQ predicate on the "created_at" field.
+func CreatedAtEQ(v time.Time) predicate.AuthIdentity {
+ return predicate.AuthIdentity(sql.FieldEQ(FieldCreatedAt, v))
+}
+
+// CreatedAtNEQ applies the NEQ predicate on the "created_at" field.
+func CreatedAtNEQ(v time.Time) predicate.AuthIdentity {
+ return predicate.AuthIdentity(sql.FieldNEQ(FieldCreatedAt, v))
+}
+
+// CreatedAtIn applies the In predicate on the "created_at" field.
+func CreatedAtIn(vs ...time.Time) predicate.AuthIdentity {
+ return predicate.AuthIdentity(sql.FieldIn(FieldCreatedAt, vs...))
+}
+
+// CreatedAtNotIn applies the NotIn predicate on the "created_at" field.
+func CreatedAtNotIn(vs ...time.Time) predicate.AuthIdentity {
+ return predicate.AuthIdentity(sql.FieldNotIn(FieldCreatedAt, vs...))
+}
+
+// CreatedAtGT applies the GT predicate on the "created_at" field.
+func CreatedAtGT(v time.Time) predicate.AuthIdentity {
+ return predicate.AuthIdentity(sql.FieldGT(FieldCreatedAt, v))
+}
+
+// CreatedAtGTE applies the GTE predicate on the "created_at" field.
+func CreatedAtGTE(v time.Time) predicate.AuthIdentity {
+ return predicate.AuthIdentity(sql.FieldGTE(FieldCreatedAt, v))
+}
+
+// CreatedAtLT applies the LT predicate on the "created_at" field.
+func CreatedAtLT(v time.Time) predicate.AuthIdentity {
+ return predicate.AuthIdentity(sql.FieldLT(FieldCreatedAt, v))
+}
+
+// CreatedAtLTE applies the LTE predicate on the "created_at" field.
+func CreatedAtLTE(v time.Time) predicate.AuthIdentity {
+ return predicate.AuthIdentity(sql.FieldLTE(FieldCreatedAt, v))
+}
+
+// UpdatedAtEQ applies the EQ predicate on the "updated_at" field.
+func UpdatedAtEQ(v time.Time) predicate.AuthIdentity {
+ return predicate.AuthIdentity(sql.FieldEQ(FieldUpdatedAt, v))
+}
+
+// UpdatedAtNEQ applies the NEQ predicate on the "updated_at" field.
+func UpdatedAtNEQ(v time.Time) predicate.AuthIdentity {
+ return predicate.AuthIdentity(sql.FieldNEQ(FieldUpdatedAt, v))
+}
+
+// UpdatedAtIn applies the In predicate on the "updated_at" field.
+func UpdatedAtIn(vs ...time.Time) predicate.AuthIdentity {
+ return predicate.AuthIdentity(sql.FieldIn(FieldUpdatedAt, vs...))
+}
+
+// UpdatedAtNotIn applies the NotIn predicate on the "updated_at" field.
+func UpdatedAtNotIn(vs ...time.Time) predicate.AuthIdentity {
+ return predicate.AuthIdentity(sql.FieldNotIn(FieldUpdatedAt, vs...))
+}
+
+// UpdatedAtGT applies the GT predicate on the "updated_at" field.
+func UpdatedAtGT(v time.Time) predicate.AuthIdentity {
+ return predicate.AuthIdentity(sql.FieldGT(FieldUpdatedAt, v))
+}
+
+// UpdatedAtGTE applies the GTE predicate on the "updated_at" field.
+func UpdatedAtGTE(v time.Time) predicate.AuthIdentity {
+ return predicate.AuthIdentity(sql.FieldGTE(FieldUpdatedAt, v))
+}
+
+// UpdatedAtLT applies the LT predicate on the "updated_at" field.
+func UpdatedAtLT(v time.Time) predicate.AuthIdentity {
+ return predicate.AuthIdentity(sql.FieldLT(FieldUpdatedAt, v))
+}
+
+// UpdatedAtLTE applies the LTE predicate on the "updated_at" field.
+func UpdatedAtLTE(v time.Time) predicate.AuthIdentity {
+ return predicate.AuthIdentity(sql.FieldLTE(FieldUpdatedAt, v))
+}
+
+// UserIDEQ applies the EQ predicate on the "user_id" field.
+func UserIDEQ(v int64) predicate.AuthIdentity {
+ return predicate.AuthIdentity(sql.FieldEQ(FieldUserID, v))
+}
+
+// UserIDNEQ applies the NEQ predicate on the "user_id" field.
+func UserIDNEQ(v int64) predicate.AuthIdentity {
+ return predicate.AuthIdentity(sql.FieldNEQ(FieldUserID, v))
+}
+
+// UserIDIn applies the In predicate on the "user_id" field.
+func UserIDIn(vs ...int64) predicate.AuthIdentity {
+ return predicate.AuthIdentity(sql.FieldIn(FieldUserID, vs...))
+}
+
+// UserIDNotIn applies the NotIn predicate on the "user_id" field.
+func UserIDNotIn(vs ...int64) predicate.AuthIdentity {
+ return predicate.AuthIdentity(sql.FieldNotIn(FieldUserID, vs...))
+}
+
+// ProviderTypeEQ applies the EQ predicate on the "provider_type" field.
+func ProviderTypeEQ(v string) predicate.AuthIdentity {
+ return predicate.AuthIdentity(sql.FieldEQ(FieldProviderType, v))
+}
+
+// ProviderTypeNEQ applies the NEQ predicate on the "provider_type" field.
+func ProviderTypeNEQ(v string) predicate.AuthIdentity {
+ return predicate.AuthIdentity(sql.FieldNEQ(FieldProviderType, v))
+}
+
+// ProviderTypeIn applies the In predicate on the "provider_type" field.
+func ProviderTypeIn(vs ...string) predicate.AuthIdentity {
+ return predicate.AuthIdentity(sql.FieldIn(FieldProviderType, vs...))
+}
+
+// ProviderTypeNotIn applies the NotIn predicate on the "provider_type" field.
+func ProviderTypeNotIn(vs ...string) predicate.AuthIdentity {
+ return predicate.AuthIdentity(sql.FieldNotIn(FieldProviderType, vs...))
+}
+
+// ProviderTypeGT applies the GT predicate on the "provider_type" field.
+func ProviderTypeGT(v string) predicate.AuthIdentity {
+ return predicate.AuthIdentity(sql.FieldGT(FieldProviderType, v))
+}
+
+// ProviderTypeGTE applies the GTE predicate on the "provider_type" field.
+func ProviderTypeGTE(v string) predicate.AuthIdentity {
+ return predicate.AuthIdentity(sql.FieldGTE(FieldProviderType, v))
+}
+
+// ProviderTypeLT applies the LT predicate on the "provider_type" field.
+func ProviderTypeLT(v string) predicate.AuthIdentity {
+ return predicate.AuthIdentity(sql.FieldLT(FieldProviderType, v))
+}
+
+// ProviderTypeLTE applies the LTE predicate on the "provider_type" field.
+func ProviderTypeLTE(v string) predicate.AuthIdentity {
+ return predicate.AuthIdentity(sql.FieldLTE(FieldProviderType, v))
+}
+
+// ProviderTypeContains applies the Contains predicate on the "provider_type" field.
+func ProviderTypeContains(v string) predicate.AuthIdentity {
+ return predicate.AuthIdentity(sql.FieldContains(FieldProviderType, v))
+}
+
+// ProviderTypeHasPrefix applies the HasPrefix predicate on the "provider_type" field.
+func ProviderTypeHasPrefix(v string) predicate.AuthIdentity {
+ return predicate.AuthIdentity(sql.FieldHasPrefix(FieldProviderType, v))
+}
+
+// ProviderTypeHasSuffix applies the HasSuffix predicate on the "provider_type" field.
+func ProviderTypeHasSuffix(v string) predicate.AuthIdentity {
+ return predicate.AuthIdentity(sql.FieldHasSuffix(FieldProviderType, v))
+}
+
+// ProviderTypeEqualFold applies the EqualFold predicate on the "provider_type" field.
+func ProviderTypeEqualFold(v string) predicate.AuthIdentity {
+ return predicate.AuthIdentity(sql.FieldEqualFold(FieldProviderType, v))
+}
+
+// ProviderTypeContainsFold applies the ContainsFold predicate on the "provider_type" field.
+func ProviderTypeContainsFold(v string) predicate.AuthIdentity {
+ return predicate.AuthIdentity(sql.FieldContainsFold(FieldProviderType, v))
+}
+
+// ProviderKeyEQ applies the EQ predicate on the "provider_key" field.
+func ProviderKeyEQ(v string) predicate.AuthIdentity {
+ return predicate.AuthIdentity(sql.FieldEQ(FieldProviderKey, v))
+}
+
+// ProviderKeyNEQ applies the NEQ predicate on the "provider_key" field.
+func ProviderKeyNEQ(v string) predicate.AuthIdentity {
+ return predicate.AuthIdentity(sql.FieldNEQ(FieldProviderKey, v))
+}
+
+// ProviderKeyIn applies the In predicate on the "provider_key" field.
+func ProviderKeyIn(vs ...string) predicate.AuthIdentity {
+ return predicate.AuthIdentity(sql.FieldIn(FieldProviderKey, vs...))
+}
+
+// ProviderKeyNotIn applies the NotIn predicate on the "provider_key" field.
+func ProviderKeyNotIn(vs ...string) predicate.AuthIdentity {
+ return predicate.AuthIdentity(sql.FieldNotIn(FieldProviderKey, vs...))
+}
+
+// ProviderKeyGT applies the GT predicate on the "provider_key" field.
+func ProviderKeyGT(v string) predicate.AuthIdentity {
+ return predicate.AuthIdentity(sql.FieldGT(FieldProviderKey, v))
+}
+
+// ProviderKeyGTE applies the GTE predicate on the "provider_key" field.
+func ProviderKeyGTE(v string) predicate.AuthIdentity {
+ return predicate.AuthIdentity(sql.FieldGTE(FieldProviderKey, v))
+}
+
+// ProviderKeyLT applies the LT predicate on the "provider_key" field.
+func ProviderKeyLT(v string) predicate.AuthIdentity {
+ return predicate.AuthIdentity(sql.FieldLT(FieldProviderKey, v))
+}
+
+// ProviderKeyLTE applies the LTE predicate on the "provider_key" field.
+func ProviderKeyLTE(v string) predicate.AuthIdentity {
+ return predicate.AuthIdentity(sql.FieldLTE(FieldProviderKey, v))
+}
+
+// ProviderKeyContains applies the Contains predicate on the "provider_key" field.
+func ProviderKeyContains(v string) predicate.AuthIdentity {
+ return predicate.AuthIdentity(sql.FieldContains(FieldProviderKey, v))
+}
+
+// ProviderKeyHasPrefix applies the HasPrefix predicate on the "provider_key" field.
+func ProviderKeyHasPrefix(v string) predicate.AuthIdentity {
+ return predicate.AuthIdentity(sql.FieldHasPrefix(FieldProviderKey, v))
+}
+
+// ProviderKeyHasSuffix applies the HasSuffix predicate on the "provider_key" field.
+func ProviderKeyHasSuffix(v string) predicate.AuthIdentity {
+ return predicate.AuthIdentity(sql.FieldHasSuffix(FieldProviderKey, v))
+}
+
+// ProviderKeyEqualFold applies the EqualFold predicate on the "provider_key" field.
+func ProviderKeyEqualFold(v string) predicate.AuthIdentity {
+ return predicate.AuthIdentity(sql.FieldEqualFold(FieldProviderKey, v))
+}
+
+// ProviderKeyContainsFold applies the ContainsFold predicate on the "provider_key" field.
+func ProviderKeyContainsFold(v string) predicate.AuthIdentity {
+ return predicate.AuthIdentity(sql.FieldContainsFold(FieldProviderKey, v))
+}
+
+// ProviderSubjectEQ applies the EQ predicate on the "provider_subject" field.
+func ProviderSubjectEQ(v string) predicate.AuthIdentity {
+ return predicate.AuthIdentity(sql.FieldEQ(FieldProviderSubject, v))
+}
+
+// ProviderSubjectNEQ applies the NEQ predicate on the "provider_subject" field.
+func ProviderSubjectNEQ(v string) predicate.AuthIdentity {
+ return predicate.AuthIdentity(sql.FieldNEQ(FieldProviderSubject, v))
+}
+
+// ProviderSubjectIn applies the In predicate on the "provider_subject" field.
+func ProviderSubjectIn(vs ...string) predicate.AuthIdentity {
+ return predicate.AuthIdentity(sql.FieldIn(FieldProviderSubject, vs...))
+}
+
+// ProviderSubjectNotIn applies the NotIn predicate on the "provider_subject" field.
+func ProviderSubjectNotIn(vs ...string) predicate.AuthIdentity {
+ return predicate.AuthIdentity(sql.FieldNotIn(FieldProviderSubject, vs...))
+}
+
+// ProviderSubjectGT applies the GT predicate on the "provider_subject" field.
+func ProviderSubjectGT(v string) predicate.AuthIdentity {
+ return predicate.AuthIdentity(sql.FieldGT(FieldProviderSubject, v))
+}
+
+// ProviderSubjectGTE applies the GTE predicate on the "provider_subject" field.
+func ProviderSubjectGTE(v string) predicate.AuthIdentity {
+ return predicate.AuthIdentity(sql.FieldGTE(FieldProviderSubject, v))
+}
+
+// ProviderSubjectLT applies the LT predicate on the "provider_subject" field.
+func ProviderSubjectLT(v string) predicate.AuthIdentity {
+ return predicate.AuthIdentity(sql.FieldLT(FieldProviderSubject, v))
+}
+
+// ProviderSubjectLTE applies the LTE predicate on the "provider_subject" field.
+func ProviderSubjectLTE(v string) predicate.AuthIdentity {
+ return predicate.AuthIdentity(sql.FieldLTE(FieldProviderSubject, v))
+}
+
+// ProviderSubjectContains applies the Contains predicate on the "provider_subject" field.
+func ProviderSubjectContains(v string) predicate.AuthIdentity {
+ return predicate.AuthIdentity(sql.FieldContains(FieldProviderSubject, v))
+}
+
+// ProviderSubjectHasPrefix applies the HasPrefix predicate on the "provider_subject" field.
+func ProviderSubjectHasPrefix(v string) predicate.AuthIdentity {
+ return predicate.AuthIdentity(sql.FieldHasPrefix(FieldProviderSubject, v))
+}
+
+// ProviderSubjectHasSuffix applies the HasSuffix predicate on the "provider_subject" field.
+func ProviderSubjectHasSuffix(v string) predicate.AuthIdentity {
+ return predicate.AuthIdentity(sql.FieldHasSuffix(FieldProviderSubject, v))
+}
+
+// ProviderSubjectEqualFold applies the EqualFold predicate on the "provider_subject" field.
+func ProviderSubjectEqualFold(v string) predicate.AuthIdentity {
+ return predicate.AuthIdentity(sql.FieldEqualFold(FieldProviderSubject, v))
+}
+
+// ProviderSubjectContainsFold applies the ContainsFold predicate on the "provider_subject" field.
+func ProviderSubjectContainsFold(v string) predicate.AuthIdentity {
+ return predicate.AuthIdentity(sql.FieldContainsFold(FieldProviderSubject, v))
+}
+
+// VerifiedAtEQ applies the EQ predicate on the "verified_at" field.
+func VerifiedAtEQ(v time.Time) predicate.AuthIdentity {
+ return predicate.AuthIdentity(sql.FieldEQ(FieldVerifiedAt, v))
+}
+
+// VerifiedAtNEQ applies the NEQ predicate on the "verified_at" field.
+func VerifiedAtNEQ(v time.Time) predicate.AuthIdentity {
+ return predicate.AuthIdentity(sql.FieldNEQ(FieldVerifiedAt, v))
+}
+
+// VerifiedAtIn applies the In predicate on the "verified_at" field.
+func VerifiedAtIn(vs ...time.Time) predicate.AuthIdentity {
+ return predicate.AuthIdentity(sql.FieldIn(FieldVerifiedAt, vs...))
+}
+
+// VerifiedAtNotIn applies the NotIn predicate on the "verified_at" field.
+func VerifiedAtNotIn(vs ...time.Time) predicate.AuthIdentity {
+ return predicate.AuthIdentity(sql.FieldNotIn(FieldVerifiedAt, vs...))
+}
+
+// VerifiedAtGT applies the GT predicate on the "verified_at" field.
+func VerifiedAtGT(v time.Time) predicate.AuthIdentity {
+ return predicate.AuthIdentity(sql.FieldGT(FieldVerifiedAt, v))
+}
+
+// VerifiedAtGTE applies the GTE predicate on the "verified_at" field.
+func VerifiedAtGTE(v time.Time) predicate.AuthIdentity {
+ return predicate.AuthIdentity(sql.FieldGTE(FieldVerifiedAt, v))
+}
+
+// VerifiedAtLT applies the LT predicate on the "verified_at" field.
+func VerifiedAtLT(v time.Time) predicate.AuthIdentity {
+ return predicate.AuthIdentity(sql.FieldLT(FieldVerifiedAt, v))
+}
+
+// VerifiedAtLTE applies the LTE predicate on the "verified_at" field.
+func VerifiedAtLTE(v time.Time) predicate.AuthIdentity {
+ return predicate.AuthIdentity(sql.FieldLTE(FieldVerifiedAt, v))
+}
+
+// VerifiedAtIsNil applies the IsNil predicate on the "verified_at" field.
+func VerifiedAtIsNil() predicate.AuthIdentity {
+ return predicate.AuthIdentity(sql.FieldIsNull(FieldVerifiedAt))
+}
+
+// VerifiedAtNotNil applies the NotNil predicate on the "verified_at" field.
+func VerifiedAtNotNil() predicate.AuthIdentity {
+ return predicate.AuthIdentity(sql.FieldNotNull(FieldVerifiedAt))
+}
+
+// IssuerEQ applies the EQ predicate on the "issuer" field.
+func IssuerEQ(v string) predicate.AuthIdentity {
+ return predicate.AuthIdentity(sql.FieldEQ(FieldIssuer, v))
+}
+
+// IssuerNEQ applies the NEQ predicate on the "issuer" field.
+func IssuerNEQ(v string) predicate.AuthIdentity {
+ return predicate.AuthIdentity(sql.FieldNEQ(FieldIssuer, v))
+}
+
+// IssuerIn applies the In predicate on the "issuer" field.
+func IssuerIn(vs ...string) predicate.AuthIdentity {
+ return predicate.AuthIdentity(sql.FieldIn(FieldIssuer, vs...))
+}
+
+// IssuerNotIn applies the NotIn predicate on the "issuer" field.
+func IssuerNotIn(vs ...string) predicate.AuthIdentity {
+ return predicate.AuthIdentity(sql.FieldNotIn(FieldIssuer, vs...))
+}
+
+// IssuerGT applies the GT predicate on the "issuer" field.
+func IssuerGT(v string) predicate.AuthIdentity {
+ return predicate.AuthIdentity(sql.FieldGT(FieldIssuer, v))
+}
+
+// IssuerGTE applies the GTE predicate on the "issuer" field.
+func IssuerGTE(v string) predicate.AuthIdentity {
+ return predicate.AuthIdentity(sql.FieldGTE(FieldIssuer, v))
+}
+
+// IssuerLT applies the LT predicate on the "issuer" field.
+func IssuerLT(v string) predicate.AuthIdentity {
+ return predicate.AuthIdentity(sql.FieldLT(FieldIssuer, v))
+}
+
+// IssuerLTE applies the LTE predicate on the "issuer" field.
+func IssuerLTE(v string) predicate.AuthIdentity {
+ return predicate.AuthIdentity(sql.FieldLTE(FieldIssuer, v))
+}
+
+// IssuerContains applies the Contains predicate on the "issuer" field.
+func IssuerContains(v string) predicate.AuthIdentity {
+ return predicate.AuthIdentity(sql.FieldContains(FieldIssuer, v))
+}
+
+// IssuerHasPrefix applies the HasPrefix predicate on the "issuer" field.
+func IssuerHasPrefix(v string) predicate.AuthIdentity {
+ return predicate.AuthIdentity(sql.FieldHasPrefix(FieldIssuer, v))
+}
+
+// IssuerHasSuffix applies the HasSuffix predicate on the "issuer" field.
+func IssuerHasSuffix(v string) predicate.AuthIdentity {
+ return predicate.AuthIdentity(sql.FieldHasSuffix(FieldIssuer, v))
+}
+
+// IssuerIsNil applies the IsNil predicate on the "issuer" field.
+func IssuerIsNil() predicate.AuthIdentity {
+ return predicate.AuthIdentity(sql.FieldIsNull(FieldIssuer))
+}
+
+// IssuerNotNil applies the NotNil predicate on the "issuer" field.
+func IssuerNotNil() predicate.AuthIdentity {
+ return predicate.AuthIdentity(sql.FieldNotNull(FieldIssuer))
+}
+
+// IssuerEqualFold applies the EqualFold predicate on the "issuer" field.
+func IssuerEqualFold(v string) predicate.AuthIdentity {
+ return predicate.AuthIdentity(sql.FieldEqualFold(FieldIssuer, v))
+}
+
+// IssuerContainsFold applies the ContainsFold predicate on the "issuer" field.
+func IssuerContainsFold(v string) predicate.AuthIdentity {
+ return predicate.AuthIdentity(sql.FieldContainsFold(FieldIssuer, v))
+}
+
+// HasUser applies the HasEdge predicate on the "user" edge.
+func HasUser() predicate.AuthIdentity {
+ return predicate.AuthIdentity(func(s *sql.Selector) {
+ step := sqlgraph.NewStep(
+ sqlgraph.From(Table, FieldID),
+ sqlgraph.Edge(sqlgraph.M2O, true, UserTable, UserColumn),
+ )
+ sqlgraph.HasNeighbors(s, step)
+ })
+}
+
+// HasUserWith applies the HasEdge predicate on the "user" edge with a given conditions (other predicates).
+func HasUserWith(preds ...predicate.User) predicate.AuthIdentity {
+ return predicate.AuthIdentity(func(s *sql.Selector) {
+ step := newUserStep()
+ sqlgraph.HasNeighborsWith(s, step, func(s *sql.Selector) {
+ for _, p := range preds {
+ p(s)
+ }
+ })
+ })
+}
+
+// HasChannels applies the HasEdge predicate on the "channels" edge.
+func HasChannels() predicate.AuthIdentity {
+ return predicate.AuthIdentity(func(s *sql.Selector) {
+ step := sqlgraph.NewStep(
+ sqlgraph.From(Table, FieldID),
+ sqlgraph.Edge(sqlgraph.O2M, false, ChannelsTable, ChannelsColumn),
+ )
+ sqlgraph.HasNeighbors(s, step)
+ })
+}
+
+// HasChannelsWith applies the HasEdge predicate on the "channels" edge with a given conditions (other predicates).
+func HasChannelsWith(preds ...predicate.AuthIdentityChannel) predicate.AuthIdentity {
+ return predicate.AuthIdentity(func(s *sql.Selector) {
+ step := newChannelsStep()
+ sqlgraph.HasNeighborsWith(s, step, func(s *sql.Selector) {
+ for _, p := range preds {
+ p(s)
+ }
+ })
+ })
+}
+
+// HasAdoptionDecisions applies the HasEdge predicate on the "adoption_decisions" edge.
+func HasAdoptionDecisions() predicate.AuthIdentity {
+ return predicate.AuthIdentity(func(s *sql.Selector) {
+ step := sqlgraph.NewStep(
+ sqlgraph.From(Table, FieldID),
+ sqlgraph.Edge(sqlgraph.O2M, false, AdoptionDecisionsTable, AdoptionDecisionsColumn),
+ )
+ sqlgraph.HasNeighbors(s, step)
+ })
+}
+
+// HasAdoptionDecisionsWith applies the HasEdge predicate on the "adoption_decisions" edge with a given conditions (other predicates).
+func HasAdoptionDecisionsWith(preds ...predicate.IdentityAdoptionDecision) predicate.AuthIdentity {
+ return predicate.AuthIdentity(func(s *sql.Selector) {
+ step := newAdoptionDecisionsStep()
+ sqlgraph.HasNeighborsWith(s, step, func(s *sql.Selector) {
+ for _, p := range preds {
+ p(s)
+ }
+ })
+ })
+}
+
+// And groups predicates with the AND operator between them.
+func And(predicates ...predicate.AuthIdentity) predicate.AuthIdentity {
+ return predicate.AuthIdentity(sql.AndPredicates(predicates...))
+}
+
+// Or groups predicates with the OR operator between them.
+func Or(predicates ...predicate.AuthIdentity) predicate.AuthIdentity {
+ return predicate.AuthIdentity(sql.OrPredicates(predicates...))
+}
+
+// Not applies the not operator on the given predicate.
+func Not(p predicate.AuthIdentity) predicate.AuthIdentity {
+ return predicate.AuthIdentity(sql.NotPredicates(p))
+}
diff --git a/backend/ent/authidentity_create.go b/backend/ent/authidentity_create.go
new file mode 100644
index 00000000..e287705c
--- /dev/null
+++ b/backend/ent/authidentity_create.go
@@ -0,0 +1,1036 @@
+// Code generated by ent, DO NOT EDIT.
+
+package ent
+
+import (
+ "context"
+ "errors"
+ "fmt"
+ "time"
+
+ "entgo.io/ent/dialect/sql"
+ "entgo.io/ent/dialect/sql/sqlgraph"
+ "entgo.io/ent/schema/field"
+ "github.com/Wei-Shaw/sub2api/ent/authidentity"
+ "github.com/Wei-Shaw/sub2api/ent/authidentitychannel"
+ "github.com/Wei-Shaw/sub2api/ent/identityadoptiondecision"
+ "github.com/Wei-Shaw/sub2api/ent/user"
+)
+
+// AuthIdentityCreate is the builder for creating a AuthIdentity entity.
+type AuthIdentityCreate struct {
+ config
+ mutation *AuthIdentityMutation
+ hooks []Hook
+ conflict []sql.ConflictOption
+}
+
+// SetCreatedAt sets the "created_at" field.
+func (_c *AuthIdentityCreate) SetCreatedAt(v time.Time) *AuthIdentityCreate {
+ _c.mutation.SetCreatedAt(v)
+ return _c
+}
+
+// SetNillableCreatedAt sets the "created_at" field if the given value is not nil.
+func (_c *AuthIdentityCreate) SetNillableCreatedAt(v *time.Time) *AuthIdentityCreate {
+ if v != nil {
+ _c.SetCreatedAt(*v)
+ }
+ return _c
+}
+
+// SetUpdatedAt sets the "updated_at" field.
+func (_c *AuthIdentityCreate) SetUpdatedAt(v time.Time) *AuthIdentityCreate {
+ _c.mutation.SetUpdatedAt(v)
+ return _c
+}
+
+// SetNillableUpdatedAt sets the "updated_at" field if the given value is not nil.
+func (_c *AuthIdentityCreate) SetNillableUpdatedAt(v *time.Time) *AuthIdentityCreate {
+ if v != nil {
+ _c.SetUpdatedAt(*v)
+ }
+ return _c
+}
+
+// SetUserID sets the "user_id" field.
+func (_c *AuthIdentityCreate) SetUserID(v int64) *AuthIdentityCreate {
+ _c.mutation.SetUserID(v)
+ return _c
+}
+
+// SetProviderType sets the "provider_type" field.
+func (_c *AuthIdentityCreate) SetProviderType(v string) *AuthIdentityCreate {
+ _c.mutation.SetProviderType(v)
+ return _c
+}
+
+// SetProviderKey sets the "provider_key" field.
+func (_c *AuthIdentityCreate) SetProviderKey(v string) *AuthIdentityCreate {
+ _c.mutation.SetProviderKey(v)
+ return _c
+}
+
+// SetProviderSubject sets the "provider_subject" field.
+func (_c *AuthIdentityCreate) SetProviderSubject(v string) *AuthIdentityCreate {
+ _c.mutation.SetProviderSubject(v)
+ return _c
+}
+
+// SetVerifiedAt sets the "verified_at" field.
+func (_c *AuthIdentityCreate) SetVerifiedAt(v time.Time) *AuthIdentityCreate {
+ _c.mutation.SetVerifiedAt(v)
+ return _c
+}
+
+// SetNillableVerifiedAt sets the "verified_at" field if the given value is not nil.
+func (_c *AuthIdentityCreate) SetNillableVerifiedAt(v *time.Time) *AuthIdentityCreate {
+ if v != nil {
+ _c.SetVerifiedAt(*v)
+ }
+ return _c
+}
+
+// SetIssuer sets the "issuer" field.
+func (_c *AuthIdentityCreate) SetIssuer(v string) *AuthIdentityCreate {
+ _c.mutation.SetIssuer(v)
+ return _c
+}
+
+// SetNillableIssuer sets the "issuer" field if the given value is not nil.
+func (_c *AuthIdentityCreate) SetNillableIssuer(v *string) *AuthIdentityCreate {
+ if v != nil {
+ _c.SetIssuer(*v)
+ }
+ return _c
+}
+
+// SetMetadata sets the "metadata" field.
+func (_c *AuthIdentityCreate) SetMetadata(v map[string]interface{}) *AuthIdentityCreate {
+ _c.mutation.SetMetadata(v)
+ return _c
+}
+
+// SetUser sets the "user" edge to the User entity.
+func (_c *AuthIdentityCreate) SetUser(v *User) *AuthIdentityCreate {
+ return _c.SetUserID(v.ID)
+}
+
+// AddChannelIDs adds the "channels" edge to the AuthIdentityChannel entity by IDs.
+func (_c *AuthIdentityCreate) AddChannelIDs(ids ...int64) *AuthIdentityCreate {
+ _c.mutation.AddChannelIDs(ids...)
+ return _c
+}
+
+// AddChannels adds the "channels" edges to the AuthIdentityChannel entity.
+func (_c *AuthIdentityCreate) AddChannels(v ...*AuthIdentityChannel) *AuthIdentityCreate {
+ ids := make([]int64, len(v))
+ for i := range v {
+ ids[i] = v[i].ID
+ }
+ return _c.AddChannelIDs(ids...)
+}
+
+// AddAdoptionDecisionIDs adds the "adoption_decisions" edge to the IdentityAdoptionDecision entity by IDs.
+func (_c *AuthIdentityCreate) AddAdoptionDecisionIDs(ids ...int64) *AuthIdentityCreate {
+ _c.mutation.AddAdoptionDecisionIDs(ids...)
+ return _c
+}
+
+// AddAdoptionDecisions adds the "adoption_decisions" edges to the IdentityAdoptionDecision entity.
+func (_c *AuthIdentityCreate) AddAdoptionDecisions(v ...*IdentityAdoptionDecision) *AuthIdentityCreate {
+ ids := make([]int64, len(v))
+ for i := range v {
+ ids[i] = v[i].ID
+ }
+ return _c.AddAdoptionDecisionIDs(ids...)
+}
+
+// Mutation returns the AuthIdentityMutation object of the builder.
+func (_c *AuthIdentityCreate) Mutation() *AuthIdentityMutation {
+ return _c.mutation
+}
+
+// Save creates the AuthIdentity in the database.
+func (_c *AuthIdentityCreate) Save(ctx context.Context) (*AuthIdentity, error) {
+ _c.defaults()
+ return withHooks(ctx, _c.sqlSave, _c.mutation, _c.hooks)
+}
+
+// SaveX calls Save and panics if Save returns an error.
+func (_c *AuthIdentityCreate) SaveX(ctx context.Context) *AuthIdentity {
+ v, err := _c.Save(ctx)
+ if err != nil {
+ panic(err)
+ }
+ return v
+}
+
+// Exec executes the query.
+func (_c *AuthIdentityCreate) Exec(ctx context.Context) error {
+ _, err := _c.Save(ctx)
+ return err
+}
+
+// ExecX is like Exec, but panics if an error occurs.
+func (_c *AuthIdentityCreate) ExecX(ctx context.Context) {
+ if err := _c.Exec(ctx); err != nil {
+ panic(err)
+ }
+}
+
+// defaults sets the default values of the builder before save.
+func (_c *AuthIdentityCreate) defaults() {
+ if _, ok := _c.mutation.CreatedAt(); !ok {
+ v := authidentity.DefaultCreatedAt()
+ _c.mutation.SetCreatedAt(v)
+ }
+ if _, ok := _c.mutation.UpdatedAt(); !ok {
+ v := authidentity.DefaultUpdatedAt()
+ _c.mutation.SetUpdatedAt(v)
+ }
+ if _, ok := _c.mutation.Metadata(); !ok {
+ v := authidentity.DefaultMetadata()
+ _c.mutation.SetMetadata(v)
+ }
+}
+
+// check runs all checks and user-defined validators on the builder.
+func (_c *AuthIdentityCreate) check() error {
+ if _, ok := _c.mutation.CreatedAt(); !ok {
+ return &ValidationError{Name: "created_at", err: errors.New(`ent: missing required field "AuthIdentity.created_at"`)}
+ }
+ if _, ok := _c.mutation.UpdatedAt(); !ok {
+ return &ValidationError{Name: "updated_at", err: errors.New(`ent: missing required field "AuthIdentity.updated_at"`)}
+ }
+ if _, ok := _c.mutation.UserID(); !ok {
+ return &ValidationError{Name: "user_id", err: errors.New(`ent: missing required field "AuthIdentity.user_id"`)}
+ }
+ if _, ok := _c.mutation.ProviderType(); !ok {
+ return &ValidationError{Name: "provider_type", err: errors.New(`ent: missing required field "AuthIdentity.provider_type"`)}
+ }
+ if v, ok := _c.mutation.ProviderType(); ok {
+ if err := authidentity.ProviderTypeValidator(v); err != nil {
+ return &ValidationError{Name: "provider_type", err: fmt.Errorf(`ent: validator failed for field "AuthIdentity.provider_type": %w`, err)}
+ }
+ }
+ if _, ok := _c.mutation.ProviderKey(); !ok {
+ return &ValidationError{Name: "provider_key", err: errors.New(`ent: missing required field "AuthIdentity.provider_key"`)}
+ }
+ if v, ok := _c.mutation.ProviderKey(); ok {
+ if err := authidentity.ProviderKeyValidator(v); err != nil {
+ return &ValidationError{Name: "provider_key", err: fmt.Errorf(`ent: validator failed for field "AuthIdentity.provider_key": %w`, err)}
+ }
+ }
+ if _, ok := _c.mutation.ProviderSubject(); !ok {
+ return &ValidationError{Name: "provider_subject", err: errors.New(`ent: missing required field "AuthIdentity.provider_subject"`)}
+ }
+ if v, ok := _c.mutation.ProviderSubject(); ok {
+ if err := authidentity.ProviderSubjectValidator(v); err != nil {
+ return &ValidationError{Name: "provider_subject", err: fmt.Errorf(`ent: validator failed for field "AuthIdentity.provider_subject": %w`, err)}
+ }
+ }
+ if _, ok := _c.mutation.Metadata(); !ok {
+ return &ValidationError{Name: "metadata", err: errors.New(`ent: missing required field "AuthIdentity.metadata"`)}
+ }
+ if len(_c.mutation.UserIDs()) == 0 {
+ return &ValidationError{Name: "user", err: errors.New(`ent: missing required edge "AuthIdentity.user"`)}
+ }
+ return nil
+}
+
+func (_c *AuthIdentityCreate) sqlSave(ctx context.Context) (*AuthIdentity, error) {
+ if err := _c.check(); err != nil {
+ return nil, err
+ }
+ _node, _spec := _c.createSpec()
+ if err := sqlgraph.CreateNode(ctx, _c.driver, _spec); err != nil {
+ if sqlgraph.IsConstraintError(err) {
+ err = &ConstraintError{msg: err.Error(), wrap: err}
+ }
+ return nil, err
+ }
+ id := _spec.ID.Value.(int64)
+ _node.ID = int64(id)
+ _c.mutation.id = &_node.ID
+ _c.mutation.done = true
+ return _node, nil
+}
+
+func (_c *AuthIdentityCreate) createSpec() (*AuthIdentity, *sqlgraph.CreateSpec) {
+ var (
+ _node = &AuthIdentity{config: _c.config}
+ _spec = sqlgraph.NewCreateSpec(authidentity.Table, sqlgraph.NewFieldSpec(authidentity.FieldID, field.TypeInt64))
+ )
+ _spec.OnConflict = _c.conflict
+ if value, ok := _c.mutation.CreatedAt(); ok {
+ _spec.SetField(authidentity.FieldCreatedAt, field.TypeTime, value)
+ _node.CreatedAt = value
+ }
+ if value, ok := _c.mutation.UpdatedAt(); ok {
+ _spec.SetField(authidentity.FieldUpdatedAt, field.TypeTime, value)
+ _node.UpdatedAt = value
+ }
+ if value, ok := _c.mutation.ProviderType(); ok {
+ _spec.SetField(authidentity.FieldProviderType, field.TypeString, value)
+ _node.ProviderType = value
+ }
+ if value, ok := _c.mutation.ProviderKey(); ok {
+ _spec.SetField(authidentity.FieldProviderKey, field.TypeString, value)
+ _node.ProviderKey = value
+ }
+ if value, ok := _c.mutation.ProviderSubject(); ok {
+ _spec.SetField(authidentity.FieldProviderSubject, field.TypeString, value)
+ _node.ProviderSubject = value
+ }
+ if value, ok := _c.mutation.VerifiedAt(); ok {
+ _spec.SetField(authidentity.FieldVerifiedAt, field.TypeTime, value)
+ _node.VerifiedAt = &value
+ }
+ if value, ok := _c.mutation.Issuer(); ok {
+ _spec.SetField(authidentity.FieldIssuer, field.TypeString, value)
+ _node.Issuer = &value
+ }
+ if value, ok := _c.mutation.Metadata(); ok {
+ _spec.SetField(authidentity.FieldMetadata, field.TypeJSON, value)
+ _node.Metadata = value
+ }
+ if nodes := _c.mutation.UserIDs(); len(nodes) > 0 {
+ edge := &sqlgraph.EdgeSpec{
+ Rel: sqlgraph.M2O,
+ Inverse: true,
+ Table: authidentity.UserTable,
+ Columns: []string{authidentity.UserColumn},
+ Bidi: false,
+ Target: &sqlgraph.EdgeTarget{
+ IDSpec: sqlgraph.NewFieldSpec(user.FieldID, field.TypeInt64),
+ },
+ }
+ for _, k := range nodes {
+ edge.Target.Nodes = append(edge.Target.Nodes, k)
+ }
+ _node.UserID = nodes[0]
+ _spec.Edges = append(_spec.Edges, edge)
+ }
+ if nodes := _c.mutation.ChannelsIDs(); len(nodes) > 0 {
+ edge := &sqlgraph.EdgeSpec{
+ Rel: sqlgraph.O2M,
+ Inverse: false,
+ Table: authidentity.ChannelsTable,
+ Columns: []string{authidentity.ChannelsColumn},
+ Bidi: false,
+ Target: &sqlgraph.EdgeTarget{
+ IDSpec: sqlgraph.NewFieldSpec(authidentitychannel.FieldID, field.TypeInt64),
+ },
+ }
+ for _, k := range nodes {
+ edge.Target.Nodes = append(edge.Target.Nodes, k)
+ }
+ _spec.Edges = append(_spec.Edges, edge)
+ }
+ if nodes := _c.mutation.AdoptionDecisionsIDs(); len(nodes) > 0 {
+ edge := &sqlgraph.EdgeSpec{
+ Rel: sqlgraph.O2M,
+ Inverse: false,
+ Table: authidentity.AdoptionDecisionsTable,
+ Columns: []string{authidentity.AdoptionDecisionsColumn},
+ Bidi: false,
+ Target: &sqlgraph.EdgeTarget{
+ IDSpec: sqlgraph.NewFieldSpec(identityadoptiondecision.FieldID, field.TypeInt64),
+ },
+ }
+ for _, k := range nodes {
+ edge.Target.Nodes = append(edge.Target.Nodes, k)
+ }
+ _spec.Edges = append(_spec.Edges, edge)
+ }
+ return _node, _spec
+}
+
+// OnConflict allows configuring the `ON CONFLICT` / `ON DUPLICATE KEY` clause
+// of the `INSERT` statement. For example:
+//
+// client.AuthIdentity.Create().
+// SetCreatedAt(v).
+// OnConflict(
+// // Update the row with the new values
+// // the was proposed for insertion.
+// sql.ResolveWithNewValues(),
+// ).
+// // Override some of the fields with custom
+// // update values.
+// Update(func(u *ent.AuthIdentityUpsert) {
+// SetCreatedAt(v+v).
+// }).
+// Exec(ctx)
+func (_c *AuthIdentityCreate) OnConflict(opts ...sql.ConflictOption) *AuthIdentityUpsertOne {
+ _c.conflict = opts
+ return &AuthIdentityUpsertOne{
+ create: _c,
+ }
+}
+
+// OnConflictColumns calls `OnConflict` and configures the columns
+// as conflict target. Using this option is equivalent to using:
+//
+// client.AuthIdentity.Create().
+// OnConflict(sql.ConflictColumns(columns...)).
+// Exec(ctx)
+func (_c *AuthIdentityCreate) OnConflictColumns(columns ...string) *AuthIdentityUpsertOne {
+ _c.conflict = append(_c.conflict, sql.ConflictColumns(columns...))
+ return &AuthIdentityUpsertOne{
+ create: _c,
+ }
+}
+
+type (
+ // AuthIdentityUpsertOne is the builder for "upsert"-ing
+ // one AuthIdentity node.
+ AuthIdentityUpsertOne struct {
+ create *AuthIdentityCreate
+ }
+
+ // AuthIdentityUpsert is the "OnConflict" setter.
+ AuthIdentityUpsert struct {
+ *sql.UpdateSet
+ }
+)
+
+// SetUpdatedAt sets the "updated_at" field.
+func (u *AuthIdentityUpsert) SetUpdatedAt(v time.Time) *AuthIdentityUpsert {
+ u.Set(authidentity.FieldUpdatedAt, v)
+ return u
+}
+
+// UpdateUpdatedAt sets the "updated_at" field to the value that was provided on create.
+func (u *AuthIdentityUpsert) UpdateUpdatedAt() *AuthIdentityUpsert {
+ u.SetExcluded(authidentity.FieldUpdatedAt)
+ return u
+}
+
+// SetUserID sets the "user_id" field.
+func (u *AuthIdentityUpsert) SetUserID(v int64) *AuthIdentityUpsert {
+ u.Set(authidentity.FieldUserID, v)
+ return u
+}
+
+// UpdateUserID sets the "user_id" field to the value that was provided on create.
+func (u *AuthIdentityUpsert) UpdateUserID() *AuthIdentityUpsert {
+ u.SetExcluded(authidentity.FieldUserID)
+ return u
+}
+
+// SetProviderType sets the "provider_type" field.
+func (u *AuthIdentityUpsert) SetProviderType(v string) *AuthIdentityUpsert {
+ u.Set(authidentity.FieldProviderType, v)
+ return u
+}
+
+// UpdateProviderType sets the "provider_type" field to the value that was provided on create.
+func (u *AuthIdentityUpsert) UpdateProviderType() *AuthIdentityUpsert {
+ u.SetExcluded(authidentity.FieldProviderType)
+ return u
+}
+
+// SetProviderKey sets the "provider_key" field.
+func (u *AuthIdentityUpsert) SetProviderKey(v string) *AuthIdentityUpsert {
+ u.Set(authidentity.FieldProviderKey, v)
+ return u
+}
+
+// UpdateProviderKey sets the "provider_key" field to the value that was provided on create.
+func (u *AuthIdentityUpsert) UpdateProviderKey() *AuthIdentityUpsert {
+ u.SetExcluded(authidentity.FieldProviderKey)
+ return u
+}
+
+// SetProviderSubject sets the "provider_subject" field.
+func (u *AuthIdentityUpsert) SetProviderSubject(v string) *AuthIdentityUpsert {
+ u.Set(authidentity.FieldProviderSubject, v)
+ return u
+}
+
+// UpdateProviderSubject sets the "provider_subject" field to the value that was provided on create.
+func (u *AuthIdentityUpsert) UpdateProviderSubject() *AuthIdentityUpsert {
+ u.SetExcluded(authidentity.FieldProviderSubject)
+ return u
+}
+
+// SetVerifiedAt sets the "verified_at" field.
+func (u *AuthIdentityUpsert) SetVerifiedAt(v time.Time) *AuthIdentityUpsert {
+ u.Set(authidentity.FieldVerifiedAt, v)
+ return u
+}
+
+// UpdateVerifiedAt sets the "verified_at" field to the value that was provided on create.
+func (u *AuthIdentityUpsert) UpdateVerifiedAt() *AuthIdentityUpsert {
+ u.SetExcluded(authidentity.FieldVerifiedAt)
+ return u
+}
+
+// ClearVerifiedAt clears the value of the "verified_at" field.
+func (u *AuthIdentityUpsert) ClearVerifiedAt() *AuthIdentityUpsert {
+ u.SetNull(authidentity.FieldVerifiedAt)
+ return u
+}
+
+// SetIssuer sets the "issuer" field.
+func (u *AuthIdentityUpsert) SetIssuer(v string) *AuthIdentityUpsert {
+ u.Set(authidentity.FieldIssuer, v)
+ return u
+}
+
+// UpdateIssuer sets the "issuer" field to the value that was provided on create.
+func (u *AuthIdentityUpsert) UpdateIssuer() *AuthIdentityUpsert {
+ u.SetExcluded(authidentity.FieldIssuer)
+ return u
+}
+
+// ClearIssuer clears the value of the "issuer" field.
+func (u *AuthIdentityUpsert) ClearIssuer() *AuthIdentityUpsert {
+ u.SetNull(authidentity.FieldIssuer)
+ return u
+}
+
+// SetMetadata sets the "metadata" field.
+func (u *AuthIdentityUpsert) SetMetadata(v map[string]interface{}) *AuthIdentityUpsert {
+ u.Set(authidentity.FieldMetadata, v)
+ return u
+}
+
+// UpdateMetadata sets the "metadata" field to the value that was provided on create.
+func (u *AuthIdentityUpsert) UpdateMetadata() *AuthIdentityUpsert {
+ u.SetExcluded(authidentity.FieldMetadata)
+ return u
+}
+
+// UpdateNewValues updates the mutable fields using the new values that were set on create.
+// Using this option is equivalent to using:
+//
+// client.AuthIdentity.Create().
+// OnConflict(
+// sql.ResolveWithNewValues(),
+// ).
+// Exec(ctx)
+func (u *AuthIdentityUpsertOne) UpdateNewValues() *AuthIdentityUpsertOne {
+ u.create.conflict = append(u.create.conflict, sql.ResolveWithNewValues())
+ u.create.conflict = append(u.create.conflict, sql.ResolveWith(func(s *sql.UpdateSet) {
+ if _, exists := u.create.mutation.CreatedAt(); exists {
+ s.SetIgnore(authidentity.FieldCreatedAt)
+ }
+ }))
+ return u
+}
+
+// Ignore sets each column to itself in case of conflict.
+// Using this option is equivalent to using:
+//
+// client.AuthIdentity.Create().
+// OnConflict(sql.ResolveWithIgnore()).
+// Exec(ctx)
+func (u *AuthIdentityUpsertOne) Ignore() *AuthIdentityUpsertOne {
+ u.create.conflict = append(u.create.conflict, sql.ResolveWithIgnore())
+ return u
+}
+
+// DoNothing configures the conflict_action to `DO NOTHING`.
+// Supported only by SQLite and PostgreSQL.
+func (u *AuthIdentityUpsertOne) DoNothing() *AuthIdentityUpsertOne {
+ u.create.conflict = append(u.create.conflict, sql.DoNothing())
+ return u
+}
+
+// Update allows overriding fields `UPDATE` values. See the AuthIdentityCreate.OnConflict
+// documentation for more info.
+func (u *AuthIdentityUpsertOne) Update(set func(*AuthIdentityUpsert)) *AuthIdentityUpsertOne {
+ u.create.conflict = append(u.create.conflict, sql.ResolveWith(func(update *sql.UpdateSet) {
+ set(&AuthIdentityUpsert{UpdateSet: update})
+ }))
+ return u
+}
+
+// SetUpdatedAt sets the "updated_at" field.
+func (u *AuthIdentityUpsertOne) SetUpdatedAt(v time.Time) *AuthIdentityUpsertOne {
+ return u.Update(func(s *AuthIdentityUpsert) {
+ s.SetUpdatedAt(v)
+ })
+}
+
+// UpdateUpdatedAt sets the "updated_at" field to the value that was provided on create.
+func (u *AuthIdentityUpsertOne) UpdateUpdatedAt() *AuthIdentityUpsertOne {
+ return u.Update(func(s *AuthIdentityUpsert) {
+ s.UpdateUpdatedAt()
+ })
+}
+
+// SetUserID sets the "user_id" field.
+func (u *AuthIdentityUpsertOne) SetUserID(v int64) *AuthIdentityUpsertOne {
+ return u.Update(func(s *AuthIdentityUpsert) {
+ s.SetUserID(v)
+ })
+}
+
+// UpdateUserID sets the "user_id" field to the value that was provided on create.
+func (u *AuthIdentityUpsertOne) UpdateUserID() *AuthIdentityUpsertOne {
+ return u.Update(func(s *AuthIdentityUpsert) {
+ s.UpdateUserID()
+ })
+}
+
+// SetProviderType sets the "provider_type" field.
+func (u *AuthIdentityUpsertOne) SetProviderType(v string) *AuthIdentityUpsertOne {
+ return u.Update(func(s *AuthIdentityUpsert) {
+ s.SetProviderType(v)
+ })
+}
+
+// UpdateProviderType sets the "provider_type" field to the value that was provided on create.
+func (u *AuthIdentityUpsertOne) UpdateProviderType() *AuthIdentityUpsertOne {
+ return u.Update(func(s *AuthIdentityUpsert) {
+ s.UpdateProviderType()
+ })
+}
+
+// SetProviderKey sets the "provider_key" field.
+func (u *AuthIdentityUpsertOne) SetProviderKey(v string) *AuthIdentityUpsertOne {
+ return u.Update(func(s *AuthIdentityUpsert) {
+ s.SetProviderKey(v)
+ })
+}
+
+// UpdateProviderKey sets the "provider_key" field to the value that was provided on create.
+func (u *AuthIdentityUpsertOne) UpdateProviderKey() *AuthIdentityUpsertOne {
+ return u.Update(func(s *AuthIdentityUpsert) {
+ s.UpdateProviderKey()
+ })
+}
+
+// SetProviderSubject sets the "provider_subject" field.
+func (u *AuthIdentityUpsertOne) SetProviderSubject(v string) *AuthIdentityUpsertOne {
+ return u.Update(func(s *AuthIdentityUpsert) {
+ s.SetProviderSubject(v)
+ })
+}
+
+// UpdateProviderSubject sets the "provider_subject" field to the value that was provided on create.
+func (u *AuthIdentityUpsertOne) UpdateProviderSubject() *AuthIdentityUpsertOne {
+ return u.Update(func(s *AuthIdentityUpsert) {
+ s.UpdateProviderSubject()
+ })
+}
+
+// SetVerifiedAt sets the "verified_at" field.
+func (u *AuthIdentityUpsertOne) SetVerifiedAt(v time.Time) *AuthIdentityUpsertOne {
+ return u.Update(func(s *AuthIdentityUpsert) {
+ s.SetVerifiedAt(v)
+ })
+}
+
+// UpdateVerifiedAt sets the "verified_at" field to the value that was provided on create.
+func (u *AuthIdentityUpsertOne) UpdateVerifiedAt() *AuthIdentityUpsertOne {
+ return u.Update(func(s *AuthIdentityUpsert) {
+ s.UpdateVerifiedAt()
+ })
+}
+
+// ClearVerifiedAt clears the value of the "verified_at" field.
+func (u *AuthIdentityUpsertOne) ClearVerifiedAt() *AuthIdentityUpsertOne {
+ return u.Update(func(s *AuthIdentityUpsert) {
+ s.ClearVerifiedAt()
+ })
+}
+
+// SetIssuer sets the "issuer" field.
+func (u *AuthIdentityUpsertOne) SetIssuer(v string) *AuthIdentityUpsertOne {
+ return u.Update(func(s *AuthIdentityUpsert) {
+ s.SetIssuer(v)
+ })
+}
+
+// UpdateIssuer sets the "issuer" field to the value that was provided on create.
+func (u *AuthIdentityUpsertOne) UpdateIssuer() *AuthIdentityUpsertOne {
+ return u.Update(func(s *AuthIdentityUpsert) {
+ s.UpdateIssuer()
+ })
+}
+
+// ClearIssuer clears the value of the "issuer" field.
+func (u *AuthIdentityUpsertOne) ClearIssuer() *AuthIdentityUpsertOne {
+ return u.Update(func(s *AuthIdentityUpsert) {
+ s.ClearIssuer()
+ })
+}
+
+// SetMetadata sets the "metadata" field.
+func (u *AuthIdentityUpsertOne) SetMetadata(v map[string]interface{}) *AuthIdentityUpsertOne {
+ return u.Update(func(s *AuthIdentityUpsert) {
+ s.SetMetadata(v)
+ })
+}
+
+// UpdateMetadata sets the "metadata" field to the value that was provided on create.
+func (u *AuthIdentityUpsertOne) UpdateMetadata() *AuthIdentityUpsertOne {
+ return u.Update(func(s *AuthIdentityUpsert) {
+ s.UpdateMetadata()
+ })
+}
+
+// Exec executes the query.
+func (u *AuthIdentityUpsertOne) Exec(ctx context.Context) error {
+ if len(u.create.conflict) == 0 {
+ return errors.New("ent: missing options for AuthIdentityCreate.OnConflict")
+ }
+ return u.create.Exec(ctx)
+}
+
+// ExecX is like Exec, but panics if an error occurs.
+func (u *AuthIdentityUpsertOne) ExecX(ctx context.Context) {
+ if err := u.create.Exec(ctx); err != nil {
+ panic(err)
+ }
+}
+
+// Exec executes the UPSERT query and returns the inserted/updated ID.
+func (u *AuthIdentityUpsertOne) ID(ctx context.Context) (id int64, err error) {
+ node, err := u.create.Save(ctx)
+ if err != nil {
+ return id, err
+ }
+ return node.ID, nil
+}
+
+// IDX is like ID, but panics if an error occurs.
+func (u *AuthIdentityUpsertOne) IDX(ctx context.Context) int64 {
+ id, err := u.ID(ctx)
+ if err != nil {
+ panic(err)
+ }
+ return id
+}
+
+// AuthIdentityCreateBulk is the builder for creating many AuthIdentity entities in bulk.
+type AuthIdentityCreateBulk struct {
+ config
+ err error
+ builders []*AuthIdentityCreate
+ conflict []sql.ConflictOption
+}
+
+// Save creates the AuthIdentity entities in the database.
+func (_c *AuthIdentityCreateBulk) Save(ctx context.Context) ([]*AuthIdentity, error) {
+ if _c.err != nil {
+ return nil, _c.err
+ }
+ specs := make([]*sqlgraph.CreateSpec, len(_c.builders))
+ nodes := make([]*AuthIdentity, len(_c.builders))
+ mutators := make([]Mutator, len(_c.builders))
+ for i := range _c.builders {
+ func(i int, root context.Context) {
+ builder := _c.builders[i]
+ builder.defaults()
+ var mut Mutator = MutateFunc(func(ctx context.Context, m Mutation) (Value, error) {
+ mutation, ok := m.(*AuthIdentityMutation)
+ if !ok {
+ return nil, fmt.Errorf("unexpected mutation type %T", m)
+ }
+ if err := builder.check(); err != nil {
+ return nil, err
+ }
+ builder.mutation = mutation
+ var err error
+ nodes[i], specs[i] = builder.createSpec()
+ if i < len(mutators)-1 {
+ _, err = mutators[i+1].Mutate(root, _c.builders[i+1].mutation)
+ } else {
+ spec := &sqlgraph.BatchCreateSpec{Nodes: specs}
+ spec.OnConflict = _c.conflict
+ // Invoke the actual operation on the latest mutation in the chain.
+ if err = sqlgraph.BatchCreate(ctx, _c.driver, spec); err != nil {
+ if sqlgraph.IsConstraintError(err) {
+ err = &ConstraintError{msg: err.Error(), wrap: err}
+ }
+ }
+ }
+ if err != nil {
+ return nil, err
+ }
+ mutation.id = &nodes[i].ID
+ if specs[i].ID.Value != nil {
+ id := specs[i].ID.Value.(int64)
+ nodes[i].ID = int64(id)
+ }
+ mutation.done = true
+ return nodes[i], nil
+ })
+ for i := len(builder.hooks) - 1; i >= 0; i-- {
+ mut = builder.hooks[i](mut)
+ }
+ mutators[i] = mut
+ }(i, ctx)
+ }
+ if len(mutators) > 0 {
+ if _, err := mutators[0].Mutate(ctx, _c.builders[0].mutation); err != nil {
+ return nil, err
+ }
+ }
+ return nodes, nil
+}
+
+// SaveX is like Save, but panics if an error occurs.
+func (_c *AuthIdentityCreateBulk) SaveX(ctx context.Context) []*AuthIdentity {
+ v, err := _c.Save(ctx)
+ if err != nil {
+ panic(err)
+ }
+ return v
+}
+
+// Exec executes the query.
+func (_c *AuthIdentityCreateBulk) Exec(ctx context.Context) error {
+ _, err := _c.Save(ctx)
+ return err
+}
+
+// ExecX is like Exec, but panics if an error occurs.
+func (_c *AuthIdentityCreateBulk) ExecX(ctx context.Context) {
+ if err := _c.Exec(ctx); err != nil {
+ panic(err)
+ }
+}
+
+// OnConflict allows configuring the `ON CONFLICT` / `ON DUPLICATE KEY` clause
+// of the `INSERT` statement. For example:
+//
+// client.AuthIdentity.CreateBulk(builders...).
+// OnConflict(
+// // Update the row with the new values
+// // the was proposed for insertion.
+// sql.ResolveWithNewValues(),
+// ).
+// // Override some of the fields with custom
+// // update values.
+// Update(func(u *ent.AuthIdentityUpsert) {
+// SetCreatedAt(v+v).
+// }).
+// Exec(ctx)
+func (_c *AuthIdentityCreateBulk) OnConflict(opts ...sql.ConflictOption) *AuthIdentityUpsertBulk {
+ _c.conflict = opts
+ return &AuthIdentityUpsertBulk{
+ create: _c,
+ }
+}
+
+// OnConflictColumns calls `OnConflict` and configures the columns
+// as conflict target. Using this option is equivalent to using:
+//
+// client.AuthIdentity.Create().
+// OnConflict(sql.ConflictColumns(columns...)).
+// Exec(ctx)
+func (_c *AuthIdentityCreateBulk) OnConflictColumns(columns ...string) *AuthIdentityUpsertBulk {
+ _c.conflict = append(_c.conflict, sql.ConflictColumns(columns...))
+ return &AuthIdentityUpsertBulk{
+ create: _c,
+ }
+}
+
+// AuthIdentityUpsertBulk is the builder for "upsert"-ing
+// a bulk of AuthIdentity nodes.
+type AuthIdentityUpsertBulk struct {
+ create *AuthIdentityCreateBulk
+}
+
+// UpdateNewValues updates the mutable fields using the new values that
+// were set on create. Using this option is equivalent to using:
+//
+// client.AuthIdentity.Create().
+// OnConflict(
+// sql.ResolveWithNewValues(),
+// ).
+// Exec(ctx)
+func (u *AuthIdentityUpsertBulk) UpdateNewValues() *AuthIdentityUpsertBulk {
+ u.create.conflict = append(u.create.conflict, sql.ResolveWithNewValues())
+ u.create.conflict = append(u.create.conflict, sql.ResolveWith(func(s *sql.UpdateSet) {
+ for _, b := range u.create.builders {
+ if _, exists := b.mutation.CreatedAt(); exists {
+ s.SetIgnore(authidentity.FieldCreatedAt)
+ }
+ }
+ }))
+ return u
+}
+
+// Ignore sets each column to itself in case of conflict.
+// Using this option is equivalent to using:
+//
+// client.AuthIdentity.Create().
+// OnConflict(sql.ResolveWithIgnore()).
+// Exec(ctx)
+func (u *AuthIdentityUpsertBulk) Ignore() *AuthIdentityUpsertBulk {
+ u.create.conflict = append(u.create.conflict, sql.ResolveWithIgnore())
+ return u
+}
+
+// DoNothing configures the conflict_action to `DO NOTHING`.
+// Supported only by SQLite and PostgreSQL.
+func (u *AuthIdentityUpsertBulk) DoNothing() *AuthIdentityUpsertBulk {
+ u.create.conflict = append(u.create.conflict, sql.DoNothing())
+ return u
+}
+
+// Update allows overriding fields `UPDATE` values. See the AuthIdentityCreateBulk.OnConflict
+// documentation for more info.
+func (u *AuthIdentityUpsertBulk) Update(set func(*AuthIdentityUpsert)) *AuthIdentityUpsertBulk {
+ u.create.conflict = append(u.create.conflict, sql.ResolveWith(func(update *sql.UpdateSet) {
+ set(&AuthIdentityUpsert{UpdateSet: update})
+ }))
+ return u
+}
+
+// SetUpdatedAt sets the "updated_at" field.
+func (u *AuthIdentityUpsertBulk) SetUpdatedAt(v time.Time) *AuthIdentityUpsertBulk {
+ return u.Update(func(s *AuthIdentityUpsert) {
+ s.SetUpdatedAt(v)
+ })
+}
+
+// UpdateUpdatedAt sets the "updated_at" field to the value that was provided on create.
+func (u *AuthIdentityUpsertBulk) UpdateUpdatedAt() *AuthIdentityUpsertBulk {
+ return u.Update(func(s *AuthIdentityUpsert) {
+ s.UpdateUpdatedAt()
+ })
+}
+
+// SetUserID sets the "user_id" field.
+func (u *AuthIdentityUpsertBulk) SetUserID(v int64) *AuthIdentityUpsertBulk {
+ return u.Update(func(s *AuthIdentityUpsert) {
+ s.SetUserID(v)
+ })
+}
+
+// UpdateUserID sets the "user_id" field to the value that was provided on create.
+func (u *AuthIdentityUpsertBulk) UpdateUserID() *AuthIdentityUpsertBulk {
+ return u.Update(func(s *AuthIdentityUpsert) {
+ s.UpdateUserID()
+ })
+}
+
+// SetProviderType sets the "provider_type" field.
+func (u *AuthIdentityUpsertBulk) SetProviderType(v string) *AuthIdentityUpsertBulk {
+ return u.Update(func(s *AuthIdentityUpsert) {
+ s.SetProviderType(v)
+ })
+}
+
+// UpdateProviderType sets the "provider_type" field to the value that was provided on create.
+func (u *AuthIdentityUpsertBulk) UpdateProviderType() *AuthIdentityUpsertBulk {
+ return u.Update(func(s *AuthIdentityUpsert) {
+ s.UpdateProviderType()
+ })
+}
+
+// SetProviderKey sets the "provider_key" field.
+func (u *AuthIdentityUpsertBulk) SetProviderKey(v string) *AuthIdentityUpsertBulk {
+ return u.Update(func(s *AuthIdentityUpsert) {
+ s.SetProviderKey(v)
+ })
+}
+
+// UpdateProviderKey sets the "provider_key" field to the value that was provided on create.
+func (u *AuthIdentityUpsertBulk) UpdateProviderKey() *AuthIdentityUpsertBulk {
+ return u.Update(func(s *AuthIdentityUpsert) {
+ s.UpdateProviderKey()
+ })
+}
+
+// SetProviderSubject sets the "provider_subject" field.
+func (u *AuthIdentityUpsertBulk) SetProviderSubject(v string) *AuthIdentityUpsertBulk {
+ return u.Update(func(s *AuthIdentityUpsert) {
+ s.SetProviderSubject(v)
+ })
+}
+
+// UpdateProviderSubject sets the "provider_subject" field to the value that was provided on create.
+func (u *AuthIdentityUpsertBulk) UpdateProviderSubject() *AuthIdentityUpsertBulk {
+ return u.Update(func(s *AuthIdentityUpsert) {
+ s.UpdateProviderSubject()
+ })
+}
+
+// SetVerifiedAt sets the "verified_at" field.
+func (u *AuthIdentityUpsertBulk) SetVerifiedAt(v time.Time) *AuthIdentityUpsertBulk {
+ return u.Update(func(s *AuthIdentityUpsert) {
+ s.SetVerifiedAt(v)
+ })
+}
+
+// UpdateVerifiedAt sets the "verified_at" field to the value that was provided on create.
+func (u *AuthIdentityUpsertBulk) UpdateVerifiedAt() *AuthIdentityUpsertBulk {
+ return u.Update(func(s *AuthIdentityUpsert) {
+ s.UpdateVerifiedAt()
+ })
+}
+
+// ClearVerifiedAt clears the value of the "verified_at" field.
+func (u *AuthIdentityUpsertBulk) ClearVerifiedAt() *AuthIdentityUpsertBulk {
+ return u.Update(func(s *AuthIdentityUpsert) {
+ s.ClearVerifiedAt()
+ })
+}
+
+// SetIssuer sets the "issuer" field.
+func (u *AuthIdentityUpsertBulk) SetIssuer(v string) *AuthIdentityUpsertBulk {
+ return u.Update(func(s *AuthIdentityUpsert) {
+ s.SetIssuer(v)
+ })
+}
+
+// UpdateIssuer sets the "issuer" field to the value that was provided on create.
+func (u *AuthIdentityUpsertBulk) UpdateIssuer() *AuthIdentityUpsertBulk {
+ return u.Update(func(s *AuthIdentityUpsert) {
+ s.UpdateIssuer()
+ })
+}
+
+// ClearIssuer clears the value of the "issuer" field.
+func (u *AuthIdentityUpsertBulk) ClearIssuer() *AuthIdentityUpsertBulk {
+ return u.Update(func(s *AuthIdentityUpsert) {
+ s.ClearIssuer()
+ })
+}
+
+// SetMetadata sets the "metadata" field.
+func (u *AuthIdentityUpsertBulk) SetMetadata(v map[string]interface{}) *AuthIdentityUpsertBulk {
+ return u.Update(func(s *AuthIdentityUpsert) {
+ s.SetMetadata(v)
+ })
+}
+
+// UpdateMetadata sets the "metadata" field to the value that was provided on create.
+func (u *AuthIdentityUpsertBulk) UpdateMetadata() *AuthIdentityUpsertBulk {
+ return u.Update(func(s *AuthIdentityUpsert) {
+ s.UpdateMetadata()
+ })
+}
+
+// Exec executes the query.
+func (u *AuthIdentityUpsertBulk) Exec(ctx context.Context) error {
+ if u.create.err != nil {
+ return u.create.err
+ }
+ for i, b := range u.create.builders {
+ if len(b.conflict) != 0 {
+ return fmt.Errorf("ent: OnConflict was set for builder %d. Set it on the AuthIdentityCreateBulk instead", i)
+ }
+ }
+ if len(u.create.conflict) == 0 {
+ return errors.New("ent: missing options for AuthIdentityCreateBulk.OnConflict")
+ }
+ return u.create.Exec(ctx)
+}
+
+// ExecX is like Exec, but panics if an error occurs.
+func (u *AuthIdentityUpsertBulk) ExecX(ctx context.Context) {
+ if err := u.create.Exec(ctx); err != nil {
+ panic(err)
+ }
+}
diff --git a/backend/ent/authidentity_delete.go b/backend/ent/authidentity_delete.go
new file mode 100644
index 00000000..4f1f6f3c
--- /dev/null
+++ b/backend/ent/authidentity_delete.go
@@ -0,0 +1,88 @@
+// Code generated by ent, DO NOT EDIT.
+
+package ent
+
+import (
+ "context"
+
+ "entgo.io/ent/dialect/sql"
+ "entgo.io/ent/dialect/sql/sqlgraph"
+ "entgo.io/ent/schema/field"
+ "github.com/Wei-Shaw/sub2api/ent/authidentity"
+ "github.com/Wei-Shaw/sub2api/ent/predicate"
+)
+
+// AuthIdentityDelete is the builder for deleting a AuthIdentity entity.
+type AuthIdentityDelete struct {
+ config
+ hooks []Hook
+ mutation *AuthIdentityMutation
+}
+
+// Where appends a list predicates to the AuthIdentityDelete builder.
+func (_d *AuthIdentityDelete) Where(ps ...predicate.AuthIdentity) *AuthIdentityDelete {
+ _d.mutation.Where(ps...)
+ return _d
+}
+
+// Exec executes the deletion query and returns how many vertices were deleted.
+func (_d *AuthIdentityDelete) Exec(ctx context.Context) (int, error) {
+ return withHooks(ctx, _d.sqlExec, _d.mutation, _d.hooks)
+}
+
+// ExecX is like Exec, but panics if an error occurs.
+func (_d *AuthIdentityDelete) ExecX(ctx context.Context) int {
+ n, err := _d.Exec(ctx)
+ if err != nil {
+ panic(err)
+ }
+ return n
+}
+
+func (_d *AuthIdentityDelete) sqlExec(ctx context.Context) (int, error) {
+ _spec := sqlgraph.NewDeleteSpec(authidentity.Table, sqlgraph.NewFieldSpec(authidentity.FieldID, field.TypeInt64))
+ if ps := _d.mutation.predicates; len(ps) > 0 {
+ _spec.Predicate = func(selector *sql.Selector) {
+ for i := range ps {
+ ps[i](selector)
+ }
+ }
+ }
+ affected, err := sqlgraph.DeleteNodes(ctx, _d.driver, _spec)
+ if err != nil && sqlgraph.IsConstraintError(err) {
+ err = &ConstraintError{msg: err.Error(), wrap: err}
+ }
+ _d.mutation.done = true
+ return affected, err
+}
+
+// AuthIdentityDeleteOne is the builder for deleting a single AuthIdentity entity.
+type AuthIdentityDeleteOne struct {
+ _d *AuthIdentityDelete
+}
+
+// Where appends a list predicates to the AuthIdentityDelete builder.
+func (_d *AuthIdentityDeleteOne) Where(ps ...predicate.AuthIdentity) *AuthIdentityDeleteOne {
+ _d._d.mutation.Where(ps...)
+ return _d
+}
+
+// Exec executes the deletion query.
+func (_d *AuthIdentityDeleteOne) Exec(ctx context.Context) error {
+ n, err := _d._d.Exec(ctx)
+ switch {
+ case err != nil:
+ return err
+ case n == 0:
+ return &NotFoundError{authidentity.Label}
+ default:
+ return nil
+ }
+}
+
+// ExecX is like Exec, but panics if an error occurs.
+func (_d *AuthIdentityDeleteOne) ExecX(ctx context.Context) {
+ if err := _d.Exec(ctx); err != nil {
+ panic(err)
+ }
+}
diff --git a/backend/ent/authidentity_query.go b/backend/ent/authidentity_query.go
new file mode 100644
index 00000000..ff27ef3c
--- /dev/null
+++ b/backend/ent/authidentity_query.go
@@ -0,0 +1,797 @@
+// Code generated by ent, DO NOT EDIT.
+
+package ent
+
+import (
+ "context"
+ "database/sql/driver"
+ "fmt"
+ "math"
+
+ "entgo.io/ent"
+ "entgo.io/ent/dialect"
+ "entgo.io/ent/dialect/sql"
+ "entgo.io/ent/dialect/sql/sqlgraph"
+ "entgo.io/ent/schema/field"
+ "github.com/Wei-Shaw/sub2api/ent/authidentity"
+ "github.com/Wei-Shaw/sub2api/ent/authidentitychannel"
+ "github.com/Wei-Shaw/sub2api/ent/identityadoptiondecision"
+ "github.com/Wei-Shaw/sub2api/ent/predicate"
+ "github.com/Wei-Shaw/sub2api/ent/user"
+)
+
+// AuthIdentityQuery is the builder for querying AuthIdentity entities.
+type AuthIdentityQuery struct {
+ config
+ ctx *QueryContext
+ order []authidentity.OrderOption
+ inters []Interceptor
+ predicates []predicate.AuthIdentity
+ withUser *UserQuery
+ withChannels *AuthIdentityChannelQuery
+ withAdoptionDecisions *IdentityAdoptionDecisionQuery
+ modifiers []func(*sql.Selector)
+ // intermediate query (i.e. traversal path).
+ sql *sql.Selector
+ path func(context.Context) (*sql.Selector, error)
+}
+
+// Where adds a new predicate for the AuthIdentityQuery builder.
+func (_q *AuthIdentityQuery) Where(ps ...predicate.AuthIdentity) *AuthIdentityQuery {
+ _q.predicates = append(_q.predicates, ps...)
+ return _q
+}
+
+// Limit the number of records to be returned by this query.
+func (_q *AuthIdentityQuery) Limit(limit int) *AuthIdentityQuery {
+ _q.ctx.Limit = &limit
+ return _q
+}
+
+// Offset to start from.
+func (_q *AuthIdentityQuery) Offset(offset int) *AuthIdentityQuery {
+ _q.ctx.Offset = &offset
+ return _q
+}
+
+// Unique configures the query builder to filter duplicate records on query.
+// By default, unique is set to true, and can be disabled using this method.
+func (_q *AuthIdentityQuery) Unique(unique bool) *AuthIdentityQuery {
+ _q.ctx.Unique = &unique
+ return _q
+}
+
+// Order specifies how the records should be ordered.
+func (_q *AuthIdentityQuery) Order(o ...authidentity.OrderOption) *AuthIdentityQuery {
+ _q.order = append(_q.order, o...)
+ return _q
+}
+
+// QueryUser chains the current query on the "user" edge.
+func (_q *AuthIdentityQuery) QueryUser() *UserQuery {
+ query := (&UserClient{config: _q.config}).Query()
+ query.path = func(ctx context.Context) (fromU *sql.Selector, err error) {
+ if err := _q.prepareQuery(ctx); err != nil {
+ return nil, err
+ }
+ selector := _q.sqlQuery(ctx)
+ if err := selector.Err(); err != nil {
+ return nil, err
+ }
+ step := sqlgraph.NewStep(
+ sqlgraph.From(authidentity.Table, authidentity.FieldID, selector),
+ sqlgraph.To(user.Table, user.FieldID),
+ sqlgraph.Edge(sqlgraph.M2O, true, authidentity.UserTable, authidentity.UserColumn),
+ )
+ fromU = sqlgraph.SetNeighbors(_q.driver.Dialect(), step)
+ return fromU, nil
+ }
+ return query
+}
+
+// QueryChannels chains the current query on the "channels" edge.
+func (_q *AuthIdentityQuery) QueryChannels() *AuthIdentityChannelQuery {
+ query := (&AuthIdentityChannelClient{config: _q.config}).Query()
+ query.path = func(ctx context.Context) (fromU *sql.Selector, err error) {
+ if err := _q.prepareQuery(ctx); err != nil {
+ return nil, err
+ }
+ selector := _q.sqlQuery(ctx)
+ if err := selector.Err(); err != nil {
+ return nil, err
+ }
+ step := sqlgraph.NewStep(
+ sqlgraph.From(authidentity.Table, authidentity.FieldID, selector),
+ sqlgraph.To(authidentitychannel.Table, authidentitychannel.FieldID),
+ sqlgraph.Edge(sqlgraph.O2M, false, authidentity.ChannelsTable, authidentity.ChannelsColumn),
+ )
+ fromU = sqlgraph.SetNeighbors(_q.driver.Dialect(), step)
+ return fromU, nil
+ }
+ return query
+}
+
+// QueryAdoptionDecisions chains the current query on the "adoption_decisions" edge.
+func (_q *AuthIdentityQuery) QueryAdoptionDecisions() *IdentityAdoptionDecisionQuery {
+ query := (&IdentityAdoptionDecisionClient{config: _q.config}).Query()
+ query.path = func(ctx context.Context) (fromU *sql.Selector, err error) {
+ if err := _q.prepareQuery(ctx); err != nil {
+ return nil, err
+ }
+ selector := _q.sqlQuery(ctx)
+ if err := selector.Err(); err != nil {
+ return nil, err
+ }
+ step := sqlgraph.NewStep(
+ sqlgraph.From(authidentity.Table, authidentity.FieldID, selector),
+ sqlgraph.To(identityadoptiondecision.Table, identityadoptiondecision.FieldID),
+ sqlgraph.Edge(sqlgraph.O2M, false, authidentity.AdoptionDecisionsTable, authidentity.AdoptionDecisionsColumn),
+ )
+ fromU = sqlgraph.SetNeighbors(_q.driver.Dialect(), step)
+ return fromU, nil
+ }
+ return query
+}
+
+// First returns the first AuthIdentity entity from the query.
+// Returns a *NotFoundError when no AuthIdentity was found.
+func (_q *AuthIdentityQuery) First(ctx context.Context) (*AuthIdentity, error) {
+ nodes, err := _q.Limit(1).All(setContextOp(ctx, _q.ctx, ent.OpQueryFirst))
+ if err != nil {
+ return nil, err
+ }
+ if len(nodes) == 0 {
+ return nil, &NotFoundError{authidentity.Label}
+ }
+ return nodes[0], nil
+}
+
+// FirstX is like First, but panics if an error occurs.
+func (_q *AuthIdentityQuery) FirstX(ctx context.Context) *AuthIdentity {
+ node, err := _q.First(ctx)
+ if err != nil && !IsNotFound(err) {
+ panic(err)
+ }
+ return node
+}
+
+// FirstID returns the first AuthIdentity ID from the query.
+// Returns a *NotFoundError when no AuthIdentity ID was found.
+func (_q *AuthIdentityQuery) FirstID(ctx context.Context) (id int64, err error) {
+ var ids []int64
+ if ids, err = _q.Limit(1).IDs(setContextOp(ctx, _q.ctx, ent.OpQueryFirstID)); err != nil {
+ return
+ }
+ if len(ids) == 0 {
+ err = &NotFoundError{authidentity.Label}
+ return
+ }
+ return ids[0], nil
+}
+
+// FirstIDX is like FirstID, but panics if an error occurs.
+func (_q *AuthIdentityQuery) FirstIDX(ctx context.Context) int64 {
+ id, err := _q.FirstID(ctx)
+ if err != nil && !IsNotFound(err) {
+ panic(err)
+ }
+ return id
+}
+
+// Only returns a single AuthIdentity entity found by the query, ensuring it only returns one.
+// Returns a *NotSingularError when more than one AuthIdentity entity is found.
+// Returns a *NotFoundError when no AuthIdentity entities are found.
+func (_q *AuthIdentityQuery) Only(ctx context.Context) (*AuthIdentity, error) {
+ nodes, err := _q.Limit(2).All(setContextOp(ctx, _q.ctx, ent.OpQueryOnly))
+ if err != nil {
+ return nil, err
+ }
+ switch len(nodes) {
+ case 1:
+ return nodes[0], nil
+ case 0:
+ return nil, &NotFoundError{authidentity.Label}
+ default:
+ return nil, &NotSingularError{authidentity.Label}
+ }
+}
+
+// OnlyX is like Only, but panics if an error occurs.
+func (_q *AuthIdentityQuery) OnlyX(ctx context.Context) *AuthIdentity {
+ node, err := _q.Only(ctx)
+ if err != nil {
+ panic(err)
+ }
+ return node
+}
+
+// OnlyID is like Only, but returns the only AuthIdentity ID in the query.
+// Returns a *NotSingularError when more than one AuthIdentity ID is found.
+// Returns a *NotFoundError when no entities are found.
+func (_q *AuthIdentityQuery) OnlyID(ctx context.Context) (id int64, err error) {
+ var ids []int64
+ if ids, err = _q.Limit(2).IDs(setContextOp(ctx, _q.ctx, ent.OpQueryOnlyID)); err != nil {
+ return
+ }
+ switch len(ids) {
+ case 1:
+ id = ids[0]
+ case 0:
+ err = &NotFoundError{authidentity.Label}
+ default:
+ err = &NotSingularError{authidentity.Label}
+ }
+ return
+}
+
+// OnlyIDX is like OnlyID, but panics if an error occurs.
+func (_q *AuthIdentityQuery) OnlyIDX(ctx context.Context) int64 {
+ id, err := _q.OnlyID(ctx)
+ if err != nil {
+ panic(err)
+ }
+ return id
+}
+
+// All executes the query and returns a list of AuthIdentities.
+func (_q *AuthIdentityQuery) All(ctx context.Context) ([]*AuthIdentity, error) {
+ ctx = setContextOp(ctx, _q.ctx, ent.OpQueryAll)
+ if err := _q.prepareQuery(ctx); err != nil {
+ return nil, err
+ }
+ qr := querierAll[[]*AuthIdentity, *AuthIdentityQuery]()
+ return withInterceptors[[]*AuthIdentity](ctx, _q, qr, _q.inters)
+}
+
+// AllX is like All, but panics if an error occurs.
+func (_q *AuthIdentityQuery) AllX(ctx context.Context) []*AuthIdentity {
+ nodes, err := _q.All(ctx)
+ if err != nil {
+ panic(err)
+ }
+ return nodes
+}
+
+// IDs executes the query and returns a list of AuthIdentity IDs.
+func (_q *AuthIdentityQuery) IDs(ctx context.Context) (ids []int64, err error) {
+ if _q.ctx.Unique == nil && _q.path != nil {
+ _q.Unique(true)
+ }
+ ctx = setContextOp(ctx, _q.ctx, ent.OpQueryIDs)
+ if err = _q.Select(authidentity.FieldID).Scan(ctx, &ids); err != nil {
+ return nil, err
+ }
+ return ids, nil
+}
+
+// IDsX is like IDs, but panics if an error occurs.
+func (_q *AuthIdentityQuery) IDsX(ctx context.Context) []int64 {
+ ids, err := _q.IDs(ctx)
+ if err != nil {
+ panic(err)
+ }
+ return ids
+}
+
+// Count returns the count of the given query.
+func (_q *AuthIdentityQuery) Count(ctx context.Context) (int, error) {
+ ctx = setContextOp(ctx, _q.ctx, ent.OpQueryCount)
+ if err := _q.prepareQuery(ctx); err != nil {
+ return 0, err
+ }
+ return withInterceptors[int](ctx, _q, querierCount[*AuthIdentityQuery](), _q.inters)
+}
+
+// CountX is like Count, but panics if an error occurs.
+func (_q *AuthIdentityQuery) CountX(ctx context.Context) int {
+ count, err := _q.Count(ctx)
+ if err != nil {
+ panic(err)
+ }
+ return count
+}
+
+// Exist returns true if the query has elements in the graph.
+func (_q *AuthIdentityQuery) Exist(ctx context.Context) (bool, error) {
+ ctx = setContextOp(ctx, _q.ctx, ent.OpQueryExist)
+ switch _, err := _q.FirstID(ctx); {
+ case IsNotFound(err):
+ return false, nil
+ case err != nil:
+ return false, fmt.Errorf("ent: check existence: %w", err)
+ default:
+ return true, nil
+ }
+}
+
+// ExistX is like Exist, but panics if an error occurs.
+func (_q *AuthIdentityQuery) ExistX(ctx context.Context) bool {
+ exist, err := _q.Exist(ctx)
+ if err != nil {
+ panic(err)
+ }
+ return exist
+}
+
+// Clone returns a duplicate of the AuthIdentityQuery builder, including all associated steps. It can be
+// used to prepare common query builders and use them differently after the clone is made.
+func (_q *AuthIdentityQuery) Clone() *AuthIdentityQuery {
+ if _q == nil {
+ return nil
+ }
+ return &AuthIdentityQuery{
+ config: _q.config,
+ ctx: _q.ctx.Clone(),
+ order: append([]authidentity.OrderOption{}, _q.order...),
+ inters: append([]Interceptor{}, _q.inters...),
+ predicates: append([]predicate.AuthIdentity{}, _q.predicates...),
+ withUser: _q.withUser.Clone(),
+ withChannels: _q.withChannels.Clone(),
+ withAdoptionDecisions: _q.withAdoptionDecisions.Clone(),
+ // clone intermediate query.
+ sql: _q.sql.Clone(),
+ path: _q.path,
+ }
+}
+
+// WithUser tells the query-builder to eager-load the nodes that are connected to
+// the "user" edge. The optional arguments are used to configure the query builder of the edge.
+func (_q *AuthIdentityQuery) WithUser(opts ...func(*UserQuery)) *AuthIdentityQuery {
+ query := (&UserClient{config: _q.config}).Query()
+ for _, opt := range opts {
+ opt(query)
+ }
+ _q.withUser = query
+ return _q
+}
+
+// WithChannels tells the query-builder to eager-load the nodes that are connected to
+// the "channels" edge. The optional arguments are used to configure the query builder of the edge.
+func (_q *AuthIdentityQuery) WithChannels(opts ...func(*AuthIdentityChannelQuery)) *AuthIdentityQuery {
+ query := (&AuthIdentityChannelClient{config: _q.config}).Query()
+ for _, opt := range opts {
+ opt(query)
+ }
+ _q.withChannels = query
+ return _q
+}
+
+// WithAdoptionDecisions tells the query-builder to eager-load the nodes that are connected to
+// the "adoption_decisions" edge. The optional arguments are used to configure the query builder of the edge.
+func (_q *AuthIdentityQuery) WithAdoptionDecisions(opts ...func(*IdentityAdoptionDecisionQuery)) *AuthIdentityQuery {
+ query := (&IdentityAdoptionDecisionClient{config: _q.config}).Query()
+ for _, opt := range opts {
+ opt(query)
+ }
+ _q.withAdoptionDecisions = query
+ return _q
+}
+
+// GroupBy is used to group vertices by one or more fields/columns.
+// It is often used with aggregate functions, like: count, max, mean, min, sum.
+//
+// Example:
+//
+// var v []struct {
+// CreatedAt time.Time `json:"created_at,omitempty"`
+// Count int `json:"count,omitempty"`
+// }
+//
+// client.AuthIdentity.Query().
+// GroupBy(authidentity.FieldCreatedAt).
+// Aggregate(ent.Count()).
+// Scan(ctx, &v)
+func (_q *AuthIdentityQuery) GroupBy(field string, fields ...string) *AuthIdentityGroupBy {
+ _q.ctx.Fields = append([]string{field}, fields...)
+ grbuild := &AuthIdentityGroupBy{build: _q}
+ grbuild.flds = &_q.ctx.Fields
+ grbuild.label = authidentity.Label
+ grbuild.scan = grbuild.Scan
+ return grbuild
+}
+
+// Select allows the selection one or more fields/columns for the given query,
+// instead of selecting all fields in the entity.
+//
+// Example:
+//
+// var v []struct {
+// CreatedAt time.Time `json:"created_at,omitempty"`
+// }
+//
+// client.AuthIdentity.Query().
+// Select(authidentity.FieldCreatedAt).
+// Scan(ctx, &v)
+func (_q *AuthIdentityQuery) Select(fields ...string) *AuthIdentitySelect {
+ _q.ctx.Fields = append(_q.ctx.Fields, fields...)
+ sbuild := &AuthIdentitySelect{AuthIdentityQuery: _q}
+ sbuild.label = authidentity.Label
+ sbuild.flds, sbuild.scan = &_q.ctx.Fields, sbuild.Scan
+ return sbuild
+}
+
+// Aggregate returns a AuthIdentitySelect configured with the given aggregations.
+func (_q *AuthIdentityQuery) Aggregate(fns ...AggregateFunc) *AuthIdentitySelect {
+ return _q.Select().Aggregate(fns...)
+}
+
+func (_q *AuthIdentityQuery) prepareQuery(ctx context.Context) error {
+ for _, inter := range _q.inters {
+ if inter == nil {
+ return fmt.Errorf("ent: uninitialized interceptor (forgotten import ent/runtime?)")
+ }
+ if trv, ok := inter.(Traverser); ok {
+ if err := trv.Traverse(ctx, _q); err != nil {
+ return err
+ }
+ }
+ }
+ for _, f := range _q.ctx.Fields {
+ if !authidentity.ValidColumn(f) {
+ return &ValidationError{Name: f, err: fmt.Errorf("ent: invalid field %q for query", f)}
+ }
+ }
+ if _q.path != nil {
+ prev, err := _q.path(ctx)
+ if err != nil {
+ return err
+ }
+ _q.sql = prev
+ }
+ return nil
+}
+
+func (_q *AuthIdentityQuery) sqlAll(ctx context.Context, hooks ...queryHook) ([]*AuthIdentity, error) {
+ var (
+ nodes = []*AuthIdentity{}
+ _spec = _q.querySpec()
+ loadedTypes = [3]bool{
+ _q.withUser != nil,
+ _q.withChannels != nil,
+ _q.withAdoptionDecisions != nil,
+ }
+ )
+ _spec.ScanValues = func(columns []string) ([]any, error) {
+ return (*AuthIdentity).scanValues(nil, columns)
+ }
+ _spec.Assign = func(columns []string, values []any) error {
+ node := &AuthIdentity{config: _q.config}
+ nodes = append(nodes, node)
+ node.Edges.loadedTypes = loadedTypes
+ return node.assignValues(columns, values)
+ }
+ if len(_q.modifiers) > 0 {
+ _spec.Modifiers = _q.modifiers
+ }
+ for i := range hooks {
+ hooks[i](ctx, _spec)
+ }
+ if err := sqlgraph.QueryNodes(ctx, _q.driver, _spec); err != nil {
+ return nil, err
+ }
+ if len(nodes) == 0 {
+ return nodes, nil
+ }
+ if query := _q.withUser; query != nil {
+ if err := _q.loadUser(ctx, query, nodes, nil,
+ func(n *AuthIdentity, e *User) { n.Edges.User = e }); err != nil {
+ return nil, err
+ }
+ }
+ if query := _q.withChannels; query != nil {
+ if err := _q.loadChannels(ctx, query, nodes,
+ func(n *AuthIdentity) { n.Edges.Channels = []*AuthIdentityChannel{} },
+ func(n *AuthIdentity, e *AuthIdentityChannel) { n.Edges.Channels = append(n.Edges.Channels, e) }); err != nil {
+ return nil, err
+ }
+ }
+ if query := _q.withAdoptionDecisions; query != nil {
+ if err := _q.loadAdoptionDecisions(ctx, query, nodes,
+ func(n *AuthIdentity) { n.Edges.AdoptionDecisions = []*IdentityAdoptionDecision{} },
+ func(n *AuthIdentity, e *IdentityAdoptionDecision) {
+ n.Edges.AdoptionDecisions = append(n.Edges.AdoptionDecisions, e)
+ }); err != nil {
+ return nil, err
+ }
+ }
+ return nodes, nil
+}
+
+func (_q *AuthIdentityQuery) loadUser(ctx context.Context, query *UserQuery, nodes []*AuthIdentity, init func(*AuthIdentity), assign func(*AuthIdentity, *User)) error {
+ ids := make([]int64, 0, len(nodes))
+ nodeids := make(map[int64][]*AuthIdentity)
+ for i := range nodes {
+ fk := nodes[i].UserID
+ if _, ok := nodeids[fk]; !ok {
+ ids = append(ids, fk)
+ }
+ nodeids[fk] = append(nodeids[fk], nodes[i])
+ }
+ if len(ids) == 0 {
+ return nil
+ }
+ query.Where(user.IDIn(ids...))
+ neighbors, err := query.All(ctx)
+ if err != nil {
+ return err
+ }
+ for _, n := range neighbors {
+ nodes, ok := nodeids[n.ID]
+ if !ok {
+ return fmt.Errorf(`unexpected foreign-key "user_id" returned %v`, n.ID)
+ }
+ for i := range nodes {
+ assign(nodes[i], n)
+ }
+ }
+ return nil
+}
+func (_q *AuthIdentityQuery) loadChannels(ctx context.Context, query *AuthIdentityChannelQuery, nodes []*AuthIdentity, init func(*AuthIdentity), assign func(*AuthIdentity, *AuthIdentityChannel)) error {
+ fks := make([]driver.Value, 0, len(nodes))
+ nodeids := make(map[int64]*AuthIdentity)
+ for i := range nodes {
+ fks = append(fks, nodes[i].ID)
+ nodeids[nodes[i].ID] = nodes[i]
+ if init != nil {
+ init(nodes[i])
+ }
+ }
+ if len(query.ctx.Fields) > 0 {
+ query.ctx.AppendFieldOnce(authidentitychannel.FieldIdentityID)
+ }
+ query.Where(predicate.AuthIdentityChannel(func(s *sql.Selector) {
+ s.Where(sql.InValues(s.C(authidentity.ChannelsColumn), fks...))
+ }))
+ neighbors, err := query.All(ctx)
+ if err != nil {
+ return err
+ }
+ for _, n := range neighbors {
+ fk := n.IdentityID
+ node, ok := nodeids[fk]
+ if !ok {
+ return fmt.Errorf(`unexpected referenced foreign-key "identity_id" returned %v for node %v`, fk, n.ID)
+ }
+ assign(node, n)
+ }
+ return nil
+}
+func (_q *AuthIdentityQuery) loadAdoptionDecisions(ctx context.Context, query *IdentityAdoptionDecisionQuery, nodes []*AuthIdentity, init func(*AuthIdentity), assign func(*AuthIdentity, *IdentityAdoptionDecision)) error {
+ fks := make([]driver.Value, 0, len(nodes))
+ nodeids := make(map[int64]*AuthIdentity)
+ for i := range nodes {
+ fks = append(fks, nodes[i].ID)
+ nodeids[nodes[i].ID] = nodes[i]
+ if init != nil {
+ init(nodes[i])
+ }
+ }
+ if len(query.ctx.Fields) > 0 {
+ query.ctx.AppendFieldOnce(identityadoptiondecision.FieldIdentityID)
+ }
+ query.Where(predicate.IdentityAdoptionDecision(func(s *sql.Selector) {
+ s.Where(sql.InValues(s.C(authidentity.AdoptionDecisionsColumn), fks...))
+ }))
+ neighbors, err := query.All(ctx)
+ if err != nil {
+ return err
+ }
+ for _, n := range neighbors {
+ fk := n.IdentityID
+ if fk == nil {
+ return fmt.Errorf(`foreign-key "identity_id" is nil for node %v`, n.ID)
+ }
+ node, ok := nodeids[*fk]
+ if !ok {
+ return fmt.Errorf(`unexpected referenced foreign-key "identity_id" returned %v for node %v`, *fk, n.ID)
+ }
+ assign(node, n)
+ }
+ return nil
+}
+
+func (_q *AuthIdentityQuery) sqlCount(ctx context.Context) (int, error) {
+ _spec := _q.querySpec()
+ if len(_q.modifiers) > 0 {
+ _spec.Modifiers = _q.modifiers
+ }
+ _spec.Node.Columns = _q.ctx.Fields
+ if len(_q.ctx.Fields) > 0 {
+ _spec.Unique = _q.ctx.Unique != nil && *_q.ctx.Unique
+ }
+ return sqlgraph.CountNodes(ctx, _q.driver, _spec)
+}
+
+func (_q *AuthIdentityQuery) querySpec() *sqlgraph.QuerySpec {
+ _spec := sqlgraph.NewQuerySpec(authidentity.Table, authidentity.Columns, sqlgraph.NewFieldSpec(authidentity.FieldID, field.TypeInt64))
+ _spec.From = _q.sql
+ if unique := _q.ctx.Unique; unique != nil {
+ _spec.Unique = *unique
+ } else if _q.path != nil {
+ _spec.Unique = true
+ }
+ if fields := _q.ctx.Fields; len(fields) > 0 {
+ _spec.Node.Columns = make([]string, 0, len(fields))
+ _spec.Node.Columns = append(_spec.Node.Columns, authidentity.FieldID)
+ for i := range fields {
+ if fields[i] != authidentity.FieldID {
+ _spec.Node.Columns = append(_spec.Node.Columns, fields[i])
+ }
+ }
+ if _q.withUser != nil {
+ _spec.Node.AddColumnOnce(authidentity.FieldUserID)
+ }
+ }
+ if ps := _q.predicates; len(ps) > 0 {
+ _spec.Predicate = func(selector *sql.Selector) {
+ for i := range ps {
+ ps[i](selector)
+ }
+ }
+ }
+ if limit := _q.ctx.Limit; limit != nil {
+ _spec.Limit = *limit
+ }
+ if offset := _q.ctx.Offset; offset != nil {
+ _spec.Offset = *offset
+ }
+ if ps := _q.order; len(ps) > 0 {
+ _spec.Order = func(selector *sql.Selector) {
+ for i := range ps {
+ ps[i](selector)
+ }
+ }
+ }
+ return _spec
+}
+
+func (_q *AuthIdentityQuery) sqlQuery(ctx context.Context) *sql.Selector {
+ builder := sql.Dialect(_q.driver.Dialect())
+ t1 := builder.Table(authidentity.Table)
+ columns := _q.ctx.Fields
+ if len(columns) == 0 {
+ columns = authidentity.Columns
+ }
+ selector := builder.Select(t1.Columns(columns...)...).From(t1)
+ if _q.sql != nil {
+ selector = _q.sql
+ selector.Select(selector.Columns(columns...)...)
+ }
+ if _q.ctx.Unique != nil && *_q.ctx.Unique {
+ selector.Distinct()
+ }
+ for _, m := range _q.modifiers {
+ m(selector)
+ }
+ for _, p := range _q.predicates {
+ p(selector)
+ }
+ for _, p := range _q.order {
+ p(selector)
+ }
+ if offset := _q.ctx.Offset; offset != nil {
+ // limit is mandatory for offset clause. We start
+ // with default value, and override it below if needed.
+ selector.Offset(*offset).Limit(math.MaxInt32)
+ }
+ if limit := _q.ctx.Limit; limit != nil {
+ selector.Limit(*limit)
+ }
+ return selector
+}
+
+// ForUpdate locks the selected rows against concurrent updates, and prevent them from being
+// updated, deleted or "selected ... for update" by other sessions, until the transaction is
+// either committed or rolled-back.
+func (_q *AuthIdentityQuery) ForUpdate(opts ...sql.LockOption) *AuthIdentityQuery {
+ if _q.driver.Dialect() == dialect.Postgres {
+ _q.Unique(false)
+ }
+ _q.modifiers = append(_q.modifiers, func(s *sql.Selector) {
+ s.ForUpdate(opts...)
+ })
+ return _q
+}
+
+// ForShare behaves similarly to ForUpdate, except that it acquires a shared mode lock
+// on any rows that are read. Other sessions can read the rows, but cannot modify them
+// until your transaction commits.
+func (_q *AuthIdentityQuery) ForShare(opts ...sql.LockOption) *AuthIdentityQuery {
+ if _q.driver.Dialect() == dialect.Postgres {
+ _q.Unique(false)
+ }
+ _q.modifiers = append(_q.modifiers, func(s *sql.Selector) {
+ s.ForShare(opts...)
+ })
+ return _q
+}
+
+// AuthIdentityGroupBy is the group-by builder for AuthIdentity entities.
+type AuthIdentityGroupBy struct {
+ selector
+ build *AuthIdentityQuery
+}
+
+// Aggregate adds the given aggregation functions to the group-by query.
+func (_g *AuthIdentityGroupBy) Aggregate(fns ...AggregateFunc) *AuthIdentityGroupBy {
+ _g.fns = append(_g.fns, fns...)
+ return _g
+}
+
+// Scan applies the selector query and scans the result into the given value.
+func (_g *AuthIdentityGroupBy) Scan(ctx context.Context, v any) error {
+ ctx = setContextOp(ctx, _g.build.ctx, ent.OpQueryGroupBy)
+ if err := _g.build.prepareQuery(ctx); err != nil {
+ return err
+ }
+ return scanWithInterceptors[*AuthIdentityQuery, *AuthIdentityGroupBy](ctx, _g.build, _g, _g.build.inters, v)
+}
+
+func (_g *AuthIdentityGroupBy) sqlScan(ctx context.Context, root *AuthIdentityQuery, v any) error {
+ selector := root.sqlQuery(ctx).Select()
+ aggregation := make([]string, 0, len(_g.fns))
+ for _, fn := range _g.fns {
+ aggregation = append(aggregation, fn(selector))
+ }
+ if len(selector.SelectedColumns()) == 0 {
+ columns := make([]string, 0, len(*_g.flds)+len(_g.fns))
+ for _, f := range *_g.flds {
+ columns = append(columns, selector.C(f))
+ }
+ columns = append(columns, aggregation...)
+ selector.Select(columns...)
+ }
+ selector.GroupBy(selector.Columns(*_g.flds...)...)
+ if err := selector.Err(); err != nil {
+ return err
+ }
+ rows := &sql.Rows{}
+ query, args := selector.Query()
+ if err := _g.build.driver.Query(ctx, query, args, rows); err != nil {
+ return err
+ }
+ defer rows.Close()
+ return sql.ScanSlice(rows, v)
+}
+
+// AuthIdentitySelect is the builder for selecting fields of AuthIdentity entities.
+type AuthIdentitySelect struct {
+ *AuthIdentityQuery
+ selector
+}
+
+// Aggregate adds the given aggregation functions to the selector query.
+func (_s *AuthIdentitySelect) Aggregate(fns ...AggregateFunc) *AuthIdentitySelect {
+ _s.fns = append(_s.fns, fns...)
+ return _s
+}
+
+// Scan applies the selector query and scans the result into the given value.
+func (_s *AuthIdentitySelect) Scan(ctx context.Context, v any) error {
+ ctx = setContextOp(ctx, _s.ctx, ent.OpQuerySelect)
+ if err := _s.prepareQuery(ctx); err != nil {
+ return err
+ }
+ return scanWithInterceptors[*AuthIdentityQuery, *AuthIdentitySelect](ctx, _s.AuthIdentityQuery, _s, _s.inters, v)
+}
+
+func (_s *AuthIdentitySelect) sqlScan(ctx context.Context, root *AuthIdentityQuery, v any) error {
+ selector := root.sqlQuery(ctx)
+ aggregation := make([]string, 0, len(_s.fns))
+ for _, fn := range _s.fns {
+ aggregation = append(aggregation, fn(selector))
+ }
+ switch n := len(*_s.selector.flds); {
+ case n == 0 && len(aggregation) > 0:
+ selector.Select(aggregation...)
+ case n != 0 && len(aggregation) > 0:
+ selector.AppendSelect(aggregation...)
+ }
+ rows := &sql.Rows{}
+ query, args := selector.Query()
+ if err := _s.driver.Query(ctx, query, args, rows); err != nil {
+ return err
+ }
+ defer rows.Close()
+ return sql.ScanSlice(rows, v)
+}
diff --git a/backend/ent/authidentity_update.go b/backend/ent/authidentity_update.go
new file mode 100644
index 00000000..c457470b
--- /dev/null
+++ b/backend/ent/authidentity_update.go
@@ -0,0 +1,923 @@
+// Code generated by ent, DO NOT EDIT.
+
+package ent
+
+import (
+ "context"
+ "errors"
+ "fmt"
+ "time"
+
+ "entgo.io/ent/dialect/sql"
+ "entgo.io/ent/dialect/sql/sqlgraph"
+ "entgo.io/ent/schema/field"
+ "github.com/Wei-Shaw/sub2api/ent/authidentity"
+ "github.com/Wei-Shaw/sub2api/ent/authidentitychannel"
+ "github.com/Wei-Shaw/sub2api/ent/identityadoptiondecision"
+ "github.com/Wei-Shaw/sub2api/ent/predicate"
+ "github.com/Wei-Shaw/sub2api/ent/user"
+)
+
+// AuthIdentityUpdate is the builder for updating AuthIdentity entities.
+type AuthIdentityUpdate struct {
+ config
+ hooks []Hook
+ mutation *AuthIdentityMutation
+}
+
+// Where appends a list predicates to the AuthIdentityUpdate builder.
+func (_u *AuthIdentityUpdate) Where(ps ...predicate.AuthIdentity) *AuthIdentityUpdate {
+ _u.mutation.Where(ps...)
+ return _u
+}
+
+// SetUpdatedAt sets the "updated_at" field.
+func (_u *AuthIdentityUpdate) SetUpdatedAt(v time.Time) *AuthIdentityUpdate {
+ _u.mutation.SetUpdatedAt(v)
+ return _u
+}
+
+// SetUserID sets the "user_id" field.
+func (_u *AuthIdentityUpdate) SetUserID(v int64) *AuthIdentityUpdate {
+ _u.mutation.SetUserID(v)
+ return _u
+}
+
+// SetNillableUserID sets the "user_id" field if the given value is not nil.
+func (_u *AuthIdentityUpdate) SetNillableUserID(v *int64) *AuthIdentityUpdate {
+ if v != nil {
+ _u.SetUserID(*v)
+ }
+ return _u
+}
+
+// SetProviderType sets the "provider_type" field.
+func (_u *AuthIdentityUpdate) SetProviderType(v string) *AuthIdentityUpdate {
+ _u.mutation.SetProviderType(v)
+ return _u
+}
+
+// SetNillableProviderType sets the "provider_type" field if the given value is not nil.
+func (_u *AuthIdentityUpdate) SetNillableProviderType(v *string) *AuthIdentityUpdate {
+ if v != nil {
+ _u.SetProviderType(*v)
+ }
+ return _u
+}
+
+// SetProviderKey sets the "provider_key" field.
+func (_u *AuthIdentityUpdate) SetProviderKey(v string) *AuthIdentityUpdate {
+ _u.mutation.SetProviderKey(v)
+ return _u
+}
+
+// SetNillableProviderKey sets the "provider_key" field if the given value is not nil.
+func (_u *AuthIdentityUpdate) SetNillableProviderKey(v *string) *AuthIdentityUpdate {
+ if v != nil {
+ _u.SetProviderKey(*v)
+ }
+ return _u
+}
+
+// SetProviderSubject sets the "provider_subject" field.
+func (_u *AuthIdentityUpdate) SetProviderSubject(v string) *AuthIdentityUpdate {
+ _u.mutation.SetProviderSubject(v)
+ return _u
+}
+
+// SetNillableProviderSubject sets the "provider_subject" field if the given value is not nil.
+func (_u *AuthIdentityUpdate) SetNillableProviderSubject(v *string) *AuthIdentityUpdate {
+ if v != nil {
+ _u.SetProviderSubject(*v)
+ }
+ return _u
+}
+
+// SetVerifiedAt sets the "verified_at" field.
+func (_u *AuthIdentityUpdate) SetVerifiedAt(v time.Time) *AuthIdentityUpdate {
+ _u.mutation.SetVerifiedAt(v)
+ return _u
+}
+
+// SetNillableVerifiedAt sets the "verified_at" field if the given value is not nil.
+func (_u *AuthIdentityUpdate) SetNillableVerifiedAt(v *time.Time) *AuthIdentityUpdate {
+ if v != nil {
+ _u.SetVerifiedAt(*v)
+ }
+ return _u
+}
+
+// ClearVerifiedAt clears the value of the "verified_at" field.
+func (_u *AuthIdentityUpdate) ClearVerifiedAt() *AuthIdentityUpdate {
+ _u.mutation.ClearVerifiedAt()
+ return _u
+}
+
+// SetIssuer sets the "issuer" field.
+func (_u *AuthIdentityUpdate) SetIssuer(v string) *AuthIdentityUpdate {
+ _u.mutation.SetIssuer(v)
+ return _u
+}
+
+// SetNillableIssuer sets the "issuer" field if the given value is not nil.
+func (_u *AuthIdentityUpdate) SetNillableIssuer(v *string) *AuthIdentityUpdate {
+ if v != nil {
+ _u.SetIssuer(*v)
+ }
+ return _u
+}
+
+// ClearIssuer clears the value of the "issuer" field.
+func (_u *AuthIdentityUpdate) ClearIssuer() *AuthIdentityUpdate {
+ _u.mutation.ClearIssuer()
+ return _u
+}
+
+// SetMetadata sets the "metadata" field.
+func (_u *AuthIdentityUpdate) SetMetadata(v map[string]interface{}) *AuthIdentityUpdate {
+ _u.mutation.SetMetadata(v)
+ return _u
+}
+
+// SetUser sets the "user" edge to the User entity.
+func (_u *AuthIdentityUpdate) SetUser(v *User) *AuthIdentityUpdate {
+ return _u.SetUserID(v.ID)
+}
+
+// AddChannelIDs adds the "channels" edge to the AuthIdentityChannel entity by IDs.
+func (_u *AuthIdentityUpdate) AddChannelIDs(ids ...int64) *AuthIdentityUpdate {
+ _u.mutation.AddChannelIDs(ids...)
+ return _u
+}
+
+// AddChannels adds the "channels" edges to the AuthIdentityChannel entity.
+func (_u *AuthIdentityUpdate) AddChannels(v ...*AuthIdentityChannel) *AuthIdentityUpdate {
+ ids := make([]int64, len(v))
+ for i := range v {
+ ids[i] = v[i].ID
+ }
+ return _u.AddChannelIDs(ids...)
+}
+
+// AddAdoptionDecisionIDs adds the "adoption_decisions" edge to the IdentityAdoptionDecision entity by IDs.
+func (_u *AuthIdentityUpdate) AddAdoptionDecisionIDs(ids ...int64) *AuthIdentityUpdate {
+ _u.mutation.AddAdoptionDecisionIDs(ids...)
+ return _u
+}
+
+// AddAdoptionDecisions adds the "adoption_decisions" edges to the IdentityAdoptionDecision entity.
+func (_u *AuthIdentityUpdate) AddAdoptionDecisions(v ...*IdentityAdoptionDecision) *AuthIdentityUpdate {
+ ids := make([]int64, len(v))
+ for i := range v {
+ ids[i] = v[i].ID
+ }
+ return _u.AddAdoptionDecisionIDs(ids...)
+}
+
+// Mutation returns the AuthIdentityMutation object of the builder.
+func (_u *AuthIdentityUpdate) Mutation() *AuthIdentityMutation {
+ return _u.mutation
+}
+
+// ClearUser clears the "user" edge to the User entity.
+func (_u *AuthIdentityUpdate) ClearUser() *AuthIdentityUpdate {
+ _u.mutation.ClearUser()
+ return _u
+}
+
+// ClearChannels clears all "channels" edges to the AuthIdentityChannel entity.
+func (_u *AuthIdentityUpdate) ClearChannels() *AuthIdentityUpdate {
+ _u.mutation.ClearChannels()
+ return _u
+}
+
+// RemoveChannelIDs removes the "channels" edge to AuthIdentityChannel entities by IDs.
+func (_u *AuthIdentityUpdate) RemoveChannelIDs(ids ...int64) *AuthIdentityUpdate {
+ _u.mutation.RemoveChannelIDs(ids...)
+ return _u
+}
+
+// RemoveChannels removes "channels" edges to AuthIdentityChannel entities.
+func (_u *AuthIdentityUpdate) RemoveChannels(v ...*AuthIdentityChannel) *AuthIdentityUpdate {
+ ids := make([]int64, len(v))
+ for i := range v {
+ ids[i] = v[i].ID
+ }
+ return _u.RemoveChannelIDs(ids...)
+}
+
+// ClearAdoptionDecisions clears all "adoption_decisions" edges to the IdentityAdoptionDecision entity.
+func (_u *AuthIdentityUpdate) ClearAdoptionDecisions() *AuthIdentityUpdate {
+ _u.mutation.ClearAdoptionDecisions()
+ return _u
+}
+
+// RemoveAdoptionDecisionIDs removes the "adoption_decisions" edge to IdentityAdoptionDecision entities by IDs.
+func (_u *AuthIdentityUpdate) RemoveAdoptionDecisionIDs(ids ...int64) *AuthIdentityUpdate {
+ _u.mutation.RemoveAdoptionDecisionIDs(ids...)
+ return _u
+}
+
+// RemoveAdoptionDecisions removes "adoption_decisions" edges to IdentityAdoptionDecision entities.
+func (_u *AuthIdentityUpdate) RemoveAdoptionDecisions(v ...*IdentityAdoptionDecision) *AuthIdentityUpdate {
+ ids := make([]int64, len(v))
+ for i := range v {
+ ids[i] = v[i].ID
+ }
+ return _u.RemoveAdoptionDecisionIDs(ids...)
+}
+
+// Save executes the query and returns the number of nodes affected by the update operation.
+func (_u *AuthIdentityUpdate) Save(ctx context.Context) (int, error) {
+ _u.defaults()
+ return withHooks(ctx, _u.sqlSave, _u.mutation, _u.hooks)
+}
+
+// SaveX is like Save, but panics if an error occurs.
+func (_u *AuthIdentityUpdate) SaveX(ctx context.Context) int {
+ affected, err := _u.Save(ctx)
+ if err != nil {
+ panic(err)
+ }
+ return affected
+}
+
+// Exec executes the query.
+func (_u *AuthIdentityUpdate) Exec(ctx context.Context) error {
+ _, err := _u.Save(ctx)
+ return err
+}
+
+// ExecX is like Exec, but panics if an error occurs.
+func (_u *AuthIdentityUpdate) ExecX(ctx context.Context) {
+ if err := _u.Exec(ctx); err != nil {
+ panic(err)
+ }
+}
+
+// defaults sets the default values of the builder before save.
+func (_u *AuthIdentityUpdate) defaults() {
+ if _, ok := _u.mutation.UpdatedAt(); !ok {
+ v := authidentity.UpdateDefaultUpdatedAt()
+ _u.mutation.SetUpdatedAt(v)
+ }
+}
+
+// check runs all checks and user-defined validators on the builder.
+func (_u *AuthIdentityUpdate) check() error {
+ if v, ok := _u.mutation.ProviderType(); ok {
+ if err := authidentity.ProviderTypeValidator(v); err != nil {
+ return &ValidationError{Name: "provider_type", err: fmt.Errorf(`ent: validator failed for field "AuthIdentity.provider_type": %w`, err)}
+ }
+ }
+ if v, ok := _u.mutation.ProviderKey(); ok {
+ if err := authidentity.ProviderKeyValidator(v); err != nil {
+ return &ValidationError{Name: "provider_key", err: fmt.Errorf(`ent: validator failed for field "AuthIdentity.provider_key": %w`, err)}
+ }
+ }
+ if v, ok := _u.mutation.ProviderSubject(); ok {
+ if err := authidentity.ProviderSubjectValidator(v); err != nil {
+ return &ValidationError{Name: "provider_subject", err: fmt.Errorf(`ent: validator failed for field "AuthIdentity.provider_subject": %w`, err)}
+ }
+ }
+ if _u.mutation.UserCleared() && len(_u.mutation.UserIDs()) > 0 {
+ return errors.New(`ent: clearing a required unique edge "AuthIdentity.user"`)
+ }
+ return nil
+}
+
+func (_u *AuthIdentityUpdate) sqlSave(ctx context.Context) (_node int, err error) {
+ if err := _u.check(); err != nil {
+ return _node, err
+ }
+ _spec := sqlgraph.NewUpdateSpec(authidentity.Table, authidentity.Columns, sqlgraph.NewFieldSpec(authidentity.FieldID, field.TypeInt64))
+ if ps := _u.mutation.predicates; len(ps) > 0 {
+ _spec.Predicate = func(selector *sql.Selector) {
+ for i := range ps {
+ ps[i](selector)
+ }
+ }
+ }
+ if value, ok := _u.mutation.UpdatedAt(); ok {
+ _spec.SetField(authidentity.FieldUpdatedAt, field.TypeTime, value)
+ }
+ if value, ok := _u.mutation.ProviderType(); ok {
+ _spec.SetField(authidentity.FieldProviderType, field.TypeString, value)
+ }
+ if value, ok := _u.mutation.ProviderKey(); ok {
+ _spec.SetField(authidentity.FieldProviderKey, field.TypeString, value)
+ }
+ if value, ok := _u.mutation.ProviderSubject(); ok {
+ _spec.SetField(authidentity.FieldProviderSubject, field.TypeString, value)
+ }
+ if value, ok := _u.mutation.VerifiedAt(); ok {
+ _spec.SetField(authidentity.FieldVerifiedAt, field.TypeTime, value)
+ }
+ if _u.mutation.VerifiedAtCleared() {
+ _spec.ClearField(authidentity.FieldVerifiedAt, field.TypeTime)
+ }
+ if value, ok := _u.mutation.Issuer(); ok {
+ _spec.SetField(authidentity.FieldIssuer, field.TypeString, value)
+ }
+ if _u.mutation.IssuerCleared() {
+ _spec.ClearField(authidentity.FieldIssuer, field.TypeString)
+ }
+ if value, ok := _u.mutation.Metadata(); ok {
+ _spec.SetField(authidentity.FieldMetadata, field.TypeJSON, value)
+ }
+ if _u.mutation.UserCleared() {
+ edge := &sqlgraph.EdgeSpec{
+ Rel: sqlgraph.M2O,
+ Inverse: true,
+ Table: authidentity.UserTable,
+ Columns: []string{authidentity.UserColumn},
+ Bidi: false,
+ Target: &sqlgraph.EdgeTarget{
+ IDSpec: sqlgraph.NewFieldSpec(user.FieldID, field.TypeInt64),
+ },
+ }
+ _spec.Edges.Clear = append(_spec.Edges.Clear, edge)
+ }
+ if nodes := _u.mutation.UserIDs(); len(nodes) > 0 {
+ edge := &sqlgraph.EdgeSpec{
+ Rel: sqlgraph.M2O,
+ Inverse: true,
+ Table: authidentity.UserTable,
+ Columns: []string{authidentity.UserColumn},
+ Bidi: false,
+ Target: &sqlgraph.EdgeTarget{
+ IDSpec: sqlgraph.NewFieldSpec(user.FieldID, field.TypeInt64),
+ },
+ }
+ for _, k := range nodes {
+ edge.Target.Nodes = append(edge.Target.Nodes, k)
+ }
+ _spec.Edges.Add = append(_spec.Edges.Add, edge)
+ }
+ if _u.mutation.ChannelsCleared() {
+ edge := &sqlgraph.EdgeSpec{
+ Rel: sqlgraph.O2M,
+ Inverse: false,
+ Table: authidentity.ChannelsTable,
+ Columns: []string{authidentity.ChannelsColumn},
+ Bidi: false,
+ Target: &sqlgraph.EdgeTarget{
+ IDSpec: sqlgraph.NewFieldSpec(authidentitychannel.FieldID, field.TypeInt64),
+ },
+ }
+ _spec.Edges.Clear = append(_spec.Edges.Clear, edge)
+ }
+ if nodes := _u.mutation.RemovedChannelsIDs(); len(nodes) > 0 && !_u.mutation.ChannelsCleared() {
+ edge := &sqlgraph.EdgeSpec{
+ Rel: sqlgraph.O2M,
+ Inverse: false,
+ Table: authidentity.ChannelsTable,
+ Columns: []string{authidentity.ChannelsColumn},
+ Bidi: false,
+ Target: &sqlgraph.EdgeTarget{
+ IDSpec: sqlgraph.NewFieldSpec(authidentitychannel.FieldID, field.TypeInt64),
+ },
+ }
+ for _, k := range nodes {
+ edge.Target.Nodes = append(edge.Target.Nodes, k)
+ }
+ _spec.Edges.Clear = append(_spec.Edges.Clear, edge)
+ }
+ if nodes := _u.mutation.ChannelsIDs(); len(nodes) > 0 {
+ edge := &sqlgraph.EdgeSpec{
+ Rel: sqlgraph.O2M,
+ Inverse: false,
+ Table: authidentity.ChannelsTable,
+ Columns: []string{authidentity.ChannelsColumn},
+ Bidi: false,
+ Target: &sqlgraph.EdgeTarget{
+ IDSpec: sqlgraph.NewFieldSpec(authidentitychannel.FieldID, field.TypeInt64),
+ },
+ }
+ for _, k := range nodes {
+ edge.Target.Nodes = append(edge.Target.Nodes, k)
+ }
+ _spec.Edges.Add = append(_spec.Edges.Add, edge)
+ }
+ if _u.mutation.AdoptionDecisionsCleared() {
+ edge := &sqlgraph.EdgeSpec{
+ Rel: sqlgraph.O2M,
+ Inverse: false,
+ Table: authidentity.AdoptionDecisionsTable,
+ Columns: []string{authidentity.AdoptionDecisionsColumn},
+ Bidi: false,
+ Target: &sqlgraph.EdgeTarget{
+ IDSpec: sqlgraph.NewFieldSpec(identityadoptiondecision.FieldID, field.TypeInt64),
+ },
+ }
+ _spec.Edges.Clear = append(_spec.Edges.Clear, edge)
+ }
+ if nodes := _u.mutation.RemovedAdoptionDecisionsIDs(); len(nodes) > 0 && !_u.mutation.AdoptionDecisionsCleared() {
+ edge := &sqlgraph.EdgeSpec{
+ Rel: sqlgraph.O2M,
+ Inverse: false,
+ Table: authidentity.AdoptionDecisionsTable,
+ Columns: []string{authidentity.AdoptionDecisionsColumn},
+ Bidi: false,
+ Target: &sqlgraph.EdgeTarget{
+ IDSpec: sqlgraph.NewFieldSpec(identityadoptiondecision.FieldID, field.TypeInt64),
+ },
+ }
+ for _, k := range nodes {
+ edge.Target.Nodes = append(edge.Target.Nodes, k)
+ }
+ _spec.Edges.Clear = append(_spec.Edges.Clear, edge)
+ }
+ if nodes := _u.mutation.AdoptionDecisionsIDs(); len(nodes) > 0 {
+ edge := &sqlgraph.EdgeSpec{
+ Rel: sqlgraph.O2M,
+ Inverse: false,
+ Table: authidentity.AdoptionDecisionsTable,
+ Columns: []string{authidentity.AdoptionDecisionsColumn},
+ Bidi: false,
+ Target: &sqlgraph.EdgeTarget{
+ IDSpec: sqlgraph.NewFieldSpec(identityadoptiondecision.FieldID, field.TypeInt64),
+ },
+ }
+ for _, k := range nodes {
+ edge.Target.Nodes = append(edge.Target.Nodes, k)
+ }
+ _spec.Edges.Add = append(_spec.Edges.Add, edge)
+ }
+ if _node, err = sqlgraph.UpdateNodes(ctx, _u.driver, _spec); err != nil {
+ if _, ok := err.(*sqlgraph.NotFoundError); ok {
+ err = &NotFoundError{authidentity.Label}
+ } else if sqlgraph.IsConstraintError(err) {
+ err = &ConstraintError{msg: err.Error(), wrap: err}
+ }
+ return 0, err
+ }
+ _u.mutation.done = true
+ return _node, nil
+}
+
+// AuthIdentityUpdateOne is the builder for updating a single AuthIdentity entity.
+type AuthIdentityUpdateOne struct {
+ config
+ fields []string
+ hooks []Hook
+ mutation *AuthIdentityMutation
+}
+
+// SetUpdatedAt sets the "updated_at" field.
+func (_u *AuthIdentityUpdateOne) SetUpdatedAt(v time.Time) *AuthIdentityUpdateOne {
+ _u.mutation.SetUpdatedAt(v)
+ return _u
+}
+
+// SetUserID sets the "user_id" field.
+func (_u *AuthIdentityUpdateOne) SetUserID(v int64) *AuthIdentityUpdateOne {
+ _u.mutation.SetUserID(v)
+ return _u
+}
+
+// SetNillableUserID sets the "user_id" field if the given value is not nil.
+func (_u *AuthIdentityUpdateOne) SetNillableUserID(v *int64) *AuthIdentityUpdateOne {
+ if v != nil {
+ _u.SetUserID(*v)
+ }
+ return _u
+}
+
+// SetProviderType sets the "provider_type" field.
+func (_u *AuthIdentityUpdateOne) SetProviderType(v string) *AuthIdentityUpdateOne {
+ _u.mutation.SetProviderType(v)
+ return _u
+}
+
+// SetNillableProviderType sets the "provider_type" field if the given value is not nil.
+func (_u *AuthIdentityUpdateOne) SetNillableProviderType(v *string) *AuthIdentityUpdateOne {
+ if v != nil {
+ _u.SetProviderType(*v)
+ }
+ return _u
+}
+
+// SetProviderKey sets the "provider_key" field.
+func (_u *AuthIdentityUpdateOne) SetProviderKey(v string) *AuthIdentityUpdateOne {
+ _u.mutation.SetProviderKey(v)
+ return _u
+}
+
+// SetNillableProviderKey sets the "provider_key" field if the given value is not nil.
+func (_u *AuthIdentityUpdateOne) SetNillableProviderKey(v *string) *AuthIdentityUpdateOne {
+ if v != nil {
+ _u.SetProviderKey(*v)
+ }
+ return _u
+}
+
+// SetProviderSubject sets the "provider_subject" field.
+func (_u *AuthIdentityUpdateOne) SetProviderSubject(v string) *AuthIdentityUpdateOne {
+ _u.mutation.SetProviderSubject(v)
+ return _u
+}
+
+// SetNillableProviderSubject sets the "provider_subject" field if the given value is not nil.
+func (_u *AuthIdentityUpdateOne) SetNillableProviderSubject(v *string) *AuthIdentityUpdateOne {
+ if v != nil {
+ _u.SetProviderSubject(*v)
+ }
+ return _u
+}
+
+// SetVerifiedAt sets the "verified_at" field.
+func (_u *AuthIdentityUpdateOne) SetVerifiedAt(v time.Time) *AuthIdentityUpdateOne {
+ _u.mutation.SetVerifiedAt(v)
+ return _u
+}
+
+// SetNillableVerifiedAt sets the "verified_at" field if the given value is not nil.
+func (_u *AuthIdentityUpdateOne) SetNillableVerifiedAt(v *time.Time) *AuthIdentityUpdateOne {
+ if v != nil {
+ _u.SetVerifiedAt(*v)
+ }
+ return _u
+}
+
+// ClearVerifiedAt clears the value of the "verified_at" field.
+func (_u *AuthIdentityUpdateOne) ClearVerifiedAt() *AuthIdentityUpdateOne {
+ _u.mutation.ClearVerifiedAt()
+ return _u
+}
+
+// SetIssuer sets the "issuer" field.
+func (_u *AuthIdentityUpdateOne) SetIssuer(v string) *AuthIdentityUpdateOne {
+ _u.mutation.SetIssuer(v)
+ return _u
+}
+
+// SetNillableIssuer sets the "issuer" field if the given value is not nil.
+func (_u *AuthIdentityUpdateOne) SetNillableIssuer(v *string) *AuthIdentityUpdateOne {
+ if v != nil {
+ _u.SetIssuer(*v)
+ }
+ return _u
+}
+
+// ClearIssuer clears the value of the "issuer" field.
+func (_u *AuthIdentityUpdateOne) ClearIssuer() *AuthIdentityUpdateOne {
+ _u.mutation.ClearIssuer()
+ return _u
+}
+
+// SetMetadata sets the "metadata" field.
+func (_u *AuthIdentityUpdateOne) SetMetadata(v map[string]interface{}) *AuthIdentityUpdateOne {
+ _u.mutation.SetMetadata(v)
+ return _u
+}
+
+// SetUser sets the "user" edge to the User entity.
+func (_u *AuthIdentityUpdateOne) SetUser(v *User) *AuthIdentityUpdateOne {
+ return _u.SetUserID(v.ID)
+}
+
+// AddChannelIDs adds the "channels" edge to the AuthIdentityChannel entity by IDs.
+func (_u *AuthIdentityUpdateOne) AddChannelIDs(ids ...int64) *AuthIdentityUpdateOne {
+ _u.mutation.AddChannelIDs(ids...)
+ return _u
+}
+
+// AddChannels adds the "channels" edges to the AuthIdentityChannel entity.
+func (_u *AuthIdentityUpdateOne) AddChannels(v ...*AuthIdentityChannel) *AuthIdentityUpdateOne {
+ ids := make([]int64, len(v))
+ for i := range v {
+ ids[i] = v[i].ID
+ }
+ return _u.AddChannelIDs(ids...)
+}
+
+// AddAdoptionDecisionIDs adds the "adoption_decisions" edge to the IdentityAdoptionDecision entity by IDs.
+func (_u *AuthIdentityUpdateOne) AddAdoptionDecisionIDs(ids ...int64) *AuthIdentityUpdateOne {
+ _u.mutation.AddAdoptionDecisionIDs(ids...)
+ return _u
+}
+
+// AddAdoptionDecisions adds the "adoption_decisions" edges to the IdentityAdoptionDecision entity.
+func (_u *AuthIdentityUpdateOne) AddAdoptionDecisions(v ...*IdentityAdoptionDecision) *AuthIdentityUpdateOne {
+ ids := make([]int64, len(v))
+ for i := range v {
+ ids[i] = v[i].ID
+ }
+ return _u.AddAdoptionDecisionIDs(ids...)
+}
+
+// Mutation returns the AuthIdentityMutation object of the builder.
+func (_u *AuthIdentityUpdateOne) Mutation() *AuthIdentityMutation {
+ return _u.mutation
+}
+
+// ClearUser clears the "user" edge to the User entity.
+func (_u *AuthIdentityUpdateOne) ClearUser() *AuthIdentityUpdateOne {
+ _u.mutation.ClearUser()
+ return _u
+}
+
+// ClearChannels clears all "channels" edges to the AuthIdentityChannel entity.
+func (_u *AuthIdentityUpdateOne) ClearChannels() *AuthIdentityUpdateOne {
+ _u.mutation.ClearChannels()
+ return _u
+}
+
+// RemoveChannelIDs removes the "channels" edge to AuthIdentityChannel entities by IDs.
+func (_u *AuthIdentityUpdateOne) RemoveChannelIDs(ids ...int64) *AuthIdentityUpdateOne {
+ _u.mutation.RemoveChannelIDs(ids...)
+ return _u
+}
+
+// RemoveChannels removes "channels" edges to AuthIdentityChannel entities.
+func (_u *AuthIdentityUpdateOne) RemoveChannels(v ...*AuthIdentityChannel) *AuthIdentityUpdateOne {
+ ids := make([]int64, len(v))
+ for i := range v {
+ ids[i] = v[i].ID
+ }
+ return _u.RemoveChannelIDs(ids...)
+}
+
+// ClearAdoptionDecisions clears all "adoption_decisions" edges to the IdentityAdoptionDecision entity.
+func (_u *AuthIdentityUpdateOne) ClearAdoptionDecisions() *AuthIdentityUpdateOne {
+ _u.mutation.ClearAdoptionDecisions()
+ return _u
+}
+
+// RemoveAdoptionDecisionIDs removes the "adoption_decisions" edge to IdentityAdoptionDecision entities by IDs.
+func (_u *AuthIdentityUpdateOne) RemoveAdoptionDecisionIDs(ids ...int64) *AuthIdentityUpdateOne {
+ _u.mutation.RemoveAdoptionDecisionIDs(ids...)
+ return _u
+}
+
+// RemoveAdoptionDecisions removes "adoption_decisions" edges to IdentityAdoptionDecision entities.
+func (_u *AuthIdentityUpdateOne) RemoveAdoptionDecisions(v ...*IdentityAdoptionDecision) *AuthIdentityUpdateOne {
+ ids := make([]int64, len(v))
+ for i := range v {
+ ids[i] = v[i].ID
+ }
+ return _u.RemoveAdoptionDecisionIDs(ids...)
+}
+
+// Where appends a list predicates to the AuthIdentityUpdate builder.
+func (_u *AuthIdentityUpdateOne) Where(ps ...predicate.AuthIdentity) *AuthIdentityUpdateOne {
+ _u.mutation.Where(ps...)
+ return _u
+}
+
+// Select allows selecting one or more fields (columns) of the returned entity.
+// The default is selecting all fields defined in the entity schema.
+func (_u *AuthIdentityUpdateOne) Select(field string, fields ...string) *AuthIdentityUpdateOne {
+ _u.fields = append([]string{field}, fields...)
+ return _u
+}
+
+// Save executes the query and returns the updated AuthIdentity entity.
+func (_u *AuthIdentityUpdateOne) Save(ctx context.Context) (*AuthIdentity, error) {
+ _u.defaults()
+ return withHooks(ctx, _u.sqlSave, _u.mutation, _u.hooks)
+}
+
+// SaveX is like Save, but panics if an error occurs.
+func (_u *AuthIdentityUpdateOne) SaveX(ctx context.Context) *AuthIdentity {
+ node, err := _u.Save(ctx)
+ if err != nil {
+ panic(err)
+ }
+ return node
+}
+
+// Exec executes the query on the entity.
+func (_u *AuthIdentityUpdateOne) Exec(ctx context.Context) error {
+ _, err := _u.Save(ctx)
+ return err
+}
+
+// ExecX is like Exec, but panics if an error occurs.
+func (_u *AuthIdentityUpdateOne) ExecX(ctx context.Context) {
+ if err := _u.Exec(ctx); err != nil {
+ panic(err)
+ }
+}
+
+// defaults sets the default values of the builder before save.
+func (_u *AuthIdentityUpdateOne) defaults() {
+ if _, ok := _u.mutation.UpdatedAt(); !ok {
+ v := authidentity.UpdateDefaultUpdatedAt()
+ _u.mutation.SetUpdatedAt(v)
+ }
+}
+
+// check runs all checks and user-defined validators on the builder.
+func (_u *AuthIdentityUpdateOne) check() error {
+ if v, ok := _u.mutation.ProviderType(); ok {
+ if err := authidentity.ProviderTypeValidator(v); err != nil {
+ return &ValidationError{Name: "provider_type", err: fmt.Errorf(`ent: validator failed for field "AuthIdentity.provider_type": %w`, err)}
+ }
+ }
+ if v, ok := _u.mutation.ProviderKey(); ok {
+ if err := authidentity.ProviderKeyValidator(v); err != nil {
+ return &ValidationError{Name: "provider_key", err: fmt.Errorf(`ent: validator failed for field "AuthIdentity.provider_key": %w`, err)}
+ }
+ }
+ if v, ok := _u.mutation.ProviderSubject(); ok {
+ if err := authidentity.ProviderSubjectValidator(v); err != nil {
+ return &ValidationError{Name: "provider_subject", err: fmt.Errorf(`ent: validator failed for field "AuthIdentity.provider_subject": %w`, err)}
+ }
+ }
+ if _u.mutation.UserCleared() && len(_u.mutation.UserIDs()) > 0 {
+ return errors.New(`ent: clearing a required unique edge "AuthIdentity.user"`)
+ }
+ return nil
+}
+
+func (_u *AuthIdentityUpdateOne) sqlSave(ctx context.Context) (_node *AuthIdentity, err error) {
+ if err := _u.check(); err != nil {
+ return _node, err
+ }
+ _spec := sqlgraph.NewUpdateSpec(authidentity.Table, authidentity.Columns, sqlgraph.NewFieldSpec(authidentity.FieldID, field.TypeInt64))
+ id, ok := _u.mutation.ID()
+ if !ok {
+ return nil, &ValidationError{Name: "id", err: errors.New(`ent: missing "AuthIdentity.id" for update`)}
+ }
+ _spec.Node.ID.Value = id
+ if fields := _u.fields; len(fields) > 0 {
+ _spec.Node.Columns = make([]string, 0, len(fields))
+ _spec.Node.Columns = append(_spec.Node.Columns, authidentity.FieldID)
+ for _, f := range fields {
+ if !authidentity.ValidColumn(f) {
+ return nil, &ValidationError{Name: f, err: fmt.Errorf("ent: invalid field %q for query", f)}
+ }
+ if f != authidentity.FieldID {
+ _spec.Node.Columns = append(_spec.Node.Columns, f)
+ }
+ }
+ }
+ if ps := _u.mutation.predicates; len(ps) > 0 {
+ _spec.Predicate = func(selector *sql.Selector) {
+ for i := range ps {
+ ps[i](selector)
+ }
+ }
+ }
+ if value, ok := _u.mutation.UpdatedAt(); ok {
+ _spec.SetField(authidentity.FieldUpdatedAt, field.TypeTime, value)
+ }
+ if value, ok := _u.mutation.ProviderType(); ok {
+ _spec.SetField(authidentity.FieldProviderType, field.TypeString, value)
+ }
+ if value, ok := _u.mutation.ProviderKey(); ok {
+ _spec.SetField(authidentity.FieldProviderKey, field.TypeString, value)
+ }
+ if value, ok := _u.mutation.ProviderSubject(); ok {
+ _spec.SetField(authidentity.FieldProviderSubject, field.TypeString, value)
+ }
+ if value, ok := _u.mutation.VerifiedAt(); ok {
+ _spec.SetField(authidentity.FieldVerifiedAt, field.TypeTime, value)
+ }
+ if _u.mutation.VerifiedAtCleared() {
+ _spec.ClearField(authidentity.FieldVerifiedAt, field.TypeTime)
+ }
+ if value, ok := _u.mutation.Issuer(); ok {
+ _spec.SetField(authidentity.FieldIssuer, field.TypeString, value)
+ }
+ if _u.mutation.IssuerCleared() {
+ _spec.ClearField(authidentity.FieldIssuer, field.TypeString)
+ }
+ if value, ok := _u.mutation.Metadata(); ok {
+ _spec.SetField(authidentity.FieldMetadata, field.TypeJSON, value)
+ }
+ if _u.mutation.UserCleared() {
+ edge := &sqlgraph.EdgeSpec{
+ Rel: sqlgraph.M2O,
+ Inverse: true,
+ Table: authidentity.UserTable,
+ Columns: []string{authidentity.UserColumn},
+ Bidi: false,
+ Target: &sqlgraph.EdgeTarget{
+ IDSpec: sqlgraph.NewFieldSpec(user.FieldID, field.TypeInt64),
+ },
+ }
+ _spec.Edges.Clear = append(_spec.Edges.Clear, edge)
+ }
+ if nodes := _u.mutation.UserIDs(); len(nodes) > 0 {
+ edge := &sqlgraph.EdgeSpec{
+ Rel: sqlgraph.M2O,
+ Inverse: true,
+ Table: authidentity.UserTable,
+ Columns: []string{authidentity.UserColumn},
+ Bidi: false,
+ Target: &sqlgraph.EdgeTarget{
+ IDSpec: sqlgraph.NewFieldSpec(user.FieldID, field.TypeInt64),
+ },
+ }
+ for _, k := range nodes {
+ edge.Target.Nodes = append(edge.Target.Nodes, k)
+ }
+ _spec.Edges.Add = append(_spec.Edges.Add, edge)
+ }
+ if _u.mutation.ChannelsCleared() {
+ edge := &sqlgraph.EdgeSpec{
+ Rel: sqlgraph.O2M,
+ Inverse: false,
+ Table: authidentity.ChannelsTable,
+ Columns: []string{authidentity.ChannelsColumn},
+ Bidi: false,
+ Target: &sqlgraph.EdgeTarget{
+ IDSpec: sqlgraph.NewFieldSpec(authidentitychannel.FieldID, field.TypeInt64),
+ },
+ }
+ _spec.Edges.Clear = append(_spec.Edges.Clear, edge)
+ }
+ if nodes := _u.mutation.RemovedChannelsIDs(); len(nodes) > 0 && !_u.mutation.ChannelsCleared() {
+ edge := &sqlgraph.EdgeSpec{
+ Rel: sqlgraph.O2M,
+ Inverse: false,
+ Table: authidentity.ChannelsTable,
+ Columns: []string{authidentity.ChannelsColumn},
+ Bidi: false,
+ Target: &sqlgraph.EdgeTarget{
+ IDSpec: sqlgraph.NewFieldSpec(authidentitychannel.FieldID, field.TypeInt64),
+ },
+ }
+ for _, k := range nodes {
+ edge.Target.Nodes = append(edge.Target.Nodes, k)
+ }
+ _spec.Edges.Clear = append(_spec.Edges.Clear, edge)
+ }
+ if nodes := _u.mutation.ChannelsIDs(); len(nodes) > 0 {
+ edge := &sqlgraph.EdgeSpec{
+ Rel: sqlgraph.O2M,
+ Inverse: false,
+ Table: authidentity.ChannelsTable,
+ Columns: []string{authidentity.ChannelsColumn},
+ Bidi: false,
+ Target: &sqlgraph.EdgeTarget{
+ IDSpec: sqlgraph.NewFieldSpec(authidentitychannel.FieldID, field.TypeInt64),
+ },
+ }
+ for _, k := range nodes {
+ edge.Target.Nodes = append(edge.Target.Nodes, k)
+ }
+ _spec.Edges.Add = append(_spec.Edges.Add, edge)
+ }
+ if _u.mutation.AdoptionDecisionsCleared() {
+ edge := &sqlgraph.EdgeSpec{
+ Rel: sqlgraph.O2M,
+ Inverse: false,
+ Table: authidentity.AdoptionDecisionsTable,
+ Columns: []string{authidentity.AdoptionDecisionsColumn},
+ Bidi: false,
+ Target: &sqlgraph.EdgeTarget{
+ IDSpec: sqlgraph.NewFieldSpec(identityadoptiondecision.FieldID, field.TypeInt64),
+ },
+ }
+ _spec.Edges.Clear = append(_spec.Edges.Clear, edge)
+ }
+ if nodes := _u.mutation.RemovedAdoptionDecisionsIDs(); len(nodes) > 0 && !_u.mutation.AdoptionDecisionsCleared() {
+ edge := &sqlgraph.EdgeSpec{
+ Rel: sqlgraph.O2M,
+ Inverse: false,
+ Table: authidentity.AdoptionDecisionsTable,
+ Columns: []string{authidentity.AdoptionDecisionsColumn},
+ Bidi: false,
+ Target: &sqlgraph.EdgeTarget{
+ IDSpec: sqlgraph.NewFieldSpec(identityadoptiondecision.FieldID, field.TypeInt64),
+ },
+ }
+ for _, k := range nodes {
+ edge.Target.Nodes = append(edge.Target.Nodes, k)
+ }
+ _spec.Edges.Clear = append(_spec.Edges.Clear, edge)
+ }
+ if nodes := _u.mutation.AdoptionDecisionsIDs(); len(nodes) > 0 {
+ edge := &sqlgraph.EdgeSpec{
+ Rel: sqlgraph.O2M,
+ Inverse: false,
+ Table: authidentity.AdoptionDecisionsTable,
+ Columns: []string{authidentity.AdoptionDecisionsColumn},
+ Bidi: false,
+ Target: &sqlgraph.EdgeTarget{
+ IDSpec: sqlgraph.NewFieldSpec(identityadoptiondecision.FieldID, field.TypeInt64),
+ },
+ }
+ for _, k := range nodes {
+ edge.Target.Nodes = append(edge.Target.Nodes, k)
+ }
+ _spec.Edges.Add = append(_spec.Edges.Add, edge)
+ }
+ _node = &AuthIdentity{config: _u.config}
+ _spec.Assign = _node.assignValues
+ _spec.ScanValues = _node.scanValues
+ if err = sqlgraph.UpdateNode(ctx, _u.driver, _spec); err != nil {
+ if _, ok := err.(*sqlgraph.NotFoundError); ok {
+ err = &NotFoundError{authidentity.Label}
+ } else if sqlgraph.IsConstraintError(err) {
+ err = &ConstraintError{msg: err.Error(), wrap: err}
+ }
+ return nil, err
+ }
+ _u.mutation.done = true
+ return _node, nil
+}
diff --git a/backend/ent/authidentitychannel.go b/backend/ent/authidentitychannel.go
new file mode 100644
index 00000000..1ff3e5d1
--- /dev/null
+++ b/backend/ent/authidentitychannel.go
@@ -0,0 +1,228 @@
+// Code generated by ent, DO NOT EDIT.
+
+package ent
+
+import (
+ "encoding/json"
+ "fmt"
+ "strings"
+ "time"
+
+ "entgo.io/ent"
+ "entgo.io/ent/dialect/sql"
+ "github.com/Wei-Shaw/sub2api/ent/authidentity"
+ "github.com/Wei-Shaw/sub2api/ent/authidentitychannel"
+)
+
+// AuthIdentityChannel is the model entity for the AuthIdentityChannel schema.
+type AuthIdentityChannel struct {
+ config `json:"-"`
+ // ID of the ent.
+ ID int64 `json:"id,omitempty"`
+ // CreatedAt holds the value of the "created_at" field.
+ CreatedAt time.Time `json:"created_at,omitempty"`
+ // UpdatedAt holds the value of the "updated_at" field.
+ UpdatedAt time.Time `json:"updated_at,omitempty"`
+ // IdentityID holds the value of the "identity_id" field.
+ IdentityID int64 `json:"identity_id,omitempty"`
+ // ProviderType holds the value of the "provider_type" field.
+ ProviderType string `json:"provider_type,omitempty"`
+ // ProviderKey holds the value of the "provider_key" field.
+ ProviderKey string `json:"provider_key,omitempty"`
+ // Channel holds the value of the "channel" field.
+ Channel string `json:"channel,omitempty"`
+ // ChannelAppID holds the value of the "channel_app_id" field.
+ ChannelAppID string `json:"channel_app_id,omitempty"`
+ // ChannelSubject holds the value of the "channel_subject" field.
+ ChannelSubject string `json:"channel_subject,omitempty"`
+ // Metadata holds the value of the "metadata" field.
+ Metadata map[string]interface{} `json:"metadata,omitempty"`
+ // Edges holds the relations/edges for other nodes in the graph.
+ // The values are being populated by the AuthIdentityChannelQuery when eager-loading is set.
+ Edges AuthIdentityChannelEdges `json:"edges"`
+ selectValues sql.SelectValues
+}
+
+// AuthIdentityChannelEdges holds the relations/edges for other nodes in the graph.
+type AuthIdentityChannelEdges struct {
+ // Identity holds the value of the identity edge.
+ Identity *AuthIdentity `json:"identity,omitempty"`
+ // loadedTypes holds the information for reporting if a
+ // type was loaded (or requested) in eager-loading or not.
+ loadedTypes [1]bool
+}
+
+// IdentityOrErr returns the Identity value or an error if the edge
+// was not loaded in eager-loading, or loaded but was not found.
+func (e AuthIdentityChannelEdges) IdentityOrErr() (*AuthIdentity, error) {
+ if e.Identity != nil {
+ return e.Identity, nil
+ } else if e.loadedTypes[0] {
+ return nil, &NotFoundError{label: authidentity.Label}
+ }
+ return nil, &NotLoadedError{edge: "identity"}
+}
+
+// scanValues returns the types for scanning values from sql.Rows.
+func (*AuthIdentityChannel) scanValues(columns []string) ([]any, error) {
+ values := make([]any, len(columns))
+ for i := range columns {
+ switch columns[i] {
+ case authidentitychannel.FieldMetadata:
+ values[i] = new([]byte)
+ case authidentitychannel.FieldID, authidentitychannel.FieldIdentityID:
+ values[i] = new(sql.NullInt64)
+ case authidentitychannel.FieldProviderType, authidentitychannel.FieldProviderKey, authidentitychannel.FieldChannel, authidentitychannel.FieldChannelAppID, authidentitychannel.FieldChannelSubject:
+ values[i] = new(sql.NullString)
+ case authidentitychannel.FieldCreatedAt, authidentitychannel.FieldUpdatedAt:
+ values[i] = new(sql.NullTime)
+ default:
+ values[i] = new(sql.UnknownType)
+ }
+ }
+ return values, nil
+}
+
+// assignValues assigns the values that were returned from sql.Rows (after scanning)
+// to the AuthIdentityChannel fields.
+func (_m *AuthIdentityChannel) assignValues(columns []string, values []any) error {
+ if m, n := len(values), len(columns); m < n {
+ return fmt.Errorf("mismatch number of scan values: %d != %d", m, n)
+ }
+ for i := range columns {
+ switch columns[i] {
+ case authidentitychannel.FieldID:
+ value, ok := values[i].(*sql.NullInt64)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field id", value)
+ }
+ _m.ID = int64(value.Int64)
+ case authidentitychannel.FieldCreatedAt:
+ if value, ok := values[i].(*sql.NullTime); !ok {
+ return fmt.Errorf("unexpected type %T for field created_at", values[i])
+ } else if value.Valid {
+ _m.CreatedAt = value.Time
+ }
+ case authidentitychannel.FieldUpdatedAt:
+ if value, ok := values[i].(*sql.NullTime); !ok {
+ return fmt.Errorf("unexpected type %T for field updated_at", values[i])
+ } else if value.Valid {
+ _m.UpdatedAt = value.Time
+ }
+ case authidentitychannel.FieldIdentityID:
+ if value, ok := values[i].(*sql.NullInt64); !ok {
+ return fmt.Errorf("unexpected type %T for field identity_id", values[i])
+ } else if value.Valid {
+ _m.IdentityID = value.Int64
+ }
+ case authidentitychannel.FieldProviderType:
+ if value, ok := values[i].(*sql.NullString); !ok {
+ return fmt.Errorf("unexpected type %T for field provider_type", values[i])
+ } else if value.Valid {
+ _m.ProviderType = value.String
+ }
+ case authidentitychannel.FieldProviderKey:
+ if value, ok := values[i].(*sql.NullString); !ok {
+ return fmt.Errorf("unexpected type %T for field provider_key", values[i])
+ } else if value.Valid {
+ _m.ProviderKey = value.String
+ }
+ case authidentitychannel.FieldChannel:
+ if value, ok := values[i].(*sql.NullString); !ok {
+ return fmt.Errorf("unexpected type %T for field channel", values[i])
+ } else if value.Valid {
+ _m.Channel = value.String
+ }
+ case authidentitychannel.FieldChannelAppID:
+ if value, ok := values[i].(*sql.NullString); !ok {
+ return fmt.Errorf("unexpected type %T for field channel_app_id", values[i])
+ } else if value.Valid {
+ _m.ChannelAppID = value.String
+ }
+ case authidentitychannel.FieldChannelSubject:
+ if value, ok := values[i].(*sql.NullString); !ok {
+ return fmt.Errorf("unexpected type %T for field channel_subject", values[i])
+ } else if value.Valid {
+ _m.ChannelSubject = value.String
+ }
+ case authidentitychannel.FieldMetadata:
+ if value, ok := values[i].(*[]byte); !ok {
+ return fmt.Errorf("unexpected type %T for field metadata", values[i])
+ } else if value != nil && len(*value) > 0 {
+ if err := json.Unmarshal(*value, &_m.Metadata); err != nil {
+ return fmt.Errorf("unmarshal field metadata: %w", err)
+ }
+ }
+ default:
+ _m.selectValues.Set(columns[i], values[i])
+ }
+ }
+ return nil
+}
+
+// Value returns the ent.Value that was dynamically selected and assigned to the AuthIdentityChannel.
+// This includes values selected through modifiers, order, etc.
+func (_m *AuthIdentityChannel) Value(name string) (ent.Value, error) {
+ return _m.selectValues.Get(name)
+}
+
+// QueryIdentity queries the "identity" edge of the AuthIdentityChannel entity.
+func (_m *AuthIdentityChannel) QueryIdentity() *AuthIdentityQuery {
+ return NewAuthIdentityChannelClient(_m.config).QueryIdentity(_m)
+}
+
+// Update returns a builder for updating this AuthIdentityChannel.
+// Note that you need to call AuthIdentityChannel.Unwrap() before calling this method if this AuthIdentityChannel
+// was returned from a transaction, and the transaction was committed or rolled back.
+func (_m *AuthIdentityChannel) Update() *AuthIdentityChannelUpdateOne {
+ return NewAuthIdentityChannelClient(_m.config).UpdateOne(_m)
+}
+
+// Unwrap unwraps the AuthIdentityChannel entity that was returned from a transaction after it was closed,
+// so that all future queries will be executed through the driver which created the transaction.
+func (_m *AuthIdentityChannel) Unwrap() *AuthIdentityChannel {
+ _tx, ok := _m.config.driver.(*txDriver)
+ if !ok {
+ panic("ent: AuthIdentityChannel is not a transactional entity")
+ }
+ _m.config.driver = _tx.drv
+ return _m
+}
+
+// String implements the fmt.Stringer.
+func (_m *AuthIdentityChannel) String() string {
+ var builder strings.Builder
+ builder.WriteString("AuthIdentityChannel(")
+ builder.WriteString(fmt.Sprintf("id=%v, ", _m.ID))
+ builder.WriteString("created_at=")
+ builder.WriteString(_m.CreatedAt.Format(time.ANSIC))
+ builder.WriteString(", ")
+ builder.WriteString("updated_at=")
+ builder.WriteString(_m.UpdatedAt.Format(time.ANSIC))
+ builder.WriteString(", ")
+ builder.WriteString("identity_id=")
+ builder.WriteString(fmt.Sprintf("%v", _m.IdentityID))
+ builder.WriteString(", ")
+ builder.WriteString("provider_type=")
+ builder.WriteString(_m.ProviderType)
+ builder.WriteString(", ")
+ builder.WriteString("provider_key=")
+ builder.WriteString(_m.ProviderKey)
+ builder.WriteString(", ")
+ builder.WriteString("channel=")
+ builder.WriteString(_m.Channel)
+ builder.WriteString(", ")
+ builder.WriteString("channel_app_id=")
+ builder.WriteString(_m.ChannelAppID)
+ builder.WriteString(", ")
+ builder.WriteString("channel_subject=")
+ builder.WriteString(_m.ChannelSubject)
+ builder.WriteString(", ")
+ builder.WriteString("metadata=")
+ builder.WriteString(fmt.Sprintf("%v", _m.Metadata))
+ builder.WriteByte(')')
+ return builder.String()
+}
+
+// AuthIdentityChannels is a parsable slice of AuthIdentityChannel.
+type AuthIdentityChannels []*AuthIdentityChannel
diff --git a/backend/ent/authidentitychannel/authidentitychannel.go b/backend/ent/authidentitychannel/authidentitychannel.go
new file mode 100644
index 00000000..7dcc98bb
--- /dev/null
+++ b/backend/ent/authidentitychannel/authidentitychannel.go
@@ -0,0 +1,153 @@
+// Code generated by ent, DO NOT EDIT.
+
+package authidentitychannel
+
+import (
+ "time"
+
+ "entgo.io/ent/dialect/sql"
+ "entgo.io/ent/dialect/sql/sqlgraph"
+)
+
+const (
+ // Label holds the string label denoting the authidentitychannel type in the database.
+ Label = "auth_identity_channel"
+ // FieldID holds the string denoting the id field in the database.
+ FieldID = "id"
+ // FieldCreatedAt holds the string denoting the created_at field in the database.
+ FieldCreatedAt = "created_at"
+ // FieldUpdatedAt holds the string denoting the updated_at field in the database.
+ FieldUpdatedAt = "updated_at"
+ // FieldIdentityID holds the string denoting the identity_id field in the database.
+ FieldIdentityID = "identity_id"
+ // FieldProviderType holds the string denoting the provider_type field in the database.
+ FieldProviderType = "provider_type"
+ // FieldProviderKey holds the string denoting the provider_key field in the database.
+ FieldProviderKey = "provider_key"
+ // FieldChannel holds the string denoting the channel field in the database.
+ FieldChannel = "channel"
+ // FieldChannelAppID holds the string denoting the channel_app_id field in the database.
+ FieldChannelAppID = "channel_app_id"
+ // FieldChannelSubject holds the string denoting the channel_subject field in the database.
+ FieldChannelSubject = "channel_subject"
+ // FieldMetadata holds the string denoting the metadata field in the database.
+ FieldMetadata = "metadata"
+ // EdgeIdentity holds the string denoting the identity edge name in mutations.
+ EdgeIdentity = "identity"
+ // Table holds the table name of the authidentitychannel in the database.
+ Table = "auth_identity_channels"
+ // IdentityTable is the table that holds the identity relation/edge.
+ IdentityTable = "auth_identity_channels"
+ // IdentityInverseTable is the table name for the AuthIdentity entity.
+ // It exists in this package in order to avoid circular dependency with the "authidentity" package.
+ IdentityInverseTable = "auth_identities"
+ // IdentityColumn is the table column denoting the identity relation/edge.
+ IdentityColumn = "identity_id"
+)
+
+// Columns holds all SQL columns for authidentitychannel fields.
+var Columns = []string{
+ FieldID,
+ FieldCreatedAt,
+ FieldUpdatedAt,
+ FieldIdentityID,
+ FieldProviderType,
+ FieldProviderKey,
+ FieldChannel,
+ FieldChannelAppID,
+ FieldChannelSubject,
+ FieldMetadata,
+}
+
+// ValidColumn reports if the column name is valid (part of the table columns).
+func ValidColumn(column string) bool {
+ for i := range Columns {
+ if column == Columns[i] {
+ return true
+ }
+ }
+ return false
+}
+
+var (
+ // DefaultCreatedAt holds the default value on creation for the "created_at" field.
+ DefaultCreatedAt func() time.Time
+ // DefaultUpdatedAt holds the default value on creation for the "updated_at" field.
+ DefaultUpdatedAt func() time.Time
+ // UpdateDefaultUpdatedAt holds the default value on update for the "updated_at" field.
+ UpdateDefaultUpdatedAt func() time.Time
+ // ProviderTypeValidator is a validator for the "provider_type" field. It is called by the builders before save.
+ ProviderTypeValidator func(string) error
+ // ProviderKeyValidator is a validator for the "provider_key" field. It is called by the builders before save.
+ ProviderKeyValidator func(string) error
+ // ChannelValidator is a validator for the "channel" field. It is called by the builders before save.
+ ChannelValidator func(string) error
+ // ChannelAppIDValidator is a validator for the "channel_app_id" field. It is called by the builders before save.
+ ChannelAppIDValidator func(string) error
+ // ChannelSubjectValidator is a validator for the "channel_subject" field. It is called by the builders before save.
+ ChannelSubjectValidator func(string) error
+ // DefaultMetadata holds the default value on creation for the "metadata" field.
+ DefaultMetadata func() map[string]interface{}
+)
+
+// OrderOption defines the ordering options for the AuthIdentityChannel queries.
+type OrderOption func(*sql.Selector)
+
+// ByID orders the results by the id field.
+func ByID(opts ...sql.OrderTermOption) OrderOption {
+ return sql.OrderByField(FieldID, opts...).ToFunc()
+}
+
+// ByCreatedAt orders the results by the created_at field.
+func ByCreatedAt(opts ...sql.OrderTermOption) OrderOption {
+ return sql.OrderByField(FieldCreatedAt, opts...).ToFunc()
+}
+
+// ByUpdatedAt orders the results by the updated_at field.
+func ByUpdatedAt(opts ...sql.OrderTermOption) OrderOption {
+ return sql.OrderByField(FieldUpdatedAt, opts...).ToFunc()
+}
+
+// ByIdentityID orders the results by the identity_id field.
+func ByIdentityID(opts ...sql.OrderTermOption) OrderOption {
+ return sql.OrderByField(FieldIdentityID, opts...).ToFunc()
+}
+
+// ByProviderType orders the results by the provider_type field.
+func ByProviderType(opts ...sql.OrderTermOption) OrderOption {
+ return sql.OrderByField(FieldProviderType, opts...).ToFunc()
+}
+
+// ByProviderKey orders the results by the provider_key field.
+func ByProviderKey(opts ...sql.OrderTermOption) OrderOption {
+ return sql.OrderByField(FieldProviderKey, opts...).ToFunc()
+}
+
+// ByChannel orders the results by the channel field.
+func ByChannel(opts ...sql.OrderTermOption) OrderOption {
+ return sql.OrderByField(FieldChannel, opts...).ToFunc()
+}
+
+// ByChannelAppID orders the results by the channel_app_id field.
+func ByChannelAppID(opts ...sql.OrderTermOption) OrderOption {
+ return sql.OrderByField(FieldChannelAppID, opts...).ToFunc()
+}
+
+// ByChannelSubject orders the results by the channel_subject field.
+func ByChannelSubject(opts ...sql.OrderTermOption) OrderOption {
+ return sql.OrderByField(FieldChannelSubject, opts...).ToFunc()
+}
+
+// ByIdentityField orders the results by identity field.
+func ByIdentityField(field string, opts ...sql.OrderTermOption) OrderOption {
+ return func(s *sql.Selector) {
+ sqlgraph.OrderByNeighborTerms(s, newIdentityStep(), sql.OrderByField(field, opts...))
+ }
+}
+func newIdentityStep() *sqlgraph.Step {
+ return sqlgraph.NewStep(
+ sqlgraph.From(Table, FieldID),
+ sqlgraph.To(IdentityInverseTable, FieldID),
+ sqlgraph.Edge(sqlgraph.M2O, true, IdentityTable, IdentityColumn),
+ )
+}
diff --git a/backend/ent/authidentitychannel/where.go b/backend/ent/authidentitychannel/where.go
new file mode 100644
index 00000000..827dc384
--- /dev/null
+++ b/backend/ent/authidentitychannel/where.go
@@ -0,0 +1,559 @@
+// Code generated by ent, DO NOT EDIT.
+
+package authidentitychannel
+
+import (
+ "time"
+
+ "entgo.io/ent/dialect/sql"
+ "entgo.io/ent/dialect/sql/sqlgraph"
+ "github.com/Wei-Shaw/sub2api/ent/predicate"
+)
+
+// ID filters vertices based on their ID field.
+func ID(id int64) predicate.AuthIdentityChannel {
+ return predicate.AuthIdentityChannel(sql.FieldEQ(FieldID, id))
+}
+
+// IDEQ applies the EQ predicate on the ID field.
+func IDEQ(id int64) predicate.AuthIdentityChannel {
+ return predicate.AuthIdentityChannel(sql.FieldEQ(FieldID, id))
+}
+
+// IDNEQ applies the NEQ predicate on the ID field.
+func IDNEQ(id int64) predicate.AuthIdentityChannel {
+ return predicate.AuthIdentityChannel(sql.FieldNEQ(FieldID, id))
+}
+
+// IDIn applies the In predicate on the ID field.
+func IDIn(ids ...int64) predicate.AuthIdentityChannel {
+ return predicate.AuthIdentityChannel(sql.FieldIn(FieldID, ids...))
+}
+
+// IDNotIn applies the NotIn predicate on the ID field.
+func IDNotIn(ids ...int64) predicate.AuthIdentityChannel {
+ return predicate.AuthIdentityChannel(sql.FieldNotIn(FieldID, ids...))
+}
+
+// IDGT applies the GT predicate on the ID field.
+func IDGT(id int64) predicate.AuthIdentityChannel {
+ return predicate.AuthIdentityChannel(sql.FieldGT(FieldID, id))
+}
+
+// IDGTE applies the GTE predicate on the ID field.
+func IDGTE(id int64) predicate.AuthIdentityChannel {
+ return predicate.AuthIdentityChannel(sql.FieldGTE(FieldID, id))
+}
+
+// IDLT applies the LT predicate on the ID field.
+func IDLT(id int64) predicate.AuthIdentityChannel {
+ return predicate.AuthIdentityChannel(sql.FieldLT(FieldID, id))
+}
+
+// IDLTE applies the LTE predicate on the ID field.
+func IDLTE(id int64) predicate.AuthIdentityChannel {
+ return predicate.AuthIdentityChannel(sql.FieldLTE(FieldID, id))
+}
+
+// CreatedAt applies equality check predicate on the "created_at" field. It's identical to CreatedAtEQ.
+func CreatedAt(v time.Time) predicate.AuthIdentityChannel {
+ return predicate.AuthIdentityChannel(sql.FieldEQ(FieldCreatedAt, v))
+}
+
+// UpdatedAt applies equality check predicate on the "updated_at" field. It's identical to UpdatedAtEQ.
+func UpdatedAt(v time.Time) predicate.AuthIdentityChannel {
+ return predicate.AuthIdentityChannel(sql.FieldEQ(FieldUpdatedAt, v))
+}
+
+// IdentityID applies equality check predicate on the "identity_id" field. It's identical to IdentityIDEQ.
+func IdentityID(v int64) predicate.AuthIdentityChannel {
+ return predicate.AuthIdentityChannel(sql.FieldEQ(FieldIdentityID, v))
+}
+
+// ProviderType applies equality check predicate on the "provider_type" field. It's identical to ProviderTypeEQ.
+func ProviderType(v string) predicate.AuthIdentityChannel {
+ return predicate.AuthIdentityChannel(sql.FieldEQ(FieldProviderType, v))
+}
+
+// ProviderKey applies equality check predicate on the "provider_key" field. It's identical to ProviderKeyEQ.
+func ProviderKey(v string) predicate.AuthIdentityChannel {
+ return predicate.AuthIdentityChannel(sql.FieldEQ(FieldProviderKey, v))
+}
+
+// Channel applies equality check predicate on the "channel" field. It's identical to ChannelEQ.
+func Channel(v string) predicate.AuthIdentityChannel {
+ return predicate.AuthIdentityChannel(sql.FieldEQ(FieldChannel, v))
+}
+
+// ChannelAppID applies equality check predicate on the "channel_app_id" field. It's identical to ChannelAppIDEQ.
+func ChannelAppID(v string) predicate.AuthIdentityChannel {
+ return predicate.AuthIdentityChannel(sql.FieldEQ(FieldChannelAppID, v))
+}
+
+// ChannelSubject applies equality check predicate on the "channel_subject" field. It's identical to ChannelSubjectEQ.
+func ChannelSubject(v string) predicate.AuthIdentityChannel {
+ return predicate.AuthIdentityChannel(sql.FieldEQ(FieldChannelSubject, v))
+}
+
+// CreatedAtEQ applies the EQ predicate on the "created_at" field.
+func CreatedAtEQ(v time.Time) predicate.AuthIdentityChannel {
+ return predicate.AuthIdentityChannel(sql.FieldEQ(FieldCreatedAt, v))
+}
+
+// CreatedAtNEQ applies the NEQ predicate on the "created_at" field.
+func CreatedAtNEQ(v time.Time) predicate.AuthIdentityChannel {
+ return predicate.AuthIdentityChannel(sql.FieldNEQ(FieldCreatedAt, v))
+}
+
+// CreatedAtIn applies the In predicate on the "created_at" field.
+func CreatedAtIn(vs ...time.Time) predicate.AuthIdentityChannel {
+ return predicate.AuthIdentityChannel(sql.FieldIn(FieldCreatedAt, vs...))
+}
+
+// CreatedAtNotIn applies the NotIn predicate on the "created_at" field.
+func CreatedAtNotIn(vs ...time.Time) predicate.AuthIdentityChannel {
+ return predicate.AuthIdentityChannel(sql.FieldNotIn(FieldCreatedAt, vs...))
+}
+
+// CreatedAtGT applies the GT predicate on the "created_at" field.
+func CreatedAtGT(v time.Time) predicate.AuthIdentityChannel {
+ return predicate.AuthIdentityChannel(sql.FieldGT(FieldCreatedAt, v))
+}
+
+// CreatedAtGTE applies the GTE predicate on the "created_at" field.
+func CreatedAtGTE(v time.Time) predicate.AuthIdentityChannel {
+ return predicate.AuthIdentityChannel(sql.FieldGTE(FieldCreatedAt, v))
+}
+
+// CreatedAtLT applies the LT predicate on the "created_at" field.
+func CreatedAtLT(v time.Time) predicate.AuthIdentityChannel {
+ return predicate.AuthIdentityChannel(sql.FieldLT(FieldCreatedAt, v))
+}
+
+// CreatedAtLTE applies the LTE predicate on the "created_at" field.
+func CreatedAtLTE(v time.Time) predicate.AuthIdentityChannel {
+ return predicate.AuthIdentityChannel(sql.FieldLTE(FieldCreatedAt, v))
+}
+
+// UpdatedAtEQ applies the EQ predicate on the "updated_at" field.
+func UpdatedAtEQ(v time.Time) predicate.AuthIdentityChannel {
+ return predicate.AuthIdentityChannel(sql.FieldEQ(FieldUpdatedAt, v))
+}
+
+// UpdatedAtNEQ applies the NEQ predicate on the "updated_at" field.
+func UpdatedAtNEQ(v time.Time) predicate.AuthIdentityChannel {
+ return predicate.AuthIdentityChannel(sql.FieldNEQ(FieldUpdatedAt, v))
+}
+
+// UpdatedAtIn applies the In predicate on the "updated_at" field.
+func UpdatedAtIn(vs ...time.Time) predicate.AuthIdentityChannel {
+ return predicate.AuthIdentityChannel(sql.FieldIn(FieldUpdatedAt, vs...))
+}
+
+// UpdatedAtNotIn applies the NotIn predicate on the "updated_at" field.
+func UpdatedAtNotIn(vs ...time.Time) predicate.AuthIdentityChannel {
+ return predicate.AuthIdentityChannel(sql.FieldNotIn(FieldUpdatedAt, vs...))
+}
+
+// UpdatedAtGT applies the GT predicate on the "updated_at" field.
+func UpdatedAtGT(v time.Time) predicate.AuthIdentityChannel {
+ return predicate.AuthIdentityChannel(sql.FieldGT(FieldUpdatedAt, v))
+}
+
+// UpdatedAtGTE applies the GTE predicate on the "updated_at" field.
+func UpdatedAtGTE(v time.Time) predicate.AuthIdentityChannel {
+ return predicate.AuthIdentityChannel(sql.FieldGTE(FieldUpdatedAt, v))
+}
+
+// UpdatedAtLT applies the LT predicate on the "updated_at" field.
+func UpdatedAtLT(v time.Time) predicate.AuthIdentityChannel {
+ return predicate.AuthIdentityChannel(sql.FieldLT(FieldUpdatedAt, v))
+}
+
+// UpdatedAtLTE applies the LTE predicate on the "updated_at" field.
+func UpdatedAtLTE(v time.Time) predicate.AuthIdentityChannel {
+ return predicate.AuthIdentityChannel(sql.FieldLTE(FieldUpdatedAt, v))
+}
+
+// IdentityIDEQ applies the EQ predicate on the "identity_id" field.
+func IdentityIDEQ(v int64) predicate.AuthIdentityChannel {
+ return predicate.AuthIdentityChannel(sql.FieldEQ(FieldIdentityID, v))
+}
+
+// IdentityIDNEQ applies the NEQ predicate on the "identity_id" field.
+func IdentityIDNEQ(v int64) predicate.AuthIdentityChannel {
+ return predicate.AuthIdentityChannel(sql.FieldNEQ(FieldIdentityID, v))
+}
+
+// IdentityIDIn applies the In predicate on the "identity_id" field.
+func IdentityIDIn(vs ...int64) predicate.AuthIdentityChannel {
+ return predicate.AuthIdentityChannel(sql.FieldIn(FieldIdentityID, vs...))
+}
+
+// IdentityIDNotIn applies the NotIn predicate on the "identity_id" field.
+func IdentityIDNotIn(vs ...int64) predicate.AuthIdentityChannel {
+ return predicate.AuthIdentityChannel(sql.FieldNotIn(FieldIdentityID, vs...))
+}
+
+// ProviderTypeEQ applies the EQ predicate on the "provider_type" field.
+func ProviderTypeEQ(v string) predicate.AuthIdentityChannel {
+ return predicate.AuthIdentityChannel(sql.FieldEQ(FieldProviderType, v))
+}
+
+// ProviderTypeNEQ applies the NEQ predicate on the "provider_type" field.
+func ProviderTypeNEQ(v string) predicate.AuthIdentityChannel {
+ return predicate.AuthIdentityChannel(sql.FieldNEQ(FieldProviderType, v))
+}
+
+// ProviderTypeIn applies the In predicate on the "provider_type" field.
+func ProviderTypeIn(vs ...string) predicate.AuthIdentityChannel {
+ return predicate.AuthIdentityChannel(sql.FieldIn(FieldProviderType, vs...))
+}
+
+// ProviderTypeNotIn applies the NotIn predicate on the "provider_type" field.
+func ProviderTypeNotIn(vs ...string) predicate.AuthIdentityChannel {
+ return predicate.AuthIdentityChannel(sql.FieldNotIn(FieldProviderType, vs...))
+}
+
+// ProviderTypeGT applies the GT predicate on the "provider_type" field.
+func ProviderTypeGT(v string) predicate.AuthIdentityChannel {
+ return predicate.AuthIdentityChannel(sql.FieldGT(FieldProviderType, v))
+}
+
+// ProviderTypeGTE applies the GTE predicate on the "provider_type" field.
+func ProviderTypeGTE(v string) predicate.AuthIdentityChannel {
+ return predicate.AuthIdentityChannel(sql.FieldGTE(FieldProviderType, v))
+}
+
+// ProviderTypeLT applies the LT predicate on the "provider_type" field.
+func ProviderTypeLT(v string) predicate.AuthIdentityChannel {
+ return predicate.AuthIdentityChannel(sql.FieldLT(FieldProviderType, v))
+}
+
+// ProviderTypeLTE applies the LTE predicate on the "provider_type" field.
+func ProviderTypeLTE(v string) predicate.AuthIdentityChannel {
+ return predicate.AuthIdentityChannel(sql.FieldLTE(FieldProviderType, v))
+}
+
+// ProviderTypeContains applies the Contains predicate on the "provider_type" field.
+func ProviderTypeContains(v string) predicate.AuthIdentityChannel {
+ return predicate.AuthIdentityChannel(sql.FieldContains(FieldProviderType, v))
+}
+
+// ProviderTypeHasPrefix applies the HasPrefix predicate on the "provider_type" field.
+func ProviderTypeHasPrefix(v string) predicate.AuthIdentityChannel {
+ return predicate.AuthIdentityChannel(sql.FieldHasPrefix(FieldProviderType, v))
+}
+
+// ProviderTypeHasSuffix applies the HasSuffix predicate on the "provider_type" field.
+func ProviderTypeHasSuffix(v string) predicate.AuthIdentityChannel {
+ return predicate.AuthIdentityChannel(sql.FieldHasSuffix(FieldProviderType, v))
+}
+
+// ProviderTypeEqualFold applies the EqualFold predicate on the "provider_type" field.
+func ProviderTypeEqualFold(v string) predicate.AuthIdentityChannel {
+ return predicate.AuthIdentityChannel(sql.FieldEqualFold(FieldProviderType, v))
+}
+
+// ProviderTypeContainsFold applies the ContainsFold predicate on the "provider_type" field.
+func ProviderTypeContainsFold(v string) predicate.AuthIdentityChannel {
+ return predicate.AuthIdentityChannel(sql.FieldContainsFold(FieldProviderType, v))
+}
+
+// ProviderKeyEQ applies the EQ predicate on the "provider_key" field.
+func ProviderKeyEQ(v string) predicate.AuthIdentityChannel {
+ return predicate.AuthIdentityChannel(sql.FieldEQ(FieldProviderKey, v))
+}
+
+// ProviderKeyNEQ applies the NEQ predicate on the "provider_key" field.
+func ProviderKeyNEQ(v string) predicate.AuthIdentityChannel {
+ return predicate.AuthIdentityChannel(sql.FieldNEQ(FieldProviderKey, v))
+}
+
+// ProviderKeyIn applies the In predicate on the "provider_key" field.
+func ProviderKeyIn(vs ...string) predicate.AuthIdentityChannel {
+ return predicate.AuthIdentityChannel(sql.FieldIn(FieldProviderKey, vs...))
+}
+
+// ProviderKeyNotIn applies the NotIn predicate on the "provider_key" field.
+func ProviderKeyNotIn(vs ...string) predicate.AuthIdentityChannel {
+ return predicate.AuthIdentityChannel(sql.FieldNotIn(FieldProviderKey, vs...))
+}
+
+// ProviderKeyGT applies the GT predicate on the "provider_key" field.
+func ProviderKeyGT(v string) predicate.AuthIdentityChannel {
+ return predicate.AuthIdentityChannel(sql.FieldGT(FieldProviderKey, v))
+}
+
+// ProviderKeyGTE applies the GTE predicate on the "provider_key" field.
+func ProviderKeyGTE(v string) predicate.AuthIdentityChannel {
+ return predicate.AuthIdentityChannel(sql.FieldGTE(FieldProviderKey, v))
+}
+
+// ProviderKeyLT applies the LT predicate on the "provider_key" field.
+func ProviderKeyLT(v string) predicate.AuthIdentityChannel {
+ return predicate.AuthIdentityChannel(sql.FieldLT(FieldProviderKey, v))
+}
+
+// ProviderKeyLTE applies the LTE predicate on the "provider_key" field.
+func ProviderKeyLTE(v string) predicate.AuthIdentityChannel {
+ return predicate.AuthIdentityChannel(sql.FieldLTE(FieldProviderKey, v))
+}
+
+// ProviderKeyContains applies the Contains predicate on the "provider_key" field.
+func ProviderKeyContains(v string) predicate.AuthIdentityChannel {
+ return predicate.AuthIdentityChannel(sql.FieldContains(FieldProviderKey, v))
+}
+
+// ProviderKeyHasPrefix applies the HasPrefix predicate on the "provider_key" field.
+func ProviderKeyHasPrefix(v string) predicate.AuthIdentityChannel {
+ return predicate.AuthIdentityChannel(sql.FieldHasPrefix(FieldProviderKey, v))
+}
+
+// ProviderKeyHasSuffix applies the HasSuffix predicate on the "provider_key" field.
+func ProviderKeyHasSuffix(v string) predicate.AuthIdentityChannel {
+ return predicate.AuthIdentityChannel(sql.FieldHasSuffix(FieldProviderKey, v))
+}
+
+// ProviderKeyEqualFold applies the EqualFold predicate on the "provider_key" field.
+func ProviderKeyEqualFold(v string) predicate.AuthIdentityChannel {
+ return predicate.AuthIdentityChannel(sql.FieldEqualFold(FieldProviderKey, v))
+}
+
+// ProviderKeyContainsFold applies the ContainsFold predicate on the "provider_key" field.
+func ProviderKeyContainsFold(v string) predicate.AuthIdentityChannel {
+ return predicate.AuthIdentityChannel(sql.FieldContainsFold(FieldProviderKey, v))
+}
+
+// ChannelEQ applies the EQ predicate on the "channel" field.
+func ChannelEQ(v string) predicate.AuthIdentityChannel {
+ return predicate.AuthIdentityChannel(sql.FieldEQ(FieldChannel, v))
+}
+
+// ChannelNEQ applies the NEQ predicate on the "channel" field.
+func ChannelNEQ(v string) predicate.AuthIdentityChannel {
+ return predicate.AuthIdentityChannel(sql.FieldNEQ(FieldChannel, v))
+}
+
+// ChannelIn applies the In predicate on the "channel" field.
+func ChannelIn(vs ...string) predicate.AuthIdentityChannel {
+ return predicate.AuthIdentityChannel(sql.FieldIn(FieldChannel, vs...))
+}
+
+// ChannelNotIn applies the NotIn predicate on the "channel" field.
+func ChannelNotIn(vs ...string) predicate.AuthIdentityChannel {
+ return predicate.AuthIdentityChannel(sql.FieldNotIn(FieldChannel, vs...))
+}
+
+// ChannelGT applies the GT predicate on the "channel" field.
+func ChannelGT(v string) predicate.AuthIdentityChannel {
+ return predicate.AuthIdentityChannel(sql.FieldGT(FieldChannel, v))
+}
+
+// ChannelGTE applies the GTE predicate on the "channel" field.
+func ChannelGTE(v string) predicate.AuthIdentityChannel {
+ return predicate.AuthIdentityChannel(sql.FieldGTE(FieldChannel, v))
+}
+
+// ChannelLT applies the LT predicate on the "channel" field.
+func ChannelLT(v string) predicate.AuthIdentityChannel {
+ return predicate.AuthIdentityChannel(sql.FieldLT(FieldChannel, v))
+}
+
+// ChannelLTE applies the LTE predicate on the "channel" field.
+func ChannelLTE(v string) predicate.AuthIdentityChannel {
+ return predicate.AuthIdentityChannel(sql.FieldLTE(FieldChannel, v))
+}
+
+// ChannelContains applies the Contains predicate on the "channel" field.
+func ChannelContains(v string) predicate.AuthIdentityChannel {
+ return predicate.AuthIdentityChannel(sql.FieldContains(FieldChannel, v))
+}
+
+// ChannelHasPrefix applies the HasPrefix predicate on the "channel" field.
+func ChannelHasPrefix(v string) predicate.AuthIdentityChannel {
+ return predicate.AuthIdentityChannel(sql.FieldHasPrefix(FieldChannel, v))
+}
+
+// ChannelHasSuffix applies the HasSuffix predicate on the "channel" field.
+func ChannelHasSuffix(v string) predicate.AuthIdentityChannel {
+ return predicate.AuthIdentityChannel(sql.FieldHasSuffix(FieldChannel, v))
+}
+
+// ChannelEqualFold applies the EqualFold predicate on the "channel" field.
+func ChannelEqualFold(v string) predicate.AuthIdentityChannel {
+ return predicate.AuthIdentityChannel(sql.FieldEqualFold(FieldChannel, v))
+}
+
+// ChannelContainsFold applies the ContainsFold predicate on the "channel" field.
+func ChannelContainsFold(v string) predicate.AuthIdentityChannel {
+ return predicate.AuthIdentityChannel(sql.FieldContainsFold(FieldChannel, v))
+}
+
+// ChannelAppIDEQ applies the EQ predicate on the "channel_app_id" field.
+func ChannelAppIDEQ(v string) predicate.AuthIdentityChannel {
+ return predicate.AuthIdentityChannel(sql.FieldEQ(FieldChannelAppID, v))
+}
+
+// ChannelAppIDNEQ applies the NEQ predicate on the "channel_app_id" field.
+func ChannelAppIDNEQ(v string) predicate.AuthIdentityChannel {
+ return predicate.AuthIdentityChannel(sql.FieldNEQ(FieldChannelAppID, v))
+}
+
+// ChannelAppIDIn applies the In predicate on the "channel_app_id" field.
+func ChannelAppIDIn(vs ...string) predicate.AuthIdentityChannel {
+ return predicate.AuthIdentityChannel(sql.FieldIn(FieldChannelAppID, vs...))
+}
+
+// ChannelAppIDNotIn applies the NotIn predicate on the "channel_app_id" field.
+func ChannelAppIDNotIn(vs ...string) predicate.AuthIdentityChannel {
+ return predicate.AuthIdentityChannel(sql.FieldNotIn(FieldChannelAppID, vs...))
+}
+
+// ChannelAppIDGT applies the GT predicate on the "channel_app_id" field.
+func ChannelAppIDGT(v string) predicate.AuthIdentityChannel {
+ return predicate.AuthIdentityChannel(sql.FieldGT(FieldChannelAppID, v))
+}
+
+// ChannelAppIDGTE applies the GTE predicate on the "channel_app_id" field.
+func ChannelAppIDGTE(v string) predicate.AuthIdentityChannel {
+ return predicate.AuthIdentityChannel(sql.FieldGTE(FieldChannelAppID, v))
+}
+
+// ChannelAppIDLT applies the LT predicate on the "channel_app_id" field.
+func ChannelAppIDLT(v string) predicate.AuthIdentityChannel {
+ return predicate.AuthIdentityChannel(sql.FieldLT(FieldChannelAppID, v))
+}
+
+// ChannelAppIDLTE applies the LTE predicate on the "channel_app_id" field.
+func ChannelAppIDLTE(v string) predicate.AuthIdentityChannel {
+ return predicate.AuthIdentityChannel(sql.FieldLTE(FieldChannelAppID, v))
+}
+
+// ChannelAppIDContains applies the Contains predicate on the "channel_app_id" field.
+func ChannelAppIDContains(v string) predicate.AuthIdentityChannel {
+ return predicate.AuthIdentityChannel(sql.FieldContains(FieldChannelAppID, v))
+}
+
+// ChannelAppIDHasPrefix applies the HasPrefix predicate on the "channel_app_id" field.
+func ChannelAppIDHasPrefix(v string) predicate.AuthIdentityChannel {
+ return predicate.AuthIdentityChannel(sql.FieldHasPrefix(FieldChannelAppID, v))
+}
+
+// ChannelAppIDHasSuffix applies the HasSuffix predicate on the "channel_app_id" field.
+func ChannelAppIDHasSuffix(v string) predicate.AuthIdentityChannel {
+ return predicate.AuthIdentityChannel(sql.FieldHasSuffix(FieldChannelAppID, v))
+}
+
+// ChannelAppIDEqualFold applies the EqualFold predicate on the "channel_app_id" field.
+func ChannelAppIDEqualFold(v string) predicate.AuthIdentityChannel {
+ return predicate.AuthIdentityChannel(sql.FieldEqualFold(FieldChannelAppID, v))
+}
+
+// ChannelAppIDContainsFold applies the ContainsFold predicate on the "channel_app_id" field.
+func ChannelAppIDContainsFold(v string) predicate.AuthIdentityChannel {
+ return predicate.AuthIdentityChannel(sql.FieldContainsFold(FieldChannelAppID, v))
+}
+
+// ChannelSubjectEQ applies the EQ predicate on the "channel_subject" field.
+func ChannelSubjectEQ(v string) predicate.AuthIdentityChannel {
+ return predicate.AuthIdentityChannel(sql.FieldEQ(FieldChannelSubject, v))
+}
+
+// ChannelSubjectNEQ applies the NEQ predicate on the "channel_subject" field.
+func ChannelSubjectNEQ(v string) predicate.AuthIdentityChannel {
+ return predicate.AuthIdentityChannel(sql.FieldNEQ(FieldChannelSubject, v))
+}
+
+// ChannelSubjectIn applies the In predicate on the "channel_subject" field.
+func ChannelSubjectIn(vs ...string) predicate.AuthIdentityChannel {
+ return predicate.AuthIdentityChannel(sql.FieldIn(FieldChannelSubject, vs...))
+}
+
+// ChannelSubjectNotIn applies the NotIn predicate on the "channel_subject" field.
+func ChannelSubjectNotIn(vs ...string) predicate.AuthIdentityChannel {
+ return predicate.AuthIdentityChannel(sql.FieldNotIn(FieldChannelSubject, vs...))
+}
+
+// ChannelSubjectGT applies the GT predicate on the "channel_subject" field.
+func ChannelSubjectGT(v string) predicate.AuthIdentityChannel {
+ return predicate.AuthIdentityChannel(sql.FieldGT(FieldChannelSubject, v))
+}
+
+// ChannelSubjectGTE applies the GTE predicate on the "channel_subject" field.
+func ChannelSubjectGTE(v string) predicate.AuthIdentityChannel {
+ return predicate.AuthIdentityChannel(sql.FieldGTE(FieldChannelSubject, v))
+}
+
+// ChannelSubjectLT applies the LT predicate on the "channel_subject" field.
+func ChannelSubjectLT(v string) predicate.AuthIdentityChannel {
+ return predicate.AuthIdentityChannel(sql.FieldLT(FieldChannelSubject, v))
+}
+
+// ChannelSubjectLTE applies the LTE predicate on the "channel_subject" field.
+func ChannelSubjectLTE(v string) predicate.AuthIdentityChannel {
+ return predicate.AuthIdentityChannel(sql.FieldLTE(FieldChannelSubject, v))
+}
+
+// ChannelSubjectContains applies the Contains predicate on the "channel_subject" field.
+func ChannelSubjectContains(v string) predicate.AuthIdentityChannel {
+ return predicate.AuthIdentityChannel(sql.FieldContains(FieldChannelSubject, v))
+}
+
+// ChannelSubjectHasPrefix applies the HasPrefix predicate on the "channel_subject" field.
+func ChannelSubjectHasPrefix(v string) predicate.AuthIdentityChannel {
+ return predicate.AuthIdentityChannel(sql.FieldHasPrefix(FieldChannelSubject, v))
+}
+
+// ChannelSubjectHasSuffix applies the HasSuffix predicate on the "channel_subject" field.
+func ChannelSubjectHasSuffix(v string) predicate.AuthIdentityChannel {
+ return predicate.AuthIdentityChannel(sql.FieldHasSuffix(FieldChannelSubject, v))
+}
+
+// ChannelSubjectEqualFold applies the EqualFold predicate on the "channel_subject" field.
+func ChannelSubjectEqualFold(v string) predicate.AuthIdentityChannel {
+ return predicate.AuthIdentityChannel(sql.FieldEqualFold(FieldChannelSubject, v))
+}
+
+// ChannelSubjectContainsFold applies the ContainsFold predicate on the "channel_subject" field.
+func ChannelSubjectContainsFold(v string) predicate.AuthIdentityChannel {
+ return predicate.AuthIdentityChannel(sql.FieldContainsFold(FieldChannelSubject, v))
+}
+
+// HasIdentity applies the HasEdge predicate on the "identity" edge.
+func HasIdentity() predicate.AuthIdentityChannel {
+ return predicate.AuthIdentityChannel(func(s *sql.Selector) {
+ step := sqlgraph.NewStep(
+ sqlgraph.From(Table, FieldID),
+ sqlgraph.Edge(sqlgraph.M2O, true, IdentityTable, IdentityColumn),
+ )
+ sqlgraph.HasNeighbors(s, step)
+ })
+}
+
+// HasIdentityWith applies the HasEdge predicate on the "identity" edge with a given conditions (other predicates).
+func HasIdentityWith(preds ...predicate.AuthIdentity) predicate.AuthIdentityChannel {
+ return predicate.AuthIdentityChannel(func(s *sql.Selector) {
+ step := newIdentityStep()
+ sqlgraph.HasNeighborsWith(s, step, func(s *sql.Selector) {
+ for _, p := range preds {
+ p(s)
+ }
+ })
+ })
+}
+
+// And groups predicates with the AND operator between them.
+func And(predicates ...predicate.AuthIdentityChannel) predicate.AuthIdentityChannel {
+ return predicate.AuthIdentityChannel(sql.AndPredicates(predicates...))
+}
+
+// Or groups predicates with the OR operator between them.
+func Or(predicates ...predicate.AuthIdentityChannel) predicate.AuthIdentityChannel {
+ return predicate.AuthIdentityChannel(sql.OrPredicates(predicates...))
+}
+
+// Not applies the not operator on the given predicate.
+func Not(p predicate.AuthIdentityChannel) predicate.AuthIdentityChannel {
+ return predicate.AuthIdentityChannel(sql.NotPredicates(p))
+}
diff --git a/backend/ent/authidentitychannel_create.go b/backend/ent/authidentitychannel_create.go
new file mode 100644
index 00000000..4ce28479
--- /dev/null
+++ b/backend/ent/authidentitychannel_create.go
@@ -0,0 +1,932 @@
+// Code generated by ent, DO NOT EDIT.
+
+package ent
+
+import (
+ "context"
+ "errors"
+ "fmt"
+ "time"
+
+ "entgo.io/ent/dialect/sql"
+ "entgo.io/ent/dialect/sql/sqlgraph"
+ "entgo.io/ent/schema/field"
+ "github.com/Wei-Shaw/sub2api/ent/authidentity"
+ "github.com/Wei-Shaw/sub2api/ent/authidentitychannel"
+)
+
+// AuthIdentityChannelCreate is the builder for creating a AuthIdentityChannel entity.
+type AuthIdentityChannelCreate struct {
+ config
+ mutation *AuthIdentityChannelMutation
+ hooks []Hook
+ conflict []sql.ConflictOption
+}
+
+// SetCreatedAt sets the "created_at" field.
+func (_c *AuthIdentityChannelCreate) SetCreatedAt(v time.Time) *AuthIdentityChannelCreate {
+ _c.mutation.SetCreatedAt(v)
+ return _c
+}
+
+// SetNillableCreatedAt sets the "created_at" field if the given value is not nil.
+func (_c *AuthIdentityChannelCreate) SetNillableCreatedAt(v *time.Time) *AuthIdentityChannelCreate {
+ if v != nil {
+ _c.SetCreatedAt(*v)
+ }
+ return _c
+}
+
+// SetUpdatedAt sets the "updated_at" field.
+func (_c *AuthIdentityChannelCreate) SetUpdatedAt(v time.Time) *AuthIdentityChannelCreate {
+ _c.mutation.SetUpdatedAt(v)
+ return _c
+}
+
+// SetNillableUpdatedAt sets the "updated_at" field if the given value is not nil.
+func (_c *AuthIdentityChannelCreate) SetNillableUpdatedAt(v *time.Time) *AuthIdentityChannelCreate {
+ if v != nil {
+ _c.SetUpdatedAt(*v)
+ }
+ return _c
+}
+
+// SetIdentityID sets the "identity_id" field.
+func (_c *AuthIdentityChannelCreate) SetIdentityID(v int64) *AuthIdentityChannelCreate {
+ _c.mutation.SetIdentityID(v)
+ return _c
+}
+
+// SetProviderType sets the "provider_type" field.
+func (_c *AuthIdentityChannelCreate) SetProviderType(v string) *AuthIdentityChannelCreate {
+ _c.mutation.SetProviderType(v)
+ return _c
+}
+
+// SetProviderKey sets the "provider_key" field.
+func (_c *AuthIdentityChannelCreate) SetProviderKey(v string) *AuthIdentityChannelCreate {
+ _c.mutation.SetProviderKey(v)
+ return _c
+}
+
+// SetChannel sets the "channel" field.
+func (_c *AuthIdentityChannelCreate) SetChannel(v string) *AuthIdentityChannelCreate {
+ _c.mutation.SetChannel(v)
+ return _c
+}
+
+// SetChannelAppID sets the "channel_app_id" field.
+func (_c *AuthIdentityChannelCreate) SetChannelAppID(v string) *AuthIdentityChannelCreate {
+ _c.mutation.SetChannelAppID(v)
+ return _c
+}
+
+// SetChannelSubject sets the "channel_subject" field.
+func (_c *AuthIdentityChannelCreate) SetChannelSubject(v string) *AuthIdentityChannelCreate {
+ _c.mutation.SetChannelSubject(v)
+ return _c
+}
+
+// SetMetadata sets the "metadata" field.
+func (_c *AuthIdentityChannelCreate) SetMetadata(v map[string]interface{}) *AuthIdentityChannelCreate {
+ _c.mutation.SetMetadata(v)
+ return _c
+}
+
+// SetIdentity sets the "identity" edge to the AuthIdentity entity.
+func (_c *AuthIdentityChannelCreate) SetIdentity(v *AuthIdentity) *AuthIdentityChannelCreate {
+ return _c.SetIdentityID(v.ID)
+}
+
+// Mutation returns the AuthIdentityChannelMutation object of the builder.
+func (_c *AuthIdentityChannelCreate) Mutation() *AuthIdentityChannelMutation {
+ return _c.mutation
+}
+
+// Save creates the AuthIdentityChannel in the database.
+func (_c *AuthIdentityChannelCreate) Save(ctx context.Context) (*AuthIdentityChannel, error) {
+ _c.defaults()
+ return withHooks(ctx, _c.sqlSave, _c.mutation, _c.hooks)
+}
+
+// SaveX calls Save and panics if Save returns an error.
+func (_c *AuthIdentityChannelCreate) SaveX(ctx context.Context) *AuthIdentityChannel {
+ v, err := _c.Save(ctx)
+ if err != nil {
+ panic(err)
+ }
+ return v
+}
+
+// Exec executes the query.
+func (_c *AuthIdentityChannelCreate) Exec(ctx context.Context) error {
+ _, err := _c.Save(ctx)
+ return err
+}
+
+// ExecX is like Exec, but panics if an error occurs.
+func (_c *AuthIdentityChannelCreate) ExecX(ctx context.Context) {
+ if err := _c.Exec(ctx); err != nil {
+ panic(err)
+ }
+}
+
+// defaults sets the default values of the builder before save.
+func (_c *AuthIdentityChannelCreate) defaults() {
+ if _, ok := _c.mutation.CreatedAt(); !ok {
+ v := authidentitychannel.DefaultCreatedAt()
+ _c.mutation.SetCreatedAt(v)
+ }
+ if _, ok := _c.mutation.UpdatedAt(); !ok {
+ v := authidentitychannel.DefaultUpdatedAt()
+ _c.mutation.SetUpdatedAt(v)
+ }
+ if _, ok := _c.mutation.Metadata(); !ok {
+ v := authidentitychannel.DefaultMetadata()
+ _c.mutation.SetMetadata(v)
+ }
+}
+
+// check runs all checks and user-defined validators on the builder.
+func (_c *AuthIdentityChannelCreate) check() error {
+ if _, ok := _c.mutation.CreatedAt(); !ok {
+ return &ValidationError{Name: "created_at", err: errors.New(`ent: missing required field "AuthIdentityChannel.created_at"`)}
+ }
+ if _, ok := _c.mutation.UpdatedAt(); !ok {
+ return &ValidationError{Name: "updated_at", err: errors.New(`ent: missing required field "AuthIdentityChannel.updated_at"`)}
+ }
+ if _, ok := _c.mutation.IdentityID(); !ok {
+ return &ValidationError{Name: "identity_id", err: errors.New(`ent: missing required field "AuthIdentityChannel.identity_id"`)}
+ }
+ if _, ok := _c.mutation.ProviderType(); !ok {
+ return &ValidationError{Name: "provider_type", err: errors.New(`ent: missing required field "AuthIdentityChannel.provider_type"`)}
+ }
+ if v, ok := _c.mutation.ProviderType(); ok {
+ if err := authidentitychannel.ProviderTypeValidator(v); err != nil {
+ return &ValidationError{Name: "provider_type", err: fmt.Errorf(`ent: validator failed for field "AuthIdentityChannel.provider_type": %w`, err)}
+ }
+ }
+ if _, ok := _c.mutation.ProviderKey(); !ok {
+ return &ValidationError{Name: "provider_key", err: errors.New(`ent: missing required field "AuthIdentityChannel.provider_key"`)}
+ }
+ if v, ok := _c.mutation.ProviderKey(); ok {
+ if err := authidentitychannel.ProviderKeyValidator(v); err != nil {
+ return &ValidationError{Name: "provider_key", err: fmt.Errorf(`ent: validator failed for field "AuthIdentityChannel.provider_key": %w`, err)}
+ }
+ }
+ if _, ok := _c.mutation.Channel(); !ok {
+ return &ValidationError{Name: "channel", err: errors.New(`ent: missing required field "AuthIdentityChannel.channel"`)}
+ }
+ if v, ok := _c.mutation.Channel(); ok {
+ if err := authidentitychannel.ChannelValidator(v); err != nil {
+ return &ValidationError{Name: "channel", err: fmt.Errorf(`ent: validator failed for field "AuthIdentityChannel.channel": %w`, err)}
+ }
+ }
+ if _, ok := _c.mutation.ChannelAppID(); !ok {
+ return &ValidationError{Name: "channel_app_id", err: errors.New(`ent: missing required field "AuthIdentityChannel.channel_app_id"`)}
+ }
+ if v, ok := _c.mutation.ChannelAppID(); ok {
+ if err := authidentitychannel.ChannelAppIDValidator(v); err != nil {
+ return &ValidationError{Name: "channel_app_id", err: fmt.Errorf(`ent: validator failed for field "AuthIdentityChannel.channel_app_id": %w`, err)}
+ }
+ }
+ if _, ok := _c.mutation.ChannelSubject(); !ok {
+ return &ValidationError{Name: "channel_subject", err: errors.New(`ent: missing required field "AuthIdentityChannel.channel_subject"`)}
+ }
+ if v, ok := _c.mutation.ChannelSubject(); ok {
+ if err := authidentitychannel.ChannelSubjectValidator(v); err != nil {
+ return &ValidationError{Name: "channel_subject", err: fmt.Errorf(`ent: validator failed for field "AuthIdentityChannel.channel_subject": %w`, err)}
+ }
+ }
+ if _, ok := _c.mutation.Metadata(); !ok {
+ return &ValidationError{Name: "metadata", err: errors.New(`ent: missing required field "AuthIdentityChannel.metadata"`)}
+ }
+ if len(_c.mutation.IdentityIDs()) == 0 {
+ return &ValidationError{Name: "identity", err: errors.New(`ent: missing required edge "AuthIdentityChannel.identity"`)}
+ }
+ return nil
+}
+
+func (_c *AuthIdentityChannelCreate) sqlSave(ctx context.Context) (*AuthIdentityChannel, error) {
+ if err := _c.check(); err != nil {
+ return nil, err
+ }
+ _node, _spec := _c.createSpec()
+ if err := sqlgraph.CreateNode(ctx, _c.driver, _spec); err != nil {
+ if sqlgraph.IsConstraintError(err) {
+ err = &ConstraintError{msg: err.Error(), wrap: err}
+ }
+ return nil, err
+ }
+ id := _spec.ID.Value.(int64)
+ _node.ID = int64(id)
+ _c.mutation.id = &_node.ID
+ _c.mutation.done = true
+ return _node, nil
+}
+
+func (_c *AuthIdentityChannelCreate) createSpec() (*AuthIdentityChannel, *sqlgraph.CreateSpec) {
+ var (
+ _node = &AuthIdentityChannel{config: _c.config}
+ _spec = sqlgraph.NewCreateSpec(authidentitychannel.Table, sqlgraph.NewFieldSpec(authidentitychannel.FieldID, field.TypeInt64))
+ )
+ _spec.OnConflict = _c.conflict
+ if value, ok := _c.mutation.CreatedAt(); ok {
+ _spec.SetField(authidentitychannel.FieldCreatedAt, field.TypeTime, value)
+ _node.CreatedAt = value
+ }
+ if value, ok := _c.mutation.UpdatedAt(); ok {
+ _spec.SetField(authidentitychannel.FieldUpdatedAt, field.TypeTime, value)
+ _node.UpdatedAt = value
+ }
+ if value, ok := _c.mutation.ProviderType(); ok {
+ _spec.SetField(authidentitychannel.FieldProviderType, field.TypeString, value)
+ _node.ProviderType = value
+ }
+ if value, ok := _c.mutation.ProviderKey(); ok {
+ _spec.SetField(authidentitychannel.FieldProviderKey, field.TypeString, value)
+ _node.ProviderKey = value
+ }
+ if value, ok := _c.mutation.Channel(); ok {
+ _spec.SetField(authidentitychannel.FieldChannel, field.TypeString, value)
+ _node.Channel = value
+ }
+ if value, ok := _c.mutation.ChannelAppID(); ok {
+ _spec.SetField(authidentitychannel.FieldChannelAppID, field.TypeString, value)
+ _node.ChannelAppID = value
+ }
+ if value, ok := _c.mutation.ChannelSubject(); ok {
+ _spec.SetField(authidentitychannel.FieldChannelSubject, field.TypeString, value)
+ _node.ChannelSubject = value
+ }
+ if value, ok := _c.mutation.Metadata(); ok {
+ _spec.SetField(authidentitychannel.FieldMetadata, field.TypeJSON, value)
+ _node.Metadata = value
+ }
+ if nodes := _c.mutation.IdentityIDs(); len(nodes) > 0 {
+ edge := &sqlgraph.EdgeSpec{
+ Rel: sqlgraph.M2O,
+ Inverse: true,
+ Table: authidentitychannel.IdentityTable,
+ Columns: []string{authidentitychannel.IdentityColumn},
+ Bidi: false,
+ Target: &sqlgraph.EdgeTarget{
+ IDSpec: sqlgraph.NewFieldSpec(authidentity.FieldID, field.TypeInt64),
+ },
+ }
+ for _, k := range nodes {
+ edge.Target.Nodes = append(edge.Target.Nodes, k)
+ }
+ _node.IdentityID = nodes[0]
+ _spec.Edges = append(_spec.Edges, edge)
+ }
+ return _node, _spec
+}
+
+// OnConflict allows configuring the `ON CONFLICT` / `ON DUPLICATE KEY` clause
+// of the `INSERT` statement. For example:
+//
+// client.AuthIdentityChannel.Create().
+// SetCreatedAt(v).
+// OnConflict(
+// // Update the row with the new values
+// // the was proposed for insertion.
+// sql.ResolveWithNewValues(),
+// ).
+// // Override some of the fields with custom
+// // update values.
+// Update(func(u *ent.AuthIdentityChannelUpsert) {
+// SetCreatedAt(v+v).
+// }).
+// Exec(ctx)
+func (_c *AuthIdentityChannelCreate) OnConflict(opts ...sql.ConflictOption) *AuthIdentityChannelUpsertOne {
+ _c.conflict = opts
+ return &AuthIdentityChannelUpsertOne{
+ create: _c,
+ }
+}
+
+// OnConflictColumns calls `OnConflict` and configures the columns
+// as conflict target. Using this option is equivalent to using:
+//
+// client.AuthIdentityChannel.Create().
+// OnConflict(sql.ConflictColumns(columns...)).
+// Exec(ctx)
+func (_c *AuthIdentityChannelCreate) OnConflictColumns(columns ...string) *AuthIdentityChannelUpsertOne {
+ _c.conflict = append(_c.conflict, sql.ConflictColumns(columns...))
+ return &AuthIdentityChannelUpsertOne{
+ create: _c,
+ }
+}
+
+type (
+ // AuthIdentityChannelUpsertOne is the builder for "upsert"-ing
+ // one AuthIdentityChannel node.
+ AuthIdentityChannelUpsertOne struct {
+ create *AuthIdentityChannelCreate
+ }
+
+ // AuthIdentityChannelUpsert is the "OnConflict" setter.
+ AuthIdentityChannelUpsert struct {
+ *sql.UpdateSet
+ }
+)
+
+// SetUpdatedAt sets the "updated_at" field.
+func (u *AuthIdentityChannelUpsert) SetUpdatedAt(v time.Time) *AuthIdentityChannelUpsert {
+ u.Set(authidentitychannel.FieldUpdatedAt, v)
+ return u
+}
+
+// UpdateUpdatedAt sets the "updated_at" field to the value that was provided on create.
+func (u *AuthIdentityChannelUpsert) UpdateUpdatedAt() *AuthIdentityChannelUpsert {
+ u.SetExcluded(authidentitychannel.FieldUpdatedAt)
+ return u
+}
+
+// SetIdentityID sets the "identity_id" field.
+func (u *AuthIdentityChannelUpsert) SetIdentityID(v int64) *AuthIdentityChannelUpsert {
+ u.Set(authidentitychannel.FieldIdentityID, v)
+ return u
+}
+
+// UpdateIdentityID sets the "identity_id" field to the value that was provided on create.
+func (u *AuthIdentityChannelUpsert) UpdateIdentityID() *AuthIdentityChannelUpsert {
+ u.SetExcluded(authidentitychannel.FieldIdentityID)
+ return u
+}
+
+// SetProviderType sets the "provider_type" field.
+func (u *AuthIdentityChannelUpsert) SetProviderType(v string) *AuthIdentityChannelUpsert {
+ u.Set(authidentitychannel.FieldProviderType, v)
+ return u
+}
+
+// UpdateProviderType sets the "provider_type" field to the value that was provided on create.
+func (u *AuthIdentityChannelUpsert) UpdateProviderType() *AuthIdentityChannelUpsert {
+ u.SetExcluded(authidentitychannel.FieldProviderType)
+ return u
+}
+
+// SetProviderKey sets the "provider_key" field.
+func (u *AuthIdentityChannelUpsert) SetProviderKey(v string) *AuthIdentityChannelUpsert {
+ u.Set(authidentitychannel.FieldProviderKey, v)
+ return u
+}
+
+// UpdateProviderKey sets the "provider_key" field to the value that was provided on create.
+func (u *AuthIdentityChannelUpsert) UpdateProviderKey() *AuthIdentityChannelUpsert {
+ u.SetExcluded(authidentitychannel.FieldProviderKey)
+ return u
+}
+
+// SetChannel sets the "channel" field.
+func (u *AuthIdentityChannelUpsert) SetChannel(v string) *AuthIdentityChannelUpsert {
+ u.Set(authidentitychannel.FieldChannel, v)
+ return u
+}
+
+// UpdateChannel sets the "channel" field to the value that was provided on create.
+func (u *AuthIdentityChannelUpsert) UpdateChannel() *AuthIdentityChannelUpsert {
+ u.SetExcluded(authidentitychannel.FieldChannel)
+ return u
+}
+
+// SetChannelAppID sets the "channel_app_id" field.
+func (u *AuthIdentityChannelUpsert) SetChannelAppID(v string) *AuthIdentityChannelUpsert {
+ u.Set(authidentitychannel.FieldChannelAppID, v)
+ return u
+}
+
+// UpdateChannelAppID sets the "channel_app_id" field to the value that was provided on create.
+func (u *AuthIdentityChannelUpsert) UpdateChannelAppID() *AuthIdentityChannelUpsert {
+ u.SetExcluded(authidentitychannel.FieldChannelAppID)
+ return u
+}
+
+// SetChannelSubject sets the "channel_subject" field.
+func (u *AuthIdentityChannelUpsert) SetChannelSubject(v string) *AuthIdentityChannelUpsert {
+ u.Set(authidentitychannel.FieldChannelSubject, v)
+ return u
+}
+
+// UpdateChannelSubject sets the "channel_subject" field to the value that was provided on create.
+func (u *AuthIdentityChannelUpsert) UpdateChannelSubject() *AuthIdentityChannelUpsert {
+ u.SetExcluded(authidentitychannel.FieldChannelSubject)
+ return u
+}
+
+// SetMetadata sets the "metadata" field.
+func (u *AuthIdentityChannelUpsert) SetMetadata(v map[string]interface{}) *AuthIdentityChannelUpsert {
+ u.Set(authidentitychannel.FieldMetadata, v)
+ return u
+}
+
+// UpdateMetadata sets the "metadata" field to the value that was provided on create.
+func (u *AuthIdentityChannelUpsert) UpdateMetadata() *AuthIdentityChannelUpsert {
+ u.SetExcluded(authidentitychannel.FieldMetadata)
+ return u
+}
+
+// UpdateNewValues updates the mutable fields using the new values that were set on create.
+// Using this option is equivalent to using:
+//
+// client.AuthIdentityChannel.Create().
+// OnConflict(
+// sql.ResolveWithNewValues(),
+// ).
+// Exec(ctx)
+func (u *AuthIdentityChannelUpsertOne) UpdateNewValues() *AuthIdentityChannelUpsertOne {
+ u.create.conflict = append(u.create.conflict, sql.ResolveWithNewValues())
+ u.create.conflict = append(u.create.conflict, sql.ResolveWith(func(s *sql.UpdateSet) {
+ if _, exists := u.create.mutation.CreatedAt(); exists {
+ s.SetIgnore(authidentitychannel.FieldCreatedAt)
+ }
+ }))
+ return u
+}
+
+// Ignore sets each column to itself in case of conflict.
+// Using this option is equivalent to using:
+//
+// client.AuthIdentityChannel.Create().
+// OnConflict(sql.ResolveWithIgnore()).
+// Exec(ctx)
+func (u *AuthIdentityChannelUpsertOne) Ignore() *AuthIdentityChannelUpsertOne {
+ u.create.conflict = append(u.create.conflict, sql.ResolveWithIgnore())
+ return u
+}
+
+// DoNothing configures the conflict_action to `DO NOTHING`.
+// Supported only by SQLite and PostgreSQL.
+func (u *AuthIdentityChannelUpsertOne) DoNothing() *AuthIdentityChannelUpsertOne {
+ u.create.conflict = append(u.create.conflict, sql.DoNothing())
+ return u
+}
+
+// Update allows overriding fields `UPDATE` values. See the AuthIdentityChannelCreate.OnConflict
+// documentation for more info.
+func (u *AuthIdentityChannelUpsertOne) Update(set func(*AuthIdentityChannelUpsert)) *AuthIdentityChannelUpsertOne {
+ u.create.conflict = append(u.create.conflict, sql.ResolveWith(func(update *sql.UpdateSet) {
+ set(&AuthIdentityChannelUpsert{UpdateSet: update})
+ }))
+ return u
+}
+
+// SetUpdatedAt sets the "updated_at" field.
+func (u *AuthIdentityChannelUpsertOne) SetUpdatedAt(v time.Time) *AuthIdentityChannelUpsertOne {
+ return u.Update(func(s *AuthIdentityChannelUpsert) {
+ s.SetUpdatedAt(v)
+ })
+}
+
+// UpdateUpdatedAt sets the "updated_at" field to the value that was provided on create.
+func (u *AuthIdentityChannelUpsertOne) UpdateUpdatedAt() *AuthIdentityChannelUpsertOne {
+ return u.Update(func(s *AuthIdentityChannelUpsert) {
+ s.UpdateUpdatedAt()
+ })
+}
+
+// SetIdentityID sets the "identity_id" field.
+func (u *AuthIdentityChannelUpsertOne) SetIdentityID(v int64) *AuthIdentityChannelUpsertOne {
+ return u.Update(func(s *AuthIdentityChannelUpsert) {
+ s.SetIdentityID(v)
+ })
+}
+
+// UpdateIdentityID sets the "identity_id" field to the value that was provided on create.
+func (u *AuthIdentityChannelUpsertOne) UpdateIdentityID() *AuthIdentityChannelUpsertOne {
+ return u.Update(func(s *AuthIdentityChannelUpsert) {
+ s.UpdateIdentityID()
+ })
+}
+
+// SetProviderType sets the "provider_type" field.
+func (u *AuthIdentityChannelUpsertOne) SetProviderType(v string) *AuthIdentityChannelUpsertOne {
+ return u.Update(func(s *AuthIdentityChannelUpsert) {
+ s.SetProviderType(v)
+ })
+}
+
+// UpdateProviderType sets the "provider_type" field to the value that was provided on create.
+func (u *AuthIdentityChannelUpsertOne) UpdateProviderType() *AuthIdentityChannelUpsertOne {
+ return u.Update(func(s *AuthIdentityChannelUpsert) {
+ s.UpdateProviderType()
+ })
+}
+
+// SetProviderKey sets the "provider_key" field.
+func (u *AuthIdentityChannelUpsertOne) SetProviderKey(v string) *AuthIdentityChannelUpsertOne {
+ return u.Update(func(s *AuthIdentityChannelUpsert) {
+ s.SetProviderKey(v)
+ })
+}
+
+// UpdateProviderKey sets the "provider_key" field to the value that was provided on create.
+func (u *AuthIdentityChannelUpsertOne) UpdateProviderKey() *AuthIdentityChannelUpsertOne {
+ return u.Update(func(s *AuthIdentityChannelUpsert) {
+ s.UpdateProviderKey()
+ })
+}
+
+// SetChannel sets the "channel" field.
+func (u *AuthIdentityChannelUpsertOne) SetChannel(v string) *AuthIdentityChannelUpsertOne {
+ return u.Update(func(s *AuthIdentityChannelUpsert) {
+ s.SetChannel(v)
+ })
+}
+
+// UpdateChannel sets the "channel" field to the value that was provided on create.
+func (u *AuthIdentityChannelUpsertOne) UpdateChannel() *AuthIdentityChannelUpsertOne {
+ return u.Update(func(s *AuthIdentityChannelUpsert) {
+ s.UpdateChannel()
+ })
+}
+
+// SetChannelAppID sets the "channel_app_id" field.
+func (u *AuthIdentityChannelUpsertOne) SetChannelAppID(v string) *AuthIdentityChannelUpsertOne {
+ return u.Update(func(s *AuthIdentityChannelUpsert) {
+ s.SetChannelAppID(v)
+ })
+}
+
+// UpdateChannelAppID sets the "channel_app_id" field to the value that was provided on create.
+func (u *AuthIdentityChannelUpsertOne) UpdateChannelAppID() *AuthIdentityChannelUpsertOne {
+ return u.Update(func(s *AuthIdentityChannelUpsert) {
+ s.UpdateChannelAppID()
+ })
+}
+
+// SetChannelSubject sets the "channel_subject" field.
+func (u *AuthIdentityChannelUpsertOne) SetChannelSubject(v string) *AuthIdentityChannelUpsertOne {
+ return u.Update(func(s *AuthIdentityChannelUpsert) {
+ s.SetChannelSubject(v)
+ })
+}
+
+// UpdateChannelSubject sets the "channel_subject" field to the value that was provided on create.
+func (u *AuthIdentityChannelUpsertOne) UpdateChannelSubject() *AuthIdentityChannelUpsertOne {
+ return u.Update(func(s *AuthIdentityChannelUpsert) {
+ s.UpdateChannelSubject()
+ })
+}
+
+// SetMetadata sets the "metadata" field.
+func (u *AuthIdentityChannelUpsertOne) SetMetadata(v map[string]interface{}) *AuthIdentityChannelUpsertOne {
+ return u.Update(func(s *AuthIdentityChannelUpsert) {
+ s.SetMetadata(v)
+ })
+}
+
+// UpdateMetadata sets the "metadata" field to the value that was provided on create.
+func (u *AuthIdentityChannelUpsertOne) UpdateMetadata() *AuthIdentityChannelUpsertOne {
+ return u.Update(func(s *AuthIdentityChannelUpsert) {
+ s.UpdateMetadata()
+ })
+}
+
+// Exec executes the query.
+func (u *AuthIdentityChannelUpsertOne) Exec(ctx context.Context) error {
+ if len(u.create.conflict) == 0 {
+ return errors.New("ent: missing options for AuthIdentityChannelCreate.OnConflict")
+ }
+ return u.create.Exec(ctx)
+}
+
+// ExecX is like Exec, but panics if an error occurs.
+func (u *AuthIdentityChannelUpsertOne) ExecX(ctx context.Context) {
+ if err := u.create.Exec(ctx); err != nil {
+ panic(err)
+ }
+}
+
+// Exec executes the UPSERT query and returns the inserted/updated ID.
+func (u *AuthIdentityChannelUpsertOne) ID(ctx context.Context) (id int64, err error) {
+ node, err := u.create.Save(ctx)
+ if err != nil {
+ return id, err
+ }
+ return node.ID, nil
+}
+
+// IDX is like ID, but panics if an error occurs.
+func (u *AuthIdentityChannelUpsertOne) IDX(ctx context.Context) int64 {
+ id, err := u.ID(ctx)
+ if err != nil {
+ panic(err)
+ }
+ return id
+}
+
+// AuthIdentityChannelCreateBulk is the builder for creating many AuthIdentityChannel entities in bulk.
+type AuthIdentityChannelCreateBulk struct {
+ config
+ err error
+ builders []*AuthIdentityChannelCreate
+ conflict []sql.ConflictOption
+}
+
+// Save creates the AuthIdentityChannel entities in the database.
+func (_c *AuthIdentityChannelCreateBulk) Save(ctx context.Context) ([]*AuthIdentityChannel, error) {
+ if _c.err != nil {
+ return nil, _c.err
+ }
+ specs := make([]*sqlgraph.CreateSpec, len(_c.builders))
+ nodes := make([]*AuthIdentityChannel, len(_c.builders))
+ mutators := make([]Mutator, len(_c.builders))
+ for i := range _c.builders {
+ func(i int, root context.Context) {
+ builder := _c.builders[i]
+ builder.defaults()
+ var mut Mutator = MutateFunc(func(ctx context.Context, m Mutation) (Value, error) {
+ mutation, ok := m.(*AuthIdentityChannelMutation)
+ if !ok {
+ return nil, fmt.Errorf("unexpected mutation type %T", m)
+ }
+ if err := builder.check(); err != nil {
+ return nil, err
+ }
+ builder.mutation = mutation
+ var err error
+ nodes[i], specs[i] = builder.createSpec()
+ if i < len(mutators)-1 {
+ _, err = mutators[i+1].Mutate(root, _c.builders[i+1].mutation)
+ } else {
+ spec := &sqlgraph.BatchCreateSpec{Nodes: specs}
+ spec.OnConflict = _c.conflict
+ // Invoke the actual operation on the latest mutation in the chain.
+ if err = sqlgraph.BatchCreate(ctx, _c.driver, spec); err != nil {
+ if sqlgraph.IsConstraintError(err) {
+ err = &ConstraintError{msg: err.Error(), wrap: err}
+ }
+ }
+ }
+ if err != nil {
+ return nil, err
+ }
+ mutation.id = &nodes[i].ID
+ if specs[i].ID.Value != nil {
+ id := specs[i].ID.Value.(int64)
+ nodes[i].ID = int64(id)
+ }
+ mutation.done = true
+ return nodes[i], nil
+ })
+ for i := len(builder.hooks) - 1; i >= 0; i-- {
+ mut = builder.hooks[i](mut)
+ }
+ mutators[i] = mut
+ }(i, ctx)
+ }
+ if len(mutators) > 0 {
+ if _, err := mutators[0].Mutate(ctx, _c.builders[0].mutation); err != nil {
+ return nil, err
+ }
+ }
+ return nodes, nil
+}
+
+// SaveX is like Save, but panics if an error occurs.
+func (_c *AuthIdentityChannelCreateBulk) SaveX(ctx context.Context) []*AuthIdentityChannel {
+ v, err := _c.Save(ctx)
+ if err != nil {
+ panic(err)
+ }
+ return v
+}
+
+// Exec executes the query.
+func (_c *AuthIdentityChannelCreateBulk) Exec(ctx context.Context) error {
+ _, err := _c.Save(ctx)
+ return err
+}
+
+// ExecX is like Exec, but panics if an error occurs.
+func (_c *AuthIdentityChannelCreateBulk) ExecX(ctx context.Context) {
+ if err := _c.Exec(ctx); err != nil {
+ panic(err)
+ }
+}
+
+// OnConflict allows configuring the `ON CONFLICT` / `ON DUPLICATE KEY` clause
+// of the `INSERT` statement. For example:
+//
+// client.AuthIdentityChannel.CreateBulk(builders...).
+// OnConflict(
+// // Update the row with the new values
+// // the was proposed for insertion.
+// sql.ResolveWithNewValues(),
+// ).
+// // Override some of the fields with custom
+// // update values.
+// Update(func(u *ent.AuthIdentityChannelUpsert) {
+// SetCreatedAt(v+v).
+// }).
+// Exec(ctx)
+func (_c *AuthIdentityChannelCreateBulk) OnConflict(opts ...sql.ConflictOption) *AuthIdentityChannelUpsertBulk {
+ _c.conflict = opts
+ return &AuthIdentityChannelUpsertBulk{
+ create: _c,
+ }
+}
+
+// OnConflictColumns calls `OnConflict` and configures the columns
+// as conflict target. Using this option is equivalent to using:
+//
+// client.AuthIdentityChannel.Create().
+// OnConflict(sql.ConflictColumns(columns...)).
+// Exec(ctx)
+func (_c *AuthIdentityChannelCreateBulk) OnConflictColumns(columns ...string) *AuthIdentityChannelUpsertBulk {
+ _c.conflict = append(_c.conflict, sql.ConflictColumns(columns...))
+ return &AuthIdentityChannelUpsertBulk{
+ create: _c,
+ }
+}
+
+// AuthIdentityChannelUpsertBulk is the builder for "upsert"-ing
+// a bulk of AuthIdentityChannel nodes.
+type AuthIdentityChannelUpsertBulk struct {
+ create *AuthIdentityChannelCreateBulk
+}
+
+// UpdateNewValues updates the mutable fields using the new values that
+// were set on create. Using this option is equivalent to using:
+//
+// client.AuthIdentityChannel.Create().
+// OnConflict(
+// sql.ResolveWithNewValues(),
+// ).
+// Exec(ctx)
+func (u *AuthIdentityChannelUpsertBulk) UpdateNewValues() *AuthIdentityChannelUpsertBulk {
+ u.create.conflict = append(u.create.conflict, sql.ResolveWithNewValues())
+ u.create.conflict = append(u.create.conflict, sql.ResolveWith(func(s *sql.UpdateSet) {
+ for _, b := range u.create.builders {
+ if _, exists := b.mutation.CreatedAt(); exists {
+ s.SetIgnore(authidentitychannel.FieldCreatedAt)
+ }
+ }
+ }))
+ return u
+}
+
+// Ignore sets each column to itself in case of conflict.
+// Using this option is equivalent to using:
+//
+// client.AuthIdentityChannel.Create().
+// OnConflict(sql.ResolveWithIgnore()).
+// Exec(ctx)
+func (u *AuthIdentityChannelUpsertBulk) Ignore() *AuthIdentityChannelUpsertBulk {
+ u.create.conflict = append(u.create.conflict, sql.ResolveWithIgnore())
+ return u
+}
+
+// DoNothing configures the conflict_action to `DO NOTHING`.
+// Supported only by SQLite and PostgreSQL.
+func (u *AuthIdentityChannelUpsertBulk) DoNothing() *AuthIdentityChannelUpsertBulk {
+ u.create.conflict = append(u.create.conflict, sql.DoNothing())
+ return u
+}
+
+// Update allows overriding fields `UPDATE` values. See the AuthIdentityChannelCreateBulk.OnConflict
+// documentation for more info.
+func (u *AuthIdentityChannelUpsertBulk) Update(set func(*AuthIdentityChannelUpsert)) *AuthIdentityChannelUpsertBulk {
+ u.create.conflict = append(u.create.conflict, sql.ResolveWith(func(update *sql.UpdateSet) {
+ set(&AuthIdentityChannelUpsert{UpdateSet: update})
+ }))
+ return u
+}
+
+// SetUpdatedAt sets the "updated_at" field.
+func (u *AuthIdentityChannelUpsertBulk) SetUpdatedAt(v time.Time) *AuthIdentityChannelUpsertBulk {
+ return u.Update(func(s *AuthIdentityChannelUpsert) {
+ s.SetUpdatedAt(v)
+ })
+}
+
+// UpdateUpdatedAt sets the "updated_at" field to the value that was provided on create.
+func (u *AuthIdentityChannelUpsertBulk) UpdateUpdatedAt() *AuthIdentityChannelUpsertBulk {
+ return u.Update(func(s *AuthIdentityChannelUpsert) {
+ s.UpdateUpdatedAt()
+ })
+}
+
+// SetIdentityID sets the "identity_id" field.
+func (u *AuthIdentityChannelUpsertBulk) SetIdentityID(v int64) *AuthIdentityChannelUpsertBulk {
+ return u.Update(func(s *AuthIdentityChannelUpsert) {
+ s.SetIdentityID(v)
+ })
+}
+
+// UpdateIdentityID sets the "identity_id" field to the value that was provided on create.
+func (u *AuthIdentityChannelUpsertBulk) UpdateIdentityID() *AuthIdentityChannelUpsertBulk {
+ return u.Update(func(s *AuthIdentityChannelUpsert) {
+ s.UpdateIdentityID()
+ })
+}
+
+// SetProviderType sets the "provider_type" field.
+func (u *AuthIdentityChannelUpsertBulk) SetProviderType(v string) *AuthIdentityChannelUpsertBulk {
+ return u.Update(func(s *AuthIdentityChannelUpsert) {
+ s.SetProviderType(v)
+ })
+}
+
+// UpdateProviderType sets the "provider_type" field to the value that was provided on create.
+func (u *AuthIdentityChannelUpsertBulk) UpdateProviderType() *AuthIdentityChannelUpsertBulk {
+ return u.Update(func(s *AuthIdentityChannelUpsert) {
+ s.UpdateProviderType()
+ })
+}
+
+// SetProviderKey sets the "provider_key" field.
+func (u *AuthIdentityChannelUpsertBulk) SetProviderKey(v string) *AuthIdentityChannelUpsertBulk {
+ return u.Update(func(s *AuthIdentityChannelUpsert) {
+ s.SetProviderKey(v)
+ })
+}
+
+// UpdateProviderKey sets the "provider_key" field to the value that was provided on create.
+func (u *AuthIdentityChannelUpsertBulk) UpdateProviderKey() *AuthIdentityChannelUpsertBulk {
+ return u.Update(func(s *AuthIdentityChannelUpsert) {
+ s.UpdateProviderKey()
+ })
+}
+
+// SetChannel sets the "channel" field.
+func (u *AuthIdentityChannelUpsertBulk) SetChannel(v string) *AuthIdentityChannelUpsertBulk {
+ return u.Update(func(s *AuthIdentityChannelUpsert) {
+ s.SetChannel(v)
+ })
+}
+
+// UpdateChannel sets the "channel" field to the value that was provided on create.
+func (u *AuthIdentityChannelUpsertBulk) UpdateChannel() *AuthIdentityChannelUpsertBulk {
+ return u.Update(func(s *AuthIdentityChannelUpsert) {
+ s.UpdateChannel()
+ })
+}
+
+// SetChannelAppID sets the "channel_app_id" field.
+func (u *AuthIdentityChannelUpsertBulk) SetChannelAppID(v string) *AuthIdentityChannelUpsertBulk {
+ return u.Update(func(s *AuthIdentityChannelUpsert) {
+ s.SetChannelAppID(v)
+ })
+}
+
+// UpdateChannelAppID sets the "channel_app_id" field to the value that was provided on create.
+func (u *AuthIdentityChannelUpsertBulk) UpdateChannelAppID() *AuthIdentityChannelUpsertBulk {
+ return u.Update(func(s *AuthIdentityChannelUpsert) {
+ s.UpdateChannelAppID()
+ })
+}
+
+// SetChannelSubject sets the "channel_subject" field.
+func (u *AuthIdentityChannelUpsertBulk) SetChannelSubject(v string) *AuthIdentityChannelUpsertBulk {
+ return u.Update(func(s *AuthIdentityChannelUpsert) {
+ s.SetChannelSubject(v)
+ })
+}
+
+// UpdateChannelSubject sets the "channel_subject" field to the value that was provided on create.
+func (u *AuthIdentityChannelUpsertBulk) UpdateChannelSubject() *AuthIdentityChannelUpsertBulk {
+ return u.Update(func(s *AuthIdentityChannelUpsert) {
+ s.UpdateChannelSubject()
+ })
+}
+
+// SetMetadata sets the "metadata" field.
+func (u *AuthIdentityChannelUpsertBulk) SetMetadata(v map[string]interface{}) *AuthIdentityChannelUpsertBulk {
+ return u.Update(func(s *AuthIdentityChannelUpsert) {
+ s.SetMetadata(v)
+ })
+}
+
+// UpdateMetadata sets the "metadata" field to the value that was provided on create.
+func (u *AuthIdentityChannelUpsertBulk) UpdateMetadata() *AuthIdentityChannelUpsertBulk {
+ return u.Update(func(s *AuthIdentityChannelUpsert) {
+ s.UpdateMetadata()
+ })
+}
+
+// Exec executes the query.
+func (u *AuthIdentityChannelUpsertBulk) Exec(ctx context.Context) error {
+ if u.create.err != nil {
+ return u.create.err
+ }
+ for i, b := range u.create.builders {
+ if len(b.conflict) != 0 {
+ return fmt.Errorf("ent: OnConflict was set for builder %d. Set it on the AuthIdentityChannelCreateBulk instead", i)
+ }
+ }
+ if len(u.create.conflict) == 0 {
+ return errors.New("ent: missing options for AuthIdentityChannelCreateBulk.OnConflict")
+ }
+ return u.create.Exec(ctx)
+}
+
+// ExecX is like Exec, but panics if an error occurs.
+func (u *AuthIdentityChannelUpsertBulk) ExecX(ctx context.Context) {
+ if err := u.create.Exec(ctx); err != nil {
+ panic(err)
+ }
+}
diff --git a/backend/ent/authidentitychannel_delete.go b/backend/ent/authidentitychannel_delete.go
new file mode 100644
index 00000000..1a4acac5
--- /dev/null
+++ b/backend/ent/authidentitychannel_delete.go
@@ -0,0 +1,88 @@
+// Code generated by ent, DO NOT EDIT.
+
+package ent
+
+import (
+ "context"
+
+ "entgo.io/ent/dialect/sql"
+ "entgo.io/ent/dialect/sql/sqlgraph"
+ "entgo.io/ent/schema/field"
+ "github.com/Wei-Shaw/sub2api/ent/authidentitychannel"
+ "github.com/Wei-Shaw/sub2api/ent/predicate"
+)
+
+// AuthIdentityChannelDelete is the builder for deleting a AuthIdentityChannel entity.
+type AuthIdentityChannelDelete struct {
+ config
+ hooks []Hook
+ mutation *AuthIdentityChannelMutation
+}
+
+// Where appends a list predicates to the AuthIdentityChannelDelete builder.
+func (_d *AuthIdentityChannelDelete) Where(ps ...predicate.AuthIdentityChannel) *AuthIdentityChannelDelete {
+ _d.mutation.Where(ps...)
+ return _d
+}
+
+// Exec executes the deletion query and returns how many vertices were deleted.
+func (_d *AuthIdentityChannelDelete) Exec(ctx context.Context) (int, error) {
+ return withHooks(ctx, _d.sqlExec, _d.mutation, _d.hooks)
+}
+
+// ExecX is like Exec, but panics if an error occurs.
+func (_d *AuthIdentityChannelDelete) ExecX(ctx context.Context) int {
+ n, err := _d.Exec(ctx)
+ if err != nil {
+ panic(err)
+ }
+ return n
+}
+
+func (_d *AuthIdentityChannelDelete) sqlExec(ctx context.Context) (int, error) {
+ _spec := sqlgraph.NewDeleteSpec(authidentitychannel.Table, sqlgraph.NewFieldSpec(authidentitychannel.FieldID, field.TypeInt64))
+ if ps := _d.mutation.predicates; len(ps) > 0 {
+ _spec.Predicate = func(selector *sql.Selector) {
+ for i := range ps {
+ ps[i](selector)
+ }
+ }
+ }
+ affected, err := sqlgraph.DeleteNodes(ctx, _d.driver, _spec)
+ if err != nil && sqlgraph.IsConstraintError(err) {
+ err = &ConstraintError{msg: err.Error(), wrap: err}
+ }
+ _d.mutation.done = true
+ return affected, err
+}
+
+// AuthIdentityChannelDeleteOne is the builder for deleting a single AuthIdentityChannel entity.
+type AuthIdentityChannelDeleteOne struct {
+ _d *AuthIdentityChannelDelete
+}
+
+// Where appends a list predicates to the AuthIdentityChannelDelete builder.
+func (_d *AuthIdentityChannelDeleteOne) Where(ps ...predicate.AuthIdentityChannel) *AuthIdentityChannelDeleteOne {
+ _d._d.mutation.Where(ps...)
+ return _d
+}
+
+// Exec executes the deletion query.
+func (_d *AuthIdentityChannelDeleteOne) Exec(ctx context.Context) error {
+ n, err := _d._d.Exec(ctx)
+ switch {
+ case err != nil:
+ return err
+ case n == 0:
+ return &NotFoundError{authidentitychannel.Label}
+ default:
+ return nil
+ }
+}
+
+// ExecX is like Exec, but panics if an error occurs.
+func (_d *AuthIdentityChannelDeleteOne) ExecX(ctx context.Context) {
+ if err := _d.Exec(ctx); err != nil {
+ panic(err)
+ }
+}
diff --git a/backend/ent/authidentitychannel_query.go b/backend/ent/authidentitychannel_query.go
new file mode 100644
index 00000000..7a202b7f
--- /dev/null
+++ b/backend/ent/authidentitychannel_query.go
@@ -0,0 +1,643 @@
+// Code generated by ent, DO NOT EDIT.
+
+package ent
+
+import (
+ "context"
+ "fmt"
+ "math"
+
+ "entgo.io/ent"
+ "entgo.io/ent/dialect"
+ "entgo.io/ent/dialect/sql"
+ "entgo.io/ent/dialect/sql/sqlgraph"
+ "entgo.io/ent/schema/field"
+ "github.com/Wei-Shaw/sub2api/ent/authidentity"
+ "github.com/Wei-Shaw/sub2api/ent/authidentitychannel"
+ "github.com/Wei-Shaw/sub2api/ent/predicate"
+)
+
+// AuthIdentityChannelQuery is the builder for querying AuthIdentityChannel entities.
+type AuthIdentityChannelQuery struct {
+ config
+ ctx *QueryContext
+ order []authidentitychannel.OrderOption
+ inters []Interceptor
+ predicates []predicate.AuthIdentityChannel
+ withIdentity *AuthIdentityQuery
+ modifiers []func(*sql.Selector)
+ // intermediate query (i.e. traversal path).
+ sql *sql.Selector
+ path func(context.Context) (*sql.Selector, error)
+}
+
+// Where adds a new predicate for the AuthIdentityChannelQuery builder.
+func (_q *AuthIdentityChannelQuery) Where(ps ...predicate.AuthIdentityChannel) *AuthIdentityChannelQuery {
+ _q.predicates = append(_q.predicates, ps...)
+ return _q
+}
+
+// Limit the number of records to be returned by this query.
+func (_q *AuthIdentityChannelQuery) Limit(limit int) *AuthIdentityChannelQuery {
+ _q.ctx.Limit = &limit
+ return _q
+}
+
+// Offset to start from.
+func (_q *AuthIdentityChannelQuery) Offset(offset int) *AuthIdentityChannelQuery {
+ _q.ctx.Offset = &offset
+ return _q
+}
+
+// Unique configures the query builder to filter duplicate records on query.
+// By default, unique is set to true, and can be disabled using this method.
+func (_q *AuthIdentityChannelQuery) Unique(unique bool) *AuthIdentityChannelQuery {
+ _q.ctx.Unique = &unique
+ return _q
+}
+
+// Order specifies how the records should be ordered.
+func (_q *AuthIdentityChannelQuery) Order(o ...authidentitychannel.OrderOption) *AuthIdentityChannelQuery {
+ _q.order = append(_q.order, o...)
+ return _q
+}
+
+// QueryIdentity chains the current query on the "identity" edge.
+func (_q *AuthIdentityChannelQuery) QueryIdentity() *AuthIdentityQuery {
+ query := (&AuthIdentityClient{config: _q.config}).Query()
+ query.path = func(ctx context.Context) (fromU *sql.Selector, err error) {
+ if err := _q.prepareQuery(ctx); err != nil {
+ return nil, err
+ }
+ selector := _q.sqlQuery(ctx)
+ if err := selector.Err(); err != nil {
+ return nil, err
+ }
+ step := sqlgraph.NewStep(
+ sqlgraph.From(authidentitychannel.Table, authidentitychannel.FieldID, selector),
+ sqlgraph.To(authidentity.Table, authidentity.FieldID),
+ sqlgraph.Edge(sqlgraph.M2O, true, authidentitychannel.IdentityTable, authidentitychannel.IdentityColumn),
+ )
+ fromU = sqlgraph.SetNeighbors(_q.driver.Dialect(), step)
+ return fromU, nil
+ }
+ return query
+}
+
+// First returns the first AuthIdentityChannel entity from the query.
+// Returns a *NotFoundError when no AuthIdentityChannel was found.
+func (_q *AuthIdentityChannelQuery) First(ctx context.Context) (*AuthIdentityChannel, error) {
+ nodes, err := _q.Limit(1).All(setContextOp(ctx, _q.ctx, ent.OpQueryFirst))
+ if err != nil {
+ return nil, err
+ }
+ if len(nodes) == 0 {
+ return nil, &NotFoundError{authidentitychannel.Label}
+ }
+ return nodes[0], nil
+}
+
+// FirstX is like First, but panics if an error occurs.
+func (_q *AuthIdentityChannelQuery) FirstX(ctx context.Context) *AuthIdentityChannel {
+ node, err := _q.First(ctx)
+ if err != nil && !IsNotFound(err) {
+ panic(err)
+ }
+ return node
+}
+
+// FirstID returns the first AuthIdentityChannel ID from the query.
+// Returns a *NotFoundError when no AuthIdentityChannel ID was found.
+func (_q *AuthIdentityChannelQuery) FirstID(ctx context.Context) (id int64, err error) {
+ var ids []int64
+ if ids, err = _q.Limit(1).IDs(setContextOp(ctx, _q.ctx, ent.OpQueryFirstID)); err != nil {
+ return
+ }
+ if len(ids) == 0 {
+ err = &NotFoundError{authidentitychannel.Label}
+ return
+ }
+ return ids[0], nil
+}
+
+// FirstIDX is like FirstID, but panics if an error occurs.
+func (_q *AuthIdentityChannelQuery) FirstIDX(ctx context.Context) int64 {
+ id, err := _q.FirstID(ctx)
+ if err != nil && !IsNotFound(err) {
+ panic(err)
+ }
+ return id
+}
+
+// Only returns a single AuthIdentityChannel entity found by the query, ensuring it only returns one.
+// Returns a *NotSingularError when more than one AuthIdentityChannel entity is found.
+// Returns a *NotFoundError when no AuthIdentityChannel entities are found.
+func (_q *AuthIdentityChannelQuery) Only(ctx context.Context) (*AuthIdentityChannel, error) {
+ nodes, err := _q.Limit(2).All(setContextOp(ctx, _q.ctx, ent.OpQueryOnly))
+ if err != nil {
+ return nil, err
+ }
+ switch len(nodes) {
+ case 1:
+ return nodes[0], nil
+ case 0:
+ return nil, &NotFoundError{authidentitychannel.Label}
+ default:
+ return nil, &NotSingularError{authidentitychannel.Label}
+ }
+}
+
+// OnlyX is like Only, but panics if an error occurs.
+func (_q *AuthIdentityChannelQuery) OnlyX(ctx context.Context) *AuthIdentityChannel {
+ node, err := _q.Only(ctx)
+ if err != nil {
+ panic(err)
+ }
+ return node
+}
+
+// OnlyID is like Only, but returns the only AuthIdentityChannel ID in the query.
+// Returns a *NotSingularError when more than one AuthIdentityChannel ID is found.
+// Returns a *NotFoundError when no entities are found.
+func (_q *AuthIdentityChannelQuery) OnlyID(ctx context.Context) (id int64, err error) {
+ var ids []int64
+ if ids, err = _q.Limit(2).IDs(setContextOp(ctx, _q.ctx, ent.OpQueryOnlyID)); err != nil {
+ return
+ }
+ switch len(ids) {
+ case 1:
+ id = ids[0]
+ case 0:
+ err = &NotFoundError{authidentitychannel.Label}
+ default:
+ err = &NotSingularError{authidentitychannel.Label}
+ }
+ return
+}
+
+// OnlyIDX is like OnlyID, but panics if an error occurs.
+func (_q *AuthIdentityChannelQuery) OnlyIDX(ctx context.Context) int64 {
+ id, err := _q.OnlyID(ctx)
+ if err != nil {
+ panic(err)
+ }
+ return id
+}
+
+// All executes the query and returns a list of AuthIdentityChannels.
+func (_q *AuthIdentityChannelQuery) All(ctx context.Context) ([]*AuthIdentityChannel, error) {
+ ctx = setContextOp(ctx, _q.ctx, ent.OpQueryAll)
+ if err := _q.prepareQuery(ctx); err != nil {
+ return nil, err
+ }
+ qr := querierAll[[]*AuthIdentityChannel, *AuthIdentityChannelQuery]()
+ return withInterceptors[[]*AuthIdentityChannel](ctx, _q, qr, _q.inters)
+}
+
+// AllX is like All, but panics if an error occurs.
+func (_q *AuthIdentityChannelQuery) AllX(ctx context.Context) []*AuthIdentityChannel {
+ nodes, err := _q.All(ctx)
+ if err != nil {
+ panic(err)
+ }
+ return nodes
+}
+
+// IDs executes the query and returns a list of AuthIdentityChannel IDs.
+func (_q *AuthIdentityChannelQuery) IDs(ctx context.Context) (ids []int64, err error) {
+ if _q.ctx.Unique == nil && _q.path != nil {
+ _q.Unique(true)
+ }
+ ctx = setContextOp(ctx, _q.ctx, ent.OpQueryIDs)
+ if err = _q.Select(authidentitychannel.FieldID).Scan(ctx, &ids); err != nil {
+ return nil, err
+ }
+ return ids, nil
+}
+
+// IDsX is like IDs, but panics if an error occurs.
+func (_q *AuthIdentityChannelQuery) IDsX(ctx context.Context) []int64 {
+ ids, err := _q.IDs(ctx)
+ if err != nil {
+ panic(err)
+ }
+ return ids
+}
+
+// Count returns the count of the given query.
+func (_q *AuthIdentityChannelQuery) Count(ctx context.Context) (int, error) {
+ ctx = setContextOp(ctx, _q.ctx, ent.OpQueryCount)
+ if err := _q.prepareQuery(ctx); err != nil {
+ return 0, err
+ }
+ return withInterceptors[int](ctx, _q, querierCount[*AuthIdentityChannelQuery](), _q.inters)
+}
+
+// CountX is like Count, but panics if an error occurs.
+func (_q *AuthIdentityChannelQuery) CountX(ctx context.Context) int {
+ count, err := _q.Count(ctx)
+ if err != nil {
+ panic(err)
+ }
+ return count
+}
+
+// Exist returns true if the query has elements in the graph.
+func (_q *AuthIdentityChannelQuery) Exist(ctx context.Context) (bool, error) {
+ ctx = setContextOp(ctx, _q.ctx, ent.OpQueryExist)
+ switch _, err := _q.FirstID(ctx); {
+ case IsNotFound(err):
+ return false, nil
+ case err != nil:
+ return false, fmt.Errorf("ent: check existence: %w", err)
+ default:
+ return true, nil
+ }
+}
+
+// ExistX is like Exist, but panics if an error occurs.
+func (_q *AuthIdentityChannelQuery) ExistX(ctx context.Context) bool {
+ exist, err := _q.Exist(ctx)
+ if err != nil {
+ panic(err)
+ }
+ return exist
+}
+
+// Clone returns a duplicate of the AuthIdentityChannelQuery builder, including all associated steps. It can be
+// used to prepare common query builders and use them differently after the clone is made.
+func (_q *AuthIdentityChannelQuery) Clone() *AuthIdentityChannelQuery {
+ if _q == nil {
+ return nil
+ }
+ return &AuthIdentityChannelQuery{
+ config: _q.config,
+ ctx: _q.ctx.Clone(),
+ order: append([]authidentitychannel.OrderOption{}, _q.order...),
+ inters: append([]Interceptor{}, _q.inters...),
+ predicates: append([]predicate.AuthIdentityChannel{}, _q.predicates...),
+ withIdentity: _q.withIdentity.Clone(),
+ // clone intermediate query.
+ sql: _q.sql.Clone(),
+ path: _q.path,
+ }
+}
+
+// WithIdentity tells the query-builder to eager-load the nodes that are connected to
+// the "identity" edge. The optional arguments are used to configure the query builder of the edge.
+func (_q *AuthIdentityChannelQuery) WithIdentity(opts ...func(*AuthIdentityQuery)) *AuthIdentityChannelQuery {
+ query := (&AuthIdentityClient{config: _q.config}).Query()
+ for _, opt := range opts {
+ opt(query)
+ }
+ _q.withIdentity = query
+ return _q
+}
+
+// GroupBy is used to group vertices by one or more fields/columns.
+// It is often used with aggregate functions, like: count, max, mean, min, sum.
+//
+// Example:
+//
+// var v []struct {
+// CreatedAt time.Time `json:"created_at,omitempty"`
+// Count int `json:"count,omitempty"`
+// }
+//
+// client.AuthIdentityChannel.Query().
+// GroupBy(authidentitychannel.FieldCreatedAt).
+// Aggregate(ent.Count()).
+// Scan(ctx, &v)
+func (_q *AuthIdentityChannelQuery) GroupBy(field string, fields ...string) *AuthIdentityChannelGroupBy {
+ _q.ctx.Fields = append([]string{field}, fields...)
+ grbuild := &AuthIdentityChannelGroupBy{build: _q}
+ grbuild.flds = &_q.ctx.Fields
+ grbuild.label = authidentitychannel.Label
+ grbuild.scan = grbuild.Scan
+ return grbuild
+}
+
+// Select allows the selection one or more fields/columns for the given query,
+// instead of selecting all fields in the entity.
+//
+// Example:
+//
+// var v []struct {
+// CreatedAt time.Time `json:"created_at,omitempty"`
+// }
+//
+// client.AuthIdentityChannel.Query().
+// Select(authidentitychannel.FieldCreatedAt).
+// Scan(ctx, &v)
+func (_q *AuthIdentityChannelQuery) Select(fields ...string) *AuthIdentityChannelSelect {
+ _q.ctx.Fields = append(_q.ctx.Fields, fields...)
+ sbuild := &AuthIdentityChannelSelect{AuthIdentityChannelQuery: _q}
+ sbuild.label = authidentitychannel.Label
+ sbuild.flds, sbuild.scan = &_q.ctx.Fields, sbuild.Scan
+ return sbuild
+}
+
+// Aggregate returns a AuthIdentityChannelSelect configured with the given aggregations.
+func (_q *AuthIdentityChannelQuery) Aggregate(fns ...AggregateFunc) *AuthIdentityChannelSelect {
+ return _q.Select().Aggregate(fns...)
+}
+
+func (_q *AuthIdentityChannelQuery) prepareQuery(ctx context.Context) error {
+ for _, inter := range _q.inters {
+ if inter == nil {
+ return fmt.Errorf("ent: uninitialized interceptor (forgotten import ent/runtime?)")
+ }
+ if trv, ok := inter.(Traverser); ok {
+ if err := trv.Traverse(ctx, _q); err != nil {
+ return err
+ }
+ }
+ }
+ for _, f := range _q.ctx.Fields {
+ if !authidentitychannel.ValidColumn(f) {
+ return &ValidationError{Name: f, err: fmt.Errorf("ent: invalid field %q for query", f)}
+ }
+ }
+ if _q.path != nil {
+ prev, err := _q.path(ctx)
+ if err != nil {
+ return err
+ }
+ _q.sql = prev
+ }
+ return nil
+}
+
+func (_q *AuthIdentityChannelQuery) sqlAll(ctx context.Context, hooks ...queryHook) ([]*AuthIdentityChannel, error) {
+ var (
+ nodes = []*AuthIdentityChannel{}
+ _spec = _q.querySpec()
+ loadedTypes = [1]bool{
+ _q.withIdentity != nil,
+ }
+ )
+ _spec.ScanValues = func(columns []string) ([]any, error) {
+ return (*AuthIdentityChannel).scanValues(nil, columns)
+ }
+ _spec.Assign = func(columns []string, values []any) error {
+ node := &AuthIdentityChannel{config: _q.config}
+ nodes = append(nodes, node)
+ node.Edges.loadedTypes = loadedTypes
+ return node.assignValues(columns, values)
+ }
+ if len(_q.modifiers) > 0 {
+ _spec.Modifiers = _q.modifiers
+ }
+ for i := range hooks {
+ hooks[i](ctx, _spec)
+ }
+ if err := sqlgraph.QueryNodes(ctx, _q.driver, _spec); err != nil {
+ return nil, err
+ }
+ if len(nodes) == 0 {
+ return nodes, nil
+ }
+ if query := _q.withIdentity; query != nil {
+ if err := _q.loadIdentity(ctx, query, nodes, nil,
+ func(n *AuthIdentityChannel, e *AuthIdentity) { n.Edges.Identity = e }); err != nil {
+ return nil, err
+ }
+ }
+ return nodes, nil
+}
+
+func (_q *AuthIdentityChannelQuery) loadIdentity(ctx context.Context, query *AuthIdentityQuery, nodes []*AuthIdentityChannel, init func(*AuthIdentityChannel), assign func(*AuthIdentityChannel, *AuthIdentity)) error {
+ ids := make([]int64, 0, len(nodes))
+ nodeids := make(map[int64][]*AuthIdentityChannel)
+ for i := range nodes {
+ fk := nodes[i].IdentityID
+ if _, ok := nodeids[fk]; !ok {
+ ids = append(ids, fk)
+ }
+ nodeids[fk] = append(nodeids[fk], nodes[i])
+ }
+ if len(ids) == 0 {
+ return nil
+ }
+ query.Where(authidentity.IDIn(ids...))
+ neighbors, err := query.All(ctx)
+ if err != nil {
+ return err
+ }
+ for _, n := range neighbors {
+ nodes, ok := nodeids[n.ID]
+ if !ok {
+ return fmt.Errorf(`unexpected foreign-key "identity_id" returned %v`, n.ID)
+ }
+ for i := range nodes {
+ assign(nodes[i], n)
+ }
+ }
+ return nil
+}
+
+func (_q *AuthIdentityChannelQuery) sqlCount(ctx context.Context) (int, error) {
+ _spec := _q.querySpec()
+ if len(_q.modifiers) > 0 {
+ _spec.Modifiers = _q.modifiers
+ }
+ _spec.Node.Columns = _q.ctx.Fields
+ if len(_q.ctx.Fields) > 0 {
+ _spec.Unique = _q.ctx.Unique != nil && *_q.ctx.Unique
+ }
+ return sqlgraph.CountNodes(ctx, _q.driver, _spec)
+}
+
+func (_q *AuthIdentityChannelQuery) querySpec() *sqlgraph.QuerySpec {
+ _spec := sqlgraph.NewQuerySpec(authidentitychannel.Table, authidentitychannel.Columns, sqlgraph.NewFieldSpec(authidentitychannel.FieldID, field.TypeInt64))
+ _spec.From = _q.sql
+ if unique := _q.ctx.Unique; unique != nil {
+ _spec.Unique = *unique
+ } else if _q.path != nil {
+ _spec.Unique = true
+ }
+ if fields := _q.ctx.Fields; len(fields) > 0 {
+ _spec.Node.Columns = make([]string, 0, len(fields))
+ _spec.Node.Columns = append(_spec.Node.Columns, authidentitychannel.FieldID)
+ for i := range fields {
+ if fields[i] != authidentitychannel.FieldID {
+ _spec.Node.Columns = append(_spec.Node.Columns, fields[i])
+ }
+ }
+ if _q.withIdentity != nil {
+ _spec.Node.AddColumnOnce(authidentitychannel.FieldIdentityID)
+ }
+ }
+ if ps := _q.predicates; len(ps) > 0 {
+ _spec.Predicate = func(selector *sql.Selector) {
+ for i := range ps {
+ ps[i](selector)
+ }
+ }
+ }
+ if limit := _q.ctx.Limit; limit != nil {
+ _spec.Limit = *limit
+ }
+ if offset := _q.ctx.Offset; offset != nil {
+ _spec.Offset = *offset
+ }
+ if ps := _q.order; len(ps) > 0 {
+ _spec.Order = func(selector *sql.Selector) {
+ for i := range ps {
+ ps[i](selector)
+ }
+ }
+ }
+ return _spec
+}
+
+func (_q *AuthIdentityChannelQuery) sqlQuery(ctx context.Context) *sql.Selector {
+ builder := sql.Dialect(_q.driver.Dialect())
+ t1 := builder.Table(authidentitychannel.Table)
+ columns := _q.ctx.Fields
+ if len(columns) == 0 {
+ columns = authidentitychannel.Columns
+ }
+ selector := builder.Select(t1.Columns(columns...)...).From(t1)
+ if _q.sql != nil {
+ selector = _q.sql
+ selector.Select(selector.Columns(columns...)...)
+ }
+ if _q.ctx.Unique != nil && *_q.ctx.Unique {
+ selector.Distinct()
+ }
+ for _, m := range _q.modifiers {
+ m(selector)
+ }
+ for _, p := range _q.predicates {
+ p(selector)
+ }
+ for _, p := range _q.order {
+ p(selector)
+ }
+ if offset := _q.ctx.Offset; offset != nil {
+ // limit is mandatory for offset clause. We start
+ // with default value, and override it below if needed.
+ selector.Offset(*offset).Limit(math.MaxInt32)
+ }
+ if limit := _q.ctx.Limit; limit != nil {
+ selector.Limit(*limit)
+ }
+ return selector
+}
+
+// ForUpdate locks the selected rows against concurrent updates, and prevent them from being
+// updated, deleted or "selected ... for update" by other sessions, until the transaction is
+// either committed or rolled-back.
+func (_q *AuthIdentityChannelQuery) ForUpdate(opts ...sql.LockOption) *AuthIdentityChannelQuery {
+ if _q.driver.Dialect() == dialect.Postgres {
+ _q.Unique(false)
+ }
+ _q.modifiers = append(_q.modifiers, func(s *sql.Selector) {
+ s.ForUpdate(opts...)
+ })
+ return _q
+}
+
+// ForShare behaves similarly to ForUpdate, except that it acquires a shared mode lock
+// on any rows that are read. Other sessions can read the rows, but cannot modify them
+// until your transaction commits.
+func (_q *AuthIdentityChannelQuery) ForShare(opts ...sql.LockOption) *AuthIdentityChannelQuery {
+ if _q.driver.Dialect() == dialect.Postgres {
+ _q.Unique(false)
+ }
+ _q.modifiers = append(_q.modifiers, func(s *sql.Selector) {
+ s.ForShare(opts...)
+ })
+ return _q
+}
+
+// AuthIdentityChannelGroupBy is the group-by builder for AuthIdentityChannel entities.
+type AuthIdentityChannelGroupBy struct {
+ selector
+ build *AuthIdentityChannelQuery
+}
+
+// Aggregate adds the given aggregation functions to the group-by query.
+func (_g *AuthIdentityChannelGroupBy) Aggregate(fns ...AggregateFunc) *AuthIdentityChannelGroupBy {
+ _g.fns = append(_g.fns, fns...)
+ return _g
+}
+
+// Scan applies the selector query and scans the result into the given value.
+func (_g *AuthIdentityChannelGroupBy) Scan(ctx context.Context, v any) error {
+ ctx = setContextOp(ctx, _g.build.ctx, ent.OpQueryGroupBy)
+ if err := _g.build.prepareQuery(ctx); err != nil {
+ return err
+ }
+ return scanWithInterceptors[*AuthIdentityChannelQuery, *AuthIdentityChannelGroupBy](ctx, _g.build, _g, _g.build.inters, v)
+}
+
+func (_g *AuthIdentityChannelGroupBy) sqlScan(ctx context.Context, root *AuthIdentityChannelQuery, v any) error {
+ selector := root.sqlQuery(ctx).Select()
+ aggregation := make([]string, 0, len(_g.fns))
+ for _, fn := range _g.fns {
+ aggregation = append(aggregation, fn(selector))
+ }
+ if len(selector.SelectedColumns()) == 0 {
+ columns := make([]string, 0, len(*_g.flds)+len(_g.fns))
+ for _, f := range *_g.flds {
+ columns = append(columns, selector.C(f))
+ }
+ columns = append(columns, aggregation...)
+ selector.Select(columns...)
+ }
+ selector.GroupBy(selector.Columns(*_g.flds...)...)
+ if err := selector.Err(); err != nil {
+ return err
+ }
+ rows := &sql.Rows{}
+ query, args := selector.Query()
+ if err := _g.build.driver.Query(ctx, query, args, rows); err != nil {
+ return err
+ }
+ defer rows.Close()
+ return sql.ScanSlice(rows, v)
+}
+
+// AuthIdentityChannelSelect is the builder for selecting fields of AuthIdentityChannel entities.
+type AuthIdentityChannelSelect struct {
+ *AuthIdentityChannelQuery
+ selector
+}
+
+// Aggregate adds the given aggregation functions to the selector query.
+func (_s *AuthIdentityChannelSelect) Aggregate(fns ...AggregateFunc) *AuthIdentityChannelSelect {
+ _s.fns = append(_s.fns, fns...)
+ return _s
+}
+
+// Scan applies the selector query and scans the result into the given value.
+func (_s *AuthIdentityChannelSelect) Scan(ctx context.Context, v any) error {
+ ctx = setContextOp(ctx, _s.ctx, ent.OpQuerySelect)
+ if err := _s.prepareQuery(ctx); err != nil {
+ return err
+ }
+ return scanWithInterceptors[*AuthIdentityChannelQuery, *AuthIdentityChannelSelect](ctx, _s.AuthIdentityChannelQuery, _s, _s.inters, v)
+}
+
+func (_s *AuthIdentityChannelSelect) sqlScan(ctx context.Context, root *AuthIdentityChannelQuery, v any) error {
+ selector := root.sqlQuery(ctx)
+ aggregation := make([]string, 0, len(_s.fns))
+ for _, fn := range _s.fns {
+ aggregation = append(aggregation, fn(selector))
+ }
+ switch n := len(*_s.selector.flds); {
+ case n == 0 && len(aggregation) > 0:
+ selector.Select(aggregation...)
+ case n != 0 && len(aggregation) > 0:
+ selector.AppendSelect(aggregation...)
+ }
+ rows := &sql.Rows{}
+ query, args := selector.Query()
+ if err := _s.driver.Query(ctx, query, args, rows); err != nil {
+ return err
+ }
+ defer rows.Close()
+ return sql.ScanSlice(rows, v)
+}
diff --git a/backend/ent/authidentitychannel_update.go b/backend/ent/authidentitychannel_update.go
new file mode 100644
index 00000000..b550c454
--- /dev/null
+++ b/backend/ent/authidentitychannel_update.go
@@ -0,0 +1,581 @@
+// Code generated by ent, DO NOT EDIT.
+
+package ent
+
+import (
+ "context"
+ "errors"
+ "fmt"
+ "time"
+
+ "entgo.io/ent/dialect/sql"
+ "entgo.io/ent/dialect/sql/sqlgraph"
+ "entgo.io/ent/schema/field"
+ "github.com/Wei-Shaw/sub2api/ent/authidentity"
+ "github.com/Wei-Shaw/sub2api/ent/authidentitychannel"
+ "github.com/Wei-Shaw/sub2api/ent/predicate"
+)
+
+// AuthIdentityChannelUpdate is the builder for updating AuthIdentityChannel entities.
+type AuthIdentityChannelUpdate struct {
+ config
+ hooks []Hook
+ mutation *AuthIdentityChannelMutation
+}
+
+// Where appends a list predicates to the AuthIdentityChannelUpdate builder.
+func (_u *AuthIdentityChannelUpdate) Where(ps ...predicate.AuthIdentityChannel) *AuthIdentityChannelUpdate {
+ _u.mutation.Where(ps...)
+ return _u
+}
+
+// SetUpdatedAt sets the "updated_at" field.
+func (_u *AuthIdentityChannelUpdate) SetUpdatedAt(v time.Time) *AuthIdentityChannelUpdate {
+ _u.mutation.SetUpdatedAt(v)
+ return _u
+}
+
+// SetIdentityID sets the "identity_id" field.
+func (_u *AuthIdentityChannelUpdate) SetIdentityID(v int64) *AuthIdentityChannelUpdate {
+ _u.mutation.SetIdentityID(v)
+ return _u
+}
+
+// SetNillableIdentityID sets the "identity_id" field if the given value is not nil.
+func (_u *AuthIdentityChannelUpdate) SetNillableIdentityID(v *int64) *AuthIdentityChannelUpdate {
+ if v != nil {
+ _u.SetIdentityID(*v)
+ }
+ return _u
+}
+
+// SetProviderType sets the "provider_type" field.
+func (_u *AuthIdentityChannelUpdate) SetProviderType(v string) *AuthIdentityChannelUpdate {
+ _u.mutation.SetProviderType(v)
+ return _u
+}
+
+// SetNillableProviderType sets the "provider_type" field if the given value is not nil.
+func (_u *AuthIdentityChannelUpdate) SetNillableProviderType(v *string) *AuthIdentityChannelUpdate {
+ if v != nil {
+ _u.SetProviderType(*v)
+ }
+ return _u
+}
+
+// SetProviderKey sets the "provider_key" field.
+func (_u *AuthIdentityChannelUpdate) SetProviderKey(v string) *AuthIdentityChannelUpdate {
+ _u.mutation.SetProviderKey(v)
+ return _u
+}
+
+// SetNillableProviderKey sets the "provider_key" field if the given value is not nil.
+func (_u *AuthIdentityChannelUpdate) SetNillableProviderKey(v *string) *AuthIdentityChannelUpdate {
+ if v != nil {
+ _u.SetProviderKey(*v)
+ }
+ return _u
+}
+
+// SetChannel sets the "channel" field.
+func (_u *AuthIdentityChannelUpdate) SetChannel(v string) *AuthIdentityChannelUpdate {
+ _u.mutation.SetChannel(v)
+ return _u
+}
+
+// SetNillableChannel sets the "channel" field if the given value is not nil.
+func (_u *AuthIdentityChannelUpdate) SetNillableChannel(v *string) *AuthIdentityChannelUpdate {
+ if v != nil {
+ _u.SetChannel(*v)
+ }
+ return _u
+}
+
+// SetChannelAppID sets the "channel_app_id" field.
+func (_u *AuthIdentityChannelUpdate) SetChannelAppID(v string) *AuthIdentityChannelUpdate {
+ _u.mutation.SetChannelAppID(v)
+ return _u
+}
+
+// SetNillableChannelAppID sets the "channel_app_id" field if the given value is not nil.
+func (_u *AuthIdentityChannelUpdate) SetNillableChannelAppID(v *string) *AuthIdentityChannelUpdate {
+ if v != nil {
+ _u.SetChannelAppID(*v)
+ }
+ return _u
+}
+
+// SetChannelSubject sets the "channel_subject" field.
+func (_u *AuthIdentityChannelUpdate) SetChannelSubject(v string) *AuthIdentityChannelUpdate {
+ _u.mutation.SetChannelSubject(v)
+ return _u
+}
+
+// SetNillableChannelSubject sets the "channel_subject" field if the given value is not nil.
+func (_u *AuthIdentityChannelUpdate) SetNillableChannelSubject(v *string) *AuthIdentityChannelUpdate {
+ if v != nil {
+ _u.SetChannelSubject(*v)
+ }
+ return _u
+}
+
+// SetMetadata sets the "metadata" field.
+func (_u *AuthIdentityChannelUpdate) SetMetadata(v map[string]interface{}) *AuthIdentityChannelUpdate {
+ _u.mutation.SetMetadata(v)
+ return _u
+}
+
+// SetIdentity sets the "identity" edge to the AuthIdentity entity.
+func (_u *AuthIdentityChannelUpdate) SetIdentity(v *AuthIdentity) *AuthIdentityChannelUpdate {
+ return _u.SetIdentityID(v.ID)
+}
+
+// Mutation returns the AuthIdentityChannelMutation object of the builder.
+func (_u *AuthIdentityChannelUpdate) Mutation() *AuthIdentityChannelMutation {
+ return _u.mutation
+}
+
+// ClearIdentity clears the "identity" edge to the AuthIdentity entity.
+func (_u *AuthIdentityChannelUpdate) ClearIdentity() *AuthIdentityChannelUpdate {
+ _u.mutation.ClearIdentity()
+ return _u
+}
+
+// Save executes the query and returns the number of nodes affected by the update operation.
+func (_u *AuthIdentityChannelUpdate) Save(ctx context.Context) (int, error) {
+ _u.defaults()
+ return withHooks(ctx, _u.sqlSave, _u.mutation, _u.hooks)
+}
+
+// SaveX is like Save, but panics if an error occurs.
+func (_u *AuthIdentityChannelUpdate) SaveX(ctx context.Context) int {
+ affected, err := _u.Save(ctx)
+ if err != nil {
+ panic(err)
+ }
+ return affected
+}
+
+// Exec executes the query.
+func (_u *AuthIdentityChannelUpdate) Exec(ctx context.Context) error {
+ _, err := _u.Save(ctx)
+ return err
+}
+
+// ExecX is like Exec, but panics if an error occurs.
+func (_u *AuthIdentityChannelUpdate) ExecX(ctx context.Context) {
+ if err := _u.Exec(ctx); err != nil {
+ panic(err)
+ }
+}
+
+// defaults sets the default values of the builder before save.
+func (_u *AuthIdentityChannelUpdate) defaults() {
+ if _, ok := _u.mutation.UpdatedAt(); !ok {
+ v := authidentitychannel.UpdateDefaultUpdatedAt()
+ _u.mutation.SetUpdatedAt(v)
+ }
+}
+
+// check runs all checks and user-defined validators on the builder.
+func (_u *AuthIdentityChannelUpdate) check() error {
+ if v, ok := _u.mutation.ProviderType(); ok {
+ if err := authidentitychannel.ProviderTypeValidator(v); err != nil {
+ return &ValidationError{Name: "provider_type", err: fmt.Errorf(`ent: validator failed for field "AuthIdentityChannel.provider_type": %w`, err)}
+ }
+ }
+ if v, ok := _u.mutation.ProviderKey(); ok {
+ if err := authidentitychannel.ProviderKeyValidator(v); err != nil {
+ return &ValidationError{Name: "provider_key", err: fmt.Errorf(`ent: validator failed for field "AuthIdentityChannel.provider_key": %w`, err)}
+ }
+ }
+ if v, ok := _u.mutation.Channel(); ok {
+ if err := authidentitychannel.ChannelValidator(v); err != nil {
+ return &ValidationError{Name: "channel", err: fmt.Errorf(`ent: validator failed for field "AuthIdentityChannel.channel": %w`, err)}
+ }
+ }
+ if v, ok := _u.mutation.ChannelAppID(); ok {
+ if err := authidentitychannel.ChannelAppIDValidator(v); err != nil {
+ return &ValidationError{Name: "channel_app_id", err: fmt.Errorf(`ent: validator failed for field "AuthIdentityChannel.channel_app_id": %w`, err)}
+ }
+ }
+ if v, ok := _u.mutation.ChannelSubject(); ok {
+ if err := authidentitychannel.ChannelSubjectValidator(v); err != nil {
+ return &ValidationError{Name: "channel_subject", err: fmt.Errorf(`ent: validator failed for field "AuthIdentityChannel.channel_subject": %w`, err)}
+ }
+ }
+ if _u.mutation.IdentityCleared() && len(_u.mutation.IdentityIDs()) > 0 {
+ return errors.New(`ent: clearing a required unique edge "AuthIdentityChannel.identity"`)
+ }
+ return nil
+}
+
+func (_u *AuthIdentityChannelUpdate) sqlSave(ctx context.Context) (_node int, err error) {
+ if err := _u.check(); err != nil {
+ return _node, err
+ }
+ _spec := sqlgraph.NewUpdateSpec(authidentitychannel.Table, authidentitychannel.Columns, sqlgraph.NewFieldSpec(authidentitychannel.FieldID, field.TypeInt64))
+ if ps := _u.mutation.predicates; len(ps) > 0 {
+ _spec.Predicate = func(selector *sql.Selector) {
+ for i := range ps {
+ ps[i](selector)
+ }
+ }
+ }
+ if value, ok := _u.mutation.UpdatedAt(); ok {
+ _spec.SetField(authidentitychannel.FieldUpdatedAt, field.TypeTime, value)
+ }
+ if value, ok := _u.mutation.ProviderType(); ok {
+ _spec.SetField(authidentitychannel.FieldProviderType, field.TypeString, value)
+ }
+ if value, ok := _u.mutation.ProviderKey(); ok {
+ _spec.SetField(authidentitychannel.FieldProviderKey, field.TypeString, value)
+ }
+ if value, ok := _u.mutation.Channel(); ok {
+ _spec.SetField(authidentitychannel.FieldChannel, field.TypeString, value)
+ }
+ if value, ok := _u.mutation.ChannelAppID(); ok {
+ _spec.SetField(authidentitychannel.FieldChannelAppID, field.TypeString, value)
+ }
+ if value, ok := _u.mutation.ChannelSubject(); ok {
+ _spec.SetField(authidentitychannel.FieldChannelSubject, field.TypeString, value)
+ }
+ if value, ok := _u.mutation.Metadata(); ok {
+ _spec.SetField(authidentitychannel.FieldMetadata, field.TypeJSON, value)
+ }
+ if _u.mutation.IdentityCleared() {
+ edge := &sqlgraph.EdgeSpec{
+ Rel: sqlgraph.M2O,
+ Inverse: true,
+ Table: authidentitychannel.IdentityTable,
+ Columns: []string{authidentitychannel.IdentityColumn},
+ Bidi: false,
+ Target: &sqlgraph.EdgeTarget{
+ IDSpec: sqlgraph.NewFieldSpec(authidentity.FieldID, field.TypeInt64),
+ },
+ }
+ _spec.Edges.Clear = append(_spec.Edges.Clear, edge)
+ }
+ if nodes := _u.mutation.IdentityIDs(); len(nodes) > 0 {
+ edge := &sqlgraph.EdgeSpec{
+ Rel: sqlgraph.M2O,
+ Inverse: true,
+ Table: authidentitychannel.IdentityTable,
+ Columns: []string{authidentitychannel.IdentityColumn},
+ Bidi: false,
+ Target: &sqlgraph.EdgeTarget{
+ IDSpec: sqlgraph.NewFieldSpec(authidentity.FieldID, field.TypeInt64),
+ },
+ }
+ for _, k := range nodes {
+ edge.Target.Nodes = append(edge.Target.Nodes, k)
+ }
+ _spec.Edges.Add = append(_spec.Edges.Add, edge)
+ }
+ if _node, err = sqlgraph.UpdateNodes(ctx, _u.driver, _spec); err != nil {
+ if _, ok := err.(*sqlgraph.NotFoundError); ok {
+ err = &NotFoundError{authidentitychannel.Label}
+ } else if sqlgraph.IsConstraintError(err) {
+ err = &ConstraintError{msg: err.Error(), wrap: err}
+ }
+ return 0, err
+ }
+ _u.mutation.done = true
+ return _node, nil
+}
+
+// AuthIdentityChannelUpdateOne is the builder for updating a single AuthIdentityChannel entity.
+type AuthIdentityChannelUpdateOne struct {
+ config
+ fields []string
+ hooks []Hook
+ mutation *AuthIdentityChannelMutation
+}
+
+// SetUpdatedAt sets the "updated_at" field.
+func (_u *AuthIdentityChannelUpdateOne) SetUpdatedAt(v time.Time) *AuthIdentityChannelUpdateOne {
+ _u.mutation.SetUpdatedAt(v)
+ return _u
+}
+
+// SetIdentityID sets the "identity_id" field.
+func (_u *AuthIdentityChannelUpdateOne) SetIdentityID(v int64) *AuthIdentityChannelUpdateOne {
+ _u.mutation.SetIdentityID(v)
+ return _u
+}
+
+// SetNillableIdentityID sets the "identity_id" field if the given value is not nil.
+func (_u *AuthIdentityChannelUpdateOne) SetNillableIdentityID(v *int64) *AuthIdentityChannelUpdateOne {
+ if v != nil {
+ _u.SetIdentityID(*v)
+ }
+ return _u
+}
+
+// SetProviderType sets the "provider_type" field.
+func (_u *AuthIdentityChannelUpdateOne) SetProviderType(v string) *AuthIdentityChannelUpdateOne {
+ _u.mutation.SetProviderType(v)
+ return _u
+}
+
+// SetNillableProviderType sets the "provider_type" field if the given value is not nil.
+func (_u *AuthIdentityChannelUpdateOne) SetNillableProviderType(v *string) *AuthIdentityChannelUpdateOne {
+ if v != nil {
+ _u.SetProviderType(*v)
+ }
+ return _u
+}
+
+// SetProviderKey sets the "provider_key" field.
+func (_u *AuthIdentityChannelUpdateOne) SetProviderKey(v string) *AuthIdentityChannelUpdateOne {
+ _u.mutation.SetProviderKey(v)
+ return _u
+}
+
+// SetNillableProviderKey sets the "provider_key" field if the given value is not nil.
+func (_u *AuthIdentityChannelUpdateOne) SetNillableProviderKey(v *string) *AuthIdentityChannelUpdateOne {
+ if v != nil {
+ _u.SetProviderKey(*v)
+ }
+ return _u
+}
+
+// SetChannel sets the "channel" field.
+func (_u *AuthIdentityChannelUpdateOne) SetChannel(v string) *AuthIdentityChannelUpdateOne {
+ _u.mutation.SetChannel(v)
+ return _u
+}
+
+// SetNillableChannel sets the "channel" field if the given value is not nil.
+func (_u *AuthIdentityChannelUpdateOne) SetNillableChannel(v *string) *AuthIdentityChannelUpdateOne {
+ if v != nil {
+ _u.SetChannel(*v)
+ }
+ return _u
+}
+
+// SetChannelAppID sets the "channel_app_id" field.
+func (_u *AuthIdentityChannelUpdateOne) SetChannelAppID(v string) *AuthIdentityChannelUpdateOne {
+ _u.mutation.SetChannelAppID(v)
+ return _u
+}
+
+// SetNillableChannelAppID sets the "channel_app_id" field if the given value is not nil.
+func (_u *AuthIdentityChannelUpdateOne) SetNillableChannelAppID(v *string) *AuthIdentityChannelUpdateOne {
+ if v != nil {
+ _u.SetChannelAppID(*v)
+ }
+ return _u
+}
+
+// SetChannelSubject sets the "channel_subject" field.
+func (_u *AuthIdentityChannelUpdateOne) SetChannelSubject(v string) *AuthIdentityChannelUpdateOne {
+ _u.mutation.SetChannelSubject(v)
+ return _u
+}
+
+// SetNillableChannelSubject sets the "channel_subject" field if the given value is not nil.
+func (_u *AuthIdentityChannelUpdateOne) SetNillableChannelSubject(v *string) *AuthIdentityChannelUpdateOne {
+ if v != nil {
+ _u.SetChannelSubject(*v)
+ }
+ return _u
+}
+
+// SetMetadata sets the "metadata" field.
+func (_u *AuthIdentityChannelUpdateOne) SetMetadata(v map[string]interface{}) *AuthIdentityChannelUpdateOne {
+ _u.mutation.SetMetadata(v)
+ return _u
+}
+
+// SetIdentity sets the "identity" edge to the AuthIdentity entity.
+func (_u *AuthIdentityChannelUpdateOne) SetIdentity(v *AuthIdentity) *AuthIdentityChannelUpdateOne {
+ return _u.SetIdentityID(v.ID)
+}
+
+// Mutation returns the AuthIdentityChannelMutation object of the builder.
+func (_u *AuthIdentityChannelUpdateOne) Mutation() *AuthIdentityChannelMutation {
+ return _u.mutation
+}
+
+// ClearIdentity clears the "identity" edge to the AuthIdentity entity.
+func (_u *AuthIdentityChannelUpdateOne) ClearIdentity() *AuthIdentityChannelUpdateOne {
+ _u.mutation.ClearIdentity()
+ return _u
+}
+
+// Where appends a list predicates to the AuthIdentityChannelUpdate builder.
+func (_u *AuthIdentityChannelUpdateOne) Where(ps ...predicate.AuthIdentityChannel) *AuthIdentityChannelUpdateOne {
+ _u.mutation.Where(ps...)
+ return _u
+}
+
+// Select allows selecting one or more fields (columns) of the returned entity.
+// The default is selecting all fields defined in the entity schema.
+func (_u *AuthIdentityChannelUpdateOne) Select(field string, fields ...string) *AuthIdentityChannelUpdateOne {
+ _u.fields = append([]string{field}, fields...)
+ return _u
+}
+
+// Save executes the query and returns the updated AuthIdentityChannel entity.
+func (_u *AuthIdentityChannelUpdateOne) Save(ctx context.Context) (*AuthIdentityChannel, error) {
+ _u.defaults()
+ return withHooks(ctx, _u.sqlSave, _u.mutation, _u.hooks)
+}
+
+// SaveX is like Save, but panics if an error occurs.
+func (_u *AuthIdentityChannelUpdateOne) SaveX(ctx context.Context) *AuthIdentityChannel {
+ node, err := _u.Save(ctx)
+ if err != nil {
+ panic(err)
+ }
+ return node
+}
+
+// Exec executes the query on the entity.
+func (_u *AuthIdentityChannelUpdateOne) Exec(ctx context.Context) error {
+ _, err := _u.Save(ctx)
+ return err
+}
+
+// ExecX is like Exec, but panics if an error occurs.
+func (_u *AuthIdentityChannelUpdateOne) ExecX(ctx context.Context) {
+ if err := _u.Exec(ctx); err != nil {
+ panic(err)
+ }
+}
+
+// defaults sets the default values of the builder before save.
+func (_u *AuthIdentityChannelUpdateOne) defaults() {
+ if _, ok := _u.mutation.UpdatedAt(); !ok {
+ v := authidentitychannel.UpdateDefaultUpdatedAt()
+ _u.mutation.SetUpdatedAt(v)
+ }
+}
+
+// check runs all checks and user-defined validators on the builder.
+func (_u *AuthIdentityChannelUpdateOne) check() error {
+ if v, ok := _u.mutation.ProviderType(); ok {
+ if err := authidentitychannel.ProviderTypeValidator(v); err != nil {
+ return &ValidationError{Name: "provider_type", err: fmt.Errorf(`ent: validator failed for field "AuthIdentityChannel.provider_type": %w`, err)}
+ }
+ }
+ if v, ok := _u.mutation.ProviderKey(); ok {
+ if err := authidentitychannel.ProviderKeyValidator(v); err != nil {
+ return &ValidationError{Name: "provider_key", err: fmt.Errorf(`ent: validator failed for field "AuthIdentityChannel.provider_key": %w`, err)}
+ }
+ }
+ if v, ok := _u.mutation.Channel(); ok {
+ if err := authidentitychannel.ChannelValidator(v); err != nil {
+ return &ValidationError{Name: "channel", err: fmt.Errorf(`ent: validator failed for field "AuthIdentityChannel.channel": %w`, err)}
+ }
+ }
+ if v, ok := _u.mutation.ChannelAppID(); ok {
+ if err := authidentitychannel.ChannelAppIDValidator(v); err != nil {
+ return &ValidationError{Name: "channel_app_id", err: fmt.Errorf(`ent: validator failed for field "AuthIdentityChannel.channel_app_id": %w`, err)}
+ }
+ }
+ if v, ok := _u.mutation.ChannelSubject(); ok {
+ if err := authidentitychannel.ChannelSubjectValidator(v); err != nil {
+ return &ValidationError{Name: "channel_subject", err: fmt.Errorf(`ent: validator failed for field "AuthIdentityChannel.channel_subject": %w`, err)}
+ }
+ }
+ if _u.mutation.IdentityCleared() && len(_u.mutation.IdentityIDs()) > 0 {
+ return errors.New(`ent: clearing a required unique edge "AuthIdentityChannel.identity"`)
+ }
+ return nil
+}
+
+func (_u *AuthIdentityChannelUpdateOne) sqlSave(ctx context.Context) (_node *AuthIdentityChannel, err error) {
+ if err := _u.check(); err != nil {
+ return _node, err
+ }
+ _spec := sqlgraph.NewUpdateSpec(authidentitychannel.Table, authidentitychannel.Columns, sqlgraph.NewFieldSpec(authidentitychannel.FieldID, field.TypeInt64))
+ id, ok := _u.mutation.ID()
+ if !ok {
+ return nil, &ValidationError{Name: "id", err: errors.New(`ent: missing "AuthIdentityChannel.id" for update`)}
+ }
+ _spec.Node.ID.Value = id
+ if fields := _u.fields; len(fields) > 0 {
+ _spec.Node.Columns = make([]string, 0, len(fields))
+ _spec.Node.Columns = append(_spec.Node.Columns, authidentitychannel.FieldID)
+ for _, f := range fields {
+ if !authidentitychannel.ValidColumn(f) {
+ return nil, &ValidationError{Name: f, err: fmt.Errorf("ent: invalid field %q for query", f)}
+ }
+ if f != authidentitychannel.FieldID {
+ _spec.Node.Columns = append(_spec.Node.Columns, f)
+ }
+ }
+ }
+ if ps := _u.mutation.predicates; len(ps) > 0 {
+ _spec.Predicate = func(selector *sql.Selector) {
+ for i := range ps {
+ ps[i](selector)
+ }
+ }
+ }
+ if value, ok := _u.mutation.UpdatedAt(); ok {
+ _spec.SetField(authidentitychannel.FieldUpdatedAt, field.TypeTime, value)
+ }
+ if value, ok := _u.mutation.ProviderType(); ok {
+ _spec.SetField(authidentitychannel.FieldProviderType, field.TypeString, value)
+ }
+ if value, ok := _u.mutation.ProviderKey(); ok {
+ _spec.SetField(authidentitychannel.FieldProviderKey, field.TypeString, value)
+ }
+ if value, ok := _u.mutation.Channel(); ok {
+ _spec.SetField(authidentitychannel.FieldChannel, field.TypeString, value)
+ }
+ if value, ok := _u.mutation.ChannelAppID(); ok {
+ _spec.SetField(authidentitychannel.FieldChannelAppID, field.TypeString, value)
+ }
+ if value, ok := _u.mutation.ChannelSubject(); ok {
+ _spec.SetField(authidentitychannel.FieldChannelSubject, field.TypeString, value)
+ }
+ if value, ok := _u.mutation.Metadata(); ok {
+ _spec.SetField(authidentitychannel.FieldMetadata, field.TypeJSON, value)
+ }
+ if _u.mutation.IdentityCleared() {
+ edge := &sqlgraph.EdgeSpec{
+ Rel: sqlgraph.M2O,
+ Inverse: true,
+ Table: authidentitychannel.IdentityTable,
+ Columns: []string{authidentitychannel.IdentityColumn},
+ Bidi: false,
+ Target: &sqlgraph.EdgeTarget{
+ IDSpec: sqlgraph.NewFieldSpec(authidentity.FieldID, field.TypeInt64),
+ },
+ }
+ _spec.Edges.Clear = append(_spec.Edges.Clear, edge)
+ }
+ if nodes := _u.mutation.IdentityIDs(); len(nodes) > 0 {
+ edge := &sqlgraph.EdgeSpec{
+ Rel: sqlgraph.M2O,
+ Inverse: true,
+ Table: authidentitychannel.IdentityTable,
+ Columns: []string{authidentitychannel.IdentityColumn},
+ Bidi: false,
+ Target: &sqlgraph.EdgeTarget{
+ IDSpec: sqlgraph.NewFieldSpec(authidentity.FieldID, field.TypeInt64),
+ },
+ }
+ for _, k := range nodes {
+ edge.Target.Nodes = append(edge.Target.Nodes, k)
+ }
+ _spec.Edges.Add = append(_spec.Edges.Add, edge)
+ }
+ _node = &AuthIdentityChannel{config: _u.config}
+ _spec.Assign = _node.assignValues
+ _spec.ScanValues = _node.scanValues
+ if err = sqlgraph.UpdateNode(ctx, _u.driver, _spec); err != nil {
+ if _, ok := err.(*sqlgraph.NotFoundError); ok {
+ err = &NotFoundError{authidentitychannel.Label}
+ } else if sqlgraph.IsConstraintError(err) {
+ err = &ConstraintError{msg: err.Error(), wrap: err}
+ }
+ return nil, err
+ }
+ _u.mutation.done = true
+ return _node, nil
+}
diff --git a/backend/ent/client.go b/backend/ent/client.go
index e52e015a..b02f519b 100644
--- a/backend/ent/client.go
+++ b/backend/ent/client.go
@@ -20,12 +20,16 @@ import (
"github.com/Wei-Shaw/sub2api/ent/announcement"
"github.com/Wei-Shaw/sub2api/ent/announcementread"
"github.com/Wei-Shaw/sub2api/ent/apikey"
+ "github.com/Wei-Shaw/sub2api/ent/authidentity"
+ "github.com/Wei-Shaw/sub2api/ent/authidentitychannel"
"github.com/Wei-Shaw/sub2api/ent/errorpassthroughrule"
"github.com/Wei-Shaw/sub2api/ent/group"
"github.com/Wei-Shaw/sub2api/ent/idempotencyrecord"
+ "github.com/Wei-Shaw/sub2api/ent/identityadoptiondecision"
"github.com/Wei-Shaw/sub2api/ent/paymentauditlog"
"github.com/Wei-Shaw/sub2api/ent/paymentorder"
"github.com/Wei-Shaw/sub2api/ent/paymentproviderinstance"
+ "github.com/Wei-Shaw/sub2api/ent/pendingauthsession"
"github.com/Wei-Shaw/sub2api/ent/promocode"
"github.com/Wei-Shaw/sub2api/ent/promocodeusage"
"github.com/Wei-Shaw/sub2api/ent/proxy"
@@ -60,18 +64,26 @@ type Client struct {
Announcement *AnnouncementClient
// AnnouncementRead is the client for interacting with the AnnouncementRead builders.
AnnouncementRead *AnnouncementReadClient
+ // AuthIdentity is the client for interacting with the AuthIdentity builders.
+ AuthIdentity *AuthIdentityClient
+ // AuthIdentityChannel is the client for interacting with the AuthIdentityChannel builders.
+ AuthIdentityChannel *AuthIdentityChannelClient
// ErrorPassthroughRule is the client for interacting with the ErrorPassthroughRule builders.
ErrorPassthroughRule *ErrorPassthroughRuleClient
// Group is the client for interacting with the Group builders.
Group *GroupClient
// IdempotencyRecord is the client for interacting with the IdempotencyRecord builders.
IdempotencyRecord *IdempotencyRecordClient
+ // IdentityAdoptionDecision is the client for interacting with the IdentityAdoptionDecision builders.
+ IdentityAdoptionDecision *IdentityAdoptionDecisionClient
// PaymentAuditLog is the client for interacting with the PaymentAuditLog builders.
PaymentAuditLog *PaymentAuditLogClient
// PaymentOrder is the client for interacting with the PaymentOrder builders.
PaymentOrder *PaymentOrderClient
// PaymentProviderInstance is the client for interacting with the PaymentProviderInstance builders.
PaymentProviderInstance *PaymentProviderInstanceClient
+ // PendingAuthSession is the client for interacting with the PendingAuthSession builders.
+ PendingAuthSession *PendingAuthSessionClient
// PromoCode is the client for interacting with the PromoCode builders.
PromoCode *PromoCodeClient
// PromoCodeUsage is the client for interacting with the PromoCodeUsage builders.
@@ -118,12 +130,16 @@ func (c *Client) init() {
c.AccountGroup = NewAccountGroupClient(c.config)
c.Announcement = NewAnnouncementClient(c.config)
c.AnnouncementRead = NewAnnouncementReadClient(c.config)
+ c.AuthIdentity = NewAuthIdentityClient(c.config)
+ c.AuthIdentityChannel = NewAuthIdentityChannelClient(c.config)
c.ErrorPassthroughRule = NewErrorPassthroughRuleClient(c.config)
c.Group = NewGroupClient(c.config)
c.IdempotencyRecord = NewIdempotencyRecordClient(c.config)
+ c.IdentityAdoptionDecision = NewIdentityAdoptionDecisionClient(c.config)
c.PaymentAuditLog = NewPaymentAuditLogClient(c.config)
c.PaymentOrder = NewPaymentOrderClient(c.config)
c.PaymentProviderInstance = NewPaymentProviderInstanceClient(c.config)
+ c.PendingAuthSession = NewPendingAuthSessionClient(c.config)
c.PromoCode = NewPromoCodeClient(c.config)
c.PromoCodeUsage = NewPromoCodeUsageClient(c.config)
c.Proxy = NewProxyClient(c.config)
@@ -229,34 +245,38 @@ func (c *Client) Tx(ctx context.Context) (*Tx, error) {
cfg := c.config
cfg.driver = tx
return &Tx{
- ctx: ctx,
- config: cfg,
- APIKey: NewAPIKeyClient(cfg),
- Account: NewAccountClient(cfg),
- AccountGroup: NewAccountGroupClient(cfg),
- Announcement: NewAnnouncementClient(cfg),
- AnnouncementRead: NewAnnouncementReadClient(cfg),
- ErrorPassthroughRule: NewErrorPassthroughRuleClient(cfg),
- Group: NewGroupClient(cfg),
- IdempotencyRecord: NewIdempotencyRecordClient(cfg),
- PaymentAuditLog: NewPaymentAuditLogClient(cfg),
- PaymentOrder: NewPaymentOrderClient(cfg),
- PaymentProviderInstance: NewPaymentProviderInstanceClient(cfg),
- PromoCode: NewPromoCodeClient(cfg),
- PromoCodeUsage: NewPromoCodeUsageClient(cfg),
- Proxy: NewProxyClient(cfg),
- RedeemCode: NewRedeemCodeClient(cfg),
- SecuritySecret: NewSecuritySecretClient(cfg),
- Setting: NewSettingClient(cfg),
- SubscriptionPlan: NewSubscriptionPlanClient(cfg),
- TLSFingerprintProfile: NewTLSFingerprintProfileClient(cfg),
- UsageCleanupTask: NewUsageCleanupTaskClient(cfg),
- UsageLog: NewUsageLogClient(cfg),
- User: NewUserClient(cfg),
- UserAllowedGroup: NewUserAllowedGroupClient(cfg),
- UserAttributeDefinition: NewUserAttributeDefinitionClient(cfg),
- UserAttributeValue: NewUserAttributeValueClient(cfg),
- UserSubscription: NewUserSubscriptionClient(cfg),
+ ctx: ctx,
+ config: cfg,
+ APIKey: NewAPIKeyClient(cfg),
+ Account: NewAccountClient(cfg),
+ AccountGroup: NewAccountGroupClient(cfg),
+ Announcement: NewAnnouncementClient(cfg),
+ AnnouncementRead: NewAnnouncementReadClient(cfg),
+ AuthIdentity: NewAuthIdentityClient(cfg),
+ AuthIdentityChannel: NewAuthIdentityChannelClient(cfg),
+ ErrorPassthroughRule: NewErrorPassthroughRuleClient(cfg),
+ Group: NewGroupClient(cfg),
+ IdempotencyRecord: NewIdempotencyRecordClient(cfg),
+ IdentityAdoptionDecision: NewIdentityAdoptionDecisionClient(cfg),
+ PaymentAuditLog: NewPaymentAuditLogClient(cfg),
+ PaymentOrder: NewPaymentOrderClient(cfg),
+ PaymentProviderInstance: NewPaymentProviderInstanceClient(cfg),
+ PendingAuthSession: NewPendingAuthSessionClient(cfg),
+ PromoCode: NewPromoCodeClient(cfg),
+ PromoCodeUsage: NewPromoCodeUsageClient(cfg),
+ Proxy: NewProxyClient(cfg),
+ RedeemCode: NewRedeemCodeClient(cfg),
+ SecuritySecret: NewSecuritySecretClient(cfg),
+ Setting: NewSettingClient(cfg),
+ SubscriptionPlan: NewSubscriptionPlanClient(cfg),
+ TLSFingerprintProfile: NewTLSFingerprintProfileClient(cfg),
+ UsageCleanupTask: NewUsageCleanupTaskClient(cfg),
+ UsageLog: NewUsageLogClient(cfg),
+ User: NewUserClient(cfg),
+ UserAllowedGroup: NewUserAllowedGroupClient(cfg),
+ UserAttributeDefinition: NewUserAttributeDefinitionClient(cfg),
+ UserAttributeValue: NewUserAttributeValueClient(cfg),
+ UserSubscription: NewUserSubscriptionClient(cfg),
}, nil
}
@@ -274,34 +294,38 @@ func (c *Client) BeginTx(ctx context.Context, opts *sql.TxOptions) (*Tx, error)
cfg := c.config
cfg.driver = &txDriver{tx: tx, drv: c.driver}
return &Tx{
- ctx: ctx,
- config: cfg,
- APIKey: NewAPIKeyClient(cfg),
- Account: NewAccountClient(cfg),
- AccountGroup: NewAccountGroupClient(cfg),
- Announcement: NewAnnouncementClient(cfg),
- AnnouncementRead: NewAnnouncementReadClient(cfg),
- ErrorPassthroughRule: NewErrorPassthroughRuleClient(cfg),
- Group: NewGroupClient(cfg),
- IdempotencyRecord: NewIdempotencyRecordClient(cfg),
- PaymentAuditLog: NewPaymentAuditLogClient(cfg),
- PaymentOrder: NewPaymentOrderClient(cfg),
- PaymentProviderInstance: NewPaymentProviderInstanceClient(cfg),
- PromoCode: NewPromoCodeClient(cfg),
- PromoCodeUsage: NewPromoCodeUsageClient(cfg),
- Proxy: NewProxyClient(cfg),
- RedeemCode: NewRedeemCodeClient(cfg),
- SecuritySecret: NewSecuritySecretClient(cfg),
- Setting: NewSettingClient(cfg),
- SubscriptionPlan: NewSubscriptionPlanClient(cfg),
- TLSFingerprintProfile: NewTLSFingerprintProfileClient(cfg),
- UsageCleanupTask: NewUsageCleanupTaskClient(cfg),
- UsageLog: NewUsageLogClient(cfg),
- User: NewUserClient(cfg),
- UserAllowedGroup: NewUserAllowedGroupClient(cfg),
- UserAttributeDefinition: NewUserAttributeDefinitionClient(cfg),
- UserAttributeValue: NewUserAttributeValueClient(cfg),
- UserSubscription: NewUserSubscriptionClient(cfg),
+ ctx: ctx,
+ config: cfg,
+ APIKey: NewAPIKeyClient(cfg),
+ Account: NewAccountClient(cfg),
+ AccountGroup: NewAccountGroupClient(cfg),
+ Announcement: NewAnnouncementClient(cfg),
+ AnnouncementRead: NewAnnouncementReadClient(cfg),
+ AuthIdentity: NewAuthIdentityClient(cfg),
+ AuthIdentityChannel: NewAuthIdentityChannelClient(cfg),
+ ErrorPassthroughRule: NewErrorPassthroughRuleClient(cfg),
+ Group: NewGroupClient(cfg),
+ IdempotencyRecord: NewIdempotencyRecordClient(cfg),
+ IdentityAdoptionDecision: NewIdentityAdoptionDecisionClient(cfg),
+ PaymentAuditLog: NewPaymentAuditLogClient(cfg),
+ PaymentOrder: NewPaymentOrderClient(cfg),
+ PaymentProviderInstance: NewPaymentProviderInstanceClient(cfg),
+ PendingAuthSession: NewPendingAuthSessionClient(cfg),
+ PromoCode: NewPromoCodeClient(cfg),
+ PromoCodeUsage: NewPromoCodeUsageClient(cfg),
+ Proxy: NewProxyClient(cfg),
+ RedeemCode: NewRedeemCodeClient(cfg),
+ SecuritySecret: NewSecuritySecretClient(cfg),
+ Setting: NewSettingClient(cfg),
+ SubscriptionPlan: NewSubscriptionPlanClient(cfg),
+ TLSFingerprintProfile: NewTLSFingerprintProfileClient(cfg),
+ UsageCleanupTask: NewUsageCleanupTaskClient(cfg),
+ UsageLog: NewUsageLogClient(cfg),
+ User: NewUserClient(cfg),
+ UserAllowedGroup: NewUserAllowedGroupClient(cfg),
+ UserAttributeDefinition: NewUserAttributeDefinitionClient(cfg),
+ UserAttributeValue: NewUserAttributeValueClient(cfg),
+ UserSubscription: NewUserSubscriptionClient(cfg),
}, nil
}
@@ -332,11 +356,12 @@ func (c *Client) Close() error {
func (c *Client) Use(hooks ...Hook) {
for _, n := range []interface{ Use(...Hook) }{
c.APIKey, c.Account, c.AccountGroup, c.Announcement, c.AnnouncementRead,
- c.ErrorPassthroughRule, c.Group, c.IdempotencyRecord, c.PaymentAuditLog,
- c.PaymentOrder, c.PaymentProviderInstance, c.PromoCode, c.PromoCodeUsage,
- c.Proxy, c.RedeemCode, c.SecuritySecret, c.Setting, c.SubscriptionPlan,
- c.TLSFingerprintProfile, c.UsageCleanupTask, c.UsageLog, c.User,
- c.UserAllowedGroup, c.UserAttributeDefinition, c.UserAttributeValue,
+ c.AuthIdentity, c.AuthIdentityChannel, c.ErrorPassthroughRule, c.Group,
+ c.IdempotencyRecord, c.IdentityAdoptionDecision, c.PaymentAuditLog,
+ c.PaymentOrder, c.PaymentProviderInstance, c.PendingAuthSession, c.PromoCode,
+ c.PromoCodeUsage, c.Proxy, c.RedeemCode, c.SecuritySecret, c.Setting,
+ c.SubscriptionPlan, c.TLSFingerprintProfile, c.UsageCleanupTask, c.UsageLog,
+ c.User, c.UserAllowedGroup, c.UserAttributeDefinition, c.UserAttributeValue,
c.UserSubscription,
} {
n.Use(hooks...)
@@ -348,11 +373,12 @@ func (c *Client) Use(hooks ...Hook) {
func (c *Client) Intercept(interceptors ...Interceptor) {
for _, n := range []interface{ Intercept(...Interceptor) }{
c.APIKey, c.Account, c.AccountGroup, c.Announcement, c.AnnouncementRead,
- c.ErrorPassthroughRule, c.Group, c.IdempotencyRecord, c.PaymentAuditLog,
- c.PaymentOrder, c.PaymentProviderInstance, c.PromoCode, c.PromoCodeUsage,
- c.Proxy, c.RedeemCode, c.SecuritySecret, c.Setting, c.SubscriptionPlan,
- c.TLSFingerprintProfile, c.UsageCleanupTask, c.UsageLog, c.User,
- c.UserAllowedGroup, c.UserAttributeDefinition, c.UserAttributeValue,
+ c.AuthIdentity, c.AuthIdentityChannel, c.ErrorPassthroughRule, c.Group,
+ c.IdempotencyRecord, c.IdentityAdoptionDecision, c.PaymentAuditLog,
+ c.PaymentOrder, c.PaymentProviderInstance, c.PendingAuthSession, c.PromoCode,
+ c.PromoCodeUsage, c.Proxy, c.RedeemCode, c.SecuritySecret, c.Setting,
+ c.SubscriptionPlan, c.TLSFingerprintProfile, c.UsageCleanupTask, c.UsageLog,
+ c.User, c.UserAllowedGroup, c.UserAttributeDefinition, c.UserAttributeValue,
c.UserSubscription,
} {
n.Intercept(interceptors...)
@@ -372,18 +398,26 @@ func (c *Client) Mutate(ctx context.Context, m Mutation) (Value, error) {
return c.Announcement.mutate(ctx, m)
case *AnnouncementReadMutation:
return c.AnnouncementRead.mutate(ctx, m)
+ case *AuthIdentityMutation:
+ return c.AuthIdentity.mutate(ctx, m)
+ case *AuthIdentityChannelMutation:
+ return c.AuthIdentityChannel.mutate(ctx, m)
case *ErrorPassthroughRuleMutation:
return c.ErrorPassthroughRule.mutate(ctx, m)
case *GroupMutation:
return c.Group.mutate(ctx, m)
case *IdempotencyRecordMutation:
return c.IdempotencyRecord.mutate(ctx, m)
+ case *IdentityAdoptionDecisionMutation:
+ return c.IdentityAdoptionDecision.mutate(ctx, m)
case *PaymentAuditLogMutation:
return c.PaymentAuditLog.mutate(ctx, m)
case *PaymentOrderMutation:
return c.PaymentOrder.mutate(ctx, m)
case *PaymentProviderInstanceMutation:
return c.PaymentProviderInstance.mutate(ctx, m)
+ case *PendingAuthSessionMutation:
+ return c.PendingAuthSession.mutate(ctx, m)
case *PromoCodeMutation:
return c.PromoCode.mutate(ctx, m)
case *PromoCodeUsageMutation:
@@ -1231,6 +1265,336 @@ func (c *AnnouncementReadClient) mutate(ctx context.Context, m *AnnouncementRead
}
}
+// AuthIdentityClient is a client for the AuthIdentity schema.
+type AuthIdentityClient struct {
+ config
+}
+
+// NewAuthIdentityClient returns a client for the AuthIdentity from the given config.
+func NewAuthIdentityClient(c config) *AuthIdentityClient {
+ return &AuthIdentityClient{config: c}
+}
+
+// Use adds a list of mutation hooks to the hooks stack.
+// A call to `Use(f, g, h)` equals to `authidentity.Hooks(f(g(h())))`.
+func (c *AuthIdentityClient) Use(hooks ...Hook) {
+ c.hooks.AuthIdentity = append(c.hooks.AuthIdentity, hooks...)
+}
+
+// Intercept adds a list of query interceptors to the interceptors stack.
+// A call to `Intercept(f, g, h)` equals to `authidentity.Intercept(f(g(h())))`.
+func (c *AuthIdentityClient) Intercept(interceptors ...Interceptor) {
+ c.inters.AuthIdentity = append(c.inters.AuthIdentity, interceptors...)
+}
+
+// Create returns a builder for creating a AuthIdentity entity.
+func (c *AuthIdentityClient) Create() *AuthIdentityCreate {
+ mutation := newAuthIdentityMutation(c.config, OpCreate)
+ return &AuthIdentityCreate{config: c.config, hooks: c.Hooks(), mutation: mutation}
+}
+
+// CreateBulk returns a builder for creating a bulk of AuthIdentity entities.
+func (c *AuthIdentityClient) CreateBulk(builders ...*AuthIdentityCreate) *AuthIdentityCreateBulk {
+ return &AuthIdentityCreateBulk{config: c.config, builders: builders}
+}
+
+// MapCreateBulk creates a bulk creation builder from the given slice. For each item in the slice, the function creates
+// a builder and applies setFunc on it.
+func (c *AuthIdentityClient) MapCreateBulk(slice any, setFunc func(*AuthIdentityCreate, int)) *AuthIdentityCreateBulk {
+ rv := reflect.ValueOf(slice)
+ if rv.Kind() != reflect.Slice {
+ return &AuthIdentityCreateBulk{err: fmt.Errorf("calling to AuthIdentityClient.MapCreateBulk with wrong type %T, need slice", slice)}
+ }
+ builders := make([]*AuthIdentityCreate, rv.Len())
+ for i := 0; i < rv.Len(); i++ {
+ builders[i] = c.Create()
+ setFunc(builders[i], i)
+ }
+ return &AuthIdentityCreateBulk{config: c.config, builders: builders}
+}
+
+// Update returns an update builder for AuthIdentity.
+func (c *AuthIdentityClient) Update() *AuthIdentityUpdate {
+ mutation := newAuthIdentityMutation(c.config, OpUpdate)
+ return &AuthIdentityUpdate{config: c.config, hooks: c.Hooks(), mutation: mutation}
+}
+
+// UpdateOne returns an update builder for the given entity.
+func (c *AuthIdentityClient) UpdateOne(_m *AuthIdentity) *AuthIdentityUpdateOne {
+ mutation := newAuthIdentityMutation(c.config, OpUpdateOne, withAuthIdentity(_m))
+ return &AuthIdentityUpdateOne{config: c.config, hooks: c.Hooks(), mutation: mutation}
+}
+
+// UpdateOneID returns an update builder for the given id.
+func (c *AuthIdentityClient) UpdateOneID(id int64) *AuthIdentityUpdateOne {
+ mutation := newAuthIdentityMutation(c.config, OpUpdateOne, withAuthIdentityID(id))
+ return &AuthIdentityUpdateOne{config: c.config, hooks: c.Hooks(), mutation: mutation}
+}
+
+// Delete returns a delete builder for AuthIdentity.
+func (c *AuthIdentityClient) Delete() *AuthIdentityDelete {
+ mutation := newAuthIdentityMutation(c.config, OpDelete)
+ return &AuthIdentityDelete{config: c.config, hooks: c.Hooks(), mutation: mutation}
+}
+
+// DeleteOne returns a builder for deleting the given entity.
+func (c *AuthIdentityClient) DeleteOne(_m *AuthIdentity) *AuthIdentityDeleteOne {
+ return c.DeleteOneID(_m.ID)
+}
+
+// DeleteOneID returns a builder for deleting the given entity by its id.
+func (c *AuthIdentityClient) DeleteOneID(id int64) *AuthIdentityDeleteOne {
+ builder := c.Delete().Where(authidentity.ID(id))
+ builder.mutation.id = &id
+ builder.mutation.op = OpDeleteOne
+ return &AuthIdentityDeleteOne{builder}
+}
+
+// Query returns a query builder for AuthIdentity.
+func (c *AuthIdentityClient) Query() *AuthIdentityQuery {
+ return &AuthIdentityQuery{
+ config: c.config,
+ ctx: &QueryContext{Type: TypeAuthIdentity},
+ inters: c.Interceptors(),
+ }
+}
+
+// Get returns a AuthIdentity entity by its id.
+func (c *AuthIdentityClient) Get(ctx context.Context, id int64) (*AuthIdentity, error) {
+ return c.Query().Where(authidentity.ID(id)).Only(ctx)
+}
+
+// GetX is like Get, but panics if an error occurs.
+func (c *AuthIdentityClient) GetX(ctx context.Context, id int64) *AuthIdentity {
+ obj, err := c.Get(ctx, id)
+ if err != nil {
+ panic(err)
+ }
+ return obj
+}
+
+// QueryUser queries the user edge of a AuthIdentity.
+func (c *AuthIdentityClient) QueryUser(_m *AuthIdentity) *UserQuery {
+ query := (&UserClient{config: c.config}).Query()
+ query.path = func(context.Context) (fromV *sql.Selector, _ error) {
+ id := _m.ID
+ step := sqlgraph.NewStep(
+ sqlgraph.From(authidentity.Table, authidentity.FieldID, id),
+ sqlgraph.To(user.Table, user.FieldID),
+ sqlgraph.Edge(sqlgraph.M2O, true, authidentity.UserTable, authidentity.UserColumn),
+ )
+ fromV = sqlgraph.Neighbors(_m.driver.Dialect(), step)
+ return fromV, nil
+ }
+ return query
+}
+
+// QueryChannels queries the channels edge of a AuthIdentity.
+func (c *AuthIdentityClient) QueryChannels(_m *AuthIdentity) *AuthIdentityChannelQuery {
+ query := (&AuthIdentityChannelClient{config: c.config}).Query()
+ query.path = func(context.Context) (fromV *sql.Selector, _ error) {
+ id := _m.ID
+ step := sqlgraph.NewStep(
+ sqlgraph.From(authidentity.Table, authidentity.FieldID, id),
+ sqlgraph.To(authidentitychannel.Table, authidentitychannel.FieldID),
+ sqlgraph.Edge(sqlgraph.O2M, false, authidentity.ChannelsTable, authidentity.ChannelsColumn),
+ )
+ fromV = sqlgraph.Neighbors(_m.driver.Dialect(), step)
+ return fromV, nil
+ }
+ return query
+}
+
+// QueryAdoptionDecisions queries the adoption_decisions edge of a AuthIdentity.
+func (c *AuthIdentityClient) QueryAdoptionDecisions(_m *AuthIdentity) *IdentityAdoptionDecisionQuery {
+ query := (&IdentityAdoptionDecisionClient{config: c.config}).Query()
+ query.path = func(context.Context) (fromV *sql.Selector, _ error) {
+ id := _m.ID
+ step := sqlgraph.NewStep(
+ sqlgraph.From(authidentity.Table, authidentity.FieldID, id),
+ sqlgraph.To(identityadoptiondecision.Table, identityadoptiondecision.FieldID),
+ sqlgraph.Edge(sqlgraph.O2M, false, authidentity.AdoptionDecisionsTable, authidentity.AdoptionDecisionsColumn),
+ )
+ fromV = sqlgraph.Neighbors(_m.driver.Dialect(), step)
+ return fromV, nil
+ }
+ return query
+}
+
+// Hooks returns the client hooks.
+func (c *AuthIdentityClient) Hooks() []Hook {
+ return c.hooks.AuthIdentity
+}
+
+// Interceptors returns the client interceptors.
+func (c *AuthIdentityClient) Interceptors() []Interceptor {
+ return c.inters.AuthIdentity
+}
+
+func (c *AuthIdentityClient) mutate(ctx context.Context, m *AuthIdentityMutation) (Value, error) {
+ switch m.Op() {
+ case OpCreate:
+ return (&AuthIdentityCreate{config: c.config, hooks: c.Hooks(), mutation: m}).Save(ctx)
+ case OpUpdate:
+ return (&AuthIdentityUpdate{config: c.config, hooks: c.Hooks(), mutation: m}).Save(ctx)
+ case OpUpdateOne:
+ return (&AuthIdentityUpdateOne{config: c.config, hooks: c.Hooks(), mutation: m}).Save(ctx)
+ case OpDelete, OpDeleteOne:
+ return (&AuthIdentityDelete{config: c.config, hooks: c.Hooks(), mutation: m}).Exec(ctx)
+ default:
+ return nil, fmt.Errorf("ent: unknown AuthIdentity mutation op: %q", m.Op())
+ }
+}
+
+// AuthIdentityChannelClient is a client for the AuthIdentityChannel schema.
+type AuthIdentityChannelClient struct {
+ config
+}
+
+// NewAuthIdentityChannelClient returns a client for the AuthIdentityChannel from the given config.
+func NewAuthIdentityChannelClient(c config) *AuthIdentityChannelClient {
+ return &AuthIdentityChannelClient{config: c}
+}
+
+// Use adds a list of mutation hooks to the hooks stack.
+// A call to `Use(f, g, h)` equals to `authidentitychannel.Hooks(f(g(h())))`.
+func (c *AuthIdentityChannelClient) Use(hooks ...Hook) {
+ c.hooks.AuthIdentityChannel = append(c.hooks.AuthIdentityChannel, hooks...)
+}
+
+// Intercept adds a list of query interceptors to the interceptors stack.
+// A call to `Intercept(f, g, h)` equals to `authidentitychannel.Intercept(f(g(h())))`.
+func (c *AuthIdentityChannelClient) Intercept(interceptors ...Interceptor) {
+ c.inters.AuthIdentityChannel = append(c.inters.AuthIdentityChannel, interceptors...)
+}
+
+// Create returns a builder for creating a AuthIdentityChannel entity.
+func (c *AuthIdentityChannelClient) Create() *AuthIdentityChannelCreate {
+ mutation := newAuthIdentityChannelMutation(c.config, OpCreate)
+ return &AuthIdentityChannelCreate{config: c.config, hooks: c.Hooks(), mutation: mutation}
+}
+
+// CreateBulk returns a builder for creating a bulk of AuthIdentityChannel entities.
+func (c *AuthIdentityChannelClient) CreateBulk(builders ...*AuthIdentityChannelCreate) *AuthIdentityChannelCreateBulk {
+ return &AuthIdentityChannelCreateBulk{config: c.config, builders: builders}
+}
+
+// MapCreateBulk creates a bulk creation builder from the given slice. For each item in the slice, the function creates
+// a builder and applies setFunc on it.
+func (c *AuthIdentityChannelClient) MapCreateBulk(slice any, setFunc func(*AuthIdentityChannelCreate, int)) *AuthIdentityChannelCreateBulk {
+ rv := reflect.ValueOf(slice)
+ if rv.Kind() != reflect.Slice {
+ return &AuthIdentityChannelCreateBulk{err: fmt.Errorf("calling to AuthIdentityChannelClient.MapCreateBulk with wrong type %T, need slice", slice)}
+ }
+ builders := make([]*AuthIdentityChannelCreate, rv.Len())
+ for i := 0; i < rv.Len(); i++ {
+ builders[i] = c.Create()
+ setFunc(builders[i], i)
+ }
+ return &AuthIdentityChannelCreateBulk{config: c.config, builders: builders}
+}
+
+// Update returns an update builder for AuthIdentityChannel.
+func (c *AuthIdentityChannelClient) Update() *AuthIdentityChannelUpdate {
+ mutation := newAuthIdentityChannelMutation(c.config, OpUpdate)
+ return &AuthIdentityChannelUpdate{config: c.config, hooks: c.Hooks(), mutation: mutation}
+}
+
+// UpdateOne returns an update builder for the given entity.
+func (c *AuthIdentityChannelClient) UpdateOne(_m *AuthIdentityChannel) *AuthIdentityChannelUpdateOne {
+ mutation := newAuthIdentityChannelMutation(c.config, OpUpdateOne, withAuthIdentityChannel(_m))
+ return &AuthIdentityChannelUpdateOne{config: c.config, hooks: c.Hooks(), mutation: mutation}
+}
+
+// UpdateOneID returns an update builder for the given id.
+func (c *AuthIdentityChannelClient) UpdateOneID(id int64) *AuthIdentityChannelUpdateOne {
+ mutation := newAuthIdentityChannelMutation(c.config, OpUpdateOne, withAuthIdentityChannelID(id))
+ return &AuthIdentityChannelUpdateOne{config: c.config, hooks: c.Hooks(), mutation: mutation}
+}
+
+// Delete returns a delete builder for AuthIdentityChannel.
+func (c *AuthIdentityChannelClient) Delete() *AuthIdentityChannelDelete {
+ mutation := newAuthIdentityChannelMutation(c.config, OpDelete)
+ return &AuthIdentityChannelDelete{config: c.config, hooks: c.Hooks(), mutation: mutation}
+}
+
+// DeleteOne returns a builder for deleting the given entity.
+func (c *AuthIdentityChannelClient) DeleteOne(_m *AuthIdentityChannel) *AuthIdentityChannelDeleteOne {
+ return c.DeleteOneID(_m.ID)
+}
+
+// DeleteOneID returns a builder for deleting the given entity by its id.
+func (c *AuthIdentityChannelClient) DeleteOneID(id int64) *AuthIdentityChannelDeleteOne {
+ builder := c.Delete().Where(authidentitychannel.ID(id))
+ builder.mutation.id = &id
+ builder.mutation.op = OpDeleteOne
+ return &AuthIdentityChannelDeleteOne{builder}
+}
+
+// Query returns a query builder for AuthIdentityChannel.
+func (c *AuthIdentityChannelClient) Query() *AuthIdentityChannelQuery {
+ return &AuthIdentityChannelQuery{
+ config: c.config,
+ ctx: &QueryContext{Type: TypeAuthIdentityChannel},
+ inters: c.Interceptors(),
+ }
+}
+
+// Get returns a AuthIdentityChannel entity by its id.
+func (c *AuthIdentityChannelClient) Get(ctx context.Context, id int64) (*AuthIdentityChannel, error) {
+ return c.Query().Where(authidentitychannel.ID(id)).Only(ctx)
+}
+
+// GetX is like Get, but panics if an error occurs.
+func (c *AuthIdentityChannelClient) GetX(ctx context.Context, id int64) *AuthIdentityChannel {
+ obj, err := c.Get(ctx, id)
+ if err != nil {
+ panic(err)
+ }
+ return obj
+}
+
+// QueryIdentity queries the identity edge of a AuthIdentityChannel.
+func (c *AuthIdentityChannelClient) QueryIdentity(_m *AuthIdentityChannel) *AuthIdentityQuery {
+ query := (&AuthIdentityClient{config: c.config}).Query()
+ query.path = func(context.Context) (fromV *sql.Selector, _ error) {
+ id := _m.ID
+ step := sqlgraph.NewStep(
+ sqlgraph.From(authidentitychannel.Table, authidentitychannel.FieldID, id),
+ sqlgraph.To(authidentity.Table, authidentity.FieldID),
+ sqlgraph.Edge(sqlgraph.M2O, true, authidentitychannel.IdentityTable, authidentitychannel.IdentityColumn),
+ )
+ fromV = sqlgraph.Neighbors(_m.driver.Dialect(), step)
+ return fromV, nil
+ }
+ return query
+}
+
+// Hooks returns the client hooks.
+func (c *AuthIdentityChannelClient) Hooks() []Hook {
+ return c.hooks.AuthIdentityChannel
+}
+
+// Interceptors returns the client interceptors.
+func (c *AuthIdentityChannelClient) Interceptors() []Interceptor {
+ return c.inters.AuthIdentityChannel
+}
+
+func (c *AuthIdentityChannelClient) mutate(ctx context.Context, m *AuthIdentityChannelMutation) (Value, error) {
+ switch m.Op() {
+ case OpCreate:
+ return (&AuthIdentityChannelCreate{config: c.config, hooks: c.Hooks(), mutation: m}).Save(ctx)
+ case OpUpdate:
+ return (&AuthIdentityChannelUpdate{config: c.config, hooks: c.Hooks(), mutation: m}).Save(ctx)
+ case OpUpdateOne:
+ return (&AuthIdentityChannelUpdateOne{config: c.config, hooks: c.Hooks(), mutation: m}).Save(ctx)
+ case OpDelete, OpDeleteOne:
+ return (&AuthIdentityChannelDelete{config: c.config, hooks: c.Hooks(), mutation: m}).Exec(ctx)
+ default:
+ return nil, fmt.Errorf("ent: unknown AuthIdentityChannel mutation op: %q", m.Op())
+ }
+}
+
// ErrorPassthroughRuleClient is a client for the ErrorPassthroughRule schema.
type ErrorPassthroughRuleClient struct {
config
@@ -1760,6 +2124,171 @@ func (c *IdempotencyRecordClient) mutate(ctx context.Context, m *IdempotencyReco
}
}
+// IdentityAdoptionDecisionClient is a client for the IdentityAdoptionDecision schema.
+type IdentityAdoptionDecisionClient struct {
+ config
+}
+
+// NewIdentityAdoptionDecisionClient returns a client for the IdentityAdoptionDecision from the given config.
+func NewIdentityAdoptionDecisionClient(c config) *IdentityAdoptionDecisionClient {
+ return &IdentityAdoptionDecisionClient{config: c}
+}
+
+// Use adds a list of mutation hooks to the hooks stack.
+// A call to `Use(f, g, h)` equals to `identityadoptiondecision.Hooks(f(g(h())))`.
+func (c *IdentityAdoptionDecisionClient) Use(hooks ...Hook) {
+ c.hooks.IdentityAdoptionDecision = append(c.hooks.IdentityAdoptionDecision, hooks...)
+}
+
+// Intercept adds a list of query interceptors to the interceptors stack.
+// A call to `Intercept(f, g, h)` equals to `identityadoptiondecision.Intercept(f(g(h())))`.
+func (c *IdentityAdoptionDecisionClient) Intercept(interceptors ...Interceptor) {
+ c.inters.IdentityAdoptionDecision = append(c.inters.IdentityAdoptionDecision, interceptors...)
+}
+
+// Create returns a builder for creating a IdentityAdoptionDecision entity.
+func (c *IdentityAdoptionDecisionClient) Create() *IdentityAdoptionDecisionCreate {
+ mutation := newIdentityAdoptionDecisionMutation(c.config, OpCreate)
+ return &IdentityAdoptionDecisionCreate{config: c.config, hooks: c.Hooks(), mutation: mutation}
+}
+
+// CreateBulk returns a builder for creating a bulk of IdentityAdoptionDecision entities.
+func (c *IdentityAdoptionDecisionClient) CreateBulk(builders ...*IdentityAdoptionDecisionCreate) *IdentityAdoptionDecisionCreateBulk {
+ return &IdentityAdoptionDecisionCreateBulk{config: c.config, builders: builders}
+}
+
+// MapCreateBulk creates a bulk creation builder from the given slice. For each item in the slice, the function creates
+// a builder and applies setFunc on it.
+func (c *IdentityAdoptionDecisionClient) MapCreateBulk(slice any, setFunc func(*IdentityAdoptionDecisionCreate, int)) *IdentityAdoptionDecisionCreateBulk {
+ rv := reflect.ValueOf(slice)
+ if rv.Kind() != reflect.Slice {
+ return &IdentityAdoptionDecisionCreateBulk{err: fmt.Errorf("calling to IdentityAdoptionDecisionClient.MapCreateBulk with wrong type %T, need slice", slice)}
+ }
+ builders := make([]*IdentityAdoptionDecisionCreate, rv.Len())
+ for i := 0; i < rv.Len(); i++ {
+ builders[i] = c.Create()
+ setFunc(builders[i], i)
+ }
+ return &IdentityAdoptionDecisionCreateBulk{config: c.config, builders: builders}
+}
+
+// Update returns an update builder for IdentityAdoptionDecision.
+func (c *IdentityAdoptionDecisionClient) Update() *IdentityAdoptionDecisionUpdate {
+ mutation := newIdentityAdoptionDecisionMutation(c.config, OpUpdate)
+ return &IdentityAdoptionDecisionUpdate{config: c.config, hooks: c.Hooks(), mutation: mutation}
+}
+
+// UpdateOne returns an update builder for the given entity.
+func (c *IdentityAdoptionDecisionClient) UpdateOne(_m *IdentityAdoptionDecision) *IdentityAdoptionDecisionUpdateOne {
+ mutation := newIdentityAdoptionDecisionMutation(c.config, OpUpdateOne, withIdentityAdoptionDecision(_m))
+ return &IdentityAdoptionDecisionUpdateOne{config: c.config, hooks: c.Hooks(), mutation: mutation}
+}
+
+// UpdateOneID returns an update builder for the given id.
+func (c *IdentityAdoptionDecisionClient) UpdateOneID(id int64) *IdentityAdoptionDecisionUpdateOne {
+ mutation := newIdentityAdoptionDecisionMutation(c.config, OpUpdateOne, withIdentityAdoptionDecisionID(id))
+ return &IdentityAdoptionDecisionUpdateOne{config: c.config, hooks: c.Hooks(), mutation: mutation}
+}
+
+// Delete returns a delete builder for IdentityAdoptionDecision.
+func (c *IdentityAdoptionDecisionClient) Delete() *IdentityAdoptionDecisionDelete {
+ mutation := newIdentityAdoptionDecisionMutation(c.config, OpDelete)
+ return &IdentityAdoptionDecisionDelete{config: c.config, hooks: c.Hooks(), mutation: mutation}
+}
+
+// DeleteOne returns a builder for deleting the given entity.
+func (c *IdentityAdoptionDecisionClient) DeleteOne(_m *IdentityAdoptionDecision) *IdentityAdoptionDecisionDeleteOne {
+ return c.DeleteOneID(_m.ID)
+}
+
+// DeleteOneID returns a builder for deleting the given entity by its id.
+func (c *IdentityAdoptionDecisionClient) DeleteOneID(id int64) *IdentityAdoptionDecisionDeleteOne {
+ builder := c.Delete().Where(identityadoptiondecision.ID(id))
+ builder.mutation.id = &id
+ builder.mutation.op = OpDeleteOne
+ return &IdentityAdoptionDecisionDeleteOne{builder}
+}
+
+// Query returns a query builder for IdentityAdoptionDecision.
+func (c *IdentityAdoptionDecisionClient) Query() *IdentityAdoptionDecisionQuery {
+ return &IdentityAdoptionDecisionQuery{
+ config: c.config,
+ ctx: &QueryContext{Type: TypeIdentityAdoptionDecision},
+ inters: c.Interceptors(),
+ }
+}
+
+// Get returns a IdentityAdoptionDecision entity by its id.
+func (c *IdentityAdoptionDecisionClient) Get(ctx context.Context, id int64) (*IdentityAdoptionDecision, error) {
+ return c.Query().Where(identityadoptiondecision.ID(id)).Only(ctx)
+}
+
+// GetX is like Get, but panics if an error occurs.
+func (c *IdentityAdoptionDecisionClient) GetX(ctx context.Context, id int64) *IdentityAdoptionDecision {
+ obj, err := c.Get(ctx, id)
+ if err != nil {
+ panic(err)
+ }
+ return obj
+}
+
+// QueryPendingAuthSession queries the pending_auth_session edge of a IdentityAdoptionDecision.
+func (c *IdentityAdoptionDecisionClient) QueryPendingAuthSession(_m *IdentityAdoptionDecision) *PendingAuthSessionQuery {
+ query := (&PendingAuthSessionClient{config: c.config}).Query()
+ query.path = func(context.Context) (fromV *sql.Selector, _ error) {
+ id := _m.ID
+ step := sqlgraph.NewStep(
+ sqlgraph.From(identityadoptiondecision.Table, identityadoptiondecision.FieldID, id),
+ sqlgraph.To(pendingauthsession.Table, pendingauthsession.FieldID),
+ sqlgraph.Edge(sqlgraph.O2O, true, identityadoptiondecision.PendingAuthSessionTable, identityadoptiondecision.PendingAuthSessionColumn),
+ )
+ fromV = sqlgraph.Neighbors(_m.driver.Dialect(), step)
+ return fromV, nil
+ }
+ return query
+}
+
+// QueryIdentity queries the identity edge of a IdentityAdoptionDecision.
+func (c *IdentityAdoptionDecisionClient) QueryIdentity(_m *IdentityAdoptionDecision) *AuthIdentityQuery {
+ query := (&AuthIdentityClient{config: c.config}).Query()
+ query.path = func(context.Context) (fromV *sql.Selector, _ error) {
+ id := _m.ID
+ step := sqlgraph.NewStep(
+ sqlgraph.From(identityadoptiondecision.Table, identityadoptiondecision.FieldID, id),
+ sqlgraph.To(authidentity.Table, authidentity.FieldID),
+ sqlgraph.Edge(sqlgraph.M2O, true, identityadoptiondecision.IdentityTable, identityadoptiondecision.IdentityColumn),
+ )
+ fromV = sqlgraph.Neighbors(_m.driver.Dialect(), step)
+ return fromV, nil
+ }
+ return query
+}
+
+// Hooks returns the client hooks.
+func (c *IdentityAdoptionDecisionClient) Hooks() []Hook {
+ return c.hooks.IdentityAdoptionDecision
+}
+
+// Interceptors returns the client interceptors.
+func (c *IdentityAdoptionDecisionClient) Interceptors() []Interceptor {
+ return c.inters.IdentityAdoptionDecision
+}
+
+func (c *IdentityAdoptionDecisionClient) mutate(ctx context.Context, m *IdentityAdoptionDecisionMutation) (Value, error) {
+ switch m.Op() {
+ case OpCreate:
+ return (&IdentityAdoptionDecisionCreate{config: c.config, hooks: c.Hooks(), mutation: m}).Save(ctx)
+ case OpUpdate:
+ return (&IdentityAdoptionDecisionUpdate{config: c.config, hooks: c.Hooks(), mutation: m}).Save(ctx)
+ case OpUpdateOne:
+ return (&IdentityAdoptionDecisionUpdateOne{config: c.config, hooks: c.Hooks(), mutation: m}).Save(ctx)
+ case OpDelete, OpDeleteOne:
+ return (&IdentityAdoptionDecisionDelete{config: c.config, hooks: c.Hooks(), mutation: m}).Exec(ctx)
+ default:
+ return nil, fmt.Errorf("ent: unknown IdentityAdoptionDecision mutation op: %q", m.Op())
+ }
+}
+
// PaymentAuditLogClient is a client for the PaymentAuditLog schema.
type PaymentAuditLogClient struct {
config
@@ -2175,6 +2704,171 @@ func (c *PaymentProviderInstanceClient) mutate(ctx context.Context, m *PaymentPr
}
}
+// PendingAuthSessionClient is a client for the PendingAuthSession schema.
+type PendingAuthSessionClient struct {
+ config
+}
+
+// NewPendingAuthSessionClient returns a client for the PendingAuthSession from the given config.
+func NewPendingAuthSessionClient(c config) *PendingAuthSessionClient {
+ return &PendingAuthSessionClient{config: c}
+}
+
+// Use adds a list of mutation hooks to the hooks stack.
+// A call to `Use(f, g, h)` equals to `pendingauthsession.Hooks(f(g(h())))`.
+func (c *PendingAuthSessionClient) Use(hooks ...Hook) {
+ c.hooks.PendingAuthSession = append(c.hooks.PendingAuthSession, hooks...)
+}
+
+// Intercept adds a list of query interceptors to the interceptors stack.
+// A call to `Intercept(f, g, h)` equals to `pendingauthsession.Intercept(f(g(h())))`.
+func (c *PendingAuthSessionClient) Intercept(interceptors ...Interceptor) {
+ c.inters.PendingAuthSession = append(c.inters.PendingAuthSession, interceptors...)
+}
+
+// Create returns a builder for creating a PendingAuthSession entity.
+func (c *PendingAuthSessionClient) Create() *PendingAuthSessionCreate {
+ mutation := newPendingAuthSessionMutation(c.config, OpCreate)
+ return &PendingAuthSessionCreate{config: c.config, hooks: c.Hooks(), mutation: mutation}
+}
+
+// CreateBulk returns a builder for creating a bulk of PendingAuthSession entities.
+func (c *PendingAuthSessionClient) CreateBulk(builders ...*PendingAuthSessionCreate) *PendingAuthSessionCreateBulk {
+ return &PendingAuthSessionCreateBulk{config: c.config, builders: builders}
+}
+
+// MapCreateBulk creates a bulk creation builder from the given slice. For each item in the slice, the function creates
+// a builder and applies setFunc on it.
+func (c *PendingAuthSessionClient) MapCreateBulk(slice any, setFunc func(*PendingAuthSessionCreate, int)) *PendingAuthSessionCreateBulk {
+ rv := reflect.ValueOf(slice)
+ if rv.Kind() != reflect.Slice {
+ return &PendingAuthSessionCreateBulk{err: fmt.Errorf("calling to PendingAuthSessionClient.MapCreateBulk with wrong type %T, need slice", slice)}
+ }
+ builders := make([]*PendingAuthSessionCreate, rv.Len())
+ for i := 0; i < rv.Len(); i++ {
+ builders[i] = c.Create()
+ setFunc(builders[i], i)
+ }
+ return &PendingAuthSessionCreateBulk{config: c.config, builders: builders}
+}
+
+// Update returns an update builder for PendingAuthSession.
+func (c *PendingAuthSessionClient) Update() *PendingAuthSessionUpdate {
+ mutation := newPendingAuthSessionMutation(c.config, OpUpdate)
+ return &PendingAuthSessionUpdate{config: c.config, hooks: c.Hooks(), mutation: mutation}
+}
+
+// UpdateOne returns an update builder for the given entity.
+func (c *PendingAuthSessionClient) UpdateOne(_m *PendingAuthSession) *PendingAuthSessionUpdateOne {
+ mutation := newPendingAuthSessionMutation(c.config, OpUpdateOne, withPendingAuthSession(_m))
+ return &PendingAuthSessionUpdateOne{config: c.config, hooks: c.Hooks(), mutation: mutation}
+}
+
+// UpdateOneID returns an update builder for the given id.
+func (c *PendingAuthSessionClient) UpdateOneID(id int64) *PendingAuthSessionUpdateOne {
+ mutation := newPendingAuthSessionMutation(c.config, OpUpdateOne, withPendingAuthSessionID(id))
+ return &PendingAuthSessionUpdateOne{config: c.config, hooks: c.Hooks(), mutation: mutation}
+}
+
+// Delete returns a delete builder for PendingAuthSession.
+func (c *PendingAuthSessionClient) Delete() *PendingAuthSessionDelete {
+ mutation := newPendingAuthSessionMutation(c.config, OpDelete)
+ return &PendingAuthSessionDelete{config: c.config, hooks: c.Hooks(), mutation: mutation}
+}
+
+// DeleteOne returns a builder for deleting the given entity.
+func (c *PendingAuthSessionClient) DeleteOne(_m *PendingAuthSession) *PendingAuthSessionDeleteOne {
+ return c.DeleteOneID(_m.ID)
+}
+
+// DeleteOneID returns a builder for deleting the given entity by its id.
+func (c *PendingAuthSessionClient) DeleteOneID(id int64) *PendingAuthSessionDeleteOne {
+ builder := c.Delete().Where(pendingauthsession.ID(id))
+ builder.mutation.id = &id
+ builder.mutation.op = OpDeleteOne
+ return &PendingAuthSessionDeleteOne{builder}
+}
+
+// Query returns a query builder for PendingAuthSession.
+func (c *PendingAuthSessionClient) Query() *PendingAuthSessionQuery {
+ return &PendingAuthSessionQuery{
+ config: c.config,
+ ctx: &QueryContext{Type: TypePendingAuthSession},
+ inters: c.Interceptors(),
+ }
+}
+
+// Get returns a PendingAuthSession entity by its id.
+func (c *PendingAuthSessionClient) Get(ctx context.Context, id int64) (*PendingAuthSession, error) {
+ return c.Query().Where(pendingauthsession.ID(id)).Only(ctx)
+}
+
+// GetX is like Get, but panics if an error occurs.
+func (c *PendingAuthSessionClient) GetX(ctx context.Context, id int64) *PendingAuthSession {
+ obj, err := c.Get(ctx, id)
+ if err != nil {
+ panic(err)
+ }
+ return obj
+}
+
+// QueryTargetUser queries the target_user edge of a PendingAuthSession.
+func (c *PendingAuthSessionClient) QueryTargetUser(_m *PendingAuthSession) *UserQuery {
+ query := (&UserClient{config: c.config}).Query()
+ query.path = func(context.Context) (fromV *sql.Selector, _ error) {
+ id := _m.ID
+ step := sqlgraph.NewStep(
+ sqlgraph.From(pendingauthsession.Table, pendingauthsession.FieldID, id),
+ sqlgraph.To(user.Table, user.FieldID),
+ sqlgraph.Edge(sqlgraph.M2O, true, pendingauthsession.TargetUserTable, pendingauthsession.TargetUserColumn),
+ )
+ fromV = sqlgraph.Neighbors(_m.driver.Dialect(), step)
+ return fromV, nil
+ }
+ return query
+}
+
+// QueryAdoptionDecision queries the adoption_decision edge of a PendingAuthSession.
+func (c *PendingAuthSessionClient) QueryAdoptionDecision(_m *PendingAuthSession) *IdentityAdoptionDecisionQuery {
+ query := (&IdentityAdoptionDecisionClient{config: c.config}).Query()
+ query.path = func(context.Context) (fromV *sql.Selector, _ error) {
+ id := _m.ID
+ step := sqlgraph.NewStep(
+ sqlgraph.From(pendingauthsession.Table, pendingauthsession.FieldID, id),
+ sqlgraph.To(identityadoptiondecision.Table, identityadoptiondecision.FieldID),
+ sqlgraph.Edge(sqlgraph.O2O, false, pendingauthsession.AdoptionDecisionTable, pendingauthsession.AdoptionDecisionColumn),
+ )
+ fromV = sqlgraph.Neighbors(_m.driver.Dialect(), step)
+ return fromV, nil
+ }
+ return query
+}
+
+// Hooks returns the client hooks.
+func (c *PendingAuthSessionClient) Hooks() []Hook {
+ return c.hooks.PendingAuthSession
+}
+
+// Interceptors returns the client interceptors.
+func (c *PendingAuthSessionClient) Interceptors() []Interceptor {
+ return c.inters.PendingAuthSession
+}
+
+func (c *PendingAuthSessionClient) mutate(ctx context.Context, m *PendingAuthSessionMutation) (Value, error) {
+ switch m.Op() {
+ case OpCreate:
+ return (&PendingAuthSessionCreate{config: c.config, hooks: c.Hooks(), mutation: m}).Save(ctx)
+ case OpUpdate:
+ return (&PendingAuthSessionUpdate{config: c.config, hooks: c.Hooks(), mutation: m}).Save(ctx)
+ case OpUpdateOne:
+ return (&PendingAuthSessionUpdateOne{config: c.config, hooks: c.Hooks(), mutation: m}).Save(ctx)
+ case OpDelete, OpDeleteOne:
+ return (&PendingAuthSessionDelete{config: c.config, hooks: c.Hooks(), mutation: m}).Exec(ctx)
+ default:
+ return nil, fmt.Errorf("ent: unknown PendingAuthSession mutation op: %q", m.Op())
+ }
+}
+
// PromoCodeClient is a client for the PromoCode schema.
type PromoCodeClient struct {
config
@@ -3951,6 +4645,38 @@ func (c *UserClient) QueryPaymentOrders(_m *User) *PaymentOrderQuery {
return query
}
+// QueryAuthIdentities queries the auth_identities edge of a User.
+func (c *UserClient) QueryAuthIdentities(_m *User) *AuthIdentityQuery {
+ query := (&AuthIdentityClient{config: c.config}).Query()
+ query.path = func(context.Context) (fromV *sql.Selector, _ error) {
+ id := _m.ID
+ step := sqlgraph.NewStep(
+ sqlgraph.From(user.Table, user.FieldID, id),
+ sqlgraph.To(authidentity.Table, authidentity.FieldID),
+ sqlgraph.Edge(sqlgraph.O2M, false, user.AuthIdentitiesTable, user.AuthIdentitiesColumn),
+ )
+ fromV = sqlgraph.Neighbors(_m.driver.Dialect(), step)
+ return fromV, nil
+ }
+ return query
+}
+
+// QueryPendingAuthSessions queries the pending_auth_sessions edge of a User.
+func (c *UserClient) QueryPendingAuthSessions(_m *User) *PendingAuthSessionQuery {
+ query := (&PendingAuthSessionClient{config: c.config}).Query()
+ query.path = func(context.Context) (fromV *sql.Selector, _ error) {
+ id := _m.ID
+ step := sqlgraph.NewStep(
+ sqlgraph.From(user.Table, user.FieldID, id),
+ sqlgraph.To(pendingauthsession.Table, pendingauthsession.FieldID),
+ sqlgraph.Edge(sqlgraph.O2M, false, user.PendingAuthSessionsTable, user.PendingAuthSessionsColumn),
+ )
+ fromV = sqlgraph.Neighbors(_m.driver.Dialect(), step)
+ return fromV, nil
+ }
+ return query
+}
+
// QueryUserAllowedGroups queries the user_allowed_groups edge of a User.
func (c *UserClient) QueryUserAllowedGroups(_m *User) *UserAllowedGroupQuery {
query := (&UserAllowedGroupClient{config: c.config}).Query()
@@ -4628,18 +5354,20 @@ func (c *UserSubscriptionClient) mutate(ctx context.Context, m *UserSubscription
// hooks and interceptors per client, for fast access.
type (
hooks struct {
- APIKey, Account, AccountGroup, Announcement, AnnouncementRead,
- ErrorPassthroughRule, Group, IdempotencyRecord, PaymentAuditLog, PaymentOrder,
- PaymentProviderInstance, PromoCode, PromoCodeUsage, Proxy, RedeemCode,
- SecuritySecret, Setting, SubscriptionPlan, TLSFingerprintProfile,
+ APIKey, Account, AccountGroup, Announcement, AnnouncementRead, AuthIdentity,
+ AuthIdentityChannel, ErrorPassthroughRule, Group, IdempotencyRecord,
+ IdentityAdoptionDecision, PaymentAuditLog, PaymentOrder,
+ PaymentProviderInstance, PendingAuthSession, PromoCode, PromoCodeUsage, Proxy,
+ RedeemCode, SecuritySecret, Setting, SubscriptionPlan, TLSFingerprintProfile,
UsageCleanupTask, UsageLog, User, UserAllowedGroup, UserAttributeDefinition,
UserAttributeValue, UserSubscription []ent.Hook
}
inters struct {
- APIKey, Account, AccountGroup, Announcement, AnnouncementRead,
- ErrorPassthroughRule, Group, IdempotencyRecord, PaymentAuditLog, PaymentOrder,
- PaymentProviderInstance, PromoCode, PromoCodeUsage, Proxy, RedeemCode,
- SecuritySecret, Setting, SubscriptionPlan, TLSFingerprintProfile,
+ APIKey, Account, AccountGroup, Announcement, AnnouncementRead, AuthIdentity,
+ AuthIdentityChannel, ErrorPassthroughRule, Group, IdempotencyRecord,
+ IdentityAdoptionDecision, PaymentAuditLog, PaymentOrder,
+ PaymentProviderInstance, PendingAuthSession, PromoCode, PromoCodeUsage, Proxy,
+ RedeemCode, SecuritySecret, Setting, SubscriptionPlan, TLSFingerprintProfile,
UsageCleanupTask, UsageLog, User, UserAllowedGroup, UserAttributeDefinition,
UserAttributeValue, UserSubscription []ent.Interceptor
}
diff --git a/backend/ent/ent.go b/backend/ent/ent.go
index 96ed5e03..339e5369 100644
--- a/backend/ent/ent.go
+++ b/backend/ent/ent.go
@@ -17,12 +17,16 @@ import (
"github.com/Wei-Shaw/sub2api/ent/announcement"
"github.com/Wei-Shaw/sub2api/ent/announcementread"
"github.com/Wei-Shaw/sub2api/ent/apikey"
+ "github.com/Wei-Shaw/sub2api/ent/authidentity"
+ "github.com/Wei-Shaw/sub2api/ent/authidentitychannel"
"github.com/Wei-Shaw/sub2api/ent/errorpassthroughrule"
"github.com/Wei-Shaw/sub2api/ent/group"
"github.com/Wei-Shaw/sub2api/ent/idempotencyrecord"
+ "github.com/Wei-Shaw/sub2api/ent/identityadoptiondecision"
"github.com/Wei-Shaw/sub2api/ent/paymentauditlog"
"github.com/Wei-Shaw/sub2api/ent/paymentorder"
"github.com/Wei-Shaw/sub2api/ent/paymentproviderinstance"
+ "github.com/Wei-Shaw/sub2api/ent/pendingauthsession"
"github.com/Wei-Shaw/sub2api/ent/promocode"
"github.com/Wei-Shaw/sub2api/ent/promocodeusage"
"github.com/Wei-Shaw/sub2api/ent/proxy"
@@ -98,32 +102,36 @@ var (
func checkColumn(t, c string) error {
initCheck.Do(func() {
columnCheck = sql.NewColumnCheck(map[string]func(string) bool{
- apikey.Table: apikey.ValidColumn,
- account.Table: account.ValidColumn,
- accountgroup.Table: accountgroup.ValidColumn,
- announcement.Table: announcement.ValidColumn,
- announcementread.Table: announcementread.ValidColumn,
- errorpassthroughrule.Table: errorpassthroughrule.ValidColumn,
- group.Table: group.ValidColumn,
- idempotencyrecord.Table: idempotencyrecord.ValidColumn,
- paymentauditlog.Table: paymentauditlog.ValidColumn,
- paymentorder.Table: paymentorder.ValidColumn,
- paymentproviderinstance.Table: paymentproviderinstance.ValidColumn,
- promocode.Table: promocode.ValidColumn,
- promocodeusage.Table: promocodeusage.ValidColumn,
- proxy.Table: proxy.ValidColumn,
- redeemcode.Table: redeemcode.ValidColumn,
- securitysecret.Table: securitysecret.ValidColumn,
- setting.Table: setting.ValidColumn,
- subscriptionplan.Table: subscriptionplan.ValidColumn,
- tlsfingerprintprofile.Table: tlsfingerprintprofile.ValidColumn,
- usagecleanuptask.Table: usagecleanuptask.ValidColumn,
- usagelog.Table: usagelog.ValidColumn,
- user.Table: user.ValidColumn,
- userallowedgroup.Table: userallowedgroup.ValidColumn,
- userattributedefinition.Table: userattributedefinition.ValidColumn,
- userattributevalue.Table: userattributevalue.ValidColumn,
- usersubscription.Table: usersubscription.ValidColumn,
+ apikey.Table: apikey.ValidColumn,
+ account.Table: account.ValidColumn,
+ accountgroup.Table: accountgroup.ValidColumn,
+ announcement.Table: announcement.ValidColumn,
+ announcementread.Table: announcementread.ValidColumn,
+ authidentity.Table: authidentity.ValidColumn,
+ authidentitychannel.Table: authidentitychannel.ValidColumn,
+ errorpassthroughrule.Table: errorpassthroughrule.ValidColumn,
+ group.Table: group.ValidColumn,
+ idempotencyrecord.Table: idempotencyrecord.ValidColumn,
+ identityadoptiondecision.Table: identityadoptiondecision.ValidColumn,
+ paymentauditlog.Table: paymentauditlog.ValidColumn,
+ paymentorder.Table: paymentorder.ValidColumn,
+ paymentproviderinstance.Table: paymentproviderinstance.ValidColumn,
+ pendingauthsession.Table: pendingauthsession.ValidColumn,
+ promocode.Table: promocode.ValidColumn,
+ promocodeusage.Table: promocodeusage.ValidColumn,
+ proxy.Table: proxy.ValidColumn,
+ redeemcode.Table: redeemcode.ValidColumn,
+ securitysecret.Table: securitysecret.ValidColumn,
+ setting.Table: setting.ValidColumn,
+ subscriptionplan.Table: subscriptionplan.ValidColumn,
+ tlsfingerprintprofile.Table: tlsfingerprintprofile.ValidColumn,
+ usagecleanuptask.Table: usagecleanuptask.ValidColumn,
+ usagelog.Table: usagelog.ValidColumn,
+ user.Table: user.ValidColumn,
+ userallowedgroup.Table: userallowedgroup.ValidColumn,
+ userattributedefinition.Table: userattributedefinition.ValidColumn,
+ userattributevalue.Table: userattributevalue.ValidColumn,
+ usersubscription.Table: usersubscription.ValidColumn,
})
})
return columnCheck(t, c)
diff --git a/backend/ent/hook/hook.go b/backend/ent/hook/hook.go
index 199dacea..46ac02bc 100644
--- a/backend/ent/hook/hook.go
+++ b/backend/ent/hook/hook.go
@@ -69,6 +69,30 @@ func (f AnnouncementReadFunc) Mutate(ctx context.Context, m ent.Mutation) (ent.V
return nil, fmt.Errorf("unexpected mutation type %T. expect *ent.AnnouncementReadMutation", m)
}
+// The AuthIdentityFunc type is an adapter to allow the use of ordinary
+// function as AuthIdentity mutator.
+type AuthIdentityFunc func(context.Context, *ent.AuthIdentityMutation) (ent.Value, error)
+
+// Mutate calls f(ctx, m).
+func (f AuthIdentityFunc) Mutate(ctx context.Context, m ent.Mutation) (ent.Value, error) {
+ if mv, ok := m.(*ent.AuthIdentityMutation); ok {
+ return f(ctx, mv)
+ }
+ return nil, fmt.Errorf("unexpected mutation type %T. expect *ent.AuthIdentityMutation", m)
+}
+
+// The AuthIdentityChannelFunc type is an adapter to allow the use of ordinary
+// function as AuthIdentityChannel mutator.
+type AuthIdentityChannelFunc func(context.Context, *ent.AuthIdentityChannelMutation) (ent.Value, error)
+
+// Mutate calls f(ctx, m).
+func (f AuthIdentityChannelFunc) Mutate(ctx context.Context, m ent.Mutation) (ent.Value, error) {
+ if mv, ok := m.(*ent.AuthIdentityChannelMutation); ok {
+ return f(ctx, mv)
+ }
+ return nil, fmt.Errorf("unexpected mutation type %T. expect *ent.AuthIdentityChannelMutation", m)
+}
+
// The ErrorPassthroughRuleFunc type is an adapter to allow the use of ordinary
// function as ErrorPassthroughRule mutator.
type ErrorPassthroughRuleFunc func(context.Context, *ent.ErrorPassthroughRuleMutation) (ent.Value, error)
@@ -105,6 +129,18 @@ func (f IdempotencyRecordFunc) Mutate(ctx context.Context, m ent.Mutation) (ent.
return nil, fmt.Errorf("unexpected mutation type %T. expect *ent.IdempotencyRecordMutation", m)
}
+// The IdentityAdoptionDecisionFunc type is an adapter to allow the use of ordinary
+// function as IdentityAdoptionDecision mutator.
+type IdentityAdoptionDecisionFunc func(context.Context, *ent.IdentityAdoptionDecisionMutation) (ent.Value, error)
+
+// Mutate calls f(ctx, m).
+func (f IdentityAdoptionDecisionFunc) Mutate(ctx context.Context, m ent.Mutation) (ent.Value, error) {
+ if mv, ok := m.(*ent.IdentityAdoptionDecisionMutation); ok {
+ return f(ctx, mv)
+ }
+ return nil, fmt.Errorf("unexpected mutation type %T. expect *ent.IdentityAdoptionDecisionMutation", m)
+}
+
// The PaymentAuditLogFunc type is an adapter to allow the use of ordinary
// function as PaymentAuditLog mutator.
type PaymentAuditLogFunc func(context.Context, *ent.PaymentAuditLogMutation) (ent.Value, error)
@@ -141,6 +177,18 @@ func (f PaymentProviderInstanceFunc) Mutate(ctx context.Context, m ent.Mutation)
return nil, fmt.Errorf("unexpected mutation type %T. expect *ent.PaymentProviderInstanceMutation", m)
}
+// The PendingAuthSessionFunc type is an adapter to allow the use of ordinary
+// function as PendingAuthSession mutator.
+type PendingAuthSessionFunc func(context.Context, *ent.PendingAuthSessionMutation) (ent.Value, error)
+
+// Mutate calls f(ctx, m).
+func (f PendingAuthSessionFunc) Mutate(ctx context.Context, m ent.Mutation) (ent.Value, error) {
+ if mv, ok := m.(*ent.PendingAuthSessionMutation); ok {
+ return f(ctx, mv)
+ }
+ return nil, fmt.Errorf("unexpected mutation type %T. expect *ent.PendingAuthSessionMutation", m)
+}
+
// The PromoCodeFunc type is an adapter to allow the use of ordinary
// function as PromoCode mutator.
type PromoCodeFunc func(context.Context, *ent.PromoCodeMutation) (ent.Value, error)
diff --git a/backend/ent/identityadoptiondecision.go b/backend/ent/identityadoptiondecision.go
new file mode 100644
index 00000000..ecaee65c
--- /dev/null
+++ b/backend/ent/identityadoptiondecision.go
@@ -0,0 +1,223 @@
+// Code generated by ent, DO NOT EDIT.
+
+package ent
+
+import (
+ "fmt"
+ "strings"
+ "time"
+
+ "entgo.io/ent"
+ "entgo.io/ent/dialect/sql"
+ "github.com/Wei-Shaw/sub2api/ent/authidentity"
+ "github.com/Wei-Shaw/sub2api/ent/identityadoptiondecision"
+ "github.com/Wei-Shaw/sub2api/ent/pendingauthsession"
+)
+
+// IdentityAdoptionDecision is the model entity for the IdentityAdoptionDecision schema.
+type IdentityAdoptionDecision struct {
+ config `json:"-"`
+ // ID of the ent.
+ ID int64 `json:"id,omitempty"`
+ // CreatedAt holds the value of the "created_at" field.
+ CreatedAt time.Time `json:"created_at,omitempty"`
+ // UpdatedAt holds the value of the "updated_at" field.
+ UpdatedAt time.Time `json:"updated_at,omitempty"`
+ // PendingAuthSessionID holds the value of the "pending_auth_session_id" field.
+ PendingAuthSessionID int64 `json:"pending_auth_session_id,omitempty"`
+ // IdentityID holds the value of the "identity_id" field.
+ IdentityID *int64 `json:"identity_id,omitempty"`
+ // AdoptDisplayName holds the value of the "adopt_display_name" field.
+ AdoptDisplayName bool `json:"adopt_display_name,omitempty"`
+ // AdoptAvatar holds the value of the "adopt_avatar" field.
+ AdoptAvatar bool `json:"adopt_avatar,omitempty"`
+ // DecidedAt holds the value of the "decided_at" field.
+ DecidedAt time.Time `json:"decided_at,omitempty"`
+ // Edges holds the relations/edges for other nodes in the graph.
+ // The values are being populated by the IdentityAdoptionDecisionQuery when eager-loading is set.
+ Edges IdentityAdoptionDecisionEdges `json:"edges"`
+ selectValues sql.SelectValues
+}
+
+// IdentityAdoptionDecisionEdges holds the relations/edges for other nodes in the graph.
+type IdentityAdoptionDecisionEdges struct {
+ // PendingAuthSession holds the value of the pending_auth_session edge.
+ PendingAuthSession *PendingAuthSession `json:"pending_auth_session,omitempty"`
+ // Identity holds the value of the identity edge.
+ Identity *AuthIdentity `json:"identity,omitempty"`
+ // loadedTypes holds the information for reporting if a
+ // type was loaded (or requested) in eager-loading or not.
+ loadedTypes [2]bool
+}
+
+// PendingAuthSessionOrErr returns the PendingAuthSession value or an error if the edge
+// was not loaded in eager-loading, or loaded but was not found.
+func (e IdentityAdoptionDecisionEdges) PendingAuthSessionOrErr() (*PendingAuthSession, error) {
+ if e.PendingAuthSession != nil {
+ return e.PendingAuthSession, nil
+ } else if e.loadedTypes[0] {
+ return nil, &NotFoundError{label: pendingauthsession.Label}
+ }
+ return nil, &NotLoadedError{edge: "pending_auth_session"}
+}
+
+// IdentityOrErr returns the Identity value or an error if the edge
+// was not loaded in eager-loading, or loaded but was not found.
+func (e IdentityAdoptionDecisionEdges) IdentityOrErr() (*AuthIdentity, error) {
+ if e.Identity != nil {
+ return e.Identity, nil
+ } else if e.loadedTypes[1] {
+ return nil, &NotFoundError{label: authidentity.Label}
+ }
+ return nil, &NotLoadedError{edge: "identity"}
+}
+
+// scanValues returns the types for scanning values from sql.Rows.
+func (*IdentityAdoptionDecision) scanValues(columns []string) ([]any, error) {
+ values := make([]any, len(columns))
+ for i := range columns {
+ switch columns[i] {
+ case identityadoptiondecision.FieldAdoptDisplayName, identityadoptiondecision.FieldAdoptAvatar:
+ values[i] = new(sql.NullBool)
+ case identityadoptiondecision.FieldID, identityadoptiondecision.FieldPendingAuthSessionID, identityadoptiondecision.FieldIdentityID:
+ values[i] = new(sql.NullInt64)
+ case identityadoptiondecision.FieldCreatedAt, identityadoptiondecision.FieldUpdatedAt, identityadoptiondecision.FieldDecidedAt:
+ values[i] = new(sql.NullTime)
+ default:
+ values[i] = new(sql.UnknownType)
+ }
+ }
+ return values, nil
+}
+
+// assignValues assigns the values that were returned from sql.Rows (after scanning)
+// to the IdentityAdoptionDecision fields.
+func (_m *IdentityAdoptionDecision) assignValues(columns []string, values []any) error {
+ if m, n := len(values), len(columns); m < n {
+ return fmt.Errorf("mismatch number of scan values: %d != %d", m, n)
+ }
+ for i := range columns {
+ switch columns[i] {
+ case identityadoptiondecision.FieldID:
+ value, ok := values[i].(*sql.NullInt64)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field id", value)
+ }
+ _m.ID = int64(value.Int64)
+ case identityadoptiondecision.FieldCreatedAt:
+ if value, ok := values[i].(*sql.NullTime); !ok {
+ return fmt.Errorf("unexpected type %T for field created_at", values[i])
+ } else if value.Valid {
+ _m.CreatedAt = value.Time
+ }
+ case identityadoptiondecision.FieldUpdatedAt:
+ if value, ok := values[i].(*sql.NullTime); !ok {
+ return fmt.Errorf("unexpected type %T for field updated_at", values[i])
+ } else if value.Valid {
+ _m.UpdatedAt = value.Time
+ }
+ case identityadoptiondecision.FieldPendingAuthSessionID:
+ if value, ok := values[i].(*sql.NullInt64); !ok {
+ return fmt.Errorf("unexpected type %T for field pending_auth_session_id", values[i])
+ } else if value.Valid {
+ _m.PendingAuthSessionID = value.Int64
+ }
+ case identityadoptiondecision.FieldIdentityID:
+ if value, ok := values[i].(*sql.NullInt64); !ok {
+ return fmt.Errorf("unexpected type %T for field identity_id", values[i])
+ } else if value.Valid {
+ _m.IdentityID = new(int64)
+ *_m.IdentityID = value.Int64
+ }
+ case identityadoptiondecision.FieldAdoptDisplayName:
+ if value, ok := values[i].(*sql.NullBool); !ok {
+ return fmt.Errorf("unexpected type %T for field adopt_display_name", values[i])
+ } else if value.Valid {
+ _m.AdoptDisplayName = value.Bool
+ }
+ case identityadoptiondecision.FieldAdoptAvatar:
+ if value, ok := values[i].(*sql.NullBool); !ok {
+ return fmt.Errorf("unexpected type %T for field adopt_avatar", values[i])
+ } else if value.Valid {
+ _m.AdoptAvatar = value.Bool
+ }
+ case identityadoptiondecision.FieldDecidedAt:
+ if value, ok := values[i].(*sql.NullTime); !ok {
+ return fmt.Errorf("unexpected type %T for field decided_at", values[i])
+ } else if value.Valid {
+ _m.DecidedAt = value.Time
+ }
+ default:
+ _m.selectValues.Set(columns[i], values[i])
+ }
+ }
+ return nil
+}
+
+// Value returns the ent.Value that was dynamically selected and assigned to the IdentityAdoptionDecision.
+// This includes values selected through modifiers, order, etc.
+func (_m *IdentityAdoptionDecision) Value(name string) (ent.Value, error) {
+ return _m.selectValues.Get(name)
+}
+
+// QueryPendingAuthSession queries the "pending_auth_session" edge of the IdentityAdoptionDecision entity.
+func (_m *IdentityAdoptionDecision) QueryPendingAuthSession() *PendingAuthSessionQuery {
+ return NewIdentityAdoptionDecisionClient(_m.config).QueryPendingAuthSession(_m)
+}
+
+// QueryIdentity queries the "identity" edge of the IdentityAdoptionDecision entity.
+func (_m *IdentityAdoptionDecision) QueryIdentity() *AuthIdentityQuery {
+ return NewIdentityAdoptionDecisionClient(_m.config).QueryIdentity(_m)
+}
+
+// Update returns a builder for updating this IdentityAdoptionDecision.
+// Note that you need to call IdentityAdoptionDecision.Unwrap() before calling this method if this IdentityAdoptionDecision
+// was returned from a transaction, and the transaction was committed or rolled back.
+func (_m *IdentityAdoptionDecision) Update() *IdentityAdoptionDecisionUpdateOne {
+ return NewIdentityAdoptionDecisionClient(_m.config).UpdateOne(_m)
+}
+
+// Unwrap unwraps the IdentityAdoptionDecision entity that was returned from a transaction after it was closed,
+// so that all future queries will be executed through the driver which created the transaction.
+func (_m *IdentityAdoptionDecision) Unwrap() *IdentityAdoptionDecision {
+ _tx, ok := _m.config.driver.(*txDriver)
+ if !ok {
+ panic("ent: IdentityAdoptionDecision is not a transactional entity")
+ }
+ _m.config.driver = _tx.drv
+ return _m
+}
+
+// String implements the fmt.Stringer.
+func (_m *IdentityAdoptionDecision) String() string {
+ var builder strings.Builder
+ builder.WriteString("IdentityAdoptionDecision(")
+ builder.WriteString(fmt.Sprintf("id=%v, ", _m.ID))
+ builder.WriteString("created_at=")
+ builder.WriteString(_m.CreatedAt.Format(time.ANSIC))
+ builder.WriteString(", ")
+ builder.WriteString("updated_at=")
+ builder.WriteString(_m.UpdatedAt.Format(time.ANSIC))
+ builder.WriteString(", ")
+ builder.WriteString("pending_auth_session_id=")
+ builder.WriteString(fmt.Sprintf("%v", _m.PendingAuthSessionID))
+ builder.WriteString(", ")
+ if v := _m.IdentityID; v != nil {
+ builder.WriteString("identity_id=")
+ builder.WriteString(fmt.Sprintf("%v", *v))
+ }
+ builder.WriteString(", ")
+ builder.WriteString("adopt_display_name=")
+ builder.WriteString(fmt.Sprintf("%v", _m.AdoptDisplayName))
+ builder.WriteString(", ")
+ builder.WriteString("adopt_avatar=")
+ builder.WriteString(fmt.Sprintf("%v", _m.AdoptAvatar))
+ builder.WriteString(", ")
+ builder.WriteString("decided_at=")
+ builder.WriteString(_m.DecidedAt.Format(time.ANSIC))
+ builder.WriteByte(')')
+ return builder.String()
+}
+
+// IdentityAdoptionDecisions is a parsable slice of IdentityAdoptionDecision.
+type IdentityAdoptionDecisions []*IdentityAdoptionDecision
diff --git a/backend/ent/identityadoptiondecision/identityadoptiondecision.go b/backend/ent/identityadoptiondecision/identityadoptiondecision.go
new file mode 100644
index 00000000..93adaf73
--- /dev/null
+++ b/backend/ent/identityadoptiondecision/identityadoptiondecision.go
@@ -0,0 +1,159 @@
+// Code generated by ent, DO NOT EDIT.
+
+package identityadoptiondecision
+
+import (
+ "time"
+
+ "entgo.io/ent/dialect/sql"
+ "entgo.io/ent/dialect/sql/sqlgraph"
+)
+
+const (
+ // Label holds the string label denoting the identityadoptiondecision type in the database.
+ Label = "identity_adoption_decision"
+ // FieldID holds the string denoting the id field in the database.
+ FieldID = "id"
+ // FieldCreatedAt holds the string denoting the created_at field in the database.
+ FieldCreatedAt = "created_at"
+ // FieldUpdatedAt holds the string denoting the updated_at field in the database.
+ FieldUpdatedAt = "updated_at"
+ // FieldPendingAuthSessionID holds the string denoting the pending_auth_session_id field in the database.
+ FieldPendingAuthSessionID = "pending_auth_session_id"
+ // FieldIdentityID holds the string denoting the identity_id field in the database.
+ FieldIdentityID = "identity_id"
+ // FieldAdoptDisplayName holds the string denoting the adopt_display_name field in the database.
+ FieldAdoptDisplayName = "adopt_display_name"
+ // FieldAdoptAvatar holds the string denoting the adopt_avatar field in the database.
+ FieldAdoptAvatar = "adopt_avatar"
+ // FieldDecidedAt holds the string denoting the decided_at field in the database.
+ FieldDecidedAt = "decided_at"
+ // EdgePendingAuthSession holds the string denoting the pending_auth_session edge name in mutations.
+ EdgePendingAuthSession = "pending_auth_session"
+ // EdgeIdentity holds the string denoting the identity edge name in mutations.
+ EdgeIdentity = "identity"
+ // Table holds the table name of the identityadoptiondecision in the database.
+ Table = "identity_adoption_decisions"
+ // PendingAuthSessionTable is the table that holds the pending_auth_session relation/edge.
+ PendingAuthSessionTable = "identity_adoption_decisions"
+ // PendingAuthSessionInverseTable is the table name for the PendingAuthSession entity.
+ // It exists in this package in order to avoid circular dependency with the "pendingauthsession" package.
+ PendingAuthSessionInverseTable = "pending_auth_sessions"
+ // PendingAuthSessionColumn is the table column denoting the pending_auth_session relation/edge.
+ PendingAuthSessionColumn = "pending_auth_session_id"
+ // IdentityTable is the table that holds the identity relation/edge.
+ IdentityTable = "identity_adoption_decisions"
+ // IdentityInverseTable is the table name for the AuthIdentity entity.
+ // It exists in this package in order to avoid circular dependency with the "authidentity" package.
+ IdentityInverseTable = "auth_identities"
+ // IdentityColumn is the table column denoting the identity relation/edge.
+ IdentityColumn = "identity_id"
+)
+
+// Columns holds all SQL columns for identityadoptiondecision fields.
+var Columns = []string{
+ FieldID,
+ FieldCreatedAt,
+ FieldUpdatedAt,
+ FieldPendingAuthSessionID,
+ FieldIdentityID,
+ FieldAdoptDisplayName,
+ FieldAdoptAvatar,
+ FieldDecidedAt,
+}
+
+// ValidColumn reports if the column name is valid (part of the table columns).
+func ValidColumn(column string) bool {
+ for i := range Columns {
+ if column == Columns[i] {
+ return true
+ }
+ }
+ return false
+}
+
+var (
+ // DefaultCreatedAt holds the default value on creation for the "created_at" field.
+ DefaultCreatedAt func() time.Time
+ // DefaultUpdatedAt holds the default value on creation for the "updated_at" field.
+ DefaultUpdatedAt func() time.Time
+ // UpdateDefaultUpdatedAt holds the default value on update for the "updated_at" field.
+ UpdateDefaultUpdatedAt func() time.Time
+ // DefaultAdoptDisplayName holds the default value on creation for the "adopt_display_name" field.
+ DefaultAdoptDisplayName bool
+ // DefaultAdoptAvatar holds the default value on creation for the "adopt_avatar" field.
+ DefaultAdoptAvatar bool
+ // DefaultDecidedAt holds the default value on creation for the "decided_at" field.
+ DefaultDecidedAt func() time.Time
+)
+
+// OrderOption defines the ordering options for the IdentityAdoptionDecision queries.
+type OrderOption func(*sql.Selector)
+
+// ByID orders the results by the id field.
+func ByID(opts ...sql.OrderTermOption) OrderOption {
+ return sql.OrderByField(FieldID, opts...).ToFunc()
+}
+
+// ByCreatedAt orders the results by the created_at field.
+func ByCreatedAt(opts ...sql.OrderTermOption) OrderOption {
+ return sql.OrderByField(FieldCreatedAt, opts...).ToFunc()
+}
+
+// ByUpdatedAt orders the results by the updated_at field.
+func ByUpdatedAt(opts ...sql.OrderTermOption) OrderOption {
+ return sql.OrderByField(FieldUpdatedAt, opts...).ToFunc()
+}
+
+// ByPendingAuthSessionID orders the results by the pending_auth_session_id field.
+func ByPendingAuthSessionID(opts ...sql.OrderTermOption) OrderOption {
+ return sql.OrderByField(FieldPendingAuthSessionID, opts...).ToFunc()
+}
+
+// ByIdentityID orders the results by the identity_id field.
+func ByIdentityID(opts ...sql.OrderTermOption) OrderOption {
+ return sql.OrderByField(FieldIdentityID, opts...).ToFunc()
+}
+
+// ByAdoptDisplayName orders the results by the adopt_display_name field.
+func ByAdoptDisplayName(opts ...sql.OrderTermOption) OrderOption {
+ return sql.OrderByField(FieldAdoptDisplayName, opts...).ToFunc()
+}
+
+// ByAdoptAvatar orders the results by the adopt_avatar field.
+func ByAdoptAvatar(opts ...sql.OrderTermOption) OrderOption {
+ return sql.OrderByField(FieldAdoptAvatar, opts...).ToFunc()
+}
+
+// ByDecidedAt orders the results by the decided_at field.
+func ByDecidedAt(opts ...sql.OrderTermOption) OrderOption {
+ return sql.OrderByField(FieldDecidedAt, opts...).ToFunc()
+}
+
+// ByPendingAuthSessionField orders the results by pending_auth_session field.
+func ByPendingAuthSessionField(field string, opts ...sql.OrderTermOption) OrderOption {
+ return func(s *sql.Selector) {
+ sqlgraph.OrderByNeighborTerms(s, newPendingAuthSessionStep(), sql.OrderByField(field, opts...))
+ }
+}
+
+// ByIdentityField orders the results by identity field.
+func ByIdentityField(field string, opts ...sql.OrderTermOption) OrderOption {
+ return func(s *sql.Selector) {
+ sqlgraph.OrderByNeighborTerms(s, newIdentityStep(), sql.OrderByField(field, opts...))
+ }
+}
+func newPendingAuthSessionStep() *sqlgraph.Step {
+ return sqlgraph.NewStep(
+ sqlgraph.From(Table, FieldID),
+ sqlgraph.To(PendingAuthSessionInverseTable, FieldID),
+ sqlgraph.Edge(sqlgraph.O2O, true, PendingAuthSessionTable, PendingAuthSessionColumn),
+ )
+}
+func newIdentityStep() *sqlgraph.Step {
+ return sqlgraph.NewStep(
+ sqlgraph.From(Table, FieldID),
+ sqlgraph.To(IdentityInverseTable, FieldID),
+ sqlgraph.Edge(sqlgraph.M2O, true, IdentityTable, IdentityColumn),
+ )
+}
diff --git a/backend/ent/identityadoptiondecision/where.go b/backend/ent/identityadoptiondecision/where.go
new file mode 100644
index 00000000..1968f175
--- /dev/null
+++ b/backend/ent/identityadoptiondecision/where.go
@@ -0,0 +1,342 @@
+// Code generated by ent, DO NOT EDIT.
+
+package identityadoptiondecision
+
+import (
+ "time"
+
+ "entgo.io/ent/dialect/sql"
+ "entgo.io/ent/dialect/sql/sqlgraph"
+ "github.com/Wei-Shaw/sub2api/ent/predicate"
+)
+
+// ID filters vertices based on their ID field.
+func ID(id int64) predicate.IdentityAdoptionDecision {
+ return predicate.IdentityAdoptionDecision(sql.FieldEQ(FieldID, id))
+}
+
+// IDEQ applies the EQ predicate on the ID field.
+func IDEQ(id int64) predicate.IdentityAdoptionDecision {
+ return predicate.IdentityAdoptionDecision(sql.FieldEQ(FieldID, id))
+}
+
+// IDNEQ applies the NEQ predicate on the ID field.
+func IDNEQ(id int64) predicate.IdentityAdoptionDecision {
+ return predicate.IdentityAdoptionDecision(sql.FieldNEQ(FieldID, id))
+}
+
+// IDIn applies the In predicate on the ID field.
+func IDIn(ids ...int64) predicate.IdentityAdoptionDecision {
+ return predicate.IdentityAdoptionDecision(sql.FieldIn(FieldID, ids...))
+}
+
+// IDNotIn applies the NotIn predicate on the ID field.
+func IDNotIn(ids ...int64) predicate.IdentityAdoptionDecision {
+ return predicate.IdentityAdoptionDecision(sql.FieldNotIn(FieldID, ids...))
+}
+
+// IDGT applies the GT predicate on the ID field.
+func IDGT(id int64) predicate.IdentityAdoptionDecision {
+ return predicate.IdentityAdoptionDecision(sql.FieldGT(FieldID, id))
+}
+
+// IDGTE applies the GTE predicate on the ID field.
+func IDGTE(id int64) predicate.IdentityAdoptionDecision {
+ return predicate.IdentityAdoptionDecision(sql.FieldGTE(FieldID, id))
+}
+
+// IDLT applies the LT predicate on the ID field.
+func IDLT(id int64) predicate.IdentityAdoptionDecision {
+ return predicate.IdentityAdoptionDecision(sql.FieldLT(FieldID, id))
+}
+
+// IDLTE applies the LTE predicate on the ID field.
+func IDLTE(id int64) predicate.IdentityAdoptionDecision {
+ return predicate.IdentityAdoptionDecision(sql.FieldLTE(FieldID, id))
+}
+
+// CreatedAt applies equality check predicate on the "created_at" field. It's identical to CreatedAtEQ.
+func CreatedAt(v time.Time) predicate.IdentityAdoptionDecision {
+ return predicate.IdentityAdoptionDecision(sql.FieldEQ(FieldCreatedAt, v))
+}
+
+// UpdatedAt applies equality check predicate on the "updated_at" field. It's identical to UpdatedAtEQ.
+func UpdatedAt(v time.Time) predicate.IdentityAdoptionDecision {
+ return predicate.IdentityAdoptionDecision(sql.FieldEQ(FieldUpdatedAt, v))
+}
+
+// PendingAuthSessionID applies equality check predicate on the "pending_auth_session_id" field. It's identical to PendingAuthSessionIDEQ.
+func PendingAuthSessionID(v int64) predicate.IdentityAdoptionDecision {
+ return predicate.IdentityAdoptionDecision(sql.FieldEQ(FieldPendingAuthSessionID, v))
+}
+
+// IdentityID applies equality check predicate on the "identity_id" field. It's identical to IdentityIDEQ.
+func IdentityID(v int64) predicate.IdentityAdoptionDecision {
+ return predicate.IdentityAdoptionDecision(sql.FieldEQ(FieldIdentityID, v))
+}
+
+// AdoptDisplayName applies equality check predicate on the "adopt_display_name" field. It's identical to AdoptDisplayNameEQ.
+func AdoptDisplayName(v bool) predicate.IdentityAdoptionDecision {
+ return predicate.IdentityAdoptionDecision(sql.FieldEQ(FieldAdoptDisplayName, v))
+}
+
+// AdoptAvatar applies equality check predicate on the "adopt_avatar" field. It's identical to AdoptAvatarEQ.
+func AdoptAvatar(v bool) predicate.IdentityAdoptionDecision {
+ return predicate.IdentityAdoptionDecision(sql.FieldEQ(FieldAdoptAvatar, v))
+}
+
+// DecidedAt applies equality check predicate on the "decided_at" field. It's identical to DecidedAtEQ.
+func DecidedAt(v time.Time) predicate.IdentityAdoptionDecision {
+ return predicate.IdentityAdoptionDecision(sql.FieldEQ(FieldDecidedAt, v))
+}
+
+// CreatedAtEQ applies the EQ predicate on the "created_at" field.
+func CreatedAtEQ(v time.Time) predicate.IdentityAdoptionDecision {
+ return predicate.IdentityAdoptionDecision(sql.FieldEQ(FieldCreatedAt, v))
+}
+
+// CreatedAtNEQ applies the NEQ predicate on the "created_at" field.
+func CreatedAtNEQ(v time.Time) predicate.IdentityAdoptionDecision {
+ return predicate.IdentityAdoptionDecision(sql.FieldNEQ(FieldCreatedAt, v))
+}
+
+// CreatedAtIn applies the In predicate on the "created_at" field.
+func CreatedAtIn(vs ...time.Time) predicate.IdentityAdoptionDecision {
+ return predicate.IdentityAdoptionDecision(sql.FieldIn(FieldCreatedAt, vs...))
+}
+
+// CreatedAtNotIn applies the NotIn predicate on the "created_at" field.
+func CreatedAtNotIn(vs ...time.Time) predicate.IdentityAdoptionDecision {
+ return predicate.IdentityAdoptionDecision(sql.FieldNotIn(FieldCreatedAt, vs...))
+}
+
+// CreatedAtGT applies the GT predicate on the "created_at" field.
+func CreatedAtGT(v time.Time) predicate.IdentityAdoptionDecision {
+ return predicate.IdentityAdoptionDecision(sql.FieldGT(FieldCreatedAt, v))
+}
+
+// CreatedAtGTE applies the GTE predicate on the "created_at" field.
+func CreatedAtGTE(v time.Time) predicate.IdentityAdoptionDecision {
+ return predicate.IdentityAdoptionDecision(sql.FieldGTE(FieldCreatedAt, v))
+}
+
+// CreatedAtLT applies the LT predicate on the "created_at" field.
+func CreatedAtLT(v time.Time) predicate.IdentityAdoptionDecision {
+ return predicate.IdentityAdoptionDecision(sql.FieldLT(FieldCreatedAt, v))
+}
+
+// CreatedAtLTE applies the LTE predicate on the "created_at" field.
+func CreatedAtLTE(v time.Time) predicate.IdentityAdoptionDecision {
+ return predicate.IdentityAdoptionDecision(sql.FieldLTE(FieldCreatedAt, v))
+}
+
+// UpdatedAtEQ applies the EQ predicate on the "updated_at" field.
+func UpdatedAtEQ(v time.Time) predicate.IdentityAdoptionDecision {
+ return predicate.IdentityAdoptionDecision(sql.FieldEQ(FieldUpdatedAt, v))
+}
+
+// UpdatedAtNEQ applies the NEQ predicate on the "updated_at" field.
+func UpdatedAtNEQ(v time.Time) predicate.IdentityAdoptionDecision {
+ return predicate.IdentityAdoptionDecision(sql.FieldNEQ(FieldUpdatedAt, v))
+}
+
+// UpdatedAtIn applies the In predicate on the "updated_at" field.
+func UpdatedAtIn(vs ...time.Time) predicate.IdentityAdoptionDecision {
+ return predicate.IdentityAdoptionDecision(sql.FieldIn(FieldUpdatedAt, vs...))
+}
+
+// UpdatedAtNotIn applies the NotIn predicate on the "updated_at" field.
+func UpdatedAtNotIn(vs ...time.Time) predicate.IdentityAdoptionDecision {
+ return predicate.IdentityAdoptionDecision(sql.FieldNotIn(FieldUpdatedAt, vs...))
+}
+
+// UpdatedAtGT applies the GT predicate on the "updated_at" field.
+func UpdatedAtGT(v time.Time) predicate.IdentityAdoptionDecision {
+ return predicate.IdentityAdoptionDecision(sql.FieldGT(FieldUpdatedAt, v))
+}
+
+// UpdatedAtGTE applies the GTE predicate on the "updated_at" field.
+func UpdatedAtGTE(v time.Time) predicate.IdentityAdoptionDecision {
+ return predicate.IdentityAdoptionDecision(sql.FieldGTE(FieldUpdatedAt, v))
+}
+
+// UpdatedAtLT applies the LT predicate on the "updated_at" field.
+func UpdatedAtLT(v time.Time) predicate.IdentityAdoptionDecision {
+ return predicate.IdentityAdoptionDecision(sql.FieldLT(FieldUpdatedAt, v))
+}
+
+// UpdatedAtLTE applies the LTE predicate on the "updated_at" field.
+func UpdatedAtLTE(v time.Time) predicate.IdentityAdoptionDecision {
+ return predicate.IdentityAdoptionDecision(sql.FieldLTE(FieldUpdatedAt, v))
+}
+
+// PendingAuthSessionIDEQ applies the EQ predicate on the "pending_auth_session_id" field.
+func PendingAuthSessionIDEQ(v int64) predicate.IdentityAdoptionDecision {
+ return predicate.IdentityAdoptionDecision(sql.FieldEQ(FieldPendingAuthSessionID, v))
+}
+
+// PendingAuthSessionIDNEQ applies the NEQ predicate on the "pending_auth_session_id" field.
+func PendingAuthSessionIDNEQ(v int64) predicate.IdentityAdoptionDecision {
+ return predicate.IdentityAdoptionDecision(sql.FieldNEQ(FieldPendingAuthSessionID, v))
+}
+
+// PendingAuthSessionIDIn applies the In predicate on the "pending_auth_session_id" field.
+func PendingAuthSessionIDIn(vs ...int64) predicate.IdentityAdoptionDecision {
+ return predicate.IdentityAdoptionDecision(sql.FieldIn(FieldPendingAuthSessionID, vs...))
+}
+
+// PendingAuthSessionIDNotIn applies the NotIn predicate on the "pending_auth_session_id" field.
+func PendingAuthSessionIDNotIn(vs ...int64) predicate.IdentityAdoptionDecision {
+ return predicate.IdentityAdoptionDecision(sql.FieldNotIn(FieldPendingAuthSessionID, vs...))
+}
+
+// IdentityIDEQ applies the EQ predicate on the "identity_id" field.
+func IdentityIDEQ(v int64) predicate.IdentityAdoptionDecision {
+ return predicate.IdentityAdoptionDecision(sql.FieldEQ(FieldIdentityID, v))
+}
+
+// IdentityIDNEQ applies the NEQ predicate on the "identity_id" field.
+func IdentityIDNEQ(v int64) predicate.IdentityAdoptionDecision {
+ return predicate.IdentityAdoptionDecision(sql.FieldNEQ(FieldIdentityID, v))
+}
+
+// IdentityIDIn applies the In predicate on the "identity_id" field.
+func IdentityIDIn(vs ...int64) predicate.IdentityAdoptionDecision {
+ return predicate.IdentityAdoptionDecision(sql.FieldIn(FieldIdentityID, vs...))
+}
+
+// IdentityIDNotIn applies the NotIn predicate on the "identity_id" field.
+func IdentityIDNotIn(vs ...int64) predicate.IdentityAdoptionDecision {
+ return predicate.IdentityAdoptionDecision(sql.FieldNotIn(FieldIdentityID, vs...))
+}
+
+// IdentityIDIsNil applies the IsNil predicate on the "identity_id" field.
+func IdentityIDIsNil() predicate.IdentityAdoptionDecision {
+ return predicate.IdentityAdoptionDecision(sql.FieldIsNull(FieldIdentityID))
+}
+
+// IdentityIDNotNil applies the NotNil predicate on the "identity_id" field.
+func IdentityIDNotNil() predicate.IdentityAdoptionDecision {
+ return predicate.IdentityAdoptionDecision(sql.FieldNotNull(FieldIdentityID))
+}
+
+// AdoptDisplayNameEQ applies the EQ predicate on the "adopt_display_name" field.
+func AdoptDisplayNameEQ(v bool) predicate.IdentityAdoptionDecision {
+ return predicate.IdentityAdoptionDecision(sql.FieldEQ(FieldAdoptDisplayName, v))
+}
+
+// AdoptDisplayNameNEQ applies the NEQ predicate on the "adopt_display_name" field.
+func AdoptDisplayNameNEQ(v bool) predicate.IdentityAdoptionDecision {
+ return predicate.IdentityAdoptionDecision(sql.FieldNEQ(FieldAdoptDisplayName, v))
+}
+
+// AdoptAvatarEQ applies the EQ predicate on the "adopt_avatar" field.
+func AdoptAvatarEQ(v bool) predicate.IdentityAdoptionDecision {
+ return predicate.IdentityAdoptionDecision(sql.FieldEQ(FieldAdoptAvatar, v))
+}
+
+// AdoptAvatarNEQ applies the NEQ predicate on the "adopt_avatar" field.
+func AdoptAvatarNEQ(v bool) predicate.IdentityAdoptionDecision {
+ return predicate.IdentityAdoptionDecision(sql.FieldNEQ(FieldAdoptAvatar, v))
+}
+
+// DecidedAtEQ applies the EQ predicate on the "decided_at" field.
+func DecidedAtEQ(v time.Time) predicate.IdentityAdoptionDecision {
+ return predicate.IdentityAdoptionDecision(sql.FieldEQ(FieldDecidedAt, v))
+}
+
+// DecidedAtNEQ applies the NEQ predicate on the "decided_at" field.
+func DecidedAtNEQ(v time.Time) predicate.IdentityAdoptionDecision {
+ return predicate.IdentityAdoptionDecision(sql.FieldNEQ(FieldDecidedAt, v))
+}
+
+// DecidedAtIn applies the In predicate on the "decided_at" field.
+func DecidedAtIn(vs ...time.Time) predicate.IdentityAdoptionDecision {
+ return predicate.IdentityAdoptionDecision(sql.FieldIn(FieldDecidedAt, vs...))
+}
+
+// DecidedAtNotIn applies the NotIn predicate on the "decided_at" field.
+func DecidedAtNotIn(vs ...time.Time) predicate.IdentityAdoptionDecision {
+ return predicate.IdentityAdoptionDecision(sql.FieldNotIn(FieldDecidedAt, vs...))
+}
+
+// DecidedAtGT applies the GT predicate on the "decided_at" field.
+func DecidedAtGT(v time.Time) predicate.IdentityAdoptionDecision {
+ return predicate.IdentityAdoptionDecision(sql.FieldGT(FieldDecidedAt, v))
+}
+
+// DecidedAtGTE applies the GTE predicate on the "decided_at" field.
+func DecidedAtGTE(v time.Time) predicate.IdentityAdoptionDecision {
+ return predicate.IdentityAdoptionDecision(sql.FieldGTE(FieldDecidedAt, v))
+}
+
+// DecidedAtLT applies the LT predicate on the "decided_at" field.
+func DecidedAtLT(v time.Time) predicate.IdentityAdoptionDecision {
+ return predicate.IdentityAdoptionDecision(sql.FieldLT(FieldDecidedAt, v))
+}
+
+// DecidedAtLTE applies the LTE predicate on the "decided_at" field.
+func DecidedAtLTE(v time.Time) predicate.IdentityAdoptionDecision {
+ return predicate.IdentityAdoptionDecision(sql.FieldLTE(FieldDecidedAt, v))
+}
+
+// HasPendingAuthSession applies the HasEdge predicate on the "pending_auth_session" edge.
+func HasPendingAuthSession() predicate.IdentityAdoptionDecision {
+ return predicate.IdentityAdoptionDecision(func(s *sql.Selector) {
+ step := sqlgraph.NewStep(
+ sqlgraph.From(Table, FieldID),
+ sqlgraph.Edge(sqlgraph.O2O, true, PendingAuthSessionTable, PendingAuthSessionColumn),
+ )
+ sqlgraph.HasNeighbors(s, step)
+ })
+}
+
+// HasPendingAuthSessionWith applies the HasEdge predicate on the "pending_auth_session" edge with a given conditions (other predicates).
+func HasPendingAuthSessionWith(preds ...predicate.PendingAuthSession) predicate.IdentityAdoptionDecision {
+ return predicate.IdentityAdoptionDecision(func(s *sql.Selector) {
+ step := newPendingAuthSessionStep()
+ sqlgraph.HasNeighborsWith(s, step, func(s *sql.Selector) {
+ for _, p := range preds {
+ p(s)
+ }
+ })
+ })
+}
+
+// HasIdentity applies the HasEdge predicate on the "identity" edge.
+func HasIdentity() predicate.IdentityAdoptionDecision {
+ return predicate.IdentityAdoptionDecision(func(s *sql.Selector) {
+ step := sqlgraph.NewStep(
+ sqlgraph.From(Table, FieldID),
+ sqlgraph.Edge(sqlgraph.M2O, true, IdentityTable, IdentityColumn),
+ )
+ sqlgraph.HasNeighbors(s, step)
+ })
+}
+
+// HasIdentityWith applies the HasEdge predicate on the "identity" edge with a given conditions (other predicates).
+func HasIdentityWith(preds ...predicate.AuthIdentity) predicate.IdentityAdoptionDecision {
+ return predicate.IdentityAdoptionDecision(func(s *sql.Selector) {
+ step := newIdentityStep()
+ sqlgraph.HasNeighborsWith(s, step, func(s *sql.Selector) {
+ for _, p := range preds {
+ p(s)
+ }
+ })
+ })
+}
+
+// And groups predicates with the AND operator between them.
+func And(predicates ...predicate.IdentityAdoptionDecision) predicate.IdentityAdoptionDecision {
+ return predicate.IdentityAdoptionDecision(sql.AndPredicates(predicates...))
+}
+
+// Or groups predicates with the OR operator between them.
+func Or(predicates ...predicate.IdentityAdoptionDecision) predicate.IdentityAdoptionDecision {
+ return predicate.IdentityAdoptionDecision(sql.OrPredicates(predicates...))
+}
+
+// Not applies the not operator on the given predicate.
+func Not(p predicate.IdentityAdoptionDecision) predicate.IdentityAdoptionDecision {
+ return predicate.IdentityAdoptionDecision(sql.NotPredicates(p))
+}
diff --git a/backend/ent/identityadoptiondecision_create.go b/backend/ent/identityadoptiondecision_create.go
new file mode 100644
index 00000000..491ba9f9
--- /dev/null
+++ b/backend/ent/identityadoptiondecision_create.go
@@ -0,0 +1,843 @@
+// Code generated by ent, DO NOT EDIT.
+
+package ent
+
+import (
+ "context"
+ "errors"
+ "fmt"
+ "time"
+
+ "entgo.io/ent/dialect/sql"
+ "entgo.io/ent/dialect/sql/sqlgraph"
+ "entgo.io/ent/schema/field"
+ "github.com/Wei-Shaw/sub2api/ent/authidentity"
+ "github.com/Wei-Shaw/sub2api/ent/identityadoptiondecision"
+ "github.com/Wei-Shaw/sub2api/ent/pendingauthsession"
+)
+
+// IdentityAdoptionDecisionCreate is the builder for creating a IdentityAdoptionDecision entity.
+type IdentityAdoptionDecisionCreate struct {
+ config
+ mutation *IdentityAdoptionDecisionMutation
+ hooks []Hook
+ conflict []sql.ConflictOption
+}
+
+// SetCreatedAt sets the "created_at" field.
+func (_c *IdentityAdoptionDecisionCreate) SetCreatedAt(v time.Time) *IdentityAdoptionDecisionCreate {
+ _c.mutation.SetCreatedAt(v)
+ return _c
+}
+
+// SetNillableCreatedAt sets the "created_at" field if the given value is not nil.
+func (_c *IdentityAdoptionDecisionCreate) SetNillableCreatedAt(v *time.Time) *IdentityAdoptionDecisionCreate {
+ if v != nil {
+ _c.SetCreatedAt(*v)
+ }
+ return _c
+}
+
+// SetUpdatedAt sets the "updated_at" field.
+func (_c *IdentityAdoptionDecisionCreate) SetUpdatedAt(v time.Time) *IdentityAdoptionDecisionCreate {
+ _c.mutation.SetUpdatedAt(v)
+ return _c
+}
+
+// SetNillableUpdatedAt sets the "updated_at" field if the given value is not nil.
+func (_c *IdentityAdoptionDecisionCreate) SetNillableUpdatedAt(v *time.Time) *IdentityAdoptionDecisionCreate {
+ if v != nil {
+ _c.SetUpdatedAt(*v)
+ }
+ return _c
+}
+
+// SetPendingAuthSessionID sets the "pending_auth_session_id" field.
+func (_c *IdentityAdoptionDecisionCreate) SetPendingAuthSessionID(v int64) *IdentityAdoptionDecisionCreate {
+ _c.mutation.SetPendingAuthSessionID(v)
+ return _c
+}
+
+// SetIdentityID sets the "identity_id" field.
+func (_c *IdentityAdoptionDecisionCreate) SetIdentityID(v int64) *IdentityAdoptionDecisionCreate {
+ _c.mutation.SetIdentityID(v)
+ return _c
+}
+
+// SetNillableIdentityID sets the "identity_id" field if the given value is not nil.
+func (_c *IdentityAdoptionDecisionCreate) SetNillableIdentityID(v *int64) *IdentityAdoptionDecisionCreate {
+ if v != nil {
+ _c.SetIdentityID(*v)
+ }
+ return _c
+}
+
+// SetAdoptDisplayName sets the "adopt_display_name" field.
+func (_c *IdentityAdoptionDecisionCreate) SetAdoptDisplayName(v bool) *IdentityAdoptionDecisionCreate {
+ _c.mutation.SetAdoptDisplayName(v)
+ return _c
+}
+
+// SetNillableAdoptDisplayName sets the "adopt_display_name" field if the given value is not nil.
+func (_c *IdentityAdoptionDecisionCreate) SetNillableAdoptDisplayName(v *bool) *IdentityAdoptionDecisionCreate {
+ if v != nil {
+ _c.SetAdoptDisplayName(*v)
+ }
+ return _c
+}
+
+// SetAdoptAvatar sets the "adopt_avatar" field.
+func (_c *IdentityAdoptionDecisionCreate) SetAdoptAvatar(v bool) *IdentityAdoptionDecisionCreate {
+ _c.mutation.SetAdoptAvatar(v)
+ return _c
+}
+
+// SetNillableAdoptAvatar sets the "adopt_avatar" field if the given value is not nil.
+func (_c *IdentityAdoptionDecisionCreate) SetNillableAdoptAvatar(v *bool) *IdentityAdoptionDecisionCreate {
+ if v != nil {
+ _c.SetAdoptAvatar(*v)
+ }
+ return _c
+}
+
+// SetDecidedAt sets the "decided_at" field.
+func (_c *IdentityAdoptionDecisionCreate) SetDecidedAt(v time.Time) *IdentityAdoptionDecisionCreate {
+ _c.mutation.SetDecidedAt(v)
+ return _c
+}
+
+// SetNillableDecidedAt sets the "decided_at" field if the given value is not nil.
+func (_c *IdentityAdoptionDecisionCreate) SetNillableDecidedAt(v *time.Time) *IdentityAdoptionDecisionCreate {
+ if v != nil {
+ _c.SetDecidedAt(*v)
+ }
+ return _c
+}
+
+// SetPendingAuthSession sets the "pending_auth_session" edge to the PendingAuthSession entity.
+func (_c *IdentityAdoptionDecisionCreate) SetPendingAuthSession(v *PendingAuthSession) *IdentityAdoptionDecisionCreate {
+ return _c.SetPendingAuthSessionID(v.ID)
+}
+
+// SetIdentity sets the "identity" edge to the AuthIdentity entity.
+func (_c *IdentityAdoptionDecisionCreate) SetIdentity(v *AuthIdentity) *IdentityAdoptionDecisionCreate {
+ return _c.SetIdentityID(v.ID)
+}
+
+// Mutation returns the IdentityAdoptionDecisionMutation object of the builder.
+func (_c *IdentityAdoptionDecisionCreate) Mutation() *IdentityAdoptionDecisionMutation {
+ return _c.mutation
+}
+
+// Save creates the IdentityAdoptionDecision in the database.
+func (_c *IdentityAdoptionDecisionCreate) Save(ctx context.Context) (*IdentityAdoptionDecision, error) {
+ _c.defaults()
+ return withHooks(ctx, _c.sqlSave, _c.mutation, _c.hooks)
+}
+
+// SaveX calls Save and panics if Save returns an error.
+func (_c *IdentityAdoptionDecisionCreate) SaveX(ctx context.Context) *IdentityAdoptionDecision {
+ v, err := _c.Save(ctx)
+ if err != nil {
+ panic(err)
+ }
+ return v
+}
+
+// Exec executes the query.
+func (_c *IdentityAdoptionDecisionCreate) Exec(ctx context.Context) error {
+ _, err := _c.Save(ctx)
+ return err
+}
+
+// ExecX is like Exec, but panics if an error occurs.
+func (_c *IdentityAdoptionDecisionCreate) ExecX(ctx context.Context) {
+ if err := _c.Exec(ctx); err != nil {
+ panic(err)
+ }
+}
+
+// defaults sets the default values of the builder before save.
+func (_c *IdentityAdoptionDecisionCreate) defaults() {
+ if _, ok := _c.mutation.CreatedAt(); !ok {
+ v := identityadoptiondecision.DefaultCreatedAt()
+ _c.mutation.SetCreatedAt(v)
+ }
+ if _, ok := _c.mutation.UpdatedAt(); !ok {
+ v := identityadoptiondecision.DefaultUpdatedAt()
+ _c.mutation.SetUpdatedAt(v)
+ }
+ if _, ok := _c.mutation.AdoptDisplayName(); !ok {
+ v := identityadoptiondecision.DefaultAdoptDisplayName
+ _c.mutation.SetAdoptDisplayName(v)
+ }
+ if _, ok := _c.mutation.AdoptAvatar(); !ok {
+ v := identityadoptiondecision.DefaultAdoptAvatar
+ _c.mutation.SetAdoptAvatar(v)
+ }
+ if _, ok := _c.mutation.DecidedAt(); !ok {
+ v := identityadoptiondecision.DefaultDecidedAt()
+ _c.mutation.SetDecidedAt(v)
+ }
+}
+
+// check runs all checks and user-defined validators on the builder.
+func (_c *IdentityAdoptionDecisionCreate) check() error {
+ if _, ok := _c.mutation.CreatedAt(); !ok {
+ return &ValidationError{Name: "created_at", err: errors.New(`ent: missing required field "IdentityAdoptionDecision.created_at"`)}
+ }
+ if _, ok := _c.mutation.UpdatedAt(); !ok {
+ return &ValidationError{Name: "updated_at", err: errors.New(`ent: missing required field "IdentityAdoptionDecision.updated_at"`)}
+ }
+ if _, ok := _c.mutation.PendingAuthSessionID(); !ok {
+ return &ValidationError{Name: "pending_auth_session_id", err: errors.New(`ent: missing required field "IdentityAdoptionDecision.pending_auth_session_id"`)}
+ }
+ if _, ok := _c.mutation.AdoptDisplayName(); !ok {
+ return &ValidationError{Name: "adopt_display_name", err: errors.New(`ent: missing required field "IdentityAdoptionDecision.adopt_display_name"`)}
+ }
+ if _, ok := _c.mutation.AdoptAvatar(); !ok {
+ return &ValidationError{Name: "adopt_avatar", err: errors.New(`ent: missing required field "IdentityAdoptionDecision.adopt_avatar"`)}
+ }
+ if _, ok := _c.mutation.DecidedAt(); !ok {
+ return &ValidationError{Name: "decided_at", err: errors.New(`ent: missing required field "IdentityAdoptionDecision.decided_at"`)}
+ }
+ if len(_c.mutation.PendingAuthSessionIDs()) == 0 {
+ return &ValidationError{Name: "pending_auth_session", err: errors.New(`ent: missing required edge "IdentityAdoptionDecision.pending_auth_session"`)}
+ }
+ return nil
+}
+
+func (_c *IdentityAdoptionDecisionCreate) sqlSave(ctx context.Context) (*IdentityAdoptionDecision, error) {
+ if err := _c.check(); err != nil {
+ return nil, err
+ }
+ _node, _spec := _c.createSpec()
+ if err := sqlgraph.CreateNode(ctx, _c.driver, _spec); err != nil {
+ if sqlgraph.IsConstraintError(err) {
+ err = &ConstraintError{msg: err.Error(), wrap: err}
+ }
+ return nil, err
+ }
+ id := _spec.ID.Value.(int64)
+ _node.ID = int64(id)
+ _c.mutation.id = &_node.ID
+ _c.mutation.done = true
+ return _node, nil
+}
+
+func (_c *IdentityAdoptionDecisionCreate) createSpec() (*IdentityAdoptionDecision, *sqlgraph.CreateSpec) {
+ var (
+ _node = &IdentityAdoptionDecision{config: _c.config}
+ _spec = sqlgraph.NewCreateSpec(identityadoptiondecision.Table, sqlgraph.NewFieldSpec(identityadoptiondecision.FieldID, field.TypeInt64))
+ )
+ _spec.OnConflict = _c.conflict
+ if value, ok := _c.mutation.CreatedAt(); ok {
+ _spec.SetField(identityadoptiondecision.FieldCreatedAt, field.TypeTime, value)
+ _node.CreatedAt = value
+ }
+ if value, ok := _c.mutation.UpdatedAt(); ok {
+ _spec.SetField(identityadoptiondecision.FieldUpdatedAt, field.TypeTime, value)
+ _node.UpdatedAt = value
+ }
+ if value, ok := _c.mutation.AdoptDisplayName(); ok {
+ _spec.SetField(identityadoptiondecision.FieldAdoptDisplayName, field.TypeBool, value)
+ _node.AdoptDisplayName = value
+ }
+ if value, ok := _c.mutation.AdoptAvatar(); ok {
+ _spec.SetField(identityadoptiondecision.FieldAdoptAvatar, field.TypeBool, value)
+ _node.AdoptAvatar = value
+ }
+ if value, ok := _c.mutation.DecidedAt(); ok {
+ _spec.SetField(identityadoptiondecision.FieldDecidedAt, field.TypeTime, value)
+ _node.DecidedAt = value
+ }
+ if nodes := _c.mutation.PendingAuthSessionIDs(); len(nodes) > 0 {
+ edge := &sqlgraph.EdgeSpec{
+ Rel: sqlgraph.O2O,
+ Inverse: true,
+ Table: identityadoptiondecision.PendingAuthSessionTable,
+ Columns: []string{identityadoptiondecision.PendingAuthSessionColumn},
+ Bidi: false,
+ Target: &sqlgraph.EdgeTarget{
+ IDSpec: sqlgraph.NewFieldSpec(pendingauthsession.FieldID, field.TypeInt64),
+ },
+ }
+ for _, k := range nodes {
+ edge.Target.Nodes = append(edge.Target.Nodes, k)
+ }
+ _node.PendingAuthSessionID = nodes[0]
+ _spec.Edges = append(_spec.Edges, edge)
+ }
+ if nodes := _c.mutation.IdentityIDs(); len(nodes) > 0 {
+ edge := &sqlgraph.EdgeSpec{
+ Rel: sqlgraph.M2O,
+ Inverse: true,
+ Table: identityadoptiondecision.IdentityTable,
+ Columns: []string{identityadoptiondecision.IdentityColumn},
+ Bidi: false,
+ Target: &sqlgraph.EdgeTarget{
+ IDSpec: sqlgraph.NewFieldSpec(authidentity.FieldID, field.TypeInt64),
+ },
+ }
+ for _, k := range nodes {
+ edge.Target.Nodes = append(edge.Target.Nodes, k)
+ }
+ _node.IdentityID = &nodes[0]
+ _spec.Edges = append(_spec.Edges, edge)
+ }
+ return _node, _spec
+}
+
+// OnConflict allows configuring the `ON CONFLICT` / `ON DUPLICATE KEY` clause
+// of the `INSERT` statement. For example:
+//
+// client.IdentityAdoptionDecision.Create().
+// SetCreatedAt(v).
+// OnConflict(
+// // Update the row with the new values
+// // the was proposed for insertion.
+// sql.ResolveWithNewValues(),
+// ).
+// // Override some of the fields with custom
+// // update values.
+// Update(func(u *ent.IdentityAdoptionDecisionUpsert) {
+// SetCreatedAt(v+v).
+// }).
+// Exec(ctx)
+func (_c *IdentityAdoptionDecisionCreate) OnConflict(opts ...sql.ConflictOption) *IdentityAdoptionDecisionUpsertOne {
+ _c.conflict = opts
+ return &IdentityAdoptionDecisionUpsertOne{
+ create: _c,
+ }
+}
+
+// OnConflictColumns calls `OnConflict` and configures the columns
+// as conflict target. Using this option is equivalent to using:
+//
+// client.IdentityAdoptionDecision.Create().
+// OnConflict(sql.ConflictColumns(columns...)).
+// Exec(ctx)
+func (_c *IdentityAdoptionDecisionCreate) OnConflictColumns(columns ...string) *IdentityAdoptionDecisionUpsertOne {
+ _c.conflict = append(_c.conflict, sql.ConflictColumns(columns...))
+ return &IdentityAdoptionDecisionUpsertOne{
+ create: _c,
+ }
+}
+
+type (
+ // IdentityAdoptionDecisionUpsertOne is the builder for "upsert"-ing
+ // one IdentityAdoptionDecision node.
+ IdentityAdoptionDecisionUpsertOne struct {
+ create *IdentityAdoptionDecisionCreate
+ }
+
+ // IdentityAdoptionDecisionUpsert is the "OnConflict" setter.
+ IdentityAdoptionDecisionUpsert struct {
+ *sql.UpdateSet
+ }
+)
+
+// SetUpdatedAt sets the "updated_at" field.
+func (u *IdentityAdoptionDecisionUpsert) SetUpdatedAt(v time.Time) *IdentityAdoptionDecisionUpsert {
+ u.Set(identityadoptiondecision.FieldUpdatedAt, v)
+ return u
+}
+
+// UpdateUpdatedAt sets the "updated_at" field to the value that was provided on create.
+func (u *IdentityAdoptionDecisionUpsert) UpdateUpdatedAt() *IdentityAdoptionDecisionUpsert {
+ u.SetExcluded(identityadoptiondecision.FieldUpdatedAt)
+ return u
+}
+
+// SetPendingAuthSessionID sets the "pending_auth_session_id" field.
+func (u *IdentityAdoptionDecisionUpsert) SetPendingAuthSessionID(v int64) *IdentityAdoptionDecisionUpsert {
+ u.Set(identityadoptiondecision.FieldPendingAuthSessionID, v)
+ return u
+}
+
+// UpdatePendingAuthSessionID sets the "pending_auth_session_id" field to the value that was provided on create.
+func (u *IdentityAdoptionDecisionUpsert) UpdatePendingAuthSessionID() *IdentityAdoptionDecisionUpsert {
+ u.SetExcluded(identityadoptiondecision.FieldPendingAuthSessionID)
+ return u
+}
+
+// SetIdentityID sets the "identity_id" field.
+func (u *IdentityAdoptionDecisionUpsert) SetIdentityID(v int64) *IdentityAdoptionDecisionUpsert {
+ u.Set(identityadoptiondecision.FieldIdentityID, v)
+ return u
+}
+
+// UpdateIdentityID sets the "identity_id" field to the value that was provided on create.
+func (u *IdentityAdoptionDecisionUpsert) UpdateIdentityID() *IdentityAdoptionDecisionUpsert {
+ u.SetExcluded(identityadoptiondecision.FieldIdentityID)
+ return u
+}
+
+// ClearIdentityID clears the value of the "identity_id" field.
+func (u *IdentityAdoptionDecisionUpsert) ClearIdentityID() *IdentityAdoptionDecisionUpsert {
+ u.SetNull(identityadoptiondecision.FieldIdentityID)
+ return u
+}
+
+// SetAdoptDisplayName sets the "adopt_display_name" field.
+func (u *IdentityAdoptionDecisionUpsert) SetAdoptDisplayName(v bool) *IdentityAdoptionDecisionUpsert {
+ u.Set(identityadoptiondecision.FieldAdoptDisplayName, v)
+ return u
+}
+
+// UpdateAdoptDisplayName sets the "adopt_display_name" field to the value that was provided on create.
+func (u *IdentityAdoptionDecisionUpsert) UpdateAdoptDisplayName() *IdentityAdoptionDecisionUpsert {
+ u.SetExcluded(identityadoptiondecision.FieldAdoptDisplayName)
+ return u
+}
+
+// SetAdoptAvatar sets the "adopt_avatar" field.
+func (u *IdentityAdoptionDecisionUpsert) SetAdoptAvatar(v bool) *IdentityAdoptionDecisionUpsert {
+ u.Set(identityadoptiondecision.FieldAdoptAvatar, v)
+ return u
+}
+
+// UpdateAdoptAvatar sets the "adopt_avatar" field to the value that was provided on create.
+func (u *IdentityAdoptionDecisionUpsert) UpdateAdoptAvatar() *IdentityAdoptionDecisionUpsert {
+ u.SetExcluded(identityadoptiondecision.FieldAdoptAvatar)
+ return u
+}
+
+// UpdateNewValues updates the mutable fields using the new values that were set on create.
+// Using this option is equivalent to using:
+//
+// client.IdentityAdoptionDecision.Create().
+// OnConflict(
+// sql.ResolveWithNewValues(),
+// ).
+// Exec(ctx)
+func (u *IdentityAdoptionDecisionUpsertOne) UpdateNewValues() *IdentityAdoptionDecisionUpsertOne {
+ u.create.conflict = append(u.create.conflict, sql.ResolveWithNewValues())
+ u.create.conflict = append(u.create.conflict, sql.ResolveWith(func(s *sql.UpdateSet) {
+ if _, exists := u.create.mutation.CreatedAt(); exists {
+ s.SetIgnore(identityadoptiondecision.FieldCreatedAt)
+ }
+ if _, exists := u.create.mutation.DecidedAt(); exists {
+ s.SetIgnore(identityadoptiondecision.FieldDecidedAt)
+ }
+ }))
+ return u
+}
+
+// Ignore sets each column to itself in case of conflict.
+// Using this option is equivalent to using:
+//
+// client.IdentityAdoptionDecision.Create().
+// OnConflict(sql.ResolveWithIgnore()).
+// Exec(ctx)
+func (u *IdentityAdoptionDecisionUpsertOne) Ignore() *IdentityAdoptionDecisionUpsertOne {
+ u.create.conflict = append(u.create.conflict, sql.ResolveWithIgnore())
+ return u
+}
+
+// DoNothing configures the conflict_action to `DO NOTHING`.
+// Supported only by SQLite and PostgreSQL.
+func (u *IdentityAdoptionDecisionUpsertOne) DoNothing() *IdentityAdoptionDecisionUpsertOne {
+ u.create.conflict = append(u.create.conflict, sql.DoNothing())
+ return u
+}
+
+// Update allows overriding fields `UPDATE` values. See the IdentityAdoptionDecisionCreate.OnConflict
+// documentation for more info.
+func (u *IdentityAdoptionDecisionUpsertOne) Update(set func(*IdentityAdoptionDecisionUpsert)) *IdentityAdoptionDecisionUpsertOne {
+ u.create.conflict = append(u.create.conflict, sql.ResolveWith(func(update *sql.UpdateSet) {
+ set(&IdentityAdoptionDecisionUpsert{UpdateSet: update})
+ }))
+ return u
+}
+
+// SetUpdatedAt sets the "updated_at" field.
+func (u *IdentityAdoptionDecisionUpsertOne) SetUpdatedAt(v time.Time) *IdentityAdoptionDecisionUpsertOne {
+ return u.Update(func(s *IdentityAdoptionDecisionUpsert) {
+ s.SetUpdatedAt(v)
+ })
+}
+
+// UpdateUpdatedAt sets the "updated_at" field to the value that was provided on create.
+func (u *IdentityAdoptionDecisionUpsertOne) UpdateUpdatedAt() *IdentityAdoptionDecisionUpsertOne {
+ return u.Update(func(s *IdentityAdoptionDecisionUpsert) {
+ s.UpdateUpdatedAt()
+ })
+}
+
+// SetPendingAuthSessionID sets the "pending_auth_session_id" field.
+func (u *IdentityAdoptionDecisionUpsertOne) SetPendingAuthSessionID(v int64) *IdentityAdoptionDecisionUpsertOne {
+ return u.Update(func(s *IdentityAdoptionDecisionUpsert) {
+ s.SetPendingAuthSessionID(v)
+ })
+}
+
+// UpdatePendingAuthSessionID sets the "pending_auth_session_id" field to the value that was provided on create.
+func (u *IdentityAdoptionDecisionUpsertOne) UpdatePendingAuthSessionID() *IdentityAdoptionDecisionUpsertOne {
+ return u.Update(func(s *IdentityAdoptionDecisionUpsert) {
+ s.UpdatePendingAuthSessionID()
+ })
+}
+
+// SetIdentityID sets the "identity_id" field.
+func (u *IdentityAdoptionDecisionUpsertOne) SetIdentityID(v int64) *IdentityAdoptionDecisionUpsertOne {
+ return u.Update(func(s *IdentityAdoptionDecisionUpsert) {
+ s.SetIdentityID(v)
+ })
+}
+
+// UpdateIdentityID sets the "identity_id" field to the value that was provided on create.
+func (u *IdentityAdoptionDecisionUpsertOne) UpdateIdentityID() *IdentityAdoptionDecisionUpsertOne {
+ return u.Update(func(s *IdentityAdoptionDecisionUpsert) {
+ s.UpdateIdentityID()
+ })
+}
+
+// ClearIdentityID clears the value of the "identity_id" field.
+func (u *IdentityAdoptionDecisionUpsertOne) ClearIdentityID() *IdentityAdoptionDecisionUpsertOne {
+ return u.Update(func(s *IdentityAdoptionDecisionUpsert) {
+ s.ClearIdentityID()
+ })
+}
+
+// SetAdoptDisplayName sets the "adopt_display_name" field.
+func (u *IdentityAdoptionDecisionUpsertOne) SetAdoptDisplayName(v bool) *IdentityAdoptionDecisionUpsertOne {
+ return u.Update(func(s *IdentityAdoptionDecisionUpsert) {
+ s.SetAdoptDisplayName(v)
+ })
+}
+
+// UpdateAdoptDisplayName sets the "adopt_display_name" field to the value that was provided on create.
+func (u *IdentityAdoptionDecisionUpsertOne) UpdateAdoptDisplayName() *IdentityAdoptionDecisionUpsertOne {
+ return u.Update(func(s *IdentityAdoptionDecisionUpsert) {
+ s.UpdateAdoptDisplayName()
+ })
+}
+
+// SetAdoptAvatar sets the "adopt_avatar" field.
+func (u *IdentityAdoptionDecisionUpsertOne) SetAdoptAvatar(v bool) *IdentityAdoptionDecisionUpsertOne {
+ return u.Update(func(s *IdentityAdoptionDecisionUpsert) {
+ s.SetAdoptAvatar(v)
+ })
+}
+
+// UpdateAdoptAvatar sets the "adopt_avatar" field to the value that was provided on create.
+func (u *IdentityAdoptionDecisionUpsertOne) UpdateAdoptAvatar() *IdentityAdoptionDecisionUpsertOne {
+ return u.Update(func(s *IdentityAdoptionDecisionUpsert) {
+ s.UpdateAdoptAvatar()
+ })
+}
+
+// Exec executes the query.
+func (u *IdentityAdoptionDecisionUpsertOne) Exec(ctx context.Context) error {
+ if len(u.create.conflict) == 0 {
+ return errors.New("ent: missing options for IdentityAdoptionDecisionCreate.OnConflict")
+ }
+ return u.create.Exec(ctx)
+}
+
+// ExecX is like Exec, but panics if an error occurs.
+func (u *IdentityAdoptionDecisionUpsertOne) ExecX(ctx context.Context) {
+ if err := u.create.Exec(ctx); err != nil {
+ panic(err)
+ }
+}
+
+// Exec executes the UPSERT query and returns the inserted/updated ID.
+func (u *IdentityAdoptionDecisionUpsertOne) ID(ctx context.Context) (id int64, err error) {
+ node, err := u.create.Save(ctx)
+ if err != nil {
+ return id, err
+ }
+ return node.ID, nil
+}
+
+// IDX is like ID, but panics if an error occurs.
+func (u *IdentityAdoptionDecisionUpsertOne) IDX(ctx context.Context) int64 {
+ id, err := u.ID(ctx)
+ if err != nil {
+ panic(err)
+ }
+ return id
+}
+
+// IdentityAdoptionDecisionCreateBulk is the builder for creating many IdentityAdoptionDecision entities in bulk.
+type IdentityAdoptionDecisionCreateBulk struct {
+ config
+ err error
+ builders []*IdentityAdoptionDecisionCreate
+ conflict []sql.ConflictOption
+}
+
+// Save creates the IdentityAdoptionDecision entities in the database.
+func (_c *IdentityAdoptionDecisionCreateBulk) Save(ctx context.Context) ([]*IdentityAdoptionDecision, error) {
+ if _c.err != nil {
+ return nil, _c.err
+ }
+ specs := make([]*sqlgraph.CreateSpec, len(_c.builders))
+ nodes := make([]*IdentityAdoptionDecision, len(_c.builders))
+ mutators := make([]Mutator, len(_c.builders))
+ for i := range _c.builders {
+ func(i int, root context.Context) {
+ builder := _c.builders[i]
+ builder.defaults()
+ var mut Mutator = MutateFunc(func(ctx context.Context, m Mutation) (Value, error) {
+ mutation, ok := m.(*IdentityAdoptionDecisionMutation)
+ if !ok {
+ return nil, fmt.Errorf("unexpected mutation type %T", m)
+ }
+ if err := builder.check(); err != nil {
+ return nil, err
+ }
+ builder.mutation = mutation
+ var err error
+ nodes[i], specs[i] = builder.createSpec()
+ if i < len(mutators)-1 {
+ _, err = mutators[i+1].Mutate(root, _c.builders[i+1].mutation)
+ } else {
+ spec := &sqlgraph.BatchCreateSpec{Nodes: specs}
+ spec.OnConflict = _c.conflict
+ // Invoke the actual operation on the latest mutation in the chain.
+ if err = sqlgraph.BatchCreate(ctx, _c.driver, spec); err != nil {
+ if sqlgraph.IsConstraintError(err) {
+ err = &ConstraintError{msg: err.Error(), wrap: err}
+ }
+ }
+ }
+ if err != nil {
+ return nil, err
+ }
+ mutation.id = &nodes[i].ID
+ if specs[i].ID.Value != nil {
+ id := specs[i].ID.Value.(int64)
+ nodes[i].ID = int64(id)
+ }
+ mutation.done = true
+ return nodes[i], nil
+ })
+ for i := len(builder.hooks) - 1; i >= 0; i-- {
+ mut = builder.hooks[i](mut)
+ }
+ mutators[i] = mut
+ }(i, ctx)
+ }
+ if len(mutators) > 0 {
+ if _, err := mutators[0].Mutate(ctx, _c.builders[0].mutation); err != nil {
+ return nil, err
+ }
+ }
+ return nodes, nil
+}
+
+// SaveX is like Save, but panics if an error occurs.
+func (_c *IdentityAdoptionDecisionCreateBulk) SaveX(ctx context.Context) []*IdentityAdoptionDecision {
+ v, err := _c.Save(ctx)
+ if err != nil {
+ panic(err)
+ }
+ return v
+}
+
+// Exec executes the query.
+func (_c *IdentityAdoptionDecisionCreateBulk) Exec(ctx context.Context) error {
+ _, err := _c.Save(ctx)
+ return err
+}
+
+// ExecX is like Exec, but panics if an error occurs.
+func (_c *IdentityAdoptionDecisionCreateBulk) ExecX(ctx context.Context) {
+ if err := _c.Exec(ctx); err != nil {
+ panic(err)
+ }
+}
+
+// OnConflict allows configuring the `ON CONFLICT` / `ON DUPLICATE KEY` clause
+// of the `INSERT` statement. For example:
+//
+// client.IdentityAdoptionDecision.CreateBulk(builders...).
+// OnConflict(
+// // Update the row with the new values
+// // the was proposed for insertion.
+// sql.ResolveWithNewValues(),
+// ).
+// // Override some of the fields with custom
+// // update values.
+// Update(func(u *ent.IdentityAdoptionDecisionUpsert) {
+// SetCreatedAt(v+v).
+// }).
+// Exec(ctx)
+func (_c *IdentityAdoptionDecisionCreateBulk) OnConflict(opts ...sql.ConflictOption) *IdentityAdoptionDecisionUpsertBulk {
+ _c.conflict = opts
+ return &IdentityAdoptionDecisionUpsertBulk{
+ create: _c,
+ }
+}
+
+// OnConflictColumns calls `OnConflict` and configures the columns
+// as conflict target. Using this option is equivalent to using:
+//
+// client.IdentityAdoptionDecision.Create().
+// OnConflict(sql.ConflictColumns(columns...)).
+// Exec(ctx)
+func (_c *IdentityAdoptionDecisionCreateBulk) OnConflictColumns(columns ...string) *IdentityAdoptionDecisionUpsertBulk {
+ _c.conflict = append(_c.conflict, sql.ConflictColumns(columns...))
+ return &IdentityAdoptionDecisionUpsertBulk{
+ create: _c,
+ }
+}
+
+// IdentityAdoptionDecisionUpsertBulk is the builder for "upsert"-ing
+// a bulk of IdentityAdoptionDecision nodes.
+type IdentityAdoptionDecisionUpsertBulk struct {
+ create *IdentityAdoptionDecisionCreateBulk
+}
+
+// UpdateNewValues updates the mutable fields using the new values that
+// were set on create. Using this option is equivalent to using:
+//
+// client.IdentityAdoptionDecision.Create().
+// OnConflict(
+// sql.ResolveWithNewValues(),
+// ).
+// Exec(ctx)
+func (u *IdentityAdoptionDecisionUpsertBulk) UpdateNewValues() *IdentityAdoptionDecisionUpsertBulk {
+ u.create.conflict = append(u.create.conflict, sql.ResolveWithNewValues())
+ u.create.conflict = append(u.create.conflict, sql.ResolveWith(func(s *sql.UpdateSet) {
+ for _, b := range u.create.builders {
+ if _, exists := b.mutation.CreatedAt(); exists {
+ s.SetIgnore(identityadoptiondecision.FieldCreatedAt)
+ }
+ if _, exists := b.mutation.DecidedAt(); exists {
+ s.SetIgnore(identityadoptiondecision.FieldDecidedAt)
+ }
+ }
+ }))
+ return u
+}
+
+// Ignore sets each column to itself in case of conflict.
+// Using this option is equivalent to using:
+//
+// client.IdentityAdoptionDecision.Create().
+// OnConflict(sql.ResolveWithIgnore()).
+// Exec(ctx)
+func (u *IdentityAdoptionDecisionUpsertBulk) Ignore() *IdentityAdoptionDecisionUpsertBulk {
+ u.create.conflict = append(u.create.conflict, sql.ResolveWithIgnore())
+ return u
+}
+
+// DoNothing configures the conflict_action to `DO NOTHING`.
+// Supported only by SQLite and PostgreSQL.
+func (u *IdentityAdoptionDecisionUpsertBulk) DoNothing() *IdentityAdoptionDecisionUpsertBulk {
+ u.create.conflict = append(u.create.conflict, sql.DoNothing())
+ return u
+}
+
+// Update allows overriding fields `UPDATE` values. See the IdentityAdoptionDecisionCreateBulk.OnConflict
+// documentation for more info.
+func (u *IdentityAdoptionDecisionUpsertBulk) Update(set func(*IdentityAdoptionDecisionUpsert)) *IdentityAdoptionDecisionUpsertBulk {
+ u.create.conflict = append(u.create.conflict, sql.ResolveWith(func(update *sql.UpdateSet) {
+ set(&IdentityAdoptionDecisionUpsert{UpdateSet: update})
+ }))
+ return u
+}
+
+// SetUpdatedAt sets the "updated_at" field.
+func (u *IdentityAdoptionDecisionUpsertBulk) SetUpdatedAt(v time.Time) *IdentityAdoptionDecisionUpsertBulk {
+ return u.Update(func(s *IdentityAdoptionDecisionUpsert) {
+ s.SetUpdatedAt(v)
+ })
+}
+
+// UpdateUpdatedAt sets the "updated_at" field to the value that was provided on create.
+func (u *IdentityAdoptionDecisionUpsertBulk) UpdateUpdatedAt() *IdentityAdoptionDecisionUpsertBulk {
+ return u.Update(func(s *IdentityAdoptionDecisionUpsert) {
+ s.UpdateUpdatedAt()
+ })
+}
+
+// SetPendingAuthSessionID sets the "pending_auth_session_id" field.
+func (u *IdentityAdoptionDecisionUpsertBulk) SetPendingAuthSessionID(v int64) *IdentityAdoptionDecisionUpsertBulk {
+ return u.Update(func(s *IdentityAdoptionDecisionUpsert) {
+ s.SetPendingAuthSessionID(v)
+ })
+}
+
+// UpdatePendingAuthSessionID sets the "pending_auth_session_id" field to the value that was provided on create.
+func (u *IdentityAdoptionDecisionUpsertBulk) UpdatePendingAuthSessionID() *IdentityAdoptionDecisionUpsertBulk {
+ return u.Update(func(s *IdentityAdoptionDecisionUpsert) {
+ s.UpdatePendingAuthSessionID()
+ })
+}
+
+// SetIdentityID sets the "identity_id" field.
+func (u *IdentityAdoptionDecisionUpsertBulk) SetIdentityID(v int64) *IdentityAdoptionDecisionUpsertBulk {
+ return u.Update(func(s *IdentityAdoptionDecisionUpsert) {
+ s.SetIdentityID(v)
+ })
+}
+
+// UpdateIdentityID sets the "identity_id" field to the value that was provided on create.
+func (u *IdentityAdoptionDecisionUpsertBulk) UpdateIdentityID() *IdentityAdoptionDecisionUpsertBulk {
+ return u.Update(func(s *IdentityAdoptionDecisionUpsert) {
+ s.UpdateIdentityID()
+ })
+}
+
+// ClearIdentityID clears the value of the "identity_id" field.
+func (u *IdentityAdoptionDecisionUpsertBulk) ClearIdentityID() *IdentityAdoptionDecisionUpsertBulk {
+ return u.Update(func(s *IdentityAdoptionDecisionUpsert) {
+ s.ClearIdentityID()
+ })
+}
+
+// SetAdoptDisplayName sets the "adopt_display_name" field.
+func (u *IdentityAdoptionDecisionUpsertBulk) SetAdoptDisplayName(v bool) *IdentityAdoptionDecisionUpsertBulk {
+ return u.Update(func(s *IdentityAdoptionDecisionUpsert) {
+ s.SetAdoptDisplayName(v)
+ })
+}
+
+// UpdateAdoptDisplayName sets the "adopt_display_name" field to the value that was provided on create.
+func (u *IdentityAdoptionDecisionUpsertBulk) UpdateAdoptDisplayName() *IdentityAdoptionDecisionUpsertBulk {
+ return u.Update(func(s *IdentityAdoptionDecisionUpsert) {
+ s.UpdateAdoptDisplayName()
+ })
+}
+
+// SetAdoptAvatar sets the "adopt_avatar" field.
+func (u *IdentityAdoptionDecisionUpsertBulk) SetAdoptAvatar(v bool) *IdentityAdoptionDecisionUpsertBulk {
+ return u.Update(func(s *IdentityAdoptionDecisionUpsert) {
+ s.SetAdoptAvatar(v)
+ })
+}
+
+// UpdateAdoptAvatar sets the "adopt_avatar" field to the value that was provided on create.
+func (u *IdentityAdoptionDecisionUpsertBulk) UpdateAdoptAvatar() *IdentityAdoptionDecisionUpsertBulk {
+ return u.Update(func(s *IdentityAdoptionDecisionUpsert) {
+ s.UpdateAdoptAvatar()
+ })
+}
+
+// Exec executes the query.
+func (u *IdentityAdoptionDecisionUpsertBulk) Exec(ctx context.Context) error {
+ if u.create.err != nil {
+ return u.create.err
+ }
+ for i, b := range u.create.builders {
+ if len(b.conflict) != 0 {
+ return fmt.Errorf("ent: OnConflict was set for builder %d. Set it on the IdentityAdoptionDecisionCreateBulk instead", i)
+ }
+ }
+ if len(u.create.conflict) == 0 {
+ return errors.New("ent: missing options for IdentityAdoptionDecisionCreateBulk.OnConflict")
+ }
+ return u.create.Exec(ctx)
+}
+
+// ExecX is like Exec, but panics if an error occurs.
+func (u *IdentityAdoptionDecisionUpsertBulk) ExecX(ctx context.Context) {
+ if err := u.create.Exec(ctx); err != nil {
+ panic(err)
+ }
+}
diff --git a/backend/ent/identityadoptiondecision_delete.go b/backend/ent/identityadoptiondecision_delete.go
new file mode 100644
index 00000000..ef3d328d
--- /dev/null
+++ b/backend/ent/identityadoptiondecision_delete.go
@@ -0,0 +1,88 @@
+// Code generated by ent, DO NOT EDIT.
+
+package ent
+
+import (
+ "context"
+
+ "entgo.io/ent/dialect/sql"
+ "entgo.io/ent/dialect/sql/sqlgraph"
+ "entgo.io/ent/schema/field"
+ "github.com/Wei-Shaw/sub2api/ent/identityadoptiondecision"
+ "github.com/Wei-Shaw/sub2api/ent/predicate"
+)
+
+// IdentityAdoptionDecisionDelete is the builder for deleting a IdentityAdoptionDecision entity.
+type IdentityAdoptionDecisionDelete struct {
+ config
+ hooks []Hook
+ mutation *IdentityAdoptionDecisionMutation
+}
+
+// Where appends a list predicates to the IdentityAdoptionDecisionDelete builder.
+func (_d *IdentityAdoptionDecisionDelete) Where(ps ...predicate.IdentityAdoptionDecision) *IdentityAdoptionDecisionDelete {
+ _d.mutation.Where(ps...)
+ return _d
+}
+
+// Exec executes the deletion query and returns how many vertices were deleted.
+func (_d *IdentityAdoptionDecisionDelete) Exec(ctx context.Context) (int, error) {
+ return withHooks(ctx, _d.sqlExec, _d.mutation, _d.hooks)
+}
+
+// ExecX is like Exec, but panics if an error occurs.
+func (_d *IdentityAdoptionDecisionDelete) ExecX(ctx context.Context) int {
+ n, err := _d.Exec(ctx)
+ if err != nil {
+ panic(err)
+ }
+ return n
+}
+
+func (_d *IdentityAdoptionDecisionDelete) sqlExec(ctx context.Context) (int, error) {
+ _spec := sqlgraph.NewDeleteSpec(identityadoptiondecision.Table, sqlgraph.NewFieldSpec(identityadoptiondecision.FieldID, field.TypeInt64))
+ if ps := _d.mutation.predicates; len(ps) > 0 {
+ _spec.Predicate = func(selector *sql.Selector) {
+ for i := range ps {
+ ps[i](selector)
+ }
+ }
+ }
+ affected, err := sqlgraph.DeleteNodes(ctx, _d.driver, _spec)
+ if err != nil && sqlgraph.IsConstraintError(err) {
+ err = &ConstraintError{msg: err.Error(), wrap: err}
+ }
+ _d.mutation.done = true
+ return affected, err
+}
+
+// IdentityAdoptionDecisionDeleteOne is the builder for deleting a single IdentityAdoptionDecision entity.
+type IdentityAdoptionDecisionDeleteOne struct {
+ _d *IdentityAdoptionDecisionDelete
+}
+
+// Where appends a list predicates to the IdentityAdoptionDecisionDelete builder.
+func (_d *IdentityAdoptionDecisionDeleteOne) Where(ps ...predicate.IdentityAdoptionDecision) *IdentityAdoptionDecisionDeleteOne {
+ _d._d.mutation.Where(ps...)
+ return _d
+}
+
+// Exec executes the deletion query.
+func (_d *IdentityAdoptionDecisionDeleteOne) Exec(ctx context.Context) error {
+ n, err := _d._d.Exec(ctx)
+ switch {
+ case err != nil:
+ return err
+ case n == 0:
+ return &NotFoundError{identityadoptiondecision.Label}
+ default:
+ return nil
+ }
+}
+
+// ExecX is like Exec, but panics if an error occurs.
+func (_d *IdentityAdoptionDecisionDeleteOne) ExecX(ctx context.Context) {
+ if err := _d.Exec(ctx); err != nil {
+ panic(err)
+ }
+}
diff --git a/backend/ent/identityadoptiondecision_query.go b/backend/ent/identityadoptiondecision_query.go
new file mode 100644
index 00000000..4082d8ee
--- /dev/null
+++ b/backend/ent/identityadoptiondecision_query.go
@@ -0,0 +1,721 @@
+// Code generated by ent, DO NOT EDIT.
+
+package ent
+
+import (
+ "context"
+ "fmt"
+ "math"
+
+ "entgo.io/ent"
+ "entgo.io/ent/dialect"
+ "entgo.io/ent/dialect/sql"
+ "entgo.io/ent/dialect/sql/sqlgraph"
+ "entgo.io/ent/schema/field"
+ "github.com/Wei-Shaw/sub2api/ent/authidentity"
+ "github.com/Wei-Shaw/sub2api/ent/identityadoptiondecision"
+ "github.com/Wei-Shaw/sub2api/ent/pendingauthsession"
+ "github.com/Wei-Shaw/sub2api/ent/predicate"
+)
+
+// IdentityAdoptionDecisionQuery is the builder for querying IdentityAdoptionDecision entities.
+type IdentityAdoptionDecisionQuery struct {
+ config
+ ctx *QueryContext
+ order []identityadoptiondecision.OrderOption
+ inters []Interceptor
+ predicates []predicate.IdentityAdoptionDecision
+ withPendingAuthSession *PendingAuthSessionQuery
+ withIdentity *AuthIdentityQuery
+ modifiers []func(*sql.Selector)
+ // intermediate query (i.e. traversal path).
+ sql *sql.Selector
+ path func(context.Context) (*sql.Selector, error)
+}
+
+// Where adds a new predicate for the IdentityAdoptionDecisionQuery builder.
+func (_q *IdentityAdoptionDecisionQuery) Where(ps ...predicate.IdentityAdoptionDecision) *IdentityAdoptionDecisionQuery {
+ _q.predicates = append(_q.predicates, ps...)
+ return _q
+}
+
+// Limit the number of records to be returned by this query.
+func (_q *IdentityAdoptionDecisionQuery) Limit(limit int) *IdentityAdoptionDecisionQuery {
+ _q.ctx.Limit = &limit
+ return _q
+}
+
+// Offset to start from.
+func (_q *IdentityAdoptionDecisionQuery) Offset(offset int) *IdentityAdoptionDecisionQuery {
+ _q.ctx.Offset = &offset
+ return _q
+}
+
+// Unique configures the query builder to filter duplicate records on query.
+// By default, unique is set to true, and can be disabled using this method.
+func (_q *IdentityAdoptionDecisionQuery) Unique(unique bool) *IdentityAdoptionDecisionQuery {
+ _q.ctx.Unique = &unique
+ return _q
+}
+
+// Order specifies how the records should be ordered.
+func (_q *IdentityAdoptionDecisionQuery) Order(o ...identityadoptiondecision.OrderOption) *IdentityAdoptionDecisionQuery {
+ _q.order = append(_q.order, o...)
+ return _q
+}
+
+// QueryPendingAuthSession chains the current query on the "pending_auth_session" edge.
+func (_q *IdentityAdoptionDecisionQuery) QueryPendingAuthSession() *PendingAuthSessionQuery {
+ query := (&PendingAuthSessionClient{config: _q.config}).Query()
+ query.path = func(ctx context.Context) (fromU *sql.Selector, err error) {
+ if err := _q.prepareQuery(ctx); err != nil {
+ return nil, err
+ }
+ selector := _q.sqlQuery(ctx)
+ if err := selector.Err(); err != nil {
+ return nil, err
+ }
+ step := sqlgraph.NewStep(
+ sqlgraph.From(identityadoptiondecision.Table, identityadoptiondecision.FieldID, selector),
+ sqlgraph.To(pendingauthsession.Table, pendingauthsession.FieldID),
+ sqlgraph.Edge(sqlgraph.O2O, true, identityadoptiondecision.PendingAuthSessionTable, identityadoptiondecision.PendingAuthSessionColumn),
+ )
+ fromU = sqlgraph.SetNeighbors(_q.driver.Dialect(), step)
+ return fromU, nil
+ }
+ return query
+}
+
+// QueryIdentity chains the current query on the "identity" edge.
+func (_q *IdentityAdoptionDecisionQuery) QueryIdentity() *AuthIdentityQuery {
+ query := (&AuthIdentityClient{config: _q.config}).Query()
+ query.path = func(ctx context.Context) (fromU *sql.Selector, err error) {
+ if err := _q.prepareQuery(ctx); err != nil {
+ return nil, err
+ }
+ selector := _q.sqlQuery(ctx)
+ if err := selector.Err(); err != nil {
+ return nil, err
+ }
+ step := sqlgraph.NewStep(
+ sqlgraph.From(identityadoptiondecision.Table, identityadoptiondecision.FieldID, selector),
+ sqlgraph.To(authidentity.Table, authidentity.FieldID),
+ sqlgraph.Edge(sqlgraph.M2O, true, identityadoptiondecision.IdentityTable, identityadoptiondecision.IdentityColumn),
+ )
+ fromU = sqlgraph.SetNeighbors(_q.driver.Dialect(), step)
+ return fromU, nil
+ }
+ return query
+}
+
+// First returns the first IdentityAdoptionDecision entity from the query.
+// Returns a *NotFoundError when no IdentityAdoptionDecision was found.
+func (_q *IdentityAdoptionDecisionQuery) First(ctx context.Context) (*IdentityAdoptionDecision, error) {
+ nodes, err := _q.Limit(1).All(setContextOp(ctx, _q.ctx, ent.OpQueryFirst))
+ if err != nil {
+ return nil, err
+ }
+ if len(nodes) == 0 {
+ return nil, &NotFoundError{identityadoptiondecision.Label}
+ }
+ return nodes[0], nil
+}
+
+// FirstX is like First, but panics if an error occurs.
+func (_q *IdentityAdoptionDecisionQuery) FirstX(ctx context.Context) *IdentityAdoptionDecision {
+ node, err := _q.First(ctx)
+ if err != nil && !IsNotFound(err) {
+ panic(err)
+ }
+ return node
+}
+
+// FirstID returns the first IdentityAdoptionDecision ID from the query.
+// Returns a *NotFoundError when no IdentityAdoptionDecision ID was found.
+func (_q *IdentityAdoptionDecisionQuery) FirstID(ctx context.Context) (id int64, err error) {
+ var ids []int64
+ if ids, err = _q.Limit(1).IDs(setContextOp(ctx, _q.ctx, ent.OpQueryFirstID)); err != nil {
+ return
+ }
+ if len(ids) == 0 {
+ err = &NotFoundError{identityadoptiondecision.Label}
+ return
+ }
+ return ids[0], nil
+}
+
+// FirstIDX is like FirstID, but panics if an error occurs.
+func (_q *IdentityAdoptionDecisionQuery) FirstIDX(ctx context.Context) int64 {
+ id, err := _q.FirstID(ctx)
+ if err != nil && !IsNotFound(err) {
+ panic(err)
+ }
+ return id
+}
+
+// Only returns a single IdentityAdoptionDecision entity found by the query, ensuring it only returns one.
+// Returns a *NotSingularError when more than one IdentityAdoptionDecision entity is found.
+// Returns a *NotFoundError when no IdentityAdoptionDecision entities are found.
+func (_q *IdentityAdoptionDecisionQuery) Only(ctx context.Context) (*IdentityAdoptionDecision, error) {
+ nodes, err := _q.Limit(2).All(setContextOp(ctx, _q.ctx, ent.OpQueryOnly))
+ if err != nil {
+ return nil, err
+ }
+ switch len(nodes) {
+ case 1:
+ return nodes[0], nil
+ case 0:
+ return nil, &NotFoundError{identityadoptiondecision.Label}
+ default:
+ return nil, &NotSingularError{identityadoptiondecision.Label}
+ }
+}
+
+// OnlyX is like Only, but panics if an error occurs.
+func (_q *IdentityAdoptionDecisionQuery) OnlyX(ctx context.Context) *IdentityAdoptionDecision {
+ node, err := _q.Only(ctx)
+ if err != nil {
+ panic(err)
+ }
+ return node
+}
+
+// OnlyID is like Only, but returns the only IdentityAdoptionDecision ID in the query.
+// Returns a *NotSingularError when more than one IdentityAdoptionDecision ID is found.
+// Returns a *NotFoundError when no entities are found.
+func (_q *IdentityAdoptionDecisionQuery) OnlyID(ctx context.Context) (id int64, err error) {
+ var ids []int64
+ if ids, err = _q.Limit(2).IDs(setContextOp(ctx, _q.ctx, ent.OpQueryOnlyID)); err != nil {
+ return
+ }
+ switch len(ids) {
+ case 1:
+ id = ids[0]
+ case 0:
+ err = &NotFoundError{identityadoptiondecision.Label}
+ default:
+ err = &NotSingularError{identityadoptiondecision.Label}
+ }
+ return
+}
+
+// OnlyIDX is like OnlyID, but panics if an error occurs.
+func (_q *IdentityAdoptionDecisionQuery) OnlyIDX(ctx context.Context) int64 {
+ id, err := _q.OnlyID(ctx)
+ if err != nil {
+ panic(err)
+ }
+ return id
+}
+
+// All executes the query and returns a list of IdentityAdoptionDecisions.
+func (_q *IdentityAdoptionDecisionQuery) All(ctx context.Context) ([]*IdentityAdoptionDecision, error) {
+ ctx = setContextOp(ctx, _q.ctx, ent.OpQueryAll)
+ if err := _q.prepareQuery(ctx); err != nil {
+ return nil, err
+ }
+ qr := querierAll[[]*IdentityAdoptionDecision, *IdentityAdoptionDecisionQuery]()
+ return withInterceptors[[]*IdentityAdoptionDecision](ctx, _q, qr, _q.inters)
+}
+
+// AllX is like All, but panics if an error occurs.
+func (_q *IdentityAdoptionDecisionQuery) AllX(ctx context.Context) []*IdentityAdoptionDecision {
+ nodes, err := _q.All(ctx)
+ if err != nil {
+ panic(err)
+ }
+ return nodes
+}
+
+// IDs executes the query and returns a list of IdentityAdoptionDecision IDs.
+func (_q *IdentityAdoptionDecisionQuery) IDs(ctx context.Context) (ids []int64, err error) {
+ if _q.ctx.Unique == nil && _q.path != nil {
+ _q.Unique(true)
+ }
+ ctx = setContextOp(ctx, _q.ctx, ent.OpQueryIDs)
+ if err = _q.Select(identityadoptiondecision.FieldID).Scan(ctx, &ids); err != nil {
+ return nil, err
+ }
+ return ids, nil
+}
+
+// IDsX is like IDs, but panics if an error occurs.
+func (_q *IdentityAdoptionDecisionQuery) IDsX(ctx context.Context) []int64 {
+ ids, err := _q.IDs(ctx)
+ if err != nil {
+ panic(err)
+ }
+ return ids
+}
+
+// Count returns the count of the given query.
+func (_q *IdentityAdoptionDecisionQuery) Count(ctx context.Context) (int, error) {
+ ctx = setContextOp(ctx, _q.ctx, ent.OpQueryCount)
+ if err := _q.prepareQuery(ctx); err != nil {
+ return 0, err
+ }
+ return withInterceptors[int](ctx, _q, querierCount[*IdentityAdoptionDecisionQuery](), _q.inters)
+}
+
+// CountX is like Count, but panics if an error occurs.
+func (_q *IdentityAdoptionDecisionQuery) CountX(ctx context.Context) int {
+ count, err := _q.Count(ctx)
+ if err != nil {
+ panic(err)
+ }
+ return count
+}
+
+// Exist returns true if the query has elements in the graph.
+func (_q *IdentityAdoptionDecisionQuery) Exist(ctx context.Context) (bool, error) {
+ ctx = setContextOp(ctx, _q.ctx, ent.OpQueryExist)
+ switch _, err := _q.FirstID(ctx); {
+ case IsNotFound(err):
+ return false, nil
+ case err != nil:
+ return false, fmt.Errorf("ent: check existence: %w", err)
+ default:
+ return true, nil
+ }
+}
+
+// ExistX is like Exist, but panics if an error occurs.
+func (_q *IdentityAdoptionDecisionQuery) ExistX(ctx context.Context) bool {
+ exist, err := _q.Exist(ctx)
+ if err != nil {
+ panic(err)
+ }
+ return exist
+}
+
+// Clone returns a duplicate of the IdentityAdoptionDecisionQuery builder, including all associated steps. It can be
+// used to prepare common query builders and use them differently after the clone is made.
+func (_q *IdentityAdoptionDecisionQuery) Clone() *IdentityAdoptionDecisionQuery {
+ if _q == nil {
+ return nil
+ }
+ return &IdentityAdoptionDecisionQuery{
+ config: _q.config,
+ ctx: _q.ctx.Clone(),
+ order: append([]identityadoptiondecision.OrderOption{}, _q.order...),
+ inters: append([]Interceptor{}, _q.inters...),
+ predicates: append([]predicate.IdentityAdoptionDecision{}, _q.predicates...),
+ withPendingAuthSession: _q.withPendingAuthSession.Clone(),
+ withIdentity: _q.withIdentity.Clone(),
+ // clone intermediate query.
+ sql: _q.sql.Clone(),
+ path: _q.path,
+ }
+}
+
+// WithPendingAuthSession tells the query-builder to eager-load the nodes that are connected to
+// the "pending_auth_session" edge. The optional arguments are used to configure the query builder of the edge.
+func (_q *IdentityAdoptionDecisionQuery) WithPendingAuthSession(opts ...func(*PendingAuthSessionQuery)) *IdentityAdoptionDecisionQuery {
+ query := (&PendingAuthSessionClient{config: _q.config}).Query()
+ for _, opt := range opts {
+ opt(query)
+ }
+ _q.withPendingAuthSession = query
+ return _q
+}
+
+// WithIdentity tells the query-builder to eager-load the nodes that are connected to
+// the "identity" edge. The optional arguments are used to configure the query builder of the edge.
+func (_q *IdentityAdoptionDecisionQuery) WithIdentity(opts ...func(*AuthIdentityQuery)) *IdentityAdoptionDecisionQuery {
+ query := (&AuthIdentityClient{config: _q.config}).Query()
+ for _, opt := range opts {
+ opt(query)
+ }
+ _q.withIdentity = query
+ return _q
+}
+
+// GroupBy is used to group vertices by one or more fields/columns.
+// It is often used with aggregate functions, like: count, max, mean, min, sum.
+//
+// Example:
+//
+// var v []struct {
+// CreatedAt time.Time `json:"created_at,omitempty"`
+// Count int `json:"count,omitempty"`
+// }
+//
+// client.IdentityAdoptionDecision.Query().
+// GroupBy(identityadoptiondecision.FieldCreatedAt).
+// Aggregate(ent.Count()).
+// Scan(ctx, &v)
+func (_q *IdentityAdoptionDecisionQuery) GroupBy(field string, fields ...string) *IdentityAdoptionDecisionGroupBy {
+ _q.ctx.Fields = append([]string{field}, fields...)
+ grbuild := &IdentityAdoptionDecisionGroupBy{build: _q}
+ grbuild.flds = &_q.ctx.Fields
+ grbuild.label = identityadoptiondecision.Label
+ grbuild.scan = grbuild.Scan
+ return grbuild
+}
+
+// Select allows the selection one or more fields/columns for the given query,
+// instead of selecting all fields in the entity.
+//
+// Example:
+//
+// var v []struct {
+// CreatedAt time.Time `json:"created_at,omitempty"`
+// }
+//
+// client.IdentityAdoptionDecision.Query().
+// Select(identityadoptiondecision.FieldCreatedAt).
+// Scan(ctx, &v)
+func (_q *IdentityAdoptionDecisionQuery) Select(fields ...string) *IdentityAdoptionDecisionSelect {
+ _q.ctx.Fields = append(_q.ctx.Fields, fields...)
+ sbuild := &IdentityAdoptionDecisionSelect{IdentityAdoptionDecisionQuery: _q}
+ sbuild.label = identityadoptiondecision.Label
+ sbuild.flds, sbuild.scan = &_q.ctx.Fields, sbuild.Scan
+ return sbuild
+}
+
+// Aggregate returns a IdentityAdoptionDecisionSelect configured with the given aggregations.
+func (_q *IdentityAdoptionDecisionQuery) Aggregate(fns ...AggregateFunc) *IdentityAdoptionDecisionSelect {
+ return _q.Select().Aggregate(fns...)
+}
+
+func (_q *IdentityAdoptionDecisionQuery) prepareQuery(ctx context.Context) error {
+ for _, inter := range _q.inters {
+ if inter == nil {
+ return fmt.Errorf("ent: uninitialized interceptor (forgotten import ent/runtime?)")
+ }
+ if trv, ok := inter.(Traverser); ok {
+ if err := trv.Traverse(ctx, _q); err != nil {
+ return err
+ }
+ }
+ }
+ for _, f := range _q.ctx.Fields {
+ if !identityadoptiondecision.ValidColumn(f) {
+ return &ValidationError{Name: f, err: fmt.Errorf("ent: invalid field %q for query", f)}
+ }
+ }
+ if _q.path != nil {
+ prev, err := _q.path(ctx)
+ if err != nil {
+ return err
+ }
+ _q.sql = prev
+ }
+ return nil
+}
+
+func (_q *IdentityAdoptionDecisionQuery) sqlAll(ctx context.Context, hooks ...queryHook) ([]*IdentityAdoptionDecision, error) {
+ var (
+ nodes = []*IdentityAdoptionDecision{}
+ _spec = _q.querySpec()
+ loadedTypes = [2]bool{
+ _q.withPendingAuthSession != nil,
+ _q.withIdentity != nil,
+ }
+ )
+ _spec.ScanValues = func(columns []string) ([]any, error) {
+ return (*IdentityAdoptionDecision).scanValues(nil, columns)
+ }
+ _spec.Assign = func(columns []string, values []any) error {
+ node := &IdentityAdoptionDecision{config: _q.config}
+ nodes = append(nodes, node)
+ node.Edges.loadedTypes = loadedTypes
+ return node.assignValues(columns, values)
+ }
+ if len(_q.modifiers) > 0 {
+ _spec.Modifiers = _q.modifiers
+ }
+ for i := range hooks {
+ hooks[i](ctx, _spec)
+ }
+ if err := sqlgraph.QueryNodes(ctx, _q.driver, _spec); err != nil {
+ return nil, err
+ }
+ if len(nodes) == 0 {
+ return nodes, nil
+ }
+ if query := _q.withPendingAuthSession; query != nil {
+ if err := _q.loadPendingAuthSession(ctx, query, nodes, nil,
+ func(n *IdentityAdoptionDecision, e *PendingAuthSession) { n.Edges.PendingAuthSession = e }); err != nil {
+ return nil, err
+ }
+ }
+ if query := _q.withIdentity; query != nil {
+ if err := _q.loadIdentity(ctx, query, nodes, nil,
+ func(n *IdentityAdoptionDecision, e *AuthIdentity) { n.Edges.Identity = e }); err != nil {
+ return nil, err
+ }
+ }
+ return nodes, nil
+}
+
+func (_q *IdentityAdoptionDecisionQuery) loadPendingAuthSession(ctx context.Context, query *PendingAuthSessionQuery, nodes []*IdentityAdoptionDecision, init func(*IdentityAdoptionDecision), assign func(*IdentityAdoptionDecision, *PendingAuthSession)) error {
+ ids := make([]int64, 0, len(nodes))
+ nodeids := make(map[int64][]*IdentityAdoptionDecision)
+ for i := range nodes {
+ fk := nodes[i].PendingAuthSessionID
+ if _, ok := nodeids[fk]; !ok {
+ ids = append(ids, fk)
+ }
+ nodeids[fk] = append(nodeids[fk], nodes[i])
+ }
+ if len(ids) == 0 {
+ return nil
+ }
+ query.Where(pendingauthsession.IDIn(ids...))
+ neighbors, err := query.All(ctx)
+ if err != nil {
+ return err
+ }
+ for _, n := range neighbors {
+ nodes, ok := nodeids[n.ID]
+ if !ok {
+ return fmt.Errorf(`unexpected foreign-key "pending_auth_session_id" returned %v`, n.ID)
+ }
+ for i := range nodes {
+ assign(nodes[i], n)
+ }
+ }
+ return nil
+}
+func (_q *IdentityAdoptionDecisionQuery) loadIdentity(ctx context.Context, query *AuthIdentityQuery, nodes []*IdentityAdoptionDecision, init func(*IdentityAdoptionDecision), assign func(*IdentityAdoptionDecision, *AuthIdentity)) error {
+ ids := make([]int64, 0, len(nodes))
+ nodeids := make(map[int64][]*IdentityAdoptionDecision)
+ for i := range nodes {
+ if nodes[i].IdentityID == nil {
+ continue
+ }
+ fk := *nodes[i].IdentityID
+ if _, ok := nodeids[fk]; !ok {
+ ids = append(ids, fk)
+ }
+ nodeids[fk] = append(nodeids[fk], nodes[i])
+ }
+ if len(ids) == 0 {
+ return nil
+ }
+ query.Where(authidentity.IDIn(ids...))
+ neighbors, err := query.All(ctx)
+ if err != nil {
+ return err
+ }
+ for _, n := range neighbors {
+ nodes, ok := nodeids[n.ID]
+ if !ok {
+ return fmt.Errorf(`unexpected foreign-key "identity_id" returned %v`, n.ID)
+ }
+ for i := range nodes {
+ assign(nodes[i], n)
+ }
+ }
+ return nil
+}
+
+func (_q *IdentityAdoptionDecisionQuery) sqlCount(ctx context.Context) (int, error) {
+ _spec := _q.querySpec()
+ if len(_q.modifiers) > 0 {
+ _spec.Modifiers = _q.modifiers
+ }
+ _spec.Node.Columns = _q.ctx.Fields
+ if len(_q.ctx.Fields) > 0 {
+ _spec.Unique = _q.ctx.Unique != nil && *_q.ctx.Unique
+ }
+ return sqlgraph.CountNodes(ctx, _q.driver, _spec)
+}
+
+func (_q *IdentityAdoptionDecisionQuery) querySpec() *sqlgraph.QuerySpec {
+ _spec := sqlgraph.NewQuerySpec(identityadoptiondecision.Table, identityadoptiondecision.Columns, sqlgraph.NewFieldSpec(identityadoptiondecision.FieldID, field.TypeInt64))
+ _spec.From = _q.sql
+ if unique := _q.ctx.Unique; unique != nil {
+ _spec.Unique = *unique
+ } else if _q.path != nil {
+ _spec.Unique = true
+ }
+ if fields := _q.ctx.Fields; len(fields) > 0 {
+ _spec.Node.Columns = make([]string, 0, len(fields))
+ _spec.Node.Columns = append(_spec.Node.Columns, identityadoptiondecision.FieldID)
+ for i := range fields {
+ if fields[i] != identityadoptiondecision.FieldID {
+ _spec.Node.Columns = append(_spec.Node.Columns, fields[i])
+ }
+ }
+ if _q.withPendingAuthSession != nil {
+ _spec.Node.AddColumnOnce(identityadoptiondecision.FieldPendingAuthSessionID)
+ }
+ if _q.withIdentity != nil {
+ _spec.Node.AddColumnOnce(identityadoptiondecision.FieldIdentityID)
+ }
+ }
+ if ps := _q.predicates; len(ps) > 0 {
+ _spec.Predicate = func(selector *sql.Selector) {
+ for i := range ps {
+ ps[i](selector)
+ }
+ }
+ }
+ if limit := _q.ctx.Limit; limit != nil {
+ _spec.Limit = *limit
+ }
+ if offset := _q.ctx.Offset; offset != nil {
+ _spec.Offset = *offset
+ }
+ if ps := _q.order; len(ps) > 0 {
+ _spec.Order = func(selector *sql.Selector) {
+ for i := range ps {
+ ps[i](selector)
+ }
+ }
+ }
+ return _spec
+}
+
+func (_q *IdentityAdoptionDecisionQuery) sqlQuery(ctx context.Context) *sql.Selector {
+ builder := sql.Dialect(_q.driver.Dialect())
+ t1 := builder.Table(identityadoptiondecision.Table)
+ columns := _q.ctx.Fields
+ if len(columns) == 0 {
+ columns = identityadoptiondecision.Columns
+ }
+ selector := builder.Select(t1.Columns(columns...)...).From(t1)
+ if _q.sql != nil {
+ selector = _q.sql
+ selector.Select(selector.Columns(columns...)...)
+ }
+ if _q.ctx.Unique != nil && *_q.ctx.Unique {
+ selector.Distinct()
+ }
+ for _, m := range _q.modifiers {
+ m(selector)
+ }
+ for _, p := range _q.predicates {
+ p(selector)
+ }
+ for _, p := range _q.order {
+ p(selector)
+ }
+ if offset := _q.ctx.Offset; offset != nil {
+ // limit is mandatory for offset clause. We start
+ // with default value, and override it below if needed.
+ selector.Offset(*offset).Limit(math.MaxInt32)
+ }
+ if limit := _q.ctx.Limit; limit != nil {
+ selector.Limit(*limit)
+ }
+ return selector
+}
+
+// ForUpdate locks the selected rows against concurrent updates, and prevent them from being
+// updated, deleted or "selected ... for update" by other sessions, until the transaction is
+// either committed or rolled-back.
+func (_q *IdentityAdoptionDecisionQuery) ForUpdate(opts ...sql.LockOption) *IdentityAdoptionDecisionQuery {
+ if _q.driver.Dialect() == dialect.Postgres {
+ _q.Unique(false)
+ }
+ _q.modifiers = append(_q.modifiers, func(s *sql.Selector) {
+ s.ForUpdate(opts...)
+ })
+ return _q
+}
+
+// ForShare behaves similarly to ForUpdate, except that it acquires a shared mode lock
+// on any rows that are read. Other sessions can read the rows, but cannot modify them
+// until your transaction commits.
+func (_q *IdentityAdoptionDecisionQuery) ForShare(opts ...sql.LockOption) *IdentityAdoptionDecisionQuery {
+ if _q.driver.Dialect() == dialect.Postgres {
+ _q.Unique(false)
+ }
+ _q.modifiers = append(_q.modifiers, func(s *sql.Selector) {
+ s.ForShare(opts...)
+ })
+ return _q
+}
+
+// IdentityAdoptionDecisionGroupBy is the group-by builder for IdentityAdoptionDecision entities.
+type IdentityAdoptionDecisionGroupBy struct {
+ selector
+ build *IdentityAdoptionDecisionQuery
+}
+
+// Aggregate adds the given aggregation functions to the group-by query.
+func (_g *IdentityAdoptionDecisionGroupBy) Aggregate(fns ...AggregateFunc) *IdentityAdoptionDecisionGroupBy {
+ _g.fns = append(_g.fns, fns...)
+ return _g
+}
+
+// Scan applies the selector query and scans the result into the given value.
+func (_g *IdentityAdoptionDecisionGroupBy) Scan(ctx context.Context, v any) error {
+ ctx = setContextOp(ctx, _g.build.ctx, ent.OpQueryGroupBy)
+ if err := _g.build.prepareQuery(ctx); err != nil {
+ return err
+ }
+ return scanWithInterceptors[*IdentityAdoptionDecisionQuery, *IdentityAdoptionDecisionGroupBy](ctx, _g.build, _g, _g.build.inters, v)
+}
+
+func (_g *IdentityAdoptionDecisionGroupBy) sqlScan(ctx context.Context, root *IdentityAdoptionDecisionQuery, v any) error {
+ selector := root.sqlQuery(ctx).Select()
+ aggregation := make([]string, 0, len(_g.fns))
+ for _, fn := range _g.fns {
+ aggregation = append(aggregation, fn(selector))
+ }
+ if len(selector.SelectedColumns()) == 0 {
+ columns := make([]string, 0, len(*_g.flds)+len(_g.fns))
+ for _, f := range *_g.flds {
+ columns = append(columns, selector.C(f))
+ }
+ columns = append(columns, aggregation...)
+ selector.Select(columns...)
+ }
+ selector.GroupBy(selector.Columns(*_g.flds...)...)
+ if err := selector.Err(); err != nil {
+ return err
+ }
+ rows := &sql.Rows{}
+ query, args := selector.Query()
+ if err := _g.build.driver.Query(ctx, query, args, rows); err != nil {
+ return err
+ }
+ defer rows.Close()
+ return sql.ScanSlice(rows, v)
+}
+
+// IdentityAdoptionDecisionSelect is the builder for selecting fields of IdentityAdoptionDecision entities.
+type IdentityAdoptionDecisionSelect struct {
+ *IdentityAdoptionDecisionQuery
+ selector
+}
+
+// Aggregate adds the given aggregation functions to the selector query.
+func (_s *IdentityAdoptionDecisionSelect) Aggregate(fns ...AggregateFunc) *IdentityAdoptionDecisionSelect {
+ _s.fns = append(_s.fns, fns...)
+ return _s
+}
+
+// Scan applies the selector query and scans the result into the given value.
+func (_s *IdentityAdoptionDecisionSelect) Scan(ctx context.Context, v any) error {
+ ctx = setContextOp(ctx, _s.ctx, ent.OpQuerySelect)
+ if err := _s.prepareQuery(ctx); err != nil {
+ return err
+ }
+ return scanWithInterceptors[*IdentityAdoptionDecisionQuery, *IdentityAdoptionDecisionSelect](ctx, _s.IdentityAdoptionDecisionQuery, _s, _s.inters, v)
+}
+
+func (_s *IdentityAdoptionDecisionSelect) sqlScan(ctx context.Context, root *IdentityAdoptionDecisionQuery, v any) error {
+ selector := root.sqlQuery(ctx)
+ aggregation := make([]string, 0, len(_s.fns))
+ for _, fn := range _s.fns {
+ aggregation = append(aggregation, fn(selector))
+ }
+ switch n := len(*_s.selector.flds); {
+ case n == 0 && len(aggregation) > 0:
+ selector.Select(aggregation...)
+ case n != 0 && len(aggregation) > 0:
+ selector.AppendSelect(aggregation...)
+ }
+ rows := &sql.Rows{}
+ query, args := selector.Query()
+ if err := _s.driver.Query(ctx, query, args, rows); err != nil {
+ return err
+ }
+ defer rows.Close()
+ return sql.ScanSlice(rows, v)
+}
diff --git a/backend/ent/identityadoptiondecision_update.go b/backend/ent/identityadoptiondecision_update.go
new file mode 100644
index 00000000..0ca21d27
--- /dev/null
+++ b/backend/ent/identityadoptiondecision_update.go
@@ -0,0 +1,532 @@
+// Code generated by ent, DO NOT EDIT.
+
+package ent
+
+import (
+ "context"
+ "errors"
+ "fmt"
+ "time"
+
+ "entgo.io/ent/dialect/sql"
+ "entgo.io/ent/dialect/sql/sqlgraph"
+ "entgo.io/ent/schema/field"
+ "github.com/Wei-Shaw/sub2api/ent/authidentity"
+ "github.com/Wei-Shaw/sub2api/ent/identityadoptiondecision"
+ "github.com/Wei-Shaw/sub2api/ent/pendingauthsession"
+ "github.com/Wei-Shaw/sub2api/ent/predicate"
+)
+
+// IdentityAdoptionDecisionUpdate is the builder for updating IdentityAdoptionDecision entities.
+type IdentityAdoptionDecisionUpdate struct {
+ config
+ hooks []Hook
+ mutation *IdentityAdoptionDecisionMutation
+}
+
+// Where appends a list predicates to the IdentityAdoptionDecisionUpdate builder.
+func (_u *IdentityAdoptionDecisionUpdate) Where(ps ...predicate.IdentityAdoptionDecision) *IdentityAdoptionDecisionUpdate {
+ _u.mutation.Where(ps...)
+ return _u
+}
+
+// SetUpdatedAt sets the "updated_at" field.
+func (_u *IdentityAdoptionDecisionUpdate) SetUpdatedAt(v time.Time) *IdentityAdoptionDecisionUpdate {
+ _u.mutation.SetUpdatedAt(v)
+ return _u
+}
+
+// SetPendingAuthSessionID sets the "pending_auth_session_id" field.
+func (_u *IdentityAdoptionDecisionUpdate) SetPendingAuthSessionID(v int64) *IdentityAdoptionDecisionUpdate {
+ _u.mutation.SetPendingAuthSessionID(v)
+ return _u
+}
+
+// SetNillablePendingAuthSessionID sets the "pending_auth_session_id" field if the given value is not nil.
+func (_u *IdentityAdoptionDecisionUpdate) SetNillablePendingAuthSessionID(v *int64) *IdentityAdoptionDecisionUpdate {
+ if v != nil {
+ _u.SetPendingAuthSessionID(*v)
+ }
+ return _u
+}
+
+// SetIdentityID sets the "identity_id" field.
+func (_u *IdentityAdoptionDecisionUpdate) SetIdentityID(v int64) *IdentityAdoptionDecisionUpdate {
+ _u.mutation.SetIdentityID(v)
+ return _u
+}
+
+// SetNillableIdentityID sets the "identity_id" field if the given value is not nil.
+func (_u *IdentityAdoptionDecisionUpdate) SetNillableIdentityID(v *int64) *IdentityAdoptionDecisionUpdate {
+ if v != nil {
+ _u.SetIdentityID(*v)
+ }
+ return _u
+}
+
+// ClearIdentityID clears the value of the "identity_id" field.
+func (_u *IdentityAdoptionDecisionUpdate) ClearIdentityID() *IdentityAdoptionDecisionUpdate {
+ _u.mutation.ClearIdentityID()
+ return _u
+}
+
+// SetAdoptDisplayName sets the "adopt_display_name" field.
+func (_u *IdentityAdoptionDecisionUpdate) SetAdoptDisplayName(v bool) *IdentityAdoptionDecisionUpdate {
+ _u.mutation.SetAdoptDisplayName(v)
+ return _u
+}
+
+// SetNillableAdoptDisplayName sets the "adopt_display_name" field if the given value is not nil.
+func (_u *IdentityAdoptionDecisionUpdate) SetNillableAdoptDisplayName(v *bool) *IdentityAdoptionDecisionUpdate {
+ if v != nil {
+ _u.SetAdoptDisplayName(*v)
+ }
+ return _u
+}
+
+// SetAdoptAvatar sets the "adopt_avatar" field.
+func (_u *IdentityAdoptionDecisionUpdate) SetAdoptAvatar(v bool) *IdentityAdoptionDecisionUpdate {
+ _u.mutation.SetAdoptAvatar(v)
+ return _u
+}
+
+// SetNillableAdoptAvatar sets the "adopt_avatar" field if the given value is not nil.
+func (_u *IdentityAdoptionDecisionUpdate) SetNillableAdoptAvatar(v *bool) *IdentityAdoptionDecisionUpdate {
+ if v != nil {
+ _u.SetAdoptAvatar(*v)
+ }
+ return _u
+}
+
+// SetPendingAuthSession sets the "pending_auth_session" edge to the PendingAuthSession entity.
+func (_u *IdentityAdoptionDecisionUpdate) SetPendingAuthSession(v *PendingAuthSession) *IdentityAdoptionDecisionUpdate {
+ return _u.SetPendingAuthSessionID(v.ID)
+}
+
+// SetIdentity sets the "identity" edge to the AuthIdentity entity.
+func (_u *IdentityAdoptionDecisionUpdate) SetIdentity(v *AuthIdentity) *IdentityAdoptionDecisionUpdate {
+ return _u.SetIdentityID(v.ID)
+}
+
+// Mutation returns the IdentityAdoptionDecisionMutation object of the builder.
+func (_u *IdentityAdoptionDecisionUpdate) Mutation() *IdentityAdoptionDecisionMutation {
+ return _u.mutation
+}
+
+// ClearPendingAuthSession clears the "pending_auth_session" edge to the PendingAuthSession entity.
+func (_u *IdentityAdoptionDecisionUpdate) ClearPendingAuthSession() *IdentityAdoptionDecisionUpdate {
+ _u.mutation.ClearPendingAuthSession()
+ return _u
+}
+
+// ClearIdentity clears the "identity" edge to the AuthIdentity entity.
+func (_u *IdentityAdoptionDecisionUpdate) ClearIdentity() *IdentityAdoptionDecisionUpdate {
+ _u.mutation.ClearIdentity()
+ return _u
+}
+
+// Save executes the query and returns the number of nodes affected by the update operation.
+func (_u *IdentityAdoptionDecisionUpdate) Save(ctx context.Context) (int, error) {
+ _u.defaults()
+ return withHooks(ctx, _u.sqlSave, _u.mutation, _u.hooks)
+}
+
+// SaveX is like Save, but panics if an error occurs.
+func (_u *IdentityAdoptionDecisionUpdate) SaveX(ctx context.Context) int {
+ affected, err := _u.Save(ctx)
+ if err != nil {
+ panic(err)
+ }
+ return affected
+}
+
+// Exec executes the query.
+func (_u *IdentityAdoptionDecisionUpdate) Exec(ctx context.Context) error {
+ _, err := _u.Save(ctx)
+ return err
+}
+
+// ExecX is like Exec, but panics if an error occurs.
+func (_u *IdentityAdoptionDecisionUpdate) ExecX(ctx context.Context) {
+ if err := _u.Exec(ctx); err != nil {
+ panic(err)
+ }
+}
+
+// defaults sets the default values of the builder before save.
+func (_u *IdentityAdoptionDecisionUpdate) defaults() {
+ if _, ok := _u.mutation.UpdatedAt(); !ok {
+ v := identityadoptiondecision.UpdateDefaultUpdatedAt()
+ _u.mutation.SetUpdatedAt(v)
+ }
+}
+
+// check runs all checks and user-defined validators on the builder.
+func (_u *IdentityAdoptionDecisionUpdate) check() error {
+ if _u.mutation.PendingAuthSessionCleared() && len(_u.mutation.PendingAuthSessionIDs()) > 0 {
+ return errors.New(`ent: clearing a required unique edge "IdentityAdoptionDecision.pending_auth_session"`)
+ }
+ return nil
+}
+
+func (_u *IdentityAdoptionDecisionUpdate) sqlSave(ctx context.Context) (_node int, err error) {
+ if err := _u.check(); err != nil {
+ return _node, err
+ }
+ _spec := sqlgraph.NewUpdateSpec(identityadoptiondecision.Table, identityadoptiondecision.Columns, sqlgraph.NewFieldSpec(identityadoptiondecision.FieldID, field.TypeInt64))
+ if ps := _u.mutation.predicates; len(ps) > 0 {
+ _spec.Predicate = func(selector *sql.Selector) {
+ for i := range ps {
+ ps[i](selector)
+ }
+ }
+ }
+ if value, ok := _u.mutation.UpdatedAt(); ok {
+ _spec.SetField(identityadoptiondecision.FieldUpdatedAt, field.TypeTime, value)
+ }
+ if value, ok := _u.mutation.AdoptDisplayName(); ok {
+ _spec.SetField(identityadoptiondecision.FieldAdoptDisplayName, field.TypeBool, value)
+ }
+ if value, ok := _u.mutation.AdoptAvatar(); ok {
+ _spec.SetField(identityadoptiondecision.FieldAdoptAvatar, field.TypeBool, value)
+ }
+ if _u.mutation.PendingAuthSessionCleared() {
+ edge := &sqlgraph.EdgeSpec{
+ Rel: sqlgraph.O2O,
+ Inverse: true,
+ Table: identityadoptiondecision.PendingAuthSessionTable,
+ Columns: []string{identityadoptiondecision.PendingAuthSessionColumn},
+ Bidi: false,
+ Target: &sqlgraph.EdgeTarget{
+ IDSpec: sqlgraph.NewFieldSpec(pendingauthsession.FieldID, field.TypeInt64),
+ },
+ }
+ _spec.Edges.Clear = append(_spec.Edges.Clear, edge)
+ }
+ if nodes := _u.mutation.PendingAuthSessionIDs(); len(nodes) > 0 {
+ edge := &sqlgraph.EdgeSpec{
+ Rel: sqlgraph.O2O,
+ Inverse: true,
+ Table: identityadoptiondecision.PendingAuthSessionTable,
+ Columns: []string{identityadoptiondecision.PendingAuthSessionColumn},
+ Bidi: false,
+ Target: &sqlgraph.EdgeTarget{
+ IDSpec: sqlgraph.NewFieldSpec(pendingauthsession.FieldID, field.TypeInt64),
+ },
+ }
+ for _, k := range nodes {
+ edge.Target.Nodes = append(edge.Target.Nodes, k)
+ }
+ _spec.Edges.Add = append(_spec.Edges.Add, edge)
+ }
+ if _u.mutation.IdentityCleared() {
+ edge := &sqlgraph.EdgeSpec{
+ Rel: sqlgraph.M2O,
+ Inverse: true,
+ Table: identityadoptiondecision.IdentityTable,
+ Columns: []string{identityadoptiondecision.IdentityColumn},
+ Bidi: false,
+ Target: &sqlgraph.EdgeTarget{
+ IDSpec: sqlgraph.NewFieldSpec(authidentity.FieldID, field.TypeInt64),
+ },
+ }
+ _spec.Edges.Clear = append(_spec.Edges.Clear, edge)
+ }
+ if nodes := _u.mutation.IdentityIDs(); len(nodes) > 0 {
+ edge := &sqlgraph.EdgeSpec{
+ Rel: sqlgraph.M2O,
+ Inverse: true,
+ Table: identityadoptiondecision.IdentityTable,
+ Columns: []string{identityadoptiondecision.IdentityColumn},
+ Bidi: false,
+ Target: &sqlgraph.EdgeTarget{
+ IDSpec: sqlgraph.NewFieldSpec(authidentity.FieldID, field.TypeInt64),
+ },
+ }
+ for _, k := range nodes {
+ edge.Target.Nodes = append(edge.Target.Nodes, k)
+ }
+ _spec.Edges.Add = append(_spec.Edges.Add, edge)
+ }
+ if _node, err = sqlgraph.UpdateNodes(ctx, _u.driver, _spec); err != nil {
+ if _, ok := err.(*sqlgraph.NotFoundError); ok {
+ err = &NotFoundError{identityadoptiondecision.Label}
+ } else if sqlgraph.IsConstraintError(err) {
+ err = &ConstraintError{msg: err.Error(), wrap: err}
+ }
+ return 0, err
+ }
+ _u.mutation.done = true
+ return _node, nil
+}
+
+// IdentityAdoptionDecisionUpdateOne is the builder for updating a single IdentityAdoptionDecision entity.
+type IdentityAdoptionDecisionUpdateOne struct {
+ config
+ fields []string
+ hooks []Hook
+ mutation *IdentityAdoptionDecisionMutation
+}
+
+// SetUpdatedAt sets the "updated_at" field.
+func (_u *IdentityAdoptionDecisionUpdateOne) SetUpdatedAt(v time.Time) *IdentityAdoptionDecisionUpdateOne {
+ _u.mutation.SetUpdatedAt(v)
+ return _u
+}
+
+// SetPendingAuthSessionID sets the "pending_auth_session_id" field.
+func (_u *IdentityAdoptionDecisionUpdateOne) SetPendingAuthSessionID(v int64) *IdentityAdoptionDecisionUpdateOne {
+ _u.mutation.SetPendingAuthSessionID(v)
+ return _u
+}
+
+// SetNillablePendingAuthSessionID sets the "pending_auth_session_id" field if the given value is not nil.
+func (_u *IdentityAdoptionDecisionUpdateOne) SetNillablePendingAuthSessionID(v *int64) *IdentityAdoptionDecisionUpdateOne {
+ if v != nil {
+ _u.SetPendingAuthSessionID(*v)
+ }
+ return _u
+}
+
+// SetIdentityID sets the "identity_id" field.
+func (_u *IdentityAdoptionDecisionUpdateOne) SetIdentityID(v int64) *IdentityAdoptionDecisionUpdateOne {
+ _u.mutation.SetIdentityID(v)
+ return _u
+}
+
+// SetNillableIdentityID sets the "identity_id" field if the given value is not nil.
+func (_u *IdentityAdoptionDecisionUpdateOne) SetNillableIdentityID(v *int64) *IdentityAdoptionDecisionUpdateOne {
+ if v != nil {
+ _u.SetIdentityID(*v)
+ }
+ return _u
+}
+
+// ClearIdentityID clears the value of the "identity_id" field.
+func (_u *IdentityAdoptionDecisionUpdateOne) ClearIdentityID() *IdentityAdoptionDecisionUpdateOne {
+ _u.mutation.ClearIdentityID()
+ return _u
+}
+
+// SetAdoptDisplayName sets the "adopt_display_name" field.
+func (_u *IdentityAdoptionDecisionUpdateOne) SetAdoptDisplayName(v bool) *IdentityAdoptionDecisionUpdateOne {
+ _u.mutation.SetAdoptDisplayName(v)
+ return _u
+}
+
+// SetNillableAdoptDisplayName sets the "adopt_display_name" field if the given value is not nil.
+func (_u *IdentityAdoptionDecisionUpdateOne) SetNillableAdoptDisplayName(v *bool) *IdentityAdoptionDecisionUpdateOne {
+ if v != nil {
+ _u.SetAdoptDisplayName(*v)
+ }
+ return _u
+}
+
+// SetAdoptAvatar sets the "adopt_avatar" field.
+func (_u *IdentityAdoptionDecisionUpdateOne) SetAdoptAvatar(v bool) *IdentityAdoptionDecisionUpdateOne {
+ _u.mutation.SetAdoptAvatar(v)
+ return _u
+}
+
+// SetNillableAdoptAvatar sets the "adopt_avatar" field if the given value is not nil.
+func (_u *IdentityAdoptionDecisionUpdateOne) SetNillableAdoptAvatar(v *bool) *IdentityAdoptionDecisionUpdateOne {
+ if v != nil {
+ _u.SetAdoptAvatar(*v)
+ }
+ return _u
+}
+
+// SetPendingAuthSession sets the "pending_auth_session" edge to the PendingAuthSession entity.
+func (_u *IdentityAdoptionDecisionUpdateOne) SetPendingAuthSession(v *PendingAuthSession) *IdentityAdoptionDecisionUpdateOne {
+ return _u.SetPendingAuthSessionID(v.ID)
+}
+
+// SetIdentity sets the "identity" edge to the AuthIdentity entity.
+func (_u *IdentityAdoptionDecisionUpdateOne) SetIdentity(v *AuthIdentity) *IdentityAdoptionDecisionUpdateOne {
+ return _u.SetIdentityID(v.ID)
+}
+
+// Mutation returns the IdentityAdoptionDecisionMutation object of the builder.
+func (_u *IdentityAdoptionDecisionUpdateOne) Mutation() *IdentityAdoptionDecisionMutation {
+ return _u.mutation
+}
+
+// ClearPendingAuthSession clears the "pending_auth_session" edge to the PendingAuthSession entity.
+func (_u *IdentityAdoptionDecisionUpdateOne) ClearPendingAuthSession() *IdentityAdoptionDecisionUpdateOne {
+ _u.mutation.ClearPendingAuthSession()
+ return _u
+}
+
+// ClearIdentity clears the "identity" edge to the AuthIdentity entity.
+func (_u *IdentityAdoptionDecisionUpdateOne) ClearIdentity() *IdentityAdoptionDecisionUpdateOne {
+ _u.mutation.ClearIdentity()
+ return _u
+}
+
+// Where appends a list predicates to the IdentityAdoptionDecisionUpdate builder.
+func (_u *IdentityAdoptionDecisionUpdateOne) Where(ps ...predicate.IdentityAdoptionDecision) *IdentityAdoptionDecisionUpdateOne {
+ _u.mutation.Where(ps...)
+ return _u
+}
+
+// Select allows selecting one or more fields (columns) of the returned entity.
+// The default is selecting all fields defined in the entity schema.
+func (_u *IdentityAdoptionDecisionUpdateOne) Select(field string, fields ...string) *IdentityAdoptionDecisionUpdateOne {
+ _u.fields = append([]string{field}, fields...)
+ return _u
+}
+
+// Save executes the query and returns the updated IdentityAdoptionDecision entity.
+func (_u *IdentityAdoptionDecisionUpdateOne) Save(ctx context.Context) (*IdentityAdoptionDecision, error) {
+ _u.defaults()
+ return withHooks(ctx, _u.sqlSave, _u.mutation, _u.hooks)
+}
+
+// SaveX is like Save, but panics if an error occurs.
+func (_u *IdentityAdoptionDecisionUpdateOne) SaveX(ctx context.Context) *IdentityAdoptionDecision {
+ node, err := _u.Save(ctx)
+ if err != nil {
+ panic(err)
+ }
+ return node
+}
+
+// Exec executes the query on the entity.
+func (_u *IdentityAdoptionDecisionUpdateOne) Exec(ctx context.Context) error {
+ _, err := _u.Save(ctx)
+ return err
+}
+
+// ExecX is like Exec, but panics if an error occurs.
+func (_u *IdentityAdoptionDecisionUpdateOne) ExecX(ctx context.Context) {
+ if err := _u.Exec(ctx); err != nil {
+ panic(err)
+ }
+}
+
+// defaults sets the default values of the builder before save.
+func (_u *IdentityAdoptionDecisionUpdateOne) defaults() {
+ if _, ok := _u.mutation.UpdatedAt(); !ok {
+ v := identityadoptiondecision.UpdateDefaultUpdatedAt()
+ _u.mutation.SetUpdatedAt(v)
+ }
+}
+
+// check runs all checks and user-defined validators on the builder.
+func (_u *IdentityAdoptionDecisionUpdateOne) check() error {
+ if _u.mutation.PendingAuthSessionCleared() && len(_u.mutation.PendingAuthSessionIDs()) > 0 {
+ return errors.New(`ent: clearing a required unique edge "IdentityAdoptionDecision.pending_auth_session"`)
+ }
+ return nil
+}
+
+func (_u *IdentityAdoptionDecisionUpdateOne) sqlSave(ctx context.Context) (_node *IdentityAdoptionDecision, err error) {
+ if err := _u.check(); err != nil {
+ return _node, err
+ }
+ _spec := sqlgraph.NewUpdateSpec(identityadoptiondecision.Table, identityadoptiondecision.Columns, sqlgraph.NewFieldSpec(identityadoptiondecision.FieldID, field.TypeInt64))
+ id, ok := _u.mutation.ID()
+ if !ok {
+ return nil, &ValidationError{Name: "id", err: errors.New(`ent: missing "IdentityAdoptionDecision.id" for update`)}
+ }
+ _spec.Node.ID.Value = id
+ if fields := _u.fields; len(fields) > 0 {
+ _spec.Node.Columns = make([]string, 0, len(fields))
+ _spec.Node.Columns = append(_spec.Node.Columns, identityadoptiondecision.FieldID)
+ for _, f := range fields {
+ if !identityadoptiondecision.ValidColumn(f) {
+ return nil, &ValidationError{Name: f, err: fmt.Errorf("ent: invalid field %q for query", f)}
+ }
+ if f != identityadoptiondecision.FieldID {
+ _spec.Node.Columns = append(_spec.Node.Columns, f)
+ }
+ }
+ }
+ if ps := _u.mutation.predicates; len(ps) > 0 {
+ _spec.Predicate = func(selector *sql.Selector) {
+ for i := range ps {
+ ps[i](selector)
+ }
+ }
+ }
+ if value, ok := _u.mutation.UpdatedAt(); ok {
+ _spec.SetField(identityadoptiondecision.FieldUpdatedAt, field.TypeTime, value)
+ }
+ if value, ok := _u.mutation.AdoptDisplayName(); ok {
+ _spec.SetField(identityadoptiondecision.FieldAdoptDisplayName, field.TypeBool, value)
+ }
+ if value, ok := _u.mutation.AdoptAvatar(); ok {
+ _spec.SetField(identityadoptiondecision.FieldAdoptAvatar, field.TypeBool, value)
+ }
+ if _u.mutation.PendingAuthSessionCleared() {
+ edge := &sqlgraph.EdgeSpec{
+ Rel: sqlgraph.O2O,
+ Inverse: true,
+ Table: identityadoptiondecision.PendingAuthSessionTable,
+ Columns: []string{identityadoptiondecision.PendingAuthSessionColumn},
+ Bidi: false,
+ Target: &sqlgraph.EdgeTarget{
+ IDSpec: sqlgraph.NewFieldSpec(pendingauthsession.FieldID, field.TypeInt64),
+ },
+ }
+ _spec.Edges.Clear = append(_spec.Edges.Clear, edge)
+ }
+ if nodes := _u.mutation.PendingAuthSessionIDs(); len(nodes) > 0 {
+ edge := &sqlgraph.EdgeSpec{
+ Rel: sqlgraph.O2O,
+ Inverse: true,
+ Table: identityadoptiondecision.PendingAuthSessionTable,
+ Columns: []string{identityadoptiondecision.PendingAuthSessionColumn},
+ Bidi: false,
+ Target: &sqlgraph.EdgeTarget{
+ IDSpec: sqlgraph.NewFieldSpec(pendingauthsession.FieldID, field.TypeInt64),
+ },
+ }
+ for _, k := range nodes {
+ edge.Target.Nodes = append(edge.Target.Nodes, k)
+ }
+ _spec.Edges.Add = append(_spec.Edges.Add, edge)
+ }
+ if _u.mutation.IdentityCleared() {
+ edge := &sqlgraph.EdgeSpec{
+ Rel: sqlgraph.M2O,
+ Inverse: true,
+ Table: identityadoptiondecision.IdentityTable,
+ Columns: []string{identityadoptiondecision.IdentityColumn},
+ Bidi: false,
+ Target: &sqlgraph.EdgeTarget{
+ IDSpec: sqlgraph.NewFieldSpec(authidentity.FieldID, field.TypeInt64),
+ },
+ }
+ _spec.Edges.Clear = append(_spec.Edges.Clear, edge)
+ }
+ if nodes := _u.mutation.IdentityIDs(); len(nodes) > 0 {
+ edge := &sqlgraph.EdgeSpec{
+ Rel: sqlgraph.M2O,
+ Inverse: true,
+ Table: identityadoptiondecision.IdentityTable,
+ Columns: []string{identityadoptiondecision.IdentityColumn},
+ Bidi: false,
+ Target: &sqlgraph.EdgeTarget{
+ IDSpec: sqlgraph.NewFieldSpec(authidentity.FieldID, field.TypeInt64),
+ },
+ }
+ for _, k := range nodes {
+ edge.Target.Nodes = append(edge.Target.Nodes, k)
+ }
+ _spec.Edges.Add = append(_spec.Edges.Add, edge)
+ }
+ _node = &IdentityAdoptionDecision{config: _u.config}
+ _spec.Assign = _node.assignValues
+ _spec.ScanValues = _node.scanValues
+ if err = sqlgraph.UpdateNode(ctx, _u.driver, _spec); err != nil {
+ if _, ok := err.(*sqlgraph.NotFoundError); ok {
+ err = &NotFoundError{identityadoptiondecision.Label}
+ } else if sqlgraph.IsConstraintError(err) {
+ err = &ConstraintError{msg: err.Error(), wrap: err}
+ }
+ return nil, err
+ }
+ _u.mutation.done = true
+ return _node, nil
+}
diff --git a/backend/ent/intercept/intercept.go b/backend/ent/intercept/intercept.go
index 8d8320bb..157c5122 100644
--- a/backend/ent/intercept/intercept.go
+++ b/backend/ent/intercept/intercept.go
@@ -13,12 +13,16 @@ import (
"github.com/Wei-Shaw/sub2api/ent/announcement"
"github.com/Wei-Shaw/sub2api/ent/announcementread"
"github.com/Wei-Shaw/sub2api/ent/apikey"
+ "github.com/Wei-Shaw/sub2api/ent/authidentity"
+ "github.com/Wei-Shaw/sub2api/ent/authidentitychannel"
"github.com/Wei-Shaw/sub2api/ent/errorpassthroughrule"
"github.com/Wei-Shaw/sub2api/ent/group"
"github.com/Wei-Shaw/sub2api/ent/idempotencyrecord"
+ "github.com/Wei-Shaw/sub2api/ent/identityadoptiondecision"
"github.com/Wei-Shaw/sub2api/ent/paymentauditlog"
"github.com/Wei-Shaw/sub2api/ent/paymentorder"
"github.com/Wei-Shaw/sub2api/ent/paymentproviderinstance"
+ "github.com/Wei-Shaw/sub2api/ent/pendingauthsession"
"github.com/Wei-Shaw/sub2api/ent/predicate"
"github.com/Wei-Shaw/sub2api/ent/promocode"
"github.com/Wei-Shaw/sub2api/ent/promocodeusage"
@@ -228,6 +232,60 @@ func (f TraverseAnnouncementRead) Traverse(ctx context.Context, q ent.Query) err
return fmt.Errorf("unexpected query type %T. expect *ent.AnnouncementReadQuery", q)
}
+// The AuthIdentityFunc type is an adapter to allow the use of ordinary function as a Querier.
+type AuthIdentityFunc func(context.Context, *ent.AuthIdentityQuery) (ent.Value, error)
+
+// Query calls f(ctx, q).
+func (f AuthIdentityFunc) Query(ctx context.Context, q ent.Query) (ent.Value, error) {
+ if q, ok := q.(*ent.AuthIdentityQuery); ok {
+ return f(ctx, q)
+ }
+ return nil, fmt.Errorf("unexpected query type %T. expect *ent.AuthIdentityQuery", q)
+}
+
+// The TraverseAuthIdentity type is an adapter to allow the use of ordinary function as Traverser.
+type TraverseAuthIdentity func(context.Context, *ent.AuthIdentityQuery) error
+
+// Intercept is a dummy implementation of Intercept that returns the next Querier in the pipeline.
+func (f TraverseAuthIdentity) Intercept(next ent.Querier) ent.Querier {
+ return next
+}
+
+// Traverse calls f(ctx, q).
+func (f TraverseAuthIdentity) Traverse(ctx context.Context, q ent.Query) error {
+ if q, ok := q.(*ent.AuthIdentityQuery); ok {
+ return f(ctx, q)
+ }
+ return fmt.Errorf("unexpected query type %T. expect *ent.AuthIdentityQuery", q)
+}
+
+// The AuthIdentityChannelFunc type is an adapter to allow the use of ordinary function as a Querier.
+type AuthIdentityChannelFunc func(context.Context, *ent.AuthIdentityChannelQuery) (ent.Value, error)
+
+// Query calls f(ctx, q).
+func (f AuthIdentityChannelFunc) Query(ctx context.Context, q ent.Query) (ent.Value, error) {
+ if q, ok := q.(*ent.AuthIdentityChannelQuery); ok {
+ return f(ctx, q)
+ }
+ return nil, fmt.Errorf("unexpected query type %T. expect *ent.AuthIdentityChannelQuery", q)
+}
+
+// The TraverseAuthIdentityChannel type is an adapter to allow the use of ordinary function as Traverser.
+type TraverseAuthIdentityChannel func(context.Context, *ent.AuthIdentityChannelQuery) error
+
+// Intercept is a dummy implementation of Intercept that returns the next Querier in the pipeline.
+func (f TraverseAuthIdentityChannel) Intercept(next ent.Querier) ent.Querier {
+ return next
+}
+
+// Traverse calls f(ctx, q).
+func (f TraverseAuthIdentityChannel) Traverse(ctx context.Context, q ent.Query) error {
+ if q, ok := q.(*ent.AuthIdentityChannelQuery); ok {
+ return f(ctx, q)
+ }
+ return fmt.Errorf("unexpected query type %T. expect *ent.AuthIdentityChannelQuery", q)
+}
+
// The ErrorPassthroughRuleFunc type is an adapter to allow the use of ordinary function as a Querier.
type ErrorPassthroughRuleFunc func(context.Context, *ent.ErrorPassthroughRuleQuery) (ent.Value, error)
@@ -309,6 +367,33 @@ func (f TraverseIdempotencyRecord) Traverse(ctx context.Context, q ent.Query) er
return fmt.Errorf("unexpected query type %T. expect *ent.IdempotencyRecordQuery", q)
}
+// The IdentityAdoptionDecisionFunc type is an adapter to allow the use of ordinary function as a Querier.
+type IdentityAdoptionDecisionFunc func(context.Context, *ent.IdentityAdoptionDecisionQuery) (ent.Value, error)
+
+// Query calls f(ctx, q).
+func (f IdentityAdoptionDecisionFunc) Query(ctx context.Context, q ent.Query) (ent.Value, error) {
+ if q, ok := q.(*ent.IdentityAdoptionDecisionQuery); ok {
+ return f(ctx, q)
+ }
+ return nil, fmt.Errorf("unexpected query type %T. expect *ent.IdentityAdoptionDecisionQuery", q)
+}
+
+// The TraverseIdentityAdoptionDecision type is an adapter to allow the use of ordinary function as Traverser.
+type TraverseIdentityAdoptionDecision func(context.Context, *ent.IdentityAdoptionDecisionQuery) error
+
+// Intercept is a dummy implementation of Intercept that returns the next Querier in the pipeline.
+func (f TraverseIdentityAdoptionDecision) Intercept(next ent.Querier) ent.Querier {
+ return next
+}
+
+// Traverse calls f(ctx, q).
+func (f TraverseIdentityAdoptionDecision) Traverse(ctx context.Context, q ent.Query) error {
+ if q, ok := q.(*ent.IdentityAdoptionDecisionQuery); ok {
+ return f(ctx, q)
+ }
+ return fmt.Errorf("unexpected query type %T. expect *ent.IdentityAdoptionDecisionQuery", q)
+}
+
// The PaymentAuditLogFunc type is an adapter to allow the use of ordinary function as a Querier.
type PaymentAuditLogFunc func(context.Context, *ent.PaymentAuditLogQuery) (ent.Value, error)
@@ -390,6 +475,33 @@ func (f TraversePaymentProviderInstance) Traverse(ctx context.Context, q ent.Que
return fmt.Errorf("unexpected query type %T. expect *ent.PaymentProviderInstanceQuery", q)
}
+// The PendingAuthSessionFunc type is an adapter to allow the use of ordinary function as a Querier.
+type PendingAuthSessionFunc func(context.Context, *ent.PendingAuthSessionQuery) (ent.Value, error)
+
+// Query calls f(ctx, q).
+func (f PendingAuthSessionFunc) Query(ctx context.Context, q ent.Query) (ent.Value, error) {
+ if q, ok := q.(*ent.PendingAuthSessionQuery); ok {
+ return f(ctx, q)
+ }
+ return nil, fmt.Errorf("unexpected query type %T. expect *ent.PendingAuthSessionQuery", q)
+}
+
+// The TraversePendingAuthSession type is an adapter to allow the use of ordinary function as Traverser.
+type TraversePendingAuthSession func(context.Context, *ent.PendingAuthSessionQuery) error
+
+// Intercept is a dummy implementation of Intercept that returns the next Querier in the pipeline.
+func (f TraversePendingAuthSession) Intercept(next ent.Querier) ent.Querier {
+ return next
+}
+
+// Traverse calls f(ctx, q).
+func (f TraversePendingAuthSession) Traverse(ctx context.Context, q ent.Query) error {
+ if q, ok := q.(*ent.PendingAuthSessionQuery); ok {
+ return f(ctx, q)
+ }
+ return fmt.Errorf("unexpected query type %T. expect *ent.PendingAuthSessionQuery", q)
+}
+
// The PromoCodeFunc type is an adapter to allow the use of ordinary function as a Querier.
type PromoCodeFunc func(context.Context, *ent.PromoCodeQuery) (ent.Value, error)
@@ -808,18 +920,26 @@ func NewQuery(q ent.Query) (Query, error) {
return &query[*ent.AnnouncementQuery, predicate.Announcement, announcement.OrderOption]{typ: ent.TypeAnnouncement, tq: q}, nil
case *ent.AnnouncementReadQuery:
return &query[*ent.AnnouncementReadQuery, predicate.AnnouncementRead, announcementread.OrderOption]{typ: ent.TypeAnnouncementRead, tq: q}, nil
+ case *ent.AuthIdentityQuery:
+ return &query[*ent.AuthIdentityQuery, predicate.AuthIdentity, authidentity.OrderOption]{typ: ent.TypeAuthIdentity, tq: q}, nil
+ case *ent.AuthIdentityChannelQuery:
+ return &query[*ent.AuthIdentityChannelQuery, predicate.AuthIdentityChannel, authidentitychannel.OrderOption]{typ: ent.TypeAuthIdentityChannel, tq: q}, nil
case *ent.ErrorPassthroughRuleQuery:
return &query[*ent.ErrorPassthroughRuleQuery, predicate.ErrorPassthroughRule, errorpassthroughrule.OrderOption]{typ: ent.TypeErrorPassthroughRule, tq: q}, nil
case *ent.GroupQuery:
return &query[*ent.GroupQuery, predicate.Group, group.OrderOption]{typ: ent.TypeGroup, tq: q}, nil
case *ent.IdempotencyRecordQuery:
return &query[*ent.IdempotencyRecordQuery, predicate.IdempotencyRecord, idempotencyrecord.OrderOption]{typ: ent.TypeIdempotencyRecord, tq: q}, nil
+ case *ent.IdentityAdoptionDecisionQuery:
+ return &query[*ent.IdentityAdoptionDecisionQuery, predicate.IdentityAdoptionDecision, identityadoptiondecision.OrderOption]{typ: ent.TypeIdentityAdoptionDecision, tq: q}, nil
case *ent.PaymentAuditLogQuery:
return &query[*ent.PaymentAuditLogQuery, predicate.PaymentAuditLog, paymentauditlog.OrderOption]{typ: ent.TypePaymentAuditLog, tq: q}, nil
case *ent.PaymentOrderQuery:
return &query[*ent.PaymentOrderQuery, predicate.PaymentOrder, paymentorder.OrderOption]{typ: ent.TypePaymentOrder, tq: q}, nil
case *ent.PaymentProviderInstanceQuery:
return &query[*ent.PaymentProviderInstanceQuery, predicate.PaymentProviderInstance, paymentproviderinstance.OrderOption]{typ: ent.TypePaymentProviderInstance, tq: q}, nil
+ case *ent.PendingAuthSessionQuery:
+ return &query[*ent.PendingAuthSessionQuery, predicate.PendingAuthSession, pendingauthsession.OrderOption]{typ: ent.TypePendingAuthSession, tq: q}, nil
case *ent.PromoCodeQuery:
return &query[*ent.PromoCodeQuery, predicate.PromoCode, promocode.OrderOption]{typ: ent.TypePromoCode, tq: q}, nil
case *ent.PromoCodeUsageQuery:
diff --git a/backend/ent/migrate/schema.go b/backend/ent/migrate/schema.go
index 68bdbf55..bf41e73b 100644
--- a/backend/ent/migrate/schema.go
+++ b/backend/ent/migrate/schema.go
@@ -338,6 +338,89 @@ var (
},
},
}
+ // AuthIdentitiesColumns holds the columns for the "auth_identities" table.
+ AuthIdentitiesColumns = []*schema.Column{
+ {Name: "id", Type: field.TypeInt64, Increment: true},
+ {Name: "created_at", Type: field.TypeTime, SchemaType: map[string]string{"postgres": "timestamptz"}},
+ {Name: "updated_at", Type: field.TypeTime, SchemaType: map[string]string{"postgres": "timestamptz"}},
+ {Name: "provider_type", Type: field.TypeString, Size: 20},
+ {Name: "provider_key", Type: field.TypeString, SchemaType: map[string]string{"postgres": "text"}},
+ {Name: "provider_subject", Type: field.TypeString, SchemaType: map[string]string{"postgres": "text"}},
+ {Name: "verified_at", Type: field.TypeTime, Nullable: true, SchemaType: map[string]string{"postgres": "timestamptz"}},
+ {Name: "issuer", Type: field.TypeString, Nullable: true, SchemaType: map[string]string{"postgres": "text"}},
+ {Name: "metadata", Type: field.TypeJSON, SchemaType: map[string]string{"postgres": "jsonb"}},
+ {Name: "user_id", Type: field.TypeInt64},
+ }
+ // AuthIdentitiesTable holds the schema information for the "auth_identities" table.
+ AuthIdentitiesTable = &schema.Table{
+ Name: "auth_identities",
+ Columns: AuthIdentitiesColumns,
+ PrimaryKey: []*schema.Column{AuthIdentitiesColumns[0]},
+ ForeignKeys: []*schema.ForeignKey{
+ {
+ Symbol: "auth_identities_users_auth_identities",
+ Columns: []*schema.Column{AuthIdentitiesColumns[9]},
+ RefColumns: []*schema.Column{UsersColumns[0]},
+ OnDelete: schema.NoAction,
+ },
+ },
+ Indexes: []*schema.Index{
+ {
+ Name: "authidentity_provider_type_provider_key_provider_subject",
+ Unique: true,
+ Columns: []*schema.Column{AuthIdentitiesColumns[3], AuthIdentitiesColumns[4], AuthIdentitiesColumns[5]},
+ },
+ {
+ Name: "authidentity_user_id",
+ Unique: false,
+ Columns: []*schema.Column{AuthIdentitiesColumns[9]},
+ },
+ {
+ Name: "authidentity_user_id_provider_type",
+ Unique: false,
+ Columns: []*schema.Column{AuthIdentitiesColumns[9], AuthIdentitiesColumns[3]},
+ },
+ },
+ }
+ // AuthIdentityChannelsColumns holds the columns for the "auth_identity_channels" table.
+ AuthIdentityChannelsColumns = []*schema.Column{
+ {Name: "id", Type: field.TypeInt64, Increment: true},
+ {Name: "created_at", Type: field.TypeTime, SchemaType: map[string]string{"postgres": "timestamptz"}},
+ {Name: "updated_at", Type: field.TypeTime, SchemaType: map[string]string{"postgres": "timestamptz"}},
+ {Name: "provider_type", Type: field.TypeString, Size: 20},
+ {Name: "provider_key", Type: field.TypeString, SchemaType: map[string]string{"postgres": "text"}},
+ {Name: "channel", Type: field.TypeString, Size: 20},
+ {Name: "channel_app_id", Type: field.TypeString, SchemaType: map[string]string{"postgres": "text"}},
+ {Name: "channel_subject", Type: field.TypeString, SchemaType: map[string]string{"postgres": "text"}},
+ {Name: "metadata", Type: field.TypeJSON, SchemaType: map[string]string{"postgres": "jsonb"}},
+ {Name: "identity_id", Type: field.TypeInt64},
+ }
+ // AuthIdentityChannelsTable holds the schema information for the "auth_identity_channels" table.
+ AuthIdentityChannelsTable = &schema.Table{
+ Name: "auth_identity_channels",
+ Columns: AuthIdentityChannelsColumns,
+ PrimaryKey: []*schema.Column{AuthIdentityChannelsColumns[0]},
+ ForeignKeys: []*schema.ForeignKey{
+ {
+ Symbol: "auth_identity_channels_auth_identities_channels",
+ Columns: []*schema.Column{AuthIdentityChannelsColumns[9]},
+ RefColumns: []*schema.Column{AuthIdentitiesColumns[0]},
+ OnDelete: schema.NoAction,
+ },
+ },
+ Indexes: []*schema.Index{
+ {
+ Name: "authidentitychannel_provider_type_provider_key_channel_channel_app_id_channel_subject",
+ Unique: true,
+ Columns: []*schema.Column{AuthIdentityChannelsColumns[3], AuthIdentityChannelsColumns[4], AuthIdentityChannelsColumns[5], AuthIdentityChannelsColumns[6], AuthIdentityChannelsColumns[7]},
+ },
+ {
+ Name: "authidentitychannel_identity_id",
+ Unique: false,
+ Columns: []*schema.Column{AuthIdentityChannelsColumns[9]},
+ },
+ },
+ }
// ErrorPassthroughRulesColumns holds the columns for the "error_passthrough_rules" table.
ErrorPassthroughRulesColumns = []*schema.Column{
{Name: "id", Type: field.TypeInt64, Increment: true},
@@ -485,6 +568,49 @@ var (
},
},
}
+ // IdentityAdoptionDecisionsColumns holds the columns for the "identity_adoption_decisions" table.
+ IdentityAdoptionDecisionsColumns = []*schema.Column{
+ {Name: "id", Type: field.TypeInt64, Increment: true},
+ {Name: "created_at", Type: field.TypeTime, SchemaType: map[string]string{"postgres": "timestamptz"}},
+ {Name: "updated_at", Type: field.TypeTime, SchemaType: map[string]string{"postgres": "timestamptz"}},
+ {Name: "adopt_display_name", Type: field.TypeBool, Default: false},
+ {Name: "adopt_avatar", Type: field.TypeBool, Default: false},
+ {Name: "decided_at", Type: field.TypeTime, SchemaType: map[string]string{"postgres": "timestamptz"}},
+ {Name: "identity_id", Type: field.TypeInt64, Nullable: true},
+ {Name: "pending_auth_session_id", Type: field.TypeInt64, Unique: true},
+ }
+ // IdentityAdoptionDecisionsTable holds the schema information for the "identity_adoption_decisions" table.
+ IdentityAdoptionDecisionsTable = &schema.Table{
+ Name: "identity_adoption_decisions",
+ Columns: IdentityAdoptionDecisionsColumns,
+ PrimaryKey: []*schema.Column{IdentityAdoptionDecisionsColumns[0]},
+ ForeignKeys: []*schema.ForeignKey{
+ {
+ Symbol: "identity_adoption_decisions_auth_identities_adoption_decisions",
+ Columns: []*schema.Column{IdentityAdoptionDecisionsColumns[6]},
+ RefColumns: []*schema.Column{AuthIdentitiesColumns[0]},
+ OnDelete: schema.SetNull,
+ },
+ {
+ Symbol: "identity_adoption_decisions_pending_auth_sessions_adoption_decision",
+ Columns: []*schema.Column{IdentityAdoptionDecisionsColumns[7]},
+ RefColumns: []*schema.Column{PendingAuthSessionsColumns[0]},
+ OnDelete: schema.NoAction,
+ },
+ },
+ Indexes: []*schema.Index{
+ {
+ Name: "identityadoptiondecision_pending_auth_session_id",
+ Unique: true,
+ Columns: []*schema.Column{IdentityAdoptionDecisionsColumns[7]},
+ },
+ {
+ Name: "identityadoptiondecision_identity_id",
+ Unique: false,
+ Columns: []*schema.Column{IdentityAdoptionDecisionsColumns[6]},
+ },
+ },
+ }
// PaymentAuditLogsColumns holds the columns for the "payment_audit_logs" table.
PaymentAuditLogsColumns = []*schema.Column{
{Name: "id", Type: field.TypeInt64, Increment: true},
@@ -638,6 +764,72 @@ var (
},
},
}
+ // PendingAuthSessionsColumns holds the columns for the "pending_auth_sessions" table.
+ PendingAuthSessionsColumns = []*schema.Column{
+ {Name: "id", Type: field.TypeInt64, Increment: true},
+ {Name: "created_at", Type: field.TypeTime, SchemaType: map[string]string{"postgres": "timestamptz"}},
+ {Name: "updated_at", Type: field.TypeTime, SchemaType: map[string]string{"postgres": "timestamptz"}},
+ {Name: "session_token", Type: field.TypeString, Size: 255},
+ {Name: "intent", Type: field.TypeString, Size: 40},
+ {Name: "provider_type", Type: field.TypeString, Size: 20},
+ {Name: "provider_key", Type: field.TypeString, SchemaType: map[string]string{"postgres": "text"}},
+ {Name: "provider_subject", Type: field.TypeString, SchemaType: map[string]string{"postgres": "text"}},
+ {Name: "redirect_to", Type: field.TypeString, Default: "", SchemaType: map[string]string{"postgres": "text"}},
+ {Name: "resolved_email", Type: field.TypeString, Default: "", SchemaType: map[string]string{"postgres": "text"}},
+ {Name: "registration_password_hash", Type: field.TypeString, Default: "", SchemaType: map[string]string{"postgres": "text"}},
+ {Name: "upstream_identity_claims", Type: field.TypeJSON, SchemaType: map[string]string{"postgres": "jsonb"}},
+ {Name: "local_flow_state", Type: field.TypeJSON, SchemaType: map[string]string{"postgres": "jsonb"}},
+ {Name: "browser_session_key", Type: field.TypeString, Default: "", SchemaType: map[string]string{"postgres": "text"}},
+ {Name: "completion_code_hash", Type: field.TypeString, Default: "", SchemaType: map[string]string{"postgres": "text"}},
+ {Name: "completion_code_expires_at", Type: field.TypeTime, Nullable: true, SchemaType: map[string]string{"postgres": "timestamptz"}},
+ {Name: "email_verified_at", Type: field.TypeTime, Nullable: true, SchemaType: map[string]string{"postgres": "timestamptz"}},
+ {Name: "password_verified_at", Type: field.TypeTime, Nullable: true, SchemaType: map[string]string{"postgres": "timestamptz"}},
+ {Name: "totp_verified_at", Type: field.TypeTime, Nullable: true, SchemaType: map[string]string{"postgres": "timestamptz"}},
+ {Name: "expires_at", Type: field.TypeTime, SchemaType: map[string]string{"postgres": "timestamptz"}},
+ {Name: "consumed_at", Type: field.TypeTime, Nullable: true, SchemaType: map[string]string{"postgres": "timestamptz"}},
+ {Name: "target_user_id", Type: field.TypeInt64, Nullable: true},
+ }
+ // PendingAuthSessionsTable holds the schema information for the "pending_auth_sessions" table.
+ PendingAuthSessionsTable = &schema.Table{
+ Name: "pending_auth_sessions",
+ Columns: PendingAuthSessionsColumns,
+ PrimaryKey: []*schema.Column{PendingAuthSessionsColumns[0]},
+ ForeignKeys: []*schema.ForeignKey{
+ {
+ Symbol: "pending_auth_sessions_users_pending_auth_sessions",
+ Columns: []*schema.Column{PendingAuthSessionsColumns[21]},
+ RefColumns: []*schema.Column{UsersColumns[0]},
+ OnDelete: schema.SetNull,
+ },
+ },
+ Indexes: []*schema.Index{
+ {
+ Name: "pendingauthsession_session_token",
+ Unique: true,
+ Columns: []*schema.Column{PendingAuthSessionsColumns[3]},
+ },
+ {
+ Name: "pendingauthsession_target_user_id",
+ Unique: false,
+ Columns: []*schema.Column{PendingAuthSessionsColumns[21]},
+ },
+ {
+ Name: "pendingauthsession_expires_at",
+ Unique: false,
+ Columns: []*schema.Column{PendingAuthSessionsColumns[19]},
+ },
+ {
+ Name: "pendingauthsession_provider_type_provider_key_provider_subject",
+ Unique: false,
+ Columns: []*schema.Column{PendingAuthSessionsColumns[5], PendingAuthSessionsColumns[6], PendingAuthSessionsColumns[7]},
+ },
+ {
+ Name: "pendingauthsession_completion_code_hash",
+ Unique: false,
+ Columns: []*schema.Column{PendingAuthSessionsColumns[14]},
+ },
+ },
+ }
// PromoCodesColumns holds the columns for the "promo_codes" table.
PromoCodesColumns = []*schema.Column{
{Name: "id", Type: field.TypeInt64, Increment: true},
@@ -1079,6 +1271,9 @@ var (
{Name: "totp_secret_encrypted", Type: field.TypeString, Nullable: true, SchemaType: map[string]string{"postgres": "text"}},
{Name: "totp_enabled", Type: field.TypeBool, Default: false},
{Name: "totp_enabled_at", Type: field.TypeTime, Nullable: true},
+ {Name: "signup_source", Type: field.TypeString, Size: 20, Default: "email"},
+ {Name: "last_login_at", Type: field.TypeTime, Nullable: true, SchemaType: map[string]string{"postgres": "timestamptz"}},
+ {Name: "last_active_at", Type: field.TypeTime, Nullable: true, SchemaType: map[string]string{"postgres": "timestamptz"}},
{Name: "balance_notify_enabled", Type: field.TypeBool, Default: true},
{Name: "balance_notify_threshold_type", Type: field.TypeString, Default: "fixed"},
{Name: "balance_notify_threshold", Type: field.TypeFloat64, Nullable: true, SchemaType: map[string]string{"postgres": "decimal(20,8)"}},
@@ -1318,12 +1513,16 @@ var (
AccountGroupsTable,
AnnouncementsTable,
AnnouncementReadsTable,
+ AuthIdentitiesTable,
+ AuthIdentityChannelsTable,
ErrorPassthroughRulesTable,
GroupsTable,
IdempotencyRecordsTable,
+ IdentityAdoptionDecisionsTable,
PaymentAuditLogsTable,
PaymentOrdersTable,
PaymentProviderInstancesTable,
+ PendingAuthSessionsTable,
PromoCodesTable,
PromoCodeUsagesTable,
ProxiesTable,
@@ -1365,6 +1564,14 @@ func init() {
AnnouncementReadsTable.Annotation = &entsql.Annotation{
Table: "announcement_reads",
}
+ AuthIdentitiesTable.ForeignKeys[0].RefTable = UsersTable
+ AuthIdentitiesTable.Annotation = &entsql.Annotation{
+ Table: "auth_identities",
+ }
+ AuthIdentityChannelsTable.ForeignKeys[0].RefTable = AuthIdentitiesTable
+ AuthIdentityChannelsTable.Annotation = &entsql.Annotation{
+ Table: "auth_identity_channels",
+ }
ErrorPassthroughRulesTable.Annotation = &entsql.Annotation{
Table: "error_passthrough_rules",
}
@@ -1374,6 +1581,11 @@ func init() {
IdempotencyRecordsTable.Annotation = &entsql.Annotation{
Table: "idempotency_records",
}
+ IdentityAdoptionDecisionsTable.ForeignKeys[0].RefTable = AuthIdentitiesTable
+ IdentityAdoptionDecisionsTable.ForeignKeys[1].RefTable = PendingAuthSessionsTable
+ IdentityAdoptionDecisionsTable.Annotation = &entsql.Annotation{
+ Table: "identity_adoption_decisions",
+ }
PaymentAuditLogsTable.Annotation = &entsql.Annotation{
Table: "payment_audit_logs",
}
@@ -1384,6 +1596,10 @@ func init() {
PaymentProviderInstancesTable.Annotation = &entsql.Annotation{
Table: "payment_provider_instances",
}
+ PendingAuthSessionsTable.ForeignKeys[0].RefTable = UsersTable
+ PendingAuthSessionsTable.Annotation = &entsql.Annotation{
+ Table: "pending_auth_sessions",
+ }
PromoCodesTable.Annotation = &entsql.Annotation{
Table: "promo_codes",
}
diff --git a/backend/ent/mutation.go b/backend/ent/mutation.go
index 524ccb92..12905c9a 100644
--- a/backend/ent/mutation.go
+++ b/backend/ent/mutation.go
@@ -17,12 +17,16 @@ import (
"github.com/Wei-Shaw/sub2api/ent/announcement"
"github.com/Wei-Shaw/sub2api/ent/announcementread"
"github.com/Wei-Shaw/sub2api/ent/apikey"
+ "github.com/Wei-Shaw/sub2api/ent/authidentity"
+ "github.com/Wei-Shaw/sub2api/ent/authidentitychannel"
"github.com/Wei-Shaw/sub2api/ent/errorpassthroughrule"
"github.com/Wei-Shaw/sub2api/ent/group"
"github.com/Wei-Shaw/sub2api/ent/idempotencyrecord"
+ "github.com/Wei-Shaw/sub2api/ent/identityadoptiondecision"
"github.com/Wei-Shaw/sub2api/ent/paymentauditlog"
"github.com/Wei-Shaw/sub2api/ent/paymentorder"
"github.com/Wei-Shaw/sub2api/ent/paymentproviderinstance"
+ "github.com/Wei-Shaw/sub2api/ent/pendingauthsession"
"github.com/Wei-Shaw/sub2api/ent/predicate"
"github.com/Wei-Shaw/sub2api/ent/promocode"
"github.com/Wei-Shaw/sub2api/ent/promocodeusage"
@@ -51,32 +55,36 @@ const (
OpUpdateOne = ent.OpUpdateOne
// Node types.
- TypeAPIKey = "APIKey"
- TypeAccount = "Account"
- TypeAccountGroup = "AccountGroup"
- TypeAnnouncement = "Announcement"
- TypeAnnouncementRead = "AnnouncementRead"
- TypeErrorPassthroughRule = "ErrorPassthroughRule"
- TypeGroup = "Group"
- TypeIdempotencyRecord = "IdempotencyRecord"
- TypePaymentAuditLog = "PaymentAuditLog"
- TypePaymentOrder = "PaymentOrder"
- TypePaymentProviderInstance = "PaymentProviderInstance"
- TypePromoCode = "PromoCode"
- TypePromoCodeUsage = "PromoCodeUsage"
- TypeProxy = "Proxy"
- TypeRedeemCode = "RedeemCode"
- TypeSecuritySecret = "SecuritySecret"
- TypeSetting = "Setting"
- TypeSubscriptionPlan = "SubscriptionPlan"
- TypeTLSFingerprintProfile = "TLSFingerprintProfile"
- TypeUsageCleanupTask = "UsageCleanupTask"
- TypeUsageLog = "UsageLog"
- TypeUser = "User"
- TypeUserAllowedGroup = "UserAllowedGroup"
- TypeUserAttributeDefinition = "UserAttributeDefinition"
- TypeUserAttributeValue = "UserAttributeValue"
- TypeUserSubscription = "UserSubscription"
+ TypeAPIKey = "APIKey"
+ TypeAccount = "Account"
+ TypeAccountGroup = "AccountGroup"
+ TypeAnnouncement = "Announcement"
+ TypeAnnouncementRead = "AnnouncementRead"
+ TypeAuthIdentity = "AuthIdentity"
+ TypeAuthIdentityChannel = "AuthIdentityChannel"
+ TypeErrorPassthroughRule = "ErrorPassthroughRule"
+ TypeGroup = "Group"
+ TypeIdempotencyRecord = "IdempotencyRecord"
+ TypeIdentityAdoptionDecision = "IdentityAdoptionDecision"
+ TypePaymentAuditLog = "PaymentAuditLog"
+ TypePaymentOrder = "PaymentOrder"
+ TypePaymentProviderInstance = "PaymentProviderInstance"
+ TypePendingAuthSession = "PendingAuthSession"
+ TypePromoCode = "PromoCode"
+ TypePromoCodeUsage = "PromoCodeUsage"
+ TypeProxy = "Proxy"
+ TypeRedeemCode = "RedeemCode"
+ TypeSecuritySecret = "SecuritySecret"
+ TypeSetting = "Setting"
+ TypeSubscriptionPlan = "SubscriptionPlan"
+ TypeTLSFingerprintProfile = "TLSFingerprintProfile"
+ TypeUsageCleanupTask = "UsageCleanupTask"
+ TypeUsageLog = "UsageLog"
+ TypeUser = "User"
+ TypeUserAllowedGroup = "UserAllowedGroup"
+ TypeUserAttributeDefinition = "UserAttributeDefinition"
+ TypeUserAttributeValue = "UserAttributeValue"
+ TypeUserSubscription = "UserSubscription"
)
// APIKeyMutation represents an operation that mutates the APIKey nodes in the graph.
@@ -6887,6 +6895,1845 @@ func (m *AnnouncementReadMutation) ResetEdge(name string) error {
return fmt.Errorf("unknown AnnouncementRead edge %s", name)
}
+// AuthIdentityMutation represents an operation that mutates the AuthIdentity nodes in the graph.
+type AuthIdentityMutation struct {
+ config
+ op Op
+ typ string
+ id *int64
+ created_at *time.Time
+ updated_at *time.Time
+ provider_type *string
+ provider_key *string
+ provider_subject *string
+ verified_at *time.Time
+ issuer *string
+ metadata *map[string]interface{}
+ clearedFields map[string]struct{}
+ user *int64
+ cleareduser bool
+ channels map[int64]struct{}
+ removedchannels map[int64]struct{}
+ clearedchannels bool
+ adoption_decisions map[int64]struct{}
+ removedadoption_decisions map[int64]struct{}
+ clearedadoption_decisions bool
+ done bool
+ oldValue func(context.Context) (*AuthIdentity, error)
+ predicates []predicate.AuthIdentity
+}
+
+var _ ent.Mutation = (*AuthIdentityMutation)(nil)
+
+// authidentityOption allows management of the mutation configuration using functional options.
+type authidentityOption func(*AuthIdentityMutation)
+
+// newAuthIdentityMutation creates new mutation for the AuthIdentity entity.
+func newAuthIdentityMutation(c config, op Op, opts ...authidentityOption) *AuthIdentityMutation {
+ m := &AuthIdentityMutation{
+ config: c,
+ op: op,
+ typ: TypeAuthIdentity,
+ clearedFields: make(map[string]struct{}),
+ }
+ for _, opt := range opts {
+ opt(m)
+ }
+ return m
+}
+
+// withAuthIdentityID sets the ID field of the mutation.
+func withAuthIdentityID(id int64) authidentityOption {
+ return func(m *AuthIdentityMutation) {
+ var (
+ err error
+ once sync.Once
+ value *AuthIdentity
+ )
+ m.oldValue = func(ctx context.Context) (*AuthIdentity, error) {
+ once.Do(func() {
+ if m.done {
+ err = errors.New("querying old values post mutation is not allowed")
+ } else {
+ value, err = m.Client().AuthIdentity.Get(ctx, id)
+ }
+ })
+ return value, err
+ }
+ m.id = &id
+ }
+}
+
+// withAuthIdentity sets the old AuthIdentity of the mutation.
+func withAuthIdentity(node *AuthIdentity) authidentityOption {
+ return func(m *AuthIdentityMutation) {
+ m.oldValue = func(context.Context) (*AuthIdentity, error) {
+ return node, nil
+ }
+ m.id = &node.ID
+ }
+}
+
+// Client returns a new `ent.Client` from the mutation. If the mutation was
+// executed in a transaction (ent.Tx), a transactional client is returned.
+func (m AuthIdentityMutation) Client() *Client {
+ client := &Client{config: m.config}
+ client.init()
+ return client
+}
+
+// Tx returns an `ent.Tx` for mutations that were executed in transactions;
+// it returns an error otherwise.
+func (m AuthIdentityMutation) Tx() (*Tx, error) {
+ if _, ok := m.driver.(*txDriver); !ok {
+ return nil, errors.New("ent: mutation is not running in a transaction")
+ }
+ tx := &Tx{config: m.config}
+ tx.init()
+ return tx, nil
+}
+
+// ID returns the ID value in the mutation. Note that the ID is only available
+// if it was provided to the builder or after it was returned from the database.
+func (m *AuthIdentityMutation) ID() (id int64, exists bool) {
+ if m.id == nil {
+ return
+ }
+ return *m.id, true
+}
+
+// IDs queries the database and returns the entity ids that match the mutation's predicate.
+// That means, if the mutation is applied within a transaction with an isolation level such
+// as sql.LevelSerializable, the returned ids match the ids of the rows that will be updated
+// or updated by the mutation.
+func (m *AuthIdentityMutation) IDs(ctx context.Context) ([]int64, error) {
+ switch {
+ case m.op.Is(OpUpdateOne | OpDeleteOne):
+ id, exists := m.ID()
+ if exists {
+ return []int64{id}, nil
+ }
+ fallthrough
+ case m.op.Is(OpUpdate | OpDelete):
+ return m.Client().AuthIdentity.Query().Where(m.predicates...).IDs(ctx)
+ default:
+ return nil, fmt.Errorf("IDs is not allowed on %s operations", m.op)
+ }
+}
+
+// SetCreatedAt sets the "created_at" field.
+func (m *AuthIdentityMutation) SetCreatedAt(t time.Time) {
+ m.created_at = &t
+}
+
+// CreatedAt returns the value of the "created_at" field in the mutation.
+func (m *AuthIdentityMutation) CreatedAt() (r time.Time, exists bool) {
+ v := m.created_at
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// OldCreatedAt returns the old "created_at" field's value of the AuthIdentity entity.
+// If the AuthIdentity object wasn't provided to the builder, the object is fetched from the database.
+// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
+func (m *AuthIdentityMutation) OldCreatedAt(ctx context.Context) (v time.Time, err error) {
+ if !m.op.Is(OpUpdateOne) {
+ return v, errors.New("OldCreatedAt is only allowed on UpdateOne operations")
+ }
+ if m.id == nil || m.oldValue == nil {
+ return v, errors.New("OldCreatedAt requires an ID field in the mutation")
+ }
+ oldValue, err := m.oldValue(ctx)
+ if err != nil {
+ return v, fmt.Errorf("querying old value for OldCreatedAt: %w", err)
+ }
+ return oldValue.CreatedAt, nil
+}
+
+// ResetCreatedAt resets all changes to the "created_at" field.
+func (m *AuthIdentityMutation) ResetCreatedAt() {
+ m.created_at = nil
+}
+
+// SetUpdatedAt sets the "updated_at" field.
+func (m *AuthIdentityMutation) SetUpdatedAt(t time.Time) {
+ m.updated_at = &t
+}
+
+// UpdatedAt returns the value of the "updated_at" field in the mutation.
+func (m *AuthIdentityMutation) UpdatedAt() (r time.Time, exists bool) {
+ v := m.updated_at
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// OldUpdatedAt returns the old "updated_at" field's value of the AuthIdentity entity.
+// If the AuthIdentity object wasn't provided to the builder, the object is fetched from the database.
+// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
+func (m *AuthIdentityMutation) OldUpdatedAt(ctx context.Context) (v time.Time, err error) {
+ if !m.op.Is(OpUpdateOne) {
+ return v, errors.New("OldUpdatedAt is only allowed on UpdateOne operations")
+ }
+ if m.id == nil || m.oldValue == nil {
+ return v, errors.New("OldUpdatedAt requires an ID field in the mutation")
+ }
+ oldValue, err := m.oldValue(ctx)
+ if err != nil {
+ return v, fmt.Errorf("querying old value for OldUpdatedAt: %w", err)
+ }
+ return oldValue.UpdatedAt, nil
+}
+
+// ResetUpdatedAt resets all changes to the "updated_at" field.
+func (m *AuthIdentityMutation) ResetUpdatedAt() {
+ m.updated_at = nil
+}
+
+// SetUserID sets the "user_id" field.
+func (m *AuthIdentityMutation) SetUserID(i int64) {
+ m.user = &i
+}
+
+// UserID returns the value of the "user_id" field in the mutation.
+func (m *AuthIdentityMutation) UserID() (r int64, exists bool) {
+ v := m.user
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// OldUserID returns the old "user_id" field's value of the AuthIdentity entity.
+// If the AuthIdentity object wasn't provided to the builder, the object is fetched from the database.
+// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
+func (m *AuthIdentityMutation) OldUserID(ctx context.Context) (v int64, err error) {
+ if !m.op.Is(OpUpdateOne) {
+ return v, errors.New("OldUserID is only allowed on UpdateOne operations")
+ }
+ if m.id == nil || m.oldValue == nil {
+ return v, errors.New("OldUserID requires an ID field in the mutation")
+ }
+ oldValue, err := m.oldValue(ctx)
+ if err != nil {
+ return v, fmt.Errorf("querying old value for OldUserID: %w", err)
+ }
+ return oldValue.UserID, nil
+}
+
+// ResetUserID resets all changes to the "user_id" field.
+func (m *AuthIdentityMutation) ResetUserID() {
+ m.user = nil
+}
+
+// SetProviderType sets the "provider_type" field.
+func (m *AuthIdentityMutation) SetProviderType(s string) {
+ m.provider_type = &s
+}
+
+// ProviderType returns the value of the "provider_type" field in the mutation.
+func (m *AuthIdentityMutation) ProviderType() (r string, exists bool) {
+ v := m.provider_type
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// OldProviderType returns the old "provider_type" field's value of the AuthIdentity entity.
+// If the AuthIdentity object wasn't provided to the builder, the object is fetched from the database.
+// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
+func (m *AuthIdentityMutation) OldProviderType(ctx context.Context) (v string, err error) {
+ if !m.op.Is(OpUpdateOne) {
+ return v, errors.New("OldProviderType is only allowed on UpdateOne operations")
+ }
+ if m.id == nil || m.oldValue == nil {
+ return v, errors.New("OldProviderType requires an ID field in the mutation")
+ }
+ oldValue, err := m.oldValue(ctx)
+ if err != nil {
+ return v, fmt.Errorf("querying old value for OldProviderType: %w", err)
+ }
+ return oldValue.ProviderType, nil
+}
+
+// ResetProviderType resets all changes to the "provider_type" field.
+func (m *AuthIdentityMutation) ResetProviderType() {
+ m.provider_type = nil
+}
+
+// SetProviderKey sets the "provider_key" field.
+func (m *AuthIdentityMutation) SetProviderKey(s string) {
+ m.provider_key = &s
+}
+
+// ProviderKey returns the value of the "provider_key" field in the mutation.
+func (m *AuthIdentityMutation) ProviderKey() (r string, exists bool) {
+ v := m.provider_key
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// OldProviderKey returns the old "provider_key" field's value of the AuthIdentity entity.
+// If the AuthIdentity object wasn't provided to the builder, the object is fetched from the database.
+// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
+func (m *AuthIdentityMutation) OldProviderKey(ctx context.Context) (v string, err error) {
+ if !m.op.Is(OpUpdateOne) {
+ return v, errors.New("OldProviderKey is only allowed on UpdateOne operations")
+ }
+ if m.id == nil || m.oldValue == nil {
+ return v, errors.New("OldProviderKey requires an ID field in the mutation")
+ }
+ oldValue, err := m.oldValue(ctx)
+ if err != nil {
+ return v, fmt.Errorf("querying old value for OldProviderKey: %w", err)
+ }
+ return oldValue.ProviderKey, nil
+}
+
+// ResetProviderKey resets all changes to the "provider_key" field.
+func (m *AuthIdentityMutation) ResetProviderKey() {
+ m.provider_key = nil
+}
+
+// SetProviderSubject sets the "provider_subject" field.
+func (m *AuthIdentityMutation) SetProviderSubject(s string) {
+ m.provider_subject = &s
+}
+
+// ProviderSubject returns the value of the "provider_subject" field in the mutation.
+func (m *AuthIdentityMutation) ProviderSubject() (r string, exists bool) {
+ v := m.provider_subject
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// OldProviderSubject returns the old "provider_subject" field's value of the AuthIdentity entity.
+// If the AuthIdentity object wasn't provided to the builder, the object is fetched from the database.
+// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
+func (m *AuthIdentityMutation) OldProviderSubject(ctx context.Context) (v string, err error) {
+ if !m.op.Is(OpUpdateOne) {
+ return v, errors.New("OldProviderSubject is only allowed on UpdateOne operations")
+ }
+ if m.id == nil || m.oldValue == nil {
+ return v, errors.New("OldProviderSubject requires an ID field in the mutation")
+ }
+ oldValue, err := m.oldValue(ctx)
+ if err != nil {
+ return v, fmt.Errorf("querying old value for OldProviderSubject: %w", err)
+ }
+ return oldValue.ProviderSubject, nil
+}
+
+// ResetProviderSubject resets all changes to the "provider_subject" field.
+func (m *AuthIdentityMutation) ResetProviderSubject() {
+ m.provider_subject = nil
+}
+
+// SetVerifiedAt sets the "verified_at" field.
+func (m *AuthIdentityMutation) SetVerifiedAt(t time.Time) {
+ m.verified_at = &t
+}
+
+// VerifiedAt returns the value of the "verified_at" field in the mutation.
+func (m *AuthIdentityMutation) VerifiedAt() (r time.Time, exists bool) {
+ v := m.verified_at
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// OldVerifiedAt returns the old "verified_at" field's value of the AuthIdentity entity.
+// If the AuthIdentity object wasn't provided to the builder, the object is fetched from the database.
+// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
+func (m *AuthIdentityMutation) OldVerifiedAt(ctx context.Context) (v *time.Time, err error) {
+ if !m.op.Is(OpUpdateOne) {
+ return v, errors.New("OldVerifiedAt is only allowed on UpdateOne operations")
+ }
+ if m.id == nil || m.oldValue == nil {
+ return v, errors.New("OldVerifiedAt requires an ID field in the mutation")
+ }
+ oldValue, err := m.oldValue(ctx)
+ if err != nil {
+ return v, fmt.Errorf("querying old value for OldVerifiedAt: %w", err)
+ }
+ return oldValue.VerifiedAt, nil
+}
+
+// ClearVerifiedAt clears the value of the "verified_at" field.
+func (m *AuthIdentityMutation) ClearVerifiedAt() {
+ m.verified_at = nil
+ m.clearedFields[authidentity.FieldVerifiedAt] = struct{}{}
+}
+
+// VerifiedAtCleared returns if the "verified_at" field was cleared in this mutation.
+func (m *AuthIdentityMutation) VerifiedAtCleared() bool {
+ _, ok := m.clearedFields[authidentity.FieldVerifiedAt]
+ return ok
+}
+
+// ResetVerifiedAt resets all changes to the "verified_at" field.
+func (m *AuthIdentityMutation) ResetVerifiedAt() {
+ m.verified_at = nil
+ delete(m.clearedFields, authidentity.FieldVerifiedAt)
+}
+
+// SetIssuer sets the "issuer" field.
+func (m *AuthIdentityMutation) SetIssuer(s string) {
+ m.issuer = &s
+}
+
+// Issuer returns the value of the "issuer" field in the mutation.
+func (m *AuthIdentityMutation) Issuer() (r string, exists bool) {
+ v := m.issuer
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// OldIssuer returns the old "issuer" field's value of the AuthIdentity entity.
+// If the AuthIdentity object wasn't provided to the builder, the object is fetched from the database.
+// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
+func (m *AuthIdentityMutation) OldIssuer(ctx context.Context) (v *string, err error) {
+ if !m.op.Is(OpUpdateOne) {
+ return v, errors.New("OldIssuer is only allowed on UpdateOne operations")
+ }
+ if m.id == nil || m.oldValue == nil {
+ return v, errors.New("OldIssuer requires an ID field in the mutation")
+ }
+ oldValue, err := m.oldValue(ctx)
+ if err != nil {
+ return v, fmt.Errorf("querying old value for OldIssuer: %w", err)
+ }
+ return oldValue.Issuer, nil
+}
+
+// ClearIssuer clears the value of the "issuer" field.
+func (m *AuthIdentityMutation) ClearIssuer() {
+ m.issuer = nil
+ m.clearedFields[authidentity.FieldIssuer] = struct{}{}
+}
+
+// IssuerCleared returns if the "issuer" field was cleared in this mutation.
+func (m *AuthIdentityMutation) IssuerCleared() bool {
+ _, ok := m.clearedFields[authidentity.FieldIssuer]
+ return ok
+}
+
+// ResetIssuer resets all changes to the "issuer" field.
+func (m *AuthIdentityMutation) ResetIssuer() {
+ m.issuer = nil
+ delete(m.clearedFields, authidentity.FieldIssuer)
+}
+
+// SetMetadata sets the "metadata" field.
+func (m *AuthIdentityMutation) SetMetadata(value map[string]interface{}) {
+ m.metadata = &value
+}
+
+// Metadata returns the value of the "metadata" field in the mutation.
+func (m *AuthIdentityMutation) Metadata() (r map[string]interface{}, exists bool) {
+ v := m.metadata
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// OldMetadata returns the old "metadata" field's value of the AuthIdentity entity.
+// If the AuthIdentity object wasn't provided to the builder, the object is fetched from the database.
+// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
+func (m *AuthIdentityMutation) OldMetadata(ctx context.Context) (v map[string]interface{}, err error) {
+ if !m.op.Is(OpUpdateOne) {
+ return v, errors.New("OldMetadata is only allowed on UpdateOne operations")
+ }
+ if m.id == nil || m.oldValue == nil {
+ return v, errors.New("OldMetadata requires an ID field in the mutation")
+ }
+ oldValue, err := m.oldValue(ctx)
+ if err != nil {
+ return v, fmt.Errorf("querying old value for OldMetadata: %w", err)
+ }
+ return oldValue.Metadata, nil
+}
+
+// ResetMetadata resets all changes to the "metadata" field.
+func (m *AuthIdentityMutation) ResetMetadata() {
+ m.metadata = nil
+}
+
+// ClearUser clears the "user" edge to the User entity.
+func (m *AuthIdentityMutation) ClearUser() {
+ m.cleareduser = true
+ m.clearedFields[authidentity.FieldUserID] = struct{}{}
+}
+
+// UserCleared reports if the "user" edge to the User entity was cleared.
+func (m *AuthIdentityMutation) UserCleared() bool {
+ return m.cleareduser
+}
+
+// UserIDs returns the "user" edge IDs in the mutation.
+// Note that IDs always returns len(IDs) <= 1 for unique edges, and you should use
+// UserID instead. It exists only for internal usage by the builders.
+func (m *AuthIdentityMutation) UserIDs() (ids []int64) {
+ if id := m.user; id != nil {
+ ids = append(ids, *id)
+ }
+ return
+}
+
+// ResetUser resets all changes to the "user" edge.
+func (m *AuthIdentityMutation) ResetUser() {
+ m.user = nil
+ m.cleareduser = false
+}
+
+// AddChannelIDs adds the "channels" edge to the AuthIdentityChannel entity by ids.
+func (m *AuthIdentityMutation) AddChannelIDs(ids ...int64) {
+ if m.channels == nil {
+ m.channels = make(map[int64]struct{})
+ }
+ for i := range ids {
+ m.channels[ids[i]] = struct{}{}
+ }
+}
+
+// ClearChannels clears the "channels" edge to the AuthIdentityChannel entity.
+func (m *AuthIdentityMutation) ClearChannels() {
+ m.clearedchannels = true
+}
+
+// ChannelsCleared reports if the "channels" edge to the AuthIdentityChannel entity was cleared.
+func (m *AuthIdentityMutation) ChannelsCleared() bool {
+ return m.clearedchannels
+}
+
+// RemoveChannelIDs removes the "channels" edge to the AuthIdentityChannel entity by IDs.
+func (m *AuthIdentityMutation) RemoveChannelIDs(ids ...int64) {
+ if m.removedchannels == nil {
+ m.removedchannels = make(map[int64]struct{})
+ }
+ for i := range ids {
+ delete(m.channels, ids[i])
+ m.removedchannels[ids[i]] = struct{}{}
+ }
+}
+
+// RemovedChannels returns the removed IDs of the "channels" edge to the AuthIdentityChannel entity.
+func (m *AuthIdentityMutation) RemovedChannelsIDs() (ids []int64) {
+ for id := range m.removedchannels {
+ ids = append(ids, id)
+ }
+ return
+}
+
+// ChannelsIDs returns the "channels" edge IDs in the mutation.
+func (m *AuthIdentityMutation) ChannelsIDs() (ids []int64) {
+ for id := range m.channels {
+ ids = append(ids, id)
+ }
+ return
+}
+
+// ResetChannels resets all changes to the "channels" edge.
+func (m *AuthIdentityMutation) ResetChannels() {
+ m.channels = nil
+ m.clearedchannels = false
+ m.removedchannels = nil
+}
+
+// AddAdoptionDecisionIDs adds the "adoption_decisions" edge to the IdentityAdoptionDecision entity by ids.
+func (m *AuthIdentityMutation) AddAdoptionDecisionIDs(ids ...int64) {
+ if m.adoption_decisions == nil {
+ m.adoption_decisions = make(map[int64]struct{})
+ }
+ for i := range ids {
+ m.adoption_decisions[ids[i]] = struct{}{}
+ }
+}
+
+// ClearAdoptionDecisions clears the "adoption_decisions" edge to the IdentityAdoptionDecision entity.
+func (m *AuthIdentityMutation) ClearAdoptionDecisions() {
+ m.clearedadoption_decisions = true
+}
+
+// AdoptionDecisionsCleared reports if the "adoption_decisions" edge to the IdentityAdoptionDecision entity was cleared.
+func (m *AuthIdentityMutation) AdoptionDecisionsCleared() bool {
+ return m.clearedadoption_decisions
+}
+
+// RemoveAdoptionDecisionIDs removes the "adoption_decisions" edge to the IdentityAdoptionDecision entity by IDs.
+func (m *AuthIdentityMutation) RemoveAdoptionDecisionIDs(ids ...int64) {
+ if m.removedadoption_decisions == nil {
+ m.removedadoption_decisions = make(map[int64]struct{})
+ }
+ for i := range ids {
+ delete(m.adoption_decisions, ids[i])
+ m.removedadoption_decisions[ids[i]] = struct{}{}
+ }
+}
+
+// RemovedAdoptionDecisions returns the removed IDs of the "adoption_decisions" edge to the IdentityAdoptionDecision entity.
+func (m *AuthIdentityMutation) RemovedAdoptionDecisionsIDs() (ids []int64) {
+ for id := range m.removedadoption_decisions {
+ ids = append(ids, id)
+ }
+ return
+}
+
+// AdoptionDecisionsIDs returns the "adoption_decisions" edge IDs in the mutation.
+func (m *AuthIdentityMutation) AdoptionDecisionsIDs() (ids []int64) {
+ for id := range m.adoption_decisions {
+ ids = append(ids, id)
+ }
+ return
+}
+
+// ResetAdoptionDecisions resets all changes to the "adoption_decisions" edge.
+func (m *AuthIdentityMutation) ResetAdoptionDecisions() {
+ m.adoption_decisions = nil
+ m.clearedadoption_decisions = false
+ m.removedadoption_decisions = nil
+}
+
+// Where appends a list predicates to the AuthIdentityMutation builder.
+func (m *AuthIdentityMutation) Where(ps ...predicate.AuthIdentity) {
+ m.predicates = append(m.predicates, ps...)
+}
+
+// WhereP appends storage-level predicates to the AuthIdentityMutation builder. Using this method,
+// users can use type-assertion to append predicates that do not depend on any generated package.
+func (m *AuthIdentityMutation) WhereP(ps ...func(*sql.Selector)) {
+ p := make([]predicate.AuthIdentity, len(ps))
+ for i := range ps {
+ p[i] = ps[i]
+ }
+ m.Where(p...)
+}
+
+// Op returns the operation name.
+func (m *AuthIdentityMutation) Op() Op {
+ return m.op
+}
+
+// SetOp allows setting the mutation operation.
+func (m *AuthIdentityMutation) SetOp(op Op) {
+ m.op = op
+}
+
+// Type returns the node type of this mutation (AuthIdentity).
+func (m *AuthIdentityMutation) Type() string {
+ return m.typ
+}
+
+// Fields returns all fields that were changed during this mutation. Note that in
+// order to get all numeric fields that were incremented/decremented, call
+// AddedFields().
+func (m *AuthIdentityMutation) Fields() []string {
+ fields := make([]string, 0, 9)
+ if m.created_at != nil {
+ fields = append(fields, authidentity.FieldCreatedAt)
+ }
+ if m.updated_at != nil {
+ fields = append(fields, authidentity.FieldUpdatedAt)
+ }
+ if m.user != nil {
+ fields = append(fields, authidentity.FieldUserID)
+ }
+ if m.provider_type != nil {
+ fields = append(fields, authidentity.FieldProviderType)
+ }
+ if m.provider_key != nil {
+ fields = append(fields, authidentity.FieldProviderKey)
+ }
+ if m.provider_subject != nil {
+ fields = append(fields, authidentity.FieldProviderSubject)
+ }
+ if m.verified_at != nil {
+ fields = append(fields, authidentity.FieldVerifiedAt)
+ }
+ if m.issuer != nil {
+ fields = append(fields, authidentity.FieldIssuer)
+ }
+ if m.metadata != nil {
+ fields = append(fields, authidentity.FieldMetadata)
+ }
+ return fields
+}
+
+// Field returns the value of a field with the given name. The second boolean
+// return value indicates that this field was not set, or was not defined in the
+// schema.
+func (m *AuthIdentityMutation) Field(name string) (ent.Value, bool) {
+ switch name {
+ case authidentity.FieldCreatedAt:
+ return m.CreatedAt()
+ case authidentity.FieldUpdatedAt:
+ return m.UpdatedAt()
+ case authidentity.FieldUserID:
+ return m.UserID()
+ case authidentity.FieldProviderType:
+ return m.ProviderType()
+ case authidentity.FieldProviderKey:
+ return m.ProviderKey()
+ case authidentity.FieldProviderSubject:
+ return m.ProviderSubject()
+ case authidentity.FieldVerifiedAt:
+ return m.VerifiedAt()
+ case authidentity.FieldIssuer:
+ return m.Issuer()
+ case authidentity.FieldMetadata:
+ return m.Metadata()
+ }
+ return nil, false
+}
+
+// OldField returns the old value of the field from the database. An error is
+// returned if the mutation operation is not UpdateOne, or the query to the
+// database failed.
+func (m *AuthIdentityMutation) OldField(ctx context.Context, name string) (ent.Value, error) {
+ switch name {
+ case authidentity.FieldCreatedAt:
+ return m.OldCreatedAt(ctx)
+ case authidentity.FieldUpdatedAt:
+ return m.OldUpdatedAt(ctx)
+ case authidentity.FieldUserID:
+ return m.OldUserID(ctx)
+ case authidentity.FieldProviderType:
+ return m.OldProviderType(ctx)
+ case authidentity.FieldProviderKey:
+ return m.OldProviderKey(ctx)
+ case authidentity.FieldProviderSubject:
+ return m.OldProviderSubject(ctx)
+ case authidentity.FieldVerifiedAt:
+ return m.OldVerifiedAt(ctx)
+ case authidentity.FieldIssuer:
+ return m.OldIssuer(ctx)
+ case authidentity.FieldMetadata:
+ return m.OldMetadata(ctx)
+ }
+ return nil, fmt.Errorf("unknown AuthIdentity field %s", name)
+}
+
+// SetField sets the value of a field with the given name. It returns an error if
+// the field is not defined in the schema, or if the type mismatched the field
+// type.
+func (m *AuthIdentityMutation) SetField(name string, value ent.Value) error {
+ switch name {
+ case authidentity.FieldCreatedAt:
+ v, ok := value.(time.Time)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetCreatedAt(v)
+ return nil
+ case authidentity.FieldUpdatedAt:
+ v, ok := value.(time.Time)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetUpdatedAt(v)
+ return nil
+ case authidentity.FieldUserID:
+ v, ok := value.(int64)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetUserID(v)
+ return nil
+ case authidentity.FieldProviderType:
+ v, ok := value.(string)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetProviderType(v)
+ return nil
+ case authidentity.FieldProviderKey:
+ v, ok := value.(string)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetProviderKey(v)
+ return nil
+ case authidentity.FieldProviderSubject:
+ v, ok := value.(string)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetProviderSubject(v)
+ return nil
+ case authidentity.FieldVerifiedAt:
+ v, ok := value.(time.Time)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetVerifiedAt(v)
+ return nil
+ case authidentity.FieldIssuer:
+ v, ok := value.(string)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetIssuer(v)
+ return nil
+ case authidentity.FieldMetadata:
+ v, ok := value.(map[string]interface{})
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetMetadata(v)
+ return nil
+ }
+ return fmt.Errorf("unknown AuthIdentity field %s", name)
+}
+
+// AddedFields returns all numeric fields that were incremented/decremented during
+// this mutation.
+func (m *AuthIdentityMutation) AddedFields() []string {
+ var fields []string
+ return fields
+}
+
+// AddedField returns the numeric value that was incremented/decremented on a field
+// with the given name. The second boolean return value indicates that this field
+// was not set, or was not defined in the schema.
+func (m *AuthIdentityMutation) AddedField(name string) (ent.Value, bool) {
+ switch name {
+ }
+ return nil, false
+}
+
+// AddField adds the value to the field with the given name. It returns an error if
+// the field is not defined in the schema, or if the type mismatched the field
+// type.
+func (m *AuthIdentityMutation) AddField(name string, value ent.Value) error {
+ switch name {
+ }
+ return fmt.Errorf("unknown AuthIdentity numeric field %s", name)
+}
+
+// ClearedFields returns all nullable fields that were cleared during this
+// mutation.
+func (m *AuthIdentityMutation) ClearedFields() []string {
+ var fields []string
+ if m.FieldCleared(authidentity.FieldVerifiedAt) {
+ fields = append(fields, authidentity.FieldVerifiedAt)
+ }
+ if m.FieldCleared(authidentity.FieldIssuer) {
+ fields = append(fields, authidentity.FieldIssuer)
+ }
+ return fields
+}
+
+// FieldCleared returns a boolean indicating if a field with the given name was
+// cleared in this mutation.
+func (m *AuthIdentityMutation) FieldCleared(name string) bool {
+ _, ok := m.clearedFields[name]
+ return ok
+}
+
+// ClearField clears the value of the field with the given name. It returns an
+// error if the field is not defined in the schema.
+func (m *AuthIdentityMutation) ClearField(name string) error {
+ switch name {
+ case authidentity.FieldVerifiedAt:
+ m.ClearVerifiedAt()
+ return nil
+ case authidentity.FieldIssuer:
+ m.ClearIssuer()
+ return nil
+ }
+ return fmt.Errorf("unknown AuthIdentity nullable field %s", name)
+}
+
+// ResetField resets all changes in the mutation for the field with the given name.
+// It returns an error if the field is not defined in the schema.
+func (m *AuthIdentityMutation) ResetField(name string) error {
+ switch name {
+ case authidentity.FieldCreatedAt:
+ m.ResetCreatedAt()
+ return nil
+ case authidentity.FieldUpdatedAt:
+ m.ResetUpdatedAt()
+ return nil
+ case authidentity.FieldUserID:
+ m.ResetUserID()
+ return nil
+ case authidentity.FieldProviderType:
+ m.ResetProviderType()
+ return nil
+ case authidentity.FieldProviderKey:
+ m.ResetProviderKey()
+ return nil
+ case authidentity.FieldProviderSubject:
+ m.ResetProviderSubject()
+ return nil
+ case authidentity.FieldVerifiedAt:
+ m.ResetVerifiedAt()
+ return nil
+ case authidentity.FieldIssuer:
+ m.ResetIssuer()
+ return nil
+ case authidentity.FieldMetadata:
+ m.ResetMetadata()
+ return nil
+ }
+ return fmt.Errorf("unknown AuthIdentity field %s", name)
+}
+
+// AddedEdges returns all edge names that were set/added in this mutation.
+func (m *AuthIdentityMutation) AddedEdges() []string {
+ edges := make([]string, 0, 3)
+ if m.user != nil {
+ edges = append(edges, authidentity.EdgeUser)
+ }
+ if m.channels != nil {
+ edges = append(edges, authidentity.EdgeChannels)
+ }
+ if m.adoption_decisions != nil {
+ edges = append(edges, authidentity.EdgeAdoptionDecisions)
+ }
+ return edges
+}
+
+// AddedIDs returns all IDs (to other nodes) that were added for the given edge
+// name in this mutation.
+func (m *AuthIdentityMutation) AddedIDs(name string) []ent.Value {
+ switch name {
+ case authidentity.EdgeUser:
+ if id := m.user; id != nil {
+ return []ent.Value{*id}
+ }
+ case authidentity.EdgeChannels:
+ ids := make([]ent.Value, 0, len(m.channels))
+ for id := range m.channels {
+ ids = append(ids, id)
+ }
+ return ids
+ case authidentity.EdgeAdoptionDecisions:
+ ids := make([]ent.Value, 0, len(m.adoption_decisions))
+ for id := range m.adoption_decisions {
+ ids = append(ids, id)
+ }
+ return ids
+ }
+ return nil
+}
+
+// RemovedEdges returns all edge names that were removed in this mutation.
+func (m *AuthIdentityMutation) RemovedEdges() []string {
+ edges := make([]string, 0, 3)
+ if m.removedchannels != nil {
+ edges = append(edges, authidentity.EdgeChannels)
+ }
+ if m.removedadoption_decisions != nil {
+ edges = append(edges, authidentity.EdgeAdoptionDecisions)
+ }
+ return edges
+}
+
+// RemovedIDs returns all IDs (to other nodes) that were removed for the edge with
+// the given name in this mutation.
+func (m *AuthIdentityMutation) RemovedIDs(name string) []ent.Value {
+ switch name {
+ case authidentity.EdgeChannels:
+ ids := make([]ent.Value, 0, len(m.removedchannels))
+ for id := range m.removedchannels {
+ ids = append(ids, id)
+ }
+ return ids
+ case authidentity.EdgeAdoptionDecisions:
+ ids := make([]ent.Value, 0, len(m.removedadoption_decisions))
+ for id := range m.removedadoption_decisions {
+ ids = append(ids, id)
+ }
+ return ids
+ }
+ return nil
+}
+
+// ClearedEdges returns all edge names that were cleared in this mutation.
+func (m *AuthIdentityMutation) ClearedEdges() []string {
+ edges := make([]string, 0, 3)
+ if m.cleareduser {
+ edges = append(edges, authidentity.EdgeUser)
+ }
+ if m.clearedchannels {
+ edges = append(edges, authidentity.EdgeChannels)
+ }
+ if m.clearedadoption_decisions {
+ edges = append(edges, authidentity.EdgeAdoptionDecisions)
+ }
+ return edges
+}
+
+// EdgeCleared returns a boolean which indicates if the edge with the given name
+// was cleared in this mutation.
+func (m *AuthIdentityMutation) EdgeCleared(name string) bool {
+ switch name {
+ case authidentity.EdgeUser:
+ return m.cleareduser
+ case authidentity.EdgeChannels:
+ return m.clearedchannels
+ case authidentity.EdgeAdoptionDecisions:
+ return m.clearedadoption_decisions
+ }
+ return false
+}
+
+// ClearEdge clears the value of the edge with the given name. It returns an error
+// if that edge is not defined in the schema.
+func (m *AuthIdentityMutation) ClearEdge(name string) error {
+ switch name {
+ case authidentity.EdgeUser:
+ m.ClearUser()
+ return nil
+ }
+ return fmt.Errorf("unknown AuthIdentity unique edge %s", name)
+}
+
+// ResetEdge resets all changes to the edge with the given name in this mutation.
+// It returns an error if the edge is not defined in the schema.
+func (m *AuthIdentityMutation) ResetEdge(name string) error {
+ switch name {
+ case authidentity.EdgeUser:
+ m.ResetUser()
+ return nil
+ case authidentity.EdgeChannels:
+ m.ResetChannels()
+ return nil
+ case authidentity.EdgeAdoptionDecisions:
+ m.ResetAdoptionDecisions()
+ return nil
+ }
+ return fmt.Errorf("unknown AuthIdentity edge %s", name)
+}
+
+// AuthIdentityChannelMutation represents an operation that mutates the AuthIdentityChannel nodes in the graph.
+type AuthIdentityChannelMutation struct {
+ config
+ op Op
+ typ string
+ id *int64
+ created_at *time.Time
+ updated_at *time.Time
+ provider_type *string
+ provider_key *string
+ channel *string
+ channel_app_id *string
+ channel_subject *string
+ metadata *map[string]interface{}
+ clearedFields map[string]struct{}
+ identity *int64
+ clearedidentity bool
+ done bool
+ oldValue func(context.Context) (*AuthIdentityChannel, error)
+ predicates []predicate.AuthIdentityChannel
+}
+
+var _ ent.Mutation = (*AuthIdentityChannelMutation)(nil)
+
+// authidentitychannelOption allows management of the mutation configuration using functional options.
+type authidentitychannelOption func(*AuthIdentityChannelMutation)
+
+// newAuthIdentityChannelMutation creates new mutation for the AuthIdentityChannel entity.
+func newAuthIdentityChannelMutation(c config, op Op, opts ...authidentitychannelOption) *AuthIdentityChannelMutation {
+ m := &AuthIdentityChannelMutation{
+ config: c,
+ op: op,
+ typ: TypeAuthIdentityChannel,
+ clearedFields: make(map[string]struct{}),
+ }
+ for _, opt := range opts {
+ opt(m)
+ }
+ return m
+}
+
+// withAuthIdentityChannelID sets the ID field of the mutation.
+func withAuthIdentityChannelID(id int64) authidentitychannelOption {
+ return func(m *AuthIdentityChannelMutation) {
+ var (
+ err error
+ once sync.Once
+ value *AuthIdentityChannel
+ )
+ m.oldValue = func(ctx context.Context) (*AuthIdentityChannel, error) {
+ once.Do(func() {
+ if m.done {
+ err = errors.New("querying old values post mutation is not allowed")
+ } else {
+ value, err = m.Client().AuthIdentityChannel.Get(ctx, id)
+ }
+ })
+ return value, err
+ }
+ m.id = &id
+ }
+}
+
+// withAuthIdentityChannel sets the old AuthIdentityChannel of the mutation.
+func withAuthIdentityChannel(node *AuthIdentityChannel) authidentitychannelOption {
+ return func(m *AuthIdentityChannelMutation) {
+ m.oldValue = func(context.Context) (*AuthIdentityChannel, error) {
+ return node, nil
+ }
+ m.id = &node.ID
+ }
+}
+
+// Client returns a new `ent.Client` from the mutation. If the mutation was
+// executed in a transaction (ent.Tx), a transactional client is returned.
+func (m AuthIdentityChannelMutation) Client() *Client {
+ client := &Client{config: m.config}
+ client.init()
+ return client
+}
+
+// Tx returns an `ent.Tx` for mutations that were executed in transactions;
+// it returns an error otherwise.
+func (m AuthIdentityChannelMutation) Tx() (*Tx, error) {
+ if _, ok := m.driver.(*txDriver); !ok {
+ return nil, errors.New("ent: mutation is not running in a transaction")
+ }
+ tx := &Tx{config: m.config}
+ tx.init()
+ return tx, nil
+}
+
+// ID returns the ID value in the mutation. Note that the ID is only available
+// if it was provided to the builder or after it was returned from the database.
+func (m *AuthIdentityChannelMutation) ID() (id int64, exists bool) {
+ if m.id == nil {
+ return
+ }
+ return *m.id, true
+}
+
+// IDs queries the database and returns the entity ids that match the mutation's predicate.
+// That means, if the mutation is applied within a transaction with an isolation level such
+// as sql.LevelSerializable, the returned ids match the ids of the rows that will be updated
+// or updated by the mutation.
+func (m *AuthIdentityChannelMutation) IDs(ctx context.Context) ([]int64, error) {
+ switch {
+ case m.op.Is(OpUpdateOne | OpDeleteOne):
+ id, exists := m.ID()
+ if exists {
+ return []int64{id}, nil
+ }
+ fallthrough
+ case m.op.Is(OpUpdate | OpDelete):
+ return m.Client().AuthIdentityChannel.Query().Where(m.predicates...).IDs(ctx)
+ default:
+ return nil, fmt.Errorf("IDs is not allowed on %s operations", m.op)
+ }
+}
+
+// SetCreatedAt sets the "created_at" field.
+func (m *AuthIdentityChannelMutation) SetCreatedAt(t time.Time) {
+ m.created_at = &t
+}
+
+// CreatedAt returns the value of the "created_at" field in the mutation.
+func (m *AuthIdentityChannelMutation) CreatedAt() (r time.Time, exists bool) {
+ v := m.created_at
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// OldCreatedAt returns the old "created_at" field's value of the AuthIdentityChannel entity.
+// If the AuthIdentityChannel object wasn't provided to the builder, the object is fetched from the database.
+// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
+func (m *AuthIdentityChannelMutation) OldCreatedAt(ctx context.Context) (v time.Time, err error) {
+ if !m.op.Is(OpUpdateOne) {
+ return v, errors.New("OldCreatedAt is only allowed on UpdateOne operations")
+ }
+ if m.id == nil || m.oldValue == nil {
+ return v, errors.New("OldCreatedAt requires an ID field in the mutation")
+ }
+ oldValue, err := m.oldValue(ctx)
+ if err != nil {
+ return v, fmt.Errorf("querying old value for OldCreatedAt: %w", err)
+ }
+ return oldValue.CreatedAt, nil
+}
+
+// ResetCreatedAt resets all changes to the "created_at" field.
+func (m *AuthIdentityChannelMutation) ResetCreatedAt() {
+ m.created_at = nil
+}
+
+// SetUpdatedAt sets the "updated_at" field.
+func (m *AuthIdentityChannelMutation) SetUpdatedAt(t time.Time) {
+ m.updated_at = &t
+}
+
+// UpdatedAt returns the value of the "updated_at" field in the mutation.
+func (m *AuthIdentityChannelMutation) UpdatedAt() (r time.Time, exists bool) {
+ v := m.updated_at
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// OldUpdatedAt returns the old "updated_at" field's value of the AuthIdentityChannel entity.
+// If the AuthIdentityChannel object wasn't provided to the builder, the object is fetched from the database.
+// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
+func (m *AuthIdentityChannelMutation) OldUpdatedAt(ctx context.Context) (v time.Time, err error) {
+ if !m.op.Is(OpUpdateOne) {
+ return v, errors.New("OldUpdatedAt is only allowed on UpdateOne operations")
+ }
+ if m.id == nil || m.oldValue == nil {
+ return v, errors.New("OldUpdatedAt requires an ID field in the mutation")
+ }
+ oldValue, err := m.oldValue(ctx)
+ if err != nil {
+ return v, fmt.Errorf("querying old value for OldUpdatedAt: %w", err)
+ }
+ return oldValue.UpdatedAt, nil
+}
+
+// ResetUpdatedAt resets all changes to the "updated_at" field.
+func (m *AuthIdentityChannelMutation) ResetUpdatedAt() {
+ m.updated_at = nil
+}
+
+// SetIdentityID sets the "identity_id" field.
+func (m *AuthIdentityChannelMutation) SetIdentityID(i int64) {
+ m.identity = &i
+}
+
+// IdentityID returns the value of the "identity_id" field in the mutation.
+func (m *AuthIdentityChannelMutation) IdentityID() (r int64, exists bool) {
+ v := m.identity
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// OldIdentityID returns the old "identity_id" field's value of the AuthIdentityChannel entity.
+// If the AuthIdentityChannel object wasn't provided to the builder, the object is fetched from the database.
+// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
+func (m *AuthIdentityChannelMutation) OldIdentityID(ctx context.Context) (v int64, err error) {
+ if !m.op.Is(OpUpdateOne) {
+ return v, errors.New("OldIdentityID is only allowed on UpdateOne operations")
+ }
+ if m.id == nil || m.oldValue == nil {
+ return v, errors.New("OldIdentityID requires an ID field in the mutation")
+ }
+ oldValue, err := m.oldValue(ctx)
+ if err != nil {
+ return v, fmt.Errorf("querying old value for OldIdentityID: %w", err)
+ }
+ return oldValue.IdentityID, nil
+}
+
+// ResetIdentityID resets all changes to the "identity_id" field.
+func (m *AuthIdentityChannelMutation) ResetIdentityID() {
+ m.identity = nil
+}
+
+// SetProviderType sets the "provider_type" field.
+func (m *AuthIdentityChannelMutation) SetProviderType(s string) {
+ m.provider_type = &s
+}
+
+// ProviderType returns the value of the "provider_type" field in the mutation.
+func (m *AuthIdentityChannelMutation) ProviderType() (r string, exists bool) {
+ v := m.provider_type
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// OldProviderType returns the old "provider_type" field's value of the AuthIdentityChannel entity.
+// If the AuthIdentityChannel object wasn't provided to the builder, the object is fetched from the database.
+// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
+func (m *AuthIdentityChannelMutation) OldProviderType(ctx context.Context) (v string, err error) {
+ if !m.op.Is(OpUpdateOne) {
+ return v, errors.New("OldProviderType is only allowed on UpdateOne operations")
+ }
+ if m.id == nil || m.oldValue == nil {
+ return v, errors.New("OldProviderType requires an ID field in the mutation")
+ }
+ oldValue, err := m.oldValue(ctx)
+ if err != nil {
+ return v, fmt.Errorf("querying old value for OldProviderType: %w", err)
+ }
+ return oldValue.ProviderType, nil
+}
+
+// ResetProviderType resets all changes to the "provider_type" field.
+func (m *AuthIdentityChannelMutation) ResetProviderType() {
+ m.provider_type = nil
+}
+
+// SetProviderKey sets the "provider_key" field.
+func (m *AuthIdentityChannelMutation) SetProviderKey(s string) {
+ m.provider_key = &s
+}
+
+// ProviderKey returns the value of the "provider_key" field in the mutation.
+func (m *AuthIdentityChannelMutation) ProviderKey() (r string, exists bool) {
+ v := m.provider_key
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// OldProviderKey returns the old "provider_key" field's value of the AuthIdentityChannel entity.
+// If the AuthIdentityChannel object wasn't provided to the builder, the object is fetched from the database.
+// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
+func (m *AuthIdentityChannelMutation) OldProviderKey(ctx context.Context) (v string, err error) {
+ if !m.op.Is(OpUpdateOne) {
+ return v, errors.New("OldProviderKey is only allowed on UpdateOne operations")
+ }
+ if m.id == nil || m.oldValue == nil {
+ return v, errors.New("OldProviderKey requires an ID field in the mutation")
+ }
+ oldValue, err := m.oldValue(ctx)
+ if err != nil {
+ return v, fmt.Errorf("querying old value for OldProviderKey: %w", err)
+ }
+ return oldValue.ProviderKey, nil
+}
+
+// ResetProviderKey resets all changes to the "provider_key" field.
+func (m *AuthIdentityChannelMutation) ResetProviderKey() {
+ m.provider_key = nil
+}
+
+// SetChannel sets the "channel" field.
+func (m *AuthIdentityChannelMutation) SetChannel(s string) {
+ m.channel = &s
+}
+
+// Channel returns the value of the "channel" field in the mutation.
+func (m *AuthIdentityChannelMutation) Channel() (r string, exists bool) {
+ v := m.channel
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// OldChannel returns the old "channel" field's value of the AuthIdentityChannel entity.
+// If the AuthIdentityChannel object wasn't provided to the builder, the object is fetched from the database.
+// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
+func (m *AuthIdentityChannelMutation) OldChannel(ctx context.Context) (v string, err error) {
+ if !m.op.Is(OpUpdateOne) {
+ return v, errors.New("OldChannel is only allowed on UpdateOne operations")
+ }
+ if m.id == nil || m.oldValue == nil {
+ return v, errors.New("OldChannel requires an ID field in the mutation")
+ }
+ oldValue, err := m.oldValue(ctx)
+ if err != nil {
+ return v, fmt.Errorf("querying old value for OldChannel: %w", err)
+ }
+ return oldValue.Channel, nil
+}
+
+// ResetChannel resets all changes to the "channel" field.
+func (m *AuthIdentityChannelMutation) ResetChannel() {
+ m.channel = nil
+}
+
+// SetChannelAppID sets the "channel_app_id" field.
+func (m *AuthIdentityChannelMutation) SetChannelAppID(s string) {
+ m.channel_app_id = &s
+}
+
+// ChannelAppID returns the value of the "channel_app_id" field in the mutation.
+func (m *AuthIdentityChannelMutation) ChannelAppID() (r string, exists bool) {
+ v := m.channel_app_id
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// OldChannelAppID returns the old "channel_app_id" field's value of the AuthIdentityChannel entity.
+// If the AuthIdentityChannel object wasn't provided to the builder, the object is fetched from the database.
+// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
+func (m *AuthIdentityChannelMutation) OldChannelAppID(ctx context.Context) (v string, err error) {
+ if !m.op.Is(OpUpdateOne) {
+ return v, errors.New("OldChannelAppID is only allowed on UpdateOne operations")
+ }
+ if m.id == nil || m.oldValue == nil {
+ return v, errors.New("OldChannelAppID requires an ID field in the mutation")
+ }
+ oldValue, err := m.oldValue(ctx)
+ if err != nil {
+ return v, fmt.Errorf("querying old value for OldChannelAppID: %w", err)
+ }
+ return oldValue.ChannelAppID, nil
+}
+
+// ResetChannelAppID resets all changes to the "channel_app_id" field.
+func (m *AuthIdentityChannelMutation) ResetChannelAppID() {
+ m.channel_app_id = nil
+}
+
+// SetChannelSubject sets the "channel_subject" field.
+func (m *AuthIdentityChannelMutation) SetChannelSubject(s string) {
+ m.channel_subject = &s
+}
+
+// ChannelSubject returns the value of the "channel_subject" field in the mutation.
+func (m *AuthIdentityChannelMutation) ChannelSubject() (r string, exists bool) {
+ v := m.channel_subject
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// OldChannelSubject returns the old "channel_subject" field's value of the AuthIdentityChannel entity.
+// If the AuthIdentityChannel object wasn't provided to the builder, the object is fetched from the database.
+// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
+func (m *AuthIdentityChannelMutation) OldChannelSubject(ctx context.Context) (v string, err error) {
+ if !m.op.Is(OpUpdateOne) {
+ return v, errors.New("OldChannelSubject is only allowed on UpdateOne operations")
+ }
+ if m.id == nil || m.oldValue == nil {
+ return v, errors.New("OldChannelSubject requires an ID field in the mutation")
+ }
+ oldValue, err := m.oldValue(ctx)
+ if err != nil {
+ return v, fmt.Errorf("querying old value for OldChannelSubject: %w", err)
+ }
+ return oldValue.ChannelSubject, nil
+}
+
+// ResetChannelSubject resets all changes to the "channel_subject" field.
+func (m *AuthIdentityChannelMutation) ResetChannelSubject() {
+ m.channel_subject = nil
+}
+
+// SetMetadata sets the "metadata" field.
+func (m *AuthIdentityChannelMutation) SetMetadata(value map[string]interface{}) {
+ m.metadata = &value
+}
+
+// Metadata returns the value of the "metadata" field in the mutation.
+func (m *AuthIdentityChannelMutation) Metadata() (r map[string]interface{}, exists bool) {
+ v := m.metadata
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// OldMetadata returns the old "metadata" field's value of the AuthIdentityChannel entity.
+// If the AuthIdentityChannel object wasn't provided to the builder, the object is fetched from the database.
+// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
+func (m *AuthIdentityChannelMutation) OldMetadata(ctx context.Context) (v map[string]interface{}, err error) {
+ if !m.op.Is(OpUpdateOne) {
+ return v, errors.New("OldMetadata is only allowed on UpdateOne operations")
+ }
+ if m.id == nil || m.oldValue == nil {
+ return v, errors.New("OldMetadata requires an ID field in the mutation")
+ }
+ oldValue, err := m.oldValue(ctx)
+ if err != nil {
+ return v, fmt.Errorf("querying old value for OldMetadata: %w", err)
+ }
+ return oldValue.Metadata, nil
+}
+
+// ResetMetadata resets all changes to the "metadata" field.
+func (m *AuthIdentityChannelMutation) ResetMetadata() {
+ m.metadata = nil
+}
+
+// ClearIdentity clears the "identity" edge to the AuthIdentity entity.
+func (m *AuthIdentityChannelMutation) ClearIdentity() {
+ m.clearedidentity = true
+ m.clearedFields[authidentitychannel.FieldIdentityID] = struct{}{}
+}
+
+// IdentityCleared reports if the "identity" edge to the AuthIdentity entity was cleared.
+func (m *AuthIdentityChannelMutation) IdentityCleared() bool {
+ return m.clearedidentity
+}
+
+// IdentityIDs returns the "identity" edge IDs in the mutation.
+// Note that IDs always returns len(IDs) <= 1 for unique edges, and you should use
+// IdentityID instead. It exists only for internal usage by the builders.
+func (m *AuthIdentityChannelMutation) IdentityIDs() (ids []int64) {
+ if id := m.identity; id != nil {
+ ids = append(ids, *id)
+ }
+ return
+}
+
+// ResetIdentity resets all changes to the "identity" edge.
+func (m *AuthIdentityChannelMutation) ResetIdentity() {
+ m.identity = nil
+ m.clearedidentity = false
+}
+
+// Where appends a list predicates to the AuthIdentityChannelMutation builder.
+func (m *AuthIdentityChannelMutation) Where(ps ...predicate.AuthIdentityChannel) {
+ m.predicates = append(m.predicates, ps...)
+}
+
+// WhereP appends storage-level predicates to the AuthIdentityChannelMutation builder. Using this method,
+// users can use type-assertion to append predicates that do not depend on any generated package.
+func (m *AuthIdentityChannelMutation) WhereP(ps ...func(*sql.Selector)) {
+ p := make([]predicate.AuthIdentityChannel, len(ps))
+ for i := range ps {
+ p[i] = ps[i]
+ }
+ m.Where(p...)
+}
+
+// Op returns the operation name.
+func (m *AuthIdentityChannelMutation) Op() Op {
+ return m.op
+}
+
+// SetOp allows setting the mutation operation.
+func (m *AuthIdentityChannelMutation) SetOp(op Op) {
+ m.op = op
+}
+
+// Type returns the node type of this mutation (AuthIdentityChannel).
+func (m *AuthIdentityChannelMutation) Type() string {
+ return m.typ
+}
+
+// Fields returns all fields that were changed during this mutation. Note that in
+// order to get all numeric fields that were incremented/decremented, call
+// AddedFields().
+func (m *AuthIdentityChannelMutation) Fields() []string {
+ fields := make([]string, 0, 9)
+ if m.created_at != nil {
+ fields = append(fields, authidentitychannel.FieldCreatedAt)
+ }
+ if m.updated_at != nil {
+ fields = append(fields, authidentitychannel.FieldUpdatedAt)
+ }
+ if m.identity != nil {
+ fields = append(fields, authidentitychannel.FieldIdentityID)
+ }
+ if m.provider_type != nil {
+ fields = append(fields, authidentitychannel.FieldProviderType)
+ }
+ if m.provider_key != nil {
+ fields = append(fields, authidentitychannel.FieldProviderKey)
+ }
+ if m.channel != nil {
+ fields = append(fields, authidentitychannel.FieldChannel)
+ }
+ if m.channel_app_id != nil {
+ fields = append(fields, authidentitychannel.FieldChannelAppID)
+ }
+ if m.channel_subject != nil {
+ fields = append(fields, authidentitychannel.FieldChannelSubject)
+ }
+ if m.metadata != nil {
+ fields = append(fields, authidentitychannel.FieldMetadata)
+ }
+ return fields
+}
+
+// Field returns the value of a field with the given name. The second boolean
+// return value indicates that this field was not set, or was not defined in the
+// schema.
+func (m *AuthIdentityChannelMutation) Field(name string) (ent.Value, bool) {
+ switch name {
+ case authidentitychannel.FieldCreatedAt:
+ return m.CreatedAt()
+ case authidentitychannel.FieldUpdatedAt:
+ return m.UpdatedAt()
+ case authidentitychannel.FieldIdentityID:
+ return m.IdentityID()
+ case authidentitychannel.FieldProviderType:
+ return m.ProviderType()
+ case authidentitychannel.FieldProviderKey:
+ return m.ProviderKey()
+ case authidentitychannel.FieldChannel:
+ return m.Channel()
+ case authidentitychannel.FieldChannelAppID:
+ return m.ChannelAppID()
+ case authidentitychannel.FieldChannelSubject:
+ return m.ChannelSubject()
+ case authidentitychannel.FieldMetadata:
+ return m.Metadata()
+ }
+ return nil, false
+}
+
+// OldField returns the old value of the field from the database. An error is
+// returned if the mutation operation is not UpdateOne, or the query to the
+// database failed.
+func (m *AuthIdentityChannelMutation) OldField(ctx context.Context, name string) (ent.Value, error) {
+ switch name {
+ case authidentitychannel.FieldCreatedAt:
+ return m.OldCreatedAt(ctx)
+ case authidentitychannel.FieldUpdatedAt:
+ return m.OldUpdatedAt(ctx)
+ case authidentitychannel.FieldIdentityID:
+ return m.OldIdentityID(ctx)
+ case authidentitychannel.FieldProviderType:
+ return m.OldProviderType(ctx)
+ case authidentitychannel.FieldProviderKey:
+ return m.OldProviderKey(ctx)
+ case authidentitychannel.FieldChannel:
+ return m.OldChannel(ctx)
+ case authidentitychannel.FieldChannelAppID:
+ return m.OldChannelAppID(ctx)
+ case authidentitychannel.FieldChannelSubject:
+ return m.OldChannelSubject(ctx)
+ case authidentitychannel.FieldMetadata:
+ return m.OldMetadata(ctx)
+ }
+ return nil, fmt.Errorf("unknown AuthIdentityChannel field %s", name)
+}
+
+// SetField sets the value of a field with the given name. It returns an error if
+// the field is not defined in the schema, or if the type mismatched the field
+// type.
+func (m *AuthIdentityChannelMutation) SetField(name string, value ent.Value) error {
+ switch name {
+ case authidentitychannel.FieldCreatedAt:
+ v, ok := value.(time.Time)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetCreatedAt(v)
+ return nil
+ case authidentitychannel.FieldUpdatedAt:
+ v, ok := value.(time.Time)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetUpdatedAt(v)
+ return nil
+ case authidentitychannel.FieldIdentityID:
+ v, ok := value.(int64)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetIdentityID(v)
+ return nil
+ case authidentitychannel.FieldProviderType:
+ v, ok := value.(string)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetProviderType(v)
+ return nil
+ case authidentitychannel.FieldProviderKey:
+ v, ok := value.(string)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetProviderKey(v)
+ return nil
+ case authidentitychannel.FieldChannel:
+ v, ok := value.(string)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetChannel(v)
+ return nil
+ case authidentitychannel.FieldChannelAppID:
+ v, ok := value.(string)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetChannelAppID(v)
+ return nil
+ case authidentitychannel.FieldChannelSubject:
+ v, ok := value.(string)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetChannelSubject(v)
+ return nil
+ case authidentitychannel.FieldMetadata:
+ v, ok := value.(map[string]interface{})
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetMetadata(v)
+ return nil
+ }
+ return fmt.Errorf("unknown AuthIdentityChannel field %s", name)
+}
+
+// AddedFields returns all numeric fields that were incremented/decremented during
+// this mutation.
+func (m *AuthIdentityChannelMutation) AddedFields() []string {
+ var fields []string
+ return fields
+}
+
+// AddedField returns the numeric value that was incremented/decremented on a field
+// with the given name. The second boolean return value indicates that this field
+// was not set, or was not defined in the schema.
+func (m *AuthIdentityChannelMutation) AddedField(name string) (ent.Value, bool) {
+ switch name {
+ }
+ return nil, false
+}
+
+// AddField adds the value to the field with the given name. It returns an error if
+// the field is not defined in the schema, or if the type mismatched the field
+// type.
+func (m *AuthIdentityChannelMutation) AddField(name string, value ent.Value) error {
+ switch name {
+ }
+ return fmt.Errorf("unknown AuthIdentityChannel numeric field %s", name)
+}
+
+// ClearedFields returns all nullable fields that were cleared during this
+// mutation.
+func (m *AuthIdentityChannelMutation) ClearedFields() []string {
+ return nil
+}
+
+// FieldCleared returns a boolean indicating if a field with the given name was
+// cleared in this mutation.
+func (m *AuthIdentityChannelMutation) FieldCleared(name string) bool {
+ _, ok := m.clearedFields[name]
+ return ok
+}
+
+// ClearField clears the value of the field with the given name. It returns an
+// error if the field is not defined in the schema.
+func (m *AuthIdentityChannelMutation) ClearField(name string) error {
+ return fmt.Errorf("unknown AuthIdentityChannel nullable field %s", name)
+}
+
+// ResetField resets all changes in the mutation for the field with the given name.
+// It returns an error if the field is not defined in the schema.
+func (m *AuthIdentityChannelMutation) ResetField(name string) error {
+ switch name {
+ case authidentitychannel.FieldCreatedAt:
+ m.ResetCreatedAt()
+ return nil
+ case authidentitychannel.FieldUpdatedAt:
+ m.ResetUpdatedAt()
+ return nil
+ case authidentitychannel.FieldIdentityID:
+ m.ResetIdentityID()
+ return nil
+ case authidentitychannel.FieldProviderType:
+ m.ResetProviderType()
+ return nil
+ case authidentitychannel.FieldProviderKey:
+ m.ResetProviderKey()
+ return nil
+ case authidentitychannel.FieldChannel:
+ m.ResetChannel()
+ return nil
+ case authidentitychannel.FieldChannelAppID:
+ m.ResetChannelAppID()
+ return nil
+ case authidentitychannel.FieldChannelSubject:
+ m.ResetChannelSubject()
+ return nil
+ case authidentitychannel.FieldMetadata:
+ m.ResetMetadata()
+ return nil
+ }
+ return fmt.Errorf("unknown AuthIdentityChannel field %s", name)
+}
+
+// AddedEdges returns all edge names that were set/added in this mutation.
+func (m *AuthIdentityChannelMutation) AddedEdges() []string {
+ edges := make([]string, 0, 1)
+ if m.identity != nil {
+ edges = append(edges, authidentitychannel.EdgeIdentity)
+ }
+ return edges
+}
+
+// AddedIDs returns all IDs (to other nodes) that were added for the given edge
+// name in this mutation.
+func (m *AuthIdentityChannelMutation) AddedIDs(name string) []ent.Value {
+ switch name {
+ case authidentitychannel.EdgeIdentity:
+ if id := m.identity; id != nil {
+ return []ent.Value{*id}
+ }
+ }
+ return nil
+}
+
+// RemovedEdges returns all edge names that were removed in this mutation.
+func (m *AuthIdentityChannelMutation) RemovedEdges() []string {
+ edges := make([]string, 0, 1)
+ return edges
+}
+
+// RemovedIDs returns all IDs (to other nodes) that were removed for the edge with
+// the given name in this mutation.
+func (m *AuthIdentityChannelMutation) RemovedIDs(name string) []ent.Value {
+ return nil
+}
+
+// ClearedEdges returns all edge names that were cleared in this mutation.
+func (m *AuthIdentityChannelMutation) ClearedEdges() []string {
+ edges := make([]string, 0, 1)
+ if m.clearedidentity {
+ edges = append(edges, authidentitychannel.EdgeIdentity)
+ }
+ return edges
+}
+
+// EdgeCleared returns a boolean which indicates if the edge with the given name
+// was cleared in this mutation.
+func (m *AuthIdentityChannelMutation) EdgeCleared(name string) bool {
+ switch name {
+ case authidentitychannel.EdgeIdentity:
+ return m.clearedidentity
+ }
+ return false
+}
+
+// ClearEdge clears the value of the edge with the given name. It returns an error
+// if that edge is not defined in the schema.
+func (m *AuthIdentityChannelMutation) ClearEdge(name string) error {
+ switch name {
+ case authidentitychannel.EdgeIdentity:
+ m.ClearIdentity()
+ return nil
+ }
+ return fmt.Errorf("unknown AuthIdentityChannel unique edge %s", name)
+}
+
+// ResetEdge resets all changes to the edge with the given name in this mutation.
+// It returns an error if the edge is not defined in the schema.
+func (m *AuthIdentityChannelMutation) ResetEdge(name string) error {
+ switch name {
+ case authidentitychannel.EdgeIdentity:
+ m.ResetIdentity()
+ return nil
+ }
+ return fmt.Errorf("unknown AuthIdentityChannel edge %s", name)
+}
+
// ErrorPassthroughRuleMutation represents an operation that mutates the ErrorPassthroughRule nodes in the graph.
type ErrorPassthroughRuleMutation struct {
config
@@ -12191,6 +14038,781 @@ func (m *IdempotencyRecordMutation) ResetEdge(name string) error {
return fmt.Errorf("unknown IdempotencyRecord edge %s", name)
}
+// IdentityAdoptionDecisionMutation represents an operation that mutates the IdentityAdoptionDecision nodes in the graph.
+type IdentityAdoptionDecisionMutation struct {
+ config
+ op Op
+ typ string
+ id *int64
+ created_at *time.Time
+ updated_at *time.Time
+ adopt_display_name *bool
+ adopt_avatar *bool
+ decided_at *time.Time
+ clearedFields map[string]struct{}
+ pending_auth_session *int64
+ clearedpending_auth_session bool
+ identity *int64
+ clearedidentity bool
+ done bool
+ oldValue func(context.Context) (*IdentityAdoptionDecision, error)
+ predicates []predicate.IdentityAdoptionDecision
+}
+
+var _ ent.Mutation = (*IdentityAdoptionDecisionMutation)(nil)
+
+// identityadoptiondecisionOption allows management of the mutation configuration using functional options.
+type identityadoptiondecisionOption func(*IdentityAdoptionDecisionMutation)
+
+// newIdentityAdoptionDecisionMutation creates new mutation for the IdentityAdoptionDecision entity.
+func newIdentityAdoptionDecisionMutation(c config, op Op, opts ...identityadoptiondecisionOption) *IdentityAdoptionDecisionMutation {
+ m := &IdentityAdoptionDecisionMutation{
+ config: c,
+ op: op,
+ typ: TypeIdentityAdoptionDecision,
+ clearedFields: make(map[string]struct{}),
+ }
+ for _, opt := range opts {
+ opt(m)
+ }
+ return m
+}
+
+// withIdentityAdoptionDecisionID sets the ID field of the mutation.
+func withIdentityAdoptionDecisionID(id int64) identityadoptiondecisionOption {
+ return func(m *IdentityAdoptionDecisionMutation) {
+ var (
+ err error
+ once sync.Once
+ value *IdentityAdoptionDecision
+ )
+ m.oldValue = func(ctx context.Context) (*IdentityAdoptionDecision, error) {
+ once.Do(func() {
+ if m.done {
+ err = errors.New("querying old values post mutation is not allowed")
+ } else {
+ value, err = m.Client().IdentityAdoptionDecision.Get(ctx, id)
+ }
+ })
+ return value, err
+ }
+ m.id = &id
+ }
+}
+
+// withIdentityAdoptionDecision sets the old IdentityAdoptionDecision of the mutation.
+func withIdentityAdoptionDecision(node *IdentityAdoptionDecision) identityadoptiondecisionOption {
+ return func(m *IdentityAdoptionDecisionMutation) {
+ m.oldValue = func(context.Context) (*IdentityAdoptionDecision, error) {
+ return node, nil
+ }
+ m.id = &node.ID
+ }
+}
+
+// Client returns a new `ent.Client` from the mutation. If the mutation was
+// executed in a transaction (ent.Tx), a transactional client is returned.
+func (m IdentityAdoptionDecisionMutation) Client() *Client {
+ client := &Client{config: m.config}
+ client.init()
+ return client
+}
+
+// Tx returns an `ent.Tx` for mutations that were executed in transactions;
+// it returns an error otherwise.
+func (m IdentityAdoptionDecisionMutation) Tx() (*Tx, error) {
+ if _, ok := m.driver.(*txDriver); !ok {
+ return nil, errors.New("ent: mutation is not running in a transaction")
+ }
+ tx := &Tx{config: m.config}
+ tx.init()
+ return tx, nil
+}
+
+// ID returns the ID value in the mutation. Note that the ID is only available
+// if it was provided to the builder or after it was returned from the database.
+func (m *IdentityAdoptionDecisionMutation) ID() (id int64, exists bool) {
+ if m.id == nil {
+ return
+ }
+ return *m.id, true
+}
+
+// IDs queries the database and returns the entity ids that match the mutation's predicate.
+// That means, if the mutation is applied within a transaction with an isolation level such
+// as sql.LevelSerializable, the returned ids match the ids of the rows that will be updated
+// or updated by the mutation.
+func (m *IdentityAdoptionDecisionMutation) IDs(ctx context.Context) ([]int64, error) {
+ switch {
+ case m.op.Is(OpUpdateOne | OpDeleteOne):
+ id, exists := m.ID()
+ if exists {
+ return []int64{id}, nil
+ }
+ fallthrough
+ case m.op.Is(OpUpdate | OpDelete):
+ return m.Client().IdentityAdoptionDecision.Query().Where(m.predicates...).IDs(ctx)
+ default:
+ return nil, fmt.Errorf("IDs is not allowed on %s operations", m.op)
+ }
+}
+
+// SetCreatedAt sets the "created_at" field.
+func (m *IdentityAdoptionDecisionMutation) SetCreatedAt(t time.Time) {
+ m.created_at = &t
+}
+
+// CreatedAt returns the value of the "created_at" field in the mutation.
+func (m *IdentityAdoptionDecisionMutation) CreatedAt() (r time.Time, exists bool) {
+ v := m.created_at
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// OldCreatedAt returns the old "created_at" field's value of the IdentityAdoptionDecision entity.
+// If the IdentityAdoptionDecision object wasn't provided to the builder, the object is fetched from the database.
+// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
+func (m *IdentityAdoptionDecisionMutation) OldCreatedAt(ctx context.Context) (v time.Time, err error) {
+ if !m.op.Is(OpUpdateOne) {
+ return v, errors.New("OldCreatedAt is only allowed on UpdateOne operations")
+ }
+ if m.id == nil || m.oldValue == nil {
+ return v, errors.New("OldCreatedAt requires an ID field in the mutation")
+ }
+ oldValue, err := m.oldValue(ctx)
+ if err != nil {
+ return v, fmt.Errorf("querying old value for OldCreatedAt: %w", err)
+ }
+ return oldValue.CreatedAt, nil
+}
+
+// ResetCreatedAt resets all changes to the "created_at" field.
+func (m *IdentityAdoptionDecisionMutation) ResetCreatedAt() {
+ m.created_at = nil
+}
+
+// SetUpdatedAt sets the "updated_at" field.
+func (m *IdentityAdoptionDecisionMutation) SetUpdatedAt(t time.Time) {
+ m.updated_at = &t
+}
+
+// UpdatedAt returns the value of the "updated_at" field in the mutation.
+func (m *IdentityAdoptionDecisionMutation) UpdatedAt() (r time.Time, exists bool) {
+ v := m.updated_at
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// OldUpdatedAt returns the old "updated_at" field's value of the IdentityAdoptionDecision entity.
+// If the IdentityAdoptionDecision object wasn't provided to the builder, the object is fetched from the database.
+// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
+func (m *IdentityAdoptionDecisionMutation) OldUpdatedAt(ctx context.Context) (v time.Time, err error) {
+ if !m.op.Is(OpUpdateOne) {
+ return v, errors.New("OldUpdatedAt is only allowed on UpdateOne operations")
+ }
+ if m.id == nil || m.oldValue == nil {
+ return v, errors.New("OldUpdatedAt requires an ID field in the mutation")
+ }
+ oldValue, err := m.oldValue(ctx)
+ if err != nil {
+ return v, fmt.Errorf("querying old value for OldUpdatedAt: %w", err)
+ }
+ return oldValue.UpdatedAt, nil
+}
+
+// ResetUpdatedAt resets all changes to the "updated_at" field.
+func (m *IdentityAdoptionDecisionMutation) ResetUpdatedAt() {
+ m.updated_at = nil
+}
+
+// SetPendingAuthSessionID sets the "pending_auth_session_id" field.
+func (m *IdentityAdoptionDecisionMutation) SetPendingAuthSessionID(i int64) {
+ m.pending_auth_session = &i
+}
+
+// PendingAuthSessionID returns the value of the "pending_auth_session_id" field in the mutation.
+func (m *IdentityAdoptionDecisionMutation) PendingAuthSessionID() (r int64, exists bool) {
+ v := m.pending_auth_session
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// OldPendingAuthSessionID returns the old "pending_auth_session_id" field's value of the IdentityAdoptionDecision entity.
+// If the IdentityAdoptionDecision object wasn't provided to the builder, the object is fetched from the database.
+// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
+func (m *IdentityAdoptionDecisionMutation) OldPendingAuthSessionID(ctx context.Context) (v int64, err error) {
+ if !m.op.Is(OpUpdateOne) {
+ return v, errors.New("OldPendingAuthSessionID is only allowed on UpdateOne operations")
+ }
+ if m.id == nil || m.oldValue == nil {
+ return v, errors.New("OldPendingAuthSessionID requires an ID field in the mutation")
+ }
+ oldValue, err := m.oldValue(ctx)
+ if err != nil {
+ return v, fmt.Errorf("querying old value for OldPendingAuthSessionID: %w", err)
+ }
+ return oldValue.PendingAuthSessionID, nil
+}
+
+// ResetPendingAuthSessionID resets all changes to the "pending_auth_session_id" field.
+func (m *IdentityAdoptionDecisionMutation) ResetPendingAuthSessionID() {
+ m.pending_auth_session = nil
+}
+
+// SetIdentityID sets the "identity_id" field.
+func (m *IdentityAdoptionDecisionMutation) SetIdentityID(i int64) {
+ m.identity = &i
+}
+
+// IdentityID returns the value of the "identity_id" field in the mutation.
+func (m *IdentityAdoptionDecisionMutation) IdentityID() (r int64, exists bool) {
+ v := m.identity
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// OldIdentityID returns the old "identity_id" field's value of the IdentityAdoptionDecision entity.
+// If the IdentityAdoptionDecision object wasn't provided to the builder, the object is fetched from the database.
+// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
+func (m *IdentityAdoptionDecisionMutation) OldIdentityID(ctx context.Context) (v *int64, err error) {
+ if !m.op.Is(OpUpdateOne) {
+ return v, errors.New("OldIdentityID is only allowed on UpdateOne operations")
+ }
+ if m.id == nil || m.oldValue == nil {
+ return v, errors.New("OldIdentityID requires an ID field in the mutation")
+ }
+ oldValue, err := m.oldValue(ctx)
+ if err != nil {
+ return v, fmt.Errorf("querying old value for OldIdentityID: %w", err)
+ }
+ return oldValue.IdentityID, nil
+}
+
+// ClearIdentityID clears the value of the "identity_id" field.
+func (m *IdentityAdoptionDecisionMutation) ClearIdentityID() {
+ m.identity = nil
+ m.clearedFields[identityadoptiondecision.FieldIdentityID] = struct{}{}
+}
+
+// IdentityIDCleared returns if the "identity_id" field was cleared in this mutation.
+func (m *IdentityAdoptionDecisionMutation) IdentityIDCleared() bool {
+ _, ok := m.clearedFields[identityadoptiondecision.FieldIdentityID]
+ return ok
+}
+
+// ResetIdentityID resets all changes to the "identity_id" field.
+func (m *IdentityAdoptionDecisionMutation) ResetIdentityID() {
+ m.identity = nil
+ delete(m.clearedFields, identityadoptiondecision.FieldIdentityID)
+}
+
+// SetAdoptDisplayName sets the "adopt_display_name" field.
+func (m *IdentityAdoptionDecisionMutation) SetAdoptDisplayName(b bool) {
+ m.adopt_display_name = &b
+}
+
+// AdoptDisplayName returns the value of the "adopt_display_name" field in the mutation.
+func (m *IdentityAdoptionDecisionMutation) AdoptDisplayName() (r bool, exists bool) {
+ v := m.adopt_display_name
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// OldAdoptDisplayName returns the old "adopt_display_name" field's value of the IdentityAdoptionDecision entity.
+// If the IdentityAdoptionDecision object wasn't provided to the builder, the object is fetched from the database.
+// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
+func (m *IdentityAdoptionDecisionMutation) OldAdoptDisplayName(ctx context.Context) (v bool, err error) {
+ if !m.op.Is(OpUpdateOne) {
+ return v, errors.New("OldAdoptDisplayName is only allowed on UpdateOne operations")
+ }
+ if m.id == nil || m.oldValue == nil {
+ return v, errors.New("OldAdoptDisplayName requires an ID field in the mutation")
+ }
+ oldValue, err := m.oldValue(ctx)
+ if err != nil {
+ return v, fmt.Errorf("querying old value for OldAdoptDisplayName: %w", err)
+ }
+ return oldValue.AdoptDisplayName, nil
+}
+
+// ResetAdoptDisplayName resets all changes to the "adopt_display_name" field.
+func (m *IdentityAdoptionDecisionMutation) ResetAdoptDisplayName() {
+ m.adopt_display_name = nil
+}
+
+// SetAdoptAvatar sets the "adopt_avatar" field.
+func (m *IdentityAdoptionDecisionMutation) SetAdoptAvatar(b bool) {
+ m.adopt_avatar = &b
+}
+
+// AdoptAvatar returns the value of the "adopt_avatar" field in the mutation.
+func (m *IdentityAdoptionDecisionMutation) AdoptAvatar() (r bool, exists bool) {
+ v := m.adopt_avatar
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// OldAdoptAvatar returns the old "adopt_avatar" field's value of the IdentityAdoptionDecision entity.
+// If the IdentityAdoptionDecision object wasn't provided to the builder, the object is fetched from the database.
+// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
+func (m *IdentityAdoptionDecisionMutation) OldAdoptAvatar(ctx context.Context) (v bool, err error) {
+ if !m.op.Is(OpUpdateOne) {
+ return v, errors.New("OldAdoptAvatar is only allowed on UpdateOne operations")
+ }
+ if m.id == nil || m.oldValue == nil {
+ return v, errors.New("OldAdoptAvatar requires an ID field in the mutation")
+ }
+ oldValue, err := m.oldValue(ctx)
+ if err != nil {
+ return v, fmt.Errorf("querying old value for OldAdoptAvatar: %w", err)
+ }
+ return oldValue.AdoptAvatar, nil
+}
+
+// ResetAdoptAvatar resets all changes to the "adopt_avatar" field.
+func (m *IdentityAdoptionDecisionMutation) ResetAdoptAvatar() {
+ m.adopt_avatar = nil
+}
+
+// SetDecidedAt sets the "decided_at" field.
+func (m *IdentityAdoptionDecisionMutation) SetDecidedAt(t time.Time) {
+ m.decided_at = &t
+}
+
+// DecidedAt returns the value of the "decided_at" field in the mutation.
+func (m *IdentityAdoptionDecisionMutation) DecidedAt() (r time.Time, exists bool) {
+ v := m.decided_at
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// OldDecidedAt returns the old "decided_at" field's value of the IdentityAdoptionDecision entity.
+// If the IdentityAdoptionDecision object wasn't provided to the builder, the object is fetched from the database.
+// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
+func (m *IdentityAdoptionDecisionMutation) OldDecidedAt(ctx context.Context) (v time.Time, err error) {
+ if !m.op.Is(OpUpdateOne) {
+ return v, errors.New("OldDecidedAt is only allowed on UpdateOne operations")
+ }
+ if m.id == nil || m.oldValue == nil {
+ return v, errors.New("OldDecidedAt requires an ID field in the mutation")
+ }
+ oldValue, err := m.oldValue(ctx)
+ if err != nil {
+ return v, fmt.Errorf("querying old value for OldDecidedAt: %w", err)
+ }
+ return oldValue.DecidedAt, nil
+}
+
+// ResetDecidedAt resets all changes to the "decided_at" field.
+func (m *IdentityAdoptionDecisionMutation) ResetDecidedAt() {
+ m.decided_at = nil
+}
+
+// ClearPendingAuthSession clears the "pending_auth_session" edge to the PendingAuthSession entity.
+func (m *IdentityAdoptionDecisionMutation) ClearPendingAuthSession() {
+ m.clearedpending_auth_session = true
+ m.clearedFields[identityadoptiondecision.FieldPendingAuthSessionID] = struct{}{}
+}
+
+// PendingAuthSessionCleared reports if the "pending_auth_session" edge to the PendingAuthSession entity was cleared.
+func (m *IdentityAdoptionDecisionMutation) PendingAuthSessionCleared() bool {
+ return m.clearedpending_auth_session
+}
+
+// PendingAuthSessionIDs returns the "pending_auth_session" edge IDs in the mutation.
+// Note that IDs always returns len(IDs) <= 1 for unique edges, and you should use
+// PendingAuthSessionID instead. It exists only for internal usage by the builders.
+func (m *IdentityAdoptionDecisionMutation) PendingAuthSessionIDs() (ids []int64) {
+ if id := m.pending_auth_session; id != nil {
+ ids = append(ids, *id)
+ }
+ return
+}
+
+// ResetPendingAuthSession resets all changes to the "pending_auth_session" edge.
+func (m *IdentityAdoptionDecisionMutation) ResetPendingAuthSession() {
+ m.pending_auth_session = nil
+ m.clearedpending_auth_session = false
+}
+
+// ClearIdentity clears the "identity" edge to the AuthIdentity entity.
+func (m *IdentityAdoptionDecisionMutation) ClearIdentity() {
+ m.clearedidentity = true
+ m.clearedFields[identityadoptiondecision.FieldIdentityID] = struct{}{}
+}
+
+// IdentityCleared reports if the "identity" edge to the AuthIdentity entity was cleared.
+func (m *IdentityAdoptionDecisionMutation) IdentityCleared() bool {
+ return m.IdentityIDCleared() || m.clearedidentity
+}
+
+// IdentityIDs returns the "identity" edge IDs in the mutation.
+// Note that IDs always returns len(IDs) <= 1 for unique edges, and you should use
+// IdentityID instead. It exists only for internal usage by the builders.
+func (m *IdentityAdoptionDecisionMutation) IdentityIDs() (ids []int64) {
+ if id := m.identity; id != nil {
+ ids = append(ids, *id)
+ }
+ return
+}
+
+// ResetIdentity resets all changes to the "identity" edge.
+func (m *IdentityAdoptionDecisionMutation) ResetIdentity() {
+ m.identity = nil
+ m.clearedidentity = false
+}
+
+// Where appends a list predicates to the IdentityAdoptionDecisionMutation builder.
+func (m *IdentityAdoptionDecisionMutation) Where(ps ...predicate.IdentityAdoptionDecision) {
+ m.predicates = append(m.predicates, ps...)
+}
+
+// WhereP appends storage-level predicates to the IdentityAdoptionDecisionMutation builder. Using this method,
+// users can use type-assertion to append predicates that do not depend on any generated package.
+func (m *IdentityAdoptionDecisionMutation) WhereP(ps ...func(*sql.Selector)) {
+ p := make([]predicate.IdentityAdoptionDecision, len(ps))
+ for i := range ps {
+ p[i] = ps[i]
+ }
+ m.Where(p...)
+}
+
+// Op returns the operation name.
+func (m *IdentityAdoptionDecisionMutation) Op() Op {
+ return m.op
+}
+
+// SetOp allows setting the mutation operation.
+func (m *IdentityAdoptionDecisionMutation) SetOp(op Op) {
+ m.op = op
+}
+
+// Type returns the node type of this mutation (IdentityAdoptionDecision).
+func (m *IdentityAdoptionDecisionMutation) Type() string {
+ return m.typ
+}
+
+// Fields returns all fields that were changed during this mutation. Note that in
+// order to get all numeric fields that were incremented/decremented, call
+// AddedFields().
+func (m *IdentityAdoptionDecisionMutation) Fields() []string {
+ fields := make([]string, 0, 7)
+ if m.created_at != nil {
+ fields = append(fields, identityadoptiondecision.FieldCreatedAt)
+ }
+ if m.updated_at != nil {
+ fields = append(fields, identityadoptiondecision.FieldUpdatedAt)
+ }
+ if m.pending_auth_session != nil {
+ fields = append(fields, identityadoptiondecision.FieldPendingAuthSessionID)
+ }
+ if m.identity != nil {
+ fields = append(fields, identityadoptiondecision.FieldIdentityID)
+ }
+ if m.adopt_display_name != nil {
+ fields = append(fields, identityadoptiondecision.FieldAdoptDisplayName)
+ }
+ if m.adopt_avatar != nil {
+ fields = append(fields, identityadoptiondecision.FieldAdoptAvatar)
+ }
+ if m.decided_at != nil {
+ fields = append(fields, identityadoptiondecision.FieldDecidedAt)
+ }
+ return fields
+}
+
+// Field returns the value of a field with the given name. The second boolean
+// return value indicates that this field was not set, or was not defined in the
+// schema.
+func (m *IdentityAdoptionDecisionMutation) Field(name string) (ent.Value, bool) {
+ switch name {
+ case identityadoptiondecision.FieldCreatedAt:
+ return m.CreatedAt()
+ case identityadoptiondecision.FieldUpdatedAt:
+ return m.UpdatedAt()
+ case identityadoptiondecision.FieldPendingAuthSessionID:
+ return m.PendingAuthSessionID()
+ case identityadoptiondecision.FieldIdentityID:
+ return m.IdentityID()
+ case identityadoptiondecision.FieldAdoptDisplayName:
+ return m.AdoptDisplayName()
+ case identityadoptiondecision.FieldAdoptAvatar:
+ return m.AdoptAvatar()
+ case identityadoptiondecision.FieldDecidedAt:
+ return m.DecidedAt()
+ }
+ return nil, false
+}
+
+// OldField returns the old value of the field from the database. An error is
+// returned if the mutation operation is not UpdateOne, or the query to the
+// database failed.
+func (m *IdentityAdoptionDecisionMutation) OldField(ctx context.Context, name string) (ent.Value, error) {
+ switch name {
+ case identityadoptiondecision.FieldCreatedAt:
+ return m.OldCreatedAt(ctx)
+ case identityadoptiondecision.FieldUpdatedAt:
+ return m.OldUpdatedAt(ctx)
+ case identityadoptiondecision.FieldPendingAuthSessionID:
+ return m.OldPendingAuthSessionID(ctx)
+ case identityadoptiondecision.FieldIdentityID:
+ return m.OldIdentityID(ctx)
+ case identityadoptiondecision.FieldAdoptDisplayName:
+ return m.OldAdoptDisplayName(ctx)
+ case identityadoptiondecision.FieldAdoptAvatar:
+ return m.OldAdoptAvatar(ctx)
+ case identityadoptiondecision.FieldDecidedAt:
+ return m.OldDecidedAt(ctx)
+ }
+ return nil, fmt.Errorf("unknown IdentityAdoptionDecision field %s", name)
+}
+
+// SetField sets the value of a field with the given name. It returns an error if
+// the field is not defined in the schema, or if the type mismatched the field
+// type.
+func (m *IdentityAdoptionDecisionMutation) SetField(name string, value ent.Value) error {
+ switch name {
+ case identityadoptiondecision.FieldCreatedAt:
+ v, ok := value.(time.Time)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetCreatedAt(v)
+ return nil
+ case identityadoptiondecision.FieldUpdatedAt:
+ v, ok := value.(time.Time)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetUpdatedAt(v)
+ return nil
+ case identityadoptiondecision.FieldPendingAuthSessionID:
+ v, ok := value.(int64)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetPendingAuthSessionID(v)
+ return nil
+ case identityadoptiondecision.FieldIdentityID:
+ v, ok := value.(int64)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetIdentityID(v)
+ return nil
+ case identityadoptiondecision.FieldAdoptDisplayName:
+ v, ok := value.(bool)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetAdoptDisplayName(v)
+ return nil
+ case identityadoptiondecision.FieldAdoptAvatar:
+ v, ok := value.(bool)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetAdoptAvatar(v)
+ return nil
+ case identityadoptiondecision.FieldDecidedAt:
+ v, ok := value.(time.Time)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetDecidedAt(v)
+ return nil
+ }
+ return fmt.Errorf("unknown IdentityAdoptionDecision field %s", name)
+}
+
+// AddedFields returns all numeric fields that were incremented/decremented during
+// this mutation.
+func (m *IdentityAdoptionDecisionMutation) AddedFields() []string {
+ var fields []string
+ return fields
+}
+
+// AddedField returns the numeric value that was incremented/decremented on a field
+// with the given name. The second boolean return value indicates that this field
+// was not set, or was not defined in the schema.
+func (m *IdentityAdoptionDecisionMutation) AddedField(name string) (ent.Value, bool) {
+ switch name {
+ }
+ return nil, false
+}
+
+// AddField adds the value to the field with the given name. It returns an error if
+// the field is not defined in the schema, or if the type mismatched the field
+// type.
+func (m *IdentityAdoptionDecisionMutation) AddField(name string, value ent.Value) error {
+ switch name {
+ }
+ return fmt.Errorf("unknown IdentityAdoptionDecision numeric field %s", name)
+}
+
+// ClearedFields returns all nullable fields that were cleared during this
+// mutation.
+func (m *IdentityAdoptionDecisionMutation) ClearedFields() []string {
+ var fields []string
+ if m.FieldCleared(identityadoptiondecision.FieldIdentityID) {
+ fields = append(fields, identityadoptiondecision.FieldIdentityID)
+ }
+ return fields
+}
+
+// FieldCleared returns a boolean indicating if a field with the given name was
+// cleared in this mutation.
+func (m *IdentityAdoptionDecisionMutation) FieldCleared(name string) bool {
+ _, ok := m.clearedFields[name]
+ return ok
+}
+
+// ClearField clears the value of the field with the given name. It returns an
+// error if the field is not defined in the schema.
+func (m *IdentityAdoptionDecisionMutation) ClearField(name string) error {
+ switch name {
+ case identityadoptiondecision.FieldIdentityID:
+ m.ClearIdentityID()
+ return nil
+ }
+ return fmt.Errorf("unknown IdentityAdoptionDecision nullable field %s", name)
+}
+
+// ResetField resets all changes in the mutation for the field with the given name.
+// It returns an error if the field is not defined in the schema.
+func (m *IdentityAdoptionDecisionMutation) ResetField(name string) error {
+ switch name {
+ case identityadoptiondecision.FieldCreatedAt:
+ m.ResetCreatedAt()
+ return nil
+ case identityadoptiondecision.FieldUpdatedAt:
+ m.ResetUpdatedAt()
+ return nil
+ case identityadoptiondecision.FieldPendingAuthSessionID:
+ m.ResetPendingAuthSessionID()
+ return nil
+ case identityadoptiondecision.FieldIdentityID:
+ m.ResetIdentityID()
+ return nil
+ case identityadoptiondecision.FieldAdoptDisplayName:
+ m.ResetAdoptDisplayName()
+ return nil
+ case identityadoptiondecision.FieldAdoptAvatar:
+ m.ResetAdoptAvatar()
+ return nil
+ case identityadoptiondecision.FieldDecidedAt:
+ m.ResetDecidedAt()
+ return nil
+ }
+ return fmt.Errorf("unknown IdentityAdoptionDecision field %s", name)
+}
+
+// AddedEdges returns all edge names that were set/added in this mutation.
+func (m *IdentityAdoptionDecisionMutation) AddedEdges() []string {
+ edges := make([]string, 0, 2)
+ if m.pending_auth_session != nil {
+ edges = append(edges, identityadoptiondecision.EdgePendingAuthSession)
+ }
+ if m.identity != nil {
+ edges = append(edges, identityadoptiondecision.EdgeIdentity)
+ }
+ return edges
+}
+
+// AddedIDs returns all IDs (to other nodes) that were added for the given edge
+// name in this mutation.
+func (m *IdentityAdoptionDecisionMutation) AddedIDs(name string) []ent.Value {
+ switch name {
+ case identityadoptiondecision.EdgePendingAuthSession:
+ if id := m.pending_auth_session; id != nil {
+ return []ent.Value{*id}
+ }
+ case identityadoptiondecision.EdgeIdentity:
+ if id := m.identity; id != nil {
+ return []ent.Value{*id}
+ }
+ }
+ return nil
+}
+
+// RemovedEdges returns all edge names that were removed in this mutation.
+func (m *IdentityAdoptionDecisionMutation) RemovedEdges() []string {
+ edges := make([]string, 0, 2)
+ return edges
+}
+
+// RemovedIDs returns all IDs (to other nodes) that were removed for the edge with
+// the given name in this mutation.
+func (m *IdentityAdoptionDecisionMutation) RemovedIDs(name string) []ent.Value {
+ return nil
+}
+
+// ClearedEdges returns all edge names that were cleared in this mutation.
+func (m *IdentityAdoptionDecisionMutation) ClearedEdges() []string {
+ edges := make([]string, 0, 2)
+ if m.clearedpending_auth_session {
+ edges = append(edges, identityadoptiondecision.EdgePendingAuthSession)
+ }
+ if m.clearedidentity {
+ edges = append(edges, identityadoptiondecision.EdgeIdentity)
+ }
+ return edges
+}
+
+// EdgeCleared returns a boolean which indicates if the edge with the given name
+// was cleared in this mutation.
+func (m *IdentityAdoptionDecisionMutation) EdgeCleared(name string) bool {
+ switch name {
+ case identityadoptiondecision.EdgePendingAuthSession:
+ return m.clearedpending_auth_session
+ case identityadoptiondecision.EdgeIdentity:
+ return m.clearedidentity
+ }
+ return false
+}
+
+// ClearEdge clears the value of the edge with the given name. It returns an error
+// if that edge is not defined in the schema.
+func (m *IdentityAdoptionDecisionMutation) ClearEdge(name string) error {
+ switch name {
+ case identityadoptiondecision.EdgePendingAuthSession:
+ m.ClearPendingAuthSession()
+ return nil
+ case identityadoptiondecision.EdgeIdentity:
+ m.ClearIdentity()
+ return nil
+ }
+ return fmt.Errorf("unknown IdentityAdoptionDecision unique edge %s", name)
+}
+
+// ResetEdge resets all changes to the edge with the given name in this mutation.
+// It returns an error if the edge is not defined in the schema.
+func (m *IdentityAdoptionDecisionMutation) ResetEdge(name string) error {
+ switch name {
+ case identityadoptiondecision.EdgePendingAuthSession:
+ m.ResetPendingAuthSession()
+ return nil
+ case identityadoptiondecision.EdgeIdentity:
+ m.ResetIdentity()
+ return nil
+ }
+ return fmt.Errorf("unknown IdentityAdoptionDecision edge %s", name)
+}
+
// PaymentAuditLogMutation represents an operation that mutates the PaymentAuditLog nodes in the graph.
type PaymentAuditLogMutation struct {
config
@@ -16595,6 +19217,1645 @@ func (m *PaymentProviderInstanceMutation) ResetEdge(name string) error {
return fmt.Errorf("unknown PaymentProviderInstance edge %s", name)
}
+// PendingAuthSessionMutation represents an operation that mutates the PendingAuthSession nodes in the graph.
+type PendingAuthSessionMutation struct {
+ config
+ op Op
+ typ string
+ id *int64
+ created_at *time.Time
+ updated_at *time.Time
+ session_token *string
+ intent *string
+ provider_type *string
+ provider_key *string
+ provider_subject *string
+ redirect_to *string
+ resolved_email *string
+ registration_password_hash *string
+ upstream_identity_claims *map[string]interface{}
+ local_flow_state *map[string]interface{}
+ browser_session_key *string
+ completion_code_hash *string
+ completion_code_expires_at *time.Time
+ email_verified_at *time.Time
+ password_verified_at *time.Time
+ totp_verified_at *time.Time
+ expires_at *time.Time
+ consumed_at *time.Time
+ clearedFields map[string]struct{}
+ target_user *int64
+ clearedtarget_user bool
+ adoption_decision *int64
+ clearedadoption_decision bool
+ done bool
+ oldValue func(context.Context) (*PendingAuthSession, error)
+ predicates []predicate.PendingAuthSession
+}
+
+var _ ent.Mutation = (*PendingAuthSessionMutation)(nil)
+
+// pendingauthsessionOption allows management of the mutation configuration using functional options.
+type pendingauthsessionOption func(*PendingAuthSessionMutation)
+
+// newPendingAuthSessionMutation creates new mutation for the PendingAuthSession entity.
+func newPendingAuthSessionMutation(c config, op Op, opts ...pendingauthsessionOption) *PendingAuthSessionMutation {
+ m := &PendingAuthSessionMutation{
+ config: c,
+ op: op,
+ typ: TypePendingAuthSession,
+ clearedFields: make(map[string]struct{}),
+ }
+ for _, opt := range opts {
+ opt(m)
+ }
+ return m
+}
+
+// withPendingAuthSessionID sets the ID field of the mutation.
+func withPendingAuthSessionID(id int64) pendingauthsessionOption {
+ return func(m *PendingAuthSessionMutation) {
+ var (
+ err error
+ once sync.Once
+ value *PendingAuthSession
+ )
+ m.oldValue = func(ctx context.Context) (*PendingAuthSession, error) {
+ once.Do(func() {
+ if m.done {
+ err = errors.New("querying old values post mutation is not allowed")
+ } else {
+ value, err = m.Client().PendingAuthSession.Get(ctx, id)
+ }
+ })
+ return value, err
+ }
+ m.id = &id
+ }
+}
+
+// withPendingAuthSession sets the old PendingAuthSession of the mutation.
+func withPendingAuthSession(node *PendingAuthSession) pendingauthsessionOption {
+ return func(m *PendingAuthSessionMutation) {
+ m.oldValue = func(context.Context) (*PendingAuthSession, error) {
+ return node, nil
+ }
+ m.id = &node.ID
+ }
+}
+
+// Client returns a new `ent.Client` from the mutation. If the mutation was
+// executed in a transaction (ent.Tx), a transactional client is returned.
+func (m PendingAuthSessionMutation) Client() *Client {
+ client := &Client{config: m.config}
+ client.init()
+ return client
+}
+
+// Tx returns an `ent.Tx` for mutations that were executed in transactions;
+// it returns an error otherwise.
+func (m PendingAuthSessionMutation) Tx() (*Tx, error) {
+ if _, ok := m.driver.(*txDriver); !ok {
+ return nil, errors.New("ent: mutation is not running in a transaction")
+ }
+ tx := &Tx{config: m.config}
+ tx.init()
+ return tx, nil
+}
+
+// ID returns the ID value in the mutation. Note that the ID is only available
+// if it was provided to the builder or after it was returned from the database.
+func (m *PendingAuthSessionMutation) ID() (id int64, exists bool) {
+ if m.id == nil {
+ return
+ }
+ return *m.id, true
+}
+
+// IDs queries the database and returns the entity ids that match the mutation's predicate.
+// That means, if the mutation is applied within a transaction with an isolation level such
+// as sql.LevelSerializable, the returned ids match the ids of the rows that will be updated
+// or updated by the mutation.
+func (m *PendingAuthSessionMutation) IDs(ctx context.Context) ([]int64, error) {
+ switch {
+ case m.op.Is(OpUpdateOne | OpDeleteOne):
+ id, exists := m.ID()
+ if exists {
+ return []int64{id}, nil
+ }
+ fallthrough
+ case m.op.Is(OpUpdate | OpDelete):
+ return m.Client().PendingAuthSession.Query().Where(m.predicates...).IDs(ctx)
+ default:
+ return nil, fmt.Errorf("IDs is not allowed on %s operations", m.op)
+ }
+}
+
+// SetCreatedAt sets the "created_at" field.
+func (m *PendingAuthSessionMutation) SetCreatedAt(t time.Time) {
+ m.created_at = &t
+}
+
+// CreatedAt returns the value of the "created_at" field in the mutation.
+func (m *PendingAuthSessionMutation) CreatedAt() (r time.Time, exists bool) {
+ v := m.created_at
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// OldCreatedAt returns the old "created_at" field's value of the PendingAuthSession entity.
+// If the PendingAuthSession object wasn't provided to the builder, the object is fetched from the database.
+// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
+func (m *PendingAuthSessionMutation) OldCreatedAt(ctx context.Context) (v time.Time, err error) {
+ if !m.op.Is(OpUpdateOne) {
+ return v, errors.New("OldCreatedAt is only allowed on UpdateOne operations")
+ }
+ if m.id == nil || m.oldValue == nil {
+ return v, errors.New("OldCreatedAt requires an ID field in the mutation")
+ }
+ oldValue, err := m.oldValue(ctx)
+ if err != nil {
+ return v, fmt.Errorf("querying old value for OldCreatedAt: %w", err)
+ }
+ return oldValue.CreatedAt, nil
+}
+
+// ResetCreatedAt resets all changes to the "created_at" field.
+func (m *PendingAuthSessionMutation) ResetCreatedAt() {
+ m.created_at = nil
+}
+
+// SetUpdatedAt sets the "updated_at" field.
+func (m *PendingAuthSessionMutation) SetUpdatedAt(t time.Time) {
+ m.updated_at = &t
+}
+
+// UpdatedAt returns the value of the "updated_at" field in the mutation.
+func (m *PendingAuthSessionMutation) UpdatedAt() (r time.Time, exists bool) {
+ v := m.updated_at
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// OldUpdatedAt returns the old "updated_at" field's value of the PendingAuthSession entity.
+// If the PendingAuthSession object wasn't provided to the builder, the object is fetched from the database.
+// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
+func (m *PendingAuthSessionMutation) OldUpdatedAt(ctx context.Context) (v time.Time, err error) {
+ if !m.op.Is(OpUpdateOne) {
+ return v, errors.New("OldUpdatedAt is only allowed on UpdateOne operations")
+ }
+ if m.id == nil || m.oldValue == nil {
+ return v, errors.New("OldUpdatedAt requires an ID field in the mutation")
+ }
+ oldValue, err := m.oldValue(ctx)
+ if err != nil {
+ return v, fmt.Errorf("querying old value for OldUpdatedAt: %w", err)
+ }
+ return oldValue.UpdatedAt, nil
+}
+
+// ResetUpdatedAt resets all changes to the "updated_at" field.
+func (m *PendingAuthSessionMutation) ResetUpdatedAt() {
+ m.updated_at = nil
+}
+
+// SetSessionToken sets the "session_token" field.
+func (m *PendingAuthSessionMutation) SetSessionToken(s string) {
+ m.session_token = &s
+}
+
+// SessionToken returns the value of the "session_token" field in the mutation.
+func (m *PendingAuthSessionMutation) SessionToken() (r string, exists bool) {
+ v := m.session_token
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// OldSessionToken returns the old "session_token" field's value of the PendingAuthSession entity.
+// If the PendingAuthSession object wasn't provided to the builder, the object is fetched from the database.
+// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
+func (m *PendingAuthSessionMutation) OldSessionToken(ctx context.Context) (v string, err error) {
+ if !m.op.Is(OpUpdateOne) {
+ return v, errors.New("OldSessionToken is only allowed on UpdateOne operations")
+ }
+ if m.id == nil || m.oldValue == nil {
+ return v, errors.New("OldSessionToken requires an ID field in the mutation")
+ }
+ oldValue, err := m.oldValue(ctx)
+ if err != nil {
+ return v, fmt.Errorf("querying old value for OldSessionToken: %w", err)
+ }
+ return oldValue.SessionToken, nil
+}
+
+// ResetSessionToken resets all changes to the "session_token" field.
+func (m *PendingAuthSessionMutation) ResetSessionToken() {
+ m.session_token = nil
+}
+
+// SetIntent sets the "intent" field.
+func (m *PendingAuthSessionMutation) SetIntent(s string) {
+ m.intent = &s
+}
+
+// Intent returns the value of the "intent" field in the mutation.
+func (m *PendingAuthSessionMutation) Intent() (r string, exists bool) {
+ v := m.intent
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// OldIntent returns the old "intent" field's value of the PendingAuthSession entity.
+// If the PendingAuthSession object wasn't provided to the builder, the object is fetched from the database.
+// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
+func (m *PendingAuthSessionMutation) OldIntent(ctx context.Context) (v string, err error) {
+ if !m.op.Is(OpUpdateOne) {
+ return v, errors.New("OldIntent is only allowed on UpdateOne operations")
+ }
+ if m.id == nil || m.oldValue == nil {
+ return v, errors.New("OldIntent requires an ID field in the mutation")
+ }
+ oldValue, err := m.oldValue(ctx)
+ if err != nil {
+ return v, fmt.Errorf("querying old value for OldIntent: %w", err)
+ }
+ return oldValue.Intent, nil
+}
+
+// ResetIntent resets all changes to the "intent" field.
+func (m *PendingAuthSessionMutation) ResetIntent() {
+ m.intent = nil
+}
+
+// SetProviderType sets the "provider_type" field.
+func (m *PendingAuthSessionMutation) SetProviderType(s string) {
+ m.provider_type = &s
+}
+
+// ProviderType returns the value of the "provider_type" field in the mutation.
+func (m *PendingAuthSessionMutation) ProviderType() (r string, exists bool) {
+ v := m.provider_type
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// OldProviderType returns the old "provider_type" field's value of the PendingAuthSession entity.
+// If the PendingAuthSession object wasn't provided to the builder, the object is fetched from the database.
+// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
+func (m *PendingAuthSessionMutation) OldProviderType(ctx context.Context) (v string, err error) {
+ if !m.op.Is(OpUpdateOne) {
+ return v, errors.New("OldProviderType is only allowed on UpdateOne operations")
+ }
+ if m.id == nil || m.oldValue == nil {
+ return v, errors.New("OldProviderType requires an ID field in the mutation")
+ }
+ oldValue, err := m.oldValue(ctx)
+ if err != nil {
+ return v, fmt.Errorf("querying old value for OldProviderType: %w", err)
+ }
+ return oldValue.ProviderType, nil
+}
+
+// ResetProviderType resets all changes to the "provider_type" field.
+func (m *PendingAuthSessionMutation) ResetProviderType() {
+ m.provider_type = nil
+}
+
+// SetProviderKey sets the "provider_key" field.
+func (m *PendingAuthSessionMutation) SetProviderKey(s string) {
+ m.provider_key = &s
+}
+
+// ProviderKey returns the value of the "provider_key" field in the mutation.
+func (m *PendingAuthSessionMutation) ProviderKey() (r string, exists bool) {
+ v := m.provider_key
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// OldProviderKey returns the old "provider_key" field's value of the PendingAuthSession entity.
+// If the PendingAuthSession object wasn't provided to the builder, the object is fetched from the database.
+// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
+func (m *PendingAuthSessionMutation) OldProviderKey(ctx context.Context) (v string, err error) {
+ if !m.op.Is(OpUpdateOne) {
+ return v, errors.New("OldProviderKey is only allowed on UpdateOne operations")
+ }
+ if m.id == nil || m.oldValue == nil {
+ return v, errors.New("OldProviderKey requires an ID field in the mutation")
+ }
+ oldValue, err := m.oldValue(ctx)
+ if err != nil {
+ return v, fmt.Errorf("querying old value for OldProviderKey: %w", err)
+ }
+ return oldValue.ProviderKey, nil
+}
+
+// ResetProviderKey resets all changes to the "provider_key" field.
+func (m *PendingAuthSessionMutation) ResetProviderKey() {
+ m.provider_key = nil
+}
+
+// SetProviderSubject sets the "provider_subject" field.
+func (m *PendingAuthSessionMutation) SetProviderSubject(s string) {
+ m.provider_subject = &s
+}
+
+// ProviderSubject returns the value of the "provider_subject" field in the mutation.
+func (m *PendingAuthSessionMutation) ProviderSubject() (r string, exists bool) {
+ v := m.provider_subject
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// OldProviderSubject returns the old "provider_subject" field's value of the PendingAuthSession entity.
+// If the PendingAuthSession object wasn't provided to the builder, the object is fetched from the database.
+// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
+func (m *PendingAuthSessionMutation) OldProviderSubject(ctx context.Context) (v string, err error) {
+ if !m.op.Is(OpUpdateOne) {
+ return v, errors.New("OldProviderSubject is only allowed on UpdateOne operations")
+ }
+ if m.id == nil || m.oldValue == nil {
+ return v, errors.New("OldProviderSubject requires an ID field in the mutation")
+ }
+ oldValue, err := m.oldValue(ctx)
+ if err != nil {
+ return v, fmt.Errorf("querying old value for OldProviderSubject: %w", err)
+ }
+ return oldValue.ProviderSubject, nil
+}
+
+// ResetProviderSubject resets all changes to the "provider_subject" field.
+func (m *PendingAuthSessionMutation) ResetProviderSubject() {
+ m.provider_subject = nil
+}
+
+// SetTargetUserID sets the "target_user_id" field.
+func (m *PendingAuthSessionMutation) SetTargetUserID(i int64) {
+ m.target_user = &i
+}
+
+// TargetUserID returns the value of the "target_user_id" field in the mutation.
+func (m *PendingAuthSessionMutation) TargetUserID() (r int64, exists bool) {
+ v := m.target_user
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// OldTargetUserID returns the old "target_user_id" field's value of the PendingAuthSession entity.
+// If the PendingAuthSession object wasn't provided to the builder, the object is fetched from the database.
+// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
+func (m *PendingAuthSessionMutation) OldTargetUserID(ctx context.Context) (v *int64, err error) {
+ if !m.op.Is(OpUpdateOne) {
+ return v, errors.New("OldTargetUserID is only allowed on UpdateOne operations")
+ }
+ if m.id == nil || m.oldValue == nil {
+ return v, errors.New("OldTargetUserID requires an ID field in the mutation")
+ }
+ oldValue, err := m.oldValue(ctx)
+ if err != nil {
+ return v, fmt.Errorf("querying old value for OldTargetUserID: %w", err)
+ }
+ return oldValue.TargetUserID, nil
+}
+
+// ClearTargetUserID clears the value of the "target_user_id" field.
+func (m *PendingAuthSessionMutation) ClearTargetUserID() {
+ m.target_user = nil
+ m.clearedFields[pendingauthsession.FieldTargetUserID] = struct{}{}
+}
+
+// TargetUserIDCleared returns if the "target_user_id" field was cleared in this mutation.
+func (m *PendingAuthSessionMutation) TargetUserIDCleared() bool {
+ _, ok := m.clearedFields[pendingauthsession.FieldTargetUserID]
+ return ok
+}
+
+// ResetTargetUserID resets all changes to the "target_user_id" field.
+func (m *PendingAuthSessionMutation) ResetTargetUserID() {
+ m.target_user = nil
+ delete(m.clearedFields, pendingauthsession.FieldTargetUserID)
+}
+
+// SetRedirectTo sets the "redirect_to" field.
+func (m *PendingAuthSessionMutation) SetRedirectTo(s string) {
+ m.redirect_to = &s
+}
+
+// RedirectTo returns the value of the "redirect_to" field in the mutation.
+func (m *PendingAuthSessionMutation) RedirectTo() (r string, exists bool) {
+ v := m.redirect_to
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// OldRedirectTo returns the old "redirect_to" field's value of the PendingAuthSession entity.
+// If the PendingAuthSession object wasn't provided to the builder, the object is fetched from the database.
+// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
+func (m *PendingAuthSessionMutation) OldRedirectTo(ctx context.Context) (v string, err error) {
+ if !m.op.Is(OpUpdateOne) {
+ return v, errors.New("OldRedirectTo is only allowed on UpdateOne operations")
+ }
+ if m.id == nil || m.oldValue == nil {
+ return v, errors.New("OldRedirectTo requires an ID field in the mutation")
+ }
+ oldValue, err := m.oldValue(ctx)
+ if err != nil {
+ return v, fmt.Errorf("querying old value for OldRedirectTo: %w", err)
+ }
+ return oldValue.RedirectTo, nil
+}
+
+// ResetRedirectTo resets all changes to the "redirect_to" field.
+func (m *PendingAuthSessionMutation) ResetRedirectTo() {
+ m.redirect_to = nil
+}
+
+// SetResolvedEmail sets the "resolved_email" field.
+func (m *PendingAuthSessionMutation) SetResolvedEmail(s string) {
+ m.resolved_email = &s
+}
+
+// ResolvedEmail returns the value of the "resolved_email" field in the mutation.
+func (m *PendingAuthSessionMutation) ResolvedEmail() (r string, exists bool) {
+ v := m.resolved_email
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// OldResolvedEmail returns the old "resolved_email" field's value of the PendingAuthSession entity.
+// If the PendingAuthSession object wasn't provided to the builder, the object is fetched from the database.
+// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
+func (m *PendingAuthSessionMutation) OldResolvedEmail(ctx context.Context) (v string, err error) {
+ if !m.op.Is(OpUpdateOne) {
+ return v, errors.New("OldResolvedEmail is only allowed on UpdateOne operations")
+ }
+ if m.id == nil || m.oldValue == nil {
+ return v, errors.New("OldResolvedEmail requires an ID field in the mutation")
+ }
+ oldValue, err := m.oldValue(ctx)
+ if err != nil {
+ return v, fmt.Errorf("querying old value for OldResolvedEmail: %w", err)
+ }
+ return oldValue.ResolvedEmail, nil
+}
+
+// ResetResolvedEmail resets all changes to the "resolved_email" field.
+func (m *PendingAuthSessionMutation) ResetResolvedEmail() {
+ m.resolved_email = nil
+}
+
+// SetRegistrationPasswordHash sets the "registration_password_hash" field.
+func (m *PendingAuthSessionMutation) SetRegistrationPasswordHash(s string) {
+ m.registration_password_hash = &s
+}
+
+// RegistrationPasswordHash returns the value of the "registration_password_hash" field in the mutation.
+func (m *PendingAuthSessionMutation) RegistrationPasswordHash() (r string, exists bool) {
+ v := m.registration_password_hash
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// OldRegistrationPasswordHash returns the old "registration_password_hash" field's value of the PendingAuthSession entity.
+// If the PendingAuthSession object wasn't provided to the builder, the object is fetched from the database.
+// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
+func (m *PendingAuthSessionMutation) OldRegistrationPasswordHash(ctx context.Context) (v string, err error) {
+ if !m.op.Is(OpUpdateOne) {
+ return v, errors.New("OldRegistrationPasswordHash is only allowed on UpdateOne operations")
+ }
+ if m.id == nil || m.oldValue == nil {
+ return v, errors.New("OldRegistrationPasswordHash requires an ID field in the mutation")
+ }
+ oldValue, err := m.oldValue(ctx)
+ if err != nil {
+ return v, fmt.Errorf("querying old value for OldRegistrationPasswordHash: %w", err)
+ }
+ return oldValue.RegistrationPasswordHash, nil
+}
+
+// ResetRegistrationPasswordHash resets all changes to the "registration_password_hash" field.
+func (m *PendingAuthSessionMutation) ResetRegistrationPasswordHash() {
+ m.registration_password_hash = nil
+}
+
+// SetUpstreamIdentityClaims sets the "upstream_identity_claims" field.
+func (m *PendingAuthSessionMutation) SetUpstreamIdentityClaims(value map[string]interface{}) {
+ m.upstream_identity_claims = &value
+}
+
+// UpstreamIdentityClaims returns the value of the "upstream_identity_claims" field in the mutation.
+func (m *PendingAuthSessionMutation) UpstreamIdentityClaims() (r map[string]interface{}, exists bool) {
+ v := m.upstream_identity_claims
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// OldUpstreamIdentityClaims returns the old "upstream_identity_claims" field's value of the PendingAuthSession entity.
+// If the PendingAuthSession object wasn't provided to the builder, the object is fetched from the database.
+// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
+func (m *PendingAuthSessionMutation) OldUpstreamIdentityClaims(ctx context.Context) (v map[string]interface{}, err error) {
+ if !m.op.Is(OpUpdateOne) {
+ return v, errors.New("OldUpstreamIdentityClaims is only allowed on UpdateOne operations")
+ }
+ if m.id == nil || m.oldValue == nil {
+ return v, errors.New("OldUpstreamIdentityClaims requires an ID field in the mutation")
+ }
+ oldValue, err := m.oldValue(ctx)
+ if err != nil {
+ return v, fmt.Errorf("querying old value for OldUpstreamIdentityClaims: %w", err)
+ }
+ return oldValue.UpstreamIdentityClaims, nil
+}
+
+// ResetUpstreamIdentityClaims resets all changes to the "upstream_identity_claims" field.
+func (m *PendingAuthSessionMutation) ResetUpstreamIdentityClaims() {
+ m.upstream_identity_claims = nil
+}
+
+// SetLocalFlowState sets the "local_flow_state" field.
+func (m *PendingAuthSessionMutation) SetLocalFlowState(value map[string]interface{}) {
+ m.local_flow_state = &value
+}
+
+// LocalFlowState returns the value of the "local_flow_state" field in the mutation.
+func (m *PendingAuthSessionMutation) LocalFlowState() (r map[string]interface{}, exists bool) {
+ v := m.local_flow_state
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// OldLocalFlowState returns the old "local_flow_state" field's value of the PendingAuthSession entity.
+// If the PendingAuthSession object wasn't provided to the builder, the object is fetched from the database.
+// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
+func (m *PendingAuthSessionMutation) OldLocalFlowState(ctx context.Context) (v map[string]interface{}, err error) {
+ if !m.op.Is(OpUpdateOne) {
+ return v, errors.New("OldLocalFlowState is only allowed on UpdateOne operations")
+ }
+ if m.id == nil || m.oldValue == nil {
+ return v, errors.New("OldLocalFlowState requires an ID field in the mutation")
+ }
+ oldValue, err := m.oldValue(ctx)
+ if err != nil {
+ return v, fmt.Errorf("querying old value for OldLocalFlowState: %w", err)
+ }
+ return oldValue.LocalFlowState, nil
+}
+
+// ResetLocalFlowState resets all changes to the "local_flow_state" field.
+func (m *PendingAuthSessionMutation) ResetLocalFlowState() {
+ m.local_flow_state = nil
+}
+
+// SetBrowserSessionKey sets the "browser_session_key" field.
+func (m *PendingAuthSessionMutation) SetBrowserSessionKey(s string) {
+ m.browser_session_key = &s
+}
+
+// BrowserSessionKey returns the value of the "browser_session_key" field in the mutation.
+func (m *PendingAuthSessionMutation) BrowserSessionKey() (r string, exists bool) {
+ v := m.browser_session_key
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// OldBrowserSessionKey returns the old "browser_session_key" field's value of the PendingAuthSession entity.
+// If the PendingAuthSession object wasn't provided to the builder, the object is fetched from the database.
+// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
+func (m *PendingAuthSessionMutation) OldBrowserSessionKey(ctx context.Context) (v string, err error) {
+ if !m.op.Is(OpUpdateOne) {
+ return v, errors.New("OldBrowserSessionKey is only allowed on UpdateOne operations")
+ }
+ if m.id == nil || m.oldValue == nil {
+ return v, errors.New("OldBrowserSessionKey requires an ID field in the mutation")
+ }
+ oldValue, err := m.oldValue(ctx)
+ if err != nil {
+ return v, fmt.Errorf("querying old value for OldBrowserSessionKey: %w", err)
+ }
+ return oldValue.BrowserSessionKey, nil
+}
+
+// ResetBrowserSessionKey resets all changes to the "browser_session_key" field.
+func (m *PendingAuthSessionMutation) ResetBrowserSessionKey() {
+ m.browser_session_key = nil
+}
+
+// SetCompletionCodeHash sets the "completion_code_hash" field.
+func (m *PendingAuthSessionMutation) SetCompletionCodeHash(s string) {
+ m.completion_code_hash = &s
+}
+
+// CompletionCodeHash returns the value of the "completion_code_hash" field in the mutation.
+func (m *PendingAuthSessionMutation) CompletionCodeHash() (r string, exists bool) {
+ v := m.completion_code_hash
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// OldCompletionCodeHash returns the old "completion_code_hash" field's value of the PendingAuthSession entity.
+// If the PendingAuthSession object wasn't provided to the builder, the object is fetched from the database.
+// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
+func (m *PendingAuthSessionMutation) OldCompletionCodeHash(ctx context.Context) (v string, err error) {
+ if !m.op.Is(OpUpdateOne) {
+ return v, errors.New("OldCompletionCodeHash is only allowed on UpdateOne operations")
+ }
+ if m.id == nil || m.oldValue == nil {
+ return v, errors.New("OldCompletionCodeHash requires an ID field in the mutation")
+ }
+ oldValue, err := m.oldValue(ctx)
+ if err != nil {
+ return v, fmt.Errorf("querying old value for OldCompletionCodeHash: %w", err)
+ }
+ return oldValue.CompletionCodeHash, nil
+}
+
+// ResetCompletionCodeHash resets all changes to the "completion_code_hash" field.
+func (m *PendingAuthSessionMutation) ResetCompletionCodeHash() {
+ m.completion_code_hash = nil
+}
+
+// SetCompletionCodeExpiresAt sets the "completion_code_expires_at" field.
+func (m *PendingAuthSessionMutation) SetCompletionCodeExpiresAt(t time.Time) {
+ m.completion_code_expires_at = &t
+}
+
+// CompletionCodeExpiresAt returns the value of the "completion_code_expires_at" field in the mutation.
+func (m *PendingAuthSessionMutation) CompletionCodeExpiresAt() (r time.Time, exists bool) {
+ v := m.completion_code_expires_at
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// OldCompletionCodeExpiresAt returns the old "completion_code_expires_at" field's value of the PendingAuthSession entity.
+// If the PendingAuthSession object wasn't provided to the builder, the object is fetched from the database.
+// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
+func (m *PendingAuthSessionMutation) OldCompletionCodeExpiresAt(ctx context.Context) (v *time.Time, err error) {
+ if !m.op.Is(OpUpdateOne) {
+ return v, errors.New("OldCompletionCodeExpiresAt is only allowed on UpdateOne operations")
+ }
+ if m.id == nil || m.oldValue == nil {
+ return v, errors.New("OldCompletionCodeExpiresAt requires an ID field in the mutation")
+ }
+ oldValue, err := m.oldValue(ctx)
+ if err != nil {
+ return v, fmt.Errorf("querying old value for OldCompletionCodeExpiresAt: %w", err)
+ }
+ return oldValue.CompletionCodeExpiresAt, nil
+}
+
+// ClearCompletionCodeExpiresAt clears the value of the "completion_code_expires_at" field.
+func (m *PendingAuthSessionMutation) ClearCompletionCodeExpiresAt() {
+ m.completion_code_expires_at = nil
+ m.clearedFields[pendingauthsession.FieldCompletionCodeExpiresAt] = struct{}{}
+}
+
+// CompletionCodeExpiresAtCleared returns if the "completion_code_expires_at" field was cleared in this mutation.
+func (m *PendingAuthSessionMutation) CompletionCodeExpiresAtCleared() bool {
+ _, ok := m.clearedFields[pendingauthsession.FieldCompletionCodeExpiresAt]
+ return ok
+}
+
+// ResetCompletionCodeExpiresAt resets all changes to the "completion_code_expires_at" field.
+func (m *PendingAuthSessionMutation) ResetCompletionCodeExpiresAt() {
+ m.completion_code_expires_at = nil
+ delete(m.clearedFields, pendingauthsession.FieldCompletionCodeExpiresAt)
+}
+
+// SetEmailVerifiedAt sets the "email_verified_at" field.
+func (m *PendingAuthSessionMutation) SetEmailVerifiedAt(t time.Time) {
+ m.email_verified_at = &t
+}
+
+// EmailVerifiedAt returns the value of the "email_verified_at" field in the mutation.
+func (m *PendingAuthSessionMutation) EmailVerifiedAt() (r time.Time, exists bool) {
+ v := m.email_verified_at
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// OldEmailVerifiedAt returns the old "email_verified_at" field's value of the PendingAuthSession entity.
+// If the PendingAuthSession object wasn't provided to the builder, the object is fetched from the database.
+// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
+func (m *PendingAuthSessionMutation) OldEmailVerifiedAt(ctx context.Context) (v *time.Time, err error) {
+ if !m.op.Is(OpUpdateOne) {
+ return v, errors.New("OldEmailVerifiedAt is only allowed on UpdateOne operations")
+ }
+ if m.id == nil || m.oldValue == nil {
+ return v, errors.New("OldEmailVerifiedAt requires an ID field in the mutation")
+ }
+ oldValue, err := m.oldValue(ctx)
+ if err != nil {
+ return v, fmt.Errorf("querying old value for OldEmailVerifiedAt: %w", err)
+ }
+ return oldValue.EmailVerifiedAt, nil
+}
+
+// ClearEmailVerifiedAt clears the value of the "email_verified_at" field.
+func (m *PendingAuthSessionMutation) ClearEmailVerifiedAt() {
+ m.email_verified_at = nil
+ m.clearedFields[pendingauthsession.FieldEmailVerifiedAt] = struct{}{}
+}
+
+// EmailVerifiedAtCleared returns if the "email_verified_at" field was cleared in this mutation.
+func (m *PendingAuthSessionMutation) EmailVerifiedAtCleared() bool {
+ _, ok := m.clearedFields[pendingauthsession.FieldEmailVerifiedAt]
+ return ok
+}
+
+// ResetEmailVerifiedAt resets all changes to the "email_verified_at" field.
+func (m *PendingAuthSessionMutation) ResetEmailVerifiedAt() {
+ m.email_verified_at = nil
+ delete(m.clearedFields, pendingauthsession.FieldEmailVerifiedAt)
+}
+
+// SetPasswordVerifiedAt sets the "password_verified_at" field.
+func (m *PendingAuthSessionMutation) SetPasswordVerifiedAt(t time.Time) {
+ m.password_verified_at = &t
+}
+
+// PasswordVerifiedAt returns the value of the "password_verified_at" field in the mutation.
+func (m *PendingAuthSessionMutation) PasswordVerifiedAt() (r time.Time, exists bool) {
+ v := m.password_verified_at
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// OldPasswordVerifiedAt returns the old "password_verified_at" field's value of the PendingAuthSession entity.
+// If the PendingAuthSession object wasn't provided to the builder, the object is fetched from the database.
+// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
+func (m *PendingAuthSessionMutation) OldPasswordVerifiedAt(ctx context.Context) (v *time.Time, err error) {
+ if !m.op.Is(OpUpdateOne) {
+ return v, errors.New("OldPasswordVerifiedAt is only allowed on UpdateOne operations")
+ }
+ if m.id == nil || m.oldValue == nil {
+ return v, errors.New("OldPasswordVerifiedAt requires an ID field in the mutation")
+ }
+ oldValue, err := m.oldValue(ctx)
+ if err != nil {
+ return v, fmt.Errorf("querying old value for OldPasswordVerifiedAt: %w", err)
+ }
+ return oldValue.PasswordVerifiedAt, nil
+}
+
+// ClearPasswordVerifiedAt clears the value of the "password_verified_at" field.
+func (m *PendingAuthSessionMutation) ClearPasswordVerifiedAt() {
+ m.password_verified_at = nil
+ m.clearedFields[pendingauthsession.FieldPasswordVerifiedAt] = struct{}{}
+}
+
+// PasswordVerifiedAtCleared returns if the "password_verified_at" field was cleared in this mutation.
+func (m *PendingAuthSessionMutation) PasswordVerifiedAtCleared() bool {
+ _, ok := m.clearedFields[pendingauthsession.FieldPasswordVerifiedAt]
+ return ok
+}
+
+// ResetPasswordVerifiedAt resets all changes to the "password_verified_at" field.
+func (m *PendingAuthSessionMutation) ResetPasswordVerifiedAt() {
+ m.password_verified_at = nil
+ delete(m.clearedFields, pendingauthsession.FieldPasswordVerifiedAt)
+}
+
+// SetTotpVerifiedAt sets the "totp_verified_at" field.
+func (m *PendingAuthSessionMutation) SetTotpVerifiedAt(t time.Time) {
+ m.totp_verified_at = &t
+}
+
+// TotpVerifiedAt returns the value of the "totp_verified_at" field in the mutation.
+func (m *PendingAuthSessionMutation) TotpVerifiedAt() (r time.Time, exists bool) {
+ v := m.totp_verified_at
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// OldTotpVerifiedAt returns the old "totp_verified_at" field's value of the PendingAuthSession entity.
+// If the PendingAuthSession object wasn't provided to the builder, the object is fetched from the database.
+// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
+func (m *PendingAuthSessionMutation) OldTotpVerifiedAt(ctx context.Context) (v *time.Time, err error) {
+ if !m.op.Is(OpUpdateOne) {
+ return v, errors.New("OldTotpVerifiedAt is only allowed on UpdateOne operations")
+ }
+ if m.id == nil || m.oldValue == nil {
+ return v, errors.New("OldTotpVerifiedAt requires an ID field in the mutation")
+ }
+ oldValue, err := m.oldValue(ctx)
+ if err != nil {
+ return v, fmt.Errorf("querying old value for OldTotpVerifiedAt: %w", err)
+ }
+ return oldValue.TotpVerifiedAt, nil
+}
+
+// ClearTotpVerifiedAt clears the value of the "totp_verified_at" field.
+func (m *PendingAuthSessionMutation) ClearTotpVerifiedAt() {
+ m.totp_verified_at = nil
+ m.clearedFields[pendingauthsession.FieldTotpVerifiedAt] = struct{}{}
+}
+
+// TotpVerifiedAtCleared returns if the "totp_verified_at" field was cleared in this mutation.
+func (m *PendingAuthSessionMutation) TotpVerifiedAtCleared() bool {
+ _, ok := m.clearedFields[pendingauthsession.FieldTotpVerifiedAt]
+ return ok
+}
+
+// ResetTotpVerifiedAt resets all changes to the "totp_verified_at" field.
+func (m *PendingAuthSessionMutation) ResetTotpVerifiedAt() {
+ m.totp_verified_at = nil
+ delete(m.clearedFields, pendingauthsession.FieldTotpVerifiedAt)
+}
+
+// SetExpiresAt sets the "expires_at" field.
+func (m *PendingAuthSessionMutation) SetExpiresAt(t time.Time) {
+ m.expires_at = &t
+}
+
+// ExpiresAt returns the value of the "expires_at" field in the mutation.
+func (m *PendingAuthSessionMutation) ExpiresAt() (r time.Time, exists bool) {
+ v := m.expires_at
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// OldExpiresAt returns the old "expires_at" field's value of the PendingAuthSession entity.
+// If the PendingAuthSession object wasn't provided to the builder, the object is fetched from the database.
+// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
+func (m *PendingAuthSessionMutation) OldExpiresAt(ctx context.Context) (v time.Time, err error) {
+ if !m.op.Is(OpUpdateOne) {
+ return v, errors.New("OldExpiresAt is only allowed on UpdateOne operations")
+ }
+ if m.id == nil || m.oldValue == nil {
+ return v, errors.New("OldExpiresAt requires an ID field in the mutation")
+ }
+ oldValue, err := m.oldValue(ctx)
+ if err != nil {
+ return v, fmt.Errorf("querying old value for OldExpiresAt: %w", err)
+ }
+ return oldValue.ExpiresAt, nil
+}
+
+// ResetExpiresAt resets all changes to the "expires_at" field.
+func (m *PendingAuthSessionMutation) ResetExpiresAt() {
+ m.expires_at = nil
+}
+
+// SetConsumedAt sets the "consumed_at" field.
+func (m *PendingAuthSessionMutation) SetConsumedAt(t time.Time) {
+ m.consumed_at = &t
+}
+
+// ConsumedAt returns the value of the "consumed_at" field in the mutation.
+func (m *PendingAuthSessionMutation) ConsumedAt() (r time.Time, exists bool) {
+ v := m.consumed_at
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// OldConsumedAt returns the old "consumed_at" field's value of the PendingAuthSession entity.
+// If the PendingAuthSession object wasn't provided to the builder, the object is fetched from the database.
+// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
+func (m *PendingAuthSessionMutation) OldConsumedAt(ctx context.Context) (v *time.Time, err error) {
+ if !m.op.Is(OpUpdateOne) {
+ return v, errors.New("OldConsumedAt is only allowed on UpdateOne operations")
+ }
+ if m.id == nil || m.oldValue == nil {
+ return v, errors.New("OldConsumedAt requires an ID field in the mutation")
+ }
+ oldValue, err := m.oldValue(ctx)
+ if err != nil {
+ return v, fmt.Errorf("querying old value for OldConsumedAt: %w", err)
+ }
+ return oldValue.ConsumedAt, nil
+}
+
+// ClearConsumedAt clears the value of the "consumed_at" field.
+func (m *PendingAuthSessionMutation) ClearConsumedAt() {
+ m.consumed_at = nil
+ m.clearedFields[pendingauthsession.FieldConsumedAt] = struct{}{}
+}
+
+// ConsumedAtCleared returns if the "consumed_at" field was cleared in this mutation.
+func (m *PendingAuthSessionMutation) ConsumedAtCleared() bool {
+ _, ok := m.clearedFields[pendingauthsession.FieldConsumedAt]
+ return ok
+}
+
+// ResetConsumedAt resets all changes to the "consumed_at" field.
+func (m *PendingAuthSessionMutation) ResetConsumedAt() {
+ m.consumed_at = nil
+ delete(m.clearedFields, pendingauthsession.FieldConsumedAt)
+}
+
+// ClearTargetUser clears the "target_user" edge to the User entity.
+func (m *PendingAuthSessionMutation) ClearTargetUser() {
+ m.clearedtarget_user = true
+ m.clearedFields[pendingauthsession.FieldTargetUserID] = struct{}{}
+}
+
+// TargetUserCleared reports if the "target_user" edge to the User entity was cleared.
+func (m *PendingAuthSessionMutation) TargetUserCleared() bool {
+ return m.TargetUserIDCleared() || m.clearedtarget_user
+}
+
+// TargetUserIDs returns the "target_user" edge IDs in the mutation.
+// Note that IDs always returns len(IDs) <= 1 for unique edges, and you should use
+// TargetUserID instead. It exists only for internal usage by the builders.
+func (m *PendingAuthSessionMutation) TargetUserIDs() (ids []int64) {
+ if id := m.target_user; id != nil {
+ ids = append(ids, *id)
+ }
+ return
+}
+
+// ResetTargetUser resets all changes to the "target_user" edge.
+func (m *PendingAuthSessionMutation) ResetTargetUser() {
+ m.target_user = nil
+ m.clearedtarget_user = false
+}
+
+// SetAdoptionDecisionID sets the "adoption_decision" edge to the IdentityAdoptionDecision entity by id.
+func (m *PendingAuthSessionMutation) SetAdoptionDecisionID(id int64) {
+ m.adoption_decision = &id
+}
+
+// ClearAdoptionDecision clears the "adoption_decision" edge to the IdentityAdoptionDecision entity.
+func (m *PendingAuthSessionMutation) ClearAdoptionDecision() {
+ m.clearedadoption_decision = true
+}
+
+// AdoptionDecisionCleared reports if the "adoption_decision" edge to the IdentityAdoptionDecision entity was cleared.
+func (m *PendingAuthSessionMutation) AdoptionDecisionCleared() bool {
+ return m.clearedadoption_decision
+}
+
+// AdoptionDecisionID returns the "adoption_decision" edge ID in the mutation.
+func (m *PendingAuthSessionMutation) AdoptionDecisionID() (id int64, exists bool) {
+ if m.adoption_decision != nil {
+ return *m.adoption_decision, true
+ }
+ return
+}
+
+// AdoptionDecisionIDs returns the "adoption_decision" edge IDs in the mutation.
+// Note that IDs always returns len(IDs) <= 1 for unique edges, and you should use
+// AdoptionDecisionID instead. It exists only for internal usage by the builders.
+func (m *PendingAuthSessionMutation) AdoptionDecisionIDs() (ids []int64) {
+ if id := m.adoption_decision; id != nil {
+ ids = append(ids, *id)
+ }
+ return
+}
+
+// ResetAdoptionDecision resets all changes to the "adoption_decision" edge.
+func (m *PendingAuthSessionMutation) ResetAdoptionDecision() {
+ m.adoption_decision = nil
+ m.clearedadoption_decision = false
+}
+
+// Where appends a list predicates to the PendingAuthSessionMutation builder.
+func (m *PendingAuthSessionMutation) Where(ps ...predicate.PendingAuthSession) {
+ m.predicates = append(m.predicates, ps...)
+}
+
+// WhereP appends storage-level predicates to the PendingAuthSessionMutation builder. Using this method,
+// users can use type-assertion to append predicates that do not depend on any generated package.
+func (m *PendingAuthSessionMutation) WhereP(ps ...func(*sql.Selector)) {
+ p := make([]predicate.PendingAuthSession, len(ps))
+ for i := range ps {
+ p[i] = ps[i]
+ }
+ m.Where(p...)
+}
+
+// Op returns the operation name.
+func (m *PendingAuthSessionMutation) Op() Op {
+ return m.op
+}
+
+// SetOp allows setting the mutation operation.
+func (m *PendingAuthSessionMutation) SetOp(op Op) {
+ m.op = op
+}
+
+// Type returns the node type of this mutation (PendingAuthSession).
+func (m *PendingAuthSessionMutation) Type() string {
+ return m.typ
+}
+
+// Fields returns all fields that were changed during this mutation. Note that in
+// order to get all numeric fields that were incremented/decremented, call
+// AddedFields().
+func (m *PendingAuthSessionMutation) Fields() []string {
+ fields := make([]string, 0, 21)
+ if m.created_at != nil {
+ fields = append(fields, pendingauthsession.FieldCreatedAt)
+ }
+ if m.updated_at != nil {
+ fields = append(fields, pendingauthsession.FieldUpdatedAt)
+ }
+ if m.session_token != nil {
+ fields = append(fields, pendingauthsession.FieldSessionToken)
+ }
+ if m.intent != nil {
+ fields = append(fields, pendingauthsession.FieldIntent)
+ }
+ if m.provider_type != nil {
+ fields = append(fields, pendingauthsession.FieldProviderType)
+ }
+ if m.provider_key != nil {
+ fields = append(fields, pendingauthsession.FieldProviderKey)
+ }
+ if m.provider_subject != nil {
+ fields = append(fields, pendingauthsession.FieldProviderSubject)
+ }
+ if m.target_user != nil {
+ fields = append(fields, pendingauthsession.FieldTargetUserID)
+ }
+ if m.redirect_to != nil {
+ fields = append(fields, pendingauthsession.FieldRedirectTo)
+ }
+ if m.resolved_email != nil {
+ fields = append(fields, pendingauthsession.FieldResolvedEmail)
+ }
+ if m.registration_password_hash != nil {
+ fields = append(fields, pendingauthsession.FieldRegistrationPasswordHash)
+ }
+ if m.upstream_identity_claims != nil {
+ fields = append(fields, pendingauthsession.FieldUpstreamIdentityClaims)
+ }
+ if m.local_flow_state != nil {
+ fields = append(fields, pendingauthsession.FieldLocalFlowState)
+ }
+ if m.browser_session_key != nil {
+ fields = append(fields, pendingauthsession.FieldBrowserSessionKey)
+ }
+ if m.completion_code_hash != nil {
+ fields = append(fields, pendingauthsession.FieldCompletionCodeHash)
+ }
+ if m.completion_code_expires_at != nil {
+ fields = append(fields, pendingauthsession.FieldCompletionCodeExpiresAt)
+ }
+ if m.email_verified_at != nil {
+ fields = append(fields, pendingauthsession.FieldEmailVerifiedAt)
+ }
+ if m.password_verified_at != nil {
+ fields = append(fields, pendingauthsession.FieldPasswordVerifiedAt)
+ }
+ if m.totp_verified_at != nil {
+ fields = append(fields, pendingauthsession.FieldTotpVerifiedAt)
+ }
+ if m.expires_at != nil {
+ fields = append(fields, pendingauthsession.FieldExpiresAt)
+ }
+ if m.consumed_at != nil {
+ fields = append(fields, pendingauthsession.FieldConsumedAt)
+ }
+ return fields
+}
+
+// Field returns the value of a field with the given name. The second boolean
+// return value indicates that this field was not set, or was not defined in the
+// schema.
+func (m *PendingAuthSessionMutation) Field(name string) (ent.Value, bool) {
+ switch name {
+ case pendingauthsession.FieldCreatedAt:
+ return m.CreatedAt()
+ case pendingauthsession.FieldUpdatedAt:
+ return m.UpdatedAt()
+ case pendingauthsession.FieldSessionToken:
+ return m.SessionToken()
+ case pendingauthsession.FieldIntent:
+ return m.Intent()
+ case pendingauthsession.FieldProviderType:
+ return m.ProviderType()
+ case pendingauthsession.FieldProviderKey:
+ return m.ProviderKey()
+ case pendingauthsession.FieldProviderSubject:
+ return m.ProviderSubject()
+ case pendingauthsession.FieldTargetUserID:
+ return m.TargetUserID()
+ case pendingauthsession.FieldRedirectTo:
+ return m.RedirectTo()
+ case pendingauthsession.FieldResolvedEmail:
+ return m.ResolvedEmail()
+ case pendingauthsession.FieldRegistrationPasswordHash:
+ return m.RegistrationPasswordHash()
+ case pendingauthsession.FieldUpstreamIdentityClaims:
+ return m.UpstreamIdentityClaims()
+ case pendingauthsession.FieldLocalFlowState:
+ return m.LocalFlowState()
+ case pendingauthsession.FieldBrowserSessionKey:
+ return m.BrowserSessionKey()
+ case pendingauthsession.FieldCompletionCodeHash:
+ return m.CompletionCodeHash()
+ case pendingauthsession.FieldCompletionCodeExpiresAt:
+ return m.CompletionCodeExpiresAt()
+ case pendingauthsession.FieldEmailVerifiedAt:
+ return m.EmailVerifiedAt()
+ case pendingauthsession.FieldPasswordVerifiedAt:
+ return m.PasswordVerifiedAt()
+ case pendingauthsession.FieldTotpVerifiedAt:
+ return m.TotpVerifiedAt()
+ case pendingauthsession.FieldExpiresAt:
+ return m.ExpiresAt()
+ case pendingauthsession.FieldConsumedAt:
+ return m.ConsumedAt()
+ }
+ return nil, false
+}
+
+// OldField returns the old value of the field from the database. An error is
+// returned if the mutation operation is not UpdateOne, or the query to the
+// database failed.
+func (m *PendingAuthSessionMutation) OldField(ctx context.Context, name string) (ent.Value, error) {
+ switch name {
+ case pendingauthsession.FieldCreatedAt:
+ return m.OldCreatedAt(ctx)
+ case pendingauthsession.FieldUpdatedAt:
+ return m.OldUpdatedAt(ctx)
+ case pendingauthsession.FieldSessionToken:
+ return m.OldSessionToken(ctx)
+ case pendingauthsession.FieldIntent:
+ return m.OldIntent(ctx)
+ case pendingauthsession.FieldProviderType:
+ return m.OldProviderType(ctx)
+ case pendingauthsession.FieldProviderKey:
+ return m.OldProviderKey(ctx)
+ case pendingauthsession.FieldProviderSubject:
+ return m.OldProviderSubject(ctx)
+ case pendingauthsession.FieldTargetUserID:
+ return m.OldTargetUserID(ctx)
+ case pendingauthsession.FieldRedirectTo:
+ return m.OldRedirectTo(ctx)
+ case pendingauthsession.FieldResolvedEmail:
+ return m.OldResolvedEmail(ctx)
+ case pendingauthsession.FieldRegistrationPasswordHash:
+ return m.OldRegistrationPasswordHash(ctx)
+ case pendingauthsession.FieldUpstreamIdentityClaims:
+ return m.OldUpstreamIdentityClaims(ctx)
+ case pendingauthsession.FieldLocalFlowState:
+ return m.OldLocalFlowState(ctx)
+ case pendingauthsession.FieldBrowserSessionKey:
+ return m.OldBrowserSessionKey(ctx)
+ case pendingauthsession.FieldCompletionCodeHash:
+ return m.OldCompletionCodeHash(ctx)
+ case pendingauthsession.FieldCompletionCodeExpiresAt:
+ return m.OldCompletionCodeExpiresAt(ctx)
+ case pendingauthsession.FieldEmailVerifiedAt:
+ return m.OldEmailVerifiedAt(ctx)
+ case pendingauthsession.FieldPasswordVerifiedAt:
+ return m.OldPasswordVerifiedAt(ctx)
+ case pendingauthsession.FieldTotpVerifiedAt:
+ return m.OldTotpVerifiedAt(ctx)
+ case pendingauthsession.FieldExpiresAt:
+ return m.OldExpiresAt(ctx)
+ case pendingauthsession.FieldConsumedAt:
+ return m.OldConsumedAt(ctx)
+ }
+ return nil, fmt.Errorf("unknown PendingAuthSession field %s", name)
+}
+
+// SetField sets the value of a field with the given name. It returns an error if
+// the field is not defined in the schema, or if the type mismatched the field
+// type.
+func (m *PendingAuthSessionMutation) SetField(name string, value ent.Value) error {
+ switch name {
+ case pendingauthsession.FieldCreatedAt:
+ v, ok := value.(time.Time)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetCreatedAt(v)
+ return nil
+ case pendingauthsession.FieldUpdatedAt:
+ v, ok := value.(time.Time)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetUpdatedAt(v)
+ return nil
+ case pendingauthsession.FieldSessionToken:
+ v, ok := value.(string)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetSessionToken(v)
+ return nil
+ case pendingauthsession.FieldIntent:
+ v, ok := value.(string)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetIntent(v)
+ return nil
+ case pendingauthsession.FieldProviderType:
+ v, ok := value.(string)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetProviderType(v)
+ return nil
+ case pendingauthsession.FieldProviderKey:
+ v, ok := value.(string)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetProviderKey(v)
+ return nil
+ case pendingauthsession.FieldProviderSubject:
+ v, ok := value.(string)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetProviderSubject(v)
+ return nil
+ case pendingauthsession.FieldTargetUserID:
+ v, ok := value.(int64)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetTargetUserID(v)
+ return nil
+ case pendingauthsession.FieldRedirectTo:
+ v, ok := value.(string)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetRedirectTo(v)
+ return nil
+ case pendingauthsession.FieldResolvedEmail:
+ v, ok := value.(string)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetResolvedEmail(v)
+ return nil
+ case pendingauthsession.FieldRegistrationPasswordHash:
+ v, ok := value.(string)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetRegistrationPasswordHash(v)
+ return nil
+ case pendingauthsession.FieldUpstreamIdentityClaims:
+ v, ok := value.(map[string]interface{})
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetUpstreamIdentityClaims(v)
+ return nil
+ case pendingauthsession.FieldLocalFlowState:
+ v, ok := value.(map[string]interface{})
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetLocalFlowState(v)
+ return nil
+ case pendingauthsession.FieldBrowserSessionKey:
+ v, ok := value.(string)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetBrowserSessionKey(v)
+ return nil
+ case pendingauthsession.FieldCompletionCodeHash:
+ v, ok := value.(string)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetCompletionCodeHash(v)
+ return nil
+ case pendingauthsession.FieldCompletionCodeExpiresAt:
+ v, ok := value.(time.Time)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetCompletionCodeExpiresAt(v)
+ return nil
+ case pendingauthsession.FieldEmailVerifiedAt:
+ v, ok := value.(time.Time)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetEmailVerifiedAt(v)
+ return nil
+ case pendingauthsession.FieldPasswordVerifiedAt:
+ v, ok := value.(time.Time)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetPasswordVerifiedAt(v)
+ return nil
+ case pendingauthsession.FieldTotpVerifiedAt:
+ v, ok := value.(time.Time)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetTotpVerifiedAt(v)
+ return nil
+ case pendingauthsession.FieldExpiresAt:
+ v, ok := value.(time.Time)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetExpiresAt(v)
+ return nil
+ case pendingauthsession.FieldConsumedAt:
+ v, ok := value.(time.Time)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetConsumedAt(v)
+ return nil
+ }
+ return fmt.Errorf("unknown PendingAuthSession field %s", name)
+}
+
+// AddedFields returns all numeric fields that were incremented/decremented during
+// this mutation.
+func (m *PendingAuthSessionMutation) AddedFields() []string {
+ var fields []string
+ return fields
+}
+
+// AddedField returns the numeric value that was incremented/decremented on a field
+// with the given name. The second boolean return value indicates that this field
+// was not set, or was not defined in the schema.
+func (m *PendingAuthSessionMutation) AddedField(name string) (ent.Value, bool) {
+ switch name {
+ }
+ return nil, false
+}
+
+// AddField adds the value to the field with the given name. It returns an error if
+// the field is not defined in the schema, or if the type mismatched the field
+// type.
+func (m *PendingAuthSessionMutation) AddField(name string, value ent.Value) error {
+ switch name {
+ }
+ return fmt.Errorf("unknown PendingAuthSession numeric field %s", name)
+}
+
+// ClearedFields returns all nullable fields that were cleared during this
+// mutation.
+func (m *PendingAuthSessionMutation) ClearedFields() []string {
+ var fields []string
+ if m.FieldCleared(pendingauthsession.FieldTargetUserID) {
+ fields = append(fields, pendingauthsession.FieldTargetUserID)
+ }
+ if m.FieldCleared(pendingauthsession.FieldCompletionCodeExpiresAt) {
+ fields = append(fields, pendingauthsession.FieldCompletionCodeExpiresAt)
+ }
+ if m.FieldCleared(pendingauthsession.FieldEmailVerifiedAt) {
+ fields = append(fields, pendingauthsession.FieldEmailVerifiedAt)
+ }
+ if m.FieldCleared(pendingauthsession.FieldPasswordVerifiedAt) {
+ fields = append(fields, pendingauthsession.FieldPasswordVerifiedAt)
+ }
+ if m.FieldCleared(pendingauthsession.FieldTotpVerifiedAt) {
+ fields = append(fields, pendingauthsession.FieldTotpVerifiedAt)
+ }
+ if m.FieldCleared(pendingauthsession.FieldConsumedAt) {
+ fields = append(fields, pendingauthsession.FieldConsumedAt)
+ }
+ return fields
+}
+
+// FieldCleared returns a boolean indicating if a field with the given name was
+// cleared in this mutation.
+func (m *PendingAuthSessionMutation) FieldCleared(name string) bool {
+ _, ok := m.clearedFields[name]
+ return ok
+}
+
+// ClearField clears the value of the field with the given name. It returns an
+// error if the field is not defined in the schema.
+func (m *PendingAuthSessionMutation) ClearField(name string) error {
+ switch name {
+ case pendingauthsession.FieldTargetUserID:
+ m.ClearTargetUserID()
+ return nil
+ case pendingauthsession.FieldCompletionCodeExpiresAt:
+ m.ClearCompletionCodeExpiresAt()
+ return nil
+ case pendingauthsession.FieldEmailVerifiedAt:
+ m.ClearEmailVerifiedAt()
+ return nil
+ case pendingauthsession.FieldPasswordVerifiedAt:
+ m.ClearPasswordVerifiedAt()
+ return nil
+ case pendingauthsession.FieldTotpVerifiedAt:
+ m.ClearTotpVerifiedAt()
+ return nil
+ case pendingauthsession.FieldConsumedAt:
+ m.ClearConsumedAt()
+ return nil
+ }
+ return fmt.Errorf("unknown PendingAuthSession nullable field %s", name)
+}
+
+// ResetField resets all changes in the mutation for the field with the given name.
+// It returns an error if the field is not defined in the schema.
+func (m *PendingAuthSessionMutation) ResetField(name string) error {
+ switch name {
+ case pendingauthsession.FieldCreatedAt:
+ m.ResetCreatedAt()
+ return nil
+ case pendingauthsession.FieldUpdatedAt:
+ m.ResetUpdatedAt()
+ return nil
+ case pendingauthsession.FieldSessionToken:
+ m.ResetSessionToken()
+ return nil
+ case pendingauthsession.FieldIntent:
+ m.ResetIntent()
+ return nil
+ case pendingauthsession.FieldProviderType:
+ m.ResetProviderType()
+ return nil
+ case pendingauthsession.FieldProviderKey:
+ m.ResetProviderKey()
+ return nil
+ case pendingauthsession.FieldProviderSubject:
+ m.ResetProviderSubject()
+ return nil
+ case pendingauthsession.FieldTargetUserID:
+ m.ResetTargetUserID()
+ return nil
+ case pendingauthsession.FieldRedirectTo:
+ m.ResetRedirectTo()
+ return nil
+ case pendingauthsession.FieldResolvedEmail:
+ m.ResetResolvedEmail()
+ return nil
+ case pendingauthsession.FieldRegistrationPasswordHash:
+ m.ResetRegistrationPasswordHash()
+ return nil
+ case pendingauthsession.FieldUpstreamIdentityClaims:
+ m.ResetUpstreamIdentityClaims()
+ return nil
+ case pendingauthsession.FieldLocalFlowState:
+ m.ResetLocalFlowState()
+ return nil
+ case pendingauthsession.FieldBrowserSessionKey:
+ m.ResetBrowserSessionKey()
+ return nil
+ case pendingauthsession.FieldCompletionCodeHash:
+ m.ResetCompletionCodeHash()
+ return nil
+ case pendingauthsession.FieldCompletionCodeExpiresAt:
+ m.ResetCompletionCodeExpiresAt()
+ return nil
+ case pendingauthsession.FieldEmailVerifiedAt:
+ m.ResetEmailVerifiedAt()
+ return nil
+ case pendingauthsession.FieldPasswordVerifiedAt:
+ m.ResetPasswordVerifiedAt()
+ return nil
+ case pendingauthsession.FieldTotpVerifiedAt:
+ m.ResetTotpVerifiedAt()
+ return nil
+ case pendingauthsession.FieldExpiresAt:
+ m.ResetExpiresAt()
+ return nil
+ case pendingauthsession.FieldConsumedAt:
+ m.ResetConsumedAt()
+ return nil
+ }
+ return fmt.Errorf("unknown PendingAuthSession field %s", name)
+}
+
+// AddedEdges returns all edge names that were set/added in this mutation.
+func (m *PendingAuthSessionMutation) AddedEdges() []string {
+ edges := make([]string, 0, 2)
+ if m.target_user != nil {
+ edges = append(edges, pendingauthsession.EdgeTargetUser)
+ }
+ if m.adoption_decision != nil {
+ edges = append(edges, pendingauthsession.EdgeAdoptionDecision)
+ }
+ return edges
+}
+
+// AddedIDs returns all IDs (to other nodes) that were added for the given edge
+// name in this mutation.
+func (m *PendingAuthSessionMutation) AddedIDs(name string) []ent.Value {
+ switch name {
+ case pendingauthsession.EdgeTargetUser:
+ if id := m.target_user; id != nil {
+ return []ent.Value{*id}
+ }
+ case pendingauthsession.EdgeAdoptionDecision:
+ if id := m.adoption_decision; id != nil {
+ return []ent.Value{*id}
+ }
+ }
+ return nil
+}
+
+// RemovedEdges returns all edge names that were removed in this mutation.
+func (m *PendingAuthSessionMutation) RemovedEdges() []string {
+ edges := make([]string, 0, 2)
+ return edges
+}
+
+// RemovedIDs returns all IDs (to other nodes) that were removed for the edge with
+// the given name in this mutation.
+func (m *PendingAuthSessionMutation) RemovedIDs(name string) []ent.Value {
+ return nil
+}
+
+// ClearedEdges returns all edge names that were cleared in this mutation.
+func (m *PendingAuthSessionMutation) ClearedEdges() []string {
+ edges := make([]string, 0, 2)
+ if m.clearedtarget_user {
+ edges = append(edges, pendingauthsession.EdgeTargetUser)
+ }
+ if m.clearedadoption_decision {
+ edges = append(edges, pendingauthsession.EdgeAdoptionDecision)
+ }
+ return edges
+}
+
+// EdgeCleared returns a boolean which indicates if the edge with the given name
+// was cleared in this mutation.
+func (m *PendingAuthSessionMutation) EdgeCleared(name string) bool {
+ switch name {
+ case pendingauthsession.EdgeTargetUser:
+ return m.clearedtarget_user
+ case pendingauthsession.EdgeAdoptionDecision:
+ return m.clearedadoption_decision
+ }
+ return false
+}
+
+// ClearEdge clears the value of the edge with the given name. It returns an error
+// if that edge is not defined in the schema.
+func (m *PendingAuthSessionMutation) ClearEdge(name string) error {
+ switch name {
+ case pendingauthsession.EdgeTargetUser:
+ m.ClearTargetUser()
+ return nil
+ case pendingauthsession.EdgeAdoptionDecision:
+ m.ClearAdoptionDecision()
+ return nil
+ }
+ return fmt.Errorf("unknown PendingAuthSession unique edge %s", name)
+}
+
+// ResetEdge resets all changes to the edge with the given name in this mutation.
+// It returns an error if the edge is not defined in the schema.
+func (m *PendingAuthSessionMutation) ResetEdge(name string) error {
+ switch name {
+ case pendingauthsession.EdgeTargetUser:
+ m.ResetTargetUser()
+ return nil
+ case pendingauthsession.EdgeAdoptionDecision:
+ m.ResetAdoptionDecision()
+ return nil
+ }
+ return fmt.Errorf("unknown PendingAuthSession edge %s", name)
+}
+
// PromoCodeMutation represents an operation that mutates the PromoCode nodes in the graph.
type PromoCodeMutation struct {
config
@@ -28264,6 +32525,9 @@ type UserMutation struct {
totp_secret_encrypted *string
totp_enabled *bool
totp_enabled_at *time.Time
+ signup_source *string
+ last_login_at *time.Time
+ last_active_at *time.Time
balance_notify_enabled *bool
balance_notify_threshold_type *string
balance_notify_threshold *float64
@@ -28302,6 +32566,12 @@ type UserMutation struct {
payment_orders map[int64]struct{}
removedpayment_orders map[int64]struct{}
clearedpayment_orders bool
+ auth_identities map[int64]struct{}
+ removedauth_identities map[int64]struct{}
+ clearedauth_identities bool
+ pending_auth_sessions map[int64]struct{}
+ removedpending_auth_sessions map[int64]struct{}
+ clearedpending_auth_sessions bool
done bool
oldValue func(context.Context) (*User, error)
predicates []predicate.User
@@ -28988,6 +33258,140 @@ func (m *UserMutation) ResetTotpEnabledAt() {
delete(m.clearedFields, user.FieldTotpEnabledAt)
}
+// SetSignupSource sets the "signup_source" field.
+func (m *UserMutation) SetSignupSource(s string) {
+ m.signup_source = &s
+}
+
+// SignupSource returns the value of the "signup_source" field in the mutation.
+func (m *UserMutation) SignupSource() (r string, exists bool) {
+ v := m.signup_source
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// OldSignupSource returns the old "signup_source" field's value of the User entity.
+// If the User object wasn't provided to the builder, the object is fetched from the database.
+// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
+func (m *UserMutation) OldSignupSource(ctx context.Context) (v string, err error) {
+ if !m.op.Is(OpUpdateOne) {
+ return v, errors.New("OldSignupSource is only allowed on UpdateOne operations")
+ }
+ if m.id == nil || m.oldValue == nil {
+ return v, errors.New("OldSignupSource requires an ID field in the mutation")
+ }
+ oldValue, err := m.oldValue(ctx)
+ if err != nil {
+ return v, fmt.Errorf("querying old value for OldSignupSource: %w", err)
+ }
+ return oldValue.SignupSource, nil
+}
+
+// ResetSignupSource resets all changes to the "signup_source" field.
+func (m *UserMutation) ResetSignupSource() {
+ m.signup_source = nil
+}
+
+// SetLastLoginAt sets the "last_login_at" field.
+func (m *UserMutation) SetLastLoginAt(t time.Time) {
+ m.last_login_at = &t
+}
+
+// LastLoginAt returns the value of the "last_login_at" field in the mutation.
+func (m *UserMutation) LastLoginAt() (r time.Time, exists bool) {
+ v := m.last_login_at
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// OldLastLoginAt returns the old "last_login_at" field's value of the User entity.
+// If the User object wasn't provided to the builder, the object is fetched from the database.
+// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
+func (m *UserMutation) OldLastLoginAt(ctx context.Context) (v *time.Time, err error) {
+ if !m.op.Is(OpUpdateOne) {
+ return v, errors.New("OldLastLoginAt is only allowed on UpdateOne operations")
+ }
+ if m.id == nil || m.oldValue == nil {
+ return v, errors.New("OldLastLoginAt requires an ID field in the mutation")
+ }
+ oldValue, err := m.oldValue(ctx)
+ if err != nil {
+ return v, fmt.Errorf("querying old value for OldLastLoginAt: %w", err)
+ }
+ return oldValue.LastLoginAt, nil
+}
+
+// ClearLastLoginAt clears the value of the "last_login_at" field.
+func (m *UserMutation) ClearLastLoginAt() {
+ m.last_login_at = nil
+ m.clearedFields[user.FieldLastLoginAt] = struct{}{}
+}
+
+// LastLoginAtCleared returns if the "last_login_at" field was cleared in this mutation.
+func (m *UserMutation) LastLoginAtCleared() bool {
+ _, ok := m.clearedFields[user.FieldLastLoginAt]
+ return ok
+}
+
+// ResetLastLoginAt resets all changes to the "last_login_at" field.
+func (m *UserMutation) ResetLastLoginAt() {
+ m.last_login_at = nil
+ delete(m.clearedFields, user.FieldLastLoginAt)
+}
+
+// SetLastActiveAt sets the "last_active_at" field.
+func (m *UserMutation) SetLastActiveAt(t time.Time) {
+ m.last_active_at = &t
+}
+
+// LastActiveAt returns the value of the "last_active_at" field in the mutation.
+func (m *UserMutation) LastActiveAt() (r time.Time, exists bool) {
+ v := m.last_active_at
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// OldLastActiveAt returns the old "last_active_at" field's value of the User entity.
+// If the User object wasn't provided to the builder, the object is fetched from the database.
+// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
+func (m *UserMutation) OldLastActiveAt(ctx context.Context) (v *time.Time, err error) {
+ if !m.op.Is(OpUpdateOne) {
+ return v, errors.New("OldLastActiveAt is only allowed on UpdateOne operations")
+ }
+ if m.id == nil || m.oldValue == nil {
+ return v, errors.New("OldLastActiveAt requires an ID field in the mutation")
+ }
+ oldValue, err := m.oldValue(ctx)
+ if err != nil {
+ return v, fmt.Errorf("querying old value for OldLastActiveAt: %w", err)
+ }
+ return oldValue.LastActiveAt, nil
+}
+
+// ClearLastActiveAt clears the value of the "last_active_at" field.
+func (m *UserMutation) ClearLastActiveAt() {
+ m.last_active_at = nil
+ m.clearedFields[user.FieldLastActiveAt] = struct{}{}
+}
+
+// LastActiveAtCleared returns if the "last_active_at" field was cleared in this mutation.
+func (m *UserMutation) LastActiveAtCleared() bool {
+ _, ok := m.clearedFields[user.FieldLastActiveAt]
+ return ok
+}
+
+// ResetLastActiveAt resets all changes to the "last_active_at" field.
+func (m *UserMutation) ResetLastActiveAt() {
+ m.last_active_at = nil
+ delete(m.clearedFields, user.FieldLastActiveAt)
+}
+
// SetBalanceNotifyEnabled sets the "balance_notify_enabled" field.
func (m *UserMutation) SetBalanceNotifyEnabled(b bool) {
m.balance_notify_enabled = &b
@@ -29762,6 +34166,114 @@ func (m *UserMutation) ResetPaymentOrders() {
m.removedpayment_orders = nil
}
+// AddAuthIdentityIDs adds the "auth_identities" edge to the AuthIdentity entity by ids.
+func (m *UserMutation) AddAuthIdentityIDs(ids ...int64) {
+ if m.auth_identities == nil {
+ m.auth_identities = make(map[int64]struct{})
+ }
+ for i := range ids {
+ m.auth_identities[ids[i]] = struct{}{}
+ }
+}
+
+// ClearAuthIdentities clears the "auth_identities" edge to the AuthIdentity entity.
+func (m *UserMutation) ClearAuthIdentities() {
+ m.clearedauth_identities = true
+}
+
+// AuthIdentitiesCleared reports if the "auth_identities" edge to the AuthIdentity entity was cleared.
+func (m *UserMutation) AuthIdentitiesCleared() bool {
+ return m.clearedauth_identities
+}
+
+// RemoveAuthIdentityIDs removes the "auth_identities" edge to the AuthIdentity entity by IDs.
+func (m *UserMutation) RemoveAuthIdentityIDs(ids ...int64) {
+ if m.removedauth_identities == nil {
+ m.removedauth_identities = make(map[int64]struct{})
+ }
+ for i := range ids {
+ delete(m.auth_identities, ids[i])
+ m.removedauth_identities[ids[i]] = struct{}{}
+ }
+}
+
+// RemovedAuthIdentities returns the removed IDs of the "auth_identities" edge to the AuthIdentity entity.
+func (m *UserMutation) RemovedAuthIdentitiesIDs() (ids []int64) {
+ for id := range m.removedauth_identities {
+ ids = append(ids, id)
+ }
+ return
+}
+
+// AuthIdentitiesIDs returns the "auth_identities" edge IDs in the mutation.
+func (m *UserMutation) AuthIdentitiesIDs() (ids []int64) {
+ for id := range m.auth_identities {
+ ids = append(ids, id)
+ }
+ return
+}
+
+// ResetAuthIdentities resets all changes to the "auth_identities" edge.
+func (m *UserMutation) ResetAuthIdentities() {
+ m.auth_identities = nil
+ m.clearedauth_identities = false
+ m.removedauth_identities = nil
+}
+
+// AddPendingAuthSessionIDs adds the "pending_auth_sessions" edge to the PendingAuthSession entity by ids.
+func (m *UserMutation) AddPendingAuthSessionIDs(ids ...int64) {
+ if m.pending_auth_sessions == nil {
+ m.pending_auth_sessions = make(map[int64]struct{})
+ }
+ for i := range ids {
+ m.pending_auth_sessions[ids[i]] = struct{}{}
+ }
+}
+
+// ClearPendingAuthSessions clears the "pending_auth_sessions" edge to the PendingAuthSession entity.
+func (m *UserMutation) ClearPendingAuthSessions() {
+ m.clearedpending_auth_sessions = true
+}
+
+// PendingAuthSessionsCleared reports if the "pending_auth_sessions" edge to the PendingAuthSession entity was cleared.
+func (m *UserMutation) PendingAuthSessionsCleared() bool {
+ return m.clearedpending_auth_sessions
+}
+
+// RemovePendingAuthSessionIDs removes the "pending_auth_sessions" edge to the PendingAuthSession entity by IDs.
+func (m *UserMutation) RemovePendingAuthSessionIDs(ids ...int64) {
+ if m.removedpending_auth_sessions == nil {
+ m.removedpending_auth_sessions = make(map[int64]struct{})
+ }
+ for i := range ids {
+ delete(m.pending_auth_sessions, ids[i])
+ m.removedpending_auth_sessions[ids[i]] = struct{}{}
+ }
+}
+
+// RemovedPendingAuthSessions returns the removed IDs of the "pending_auth_sessions" edge to the PendingAuthSession entity.
+func (m *UserMutation) RemovedPendingAuthSessionsIDs() (ids []int64) {
+ for id := range m.removedpending_auth_sessions {
+ ids = append(ids, id)
+ }
+ return
+}
+
+// PendingAuthSessionsIDs returns the "pending_auth_sessions" edge IDs in the mutation.
+func (m *UserMutation) PendingAuthSessionsIDs() (ids []int64) {
+ for id := range m.pending_auth_sessions {
+ ids = append(ids, id)
+ }
+ return
+}
+
+// ResetPendingAuthSessions resets all changes to the "pending_auth_sessions" edge.
+func (m *UserMutation) ResetPendingAuthSessions() {
+ m.pending_auth_sessions = nil
+ m.clearedpending_auth_sessions = false
+ m.removedpending_auth_sessions = nil
+}
+
// Where appends a list predicates to the UserMutation builder.
func (m *UserMutation) Where(ps ...predicate.User) {
m.predicates = append(m.predicates, ps...)
@@ -29796,7 +34308,7 @@ func (m *UserMutation) Type() string {
// order to get all numeric fields that were incremented/decremented, call
// AddedFields().
func (m *UserMutation) Fields() []string {
- fields := make([]string, 0, 19)
+ fields := make([]string, 0, 22)
if m.created_at != nil {
fields = append(fields, user.FieldCreatedAt)
}
@@ -29839,6 +34351,15 @@ func (m *UserMutation) Fields() []string {
if m.totp_enabled_at != nil {
fields = append(fields, user.FieldTotpEnabledAt)
}
+ if m.signup_source != nil {
+ fields = append(fields, user.FieldSignupSource)
+ }
+ if m.last_login_at != nil {
+ fields = append(fields, user.FieldLastLoginAt)
+ }
+ if m.last_active_at != nil {
+ fields = append(fields, user.FieldLastActiveAt)
+ }
if m.balance_notify_enabled != nil {
fields = append(fields, user.FieldBalanceNotifyEnabled)
}
@@ -29890,6 +34411,12 @@ func (m *UserMutation) Field(name string) (ent.Value, bool) {
return m.TotpEnabled()
case user.FieldTotpEnabledAt:
return m.TotpEnabledAt()
+ case user.FieldSignupSource:
+ return m.SignupSource()
+ case user.FieldLastLoginAt:
+ return m.LastLoginAt()
+ case user.FieldLastActiveAt:
+ return m.LastActiveAt()
case user.FieldBalanceNotifyEnabled:
return m.BalanceNotifyEnabled()
case user.FieldBalanceNotifyThresholdType:
@@ -29937,6 +34464,12 @@ func (m *UserMutation) OldField(ctx context.Context, name string) (ent.Value, er
return m.OldTotpEnabled(ctx)
case user.FieldTotpEnabledAt:
return m.OldTotpEnabledAt(ctx)
+ case user.FieldSignupSource:
+ return m.OldSignupSource(ctx)
+ case user.FieldLastLoginAt:
+ return m.OldLastLoginAt(ctx)
+ case user.FieldLastActiveAt:
+ return m.OldLastActiveAt(ctx)
case user.FieldBalanceNotifyEnabled:
return m.OldBalanceNotifyEnabled(ctx)
case user.FieldBalanceNotifyThresholdType:
@@ -30054,6 +34587,27 @@ func (m *UserMutation) SetField(name string, value ent.Value) error {
}
m.SetTotpEnabledAt(v)
return nil
+ case user.FieldSignupSource:
+ v, ok := value.(string)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetSignupSource(v)
+ return nil
+ case user.FieldLastLoginAt:
+ v, ok := value.(time.Time)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetLastLoginAt(v)
+ return nil
+ case user.FieldLastActiveAt:
+ v, ok := value.(time.Time)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetLastActiveAt(v)
+ return nil
case user.FieldBalanceNotifyEnabled:
v, ok := value.(bool)
if !ok {
@@ -30179,6 +34733,12 @@ func (m *UserMutation) ClearedFields() []string {
if m.FieldCleared(user.FieldTotpEnabledAt) {
fields = append(fields, user.FieldTotpEnabledAt)
}
+ if m.FieldCleared(user.FieldLastLoginAt) {
+ fields = append(fields, user.FieldLastLoginAt)
+ }
+ if m.FieldCleared(user.FieldLastActiveAt) {
+ fields = append(fields, user.FieldLastActiveAt)
+ }
if m.FieldCleared(user.FieldBalanceNotifyThreshold) {
fields = append(fields, user.FieldBalanceNotifyThreshold)
}
@@ -30205,6 +34765,12 @@ func (m *UserMutation) ClearField(name string) error {
case user.FieldTotpEnabledAt:
m.ClearTotpEnabledAt()
return nil
+ case user.FieldLastLoginAt:
+ m.ClearLastLoginAt()
+ return nil
+ case user.FieldLastActiveAt:
+ m.ClearLastActiveAt()
+ return nil
case user.FieldBalanceNotifyThreshold:
m.ClearBalanceNotifyThreshold()
return nil
@@ -30258,6 +34824,15 @@ func (m *UserMutation) ResetField(name string) error {
case user.FieldTotpEnabledAt:
m.ResetTotpEnabledAt()
return nil
+ case user.FieldSignupSource:
+ m.ResetSignupSource()
+ return nil
+ case user.FieldLastLoginAt:
+ m.ResetLastLoginAt()
+ return nil
+ case user.FieldLastActiveAt:
+ m.ResetLastActiveAt()
+ return nil
case user.FieldBalanceNotifyEnabled:
m.ResetBalanceNotifyEnabled()
return nil
@@ -30279,7 +34854,7 @@ func (m *UserMutation) ResetField(name string) error {
// AddedEdges returns all edge names that were set/added in this mutation.
func (m *UserMutation) AddedEdges() []string {
- edges := make([]string, 0, 10)
+ edges := make([]string, 0, 12)
if m.api_keys != nil {
edges = append(edges, user.EdgeAPIKeys)
}
@@ -30310,6 +34885,12 @@ func (m *UserMutation) AddedEdges() []string {
if m.payment_orders != nil {
edges = append(edges, user.EdgePaymentOrders)
}
+ if m.auth_identities != nil {
+ edges = append(edges, user.EdgeAuthIdentities)
+ }
+ if m.pending_auth_sessions != nil {
+ edges = append(edges, user.EdgePendingAuthSessions)
+ }
return edges
}
@@ -30377,13 +34958,25 @@ func (m *UserMutation) AddedIDs(name string) []ent.Value {
ids = append(ids, id)
}
return ids
+ case user.EdgeAuthIdentities:
+ ids := make([]ent.Value, 0, len(m.auth_identities))
+ for id := range m.auth_identities {
+ ids = append(ids, id)
+ }
+ return ids
+ case user.EdgePendingAuthSessions:
+ ids := make([]ent.Value, 0, len(m.pending_auth_sessions))
+ for id := range m.pending_auth_sessions {
+ ids = append(ids, id)
+ }
+ return ids
}
return nil
}
// RemovedEdges returns all edge names that were removed in this mutation.
func (m *UserMutation) RemovedEdges() []string {
- edges := make([]string, 0, 10)
+ edges := make([]string, 0, 12)
if m.removedapi_keys != nil {
edges = append(edges, user.EdgeAPIKeys)
}
@@ -30414,6 +35007,12 @@ func (m *UserMutation) RemovedEdges() []string {
if m.removedpayment_orders != nil {
edges = append(edges, user.EdgePaymentOrders)
}
+ if m.removedauth_identities != nil {
+ edges = append(edges, user.EdgeAuthIdentities)
+ }
+ if m.removedpending_auth_sessions != nil {
+ edges = append(edges, user.EdgePendingAuthSessions)
+ }
return edges
}
@@ -30481,13 +35080,25 @@ func (m *UserMutation) RemovedIDs(name string) []ent.Value {
ids = append(ids, id)
}
return ids
+ case user.EdgeAuthIdentities:
+ ids := make([]ent.Value, 0, len(m.removedauth_identities))
+ for id := range m.removedauth_identities {
+ ids = append(ids, id)
+ }
+ return ids
+ case user.EdgePendingAuthSessions:
+ ids := make([]ent.Value, 0, len(m.removedpending_auth_sessions))
+ for id := range m.removedpending_auth_sessions {
+ ids = append(ids, id)
+ }
+ return ids
}
return nil
}
// ClearedEdges returns all edge names that were cleared in this mutation.
func (m *UserMutation) ClearedEdges() []string {
- edges := make([]string, 0, 10)
+ edges := make([]string, 0, 12)
if m.clearedapi_keys {
edges = append(edges, user.EdgeAPIKeys)
}
@@ -30518,6 +35129,12 @@ func (m *UserMutation) ClearedEdges() []string {
if m.clearedpayment_orders {
edges = append(edges, user.EdgePaymentOrders)
}
+ if m.clearedauth_identities {
+ edges = append(edges, user.EdgeAuthIdentities)
+ }
+ if m.clearedpending_auth_sessions {
+ edges = append(edges, user.EdgePendingAuthSessions)
+ }
return edges
}
@@ -30545,6 +35162,10 @@ func (m *UserMutation) EdgeCleared(name string) bool {
return m.clearedpromo_code_usages
case user.EdgePaymentOrders:
return m.clearedpayment_orders
+ case user.EdgeAuthIdentities:
+ return m.clearedauth_identities
+ case user.EdgePendingAuthSessions:
+ return m.clearedpending_auth_sessions
}
return false
}
@@ -30591,6 +35212,12 @@ func (m *UserMutation) ResetEdge(name string) error {
case user.EdgePaymentOrders:
m.ResetPaymentOrders()
return nil
+ case user.EdgeAuthIdentities:
+ m.ResetAuthIdentities()
+ return nil
+ case user.EdgePendingAuthSessions:
+ m.ResetPendingAuthSessions()
+ return nil
}
return fmt.Errorf("unknown User edge %s", name)
}
diff --git a/backend/ent/pendingauthsession.go b/backend/ent/pendingauthsession.go
new file mode 100644
index 00000000..e77c065f
--- /dev/null
+++ b/backend/ent/pendingauthsession.go
@@ -0,0 +1,399 @@
+// Code generated by ent, DO NOT EDIT.
+
+package ent
+
+import (
+ "encoding/json"
+ "fmt"
+ "strings"
+ "time"
+
+ "entgo.io/ent"
+ "entgo.io/ent/dialect/sql"
+ "github.com/Wei-Shaw/sub2api/ent/identityadoptiondecision"
+ "github.com/Wei-Shaw/sub2api/ent/pendingauthsession"
+ "github.com/Wei-Shaw/sub2api/ent/user"
+)
+
+// PendingAuthSession is the model entity for the PendingAuthSession schema.
+type PendingAuthSession struct {
+ config `json:"-"`
+ // ID of the ent.
+ ID int64 `json:"id,omitempty"`
+ // CreatedAt holds the value of the "created_at" field.
+ CreatedAt time.Time `json:"created_at,omitempty"`
+ // UpdatedAt holds the value of the "updated_at" field.
+ UpdatedAt time.Time `json:"updated_at,omitempty"`
+ // SessionToken holds the value of the "session_token" field.
+ SessionToken string `json:"session_token,omitempty"`
+ // Intent holds the value of the "intent" field.
+ Intent string `json:"intent,omitempty"`
+ // ProviderType holds the value of the "provider_type" field.
+ ProviderType string `json:"provider_type,omitempty"`
+ // ProviderKey holds the value of the "provider_key" field.
+ ProviderKey string `json:"provider_key,omitempty"`
+ // ProviderSubject holds the value of the "provider_subject" field.
+ ProviderSubject string `json:"provider_subject,omitempty"`
+ // TargetUserID holds the value of the "target_user_id" field.
+ TargetUserID *int64 `json:"target_user_id,omitempty"`
+ // RedirectTo holds the value of the "redirect_to" field.
+ RedirectTo string `json:"redirect_to,omitempty"`
+ // ResolvedEmail holds the value of the "resolved_email" field.
+ ResolvedEmail string `json:"resolved_email,omitempty"`
+ // RegistrationPasswordHash holds the value of the "registration_password_hash" field.
+ RegistrationPasswordHash string `json:"registration_password_hash,omitempty"`
+ // UpstreamIdentityClaims holds the value of the "upstream_identity_claims" field.
+ UpstreamIdentityClaims map[string]interface{} `json:"upstream_identity_claims,omitempty"`
+ // LocalFlowState holds the value of the "local_flow_state" field.
+ LocalFlowState map[string]interface{} `json:"local_flow_state,omitempty"`
+ // BrowserSessionKey holds the value of the "browser_session_key" field.
+ BrowserSessionKey string `json:"browser_session_key,omitempty"`
+ // CompletionCodeHash holds the value of the "completion_code_hash" field.
+ CompletionCodeHash string `json:"completion_code_hash,omitempty"`
+ // CompletionCodeExpiresAt holds the value of the "completion_code_expires_at" field.
+ CompletionCodeExpiresAt *time.Time `json:"completion_code_expires_at,omitempty"`
+ // EmailVerifiedAt holds the value of the "email_verified_at" field.
+ EmailVerifiedAt *time.Time `json:"email_verified_at,omitempty"`
+ // PasswordVerifiedAt holds the value of the "password_verified_at" field.
+ PasswordVerifiedAt *time.Time `json:"password_verified_at,omitempty"`
+ // TotpVerifiedAt holds the value of the "totp_verified_at" field.
+ TotpVerifiedAt *time.Time `json:"totp_verified_at,omitempty"`
+ // ExpiresAt holds the value of the "expires_at" field.
+ ExpiresAt time.Time `json:"expires_at,omitempty"`
+ // ConsumedAt holds the value of the "consumed_at" field.
+ ConsumedAt *time.Time `json:"consumed_at,omitempty"`
+ // Edges holds the relations/edges for other nodes in the graph.
+ // The values are being populated by the PendingAuthSessionQuery when eager-loading is set.
+ Edges PendingAuthSessionEdges `json:"edges"`
+ selectValues sql.SelectValues
+}
+
+// PendingAuthSessionEdges holds the relations/edges for other nodes in the graph.
+type PendingAuthSessionEdges struct {
+ // TargetUser holds the value of the target_user edge.
+ TargetUser *User `json:"target_user,omitempty"`
+ // AdoptionDecision holds the value of the adoption_decision edge.
+ AdoptionDecision *IdentityAdoptionDecision `json:"adoption_decision,omitempty"`
+ // loadedTypes holds the information for reporting if a
+ // type was loaded (or requested) in eager-loading or not.
+ loadedTypes [2]bool
+}
+
+// TargetUserOrErr returns the TargetUser value or an error if the edge
+// was not loaded in eager-loading, or loaded but was not found.
+func (e PendingAuthSessionEdges) TargetUserOrErr() (*User, error) {
+ if e.TargetUser != nil {
+ return e.TargetUser, nil
+ } else if e.loadedTypes[0] {
+ return nil, &NotFoundError{label: user.Label}
+ }
+ return nil, &NotLoadedError{edge: "target_user"}
+}
+
+// AdoptionDecisionOrErr returns the AdoptionDecision value or an error if the edge
+// was not loaded in eager-loading, or loaded but was not found.
+func (e PendingAuthSessionEdges) AdoptionDecisionOrErr() (*IdentityAdoptionDecision, error) {
+ if e.AdoptionDecision != nil {
+ return e.AdoptionDecision, nil
+ } else if e.loadedTypes[1] {
+ return nil, &NotFoundError{label: identityadoptiondecision.Label}
+ }
+ return nil, &NotLoadedError{edge: "adoption_decision"}
+}
+
+// scanValues returns the types for scanning values from sql.Rows.
+func (*PendingAuthSession) scanValues(columns []string) ([]any, error) {
+ values := make([]any, len(columns))
+ for i := range columns {
+ switch columns[i] {
+ case pendingauthsession.FieldUpstreamIdentityClaims, pendingauthsession.FieldLocalFlowState:
+ values[i] = new([]byte)
+ case pendingauthsession.FieldID, pendingauthsession.FieldTargetUserID:
+ values[i] = new(sql.NullInt64)
+ case pendingauthsession.FieldSessionToken, pendingauthsession.FieldIntent, pendingauthsession.FieldProviderType, pendingauthsession.FieldProviderKey, pendingauthsession.FieldProviderSubject, pendingauthsession.FieldRedirectTo, pendingauthsession.FieldResolvedEmail, pendingauthsession.FieldRegistrationPasswordHash, pendingauthsession.FieldBrowserSessionKey, pendingauthsession.FieldCompletionCodeHash:
+ values[i] = new(sql.NullString)
+ case pendingauthsession.FieldCreatedAt, pendingauthsession.FieldUpdatedAt, pendingauthsession.FieldCompletionCodeExpiresAt, pendingauthsession.FieldEmailVerifiedAt, pendingauthsession.FieldPasswordVerifiedAt, pendingauthsession.FieldTotpVerifiedAt, pendingauthsession.FieldExpiresAt, pendingauthsession.FieldConsumedAt:
+ values[i] = new(sql.NullTime)
+ default:
+ values[i] = new(sql.UnknownType)
+ }
+ }
+ return values, nil
+}
+
+// assignValues assigns the values that were returned from sql.Rows (after scanning)
+// to the PendingAuthSession fields.
+func (_m *PendingAuthSession) assignValues(columns []string, values []any) error {
+ if m, n := len(values), len(columns); m < n {
+ return fmt.Errorf("mismatch number of scan values: %d != %d", m, n)
+ }
+ for i := range columns {
+ switch columns[i] {
+ case pendingauthsession.FieldID:
+ value, ok := values[i].(*sql.NullInt64)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field id", value)
+ }
+ _m.ID = int64(value.Int64)
+ case pendingauthsession.FieldCreatedAt:
+ if value, ok := values[i].(*sql.NullTime); !ok {
+ return fmt.Errorf("unexpected type %T for field created_at", values[i])
+ } else if value.Valid {
+ _m.CreatedAt = value.Time
+ }
+ case pendingauthsession.FieldUpdatedAt:
+ if value, ok := values[i].(*sql.NullTime); !ok {
+ return fmt.Errorf("unexpected type %T for field updated_at", values[i])
+ } else if value.Valid {
+ _m.UpdatedAt = value.Time
+ }
+ case pendingauthsession.FieldSessionToken:
+ if value, ok := values[i].(*sql.NullString); !ok {
+ return fmt.Errorf("unexpected type %T for field session_token", values[i])
+ } else if value.Valid {
+ _m.SessionToken = value.String
+ }
+ case pendingauthsession.FieldIntent:
+ if value, ok := values[i].(*sql.NullString); !ok {
+ return fmt.Errorf("unexpected type %T for field intent", values[i])
+ } else if value.Valid {
+ _m.Intent = value.String
+ }
+ case pendingauthsession.FieldProviderType:
+ if value, ok := values[i].(*sql.NullString); !ok {
+ return fmt.Errorf("unexpected type %T for field provider_type", values[i])
+ } else if value.Valid {
+ _m.ProviderType = value.String
+ }
+ case pendingauthsession.FieldProviderKey:
+ if value, ok := values[i].(*sql.NullString); !ok {
+ return fmt.Errorf("unexpected type %T for field provider_key", values[i])
+ } else if value.Valid {
+ _m.ProviderKey = value.String
+ }
+ case pendingauthsession.FieldProviderSubject:
+ if value, ok := values[i].(*sql.NullString); !ok {
+ return fmt.Errorf("unexpected type %T for field provider_subject", values[i])
+ } else if value.Valid {
+ _m.ProviderSubject = value.String
+ }
+ case pendingauthsession.FieldTargetUserID:
+ if value, ok := values[i].(*sql.NullInt64); !ok {
+ return fmt.Errorf("unexpected type %T for field target_user_id", values[i])
+ } else if value.Valid {
+ _m.TargetUserID = new(int64)
+ *_m.TargetUserID = value.Int64
+ }
+ case pendingauthsession.FieldRedirectTo:
+ if value, ok := values[i].(*sql.NullString); !ok {
+ return fmt.Errorf("unexpected type %T for field redirect_to", values[i])
+ } else if value.Valid {
+ _m.RedirectTo = value.String
+ }
+ case pendingauthsession.FieldResolvedEmail:
+ if value, ok := values[i].(*sql.NullString); !ok {
+ return fmt.Errorf("unexpected type %T for field resolved_email", values[i])
+ } else if value.Valid {
+ _m.ResolvedEmail = value.String
+ }
+ case pendingauthsession.FieldRegistrationPasswordHash:
+ if value, ok := values[i].(*sql.NullString); !ok {
+ return fmt.Errorf("unexpected type %T for field registration_password_hash", values[i])
+ } else if value.Valid {
+ _m.RegistrationPasswordHash = value.String
+ }
+ case pendingauthsession.FieldUpstreamIdentityClaims:
+ if value, ok := values[i].(*[]byte); !ok {
+ return fmt.Errorf("unexpected type %T for field upstream_identity_claims", values[i])
+ } else if value != nil && len(*value) > 0 {
+ if err := json.Unmarshal(*value, &_m.UpstreamIdentityClaims); err != nil {
+ return fmt.Errorf("unmarshal field upstream_identity_claims: %w", err)
+ }
+ }
+ case pendingauthsession.FieldLocalFlowState:
+ if value, ok := values[i].(*[]byte); !ok {
+ return fmt.Errorf("unexpected type %T for field local_flow_state", values[i])
+ } else if value != nil && len(*value) > 0 {
+ if err := json.Unmarshal(*value, &_m.LocalFlowState); err != nil {
+ return fmt.Errorf("unmarshal field local_flow_state: %w", err)
+ }
+ }
+ case pendingauthsession.FieldBrowserSessionKey:
+ if value, ok := values[i].(*sql.NullString); !ok {
+ return fmt.Errorf("unexpected type %T for field browser_session_key", values[i])
+ } else if value.Valid {
+ _m.BrowserSessionKey = value.String
+ }
+ case pendingauthsession.FieldCompletionCodeHash:
+ if value, ok := values[i].(*sql.NullString); !ok {
+ return fmt.Errorf("unexpected type %T for field completion_code_hash", values[i])
+ } else if value.Valid {
+ _m.CompletionCodeHash = value.String
+ }
+ case pendingauthsession.FieldCompletionCodeExpiresAt:
+ if value, ok := values[i].(*sql.NullTime); !ok {
+ return fmt.Errorf("unexpected type %T for field completion_code_expires_at", values[i])
+ } else if value.Valid {
+ _m.CompletionCodeExpiresAt = new(time.Time)
+ *_m.CompletionCodeExpiresAt = value.Time
+ }
+ case pendingauthsession.FieldEmailVerifiedAt:
+ if value, ok := values[i].(*sql.NullTime); !ok {
+ return fmt.Errorf("unexpected type %T for field email_verified_at", values[i])
+ } else if value.Valid {
+ _m.EmailVerifiedAt = new(time.Time)
+ *_m.EmailVerifiedAt = value.Time
+ }
+ case pendingauthsession.FieldPasswordVerifiedAt:
+ if value, ok := values[i].(*sql.NullTime); !ok {
+ return fmt.Errorf("unexpected type %T for field password_verified_at", values[i])
+ } else if value.Valid {
+ _m.PasswordVerifiedAt = new(time.Time)
+ *_m.PasswordVerifiedAt = value.Time
+ }
+ case pendingauthsession.FieldTotpVerifiedAt:
+ if value, ok := values[i].(*sql.NullTime); !ok {
+ return fmt.Errorf("unexpected type %T for field totp_verified_at", values[i])
+ } else if value.Valid {
+ _m.TotpVerifiedAt = new(time.Time)
+ *_m.TotpVerifiedAt = value.Time
+ }
+ case pendingauthsession.FieldExpiresAt:
+ if value, ok := values[i].(*sql.NullTime); !ok {
+ return fmt.Errorf("unexpected type %T for field expires_at", values[i])
+ } else if value.Valid {
+ _m.ExpiresAt = value.Time
+ }
+ case pendingauthsession.FieldConsumedAt:
+ if value, ok := values[i].(*sql.NullTime); !ok {
+ return fmt.Errorf("unexpected type %T for field consumed_at", values[i])
+ } else if value.Valid {
+ _m.ConsumedAt = new(time.Time)
+ *_m.ConsumedAt = value.Time
+ }
+ default:
+ _m.selectValues.Set(columns[i], values[i])
+ }
+ }
+ return nil
+}
+
+// Value returns the ent.Value that was dynamically selected and assigned to the PendingAuthSession.
+// This includes values selected through modifiers, order, etc.
+func (_m *PendingAuthSession) Value(name string) (ent.Value, error) {
+ return _m.selectValues.Get(name)
+}
+
+// QueryTargetUser queries the "target_user" edge of the PendingAuthSession entity.
+func (_m *PendingAuthSession) QueryTargetUser() *UserQuery {
+ return NewPendingAuthSessionClient(_m.config).QueryTargetUser(_m)
+}
+
+// QueryAdoptionDecision queries the "adoption_decision" edge of the PendingAuthSession entity.
+func (_m *PendingAuthSession) QueryAdoptionDecision() *IdentityAdoptionDecisionQuery {
+ return NewPendingAuthSessionClient(_m.config).QueryAdoptionDecision(_m)
+}
+
+// Update returns a builder for updating this PendingAuthSession.
+// Note that you need to call PendingAuthSession.Unwrap() before calling this method if this PendingAuthSession
+// was returned from a transaction, and the transaction was committed or rolled back.
+func (_m *PendingAuthSession) Update() *PendingAuthSessionUpdateOne {
+ return NewPendingAuthSessionClient(_m.config).UpdateOne(_m)
+}
+
+// Unwrap unwraps the PendingAuthSession entity that was returned from a transaction after it was closed,
+// so that all future queries will be executed through the driver which created the transaction.
+func (_m *PendingAuthSession) Unwrap() *PendingAuthSession {
+ _tx, ok := _m.config.driver.(*txDriver)
+ if !ok {
+ panic("ent: PendingAuthSession is not a transactional entity")
+ }
+ _m.config.driver = _tx.drv
+ return _m
+}
+
+// String implements the fmt.Stringer.
+func (_m *PendingAuthSession) String() string {
+ var builder strings.Builder
+ builder.WriteString("PendingAuthSession(")
+ builder.WriteString(fmt.Sprintf("id=%v, ", _m.ID))
+ builder.WriteString("created_at=")
+ builder.WriteString(_m.CreatedAt.Format(time.ANSIC))
+ builder.WriteString(", ")
+ builder.WriteString("updated_at=")
+ builder.WriteString(_m.UpdatedAt.Format(time.ANSIC))
+ builder.WriteString(", ")
+ builder.WriteString("session_token=")
+ builder.WriteString(_m.SessionToken)
+ builder.WriteString(", ")
+ builder.WriteString("intent=")
+ builder.WriteString(_m.Intent)
+ builder.WriteString(", ")
+ builder.WriteString("provider_type=")
+ builder.WriteString(_m.ProviderType)
+ builder.WriteString(", ")
+ builder.WriteString("provider_key=")
+ builder.WriteString(_m.ProviderKey)
+ builder.WriteString(", ")
+ builder.WriteString("provider_subject=")
+ builder.WriteString(_m.ProviderSubject)
+ builder.WriteString(", ")
+ if v := _m.TargetUserID; v != nil {
+ builder.WriteString("target_user_id=")
+ builder.WriteString(fmt.Sprintf("%v", *v))
+ }
+ builder.WriteString(", ")
+ builder.WriteString("redirect_to=")
+ builder.WriteString(_m.RedirectTo)
+ builder.WriteString(", ")
+ builder.WriteString("resolved_email=")
+ builder.WriteString(_m.ResolvedEmail)
+ builder.WriteString(", ")
+ builder.WriteString("registration_password_hash=")
+ builder.WriteString(_m.RegistrationPasswordHash)
+ builder.WriteString(", ")
+ builder.WriteString("upstream_identity_claims=")
+ builder.WriteString(fmt.Sprintf("%v", _m.UpstreamIdentityClaims))
+ builder.WriteString(", ")
+ builder.WriteString("local_flow_state=")
+ builder.WriteString(fmt.Sprintf("%v", _m.LocalFlowState))
+ builder.WriteString(", ")
+ builder.WriteString("browser_session_key=")
+ builder.WriteString(_m.BrowserSessionKey)
+ builder.WriteString(", ")
+ builder.WriteString("completion_code_hash=")
+ builder.WriteString(_m.CompletionCodeHash)
+ builder.WriteString(", ")
+ if v := _m.CompletionCodeExpiresAt; v != nil {
+ builder.WriteString("completion_code_expires_at=")
+ builder.WriteString(v.Format(time.ANSIC))
+ }
+ builder.WriteString(", ")
+ if v := _m.EmailVerifiedAt; v != nil {
+ builder.WriteString("email_verified_at=")
+ builder.WriteString(v.Format(time.ANSIC))
+ }
+ builder.WriteString(", ")
+ if v := _m.PasswordVerifiedAt; v != nil {
+ builder.WriteString("password_verified_at=")
+ builder.WriteString(v.Format(time.ANSIC))
+ }
+ builder.WriteString(", ")
+ if v := _m.TotpVerifiedAt; v != nil {
+ builder.WriteString("totp_verified_at=")
+ builder.WriteString(v.Format(time.ANSIC))
+ }
+ builder.WriteString(", ")
+ builder.WriteString("expires_at=")
+ builder.WriteString(_m.ExpiresAt.Format(time.ANSIC))
+ builder.WriteString(", ")
+ if v := _m.ConsumedAt; v != nil {
+ builder.WriteString("consumed_at=")
+ builder.WriteString(v.Format(time.ANSIC))
+ }
+ builder.WriteByte(')')
+ return builder.String()
+}
+
+// PendingAuthSessions is a parsable slice of PendingAuthSession.
+type PendingAuthSessions []*PendingAuthSession
diff --git a/backend/ent/pendingauthsession/pendingauthsession.go b/backend/ent/pendingauthsession/pendingauthsession.go
new file mode 100644
index 00000000..8a3ac9bf
--- /dev/null
+++ b/backend/ent/pendingauthsession/pendingauthsession.go
@@ -0,0 +1,279 @@
+// Code generated by ent, DO NOT EDIT.
+
+package pendingauthsession
+
+import (
+ "time"
+
+ "entgo.io/ent/dialect/sql"
+ "entgo.io/ent/dialect/sql/sqlgraph"
+)
+
+const (
+ // Label holds the string label denoting the pendingauthsession type in the database.
+ Label = "pending_auth_session"
+ // FieldID holds the string denoting the id field in the database.
+ FieldID = "id"
+ // FieldCreatedAt holds the string denoting the created_at field in the database.
+ FieldCreatedAt = "created_at"
+ // FieldUpdatedAt holds the string denoting the updated_at field in the database.
+ FieldUpdatedAt = "updated_at"
+ // FieldSessionToken holds the string denoting the session_token field in the database.
+ FieldSessionToken = "session_token"
+ // FieldIntent holds the string denoting the intent field in the database.
+ FieldIntent = "intent"
+ // FieldProviderType holds the string denoting the provider_type field in the database.
+ FieldProviderType = "provider_type"
+ // FieldProviderKey holds the string denoting the provider_key field in the database.
+ FieldProviderKey = "provider_key"
+ // FieldProviderSubject holds the string denoting the provider_subject field in the database.
+ FieldProviderSubject = "provider_subject"
+ // FieldTargetUserID holds the string denoting the target_user_id field in the database.
+ FieldTargetUserID = "target_user_id"
+ // FieldRedirectTo holds the string denoting the redirect_to field in the database.
+ FieldRedirectTo = "redirect_to"
+ // FieldResolvedEmail holds the string denoting the resolved_email field in the database.
+ FieldResolvedEmail = "resolved_email"
+ // FieldRegistrationPasswordHash holds the string denoting the registration_password_hash field in the database.
+ FieldRegistrationPasswordHash = "registration_password_hash"
+ // FieldUpstreamIdentityClaims holds the string denoting the upstream_identity_claims field in the database.
+ FieldUpstreamIdentityClaims = "upstream_identity_claims"
+ // FieldLocalFlowState holds the string denoting the local_flow_state field in the database.
+ FieldLocalFlowState = "local_flow_state"
+ // FieldBrowserSessionKey holds the string denoting the browser_session_key field in the database.
+ FieldBrowserSessionKey = "browser_session_key"
+ // FieldCompletionCodeHash holds the string denoting the completion_code_hash field in the database.
+ FieldCompletionCodeHash = "completion_code_hash"
+ // FieldCompletionCodeExpiresAt holds the string denoting the completion_code_expires_at field in the database.
+ FieldCompletionCodeExpiresAt = "completion_code_expires_at"
+ // FieldEmailVerifiedAt holds the string denoting the email_verified_at field in the database.
+ FieldEmailVerifiedAt = "email_verified_at"
+ // FieldPasswordVerifiedAt holds the string denoting the password_verified_at field in the database.
+ FieldPasswordVerifiedAt = "password_verified_at"
+ // FieldTotpVerifiedAt holds the string denoting the totp_verified_at field in the database.
+ FieldTotpVerifiedAt = "totp_verified_at"
+ // FieldExpiresAt holds the string denoting the expires_at field in the database.
+ FieldExpiresAt = "expires_at"
+ // FieldConsumedAt holds the string denoting the consumed_at field in the database.
+ FieldConsumedAt = "consumed_at"
+ // EdgeTargetUser holds the string denoting the target_user edge name in mutations.
+ EdgeTargetUser = "target_user"
+ // EdgeAdoptionDecision holds the string denoting the adoption_decision edge name in mutations.
+ EdgeAdoptionDecision = "adoption_decision"
+ // Table holds the table name of the pendingauthsession in the database.
+ Table = "pending_auth_sessions"
+ // TargetUserTable is the table that holds the target_user relation/edge.
+ TargetUserTable = "pending_auth_sessions"
+ // TargetUserInverseTable is the table name for the User entity.
+ // It exists in this package in order to avoid circular dependency with the "user" package.
+ TargetUserInverseTable = "users"
+ // TargetUserColumn is the table column denoting the target_user relation/edge.
+ TargetUserColumn = "target_user_id"
+ // AdoptionDecisionTable is the table that holds the adoption_decision relation/edge.
+ AdoptionDecisionTable = "identity_adoption_decisions"
+ // AdoptionDecisionInverseTable is the table name for the IdentityAdoptionDecision entity.
+ // It exists in this package in order to avoid circular dependency with the "identityadoptiondecision" package.
+ AdoptionDecisionInverseTable = "identity_adoption_decisions"
+ // AdoptionDecisionColumn is the table column denoting the adoption_decision relation/edge.
+ AdoptionDecisionColumn = "pending_auth_session_id"
+)
+
+// Columns holds all SQL columns for pendingauthsession fields.
+var Columns = []string{
+ FieldID,
+ FieldCreatedAt,
+ FieldUpdatedAt,
+ FieldSessionToken,
+ FieldIntent,
+ FieldProviderType,
+ FieldProviderKey,
+ FieldProviderSubject,
+ FieldTargetUserID,
+ FieldRedirectTo,
+ FieldResolvedEmail,
+ FieldRegistrationPasswordHash,
+ FieldUpstreamIdentityClaims,
+ FieldLocalFlowState,
+ FieldBrowserSessionKey,
+ FieldCompletionCodeHash,
+ FieldCompletionCodeExpiresAt,
+ FieldEmailVerifiedAt,
+ FieldPasswordVerifiedAt,
+ FieldTotpVerifiedAt,
+ FieldExpiresAt,
+ FieldConsumedAt,
+}
+
+// ValidColumn reports if the column name is valid (part of the table columns).
+func ValidColumn(column string) bool {
+ for i := range Columns {
+ if column == Columns[i] {
+ return true
+ }
+ }
+ return false
+}
+
+var (
+ // DefaultCreatedAt holds the default value on creation for the "created_at" field.
+ DefaultCreatedAt func() time.Time
+ // DefaultUpdatedAt holds the default value on creation for the "updated_at" field.
+ DefaultUpdatedAt func() time.Time
+ // UpdateDefaultUpdatedAt holds the default value on update for the "updated_at" field.
+ UpdateDefaultUpdatedAt func() time.Time
+ // SessionTokenValidator is a validator for the "session_token" field. It is called by the builders before save.
+ SessionTokenValidator func(string) error
+ // IntentValidator is a validator for the "intent" field. It is called by the builders before save.
+ IntentValidator func(string) error
+ // ProviderTypeValidator is a validator for the "provider_type" field. It is called by the builders before save.
+ ProviderTypeValidator func(string) error
+ // ProviderKeyValidator is a validator for the "provider_key" field. It is called by the builders before save.
+ ProviderKeyValidator func(string) error
+ // ProviderSubjectValidator is a validator for the "provider_subject" field. It is called by the builders before save.
+ ProviderSubjectValidator func(string) error
+ // DefaultRedirectTo holds the default value on creation for the "redirect_to" field.
+ DefaultRedirectTo string
+ // DefaultResolvedEmail holds the default value on creation for the "resolved_email" field.
+ DefaultResolvedEmail string
+ // DefaultRegistrationPasswordHash holds the default value on creation for the "registration_password_hash" field.
+ DefaultRegistrationPasswordHash string
+ // DefaultUpstreamIdentityClaims holds the default value on creation for the "upstream_identity_claims" field.
+ DefaultUpstreamIdentityClaims func() map[string]interface{}
+ // DefaultLocalFlowState holds the default value on creation for the "local_flow_state" field.
+ DefaultLocalFlowState func() map[string]interface{}
+ // DefaultBrowserSessionKey holds the default value on creation for the "browser_session_key" field.
+ DefaultBrowserSessionKey string
+ // DefaultCompletionCodeHash holds the default value on creation for the "completion_code_hash" field.
+ DefaultCompletionCodeHash string
+)
+
+// OrderOption defines the ordering options for the PendingAuthSession queries.
+type OrderOption func(*sql.Selector)
+
+// ByID orders the results by the id field.
+func ByID(opts ...sql.OrderTermOption) OrderOption {
+ return sql.OrderByField(FieldID, opts...).ToFunc()
+}
+
+// ByCreatedAt orders the results by the created_at field.
+func ByCreatedAt(opts ...sql.OrderTermOption) OrderOption {
+ return sql.OrderByField(FieldCreatedAt, opts...).ToFunc()
+}
+
+// ByUpdatedAt orders the results by the updated_at field.
+func ByUpdatedAt(opts ...sql.OrderTermOption) OrderOption {
+ return sql.OrderByField(FieldUpdatedAt, opts...).ToFunc()
+}
+
+// BySessionToken orders the results by the session_token field.
+func BySessionToken(opts ...sql.OrderTermOption) OrderOption {
+ return sql.OrderByField(FieldSessionToken, opts...).ToFunc()
+}
+
+// ByIntent orders the results by the intent field.
+func ByIntent(opts ...sql.OrderTermOption) OrderOption {
+ return sql.OrderByField(FieldIntent, opts...).ToFunc()
+}
+
+// ByProviderType orders the results by the provider_type field.
+func ByProviderType(opts ...sql.OrderTermOption) OrderOption {
+ return sql.OrderByField(FieldProviderType, opts...).ToFunc()
+}
+
+// ByProviderKey orders the results by the provider_key field.
+func ByProviderKey(opts ...sql.OrderTermOption) OrderOption {
+ return sql.OrderByField(FieldProviderKey, opts...).ToFunc()
+}
+
+// ByProviderSubject orders the results by the provider_subject field.
+func ByProviderSubject(opts ...sql.OrderTermOption) OrderOption {
+ return sql.OrderByField(FieldProviderSubject, opts...).ToFunc()
+}
+
+// ByTargetUserID orders the results by the target_user_id field.
+func ByTargetUserID(opts ...sql.OrderTermOption) OrderOption {
+ return sql.OrderByField(FieldTargetUserID, opts...).ToFunc()
+}
+
+// ByRedirectTo orders the results by the redirect_to field.
+func ByRedirectTo(opts ...sql.OrderTermOption) OrderOption {
+ return sql.OrderByField(FieldRedirectTo, opts...).ToFunc()
+}
+
+// ByResolvedEmail orders the results by the resolved_email field.
+func ByResolvedEmail(opts ...sql.OrderTermOption) OrderOption {
+ return sql.OrderByField(FieldResolvedEmail, opts...).ToFunc()
+}
+
+// ByRegistrationPasswordHash orders the results by the registration_password_hash field.
+func ByRegistrationPasswordHash(opts ...sql.OrderTermOption) OrderOption {
+ return sql.OrderByField(FieldRegistrationPasswordHash, opts...).ToFunc()
+}
+
+// ByBrowserSessionKey orders the results by the browser_session_key field.
+func ByBrowserSessionKey(opts ...sql.OrderTermOption) OrderOption {
+ return sql.OrderByField(FieldBrowserSessionKey, opts...).ToFunc()
+}
+
+// ByCompletionCodeHash orders the results by the completion_code_hash field.
+func ByCompletionCodeHash(opts ...sql.OrderTermOption) OrderOption {
+ return sql.OrderByField(FieldCompletionCodeHash, opts...).ToFunc()
+}
+
+// ByCompletionCodeExpiresAt orders the results by the completion_code_expires_at field.
+func ByCompletionCodeExpiresAt(opts ...sql.OrderTermOption) OrderOption {
+ return sql.OrderByField(FieldCompletionCodeExpiresAt, opts...).ToFunc()
+}
+
+// ByEmailVerifiedAt orders the results by the email_verified_at field.
+func ByEmailVerifiedAt(opts ...sql.OrderTermOption) OrderOption {
+ return sql.OrderByField(FieldEmailVerifiedAt, opts...).ToFunc()
+}
+
+// ByPasswordVerifiedAt orders the results by the password_verified_at field.
+func ByPasswordVerifiedAt(opts ...sql.OrderTermOption) OrderOption {
+ return sql.OrderByField(FieldPasswordVerifiedAt, opts...).ToFunc()
+}
+
+// ByTotpVerifiedAt orders the results by the totp_verified_at field.
+func ByTotpVerifiedAt(opts ...sql.OrderTermOption) OrderOption {
+ return sql.OrderByField(FieldTotpVerifiedAt, opts...).ToFunc()
+}
+
+// ByExpiresAt orders the results by the expires_at field.
+func ByExpiresAt(opts ...sql.OrderTermOption) OrderOption {
+ return sql.OrderByField(FieldExpiresAt, opts...).ToFunc()
+}
+
+// ByConsumedAt orders the results by the consumed_at field.
+func ByConsumedAt(opts ...sql.OrderTermOption) OrderOption {
+ return sql.OrderByField(FieldConsumedAt, opts...).ToFunc()
+}
+
+// ByTargetUserField orders the results by target_user field.
+func ByTargetUserField(field string, opts ...sql.OrderTermOption) OrderOption {
+ return func(s *sql.Selector) {
+ sqlgraph.OrderByNeighborTerms(s, newTargetUserStep(), sql.OrderByField(field, opts...))
+ }
+}
+
+// ByAdoptionDecisionField orders the results by adoption_decision field.
+func ByAdoptionDecisionField(field string, opts ...sql.OrderTermOption) OrderOption {
+ return func(s *sql.Selector) {
+ sqlgraph.OrderByNeighborTerms(s, newAdoptionDecisionStep(), sql.OrderByField(field, opts...))
+ }
+}
+func newTargetUserStep() *sqlgraph.Step {
+ return sqlgraph.NewStep(
+ sqlgraph.From(Table, FieldID),
+ sqlgraph.To(TargetUserInverseTable, FieldID),
+ sqlgraph.Edge(sqlgraph.M2O, true, TargetUserTable, TargetUserColumn),
+ )
+}
+func newAdoptionDecisionStep() *sqlgraph.Step {
+ return sqlgraph.NewStep(
+ sqlgraph.From(Table, FieldID),
+ sqlgraph.To(AdoptionDecisionInverseTable, FieldID),
+ sqlgraph.Edge(sqlgraph.O2O, false, AdoptionDecisionTable, AdoptionDecisionColumn),
+ )
+}
diff --git a/backend/ent/pendingauthsession/where.go b/backend/ent/pendingauthsession/where.go
new file mode 100644
index 00000000..cb316f47
--- /dev/null
+++ b/backend/ent/pendingauthsession/where.go
@@ -0,0 +1,1262 @@
+// Code generated by ent, DO NOT EDIT.
+
+package pendingauthsession
+
+import (
+ "time"
+
+ "entgo.io/ent/dialect/sql"
+ "entgo.io/ent/dialect/sql/sqlgraph"
+ "github.com/Wei-Shaw/sub2api/ent/predicate"
+)
+
+// ID filters vertices based on their ID field.
+func ID(id int64) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldEQ(FieldID, id))
+}
+
+// IDEQ applies the EQ predicate on the ID field.
+func IDEQ(id int64) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldEQ(FieldID, id))
+}
+
+// IDNEQ applies the NEQ predicate on the ID field.
+func IDNEQ(id int64) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldNEQ(FieldID, id))
+}
+
+// IDIn applies the In predicate on the ID field.
+func IDIn(ids ...int64) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldIn(FieldID, ids...))
+}
+
+// IDNotIn applies the NotIn predicate on the ID field.
+func IDNotIn(ids ...int64) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldNotIn(FieldID, ids...))
+}
+
+// IDGT applies the GT predicate on the ID field.
+func IDGT(id int64) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldGT(FieldID, id))
+}
+
+// IDGTE applies the GTE predicate on the ID field.
+func IDGTE(id int64) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldGTE(FieldID, id))
+}
+
+// IDLT applies the LT predicate on the ID field.
+func IDLT(id int64) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldLT(FieldID, id))
+}
+
+// IDLTE applies the LTE predicate on the ID field.
+func IDLTE(id int64) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldLTE(FieldID, id))
+}
+
+// CreatedAt applies equality check predicate on the "created_at" field. It's identical to CreatedAtEQ.
+func CreatedAt(v time.Time) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldEQ(FieldCreatedAt, v))
+}
+
+// UpdatedAt applies equality check predicate on the "updated_at" field. It's identical to UpdatedAtEQ.
+func UpdatedAt(v time.Time) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldEQ(FieldUpdatedAt, v))
+}
+
+// SessionToken applies equality check predicate on the "session_token" field. It's identical to SessionTokenEQ.
+func SessionToken(v string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldEQ(FieldSessionToken, v))
+}
+
+// Intent applies equality check predicate on the "intent" field. It's identical to IntentEQ.
+func Intent(v string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldEQ(FieldIntent, v))
+}
+
+// ProviderType applies equality check predicate on the "provider_type" field. It's identical to ProviderTypeEQ.
+func ProviderType(v string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldEQ(FieldProviderType, v))
+}
+
+// ProviderKey applies equality check predicate on the "provider_key" field. It's identical to ProviderKeyEQ.
+func ProviderKey(v string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldEQ(FieldProviderKey, v))
+}
+
+// ProviderSubject applies equality check predicate on the "provider_subject" field. It's identical to ProviderSubjectEQ.
+func ProviderSubject(v string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldEQ(FieldProviderSubject, v))
+}
+
+// TargetUserID applies equality check predicate on the "target_user_id" field. It's identical to TargetUserIDEQ.
+func TargetUserID(v int64) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldEQ(FieldTargetUserID, v))
+}
+
+// RedirectTo applies equality check predicate on the "redirect_to" field. It's identical to RedirectToEQ.
+func RedirectTo(v string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldEQ(FieldRedirectTo, v))
+}
+
+// ResolvedEmail applies equality check predicate on the "resolved_email" field. It's identical to ResolvedEmailEQ.
+func ResolvedEmail(v string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldEQ(FieldResolvedEmail, v))
+}
+
+// RegistrationPasswordHash applies equality check predicate on the "registration_password_hash" field. It's identical to RegistrationPasswordHashEQ.
+func RegistrationPasswordHash(v string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldEQ(FieldRegistrationPasswordHash, v))
+}
+
+// BrowserSessionKey applies equality check predicate on the "browser_session_key" field. It's identical to BrowserSessionKeyEQ.
+func BrowserSessionKey(v string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldEQ(FieldBrowserSessionKey, v))
+}
+
+// CompletionCodeHash applies equality check predicate on the "completion_code_hash" field. It's identical to CompletionCodeHashEQ.
+func CompletionCodeHash(v string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldEQ(FieldCompletionCodeHash, v))
+}
+
+// CompletionCodeExpiresAt applies equality check predicate on the "completion_code_expires_at" field. It's identical to CompletionCodeExpiresAtEQ.
+func CompletionCodeExpiresAt(v time.Time) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldEQ(FieldCompletionCodeExpiresAt, v))
+}
+
+// EmailVerifiedAt applies equality check predicate on the "email_verified_at" field. It's identical to EmailVerifiedAtEQ.
+func EmailVerifiedAt(v time.Time) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldEQ(FieldEmailVerifiedAt, v))
+}
+
+// PasswordVerifiedAt applies equality check predicate on the "password_verified_at" field. It's identical to PasswordVerifiedAtEQ.
+func PasswordVerifiedAt(v time.Time) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldEQ(FieldPasswordVerifiedAt, v))
+}
+
+// TotpVerifiedAt applies equality check predicate on the "totp_verified_at" field. It's identical to TotpVerifiedAtEQ.
+func TotpVerifiedAt(v time.Time) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldEQ(FieldTotpVerifiedAt, v))
+}
+
+// ExpiresAt applies equality check predicate on the "expires_at" field. It's identical to ExpiresAtEQ.
+func ExpiresAt(v time.Time) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldEQ(FieldExpiresAt, v))
+}
+
+// ConsumedAt applies equality check predicate on the "consumed_at" field. It's identical to ConsumedAtEQ.
+func ConsumedAt(v time.Time) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldEQ(FieldConsumedAt, v))
+}
+
+// CreatedAtEQ applies the EQ predicate on the "created_at" field.
+func CreatedAtEQ(v time.Time) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldEQ(FieldCreatedAt, v))
+}
+
+// CreatedAtNEQ applies the NEQ predicate on the "created_at" field.
+func CreatedAtNEQ(v time.Time) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldNEQ(FieldCreatedAt, v))
+}
+
+// CreatedAtIn applies the In predicate on the "created_at" field.
+func CreatedAtIn(vs ...time.Time) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldIn(FieldCreatedAt, vs...))
+}
+
+// CreatedAtNotIn applies the NotIn predicate on the "created_at" field.
+func CreatedAtNotIn(vs ...time.Time) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldNotIn(FieldCreatedAt, vs...))
+}
+
+// CreatedAtGT applies the GT predicate on the "created_at" field.
+func CreatedAtGT(v time.Time) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldGT(FieldCreatedAt, v))
+}
+
+// CreatedAtGTE applies the GTE predicate on the "created_at" field.
+func CreatedAtGTE(v time.Time) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldGTE(FieldCreatedAt, v))
+}
+
+// CreatedAtLT applies the LT predicate on the "created_at" field.
+func CreatedAtLT(v time.Time) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldLT(FieldCreatedAt, v))
+}
+
+// CreatedAtLTE applies the LTE predicate on the "created_at" field.
+func CreatedAtLTE(v time.Time) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldLTE(FieldCreatedAt, v))
+}
+
+// UpdatedAtEQ applies the EQ predicate on the "updated_at" field.
+func UpdatedAtEQ(v time.Time) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldEQ(FieldUpdatedAt, v))
+}
+
+// UpdatedAtNEQ applies the NEQ predicate on the "updated_at" field.
+func UpdatedAtNEQ(v time.Time) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldNEQ(FieldUpdatedAt, v))
+}
+
+// UpdatedAtIn applies the In predicate on the "updated_at" field.
+func UpdatedAtIn(vs ...time.Time) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldIn(FieldUpdatedAt, vs...))
+}
+
+// UpdatedAtNotIn applies the NotIn predicate on the "updated_at" field.
+func UpdatedAtNotIn(vs ...time.Time) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldNotIn(FieldUpdatedAt, vs...))
+}
+
+// UpdatedAtGT applies the GT predicate on the "updated_at" field.
+func UpdatedAtGT(v time.Time) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldGT(FieldUpdatedAt, v))
+}
+
+// UpdatedAtGTE applies the GTE predicate on the "updated_at" field.
+func UpdatedAtGTE(v time.Time) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldGTE(FieldUpdatedAt, v))
+}
+
+// UpdatedAtLT applies the LT predicate on the "updated_at" field.
+func UpdatedAtLT(v time.Time) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldLT(FieldUpdatedAt, v))
+}
+
+// UpdatedAtLTE applies the LTE predicate on the "updated_at" field.
+func UpdatedAtLTE(v time.Time) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldLTE(FieldUpdatedAt, v))
+}
+
+// SessionTokenEQ applies the EQ predicate on the "session_token" field.
+func SessionTokenEQ(v string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldEQ(FieldSessionToken, v))
+}
+
+// SessionTokenNEQ applies the NEQ predicate on the "session_token" field.
+func SessionTokenNEQ(v string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldNEQ(FieldSessionToken, v))
+}
+
+// SessionTokenIn applies the In predicate on the "session_token" field.
+func SessionTokenIn(vs ...string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldIn(FieldSessionToken, vs...))
+}
+
+// SessionTokenNotIn applies the NotIn predicate on the "session_token" field.
+func SessionTokenNotIn(vs ...string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldNotIn(FieldSessionToken, vs...))
+}
+
+// SessionTokenGT applies the GT predicate on the "session_token" field.
+func SessionTokenGT(v string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldGT(FieldSessionToken, v))
+}
+
+// SessionTokenGTE applies the GTE predicate on the "session_token" field.
+func SessionTokenGTE(v string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldGTE(FieldSessionToken, v))
+}
+
+// SessionTokenLT applies the LT predicate on the "session_token" field.
+func SessionTokenLT(v string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldLT(FieldSessionToken, v))
+}
+
+// SessionTokenLTE applies the LTE predicate on the "session_token" field.
+func SessionTokenLTE(v string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldLTE(FieldSessionToken, v))
+}
+
+// SessionTokenContains applies the Contains predicate on the "session_token" field.
+func SessionTokenContains(v string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldContains(FieldSessionToken, v))
+}
+
+// SessionTokenHasPrefix applies the HasPrefix predicate on the "session_token" field.
+func SessionTokenHasPrefix(v string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldHasPrefix(FieldSessionToken, v))
+}
+
+// SessionTokenHasSuffix applies the HasSuffix predicate on the "session_token" field.
+func SessionTokenHasSuffix(v string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldHasSuffix(FieldSessionToken, v))
+}
+
+// SessionTokenEqualFold applies the EqualFold predicate on the "session_token" field.
+func SessionTokenEqualFold(v string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldEqualFold(FieldSessionToken, v))
+}
+
+// SessionTokenContainsFold applies the ContainsFold predicate on the "session_token" field.
+func SessionTokenContainsFold(v string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldContainsFold(FieldSessionToken, v))
+}
+
+// IntentEQ applies the EQ predicate on the "intent" field.
+func IntentEQ(v string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldEQ(FieldIntent, v))
+}
+
+// IntentNEQ applies the NEQ predicate on the "intent" field.
+func IntentNEQ(v string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldNEQ(FieldIntent, v))
+}
+
+// IntentIn applies the In predicate on the "intent" field.
+func IntentIn(vs ...string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldIn(FieldIntent, vs...))
+}
+
+// IntentNotIn applies the NotIn predicate on the "intent" field.
+func IntentNotIn(vs ...string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldNotIn(FieldIntent, vs...))
+}
+
+// IntentGT applies the GT predicate on the "intent" field.
+func IntentGT(v string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldGT(FieldIntent, v))
+}
+
+// IntentGTE applies the GTE predicate on the "intent" field.
+func IntentGTE(v string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldGTE(FieldIntent, v))
+}
+
+// IntentLT applies the LT predicate on the "intent" field.
+func IntentLT(v string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldLT(FieldIntent, v))
+}
+
+// IntentLTE applies the LTE predicate on the "intent" field.
+func IntentLTE(v string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldLTE(FieldIntent, v))
+}
+
+// IntentContains applies the Contains predicate on the "intent" field.
+func IntentContains(v string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldContains(FieldIntent, v))
+}
+
+// IntentHasPrefix applies the HasPrefix predicate on the "intent" field.
+func IntentHasPrefix(v string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldHasPrefix(FieldIntent, v))
+}
+
+// IntentHasSuffix applies the HasSuffix predicate on the "intent" field.
+func IntentHasSuffix(v string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldHasSuffix(FieldIntent, v))
+}
+
+// IntentEqualFold applies the EqualFold predicate on the "intent" field.
+func IntentEqualFold(v string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldEqualFold(FieldIntent, v))
+}
+
+// IntentContainsFold applies the ContainsFold predicate on the "intent" field.
+func IntentContainsFold(v string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldContainsFold(FieldIntent, v))
+}
+
+// ProviderTypeEQ applies the EQ predicate on the "provider_type" field.
+func ProviderTypeEQ(v string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldEQ(FieldProviderType, v))
+}
+
+// ProviderTypeNEQ applies the NEQ predicate on the "provider_type" field.
+func ProviderTypeNEQ(v string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldNEQ(FieldProviderType, v))
+}
+
+// ProviderTypeIn applies the In predicate on the "provider_type" field.
+func ProviderTypeIn(vs ...string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldIn(FieldProviderType, vs...))
+}
+
+// ProviderTypeNotIn applies the NotIn predicate on the "provider_type" field.
+func ProviderTypeNotIn(vs ...string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldNotIn(FieldProviderType, vs...))
+}
+
+// ProviderTypeGT applies the GT predicate on the "provider_type" field.
+func ProviderTypeGT(v string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldGT(FieldProviderType, v))
+}
+
+// ProviderTypeGTE applies the GTE predicate on the "provider_type" field.
+func ProviderTypeGTE(v string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldGTE(FieldProviderType, v))
+}
+
+// ProviderTypeLT applies the LT predicate on the "provider_type" field.
+func ProviderTypeLT(v string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldLT(FieldProviderType, v))
+}
+
+// ProviderTypeLTE applies the LTE predicate on the "provider_type" field.
+func ProviderTypeLTE(v string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldLTE(FieldProviderType, v))
+}
+
+// ProviderTypeContains applies the Contains predicate on the "provider_type" field.
+func ProviderTypeContains(v string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldContains(FieldProviderType, v))
+}
+
+// ProviderTypeHasPrefix applies the HasPrefix predicate on the "provider_type" field.
+func ProviderTypeHasPrefix(v string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldHasPrefix(FieldProviderType, v))
+}
+
+// ProviderTypeHasSuffix applies the HasSuffix predicate on the "provider_type" field.
+func ProviderTypeHasSuffix(v string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldHasSuffix(FieldProviderType, v))
+}
+
+// ProviderTypeEqualFold applies the EqualFold predicate on the "provider_type" field.
+func ProviderTypeEqualFold(v string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldEqualFold(FieldProviderType, v))
+}
+
+// ProviderTypeContainsFold applies the ContainsFold predicate on the "provider_type" field.
+func ProviderTypeContainsFold(v string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldContainsFold(FieldProviderType, v))
+}
+
+// ProviderKeyEQ applies the EQ predicate on the "provider_key" field.
+func ProviderKeyEQ(v string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldEQ(FieldProviderKey, v))
+}
+
+// ProviderKeyNEQ applies the NEQ predicate on the "provider_key" field.
+func ProviderKeyNEQ(v string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldNEQ(FieldProviderKey, v))
+}
+
+// ProviderKeyIn applies the In predicate on the "provider_key" field.
+func ProviderKeyIn(vs ...string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldIn(FieldProviderKey, vs...))
+}
+
+// ProviderKeyNotIn applies the NotIn predicate on the "provider_key" field.
+func ProviderKeyNotIn(vs ...string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldNotIn(FieldProviderKey, vs...))
+}
+
+// ProviderKeyGT applies the GT predicate on the "provider_key" field.
+func ProviderKeyGT(v string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldGT(FieldProviderKey, v))
+}
+
+// ProviderKeyGTE applies the GTE predicate on the "provider_key" field.
+func ProviderKeyGTE(v string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldGTE(FieldProviderKey, v))
+}
+
+// ProviderKeyLT applies the LT predicate on the "provider_key" field.
+func ProviderKeyLT(v string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldLT(FieldProviderKey, v))
+}
+
+// ProviderKeyLTE applies the LTE predicate on the "provider_key" field.
+func ProviderKeyLTE(v string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldLTE(FieldProviderKey, v))
+}
+
+// ProviderKeyContains applies the Contains predicate on the "provider_key" field.
+func ProviderKeyContains(v string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldContains(FieldProviderKey, v))
+}
+
+// ProviderKeyHasPrefix applies the HasPrefix predicate on the "provider_key" field.
+func ProviderKeyHasPrefix(v string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldHasPrefix(FieldProviderKey, v))
+}
+
+// ProviderKeyHasSuffix applies the HasSuffix predicate on the "provider_key" field.
+func ProviderKeyHasSuffix(v string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldHasSuffix(FieldProviderKey, v))
+}
+
+// ProviderKeyEqualFold applies the EqualFold predicate on the "provider_key" field.
+func ProviderKeyEqualFold(v string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldEqualFold(FieldProviderKey, v))
+}
+
+// ProviderKeyContainsFold applies the ContainsFold predicate on the "provider_key" field.
+func ProviderKeyContainsFold(v string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldContainsFold(FieldProviderKey, v))
+}
+
+// ProviderSubjectEQ applies the EQ predicate on the "provider_subject" field.
+func ProviderSubjectEQ(v string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldEQ(FieldProviderSubject, v))
+}
+
+// ProviderSubjectNEQ applies the NEQ predicate on the "provider_subject" field.
+func ProviderSubjectNEQ(v string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldNEQ(FieldProviderSubject, v))
+}
+
+// ProviderSubjectIn applies the In predicate on the "provider_subject" field.
+func ProviderSubjectIn(vs ...string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldIn(FieldProviderSubject, vs...))
+}
+
+// ProviderSubjectNotIn applies the NotIn predicate on the "provider_subject" field.
+func ProviderSubjectNotIn(vs ...string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldNotIn(FieldProviderSubject, vs...))
+}
+
+// ProviderSubjectGT applies the GT predicate on the "provider_subject" field.
+func ProviderSubjectGT(v string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldGT(FieldProviderSubject, v))
+}
+
+// ProviderSubjectGTE applies the GTE predicate on the "provider_subject" field.
+func ProviderSubjectGTE(v string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldGTE(FieldProviderSubject, v))
+}
+
+// ProviderSubjectLT applies the LT predicate on the "provider_subject" field.
+func ProviderSubjectLT(v string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldLT(FieldProviderSubject, v))
+}
+
+// ProviderSubjectLTE applies the LTE predicate on the "provider_subject" field.
+func ProviderSubjectLTE(v string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldLTE(FieldProviderSubject, v))
+}
+
+// ProviderSubjectContains applies the Contains predicate on the "provider_subject" field.
+func ProviderSubjectContains(v string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldContains(FieldProviderSubject, v))
+}
+
+// ProviderSubjectHasPrefix applies the HasPrefix predicate on the "provider_subject" field.
+func ProviderSubjectHasPrefix(v string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldHasPrefix(FieldProviderSubject, v))
+}
+
+// ProviderSubjectHasSuffix applies the HasSuffix predicate on the "provider_subject" field.
+func ProviderSubjectHasSuffix(v string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldHasSuffix(FieldProviderSubject, v))
+}
+
+// ProviderSubjectEqualFold applies the EqualFold predicate on the "provider_subject" field.
+func ProviderSubjectEqualFold(v string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldEqualFold(FieldProviderSubject, v))
+}
+
+// ProviderSubjectContainsFold applies the ContainsFold predicate on the "provider_subject" field.
+func ProviderSubjectContainsFold(v string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldContainsFold(FieldProviderSubject, v))
+}
+
+// TargetUserIDEQ applies the EQ predicate on the "target_user_id" field.
+func TargetUserIDEQ(v int64) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldEQ(FieldTargetUserID, v))
+}
+
+// TargetUserIDNEQ applies the NEQ predicate on the "target_user_id" field.
+func TargetUserIDNEQ(v int64) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldNEQ(FieldTargetUserID, v))
+}
+
+// TargetUserIDIn applies the In predicate on the "target_user_id" field.
+func TargetUserIDIn(vs ...int64) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldIn(FieldTargetUserID, vs...))
+}
+
+// TargetUserIDNotIn applies the NotIn predicate on the "target_user_id" field.
+func TargetUserIDNotIn(vs ...int64) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldNotIn(FieldTargetUserID, vs...))
+}
+
+// TargetUserIDIsNil applies the IsNil predicate on the "target_user_id" field.
+func TargetUserIDIsNil() predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldIsNull(FieldTargetUserID))
+}
+
+// TargetUserIDNotNil applies the NotNil predicate on the "target_user_id" field.
+func TargetUserIDNotNil() predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldNotNull(FieldTargetUserID))
+}
+
+// RedirectToEQ applies the EQ predicate on the "redirect_to" field.
+func RedirectToEQ(v string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldEQ(FieldRedirectTo, v))
+}
+
+// RedirectToNEQ applies the NEQ predicate on the "redirect_to" field.
+func RedirectToNEQ(v string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldNEQ(FieldRedirectTo, v))
+}
+
+// RedirectToIn applies the In predicate on the "redirect_to" field.
+func RedirectToIn(vs ...string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldIn(FieldRedirectTo, vs...))
+}
+
+// RedirectToNotIn applies the NotIn predicate on the "redirect_to" field.
+func RedirectToNotIn(vs ...string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldNotIn(FieldRedirectTo, vs...))
+}
+
+// RedirectToGT applies the GT predicate on the "redirect_to" field.
+func RedirectToGT(v string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldGT(FieldRedirectTo, v))
+}
+
+// RedirectToGTE applies the GTE predicate on the "redirect_to" field.
+func RedirectToGTE(v string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldGTE(FieldRedirectTo, v))
+}
+
+// RedirectToLT applies the LT predicate on the "redirect_to" field.
+func RedirectToLT(v string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldLT(FieldRedirectTo, v))
+}
+
+// RedirectToLTE applies the LTE predicate on the "redirect_to" field.
+func RedirectToLTE(v string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldLTE(FieldRedirectTo, v))
+}
+
+// RedirectToContains applies the Contains predicate on the "redirect_to" field.
+func RedirectToContains(v string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldContains(FieldRedirectTo, v))
+}
+
+// RedirectToHasPrefix applies the HasPrefix predicate on the "redirect_to" field.
+func RedirectToHasPrefix(v string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldHasPrefix(FieldRedirectTo, v))
+}
+
+// RedirectToHasSuffix applies the HasSuffix predicate on the "redirect_to" field.
+func RedirectToHasSuffix(v string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldHasSuffix(FieldRedirectTo, v))
+}
+
+// RedirectToEqualFold applies the EqualFold predicate on the "redirect_to" field.
+func RedirectToEqualFold(v string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldEqualFold(FieldRedirectTo, v))
+}
+
+// RedirectToContainsFold applies the ContainsFold predicate on the "redirect_to" field.
+func RedirectToContainsFold(v string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldContainsFold(FieldRedirectTo, v))
+}
+
+// ResolvedEmailEQ applies the EQ predicate on the "resolved_email" field.
+func ResolvedEmailEQ(v string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldEQ(FieldResolvedEmail, v))
+}
+
+// ResolvedEmailNEQ applies the NEQ predicate on the "resolved_email" field.
+func ResolvedEmailNEQ(v string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldNEQ(FieldResolvedEmail, v))
+}
+
+// ResolvedEmailIn applies the In predicate on the "resolved_email" field.
+func ResolvedEmailIn(vs ...string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldIn(FieldResolvedEmail, vs...))
+}
+
+// ResolvedEmailNotIn applies the NotIn predicate on the "resolved_email" field.
+func ResolvedEmailNotIn(vs ...string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldNotIn(FieldResolvedEmail, vs...))
+}
+
+// ResolvedEmailGT applies the GT predicate on the "resolved_email" field.
+func ResolvedEmailGT(v string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldGT(FieldResolvedEmail, v))
+}
+
+// ResolvedEmailGTE applies the GTE predicate on the "resolved_email" field.
+func ResolvedEmailGTE(v string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldGTE(FieldResolvedEmail, v))
+}
+
+// ResolvedEmailLT applies the LT predicate on the "resolved_email" field.
+func ResolvedEmailLT(v string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldLT(FieldResolvedEmail, v))
+}
+
+// ResolvedEmailLTE applies the LTE predicate on the "resolved_email" field.
+func ResolvedEmailLTE(v string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldLTE(FieldResolvedEmail, v))
+}
+
+// ResolvedEmailContains applies the Contains predicate on the "resolved_email" field.
+func ResolvedEmailContains(v string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldContains(FieldResolvedEmail, v))
+}
+
+// ResolvedEmailHasPrefix applies the HasPrefix predicate on the "resolved_email" field.
+func ResolvedEmailHasPrefix(v string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldHasPrefix(FieldResolvedEmail, v))
+}
+
+// ResolvedEmailHasSuffix applies the HasSuffix predicate on the "resolved_email" field.
+func ResolvedEmailHasSuffix(v string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldHasSuffix(FieldResolvedEmail, v))
+}
+
+// ResolvedEmailEqualFold applies the EqualFold predicate on the "resolved_email" field.
+func ResolvedEmailEqualFold(v string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldEqualFold(FieldResolvedEmail, v))
+}
+
+// ResolvedEmailContainsFold applies the ContainsFold predicate on the "resolved_email" field.
+func ResolvedEmailContainsFold(v string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldContainsFold(FieldResolvedEmail, v))
+}
+
+// RegistrationPasswordHashEQ applies the EQ predicate on the "registration_password_hash" field.
+func RegistrationPasswordHashEQ(v string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldEQ(FieldRegistrationPasswordHash, v))
+}
+
+// RegistrationPasswordHashNEQ applies the NEQ predicate on the "registration_password_hash" field.
+func RegistrationPasswordHashNEQ(v string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldNEQ(FieldRegistrationPasswordHash, v))
+}
+
+// RegistrationPasswordHashIn applies the In predicate on the "registration_password_hash" field.
+func RegistrationPasswordHashIn(vs ...string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldIn(FieldRegistrationPasswordHash, vs...))
+}
+
+// RegistrationPasswordHashNotIn applies the NotIn predicate on the "registration_password_hash" field.
+func RegistrationPasswordHashNotIn(vs ...string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldNotIn(FieldRegistrationPasswordHash, vs...))
+}
+
+// RegistrationPasswordHashGT applies the GT predicate on the "registration_password_hash" field.
+func RegistrationPasswordHashGT(v string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldGT(FieldRegistrationPasswordHash, v))
+}
+
+// RegistrationPasswordHashGTE applies the GTE predicate on the "registration_password_hash" field.
+func RegistrationPasswordHashGTE(v string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldGTE(FieldRegistrationPasswordHash, v))
+}
+
+// RegistrationPasswordHashLT applies the LT predicate on the "registration_password_hash" field.
+func RegistrationPasswordHashLT(v string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldLT(FieldRegistrationPasswordHash, v))
+}
+
+// RegistrationPasswordHashLTE applies the LTE predicate on the "registration_password_hash" field.
+func RegistrationPasswordHashLTE(v string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldLTE(FieldRegistrationPasswordHash, v))
+}
+
+// RegistrationPasswordHashContains applies the Contains predicate on the "registration_password_hash" field.
+func RegistrationPasswordHashContains(v string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldContains(FieldRegistrationPasswordHash, v))
+}
+
+// RegistrationPasswordHashHasPrefix applies the HasPrefix predicate on the "registration_password_hash" field.
+func RegistrationPasswordHashHasPrefix(v string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldHasPrefix(FieldRegistrationPasswordHash, v))
+}
+
+// RegistrationPasswordHashHasSuffix applies the HasSuffix predicate on the "registration_password_hash" field.
+func RegistrationPasswordHashHasSuffix(v string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldHasSuffix(FieldRegistrationPasswordHash, v))
+}
+
+// RegistrationPasswordHashEqualFold applies the EqualFold predicate on the "registration_password_hash" field.
+func RegistrationPasswordHashEqualFold(v string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldEqualFold(FieldRegistrationPasswordHash, v))
+}
+
+// RegistrationPasswordHashContainsFold applies the ContainsFold predicate on the "registration_password_hash" field.
+func RegistrationPasswordHashContainsFold(v string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldContainsFold(FieldRegistrationPasswordHash, v))
+}
+
+// BrowserSessionKeyEQ applies the EQ predicate on the "browser_session_key" field.
+func BrowserSessionKeyEQ(v string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldEQ(FieldBrowserSessionKey, v))
+}
+
+// BrowserSessionKeyNEQ applies the NEQ predicate on the "browser_session_key" field.
+func BrowserSessionKeyNEQ(v string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldNEQ(FieldBrowserSessionKey, v))
+}
+
+// BrowserSessionKeyIn applies the In predicate on the "browser_session_key" field.
+func BrowserSessionKeyIn(vs ...string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldIn(FieldBrowserSessionKey, vs...))
+}
+
+// BrowserSessionKeyNotIn applies the NotIn predicate on the "browser_session_key" field.
+func BrowserSessionKeyNotIn(vs ...string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldNotIn(FieldBrowserSessionKey, vs...))
+}
+
+// BrowserSessionKeyGT applies the GT predicate on the "browser_session_key" field.
+func BrowserSessionKeyGT(v string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldGT(FieldBrowserSessionKey, v))
+}
+
+// BrowserSessionKeyGTE applies the GTE predicate on the "browser_session_key" field.
+func BrowserSessionKeyGTE(v string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldGTE(FieldBrowserSessionKey, v))
+}
+
+// BrowserSessionKeyLT applies the LT predicate on the "browser_session_key" field.
+func BrowserSessionKeyLT(v string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldLT(FieldBrowserSessionKey, v))
+}
+
+// BrowserSessionKeyLTE applies the LTE predicate on the "browser_session_key" field.
+func BrowserSessionKeyLTE(v string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldLTE(FieldBrowserSessionKey, v))
+}
+
+// BrowserSessionKeyContains applies the Contains predicate on the "browser_session_key" field.
+func BrowserSessionKeyContains(v string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldContains(FieldBrowserSessionKey, v))
+}
+
+// BrowserSessionKeyHasPrefix applies the HasPrefix predicate on the "browser_session_key" field.
+func BrowserSessionKeyHasPrefix(v string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldHasPrefix(FieldBrowserSessionKey, v))
+}
+
+// BrowserSessionKeyHasSuffix applies the HasSuffix predicate on the "browser_session_key" field.
+func BrowserSessionKeyHasSuffix(v string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldHasSuffix(FieldBrowserSessionKey, v))
+}
+
+// BrowserSessionKeyEqualFold applies the EqualFold predicate on the "browser_session_key" field.
+func BrowserSessionKeyEqualFold(v string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldEqualFold(FieldBrowserSessionKey, v))
+}
+
+// BrowserSessionKeyContainsFold applies the ContainsFold predicate on the "browser_session_key" field.
+func BrowserSessionKeyContainsFold(v string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldContainsFold(FieldBrowserSessionKey, v))
+}
+
+// CompletionCodeHashEQ applies the EQ predicate on the "completion_code_hash" field.
+func CompletionCodeHashEQ(v string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldEQ(FieldCompletionCodeHash, v))
+}
+
+// CompletionCodeHashNEQ applies the NEQ predicate on the "completion_code_hash" field.
+func CompletionCodeHashNEQ(v string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldNEQ(FieldCompletionCodeHash, v))
+}
+
+// CompletionCodeHashIn applies the In predicate on the "completion_code_hash" field.
+func CompletionCodeHashIn(vs ...string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldIn(FieldCompletionCodeHash, vs...))
+}
+
+// CompletionCodeHashNotIn applies the NotIn predicate on the "completion_code_hash" field.
+func CompletionCodeHashNotIn(vs ...string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldNotIn(FieldCompletionCodeHash, vs...))
+}
+
+// CompletionCodeHashGT applies the GT predicate on the "completion_code_hash" field.
+func CompletionCodeHashGT(v string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldGT(FieldCompletionCodeHash, v))
+}
+
+// CompletionCodeHashGTE applies the GTE predicate on the "completion_code_hash" field.
+func CompletionCodeHashGTE(v string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldGTE(FieldCompletionCodeHash, v))
+}
+
+// CompletionCodeHashLT applies the LT predicate on the "completion_code_hash" field.
+func CompletionCodeHashLT(v string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldLT(FieldCompletionCodeHash, v))
+}
+
+// CompletionCodeHashLTE applies the LTE predicate on the "completion_code_hash" field.
+func CompletionCodeHashLTE(v string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldLTE(FieldCompletionCodeHash, v))
+}
+
+// CompletionCodeHashContains applies the Contains predicate on the "completion_code_hash" field.
+func CompletionCodeHashContains(v string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldContains(FieldCompletionCodeHash, v))
+}
+
+// CompletionCodeHashHasPrefix applies the HasPrefix predicate on the "completion_code_hash" field.
+func CompletionCodeHashHasPrefix(v string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldHasPrefix(FieldCompletionCodeHash, v))
+}
+
+// CompletionCodeHashHasSuffix applies the HasSuffix predicate on the "completion_code_hash" field.
+func CompletionCodeHashHasSuffix(v string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldHasSuffix(FieldCompletionCodeHash, v))
+}
+
+// CompletionCodeHashEqualFold applies the EqualFold predicate on the "completion_code_hash" field.
+func CompletionCodeHashEqualFold(v string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldEqualFold(FieldCompletionCodeHash, v))
+}
+
+// CompletionCodeHashContainsFold applies the ContainsFold predicate on the "completion_code_hash" field.
+func CompletionCodeHashContainsFold(v string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldContainsFold(FieldCompletionCodeHash, v))
+}
+
+// CompletionCodeExpiresAtEQ applies the EQ predicate on the "completion_code_expires_at" field.
+func CompletionCodeExpiresAtEQ(v time.Time) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldEQ(FieldCompletionCodeExpiresAt, v))
+}
+
+// CompletionCodeExpiresAtNEQ applies the NEQ predicate on the "completion_code_expires_at" field.
+func CompletionCodeExpiresAtNEQ(v time.Time) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldNEQ(FieldCompletionCodeExpiresAt, v))
+}
+
+// CompletionCodeExpiresAtIn applies the In predicate on the "completion_code_expires_at" field.
+func CompletionCodeExpiresAtIn(vs ...time.Time) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldIn(FieldCompletionCodeExpiresAt, vs...))
+}
+
+// CompletionCodeExpiresAtNotIn applies the NotIn predicate on the "completion_code_expires_at" field.
+func CompletionCodeExpiresAtNotIn(vs ...time.Time) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldNotIn(FieldCompletionCodeExpiresAt, vs...))
+}
+
+// CompletionCodeExpiresAtGT applies the GT predicate on the "completion_code_expires_at" field.
+func CompletionCodeExpiresAtGT(v time.Time) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldGT(FieldCompletionCodeExpiresAt, v))
+}
+
+// CompletionCodeExpiresAtGTE applies the GTE predicate on the "completion_code_expires_at" field.
+func CompletionCodeExpiresAtGTE(v time.Time) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldGTE(FieldCompletionCodeExpiresAt, v))
+}
+
+// CompletionCodeExpiresAtLT applies the LT predicate on the "completion_code_expires_at" field.
+func CompletionCodeExpiresAtLT(v time.Time) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldLT(FieldCompletionCodeExpiresAt, v))
+}
+
+// CompletionCodeExpiresAtLTE applies the LTE predicate on the "completion_code_expires_at" field.
+func CompletionCodeExpiresAtLTE(v time.Time) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldLTE(FieldCompletionCodeExpiresAt, v))
+}
+
+// CompletionCodeExpiresAtIsNil applies the IsNil predicate on the "completion_code_expires_at" field.
+func CompletionCodeExpiresAtIsNil() predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldIsNull(FieldCompletionCodeExpiresAt))
+}
+
+// CompletionCodeExpiresAtNotNil applies the NotNil predicate on the "completion_code_expires_at" field.
+func CompletionCodeExpiresAtNotNil() predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldNotNull(FieldCompletionCodeExpiresAt))
+}
+
+// EmailVerifiedAtEQ applies the EQ predicate on the "email_verified_at" field.
+func EmailVerifiedAtEQ(v time.Time) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldEQ(FieldEmailVerifiedAt, v))
+}
+
+// EmailVerifiedAtNEQ applies the NEQ predicate on the "email_verified_at" field.
+func EmailVerifiedAtNEQ(v time.Time) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldNEQ(FieldEmailVerifiedAt, v))
+}
+
+// EmailVerifiedAtIn applies the In predicate on the "email_verified_at" field.
+func EmailVerifiedAtIn(vs ...time.Time) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldIn(FieldEmailVerifiedAt, vs...))
+}
+
+// EmailVerifiedAtNotIn applies the NotIn predicate on the "email_verified_at" field.
+func EmailVerifiedAtNotIn(vs ...time.Time) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldNotIn(FieldEmailVerifiedAt, vs...))
+}
+
+// EmailVerifiedAtGT applies the GT predicate on the "email_verified_at" field.
+func EmailVerifiedAtGT(v time.Time) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldGT(FieldEmailVerifiedAt, v))
+}
+
+// EmailVerifiedAtGTE applies the GTE predicate on the "email_verified_at" field.
+func EmailVerifiedAtGTE(v time.Time) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldGTE(FieldEmailVerifiedAt, v))
+}
+
+// EmailVerifiedAtLT applies the LT predicate on the "email_verified_at" field.
+func EmailVerifiedAtLT(v time.Time) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldLT(FieldEmailVerifiedAt, v))
+}
+
+// EmailVerifiedAtLTE applies the LTE predicate on the "email_verified_at" field.
+func EmailVerifiedAtLTE(v time.Time) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldLTE(FieldEmailVerifiedAt, v))
+}
+
+// EmailVerifiedAtIsNil applies the IsNil predicate on the "email_verified_at" field.
+func EmailVerifiedAtIsNil() predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldIsNull(FieldEmailVerifiedAt))
+}
+
+// EmailVerifiedAtNotNil applies the NotNil predicate on the "email_verified_at" field.
+func EmailVerifiedAtNotNil() predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldNotNull(FieldEmailVerifiedAt))
+}
+
+// PasswordVerifiedAtEQ applies the EQ predicate on the "password_verified_at" field.
+func PasswordVerifiedAtEQ(v time.Time) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldEQ(FieldPasswordVerifiedAt, v))
+}
+
+// PasswordVerifiedAtNEQ applies the NEQ predicate on the "password_verified_at" field.
+func PasswordVerifiedAtNEQ(v time.Time) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldNEQ(FieldPasswordVerifiedAt, v))
+}
+
+// PasswordVerifiedAtIn applies the In predicate on the "password_verified_at" field.
+func PasswordVerifiedAtIn(vs ...time.Time) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldIn(FieldPasswordVerifiedAt, vs...))
+}
+
+// PasswordVerifiedAtNotIn applies the NotIn predicate on the "password_verified_at" field.
+func PasswordVerifiedAtNotIn(vs ...time.Time) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldNotIn(FieldPasswordVerifiedAt, vs...))
+}
+
+// PasswordVerifiedAtGT applies the GT predicate on the "password_verified_at" field.
+func PasswordVerifiedAtGT(v time.Time) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldGT(FieldPasswordVerifiedAt, v))
+}
+
+// PasswordVerifiedAtGTE applies the GTE predicate on the "password_verified_at" field.
+func PasswordVerifiedAtGTE(v time.Time) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldGTE(FieldPasswordVerifiedAt, v))
+}
+
+// PasswordVerifiedAtLT applies the LT predicate on the "password_verified_at" field.
+func PasswordVerifiedAtLT(v time.Time) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldLT(FieldPasswordVerifiedAt, v))
+}
+
+// PasswordVerifiedAtLTE applies the LTE predicate on the "password_verified_at" field.
+func PasswordVerifiedAtLTE(v time.Time) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldLTE(FieldPasswordVerifiedAt, v))
+}
+
+// PasswordVerifiedAtIsNil applies the IsNil predicate on the "password_verified_at" field.
+func PasswordVerifiedAtIsNil() predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldIsNull(FieldPasswordVerifiedAt))
+}
+
+// PasswordVerifiedAtNotNil applies the NotNil predicate on the "password_verified_at" field.
+func PasswordVerifiedAtNotNil() predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldNotNull(FieldPasswordVerifiedAt))
+}
+
+// TotpVerifiedAtEQ applies the EQ predicate on the "totp_verified_at" field.
+func TotpVerifiedAtEQ(v time.Time) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldEQ(FieldTotpVerifiedAt, v))
+}
+
+// TotpVerifiedAtNEQ applies the NEQ predicate on the "totp_verified_at" field.
+func TotpVerifiedAtNEQ(v time.Time) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldNEQ(FieldTotpVerifiedAt, v))
+}
+
+// TotpVerifiedAtIn applies the In predicate on the "totp_verified_at" field.
+func TotpVerifiedAtIn(vs ...time.Time) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldIn(FieldTotpVerifiedAt, vs...))
+}
+
+// TotpVerifiedAtNotIn applies the NotIn predicate on the "totp_verified_at" field.
+func TotpVerifiedAtNotIn(vs ...time.Time) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldNotIn(FieldTotpVerifiedAt, vs...))
+}
+
+// TotpVerifiedAtGT applies the GT predicate on the "totp_verified_at" field.
+func TotpVerifiedAtGT(v time.Time) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldGT(FieldTotpVerifiedAt, v))
+}
+
+// TotpVerifiedAtGTE applies the GTE predicate on the "totp_verified_at" field.
+func TotpVerifiedAtGTE(v time.Time) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldGTE(FieldTotpVerifiedAt, v))
+}
+
+// TotpVerifiedAtLT applies the LT predicate on the "totp_verified_at" field.
+func TotpVerifiedAtLT(v time.Time) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldLT(FieldTotpVerifiedAt, v))
+}
+
+// TotpVerifiedAtLTE applies the LTE predicate on the "totp_verified_at" field.
+func TotpVerifiedAtLTE(v time.Time) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldLTE(FieldTotpVerifiedAt, v))
+}
+
+// TotpVerifiedAtIsNil applies the IsNil predicate on the "totp_verified_at" field.
+func TotpVerifiedAtIsNil() predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldIsNull(FieldTotpVerifiedAt))
+}
+
+// TotpVerifiedAtNotNil applies the NotNil predicate on the "totp_verified_at" field.
+func TotpVerifiedAtNotNil() predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldNotNull(FieldTotpVerifiedAt))
+}
+
+// ExpiresAtEQ applies the EQ predicate on the "expires_at" field.
+func ExpiresAtEQ(v time.Time) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldEQ(FieldExpiresAt, v))
+}
+
+// ExpiresAtNEQ applies the NEQ predicate on the "expires_at" field.
+func ExpiresAtNEQ(v time.Time) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldNEQ(FieldExpiresAt, v))
+}
+
+// ExpiresAtIn applies the In predicate on the "expires_at" field.
+func ExpiresAtIn(vs ...time.Time) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldIn(FieldExpiresAt, vs...))
+}
+
+// ExpiresAtNotIn applies the NotIn predicate on the "expires_at" field.
+func ExpiresAtNotIn(vs ...time.Time) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldNotIn(FieldExpiresAt, vs...))
+}
+
+// ExpiresAtGT applies the GT predicate on the "expires_at" field.
+func ExpiresAtGT(v time.Time) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldGT(FieldExpiresAt, v))
+}
+
+// ExpiresAtGTE applies the GTE predicate on the "expires_at" field.
+func ExpiresAtGTE(v time.Time) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldGTE(FieldExpiresAt, v))
+}
+
+// ExpiresAtLT applies the LT predicate on the "expires_at" field.
+func ExpiresAtLT(v time.Time) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldLT(FieldExpiresAt, v))
+}
+
+// ExpiresAtLTE applies the LTE predicate on the "expires_at" field.
+func ExpiresAtLTE(v time.Time) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldLTE(FieldExpiresAt, v))
+}
+
+// ConsumedAtEQ applies the EQ predicate on the "consumed_at" field.
+func ConsumedAtEQ(v time.Time) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldEQ(FieldConsumedAt, v))
+}
+
+// ConsumedAtNEQ applies the NEQ predicate on the "consumed_at" field.
+func ConsumedAtNEQ(v time.Time) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldNEQ(FieldConsumedAt, v))
+}
+
+// ConsumedAtIn applies the In predicate on the "consumed_at" field.
+func ConsumedAtIn(vs ...time.Time) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldIn(FieldConsumedAt, vs...))
+}
+
+// ConsumedAtNotIn applies the NotIn predicate on the "consumed_at" field.
+func ConsumedAtNotIn(vs ...time.Time) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldNotIn(FieldConsumedAt, vs...))
+}
+
+// ConsumedAtGT applies the GT predicate on the "consumed_at" field.
+func ConsumedAtGT(v time.Time) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldGT(FieldConsumedAt, v))
+}
+
+// ConsumedAtGTE applies the GTE predicate on the "consumed_at" field.
+func ConsumedAtGTE(v time.Time) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldGTE(FieldConsumedAt, v))
+}
+
+// ConsumedAtLT applies the LT predicate on the "consumed_at" field.
+func ConsumedAtLT(v time.Time) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldLT(FieldConsumedAt, v))
+}
+
+// ConsumedAtLTE applies the LTE predicate on the "consumed_at" field.
+func ConsumedAtLTE(v time.Time) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldLTE(FieldConsumedAt, v))
+}
+
+// ConsumedAtIsNil applies the IsNil predicate on the "consumed_at" field.
+func ConsumedAtIsNil() predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldIsNull(FieldConsumedAt))
+}
+
+// ConsumedAtNotNil applies the NotNil predicate on the "consumed_at" field.
+func ConsumedAtNotNil() predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldNotNull(FieldConsumedAt))
+}
+
+// HasTargetUser applies the HasEdge predicate on the "target_user" edge.
+func HasTargetUser() predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(func(s *sql.Selector) {
+ step := sqlgraph.NewStep(
+ sqlgraph.From(Table, FieldID),
+ sqlgraph.Edge(sqlgraph.M2O, true, TargetUserTable, TargetUserColumn),
+ )
+ sqlgraph.HasNeighbors(s, step)
+ })
+}
+
+// HasTargetUserWith applies the HasEdge predicate on the "target_user" edge with a given conditions (other predicates).
+func HasTargetUserWith(preds ...predicate.User) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(func(s *sql.Selector) {
+ step := newTargetUserStep()
+ sqlgraph.HasNeighborsWith(s, step, func(s *sql.Selector) {
+ for _, p := range preds {
+ p(s)
+ }
+ })
+ })
+}
+
+// HasAdoptionDecision applies the HasEdge predicate on the "adoption_decision" edge.
+func HasAdoptionDecision() predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(func(s *sql.Selector) {
+ step := sqlgraph.NewStep(
+ sqlgraph.From(Table, FieldID),
+ sqlgraph.Edge(sqlgraph.O2O, false, AdoptionDecisionTable, AdoptionDecisionColumn),
+ )
+ sqlgraph.HasNeighbors(s, step)
+ })
+}
+
+// HasAdoptionDecisionWith applies the HasEdge predicate on the "adoption_decision" edge with a given conditions (other predicates).
+func HasAdoptionDecisionWith(preds ...predicate.IdentityAdoptionDecision) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(func(s *sql.Selector) {
+ step := newAdoptionDecisionStep()
+ sqlgraph.HasNeighborsWith(s, step, func(s *sql.Selector) {
+ for _, p := range preds {
+ p(s)
+ }
+ })
+ })
+}
+
+// And groups predicates with the AND operator between them.
+func And(predicates ...predicate.PendingAuthSession) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.AndPredicates(predicates...))
+}
+
+// Or groups predicates with the OR operator between them.
+func Or(predicates ...predicate.PendingAuthSession) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.OrPredicates(predicates...))
+}
+
+// Not applies the not operator on the given predicate.
+func Not(p predicate.PendingAuthSession) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.NotPredicates(p))
+}
diff --git a/backend/ent/pendingauthsession_create.go b/backend/ent/pendingauthsession_create.go
new file mode 100644
index 00000000..60276daa
--- /dev/null
+++ b/backend/ent/pendingauthsession_create.go
@@ -0,0 +1,1815 @@
+// Code generated by ent, DO NOT EDIT.
+
+package ent
+
+import (
+ "context"
+ "errors"
+ "fmt"
+ "time"
+
+ "entgo.io/ent/dialect/sql"
+ "entgo.io/ent/dialect/sql/sqlgraph"
+ "entgo.io/ent/schema/field"
+ "github.com/Wei-Shaw/sub2api/ent/identityadoptiondecision"
+ "github.com/Wei-Shaw/sub2api/ent/pendingauthsession"
+ "github.com/Wei-Shaw/sub2api/ent/user"
+)
+
+// PendingAuthSessionCreate is the builder for creating a PendingAuthSession entity.
+type PendingAuthSessionCreate struct {
+ config
+ mutation *PendingAuthSessionMutation
+ hooks []Hook
+ conflict []sql.ConflictOption
+}
+
+// SetCreatedAt sets the "created_at" field.
+func (_c *PendingAuthSessionCreate) SetCreatedAt(v time.Time) *PendingAuthSessionCreate {
+ _c.mutation.SetCreatedAt(v)
+ return _c
+}
+
+// SetNillableCreatedAt sets the "created_at" field if the given value is not nil.
+func (_c *PendingAuthSessionCreate) SetNillableCreatedAt(v *time.Time) *PendingAuthSessionCreate {
+ if v != nil {
+ _c.SetCreatedAt(*v)
+ }
+ return _c
+}
+
+// SetUpdatedAt sets the "updated_at" field.
+func (_c *PendingAuthSessionCreate) SetUpdatedAt(v time.Time) *PendingAuthSessionCreate {
+ _c.mutation.SetUpdatedAt(v)
+ return _c
+}
+
+// SetNillableUpdatedAt sets the "updated_at" field if the given value is not nil.
+func (_c *PendingAuthSessionCreate) SetNillableUpdatedAt(v *time.Time) *PendingAuthSessionCreate {
+ if v != nil {
+ _c.SetUpdatedAt(*v)
+ }
+ return _c
+}
+
+// SetSessionToken sets the "session_token" field.
+func (_c *PendingAuthSessionCreate) SetSessionToken(v string) *PendingAuthSessionCreate {
+ _c.mutation.SetSessionToken(v)
+ return _c
+}
+
+// SetIntent sets the "intent" field.
+func (_c *PendingAuthSessionCreate) SetIntent(v string) *PendingAuthSessionCreate {
+ _c.mutation.SetIntent(v)
+ return _c
+}
+
+// SetProviderType sets the "provider_type" field.
+func (_c *PendingAuthSessionCreate) SetProviderType(v string) *PendingAuthSessionCreate {
+ _c.mutation.SetProviderType(v)
+ return _c
+}
+
+// SetProviderKey sets the "provider_key" field.
+func (_c *PendingAuthSessionCreate) SetProviderKey(v string) *PendingAuthSessionCreate {
+ _c.mutation.SetProviderKey(v)
+ return _c
+}
+
+// SetProviderSubject sets the "provider_subject" field.
+func (_c *PendingAuthSessionCreate) SetProviderSubject(v string) *PendingAuthSessionCreate {
+ _c.mutation.SetProviderSubject(v)
+ return _c
+}
+
+// SetTargetUserID sets the "target_user_id" field.
+func (_c *PendingAuthSessionCreate) SetTargetUserID(v int64) *PendingAuthSessionCreate {
+ _c.mutation.SetTargetUserID(v)
+ return _c
+}
+
+// SetNillableTargetUserID sets the "target_user_id" field if the given value is not nil.
+func (_c *PendingAuthSessionCreate) SetNillableTargetUserID(v *int64) *PendingAuthSessionCreate {
+ if v != nil {
+ _c.SetTargetUserID(*v)
+ }
+ return _c
+}
+
+// SetRedirectTo sets the "redirect_to" field.
+func (_c *PendingAuthSessionCreate) SetRedirectTo(v string) *PendingAuthSessionCreate {
+ _c.mutation.SetRedirectTo(v)
+ return _c
+}
+
+// SetNillableRedirectTo sets the "redirect_to" field if the given value is not nil.
+func (_c *PendingAuthSessionCreate) SetNillableRedirectTo(v *string) *PendingAuthSessionCreate {
+ if v != nil {
+ _c.SetRedirectTo(*v)
+ }
+ return _c
+}
+
+// SetResolvedEmail sets the "resolved_email" field.
+func (_c *PendingAuthSessionCreate) SetResolvedEmail(v string) *PendingAuthSessionCreate {
+ _c.mutation.SetResolvedEmail(v)
+ return _c
+}
+
+// SetNillableResolvedEmail sets the "resolved_email" field if the given value is not nil.
+func (_c *PendingAuthSessionCreate) SetNillableResolvedEmail(v *string) *PendingAuthSessionCreate {
+ if v != nil {
+ _c.SetResolvedEmail(*v)
+ }
+ return _c
+}
+
+// SetRegistrationPasswordHash sets the "registration_password_hash" field.
+func (_c *PendingAuthSessionCreate) SetRegistrationPasswordHash(v string) *PendingAuthSessionCreate {
+ _c.mutation.SetRegistrationPasswordHash(v)
+ return _c
+}
+
+// SetNillableRegistrationPasswordHash sets the "registration_password_hash" field if the given value is not nil.
+func (_c *PendingAuthSessionCreate) SetNillableRegistrationPasswordHash(v *string) *PendingAuthSessionCreate {
+ if v != nil {
+ _c.SetRegistrationPasswordHash(*v)
+ }
+ return _c
+}
+
+// SetUpstreamIdentityClaims sets the "upstream_identity_claims" field.
+func (_c *PendingAuthSessionCreate) SetUpstreamIdentityClaims(v map[string]interface{}) *PendingAuthSessionCreate {
+ _c.mutation.SetUpstreamIdentityClaims(v)
+ return _c
+}
+
+// SetLocalFlowState sets the "local_flow_state" field.
+func (_c *PendingAuthSessionCreate) SetLocalFlowState(v map[string]interface{}) *PendingAuthSessionCreate {
+ _c.mutation.SetLocalFlowState(v)
+ return _c
+}
+
+// SetBrowserSessionKey sets the "browser_session_key" field.
+func (_c *PendingAuthSessionCreate) SetBrowserSessionKey(v string) *PendingAuthSessionCreate {
+ _c.mutation.SetBrowserSessionKey(v)
+ return _c
+}
+
+// SetNillableBrowserSessionKey sets the "browser_session_key" field if the given value is not nil.
+func (_c *PendingAuthSessionCreate) SetNillableBrowserSessionKey(v *string) *PendingAuthSessionCreate {
+ if v != nil {
+ _c.SetBrowserSessionKey(*v)
+ }
+ return _c
+}
+
+// SetCompletionCodeHash sets the "completion_code_hash" field.
+func (_c *PendingAuthSessionCreate) SetCompletionCodeHash(v string) *PendingAuthSessionCreate {
+ _c.mutation.SetCompletionCodeHash(v)
+ return _c
+}
+
+// SetNillableCompletionCodeHash sets the "completion_code_hash" field if the given value is not nil.
+func (_c *PendingAuthSessionCreate) SetNillableCompletionCodeHash(v *string) *PendingAuthSessionCreate {
+ if v != nil {
+ _c.SetCompletionCodeHash(*v)
+ }
+ return _c
+}
+
+// SetCompletionCodeExpiresAt sets the "completion_code_expires_at" field.
+func (_c *PendingAuthSessionCreate) SetCompletionCodeExpiresAt(v time.Time) *PendingAuthSessionCreate {
+ _c.mutation.SetCompletionCodeExpiresAt(v)
+ return _c
+}
+
+// SetNillableCompletionCodeExpiresAt sets the "completion_code_expires_at" field if the given value is not nil.
+func (_c *PendingAuthSessionCreate) SetNillableCompletionCodeExpiresAt(v *time.Time) *PendingAuthSessionCreate {
+ if v != nil {
+ _c.SetCompletionCodeExpiresAt(*v)
+ }
+ return _c
+}
+
+// SetEmailVerifiedAt sets the "email_verified_at" field.
+func (_c *PendingAuthSessionCreate) SetEmailVerifiedAt(v time.Time) *PendingAuthSessionCreate {
+ _c.mutation.SetEmailVerifiedAt(v)
+ return _c
+}
+
+// SetNillableEmailVerifiedAt sets the "email_verified_at" field if the given value is not nil.
+func (_c *PendingAuthSessionCreate) SetNillableEmailVerifiedAt(v *time.Time) *PendingAuthSessionCreate {
+ if v != nil {
+ _c.SetEmailVerifiedAt(*v)
+ }
+ return _c
+}
+
+// SetPasswordVerifiedAt sets the "password_verified_at" field.
+func (_c *PendingAuthSessionCreate) SetPasswordVerifiedAt(v time.Time) *PendingAuthSessionCreate {
+ _c.mutation.SetPasswordVerifiedAt(v)
+ return _c
+}
+
+// SetNillablePasswordVerifiedAt sets the "password_verified_at" field if the given value is not nil.
+func (_c *PendingAuthSessionCreate) SetNillablePasswordVerifiedAt(v *time.Time) *PendingAuthSessionCreate {
+ if v != nil {
+ _c.SetPasswordVerifiedAt(*v)
+ }
+ return _c
+}
+
+// SetTotpVerifiedAt sets the "totp_verified_at" field.
+func (_c *PendingAuthSessionCreate) SetTotpVerifiedAt(v time.Time) *PendingAuthSessionCreate {
+ _c.mutation.SetTotpVerifiedAt(v)
+ return _c
+}
+
+// SetNillableTotpVerifiedAt sets the "totp_verified_at" field if the given value is not nil.
+func (_c *PendingAuthSessionCreate) SetNillableTotpVerifiedAt(v *time.Time) *PendingAuthSessionCreate {
+ if v != nil {
+ _c.SetTotpVerifiedAt(*v)
+ }
+ return _c
+}
+
+// SetExpiresAt sets the "expires_at" field.
+func (_c *PendingAuthSessionCreate) SetExpiresAt(v time.Time) *PendingAuthSessionCreate {
+ _c.mutation.SetExpiresAt(v)
+ return _c
+}
+
+// SetConsumedAt sets the "consumed_at" field.
+func (_c *PendingAuthSessionCreate) SetConsumedAt(v time.Time) *PendingAuthSessionCreate {
+ _c.mutation.SetConsumedAt(v)
+ return _c
+}
+
+// SetNillableConsumedAt sets the "consumed_at" field if the given value is not nil.
+func (_c *PendingAuthSessionCreate) SetNillableConsumedAt(v *time.Time) *PendingAuthSessionCreate {
+ if v != nil {
+ _c.SetConsumedAt(*v)
+ }
+ return _c
+}
+
+// SetTargetUser sets the "target_user" edge to the User entity.
+func (_c *PendingAuthSessionCreate) SetTargetUser(v *User) *PendingAuthSessionCreate {
+ return _c.SetTargetUserID(v.ID)
+}
+
+// SetAdoptionDecisionID sets the "adoption_decision" edge to the IdentityAdoptionDecision entity by ID.
+func (_c *PendingAuthSessionCreate) SetAdoptionDecisionID(id int64) *PendingAuthSessionCreate {
+ _c.mutation.SetAdoptionDecisionID(id)
+ return _c
+}
+
+// SetNillableAdoptionDecisionID sets the "adoption_decision" edge to the IdentityAdoptionDecision entity by ID if the given value is not nil.
+func (_c *PendingAuthSessionCreate) SetNillableAdoptionDecisionID(id *int64) *PendingAuthSessionCreate {
+ if id != nil {
+ _c = _c.SetAdoptionDecisionID(*id)
+ }
+ return _c
+}
+
+// SetAdoptionDecision sets the "adoption_decision" edge to the IdentityAdoptionDecision entity.
+func (_c *PendingAuthSessionCreate) SetAdoptionDecision(v *IdentityAdoptionDecision) *PendingAuthSessionCreate {
+ return _c.SetAdoptionDecisionID(v.ID)
+}
+
+// Mutation returns the PendingAuthSessionMutation object of the builder.
+func (_c *PendingAuthSessionCreate) Mutation() *PendingAuthSessionMutation {
+ return _c.mutation
+}
+
+// Save creates the PendingAuthSession in the database.
+func (_c *PendingAuthSessionCreate) Save(ctx context.Context) (*PendingAuthSession, error) {
+ _c.defaults()
+ return withHooks(ctx, _c.sqlSave, _c.mutation, _c.hooks)
+}
+
+// SaveX calls Save and panics if Save returns an error.
+func (_c *PendingAuthSessionCreate) SaveX(ctx context.Context) *PendingAuthSession {
+ v, err := _c.Save(ctx)
+ if err != nil {
+ panic(err)
+ }
+ return v
+}
+
+// Exec executes the query.
+func (_c *PendingAuthSessionCreate) Exec(ctx context.Context) error {
+ _, err := _c.Save(ctx)
+ return err
+}
+
+// ExecX is like Exec, but panics if an error occurs.
+func (_c *PendingAuthSessionCreate) ExecX(ctx context.Context) {
+ if err := _c.Exec(ctx); err != nil {
+ panic(err)
+ }
+}
+
+// defaults sets the default values of the builder before save.
+func (_c *PendingAuthSessionCreate) defaults() {
+ if _, ok := _c.mutation.CreatedAt(); !ok {
+ v := pendingauthsession.DefaultCreatedAt()
+ _c.mutation.SetCreatedAt(v)
+ }
+ if _, ok := _c.mutation.UpdatedAt(); !ok {
+ v := pendingauthsession.DefaultUpdatedAt()
+ _c.mutation.SetUpdatedAt(v)
+ }
+ if _, ok := _c.mutation.RedirectTo(); !ok {
+ v := pendingauthsession.DefaultRedirectTo
+ _c.mutation.SetRedirectTo(v)
+ }
+ if _, ok := _c.mutation.ResolvedEmail(); !ok {
+ v := pendingauthsession.DefaultResolvedEmail
+ _c.mutation.SetResolvedEmail(v)
+ }
+ if _, ok := _c.mutation.RegistrationPasswordHash(); !ok {
+ v := pendingauthsession.DefaultRegistrationPasswordHash
+ _c.mutation.SetRegistrationPasswordHash(v)
+ }
+ if _, ok := _c.mutation.UpstreamIdentityClaims(); !ok {
+ v := pendingauthsession.DefaultUpstreamIdentityClaims()
+ _c.mutation.SetUpstreamIdentityClaims(v)
+ }
+ if _, ok := _c.mutation.LocalFlowState(); !ok {
+ v := pendingauthsession.DefaultLocalFlowState()
+ _c.mutation.SetLocalFlowState(v)
+ }
+ if _, ok := _c.mutation.BrowserSessionKey(); !ok {
+ v := pendingauthsession.DefaultBrowserSessionKey
+ _c.mutation.SetBrowserSessionKey(v)
+ }
+ if _, ok := _c.mutation.CompletionCodeHash(); !ok {
+ v := pendingauthsession.DefaultCompletionCodeHash
+ _c.mutation.SetCompletionCodeHash(v)
+ }
+}
+
+// check runs all checks and user-defined validators on the builder.
+func (_c *PendingAuthSessionCreate) check() error {
+ if _, ok := _c.mutation.CreatedAt(); !ok {
+ return &ValidationError{Name: "created_at", err: errors.New(`ent: missing required field "PendingAuthSession.created_at"`)}
+ }
+ if _, ok := _c.mutation.UpdatedAt(); !ok {
+ return &ValidationError{Name: "updated_at", err: errors.New(`ent: missing required field "PendingAuthSession.updated_at"`)}
+ }
+ if _, ok := _c.mutation.SessionToken(); !ok {
+ return &ValidationError{Name: "session_token", err: errors.New(`ent: missing required field "PendingAuthSession.session_token"`)}
+ }
+ if v, ok := _c.mutation.SessionToken(); ok {
+ if err := pendingauthsession.SessionTokenValidator(v); err != nil {
+ return &ValidationError{Name: "session_token", err: fmt.Errorf(`ent: validator failed for field "PendingAuthSession.session_token": %w`, err)}
+ }
+ }
+ if _, ok := _c.mutation.Intent(); !ok {
+ return &ValidationError{Name: "intent", err: errors.New(`ent: missing required field "PendingAuthSession.intent"`)}
+ }
+ if v, ok := _c.mutation.Intent(); ok {
+ if err := pendingauthsession.IntentValidator(v); err != nil {
+ return &ValidationError{Name: "intent", err: fmt.Errorf(`ent: validator failed for field "PendingAuthSession.intent": %w`, err)}
+ }
+ }
+ if _, ok := _c.mutation.ProviderType(); !ok {
+ return &ValidationError{Name: "provider_type", err: errors.New(`ent: missing required field "PendingAuthSession.provider_type"`)}
+ }
+ if v, ok := _c.mutation.ProviderType(); ok {
+ if err := pendingauthsession.ProviderTypeValidator(v); err != nil {
+ return &ValidationError{Name: "provider_type", err: fmt.Errorf(`ent: validator failed for field "PendingAuthSession.provider_type": %w`, err)}
+ }
+ }
+ if _, ok := _c.mutation.ProviderKey(); !ok {
+ return &ValidationError{Name: "provider_key", err: errors.New(`ent: missing required field "PendingAuthSession.provider_key"`)}
+ }
+ if v, ok := _c.mutation.ProviderKey(); ok {
+ if err := pendingauthsession.ProviderKeyValidator(v); err != nil {
+ return &ValidationError{Name: "provider_key", err: fmt.Errorf(`ent: validator failed for field "PendingAuthSession.provider_key": %w`, err)}
+ }
+ }
+ if _, ok := _c.mutation.ProviderSubject(); !ok {
+ return &ValidationError{Name: "provider_subject", err: errors.New(`ent: missing required field "PendingAuthSession.provider_subject"`)}
+ }
+ if v, ok := _c.mutation.ProviderSubject(); ok {
+ if err := pendingauthsession.ProviderSubjectValidator(v); err != nil {
+ return &ValidationError{Name: "provider_subject", err: fmt.Errorf(`ent: validator failed for field "PendingAuthSession.provider_subject": %w`, err)}
+ }
+ }
+ if _, ok := _c.mutation.RedirectTo(); !ok {
+ return &ValidationError{Name: "redirect_to", err: errors.New(`ent: missing required field "PendingAuthSession.redirect_to"`)}
+ }
+ if _, ok := _c.mutation.ResolvedEmail(); !ok {
+ return &ValidationError{Name: "resolved_email", err: errors.New(`ent: missing required field "PendingAuthSession.resolved_email"`)}
+ }
+ if _, ok := _c.mutation.RegistrationPasswordHash(); !ok {
+ return &ValidationError{Name: "registration_password_hash", err: errors.New(`ent: missing required field "PendingAuthSession.registration_password_hash"`)}
+ }
+ if _, ok := _c.mutation.UpstreamIdentityClaims(); !ok {
+ return &ValidationError{Name: "upstream_identity_claims", err: errors.New(`ent: missing required field "PendingAuthSession.upstream_identity_claims"`)}
+ }
+ if _, ok := _c.mutation.LocalFlowState(); !ok {
+ return &ValidationError{Name: "local_flow_state", err: errors.New(`ent: missing required field "PendingAuthSession.local_flow_state"`)}
+ }
+ if _, ok := _c.mutation.BrowserSessionKey(); !ok {
+ return &ValidationError{Name: "browser_session_key", err: errors.New(`ent: missing required field "PendingAuthSession.browser_session_key"`)}
+ }
+ if _, ok := _c.mutation.CompletionCodeHash(); !ok {
+ return &ValidationError{Name: "completion_code_hash", err: errors.New(`ent: missing required field "PendingAuthSession.completion_code_hash"`)}
+ }
+ if _, ok := _c.mutation.ExpiresAt(); !ok {
+ return &ValidationError{Name: "expires_at", err: errors.New(`ent: missing required field "PendingAuthSession.expires_at"`)}
+ }
+ return nil
+}
+
+func (_c *PendingAuthSessionCreate) sqlSave(ctx context.Context) (*PendingAuthSession, error) {
+ if err := _c.check(); err != nil {
+ return nil, err
+ }
+ _node, _spec := _c.createSpec()
+ if err := sqlgraph.CreateNode(ctx, _c.driver, _spec); err != nil {
+ if sqlgraph.IsConstraintError(err) {
+ err = &ConstraintError{msg: err.Error(), wrap: err}
+ }
+ return nil, err
+ }
+ id := _spec.ID.Value.(int64)
+ _node.ID = int64(id)
+ _c.mutation.id = &_node.ID
+ _c.mutation.done = true
+ return _node, nil
+}
+
+func (_c *PendingAuthSessionCreate) createSpec() (*PendingAuthSession, *sqlgraph.CreateSpec) {
+ var (
+ _node = &PendingAuthSession{config: _c.config}
+ _spec = sqlgraph.NewCreateSpec(pendingauthsession.Table, sqlgraph.NewFieldSpec(pendingauthsession.FieldID, field.TypeInt64))
+ )
+ _spec.OnConflict = _c.conflict
+ if value, ok := _c.mutation.CreatedAt(); ok {
+ _spec.SetField(pendingauthsession.FieldCreatedAt, field.TypeTime, value)
+ _node.CreatedAt = value
+ }
+ if value, ok := _c.mutation.UpdatedAt(); ok {
+ _spec.SetField(pendingauthsession.FieldUpdatedAt, field.TypeTime, value)
+ _node.UpdatedAt = value
+ }
+ if value, ok := _c.mutation.SessionToken(); ok {
+ _spec.SetField(pendingauthsession.FieldSessionToken, field.TypeString, value)
+ _node.SessionToken = value
+ }
+ if value, ok := _c.mutation.Intent(); ok {
+ _spec.SetField(pendingauthsession.FieldIntent, field.TypeString, value)
+ _node.Intent = value
+ }
+ if value, ok := _c.mutation.ProviderType(); ok {
+ _spec.SetField(pendingauthsession.FieldProviderType, field.TypeString, value)
+ _node.ProviderType = value
+ }
+ if value, ok := _c.mutation.ProviderKey(); ok {
+ _spec.SetField(pendingauthsession.FieldProviderKey, field.TypeString, value)
+ _node.ProviderKey = value
+ }
+ if value, ok := _c.mutation.ProviderSubject(); ok {
+ _spec.SetField(pendingauthsession.FieldProviderSubject, field.TypeString, value)
+ _node.ProviderSubject = value
+ }
+ if value, ok := _c.mutation.RedirectTo(); ok {
+ _spec.SetField(pendingauthsession.FieldRedirectTo, field.TypeString, value)
+ _node.RedirectTo = value
+ }
+ if value, ok := _c.mutation.ResolvedEmail(); ok {
+ _spec.SetField(pendingauthsession.FieldResolvedEmail, field.TypeString, value)
+ _node.ResolvedEmail = value
+ }
+ if value, ok := _c.mutation.RegistrationPasswordHash(); ok {
+ _spec.SetField(pendingauthsession.FieldRegistrationPasswordHash, field.TypeString, value)
+ _node.RegistrationPasswordHash = value
+ }
+ if value, ok := _c.mutation.UpstreamIdentityClaims(); ok {
+ _spec.SetField(pendingauthsession.FieldUpstreamIdentityClaims, field.TypeJSON, value)
+ _node.UpstreamIdentityClaims = value
+ }
+ if value, ok := _c.mutation.LocalFlowState(); ok {
+ _spec.SetField(pendingauthsession.FieldLocalFlowState, field.TypeJSON, value)
+ _node.LocalFlowState = value
+ }
+ if value, ok := _c.mutation.BrowserSessionKey(); ok {
+ _spec.SetField(pendingauthsession.FieldBrowserSessionKey, field.TypeString, value)
+ _node.BrowserSessionKey = value
+ }
+ if value, ok := _c.mutation.CompletionCodeHash(); ok {
+ _spec.SetField(pendingauthsession.FieldCompletionCodeHash, field.TypeString, value)
+ _node.CompletionCodeHash = value
+ }
+ if value, ok := _c.mutation.CompletionCodeExpiresAt(); ok {
+ _spec.SetField(pendingauthsession.FieldCompletionCodeExpiresAt, field.TypeTime, value)
+ _node.CompletionCodeExpiresAt = &value
+ }
+ if value, ok := _c.mutation.EmailVerifiedAt(); ok {
+ _spec.SetField(pendingauthsession.FieldEmailVerifiedAt, field.TypeTime, value)
+ _node.EmailVerifiedAt = &value
+ }
+ if value, ok := _c.mutation.PasswordVerifiedAt(); ok {
+ _spec.SetField(pendingauthsession.FieldPasswordVerifiedAt, field.TypeTime, value)
+ _node.PasswordVerifiedAt = &value
+ }
+ if value, ok := _c.mutation.TotpVerifiedAt(); ok {
+ _spec.SetField(pendingauthsession.FieldTotpVerifiedAt, field.TypeTime, value)
+ _node.TotpVerifiedAt = &value
+ }
+ if value, ok := _c.mutation.ExpiresAt(); ok {
+ _spec.SetField(pendingauthsession.FieldExpiresAt, field.TypeTime, value)
+ _node.ExpiresAt = value
+ }
+ if value, ok := _c.mutation.ConsumedAt(); ok {
+ _spec.SetField(pendingauthsession.FieldConsumedAt, field.TypeTime, value)
+ _node.ConsumedAt = &value
+ }
+ if nodes := _c.mutation.TargetUserIDs(); len(nodes) > 0 {
+ edge := &sqlgraph.EdgeSpec{
+ Rel: sqlgraph.M2O,
+ Inverse: true,
+ Table: pendingauthsession.TargetUserTable,
+ Columns: []string{pendingauthsession.TargetUserColumn},
+ Bidi: false,
+ Target: &sqlgraph.EdgeTarget{
+ IDSpec: sqlgraph.NewFieldSpec(user.FieldID, field.TypeInt64),
+ },
+ }
+ for _, k := range nodes {
+ edge.Target.Nodes = append(edge.Target.Nodes, k)
+ }
+ _node.TargetUserID = &nodes[0]
+ _spec.Edges = append(_spec.Edges, edge)
+ }
+ if nodes := _c.mutation.AdoptionDecisionIDs(); len(nodes) > 0 {
+ edge := &sqlgraph.EdgeSpec{
+ Rel: sqlgraph.O2O,
+ Inverse: false,
+ Table: pendingauthsession.AdoptionDecisionTable,
+ Columns: []string{pendingauthsession.AdoptionDecisionColumn},
+ Bidi: false,
+ Target: &sqlgraph.EdgeTarget{
+ IDSpec: sqlgraph.NewFieldSpec(identityadoptiondecision.FieldID, field.TypeInt64),
+ },
+ }
+ for _, k := range nodes {
+ edge.Target.Nodes = append(edge.Target.Nodes, k)
+ }
+ _spec.Edges = append(_spec.Edges, edge)
+ }
+ return _node, _spec
+}
+
+// OnConflict allows configuring the `ON CONFLICT` / `ON DUPLICATE KEY` clause
+// of the `INSERT` statement. For example:
+//
+// client.PendingAuthSession.Create().
+// SetCreatedAt(v).
+// OnConflict(
+// // Update the row with the new values
+// // the was proposed for insertion.
+// sql.ResolveWithNewValues(),
+// ).
+// // Override some of the fields with custom
+// // update values.
+// Update(func(u *ent.PendingAuthSessionUpsert) {
+// SetCreatedAt(v+v).
+// }).
+// Exec(ctx)
+func (_c *PendingAuthSessionCreate) OnConflict(opts ...sql.ConflictOption) *PendingAuthSessionUpsertOne {
+ _c.conflict = opts
+ return &PendingAuthSessionUpsertOne{
+ create: _c,
+ }
+}
+
+// OnConflictColumns calls `OnConflict` and configures the columns
+// as conflict target. Using this option is equivalent to using:
+//
+// client.PendingAuthSession.Create().
+// OnConflict(sql.ConflictColumns(columns...)).
+// Exec(ctx)
+func (_c *PendingAuthSessionCreate) OnConflictColumns(columns ...string) *PendingAuthSessionUpsertOne {
+ _c.conflict = append(_c.conflict, sql.ConflictColumns(columns...))
+ return &PendingAuthSessionUpsertOne{
+ create: _c,
+ }
+}
+
+type (
+ // PendingAuthSessionUpsertOne is the builder for "upsert"-ing
+ // one PendingAuthSession node.
+ PendingAuthSessionUpsertOne struct {
+ create *PendingAuthSessionCreate
+ }
+
+ // PendingAuthSessionUpsert is the "OnConflict" setter.
+ PendingAuthSessionUpsert struct {
+ *sql.UpdateSet
+ }
+)
+
+// SetUpdatedAt sets the "updated_at" field.
+func (u *PendingAuthSessionUpsert) SetUpdatedAt(v time.Time) *PendingAuthSessionUpsert {
+ u.Set(pendingauthsession.FieldUpdatedAt, v)
+ return u
+}
+
+// UpdateUpdatedAt sets the "updated_at" field to the value that was provided on create.
+func (u *PendingAuthSessionUpsert) UpdateUpdatedAt() *PendingAuthSessionUpsert {
+ u.SetExcluded(pendingauthsession.FieldUpdatedAt)
+ return u
+}
+
+// SetSessionToken sets the "session_token" field.
+func (u *PendingAuthSessionUpsert) SetSessionToken(v string) *PendingAuthSessionUpsert {
+ u.Set(pendingauthsession.FieldSessionToken, v)
+ return u
+}
+
+// UpdateSessionToken sets the "session_token" field to the value that was provided on create.
+func (u *PendingAuthSessionUpsert) UpdateSessionToken() *PendingAuthSessionUpsert {
+ u.SetExcluded(pendingauthsession.FieldSessionToken)
+ return u
+}
+
+// SetIntent sets the "intent" field.
+func (u *PendingAuthSessionUpsert) SetIntent(v string) *PendingAuthSessionUpsert {
+ u.Set(pendingauthsession.FieldIntent, v)
+ return u
+}
+
+// UpdateIntent sets the "intent" field to the value that was provided on create.
+func (u *PendingAuthSessionUpsert) UpdateIntent() *PendingAuthSessionUpsert {
+ u.SetExcluded(pendingauthsession.FieldIntent)
+ return u
+}
+
+// SetProviderType sets the "provider_type" field.
+func (u *PendingAuthSessionUpsert) SetProviderType(v string) *PendingAuthSessionUpsert {
+ u.Set(pendingauthsession.FieldProviderType, v)
+ return u
+}
+
+// UpdateProviderType sets the "provider_type" field to the value that was provided on create.
+func (u *PendingAuthSessionUpsert) UpdateProviderType() *PendingAuthSessionUpsert {
+ u.SetExcluded(pendingauthsession.FieldProviderType)
+ return u
+}
+
+// SetProviderKey sets the "provider_key" field.
+func (u *PendingAuthSessionUpsert) SetProviderKey(v string) *PendingAuthSessionUpsert {
+ u.Set(pendingauthsession.FieldProviderKey, v)
+ return u
+}
+
+// UpdateProviderKey sets the "provider_key" field to the value that was provided on create.
+func (u *PendingAuthSessionUpsert) UpdateProviderKey() *PendingAuthSessionUpsert {
+ u.SetExcluded(pendingauthsession.FieldProviderKey)
+ return u
+}
+
+// SetProviderSubject sets the "provider_subject" field.
+func (u *PendingAuthSessionUpsert) SetProviderSubject(v string) *PendingAuthSessionUpsert {
+ u.Set(pendingauthsession.FieldProviderSubject, v)
+ return u
+}
+
+// UpdateProviderSubject sets the "provider_subject" field to the value that was provided on create.
+func (u *PendingAuthSessionUpsert) UpdateProviderSubject() *PendingAuthSessionUpsert {
+ u.SetExcluded(pendingauthsession.FieldProviderSubject)
+ return u
+}
+
+// SetTargetUserID sets the "target_user_id" field.
+func (u *PendingAuthSessionUpsert) SetTargetUserID(v int64) *PendingAuthSessionUpsert {
+ u.Set(pendingauthsession.FieldTargetUserID, v)
+ return u
+}
+
+// UpdateTargetUserID sets the "target_user_id" field to the value that was provided on create.
+func (u *PendingAuthSessionUpsert) UpdateTargetUserID() *PendingAuthSessionUpsert {
+ u.SetExcluded(pendingauthsession.FieldTargetUserID)
+ return u
+}
+
+// ClearTargetUserID clears the value of the "target_user_id" field.
+func (u *PendingAuthSessionUpsert) ClearTargetUserID() *PendingAuthSessionUpsert {
+ u.SetNull(pendingauthsession.FieldTargetUserID)
+ return u
+}
+
+// SetRedirectTo sets the "redirect_to" field.
+func (u *PendingAuthSessionUpsert) SetRedirectTo(v string) *PendingAuthSessionUpsert {
+ u.Set(pendingauthsession.FieldRedirectTo, v)
+ return u
+}
+
+// UpdateRedirectTo sets the "redirect_to" field to the value that was provided on create.
+func (u *PendingAuthSessionUpsert) UpdateRedirectTo() *PendingAuthSessionUpsert {
+ u.SetExcluded(pendingauthsession.FieldRedirectTo)
+ return u
+}
+
+// SetResolvedEmail sets the "resolved_email" field.
+func (u *PendingAuthSessionUpsert) SetResolvedEmail(v string) *PendingAuthSessionUpsert {
+ u.Set(pendingauthsession.FieldResolvedEmail, v)
+ return u
+}
+
+// UpdateResolvedEmail sets the "resolved_email" field to the value that was provided on create.
+func (u *PendingAuthSessionUpsert) UpdateResolvedEmail() *PendingAuthSessionUpsert {
+ u.SetExcluded(pendingauthsession.FieldResolvedEmail)
+ return u
+}
+
+// SetRegistrationPasswordHash sets the "registration_password_hash" field.
+func (u *PendingAuthSessionUpsert) SetRegistrationPasswordHash(v string) *PendingAuthSessionUpsert {
+ u.Set(pendingauthsession.FieldRegistrationPasswordHash, v)
+ return u
+}
+
+// UpdateRegistrationPasswordHash sets the "registration_password_hash" field to the value that was provided on create.
+func (u *PendingAuthSessionUpsert) UpdateRegistrationPasswordHash() *PendingAuthSessionUpsert {
+ u.SetExcluded(pendingauthsession.FieldRegistrationPasswordHash)
+ return u
+}
+
+// SetUpstreamIdentityClaims sets the "upstream_identity_claims" field.
+func (u *PendingAuthSessionUpsert) SetUpstreamIdentityClaims(v map[string]interface{}) *PendingAuthSessionUpsert {
+ u.Set(pendingauthsession.FieldUpstreamIdentityClaims, v)
+ return u
+}
+
+// UpdateUpstreamIdentityClaims sets the "upstream_identity_claims" field to the value that was provided on create.
+func (u *PendingAuthSessionUpsert) UpdateUpstreamIdentityClaims() *PendingAuthSessionUpsert {
+ u.SetExcluded(pendingauthsession.FieldUpstreamIdentityClaims)
+ return u
+}
+
+// SetLocalFlowState sets the "local_flow_state" field.
+func (u *PendingAuthSessionUpsert) SetLocalFlowState(v map[string]interface{}) *PendingAuthSessionUpsert {
+ u.Set(pendingauthsession.FieldLocalFlowState, v)
+ return u
+}
+
+// UpdateLocalFlowState sets the "local_flow_state" field to the value that was provided on create.
+func (u *PendingAuthSessionUpsert) UpdateLocalFlowState() *PendingAuthSessionUpsert {
+ u.SetExcluded(pendingauthsession.FieldLocalFlowState)
+ return u
+}
+
+// SetBrowserSessionKey sets the "browser_session_key" field.
+func (u *PendingAuthSessionUpsert) SetBrowserSessionKey(v string) *PendingAuthSessionUpsert {
+ u.Set(pendingauthsession.FieldBrowserSessionKey, v)
+ return u
+}
+
+// UpdateBrowserSessionKey sets the "browser_session_key" field to the value that was provided on create.
+func (u *PendingAuthSessionUpsert) UpdateBrowserSessionKey() *PendingAuthSessionUpsert {
+ u.SetExcluded(pendingauthsession.FieldBrowserSessionKey)
+ return u
+}
+
+// SetCompletionCodeHash sets the "completion_code_hash" field.
+func (u *PendingAuthSessionUpsert) SetCompletionCodeHash(v string) *PendingAuthSessionUpsert {
+ u.Set(pendingauthsession.FieldCompletionCodeHash, v)
+ return u
+}
+
+// UpdateCompletionCodeHash sets the "completion_code_hash" field to the value that was provided on create.
+func (u *PendingAuthSessionUpsert) UpdateCompletionCodeHash() *PendingAuthSessionUpsert {
+ u.SetExcluded(pendingauthsession.FieldCompletionCodeHash)
+ return u
+}
+
+// SetCompletionCodeExpiresAt sets the "completion_code_expires_at" field.
+func (u *PendingAuthSessionUpsert) SetCompletionCodeExpiresAt(v time.Time) *PendingAuthSessionUpsert {
+ u.Set(pendingauthsession.FieldCompletionCodeExpiresAt, v)
+ return u
+}
+
+// UpdateCompletionCodeExpiresAt sets the "completion_code_expires_at" field to the value that was provided on create.
+func (u *PendingAuthSessionUpsert) UpdateCompletionCodeExpiresAt() *PendingAuthSessionUpsert {
+ u.SetExcluded(pendingauthsession.FieldCompletionCodeExpiresAt)
+ return u
+}
+
+// ClearCompletionCodeExpiresAt clears the value of the "completion_code_expires_at" field.
+func (u *PendingAuthSessionUpsert) ClearCompletionCodeExpiresAt() *PendingAuthSessionUpsert {
+ u.SetNull(pendingauthsession.FieldCompletionCodeExpiresAt)
+ return u
+}
+
+// SetEmailVerifiedAt sets the "email_verified_at" field.
+func (u *PendingAuthSessionUpsert) SetEmailVerifiedAt(v time.Time) *PendingAuthSessionUpsert {
+ u.Set(pendingauthsession.FieldEmailVerifiedAt, v)
+ return u
+}
+
+// UpdateEmailVerifiedAt sets the "email_verified_at" field to the value that was provided on create.
+func (u *PendingAuthSessionUpsert) UpdateEmailVerifiedAt() *PendingAuthSessionUpsert {
+ u.SetExcluded(pendingauthsession.FieldEmailVerifiedAt)
+ return u
+}
+
+// ClearEmailVerifiedAt clears the value of the "email_verified_at" field.
+func (u *PendingAuthSessionUpsert) ClearEmailVerifiedAt() *PendingAuthSessionUpsert {
+ u.SetNull(pendingauthsession.FieldEmailVerifiedAt)
+ return u
+}
+
+// SetPasswordVerifiedAt sets the "password_verified_at" field.
+func (u *PendingAuthSessionUpsert) SetPasswordVerifiedAt(v time.Time) *PendingAuthSessionUpsert {
+ u.Set(pendingauthsession.FieldPasswordVerifiedAt, v)
+ return u
+}
+
+// UpdatePasswordVerifiedAt sets the "password_verified_at" field to the value that was provided on create.
+func (u *PendingAuthSessionUpsert) UpdatePasswordVerifiedAt() *PendingAuthSessionUpsert {
+ u.SetExcluded(pendingauthsession.FieldPasswordVerifiedAt)
+ return u
+}
+
+// ClearPasswordVerifiedAt clears the value of the "password_verified_at" field.
+func (u *PendingAuthSessionUpsert) ClearPasswordVerifiedAt() *PendingAuthSessionUpsert {
+ u.SetNull(pendingauthsession.FieldPasswordVerifiedAt)
+ return u
+}
+
+// SetTotpVerifiedAt sets the "totp_verified_at" field.
+func (u *PendingAuthSessionUpsert) SetTotpVerifiedAt(v time.Time) *PendingAuthSessionUpsert {
+ u.Set(pendingauthsession.FieldTotpVerifiedAt, v)
+ return u
+}
+
+// UpdateTotpVerifiedAt sets the "totp_verified_at" field to the value that was provided on create.
+func (u *PendingAuthSessionUpsert) UpdateTotpVerifiedAt() *PendingAuthSessionUpsert {
+ u.SetExcluded(pendingauthsession.FieldTotpVerifiedAt)
+ return u
+}
+
+// ClearTotpVerifiedAt clears the value of the "totp_verified_at" field.
+func (u *PendingAuthSessionUpsert) ClearTotpVerifiedAt() *PendingAuthSessionUpsert {
+ u.SetNull(pendingauthsession.FieldTotpVerifiedAt)
+ return u
+}
+
+// SetExpiresAt sets the "expires_at" field.
+func (u *PendingAuthSessionUpsert) SetExpiresAt(v time.Time) *PendingAuthSessionUpsert {
+ u.Set(pendingauthsession.FieldExpiresAt, v)
+ return u
+}
+
+// UpdateExpiresAt sets the "expires_at" field to the value that was provided on create.
+func (u *PendingAuthSessionUpsert) UpdateExpiresAt() *PendingAuthSessionUpsert {
+ u.SetExcluded(pendingauthsession.FieldExpiresAt)
+ return u
+}
+
+// SetConsumedAt sets the "consumed_at" field.
+func (u *PendingAuthSessionUpsert) SetConsumedAt(v time.Time) *PendingAuthSessionUpsert {
+ u.Set(pendingauthsession.FieldConsumedAt, v)
+ return u
+}
+
+// UpdateConsumedAt sets the "consumed_at" field to the value that was provided on create.
+func (u *PendingAuthSessionUpsert) UpdateConsumedAt() *PendingAuthSessionUpsert {
+ u.SetExcluded(pendingauthsession.FieldConsumedAt)
+ return u
+}
+
+// ClearConsumedAt clears the value of the "consumed_at" field.
+func (u *PendingAuthSessionUpsert) ClearConsumedAt() *PendingAuthSessionUpsert {
+ u.SetNull(pendingauthsession.FieldConsumedAt)
+ return u
+}
+
+// UpdateNewValues updates the mutable fields using the new values that were set on create.
+// Using this option is equivalent to using:
+//
+// client.PendingAuthSession.Create().
+// OnConflict(
+// sql.ResolveWithNewValues(),
+// ).
+// Exec(ctx)
+func (u *PendingAuthSessionUpsertOne) UpdateNewValues() *PendingAuthSessionUpsertOne {
+ u.create.conflict = append(u.create.conflict, sql.ResolveWithNewValues())
+ u.create.conflict = append(u.create.conflict, sql.ResolveWith(func(s *sql.UpdateSet) {
+ if _, exists := u.create.mutation.CreatedAt(); exists {
+ s.SetIgnore(pendingauthsession.FieldCreatedAt)
+ }
+ }))
+ return u
+}
+
+// Ignore sets each column to itself in case of conflict.
+// Using this option is equivalent to using:
+//
+// client.PendingAuthSession.Create().
+// OnConflict(sql.ResolveWithIgnore()).
+// Exec(ctx)
+func (u *PendingAuthSessionUpsertOne) Ignore() *PendingAuthSessionUpsertOne {
+ u.create.conflict = append(u.create.conflict, sql.ResolveWithIgnore())
+ return u
+}
+
+// DoNothing configures the conflict_action to `DO NOTHING`.
+// Supported only by SQLite and PostgreSQL.
+func (u *PendingAuthSessionUpsertOne) DoNothing() *PendingAuthSessionUpsertOne {
+ u.create.conflict = append(u.create.conflict, sql.DoNothing())
+ return u
+}
+
+// Update allows overriding fields `UPDATE` values. See the PendingAuthSessionCreate.OnConflict
+// documentation for more info.
+func (u *PendingAuthSessionUpsertOne) Update(set func(*PendingAuthSessionUpsert)) *PendingAuthSessionUpsertOne {
+ u.create.conflict = append(u.create.conflict, sql.ResolveWith(func(update *sql.UpdateSet) {
+ set(&PendingAuthSessionUpsert{UpdateSet: update})
+ }))
+ return u
+}
+
+// SetUpdatedAt sets the "updated_at" field.
+func (u *PendingAuthSessionUpsertOne) SetUpdatedAt(v time.Time) *PendingAuthSessionUpsertOne {
+ return u.Update(func(s *PendingAuthSessionUpsert) {
+ s.SetUpdatedAt(v)
+ })
+}
+
+// UpdateUpdatedAt sets the "updated_at" field to the value that was provided on create.
+func (u *PendingAuthSessionUpsertOne) UpdateUpdatedAt() *PendingAuthSessionUpsertOne {
+ return u.Update(func(s *PendingAuthSessionUpsert) {
+ s.UpdateUpdatedAt()
+ })
+}
+
+// SetSessionToken sets the "session_token" field.
+func (u *PendingAuthSessionUpsertOne) SetSessionToken(v string) *PendingAuthSessionUpsertOne {
+ return u.Update(func(s *PendingAuthSessionUpsert) {
+ s.SetSessionToken(v)
+ })
+}
+
+// UpdateSessionToken sets the "session_token" field to the value that was provided on create.
+func (u *PendingAuthSessionUpsertOne) UpdateSessionToken() *PendingAuthSessionUpsertOne {
+ return u.Update(func(s *PendingAuthSessionUpsert) {
+ s.UpdateSessionToken()
+ })
+}
+
+// SetIntent sets the "intent" field.
+func (u *PendingAuthSessionUpsertOne) SetIntent(v string) *PendingAuthSessionUpsertOne {
+ return u.Update(func(s *PendingAuthSessionUpsert) {
+ s.SetIntent(v)
+ })
+}
+
+// UpdateIntent sets the "intent" field to the value that was provided on create.
+func (u *PendingAuthSessionUpsertOne) UpdateIntent() *PendingAuthSessionUpsertOne {
+ return u.Update(func(s *PendingAuthSessionUpsert) {
+ s.UpdateIntent()
+ })
+}
+
+// SetProviderType sets the "provider_type" field.
+func (u *PendingAuthSessionUpsertOne) SetProviderType(v string) *PendingAuthSessionUpsertOne {
+ return u.Update(func(s *PendingAuthSessionUpsert) {
+ s.SetProviderType(v)
+ })
+}
+
+// UpdateProviderType sets the "provider_type" field to the value that was provided on create.
+func (u *PendingAuthSessionUpsertOne) UpdateProviderType() *PendingAuthSessionUpsertOne {
+ return u.Update(func(s *PendingAuthSessionUpsert) {
+ s.UpdateProviderType()
+ })
+}
+
+// SetProviderKey sets the "provider_key" field.
+func (u *PendingAuthSessionUpsertOne) SetProviderKey(v string) *PendingAuthSessionUpsertOne {
+ return u.Update(func(s *PendingAuthSessionUpsert) {
+ s.SetProviderKey(v)
+ })
+}
+
+// UpdateProviderKey sets the "provider_key" field to the value that was provided on create.
+func (u *PendingAuthSessionUpsertOne) UpdateProviderKey() *PendingAuthSessionUpsertOne {
+ return u.Update(func(s *PendingAuthSessionUpsert) {
+ s.UpdateProviderKey()
+ })
+}
+
+// SetProviderSubject sets the "provider_subject" field.
+func (u *PendingAuthSessionUpsertOne) SetProviderSubject(v string) *PendingAuthSessionUpsertOne {
+ return u.Update(func(s *PendingAuthSessionUpsert) {
+ s.SetProviderSubject(v)
+ })
+}
+
+// UpdateProviderSubject sets the "provider_subject" field to the value that was provided on create.
+func (u *PendingAuthSessionUpsertOne) UpdateProviderSubject() *PendingAuthSessionUpsertOne {
+ return u.Update(func(s *PendingAuthSessionUpsert) {
+ s.UpdateProviderSubject()
+ })
+}
+
+// SetTargetUserID sets the "target_user_id" field.
+func (u *PendingAuthSessionUpsertOne) SetTargetUserID(v int64) *PendingAuthSessionUpsertOne {
+ return u.Update(func(s *PendingAuthSessionUpsert) {
+ s.SetTargetUserID(v)
+ })
+}
+
+// UpdateTargetUserID sets the "target_user_id" field to the value that was provided on create.
+func (u *PendingAuthSessionUpsertOne) UpdateTargetUserID() *PendingAuthSessionUpsertOne {
+ return u.Update(func(s *PendingAuthSessionUpsert) {
+ s.UpdateTargetUserID()
+ })
+}
+
+// ClearTargetUserID clears the value of the "target_user_id" field.
+func (u *PendingAuthSessionUpsertOne) ClearTargetUserID() *PendingAuthSessionUpsertOne {
+ return u.Update(func(s *PendingAuthSessionUpsert) {
+ s.ClearTargetUserID()
+ })
+}
+
+// SetRedirectTo sets the "redirect_to" field.
+func (u *PendingAuthSessionUpsertOne) SetRedirectTo(v string) *PendingAuthSessionUpsertOne {
+ return u.Update(func(s *PendingAuthSessionUpsert) {
+ s.SetRedirectTo(v)
+ })
+}
+
+// UpdateRedirectTo sets the "redirect_to" field to the value that was provided on create.
+func (u *PendingAuthSessionUpsertOne) UpdateRedirectTo() *PendingAuthSessionUpsertOne {
+ return u.Update(func(s *PendingAuthSessionUpsert) {
+ s.UpdateRedirectTo()
+ })
+}
+
+// SetResolvedEmail sets the "resolved_email" field.
+func (u *PendingAuthSessionUpsertOne) SetResolvedEmail(v string) *PendingAuthSessionUpsertOne {
+ return u.Update(func(s *PendingAuthSessionUpsert) {
+ s.SetResolvedEmail(v)
+ })
+}
+
+// UpdateResolvedEmail sets the "resolved_email" field to the value that was provided on create.
+func (u *PendingAuthSessionUpsertOne) UpdateResolvedEmail() *PendingAuthSessionUpsertOne {
+ return u.Update(func(s *PendingAuthSessionUpsert) {
+ s.UpdateResolvedEmail()
+ })
+}
+
+// SetRegistrationPasswordHash sets the "registration_password_hash" field.
+func (u *PendingAuthSessionUpsertOne) SetRegistrationPasswordHash(v string) *PendingAuthSessionUpsertOne {
+ return u.Update(func(s *PendingAuthSessionUpsert) {
+ s.SetRegistrationPasswordHash(v)
+ })
+}
+
+// UpdateRegistrationPasswordHash sets the "registration_password_hash" field to the value that was provided on create.
+func (u *PendingAuthSessionUpsertOne) UpdateRegistrationPasswordHash() *PendingAuthSessionUpsertOne {
+ return u.Update(func(s *PendingAuthSessionUpsert) {
+ s.UpdateRegistrationPasswordHash()
+ })
+}
+
+// SetUpstreamIdentityClaims sets the "upstream_identity_claims" field.
+func (u *PendingAuthSessionUpsertOne) SetUpstreamIdentityClaims(v map[string]interface{}) *PendingAuthSessionUpsertOne {
+ return u.Update(func(s *PendingAuthSessionUpsert) {
+ s.SetUpstreamIdentityClaims(v)
+ })
+}
+
+// UpdateUpstreamIdentityClaims sets the "upstream_identity_claims" field to the value that was provided on create.
+func (u *PendingAuthSessionUpsertOne) UpdateUpstreamIdentityClaims() *PendingAuthSessionUpsertOne {
+ return u.Update(func(s *PendingAuthSessionUpsert) {
+ s.UpdateUpstreamIdentityClaims()
+ })
+}
+
+// SetLocalFlowState sets the "local_flow_state" field.
+func (u *PendingAuthSessionUpsertOne) SetLocalFlowState(v map[string]interface{}) *PendingAuthSessionUpsertOne {
+ return u.Update(func(s *PendingAuthSessionUpsert) {
+ s.SetLocalFlowState(v)
+ })
+}
+
+// UpdateLocalFlowState sets the "local_flow_state" field to the value that was provided on create.
+func (u *PendingAuthSessionUpsertOne) UpdateLocalFlowState() *PendingAuthSessionUpsertOne {
+ return u.Update(func(s *PendingAuthSessionUpsert) {
+ s.UpdateLocalFlowState()
+ })
+}
+
+// SetBrowserSessionKey sets the "browser_session_key" field.
+func (u *PendingAuthSessionUpsertOne) SetBrowserSessionKey(v string) *PendingAuthSessionUpsertOne {
+ return u.Update(func(s *PendingAuthSessionUpsert) {
+ s.SetBrowserSessionKey(v)
+ })
+}
+
+// UpdateBrowserSessionKey sets the "browser_session_key" field to the value that was provided on create.
+func (u *PendingAuthSessionUpsertOne) UpdateBrowserSessionKey() *PendingAuthSessionUpsertOne {
+ return u.Update(func(s *PendingAuthSessionUpsert) {
+ s.UpdateBrowserSessionKey()
+ })
+}
+
+// SetCompletionCodeHash sets the "completion_code_hash" field.
+func (u *PendingAuthSessionUpsertOne) SetCompletionCodeHash(v string) *PendingAuthSessionUpsertOne {
+ return u.Update(func(s *PendingAuthSessionUpsert) {
+ s.SetCompletionCodeHash(v)
+ })
+}
+
+// UpdateCompletionCodeHash sets the "completion_code_hash" field to the value that was provided on create.
+func (u *PendingAuthSessionUpsertOne) UpdateCompletionCodeHash() *PendingAuthSessionUpsertOne {
+ return u.Update(func(s *PendingAuthSessionUpsert) {
+ s.UpdateCompletionCodeHash()
+ })
+}
+
+// SetCompletionCodeExpiresAt sets the "completion_code_expires_at" field.
+func (u *PendingAuthSessionUpsertOne) SetCompletionCodeExpiresAt(v time.Time) *PendingAuthSessionUpsertOne {
+ return u.Update(func(s *PendingAuthSessionUpsert) {
+ s.SetCompletionCodeExpiresAt(v)
+ })
+}
+
+// UpdateCompletionCodeExpiresAt sets the "completion_code_expires_at" field to the value that was provided on create.
+func (u *PendingAuthSessionUpsertOne) UpdateCompletionCodeExpiresAt() *PendingAuthSessionUpsertOne {
+ return u.Update(func(s *PendingAuthSessionUpsert) {
+ s.UpdateCompletionCodeExpiresAt()
+ })
+}
+
+// ClearCompletionCodeExpiresAt clears the value of the "completion_code_expires_at" field.
+func (u *PendingAuthSessionUpsertOne) ClearCompletionCodeExpiresAt() *PendingAuthSessionUpsertOne {
+ return u.Update(func(s *PendingAuthSessionUpsert) {
+ s.ClearCompletionCodeExpiresAt()
+ })
+}
+
+// SetEmailVerifiedAt sets the "email_verified_at" field.
+func (u *PendingAuthSessionUpsertOne) SetEmailVerifiedAt(v time.Time) *PendingAuthSessionUpsertOne {
+ return u.Update(func(s *PendingAuthSessionUpsert) {
+ s.SetEmailVerifiedAt(v)
+ })
+}
+
+// UpdateEmailVerifiedAt sets the "email_verified_at" field to the value that was provided on create.
+func (u *PendingAuthSessionUpsertOne) UpdateEmailVerifiedAt() *PendingAuthSessionUpsertOne {
+ return u.Update(func(s *PendingAuthSessionUpsert) {
+ s.UpdateEmailVerifiedAt()
+ })
+}
+
+// ClearEmailVerifiedAt clears the value of the "email_verified_at" field.
+func (u *PendingAuthSessionUpsertOne) ClearEmailVerifiedAt() *PendingAuthSessionUpsertOne {
+ return u.Update(func(s *PendingAuthSessionUpsert) {
+ s.ClearEmailVerifiedAt()
+ })
+}
+
+// SetPasswordVerifiedAt sets the "password_verified_at" field.
+func (u *PendingAuthSessionUpsertOne) SetPasswordVerifiedAt(v time.Time) *PendingAuthSessionUpsertOne {
+ return u.Update(func(s *PendingAuthSessionUpsert) {
+ s.SetPasswordVerifiedAt(v)
+ })
+}
+
+// UpdatePasswordVerifiedAt sets the "password_verified_at" field to the value that was provided on create.
+func (u *PendingAuthSessionUpsertOne) UpdatePasswordVerifiedAt() *PendingAuthSessionUpsertOne {
+ return u.Update(func(s *PendingAuthSessionUpsert) {
+ s.UpdatePasswordVerifiedAt()
+ })
+}
+
+// ClearPasswordVerifiedAt clears the value of the "password_verified_at" field.
+func (u *PendingAuthSessionUpsertOne) ClearPasswordVerifiedAt() *PendingAuthSessionUpsertOne {
+ return u.Update(func(s *PendingAuthSessionUpsert) {
+ s.ClearPasswordVerifiedAt()
+ })
+}
+
+// SetTotpVerifiedAt sets the "totp_verified_at" field.
+func (u *PendingAuthSessionUpsertOne) SetTotpVerifiedAt(v time.Time) *PendingAuthSessionUpsertOne {
+ return u.Update(func(s *PendingAuthSessionUpsert) {
+ s.SetTotpVerifiedAt(v)
+ })
+}
+
+// UpdateTotpVerifiedAt sets the "totp_verified_at" field to the value that was provided on create.
+func (u *PendingAuthSessionUpsertOne) UpdateTotpVerifiedAt() *PendingAuthSessionUpsertOne {
+ return u.Update(func(s *PendingAuthSessionUpsert) {
+ s.UpdateTotpVerifiedAt()
+ })
+}
+
+// ClearTotpVerifiedAt clears the value of the "totp_verified_at" field.
+func (u *PendingAuthSessionUpsertOne) ClearTotpVerifiedAt() *PendingAuthSessionUpsertOne {
+ return u.Update(func(s *PendingAuthSessionUpsert) {
+ s.ClearTotpVerifiedAt()
+ })
+}
+
+// SetExpiresAt sets the "expires_at" field.
+func (u *PendingAuthSessionUpsertOne) SetExpiresAt(v time.Time) *PendingAuthSessionUpsertOne {
+ return u.Update(func(s *PendingAuthSessionUpsert) {
+ s.SetExpiresAt(v)
+ })
+}
+
+// UpdateExpiresAt sets the "expires_at" field to the value that was provided on create.
+func (u *PendingAuthSessionUpsertOne) UpdateExpiresAt() *PendingAuthSessionUpsertOne {
+ return u.Update(func(s *PendingAuthSessionUpsert) {
+ s.UpdateExpiresAt()
+ })
+}
+
+// SetConsumedAt sets the "consumed_at" field.
+func (u *PendingAuthSessionUpsertOne) SetConsumedAt(v time.Time) *PendingAuthSessionUpsertOne {
+ return u.Update(func(s *PendingAuthSessionUpsert) {
+ s.SetConsumedAt(v)
+ })
+}
+
+// UpdateConsumedAt sets the "consumed_at" field to the value that was provided on create.
+func (u *PendingAuthSessionUpsertOne) UpdateConsumedAt() *PendingAuthSessionUpsertOne {
+ return u.Update(func(s *PendingAuthSessionUpsert) {
+ s.UpdateConsumedAt()
+ })
+}
+
+// ClearConsumedAt clears the value of the "consumed_at" field.
+func (u *PendingAuthSessionUpsertOne) ClearConsumedAt() *PendingAuthSessionUpsertOne {
+ return u.Update(func(s *PendingAuthSessionUpsert) {
+ s.ClearConsumedAt()
+ })
+}
+
+// Exec executes the query.
+func (u *PendingAuthSessionUpsertOne) Exec(ctx context.Context) error {
+ if len(u.create.conflict) == 0 {
+ return errors.New("ent: missing options for PendingAuthSessionCreate.OnConflict")
+ }
+ return u.create.Exec(ctx)
+}
+
+// ExecX is like Exec, but panics if an error occurs.
+func (u *PendingAuthSessionUpsertOne) ExecX(ctx context.Context) {
+ if err := u.create.Exec(ctx); err != nil {
+ panic(err)
+ }
+}
+
+// Exec executes the UPSERT query and returns the inserted/updated ID.
+func (u *PendingAuthSessionUpsertOne) ID(ctx context.Context) (id int64, err error) {
+ node, err := u.create.Save(ctx)
+ if err != nil {
+ return id, err
+ }
+ return node.ID, nil
+}
+
+// IDX is like ID, but panics if an error occurs.
+func (u *PendingAuthSessionUpsertOne) IDX(ctx context.Context) int64 {
+ id, err := u.ID(ctx)
+ if err != nil {
+ panic(err)
+ }
+ return id
+}
+
+// PendingAuthSessionCreateBulk is the builder for creating many PendingAuthSession entities in bulk.
+type PendingAuthSessionCreateBulk struct {
+ config
+ err error
+ builders []*PendingAuthSessionCreate
+ conflict []sql.ConflictOption
+}
+
+// Save creates the PendingAuthSession entities in the database.
+func (_c *PendingAuthSessionCreateBulk) Save(ctx context.Context) ([]*PendingAuthSession, error) {
+ if _c.err != nil {
+ return nil, _c.err
+ }
+ specs := make([]*sqlgraph.CreateSpec, len(_c.builders))
+ nodes := make([]*PendingAuthSession, len(_c.builders))
+ mutators := make([]Mutator, len(_c.builders))
+ for i := range _c.builders {
+ func(i int, root context.Context) {
+ builder := _c.builders[i]
+ builder.defaults()
+ var mut Mutator = MutateFunc(func(ctx context.Context, m Mutation) (Value, error) {
+ mutation, ok := m.(*PendingAuthSessionMutation)
+ if !ok {
+ return nil, fmt.Errorf("unexpected mutation type %T", m)
+ }
+ if err := builder.check(); err != nil {
+ return nil, err
+ }
+ builder.mutation = mutation
+ var err error
+ nodes[i], specs[i] = builder.createSpec()
+ if i < len(mutators)-1 {
+ _, err = mutators[i+1].Mutate(root, _c.builders[i+1].mutation)
+ } else {
+ spec := &sqlgraph.BatchCreateSpec{Nodes: specs}
+ spec.OnConflict = _c.conflict
+ // Invoke the actual operation on the latest mutation in the chain.
+ if err = sqlgraph.BatchCreate(ctx, _c.driver, spec); err != nil {
+ if sqlgraph.IsConstraintError(err) {
+ err = &ConstraintError{msg: err.Error(), wrap: err}
+ }
+ }
+ }
+ if err != nil {
+ return nil, err
+ }
+ mutation.id = &nodes[i].ID
+ if specs[i].ID.Value != nil {
+ id := specs[i].ID.Value.(int64)
+ nodes[i].ID = int64(id)
+ }
+ mutation.done = true
+ return nodes[i], nil
+ })
+ for i := len(builder.hooks) - 1; i >= 0; i-- {
+ mut = builder.hooks[i](mut)
+ }
+ mutators[i] = mut
+ }(i, ctx)
+ }
+ if len(mutators) > 0 {
+ if _, err := mutators[0].Mutate(ctx, _c.builders[0].mutation); err != nil {
+ return nil, err
+ }
+ }
+ return nodes, nil
+}
+
+// SaveX is like Save, but panics if an error occurs.
+func (_c *PendingAuthSessionCreateBulk) SaveX(ctx context.Context) []*PendingAuthSession {
+ v, err := _c.Save(ctx)
+ if err != nil {
+ panic(err)
+ }
+ return v
+}
+
+// Exec executes the query.
+func (_c *PendingAuthSessionCreateBulk) Exec(ctx context.Context) error {
+ _, err := _c.Save(ctx)
+ return err
+}
+
+// ExecX is like Exec, but panics if an error occurs.
+func (_c *PendingAuthSessionCreateBulk) ExecX(ctx context.Context) {
+ if err := _c.Exec(ctx); err != nil {
+ panic(err)
+ }
+}
+
+// OnConflict allows configuring the `ON CONFLICT` / `ON DUPLICATE KEY` clause
+// of the `INSERT` statement. For example:
+//
+// client.PendingAuthSession.CreateBulk(builders...).
+// OnConflict(
+// // Update the row with the new values
+// // the was proposed for insertion.
+// sql.ResolveWithNewValues(),
+// ).
+// // Override some of the fields with custom
+// // update values.
+// Update(func(u *ent.PendingAuthSessionUpsert) {
+// SetCreatedAt(v+v).
+// }).
+// Exec(ctx)
+func (_c *PendingAuthSessionCreateBulk) OnConflict(opts ...sql.ConflictOption) *PendingAuthSessionUpsertBulk {
+ _c.conflict = opts
+ return &PendingAuthSessionUpsertBulk{
+ create: _c,
+ }
+}
+
+// OnConflictColumns calls `OnConflict` and configures the columns
+// as conflict target. Using this option is equivalent to using:
+//
+// client.PendingAuthSession.Create().
+// OnConflict(sql.ConflictColumns(columns...)).
+// Exec(ctx)
+func (_c *PendingAuthSessionCreateBulk) OnConflictColumns(columns ...string) *PendingAuthSessionUpsertBulk {
+ _c.conflict = append(_c.conflict, sql.ConflictColumns(columns...))
+ return &PendingAuthSessionUpsertBulk{
+ create: _c,
+ }
+}
+
+// PendingAuthSessionUpsertBulk is the builder for "upsert"-ing
+// a bulk of PendingAuthSession nodes.
+type PendingAuthSessionUpsertBulk struct {
+ create *PendingAuthSessionCreateBulk
+}
+
+// UpdateNewValues updates the mutable fields using the new values that
+// were set on create. Using this option is equivalent to using:
+//
+// client.PendingAuthSession.Create().
+// OnConflict(
+// sql.ResolveWithNewValues(),
+// ).
+// Exec(ctx)
+func (u *PendingAuthSessionUpsertBulk) UpdateNewValues() *PendingAuthSessionUpsertBulk {
+ u.create.conflict = append(u.create.conflict, sql.ResolveWithNewValues())
+ u.create.conflict = append(u.create.conflict, sql.ResolveWith(func(s *sql.UpdateSet) {
+ for _, b := range u.create.builders {
+ if _, exists := b.mutation.CreatedAt(); exists {
+ s.SetIgnore(pendingauthsession.FieldCreatedAt)
+ }
+ }
+ }))
+ return u
+}
+
+// Ignore sets each column to itself in case of conflict.
+// Using this option is equivalent to using:
+//
+// client.PendingAuthSession.Create().
+// OnConflict(sql.ResolveWithIgnore()).
+// Exec(ctx)
+func (u *PendingAuthSessionUpsertBulk) Ignore() *PendingAuthSessionUpsertBulk {
+ u.create.conflict = append(u.create.conflict, sql.ResolveWithIgnore())
+ return u
+}
+
+// DoNothing configures the conflict_action to `DO NOTHING`.
+// Supported only by SQLite and PostgreSQL.
+func (u *PendingAuthSessionUpsertBulk) DoNothing() *PendingAuthSessionUpsertBulk {
+ u.create.conflict = append(u.create.conflict, sql.DoNothing())
+ return u
+}
+
+// Update allows overriding fields `UPDATE` values. See the PendingAuthSessionCreateBulk.OnConflict
+// documentation for more info.
+func (u *PendingAuthSessionUpsertBulk) Update(set func(*PendingAuthSessionUpsert)) *PendingAuthSessionUpsertBulk {
+ u.create.conflict = append(u.create.conflict, sql.ResolveWith(func(update *sql.UpdateSet) {
+ set(&PendingAuthSessionUpsert{UpdateSet: update})
+ }))
+ return u
+}
+
+// SetUpdatedAt sets the "updated_at" field.
+func (u *PendingAuthSessionUpsertBulk) SetUpdatedAt(v time.Time) *PendingAuthSessionUpsertBulk {
+ return u.Update(func(s *PendingAuthSessionUpsert) {
+ s.SetUpdatedAt(v)
+ })
+}
+
+// UpdateUpdatedAt sets the "updated_at" field to the value that was provided on create.
+func (u *PendingAuthSessionUpsertBulk) UpdateUpdatedAt() *PendingAuthSessionUpsertBulk {
+ return u.Update(func(s *PendingAuthSessionUpsert) {
+ s.UpdateUpdatedAt()
+ })
+}
+
+// SetSessionToken sets the "session_token" field.
+func (u *PendingAuthSessionUpsertBulk) SetSessionToken(v string) *PendingAuthSessionUpsertBulk {
+ return u.Update(func(s *PendingAuthSessionUpsert) {
+ s.SetSessionToken(v)
+ })
+}
+
+// UpdateSessionToken sets the "session_token" field to the value that was provided on create.
+func (u *PendingAuthSessionUpsertBulk) UpdateSessionToken() *PendingAuthSessionUpsertBulk {
+ return u.Update(func(s *PendingAuthSessionUpsert) {
+ s.UpdateSessionToken()
+ })
+}
+
+// SetIntent sets the "intent" field.
+func (u *PendingAuthSessionUpsertBulk) SetIntent(v string) *PendingAuthSessionUpsertBulk {
+ return u.Update(func(s *PendingAuthSessionUpsert) {
+ s.SetIntent(v)
+ })
+}
+
+// UpdateIntent sets the "intent" field to the value that was provided on create.
+func (u *PendingAuthSessionUpsertBulk) UpdateIntent() *PendingAuthSessionUpsertBulk {
+ return u.Update(func(s *PendingAuthSessionUpsert) {
+ s.UpdateIntent()
+ })
+}
+
+// SetProviderType sets the "provider_type" field.
+func (u *PendingAuthSessionUpsertBulk) SetProviderType(v string) *PendingAuthSessionUpsertBulk {
+ return u.Update(func(s *PendingAuthSessionUpsert) {
+ s.SetProviderType(v)
+ })
+}
+
+// UpdateProviderType sets the "provider_type" field to the value that was provided on create.
+func (u *PendingAuthSessionUpsertBulk) UpdateProviderType() *PendingAuthSessionUpsertBulk {
+ return u.Update(func(s *PendingAuthSessionUpsert) {
+ s.UpdateProviderType()
+ })
+}
+
+// SetProviderKey sets the "provider_key" field.
+func (u *PendingAuthSessionUpsertBulk) SetProviderKey(v string) *PendingAuthSessionUpsertBulk {
+ return u.Update(func(s *PendingAuthSessionUpsert) {
+ s.SetProviderKey(v)
+ })
+}
+
+// UpdateProviderKey sets the "provider_key" field to the value that was provided on create.
+func (u *PendingAuthSessionUpsertBulk) UpdateProviderKey() *PendingAuthSessionUpsertBulk {
+ return u.Update(func(s *PendingAuthSessionUpsert) {
+ s.UpdateProviderKey()
+ })
+}
+
+// SetProviderSubject sets the "provider_subject" field.
+func (u *PendingAuthSessionUpsertBulk) SetProviderSubject(v string) *PendingAuthSessionUpsertBulk {
+ return u.Update(func(s *PendingAuthSessionUpsert) {
+ s.SetProviderSubject(v)
+ })
+}
+
+// UpdateProviderSubject sets the "provider_subject" field to the value that was provided on create.
+func (u *PendingAuthSessionUpsertBulk) UpdateProviderSubject() *PendingAuthSessionUpsertBulk {
+ return u.Update(func(s *PendingAuthSessionUpsert) {
+ s.UpdateProviderSubject()
+ })
+}
+
+// SetTargetUserID sets the "target_user_id" field.
+func (u *PendingAuthSessionUpsertBulk) SetTargetUserID(v int64) *PendingAuthSessionUpsertBulk {
+ return u.Update(func(s *PendingAuthSessionUpsert) {
+ s.SetTargetUserID(v)
+ })
+}
+
+// UpdateTargetUserID sets the "target_user_id" field to the value that was provided on create.
+func (u *PendingAuthSessionUpsertBulk) UpdateTargetUserID() *PendingAuthSessionUpsertBulk {
+ return u.Update(func(s *PendingAuthSessionUpsert) {
+ s.UpdateTargetUserID()
+ })
+}
+
+// ClearTargetUserID clears the value of the "target_user_id" field.
+func (u *PendingAuthSessionUpsertBulk) ClearTargetUserID() *PendingAuthSessionUpsertBulk {
+ return u.Update(func(s *PendingAuthSessionUpsert) {
+ s.ClearTargetUserID()
+ })
+}
+
+// SetRedirectTo sets the "redirect_to" field.
+func (u *PendingAuthSessionUpsertBulk) SetRedirectTo(v string) *PendingAuthSessionUpsertBulk {
+ return u.Update(func(s *PendingAuthSessionUpsert) {
+ s.SetRedirectTo(v)
+ })
+}
+
+// UpdateRedirectTo sets the "redirect_to" field to the value that was provided on create.
+func (u *PendingAuthSessionUpsertBulk) UpdateRedirectTo() *PendingAuthSessionUpsertBulk {
+ return u.Update(func(s *PendingAuthSessionUpsert) {
+ s.UpdateRedirectTo()
+ })
+}
+
+// SetResolvedEmail sets the "resolved_email" field.
+func (u *PendingAuthSessionUpsertBulk) SetResolvedEmail(v string) *PendingAuthSessionUpsertBulk {
+ return u.Update(func(s *PendingAuthSessionUpsert) {
+ s.SetResolvedEmail(v)
+ })
+}
+
+// UpdateResolvedEmail sets the "resolved_email" field to the value that was provided on create.
+func (u *PendingAuthSessionUpsertBulk) UpdateResolvedEmail() *PendingAuthSessionUpsertBulk {
+ return u.Update(func(s *PendingAuthSessionUpsert) {
+ s.UpdateResolvedEmail()
+ })
+}
+
+// SetRegistrationPasswordHash sets the "registration_password_hash" field.
+func (u *PendingAuthSessionUpsertBulk) SetRegistrationPasswordHash(v string) *PendingAuthSessionUpsertBulk {
+ return u.Update(func(s *PendingAuthSessionUpsert) {
+ s.SetRegistrationPasswordHash(v)
+ })
+}
+
+// UpdateRegistrationPasswordHash sets the "registration_password_hash" field to the value that was provided on create.
+func (u *PendingAuthSessionUpsertBulk) UpdateRegistrationPasswordHash() *PendingAuthSessionUpsertBulk {
+ return u.Update(func(s *PendingAuthSessionUpsert) {
+ s.UpdateRegistrationPasswordHash()
+ })
+}
+
+// SetUpstreamIdentityClaims sets the "upstream_identity_claims" field.
+func (u *PendingAuthSessionUpsertBulk) SetUpstreamIdentityClaims(v map[string]interface{}) *PendingAuthSessionUpsertBulk {
+ return u.Update(func(s *PendingAuthSessionUpsert) {
+ s.SetUpstreamIdentityClaims(v)
+ })
+}
+
+// UpdateUpstreamIdentityClaims sets the "upstream_identity_claims" field to the value that was provided on create.
+func (u *PendingAuthSessionUpsertBulk) UpdateUpstreamIdentityClaims() *PendingAuthSessionUpsertBulk {
+ return u.Update(func(s *PendingAuthSessionUpsert) {
+ s.UpdateUpstreamIdentityClaims()
+ })
+}
+
+// SetLocalFlowState sets the "local_flow_state" field.
+func (u *PendingAuthSessionUpsertBulk) SetLocalFlowState(v map[string]interface{}) *PendingAuthSessionUpsertBulk {
+ return u.Update(func(s *PendingAuthSessionUpsert) {
+ s.SetLocalFlowState(v)
+ })
+}
+
+// UpdateLocalFlowState sets the "local_flow_state" field to the value that was provided on create.
+func (u *PendingAuthSessionUpsertBulk) UpdateLocalFlowState() *PendingAuthSessionUpsertBulk {
+ return u.Update(func(s *PendingAuthSessionUpsert) {
+ s.UpdateLocalFlowState()
+ })
+}
+
+// SetBrowserSessionKey sets the "browser_session_key" field.
+func (u *PendingAuthSessionUpsertBulk) SetBrowserSessionKey(v string) *PendingAuthSessionUpsertBulk {
+ return u.Update(func(s *PendingAuthSessionUpsert) {
+ s.SetBrowserSessionKey(v)
+ })
+}
+
+// UpdateBrowserSessionKey sets the "browser_session_key" field to the value that was provided on create.
+func (u *PendingAuthSessionUpsertBulk) UpdateBrowserSessionKey() *PendingAuthSessionUpsertBulk {
+ return u.Update(func(s *PendingAuthSessionUpsert) {
+ s.UpdateBrowserSessionKey()
+ })
+}
+
+// SetCompletionCodeHash sets the "completion_code_hash" field.
+func (u *PendingAuthSessionUpsertBulk) SetCompletionCodeHash(v string) *PendingAuthSessionUpsertBulk {
+ return u.Update(func(s *PendingAuthSessionUpsert) {
+ s.SetCompletionCodeHash(v)
+ })
+}
+
+// UpdateCompletionCodeHash sets the "completion_code_hash" field to the value that was provided on create.
+func (u *PendingAuthSessionUpsertBulk) UpdateCompletionCodeHash() *PendingAuthSessionUpsertBulk {
+ return u.Update(func(s *PendingAuthSessionUpsert) {
+ s.UpdateCompletionCodeHash()
+ })
+}
+
+// SetCompletionCodeExpiresAt sets the "completion_code_expires_at" field.
+func (u *PendingAuthSessionUpsertBulk) SetCompletionCodeExpiresAt(v time.Time) *PendingAuthSessionUpsertBulk {
+ return u.Update(func(s *PendingAuthSessionUpsert) {
+ s.SetCompletionCodeExpiresAt(v)
+ })
+}
+
+// UpdateCompletionCodeExpiresAt sets the "completion_code_expires_at" field to the value that was provided on create.
+func (u *PendingAuthSessionUpsertBulk) UpdateCompletionCodeExpiresAt() *PendingAuthSessionUpsertBulk {
+ return u.Update(func(s *PendingAuthSessionUpsert) {
+ s.UpdateCompletionCodeExpiresAt()
+ })
+}
+
+// ClearCompletionCodeExpiresAt clears the value of the "completion_code_expires_at" field.
+func (u *PendingAuthSessionUpsertBulk) ClearCompletionCodeExpiresAt() *PendingAuthSessionUpsertBulk {
+ return u.Update(func(s *PendingAuthSessionUpsert) {
+ s.ClearCompletionCodeExpiresAt()
+ })
+}
+
+// SetEmailVerifiedAt sets the "email_verified_at" field.
+func (u *PendingAuthSessionUpsertBulk) SetEmailVerifiedAt(v time.Time) *PendingAuthSessionUpsertBulk {
+ return u.Update(func(s *PendingAuthSessionUpsert) {
+ s.SetEmailVerifiedAt(v)
+ })
+}
+
+// UpdateEmailVerifiedAt sets the "email_verified_at" field to the value that was provided on create.
+func (u *PendingAuthSessionUpsertBulk) UpdateEmailVerifiedAt() *PendingAuthSessionUpsertBulk {
+ return u.Update(func(s *PendingAuthSessionUpsert) {
+ s.UpdateEmailVerifiedAt()
+ })
+}
+
+// ClearEmailVerifiedAt clears the value of the "email_verified_at" field.
+func (u *PendingAuthSessionUpsertBulk) ClearEmailVerifiedAt() *PendingAuthSessionUpsertBulk {
+ return u.Update(func(s *PendingAuthSessionUpsert) {
+ s.ClearEmailVerifiedAt()
+ })
+}
+
+// SetPasswordVerifiedAt sets the "password_verified_at" field.
+func (u *PendingAuthSessionUpsertBulk) SetPasswordVerifiedAt(v time.Time) *PendingAuthSessionUpsertBulk {
+ return u.Update(func(s *PendingAuthSessionUpsert) {
+ s.SetPasswordVerifiedAt(v)
+ })
+}
+
+// UpdatePasswordVerifiedAt sets the "password_verified_at" field to the value that was provided on create.
+func (u *PendingAuthSessionUpsertBulk) UpdatePasswordVerifiedAt() *PendingAuthSessionUpsertBulk {
+ return u.Update(func(s *PendingAuthSessionUpsert) {
+ s.UpdatePasswordVerifiedAt()
+ })
+}
+
+// ClearPasswordVerifiedAt clears the value of the "password_verified_at" field.
+func (u *PendingAuthSessionUpsertBulk) ClearPasswordVerifiedAt() *PendingAuthSessionUpsertBulk {
+ return u.Update(func(s *PendingAuthSessionUpsert) {
+ s.ClearPasswordVerifiedAt()
+ })
+}
+
+// SetTotpVerifiedAt sets the "totp_verified_at" field.
+func (u *PendingAuthSessionUpsertBulk) SetTotpVerifiedAt(v time.Time) *PendingAuthSessionUpsertBulk {
+ return u.Update(func(s *PendingAuthSessionUpsert) {
+ s.SetTotpVerifiedAt(v)
+ })
+}
+
+// UpdateTotpVerifiedAt sets the "totp_verified_at" field to the value that was provided on create.
+func (u *PendingAuthSessionUpsertBulk) UpdateTotpVerifiedAt() *PendingAuthSessionUpsertBulk {
+ return u.Update(func(s *PendingAuthSessionUpsert) {
+ s.UpdateTotpVerifiedAt()
+ })
+}
+
+// ClearTotpVerifiedAt clears the value of the "totp_verified_at" field.
+func (u *PendingAuthSessionUpsertBulk) ClearTotpVerifiedAt() *PendingAuthSessionUpsertBulk {
+ return u.Update(func(s *PendingAuthSessionUpsert) {
+ s.ClearTotpVerifiedAt()
+ })
+}
+
+// SetExpiresAt sets the "expires_at" field.
+func (u *PendingAuthSessionUpsertBulk) SetExpiresAt(v time.Time) *PendingAuthSessionUpsertBulk {
+ return u.Update(func(s *PendingAuthSessionUpsert) {
+ s.SetExpiresAt(v)
+ })
+}
+
+// UpdateExpiresAt sets the "expires_at" field to the value that was provided on create.
+func (u *PendingAuthSessionUpsertBulk) UpdateExpiresAt() *PendingAuthSessionUpsertBulk {
+ return u.Update(func(s *PendingAuthSessionUpsert) {
+ s.UpdateExpiresAt()
+ })
+}
+
+// SetConsumedAt sets the "consumed_at" field.
+func (u *PendingAuthSessionUpsertBulk) SetConsumedAt(v time.Time) *PendingAuthSessionUpsertBulk {
+ return u.Update(func(s *PendingAuthSessionUpsert) {
+ s.SetConsumedAt(v)
+ })
+}
+
+// UpdateConsumedAt sets the "consumed_at" field to the value that was provided on create.
+func (u *PendingAuthSessionUpsertBulk) UpdateConsumedAt() *PendingAuthSessionUpsertBulk {
+ return u.Update(func(s *PendingAuthSessionUpsert) {
+ s.UpdateConsumedAt()
+ })
+}
+
+// ClearConsumedAt clears the value of the "consumed_at" field.
+func (u *PendingAuthSessionUpsertBulk) ClearConsumedAt() *PendingAuthSessionUpsertBulk {
+ return u.Update(func(s *PendingAuthSessionUpsert) {
+ s.ClearConsumedAt()
+ })
+}
+
+// Exec executes the query.
+func (u *PendingAuthSessionUpsertBulk) Exec(ctx context.Context) error {
+ if u.create.err != nil {
+ return u.create.err
+ }
+ for i, b := range u.create.builders {
+ if len(b.conflict) != 0 {
+ return fmt.Errorf("ent: OnConflict was set for builder %d. Set it on the PendingAuthSessionCreateBulk instead", i)
+ }
+ }
+ if len(u.create.conflict) == 0 {
+ return errors.New("ent: missing options for PendingAuthSessionCreateBulk.OnConflict")
+ }
+ return u.create.Exec(ctx)
+}
+
+// ExecX is like Exec, but panics if an error occurs.
+func (u *PendingAuthSessionUpsertBulk) ExecX(ctx context.Context) {
+ if err := u.create.Exec(ctx); err != nil {
+ panic(err)
+ }
+}
diff --git a/backend/ent/pendingauthsession_delete.go b/backend/ent/pendingauthsession_delete.go
new file mode 100644
index 00000000..ee4fe605
--- /dev/null
+++ b/backend/ent/pendingauthsession_delete.go
@@ -0,0 +1,88 @@
+// Code generated by ent, DO NOT EDIT.
+
+package ent
+
+import (
+ "context"
+
+ "entgo.io/ent/dialect/sql"
+ "entgo.io/ent/dialect/sql/sqlgraph"
+ "entgo.io/ent/schema/field"
+ "github.com/Wei-Shaw/sub2api/ent/pendingauthsession"
+ "github.com/Wei-Shaw/sub2api/ent/predicate"
+)
+
+// PendingAuthSessionDelete is the builder for deleting a PendingAuthSession entity.
+type PendingAuthSessionDelete struct {
+ config
+ hooks []Hook
+ mutation *PendingAuthSessionMutation
+}
+
+// Where appends a list predicates to the PendingAuthSessionDelete builder.
+func (_d *PendingAuthSessionDelete) Where(ps ...predicate.PendingAuthSession) *PendingAuthSessionDelete {
+ _d.mutation.Where(ps...)
+ return _d
+}
+
+// Exec executes the deletion query and returns how many vertices were deleted.
+func (_d *PendingAuthSessionDelete) Exec(ctx context.Context) (int, error) {
+ return withHooks(ctx, _d.sqlExec, _d.mutation, _d.hooks)
+}
+
+// ExecX is like Exec, but panics if an error occurs.
+func (_d *PendingAuthSessionDelete) ExecX(ctx context.Context) int {
+ n, err := _d.Exec(ctx)
+ if err != nil {
+ panic(err)
+ }
+ return n
+}
+
+func (_d *PendingAuthSessionDelete) sqlExec(ctx context.Context) (int, error) {
+ _spec := sqlgraph.NewDeleteSpec(pendingauthsession.Table, sqlgraph.NewFieldSpec(pendingauthsession.FieldID, field.TypeInt64))
+ if ps := _d.mutation.predicates; len(ps) > 0 {
+ _spec.Predicate = func(selector *sql.Selector) {
+ for i := range ps {
+ ps[i](selector)
+ }
+ }
+ }
+ affected, err := sqlgraph.DeleteNodes(ctx, _d.driver, _spec)
+ if err != nil && sqlgraph.IsConstraintError(err) {
+ err = &ConstraintError{msg: err.Error(), wrap: err}
+ }
+ _d.mutation.done = true
+ return affected, err
+}
+
+// PendingAuthSessionDeleteOne is the builder for deleting a single PendingAuthSession entity.
+type PendingAuthSessionDeleteOne struct {
+ _d *PendingAuthSessionDelete
+}
+
+// Where appends a list predicates to the PendingAuthSessionDelete builder.
+func (_d *PendingAuthSessionDeleteOne) Where(ps ...predicate.PendingAuthSession) *PendingAuthSessionDeleteOne {
+ _d._d.mutation.Where(ps...)
+ return _d
+}
+
+// Exec executes the deletion query.
+func (_d *PendingAuthSessionDeleteOne) Exec(ctx context.Context) error {
+ n, err := _d._d.Exec(ctx)
+ switch {
+ case err != nil:
+ return err
+ case n == 0:
+ return &NotFoundError{pendingauthsession.Label}
+ default:
+ return nil
+ }
+}
+
+// ExecX is like Exec, but panics if an error occurs.
+func (_d *PendingAuthSessionDeleteOne) ExecX(ctx context.Context) {
+ if err := _d.Exec(ctx); err != nil {
+ panic(err)
+ }
+}
diff --git a/backend/ent/pendingauthsession_query.go b/backend/ent/pendingauthsession_query.go
new file mode 100644
index 00000000..78e29cd2
--- /dev/null
+++ b/backend/ent/pendingauthsession_query.go
@@ -0,0 +1,717 @@
+// Code generated by ent, DO NOT EDIT.
+
+package ent
+
+import (
+ "context"
+ "database/sql/driver"
+ "fmt"
+ "math"
+
+ "entgo.io/ent"
+ "entgo.io/ent/dialect"
+ "entgo.io/ent/dialect/sql"
+ "entgo.io/ent/dialect/sql/sqlgraph"
+ "entgo.io/ent/schema/field"
+ "github.com/Wei-Shaw/sub2api/ent/identityadoptiondecision"
+ "github.com/Wei-Shaw/sub2api/ent/pendingauthsession"
+ "github.com/Wei-Shaw/sub2api/ent/predicate"
+ "github.com/Wei-Shaw/sub2api/ent/user"
+)
+
+// PendingAuthSessionQuery is the builder for querying PendingAuthSession entities.
+type PendingAuthSessionQuery struct {
+ config
+ ctx *QueryContext
+ order []pendingauthsession.OrderOption
+ inters []Interceptor
+ predicates []predicate.PendingAuthSession
+ withTargetUser *UserQuery
+ withAdoptionDecision *IdentityAdoptionDecisionQuery
+ modifiers []func(*sql.Selector)
+ // intermediate query (i.e. traversal path).
+ sql *sql.Selector
+ path func(context.Context) (*sql.Selector, error)
+}
+
+// Where adds a new predicate for the PendingAuthSessionQuery builder.
+func (_q *PendingAuthSessionQuery) Where(ps ...predicate.PendingAuthSession) *PendingAuthSessionQuery {
+ _q.predicates = append(_q.predicates, ps...)
+ return _q
+}
+
+// Limit the number of records to be returned by this query.
+func (_q *PendingAuthSessionQuery) Limit(limit int) *PendingAuthSessionQuery {
+ _q.ctx.Limit = &limit
+ return _q
+}
+
+// Offset to start from.
+func (_q *PendingAuthSessionQuery) Offset(offset int) *PendingAuthSessionQuery {
+ _q.ctx.Offset = &offset
+ return _q
+}
+
+// Unique configures the query builder to filter duplicate records on query.
+// By default, unique is set to true, and can be disabled using this method.
+func (_q *PendingAuthSessionQuery) Unique(unique bool) *PendingAuthSessionQuery {
+ _q.ctx.Unique = &unique
+ return _q
+}
+
+// Order specifies how the records should be ordered.
+func (_q *PendingAuthSessionQuery) Order(o ...pendingauthsession.OrderOption) *PendingAuthSessionQuery {
+ _q.order = append(_q.order, o...)
+ return _q
+}
+
+// QueryTargetUser chains the current query on the "target_user" edge.
+func (_q *PendingAuthSessionQuery) QueryTargetUser() *UserQuery {
+ query := (&UserClient{config: _q.config}).Query()
+ query.path = func(ctx context.Context) (fromU *sql.Selector, err error) {
+ if err := _q.prepareQuery(ctx); err != nil {
+ return nil, err
+ }
+ selector := _q.sqlQuery(ctx)
+ if err := selector.Err(); err != nil {
+ return nil, err
+ }
+ step := sqlgraph.NewStep(
+ sqlgraph.From(pendingauthsession.Table, pendingauthsession.FieldID, selector),
+ sqlgraph.To(user.Table, user.FieldID),
+ sqlgraph.Edge(sqlgraph.M2O, true, pendingauthsession.TargetUserTable, pendingauthsession.TargetUserColumn),
+ )
+ fromU = sqlgraph.SetNeighbors(_q.driver.Dialect(), step)
+ return fromU, nil
+ }
+ return query
+}
+
+// QueryAdoptionDecision chains the current query on the "adoption_decision" edge.
+func (_q *PendingAuthSessionQuery) QueryAdoptionDecision() *IdentityAdoptionDecisionQuery {
+ query := (&IdentityAdoptionDecisionClient{config: _q.config}).Query()
+ query.path = func(ctx context.Context) (fromU *sql.Selector, err error) {
+ if err := _q.prepareQuery(ctx); err != nil {
+ return nil, err
+ }
+ selector := _q.sqlQuery(ctx)
+ if err := selector.Err(); err != nil {
+ return nil, err
+ }
+ step := sqlgraph.NewStep(
+ sqlgraph.From(pendingauthsession.Table, pendingauthsession.FieldID, selector),
+ sqlgraph.To(identityadoptiondecision.Table, identityadoptiondecision.FieldID),
+ sqlgraph.Edge(sqlgraph.O2O, false, pendingauthsession.AdoptionDecisionTable, pendingauthsession.AdoptionDecisionColumn),
+ )
+ fromU = sqlgraph.SetNeighbors(_q.driver.Dialect(), step)
+ return fromU, nil
+ }
+ return query
+}
+
+// First returns the first PendingAuthSession entity from the query.
+// Returns a *NotFoundError when no PendingAuthSession was found.
+func (_q *PendingAuthSessionQuery) First(ctx context.Context) (*PendingAuthSession, error) {
+ nodes, err := _q.Limit(1).All(setContextOp(ctx, _q.ctx, ent.OpQueryFirst))
+ if err != nil {
+ return nil, err
+ }
+ if len(nodes) == 0 {
+ return nil, &NotFoundError{pendingauthsession.Label}
+ }
+ return nodes[0], nil
+}
+
+// FirstX is like First, but panics if an error occurs.
+func (_q *PendingAuthSessionQuery) FirstX(ctx context.Context) *PendingAuthSession {
+ node, err := _q.First(ctx)
+ if err != nil && !IsNotFound(err) {
+ panic(err)
+ }
+ return node
+}
+
+// FirstID returns the first PendingAuthSession ID from the query.
+// Returns a *NotFoundError when no PendingAuthSession ID was found.
+func (_q *PendingAuthSessionQuery) FirstID(ctx context.Context) (id int64, err error) {
+ var ids []int64
+ if ids, err = _q.Limit(1).IDs(setContextOp(ctx, _q.ctx, ent.OpQueryFirstID)); err != nil {
+ return
+ }
+ if len(ids) == 0 {
+ err = &NotFoundError{pendingauthsession.Label}
+ return
+ }
+ return ids[0], nil
+}
+
+// FirstIDX is like FirstID, but panics if an error occurs.
+func (_q *PendingAuthSessionQuery) FirstIDX(ctx context.Context) int64 {
+ id, err := _q.FirstID(ctx)
+ if err != nil && !IsNotFound(err) {
+ panic(err)
+ }
+ return id
+}
+
+// Only returns a single PendingAuthSession entity found by the query, ensuring it only returns one.
+// Returns a *NotSingularError when more than one PendingAuthSession entity is found.
+// Returns a *NotFoundError when no PendingAuthSession entities are found.
+func (_q *PendingAuthSessionQuery) Only(ctx context.Context) (*PendingAuthSession, error) {
+ nodes, err := _q.Limit(2).All(setContextOp(ctx, _q.ctx, ent.OpQueryOnly))
+ if err != nil {
+ return nil, err
+ }
+ switch len(nodes) {
+ case 1:
+ return nodes[0], nil
+ case 0:
+ return nil, &NotFoundError{pendingauthsession.Label}
+ default:
+ return nil, &NotSingularError{pendingauthsession.Label}
+ }
+}
+
+// OnlyX is like Only, but panics if an error occurs.
+func (_q *PendingAuthSessionQuery) OnlyX(ctx context.Context) *PendingAuthSession {
+ node, err := _q.Only(ctx)
+ if err != nil {
+ panic(err)
+ }
+ return node
+}
+
+// OnlyID is like Only, but returns the only PendingAuthSession ID in the query.
+// Returns a *NotSingularError when more than one PendingAuthSession ID is found.
+// Returns a *NotFoundError when no entities are found.
+func (_q *PendingAuthSessionQuery) OnlyID(ctx context.Context) (id int64, err error) {
+ var ids []int64
+ if ids, err = _q.Limit(2).IDs(setContextOp(ctx, _q.ctx, ent.OpQueryOnlyID)); err != nil {
+ return
+ }
+ switch len(ids) {
+ case 1:
+ id = ids[0]
+ case 0:
+ err = &NotFoundError{pendingauthsession.Label}
+ default:
+ err = &NotSingularError{pendingauthsession.Label}
+ }
+ return
+}
+
+// OnlyIDX is like OnlyID, but panics if an error occurs.
+func (_q *PendingAuthSessionQuery) OnlyIDX(ctx context.Context) int64 {
+ id, err := _q.OnlyID(ctx)
+ if err != nil {
+ panic(err)
+ }
+ return id
+}
+
+// All executes the query and returns a list of PendingAuthSessions.
+func (_q *PendingAuthSessionQuery) All(ctx context.Context) ([]*PendingAuthSession, error) {
+ ctx = setContextOp(ctx, _q.ctx, ent.OpQueryAll)
+ if err := _q.prepareQuery(ctx); err != nil {
+ return nil, err
+ }
+ qr := querierAll[[]*PendingAuthSession, *PendingAuthSessionQuery]()
+ return withInterceptors[[]*PendingAuthSession](ctx, _q, qr, _q.inters)
+}
+
+// AllX is like All, but panics if an error occurs.
+func (_q *PendingAuthSessionQuery) AllX(ctx context.Context) []*PendingAuthSession {
+ nodes, err := _q.All(ctx)
+ if err != nil {
+ panic(err)
+ }
+ return nodes
+}
+
+// IDs executes the query and returns a list of PendingAuthSession IDs.
+func (_q *PendingAuthSessionQuery) IDs(ctx context.Context) (ids []int64, err error) {
+ if _q.ctx.Unique == nil && _q.path != nil {
+ _q.Unique(true)
+ }
+ ctx = setContextOp(ctx, _q.ctx, ent.OpQueryIDs)
+ if err = _q.Select(pendingauthsession.FieldID).Scan(ctx, &ids); err != nil {
+ return nil, err
+ }
+ return ids, nil
+}
+
+// IDsX is like IDs, but panics if an error occurs.
+func (_q *PendingAuthSessionQuery) IDsX(ctx context.Context) []int64 {
+ ids, err := _q.IDs(ctx)
+ if err != nil {
+ panic(err)
+ }
+ return ids
+}
+
+// Count returns the count of the given query.
+func (_q *PendingAuthSessionQuery) Count(ctx context.Context) (int, error) {
+ ctx = setContextOp(ctx, _q.ctx, ent.OpQueryCount)
+ if err := _q.prepareQuery(ctx); err != nil {
+ return 0, err
+ }
+ return withInterceptors[int](ctx, _q, querierCount[*PendingAuthSessionQuery](), _q.inters)
+}
+
+// CountX is like Count, but panics if an error occurs.
+func (_q *PendingAuthSessionQuery) CountX(ctx context.Context) int {
+ count, err := _q.Count(ctx)
+ if err != nil {
+ panic(err)
+ }
+ return count
+}
+
+// Exist returns true if the query has elements in the graph.
+func (_q *PendingAuthSessionQuery) Exist(ctx context.Context) (bool, error) {
+ ctx = setContextOp(ctx, _q.ctx, ent.OpQueryExist)
+ switch _, err := _q.FirstID(ctx); {
+ case IsNotFound(err):
+ return false, nil
+ case err != nil:
+ return false, fmt.Errorf("ent: check existence: %w", err)
+ default:
+ return true, nil
+ }
+}
+
+// ExistX is like Exist, but panics if an error occurs.
+func (_q *PendingAuthSessionQuery) ExistX(ctx context.Context) bool {
+ exist, err := _q.Exist(ctx)
+ if err != nil {
+ panic(err)
+ }
+ return exist
+}
+
+// Clone returns a duplicate of the PendingAuthSessionQuery builder, including all associated steps. It can be
+// used to prepare common query builders and use them differently after the clone is made.
+func (_q *PendingAuthSessionQuery) Clone() *PendingAuthSessionQuery {
+ if _q == nil {
+ return nil
+ }
+ return &PendingAuthSessionQuery{
+ config: _q.config,
+ ctx: _q.ctx.Clone(),
+ order: append([]pendingauthsession.OrderOption{}, _q.order...),
+ inters: append([]Interceptor{}, _q.inters...),
+ predicates: append([]predicate.PendingAuthSession{}, _q.predicates...),
+ withTargetUser: _q.withTargetUser.Clone(),
+ withAdoptionDecision: _q.withAdoptionDecision.Clone(),
+ // clone intermediate query.
+ sql: _q.sql.Clone(),
+ path: _q.path,
+ }
+}
+
+// WithTargetUser tells the query-builder to eager-load the nodes that are connected to
+// the "target_user" edge. The optional arguments are used to configure the query builder of the edge.
+func (_q *PendingAuthSessionQuery) WithTargetUser(opts ...func(*UserQuery)) *PendingAuthSessionQuery {
+ query := (&UserClient{config: _q.config}).Query()
+ for _, opt := range opts {
+ opt(query)
+ }
+ _q.withTargetUser = query
+ return _q
+}
+
+// WithAdoptionDecision tells the query-builder to eager-load the nodes that are connected to
+// the "adoption_decision" edge. The optional arguments are used to configure the query builder of the edge.
+func (_q *PendingAuthSessionQuery) WithAdoptionDecision(opts ...func(*IdentityAdoptionDecisionQuery)) *PendingAuthSessionQuery {
+ query := (&IdentityAdoptionDecisionClient{config: _q.config}).Query()
+ for _, opt := range opts {
+ opt(query)
+ }
+ _q.withAdoptionDecision = query
+ return _q
+}
+
+// GroupBy is used to group vertices by one or more fields/columns.
+// It is often used with aggregate functions, like: count, max, mean, min, sum.
+//
+// Example:
+//
+// var v []struct {
+// CreatedAt time.Time `json:"created_at,omitempty"`
+// Count int `json:"count,omitempty"`
+// }
+//
+// client.PendingAuthSession.Query().
+// GroupBy(pendingauthsession.FieldCreatedAt).
+// Aggregate(ent.Count()).
+// Scan(ctx, &v)
+func (_q *PendingAuthSessionQuery) GroupBy(field string, fields ...string) *PendingAuthSessionGroupBy {
+ _q.ctx.Fields = append([]string{field}, fields...)
+ grbuild := &PendingAuthSessionGroupBy{build: _q}
+ grbuild.flds = &_q.ctx.Fields
+ grbuild.label = pendingauthsession.Label
+ grbuild.scan = grbuild.Scan
+ return grbuild
+}
+
+// Select allows the selection one or more fields/columns for the given query,
+// instead of selecting all fields in the entity.
+//
+// Example:
+//
+// var v []struct {
+// CreatedAt time.Time `json:"created_at,omitempty"`
+// }
+//
+// client.PendingAuthSession.Query().
+// Select(pendingauthsession.FieldCreatedAt).
+// Scan(ctx, &v)
+func (_q *PendingAuthSessionQuery) Select(fields ...string) *PendingAuthSessionSelect {
+ _q.ctx.Fields = append(_q.ctx.Fields, fields...)
+ sbuild := &PendingAuthSessionSelect{PendingAuthSessionQuery: _q}
+ sbuild.label = pendingauthsession.Label
+ sbuild.flds, sbuild.scan = &_q.ctx.Fields, sbuild.Scan
+ return sbuild
+}
+
+// Aggregate returns a PendingAuthSessionSelect configured with the given aggregations.
+func (_q *PendingAuthSessionQuery) Aggregate(fns ...AggregateFunc) *PendingAuthSessionSelect {
+ return _q.Select().Aggregate(fns...)
+}
+
+func (_q *PendingAuthSessionQuery) prepareQuery(ctx context.Context) error {
+ for _, inter := range _q.inters {
+ if inter == nil {
+ return fmt.Errorf("ent: uninitialized interceptor (forgotten import ent/runtime?)")
+ }
+ if trv, ok := inter.(Traverser); ok {
+ if err := trv.Traverse(ctx, _q); err != nil {
+ return err
+ }
+ }
+ }
+ for _, f := range _q.ctx.Fields {
+ if !pendingauthsession.ValidColumn(f) {
+ return &ValidationError{Name: f, err: fmt.Errorf("ent: invalid field %q for query", f)}
+ }
+ }
+ if _q.path != nil {
+ prev, err := _q.path(ctx)
+ if err != nil {
+ return err
+ }
+ _q.sql = prev
+ }
+ return nil
+}
+
+func (_q *PendingAuthSessionQuery) sqlAll(ctx context.Context, hooks ...queryHook) ([]*PendingAuthSession, error) {
+ var (
+ nodes = []*PendingAuthSession{}
+ _spec = _q.querySpec()
+ loadedTypes = [2]bool{
+ _q.withTargetUser != nil,
+ _q.withAdoptionDecision != nil,
+ }
+ )
+ _spec.ScanValues = func(columns []string) ([]any, error) {
+ return (*PendingAuthSession).scanValues(nil, columns)
+ }
+ _spec.Assign = func(columns []string, values []any) error {
+ node := &PendingAuthSession{config: _q.config}
+ nodes = append(nodes, node)
+ node.Edges.loadedTypes = loadedTypes
+ return node.assignValues(columns, values)
+ }
+ if len(_q.modifiers) > 0 {
+ _spec.Modifiers = _q.modifiers
+ }
+ for i := range hooks {
+ hooks[i](ctx, _spec)
+ }
+ if err := sqlgraph.QueryNodes(ctx, _q.driver, _spec); err != nil {
+ return nil, err
+ }
+ if len(nodes) == 0 {
+ return nodes, nil
+ }
+ if query := _q.withTargetUser; query != nil {
+ if err := _q.loadTargetUser(ctx, query, nodes, nil,
+ func(n *PendingAuthSession, e *User) { n.Edges.TargetUser = e }); err != nil {
+ return nil, err
+ }
+ }
+ if query := _q.withAdoptionDecision; query != nil {
+ if err := _q.loadAdoptionDecision(ctx, query, nodes, nil,
+ func(n *PendingAuthSession, e *IdentityAdoptionDecision) { n.Edges.AdoptionDecision = e }); err != nil {
+ return nil, err
+ }
+ }
+ return nodes, nil
+}
+
+func (_q *PendingAuthSessionQuery) loadTargetUser(ctx context.Context, query *UserQuery, nodes []*PendingAuthSession, init func(*PendingAuthSession), assign func(*PendingAuthSession, *User)) error {
+ ids := make([]int64, 0, len(nodes))
+ nodeids := make(map[int64][]*PendingAuthSession)
+ for i := range nodes {
+ if nodes[i].TargetUserID == nil {
+ continue
+ }
+ fk := *nodes[i].TargetUserID
+ if _, ok := nodeids[fk]; !ok {
+ ids = append(ids, fk)
+ }
+ nodeids[fk] = append(nodeids[fk], nodes[i])
+ }
+ if len(ids) == 0 {
+ return nil
+ }
+ query.Where(user.IDIn(ids...))
+ neighbors, err := query.All(ctx)
+ if err != nil {
+ return err
+ }
+ for _, n := range neighbors {
+ nodes, ok := nodeids[n.ID]
+ if !ok {
+ return fmt.Errorf(`unexpected foreign-key "target_user_id" returned %v`, n.ID)
+ }
+ for i := range nodes {
+ assign(nodes[i], n)
+ }
+ }
+ return nil
+}
+func (_q *PendingAuthSessionQuery) loadAdoptionDecision(ctx context.Context, query *IdentityAdoptionDecisionQuery, nodes []*PendingAuthSession, init func(*PendingAuthSession), assign func(*PendingAuthSession, *IdentityAdoptionDecision)) error {
+ fks := make([]driver.Value, 0, len(nodes))
+ nodeids := make(map[int64]*PendingAuthSession)
+ for i := range nodes {
+ fks = append(fks, nodes[i].ID)
+ nodeids[nodes[i].ID] = nodes[i]
+ }
+ if len(query.ctx.Fields) > 0 {
+ query.ctx.AppendFieldOnce(identityadoptiondecision.FieldPendingAuthSessionID)
+ }
+ query.Where(predicate.IdentityAdoptionDecision(func(s *sql.Selector) {
+ s.Where(sql.InValues(s.C(pendingauthsession.AdoptionDecisionColumn), fks...))
+ }))
+ neighbors, err := query.All(ctx)
+ if err != nil {
+ return err
+ }
+ for _, n := range neighbors {
+ fk := n.PendingAuthSessionID
+ node, ok := nodeids[fk]
+ if !ok {
+ return fmt.Errorf(`unexpected referenced foreign-key "pending_auth_session_id" returned %v for node %v`, fk, n.ID)
+ }
+ assign(node, n)
+ }
+ return nil
+}
+
+func (_q *PendingAuthSessionQuery) sqlCount(ctx context.Context) (int, error) {
+ _spec := _q.querySpec()
+ if len(_q.modifiers) > 0 {
+ _spec.Modifiers = _q.modifiers
+ }
+ _spec.Node.Columns = _q.ctx.Fields
+ if len(_q.ctx.Fields) > 0 {
+ _spec.Unique = _q.ctx.Unique != nil && *_q.ctx.Unique
+ }
+ return sqlgraph.CountNodes(ctx, _q.driver, _spec)
+}
+
+func (_q *PendingAuthSessionQuery) querySpec() *sqlgraph.QuerySpec {
+ _spec := sqlgraph.NewQuerySpec(pendingauthsession.Table, pendingauthsession.Columns, sqlgraph.NewFieldSpec(pendingauthsession.FieldID, field.TypeInt64))
+ _spec.From = _q.sql
+ if unique := _q.ctx.Unique; unique != nil {
+ _spec.Unique = *unique
+ } else if _q.path != nil {
+ _spec.Unique = true
+ }
+ if fields := _q.ctx.Fields; len(fields) > 0 {
+ _spec.Node.Columns = make([]string, 0, len(fields))
+ _spec.Node.Columns = append(_spec.Node.Columns, pendingauthsession.FieldID)
+ for i := range fields {
+ if fields[i] != pendingauthsession.FieldID {
+ _spec.Node.Columns = append(_spec.Node.Columns, fields[i])
+ }
+ }
+ if _q.withTargetUser != nil {
+ _spec.Node.AddColumnOnce(pendingauthsession.FieldTargetUserID)
+ }
+ }
+ if ps := _q.predicates; len(ps) > 0 {
+ _spec.Predicate = func(selector *sql.Selector) {
+ for i := range ps {
+ ps[i](selector)
+ }
+ }
+ }
+ if limit := _q.ctx.Limit; limit != nil {
+ _spec.Limit = *limit
+ }
+ if offset := _q.ctx.Offset; offset != nil {
+ _spec.Offset = *offset
+ }
+ if ps := _q.order; len(ps) > 0 {
+ _spec.Order = func(selector *sql.Selector) {
+ for i := range ps {
+ ps[i](selector)
+ }
+ }
+ }
+ return _spec
+}
+
+func (_q *PendingAuthSessionQuery) sqlQuery(ctx context.Context) *sql.Selector {
+ builder := sql.Dialect(_q.driver.Dialect())
+ t1 := builder.Table(pendingauthsession.Table)
+ columns := _q.ctx.Fields
+ if len(columns) == 0 {
+ columns = pendingauthsession.Columns
+ }
+ selector := builder.Select(t1.Columns(columns...)...).From(t1)
+ if _q.sql != nil {
+ selector = _q.sql
+ selector.Select(selector.Columns(columns...)...)
+ }
+ if _q.ctx.Unique != nil && *_q.ctx.Unique {
+ selector.Distinct()
+ }
+ for _, m := range _q.modifiers {
+ m(selector)
+ }
+ for _, p := range _q.predicates {
+ p(selector)
+ }
+ for _, p := range _q.order {
+ p(selector)
+ }
+ if offset := _q.ctx.Offset; offset != nil {
+ // limit is mandatory for offset clause. We start
+ // with default value, and override it below if needed.
+ selector.Offset(*offset).Limit(math.MaxInt32)
+ }
+ if limit := _q.ctx.Limit; limit != nil {
+ selector.Limit(*limit)
+ }
+ return selector
+}
+
+// ForUpdate locks the selected rows against concurrent updates, and prevent them from being
+// updated, deleted or "selected ... for update" by other sessions, until the transaction is
+// either committed or rolled-back.
+func (_q *PendingAuthSessionQuery) ForUpdate(opts ...sql.LockOption) *PendingAuthSessionQuery {
+ if _q.driver.Dialect() == dialect.Postgres {
+ _q.Unique(false)
+ }
+ _q.modifiers = append(_q.modifiers, func(s *sql.Selector) {
+ s.ForUpdate(opts...)
+ })
+ return _q
+}
+
+// ForShare behaves similarly to ForUpdate, except that it acquires a shared mode lock
+// on any rows that are read. Other sessions can read the rows, but cannot modify them
+// until your transaction commits.
+func (_q *PendingAuthSessionQuery) ForShare(opts ...sql.LockOption) *PendingAuthSessionQuery {
+ if _q.driver.Dialect() == dialect.Postgres {
+ _q.Unique(false)
+ }
+ _q.modifiers = append(_q.modifiers, func(s *sql.Selector) {
+ s.ForShare(opts...)
+ })
+ return _q
+}
+
+// PendingAuthSessionGroupBy is the group-by builder for PendingAuthSession entities.
+type PendingAuthSessionGroupBy struct {
+ selector
+ build *PendingAuthSessionQuery
+}
+
+// Aggregate adds the given aggregation functions to the group-by query.
+func (_g *PendingAuthSessionGroupBy) Aggregate(fns ...AggregateFunc) *PendingAuthSessionGroupBy {
+ _g.fns = append(_g.fns, fns...)
+ return _g
+}
+
+// Scan applies the selector query and scans the result into the given value.
+func (_g *PendingAuthSessionGroupBy) Scan(ctx context.Context, v any) error {
+ ctx = setContextOp(ctx, _g.build.ctx, ent.OpQueryGroupBy)
+ if err := _g.build.prepareQuery(ctx); err != nil {
+ return err
+ }
+ return scanWithInterceptors[*PendingAuthSessionQuery, *PendingAuthSessionGroupBy](ctx, _g.build, _g, _g.build.inters, v)
+}
+
+func (_g *PendingAuthSessionGroupBy) sqlScan(ctx context.Context, root *PendingAuthSessionQuery, v any) error {
+ selector := root.sqlQuery(ctx).Select()
+ aggregation := make([]string, 0, len(_g.fns))
+ for _, fn := range _g.fns {
+ aggregation = append(aggregation, fn(selector))
+ }
+ if len(selector.SelectedColumns()) == 0 {
+ columns := make([]string, 0, len(*_g.flds)+len(_g.fns))
+ for _, f := range *_g.flds {
+ columns = append(columns, selector.C(f))
+ }
+ columns = append(columns, aggregation...)
+ selector.Select(columns...)
+ }
+ selector.GroupBy(selector.Columns(*_g.flds...)...)
+ if err := selector.Err(); err != nil {
+ return err
+ }
+ rows := &sql.Rows{}
+ query, args := selector.Query()
+ if err := _g.build.driver.Query(ctx, query, args, rows); err != nil {
+ return err
+ }
+ defer rows.Close()
+ return sql.ScanSlice(rows, v)
+}
+
+// PendingAuthSessionSelect is the builder for selecting fields of PendingAuthSession entities.
+type PendingAuthSessionSelect struct {
+ *PendingAuthSessionQuery
+ selector
+}
+
+// Aggregate adds the given aggregation functions to the selector query.
+func (_s *PendingAuthSessionSelect) Aggregate(fns ...AggregateFunc) *PendingAuthSessionSelect {
+ _s.fns = append(_s.fns, fns...)
+ return _s
+}
+
+// Scan applies the selector query and scans the result into the given value.
+func (_s *PendingAuthSessionSelect) Scan(ctx context.Context, v any) error {
+ ctx = setContextOp(ctx, _s.ctx, ent.OpQuerySelect)
+ if err := _s.prepareQuery(ctx); err != nil {
+ return err
+ }
+ return scanWithInterceptors[*PendingAuthSessionQuery, *PendingAuthSessionSelect](ctx, _s.PendingAuthSessionQuery, _s, _s.inters, v)
+}
+
+func (_s *PendingAuthSessionSelect) sqlScan(ctx context.Context, root *PendingAuthSessionQuery, v any) error {
+ selector := root.sqlQuery(ctx)
+ aggregation := make([]string, 0, len(_s.fns))
+ for _, fn := range _s.fns {
+ aggregation = append(aggregation, fn(selector))
+ }
+ switch n := len(*_s.selector.flds); {
+ case n == 0 && len(aggregation) > 0:
+ selector.Select(aggregation...)
+ case n != 0 && len(aggregation) > 0:
+ selector.AppendSelect(aggregation...)
+ }
+ rows := &sql.Rows{}
+ query, args := selector.Query()
+ if err := _s.driver.Query(ctx, query, args, rows); err != nil {
+ return err
+ }
+ defer rows.Close()
+ return sql.ScanSlice(rows, v)
+}
diff --git a/backend/ent/pendingauthsession_update.go b/backend/ent/pendingauthsession_update.go
new file mode 100644
index 00000000..00066f69
--- /dev/null
+++ b/backend/ent/pendingauthsession_update.go
@@ -0,0 +1,1178 @@
+// Code generated by ent, DO NOT EDIT.
+
+package ent
+
+import (
+ "context"
+ "errors"
+ "fmt"
+ "time"
+
+ "entgo.io/ent/dialect/sql"
+ "entgo.io/ent/dialect/sql/sqlgraph"
+ "entgo.io/ent/schema/field"
+ "github.com/Wei-Shaw/sub2api/ent/identityadoptiondecision"
+ "github.com/Wei-Shaw/sub2api/ent/pendingauthsession"
+ "github.com/Wei-Shaw/sub2api/ent/predicate"
+ "github.com/Wei-Shaw/sub2api/ent/user"
+)
+
+// PendingAuthSessionUpdate is the builder for updating PendingAuthSession entities.
+type PendingAuthSessionUpdate struct {
+ config
+ hooks []Hook
+ mutation *PendingAuthSessionMutation
+}
+
+// Where appends a list predicates to the PendingAuthSessionUpdate builder.
+func (_u *PendingAuthSessionUpdate) Where(ps ...predicate.PendingAuthSession) *PendingAuthSessionUpdate {
+ _u.mutation.Where(ps...)
+ return _u
+}
+
+// SetUpdatedAt sets the "updated_at" field.
+func (_u *PendingAuthSessionUpdate) SetUpdatedAt(v time.Time) *PendingAuthSessionUpdate {
+ _u.mutation.SetUpdatedAt(v)
+ return _u
+}
+
+// SetSessionToken sets the "session_token" field.
+func (_u *PendingAuthSessionUpdate) SetSessionToken(v string) *PendingAuthSessionUpdate {
+ _u.mutation.SetSessionToken(v)
+ return _u
+}
+
+// SetNillableSessionToken sets the "session_token" field if the given value is not nil.
+func (_u *PendingAuthSessionUpdate) SetNillableSessionToken(v *string) *PendingAuthSessionUpdate {
+ if v != nil {
+ _u.SetSessionToken(*v)
+ }
+ return _u
+}
+
+// SetIntent sets the "intent" field.
+func (_u *PendingAuthSessionUpdate) SetIntent(v string) *PendingAuthSessionUpdate {
+ _u.mutation.SetIntent(v)
+ return _u
+}
+
+// SetNillableIntent sets the "intent" field if the given value is not nil.
+func (_u *PendingAuthSessionUpdate) SetNillableIntent(v *string) *PendingAuthSessionUpdate {
+ if v != nil {
+ _u.SetIntent(*v)
+ }
+ return _u
+}
+
+// SetProviderType sets the "provider_type" field.
+func (_u *PendingAuthSessionUpdate) SetProviderType(v string) *PendingAuthSessionUpdate {
+ _u.mutation.SetProviderType(v)
+ return _u
+}
+
+// SetNillableProviderType sets the "provider_type" field if the given value is not nil.
+func (_u *PendingAuthSessionUpdate) SetNillableProviderType(v *string) *PendingAuthSessionUpdate {
+ if v != nil {
+ _u.SetProviderType(*v)
+ }
+ return _u
+}
+
+// SetProviderKey sets the "provider_key" field.
+func (_u *PendingAuthSessionUpdate) SetProviderKey(v string) *PendingAuthSessionUpdate {
+ _u.mutation.SetProviderKey(v)
+ return _u
+}
+
+// SetNillableProviderKey sets the "provider_key" field if the given value is not nil.
+func (_u *PendingAuthSessionUpdate) SetNillableProviderKey(v *string) *PendingAuthSessionUpdate {
+ if v != nil {
+ _u.SetProviderKey(*v)
+ }
+ return _u
+}
+
+// SetProviderSubject sets the "provider_subject" field.
+func (_u *PendingAuthSessionUpdate) SetProviderSubject(v string) *PendingAuthSessionUpdate {
+ _u.mutation.SetProviderSubject(v)
+ return _u
+}
+
+// SetNillableProviderSubject sets the "provider_subject" field if the given value is not nil.
+func (_u *PendingAuthSessionUpdate) SetNillableProviderSubject(v *string) *PendingAuthSessionUpdate {
+ if v != nil {
+ _u.SetProviderSubject(*v)
+ }
+ return _u
+}
+
+// SetTargetUserID sets the "target_user_id" field.
+func (_u *PendingAuthSessionUpdate) SetTargetUserID(v int64) *PendingAuthSessionUpdate {
+ _u.mutation.SetTargetUserID(v)
+ return _u
+}
+
+// SetNillableTargetUserID sets the "target_user_id" field if the given value is not nil.
+func (_u *PendingAuthSessionUpdate) SetNillableTargetUserID(v *int64) *PendingAuthSessionUpdate {
+ if v != nil {
+ _u.SetTargetUserID(*v)
+ }
+ return _u
+}
+
+// ClearTargetUserID clears the value of the "target_user_id" field.
+func (_u *PendingAuthSessionUpdate) ClearTargetUserID() *PendingAuthSessionUpdate {
+ _u.mutation.ClearTargetUserID()
+ return _u
+}
+
+// SetRedirectTo sets the "redirect_to" field.
+func (_u *PendingAuthSessionUpdate) SetRedirectTo(v string) *PendingAuthSessionUpdate {
+ _u.mutation.SetRedirectTo(v)
+ return _u
+}
+
+// SetNillableRedirectTo sets the "redirect_to" field if the given value is not nil.
+func (_u *PendingAuthSessionUpdate) SetNillableRedirectTo(v *string) *PendingAuthSessionUpdate {
+ if v != nil {
+ _u.SetRedirectTo(*v)
+ }
+ return _u
+}
+
+// SetResolvedEmail sets the "resolved_email" field.
+func (_u *PendingAuthSessionUpdate) SetResolvedEmail(v string) *PendingAuthSessionUpdate {
+ _u.mutation.SetResolvedEmail(v)
+ return _u
+}
+
+// SetNillableResolvedEmail sets the "resolved_email" field if the given value is not nil.
+func (_u *PendingAuthSessionUpdate) SetNillableResolvedEmail(v *string) *PendingAuthSessionUpdate {
+ if v != nil {
+ _u.SetResolvedEmail(*v)
+ }
+ return _u
+}
+
+// SetRegistrationPasswordHash sets the "registration_password_hash" field.
+func (_u *PendingAuthSessionUpdate) SetRegistrationPasswordHash(v string) *PendingAuthSessionUpdate {
+ _u.mutation.SetRegistrationPasswordHash(v)
+ return _u
+}
+
+// SetNillableRegistrationPasswordHash sets the "registration_password_hash" field if the given value is not nil.
+func (_u *PendingAuthSessionUpdate) SetNillableRegistrationPasswordHash(v *string) *PendingAuthSessionUpdate {
+ if v != nil {
+ _u.SetRegistrationPasswordHash(*v)
+ }
+ return _u
+}
+
+// SetUpstreamIdentityClaims sets the "upstream_identity_claims" field.
+func (_u *PendingAuthSessionUpdate) SetUpstreamIdentityClaims(v map[string]interface{}) *PendingAuthSessionUpdate {
+ _u.mutation.SetUpstreamIdentityClaims(v)
+ return _u
+}
+
+// SetLocalFlowState sets the "local_flow_state" field.
+func (_u *PendingAuthSessionUpdate) SetLocalFlowState(v map[string]interface{}) *PendingAuthSessionUpdate {
+ _u.mutation.SetLocalFlowState(v)
+ return _u
+}
+
+// SetBrowserSessionKey sets the "browser_session_key" field.
+func (_u *PendingAuthSessionUpdate) SetBrowserSessionKey(v string) *PendingAuthSessionUpdate {
+ _u.mutation.SetBrowserSessionKey(v)
+ return _u
+}
+
+// SetNillableBrowserSessionKey sets the "browser_session_key" field if the given value is not nil.
+func (_u *PendingAuthSessionUpdate) SetNillableBrowserSessionKey(v *string) *PendingAuthSessionUpdate {
+ if v != nil {
+ _u.SetBrowserSessionKey(*v)
+ }
+ return _u
+}
+
+// SetCompletionCodeHash sets the "completion_code_hash" field.
+func (_u *PendingAuthSessionUpdate) SetCompletionCodeHash(v string) *PendingAuthSessionUpdate {
+ _u.mutation.SetCompletionCodeHash(v)
+ return _u
+}
+
+// SetNillableCompletionCodeHash sets the "completion_code_hash" field if the given value is not nil.
+func (_u *PendingAuthSessionUpdate) SetNillableCompletionCodeHash(v *string) *PendingAuthSessionUpdate {
+ if v != nil {
+ _u.SetCompletionCodeHash(*v)
+ }
+ return _u
+}
+
+// SetCompletionCodeExpiresAt sets the "completion_code_expires_at" field.
+func (_u *PendingAuthSessionUpdate) SetCompletionCodeExpiresAt(v time.Time) *PendingAuthSessionUpdate {
+ _u.mutation.SetCompletionCodeExpiresAt(v)
+ return _u
+}
+
+// SetNillableCompletionCodeExpiresAt sets the "completion_code_expires_at" field if the given value is not nil.
+func (_u *PendingAuthSessionUpdate) SetNillableCompletionCodeExpiresAt(v *time.Time) *PendingAuthSessionUpdate {
+ if v != nil {
+ _u.SetCompletionCodeExpiresAt(*v)
+ }
+ return _u
+}
+
+// ClearCompletionCodeExpiresAt clears the value of the "completion_code_expires_at" field.
+func (_u *PendingAuthSessionUpdate) ClearCompletionCodeExpiresAt() *PendingAuthSessionUpdate {
+ _u.mutation.ClearCompletionCodeExpiresAt()
+ return _u
+}
+
+// SetEmailVerifiedAt sets the "email_verified_at" field.
+func (_u *PendingAuthSessionUpdate) SetEmailVerifiedAt(v time.Time) *PendingAuthSessionUpdate {
+ _u.mutation.SetEmailVerifiedAt(v)
+ return _u
+}
+
+// SetNillableEmailVerifiedAt sets the "email_verified_at" field if the given value is not nil.
+func (_u *PendingAuthSessionUpdate) SetNillableEmailVerifiedAt(v *time.Time) *PendingAuthSessionUpdate {
+ if v != nil {
+ _u.SetEmailVerifiedAt(*v)
+ }
+ return _u
+}
+
+// ClearEmailVerifiedAt clears the value of the "email_verified_at" field.
+func (_u *PendingAuthSessionUpdate) ClearEmailVerifiedAt() *PendingAuthSessionUpdate {
+ _u.mutation.ClearEmailVerifiedAt()
+ return _u
+}
+
+// SetPasswordVerifiedAt sets the "password_verified_at" field.
+func (_u *PendingAuthSessionUpdate) SetPasswordVerifiedAt(v time.Time) *PendingAuthSessionUpdate {
+ _u.mutation.SetPasswordVerifiedAt(v)
+ return _u
+}
+
+// SetNillablePasswordVerifiedAt sets the "password_verified_at" field if the given value is not nil.
+func (_u *PendingAuthSessionUpdate) SetNillablePasswordVerifiedAt(v *time.Time) *PendingAuthSessionUpdate {
+ if v != nil {
+ _u.SetPasswordVerifiedAt(*v)
+ }
+ return _u
+}
+
+// ClearPasswordVerifiedAt clears the value of the "password_verified_at" field.
+func (_u *PendingAuthSessionUpdate) ClearPasswordVerifiedAt() *PendingAuthSessionUpdate {
+ _u.mutation.ClearPasswordVerifiedAt()
+ return _u
+}
+
+// SetTotpVerifiedAt sets the "totp_verified_at" field.
+func (_u *PendingAuthSessionUpdate) SetTotpVerifiedAt(v time.Time) *PendingAuthSessionUpdate {
+ _u.mutation.SetTotpVerifiedAt(v)
+ return _u
+}
+
+// SetNillableTotpVerifiedAt sets the "totp_verified_at" field if the given value is not nil.
+func (_u *PendingAuthSessionUpdate) SetNillableTotpVerifiedAt(v *time.Time) *PendingAuthSessionUpdate {
+ if v != nil {
+ _u.SetTotpVerifiedAt(*v)
+ }
+ return _u
+}
+
+// ClearTotpVerifiedAt clears the value of the "totp_verified_at" field.
+func (_u *PendingAuthSessionUpdate) ClearTotpVerifiedAt() *PendingAuthSessionUpdate {
+ _u.mutation.ClearTotpVerifiedAt()
+ return _u
+}
+
+// SetExpiresAt sets the "expires_at" field.
+func (_u *PendingAuthSessionUpdate) SetExpiresAt(v time.Time) *PendingAuthSessionUpdate {
+ _u.mutation.SetExpiresAt(v)
+ return _u
+}
+
+// SetNillableExpiresAt sets the "expires_at" field if the given value is not nil.
+func (_u *PendingAuthSessionUpdate) SetNillableExpiresAt(v *time.Time) *PendingAuthSessionUpdate {
+ if v != nil {
+ _u.SetExpiresAt(*v)
+ }
+ return _u
+}
+
+// SetConsumedAt sets the "consumed_at" field.
+func (_u *PendingAuthSessionUpdate) SetConsumedAt(v time.Time) *PendingAuthSessionUpdate {
+ _u.mutation.SetConsumedAt(v)
+ return _u
+}
+
+// SetNillableConsumedAt sets the "consumed_at" field if the given value is not nil.
+func (_u *PendingAuthSessionUpdate) SetNillableConsumedAt(v *time.Time) *PendingAuthSessionUpdate {
+ if v != nil {
+ _u.SetConsumedAt(*v)
+ }
+ return _u
+}
+
+// ClearConsumedAt clears the value of the "consumed_at" field.
+func (_u *PendingAuthSessionUpdate) ClearConsumedAt() *PendingAuthSessionUpdate {
+ _u.mutation.ClearConsumedAt()
+ return _u
+}
+
+// SetTargetUser sets the "target_user" edge to the User entity.
+func (_u *PendingAuthSessionUpdate) SetTargetUser(v *User) *PendingAuthSessionUpdate {
+ return _u.SetTargetUserID(v.ID)
+}
+
+// SetAdoptionDecisionID sets the "adoption_decision" edge to the IdentityAdoptionDecision entity by ID.
+func (_u *PendingAuthSessionUpdate) SetAdoptionDecisionID(id int64) *PendingAuthSessionUpdate {
+ _u.mutation.SetAdoptionDecisionID(id)
+ return _u
+}
+
+// SetNillableAdoptionDecisionID sets the "adoption_decision" edge to the IdentityAdoptionDecision entity by ID if the given value is not nil.
+func (_u *PendingAuthSessionUpdate) SetNillableAdoptionDecisionID(id *int64) *PendingAuthSessionUpdate {
+ if id != nil {
+ _u = _u.SetAdoptionDecisionID(*id)
+ }
+ return _u
+}
+
+// SetAdoptionDecision sets the "adoption_decision" edge to the IdentityAdoptionDecision entity.
+func (_u *PendingAuthSessionUpdate) SetAdoptionDecision(v *IdentityAdoptionDecision) *PendingAuthSessionUpdate {
+ return _u.SetAdoptionDecisionID(v.ID)
+}
+
+// Mutation returns the PendingAuthSessionMutation object of the builder.
+func (_u *PendingAuthSessionUpdate) Mutation() *PendingAuthSessionMutation {
+ return _u.mutation
+}
+
+// ClearTargetUser clears the "target_user" edge to the User entity.
+func (_u *PendingAuthSessionUpdate) ClearTargetUser() *PendingAuthSessionUpdate {
+ _u.mutation.ClearTargetUser()
+ return _u
+}
+
+// ClearAdoptionDecision clears the "adoption_decision" edge to the IdentityAdoptionDecision entity.
+func (_u *PendingAuthSessionUpdate) ClearAdoptionDecision() *PendingAuthSessionUpdate {
+ _u.mutation.ClearAdoptionDecision()
+ return _u
+}
+
+// Save executes the query and returns the number of nodes affected by the update operation.
+func (_u *PendingAuthSessionUpdate) Save(ctx context.Context) (int, error) {
+ _u.defaults()
+ return withHooks(ctx, _u.sqlSave, _u.mutation, _u.hooks)
+}
+
+// SaveX is like Save, but panics if an error occurs.
+func (_u *PendingAuthSessionUpdate) SaveX(ctx context.Context) int {
+ affected, err := _u.Save(ctx)
+ if err != nil {
+ panic(err)
+ }
+ return affected
+}
+
+// Exec executes the query.
+func (_u *PendingAuthSessionUpdate) Exec(ctx context.Context) error {
+ _, err := _u.Save(ctx)
+ return err
+}
+
+// ExecX is like Exec, but panics if an error occurs.
+func (_u *PendingAuthSessionUpdate) ExecX(ctx context.Context) {
+ if err := _u.Exec(ctx); err != nil {
+ panic(err)
+ }
+}
+
+// defaults sets the default values of the builder before save.
+func (_u *PendingAuthSessionUpdate) defaults() {
+ if _, ok := _u.mutation.UpdatedAt(); !ok {
+ v := pendingauthsession.UpdateDefaultUpdatedAt()
+ _u.mutation.SetUpdatedAt(v)
+ }
+}
+
+// check runs all checks and user-defined validators on the builder.
+func (_u *PendingAuthSessionUpdate) check() error {
+ if v, ok := _u.mutation.SessionToken(); ok {
+ if err := pendingauthsession.SessionTokenValidator(v); err != nil {
+ return &ValidationError{Name: "session_token", err: fmt.Errorf(`ent: validator failed for field "PendingAuthSession.session_token": %w`, err)}
+ }
+ }
+ if v, ok := _u.mutation.Intent(); ok {
+ if err := pendingauthsession.IntentValidator(v); err != nil {
+ return &ValidationError{Name: "intent", err: fmt.Errorf(`ent: validator failed for field "PendingAuthSession.intent": %w`, err)}
+ }
+ }
+ if v, ok := _u.mutation.ProviderType(); ok {
+ if err := pendingauthsession.ProviderTypeValidator(v); err != nil {
+ return &ValidationError{Name: "provider_type", err: fmt.Errorf(`ent: validator failed for field "PendingAuthSession.provider_type": %w`, err)}
+ }
+ }
+ if v, ok := _u.mutation.ProviderKey(); ok {
+ if err := pendingauthsession.ProviderKeyValidator(v); err != nil {
+ return &ValidationError{Name: "provider_key", err: fmt.Errorf(`ent: validator failed for field "PendingAuthSession.provider_key": %w`, err)}
+ }
+ }
+ if v, ok := _u.mutation.ProviderSubject(); ok {
+ if err := pendingauthsession.ProviderSubjectValidator(v); err != nil {
+ return &ValidationError{Name: "provider_subject", err: fmt.Errorf(`ent: validator failed for field "PendingAuthSession.provider_subject": %w`, err)}
+ }
+ }
+ return nil
+}
+
+func (_u *PendingAuthSessionUpdate) sqlSave(ctx context.Context) (_node int, err error) {
+ if err := _u.check(); err != nil {
+ return _node, err
+ }
+ _spec := sqlgraph.NewUpdateSpec(pendingauthsession.Table, pendingauthsession.Columns, sqlgraph.NewFieldSpec(pendingauthsession.FieldID, field.TypeInt64))
+ if ps := _u.mutation.predicates; len(ps) > 0 {
+ _spec.Predicate = func(selector *sql.Selector) {
+ for i := range ps {
+ ps[i](selector)
+ }
+ }
+ }
+ if value, ok := _u.mutation.UpdatedAt(); ok {
+ _spec.SetField(pendingauthsession.FieldUpdatedAt, field.TypeTime, value)
+ }
+ if value, ok := _u.mutation.SessionToken(); ok {
+ _spec.SetField(pendingauthsession.FieldSessionToken, field.TypeString, value)
+ }
+ if value, ok := _u.mutation.Intent(); ok {
+ _spec.SetField(pendingauthsession.FieldIntent, field.TypeString, value)
+ }
+ if value, ok := _u.mutation.ProviderType(); ok {
+ _spec.SetField(pendingauthsession.FieldProviderType, field.TypeString, value)
+ }
+ if value, ok := _u.mutation.ProviderKey(); ok {
+ _spec.SetField(pendingauthsession.FieldProviderKey, field.TypeString, value)
+ }
+ if value, ok := _u.mutation.ProviderSubject(); ok {
+ _spec.SetField(pendingauthsession.FieldProviderSubject, field.TypeString, value)
+ }
+ if value, ok := _u.mutation.RedirectTo(); ok {
+ _spec.SetField(pendingauthsession.FieldRedirectTo, field.TypeString, value)
+ }
+ if value, ok := _u.mutation.ResolvedEmail(); ok {
+ _spec.SetField(pendingauthsession.FieldResolvedEmail, field.TypeString, value)
+ }
+ if value, ok := _u.mutation.RegistrationPasswordHash(); ok {
+ _spec.SetField(pendingauthsession.FieldRegistrationPasswordHash, field.TypeString, value)
+ }
+ if value, ok := _u.mutation.UpstreamIdentityClaims(); ok {
+ _spec.SetField(pendingauthsession.FieldUpstreamIdentityClaims, field.TypeJSON, value)
+ }
+ if value, ok := _u.mutation.LocalFlowState(); ok {
+ _spec.SetField(pendingauthsession.FieldLocalFlowState, field.TypeJSON, value)
+ }
+ if value, ok := _u.mutation.BrowserSessionKey(); ok {
+ _spec.SetField(pendingauthsession.FieldBrowserSessionKey, field.TypeString, value)
+ }
+ if value, ok := _u.mutation.CompletionCodeHash(); ok {
+ _spec.SetField(pendingauthsession.FieldCompletionCodeHash, field.TypeString, value)
+ }
+ if value, ok := _u.mutation.CompletionCodeExpiresAt(); ok {
+ _spec.SetField(pendingauthsession.FieldCompletionCodeExpiresAt, field.TypeTime, value)
+ }
+ if _u.mutation.CompletionCodeExpiresAtCleared() {
+ _spec.ClearField(pendingauthsession.FieldCompletionCodeExpiresAt, field.TypeTime)
+ }
+ if value, ok := _u.mutation.EmailVerifiedAt(); ok {
+ _spec.SetField(pendingauthsession.FieldEmailVerifiedAt, field.TypeTime, value)
+ }
+ if _u.mutation.EmailVerifiedAtCleared() {
+ _spec.ClearField(pendingauthsession.FieldEmailVerifiedAt, field.TypeTime)
+ }
+ if value, ok := _u.mutation.PasswordVerifiedAt(); ok {
+ _spec.SetField(pendingauthsession.FieldPasswordVerifiedAt, field.TypeTime, value)
+ }
+ if _u.mutation.PasswordVerifiedAtCleared() {
+ _spec.ClearField(pendingauthsession.FieldPasswordVerifiedAt, field.TypeTime)
+ }
+ if value, ok := _u.mutation.TotpVerifiedAt(); ok {
+ _spec.SetField(pendingauthsession.FieldTotpVerifiedAt, field.TypeTime, value)
+ }
+ if _u.mutation.TotpVerifiedAtCleared() {
+ _spec.ClearField(pendingauthsession.FieldTotpVerifiedAt, field.TypeTime)
+ }
+ if value, ok := _u.mutation.ExpiresAt(); ok {
+ _spec.SetField(pendingauthsession.FieldExpiresAt, field.TypeTime, value)
+ }
+ if value, ok := _u.mutation.ConsumedAt(); ok {
+ _spec.SetField(pendingauthsession.FieldConsumedAt, field.TypeTime, value)
+ }
+ if _u.mutation.ConsumedAtCleared() {
+ _spec.ClearField(pendingauthsession.FieldConsumedAt, field.TypeTime)
+ }
+ if _u.mutation.TargetUserCleared() {
+ edge := &sqlgraph.EdgeSpec{
+ Rel: sqlgraph.M2O,
+ Inverse: true,
+ Table: pendingauthsession.TargetUserTable,
+ Columns: []string{pendingauthsession.TargetUserColumn},
+ Bidi: false,
+ Target: &sqlgraph.EdgeTarget{
+ IDSpec: sqlgraph.NewFieldSpec(user.FieldID, field.TypeInt64),
+ },
+ }
+ _spec.Edges.Clear = append(_spec.Edges.Clear, edge)
+ }
+ if nodes := _u.mutation.TargetUserIDs(); len(nodes) > 0 {
+ edge := &sqlgraph.EdgeSpec{
+ Rel: sqlgraph.M2O,
+ Inverse: true,
+ Table: pendingauthsession.TargetUserTable,
+ Columns: []string{pendingauthsession.TargetUserColumn},
+ Bidi: false,
+ Target: &sqlgraph.EdgeTarget{
+ IDSpec: sqlgraph.NewFieldSpec(user.FieldID, field.TypeInt64),
+ },
+ }
+ for _, k := range nodes {
+ edge.Target.Nodes = append(edge.Target.Nodes, k)
+ }
+ _spec.Edges.Add = append(_spec.Edges.Add, edge)
+ }
+ if _u.mutation.AdoptionDecisionCleared() {
+ edge := &sqlgraph.EdgeSpec{
+ Rel: sqlgraph.O2O,
+ Inverse: false,
+ Table: pendingauthsession.AdoptionDecisionTable,
+ Columns: []string{pendingauthsession.AdoptionDecisionColumn},
+ Bidi: false,
+ Target: &sqlgraph.EdgeTarget{
+ IDSpec: sqlgraph.NewFieldSpec(identityadoptiondecision.FieldID, field.TypeInt64),
+ },
+ }
+ _spec.Edges.Clear = append(_spec.Edges.Clear, edge)
+ }
+ if nodes := _u.mutation.AdoptionDecisionIDs(); len(nodes) > 0 {
+ edge := &sqlgraph.EdgeSpec{
+ Rel: sqlgraph.O2O,
+ Inverse: false,
+ Table: pendingauthsession.AdoptionDecisionTable,
+ Columns: []string{pendingauthsession.AdoptionDecisionColumn},
+ Bidi: false,
+ Target: &sqlgraph.EdgeTarget{
+ IDSpec: sqlgraph.NewFieldSpec(identityadoptiondecision.FieldID, field.TypeInt64),
+ },
+ }
+ for _, k := range nodes {
+ edge.Target.Nodes = append(edge.Target.Nodes, k)
+ }
+ _spec.Edges.Add = append(_spec.Edges.Add, edge)
+ }
+ if _node, err = sqlgraph.UpdateNodes(ctx, _u.driver, _spec); err != nil {
+ if _, ok := err.(*sqlgraph.NotFoundError); ok {
+ err = &NotFoundError{pendingauthsession.Label}
+ } else if sqlgraph.IsConstraintError(err) {
+ err = &ConstraintError{msg: err.Error(), wrap: err}
+ }
+ return 0, err
+ }
+ _u.mutation.done = true
+ return _node, nil
+}
+
+// PendingAuthSessionUpdateOne is the builder for updating a single PendingAuthSession entity.
+type PendingAuthSessionUpdateOne struct {
+ config
+ fields []string
+ hooks []Hook
+ mutation *PendingAuthSessionMutation
+}
+
+// SetUpdatedAt sets the "updated_at" field.
+func (_u *PendingAuthSessionUpdateOne) SetUpdatedAt(v time.Time) *PendingAuthSessionUpdateOne {
+ _u.mutation.SetUpdatedAt(v)
+ return _u
+}
+
+// SetSessionToken sets the "session_token" field.
+func (_u *PendingAuthSessionUpdateOne) SetSessionToken(v string) *PendingAuthSessionUpdateOne {
+ _u.mutation.SetSessionToken(v)
+ return _u
+}
+
+// SetNillableSessionToken sets the "session_token" field if the given value is not nil.
+func (_u *PendingAuthSessionUpdateOne) SetNillableSessionToken(v *string) *PendingAuthSessionUpdateOne {
+ if v != nil {
+ _u.SetSessionToken(*v)
+ }
+ return _u
+}
+
+// SetIntent sets the "intent" field.
+func (_u *PendingAuthSessionUpdateOne) SetIntent(v string) *PendingAuthSessionUpdateOne {
+ _u.mutation.SetIntent(v)
+ return _u
+}
+
+// SetNillableIntent sets the "intent" field if the given value is not nil.
+func (_u *PendingAuthSessionUpdateOne) SetNillableIntent(v *string) *PendingAuthSessionUpdateOne {
+ if v != nil {
+ _u.SetIntent(*v)
+ }
+ return _u
+}
+
+// SetProviderType sets the "provider_type" field.
+func (_u *PendingAuthSessionUpdateOne) SetProviderType(v string) *PendingAuthSessionUpdateOne {
+ _u.mutation.SetProviderType(v)
+ return _u
+}
+
+// SetNillableProviderType sets the "provider_type" field if the given value is not nil.
+func (_u *PendingAuthSessionUpdateOne) SetNillableProviderType(v *string) *PendingAuthSessionUpdateOne {
+ if v != nil {
+ _u.SetProviderType(*v)
+ }
+ return _u
+}
+
+// SetProviderKey sets the "provider_key" field.
+func (_u *PendingAuthSessionUpdateOne) SetProviderKey(v string) *PendingAuthSessionUpdateOne {
+ _u.mutation.SetProviderKey(v)
+ return _u
+}
+
+// SetNillableProviderKey sets the "provider_key" field if the given value is not nil.
+func (_u *PendingAuthSessionUpdateOne) SetNillableProviderKey(v *string) *PendingAuthSessionUpdateOne {
+ if v != nil {
+ _u.SetProviderKey(*v)
+ }
+ return _u
+}
+
+// SetProviderSubject sets the "provider_subject" field.
+func (_u *PendingAuthSessionUpdateOne) SetProviderSubject(v string) *PendingAuthSessionUpdateOne {
+ _u.mutation.SetProviderSubject(v)
+ return _u
+}
+
+// SetNillableProviderSubject sets the "provider_subject" field if the given value is not nil.
+func (_u *PendingAuthSessionUpdateOne) SetNillableProviderSubject(v *string) *PendingAuthSessionUpdateOne {
+ if v != nil {
+ _u.SetProviderSubject(*v)
+ }
+ return _u
+}
+
+// SetTargetUserID sets the "target_user_id" field.
+func (_u *PendingAuthSessionUpdateOne) SetTargetUserID(v int64) *PendingAuthSessionUpdateOne {
+ _u.mutation.SetTargetUserID(v)
+ return _u
+}
+
+// SetNillableTargetUserID sets the "target_user_id" field if the given value is not nil.
+func (_u *PendingAuthSessionUpdateOne) SetNillableTargetUserID(v *int64) *PendingAuthSessionUpdateOne {
+ if v != nil {
+ _u.SetTargetUserID(*v)
+ }
+ return _u
+}
+
+// ClearTargetUserID clears the value of the "target_user_id" field.
+func (_u *PendingAuthSessionUpdateOne) ClearTargetUserID() *PendingAuthSessionUpdateOne {
+ _u.mutation.ClearTargetUserID()
+ return _u
+}
+
+// SetRedirectTo sets the "redirect_to" field.
+func (_u *PendingAuthSessionUpdateOne) SetRedirectTo(v string) *PendingAuthSessionUpdateOne {
+ _u.mutation.SetRedirectTo(v)
+ return _u
+}
+
+// SetNillableRedirectTo sets the "redirect_to" field if the given value is not nil.
+func (_u *PendingAuthSessionUpdateOne) SetNillableRedirectTo(v *string) *PendingAuthSessionUpdateOne {
+ if v != nil {
+ _u.SetRedirectTo(*v)
+ }
+ return _u
+}
+
+// SetResolvedEmail sets the "resolved_email" field.
+func (_u *PendingAuthSessionUpdateOne) SetResolvedEmail(v string) *PendingAuthSessionUpdateOne {
+ _u.mutation.SetResolvedEmail(v)
+ return _u
+}
+
+// SetNillableResolvedEmail sets the "resolved_email" field if the given value is not nil.
+func (_u *PendingAuthSessionUpdateOne) SetNillableResolvedEmail(v *string) *PendingAuthSessionUpdateOne {
+ if v != nil {
+ _u.SetResolvedEmail(*v)
+ }
+ return _u
+}
+
+// SetRegistrationPasswordHash sets the "registration_password_hash" field.
+func (_u *PendingAuthSessionUpdateOne) SetRegistrationPasswordHash(v string) *PendingAuthSessionUpdateOne {
+ _u.mutation.SetRegistrationPasswordHash(v)
+ return _u
+}
+
+// SetNillableRegistrationPasswordHash sets the "registration_password_hash" field if the given value is not nil.
+func (_u *PendingAuthSessionUpdateOne) SetNillableRegistrationPasswordHash(v *string) *PendingAuthSessionUpdateOne {
+ if v != nil {
+ _u.SetRegistrationPasswordHash(*v)
+ }
+ return _u
+}
+
+// SetUpstreamIdentityClaims sets the "upstream_identity_claims" field.
+func (_u *PendingAuthSessionUpdateOne) SetUpstreamIdentityClaims(v map[string]interface{}) *PendingAuthSessionUpdateOne {
+ _u.mutation.SetUpstreamIdentityClaims(v)
+ return _u
+}
+
+// SetLocalFlowState sets the "local_flow_state" field.
+func (_u *PendingAuthSessionUpdateOne) SetLocalFlowState(v map[string]interface{}) *PendingAuthSessionUpdateOne {
+ _u.mutation.SetLocalFlowState(v)
+ return _u
+}
+
+// SetBrowserSessionKey sets the "browser_session_key" field.
+func (_u *PendingAuthSessionUpdateOne) SetBrowserSessionKey(v string) *PendingAuthSessionUpdateOne {
+ _u.mutation.SetBrowserSessionKey(v)
+ return _u
+}
+
+// SetNillableBrowserSessionKey sets the "browser_session_key" field if the given value is not nil.
+func (_u *PendingAuthSessionUpdateOne) SetNillableBrowserSessionKey(v *string) *PendingAuthSessionUpdateOne {
+ if v != nil {
+ _u.SetBrowserSessionKey(*v)
+ }
+ return _u
+}
+
+// SetCompletionCodeHash sets the "completion_code_hash" field.
+func (_u *PendingAuthSessionUpdateOne) SetCompletionCodeHash(v string) *PendingAuthSessionUpdateOne {
+ _u.mutation.SetCompletionCodeHash(v)
+ return _u
+}
+
+// SetNillableCompletionCodeHash sets the "completion_code_hash" field if the given value is not nil.
+func (_u *PendingAuthSessionUpdateOne) SetNillableCompletionCodeHash(v *string) *PendingAuthSessionUpdateOne {
+ if v != nil {
+ _u.SetCompletionCodeHash(*v)
+ }
+ return _u
+}
+
+// SetCompletionCodeExpiresAt sets the "completion_code_expires_at" field.
+func (_u *PendingAuthSessionUpdateOne) SetCompletionCodeExpiresAt(v time.Time) *PendingAuthSessionUpdateOne {
+ _u.mutation.SetCompletionCodeExpiresAt(v)
+ return _u
+}
+
+// SetNillableCompletionCodeExpiresAt sets the "completion_code_expires_at" field if the given value is not nil.
+func (_u *PendingAuthSessionUpdateOne) SetNillableCompletionCodeExpiresAt(v *time.Time) *PendingAuthSessionUpdateOne {
+ if v != nil {
+ _u.SetCompletionCodeExpiresAt(*v)
+ }
+ return _u
+}
+
+// ClearCompletionCodeExpiresAt clears the value of the "completion_code_expires_at" field.
+func (_u *PendingAuthSessionUpdateOne) ClearCompletionCodeExpiresAt() *PendingAuthSessionUpdateOne {
+ _u.mutation.ClearCompletionCodeExpiresAt()
+ return _u
+}
+
+// SetEmailVerifiedAt sets the "email_verified_at" field.
+func (_u *PendingAuthSessionUpdateOne) SetEmailVerifiedAt(v time.Time) *PendingAuthSessionUpdateOne {
+ _u.mutation.SetEmailVerifiedAt(v)
+ return _u
+}
+
+// SetNillableEmailVerifiedAt sets the "email_verified_at" field if the given value is not nil.
+func (_u *PendingAuthSessionUpdateOne) SetNillableEmailVerifiedAt(v *time.Time) *PendingAuthSessionUpdateOne {
+ if v != nil {
+ _u.SetEmailVerifiedAt(*v)
+ }
+ return _u
+}
+
+// ClearEmailVerifiedAt clears the value of the "email_verified_at" field.
+func (_u *PendingAuthSessionUpdateOne) ClearEmailVerifiedAt() *PendingAuthSessionUpdateOne {
+ _u.mutation.ClearEmailVerifiedAt()
+ return _u
+}
+
+// SetPasswordVerifiedAt sets the "password_verified_at" field.
+func (_u *PendingAuthSessionUpdateOne) SetPasswordVerifiedAt(v time.Time) *PendingAuthSessionUpdateOne {
+ _u.mutation.SetPasswordVerifiedAt(v)
+ return _u
+}
+
+// SetNillablePasswordVerifiedAt sets the "password_verified_at" field if the given value is not nil.
+func (_u *PendingAuthSessionUpdateOne) SetNillablePasswordVerifiedAt(v *time.Time) *PendingAuthSessionUpdateOne {
+ if v != nil {
+ _u.SetPasswordVerifiedAt(*v)
+ }
+ return _u
+}
+
+// ClearPasswordVerifiedAt clears the value of the "password_verified_at" field.
+func (_u *PendingAuthSessionUpdateOne) ClearPasswordVerifiedAt() *PendingAuthSessionUpdateOne {
+ _u.mutation.ClearPasswordVerifiedAt()
+ return _u
+}
+
+// SetTotpVerifiedAt sets the "totp_verified_at" field.
+func (_u *PendingAuthSessionUpdateOne) SetTotpVerifiedAt(v time.Time) *PendingAuthSessionUpdateOne {
+ _u.mutation.SetTotpVerifiedAt(v)
+ return _u
+}
+
+// SetNillableTotpVerifiedAt sets the "totp_verified_at" field if the given value is not nil.
+func (_u *PendingAuthSessionUpdateOne) SetNillableTotpVerifiedAt(v *time.Time) *PendingAuthSessionUpdateOne {
+ if v != nil {
+ _u.SetTotpVerifiedAt(*v)
+ }
+ return _u
+}
+
+// ClearTotpVerifiedAt clears the value of the "totp_verified_at" field.
+func (_u *PendingAuthSessionUpdateOne) ClearTotpVerifiedAt() *PendingAuthSessionUpdateOne {
+ _u.mutation.ClearTotpVerifiedAt()
+ return _u
+}
+
+// SetExpiresAt sets the "expires_at" field.
+func (_u *PendingAuthSessionUpdateOne) SetExpiresAt(v time.Time) *PendingAuthSessionUpdateOne {
+ _u.mutation.SetExpiresAt(v)
+ return _u
+}
+
+// SetNillableExpiresAt sets the "expires_at" field if the given value is not nil.
+func (_u *PendingAuthSessionUpdateOne) SetNillableExpiresAt(v *time.Time) *PendingAuthSessionUpdateOne {
+ if v != nil {
+ _u.SetExpiresAt(*v)
+ }
+ return _u
+}
+
+// SetConsumedAt sets the "consumed_at" field.
+func (_u *PendingAuthSessionUpdateOne) SetConsumedAt(v time.Time) *PendingAuthSessionUpdateOne {
+ _u.mutation.SetConsumedAt(v)
+ return _u
+}
+
+// SetNillableConsumedAt sets the "consumed_at" field if the given value is not nil.
+func (_u *PendingAuthSessionUpdateOne) SetNillableConsumedAt(v *time.Time) *PendingAuthSessionUpdateOne {
+ if v != nil {
+ _u.SetConsumedAt(*v)
+ }
+ return _u
+}
+
+// ClearConsumedAt clears the value of the "consumed_at" field.
+func (_u *PendingAuthSessionUpdateOne) ClearConsumedAt() *PendingAuthSessionUpdateOne {
+ _u.mutation.ClearConsumedAt()
+ return _u
+}
+
+// SetTargetUser sets the "target_user" edge to the User entity.
+func (_u *PendingAuthSessionUpdateOne) SetTargetUser(v *User) *PendingAuthSessionUpdateOne {
+ return _u.SetTargetUserID(v.ID)
+}
+
+// SetAdoptionDecisionID sets the "adoption_decision" edge to the IdentityAdoptionDecision entity by ID.
+func (_u *PendingAuthSessionUpdateOne) SetAdoptionDecisionID(id int64) *PendingAuthSessionUpdateOne {
+ _u.mutation.SetAdoptionDecisionID(id)
+ return _u
+}
+
+// SetNillableAdoptionDecisionID sets the "adoption_decision" edge to the IdentityAdoptionDecision entity by ID if the given value is not nil.
+func (_u *PendingAuthSessionUpdateOne) SetNillableAdoptionDecisionID(id *int64) *PendingAuthSessionUpdateOne {
+ if id != nil {
+ _u = _u.SetAdoptionDecisionID(*id)
+ }
+ return _u
+}
+
+// SetAdoptionDecision sets the "adoption_decision" edge to the IdentityAdoptionDecision entity.
+func (_u *PendingAuthSessionUpdateOne) SetAdoptionDecision(v *IdentityAdoptionDecision) *PendingAuthSessionUpdateOne {
+ return _u.SetAdoptionDecisionID(v.ID)
+}
+
+// Mutation returns the PendingAuthSessionMutation object of the builder.
+func (_u *PendingAuthSessionUpdateOne) Mutation() *PendingAuthSessionMutation {
+ return _u.mutation
+}
+
+// ClearTargetUser clears the "target_user" edge to the User entity.
+func (_u *PendingAuthSessionUpdateOne) ClearTargetUser() *PendingAuthSessionUpdateOne {
+ _u.mutation.ClearTargetUser()
+ return _u
+}
+
+// ClearAdoptionDecision clears the "adoption_decision" edge to the IdentityAdoptionDecision entity.
+func (_u *PendingAuthSessionUpdateOne) ClearAdoptionDecision() *PendingAuthSessionUpdateOne {
+ _u.mutation.ClearAdoptionDecision()
+ return _u
+}
+
+// Where appends a list predicates to the PendingAuthSessionUpdate builder.
+func (_u *PendingAuthSessionUpdateOne) Where(ps ...predicate.PendingAuthSession) *PendingAuthSessionUpdateOne {
+ _u.mutation.Where(ps...)
+ return _u
+}
+
+// Select allows selecting one or more fields (columns) of the returned entity.
+// The default is selecting all fields defined in the entity schema.
+func (_u *PendingAuthSessionUpdateOne) Select(field string, fields ...string) *PendingAuthSessionUpdateOne {
+ _u.fields = append([]string{field}, fields...)
+ return _u
+}
+
+// Save executes the query and returns the updated PendingAuthSession entity.
+func (_u *PendingAuthSessionUpdateOne) Save(ctx context.Context) (*PendingAuthSession, error) {
+ _u.defaults()
+ return withHooks(ctx, _u.sqlSave, _u.mutation, _u.hooks)
+}
+
+// SaveX is like Save, but panics if an error occurs.
+func (_u *PendingAuthSessionUpdateOne) SaveX(ctx context.Context) *PendingAuthSession {
+ node, err := _u.Save(ctx)
+ if err != nil {
+ panic(err)
+ }
+ return node
+}
+
+// Exec executes the query on the entity.
+func (_u *PendingAuthSessionUpdateOne) Exec(ctx context.Context) error {
+ _, err := _u.Save(ctx)
+ return err
+}
+
+// ExecX is like Exec, but panics if an error occurs.
+func (_u *PendingAuthSessionUpdateOne) ExecX(ctx context.Context) {
+ if err := _u.Exec(ctx); err != nil {
+ panic(err)
+ }
+}
+
+// defaults sets the default values of the builder before save.
+func (_u *PendingAuthSessionUpdateOne) defaults() {
+ if _, ok := _u.mutation.UpdatedAt(); !ok {
+ v := pendingauthsession.UpdateDefaultUpdatedAt()
+ _u.mutation.SetUpdatedAt(v)
+ }
+}
+
+// check runs all checks and user-defined validators on the builder.
+func (_u *PendingAuthSessionUpdateOne) check() error {
+ if v, ok := _u.mutation.SessionToken(); ok {
+ if err := pendingauthsession.SessionTokenValidator(v); err != nil {
+ return &ValidationError{Name: "session_token", err: fmt.Errorf(`ent: validator failed for field "PendingAuthSession.session_token": %w`, err)}
+ }
+ }
+ if v, ok := _u.mutation.Intent(); ok {
+ if err := pendingauthsession.IntentValidator(v); err != nil {
+ return &ValidationError{Name: "intent", err: fmt.Errorf(`ent: validator failed for field "PendingAuthSession.intent": %w`, err)}
+ }
+ }
+ if v, ok := _u.mutation.ProviderType(); ok {
+ if err := pendingauthsession.ProviderTypeValidator(v); err != nil {
+ return &ValidationError{Name: "provider_type", err: fmt.Errorf(`ent: validator failed for field "PendingAuthSession.provider_type": %w`, err)}
+ }
+ }
+ if v, ok := _u.mutation.ProviderKey(); ok {
+ if err := pendingauthsession.ProviderKeyValidator(v); err != nil {
+ return &ValidationError{Name: "provider_key", err: fmt.Errorf(`ent: validator failed for field "PendingAuthSession.provider_key": %w`, err)}
+ }
+ }
+ if v, ok := _u.mutation.ProviderSubject(); ok {
+ if err := pendingauthsession.ProviderSubjectValidator(v); err != nil {
+ return &ValidationError{Name: "provider_subject", err: fmt.Errorf(`ent: validator failed for field "PendingAuthSession.provider_subject": %w`, err)}
+ }
+ }
+ return nil
+}
+
+func (_u *PendingAuthSessionUpdateOne) sqlSave(ctx context.Context) (_node *PendingAuthSession, err error) {
+ if err := _u.check(); err != nil {
+ return _node, err
+ }
+ _spec := sqlgraph.NewUpdateSpec(pendingauthsession.Table, pendingauthsession.Columns, sqlgraph.NewFieldSpec(pendingauthsession.FieldID, field.TypeInt64))
+ id, ok := _u.mutation.ID()
+ if !ok {
+ return nil, &ValidationError{Name: "id", err: errors.New(`ent: missing "PendingAuthSession.id" for update`)}
+ }
+ _spec.Node.ID.Value = id
+ if fields := _u.fields; len(fields) > 0 {
+ _spec.Node.Columns = make([]string, 0, len(fields))
+ _spec.Node.Columns = append(_spec.Node.Columns, pendingauthsession.FieldID)
+ for _, f := range fields {
+ if !pendingauthsession.ValidColumn(f) {
+ return nil, &ValidationError{Name: f, err: fmt.Errorf("ent: invalid field %q for query", f)}
+ }
+ if f != pendingauthsession.FieldID {
+ _spec.Node.Columns = append(_spec.Node.Columns, f)
+ }
+ }
+ }
+ if ps := _u.mutation.predicates; len(ps) > 0 {
+ _spec.Predicate = func(selector *sql.Selector) {
+ for i := range ps {
+ ps[i](selector)
+ }
+ }
+ }
+ if value, ok := _u.mutation.UpdatedAt(); ok {
+ _spec.SetField(pendingauthsession.FieldUpdatedAt, field.TypeTime, value)
+ }
+ if value, ok := _u.mutation.SessionToken(); ok {
+ _spec.SetField(pendingauthsession.FieldSessionToken, field.TypeString, value)
+ }
+ if value, ok := _u.mutation.Intent(); ok {
+ _spec.SetField(pendingauthsession.FieldIntent, field.TypeString, value)
+ }
+ if value, ok := _u.mutation.ProviderType(); ok {
+ _spec.SetField(pendingauthsession.FieldProviderType, field.TypeString, value)
+ }
+ if value, ok := _u.mutation.ProviderKey(); ok {
+ _spec.SetField(pendingauthsession.FieldProviderKey, field.TypeString, value)
+ }
+ if value, ok := _u.mutation.ProviderSubject(); ok {
+ _spec.SetField(pendingauthsession.FieldProviderSubject, field.TypeString, value)
+ }
+ if value, ok := _u.mutation.RedirectTo(); ok {
+ _spec.SetField(pendingauthsession.FieldRedirectTo, field.TypeString, value)
+ }
+ if value, ok := _u.mutation.ResolvedEmail(); ok {
+ _spec.SetField(pendingauthsession.FieldResolvedEmail, field.TypeString, value)
+ }
+ if value, ok := _u.mutation.RegistrationPasswordHash(); ok {
+ _spec.SetField(pendingauthsession.FieldRegistrationPasswordHash, field.TypeString, value)
+ }
+ if value, ok := _u.mutation.UpstreamIdentityClaims(); ok {
+ _spec.SetField(pendingauthsession.FieldUpstreamIdentityClaims, field.TypeJSON, value)
+ }
+ if value, ok := _u.mutation.LocalFlowState(); ok {
+ _spec.SetField(pendingauthsession.FieldLocalFlowState, field.TypeJSON, value)
+ }
+ if value, ok := _u.mutation.BrowserSessionKey(); ok {
+ _spec.SetField(pendingauthsession.FieldBrowserSessionKey, field.TypeString, value)
+ }
+ if value, ok := _u.mutation.CompletionCodeHash(); ok {
+ _spec.SetField(pendingauthsession.FieldCompletionCodeHash, field.TypeString, value)
+ }
+ if value, ok := _u.mutation.CompletionCodeExpiresAt(); ok {
+ _spec.SetField(pendingauthsession.FieldCompletionCodeExpiresAt, field.TypeTime, value)
+ }
+ if _u.mutation.CompletionCodeExpiresAtCleared() {
+ _spec.ClearField(pendingauthsession.FieldCompletionCodeExpiresAt, field.TypeTime)
+ }
+ if value, ok := _u.mutation.EmailVerifiedAt(); ok {
+ _spec.SetField(pendingauthsession.FieldEmailVerifiedAt, field.TypeTime, value)
+ }
+ if _u.mutation.EmailVerifiedAtCleared() {
+ _spec.ClearField(pendingauthsession.FieldEmailVerifiedAt, field.TypeTime)
+ }
+ if value, ok := _u.mutation.PasswordVerifiedAt(); ok {
+ _spec.SetField(pendingauthsession.FieldPasswordVerifiedAt, field.TypeTime, value)
+ }
+ if _u.mutation.PasswordVerifiedAtCleared() {
+ _spec.ClearField(pendingauthsession.FieldPasswordVerifiedAt, field.TypeTime)
+ }
+ if value, ok := _u.mutation.TotpVerifiedAt(); ok {
+ _spec.SetField(pendingauthsession.FieldTotpVerifiedAt, field.TypeTime, value)
+ }
+ if _u.mutation.TotpVerifiedAtCleared() {
+ _spec.ClearField(pendingauthsession.FieldTotpVerifiedAt, field.TypeTime)
+ }
+ if value, ok := _u.mutation.ExpiresAt(); ok {
+ _spec.SetField(pendingauthsession.FieldExpiresAt, field.TypeTime, value)
+ }
+ if value, ok := _u.mutation.ConsumedAt(); ok {
+ _spec.SetField(pendingauthsession.FieldConsumedAt, field.TypeTime, value)
+ }
+ if _u.mutation.ConsumedAtCleared() {
+ _spec.ClearField(pendingauthsession.FieldConsumedAt, field.TypeTime)
+ }
+ if _u.mutation.TargetUserCleared() {
+ edge := &sqlgraph.EdgeSpec{
+ Rel: sqlgraph.M2O,
+ Inverse: true,
+ Table: pendingauthsession.TargetUserTable,
+ Columns: []string{pendingauthsession.TargetUserColumn},
+ Bidi: false,
+ Target: &sqlgraph.EdgeTarget{
+ IDSpec: sqlgraph.NewFieldSpec(user.FieldID, field.TypeInt64),
+ },
+ }
+ _spec.Edges.Clear = append(_spec.Edges.Clear, edge)
+ }
+ if nodes := _u.mutation.TargetUserIDs(); len(nodes) > 0 {
+ edge := &sqlgraph.EdgeSpec{
+ Rel: sqlgraph.M2O,
+ Inverse: true,
+ Table: pendingauthsession.TargetUserTable,
+ Columns: []string{pendingauthsession.TargetUserColumn},
+ Bidi: false,
+ Target: &sqlgraph.EdgeTarget{
+ IDSpec: sqlgraph.NewFieldSpec(user.FieldID, field.TypeInt64),
+ },
+ }
+ for _, k := range nodes {
+ edge.Target.Nodes = append(edge.Target.Nodes, k)
+ }
+ _spec.Edges.Add = append(_spec.Edges.Add, edge)
+ }
+ if _u.mutation.AdoptionDecisionCleared() {
+ edge := &sqlgraph.EdgeSpec{
+ Rel: sqlgraph.O2O,
+ Inverse: false,
+ Table: pendingauthsession.AdoptionDecisionTable,
+ Columns: []string{pendingauthsession.AdoptionDecisionColumn},
+ Bidi: false,
+ Target: &sqlgraph.EdgeTarget{
+ IDSpec: sqlgraph.NewFieldSpec(identityadoptiondecision.FieldID, field.TypeInt64),
+ },
+ }
+ _spec.Edges.Clear = append(_spec.Edges.Clear, edge)
+ }
+ if nodes := _u.mutation.AdoptionDecisionIDs(); len(nodes) > 0 {
+ edge := &sqlgraph.EdgeSpec{
+ Rel: sqlgraph.O2O,
+ Inverse: false,
+ Table: pendingauthsession.AdoptionDecisionTable,
+ Columns: []string{pendingauthsession.AdoptionDecisionColumn},
+ Bidi: false,
+ Target: &sqlgraph.EdgeTarget{
+ IDSpec: sqlgraph.NewFieldSpec(identityadoptiondecision.FieldID, field.TypeInt64),
+ },
+ }
+ for _, k := range nodes {
+ edge.Target.Nodes = append(edge.Target.Nodes, k)
+ }
+ _spec.Edges.Add = append(_spec.Edges.Add, edge)
+ }
+ _node = &PendingAuthSession{config: _u.config}
+ _spec.Assign = _node.assignValues
+ _spec.ScanValues = _node.scanValues
+ if err = sqlgraph.UpdateNode(ctx, _u.driver, _spec); err != nil {
+ if _, ok := err.(*sqlgraph.NotFoundError); ok {
+ err = &NotFoundError{pendingauthsession.Label}
+ } else if sqlgraph.IsConstraintError(err) {
+ err = &ConstraintError{msg: err.Error(), wrap: err}
+ }
+ return nil, err
+ }
+ _u.mutation.done = true
+ return _node, nil
+}
diff --git a/backend/ent/predicate/predicate.go b/backend/ent/predicate/predicate.go
index ef551940..0aa90b90 100644
--- a/backend/ent/predicate/predicate.go
+++ b/backend/ent/predicate/predicate.go
@@ -21,6 +21,12 @@ type Announcement func(*sql.Selector)
// AnnouncementRead is the predicate function for announcementread builders.
type AnnouncementRead func(*sql.Selector)
+// AuthIdentity is the predicate function for authidentity builders.
+type AuthIdentity func(*sql.Selector)
+
+// AuthIdentityChannel is the predicate function for authidentitychannel builders.
+type AuthIdentityChannel func(*sql.Selector)
+
// ErrorPassthroughRule is the predicate function for errorpassthroughrule builders.
type ErrorPassthroughRule func(*sql.Selector)
@@ -30,6 +36,9 @@ type Group func(*sql.Selector)
// IdempotencyRecord is the predicate function for idempotencyrecord builders.
type IdempotencyRecord func(*sql.Selector)
+// IdentityAdoptionDecision is the predicate function for identityadoptiondecision builders.
+type IdentityAdoptionDecision func(*sql.Selector)
+
// PaymentAuditLog is the predicate function for paymentauditlog builders.
type PaymentAuditLog func(*sql.Selector)
@@ -39,6 +48,9 @@ type PaymentOrder func(*sql.Selector)
// PaymentProviderInstance is the predicate function for paymentproviderinstance builders.
type PaymentProviderInstance func(*sql.Selector)
+// PendingAuthSession is the predicate function for pendingauthsession builders.
+type PendingAuthSession func(*sql.Selector)
+
// PromoCode is the predicate function for promocode builders.
type PromoCode func(*sql.Selector)
diff --git a/backend/ent/runtime/runtime.go b/backend/ent/runtime/runtime.go
index fbdd08c7..268e9ddb 100644
--- a/backend/ent/runtime/runtime.go
+++ b/backend/ent/runtime/runtime.go
@@ -10,12 +10,16 @@ import (
"github.com/Wei-Shaw/sub2api/ent/announcement"
"github.com/Wei-Shaw/sub2api/ent/announcementread"
"github.com/Wei-Shaw/sub2api/ent/apikey"
+ "github.com/Wei-Shaw/sub2api/ent/authidentity"
+ "github.com/Wei-Shaw/sub2api/ent/authidentitychannel"
"github.com/Wei-Shaw/sub2api/ent/errorpassthroughrule"
"github.com/Wei-Shaw/sub2api/ent/group"
"github.com/Wei-Shaw/sub2api/ent/idempotencyrecord"
+ "github.com/Wei-Shaw/sub2api/ent/identityadoptiondecision"
"github.com/Wei-Shaw/sub2api/ent/paymentauditlog"
"github.com/Wei-Shaw/sub2api/ent/paymentorder"
"github.com/Wei-Shaw/sub2api/ent/paymentproviderinstance"
+ "github.com/Wei-Shaw/sub2api/ent/pendingauthsession"
"github.com/Wei-Shaw/sub2api/ent/promocode"
"github.com/Wei-Shaw/sub2api/ent/promocodeusage"
"github.com/Wei-Shaw/sub2api/ent/proxy"
@@ -309,6 +313,120 @@ func init() {
announcementreadDescCreatedAt := announcementreadFields[3].Descriptor()
// announcementread.DefaultCreatedAt holds the default value on creation for the created_at field.
announcementread.DefaultCreatedAt = announcementreadDescCreatedAt.Default.(func() time.Time)
+ authidentityMixin := schema.AuthIdentity{}.Mixin()
+ authidentityMixinFields0 := authidentityMixin[0].Fields()
+ _ = authidentityMixinFields0
+ authidentityFields := schema.AuthIdentity{}.Fields()
+ _ = authidentityFields
+ // authidentityDescCreatedAt is the schema descriptor for created_at field.
+ authidentityDescCreatedAt := authidentityMixinFields0[0].Descriptor()
+ // authidentity.DefaultCreatedAt holds the default value on creation for the created_at field.
+ authidentity.DefaultCreatedAt = authidentityDescCreatedAt.Default.(func() time.Time)
+ // authidentityDescUpdatedAt is the schema descriptor for updated_at field.
+ authidentityDescUpdatedAt := authidentityMixinFields0[1].Descriptor()
+ // authidentity.DefaultUpdatedAt holds the default value on creation for the updated_at field.
+ authidentity.DefaultUpdatedAt = authidentityDescUpdatedAt.Default.(func() time.Time)
+ // authidentity.UpdateDefaultUpdatedAt holds the default value on update for the updated_at field.
+ authidentity.UpdateDefaultUpdatedAt = authidentityDescUpdatedAt.UpdateDefault.(func() time.Time)
+ // authidentityDescProviderType is the schema descriptor for provider_type field.
+ authidentityDescProviderType := authidentityFields[1].Descriptor()
+ // authidentity.ProviderTypeValidator is a validator for the "provider_type" field. It is called by the builders before save.
+ authidentity.ProviderTypeValidator = func() func(string) error {
+ validators := authidentityDescProviderType.Validators
+ fns := [...]func(string) error{
+ validators[0].(func(string) error),
+ validators[1].(func(string) error),
+ validators[2].(func(string) error),
+ }
+ return func(provider_type string) error {
+ for _, fn := range fns {
+ if err := fn(provider_type); err != nil {
+ return err
+ }
+ }
+ return nil
+ }
+ }()
+ // authidentityDescProviderKey is the schema descriptor for provider_key field.
+ authidentityDescProviderKey := authidentityFields[2].Descriptor()
+ // authidentity.ProviderKeyValidator is a validator for the "provider_key" field. It is called by the builders before save.
+ authidentity.ProviderKeyValidator = authidentityDescProviderKey.Validators[0].(func(string) error)
+ // authidentityDescProviderSubject is the schema descriptor for provider_subject field.
+ authidentityDescProviderSubject := authidentityFields[3].Descriptor()
+ // authidentity.ProviderSubjectValidator is a validator for the "provider_subject" field. It is called by the builders before save.
+ authidentity.ProviderSubjectValidator = authidentityDescProviderSubject.Validators[0].(func(string) error)
+ // authidentityDescMetadata is the schema descriptor for metadata field.
+ authidentityDescMetadata := authidentityFields[6].Descriptor()
+ // authidentity.DefaultMetadata holds the default value on creation for the metadata field.
+ authidentity.DefaultMetadata = authidentityDescMetadata.Default.(func() map[string]interface{})
+ authidentitychannelMixin := schema.AuthIdentityChannel{}.Mixin()
+ authidentitychannelMixinFields0 := authidentitychannelMixin[0].Fields()
+ _ = authidentitychannelMixinFields0
+ authidentitychannelFields := schema.AuthIdentityChannel{}.Fields()
+ _ = authidentitychannelFields
+ // authidentitychannelDescCreatedAt is the schema descriptor for created_at field.
+ authidentitychannelDescCreatedAt := authidentitychannelMixinFields0[0].Descriptor()
+ // authidentitychannel.DefaultCreatedAt holds the default value on creation for the created_at field.
+ authidentitychannel.DefaultCreatedAt = authidentitychannelDescCreatedAt.Default.(func() time.Time)
+ // authidentitychannelDescUpdatedAt is the schema descriptor for updated_at field.
+ authidentitychannelDescUpdatedAt := authidentitychannelMixinFields0[1].Descriptor()
+ // authidentitychannel.DefaultUpdatedAt holds the default value on creation for the updated_at field.
+ authidentitychannel.DefaultUpdatedAt = authidentitychannelDescUpdatedAt.Default.(func() time.Time)
+ // authidentitychannel.UpdateDefaultUpdatedAt holds the default value on update for the updated_at field.
+ authidentitychannel.UpdateDefaultUpdatedAt = authidentitychannelDescUpdatedAt.UpdateDefault.(func() time.Time)
+ // authidentitychannelDescProviderType is the schema descriptor for provider_type field.
+ authidentitychannelDescProviderType := authidentitychannelFields[1].Descriptor()
+ // authidentitychannel.ProviderTypeValidator is a validator for the "provider_type" field. It is called by the builders before save.
+ authidentitychannel.ProviderTypeValidator = func() func(string) error {
+ validators := authidentitychannelDescProviderType.Validators
+ fns := [...]func(string) error{
+ validators[0].(func(string) error),
+ validators[1].(func(string) error),
+ validators[2].(func(string) error),
+ }
+ return func(provider_type string) error {
+ for _, fn := range fns {
+ if err := fn(provider_type); err != nil {
+ return err
+ }
+ }
+ return nil
+ }
+ }()
+ // authidentitychannelDescProviderKey is the schema descriptor for provider_key field.
+ authidentitychannelDescProviderKey := authidentitychannelFields[2].Descriptor()
+ // authidentitychannel.ProviderKeyValidator is a validator for the "provider_key" field. It is called by the builders before save.
+ authidentitychannel.ProviderKeyValidator = authidentitychannelDescProviderKey.Validators[0].(func(string) error)
+ // authidentitychannelDescChannel is the schema descriptor for channel field.
+ authidentitychannelDescChannel := authidentitychannelFields[3].Descriptor()
+ // authidentitychannel.ChannelValidator is a validator for the "channel" field. It is called by the builders before save.
+ authidentitychannel.ChannelValidator = func() func(string) error {
+ validators := authidentitychannelDescChannel.Validators
+ fns := [...]func(string) error{
+ validators[0].(func(string) error),
+ validators[1].(func(string) error),
+ }
+ return func(channel string) error {
+ for _, fn := range fns {
+ if err := fn(channel); err != nil {
+ return err
+ }
+ }
+ return nil
+ }
+ }()
+ // authidentitychannelDescChannelAppID is the schema descriptor for channel_app_id field.
+ authidentitychannelDescChannelAppID := authidentitychannelFields[4].Descriptor()
+ // authidentitychannel.ChannelAppIDValidator is a validator for the "channel_app_id" field. It is called by the builders before save.
+ authidentitychannel.ChannelAppIDValidator = authidentitychannelDescChannelAppID.Validators[0].(func(string) error)
+ // authidentitychannelDescChannelSubject is the schema descriptor for channel_subject field.
+ authidentitychannelDescChannelSubject := authidentitychannelFields[5].Descriptor()
+ // authidentitychannel.ChannelSubjectValidator is a validator for the "channel_subject" field. It is called by the builders before save.
+ authidentitychannel.ChannelSubjectValidator = authidentitychannelDescChannelSubject.Validators[0].(func(string) error)
+ // authidentitychannelDescMetadata is the schema descriptor for metadata field.
+ authidentitychannelDescMetadata := authidentitychannelFields[6].Descriptor()
+ // authidentitychannel.DefaultMetadata holds the default value on creation for the metadata field.
+ authidentitychannel.DefaultMetadata = authidentitychannelDescMetadata.Default.(func() map[string]interface{})
errorpassthroughruleMixin := schema.ErrorPassthroughRule{}.Mixin()
errorpassthroughruleMixinFields0 := errorpassthroughruleMixin[0].Fields()
_ = errorpassthroughruleMixinFields0
@@ -512,6 +630,33 @@ func init() {
idempotencyrecordDescErrorReason := idempotencyrecordFields[6].Descriptor()
// idempotencyrecord.ErrorReasonValidator is a validator for the "error_reason" field. It is called by the builders before save.
idempotencyrecord.ErrorReasonValidator = idempotencyrecordDescErrorReason.Validators[0].(func(string) error)
+ identityadoptiondecisionMixin := schema.IdentityAdoptionDecision{}.Mixin()
+ identityadoptiondecisionMixinFields0 := identityadoptiondecisionMixin[0].Fields()
+ _ = identityadoptiondecisionMixinFields0
+ identityadoptiondecisionFields := schema.IdentityAdoptionDecision{}.Fields()
+ _ = identityadoptiondecisionFields
+ // identityadoptiondecisionDescCreatedAt is the schema descriptor for created_at field.
+ identityadoptiondecisionDescCreatedAt := identityadoptiondecisionMixinFields0[0].Descriptor()
+ // identityadoptiondecision.DefaultCreatedAt holds the default value on creation for the created_at field.
+ identityadoptiondecision.DefaultCreatedAt = identityadoptiondecisionDescCreatedAt.Default.(func() time.Time)
+ // identityadoptiondecisionDescUpdatedAt is the schema descriptor for updated_at field.
+ identityadoptiondecisionDescUpdatedAt := identityadoptiondecisionMixinFields0[1].Descriptor()
+ // identityadoptiondecision.DefaultUpdatedAt holds the default value on creation for the updated_at field.
+ identityadoptiondecision.DefaultUpdatedAt = identityadoptiondecisionDescUpdatedAt.Default.(func() time.Time)
+ // identityadoptiondecision.UpdateDefaultUpdatedAt holds the default value on update for the updated_at field.
+ identityadoptiondecision.UpdateDefaultUpdatedAt = identityadoptiondecisionDescUpdatedAt.UpdateDefault.(func() time.Time)
+ // identityadoptiondecisionDescAdoptDisplayName is the schema descriptor for adopt_display_name field.
+ identityadoptiondecisionDescAdoptDisplayName := identityadoptiondecisionFields[2].Descriptor()
+ // identityadoptiondecision.DefaultAdoptDisplayName holds the default value on creation for the adopt_display_name field.
+ identityadoptiondecision.DefaultAdoptDisplayName = identityadoptiondecisionDescAdoptDisplayName.Default.(bool)
+ // identityadoptiondecisionDescAdoptAvatar is the schema descriptor for adopt_avatar field.
+ identityadoptiondecisionDescAdoptAvatar := identityadoptiondecisionFields[3].Descriptor()
+ // identityadoptiondecision.DefaultAdoptAvatar holds the default value on creation for the adopt_avatar field.
+ identityadoptiondecision.DefaultAdoptAvatar = identityadoptiondecisionDescAdoptAvatar.Default.(bool)
+ // identityadoptiondecisionDescDecidedAt is the schema descriptor for decided_at field.
+ identityadoptiondecisionDescDecidedAt := identityadoptiondecisionFields[4].Descriptor()
+ // identityadoptiondecision.DefaultDecidedAt holds the default value on creation for the decided_at field.
+ identityadoptiondecision.DefaultDecidedAt = identityadoptiondecisionDescDecidedAt.Default.(func() time.Time)
paymentauditlogFields := schema.PaymentAuditLog{}.Fields()
_ = paymentauditlogFields
// paymentauditlogDescOrderID is the schema descriptor for order_id field.
@@ -682,6 +827,113 @@ func init() {
paymentproviderinstance.DefaultUpdatedAt = paymentproviderinstanceDescUpdatedAt.Default.(func() time.Time)
// paymentproviderinstance.UpdateDefaultUpdatedAt holds the default value on update for the updated_at field.
paymentproviderinstance.UpdateDefaultUpdatedAt = paymentproviderinstanceDescUpdatedAt.UpdateDefault.(func() time.Time)
+ pendingauthsessionMixin := schema.PendingAuthSession{}.Mixin()
+ pendingauthsessionMixinFields0 := pendingauthsessionMixin[0].Fields()
+ _ = pendingauthsessionMixinFields0
+ pendingauthsessionFields := schema.PendingAuthSession{}.Fields()
+ _ = pendingauthsessionFields
+ // pendingauthsessionDescCreatedAt is the schema descriptor for created_at field.
+ pendingauthsessionDescCreatedAt := pendingauthsessionMixinFields0[0].Descriptor()
+ // pendingauthsession.DefaultCreatedAt holds the default value on creation for the created_at field.
+ pendingauthsession.DefaultCreatedAt = pendingauthsessionDescCreatedAt.Default.(func() time.Time)
+ // pendingauthsessionDescUpdatedAt is the schema descriptor for updated_at field.
+ pendingauthsessionDescUpdatedAt := pendingauthsessionMixinFields0[1].Descriptor()
+ // pendingauthsession.DefaultUpdatedAt holds the default value on creation for the updated_at field.
+ pendingauthsession.DefaultUpdatedAt = pendingauthsessionDescUpdatedAt.Default.(func() time.Time)
+ // pendingauthsession.UpdateDefaultUpdatedAt holds the default value on update for the updated_at field.
+ pendingauthsession.UpdateDefaultUpdatedAt = pendingauthsessionDescUpdatedAt.UpdateDefault.(func() time.Time)
+ // pendingauthsessionDescSessionToken is the schema descriptor for session_token field.
+ pendingauthsessionDescSessionToken := pendingauthsessionFields[0].Descriptor()
+ // pendingauthsession.SessionTokenValidator is a validator for the "session_token" field. It is called by the builders before save.
+ pendingauthsession.SessionTokenValidator = func() func(string) error {
+ validators := pendingauthsessionDescSessionToken.Validators
+ fns := [...]func(string) error{
+ validators[0].(func(string) error),
+ validators[1].(func(string) error),
+ }
+ return func(session_token string) error {
+ for _, fn := range fns {
+ if err := fn(session_token); err != nil {
+ return err
+ }
+ }
+ return nil
+ }
+ }()
+ // pendingauthsessionDescIntent is the schema descriptor for intent field.
+ pendingauthsessionDescIntent := pendingauthsessionFields[1].Descriptor()
+ // pendingauthsession.IntentValidator is a validator for the "intent" field. It is called by the builders before save.
+ pendingauthsession.IntentValidator = func() func(string) error {
+ validators := pendingauthsessionDescIntent.Validators
+ fns := [...]func(string) error{
+ validators[0].(func(string) error),
+ validators[1].(func(string) error),
+ validators[2].(func(string) error),
+ }
+ return func(intent string) error {
+ for _, fn := range fns {
+ if err := fn(intent); err != nil {
+ return err
+ }
+ }
+ return nil
+ }
+ }()
+ // pendingauthsessionDescProviderType is the schema descriptor for provider_type field.
+ pendingauthsessionDescProviderType := pendingauthsessionFields[2].Descriptor()
+ // pendingauthsession.ProviderTypeValidator is a validator for the "provider_type" field. It is called by the builders before save.
+ pendingauthsession.ProviderTypeValidator = func() func(string) error {
+ validators := pendingauthsessionDescProviderType.Validators
+ fns := [...]func(string) error{
+ validators[0].(func(string) error),
+ validators[1].(func(string) error),
+ validators[2].(func(string) error),
+ }
+ return func(provider_type string) error {
+ for _, fn := range fns {
+ if err := fn(provider_type); err != nil {
+ return err
+ }
+ }
+ return nil
+ }
+ }()
+ // pendingauthsessionDescProviderKey is the schema descriptor for provider_key field.
+ pendingauthsessionDescProviderKey := pendingauthsessionFields[3].Descriptor()
+ // pendingauthsession.ProviderKeyValidator is a validator for the "provider_key" field. It is called by the builders before save.
+ pendingauthsession.ProviderKeyValidator = pendingauthsessionDescProviderKey.Validators[0].(func(string) error)
+ // pendingauthsessionDescProviderSubject is the schema descriptor for provider_subject field.
+ pendingauthsessionDescProviderSubject := pendingauthsessionFields[4].Descriptor()
+ // pendingauthsession.ProviderSubjectValidator is a validator for the "provider_subject" field. It is called by the builders before save.
+ pendingauthsession.ProviderSubjectValidator = pendingauthsessionDescProviderSubject.Validators[0].(func(string) error)
+ // pendingauthsessionDescRedirectTo is the schema descriptor for redirect_to field.
+ pendingauthsessionDescRedirectTo := pendingauthsessionFields[6].Descriptor()
+ // pendingauthsession.DefaultRedirectTo holds the default value on creation for the redirect_to field.
+ pendingauthsession.DefaultRedirectTo = pendingauthsessionDescRedirectTo.Default.(string)
+ // pendingauthsessionDescResolvedEmail is the schema descriptor for resolved_email field.
+ pendingauthsessionDescResolvedEmail := pendingauthsessionFields[7].Descriptor()
+ // pendingauthsession.DefaultResolvedEmail holds the default value on creation for the resolved_email field.
+ pendingauthsession.DefaultResolvedEmail = pendingauthsessionDescResolvedEmail.Default.(string)
+ // pendingauthsessionDescRegistrationPasswordHash is the schema descriptor for registration_password_hash field.
+ pendingauthsessionDescRegistrationPasswordHash := pendingauthsessionFields[8].Descriptor()
+ // pendingauthsession.DefaultRegistrationPasswordHash holds the default value on creation for the registration_password_hash field.
+ pendingauthsession.DefaultRegistrationPasswordHash = pendingauthsessionDescRegistrationPasswordHash.Default.(string)
+ // pendingauthsessionDescUpstreamIdentityClaims is the schema descriptor for upstream_identity_claims field.
+ pendingauthsessionDescUpstreamIdentityClaims := pendingauthsessionFields[9].Descriptor()
+ // pendingauthsession.DefaultUpstreamIdentityClaims holds the default value on creation for the upstream_identity_claims field.
+ pendingauthsession.DefaultUpstreamIdentityClaims = pendingauthsessionDescUpstreamIdentityClaims.Default.(func() map[string]interface{})
+ // pendingauthsessionDescLocalFlowState is the schema descriptor for local_flow_state field.
+ pendingauthsessionDescLocalFlowState := pendingauthsessionFields[10].Descriptor()
+ // pendingauthsession.DefaultLocalFlowState holds the default value on creation for the local_flow_state field.
+ pendingauthsession.DefaultLocalFlowState = pendingauthsessionDescLocalFlowState.Default.(func() map[string]interface{})
+ // pendingauthsessionDescBrowserSessionKey is the schema descriptor for browser_session_key field.
+ pendingauthsessionDescBrowserSessionKey := pendingauthsessionFields[11].Descriptor()
+ // pendingauthsession.DefaultBrowserSessionKey holds the default value on creation for the browser_session_key field.
+ pendingauthsession.DefaultBrowserSessionKey = pendingauthsessionDescBrowserSessionKey.Default.(string)
+ // pendingauthsessionDescCompletionCodeHash is the schema descriptor for completion_code_hash field.
+ pendingauthsessionDescCompletionCodeHash := pendingauthsessionFields[12].Descriptor()
+ // pendingauthsession.DefaultCompletionCodeHash holds the default value on creation for the completion_code_hash field.
+ pendingauthsession.DefaultCompletionCodeHash = pendingauthsessionDescCompletionCodeHash.Default.(string)
promocodeFields := schema.PromoCode{}.Fields()
_ = promocodeFields
// promocodeDescCode is the schema descriptor for code field.
@@ -1297,20 +1549,26 @@ func init() {
userDescTotpEnabled := userFields[9].Descriptor()
// user.DefaultTotpEnabled holds the default value on creation for the totp_enabled field.
user.DefaultTotpEnabled = userDescTotpEnabled.Default.(bool)
+ // userDescSignupSource is the schema descriptor for signup_source field.
+ userDescSignupSource := userFields[11].Descriptor()
+ // user.DefaultSignupSource holds the default value on creation for the signup_source field.
+ user.DefaultSignupSource = userDescSignupSource.Default.(string)
+ // user.SignupSourceValidator is a validator for the "signup_source" field. It is called by the builders before save.
+ user.SignupSourceValidator = userDescSignupSource.Validators[0].(func(string) error)
// userDescBalanceNotifyEnabled is the schema descriptor for balance_notify_enabled field.
- userDescBalanceNotifyEnabled := userFields[11].Descriptor()
+ userDescBalanceNotifyEnabled := userFields[14].Descriptor()
// user.DefaultBalanceNotifyEnabled holds the default value on creation for the balance_notify_enabled field.
user.DefaultBalanceNotifyEnabled = userDescBalanceNotifyEnabled.Default.(bool)
// userDescBalanceNotifyThresholdType is the schema descriptor for balance_notify_threshold_type field.
- userDescBalanceNotifyThresholdType := userFields[12].Descriptor()
+ userDescBalanceNotifyThresholdType := userFields[15].Descriptor()
// user.DefaultBalanceNotifyThresholdType holds the default value on creation for the balance_notify_threshold_type field.
user.DefaultBalanceNotifyThresholdType = userDescBalanceNotifyThresholdType.Default.(string)
// userDescBalanceNotifyExtraEmails is the schema descriptor for balance_notify_extra_emails field.
- userDescBalanceNotifyExtraEmails := userFields[14].Descriptor()
+ userDescBalanceNotifyExtraEmails := userFields[17].Descriptor()
// user.DefaultBalanceNotifyExtraEmails holds the default value on creation for the balance_notify_extra_emails field.
user.DefaultBalanceNotifyExtraEmails = userDescBalanceNotifyExtraEmails.Default.(string)
// userDescTotalRecharged is the schema descriptor for total_recharged field.
- userDescTotalRecharged := userFields[15].Descriptor()
+ userDescTotalRecharged := userFields[18].Descriptor()
// user.DefaultTotalRecharged holds the default value on creation for the total_recharged field.
user.DefaultTotalRecharged = userDescTotalRecharged.Default.(float64)
userallowedgroupFields := schema.UserAllowedGroup{}.Fields()
diff --git a/backend/ent/schema/auth_identity.go b/backend/ent/schema/auth_identity.go
new file mode 100644
index 00000000..e4b9ac90
--- /dev/null
+++ b/backend/ent/schema/auth_identity.go
@@ -0,0 +1,93 @@
+package schema
+
+import (
+ "fmt"
+
+ "github.com/Wei-Shaw/sub2api/ent/schema/mixins"
+
+ "entgo.io/ent"
+ "entgo.io/ent/dialect"
+ "entgo.io/ent/dialect/entsql"
+ "entgo.io/ent/schema"
+ "entgo.io/ent/schema/edge"
+ "entgo.io/ent/schema/field"
+ "entgo.io/ent/schema/index"
+)
+
+var authProviderTypes = map[string]struct{}{
+ "email": {},
+ "linuxdo": {},
+ "oidc": {},
+ "wechat": {},
+}
+
+func validateAuthProviderType(value string) error {
+ if _, ok := authProviderTypes[value]; ok {
+ return nil
+ }
+ return fmt.Errorf("invalid auth provider type %q", value)
+}
+
+// AuthIdentity stores the canonical login identity for an account.
+type AuthIdentity struct {
+ ent.Schema
+}
+
+func (AuthIdentity) Annotations() []schema.Annotation {
+ return []schema.Annotation{
+ entsql.Annotation{Table: "auth_identities"},
+ }
+}
+
+func (AuthIdentity) Mixin() []ent.Mixin {
+ return []ent.Mixin{
+ mixins.TimeMixin{},
+ }
+}
+
+func (AuthIdentity) Fields() []ent.Field {
+ return []ent.Field{
+ field.Int64("user_id"),
+ field.String("provider_type").
+ MaxLen(20).
+ NotEmpty().
+ Validate(validateAuthProviderType),
+ field.String("provider_key").
+ NotEmpty().
+ SchemaType(map[string]string{dialect.Postgres: "text"}),
+ field.String("provider_subject").
+ NotEmpty().
+ SchemaType(map[string]string{dialect.Postgres: "text"}),
+ field.Time("verified_at").
+ Optional().
+ Nillable().
+ SchemaType(map[string]string{dialect.Postgres: "timestamptz"}),
+ field.String("issuer").
+ Optional().
+ Nillable().
+ SchemaType(map[string]string{dialect.Postgres: "text"}),
+ field.JSON("metadata", map[string]any{}).
+ Default(func() map[string]any { return map[string]any{} }).
+ SchemaType(map[string]string{dialect.Postgres: "jsonb"}),
+ }
+}
+
+func (AuthIdentity) Edges() []ent.Edge {
+ return []ent.Edge{
+ edge.From("user", User.Type).
+ Ref("auth_identities").
+ Field("user_id").
+ Required().
+ Unique(),
+ edge.To("channels", AuthIdentityChannel.Type),
+ edge.To("adoption_decisions", IdentityAdoptionDecision.Type),
+ }
+}
+
+func (AuthIdentity) Indexes() []ent.Index {
+ return []ent.Index{
+ index.Fields("provider_type", "provider_key", "provider_subject").Unique(),
+ index.Fields("user_id"),
+ index.Fields("user_id", "provider_type"),
+ }
+}
diff --git a/backend/ent/schema/auth_identity_channel.go b/backend/ent/schema/auth_identity_channel.go
new file mode 100644
index 00000000..69f2ad02
--- /dev/null
+++ b/backend/ent/schema/auth_identity_channel.go
@@ -0,0 +1,72 @@
+package schema
+
+import (
+ "github.com/Wei-Shaw/sub2api/ent/schema/mixins"
+
+ "entgo.io/ent"
+ "entgo.io/ent/dialect"
+ "entgo.io/ent/dialect/entsql"
+ "entgo.io/ent/schema"
+ "entgo.io/ent/schema/edge"
+ "entgo.io/ent/schema/field"
+ "entgo.io/ent/schema/index"
+)
+
+// AuthIdentityChannel stores channel-scoped identifiers for a canonical identity.
+type AuthIdentityChannel struct {
+ ent.Schema
+}
+
+func (AuthIdentityChannel) Annotations() []schema.Annotation {
+ return []schema.Annotation{
+ entsql.Annotation{Table: "auth_identity_channels"},
+ }
+}
+
+func (AuthIdentityChannel) Mixin() []ent.Mixin {
+ return []ent.Mixin{
+ mixins.TimeMixin{},
+ }
+}
+
+func (AuthIdentityChannel) Fields() []ent.Field {
+ return []ent.Field{
+ field.Int64("identity_id"),
+ field.String("provider_type").
+ MaxLen(20).
+ NotEmpty().
+ Validate(validateAuthProviderType),
+ field.String("provider_key").
+ NotEmpty().
+ SchemaType(map[string]string{dialect.Postgres: "text"}),
+ field.String("channel").
+ MaxLen(20).
+ NotEmpty(),
+ field.String("channel_app_id").
+ NotEmpty().
+ SchemaType(map[string]string{dialect.Postgres: "text"}),
+ field.String("channel_subject").
+ NotEmpty().
+ SchemaType(map[string]string{dialect.Postgres: "text"}),
+ field.JSON("metadata", map[string]any{}).
+ Default(func() map[string]any { return map[string]any{} }).
+ SchemaType(map[string]string{dialect.Postgres: "jsonb"}),
+ }
+}
+
+func (AuthIdentityChannel) Edges() []ent.Edge {
+ return []ent.Edge{
+ edge.From("identity", AuthIdentity.Type).
+ Ref("channels").
+ Field("identity_id").
+ Required().
+ Unique(),
+ }
+}
+
+func (AuthIdentityChannel) Indexes() []ent.Index {
+ return []ent.Index{
+ index.Fields("provider_type", "provider_key", "channel", "channel_app_id", "channel_subject").Unique(),
+ index.Fields("identity_id"),
+ }
+}
diff --git a/backend/ent/schema/auth_identity_schema_test.go b/backend/ent/schema/auth_identity_schema_test.go
new file mode 100644
index 00000000..de55dd69
--- /dev/null
+++ b/backend/ent/schema/auth_identity_schema_test.go
@@ -0,0 +1,124 @@
+package schema
+
+import (
+ "testing"
+
+ "entgo.io/ent/entc/load"
+ "github.com/stretchr/testify/require"
+)
+
+func TestAuthIdentityFoundationSchemas(t *testing.T) {
+ spec, err := (&load.Config{Path: "."}).Load()
+ require.NoError(t, err)
+
+ schemas := map[string]*load.Schema{}
+ for _, schema := range spec.Schemas {
+ schemas[schema.Name] = schema
+ }
+
+ authIdentity := requireSchema(t, schemas, "AuthIdentity")
+ requireSchemaFields(t, authIdentity,
+ "user_id",
+ "provider_type",
+ "provider_key",
+ "provider_subject",
+ "verified_at",
+ "issuer",
+ "metadata",
+ )
+ requireHasUniqueIndex(t, authIdentity, "provider_type", "provider_key", "provider_subject")
+
+ authIdentityChannel := requireSchema(t, schemas, "AuthIdentityChannel")
+ requireSchemaFields(t, authIdentityChannel,
+ "identity_id",
+ "provider_type",
+ "provider_key",
+ "channel",
+ "channel_app_id",
+ "channel_subject",
+ "metadata",
+ )
+ requireHasUniqueIndex(t, authIdentityChannel, "provider_type", "provider_key", "channel", "channel_app_id", "channel_subject")
+
+ pendingAuthSession := requireSchema(t, schemas, "PendingAuthSession")
+ requireSchemaFields(t, pendingAuthSession,
+ "intent",
+ "provider_type",
+ "provider_key",
+ "provider_subject",
+ "target_user_id",
+ "redirect_to",
+ "resolved_email",
+ "registration_password_hash",
+ "upstream_identity_claims",
+ "local_flow_state",
+ "browser_session_key",
+ "completion_code_hash",
+ "completion_code_expires_at",
+ "email_verified_at",
+ "password_verified_at",
+ "totp_verified_at",
+ "expires_at",
+ "consumed_at",
+ )
+
+ adoptionDecision := requireSchema(t, schemas, "IdentityAdoptionDecision")
+ requireSchemaFields(t, adoptionDecision,
+ "pending_auth_session_id",
+ "identity_id",
+ "adopt_display_name",
+ "adopt_avatar",
+ "decided_at",
+ )
+ requireHasUniqueIndex(t, adoptionDecision, "pending_auth_session_id")
+
+ userSchema := requireSchema(t, schemas, "User")
+ requireSchemaFields(t, userSchema, "signup_source", "last_login_at", "last_active_at")
+}
+
+func requireSchema(t *testing.T, schemas map[string]*load.Schema, name string) *load.Schema {
+ t.Helper()
+
+ schema, ok := schemas[name]
+ require.True(t, ok, "schema %s should exist", name)
+ return schema
+}
+
+func requireSchemaFields(t *testing.T, schema *load.Schema, names ...string) {
+ t.Helper()
+
+ fields := map[string]struct{}{}
+ for _, field := range schema.Fields {
+ fields[field.Name] = struct{}{}
+ }
+
+ for _, name := range names {
+ _, ok := fields[name]
+ require.True(t, ok, "schema %s should include field %s", schema.Name, name)
+ }
+}
+
+func requireHasUniqueIndex(t *testing.T, schema *load.Schema, fields ...string) {
+ t.Helper()
+
+ for _, index := range schema.Indexes {
+ if !index.Unique {
+ continue
+ }
+ if len(index.Fields) != len(fields) {
+ continue
+ }
+ match := true
+ for i := range fields {
+ if index.Fields[i] != fields[i] {
+ match = false
+ break
+ }
+ }
+ if match {
+ return
+ }
+ }
+
+ require.Failf(t, "missing unique index", "schema %s should include unique index on %v", schema.Name, fields)
+}
diff --git a/backend/ent/schema/identity_adoption_decision.go b/backend/ent/schema/identity_adoption_decision.go
new file mode 100644
index 00000000..9fdd26fb
--- /dev/null
+++ b/backend/ent/schema/identity_adoption_decision.go
@@ -0,0 +1,70 @@
+package schema
+
+import (
+ "time"
+
+ "github.com/Wei-Shaw/sub2api/ent/schema/mixins"
+
+ "entgo.io/ent"
+ "entgo.io/ent/dialect"
+ "entgo.io/ent/dialect/entsql"
+ "entgo.io/ent/schema"
+ "entgo.io/ent/schema/edge"
+ "entgo.io/ent/schema/field"
+ "entgo.io/ent/schema/index"
+)
+
+// IdentityAdoptionDecision stores the one-time profile adoption choice captured during a pending auth flow.
+type IdentityAdoptionDecision struct {
+ ent.Schema
+}
+
+func (IdentityAdoptionDecision) Annotations() []schema.Annotation {
+ return []schema.Annotation{
+ entsql.Annotation{Table: "identity_adoption_decisions"},
+ }
+}
+
+func (IdentityAdoptionDecision) Mixin() []ent.Mixin {
+ return []ent.Mixin{
+ mixins.TimeMixin{},
+ }
+}
+
+func (IdentityAdoptionDecision) Fields() []ent.Field {
+ return []ent.Field{
+ field.Int64("pending_auth_session_id"),
+ field.Int64("identity_id").
+ Optional().
+ Nillable(),
+ field.Bool("adopt_display_name").
+ Default(false),
+ field.Bool("adopt_avatar").
+ Default(false),
+ field.Time("decided_at").
+ Immutable().
+ Default(time.Now).
+ SchemaType(map[string]string{dialect.Postgres: "timestamptz"}),
+ }
+}
+
+func (IdentityAdoptionDecision) Edges() []ent.Edge {
+ return []ent.Edge{
+ edge.From("pending_auth_session", PendingAuthSession.Type).
+ Ref("adoption_decision").
+ Field("pending_auth_session_id").
+ Required().
+ Unique(),
+ edge.From("identity", AuthIdentity.Type).
+ Ref("adoption_decisions").
+ Field("identity_id").
+ Unique(),
+ }
+}
+
+func (IdentityAdoptionDecision) Indexes() []ent.Index {
+ return []ent.Index{
+ index.Fields("pending_auth_session_id").Unique(),
+ index.Fields("identity_id"),
+ }
+}
diff --git a/backend/ent/schema/pending_auth_session.go b/backend/ent/schema/pending_auth_session.go
new file mode 100644
index 00000000..91341d49
--- /dev/null
+++ b/backend/ent/schema/pending_auth_session.go
@@ -0,0 +1,134 @@
+package schema
+
+import (
+ "fmt"
+
+ "github.com/Wei-Shaw/sub2api/ent/schema/mixins"
+
+ "entgo.io/ent"
+ "entgo.io/ent/dialect"
+ "entgo.io/ent/dialect/entsql"
+ "entgo.io/ent/schema"
+ "entgo.io/ent/schema/edge"
+ "entgo.io/ent/schema/field"
+ "entgo.io/ent/schema/index"
+)
+
+var pendingAuthIntents = map[string]struct{}{
+ "login": {},
+ "bind_current_user": {},
+ "adopt_existing_user_by_email": {},
+}
+
+func validatePendingAuthIntent(value string) error {
+ if _, ok := pendingAuthIntents[value]; ok {
+ return nil
+ }
+ return fmt.Errorf("invalid pending auth intent %q", value)
+}
+
+// PendingAuthSession stores a short-lived post-auth decision session.
+type PendingAuthSession struct {
+ ent.Schema
+}
+
+func (PendingAuthSession) Annotations() []schema.Annotation {
+ return []schema.Annotation{
+ entsql.Annotation{Table: "pending_auth_sessions"},
+ }
+}
+
+func (PendingAuthSession) Mixin() []ent.Mixin {
+ return []ent.Mixin{
+ mixins.TimeMixin{},
+ }
+}
+
+func (PendingAuthSession) Fields() []ent.Field {
+ return []ent.Field{
+ field.String("session_token").
+ MaxLen(255).
+ NotEmpty(),
+ field.String("intent").
+ MaxLen(40).
+ NotEmpty().
+ Validate(validatePendingAuthIntent),
+ field.String("provider_type").
+ MaxLen(20).
+ NotEmpty().
+ Validate(validateAuthProviderType),
+ field.String("provider_key").
+ NotEmpty().
+ SchemaType(map[string]string{dialect.Postgres: "text"}),
+ field.String("provider_subject").
+ NotEmpty().
+ SchemaType(map[string]string{dialect.Postgres: "text"}),
+ field.Int64("target_user_id").
+ Optional().
+ Nillable(),
+ field.String("redirect_to").
+ Default("").
+ SchemaType(map[string]string{dialect.Postgres: "text"}),
+ field.String("resolved_email").
+ Default("").
+ SchemaType(map[string]string{dialect.Postgres: "text"}),
+ field.String("registration_password_hash").
+ Default("").
+ SchemaType(map[string]string{dialect.Postgres: "text"}),
+ field.JSON("upstream_identity_claims", map[string]any{}).
+ Default(func() map[string]any { return map[string]any{} }).
+ SchemaType(map[string]string{dialect.Postgres: "jsonb"}),
+ field.JSON("local_flow_state", map[string]any{}).
+ Default(func() map[string]any { return map[string]any{} }).
+ SchemaType(map[string]string{dialect.Postgres: "jsonb"}),
+ field.String("browser_session_key").
+ Default("").
+ SchemaType(map[string]string{dialect.Postgres: "text"}),
+ field.String("completion_code_hash").
+ Default("").
+ SchemaType(map[string]string{dialect.Postgres: "text"}),
+ field.Time("completion_code_expires_at").
+ Optional().
+ Nillable().
+ SchemaType(map[string]string{dialect.Postgres: "timestamptz"}),
+ field.Time("email_verified_at").
+ Optional().
+ Nillable().
+ SchemaType(map[string]string{dialect.Postgres: "timestamptz"}),
+ field.Time("password_verified_at").
+ Optional().
+ Nillable().
+ SchemaType(map[string]string{dialect.Postgres: "timestamptz"}),
+ field.Time("totp_verified_at").
+ Optional().
+ Nillable().
+ SchemaType(map[string]string{dialect.Postgres: "timestamptz"}),
+ field.Time("expires_at").
+ SchemaType(map[string]string{dialect.Postgres: "timestamptz"}),
+ field.Time("consumed_at").
+ Optional().
+ Nillable().
+ SchemaType(map[string]string{dialect.Postgres: "timestamptz"}),
+ }
+}
+
+func (PendingAuthSession) Edges() []ent.Edge {
+ return []ent.Edge{
+ edge.From("target_user", User.Type).
+ Ref("pending_auth_sessions").
+ Field("target_user_id").
+ Unique(),
+ edge.To("adoption_decision", IdentityAdoptionDecision.Type).
+ Unique(),
+ }
+}
+
+func (PendingAuthSession) Indexes() []ent.Index {
+ return []ent.Index{
+ index.Fields("session_token").Unique(),
+ index.Fields("target_user_id"),
+ index.Fields("expires_at"),
+ index.Fields("provider_type", "provider_key", "provider_subject"),
+ index.Fields("completion_code_hash"),
+ }
+}
diff --git a/backend/ent/schema/user.go b/backend/ent/schema/user.go
index ef52e985..bb58d9e3 100644
--- a/backend/ent/schema/user.go
+++ b/backend/ent/schema/user.go
@@ -72,6 +72,17 @@ func (User) Fields() []ent.Field {
field.Time("totp_enabled_at").
Optional().
Nillable(),
+ field.String("signup_source").
+ MaxLen(20).
+ Default("email"),
+ field.Time("last_login_at").
+ Optional().
+ Nillable().
+ SchemaType(map[string]string{dialect.Postgres: "timestamptz"}),
+ field.Time("last_active_at").
+ Optional().
+ Nillable().
+ SchemaType(map[string]string{dialect.Postgres: "timestamptz"}),
// 余额不足通知
field.Bool("balance_notify_enabled").
@@ -104,6 +115,8 @@ func (User) Edges() []ent.Edge {
edge.To("attribute_values", UserAttributeValue.Type),
edge.To("promo_code_usages", PromoCodeUsage.Type),
edge.To("payment_orders", PaymentOrder.Type),
+ edge.To("auth_identities", AuthIdentity.Type),
+ edge.To("pending_auth_sessions", PendingAuthSession.Type),
}
}
diff --git a/backend/ent/tx.go b/backend/ent/tx.go
index bb3139d5..bde3e35b 100644
--- a/backend/ent/tx.go
+++ b/backend/ent/tx.go
@@ -24,18 +24,26 @@ type Tx struct {
Announcement *AnnouncementClient
// AnnouncementRead is the client for interacting with the AnnouncementRead builders.
AnnouncementRead *AnnouncementReadClient
+ // AuthIdentity is the client for interacting with the AuthIdentity builders.
+ AuthIdentity *AuthIdentityClient
+ // AuthIdentityChannel is the client for interacting with the AuthIdentityChannel builders.
+ AuthIdentityChannel *AuthIdentityChannelClient
// ErrorPassthroughRule is the client for interacting with the ErrorPassthroughRule builders.
ErrorPassthroughRule *ErrorPassthroughRuleClient
// Group is the client for interacting with the Group builders.
Group *GroupClient
// IdempotencyRecord is the client for interacting with the IdempotencyRecord builders.
IdempotencyRecord *IdempotencyRecordClient
+ // IdentityAdoptionDecision is the client for interacting with the IdentityAdoptionDecision builders.
+ IdentityAdoptionDecision *IdentityAdoptionDecisionClient
// PaymentAuditLog is the client for interacting with the PaymentAuditLog builders.
PaymentAuditLog *PaymentAuditLogClient
// PaymentOrder is the client for interacting with the PaymentOrder builders.
PaymentOrder *PaymentOrderClient
// PaymentProviderInstance is the client for interacting with the PaymentProviderInstance builders.
PaymentProviderInstance *PaymentProviderInstanceClient
+ // PendingAuthSession is the client for interacting with the PendingAuthSession builders.
+ PendingAuthSession *PendingAuthSessionClient
// PromoCode is the client for interacting with the PromoCode builders.
PromoCode *PromoCodeClient
// PromoCodeUsage is the client for interacting with the PromoCodeUsage builders.
@@ -202,12 +210,16 @@ func (tx *Tx) init() {
tx.AccountGroup = NewAccountGroupClient(tx.config)
tx.Announcement = NewAnnouncementClient(tx.config)
tx.AnnouncementRead = NewAnnouncementReadClient(tx.config)
+ tx.AuthIdentity = NewAuthIdentityClient(tx.config)
+ tx.AuthIdentityChannel = NewAuthIdentityChannelClient(tx.config)
tx.ErrorPassthroughRule = NewErrorPassthroughRuleClient(tx.config)
tx.Group = NewGroupClient(tx.config)
tx.IdempotencyRecord = NewIdempotencyRecordClient(tx.config)
+ tx.IdentityAdoptionDecision = NewIdentityAdoptionDecisionClient(tx.config)
tx.PaymentAuditLog = NewPaymentAuditLogClient(tx.config)
tx.PaymentOrder = NewPaymentOrderClient(tx.config)
tx.PaymentProviderInstance = NewPaymentProviderInstanceClient(tx.config)
+ tx.PendingAuthSession = NewPendingAuthSessionClient(tx.config)
tx.PromoCode = NewPromoCodeClient(tx.config)
tx.PromoCodeUsage = NewPromoCodeUsageClient(tx.config)
tx.Proxy = NewProxyClient(tx.config)
diff --git a/backend/ent/user.go b/backend/ent/user.go
index 9fa91f74..66f33623 100644
--- a/backend/ent/user.go
+++ b/backend/ent/user.go
@@ -45,6 +45,12 @@ type User struct {
TotpEnabled bool `json:"totp_enabled,omitempty"`
// TotpEnabledAt holds the value of the "totp_enabled_at" field.
TotpEnabledAt *time.Time `json:"totp_enabled_at,omitempty"`
+ // SignupSource holds the value of the "signup_source" field.
+ SignupSource string `json:"signup_source,omitempty"`
+ // LastLoginAt holds the value of the "last_login_at" field.
+ LastLoginAt *time.Time `json:"last_login_at,omitempty"`
+ // LastActiveAt holds the value of the "last_active_at" field.
+ LastActiveAt *time.Time `json:"last_active_at,omitempty"`
// BalanceNotifyEnabled holds the value of the "balance_notify_enabled" field.
BalanceNotifyEnabled bool `json:"balance_notify_enabled,omitempty"`
// BalanceNotifyThresholdType holds the value of the "balance_notify_threshold_type" field.
@@ -83,11 +89,15 @@ type UserEdges struct {
PromoCodeUsages []*PromoCodeUsage `json:"promo_code_usages,omitempty"`
// PaymentOrders holds the value of the payment_orders edge.
PaymentOrders []*PaymentOrder `json:"payment_orders,omitempty"`
+ // AuthIdentities holds the value of the auth_identities edge.
+ AuthIdentities []*AuthIdentity `json:"auth_identities,omitempty"`
+ // PendingAuthSessions holds the value of the pending_auth_sessions edge.
+ PendingAuthSessions []*PendingAuthSession `json:"pending_auth_sessions,omitempty"`
// UserAllowedGroups holds the value of the user_allowed_groups edge.
UserAllowedGroups []*UserAllowedGroup `json:"user_allowed_groups,omitempty"`
// loadedTypes holds the information for reporting if a
// type was loaded (or requested) in eager-loading or not.
- loadedTypes [11]bool
+ loadedTypes [13]bool
}
// APIKeysOrErr returns the APIKeys value or an error if the edge
@@ -180,10 +190,28 @@ func (e UserEdges) PaymentOrdersOrErr() ([]*PaymentOrder, error) {
return nil, &NotLoadedError{edge: "payment_orders"}
}
+// AuthIdentitiesOrErr returns the AuthIdentities value or an error if the edge
+// was not loaded in eager-loading.
+func (e UserEdges) AuthIdentitiesOrErr() ([]*AuthIdentity, error) {
+ if e.loadedTypes[10] {
+ return e.AuthIdentities, nil
+ }
+ return nil, &NotLoadedError{edge: "auth_identities"}
+}
+
+// PendingAuthSessionsOrErr returns the PendingAuthSessions value or an error if the edge
+// was not loaded in eager-loading.
+func (e UserEdges) PendingAuthSessionsOrErr() ([]*PendingAuthSession, error) {
+ if e.loadedTypes[11] {
+ return e.PendingAuthSessions, nil
+ }
+ return nil, &NotLoadedError{edge: "pending_auth_sessions"}
+}
+
// UserAllowedGroupsOrErr returns the UserAllowedGroups value or an error if the edge
// was not loaded in eager-loading.
func (e UserEdges) UserAllowedGroupsOrErr() ([]*UserAllowedGroup, error) {
- if e.loadedTypes[10] {
+ if e.loadedTypes[12] {
return e.UserAllowedGroups, nil
}
return nil, &NotLoadedError{edge: "user_allowed_groups"}
@@ -200,9 +228,9 @@ func (*User) scanValues(columns []string) ([]any, error) {
values[i] = new(sql.NullFloat64)
case user.FieldID, user.FieldConcurrency:
values[i] = new(sql.NullInt64)
- case user.FieldEmail, user.FieldPasswordHash, user.FieldRole, user.FieldStatus, user.FieldUsername, user.FieldNotes, user.FieldTotpSecretEncrypted, user.FieldBalanceNotifyThresholdType, user.FieldBalanceNotifyExtraEmails:
+ case user.FieldEmail, user.FieldPasswordHash, user.FieldRole, user.FieldStatus, user.FieldUsername, user.FieldNotes, user.FieldTotpSecretEncrypted, user.FieldSignupSource, user.FieldBalanceNotifyThresholdType, user.FieldBalanceNotifyExtraEmails:
values[i] = new(sql.NullString)
- case user.FieldCreatedAt, user.FieldUpdatedAt, user.FieldDeletedAt, user.FieldTotpEnabledAt:
+ case user.FieldCreatedAt, user.FieldUpdatedAt, user.FieldDeletedAt, user.FieldTotpEnabledAt, user.FieldLastLoginAt, user.FieldLastActiveAt:
values[i] = new(sql.NullTime)
default:
values[i] = new(sql.UnknownType)
@@ -312,6 +340,26 @@ func (_m *User) assignValues(columns []string, values []any) error {
_m.TotpEnabledAt = new(time.Time)
*_m.TotpEnabledAt = value.Time
}
+ case user.FieldSignupSource:
+ if value, ok := values[i].(*sql.NullString); !ok {
+ return fmt.Errorf("unexpected type %T for field signup_source", values[i])
+ } else if value.Valid {
+ _m.SignupSource = value.String
+ }
+ case user.FieldLastLoginAt:
+ if value, ok := values[i].(*sql.NullTime); !ok {
+ return fmt.Errorf("unexpected type %T for field last_login_at", values[i])
+ } else if value.Valid {
+ _m.LastLoginAt = new(time.Time)
+ *_m.LastLoginAt = value.Time
+ }
+ case user.FieldLastActiveAt:
+ if value, ok := values[i].(*sql.NullTime); !ok {
+ return fmt.Errorf("unexpected type %T for field last_active_at", values[i])
+ } else if value.Valid {
+ _m.LastActiveAt = new(time.Time)
+ *_m.LastActiveAt = value.Time
+ }
case user.FieldBalanceNotifyEnabled:
if value, ok := values[i].(*sql.NullBool); !ok {
return fmt.Errorf("unexpected type %T for field balance_notify_enabled", values[i])
@@ -406,6 +454,16 @@ func (_m *User) QueryPaymentOrders() *PaymentOrderQuery {
return NewUserClient(_m.config).QueryPaymentOrders(_m)
}
+// QueryAuthIdentities queries the "auth_identities" edge of the User entity.
+func (_m *User) QueryAuthIdentities() *AuthIdentityQuery {
+ return NewUserClient(_m.config).QueryAuthIdentities(_m)
+}
+
+// QueryPendingAuthSessions queries the "pending_auth_sessions" edge of the User entity.
+func (_m *User) QueryPendingAuthSessions() *PendingAuthSessionQuery {
+ return NewUserClient(_m.config).QueryPendingAuthSessions(_m)
+}
+
// QueryUserAllowedGroups queries the "user_allowed_groups" edge of the User entity.
func (_m *User) QueryUserAllowedGroups() *UserAllowedGroupQuery {
return NewUserClient(_m.config).QueryUserAllowedGroups(_m)
@@ -482,6 +540,19 @@ func (_m *User) String() string {
builder.WriteString(v.Format(time.ANSIC))
}
builder.WriteString(", ")
+ builder.WriteString("signup_source=")
+ builder.WriteString(_m.SignupSource)
+ builder.WriteString(", ")
+ if v := _m.LastLoginAt; v != nil {
+ builder.WriteString("last_login_at=")
+ builder.WriteString(v.Format(time.ANSIC))
+ }
+ builder.WriteString(", ")
+ if v := _m.LastActiveAt; v != nil {
+ builder.WriteString("last_active_at=")
+ builder.WriteString(v.Format(time.ANSIC))
+ }
+ builder.WriteString(", ")
builder.WriteString("balance_notify_enabled=")
builder.WriteString(fmt.Sprintf("%v", _m.BalanceNotifyEnabled))
builder.WriteString(", ")
diff --git a/backend/ent/user/user.go b/backend/ent/user/user.go
index d88a3a38..567e3b14 100644
--- a/backend/ent/user/user.go
+++ b/backend/ent/user/user.go
@@ -43,6 +43,12 @@ const (
FieldTotpEnabled = "totp_enabled"
// FieldTotpEnabledAt holds the string denoting the totp_enabled_at field in the database.
FieldTotpEnabledAt = "totp_enabled_at"
+ // FieldSignupSource holds the string denoting the signup_source field in the database.
+ FieldSignupSource = "signup_source"
+ // FieldLastLoginAt holds the string denoting the last_login_at field in the database.
+ FieldLastLoginAt = "last_login_at"
+ // FieldLastActiveAt holds the string denoting the last_active_at field in the database.
+ FieldLastActiveAt = "last_active_at"
// FieldBalanceNotifyEnabled holds the string denoting the balance_notify_enabled field in the database.
FieldBalanceNotifyEnabled = "balance_notify_enabled"
// FieldBalanceNotifyThresholdType holds the string denoting the balance_notify_threshold_type field in the database.
@@ -73,6 +79,10 @@ const (
EdgePromoCodeUsages = "promo_code_usages"
// EdgePaymentOrders holds the string denoting the payment_orders edge name in mutations.
EdgePaymentOrders = "payment_orders"
+ // EdgeAuthIdentities holds the string denoting the auth_identities edge name in mutations.
+ EdgeAuthIdentities = "auth_identities"
+ // EdgePendingAuthSessions holds the string denoting the pending_auth_sessions edge name in mutations.
+ EdgePendingAuthSessions = "pending_auth_sessions"
// EdgeUserAllowedGroups holds the string denoting the user_allowed_groups edge name in mutations.
EdgeUserAllowedGroups = "user_allowed_groups"
// Table holds the table name of the user in the database.
@@ -145,6 +155,20 @@ const (
PaymentOrdersInverseTable = "payment_orders"
// PaymentOrdersColumn is the table column denoting the payment_orders relation/edge.
PaymentOrdersColumn = "user_id"
+ // AuthIdentitiesTable is the table that holds the auth_identities relation/edge.
+ AuthIdentitiesTable = "auth_identities"
+ // AuthIdentitiesInverseTable is the table name for the AuthIdentity entity.
+ // It exists in this package in order to avoid circular dependency with the "authidentity" package.
+ AuthIdentitiesInverseTable = "auth_identities"
+ // AuthIdentitiesColumn is the table column denoting the auth_identities relation/edge.
+ AuthIdentitiesColumn = "user_id"
+ // PendingAuthSessionsTable is the table that holds the pending_auth_sessions relation/edge.
+ PendingAuthSessionsTable = "pending_auth_sessions"
+ // PendingAuthSessionsInverseTable is the table name for the PendingAuthSession entity.
+ // It exists in this package in order to avoid circular dependency with the "pendingauthsession" package.
+ PendingAuthSessionsInverseTable = "pending_auth_sessions"
+ // PendingAuthSessionsColumn is the table column denoting the pending_auth_sessions relation/edge.
+ PendingAuthSessionsColumn = "target_user_id"
// UserAllowedGroupsTable is the table that holds the user_allowed_groups relation/edge.
UserAllowedGroupsTable = "user_allowed_groups"
// UserAllowedGroupsInverseTable is the table name for the UserAllowedGroup entity.
@@ -171,6 +195,9 @@ var Columns = []string{
FieldTotpSecretEncrypted,
FieldTotpEnabled,
FieldTotpEnabledAt,
+ FieldSignupSource,
+ FieldLastLoginAt,
+ FieldLastActiveAt,
FieldBalanceNotifyEnabled,
FieldBalanceNotifyThresholdType,
FieldBalanceNotifyThreshold,
@@ -232,6 +259,10 @@ var (
DefaultNotes string
// DefaultTotpEnabled holds the default value on creation for the "totp_enabled" field.
DefaultTotpEnabled bool
+ // DefaultSignupSource holds the default value on creation for the "signup_source" field.
+ DefaultSignupSource string
+ // SignupSourceValidator is a validator for the "signup_source" field. It is called by the builders before save.
+ SignupSourceValidator func(string) error
// DefaultBalanceNotifyEnabled holds the default value on creation for the "balance_notify_enabled" field.
DefaultBalanceNotifyEnabled bool
// DefaultBalanceNotifyThresholdType holds the default value on creation for the "balance_notify_threshold_type" field.
@@ -320,6 +351,21 @@ func ByTotpEnabledAt(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldTotpEnabledAt, opts...).ToFunc()
}
+// BySignupSource orders the results by the signup_source field.
+func BySignupSource(opts ...sql.OrderTermOption) OrderOption {
+ return sql.OrderByField(FieldSignupSource, opts...).ToFunc()
+}
+
+// ByLastLoginAt orders the results by the last_login_at field.
+func ByLastLoginAt(opts ...sql.OrderTermOption) OrderOption {
+ return sql.OrderByField(FieldLastLoginAt, opts...).ToFunc()
+}
+
+// ByLastActiveAt orders the results by the last_active_at field.
+func ByLastActiveAt(opts ...sql.OrderTermOption) OrderOption {
+ return sql.OrderByField(FieldLastActiveAt, opts...).ToFunc()
+}
+
// ByBalanceNotifyEnabled orders the results by the balance_notify_enabled field.
func ByBalanceNotifyEnabled(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldBalanceNotifyEnabled, opts...).ToFunc()
@@ -485,6 +531,34 @@ func ByPaymentOrders(term sql.OrderTerm, terms ...sql.OrderTerm) OrderOption {
}
}
+// ByAuthIdentitiesCount orders the results by auth_identities count.
+func ByAuthIdentitiesCount(opts ...sql.OrderTermOption) OrderOption {
+ return func(s *sql.Selector) {
+ sqlgraph.OrderByNeighborsCount(s, newAuthIdentitiesStep(), opts...)
+ }
+}
+
+// ByAuthIdentities orders the results by auth_identities terms.
+func ByAuthIdentities(term sql.OrderTerm, terms ...sql.OrderTerm) OrderOption {
+ return func(s *sql.Selector) {
+ sqlgraph.OrderByNeighborTerms(s, newAuthIdentitiesStep(), append([]sql.OrderTerm{term}, terms...)...)
+ }
+}
+
+// ByPendingAuthSessionsCount orders the results by pending_auth_sessions count.
+func ByPendingAuthSessionsCount(opts ...sql.OrderTermOption) OrderOption {
+ return func(s *sql.Selector) {
+ sqlgraph.OrderByNeighborsCount(s, newPendingAuthSessionsStep(), opts...)
+ }
+}
+
+// ByPendingAuthSessions orders the results by pending_auth_sessions terms.
+func ByPendingAuthSessions(term sql.OrderTerm, terms ...sql.OrderTerm) OrderOption {
+ return func(s *sql.Selector) {
+ sqlgraph.OrderByNeighborTerms(s, newPendingAuthSessionsStep(), append([]sql.OrderTerm{term}, terms...)...)
+ }
+}
+
// ByUserAllowedGroupsCount orders the results by user_allowed_groups count.
func ByUserAllowedGroupsCount(opts ...sql.OrderTermOption) OrderOption {
return func(s *sql.Selector) {
@@ -568,6 +642,20 @@ func newPaymentOrdersStep() *sqlgraph.Step {
sqlgraph.Edge(sqlgraph.O2M, false, PaymentOrdersTable, PaymentOrdersColumn),
)
}
+func newAuthIdentitiesStep() *sqlgraph.Step {
+ return sqlgraph.NewStep(
+ sqlgraph.From(Table, FieldID),
+ sqlgraph.To(AuthIdentitiesInverseTable, FieldID),
+ sqlgraph.Edge(sqlgraph.O2M, false, AuthIdentitiesTable, AuthIdentitiesColumn),
+ )
+}
+func newPendingAuthSessionsStep() *sqlgraph.Step {
+ return sqlgraph.NewStep(
+ sqlgraph.From(Table, FieldID),
+ sqlgraph.To(PendingAuthSessionsInverseTable, FieldID),
+ sqlgraph.Edge(sqlgraph.O2M, false, PendingAuthSessionsTable, PendingAuthSessionsColumn),
+ )
+}
func newUserAllowedGroupsStep() *sqlgraph.Step {
return sqlgraph.NewStep(
sqlgraph.From(Table, FieldID),
diff --git a/backend/ent/user/where.go b/backend/ent/user/where.go
index 2788aa7a..cbcfcc26 100644
--- a/backend/ent/user/where.go
+++ b/backend/ent/user/where.go
@@ -125,6 +125,21 @@ func TotpEnabledAt(v time.Time) predicate.User {
return predicate.User(sql.FieldEQ(FieldTotpEnabledAt, v))
}
+// SignupSource applies equality check predicate on the "signup_source" field. It's identical to SignupSourceEQ.
+func SignupSource(v string) predicate.User {
+ return predicate.User(sql.FieldEQ(FieldSignupSource, v))
+}
+
+// LastLoginAt applies equality check predicate on the "last_login_at" field. It's identical to LastLoginAtEQ.
+func LastLoginAt(v time.Time) predicate.User {
+ return predicate.User(sql.FieldEQ(FieldLastLoginAt, v))
+}
+
+// LastActiveAt applies equality check predicate on the "last_active_at" field. It's identical to LastActiveAtEQ.
+func LastActiveAt(v time.Time) predicate.User {
+ return predicate.User(sql.FieldEQ(FieldLastActiveAt, v))
+}
+
// BalanceNotifyEnabled applies equality check predicate on the "balance_notify_enabled" field. It's identical to BalanceNotifyEnabledEQ.
func BalanceNotifyEnabled(v bool) predicate.User {
return predicate.User(sql.FieldEQ(FieldBalanceNotifyEnabled, v))
@@ -885,6 +900,171 @@ func TotpEnabledAtNotNil() predicate.User {
return predicate.User(sql.FieldNotNull(FieldTotpEnabledAt))
}
+// SignupSourceEQ applies the EQ predicate on the "signup_source" field.
+func SignupSourceEQ(v string) predicate.User {
+ return predicate.User(sql.FieldEQ(FieldSignupSource, v))
+}
+
+// SignupSourceNEQ applies the NEQ predicate on the "signup_source" field.
+func SignupSourceNEQ(v string) predicate.User {
+ return predicate.User(sql.FieldNEQ(FieldSignupSource, v))
+}
+
+// SignupSourceIn applies the In predicate on the "signup_source" field.
+func SignupSourceIn(vs ...string) predicate.User {
+ return predicate.User(sql.FieldIn(FieldSignupSource, vs...))
+}
+
+// SignupSourceNotIn applies the NotIn predicate on the "signup_source" field.
+func SignupSourceNotIn(vs ...string) predicate.User {
+ return predicate.User(sql.FieldNotIn(FieldSignupSource, vs...))
+}
+
+// SignupSourceGT applies the GT predicate on the "signup_source" field.
+func SignupSourceGT(v string) predicate.User {
+ return predicate.User(sql.FieldGT(FieldSignupSource, v))
+}
+
+// SignupSourceGTE applies the GTE predicate on the "signup_source" field.
+func SignupSourceGTE(v string) predicate.User {
+ return predicate.User(sql.FieldGTE(FieldSignupSource, v))
+}
+
+// SignupSourceLT applies the LT predicate on the "signup_source" field.
+func SignupSourceLT(v string) predicate.User {
+ return predicate.User(sql.FieldLT(FieldSignupSource, v))
+}
+
+// SignupSourceLTE applies the LTE predicate on the "signup_source" field.
+func SignupSourceLTE(v string) predicate.User {
+ return predicate.User(sql.FieldLTE(FieldSignupSource, v))
+}
+
+// SignupSourceContains applies the Contains predicate on the "signup_source" field.
+func SignupSourceContains(v string) predicate.User {
+ return predicate.User(sql.FieldContains(FieldSignupSource, v))
+}
+
+// SignupSourceHasPrefix applies the HasPrefix predicate on the "signup_source" field.
+func SignupSourceHasPrefix(v string) predicate.User {
+ return predicate.User(sql.FieldHasPrefix(FieldSignupSource, v))
+}
+
+// SignupSourceHasSuffix applies the HasSuffix predicate on the "signup_source" field.
+func SignupSourceHasSuffix(v string) predicate.User {
+ return predicate.User(sql.FieldHasSuffix(FieldSignupSource, v))
+}
+
+// SignupSourceEqualFold applies the EqualFold predicate on the "signup_source" field.
+func SignupSourceEqualFold(v string) predicate.User {
+ return predicate.User(sql.FieldEqualFold(FieldSignupSource, v))
+}
+
+// SignupSourceContainsFold applies the ContainsFold predicate on the "signup_source" field.
+func SignupSourceContainsFold(v string) predicate.User {
+ return predicate.User(sql.FieldContainsFold(FieldSignupSource, v))
+}
+
+// LastLoginAtEQ applies the EQ predicate on the "last_login_at" field.
+func LastLoginAtEQ(v time.Time) predicate.User {
+ return predicate.User(sql.FieldEQ(FieldLastLoginAt, v))
+}
+
+// LastLoginAtNEQ applies the NEQ predicate on the "last_login_at" field.
+func LastLoginAtNEQ(v time.Time) predicate.User {
+ return predicate.User(sql.FieldNEQ(FieldLastLoginAt, v))
+}
+
+// LastLoginAtIn applies the In predicate on the "last_login_at" field.
+func LastLoginAtIn(vs ...time.Time) predicate.User {
+ return predicate.User(sql.FieldIn(FieldLastLoginAt, vs...))
+}
+
+// LastLoginAtNotIn applies the NotIn predicate on the "last_login_at" field.
+func LastLoginAtNotIn(vs ...time.Time) predicate.User {
+ return predicate.User(sql.FieldNotIn(FieldLastLoginAt, vs...))
+}
+
+// LastLoginAtGT applies the GT predicate on the "last_login_at" field.
+func LastLoginAtGT(v time.Time) predicate.User {
+ return predicate.User(sql.FieldGT(FieldLastLoginAt, v))
+}
+
+// LastLoginAtGTE applies the GTE predicate on the "last_login_at" field.
+func LastLoginAtGTE(v time.Time) predicate.User {
+ return predicate.User(sql.FieldGTE(FieldLastLoginAt, v))
+}
+
+// LastLoginAtLT applies the LT predicate on the "last_login_at" field.
+func LastLoginAtLT(v time.Time) predicate.User {
+ return predicate.User(sql.FieldLT(FieldLastLoginAt, v))
+}
+
+// LastLoginAtLTE applies the LTE predicate on the "last_login_at" field.
+func LastLoginAtLTE(v time.Time) predicate.User {
+ return predicate.User(sql.FieldLTE(FieldLastLoginAt, v))
+}
+
+// LastLoginAtIsNil applies the IsNil predicate on the "last_login_at" field.
+func LastLoginAtIsNil() predicate.User {
+ return predicate.User(sql.FieldIsNull(FieldLastLoginAt))
+}
+
+// LastLoginAtNotNil applies the NotNil predicate on the "last_login_at" field.
+func LastLoginAtNotNil() predicate.User {
+ return predicate.User(sql.FieldNotNull(FieldLastLoginAt))
+}
+
+// LastActiveAtEQ applies the EQ predicate on the "last_active_at" field.
+func LastActiveAtEQ(v time.Time) predicate.User {
+ return predicate.User(sql.FieldEQ(FieldLastActiveAt, v))
+}
+
+// LastActiveAtNEQ applies the NEQ predicate on the "last_active_at" field.
+func LastActiveAtNEQ(v time.Time) predicate.User {
+ return predicate.User(sql.FieldNEQ(FieldLastActiveAt, v))
+}
+
+// LastActiveAtIn applies the In predicate on the "last_active_at" field.
+func LastActiveAtIn(vs ...time.Time) predicate.User {
+ return predicate.User(sql.FieldIn(FieldLastActiveAt, vs...))
+}
+
+// LastActiveAtNotIn applies the NotIn predicate on the "last_active_at" field.
+func LastActiveAtNotIn(vs ...time.Time) predicate.User {
+ return predicate.User(sql.FieldNotIn(FieldLastActiveAt, vs...))
+}
+
+// LastActiveAtGT applies the GT predicate on the "last_active_at" field.
+func LastActiveAtGT(v time.Time) predicate.User {
+ return predicate.User(sql.FieldGT(FieldLastActiveAt, v))
+}
+
+// LastActiveAtGTE applies the GTE predicate on the "last_active_at" field.
+func LastActiveAtGTE(v time.Time) predicate.User {
+ return predicate.User(sql.FieldGTE(FieldLastActiveAt, v))
+}
+
+// LastActiveAtLT applies the LT predicate on the "last_active_at" field.
+func LastActiveAtLT(v time.Time) predicate.User {
+ return predicate.User(sql.FieldLT(FieldLastActiveAt, v))
+}
+
+// LastActiveAtLTE applies the LTE predicate on the "last_active_at" field.
+func LastActiveAtLTE(v time.Time) predicate.User {
+ return predicate.User(sql.FieldLTE(FieldLastActiveAt, v))
+}
+
+// LastActiveAtIsNil applies the IsNil predicate on the "last_active_at" field.
+func LastActiveAtIsNil() predicate.User {
+ return predicate.User(sql.FieldIsNull(FieldLastActiveAt))
+}
+
+// LastActiveAtNotNil applies the NotNil predicate on the "last_active_at" field.
+func LastActiveAtNotNil() predicate.User {
+ return predicate.User(sql.FieldNotNull(FieldLastActiveAt))
+}
+
// BalanceNotifyEnabledEQ applies the EQ predicate on the "balance_notify_enabled" field.
func BalanceNotifyEnabledEQ(v bool) predicate.User {
return predicate.User(sql.FieldEQ(FieldBalanceNotifyEnabled, v))
@@ -1345,6 +1525,52 @@ func HasPaymentOrdersWith(preds ...predicate.PaymentOrder) predicate.User {
})
}
+// HasAuthIdentities applies the HasEdge predicate on the "auth_identities" edge.
+func HasAuthIdentities() predicate.User {
+ return predicate.User(func(s *sql.Selector) {
+ step := sqlgraph.NewStep(
+ sqlgraph.From(Table, FieldID),
+ sqlgraph.Edge(sqlgraph.O2M, false, AuthIdentitiesTable, AuthIdentitiesColumn),
+ )
+ sqlgraph.HasNeighbors(s, step)
+ })
+}
+
+// HasAuthIdentitiesWith applies the HasEdge predicate on the "auth_identities" edge with a given conditions (other predicates).
+func HasAuthIdentitiesWith(preds ...predicate.AuthIdentity) predicate.User {
+ return predicate.User(func(s *sql.Selector) {
+ step := newAuthIdentitiesStep()
+ sqlgraph.HasNeighborsWith(s, step, func(s *sql.Selector) {
+ for _, p := range preds {
+ p(s)
+ }
+ })
+ })
+}
+
+// HasPendingAuthSessions applies the HasEdge predicate on the "pending_auth_sessions" edge.
+func HasPendingAuthSessions() predicate.User {
+ return predicate.User(func(s *sql.Selector) {
+ step := sqlgraph.NewStep(
+ sqlgraph.From(Table, FieldID),
+ sqlgraph.Edge(sqlgraph.O2M, false, PendingAuthSessionsTable, PendingAuthSessionsColumn),
+ )
+ sqlgraph.HasNeighbors(s, step)
+ })
+}
+
+// HasPendingAuthSessionsWith applies the HasEdge predicate on the "pending_auth_sessions" edge with a given conditions (other predicates).
+func HasPendingAuthSessionsWith(preds ...predicate.PendingAuthSession) predicate.User {
+ return predicate.User(func(s *sql.Selector) {
+ step := newPendingAuthSessionsStep()
+ sqlgraph.HasNeighborsWith(s, step, func(s *sql.Selector) {
+ for _, p := range preds {
+ p(s)
+ }
+ })
+ })
+}
+
// HasUserAllowedGroups applies the HasEdge predicate on the "user_allowed_groups" edge.
func HasUserAllowedGroups() predicate.User {
return predicate.User(func(s *sql.Selector) {
diff --git a/backend/ent/user_create.go b/backend/ent/user_create.go
index fbc64f9c..db95e813 100644
--- a/backend/ent/user_create.go
+++ b/backend/ent/user_create.go
@@ -13,8 +13,10 @@ import (
"entgo.io/ent/schema/field"
"github.com/Wei-Shaw/sub2api/ent/announcementread"
"github.com/Wei-Shaw/sub2api/ent/apikey"
+ "github.com/Wei-Shaw/sub2api/ent/authidentity"
"github.com/Wei-Shaw/sub2api/ent/group"
"github.com/Wei-Shaw/sub2api/ent/paymentorder"
+ "github.com/Wei-Shaw/sub2api/ent/pendingauthsession"
"github.com/Wei-Shaw/sub2api/ent/promocodeusage"
"github.com/Wei-Shaw/sub2api/ent/redeemcode"
"github.com/Wei-Shaw/sub2api/ent/usagelog"
@@ -211,6 +213,48 @@ func (_c *UserCreate) SetNillableTotpEnabledAt(v *time.Time) *UserCreate {
return _c
}
+// SetSignupSource sets the "signup_source" field.
+func (_c *UserCreate) SetSignupSource(v string) *UserCreate {
+ _c.mutation.SetSignupSource(v)
+ return _c
+}
+
+// SetNillableSignupSource sets the "signup_source" field if the given value is not nil.
+func (_c *UserCreate) SetNillableSignupSource(v *string) *UserCreate {
+ if v != nil {
+ _c.SetSignupSource(*v)
+ }
+ return _c
+}
+
+// SetLastLoginAt sets the "last_login_at" field.
+func (_c *UserCreate) SetLastLoginAt(v time.Time) *UserCreate {
+ _c.mutation.SetLastLoginAt(v)
+ return _c
+}
+
+// SetNillableLastLoginAt sets the "last_login_at" field if the given value is not nil.
+func (_c *UserCreate) SetNillableLastLoginAt(v *time.Time) *UserCreate {
+ if v != nil {
+ _c.SetLastLoginAt(*v)
+ }
+ return _c
+}
+
+// SetLastActiveAt sets the "last_active_at" field.
+func (_c *UserCreate) SetLastActiveAt(v time.Time) *UserCreate {
+ _c.mutation.SetLastActiveAt(v)
+ return _c
+}
+
+// SetNillableLastActiveAt sets the "last_active_at" field if the given value is not nil.
+func (_c *UserCreate) SetNillableLastActiveAt(v *time.Time) *UserCreate {
+ if v != nil {
+ _c.SetLastActiveAt(*v)
+ }
+ return _c
+}
+
// SetBalanceNotifyEnabled sets the "balance_notify_enabled" field.
func (_c *UserCreate) SetBalanceNotifyEnabled(v bool) *UserCreate {
_c.mutation.SetBalanceNotifyEnabled(v)
@@ -431,6 +475,36 @@ func (_c *UserCreate) AddPaymentOrders(v ...*PaymentOrder) *UserCreate {
return _c.AddPaymentOrderIDs(ids...)
}
+// AddAuthIdentityIDs adds the "auth_identities" edge to the AuthIdentity entity by IDs.
+func (_c *UserCreate) AddAuthIdentityIDs(ids ...int64) *UserCreate {
+ _c.mutation.AddAuthIdentityIDs(ids...)
+ return _c
+}
+
+// AddAuthIdentities adds the "auth_identities" edges to the AuthIdentity entity.
+func (_c *UserCreate) AddAuthIdentities(v ...*AuthIdentity) *UserCreate {
+ ids := make([]int64, len(v))
+ for i := range v {
+ ids[i] = v[i].ID
+ }
+ return _c.AddAuthIdentityIDs(ids...)
+}
+
+// AddPendingAuthSessionIDs adds the "pending_auth_sessions" edge to the PendingAuthSession entity by IDs.
+func (_c *UserCreate) AddPendingAuthSessionIDs(ids ...int64) *UserCreate {
+ _c.mutation.AddPendingAuthSessionIDs(ids...)
+ return _c
+}
+
+// AddPendingAuthSessions adds the "pending_auth_sessions" edges to the PendingAuthSession entity.
+func (_c *UserCreate) AddPendingAuthSessions(v ...*PendingAuthSession) *UserCreate {
+ ids := make([]int64, len(v))
+ for i := range v {
+ ids[i] = v[i].ID
+ }
+ return _c.AddPendingAuthSessionIDs(ids...)
+}
+
// Mutation returns the UserMutation object of the builder.
func (_c *UserCreate) Mutation() *UserMutation {
return _c.mutation
@@ -510,6 +584,10 @@ func (_c *UserCreate) defaults() error {
v := user.DefaultTotpEnabled
_c.mutation.SetTotpEnabled(v)
}
+ if _, ok := _c.mutation.SignupSource(); !ok {
+ v := user.DefaultSignupSource
+ _c.mutation.SetSignupSource(v)
+ }
if _, ok := _c.mutation.BalanceNotifyEnabled(); !ok {
v := user.DefaultBalanceNotifyEnabled
_c.mutation.SetBalanceNotifyEnabled(v)
@@ -589,6 +667,14 @@ func (_c *UserCreate) check() error {
if _, ok := _c.mutation.TotpEnabled(); !ok {
return &ValidationError{Name: "totp_enabled", err: errors.New(`ent: missing required field "User.totp_enabled"`)}
}
+ if _, ok := _c.mutation.SignupSource(); !ok {
+ return &ValidationError{Name: "signup_source", err: errors.New(`ent: missing required field "User.signup_source"`)}
+ }
+ if v, ok := _c.mutation.SignupSource(); ok {
+ if err := user.SignupSourceValidator(v); err != nil {
+ return &ValidationError{Name: "signup_source", err: fmt.Errorf(`ent: validator failed for field "User.signup_source": %w`, err)}
+ }
+ }
if _, ok := _c.mutation.BalanceNotifyEnabled(); !ok {
return &ValidationError{Name: "balance_notify_enabled", err: errors.New(`ent: missing required field "User.balance_notify_enabled"`)}
}
@@ -684,6 +770,18 @@ func (_c *UserCreate) createSpec() (*User, *sqlgraph.CreateSpec) {
_spec.SetField(user.FieldTotpEnabledAt, field.TypeTime, value)
_node.TotpEnabledAt = &value
}
+ if value, ok := _c.mutation.SignupSource(); ok {
+ _spec.SetField(user.FieldSignupSource, field.TypeString, value)
+ _node.SignupSource = value
+ }
+ if value, ok := _c.mutation.LastLoginAt(); ok {
+ _spec.SetField(user.FieldLastLoginAt, field.TypeTime, value)
+ _node.LastLoginAt = &value
+ }
+ if value, ok := _c.mutation.LastActiveAt(); ok {
+ _spec.SetField(user.FieldLastActiveAt, field.TypeTime, value)
+ _node.LastActiveAt = &value
+ }
if value, ok := _c.mutation.BalanceNotifyEnabled(); ok {
_spec.SetField(user.FieldBalanceNotifyEnabled, field.TypeBool, value)
_node.BalanceNotifyEnabled = value
@@ -868,6 +966,38 @@ func (_c *UserCreate) createSpec() (*User, *sqlgraph.CreateSpec) {
}
_spec.Edges = append(_spec.Edges, edge)
}
+ if nodes := _c.mutation.AuthIdentitiesIDs(); len(nodes) > 0 {
+ edge := &sqlgraph.EdgeSpec{
+ Rel: sqlgraph.O2M,
+ Inverse: false,
+ Table: user.AuthIdentitiesTable,
+ Columns: []string{user.AuthIdentitiesColumn},
+ Bidi: false,
+ Target: &sqlgraph.EdgeTarget{
+ IDSpec: sqlgraph.NewFieldSpec(authidentity.FieldID, field.TypeInt64),
+ },
+ }
+ for _, k := range nodes {
+ edge.Target.Nodes = append(edge.Target.Nodes, k)
+ }
+ _spec.Edges = append(_spec.Edges, edge)
+ }
+ if nodes := _c.mutation.PendingAuthSessionsIDs(); len(nodes) > 0 {
+ edge := &sqlgraph.EdgeSpec{
+ Rel: sqlgraph.O2M,
+ Inverse: false,
+ Table: user.PendingAuthSessionsTable,
+ Columns: []string{user.PendingAuthSessionsColumn},
+ Bidi: false,
+ Target: &sqlgraph.EdgeTarget{
+ IDSpec: sqlgraph.NewFieldSpec(pendingauthsession.FieldID, field.TypeInt64),
+ },
+ }
+ for _, k := range nodes {
+ edge.Target.Nodes = append(edge.Target.Nodes, k)
+ }
+ _spec.Edges = append(_spec.Edges, edge)
+ }
return _node, _spec
}
@@ -1106,6 +1236,54 @@ func (u *UserUpsert) ClearTotpEnabledAt() *UserUpsert {
return u
}
+// SetSignupSource sets the "signup_source" field.
+func (u *UserUpsert) SetSignupSource(v string) *UserUpsert {
+ u.Set(user.FieldSignupSource, v)
+ return u
+}
+
+// UpdateSignupSource sets the "signup_source" field to the value that was provided on create.
+func (u *UserUpsert) UpdateSignupSource() *UserUpsert {
+ u.SetExcluded(user.FieldSignupSource)
+ return u
+}
+
+// SetLastLoginAt sets the "last_login_at" field.
+func (u *UserUpsert) SetLastLoginAt(v time.Time) *UserUpsert {
+ u.Set(user.FieldLastLoginAt, v)
+ return u
+}
+
+// UpdateLastLoginAt sets the "last_login_at" field to the value that was provided on create.
+func (u *UserUpsert) UpdateLastLoginAt() *UserUpsert {
+ u.SetExcluded(user.FieldLastLoginAt)
+ return u
+}
+
+// ClearLastLoginAt clears the value of the "last_login_at" field.
+func (u *UserUpsert) ClearLastLoginAt() *UserUpsert {
+ u.SetNull(user.FieldLastLoginAt)
+ return u
+}
+
+// SetLastActiveAt sets the "last_active_at" field.
+func (u *UserUpsert) SetLastActiveAt(v time.Time) *UserUpsert {
+ u.Set(user.FieldLastActiveAt, v)
+ return u
+}
+
+// UpdateLastActiveAt sets the "last_active_at" field to the value that was provided on create.
+func (u *UserUpsert) UpdateLastActiveAt() *UserUpsert {
+ u.SetExcluded(user.FieldLastActiveAt)
+ return u
+}
+
+// ClearLastActiveAt clears the value of the "last_active_at" field.
+func (u *UserUpsert) ClearLastActiveAt() *UserUpsert {
+ u.SetNull(user.FieldLastActiveAt)
+ return u
+}
+
// SetBalanceNotifyEnabled sets the "balance_notify_enabled" field.
func (u *UserUpsert) SetBalanceNotifyEnabled(v bool) *UserUpsert {
u.Set(user.FieldBalanceNotifyEnabled, v)
@@ -1446,6 +1624,62 @@ func (u *UserUpsertOne) ClearTotpEnabledAt() *UserUpsertOne {
})
}
+// SetSignupSource sets the "signup_source" field.
+func (u *UserUpsertOne) SetSignupSource(v string) *UserUpsertOne {
+ return u.Update(func(s *UserUpsert) {
+ s.SetSignupSource(v)
+ })
+}
+
+// UpdateSignupSource sets the "signup_source" field to the value that was provided on create.
+func (u *UserUpsertOne) UpdateSignupSource() *UserUpsertOne {
+ return u.Update(func(s *UserUpsert) {
+ s.UpdateSignupSource()
+ })
+}
+
+// SetLastLoginAt sets the "last_login_at" field.
+func (u *UserUpsertOne) SetLastLoginAt(v time.Time) *UserUpsertOne {
+ return u.Update(func(s *UserUpsert) {
+ s.SetLastLoginAt(v)
+ })
+}
+
+// UpdateLastLoginAt sets the "last_login_at" field to the value that was provided on create.
+func (u *UserUpsertOne) UpdateLastLoginAt() *UserUpsertOne {
+ return u.Update(func(s *UserUpsert) {
+ s.UpdateLastLoginAt()
+ })
+}
+
+// ClearLastLoginAt clears the value of the "last_login_at" field.
+func (u *UserUpsertOne) ClearLastLoginAt() *UserUpsertOne {
+ return u.Update(func(s *UserUpsert) {
+ s.ClearLastLoginAt()
+ })
+}
+
+// SetLastActiveAt sets the "last_active_at" field.
+func (u *UserUpsertOne) SetLastActiveAt(v time.Time) *UserUpsertOne {
+ return u.Update(func(s *UserUpsert) {
+ s.SetLastActiveAt(v)
+ })
+}
+
+// UpdateLastActiveAt sets the "last_active_at" field to the value that was provided on create.
+func (u *UserUpsertOne) UpdateLastActiveAt() *UserUpsertOne {
+ return u.Update(func(s *UserUpsert) {
+ s.UpdateLastActiveAt()
+ })
+}
+
+// ClearLastActiveAt clears the value of the "last_active_at" field.
+func (u *UserUpsertOne) ClearLastActiveAt() *UserUpsertOne {
+ return u.Update(func(s *UserUpsert) {
+ s.ClearLastActiveAt()
+ })
+}
+
// SetBalanceNotifyEnabled sets the "balance_notify_enabled" field.
func (u *UserUpsertOne) SetBalanceNotifyEnabled(v bool) *UserUpsertOne {
return u.Update(func(s *UserUpsert) {
@@ -1965,6 +2199,62 @@ func (u *UserUpsertBulk) ClearTotpEnabledAt() *UserUpsertBulk {
})
}
+// SetSignupSource sets the "signup_source" field.
+func (u *UserUpsertBulk) SetSignupSource(v string) *UserUpsertBulk {
+ return u.Update(func(s *UserUpsert) {
+ s.SetSignupSource(v)
+ })
+}
+
+// UpdateSignupSource sets the "signup_source" field to the value that was provided on create.
+func (u *UserUpsertBulk) UpdateSignupSource() *UserUpsertBulk {
+ return u.Update(func(s *UserUpsert) {
+ s.UpdateSignupSource()
+ })
+}
+
+// SetLastLoginAt sets the "last_login_at" field.
+func (u *UserUpsertBulk) SetLastLoginAt(v time.Time) *UserUpsertBulk {
+ return u.Update(func(s *UserUpsert) {
+ s.SetLastLoginAt(v)
+ })
+}
+
+// UpdateLastLoginAt sets the "last_login_at" field to the value that was provided on create.
+func (u *UserUpsertBulk) UpdateLastLoginAt() *UserUpsertBulk {
+ return u.Update(func(s *UserUpsert) {
+ s.UpdateLastLoginAt()
+ })
+}
+
+// ClearLastLoginAt clears the value of the "last_login_at" field.
+func (u *UserUpsertBulk) ClearLastLoginAt() *UserUpsertBulk {
+ return u.Update(func(s *UserUpsert) {
+ s.ClearLastLoginAt()
+ })
+}
+
+// SetLastActiveAt sets the "last_active_at" field.
+func (u *UserUpsertBulk) SetLastActiveAt(v time.Time) *UserUpsertBulk {
+ return u.Update(func(s *UserUpsert) {
+ s.SetLastActiveAt(v)
+ })
+}
+
+// UpdateLastActiveAt sets the "last_active_at" field to the value that was provided on create.
+func (u *UserUpsertBulk) UpdateLastActiveAt() *UserUpsertBulk {
+ return u.Update(func(s *UserUpsert) {
+ s.UpdateLastActiveAt()
+ })
+}
+
+// ClearLastActiveAt clears the value of the "last_active_at" field.
+func (u *UserUpsertBulk) ClearLastActiveAt() *UserUpsertBulk {
+ return u.Update(func(s *UserUpsert) {
+ s.ClearLastActiveAt()
+ })
+}
+
// SetBalanceNotifyEnabled sets the "balance_notify_enabled" field.
func (u *UserUpsertBulk) SetBalanceNotifyEnabled(v bool) *UserUpsertBulk {
return u.Update(func(s *UserUpsert) {
diff --git a/backend/ent/user_query.go b/backend/ent/user_query.go
index 113d87ac..f1ee5cfe 100644
--- a/backend/ent/user_query.go
+++ b/backend/ent/user_query.go
@@ -15,8 +15,10 @@ import (
"entgo.io/ent/schema/field"
"github.com/Wei-Shaw/sub2api/ent/announcementread"
"github.com/Wei-Shaw/sub2api/ent/apikey"
+ "github.com/Wei-Shaw/sub2api/ent/authidentity"
"github.com/Wei-Shaw/sub2api/ent/group"
"github.com/Wei-Shaw/sub2api/ent/paymentorder"
+ "github.com/Wei-Shaw/sub2api/ent/pendingauthsession"
"github.com/Wei-Shaw/sub2api/ent/predicate"
"github.com/Wei-Shaw/sub2api/ent/promocodeusage"
"github.com/Wei-Shaw/sub2api/ent/redeemcode"
@@ -44,6 +46,8 @@ type UserQuery struct {
withAttributeValues *UserAttributeValueQuery
withPromoCodeUsages *PromoCodeUsageQuery
withPaymentOrders *PaymentOrderQuery
+ withAuthIdentities *AuthIdentityQuery
+ withPendingAuthSessions *PendingAuthSessionQuery
withUserAllowedGroups *UserAllowedGroupQuery
modifiers []func(*sql.Selector)
// intermediate query (i.e. traversal path).
@@ -302,6 +306,50 @@ func (_q *UserQuery) QueryPaymentOrders() *PaymentOrderQuery {
return query
}
+// QueryAuthIdentities chains the current query on the "auth_identities" edge.
+func (_q *UserQuery) QueryAuthIdentities() *AuthIdentityQuery {
+ query := (&AuthIdentityClient{config: _q.config}).Query()
+ query.path = func(ctx context.Context) (fromU *sql.Selector, err error) {
+ if err := _q.prepareQuery(ctx); err != nil {
+ return nil, err
+ }
+ selector := _q.sqlQuery(ctx)
+ if err := selector.Err(); err != nil {
+ return nil, err
+ }
+ step := sqlgraph.NewStep(
+ sqlgraph.From(user.Table, user.FieldID, selector),
+ sqlgraph.To(authidentity.Table, authidentity.FieldID),
+ sqlgraph.Edge(sqlgraph.O2M, false, user.AuthIdentitiesTable, user.AuthIdentitiesColumn),
+ )
+ fromU = sqlgraph.SetNeighbors(_q.driver.Dialect(), step)
+ return fromU, nil
+ }
+ return query
+}
+
+// QueryPendingAuthSessions chains the current query on the "pending_auth_sessions" edge.
+func (_q *UserQuery) QueryPendingAuthSessions() *PendingAuthSessionQuery {
+ query := (&PendingAuthSessionClient{config: _q.config}).Query()
+ query.path = func(ctx context.Context) (fromU *sql.Selector, err error) {
+ if err := _q.prepareQuery(ctx); err != nil {
+ return nil, err
+ }
+ selector := _q.sqlQuery(ctx)
+ if err := selector.Err(); err != nil {
+ return nil, err
+ }
+ step := sqlgraph.NewStep(
+ sqlgraph.From(user.Table, user.FieldID, selector),
+ sqlgraph.To(pendingauthsession.Table, pendingauthsession.FieldID),
+ sqlgraph.Edge(sqlgraph.O2M, false, user.PendingAuthSessionsTable, user.PendingAuthSessionsColumn),
+ )
+ fromU = sqlgraph.SetNeighbors(_q.driver.Dialect(), step)
+ return fromU, nil
+ }
+ return query
+}
+
// QueryUserAllowedGroups chains the current query on the "user_allowed_groups" edge.
func (_q *UserQuery) QueryUserAllowedGroups() *UserAllowedGroupQuery {
query := (&UserAllowedGroupClient{config: _q.config}).Query()
@@ -526,6 +574,8 @@ func (_q *UserQuery) Clone() *UserQuery {
withAttributeValues: _q.withAttributeValues.Clone(),
withPromoCodeUsages: _q.withPromoCodeUsages.Clone(),
withPaymentOrders: _q.withPaymentOrders.Clone(),
+ withAuthIdentities: _q.withAuthIdentities.Clone(),
+ withPendingAuthSessions: _q.withPendingAuthSessions.Clone(),
withUserAllowedGroups: _q.withUserAllowedGroups.Clone(),
// clone intermediate query.
sql: _q.sql.Clone(),
@@ -643,6 +693,28 @@ func (_q *UserQuery) WithPaymentOrders(opts ...func(*PaymentOrderQuery)) *UserQu
return _q
}
+// WithAuthIdentities tells the query-builder to eager-load the nodes that are connected to
+// the "auth_identities" edge. The optional arguments are used to configure the query builder of the edge.
+func (_q *UserQuery) WithAuthIdentities(opts ...func(*AuthIdentityQuery)) *UserQuery {
+ query := (&AuthIdentityClient{config: _q.config}).Query()
+ for _, opt := range opts {
+ opt(query)
+ }
+ _q.withAuthIdentities = query
+ return _q
+}
+
+// WithPendingAuthSessions tells the query-builder to eager-load the nodes that are connected to
+// the "pending_auth_sessions" edge. The optional arguments are used to configure the query builder of the edge.
+func (_q *UserQuery) WithPendingAuthSessions(opts ...func(*PendingAuthSessionQuery)) *UserQuery {
+ query := (&PendingAuthSessionClient{config: _q.config}).Query()
+ for _, opt := range opts {
+ opt(query)
+ }
+ _q.withPendingAuthSessions = query
+ return _q
+}
+
// WithUserAllowedGroups tells the query-builder to eager-load the nodes that are connected to
// the "user_allowed_groups" edge. The optional arguments are used to configure the query builder of the edge.
func (_q *UserQuery) WithUserAllowedGroups(opts ...func(*UserAllowedGroupQuery)) *UserQuery {
@@ -732,7 +804,7 @@ func (_q *UserQuery) sqlAll(ctx context.Context, hooks ...queryHook) ([]*User, e
var (
nodes = []*User{}
_spec = _q.querySpec()
- loadedTypes = [11]bool{
+ loadedTypes = [13]bool{
_q.withAPIKeys != nil,
_q.withRedeemCodes != nil,
_q.withSubscriptions != nil,
@@ -743,6 +815,8 @@ func (_q *UserQuery) sqlAll(ctx context.Context, hooks ...queryHook) ([]*User, e
_q.withAttributeValues != nil,
_q.withPromoCodeUsages != nil,
_q.withPaymentOrders != nil,
+ _q.withAuthIdentities != nil,
+ _q.withPendingAuthSessions != nil,
_q.withUserAllowedGroups != nil,
}
)
@@ -839,6 +913,22 @@ func (_q *UserQuery) sqlAll(ctx context.Context, hooks ...queryHook) ([]*User, e
return nil, err
}
}
+ if query := _q.withAuthIdentities; query != nil {
+ if err := _q.loadAuthIdentities(ctx, query, nodes,
+ func(n *User) { n.Edges.AuthIdentities = []*AuthIdentity{} },
+ func(n *User, e *AuthIdentity) { n.Edges.AuthIdentities = append(n.Edges.AuthIdentities, e) }); err != nil {
+ return nil, err
+ }
+ }
+ if query := _q.withPendingAuthSessions; query != nil {
+ if err := _q.loadPendingAuthSessions(ctx, query, nodes,
+ func(n *User) { n.Edges.PendingAuthSessions = []*PendingAuthSession{} },
+ func(n *User, e *PendingAuthSession) {
+ n.Edges.PendingAuthSessions = append(n.Edges.PendingAuthSessions, e)
+ }); err != nil {
+ return nil, err
+ }
+ }
if query := _q.withUserAllowedGroups; query != nil {
if err := _q.loadUserAllowedGroups(ctx, query, nodes,
func(n *User) { n.Edges.UserAllowedGroups = []*UserAllowedGroup{} },
@@ -1186,6 +1276,69 @@ func (_q *UserQuery) loadPaymentOrders(ctx context.Context, query *PaymentOrderQ
}
return nil
}
+func (_q *UserQuery) loadAuthIdentities(ctx context.Context, query *AuthIdentityQuery, nodes []*User, init func(*User), assign func(*User, *AuthIdentity)) error {
+ fks := make([]driver.Value, 0, len(nodes))
+ nodeids := make(map[int64]*User)
+ for i := range nodes {
+ fks = append(fks, nodes[i].ID)
+ nodeids[nodes[i].ID] = nodes[i]
+ if init != nil {
+ init(nodes[i])
+ }
+ }
+ if len(query.ctx.Fields) > 0 {
+ query.ctx.AppendFieldOnce(authidentity.FieldUserID)
+ }
+ query.Where(predicate.AuthIdentity(func(s *sql.Selector) {
+ s.Where(sql.InValues(s.C(user.AuthIdentitiesColumn), fks...))
+ }))
+ neighbors, err := query.All(ctx)
+ if err != nil {
+ return err
+ }
+ for _, n := range neighbors {
+ fk := n.UserID
+ node, ok := nodeids[fk]
+ if !ok {
+ return fmt.Errorf(`unexpected referenced foreign-key "user_id" returned %v for node %v`, fk, n.ID)
+ }
+ assign(node, n)
+ }
+ return nil
+}
+func (_q *UserQuery) loadPendingAuthSessions(ctx context.Context, query *PendingAuthSessionQuery, nodes []*User, init func(*User), assign func(*User, *PendingAuthSession)) error {
+ fks := make([]driver.Value, 0, len(nodes))
+ nodeids := make(map[int64]*User)
+ for i := range nodes {
+ fks = append(fks, nodes[i].ID)
+ nodeids[nodes[i].ID] = nodes[i]
+ if init != nil {
+ init(nodes[i])
+ }
+ }
+ if len(query.ctx.Fields) > 0 {
+ query.ctx.AppendFieldOnce(pendingauthsession.FieldTargetUserID)
+ }
+ query.Where(predicate.PendingAuthSession(func(s *sql.Selector) {
+ s.Where(sql.InValues(s.C(user.PendingAuthSessionsColumn), fks...))
+ }))
+ neighbors, err := query.All(ctx)
+ if err != nil {
+ return err
+ }
+ for _, n := range neighbors {
+ fk := n.TargetUserID
+ if fk == nil {
+ return fmt.Errorf(`foreign-key "target_user_id" is nil for node %v`, n.ID)
+ }
+ node, ok := nodeids[*fk]
+ if !ok {
+ return fmt.Errorf(`unexpected referenced foreign-key "target_user_id" returned %v for node %v`, *fk, n.ID)
+ }
+ assign(node, n)
+ }
+ return nil
+}
func (_q *UserQuery) loadUserAllowedGroups(ctx context.Context, query *UserAllowedGroupQuery, nodes []*User, init func(*User), assign func(*User, *UserAllowedGroup)) error {
fks := make([]driver.Value, 0, len(nodes))
nodeids := make(map[int64]*User)
diff --git a/backend/ent/user_update.go b/backend/ent/user_update.go
index 6b355247..677eeb6b 100644
--- a/backend/ent/user_update.go
+++ b/backend/ent/user_update.go
@@ -13,8 +13,10 @@ import (
"entgo.io/ent/schema/field"
"github.com/Wei-Shaw/sub2api/ent/announcementread"
"github.com/Wei-Shaw/sub2api/ent/apikey"
+ "github.com/Wei-Shaw/sub2api/ent/authidentity"
"github.com/Wei-Shaw/sub2api/ent/group"
"github.com/Wei-Shaw/sub2api/ent/paymentorder"
+ "github.com/Wei-Shaw/sub2api/ent/pendingauthsession"
"github.com/Wei-Shaw/sub2api/ent/predicate"
"github.com/Wei-Shaw/sub2api/ent/promocodeusage"
"github.com/Wei-Shaw/sub2api/ent/redeemcode"
@@ -243,6 +245,60 @@ func (_u *UserUpdate) ClearTotpEnabledAt() *UserUpdate {
return _u
}
+// SetSignupSource sets the "signup_source" field.
+func (_u *UserUpdate) SetSignupSource(v string) *UserUpdate {
+ _u.mutation.SetSignupSource(v)
+ return _u
+}
+
+// SetNillableSignupSource sets the "signup_source" field if the given value is not nil.
+func (_u *UserUpdate) SetNillableSignupSource(v *string) *UserUpdate {
+ if v != nil {
+ _u.SetSignupSource(*v)
+ }
+ return _u
+}
+
+// SetLastLoginAt sets the "last_login_at" field.
+func (_u *UserUpdate) SetLastLoginAt(v time.Time) *UserUpdate {
+ _u.mutation.SetLastLoginAt(v)
+ return _u
+}
+
+// SetNillableLastLoginAt sets the "last_login_at" field if the given value is not nil.
+func (_u *UserUpdate) SetNillableLastLoginAt(v *time.Time) *UserUpdate {
+ if v != nil {
+ _u.SetLastLoginAt(*v)
+ }
+ return _u
+}
+
+// ClearLastLoginAt clears the value of the "last_login_at" field.
+func (_u *UserUpdate) ClearLastLoginAt() *UserUpdate {
+ _u.mutation.ClearLastLoginAt()
+ return _u
+}
+
+// SetLastActiveAt sets the "last_active_at" field.
+func (_u *UserUpdate) SetLastActiveAt(v time.Time) *UserUpdate {
+ _u.mutation.SetLastActiveAt(v)
+ return _u
+}
+
+// SetNillableLastActiveAt sets the "last_active_at" field if the given value is not nil.
+func (_u *UserUpdate) SetNillableLastActiveAt(v *time.Time) *UserUpdate {
+ if v != nil {
+ _u.SetLastActiveAt(*v)
+ }
+ return _u
+}
+
+// ClearLastActiveAt clears the value of the "last_active_at" field.
+func (_u *UserUpdate) ClearLastActiveAt() *UserUpdate {
+ _u.mutation.ClearLastActiveAt()
+ return _u
+}
+
// SetBalanceNotifyEnabled sets the "balance_notify_enabled" field.
func (_u *UserUpdate) SetBalanceNotifyEnabled(v bool) *UserUpdate {
_u.mutation.SetBalanceNotifyEnabled(v)
@@ -483,6 +539,36 @@ func (_u *UserUpdate) AddPaymentOrders(v ...*PaymentOrder) *UserUpdate {
return _u.AddPaymentOrderIDs(ids...)
}
+// AddAuthIdentityIDs adds the "auth_identities" edge to the AuthIdentity entity by IDs.
+func (_u *UserUpdate) AddAuthIdentityIDs(ids ...int64) *UserUpdate {
+ _u.mutation.AddAuthIdentityIDs(ids...)
+ return _u
+}
+
+// AddAuthIdentities adds the "auth_identities" edges to the AuthIdentity entity.
+func (_u *UserUpdate) AddAuthIdentities(v ...*AuthIdentity) *UserUpdate {
+ ids := make([]int64, len(v))
+ for i := range v {
+ ids[i] = v[i].ID
+ }
+ return _u.AddAuthIdentityIDs(ids...)
+}
+
+// AddPendingAuthSessionIDs adds the "pending_auth_sessions" edge to the PendingAuthSession entity by IDs.
+func (_u *UserUpdate) AddPendingAuthSessionIDs(ids ...int64) *UserUpdate {
+ _u.mutation.AddPendingAuthSessionIDs(ids...)
+ return _u
+}
+
+// AddPendingAuthSessions adds the "pending_auth_sessions" edges to the PendingAuthSession entity.
+func (_u *UserUpdate) AddPendingAuthSessions(v ...*PendingAuthSession) *UserUpdate {
+ ids := make([]int64, len(v))
+ for i := range v {
+ ids[i] = v[i].ID
+ }
+ return _u.AddPendingAuthSessionIDs(ids...)
+}
+
// Mutation returns the UserMutation object of the builder.
func (_u *UserUpdate) Mutation() *UserMutation {
return _u.mutation
@@ -698,6 +784,48 @@ func (_u *UserUpdate) RemovePaymentOrders(v ...*PaymentOrder) *UserUpdate {
return _u.RemovePaymentOrderIDs(ids...)
}
+// ClearAuthIdentities clears all "auth_identities" edges to the AuthIdentity entity.
+func (_u *UserUpdate) ClearAuthIdentities() *UserUpdate {
+ _u.mutation.ClearAuthIdentities()
+ return _u
+}
+
+// RemoveAuthIdentityIDs removes the "auth_identities" edge to AuthIdentity entities by IDs.
+func (_u *UserUpdate) RemoveAuthIdentityIDs(ids ...int64) *UserUpdate {
+ _u.mutation.RemoveAuthIdentityIDs(ids...)
+ return _u
+}
+
+// RemoveAuthIdentities removes "auth_identities" edges to AuthIdentity entities.
+func (_u *UserUpdate) RemoveAuthIdentities(v ...*AuthIdentity) *UserUpdate {
+ ids := make([]int64, len(v))
+ for i := range v {
+ ids[i] = v[i].ID
+ }
+ return _u.RemoveAuthIdentityIDs(ids...)
+}
+
+// ClearPendingAuthSessions clears all "pending_auth_sessions" edges to the PendingAuthSession entity.
+func (_u *UserUpdate) ClearPendingAuthSessions() *UserUpdate {
+ _u.mutation.ClearPendingAuthSessions()
+ return _u
+}
+
+// RemovePendingAuthSessionIDs removes the "pending_auth_sessions" edge to PendingAuthSession entities by IDs.
+func (_u *UserUpdate) RemovePendingAuthSessionIDs(ids ...int64) *UserUpdate {
+ _u.mutation.RemovePendingAuthSessionIDs(ids...)
+ return _u
+}
+
+// RemovePendingAuthSessions removes "pending_auth_sessions" edges to PendingAuthSession entities.
+func (_u *UserUpdate) RemovePendingAuthSessions(v ...*PendingAuthSession) *UserUpdate {
+ ids := make([]int64, len(v))
+ for i := range v {
+ ids[i] = v[i].ID
+ }
+ return _u.RemovePendingAuthSessionIDs(ids...)
+}
+
// Save executes the query and returns the number of nodes affected by the update operation.
func (_u *UserUpdate) Save(ctx context.Context) (int, error) {
if err := _u.defaults(); err != nil {
@@ -767,6 +895,11 @@ func (_u *UserUpdate) check() error {
return &ValidationError{Name: "username", err: fmt.Errorf(`ent: validator failed for field "User.username": %w`, err)}
}
}
+ if v, ok := _u.mutation.SignupSource(); ok {
+ if err := user.SignupSourceValidator(v); err != nil {
+ return &ValidationError{Name: "signup_source", err: fmt.Errorf(`ent: validator failed for field "User.signup_source": %w`, err)}
+ }
+ }
return nil
}
@@ -836,6 +969,21 @@ func (_u *UserUpdate) sqlSave(ctx context.Context) (_node int, err error) {
if _u.mutation.TotpEnabledAtCleared() {
_spec.ClearField(user.FieldTotpEnabledAt, field.TypeTime)
}
+ if value, ok := _u.mutation.SignupSource(); ok {
+ _spec.SetField(user.FieldSignupSource, field.TypeString, value)
+ }
+ if value, ok := _u.mutation.LastLoginAt(); ok {
+ _spec.SetField(user.FieldLastLoginAt, field.TypeTime, value)
+ }
+ if _u.mutation.LastLoginAtCleared() {
+ _spec.ClearField(user.FieldLastLoginAt, field.TypeTime)
+ }
+ if value, ok := _u.mutation.LastActiveAt(); ok {
+ _spec.SetField(user.FieldLastActiveAt, field.TypeTime, value)
+ }
+ if _u.mutation.LastActiveAtCleared() {
+ _spec.ClearField(user.FieldLastActiveAt, field.TypeTime)
+ }
if value, ok := _u.mutation.BalanceNotifyEnabled(); ok {
_spec.SetField(user.FieldBalanceNotifyEnabled, field.TypeBool, value)
}
@@ -1322,6 +1470,96 @@ func (_u *UserUpdate) sqlSave(ctx context.Context) (_node int, err error) {
}
_spec.Edges.Add = append(_spec.Edges.Add, edge)
}
+ if _u.mutation.AuthIdentitiesCleared() {
+ edge := &sqlgraph.EdgeSpec{
+ Rel: sqlgraph.O2M,
+ Inverse: false,
+ Table: user.AuthIdentitiesTable,
+ Columns: []string{user.AuthIdentitiesColumn},
+ Bidi: false,
+ Target: &sqlgraph.EdgeTarget{
+ IDSpec: sqlgraph.NewFieldSpec(authidentity.FieldID, field.TypeInt64),
+ },
+ }
+ _spec.Edges.Clear = append(_spec.Edges.Clear, edge)
+ }
+ if nodes := _u.mutation.RemovedAuthIdentitiesIDs(); len(nodes) > 0 && !_u.mutation.AuthIdentitiesCleared() {
+ edge := &sqlgraph.EdgeSpec{
+ Rel: sqlgraph.O2M,
+ Inverse: false,
+ Table: user.AuthIdentitiesTable,
+ Columns: []string{user.AuthIdentitiesColumn},
+ Bidi: false,
+ Target: &sqlgraph.EdgeTarget{
+ IDSpec: sqlgraph.NewFieldSpec(authidentity.FieldID, field.TypeInt64),
+ },
+ }
+ for _, k := range nodes {
+ edge.Target.Nodes = append(edge.Target.Nodes, k)
+ }
+ _spec.Edges.Clear = append(_spec.Edges.Clear, edge)
+ }
+ if nodes := _u.mutation.AuthIdentitiesIDs(); len(nodes) > 0 {
+ edge := &sqlgraph.EdgeSpec{
+ Rel: sqlgraph.O2M,
+ Inverse: false,
+ Table: user.AuthIdentitiesTable,
+ Columns: []string{user.AuthIdentitiesColumn},
+ Bidi: false,
+ Target: &sqlgraph.EdgeTarget{
+ IDSpec: sqlgraph.NewFieldSpec(authidentity.FieldID, field.TypeInt64),
+ },
+ }
+ for _, k := range nodes {
+ edge.Target.Nodes = append(edge.Target.Nodes, k)
+ }
+ _spec.Edges.Add = append(_spec.Edges.Add, edge)
+ }
+ if _u.mutation.PendingAuthSessionsCleared() {
+ edge := &sqlgraph.EdgeSpec{
+ Rel: sqlgraph.O2M,
+ Inverse: false,
+ Table: user.PendingAuthSessionsTable,
+ Columns: []string{user.PendingAuthSessionsColumn},
+ Bidi: false,
+ Target: &sqlgraph.EdgeTarget{
+ IDSpec: sqlgraph.NewFieldSpec(pendingauthsession.FieldID, field.TypeInt64),
+ },
+ }
+ _spec.Edges.Clear = append(_spec.Edges.Clear, edge)
+ }
+ if nodes := _u.mutation.RemovedPendingAuthSessionsIDs(); len(nodes) > 0 && !_u.mutation.PendingAuthSessionsCleared() {
+ edge := &sqlgraph.EdgeSpec{
+ Rel: sqlgraph.O2M,
+ Inverse: false,
+ Table: user.PendingAuthSessionsTable,
+ Columns: []string{user.PendingAuthSessionsColumn},
+ Bidi: false,
+ Target: &sqlgraph.EdgeTarget{
+ IDSpec: sqlgraph.NewFieldSpec(pendingauthsession.FieldID, field.TypeInt64),
+ },
+ }
+ for _, k := range nodes {
+ edge.Target.Nodes = append(edge.Target.Nodes, k)
+ }
+ _spec.Edges.Clear = append(_spec.Edges.Clear, edge)
+ }
+ if nodes := _u.mutation.PendingAuthSessionsIDs(); len(nodes) > 0 {
+ edge := &sqlgraph.EdgeSpec{
+ Rel: sqlgraph.O2M,
+ Inverse: false,
+ Table: user.PendingAuthSessionsTable,
+ Columns: []string{user.PendingAuthSessionsColumn},
+ Bidi: false,
+ Target: &sqlgraph.EdgeTarget{
+ IDSpec: sqlgraph.NewFieldSpec(pendingauthsession.FieldID, field.TypeInt64),
+ },
+ }
+ for _, k := range nodes {
+ edge.Target.Nodes = append(edge.Target.Nodes, k)
+ }
+ _spec.Edges.Add = append(_spec.Edges.Add, edge)
+ }
if _node, err = sqlgraph.UpdateNodes(ctx, _u.driver, _spec); err != nil {
if _, ok := err.(*sqlgraph.NotFoundError); ok {
err = &NotFoundError{user.Label}
@@ -1548,6 +1786,60 @@ func (_u *UserUpdateOne) ClearTotpEnabledAt() *UserUpdateOne {
return _u
}
+// SetSignupSource sets the "signup_source" field.
+func (_u *UserUpdateOne) SetSignupSource(v string) *UserUpdateOne {
+ _u.mutation.SetSignupSource(v)
+ return _u
+}
+
+// SetNillableSignupSource sets the "signup_source" field if the given value is not nil.
+func (_u *UserUpdateOne) SetNillableSignupSource(v *string) *UserUpdateOne {
+ if v != nil {
+ _u.SetSignupSource(*v)
+ }
+ return _u
+}
+
+// SetLastLoginAt sets the "last_login_at" field.
+func (_u *UserUpdateOne) SetLastLoginAt(v time.Time) *UserUpdateOne {
+ _u.mutation.SetLastLoginAt(v)
+ return _u
+}
+
+// SetNillableLastLoginAt sets the "last_login_at" field if the given value is not nil.
+func (_u *UserUpdateOne) SetNillableLastLoginAt(v *time.Time) *UserUpdateOne {
+ if v != nil {
+ _u.SetLastLoginAt(*v)
+ }
+ return _u
+}
+
+// ClearLastLoginAt clears the value of the "last_login_at" field.
+func (_u *UserUpdateOne) ClearLastLoginAt() *UserUpdateOne {
+ _u.mutation.ClearLastLoginAt()
+ return _u
+}
+
+// SetLastActiveAt sets the "last_active_at" field.
+func (_u *UserUpdateOne) SetLastActiveAt(v time.Time) *UserUpdateOne {
+ _u.mutation.SetLastActiveAt(v)
+ return _u
+}
+
+// SetNillableLastActiveAt sets the "last_active_at" field if the given value is not nil.
+func (_u *UserUpdateOne) SetNillableLastActiveAt(v *time.Time) *UserUpdateOne {
+ if v != nil {
+ _u.SetLastActiveAt(*v)
+ }
+ return _u
+}
+
+// ClearLastActiveAt clears the value of the "last_active_at" field.
+func (_u *UserUpdateOne) ClearLastActiveAt() *UserUpdateOne {
+ _u.mutation.ClearLastActiveAt()
+ return _u
+}
+
// SetBalanceNotifyEnabled sets the "balance_notify_enabled" field.
func (_u *UserUpdateOne) SetBalanceNotifyEnabled(v bool) *UserUpdateOne {
_u.mutation.SetBalanceNotifyEnabled(v)
@@ -1788,6 +2080,36 @@ func (_u *UserUpdateOne) AddPaymentOrders(v ...*PaymentOrder) *UserUpdateOne {
return _u.AddPaymentOrderIDs(ids...)
}
+// AddAuthIdentityIDs adds the "auth_identities" edge to the AuthIdentity entity by IDs.
+func (_u *UserUpdateOne) AddAuthIdentityIDs(ids ...int64) *UserUpdateOne {
+ _u.mutation.AddAuthIdentityIDs(ids...)
+ return _u
+}
+
+// AddAuthIdentities adds the "auth_identities" edges to the AuthIdentity entity.
+func (_u *UserUpdateOne) AddAuthIdentities(v ...*AuthIdentity) *UserUpdateOne {
+ ids := make([]int64, len(v))
+ for i := range v {
+ ids[i] = v[i].ID
+ }
+ return _u.AddAuthIdentityIDs(ids...)
+}
+
+// AddPendingAuthSessionIDs adds the "pending_auth_sessions" edge to the PendingAuthSession entity by IDs.
+func (_u *UserUpdateOne) AddPendingAuthSessionIDs(ids ...int64) *UserUpdateOne {
+ _u.mutation.AddPendingAuthSessionIDs(ids...)
+ return _u
+}
+
+// AddPendingAuthSessions adds the "pending_auth_sessions" edges to the PendingAuthSession entity.
+func (_u *UserUpdateOne) AddPendingAuthSessions(v ...*PendingAuthSession) *UserUpdateOne {
+ ids := make([]int64, len(v))
+ for i := range v {
+ ids[i] = v[i].ID
+ }
+ return _u.AddPendingAuthSessionIDs(ids...)
+}
+
// Mutation returns the UserMutation object of the builder.
func (_u *UserUpdateOne) Mutation() *UserMutation {
return _u.mutation
@@ -2003,6 +2325,48 @@ func (_u *UserUpdateOne) RemovePaymentOrders(v ...*PaymentOrder) *UserUpdateOne
return _u.RemovePaymentOrderIDs(ids...)
}
+// ClearAuthIdentities clears all "auth_identities" edges to the AuthIdentity entity.
+func (_u *UserUpdateOne) ClearAuthIdentities() *UserUpdateOne {
+ _u.mutation.ClearAuthIdentities()
+ return _u
+}
+
+// RemoveAuthIdentityIDs removes the "auth_identities" edge to AuthIdentity entities by IDs.
+func (_u *UserUpdateOne) RemoveAuthIdentityIDs(ids ...int64) *UserUpdateOne {
+ _u.mutation.RemoveAuthIdentityIDs(ids...)
+ return _u
+}
+
+// RemoveAuthIdentities removes "auth_identities" edges to AuthIdentity entities.
+func (_u *UserUpdateOne) RemoveAuthIdentities(v ...*AuthIdentity) *UserUpdateOne {
+ ids := make([]int64, len(v))
+ for i := range v {
+ ids[i] = v[i].ID
+ }
+ return _u.RemoveAuthIdentityIDs(ids...)
+}
+
+// ClearPendingAuthSessions clears all "pending_auth_sessions" edges to the PendingAuthSession entity.
+func (_u *UserUpdateOne) ClearPendingAuthSessions() *UserUpdateOne {
+ _u.mutation.ClearPendingAuthSessions()
+ return _u
+}
+
+// RemovePendingAuthSessionIDs removes the "pending_auth_sessions" edge to PendingAuthSession entities by IDs.
+func (_u *UserUpdateOne) RemovePendingAuthSessionIDs(ids ...int64) *UserUpdateOne {
+ _u.mutation.RemovePendingAuthSessionIDs(ids...)
+ return _u
+}
+
+// RemovePendingAuthSessions removes "pending_auth_sessions" edges to PendingAuthSession entities.
+func (_u *UserUpdateOne) RemovePendingAuthSessions(v ...*PendingAuthSession) *UserUpdateOne {
+ ids := make([]int64, len(v))
+ for i := range v {
+ ids[i] = v[i].ID
+ }
+ return _u.RemovePendingAuthSessionIDs(ids...)
+}
+
// Where appends a list predicates to the UserUpdate builder.
func (_u *UserUpdateOne) Where(ps ...predicate.User) *UserUpdateOne {
_u.mutation.Where(ps...)
@@ -2085,6 +2449,11 @@ func (_u *UserUpdateOne) check() error {
return &ValidationError{Name: "username", err: fmt.Errorf(`ent: validator failed for field "User.username": %w`, err)}
}
}
+ if v, ok := _u.mutation.SignupSource(); ok {
+ if err := user.SignupSourceValidator(v); err != nil {
+ return &ValidationError{Name: "signup_source", err: fmt.Errorf(`ent: validator failed for field "User.signup_source": %w`, err)}
+ }
+ }
return nil
}
@@ -2171,6 +2540,21 @@ func (_u *UserUpdateOne) sqlSave(ctx context.Context) (_node *User, err error) {
if _u.mutation.TotpEnabledAtCleared() {
_spec.ClearField(user.FieldTotpEnabledAt, field.TypeTime)
}
+ if value, ok := _u.mutation.SignupSource(); ok {
+ _spec.SetField(user.FieldSignupSource, field.TypeString, value)
+ }
+ if value, ok := _u.mutation.LastLoginAt(); ok {
+ _spec.SetField(user.FieldLastLoginAt, field.TypeTime, value)
+ }
+ if _u.mutation.LastLoginAtCleared() {
+ _spec.ClearField(user.FieldLastLoginAt, field.TypeTime)
+ }
+ if value, ok := _u.mutation.LastActiveAt(); ok {
+ _spec.SetField(user.FieldLastActiveAt, field.TypeTime, value)
+ }
+ if _u.mutation.LastActiveAtCleared() {
+ _spec.ClearField(user.FieldLastActiveAt, field.TypeTime)
+ }
if value, ok := _u.mutation.BalanceNotifyEnabled(); ok {
_spec.SetField(user.FieldBalanceNotifyEnabled, field.TypeBool, value)
}
@@ -2657,6 +3041,96 @@ func (_u *UserUpdateOne) sqlSave(ctx context.Context) (_node *User, err error) {
}
_spec.Edges.Add = append(_spec.Edges.Add, edge)
}
+ if _u.mutation.AuthIdentitiesCleared() {
+ edge := &sqlgraph.EdgeSpec{
+ Rel: sqlgraph.O2M,
+ Inverse: false,
+ Table: user.AuthIdentitiesTable,
+ Columns: []string{user.AuthIdentitiesColumn},
+ Bidi: false,
+ Target: &sqlgraph.EdgeTarget{
+ IDSpec: sqlgraph.NewFieldSpec(authidentity.FieldID, field.TypeInt64),
+ },
+ }
+ _spec.Edges.Clear = append(_spec.Edges.Clear, edge)
+ }
+ if nodes := _u.mutation.RemovedAuthIdentitiesIDs(); len(nodes) > 0 && !_u.mutation.AuthIdentitiesCleared() {
+ edge := &sqlgraph.EdgeSpec{
+ Rel: sqlgraph.O2M,
+ Inverse: false,
+ Table: user.AuthIdentitiesTable,
+ Columns: []string{user.AuthIdentitiesColumn},
+ Bidi: false,
+ Target: &sqlgraph.EdgeTarget{
+ IDSpec: sqlgraph.NewFieldSpec(authidentity.FieldID, field.TypeInt64),
+ },
+ }
+ for _, k := range nodes {
+ edge.Target.Nodes = append(edge.Target.Nodes, k)
+ }
+ _spec.Edges.Clear = append(_spec.Edges.Clear, edge)
+ }
+ if nodes := _u.mutation.AuthIdentitiesIDs(); len(nodes) > 0 {
+ edge := &sqlgraph.EdgeSpec{
+ Rel: sqlgraph.O2M,
+ Inverse: false,
+ Table: user.AuthIdentitiesTable,
+ Columns: []string{user.AuthIdentitiesColumn},
+ Bidi: false,
+ Target: &sqlgraph.EdgeTarget{
+ IDSpec: sqlgraph.NewFieldSpec(authidentity.FieldID, field.TypeInt64),
+ },
+ }
+ for _, k := range nodes {
+ edge.Target.Nodes = append(edge.Target.Nodes, k)
+ }
+ _spec.Edges.Add = append(_spec.Edges.Add, edge)
+ }
+ if _u.mutation.PendingAuthSessionsCleared() {
+ edge := &sqlgraph.EdgeSpec{
+ Rel: sqlgraph.O2M,
+ Inverse: false,
+ Table: user.PendingAuthSessionsTable,
+ Columns: []string{user.PendingAuthSessionsColumn},
+ Bidi: false,
+ Target: &sqlgraph.EdgeTarget{
+ IDSpec: sqlgraph.NewFieldSpec(pendingauthsession.FieldID, field.TypeInt64),
+ },
+ }
+ _spec.Edges.Clear = append(_spec.Edges.Clear, edge)
+ }
+ if nodes := _u.mutation.RemovedPendingAuthSessionsIDs(); len(nodes) > 0 && !_u.mutation.PendingAuthSessionsCleared() {
+ edge := &sqlgraph.EdgeSpec{
+ Rel: sqlgraph.O2M,
+ Inverse: false,
+ Table: user.PendingAuthSessionsTable,
+ Columns: []string{user.PendingAuthSessionsColumn},
+ Bidi: false,
+ Target: &sqlgraph.EdgeTarget{
+ IDSpec: sqlgraph.NewFieldSpec(pendingauthsession.FieldID, field.TypeInt64),
+ },
+ }
+ for _, k := range nodes {
+ edge.Target.Nodes = append(edge.Target.Nodes, k)
+ }
+ _spec.Edges.Clear = append(_spec.Edges.Clear, edge)
+ }
+ if nodes := _u.mutation.PendingAuthSessionsIDs(); len(nodes) > 0 {
+ edge := &sqlgraph.EdgeSpec{
+ Rel: sqlgraph.O2M,
+ Inverse: false,
+ Table: user.PendingAuthSessionsTable,
+ Columns: []string{user.PendingAuthSessionsColumn},
+ Bidi: false,
+ Target: &sqlgraph.EdgeTarget{
+ IDSpec: sqlgraph.NewFieldSpec(pendingauthsession.FieldID, field.TypeInt64),
+ },
+ }
+ for _, k := range nodes {
+ edge.Target.Nodes = append(edge.Target.Nodes, k)
+ }
+ _spec.Edges.Add = append(_spec.Edges.Add, edge)
+ }
_node = &User{config: _u.config}
_spec.Assign = _node.assignValues
_spec.ScanValues = _node.scanValues
diff --git a/backend/internal/config/config.go b/backend/internal/config/config.go
index dd9a4e58..6136e9ea 100644
--- a/backend/internal/config/config.go
+++ b/backend/internal/config/config.go
@@ -1608,6 +1608,9 @@ func (c *Config) Validate() error {
return fmt.Errorf("security.csp.policy is required when CSP is enabled")
}
if c.LinuxDo.Enabled {
+ if !c.LinuxDo.UsePKCE {
+ return fmt.Errorf("linuxdo_connect.use_pkce must be true when linuxdo_connect.enabled=true")
+ }
if strings.TrimSpace(c.LinuxDo.ClientID) == "" {
return fmt.Errorf("linuxdo_connect.client_id is required when linuxdo_connect.enabled=true")
}
@@ -1629,9 +1632,6 @@ func (c *Config) Validate() error {
default:
return fmt.Errorf("linuxdo_connect.token_auth_method must be one of: client_secret_post/client_secret_basic/none")
}
- if method == "none" && !c.LinuxDo.UsePKCE {
- return fmt.Errorf("linuxdo_connect.use_pkce must be true when linuxdo_connect.token_auth_method=none")
- }
if (method == "" || method == "client_secret_post" || method == "client_secret_basic") &&
strings.TrimSpace(c.LinuxDo.ClientSecret) == "" {
return fmt.Errorf("linuxdo_connect.client_secret is required when linuxdo_connect.enabled=true and token_auth_method is client_secret_post/client_secret_basic")
@@ -1663,6 +1663,12 @@ func (c *Config) Validate() error {
warnIfInsecureURL("linuxdo_connect.frontend_redirect_url", c.LinuxDo.FrontendRedirectURL)
}
if c.OIDC.Enabled {
+ if !c.OIDC.UsePKCE {
+ return fmt.Errorf("oidc_connect.use_pkce must be true when oidc_connect.enabled=true")
+ }
+ if !c.OIDC.ValidateIDToken {
+ return fmt.Errorf("oidc_connect.validate_id_token must be true when oidc_connect.enabled=true")
+ }
if strings.TrimSpace(c.OIDC.ClientID) == "" {
return fmt.Errorf("oidc_connect.client_id is required when oidc_connect.enabled=true")
}
@@ -1685,9 +1691,6 @@ func (c *Config) Validate() error {
default:
return fmt.Errorf("oidc_connect.token_auth_method must be one of: client_secret_post/client_secret_basic/none")
}
- if method == "none" && !c.OIDC.UsePKCE {
- return fmt.Errorf("oidc_connect.use_pkce must be true when oidc_connect.token_auth_method=none")
- }
if (method == "" || method == "client_secret_post" || method == "client_secret_basic") &&
strings.TrimSpace(c.OIDC.ClientSecret) == "" {
return fmt.Errorf("oidc_connect.client_secret is required when oidc_connect.enabled=true and token_auth_method is client_secret_post/client_secret_basic")
diff --git a/backend/internal/handler/admin/setting_handler.go b/backend/internal/handler/admin/setting_handler.go
index bec0f126..fe5c7928 100644
--- a/backend/internal/handler/admin/setting_handler.go
+++ b/backend/internal/handler/admin/setting_handler.go
@@ -73,6 +73,11 @@ func (h *SettingHandler) GetSettings(c *gin.Context) {
response.ErrorFrom(c, err)
return
}
+ authSourceDefaults, err := h.settingService.GetAuthSourceDefaultSettings(c.Request.Context())
+ if err != nil {
+ response.ErrorFrom(c, err)
+ return
+ }
// Check if ops monitoring is enabled (respects config.ops.enabled)
opsEnabled := h.opsService != nil && h.opsService.IsMonitoringEnabled(c.Request.Context())
@@ -93,7 +98,7 @@ func (h *SettingHandler) GetSettings(c *gin.Context) {
paymentCfg = &service.PaymentConfig{}
}
- response.Success(c, dto.SystemSettings{
+ payload := dto.SystemSettings{
RegistrationEnabled: settings.RegistrationEnabled,
EmailVerifyEnabled: settings.EmailVerifyEnabled,
RegistrationEmailSuffixWhitelist: settings.RegistrationEmailSuffixWhitelist,
@@ -200,7 +205,8 @@ func (h *SettingHandler) GetSettings(c *gin.Context) {
PaymentCancelRateLimitWindow: paymentCfg.CancelRateLimitWindow,
PaymentCancelRateLimitUnit: paymentCfg.CancelRateLimitUnit,
PaymentCancelRateLimitMode: paymentCfg.CancelRateLimitMode,
- })
+ }
+ response.Success(c, systemSettingsResponseData(payload, authSourceDefaults))
}
// UpdateSettingsRequest 更新设置请求
@@ -276,9 +282,30 @@ type UpdateSettingsRequest struct {
CustomEndpoints *[]dto.CustomEndpoint `json:"custom_endpoints"`
// 默认配置
- DefaultConcurrency int `json:"default_concurrency"`
- DefaultBalance float64 `json:"default_balance"`
- DefaultSubscriptions []dto.DefaultSubscriptionSetting `json:"default_subscriptions"`
+ DefaultConcurrency int `json:"default_concurrency"`
+ DefaultBalance float64 `json:"default_balance"`
+ DefaultSubscriptions []dto.DefaultSubscriptionSetting `json:"default_subscriptions"`
+ AuthSourceDefaultEmailBalance *float64 `json:"auth_source_default_email_balance"`
+ AuthSourceDefaultEmailConcurrency *int `json:"auth_source_default_email_concurrency"`
+ AuthSourceDefaultEmailSubscriptions *[]dto.DefaultSubscriptionSetting `json:"auth_source_default_email_subscriptions"`
+ AuthSourceDefaultEmailGrantOnSignup *bool `json:"auth_source_default_email_grant_on_signup"`
+ AuthSourceDefaultEmailGrantOnFirstBind *bool `json:"auth_source_default_email_grant_on_first_bind"`
+ AuthSourceDefaultLinuxDoBalance *float64 `json:"auth_source_default_linuxdo_balance"`
+ AuthSourceDefaultLinuxDoConcurrency *int `json:"auth_source_default_linuxdo_concurrency"`
+ AuthSourceDefaultLinuxDoSubscriptions *[]dto.DefaultSubscriptionSetting `json:"auth_source_default_linuxdo_subscriptions"`
+ AuthSourceDefaultLinuxDoGrantOnSignup *bool `json:"auth_source_default_linuxdo_grant_on_signup"`
+ AuthSourceDefaultLinuxDoGrantOnFirstBind *bool `json:"auth_source_default_linuxdo_grant_on_first_bind"`
+ AuthSourceDefaultOIDCBalance *float64 `json:"auth_source_default_oidc_balance"`
+ AuthSourceDefaultOIDCConcurrency *int `json:"auth_source_default_oidc_concurrency"`
+ AuthSourceDefaultOIDCSubscriptions *[]dto.DefaultSubscriptionSetting `json:"auth_source_default_oidc_subscriptions"`
+ AuthSourceDefaultOIDCGrantOnSignup *bool `json:"auth_source_default_oidc_grant_on_signup"`
+ AuthSourceDefaultOIDCGrantOnFirstBind *bool `json:"auth_source_default_oidc_grant_on_first_bind"`
+ AuthSourceDefaultWeChatBalance *float64 `json:"auth_source_default_wechat_balance"`
+ AuthSourceDefaultWeChatConcurrency *int `json:"auth_source_default_wechat_concurrency"`
+ AuthSourceDefaultWeChatSubscriptions *[]dto.DefaultSubscriptionSetting `json:"auth_source_default_wechat_subscriptions"`
+ AuthSourceDefaultWeChatGrantOnSignup *bool `json:"auth_source_default_wechat_grant_on_signup"`
+ AuthSourceDefaultWeChatGrantOnFirstBind *bool `json:"auth_source_default_wechat_grant_on_first_bind"`
+ ForceEmailOnThirdPartySignup *bool `json:"force_email_on_third_party_signup"`
// Model fallback configuration
EnableModelFallback bool `json:"enable_model_fallback"`
@@ -357,6 +384,11 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
response.ErrorFrom(c, err)
return
}
+ previousAuthSourceDefaults, err := h.settingService.GetAuthSourceDefaultSettings(c.Request.Context())
+ if err != nil {
+ response.ErrorFrom(c, err)
+ return
+ }
// 验证参数
if req.DefaultConcurrency < 1 {
@@ -381,6 +413,10 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
req.SMTPPort = 587
}
req.DefaultSubscriptions = normalizeDefaultSubscriptions(req.DefaultSubscriptions)
+ req.AuthSourceDefaultEmailSubscriptions = normalizeOptionalDefaultSubscriptions(req.AuthSourceDefaultEmailSubscriptions)
+ req.AuthSourceDefaultLinuxDoSubscriptions = normalizeOptionalDefaultSubscriptions(req.AuthSourceDefaultLinuxDoSubscriptions)
+ req.AuthSourceDefaultOIDCSubscriptions = normalizeOptionalDefaultSubscriptions(req.AuthSourceDefaultOIDCSubscriptions)
+ req.AuthSourceDefaultWeChatSubscriptions = normalizeOptionalDefaultSubscriptions(req.AuthSourceDefaultWeChatSubscriptions)
// SMTP 配置保护:如果请求中 smtp_host 为空但数据库中已有配置,则保留已有 SMTP 配置
// 防止前端加载设置失败时空表单覆盖已保存的 SMTP 配置
@@ -538,25 +574,27 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
response.BadRequest(c, "OIDC scopes must contain openid")
return
}
+ if !req.OIDCConnectUsePKCE {
+ response.BadRequest(c, "OIDC PKCE must be enabled")
+ return
+ }
+ if !req.OIDCConnectValidateIDToken {
+ response.BadRequest(c, "OIDC ID Token validation must be enabled")
+ return
+ }
switch req.OIDCConnectTokenAuthMethod {
case "", "client_secret_post", "client_secret_basic", "none":
default:
response.BadRequest(c, "OIDC Token Auth Method must be one of client_secret_post/client_secret_basic/none")
return
}
- if req.OIDCConnectTokenAuthMethod == "none" && !req.OIDCConnectUsePKCE {
- response.BadRequest(c, "OIDC PKCE must be enabled when token_auth_method=none")
- return
- }
if req.OIDCConnectClockSkewSeconds < 0 || req.OIDCConnectClockSkewSeconds > 600 {
response.BadRequest(c, "OIDC clock skew seconds must be between 0 and 600")
return
}
- if req.OIDCConnectValidateIDToken {
- if req.OIDCConnectAllowedSigningAlgs == "" {
- response.BadRequest(c, "OIDC Allowed Signing Algs is required when validate_id_token=true")
- return
- }
+ if req.OIDCConnectAllowedSigningAlgs == "" {
+ response.BadRequest(c, "OIDC Allowed Signing Algs is required when validate_id_token=true")
+ return
}
if req.OIDCConnectJWKSURL != "" {
if err := config.ValidateAbsoluteHTTPURL(req.OIDCConnectJWKSURL); err != nil {
@@ -933,6 +971,41 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
response.ErrorFrom(c, err)
return
}
+ authSourceDefaults := &service.AuthSourceDefaultSettings{
+ Email: service.ProviderDefaultGrantSettings{
+ Balance: float64ValueOrDefault(req.AuthSourceDefaultEmailBalance, previousAuthSourceDefaults.Email.Balance),
+ Concurrency: intValueOrDefault(req.AuthSourceDefaultEmailConcurrency, previousAuthSourceDefaults.Email.Concurrency),
+ Subscriptions: defaultSubscriptionsValueOrDefault(req.AuthSourceDefaultEmailSubscriptions, previousAuthSourceDefaults.Email.Subscriptions),
+ GrantOnSignup: boolValueOrDefault(req.AuthSourceDefaultEmailGrantOnSignup, previousAuthSourceDefaults.Email.GrantOnSignup),
+ GrantOnFirstBind: boolValueOrDefault(req.AuthSourceDefaultEmailGrantOnFirstBind, previousAuthSourceDefaults.Email.GrantOnFirstBind),
+ },
+ LinuxDo: service.ProviderDefaultGrantSettings{
+ Balance: float64ValueOrDefault(req.AuthSourceDefaultLinuxDoBalance, previousAuthSourceDefaults.LinuxDo.Balance),
+ Concurrency: intValueOrDefault(req.AuthSourceDefaultLinuxDoConcurrency, previousAuthSourceDefaults.LinuxDo.Concurrency),
+ Subscriptions: defaultSubscriptionsValueOrDefault(req.AuthSourceDefaultLinuxDoSubscriptions, previousAuthSourceDefaults.LinuxDo.Subscriptions),
+ GrantOnSignup: boolValueOrDefault(req.AuthSourceDefaultLinuxDoGrantOnSignup, previousAuthSourceDefaults.LinuxDo.GrantOnSignup),
+ GrantOnFirstBind: boolValueOrDefault(req.AuthSourceDefaultLinuxDoGrantOnFirstBind, previousAuthSourceDefaults.LinuxDo.GrantOnFirstBind),
+ },
+ OIDC: service.ProviderDefaultGrantSettings{
+ Balance: float64ValueOrDefault(req.AuthSourceDefaultOIDCBalance, previousAuthSourceDefaults.OIDC.Balance),
+ Concurrency: intValueOrDefault(req.AuthSourceDefaultOIDCConcurrency, previousAuthSourceDefaults.OIDC.Concurrency),
+ Subscriptions: defaultSubscriptionsValueOrDefault(req.AuthSourceDefaultOIDCSubscriptions, previousAuthSourceDefaults.OIDC.Subscriptions),
+ GrantOnSignup: boolValueOrDefault(req.AuthSourceDefaultOIDCGrantOnSignup, previousAuthSourceDefaults.OIDC.GrantOnSignup),
+ GrantOnFirstBind: boolValueOrDefault(req.AuthSourceDefaultOIDCGrantOnFirstBind, previousAuthSourceDefaults.OIDC.GrantOnFirstBind),
+ },
+ WeChat: service.ProviderDefaultGrantSettings{
+ Balance: float64ValueOrDefault(req.AuthSourceDefaultWeChatBalance, previousAuthSourceDefaults.WeChat.Balance),
+ Concurrency: intValueOrDefault(req.AuthSourceDefaultWeChatConcurrency, previousAuthSourceDefaults.WeChat.Concurrency),
+ Subscriptions: defaultSubscriptionsValueOrDefault(req.AuthSourceDefaultWeChatSubscriptions, previousAuthSourceDefaults.WeChat.Subscriptions),
+ GrantOnSignup: boolValueOrDefault(req.AuthSourceDefaultWeChatGrantOnSignup, previousAuthSourceDefaults.WeChat.GrantOnSignup),
+ GrantOnFirstBind: boolValueOrDefault(req.AuthSourceDefaultWeChatGrantOnFirstBind, previousAuthSourceDefaults.WeChat.GrantOnFirstBind),
+ },
+ ForceEmailOnThirdPartySignup: boolValueOrDefault(req.ForceEmailOnThirdPartySignup, previousAuthSourceDefaults.ForceEmailOnThirdPartySignup),
+ }
+ if err := h.settingService.UpdateAuthSourceDefaultSettings(c.Request.Context(), authSourceDefaults); err != nil {
+ response.ErrorFrom(c, err)
+ return
+ }
// Update payment configuration (integrated into system settings).
// Skip if no payment fields were provided (prevents accidental wipe).
@@ -977,6 +1050,11 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
response.ErrorFrom(c, err)
return
}
+ updatedAuthSourceDefaults, err := h.settingService.GetAuthSourceDefaultSettings(c.Request.Context())
+ if err != nil {
+ response.ErrorFrom(c, err)
+ return
+ }
updatedDefaultSubscriptions := make([]dto.DefaultSubscriptionSetting, 0, len(updatedSettings.DefaultSubscriptions))
for _, sub := range updatedSettings.DefaultSubscriptions {
updatedDefaultSubscriptions = append(updatedDefaultSubscriptions, dto.DefaultSubscriptionSetting{
@@ -994,7 +1072,7 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
updatedPaymentCfg = &service.PaymentConfig{}
}
- response.Success(c, dto.SystemSettings{
+ payload := dto.SystemSettings{
RegistrationEnabled: updatedSettings.RegistrationEnabled,
EmailVerifyEnabled: updatedSettings.EmailVerifyEnabled,
RegistrationEmailSuffixWhitelist: updatedSettings.RegistrationEmailSuffixWhitelist,
@@ -1100,7 +1178,8 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
PaymentCancelRateLimitWindow: updatedPaymentCfg.CancelRateLimitWindow,
PaymentCancelRateLimitUnit: updatedPaymentCfg.CancelRateLimitUnit,
PaymentCancelRateLimitMode: updatedPaymentCfg.CancelRateLimitMode,
- })
+ }
+ response.Success(c, systemSettingsResponseData(payload, updatedAuthSourceDefaults))
}
// hasPaymentFields returns true if any payment-related field was explicitly provided.
@@ -1412,6 +1491,84 @@ func normalizeDefaultSubscriptions(input []dto.DefaultSubscriptionSetting) []dto
return normalized
}
+func normalizeOptionalDefaultSubscriptions(input *[]dto.DefaultSubscriptionSetting) *[]dto.DefaultSubscriptionSetting {
+ if input == nil {
+ return nil
+ }
+ normalized := normalizeDefaultSubscriptions(*input)
+ return &normalized
+}
+
+func float64ValueOrDefault(value *float64, fallback float64) float64 {
+ if value == nil {
+ return fallback
+ }
+ return *value
+}
+
+func intValueOrDefault(value *int, fallback int) int {
+ if value == nil {
+ return fallback
+ }
+ return *value
+}
+
+func boolValueOrDefault(value *bool, fallback bool) bool {
+ if value == nil {
+ return fallback
+ }
+ return *value
+}
+
+func defaultSubscriptionsValueOrDefault(input *[]dto.DefaultSubscriptionSetting, fallback []service.DefaultSubscriptionSetting) []service.DefaultSubscriptionSetting {
+ if input == nil {
+ return fallback
+ }
+ result := make([]service.DefaultSubscriptionSetting, 0, len(*input))
+ for _, item := range *input {
+ result = append(result, service.DefaultSubscriptionSetting{
+ GroupID: item.GroupID,
+ ValidityDays: item.ValidityDays,
+ })
+ }
+ return result
+}
+
+func systemSettingsResponseData(settings dto.SystemSettings, authSourceDefaults *service.AuthSourceDefaultSettings) map[string]any {
+ data := make(map[string]any)
+ raw, err := json.Marshal(settings)
+ if err == nil {
+ _ = json.Unmarshal(raw, &data)
+ }
+ if authSourceDefaults == nil {
+ authSourceDefaults = &service.AuthSourceDefaultSettings{}
+ }
+
+ data["auth_source_default_email_balance"] = authSourceDefaults.Email.Balance
+ data["auth_source_default_email_concurrency"] = authSourceDefaults.Email.Concurrency
+ data["auth_source_default_email_subscriptions"] = authSourceDefaults.Email.Subscriptions
+ data["auth_source_default_email_grant_on_signup"] = authSourceDefaults.Email.GrantOnSignup
+ data["auth_source_default_email_grant_on_first_bind"] = authSourceDefaults.Email.GrantOnFirstBind
+ data["auth_source_default_linuxdo_balance"] = authSourceDefaults.LinuxDo.Balance
+ data["auth_source_default_linuxdo_concurrency"] = authSourceDefaults.LinuxDo.Concurrency
+ data["auth_source_default_linuxdo_subscriptions"] = authSourceDefaults.LinuxDo.Subscriptions
+ data["auth_source_default_linuxdo_grant_on_signup"] = authSourceDefaults.LinuxDo.GrantOnSignup
+ data["auth_source_default_linuxdo_grant_on_first_bind"] = authSourceDefaults.LinuxDo.GrantOnFirstBind
+ data["auth_source_default_oidc_balance"] = authSourceDefaults.OIDC.Balance
+ data["auth_source_default_oidc_concurrency"] = authSourceDefaults.OIDC.Concurrency
+ data["auth_source_default_oidc_subscriptions"] = authSourceDefaults.OIDC.Subscriptions
+ data["auth_source_default_oidc_grant_on_signup"] = authSourceDefaults.OIDC.GrantOnSignup
+ data["auth_source_default_oidc_grant_on_first_bind"] = authSourceDefaults.OIDC.GrantOnFirstBind
+ data["auth_source_default_wechat_balance"] = authSourceDefaults.WeChat.Balance
+ data["auth_source_default_wechat_concurrency"] = authSourceDefaults.WeChat.Concurrency
+ data["auth_source_default_wechat_subscriptions"] = authSourceDefaults.WeChat.Subscriptions
+ data["auth_source_default_wechat_grant_on_signup"] = authSourceDefaults.WeChat.GrantOnSignup
+ data["auth_source_default_wechat_grant_on_first_bind"] = authSourceDefaults.WeChat.GrantOnFirstBind
+ data["force_email_on_third_party_signup"] = authSourceDefaults.ForceEmailOnThirdPartySignup
+
+ return data
+}
+
func equalStringSlice(a, b []string) bool {
if len(a) != len(b) {
return false
diff --git a/backend/internal/handler/admin/setting_handler_auth_source_defaults_test.go b/backend/internal/handler/admin/setting_handler_auth_source_defaults_test.go
new file mode 100644
index 00000000..b26fa447
--- /dev/null
+++ b/backend/internal/handler/admin/setting_handler_auth_source_defaults_test.go
@@ -0,0 +1,149 @@
+package admin
+
+import (
+ "bytes"
+ "context"
+ "encoding/json"
+ "net/http"
+ "net/http/httptest"
+ "testing"
+
+ "github.com/Wei-Shaw/sub2api/internal/config"
+ "github.com/Wei-Shaw/sub2api/internal/pkg/response"
+ "github.com/Wei-Shaw/sub2api/internal/service"
+ "github.com/gin-gonic/gin"
+ "github.com/stretchr/testify/require"
+)
+
+type settingHandlerRepoStub struct {
+ values map[string]string
+ lastUpdates map[string]string
+}
+
+func (s *settingHandlerRepoStub) Get(ctx context.Context, key string) (*service.Setting, error) {
+ panic("unexpected Get call")
+}
+
+func (s *settingHandlerRepoStub) GetValue(ctx context.Context, key string) (string, error) {
+ panic("unexpected GetValue call")
+}
+
+func (s *settingHandlerRepoStub) Set(ctx context.Context, key, value string) error {
+ panic("unexpected Set call")
+}
+
+func (s *settingHandlerRepoStub) GetMultiple(ctx context.Context, keys []string) (map[string]string, error) {
+ out := make(map[string]string, len(keys))
+ for _, key := range keys {
+ if value, ok := s.values[key]; ok {
+ out[key] = value
+ }
+ }
+ return out, nil
+}
+
+func (s *settingHandlerRepoStub) SetMultiple(ctx context.Context, settings map[string]string) error {
+ s.lastUpdates = make(map[string]string, len(settings))
+ for key, value := range settings {
+ s.lastUpdates[key] = value
+ if s.values == nil {
+ s.values = map[string]string{}
+ }
+ s.values[key] = value
+ }
+ return nil
+}
+
+func (s *settingHandlerRepoStub) GetAll(ctx context.Context) (map[string]string, error) {
+ out := make(map[string]string, len(s.values))
+ for key, value := range s.values {
+ out[key] = value
+ }
+ return out, nil
+}
+
+func (s *settingHandlerRepoStub) Delete(ctx context.Context, key string) error {
+ panic("unexpected Delete call")
+}
+
+func TestSettingHandler_GetSettings_InjectsAuthSourceDefaults(t *testing.T) {
+ gin.SetMode(gin.TestMode)
+ repo := &settingHandlerRepoStub{
+ values: map[string]string{
+ service.SettingKeyRegistrationEnabled: "true",
+ service.SettingKeyPromoCodeEnabled: "true",
+ service.SettingKeyAuthSourceDefaultEmailBalance: "9.5",
+ service.SettingKeyAuthSourceDefaultEmailConcurrency: "8",
+ service.SettingKeyAuthSourceDefaultEmailSubscriptions: `[{"group_id":31,"validity_days":15}]`,
+ service.SettingKeyForceEmailOnThirdPartySignup: "true",
+ },
+ }
+ svc := service.NewSettingService(repo, &config.Config{Default: config.DefaultConfig{UserConcurrency: 5}})
+ handler := NewSettingHandler(svc, nil, nil, nil, nil, nil)
+
+ rec := httptest.NewRecorder()
+ c, _ := gin.CreateTestContext(rec)
+ c.Request = httptest.NewRequest(http.MethodGet, "/api/v1/admin/settings", nil)
+
+ handler.GetSettings(c)
+
+ require.Equal(t, http.StatusOK, rec.Code)
+ var resp response.Response
+ require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &resp))
+ data, ok := resp.Data.(map[string]any)
+ require.True(t, ok)
+ require.Equal(t, 9.5, data["auth_source_default_email_balance"])
+ require.Equal(t, float64(8), data["auth_source_default_email_concurrency"])
+ require.Equal(t, true, data["force_email_on_third_party_signup"])
+
+ subscriptions, ok := data["auth_source_default_email_subscriptions"].([]any)
+ require.True(t, ok)
+ require.Len(t, subscriptions, 1)
+}
+
+func TestSettingHandler_UpdateSettings_PreservesOmittedAuthSourceDefaults(t *testing.T) {
+ gin.SetMode(gin.TestMode)
+ repo := &settingHandlerRepoStub{
+ values: map[string]string{
+ service.SettingKeyRegistrationEnabled: "false",
+ service.SettingKeyPromoCodeEnabled: "true",
+ service.SettingKeyAuthSourceDefaultEmailBalance: "9.5",
+ service.SettingKeyAuthSourceDefaultEmailConcurrency: "8",
+ service.SettingKeyAuthSourceDefaultEmailSubscriptions: `[{"group_id":31,"validity_days":15}]`,
+ service.SettingKeyAuthSourceDefaultEmailGrantOnSignup: "true",
+ service.SettingKeyAuthSourceDefaultEmailGrantOnFirstBind: "false",
+ service.SettingKeyForceEmailOnThirdPartySignup: "true",
+ },
+ }
+ svc := service.NewSettingService(repo, &config.Config{Default: config.DefaultConfig{UserConcurrency: 5}})
+ handler := NewSettingHandler(svc, nil, nil, nil, nil, nil)
+
+ body := map[string]any{
+ "registration_enabled": true,
+ "promo_code_enabled": true,
+ "auth_source_default_email_balance": 12.75,
+ }
+ rawBody, err := json.Marshal(body)
+ require.NoError(t, err)
+
+ rec := httptest.NewRecorder()
+ c, _ := gin.CreateTestContext(rec)
+ c.Request = httptest.NewRequest(http.MethodPut, "/api/v1/admin/settings", bytes.NewReader(rawBody))
+ c.Request.Header.Set("Content-Type", "application/json")
+
+ handler.UpdateSettings(c)
+
+ require.Equal(t, http.StatusOK, rec.Code)
+ require.Equal(t, "12.75000000", repo.values[service.SettingKeyAuthSourceDefaultEmailBalance])
+ require.Equal(t, "8", repo.values[service.SettingKeyAuthSourceDefaultEmailConcurrency])
+ require.Equal(t, `[{"group_id":31,"validity_days":15}]`, repo.values[service.SettingKeyAuthSourceDefaultEmailSubscriptions])
+ require.Equal(t, "true", repo.values[service.SettingKeyForceEmailOnThirdPartySignup])
+
+ var resp response.Response
+ require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &resp))
+ data, ok := resp.Data.(map[string]any)
+ require.True(t, ok)
+ require.Equal(t, 12.75, data["auth_source_default_email_balance"])
+ require.Equal(t, float64(8), data["auth_source_default_email_concurrency"])
+ require.Equal(t, true, data["force_email_on_third_party_signup"])
+}
diff --git a/backend/internal/handler/auth_linuxdo_oauth.go b/backend/internal/handler/auth_linuxdo_oauth.go
index 2f182642..b0edcf5a 100644
--- a/backend/internal/handler/auth_linuxdo_oauth.go
+++ b/backend/internal/handler/auth_linuxdo_oauth.go
@@ -219,7 +219,7 @@ func (h *AuthHandler) LinuxDoOAuthCallback(c *gin.Context) {
}
// 传入空邀请码;如果需要邀请码,服务层返回 ErrOAuthInvitationRequired
- tokenPair, _, err := h.authService.LoginOrRegisterOAuthWithTokenPair(c.Request.Context(), email, username, "")
+ tokenPair, user, err := h.authService.LoginOrRegisterOAuthWithTokenPair(c.Request.Context(), email, username, "")
if err != nil {
if errors.Is(err, service.ErrOAuthInvitationRequired) {
if err := h.createOAuthPendingSession(c, oauthPendingSessionPayload{
@@ -262,6 +262,7 @@ func (h *AuthHandler) LinuxDoOAuthCallback(c *gin.Context) {
ProviderKey: "linuxdo",
ProviderSubject: subject,
},
+ TargetUserID: &user.ID,
ResolvedEmail: email,
RedirectTo: redirectTo,
BrowserSessionKey: browserSessionKey,
@@ -287,7 +288,9 @@ func (h *AuthHandler) LinuxDoOAuthCallback(c *gin.Context) {
}
type completeLinuxDoOAuthRequest struct {
- InvitationCode string `json:"invitation_code" binding:"required"`
+ InvitationCode string `json:"invitation_code" binding:"required"`
+ AdoptDisplayName *bool `json:"adopt_display_name,omitempty"`
+ AdoptAvatar *bool `json:"adopt_avatar,omitempty"`
}
// CompleteLinuxDoOAuthRegistration completes a pending OAuth registration by validating
@@ -335,11 +338,23 @@ func (h *AuthHandler) CompleteLinuxDoOAuthRegistration(c *gin.Context) {
return
}
- tokenPair, _, err := h.authService.LoginOrRegisterOAuthWithTokenPair(c.Request.Context(), email, username, req.InvitationCode)
+ tokenPair, user, err := h.authService.LoginOrRegisterOAuthWithTokenPair(c.Request.Context(), email, username, req.InvitationCode)
if err != nil {
response.ErrorFrom(c, err)
return
}
+ decision, err := h.upsertPendingOAuthAdoptionDecision(c, session.ID, oauthAdoptionDecisionRequest{
+ AdoptDisplayName: req.AdoptDisplayName,
+ AdoptAvatar: req.AdoptAvatar,
+ })
+ if err != nil {
+ response.ErrorFrom(c, err)
+ return
+ }
+ if err := applyPendingOAuthAdoption(c.Request.Context(), h.entClient(), session, decision, &user.ID); err != nil {
+ response.ErrorFrom(c, infraerrors.InternalServer("PENDING_AUTH_ADOPTION_APPLY_FAILED", "failed to apply oauth profile adoption").WithCause(err))
+ return
+ }
if _, err := pendingSvc.ConsumeBrowserSession(c.Request.Context(), sessionToken, browserSessionKey); err != nil {
clearOAuthPendingSessionCookie(c, secureCookie)
clearOAuthPendingBrowserCookie(c, secureCookie)
diff --git a/backend/internal/handler/auth_linuxdo_oauth_test.go b/backend/internal/handler/auth_linuxdo_oauth_test.go
index 90bc10d1..661c0da0 100644
--- a/backend/internal/handler/auth_linuxdo_oauth_test.go
+++ b/backend/internal/handler/auth_linuxdo_oauth_test.go
@@ -1,10 +1,21 @@
package handler
import (
+ "bytes"
+ "context"
+ "net/http"
+ "net/http/httptest"
"strings"
"testing"
+ "time"
+ "github.com/Wei-Shaw/sub2api/ent/authidentity"
+ "github.com/Wei-Shaw/sub2api/ent/identityadoptiondecision"
+ "github.com/Wei-Shaw/sub2api/ent/pendingauthsession"
+ dbuser "github.com/Wei-Shaw/sub2api/ent/user"
"github.com/Wei-Shaw/sub2api/internal/config"
+ "github.com/Wei-Shaw/sub2api/internal/service"
+ "github.com/gin-gonic/gin"
"github.com/stretchr/testify/require"
)
@@ -110,3 +121,79 @@ func TestSingleLineStripsWhitespace(t *testing.T) {
require.Equal(t, "hello world", singleLine("hello\r\nworld"))
require.Equal(t, "", singleLine("\n\t\r"))
}
+
+func TestCompleteLinuxDoOAuthRegistrationAppliesPendingAdoptionDecision(t *testing.T) {
+ handler, client := newOAuthPendingFlowTestHandler(t, false)
+ ctx := context.Background()
+
+ session, err := client.PendingAuthSession.Create().
+ SetSessionToken("linuxdo-complete-session").
+ SetIntent("login").
+ SetProviderType("linuxdo").
+ SetProviderKey("linuxdo").
+ SetProviderSubject("linuxdo-subject-1").
+ SetResolvedEmail("linuxdo-subject-1@linuxdo-connect.invalid").
+ SetBrowserSessionKey("linuxdo-browser").
+ SetUpstreamIdentityClaims(map[string]any{
+ "username": "linuxdo_user",
+ "suggested_display_name": "LinuxDo Display",
+ "suggested_avatar_url": "https://cdn.example/linuxdo.png",
+ }).
+ SetExpiresAt(time.Now().UTC().Add(10 * time.Minute)).
+ Save(ctx)
+ require.NoError(t, err)
+
+ _, err = service.NewAuthPendingIdentityService(client).UpsertAdoptionDecision(ctx, service.PendingIdentityAdoptionDecisionInput{
+ PendingAuthSessionID: session.ID,
+ AdoptAvatar: true,
+ })
+ require.NoError(t, err)
+
+ body := bytes.NewBufferString(`{"invitation_code":"invite-1","adopt_display_name":true}`)
+ recorder := httptest.NewRecorder()
+ c, _ := gin.CreateTestContext(recorder)
+ req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/oauth/linuxdo/complete-registration", body)
+ req.Header.Set("Content-Type", "application/json")
+ req.AddCookie(&http.Cookie{Name: oauthPendingSessionCookieName, Value: encodeCookieValue(session.SessionToken)})
+ req.AddCookie(&http.Cookie{Name: oauthPendingBrowserCookieName, Value: encodeCookieValue("linuxdo-browser")})
+ c.Request = req
+
+ handler.CompleteLinuxDoOAuthRegistration(c)
+
+ require.Equal(t, http.StatusOK, recorder.Code)
+ responseData := decodeJSONBody(t, recorder)
+ require.NotEmpty(t, responseData["access_token"])
+
+ userEntity, err := client.User.Query().
+ Where(dbuser.EmailEQ(session.ResolvedEmail)).
+ Only(ctx)
+ require.NoError(t, err)
+ require.Equal(t, "LinuxDo Display", userEntity.Username)
+
+ identity, err := client.AuthIdentity.Query().
+ Where(
+ authidentity.ProviderTypeEQ("linuxdo"),
+ authidentity.ProviderKeyEQ("linuxdo"),
+ authidentity.ProviderSubjectEQ("linuxdo-subject-1"),
+ ).
+ Only(ctx)
+ require.NoError(t, err)
+ require.Equal(t, userEntity.ID, identity.UserID)
+ require.Equal(t, "LinuxDo Display", identity.Metadata["display_name"])
+ require.Equal(t, "https://cdn.example/linuxdo.png", identity.Metadata["avatar_url"])
+
+ decision, err := client.IdentityAdoptionDecision.Query().
+ Where(identityadoptiondecision.PendingAuthSessionIDEQ(session.ID)).
+ Only(ctx)
+ require.NoError(t, err)
+ require.NotNil(t, decision.IdentityID)
+ require.Equal(t, identity.ID, *decision.IdentityID)
+ require.True(t, decision.AdoptDisplayName)
+ require.True(t, decision.AdoptAvatar)
+
+ consumed, err := client.PendingAuthSession.Query().
+ Where(pendingauthsession.IDEQ(session.ID)).
+ Only(ctx)
+ require.NoError(t, err)
+ require.NotNil(t, consumed.ConsumedAt)
+}
diff --git a/backend/internal/handler/auth_oauth_pending_flow.go b/backend/internal/handler/auth_oauth_pending_flow.go
index a758c0b9..da8ac858 100644
--- a/backend/internal/handler/auth_oauth_pending_flow.go
+++ b/backend/internal/handler/auth_oauth_pending_flow.go
@@ -1,10 +1,17 @@
package handler
import (
+ "context"
+ "errors"
+ "io"
"net/http"
"net/url"
"strings"
+ dbent "github.com/Wei-Shaw/sub2api/ent"
+ "github.com/Wei-Shaw/sub2api/ent/authidentity"
+ "github.com/Wei-Shaw/sub2api/ent/identityadoptiondecision"
+ dbuser "github.com/Wei-Shaw/sub2api/ent/user"
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
"github.com/Wei-Shaw/sub2api/internal/pkg/oauth"
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
@@ -26,6 +33,7 @@ const (
type oauthPendingSessionPayload struct {
Intent string
Identity service.PendingAuthIdentityKey
+ TargetUserID *int64
ResolvedEmail string
RedirectTo string
BrowserSessionKey string
@@ -33,6 +41,11 @@ type oauthPendingSessionPayload struct {
CompletionResponse map[string]any
}
+type oauthAdoptionDecisionRequest struct {
+ AdoptDisplayName *bool `json:"adopt_display_name,omitempty"`
+ AdoptAvatar *bool `json:"adopt_avatar,omitempty"`
+}
+
func (h *AuthHandler) pendingIdentityService() (*service.AuthPendingIdentityService, error) {
if h == nil || h.authService == nil || h.authService.EntClient() == nil {
return nil, infraerrors.ServiceUnavailable("PENDING_AUTH_NOT_READY", "pending auth service is not ready")
@@ -125,6 +138,7 @@ func (h *AuthHandler) createOAuthPendingSession(c *gin.Context, payload oauthPen
session, err := svc.CreatePendingSession(c.Request.Context(), service.CreatePendingAuthSessionInput{
Intent: strings.TrimSpace(payload.Intent),
Identity: payload.Identity,
+ TargetUserID: payload.TargetUserID,
ResolvedEmail: strings.TrimSpace(payload.ResolvedEmail),
RedirectTo: strings.TrimSpace(payload.RedirectTo),
BrowserSessionKey: strings.TrimSpace(payload.BrowserSessionKey),
@@ -175,6 +189,291 @@ func pendingSessionWantsInvitation(payload map[string]any) bool {
return strings.EqualFold(strings.TrimSpace(pendingSessionStringValue(payload, "error")), "invitation_required")
}
+func (r oauthAdoptionDecisionRequest) hasDecision() bool {
+ return r.AdoptDisplayName != nil || r.AdoptAvatar != nil
+}
+
+func (r oauthAdoptionDecisionRequest) toServiceInput(sessionID int64) service.PendingIdentityAdoptionDecisionInput {
+ input := service.PendingIdentityAdoptionDecisionInput{
+ PendingAuthSessionID: sessionID,
+ }
+ if r.AdoptDisplayName != nil {
+ input.AdoptDisplayName = *r.AdoptDisplayName
+ }
+ if r.AdoptAvatar != nil {
+ input.AdoptAvatar = *r.AdoptAvatar
+ }
+ return input
+}
+
+func bindOptionalOAuthAdoptionDecision(c *gin.Context) (oauthAdoptionDecisionRequest, error) {
+ var req oauthAdoptionDecisionRequest
+ if c == nil || c.Request == nil || c.Request.Body == nil {
+ return req, nil
+ }
+ if err := c.ShouldBindJSON(&req); err != nil {
+ if errors.Is(err, io.EOF) {
+ return req, nil
+ }
+ return req, err
+ }
+ return req, nil
+}
+
+func persistPendingOAuthAdoptionDecision(
+ c *gin.Context,
+ svc *service.AuthPendingIdentityService,
+ sessionID int64,
+ req oauthAdoptionDecisionRequest,
+) error {
+ if !req.hasDecision() {
+ return nil
+ }
+ if svc == nil {
+ return infraerrors.ServiceUnavailable("PENDING_AUTH_NOT_READY", "pending auth service is not ready")
+ }
+ if _, err := svc.UpsertAdoptionDecision(c.Request.Context(), req.toServiceInput(sessionID)); err != nil {
+ return infraerrors.InternalServer("PENDING_AUTH_ADOPTION_SAVE_FAILED", "failed to save oauth profile adoption decision").WithCause(err)
+ }
+ return nil
+}
+
+func cloneOAuthMetadata(values map[string]any) map[string]any {
+ if len(values) == 0 {
+ return map[string]any{}
+ }
+ cloned := make(map[string]any, len(values))
+ for key, value := range values {
+ cloned[key] = value
+ }
+ return cloned
+}
+
+func normalizeAdoptedOAuthDisplayName(value string) string {
+ value = strings.TrimSpace(value)
+ if len([]rune(value)) > 100 {
+ value = string([]rune(value)[:100])
+ }
+ return value
+}
+
+func (h *AuthHandler) entClient() *dbent.Client {
+ if h == nil || h.authService == nil {
+ return nil
+ }
+ return h.authService.EntClient()
+}
+
+func (h *AuthHandler) upsertPendingOAuthAdoptionDecision(
+ c *gin.Context,
+ sessionID int64,
+ req oauthAdoptionDecisionRequest,
+) (*dbent.IdentityAdoptionDecision, error) {
+ client := h.entClient()
+ if client == nil {
+ return nil, infraerrors.ServiceUnavailable("PENDING_AUTH_NOT_READY", "pending auth service is not ready")
+ }
+
+ existing, err := client.IdentityAdoptionDecision.Query().
+ Where(identityadoptiondecision.PendingAuthSessionIDEQ(sessionID)).
+ Only(c.Request.Context())
+ if err != nil && !dbent.IsNotFound(err) {
+ return nil, infraerrors.InternalServer("PENDING_AUTH_ADOPTION_LOAD_FAILED", "failed to load oauth profile adoption decision").WithCause(err)
+ }
+ if existing != nil && !req.hasDecision() {
+ return existing, nil
+ }
+ if existing == nil && !req.hasDecision() {
+ return nil, nil
+ }
+
+ input := service.PendingIdentityAdoptionDecisionInput{
+ PendingAuthSessionID: sessionID,
+ }
+ if existing != nil {
+ input.AdoptDisplayName = existing.AdoptDisplayName
+ input.AdoptAvatar = existing.AdoptAvatar
+ input.IdentityID = existing.IdentityID
+ }
+ if req.AdoptDisplayName != nil {
+ input.AdoptDisplayName = *req.AdoptDisplayName
+ }
+ if req.AdoptAvatar != nil {
+ input.AdoptAvatar = *req.AdoptAvatar
+ }
+
+ svc, err := h.pendingIdentityService()
+ if err != nil {
+ return nil, err
+ }
+ decision, err := svc.UpsertAdoptionDecision(c.Request.Context(), input)
+ if err != nil {
+ return nil, infraerrors.InternalServer("PENDING_AUTH_ADOPTION_SAVE_FAILED", "failed to save oauth profile adoption decision").WithCause(err)
+ }
+ return decision, nil
+}
+
+func resolvePendingOAuthTargetUserID(ctx context.Context, client *dbent.Client, session *dbent.PendingAuthSession) (int64, error) {
+ if session == nil {
+ return 0, infraerrors.BadRequest("PENDING_AUTH_SESSION_INVALID", "pending auth session is invalid")
+ }
+ if session.TargetUserID != nil && *session.TargetUserID > 0 {
+ return *session.TargetUserID, nil
+ }
+ email := strings.TrimSpace(session.ResolvedEmail)
+ if email == "" {
+ return 0, infraerrors.BadRequest("PENDING_AUTH_TARGET_USER_MISSING", "pending auth target user is missing")
+ }
+
+ userEntity, err := client.User.Query().
+ Where(dbuser.EmailEQ(email)).
+ Only(ctx)
+ if err != nil {
+ if dbent.IsNotFound(err) {
+ return 0, infraerrors.InternalServer("PENDING_AUTH_TARGET_USER_NOT_FOUND", "pending auth target user was not found")
+ }
+ return 0, err
+ }
+ return userEntity.ID, nil
+}
+
+func oauthIdentityIssuer(session *dbent.PendingAuthSession) *string {
+ if session == nil {
+ return nil
+ }
+ switch strings.TrimSpace(session.ProviderType) {
+ case "oidc":
+ issuer := strings.TrimSpace(session.ProviderKey)
+ if issuer == "" {
+ issuer = pendingSessionStringValue(session.UpstreamIdentityClaims, "issuer")
+ }
+ if issuer == "" {
+ return nil
+ }
+ return &issuer
+ default:
+ issuer := pendingSessionStringValue(session.UpstreamIdentityClaims, "issuer")
+ if issuer == "" {
+ return nil
+ }
+ return &issuer
+ }
+}
+
+func ensurePendingOAuthIdentityForUser(ctx context.Context, tx *dbent.Tx, session *dbent.PendingAuthSession, userID int64) (*dbent.AuthIdentity, error) {
+ client := tx.Client()
+ identity, err := client.AuthIdentity.Query().
+ Where(
+ authidentity.ProviderTypeEQ(strings.TrimSpace(session.ProviderType)),
+ authidentity.ProviderKeyEQ(strings.TrimSpace(session.ProviderKey)),
+ authidentity.ProviderSubjectEQ(strings.TrimSpace(session.ProviderSubject)),
+ ).
+ Only(ctx)
+ if err != nil && !dbent.IsNotFound(err) {
+ return nil, err
+ }
+ if identity != nil {
+ if identity.UserID != userID {
+ return nil, infraerrors.Conflict("AUTH_IDENTITY_OWNERSHIP_CONFLICT", "auth identity already belongs to another user")
+ }
+ return identity, nil
+ }
+
+ create := client.AuthIdentity.Create().
+ SetUserID(userID).
+ SetProviderType(strings.TrimSpace(session.ProviderType)).
+ SetProviderKey(strings.TrimSpace(session.ProviderKey)).
+ SetProviderSubject(strings.TrimSpace(session.ProviderSubject)).
+ SetMetadata(cloneOAuthMetadata(session.UpstreamIdentityClaims))
+ if issuer := oauthIdentityIssuer(session); issuer != nil {
+ create = create.SetIssuer(strings.TrimSpace(*issuer))
+ }
+ return create.Save(ctx)
+}
+
+func applyPendingOAuthAdoption(
+ ctx context.Context,
+ client *dbent.Client,
+ session *dbent.PendingAuthSession,
+ decision *dbent.IdentityAdoptionDecision,
+ overrideUserID *int64,
+) error {
+ if client == nil || session == nil || decision == nil {
+ return nil
+ }
+ if !decision.AdoptDisplayName && !decision.AdoptAvatar {
+ return nil
+ }
+
+ targetUserID := int64(0)
+ if overrideUserID != nil && *overrideUserID > 0 {
+ targetUserID = *overrideUserID
+ } else {
+ resolvedUserID, err := resolvePendingOAuthTargetUserID(ctx, client, session)
+ if err != nil {
+ return err
+ }
+ targetUserID = resolvedUserID
+ }
+
+ adoptedDisplayName := ""
+ if decision.AdoptDisplayName {
+ adoptedDisplayName = normalizeAdoptedOAuthDisplayName(pendingSessionStringValue(session.UpstreamIdentityClaims, "suggested_display_name"))
+ }
+ adoptedAvatarURL := ""
+ if decision.AdoptAvatar {
+ adoptedAvatarURL = pendingSessionStringValue(session.UpstreamIdentityClaims, "suggested_avatar_url")
+ }
+
+ tx, err := client.Tx(ctx)
+ if err != nil {
+ return err
+ }
+ defer func() { _ = tx.Rollback() }()
+
+ if decision.AdoptDisplayName && adoptedDisplayName != "" {
+ if err := tx.Client().User.UpdateOneID(targetUserID).
+ SetUsername(adoptedDisplayName).
+ Exec(ctx); err != nil {
+ return err
+ }
+ }
+
+ identity, err := ensurePendingOAuthIdentityForUser(ctx, tx, session, targetUserID)
+ if err != nil {
+ return err
+ }
+
+ metadata := cloneOAuthMetadata(identity.Metadata)
+ for key, value := range session.UpstreamIdentityClaims {
+ metadata[key] = value
+ }
+ if decision.AdoptDisplayName && adoptedDisplayName != "" {
+ metadata["display_name"] = adoptedDisplayName
+ }
+ if decision.AdoptAvatar && adoptedAvatarURL != "" {
+ metadata["avatar_url"] = adoptedAvatarURL
+ }
+
+ updateIdentity := tx.Client().AuthIdentity.UpdateOneID(identity.ID).SetMetadata(metadata)
+ if issuer := oauthIdentityIssuer(session); issuer != nil {
+ updateIdentity = updateIdentity.SetIssuer(strings.TrimSpace(*issuer))
+ }
+ if _, err := updateIdentity.Save(ctx); err != nil {
+ return err
+ }
+
+ if decision.IdentityID == nil || *decision.IdentityID != identity.ID {
+ if _, err := tx.Client().IdentityAdoptionDecision.UpdateOneID(decision.ID).
+ SetIdentityID(identity.ID).
+ Save(ctx); err != nil {
+ return err
+ }
+ }
+
+ return tx.Commit()
+}
+
func applySuggestedProfileToCompletionResponse(payload map[string]any, upstream map[string]any) {
if len(payload) == 0 || len(upstream) == 0 {
return
@@ -206,6 +505,11 @@ func (h *AuthHandler) ExchangePendingOAuthCompletion(c *gin.Context) {
clearOAuthPendingSessionCookie(c, secureCookie)
clearOAuthPendingBrowserCookie(c, secureCookie)
}
+ adoptionDecision, err := bindOptionalOAuthAdoptionDecision(c)
+ if err != nil {
+ response.BadRequest(c, "Invalid request: "+err.Error())
+ return
+ }
sessionToken, err := readOAuthPendingSessionCookie(c)
if err != nil || strings.TrimSpace(sessionToken) == "" {
@@ -248,9 +552,30 @@ func (h *AuthHandler) ExchangePendingOAuthCompletion(c *gin.Context) {
applySuggestedProfileToCompletionResponse(payload, session.UpstreamIdentityClaims)
if pendingSessionWantsInvitation(payload) {
+ if adoptionDecision.hasDecision() {
+ decision, err := h.upsertPendingOAuthAdoptionDecision(c, session.ID, adoptionDecision)
+ if err != nil {
+ response.ErrorFrom(c, err)
+ return
+ }
+ _ = decision
+ }
response.Success(c, payload)
return
}
+ if !adoptionDecision.hasDecision() {
+ response.Success(c, payload)
+ return
+ }
+ decision, err := h.upsertPendingOAuthAdoptionDecision(c, session.ID, adoptionDecision)
+ if err != nil {
+ response.ErrorFrom(c, err)
+ return
+ }
+ if err := applyPendingOAuthAdoption(c.Request.Context(), h.entClient(), session, decision, session.TargetUserID); err != nil {
+ response.ErrorFrom(c, infraerrors.InternalServer("PENDING_AUTH_ADOPTION_APPLY_FAILED", "failed to apply oauth profile adoption").WithCause(err))
+ return
+ }
if _, err := svc.ConsumeBrowserSession(c.Request.Context(), sessionToken, browserSessionKey); err != nil {
clearCookies()
diff --git a/backend/internal/handler/auth_oauth_pending_flow_test.go b/backend/internal/handler/auth_oauth_pending_flow_test.go
index 5517bae2..829fc217 100644
--- a/backend/internal/handler/auth_oauth_pending_flow_test.go
+++ b/backend/internal/handler/auth_oauth_pending_flow_test.go
@@ -1,9 +1,30 @@
package handler
import (
+ "bytes"
+ "context"
+ "database/sql"
+ "encoding/json"
+ "net/http"
+ "net/http/httptest"
"testing"
+ "time"
+ dbent "github.com/Wei-Shaw/sub2api/ent"
+ "github.com/Wei-Shaw/sub2api/ent/authidentity"
+ "github.com/Wei-Shaw/sub2api/ent/enttest"
+ "github.com/Wei-Shaw/sub2api/ent/identityadoptiondecision"
+ "github.com/Wei-Shaw/sub2api/ent/pendingauthsession"
+ dbuser "github.com/Wei-Shaw/sub2api/ent/user"
+ "github.com/Wei-Shaw/sub2api/internal/config"
+ "github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
+ "github.com/Wei-Shaw/sub2api/internal/service"
+ "github.com/gin-gonic/gin"
"github.com/stretchr/testify/require"
+
+ "entgo.io/ent/dialect"
+ entsql "entgo.io/ent/dialect/sql"
+ _ "modernc.org/sqlite"
)
func TestApplySuggestedProfileToCompletionResponse(t *testing.T) {
@@ -38,3 +59,439 @@ func TestApplySuggestedProfileToCompletionResponseKeepsExistingPayloadValues(t *
require.Equal(t, "https://cdn.example/avatar.png", payload["suggested_avatar_url"])
require.Equal(t, true, payload["adoption_required"])
}
+
+func TestExchangePendingOAuthCompletionPreviewThenFinalizeAppliesAdoptionDecision(t *testing.T) {
+ handler, client := newOAuthPendingFlowTestHandler(t, false)
+ ctx := context.Background()
+
+ userEntity, err := client.User.Create().
+ SetEmail("linuxdo-123@linuxdo-connect.invalid").
+ SetUsername("legacy-name").
+ SetPasswordHash("hash").
+ SetRole(service.RoleUser).
+ SetStatus(service.StatusActive).
+ Save(ctx)
+ require.NoError(t, err)
+
+ session, err := client.PendingAuthSession.Create().
+ SetSessionToken("pending-session-token").
+ SetIntent("login").
+ SetProviderType("linuxdo").
+ SetProviderKey("linuxdo").
+ SetProviderSubject("123").
+ SetTargetUserID(userEntity.ID).
+ SetResolvedEmail(userEntity.Email).
+ SetBrowserSessionKey("browser-session-key").
+ SetUpstreamIdentityClaims(map[string]any{
+ "username": "linuxdo_user",
+ "suggested_display_name": "Alice Example",
+ "suggested_avatar_url": "https://cdn.example/alice.png",
+ }).
+ SetLocalFlowState(map[string]any{
+ oauthCompletionResponseKey: map[string]any{
+ "access_token": "access-token",
+ "redirect": "/dashboard",
+ },
+ }).
+ SetExpiresAt(time.Now().UTC().Add(10 * time.Minute)).
+ Save(ctx)
+ require.NoError(t, err)
+
+ previewRecorder := httptest.NewRecorder()
+ previewCtx, _ := gin.CreateTestContext(previewRecorder)
+ previewReq := httptest.NewRequest(http.MethodPost, "/api/v1/auth/oauth/pending/exchange", nil)
+ previewReq.AddCookie(&http.Cookie{Name: oauthPendingSessionCookieName, Value: encodeCookieValue(session.SessionToken)})
+ previewReq.AddCookie(&http.Cookie{Name: oauthPendingBrowserCookieName, Value: encodeCookieValue("browser-session-key")})
+ previewCtx.Request = previewReq
+
+ handler.ExchangePendingOAuthCompletion(previewCtx)
+
+ require.Equal(t, http.StatusOK, previewRecorder.Code)
+ previewData := decodeJSONResponseData(t, previewRecorder)
+ require.Equal(t, "Alice Example", previewData["suggested_display_name"])
+ require.Equal(t, "https://cdn.example/alice.png", previewData["suggested_avatar_url"])
+ require.Equal(t, true, previewData["adoption_required"])
+
+ storedUser, err := client.User.Get(ctx, userEntity.ID)
+ require.NoError(t, err)
+ require.Equal(t, "legacy-name", storedUser.Username)
+
+ previewSession, err := client.PendingAuthSession.Query().
+ Where(pendingauthsession.IDEQ(session.ID)).
+ Only(ctx)
+ require.NoError(t, err)
+ require.Nil(t, previewSession.ConsumedAt)
+
+ body := bytes.NewBufferString(`{"adopt_display_name":true,"adopt_avatar":true}`)
+ finalizeRecorder := httptest.NewRecorder()
+ finalizeCtx, _ := gin.CreateTestContext(finalizeRecorder)
+ finalizeReq := httptest.NewRequest(http.MethodPost, "/api/v1/auth/oauth/pending/exchange", body)
+ finalizeReq.Header.Set("Content-Type", "application/json")
+ finalizeReq.AddCookie(&http.Cookie{Name: oauthPendingSessionCookieName, Value: encodeCookieValue(session.SessionToken)})
+ finalizeReq.AddCookie(&http.Cookie{Name: oauthPendingBrowserCookieName, Value: encodeCookieValue("browser-session-key")})
+ finalizeCtx.Request = finalizeReq
+
+ handler.ExchangePendingOAuthCompletion(finalizeCtx)
+
+ require.Equal(t, http.StatusOK, finalizeRecorder.Code)
+
+ storedUser, err = client.User.Get(ctx, userEntity.ID)
+ require.NoError(t, err)
+ require.Equal(t, "Alice Example", storedUser.Username)
+
+ identity, err := client.AuthIdentity.Query().
+ Where(
+ authidentity.ProviderTypeEQ("linuxdo"),
+ authidentity.ProviderKeyEQ("linuxdo"),
+ authidentity.ProviderSubjectEQ("123"),
+ ).
+ Only(ctx)
+ require.NoError(t, err)
+ require.Equal(t, userEntity.ID, identity.UserID)
+ require.Equal(t, "Alice Example", identity.Metadata["display_name"])
+ require.Equal(t, "https://cdn.example/alice.png", identity.Metadata["avatar_url"])
+
+ decision, err := client.IdentityAdoptionDecision.Query().
+ Where(identityadoptiondecision.PendingAuthSessionIDEQ(session.ID)).
+ Only(ctx)
+ require.NoError(t, err)
+ require.NotNil(t, decision.IdentityID)
+ require.Equal(t, identity.ID, *decision.IdentityID)
+ require.True(t, decision.AdoptDisplayName)
+ require.True(t, decision.AdoptAvatar)
+
+ consumed, err := client.PendingAuthSession.Query().
+ Where(pendingauthsession.IDEQ(session.ID)).
+ Only(ctx)
+ require.NoError(t, err)
+ require.NotNil(t, consumed.ConsumedAt)
+}
+
+func newOAuthPendingFlowTestHandler(t *testing.T, invitationEnabled bool) (*AuthHandler, *dbent.Client) {
+ t.Helper()
+
+ db, err := sql.Open("sqlite", "file:auth_oauth_pending_flow_handler?mode=memory&cache=shared")
+ require.NoError(t, err)
+ t.Cleanup(func() { _ = db.Close() })
+
+ _, err = db.Exec("PRAGMA foreign_keys = ON")
+ require.NoError(t, err)
+
+ drv := entsql.OpenDB(dialect.SQLite, db)
+ client := enttest.NewClient(t, enttest.WithOptions(dbent.Driver(drv)))
+
+ cfg := &config.Config{
+ JWT: config.JWTConfig{
+ Secret: "test-secret",
+ ExpireHour: 1,
+ AccessTokenExpireMinutes: 60,
+ RefreshTokenExpireDays: 7,
+ },
+ Default: config.DefaultConfig{
+ UserBalance: 0,
+ UserConcurrency: 1,
+ },
+ }
+ settingSvc := service.NewSettingService(&oauthPendingFlowSettingRepoStub{
+ values: map[string]string{
+ service.SettingKeyRegistrationEnabled: "true",
+ service.SettingKeyInvitationCodeEnabled: boolSettingValue(invitationEnabled),
+ },
+ }, cfg)
+ authSvc := service.NewAuthService(
+ client,
+ &oauthPendingFlowUserRepo{client: client},
+ nil,
+ &oauthPendingFlowRefreshTokenCacheStub{},
+ cfg,
+ settingSvc,
+ nil,
+ nil,
+ nil,
+ nil,
+ nil,
+ )
+
+ return &AuthHandler{
+ authService: authSvc,
+ settingSvc: settingSvc,
+ }, client
+}
+
+func boolSettingValue(v bool) string {
+ if v {
+ return "true"
+ }
+ return "false"
+}
+
+func boolPtr(v bool) *bool {
+ return &v
+}
+
+type oauthPendingFlowSettingRepoStub struct {
+ values map[string]string
+}
+
+func (s *oauthPendingFlowSettingRepoStub) Get(context.Context, string) (*service.Setting, error) {
+ return nil, service.ErrSettingNotFound
+}
+
+func (s *oauthPendingFlowSettingRepoStub) GetValue(_ context.Context, key string) (string, error) {
+ value, ok := s.values[key]
+ if !ok {
+ return "", service.ErrSettingNotFound
+ }
+ return value, nil
+}
+
+func (s *oauthPendingFlowSettingRepoStub) Set(context.Context, string, string) error {
+ return nil
+}
+
+func (s *oauthPendingFlowSettingRepoStub) GetMultiple(_ context.Context, keys []string) (map[string]string, error) {
+ result := make(map[string]string, len(keys))
+ for _, key := range keys {
+ if value, ok := s.values[key]; ok {
+ result[key] = value
+ }
+ }
+ return result, nil
+}
+
+func (s *oauthPendingFlowSettingRepoStub) SetMultiple(context.Context, map[string]string) error {
+ return nil
+}
+
+func (s *oauthPendingFlowSettingRepoStub) GetAll(context.Context) (map[string]string, error) {
+ result := make(map[string]string, len(s.values))
+ for key, value := range s.values {
+ result[key] = value
+ }
+ return result, nil
+}
+
+func (s *oauthPendingFlowSettingRepoStub) Delete(context.Context, string) error {
+ return nil
+}
+
+type oauthPendingFlowRefreshTokenCacheStub struct{}
+
+func (s *oauthPendingFlowRefreshTokenCacheStub) StoreRefreshToken(context.Context, string, *service.RefreshTokenData, time.Duration) error {
+ return nil
+}
+
+func (s *oauthPendingFlowRefreshTokenCacheStub) GetRefreshToken(context.Context, string) (*service.RefreshTokenData, error) {
+ return nil, service.ErrRefreshTokenNotFound
+}
+
+func (s *oauthPendingFlowRefreshTokenCacheStub) DeleteRefreshToken(context.Context, string) error {
+ return nil
+}
+
+func (s *oauthPendingFlowRefreshTokenCacheStub) DeleteUserRefreshTokens(context.Context, int64) error {
+ return nil
+}
+
+func (s *oauthPendingFlowRefreshTokenCacheStub) DeleteTokenFamily(context.Context, string) error {
+ return nil
+}
+
+func (s *oauthPendingFlowRefreshTokenCacheStub) AddToUserTokenSet(context.Context, int64, string, time.Duration) error {
+ return nil
+}
+
+func (s *oauthPendingFlowRefreshTokenCacheStub) AddToFamilyTokenSet(context.Context, string, string, time.Duration) error {
+ return nil
+}
+
+func (s *oauthPendingFlowRefreshTokenCacheStub) GetUserTokenHashes(context.Context, int64) ([]string, error) {
+ return nil, nil
+}
+
+func (s *oauthPendingFlowRefreshTokenCacheStub) GetFamilyTokenHashes(context.Context, string) ([]string, error) {
+ return nil, nil
+}
+
+func (s *oauthPendingFlowRefreshTokenCacheStub) IsTokenInFamily(context.Context, string, string) (bool, error) {
+ return false, nil
+}
+
+func decodeJSONResponseData(t *testing.T, recorder *httptest.ResponseRecorder) map[string]any {
+ t.Helper()
+
+ var envelope struct {
+ Data map[string]any `json:"data"`
+ }
+ require.NoError(t, json.Unmarshal(recorder.Body.Bytes(), &envelope))
+ return envelope.Data
+}
+
+func decodeJSONBody(t *testing.T, recorder *httptest.ResponseRecorder) map[string]any {
+ t.Helper()
+
+ var payload map[string]any
+ require.NoError(t, json.Unmarshal(recorder.Body.Bytes(), &payload))
+ return payload
+}
+
+type oauthPendingFlowUserRepo struct {
+ client *dbent.Client
+}
+
+func (r *oauthPendingFlowUserRepo) Create(ctx context.Context, user *service.User) error {
+ entity, err := r.client.User.Create().
+ SetEmail(user.Email).
+ SetUsername(user.Username).
+ SetNotes(user.Notes).
+ SetPasswordHash(user.PasswordHash).
+ SetRole(user.Role).
+ SetBalance(user.Balance).
+ SetConcurrency(user.Concurrency).
+ SetStatus(user.Status).
+ SetSignupSource(user.SignupSource).
+ SetNillableLastLoginAt(user.LastLoginAt).
+ SetNillableLastActiveAt(user.LastActiveAt).
+ Save(ctx)
+ if err != nil {
+ return err
+ }
+ user.ID = entity.ID
+ user.CreatedAt = entity.CreatedAt
+ user.UpdatedAt = entity.UpdatedAt
+ return nil
+}
+
+func (r *oauthPendingFlowUserRepo) GetByID(ctx context.Context, id int64) (*service.User, error) {
+ entity, err := r.client.User.Get(ctx, id)
+ if err != nil {
+ if dbent.IsNotFound(err) {
+ return nil, service.ErrUserNotFound
+ }
+ return nil, err
+ }
+ return oauthPendingFlowServiceUser(entity), nil
+}
+
+func (r *oauthPendingFlowUserRepo) GetByEmail(ctx context.Context, email string) (*service.User, error) {
+ entity, err := r.client.User.Query().Where(dbuser.EmailEQ(email)).Only(ctx)
+ if err != nil {
+ if dbent.IsNotFound(err) {
+ return nil, service.ErrUserNotFound
+ }
+ return nil, err
+ }
+ return oauthPendingFlowServiceUser(entity), nil
+}
+
+func (r *oauthPendingFlowUserRepo) GetFirstAdmin(context.Context) (*service.User, error) {
+ panic("unexpected GetFirstAdmin call")
+}
+
+func (r *oauthPendingFlowUserRepo) Update(ctx context.Context, user *service.User) error {
+ entity, err := r.client.User.UpdateOneID(user.ID).
+ SetEmail(user.Email).
+ SetUsername(user.Username).
+ SetNotes(user.Notes).
+ SetPasswordHash(user.PasswordHash).
+ SetRole(user.Role).
+ SetBalance(user.Balance).
+ SetConcurrency(user.Concurrency).
+ SetStatus(user.Status).
+ SetSignupSource(user.SignupSource).
+ SetNillableLastLoginAt(user.LastLoginAt).
+ SetNillableLastActiveAt(user.LastActiveAt).
+ Save(ctx)
+ if err != nil {
+ return err
+ }
+ user.UpdatedAt = entity.UpdatedAt
+ return nil
+}
+
+func (r *oauthPendingFlowUserRepo) Delete(ctx context.Context, id int64) error {
+ return r.client.User.DeleteOneID(id).Exec(ctx)
+}
+
+func (r *oauthPendingFlowUserRepo) GetUserAvatar(context.Context, int64) (*service.UserAvatar, error) {
+ return nil, service.ErrUserNotFound
+}
+
+func (r *oauthPendingFlowUserRepo) UpsertUserAvatar(context.Context, int64, service.UpsertUserAvatarInput) (*service.UserAvatar, error) {
+ panic("unexpected UpsertUserAvatar call")
+}
+
+func (r *oauthPendingFlowUserRepo) DeleteUserAvatar(context.Context, int64) error {
+ return nil
+}
+
+func (r *oauthPendingFlowUserRepo) List(context.Context, pagination.PaginationParams) ([]service.User, *pagination.PaginationResult, error) {
+ panic("unexpected List call")
+}
+
+func (r *oauthPendingFlowUserRepo) ListWithFilters(context.Context, pagination.PaginationParams, service.UserListFilters) ([]service.User, *pagination.PaginationResult, error) {
+ panic("unexpected ListWithFilters call")
+}
+
+func (r *oauthPendingFlowUserRepo) UpdateBalance(context.Context, int64, float64) error {
+ panic("unexpected UpdateBalance call")
+}
+
+func (r *oauthPendingFlowUserRepo) DeductBalance(context.Context, int64, float64) error {
+ panic("unexpected DeductBalance call")
+}
+
+func (r *oauthPendingFlowUserRepo) UpdateConcurrency(context.Context, int64, int) error {
+ panic("unexpected UpdateConcurrency call")
+}
+
+func (r *oauthPendingFlowUserRepo) ExistsByEmail(ctx context.Context, email string) (bool, error) {
+ count, err := r.client.User.Query().Where(dbuser.EmailEQ(email)).Count(ctx)
+ return count > 0, err
+}
+
+func (r *oauthPendingFlowUserRepo) RemoveGroupFromAllowedGroups(context.Context, int64) (int64, error) {
+ panic("unexpected RemoveGroupFromAllowedGroups call")
+}
+
+func (r *oauthPendingFlowUserRepo) AddGroupToAllowedGroups(context.Context, int64, int64) error {
+ panic("unexpected AddGroupToAllowedGroups call")
+}
+
+func (r *oauthPendingFlowUserRepo) RemoveGroupFromUserAllowedGroups(context.Context, int64, int64) error {
+ panic("unexpected RemoveGroupFromUserAllowedGroups call")
+}
+
+func (r *oauthPendingFlowUserRepo) UpdateTotpSecret(context.Context, int64, *string) error {
+ panic("unexpected UpdateTotpSecret call")
+}
+
+func (r *oauthPendingFlowUserRepo) EnableTotp(context.Context, int64) error {
+ panic("unexpected EnableTotp call")
+}
+
+func (r *oauthPendingFlowUserRepo) DisableTotp(context.Context, int64) error {
+ panic("unexpected DisableTotp call")
+}
+
+func oauthPendingFlowServiceUser(entity *dbent.User) *service.User {
+ if entity == nil {
+ return nil
+ }
+ return &service.User{
+ ID: entity.ID,
+ Email: entity.Email,
+ Username: entity.Username,
+ Notes: entity.Notes,
+ PasswordHash: entity.PasswordHash,
+ Role: entity.Role,
+ Balance: entity.Balance,
+ Concurrency: entity.Concurrency,
+ Status: entity.Status,
+ SignupSource: entity.SignupSource,
+ LastLoginAt: entity.LastLoginAt,
+ LastActiveAt: entity.LastActiveAt,
+ CreatedAt: entity.CreatedAt,
+ UpdatedAt: entity.UpdatedAt,
+ }
+}
diff --git a/backend/internal/handler/auth_oidc_oauth.go b/backend/internal/handler/auth_oidc_oauth.go
index e3694c8f..ceda633c 100644
--- a/backend/internal/handler/auth_oidc_oauth.go
+++ b/backend/internal/handler/auth_oidc_oauth.go
@@ -326,7 +326,7 @@ func (h *AuthHandler) OIDCOAuthCallback(c *gin.Context) {
)
// 传入空邀请码;如果需要邀请码,服务层返回 ErrOAuthInvitationRequired
- tokenPair, _, err := h.authService.LoginOrRegisterOAuthWithTokenPair(c.Request.Context(), email, username, "")
+ tokenPair, user, err := h.authService.LoginOrRegisterOAuthWithTokenPair(c.Request.Context(), email, username, "")
if err != nil {
if errors.Is(err, service.ErrOAuthInvitationRequired) {
if err := h.createOAuthPendingSession(c, oauthPendingSessionPayload{
@@ -371,6 +371,7 @@ func (h *AuthHandler) OIDCOAuthCallback(c *gin.Context) {
ProviderKey: issuer,
ProviderSubject: subject,
},
+ TargetUserID: &user.ID,
ResolvedEmail: email,
RedirectTo: redirectTo,
BrowserSessionKey: browserSessionKey,
@@ -399,7 +400,9 @@ func (h *AuthHandler) OIDCOAuthCallback(c *gin.Context) {
}
type completeOIDCOAuthRequest struct {
- InvitationCode string `json:"invitation_code" binding:"required"`
+ InvitationCode string `json:"invitation_code" binding:"required"`
+ AdoptDisplayName *bool `json:"adopt_display_name,omitempty"`
+ AdoptAvatar *bool `json:"adopt_avatar,omitempty"`
}
// CompleteOIDCOAuthRegistration completes a pending OAuth registration by validating
@@ -447,11 +450,23 @@ func (h *AuthHandler) CompleteOIDCOAuthRegistration(c *gin.Context) {
return
}
- tokenPair, _, err := h.authService.LoginOrRegisterOAuthWithTokenPair(c.Request.Context(), email, username, req.InvitationCode)
+ tokenPair, user, err := h.authService.LoginOrRegisterOAuthWithTokenPair(c.Request.Context(), email, username, req.InvitationCode)
if err != nil {
response.ErrorFrom(c, err)
return
}
+ decision, err := h.upsertPendingOAuthAdoptionDecision(c, session.ID, oauthAdoptionDecisionRequest{
+ AdoptDisplayName: req.AdoptDisplayName,
+ AdoptAvatar: req.AdoptAvatar,
+ })
+ if err != nil {
+ response.ErrorFrom(c, err)
+ return
+ }
+ if err := applyPendingOAuthAdoption(c.Request.Context(), h.entClient(), session, decision, &user.ID); err != nil {
+ response.ErrorFrom(c, infraerrors.InternalServer("PENDING_AUTH_ADOPTION_APPLY_FAILED", "failed to apply oauth profile adoption").WithCause(err))
+ return
+ }
if _, err := pendingSvc.ConsumeBrowserSession(c.Request.Context(), sessionToken, browserSessionKey); err != nil {
clearOAuthPendingSessionCookie(c, secureCookie)
clearOAuthPendingBrowserCookie(c, secureCookie)
diff --git a/backend/internal/handler/auth_oidc_oauth_test.go b/backend/internal/handler/auth_oidc_oauth_test.go
index c389db51..9107e13a 100644
--- a/backend/internal/handler/auth_oidc_oauth_test.go
+++ b/backend/internal/handler/auth_oidc_oauth_test.go
@@ -1,6 +1,7 @@
package handler
import (
+ "bytes"
"context"
"crypto/rand"
"crypto/rsa"
@@ -12,7 +13,13 @@ import (
"testing"
"time"
+ "github.com/Wei-Shaw/sub2api/ent/authidentity"
+ "github.com/Wei-Shaw/sub2api/ent/identityadoptiondecision"
+ "github.com/Wei-Shaw/sub2api/ent/pendingauthsession"
+ dbuser "github.com/Wei-Shaw/sub2api/ent/user"
"github.com/Wei-Shaw/sub2api/internal/config"
+ "github.com/Wei-Shaw/sub2api/internal/service"
+ "github.com/gin-gonic/gin"
"github.com/golang-jwt/jwt/v5"
"github.com/stretchr/testify/require"
)
@@ -123,3 +130,80 @@ func buildRSAJWK(kid string, pub *rsa.PublicKey) oidcJWK {
E: e,
}
}
+
+func TestCompleteOIDCOAuthRegistrationAppliesPendingAdoptionDecision(t *testing.T) {
+ handler, client := newOAuthPendingFlowTestHandler(t, false)
+ ctx := context.Background()
+
+ session, err := client.PendingAuthSession.Create().
+ SetSessionToken("oidc-complete-session").
+ SetIntent("login").
+ SetProviderType("oidc").
+ SetProviderKey("https://issuer.example.com").
+ SetProviderSubject("oidc-subject-1").
+ SetResolvedEmail("93a310f4c1944c5bbd2e246df1f76485@oidc-connect.invalid").
+ SetBrowserSessionKey("oidc-browser").
+ SetUpstreamIdentityClaims(map[string]any{
+ "username": "oidc_user",
+ "issuer": "https://issuer.example.com",
+ "suggested_display_name": "OIDC Display",
+ "suggested_avatar_url": "https://cdn.example/oidc.png",
+ }).
+ SetExpiresAt(time.Now().UTC().Add(10 * time.Minute)).
+ Save(ctx)
+ require.NoError(t, err)
+
+ _, err = service.NewAuthPendingIdentityService(client).UpsertAdoptionDecision(ctx, service.PendingIdentityAdoptionDecisionInput{
+ PendingAuthSessionID: session.ID,
+ AdoptAvatar: true,
+ })
+ require.NoError(t, err)
+
+ body := bytes.NewBufferString(`{"invitation_code":"invite-1","adopt_display_name":true}`)
+ recorder := httptest.NewRecorder()
+ c, _ := gin.CreateTestContext(recorder)
+ req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/oauth/oidc/complete-registration", body)
+ req.Header.Set("Content-Type", "application/json")
+ req.AddCookie(&http.Cookie{Name: oauthPendingSessionCookieName, Value: encodeCookieValue(session.SessionToken)})
+ req.AddCookie(&http.Cookie{Name: oauthPendingBrowserCookieName, Value: encodeCookieValue("oidc-browser")})
+ c.Request = req
+
+ handler.CompleteOIDCOAuthRegistration(c)
+
+ require.Equal(t, http.StatusOK, recorder.Code)
+ responseData := decodeJSONBody(t, recorder)
+ require.NotEmpty(t, responseData["access_token"])
+
+ userEntity, err := client.User.Query().
+ Where(dbuser.EmailEQ(session.ResolvedEmail)).
+ Only(ctx)
+ require.NoError(t, err)
+ require.Equal(t, "OIDC Display", userEntity.Username)
+
+ identity, err := client.AuthIdentity.Query().
+ Where(
+ authidentity.ProviderTypeEQ("oidc"),
+ authidentity.ProviderKeyEQ("https://issuer.example.com"),
+ authidentity.ProviderSubjectEQ("oidc-subject-1"),
+ ).
+ Only(ctx)
+ require.NoError(t, err)
+ require.Equal(t, userEntity.ID, identity.UserID)
+ require.Equal(t, "OIDC Display", identity.Metadata["display_name"])
+ require.Equal(t, "https://cdn.example/oidc.png", identity.Metadata["avatar_url"])
+
+ decision, err := client.IdentityAdoptionDecision.Query().
+ Where(identityadoptiondecision.PendingAuthSessionIDEQ(session.ID)).
+ Only(ctx)
+ require.NoError(t, err)
+ require.NotNil(t, decision.IdentityID)
+ require.Equal(t, identity.ID, *decision.IdentityID)
+ require.True(t, decision.AdoptDisplayName)
+ require.True(t, decision.AdoptAvatar)
+
+ consumed, err := client.PendingAuthSession.Query().
+ Where(pendingauthsession.IDEQ(session.ID)).
+ Only(ctx)
+ require.NoError(t, err)
+ require.NotNil(t, consumed.ConsumedAt)
+}
diff --git a/backend/internal/handler/auth_wechat_oauth.go b/backend/internal/handler/auth_wechat_oauth.go
new file mode 100644
index 00000000..867a77a1
--- /dev/null
+++ b/backend/internal/handler/auth_wechat_oauth.go
@@ -0,0 +1,618 @@
+package handler
+
+import (
+ "context"
+ "encoding/json"
+ "errors"
+ "fmt"
+ "io"
+ "net/http"
+ "net/url"
+ "os"
+ "strings"
+ "time"
+
+ infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
+ "github.com/Wei-Shaw/sub2api/internal/pkg/oauth"
+ "github.com/Wei-Shaw/sub2api/internal/pkg/response"
+ "github.com/Wei-Shaw/sub2api/internal/service"
+
+ "github.com/gin-gonic/gin"
+)
+
+const (
+ wechatOAuthCookiePath = "/api/v1/auth/oauth/wechat"
+ wechatOAuthCookieMaxAgeSec = 10 * 60
+ wechatOAuthStateCookieName = "wechat_oauth_state"
+ wechatOAuthRedirectCookieName = "wechat_oauth_redirect"
+ wechatOAuthIntentCookieName = "wechat_oauth_intent"
+ wechatOAuthModeCookieName = "wechat_oauth_mode"
+ wechatOAuthDefaultRedirectTo = "/dashboard"
+ wechatOAuthDefaultFrontendCB = "/auth/wechat/callback"
+ wechatOAuthProviderKey = "wechat-main"
+
+ wechatOAuthIntentLogin = "login"
+ wechatOAuthIntentBind = "bind_current_user"
+ wechatOAuthIntentAdoptEmail = "adopt_existing_user_by_email"
+)
+
+var (
+ wechatOAuthAccessTokenURL = "https://api.weixin.qq.com/sns/oauth2/access_token"
+ wechatOAuthUserInfoURL = "https://api.weixin.qq.com/sns/userinfo"
+)
+
+type wechatOAuthConfig struct {
+ mode string
+ appID string
+ appSecret string
+ authorizeURL string
+ scope string
+ redirectURI string
+ frontendCallback string
+}
+
+type wechatOAuthTokenResponse struct {
+ AccessToken string `json:"access_token"`
+ ExpiresIn int64 `json:"expires_in"`
+ RefreshToken string `json:"refresh_token"`
+ OpenID string `json:"openid"`
+ Scope string `json:"scope"`
+ UnionID string `json:"unionid"`
+ ErrCode int64 `json:"errcode"`
+ ErrMsg string `json:"errmsg"`
+}
+
+type wechatOAuthUserInfoResponse struct {
+ OpenID string `json:"openid"`
+ Nickname string `json:"nickname"`
+ HeadImgURL string `json:"headimgurl"`
+ UnionID string `json:"unionid"`
+ ErrCode int64 `json:"errcode"`
+ ErrMsg string `json:"errmsg"`
+}
+
+// WeChatOAuthStart starts the WeChat OAuth login flow and stores the short-lived
+// browser cookies required by the rebuild pending-auth bridge.
+func (h *AuthHandler) WeChatOAuthStart(c *gin.Context) {
+ cfg, err := h.getWeChatOAuthConfig(c.Request.Context(), c.Query("mode"), c)
+ if err != nil {
+ response.ErrorFrom(c, err)
+ return
+ }
+
+ state, err := oauth.GenerateState()
+ if err != nil {
+ response.ErrorFrom(c, infraerrors.InternalServer("OAUTH_STATE_GEN_FAILED", "failed to generate oauth state").WithCause(err))
+ return
+ }
+
+ redirectTo := sanitizeFrontendRedirectPath(c.Query("redirect"))
+ if redirectTo == "" {
+ redirectTo = wechatOAuthDefaultRedirectTo
+ }
+
+ browserSessionKey, err := generateOAuthPendingBrowserSession()
+ if err != nil {
+ response.ErrorFrom(c, infraerrors.InternalServer("OAUTH_BROWSER_SESSION_GEN_FAILED", "failed to generate oauth browser session").WithCause(err))
+ return
+ }
+
+ intent := normalizeWeChatOAuthIntent(c.Query("intent"))
+ secureCookie := isRequestHTTPS(c)
+ wechatSetCookie(c, wechatOAuthStateCookieName, encodeCookieValue(state), wechatOAuthCookieMaxAgeSec, secureCookie)
+ wechatSetCookie(c, wechatOAuthRedirectCookieName, encodeCookieValue(redirectTo), wechatOAuthCookieMaxAgeSec, secureCookie)
+ wechatSetCookie(c, wechatOAuthIntentCookieName, encodeCookieValue(intent), wechatOAuthCookieMaxAgeSec, secureCookie)
+ wechatSetCookie(c, wechatOAuthModeCookieName, encodeCookieValue(cfg.mode), wechatOAuthCookieMaxAgeSec, secureCookie)
+ setOAuthPendingBrowserCookie(c, browserSessionKey, secureCookie)
+ clearOAuthPendingSessionCookie(c, secureCookie)
+
+ authURL, err := buildWeChatAuthorizeURL(cfg, state)
+ if err != nil {
+ response.ErrorFrom(c, infraerrors.InternalServer("OAUTH_BUILD_URL_FAILED", "failed to build oauth authorization url").WithCause(err))
+ return
+ }
+
+ c.Redirect(http.StatusFound, authURL)
+}
+
+// WeChatOAuthCallback exchanges the code with WeChat, resolves openid/unionid,
+// and stores the result in the unified pending-auth flow.
+func (h *AuthHandler) WeChatOAuthCallback(c *gin.Context) {
+ frontendCallback := wechatOAuthFrontendCallback()
+
+ if providerErr := strings.TrimSpace(c.Query("error")); providerErr != "" {
+ redirectOAuthError(c, frontendCallback, "provider_error", providerErr, c.Query("error_description"))
+ return
+ }
+
+ code := strings.TrimSpace(c.Query("code"))
+ state := strings.TrimSpace(c.Query("state"))
+ if code == "" || state == "" {
+ redirectOAuthError(c, frontendCallback, "missing_params", "missing code/state", "")
+ return
+ }
+
+ secureCookie := isRequestHTTPS(c)
+ defer func() {
+ wechatClearCookie(c, wechatOAuthStateCookieName, secureCookie)
+ wechatClearCookie(c, wechatOAuthRedirectCookieName, secureCookie)
+ wechatClearCookie(c, wechatOAuthIntentCookieName, secureCookie)
+ wechatClearCookie(c, wechatOAuthModeCookieName, secureCookie)
+ }()
+
+ expectedState, err := readCookieDecoded(c, wechatOAuthStateCookieName)
+ if err != nil || expectedState == "" || state != expectedState {
+ redirectOAuthError(c, frontendCallback, "invalid_state", "invalid oauth state", "")
+ return
+ }
+
+ redirectTo, _ := readCookieDecoded(c, wechatOAuthRedirectCookieName)
+ redirectTo = sanitizeFrontendRedirectPath(redirectTo)
+ if redirectTo == "" {
+ redirectTo = wechatOAuthDefaultRedirectTo
+ }
+ browserSessionKey, _ := readOAuthPendingBrowserCookie(c)
+ if strings.TrimSpace(browserSessionKey) == "" {
+ redirectOAuthError(c, frontendCallback, "missing_browser_session", "missing oauth browser session", "")
+ return
+ }
+
+ intent, _ := readCookieDecoded(c, wechatOAuthIntentCookieName)
+ mode, err := readCookieDecoded(c, wechatOAuthModeCookieName)
+ if err != nil || strings.TrimSpace(mode) == "" {
+ redirectOAuthError(c, frontendCallback, "invalid_state", "missing oauth mode", "")
+ return
+ }
+
+ cfg, err := h.getWeChatOAuthConfig(c.Request.Context(), mode, c)
+ if err != nil {
+ redirectOAuthError(c, frontendCallback, "provider_error", infraerrors.Reason(err), infraerrors.Message(err))
+ return
+ }
+
+ tokenResp, userInfo, err := fetchWeChatOAuthIdentity(c.Request.Context(), cfg, code)
+ if err != nil {
+ redirectOAuthError(c, frontendCallback, "provider_error", "wechat_identity_fetch_failed", singleLine(err.Error()))
+ return
+ }
+
+ unionid := strings.TrimSpace(firstNonEmpty(userInfo.UnionID, tokenResp.UnionID))
+ openid := strings.TrimSpace(firstNonEmpty(userInfo.OpenID, tokenResp.OpenID))
+ providerSubject := firstNonEmpty(unionid, openid)
+ if providerSubject == "" {
+ redirectOAuthError(c, frontendCallback, "provider_error", "wechat_missing_subject", "")
+ return
+ }
+
+ username := firstNonEmpty(userInfo.Nickname, wechatFallbackUsername(providerSubject))
+ email := wechatSyntheticEmail(providerSubject)
+ upstreamClaims := map[string]any{
+ "email": email,
+ "username": username,
+ "subject": providerSubject,
+ "openid": openid,
+ "unionid": unionid,
+ "mode": cfg.mode,
+ "suggested_display_name": strings.TrimSpace(userInfo.Nickname),
+ "suggested_avatar_url": strings.TrimSpace(userInfo.HeadImgURL),
+ }
+
+ tokenPair, _, err := h.authService.LoginOrRegisterOAuthWithTokenPair(c.Request.Context(), email, username, "")
+ if err != nil {
+ if err := h.createWeChatPendingSession(c, normalizeWeChatOAuthIntent(intent), providerSubject, email, redirectTo, browserSessionKey, upstreamClaims, tokenPair, err); err != nil {
+ redirectOAuthError(c, frontendCallback, "session_error", "failed to continue oauth login", "")
+ return
+ }
+ redirectToFrontendCallback(c, frontendCallback)
+ return
+ }
+
+ if err := h.createWeChatPendingSession(c, normalizeWeChatOAuthIntent(intent), providerSubject, email, redirectTo, browserSessionKey, upstreamClaims, tokenPair, nil); err != nil {
+ redirectOAuthError(c, frontendCallback, "session_error", "failed to continue oauth login", "")
+ return
+ }
+ redirectToFrontendCallback(c, frontendCallback)
+}
+
+type completeWeChatOAuthRequest struct {
+ InvitationCode string `json:"invitation_code" binding:"required"`
+ AdoptDisplayName *bool `json:"adopt_display_name,omitempty"`
+ AdoptAvatar *bool `json:"adopt_avatar,omitempty"`
+}
+
+// CompleteWeChatOAuthRegistration completes a pending WeChat OAuth registration by
+// validating the invitation code and consuming the current pending browser session.
+// POST /api/v1/auth/oauth/wechat/complete-registration
+func (h *AuthHandler) CompleteWeChatOAuthRegistration(c *gin.Context) {
+ var req completeWeChatOAuthRequest
+ if err := c.ShouldBindJSON(&req); err != nil {
+ c.JSON(http.StatusBadRequest, gin.H{"error": "INVALID_REQUEST", "message": err.Error()})
+ return
+ }
+
+ secureCookie := isRequestHTTPS(c)
+ sessionToken, err := readOAuthPendingSessionCookie(c)
+ if err != nil {
+ clearOAuthPendingSessionCookie(c, secureCookie)
+ clearOAuthPendingBrowserCookie(c, secureCookie)
+ response.ErrorFrom(c, service.ErrPendingAuthSessionNotFound)
+ return
+ }
+ browserSessionKey, err := readOAuthPendingBrowserCookie(c)
+ if err != nil {
+ clearOAuthPendingSessionCookie(c, secureCookie)
+ clearOAuthPendingBrowserCookie(c, secureCookie)
+ response.ErrorFrom(c, service.ErrPendingAuthBrowserMismatch)
+ return
+ }
+ pendingSvc, err := h.pendingIdentityService()
+ if err != nil {
+ response.ErrorFrom(c, err)
+ return
+ }
+ session, err := pendingSvc.GetBrowserSession(c.Request.Context(), sessionToken, browserSessionKey)
+ if err != nil {
+ clearOAuthPendingSessionCookie(c, secureCookie)
+ clearOAuthPendingBrowserCookie(c, secureCookie)
+ response.ErrorFrom(c, err)
+ return
+ }
+
+ email := strings.TrimSpace(session.ResolvedEmail)
+ username := pendingSessionStringValue(session.UpstreamIdentityClaims, "username")
+ if email == "" || username == "" {
+ response.ErrorFrom(c, infraerrors.BadRequest("PENDING_AUTH_SESSION_INVALID", "pending auth registration context is invalid"))
+ return
+ }
+
+ tokenPair, user, err := h.authService.LoginOrRegisterOAuthWithTokenPair(c.Request.Context(), email, username, req.InvitationCode)
+ if err != nil {
+ response.ErrorFrom(c, err)
+ return
+ }
+ decision, err := h.upsertPendingOAuthAdoptionDecision(c, session.ID, oauthAdoptionDecisionRequest{
+ AdoptDisplayName: req.AdoptDisplayName,
+ AdoptAvatar: req.AdoptAvatar,
+ })
+ if err != nil {
+ response.ErrorFrom(c, err)
+ return
+ }
+ if err := applyPendingOAuthAdoption(c.Request.Context(), h.entClient(), session, decision, &user.ID); err != nil {
+ response.ErrorFrom(c, infraerrors.InternalServer("PENDING_AUTH_ADOPTION_APPLY_FAILED", "failed to apply oauth profile adoption").WithCause(err))
+ return
+ }
+ if _, err := pendingSvc.ConsumeBrowserSession(c.Request.Context(), sessionToken, browserSessionKey); err != nil {
+ clearOAuthPendingSessionCookie(c, secureCookie)
+ clearOAuthPendingBrowserCookie(c, secureCookie)
+ response.ErrorFrom(c, err)
+ return
+ }
+ clearOAuthPendingSessionCookie(c, secureCookie)
+ clearOAuthPendingBrowserCookie(c, secureCookie)
+
+ c.JSON(http.StatusOK, gin.H{
+ "access_token": tokenPair.AccessToken,
+ "refresh_token": tokenPair.RefreshToken,
+ "expires_in": tokenPair.ExpiresIn,
+ "token_type": "Bearer",
+ })
+}
+
+func (h *AuthHandler) createWeChatPendingSession(
+ c *gin.Context,
+ intent string,
+ providerSubject string,
+ email string,
+ redirectTo string,
+ browserSessionKey string,
+ upstreamClaims map[string]any,
+ tokenPair *service.TokenPair,
+ authErr error,
+) error {
+ completionResponse := map[string]any{
+ "redirect": redirectTo,
+ }
+ if authErr != nil {
+ if errors.Is(authErr, service.ErrOAuthInvitationRequired) {
+ completionResponse["error"] = "invitation_required"
+ } else {
+ return authErr
+ }
+ } else if tokenPair != nil {
+ completionResponse["access_token"] = tokenPair.AccessToken
+ completionResponse["refresh_token"] = tokenPair.RefreshToken
+ completionResponse["expires_in"] = tokenPair.ExpiresIn
+ completionResponse["token_type"] = "Bearer"
+ }
+
+ return h.createOAuthPendingSession(c, oauthPendingSessionPayload{
+ Intent: intent,
+ Identity: service.PendingAuthIdentityKey{
+ ProviderType: "wechat",
+ ProviderKey: wechatOAuthProviderKey,
+ ProviderSubject: providerSubject,
+ },
+ ResolvedEmail: email,
+ RedirectTo: redirectTo,
+ BrowserSessionKey: browserSessionKey,
+ UpstreamIdentityClaims: upstreamClaims,
+ CompletionResponse: completionResponse,
+ })
+}
+
+func (h *AuthHandler) getWeChatOAuthConfig(ctx context.Context, rawMode string, c *gin.Context) (wechatOAuthConfig, error) {
+ mode, err := resolveWeChatOAuthMode(rawMode, c)
+ if err != nil {
+ return wechatOAuthConfig{}, err
+ }
+
+ apiBaseURL := ""
+ if h != nil && h.settingSvc != nil {
+ settings, err := h.settingSvc.GetAllSettings(ctx)
+ if err == nil && settings != nil {
+ apiBaseURL = strings.TrimSpace(settings.APIBaseURL)
+ }
+ }
+
+ cfg := wechatOAuthConfig{
+ mode: mode,
+ redirectURI: resolveWeChatOAuthAbsoluteURL(apiBaseURL, c, "/api/v1/auth/oauth/wechat/callback"),
+ frontendCallback: wechatOAuthFrontendCallback(),
+ }
+
+ switch mode {
+ case "mp":
+ cfg.appID = strings.TrimSpace(os.Getenv("WECHAT_OAUTH_MP_APP_ID"))
+ cfg.appSecret = strings.TrimSpace(os.Getenv("WECHAT_OAUTH_MP_APP_SECRET"))
+ cfg.authorizeURL = "https://open.weixin.qq.com/connect/oauth2/authorize"
+ cfg.scope = "snsapi_userinfo"
+ default:
+ cfg.appID = strings.TrimSpace(os.Getenv("WECHAT_OAUTH_OPEN_APP_ID"))
+ cfg.appSecret = strings.TrimSpace(os.Getenv("WECHAT_OAUTH_OPEN_APP_SECRET"))
+ cfg.authorizeURL = "https://open.weixin.qq.com/connect/qrconnect"
+ cfg.scope = "snsapi_login"
+ }
+
+ if cfg.appID == "" || cfg.appSecret == "" {
+ return wechatOAuthConfig{}, infraerrors.NotFound("OAUTH_DISABLED", "wechat oauth is disabled")
+ }
+ if strings.TrimSpace(cfg.redirectURI) == "" {
+ return wechatOAuthConfig{}, infraerrors.InternalServer("OAUTH_CONFIG_INVALID", "wechat oauth redirect url not configured")
+ }
+
+ return cfg, nil
+}
+
+func wechatOAuthFrontendCallback() string {
+ return firstNonEmpty(strings.TrimSpace(os.Getenv("WECHAT_OAUTH_FRONTEND_REDIRECT_URL")), wechatOAuthDefaultFrontendCB)
+}
+
+func resolveWeChatOAuthMode(rawMode string, c *gin.Context) (string, error) {
+ mode := strings.ToLower(strings.TrimSpace(rawMode))
+ if mode == "" {
+ if isWeChatBrowserRequest(c) {
+ return "mp", nil
+ }
+ return "open", nil
+ }
+ if mode != "open" && mode != "mp" {
+ return "", infraerrors.BadRequest("INVALID_MODE", "wechat oauth mode must be open or mp")
+ }
+ return mode, nil
+}
+
+func isWeChatBrowserRequest(c *gin.Context) bool {
+ if c == nil || c.Request == nil {
+ return false
+ }
+ return strings.Contains(strings.ToLower(strings.TrimSpace(c.GetHeader("User-Agent"))), "micromessenger")
+}
+
+func normalizeWeChatOAuthIntent(raw string) string {
+ switch strings.ToLower(strings.TrimSpace(raw)) {
+ case "", "login":
+ return wechatOAuthIntentLogin
+ case "bind", "bind_current_user":
+ return wechatOAuthIntentBind
+ case "adopt", "adopt_existing_user_by_email":
+ return wechatOAuthIntentAdoptEmail
+ default:
+ return wechatOAuthIntentLogin
+ }
+}
+
+func buildWeChatAuthorizeURL(cfg wechatOAuthConfig, state string) (string, error) {
+ u, err := url.Parse(cfg.authorizeURL)
+ if err != nil {
+ return "", fmt.Errorf("parse authorize url: %w", err)
+ }
+ query := u.Query()
+ query.Set("appid", cfg.appID)
+ query.Set("redirect_uri", cfg.redirectURI)
+ query.Set("response_type", "code")
+ query.Set("scope", cfg.scope)
+ query.Set("state", state)
+ u.RawQuery = query.Encode()
+ u.Fragment = "wechat_redirect"
+ return u.String(), nil
+}
+
+func resolveWeChatOAuthAbsoluteURL(apiBaseURL string, c *gin.Context, callbackPath string) string {
+ callbackPath = strings.TrimSpace(callbackPath)
+ if callbackPath == "" {
+ return ""
+ }
+
+ if raw := strings.TrimSpace(apiBaseURL); raw != "" {
+ if parsed, err := url.Parse(raw); err == nil && parsed.Scheme != "" && parsed.Host != "" {
+ basePath := strings.TrimRight(parsed.EscapedPath(), "/")
+ targetPath := callbackPath
+ if basePath != "" && strings.HasSuffix(basePath, "/api/v1") && strings.HasPrefix(callbackPath, "/api/v1") {
+ targetPath = basePath + strings.TrimPrefix(callbackPath, "/api/v1")
+ } else if basePath != "" {
+ targetPath = basePath + callbackPath
+ }
+ return parsed.Scheme + "://" + parsed.Host + targetPath
+ }
+ }
+
+ if c == nil || c.Request == nil {
+ return ""
+ }
+ scheme := "http"
+ if isRequestHTTPS(c) {
+ scheme = "https"
+ }
+ host := strings.TrimSpace(c.Request.Host)
+ if forwardedHost := strings.TrimSpace(c.GetHeader("X-Forwarded-Host")); forwardedHost != "" {
+ host = forwardedHost
+ }
+ if host == "" {
+ return ""
+ }
+ return scheme + "://" + host + callbackPath
+}
+
+func fetchWeChatOAuthIdentity(ctx context.Context, cfg wechatOAuthConfig, code string) (*wechatOAuthTokenResponse, *wechatOAuthUserInfoResponse, error) {
+ tokenResp, err := exchangeWeChatOAuthCode(ctx, cfg, code)
+ if err != nil {
+ return nil, nil, err
+ }
+ userInfo, err := fetchWeChatUserInfo(ctx, tokenResp)
+ if err != nil {
+ return nil, nil, err
+ }
+ return tokenResp, userInfo, nil
+}
+
+func exchangeWeChatOAuthCode(ctx context.Context, cfg wechatOAuthConfig, code string) (*wechatOAuthTokenResponse, error) {
+ endpoint, err := url.Parse(wechatOAuthAccessTokenURL)
+ if err != nil {
+ return nil, fmt.Errorf("parse wechat access token url: %w", err)
+ }
+
+ query := endpoint.Query()
+ query.Set("appid", cfg.appID)
+ query.Set("secret", cfg.appSecret)
+ query.Set("code", strings.TrimSpace(code))
+ query.Set("grant_type", "authorization_code")
+ endpoint.RawQuery = query.Encode()
+
+ req, err := http.NewRequestWithContext(ctx, http.MethodGet, endpoint.String(), nil)
+ if err != nil {
+ return nil, fmt.Errorf("build wechat access token request: %w", err)
+ }
+
+ client := &http.Client{Timeout: 30 * time.Second}
+ resp, err := client.Do(req)
+ if err != nil {
+ return nil, fmt.Errorf("request wechat access token: %w", err)
+ }
+ defer resp.Body.Close()
+
+ body, err := io.ReadAll(resp.Body)
+ if err != nil {
+ return nil, fmt.Errorf("read wechat access token response: %w", err)
+ }
+ if resp.StatusCode < http.StatusOK || resp.StatusCode >= http.StatusMultipleChoices {
+ return nil, fmt.Errorf("wechat access token status=%d", resp.StatusCode)
+ }
+
+ var tokenResp wechatOAuthTokenResponse
+ if err := json.Unmarshal(body, &tokenResp); err != nil {
+ return nil, fmt.Errorf("decode wechat access token response: %w", err)
+ }
+ if tokenResp.ErrCode != 0 {
+ return nil, fmt.Errorf("wechat access token error=%d %s", tokenResp.ErrCode, strings.TrimSpace(tokenResp.ErrMsg))
+ }
+ if strings.TrimSpace(tokenResp.AccessToken) == "" {
+ return nil, fmt.Errorf("wechat access token missing access_token")
+ }
+ return &tokenResp, nil
+}
+
+func fetchWeChatUserInfo(ctx context.Context, tokenResp *wechatOAuthTokenResponse) (*wechatOAuthUserInfoResponse, error) {
+ if tokenResp == nil {
+ return nil, fmt.Errorf("wechat token response is nil")
+ }
+
+ endpoint, err := url.Parse(wechatOAuthUserInfoURL)
+ if err != nil {
+ return nil, fmt.Errorf("parse wechat userinfo url: %w", err)
+ }
+ query := endpoint.Query()
+ query.Set("access_token", strings.TrimSpace(tokenResp.AccessToken))
+ query.Set("openid", strings.TrimSpace(tokenResp.OpenID))
+ query.Set("lang", "zh_CN")
+ endpoint.RawQuery = query.Encode()
+
+ req, err := http.NewRequestWithContext(ctx, http.MethodGet, endpoint.String(), nil)
+ if err != nil {
+ return nil, fmt.Errorf("build wechat userinfo request: %w", err)
+ }
+
+ client := &http.Client{Timeout: 30 * time.Second}
+ resp, err := client.Do(req)
+ if err != nil {
+ return nil, fmt.Errorf("request wechat userinfo: %w", err)
+ }
+ defer resp.Body.Close()
+
+ body, err := io.ReadAll(resp.Body)
+ if err != nil {
+ return nil, fmt.Errorf("read wechat userinfo response: %w", err)
+ }
+ if resp.StatusCode < http.StatusOK || resp.StatusCode >= http.StatusMultipleChoices {
+ return nil, fmt.Errorf("wechat userinfo status=%d", resp.StatusCode)
+ }
+
+ var userInfo wechatOAuthUserInfoResponse
+ if err := json.Unmarshal(body, &userInfo); err != nil {
+ return nil, fmt.Errorf("decode wechat userinfo response: %w", err)
+ }
+ if userInfo.ErrCode != 0 {
+ return nil, fmt.Errorf("wechat userinfo error=%d %s", userInfo.ErrCode, strings.TrimSpace(userInfo.ErrMsg))
+ }
+ return &userInfo, nil
+}
+
+func wechatSyntheticEmail(subject string) string {
+ subject = strings.TrimSpace(subject)
+ if subject == "" {
+ return ""
+ }
+ return "wechat-" + subject + service.WeChatConnectSyntheticEmailDomain
+}
+
+func wechatFallbackUsername(subject string) string {
+ subject = strings.TrimSpace(subject)
+ if subject == "" {
+ return "wechat_user"
+ }
+ return "wechat_" + truncateFragmentValue(subject)
+}
+
+func wechatSetCookie(c *gin.Context, name string, value string, maxAgeSec int, secure bool) {
+ http.SetCookie(c.Writer, &http.Cookie{
+ Name: name,
+ Value: value,
+ Path: wechatOAuthCookiePath,
+ MaxAge: maxAgeSec,
+ HttpOnly: true,
+ Secure: secure,
+ SameSite: http.SameSiteLaxMode,
+ })
+}
+
+func wechatClearCookie(c *gin.Context, name string, secure bool) {
+ http.SetCookie(c.Writer, &http.Cookie{
+ Name: name,
+ Value: "",
+ Path: wechatOAuthCookiePath,
+ MaxAge: -1,
+ HttpOnly: true,
+ Secure: secure,
+ SameSite: http.SameSiteLaxMode,
+ })
+}
diff --git a/backend/internal/handler/auth_wechat_oauth_test.go b/backend/internal/handler/auth_wechat_oauth_test.go
new file mode 100644
index 00000000..1a765dcc
--- /dev/null
+++ b/backend/internal/handler/auth_wechat_oauth_test.go
@@ -0,0 +1,411 @@
+//go:build unit
+
+package handler
+
+import (
+ "bytes"
+ "context"
+ "database/sql"
+ "encoding/base64"
+ "net/http"
+ "net/http/httptest"
+ "strings"
+ "testing"
+ "time"
+
+ dbent "github.com/Wei-Shaw/sub2api/ent"
+ "github.com/Wei-Shaw/sub2api/ent/authidentity"
+ "github.com/Wei-Shaw/sub2api/ent/enttest"
+ "github.com/Wei-Shaw/sub2api/ent/identityadoptiondecision"
+ "github.com/Wei-Shaw/sub2api/ent/pendingauthsession"
+ dbuser "github.com/Wei-Shaw/sub2api/ent/user"
+ "github.com/Wei-Shaw/sub2api/internal/config"
+ "github.com/Wei-Shaw/sub2api/internal/repository"
+ "github.com/Wei-Shaw/sub2api/internal/service"
+ "github.com/gin-gonic/gin"
+ "github.com/stretchr/testify/require"
+
+ "entgo.io/ent/dialect"
+ entsql "entgo.io/ent/dialect/sql"
+ _ "modernc.org/sqlite"
+)
+
+func TestWeChatOAuthStartRedirectsAndSetsPendingCookies(t *testing.T) {
+ t.Setenv("WECHAT_OAUTH_OPEN_APP_ID", "wx-open-app")
+ t.Setenv("WECHAT_OAUTH_OPEN_APP_SECRET", "wx-open-secret")
+
+ gin.SetMode(gin.TestMode)
+ recorder := httptest.NewRecorder()
+ c, _ := gin.CreateTestContext(recorder)
+ c.Request = httptest.NewRequest(http.MethodGet, "/api/v1/auth/oauth/wechat/start?mode=open&redirect=/billing", nil)
+ c.Request.Host = "api.example.com"
+
+ handler := &AuthHandler{}
+ handler.WeChatOAuthStart(c)
+
+ require.Equal(t, http.StatusFound, recorder.Code)
+ location := recorder.Header().Get("Location")
+ require.NotEmpty(t, location)
+ require.Contains(t, location, "open.weixin.qq.com")
+ require.Contains(t, location, "appid=wx-open-app")
+ require.Contains(t, location, "scope=snsapi_login")
+
+ cookies := recorder.Result().Cookies()
+ require.NotEmpty(t, findCookie(cookies, wechatOAuthStateCookieName))
+ require.NotEmpty(t, findCookie(cookies, wechatOAuthRedirectCookieName))
+ require.NotEmpty(t, findCookie(cookies, wechatOAuthModeCookieName))
+ require.NotEmpty(t, findCookie(cookies, oauthPendingBrowserCookieName))
+}
+
+func TestWeChatOAuthCallbackCreatesPendingSessionForUnifiedFlow(t *testing.T) {
+ t.Setenv("WECHAT_OAUTH_OPEN_APP_ID", "wx-open-app")
+ t.Setenv("WECHAT_OAUTH_OPEN_APP_SECRET", "wx-open-secret")
+ t.Setenv("WECHAT_OAUTH_FRONTEND_REDIRECT_URL", "/auth/wechat/callback")
+
+ originalAccessTokenURL := wechatOAuthAccessTokenURL
+ originalUserInfoURL := wechatOAuthUserInfoURL
+ t.Cleanup(func() {
+ wechatOAuthAccessTokenURL = originalAccessTokenURL
+ wechatOAuthUserInfoURL = originalUserInfoURL
+ })
+
+ upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ switch {
+ case strings.Contains(r.URL.Path, "/sns/oauth2/access_token"):
+ w.Header().Set("Content-Type", "application/json")
+ _, _ = w.Write([]byte(`{"access_token":"wechat-access","openid":"openid-123","unionid":"union-456","scope":"snsapi_login"}`))
+ case strings.Contains(r.URL.Path, "/sns/userinfo"):
+ w.Header().Set("Content-Type", "application/json")
+ _, _ = w.Write([]byte(`{"openid":"openid-123","unionid":"union-456","nickname":"WeChat Nick","headimgurl":"https://cdn.example/avatar.png"}`))
+ default:
+ http.NotFound(w, r)
+ }
+ }))
+ defer upstream.Close()
+ wechatOAuthAccessTokenURL = upstream.URL + "/sns/oauth2/access_token"
+ wechatOAuthUserInfoURL = upstream.URL + "/sns/userinfo"
+
+ handler, client := newWeChatOAuthTestHandler(t, false)
+ defer client.Close()
+
+ recorder := httptest.NewRecorder()
+ c, _ := gin.CreateTestContext(recorder)
+ req := httptest.NewRequest(http.MethodGet, "/api/v1/auth/oauth/wechat/callback?code=wechat-code&state=state-123", nil)
+ req.Host = "api.example.com"
+ req.AddCookie(encodedCookie(wechatOAuthStateCookieName, "state-123"))
+ req.AddCookie(encodedCookie(wechatOAuthRedirectCookieName, "/dashboard"))
+ req.AddCookie(encodedCookie(wechatOAuthModeCookieName, "open"))
+ req.AddCookie(encodedCookie(oauthPendingBrowserCookieName, "browser-123"))
+ c.Request = req
+
+ handler.WeChatOAuthCallback(c)
+
+ require.Equal(t, http.StatusFound, recorder.Code)
+ require.Equal(t, "/auth/wechat/callback", recorder.Header().Get("Location"))
+
+ sessionCookie := findCookie(recorder.Result().Cookies(), oauthPendingSessionCookieName)
+ require.NotNil(t, sessionCookie)
+
+ ctx := context.Background()
+ session, err := client.PendingAuthSession.Query().
+ Where(pendingauthsession.SessionTokenEQ(decodeCookieValueForTest(t, sessionCookie.Value))).
+ Only(ctx)
+ require.NoError(t, err)
+ require.Equal(t, "wechat", session.ProviderType)
+ require.Equal(t, "wechat-main", session.ProviderKey)
+ require.Equal(t, "union-456", session.ProviderSubject)
+ require.Equal(t, "wechat-union-456@wechat-connect.invalid", session.ResolvedEmail)
+ require.Equal(t, "WeChat Nick", session.UpstreamIdentityClaims["suggested_display_name"])
+ require.Equal(t, "https://cdn.example/avatar.png", session.UpstreamIdentityClaims["suggested_avatar_url"])
+ require.Equal(t, "union-456", session.UpstreamIdentityClaims["unionid"])
+ require.Equal(t, "openid-123", session.UpstreamIdentityClaims["openid"])
+}
+
+func TestCompleteWeChatOAuthRegistrationAfterInvitationPendingSession(t *testing.T) {
+ t.Setenv("WECHAT_OAUTH_OPEN_APP_ID", "wx-open-app")
+ t.Setenv("WECHAT_OAUTH_OPEN_APP_SECRET", "wx-open-secret")
+ t.Setenv("WECHAT_OAUTH_FRONTEND_REDIRECT_URL", "/auth/wechat/callback")
+
+ originalAccessTokenURL := wechatOAuthAccessTokenURL
+ originalUserInfoURL := wechatOAuthUserInfoURL
+ t.Cleanup(func() {
+ wechatOAuthAccessTokenURL = originalAccessTokenURL
+ wechatOAuthUserInfoURL = originalUserInfoURL
+ })
+
+ upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ switch {
+ case strings.Contains(r.URL.Path, "/sns/oauth2/access_token"):
+ w.Header().Set("Content-Type", "application/json")
+ _, _ = w.Write([]byte(`{"access_token":"wechat-access","openid":"openid-123","unionid":"union-456","scope":"snsapi_login"}`))
+ case strings.Contains(r.URL.Path, "/sns/userinfo"):
+ w.Header().Set("Content-Type", "application/json")
+ _, _ = w.Write([]byte(`{"openid":"openid-123","unionid":"union-456","nickname":"WeChat Display","headimgurl":"https://cdn.example/wechat.png"}`))
+ default:
+ http.NotFound(w, r)
+ }
+ }))
+ defer upstream.Close()
+ wechatOAuthAccessTokenURL = upstream.URL + "/sns/oauth2/access_token"
+ wechatOAuthUserInfoURL = upstream.URL + "/sns/userinfo"
+
+ handler, client := newWeChatOAuthTestHandler(t, true)
+ defer client.Close()
+
+ ctx := context.Background()
+ redeemRepo := repository.NewRedeemCodeRepository(client)
+ require.NoError(t, redeemRepo.Create(ctx, &service.RedeemCode{
+ Code: "invite-1",
+ Type: service.RedeemTypeInvitation,
+ Status: service.StatusUnused,
+ }))
+
+ callbackRecorder := httptest.NewRecorder()
+ callbackCtx, _ := gin.CreateTestContext(callbackRecorder)
+ callbackReq := httptest.NewRequest(http.MethodGet, "/api/v1/auth/oauth/wechat/callback?code=wechat-code&state=state-123", nil)
+ callbackReq.Host = "api.example.com"
+ callbackReq.AddCookie(encodedCookie(wechatOAuthStateCookieName, "state-123"))
+ callbackReq.AddCookie(encodedCookie(wechatOAuthRedirectCookieName, "/dashboard"))
+ callbackReq.AddCookie(encodedCookie(wechatOAuthModeCookieName, "open"))
+ callbackReq.AddCookie(encodedCookie(oauthPendingBrowserCookieName, "browser-123"))
+ callbackCtx.Request = callbackReq
+
+ handler.WeChatOAuthCallback(callbackCtx)
+
+ require.Equal(t, http.StatusFound, callbackRecorder.Code)
+ require.Equal(t, "/auth/wechat/callback", callbackRecorder.Header().Get("Location"))
+
+ sessionCookie := findCookie(callbackRecorder.Result().Cookies(), oauthPendingSessionCookieName)
+ require.NotNil(t, sessionCookie)
+ sessionToken := decodeCookieValueForTest(t, sessionCookie.Value)
+
+ pendingSession, err := client.PendingAuthSession.Query().
+ Where(pendingauthsession.SessionTokenEQ(sessionToken)).
+ Only(ctx)
+ require.NoError(t, err)
+ require.Equal(t, "invitation_required", pendingSession.LocalFlowState[oauthCompletionResponseKey].(map[string]any)["error"])
+
+ body := bytes.NewBufferString(`{"invitation_code":"invite-1","adopt_display_name":true,"adopt_avatar":true}`)
+ completeRecorder := httptest.NewRecorder()
+ completeCtx, _ := gin.CreateTestContext(completeRecorder)
+ completeReq := httptest.NewRequest(http.MethodPost, "/api/v1/auth/oauth/wechat/complete-registration", body)
+ completeReq.Header.Set("Content-Type", "application/json")
+ completeReq.AddCookie(&http.Cookie{Name: oauthPendingSessionCookieName, Value: encodeCookieValue(sessionToken)})
+ completeReq.AddCookie(&http.Cookie{Name: oauthPendingBrowserCookieName, Value: encodeCookieValue("browser-123")})
+ completeCtx.Request = completeReq
+
+ handler.CompleteWeChatOAuthRegistration(completeCtx)
+
+ require.Equal(t, http.StatusOK, completeRecorder.Code)
+ responseData := decodeJSONBody(t, completeRecorder)
+ require.NotEmpty(t, responseData["access_token"])
+
+ userEntity, err := client.User.Query().
+ Where(dbuser.EmailEQ("wechat-union-456@wechat-connect.invalid")).
+ Only(ctx)
+ require.NoError(t, err)
+ require.Equal(t, "WeChat Display", userEntity.Username)
+
+ identity, err := client.AuthIdentity.Query().
+ Where(
+ authidentity.ProviderTypeEQ("wechat"),
+ authidentity.ProviderKeyEQ("wechat-main"),
+ authidentity.ProviderSubjectEQ("union-456"),
+ ).
+ Only(ctx)
+ require.NoError(t, err)
+ require.Equal(t, userEntity.ID, identity.UserID)
+ require.Equal(t, "WeChat Display", identity.Metadata["display_name"])
+ require.Equal(t, "https://cdn.example/wechat.png", identity.Metadata["avatar_url"])
+
+ decision, err := client.IdentityAdoptionDecision.Query().
+ Where(identityadoptiondecision.PendingAuthSessionIDEQ(pendingSession.ID)).
+ Only(ctx)
+ require.NoError(t, err)
+ require.NotNil(t, decision.IdentityID)
+ require.Equal(t, identity.ID, *decision.IdentityID)
+ require.True(t, decision.AdoptDisplayName)
+ require.True(t, decision.AdoptAvatar)
+
+ consumed, err := client.PendingAuthSession.Query().
+ Where(pendingauthsession.IDEQ(pendingSession.ID)).
+ Only(ctx)
+ require.NoError(t, err)
+ require.NotNil(t, consumed.ConsumedAt)
+}
+
+func newWeChatOAuthTestHandler(t *testing.T, invitationEnabled bool) (*AuthHandler, *dbent.Client) {
+ t.Helper()
+
+ db, err := sql.Open("sqlite", "file:auth_wechat_oauth?mode=memory&cache=shared")
+ require.NoError(t, err)
+ t.Cleanup(func() { _ = db.Close() })
+
+ _, err = db.Exec("PRAGMA foreign_keys = ON")
+ require.NoError(t, err)
+
+ drv := entsql.OpenDB(dialect.SQLite, db)
+ client := enttest.NewClient(t, enttest.WithOptions(dbent.Driver(drv)))
+
+ userRepo := &oauthPendingFlowUserRepo{client: client}
+ redeemRepo := repository.NewRedeemCodeRepository(client)
+ settingSvc := service.NewSettingService(&wechatOAuthSettingRepoStub{
+ values: map[string]string{
+ service.SettingKeyRegistrationEnabled: "true",
+ service.SettingKeyInvitationCodeEnabled: boolSettingValue(invitationEnabled),
+ },
+ }, &config.Config{
+ JWT: config.JWTConfig{
+ Secret: "test-secret",
+ ExpireHour: 1,
+ AccessTokenExpireMinutes: 60,
+ RefreshTokenExpireDays: 7,
+ },
+ Default: config.DefaultConfig{
+ UserBalance: 0,
+ UserConcurrency: 1,
+ },
+ })
+
+ authSvc := service.NewAuthService(
+ client,
+ userRepo,
+ redeemRepo,
+ &wechatOAuthRefreshTokenCacheStub{},
+ &config.Config{
+ JWT: config.JWTConfig{
+ Secret: "test-secret",
+ ExpireHour: 1,
+ AccessTokenExpireMinutes: 60,
+ RefreshTokenExpireDays: 7,
+ },
+ Default: config.DefaultConfig{
+ UserBalance: 0,
+ UserConcurrency: 1,
+ },
+ },
+ settingSvc,
+ nil,
+ nil,
+ nil,
+ nil,
+ nil,
+ )
+
+ return &AuthHandler{
+ authService: authSvc,
+ settingSvc: settingSvc,
+ }, client
+}
+
+func encodedCookie(name, value string) *http.Cookie {
+ return &http.Cookie{
+ Name: name,
+ Value: encodeCookieValue(value),
+ Path: "/",
+ }
+}
+
+func findCookie(cookies []*http.Cookie, name string) *http.Cookie {
+ for _, cookie := range cookies {
+ if cookie.Name == name {
+ return cookie
+ }
+ }
+ return nil
+}
+
+func decodeCookieValueForTest(t *testing.T, value string) string {
+ t.Helper()
+ raw, err := base64.RawURLEncoding.DecodeString(value)
+ require.NoError(t, err)
+ return string(raw)
+}
+
+type wechatOAuthSettingRepoStub struct {
+ values map[string]string
+}
+
+func (s *wechatOAuthSettingRepoStub) Get(context.Context, string) (*service.Setting, error) {
+ return nil, service.ErrSettingNotFound
+}
+
+func (s *wechatOAuthSettingRepoStub) GetValue(_ context.Context, key string) (string, error) {
+ value, ok := s.values[key]
+ if !ok {
+ return "", service.ErrSettingNotFound
+ }
+ return value, nil
+}
+
+func (s *wechatOAuthSettingRepoStub) Set(context.Context, string, string) error {
+ return nil
+}
+
+func (s *wechatOAuthSettingRepoStub) GetMultiple(_ context.Context, keys []string) (map[string]string, error) {
+ result := make(map[string]string, len(keys))
+ for _, key := range keys {
+ if value, ok := s.values[key]; ok {
+ result[key] = value
+ }
+ }
+ return result, nil
+}
+
+func (s *wechatOAuthSettingRepoStub) SetMultiple(context.Context, map[string]string) error {
+ return nil
+}
+
+func (s *wechatOAuthSettingRepoStub) GetAll(context.Context) (map[string]string, error) {
+ result := make(map[string]string, len(s.values))
+ for key, value := range s.values {
+ result[key] = value
+ }
+ return result, nil
+}
+
+func (s *wechatOAuthSettingRepoStub) Delete(context.Context, string) error {
+ return nil
+}
+
+type wechatOAuthRefreshTokenCacheStub struct{}
+
+func (s *wechatOAuthRefreshTokenCacheStub) StoreRefreshToken(context.Context, string, *service.RefreshTokenData, time.Duration) error {
+ return nil
+}
+
+func (s *wechatOAuthRefreshTokenCacheStub) GetRefreshToken(context.Context, string) (*service.RefreshTokenData, error) {
+ return nil, service.ErrRefreshTokenNotFound
+}
+
+func (s *wechatOAuthRefreshTokenCacheStub) DeleteRefreshToken(context.Context, string) error {
+ return nil
+}
+
+func (s *wechatOAuthRefreshTokenCacheStub) DeleteUserRefreshTokens(context.Context, int64) error {
+ return nil
+}
+
+func (s *wechatOAuthRefreshTokenCacheStub) DeleteTokenFamily(context.Context, string) error {
+ return nil
+}
+
+func (s *wechatOAuthRefreshTokenCacheStub) AddToUserTokenSet(context.Context, int64, string, time.Duration) error {
+ return nil
+}
+
+func (s *wechatOAuthRefreshTokenCacheStub) AddToFamilyTokenSet(context.Context, string, string, time.Duration) error {
+ return nil
+}
+
+func (s *wechatOAuthRefreshTokenCacheStub) GetUserTokenHashes(context.Context, int64) ([]string, error) {
+ return nil, nil
+}
+
+func (s *wechatOAuthRefreshTokenCacheStub) GetFamilyTokenHashes(context.Context, string) ([]string, error) {
+ return nil, nil
+}
+
+func (s *wechatOAuthRefreshTokenCacheStub) IsTokenInFamily(context.Context, string, string) (bool, error) {
+ return false, nil
+}
diff --git a/backend/internal/handler/dto/settings.go b/backend/internal/handler/dto/settings.go
index 3659e79b..f44b3e3b 100644
--- a/backend/internal/handler/dto/settings.go
+++ b/backend/internal/handler/dto/settings.go
@@ -189,6 +189,7 @@ type PublicSettings struct {
CustomMenuItems []CustomMenuItem `json:"custom_menu_items"`
CustomEndpoints []CustomEndpoint `json:"custom_endpoints"`
LinuxDoOAuthEnabled bool `json:"linuxdo_oauth_enabled"`
+ WeChatOAuthEnabled bool `json:"wechat_oauth_enabled"`
OIDCOAuthEnabled bool `json:"oidc_oauth_enabled"`
OIDCOAuthProviderName string `json:"oidc_oauth_provider_name"`
SoraClientEnabled bool `json:"sora_client_enabled"`
diff --git a/backend/internal/handler/payment_webhook_handler.go b/backend/internal/handler/payment_webhook_handler.go
index 8a83bfeb..9fdefa93 100644
--- a/backend/internal/handler/payment_webhook_handler.go
+++ b/backend/internal/handler/payment_webhook_handler.go
@@ -120,7 +120,7 @@ func (h *PaymentWebhookHandler) handleNotify(c *gin.Context, providerKey string)
// This allows looking up the correct provider instance before verification.
func extractOutTradeNo(rawBody, providerKey string) string {
switch providerKey {
- case payment.TypeEasyPay:
+ case payment.TypeEasyPay, payment.TypeAlipay:
values, err := url.ParseQuery(rawBody)
if err == nil {
return values.Get("out_trade_no")
diff --git a/backend/internal/handler/payment_webhook_handler_test.go b/backend/internal/handler/payment_webhook_handler_test.go
index bdef1766..6f448131 100644
--- a/backend/internal/handler/payment_webhook_handler_test.go
+++ b/backend/internal/handler/payment_webhook_handler_test.go
@@ -97,3 +97,37 @@ func TestWebhookConstants(t *testing.T) {
assert.Equal(t, 200, webhookLogTruncateLen)
})
}
+
+func TestExtractOutTradeNo(t *testing.T) {
+ tests := []struct {
+ name string
+ providerKey string
+ rawBody string
+ want string
+ }{
+ {
+ name: "easypay query payload",
+ providerKey: "easypay",
+ rawBody: "out_trade_no=sub2_123&trade_status=TRADE_SUCCESS",
+ want: "sub2_123",
+ },
+ {
+ name: "alipay query payload",
+ providerKey: "alipay",
+ rawBody: "notify_time=2026-04-20+12%3A00%3A00&out_trade_no=sub2_456",
+ want: "sub2_456",
+ },
+ {
+ name: "unknown provider",
+ providerKey: "wxpay",
+ rawBody: "{}",
+ want: "",
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ assert.Equal(t, tt.want, extractOutTradeNo(tt.rawBody, tt.providerKey))
+ })
+ }
+}
diff --git a/backend/internal/handler/setting_handler.go b/backend/internal/handler/setting_handler.go
index 1717b7a1..c7bc3e2a 100644
--- a/backend/internal/handler/setting_handler.go
+++ b/backend/internal/handler/setting_handler.go
@@ -56,6 +56,7 @@ func (h *SettingHandler) GetPublicSettings(c *gin.Context) {
CustomMenuItems: dto.ParseUserVisibleMenuItems(settings.CustomMenuItems),
CustomEndpoints: dto.ParseCustomEndpoints(settings.CustomEndpoints),
LinuxDoOAuthEnabled: settings.LinuxDoOAuthEnabled,
+ WeChatOAuthEnabled: settings.WeChatOAuthEnabled,
OIDCOAuthEnabled: settings.OIDCOAuthEnabled,
OIDCOAuthProviderName: settings.OIDCOAuthProviderName,
BackendModeEnabled: settings.BackendModeEnabled,
diff --git a/backend/internal/handler/user_handler.go b/backend/internal/handler/user_handler.go
index 2535ea5e..904341d0 100644
--- a/backend/internal/handler/user_handler.go
+++ b/backend/internal/handler/user_handler.go
@@ -34,10 +34,16 @@ type ChangePasswordRequest struct {
// UpdateProfileRequest represents the update profile request payload
type UpdateProfileRequest struct {
Username *string `json:"username"`
+ AvatarURL *string `json:"avatar_url"`
BalanceNotifyEnabled *bool `json:"balance_notify_enabled"`
BalanceNotifyThreshold *float64 `json:"balance_notify_threshold"`
}
+type userProfileResponse struct {
+ dto.User
+ AvatarURL string `json:"avatar_url,omitempty"`
+}
+
// GetProfile handles getting user profile
// GET /api/v1/users/me
func (h *UserHandler) GetProfile(c *gin.Context) {
@@ -47,13 +53,13 @@ func (h *UserHandler) GetProfile(c *gin.Context) {
return
}
- userData, err := h.userService.GetByID(c.Request.Context(), subject.UserID)
+ userData, err := h.userService.GetProfile(c.Request.Context(), subject.UserID)
if err != nil {
response.ErrorFrom(c, err)
return
}
- response.Success(c, dto.UserFromService(userData))
+ response.Success(c, userProfileResponseFromService(userData))
}
// ChangePassword handles changing user password
@@ -101,6 +107,7 @@ func (h *UserHandler) UpdateProfile(c *gin.Context) {
svcReq := service.UpdateProfileRequest{
Username: req.Username,
+ AvatarURL: req.AvatarURL,
BalanceNotifyEnabled: req.BalanceNotifyEnabled,
BalanceNotifyThreshold: req.BalanceNotifyThreshold,
}
@@ -110,7 +117,7 @@ func (h *UserHandler) UpdateProfile(c *gin.Context) {
return
}
- response.Success(c, dto.UserFromService(updatedUser))
+ response.Success(c, userProfileResponseFromService(updatedUser))
}
// SendNotifyEmailCodeRequest represents the request to send notify email verification code
@@ -176,7 +183,7 @@ func (h *UserHandler) VerifyNotifyEmail(c *gin.Context) {
return
}
- response.Success(c, dto.UserFromService(updatedUser))
+ response.Success(c, userProfileResponseFromService(updatedUser))
}
// RemoveNotifyEmailRequest represents the request to remove a notify email
@@ -212,7 +219,7 @@ func (h *UserHandler) RemoveNotifyEmail(c *gin.Context) {
return
}
- response.Success(c, dto.UserFromService(updatedUser))
+ response.Success(c, userProfileResponseFromService(updatedUser))
}
// ToggleNotifyEmailRequest represents the request to toggle a notify email's disabled state
@@ -248,5 +255,16 @@ func (h *UserHandler) ToggleNotifyEmail(c *gin.Context) {
return
}
- response.Success(c, dto.UserFromService(updatedUser))
+ response.Success(c, userProfileResponseFromService(updatedUser))
+}
+
+func userProfileResponseFromService(user *service.User) userProfileResponse {
+ base := dto.UserFromService(user)
+ if base == nil {
+ return userProfileResponse{}
+ }
+ return userProfileResponse{
+ User: *base,
+ AvatarURL: user.AvatarURL,
+ }
}
diff --git a/backend/internal/handler/user_handler_test.go b/backend/internal/handler/user_handler_test.go
new file mode 100644
index 00000000..1973f59e
--- /dev/null
+++ b/backend/internal/handler/user_handler_test.go
@@ -0,0 +1,136 @@
+//go:build unit
+
+package handler
+
+import (
+ "bytes"
+ "context"
+ "encoding/json"
+ "net/http"
+ "net/http/httptest"
+ "testing"
+
+ "github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
+ middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware"
+ "github.com/Wei-Shaw/sub2api/internal/service"
+ "github.com/gin-gonic/gin"
+ "github.com/stretchr/testify/require"
+)
+
+type userHandlerRepoStub struct {
+ user *service.User
+}
+
+func (s *userHandlerRepoStub) Create(context.Context, *service.User) error { return nil }
+func (s *userHandlerRepoStub) GetByID(context.Context, int64) (*service.User, error) {
+ cloned := *s.user
+ return &cloned, nil
+}
+func (s *userHandlerRepoStub) GetByEmail(context.Context, string) (*service.User, error) {
+ cloned := *s.user
+ return &cloned, nil
+}
+func (s *userHandlerRepoStub) GetFirstAdmin(context.Context) (*service.User, error) {
+ cloned := *s.user
+ return &cloned, nil
+}
+func (s *userHandlerRepoStub) Update(_ context.Context, user *service.User) error {
+ cloned := *user
+ s.user = &cloned
+ return nil
+}
+func (s *userHandlerRepoStub) Delete(context.Context, int64) error { return nil }
+func (s *userHandlerRepoStub) GetUserAvatar(context.Context, int64) (*service.UserAvatar, error) {
+ if s.user == nil || s.user.AvatarURL == "" {
+ return nil, nil
+ }
+ return &service.UserAvatar{
+ StorageProvider: s.user.AvatarSource,
+ URL: s.user.AvatarURL,
+ ContentType: s.user.AvatarMIME,
+ ByteSize: s.user.AvatarByteSize,
+ SHA256: s.user.AvatarSHA256,
+ }, nil
+}
+func (s *userHandlerRepoStub) UpsertUserAvatar(_ context.Context, _ int64, input service.UpsertUserAvatarInput) (*service.UserAvatar, error) {
+ s.user.AvatarURL = input.URL
+ s.user.AvatarSource = input.StorageProvider
+ s.user.AvatarMIME = input.ContentType
+ s.user.AvatarByteSize = input.ByteSize
+ s.user.AvatarSHA256 = input.SHA256
+ return &service.UserAvatar{
+ StorageProvider: input.StorageProvider,
+ URL: input.URL,
+ ContentType: input.ContentType,
+ ByteSize: input.ByteSize,
+ SHA256: input.SHA256,
+ }, nil
+}
+func (s *userHandlerRepoStub) DeleteUserAvatar(context.Context, int64) error {
+ s.user.AvatarURL = ""
+ s.user.AvatarSource = ""
+ s.user.AvatarMIME = ""
+ s.user.AvatarByteSize = 0
+ s.user.AvatarSHA256 = ""
+ return nil
+}
+func (s *userHandlerRepoStub) List(context.Context, pagination.PaginationParams) ([]service.User, *pagination.PaginationResult, error) {
+ return nil, nil, nil
+}
+func (s *userHandlerRepoStub) ListWithFilters(context.Context, pagination.PaginationParams, service.UserListFilters) ([]service.User, *pagination.PaginationResult, error) {
+ return nil, nil, nil
+}
+func (s *userHandlerRepoStub) UpdateBalance(context.Context, int64, float64) error { return nil }
+func (s *userHandlerRepoStub) DeductBalance(context.Context, int64, float64) error { return nil }
+func (s *userHandlerRepoStub) UpdateConcurrency(context.Context, int64, int) error { return nil }
+func (s *userHandlerRepoStub) ExistsByEmail(context.Context, string) (bool, error) { return false, nil }
+func (s *userHandlerRepoStub) RemoveGroupFromAllowedGroups(context.Context, int64) (int64, error) {
+ return 0, nil
+}
+func (s *userHandlerRepoStub) AddGroupToAllowedGroups(context.Context, int64, int64) error {
+ return nil
+}
+func (s *userHandlerRepoStub) RemoveGroupFromUserAllowedGroups(context.Context, int64, int64) error {
+ return nil
+}
+func (s *userHandlerRepoStub) UpdateTotpSecret(context.Context, int64, *string) error { return nil }
+func (s *userHandlerRepoStub) EnableTotp(context.Context, int64) error { return nil }
+func (s *userHandlerRepoStub) DisableTotp(context.Context, int64) error { return nil }
+
+func TestUserHandlerUpdateProfileReturnsAvatarURL(t *testing.T) {
+ gin.SetMode(gin.TestMode)
+
+ repo := &userHandlerRepoStub{
+ user: &service.User{
+ ID: 11,
+ Email: "handler-avatar@example.com",
+ Username: "handler-avatar",
+ Role: service.RoleUser,
+ Status: service.StatusActive,
+ },
+ }
+ handler := NewUserHandler(service.NewUserService(repo, nil, nil, nil), nil, nil)
+
+ body := []byte(`{"avatar_url":"https://cdn.example.com/avatar.png"}`)
+ recorder := httptest.NewRecorder()
+ c, _ := gin.CreateTestContext(recorder)
+ c.Request = httptest.NewRequest(http.MethodPut, "/api/v1/user", bytes.NewReader(body))
+ c.Request.Header.Set("Content-Type", "application/json")
+ c.Set(string(middleware2.ContextKeyUser), middleware2.AuthSubject{UserID: 11})
+
+ handler.UpdateProfile(c)
+
+ require.Equal(t, http.StatusOK, recorder.Code)
+
+ var resp struct {
+ Code int `json:"code"`
+ Data struct {
+ AvatarURL string `json:"avatar_url"`
+ Username string `json:"username"`
+ } `json:"data"`
+ }
+ require.NoError(t, json.Unmarshal(recorder.Body.Bytes(), &resp))
+ require.Equal(t, 0, resp.Code)
+ require.Equal(t, "https://cdn.example.com/avatar.png", resp.Data.AvatarURL)
+ require.Equal(t, "handler-avatar", resp.Data.Username)
+}
diff --git a/backend/internal/repository/api_key_repo.go b/backend/internal/repository/api_key_repo.go
index 38ea9bde..36d80309 100644
--- a/backend/internal/repository/api_key_repo.go
+++ b/backend/internal/repository/api_key_repo.go
@@ -149,6 +149,9 @@ func (r *apiKeyRepository) GetByKeyForAuth(ctx context.Context, key string) (*se
user.FieldBalanceNotifyThreshold,
user.FieldBalanceNotifyExtraEmails,
user.FieldTotalRecharged,
+ user.FieldSignupSource,
+ user.FieldLastLoginAt,
+ user.FieldLastActiveAt,
)
}).
WithGroup(func(q *dbent.GroupQuery) {
@@ -656,6 +659,9 @@ func userEntityToService(u *dbent.User) *service.User {
Balance: u.Balance,
Concurrency: u.Concurrency,
Status: u.Status,
+ SignupSource: u.SignupSource,
+ LastLoginAt: u.LastLoginAt,
+ LastActiveAt: u.LastActiveAt,
TotpSecretEncrypted: u.TotpSecretEncrypted,
TotpEnabled: u.TotpEnabled,
TotpEnabledAt: u.TotpEnabledAt,
diff --git a/backend/internal/repository/auth_identity_migration_report.go b/backend/internal/repository/auth_identity_migration_report.go
new file mode 100644
index 00000000..70f298c1
--- /dev/null
+++ b/backend/internal/repository/auth_identity_migration_report.go
@@ -0,0 +1,148 @@
+package repository
+
+import (
+ "context"
+ "database/sql"
+ "encoding/json"
+ "fmt"
+ "strings"
+ "time"
+)
+
+type AuthIdentityMigrationReport struct {
+ ID int64
+ ReportType string
+ ReportKey string
+ Details map[string]any
+ CreatedAt time.Time
+}
+
+type AuthIdentityMigrationReportQuery struct {
+ ReportType string
+ Limit int
+ Offset int
+}
+
+type AuthIdentityMigrationReportSummary struct {
+ Total int64
+ ByType map[string]int64
+}
+
+func (r *userRepository) ListAuthIdentityMigrationReports(ctx context.Context, query AuthIdentityMigrationReportQuery) ([]AuthIdentityMigrationReport, error) {
+ exec := txAwareSQLExecutor(ctx, r.sql, r.client)
+ if exec == nil {
+ return nil, fmt.Errorf("sql executor is not configured")
+ }
+
+ limit := query.Limit
+ if limit <= 0 {
+ limit = 100
+ }
+ rows, err := exec.QueryContext(ctx, `
+SELECT id, report_type, report_key, details, created_at
+FROM auth_identity_migration_reports
+WHERE ($1 = '' OR report_type = $1)
+ORDER BY created_at DESC, id DESC
+LIMIT $2 OFFSET $3`,
+ strings.TrimSpace(query.ReportType),
+ limit,
+ query.Offset,
+ )
+ if err != nil {
+ return nil, err
+ }
+ defer func() { _ = rows.Close() }()
+
+ reports := make([]AuthIdentityMigrationReport, 0)
+ for rows.Next() {
+ report, scanErr := scanAuthIdentityMigrationReport(rows)
+ if scanErr != nil {
+ return nil, scanErr
+ }
+ reports = append(reports, report)
+ }
+ if err := rows.Err(); err != nil {
+ return nil, err
+ }
+ return reports, nil
+}
+
+func (r *userRepository) GetAuthIdentityMigrationReport(ctx context.Context, reportType, reportKey string) (*AuthIdentityMigrationReport, error) {
+ exec := txAwareSQLExecutor(ctx, r.sql, r.client)
+ if exec == nil {
+ return nil, fmt.Errorf("sql executor is not configured")
+ }
+
+ rows, err := exec.QueryContext(ctx, `
+SELECT id, report_type, report_key, details, created_at
+FROM auth_identity_migration_reports
+WHERE report_type = $1 AND report_key = $2
+LIMIT 1`,
+ strings.TrimSpace(reportType),
+ strings.TrimSpace(reportKey),
+ )
+ if err != nil {
+ return nil, err
+ }
+ defer func() { _ = rows.Close() }()
+
+ if !rows.Next() {
+ return nil, sql.ErrNoRows
+ }
+ report, err := scanAuthIdentityMigrationReport(rows)
+ if err != nil {
+ return nil, err
+ }
+ return &report, rows.Err()
+}
+
+func (r *userRepository) SummarizeAuthIdentityMigrationReports(ctx context.Context) (*AuthIdentityMigrationReportSummary, error) {
+ exec := txAwareSQLExecutor(ctx, r.sql, r.client)
+ if exec == nil {
+ return nil, fmt.Errorf("sql executor is not configured")
+ }
+
+ rows, err := exec.QueryContext(ctx, `
+SELECT report_type, COUNT(*)
+FROM auth_identity_migration_reports
+GROUP BY report_type
+ORDER BY report_type ASC`)
+ if err != nil {
+ return nil, err
+ }
+ defer func() { _ = rows.Close() }()
+
+ summary := &AuthIdentityMigrationReportSummary{
+ ByType: make(map[string]int64),
+ }
+ for rows.Next() {
+ var reportType string
+ var count int64
+ if err := rows.Scan(&reportType, &count); err != nil {
+ return nil, err
+ }
+ summary.ByType[reportType] = count
+ summary.Total += count
+ }
+ if err := rows.Err(); err != nil {
+ return nil, err
+ }
+ return summary, nil
+}
+
+func scanAuthIdentityMigrationReport(scanner interface{ Scan(dest ...any) error }) (AuthIdentityMigrationReport, error) {
+ var (
+ report AuthIdentityMigrationReport
+ details []byte
+ )
+ if err := scanner.Scan(&report.ID, &report.ReportType, &report.ReportKey, &details, &report.CreatedAt); err != nil {
+ return AuthIdentityMigrationReport{}, err
+ }
+ report.Details = map[string]any{}
+ if len(details) > 0 {
+ if err := json.Unmarshal(details, &report.Details); err != nil {
+ return AuthIdentityMigrationReport{}, err
+ }
+ }
+ return report, nil
+}
diff --git a/backend/internal/repository/user_profile_identity_repo.go b/backend/internal/repository/user_profile_identity_repo.go
new file mode 100644
index 00000000..4ecae4a4
--- /dev/null
+++ b/backend/internal/repository/user_profile_identity_repo.go
@@ -0,0 +1,544 @@
+package repository
+
+import (
+ "context"
+ "database/sql"
+ "fmt"
+ "reflect"
+ "strings"
+ "time"
+ "unsafe"
+
+ dbent "github.com/Wei-Shaw/sub2api/ent"
+ "github.com/Wei-Shaw/sub2api/ent/authidentity"
+ "github.com/Wei-Shaw/sub2api/ent/authidentitychannel"
+ "github.com/Wei-Shaw/sub2api/ent/identityadoptiondecision"
+ infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
+ "github.com/Wei-Shaw/sub2api/internal/service"
+)
+
+var (
+ ErrAuthIdentityOwnershipConflict = infraerrors.Conflict(
+ "AUTH_IDENTITY_OWNERSHIP_CONFLICT",
+ "auth identity already belongs to another user",
+ )
+ ErrAuthIdentityChannelOwnershipConflict = infraerrors.Conflict(
+ "AUTH_IDENTITY_CHANNEL_OWNERSHIP_CONFLICT",
+ "auth identity channel already belongs to another user",
+ )
+)
+
+type ProviderGrantReason string
+
+const (
+ ProviderGrantReasonSignup ProviderGrantReason = "signup"
+ ProviderGrantReasonFirstBind ProviderGrantReason = "first_bind"
+)
+
+type AuthIdentityKey struct {
+ ProviderType string
+ ProviderKey string
+ ProviderSubject string
+}
+
+type AuthIdentityChannelKey struct {
+ ProviderType string
+ ProviderKey string
+ Channel string
+ ChannelAppID string
+ ChannelSubject string
+}
+
+type CreateAuthIdentityInput struct {
+ UserID int64
+ Canonical AuthIdentityKey
+ Channel *AuthIdentityChannelKey
+ Issuer *string
+ VerifiedAt *time.Time
+ Metadata map[string]any
+ ChannelMetadata map[string]any
+}
+
+type BindAuthIdentityInput = CreateAuthIdentityInput
+
+type CreateAuthIdentityResult struct {
+ Identity *dbent.AuthIdentity
+ Channel *dbent.AuthIdentityChannel
+}
+
+func (r *CreateAuthIdentityResult) IdentityRef() AuthIdentityKey {
+ if r == nil || r.Identity == nil {
+ return AuthIdentityKey{}
+ }
+ return AuthIdentityKey{
+ ProviderType: r.Identity.ProviderType,
+ ProviderKey: r.Identity.ProviderKey,
+ ProviderSubject: r.Identity.ProviderSubject,
+ }
+}
+
+func (r *CreateAuthIdentityResult) ChannelRef() *AuthIdentityChannelKey {
+ if r == nil || r.Channel == nil {
+ return nil
+ }
+ return &AuthIdentityChannelKey{
+ ProviderType: r.Channel.ProviderType,
+ ProviderKey: r.Channel.ProviderKey,
+ Channel: r.Channel.Channel,
+ ChannelAppID: r.Channel.ChannelAppID,
+ ChannelSubject: r.Channel.ChannelSubject,
+ }
+}
+
+type UserAuthIdentityLookup struct {
+ User *dbent.User
+ Identity *dbent.AuthIdentity
+ Channel *dbent.AuthIdentityChannel
+}
+
+type ProviderGrantRecordInput struct {
+ UserID int64
+ ProviderType string
+ GrantReason ProviderGrantReason
+}
+
+type IdentityAdoptionDecisionInput struct {
+ PendingAuthSessionID int64
+ IdentityID *int64
+ AdoptDisplayName bool
+ AdoptAvatar bool
+}
+
+type sqlQueryExecutor interface {
+ ExecContext(ctx context.Context, query string, args ...any) (sql.Result, error)
+ QueryContext(ctx context.Context, query string, args ...any) (*sql.Rows, error)
+}
+
+func (r *userRepository) WithUserProfileIdentityTx(ctx context.Context, fn func(txCtx context.Context) error) error {
+ if dbent.TxFromContext(ctx) != nil {
+ return fn(ctx)
+ }
+
+ tx, err := r.client.Tx(ctx)
+ if err != nil {
+ return err
+ }
+ defer func() { _ = tx.Rollback() }()
+
+ txCtx := dbent.NewTxContext(ctx, tx)
+ if err := fn(txCtx); err != nil {
+ return err
+ }
+ return tx.Commit()
+}
+
+func (r *userRepository) CreateAuthIdentity(ctx context.Context, input CreateAuthIdentityInput) (*CreateAuthIdentityResult, error) {
+ client := clientFromContext(ctx, r.client)
+
+ create := client.AuthIdentity.Create().
+ SetUserID(input.UserID).
+ SetProviderType(strings.TrimSpace(input.Canonical.ProviderType)).
+ SetProviderKey(strings.TrimSpace(input.Canonical.ProviderKey)).
+ SetProviderSubject(strings.TrimSpace(input.Canonical.ProviderSubject)).
+ SetMetadata(copyMetadata(input.Metadata)).
+ SetNillableIssuer(input.Issuer).
+ SetNillableVerifiedAt(input.VerifiedAt)
+
+ identity, err := create.Save(ctx)
+ if err != nil {
+ return nil, err
+ }
+
+ var channel *dbent.AuthIdentityChannel
+ if input.Channel != nil {
+ channel, err = client.AuthIdentityChannel.Create().
+ SetIdentityID(identity.ID).
+ SetProviderType(strings.TrimSpace(input.Channel.ProviderType)).
+ SetProviderKey(strings.TrimSpace(input.Channel.ProviderKey)).
+ SetChannel(strings.TrimSpace(input.Channel.Channel)).
+ SetChannelAppID(strings.TrimSpace(input.Channel.ChannelAppID)).
+ SetChannelSubject(strings.TrimSpace(input.Channel.ChannelSubject)).
+ SetMetadata(copyMetadata(input.ChannelMetadata)).
+ Save(ctx)
+ if err != nil {
+ return nil, err
+ }
+ }
+
+ return &CreateAuthIdentityResult{Identity: identity, Channel: channel}, nil
+}
+
+func (r *userRepository) GetUserByCanonicalIdentity(ctx context.Context, key AuthIdentityKey) (*UserAuthIdentityLookup, error) {
+ identity, err := clientFromContext(ctx, r.client).AuthIdentity.Query().
+ Where(
+ authidentity.ProviderTypeEQ(strings.TrimSpace(key.ProviderType)),
+ authidentity.ProviderKeyEQ(strings.TrimSpace(key.ProviderKey)),
+ authidentity.ProviderSubjectEQ(strings.TrimSpace(key.ProviderSubject)),
+ ).
+ WithUser().
+ Only(ctx)
+ if err != nil {
+ return nil, err
+ }
+
+ return &UserAuthIdentityLookup{
+ User: identity.Edges.User,
+ Identity: identity,
+ }, nil
+}
+
+func (r *userRepository) GetUserByChannelIdentity(ctx context.Context, key AuthIdentityChannelKey) (*UserAuthIdentityLookup, error) {
+ channel, err := clientFromContext(ctx, r.client).AuthIdentityChannel.Query().
+ Where(
+ authidentitychannel.ProviderTypeEQ(strings.TrimSpace(key.ProviderType)),
+ authidentitychannel.ProviderKeyEQ(strings.TrimSpace(key.ProviderKey)),
+ authidentitychannel.ChannelEQ(strings.TrimSpace(key.Channel)),
+ authidentitychannel.ChannelAppIDEQ(strings.TrimSpace(key.ChannelAppID)),
+ authidentitychannel.ChannelSubjectEQ(strings.TrimSpace(key.ChannelSubject)),
+ ).
+ WithIdentity(func(q *dbent.AuthIdentityQuery) {
+ q.WithUser()
+ }).
+ Only(ctx)
+ if err != nil {
+ return nil, err
+ }
+
+ return &UserAuthIdentityLookup{
+ User: channel.Edges.Identity.Edges.User,
+ Identity: channel.Edges.Identity,
+ Channel: channel,
+ }, nil
+}
+
+func (r *userRepository) BindAuthIdentityToUser(ctx context.Context, input BindAuthIdentityInput) (*CreateAuthIdentityResult, error) {
+ var result *CreateAuthIdentityResult
+ err := r.WithUserProfileIdentityTx(ctx, func(txCtx context.Context) error {
+ client := clientFromContext(txCtx, r.client)
+ canonical := input.Canonical
+
+ identity, err := client.AuthIdentity.Query().
+ Where(
+ authidentity.ProviderTypeEQ(strings.TrimSpace(canonical.ProviderType)),
+ authidentity.ProviderKeyEQ(strings.TrimSpace(canonical.ProviderKey)),
+ authidentity.ProviderSubjectEQ(strings.TrimSpace(canonical.ProviderSubject)),
+ ).
+ Only(txCtx)
+ if err != nil && !dbent.IsNotFound(err) {
+ return err
+ }
+ if identity != nil && identity.UserID != input.UserID {
+ return ErrAuthIdentityOwnershipConflict
+ }
+ if identity == nil {
+ identity, err = client.AuthIdentity.Create().
+ SetUserID(input.UserID).
+ SetProviderType(strings.TrimSpace(canonical.ProviderType)).
+ SetProviderKey(strings.TrimSpace(canonical.ProviderKey)).
+ SetProviderSubject(strings.TrimSpace(canonical.ProviderSubject)).
+ SetMetadata(copyMetadata(input.Metadata)).
+ SetNillableIssuer(input.Issuer).
+ SetNillableVerifiedAt(input.VerifiedAt).
+ Save(txCtx)
+ if err != nil {
+ return err
+ }
+ } else {
+ update := client.AuthIdentity.UpdateOneID(identity.ID)
+ if input.Metadata != nil {
+ update = update.SetMetadata(copyMetadata(input.Metadata))
+ }
+ if input.Issuer != nil {
+ update = update.SetIssuer(strings.TrimSpace(*input.Issuer))
+ }
+ if input.VerifiedAt != nil {
+ update = update.SetVerifiedAt(*input.VerifiedAt)
+ }
+ identity, err = update.Save(txCtx)
+ if err != nil {
+ return err
+ }
+ }
+
+ var channel *dbent.AuthIdentityChannel
+ if input.Channel != nil {
+ channel, err = client.AuthIdentityChannel.Query().
+ Where(
+ authidentitychannel.ProviderTypeEQ(strings.TrimSpace(input.Channel.ProviderType)),
+ authidentitychannel.ProviderKeyEQ(strings.TrimSpace(input.Channel.ProviderKey)),
+ authidentitychannel.ChannelEQ(strings.TrimSpace(input.Channel.Channel)),
+ authidentitychannel.ChannelAppIDEQ(strings.TrimSpace(input.Channel.ChannelAppID)),
+ authidentitychannel.ChannelSubjectEQ(strings.TrimSpace(input.Channel.ChannelSubject)),
+ ).
+ WithIdentity().
+ Only(txCtx)
+ if err != nil && !dbent.IsNotFound(err) {
+ return err
+ }
+ if channel != nil && channel.Edges.Identity != nil && channel.Edges.Identity.UserID != input.UserID {
+ return ErrAuthIdentityChannelOwnershipConflict
+ }
+ if channel == nil {
+ channel, err = client.AuthIdentityChannel.Create().
+ SetIdentityID(identity.ID).
+ SetProviderType(strings.TrimSpace(input.Channel.ProviderType)).
+ SetProviderKey(strings.TrimSpace(input.Channel.ProviderKey)).
+ SetChannel(strings.TrimSpace(input.Channel.Channel)).
+ SetChannelAppID(strings.TrimSpace(input.Channel.ChannelAppID)).
+ SetChannelSubject(strings.TrimSpace(input.Channel.ChannelSubject)).
+ SetMetadata(copyMetadata(input.ChannelMetadata)).
+ Save(txCtx)
+ if err != nil {
+ return err
+ }
+ } else {
+ update := client.AuthIdentityChannel.UpdateOneID(channel.ID).
+ SetIdentityID(identity.ID)
+ if input.ChannelMetadata != nil {
+ update = update.SetMetadata(copyMetadata(input.ChannelMetadata))
+ }
+ channel, err = update.Save(txCtx)
+ if err != nil {
+ return err
+ }
+ }
+ }
+
+ result = &CreateAuthIdentityResult{Identity: identity, Channel: channel}
+ return nil
+ })
+ if err != nil {
+ return nil, err
+ }
+ return result, nil
+}
+
+func (r *userRepository) RecordProviderGrant(ctx context.Context, input ProviderGrantRecordInput) (bool, error) {
+ exec := txAwareSQLExecutor(ctx, r.sql, r.client)
+ if exec == nil {
+ return false, fmt.Errorf("sql executor is not configured")
+ }
+
+ result, err := exec.ExecContext(ctx, `
+INSERT INTO user_provider_default_grants (user_id, provider_type, grant_reason)
+VALUES ($1, $2, $3)
+ON CONFLICT (user_id, provider_type, grant_reason) DO NOTHING`,
+ input.UserID,
+ strings.TrimSpace(input.ProviderType),
+ string(input.GrantReason),
+ )
+ if err != nil {
+ return false, err
+ }
+ affected, err := result.RowsAffected()
+ if err != nil {
+ return false, err
+ }
+ return affected > 0, nil
+}
+
+func (r *userRepository) UpsertIdentityAdoptionDecision(ctx context.Context, input IdentityAdoptionDecisionInput) (*dbent.IdentityAdoptionDecision, error) {
+ client := clientFromContext(ctx, r.client)
+ current, err := client.IdentityAdoptionDecision.Query().
+ Where(identityadoptiondecision.PendingAuthSessionIDEQ(input.PendingAuthSessionID)).
+ Only(ctx)
+ if err != nil && !dbent.IsNotFound(err) {
+ return nil, err
+ }
+ now := time.Now().UTC()
+ if current == nil {
+ create := client.IdentityAdoptionDecision.Create().
+ SetPendingAuthSessionID(input.PendingAuthSessionID).
+ SetAdoptDisplayName(input.AdoptDisplayName).
+ SetAdoptAvatar(input.AdoptAvatar).
+ SetDecidedAt(now)
+ if input.IdentityID != nil {
+ create = create.SetIdentityID(*input.IdentityID)
+ }
+ return create.Save(ctx)
+ }
+
+ update := client.IdentityAdoptionDecision.UpdateOneID(current.ID).
+ SetAdoptDisplayName(input.AdoptDisplayName).
+ SetAdoptAvatar(input.AdoptAvatar)
+ if input.IdentityID != nil {
+ update = update.SetIdentityID(*input.IdentityID)
+ }
+ return update.Save(ctx)
+}
+
+func (r *userRepository) GetIdentityAdoptionDecisionByPendingAuthSessionID(ctx context.Context, pendingAuthSessionID int64) (*dbent.IdentityAdoptionDecision, error) {
+ return clientFromContext(ctx, r.client).IdentityAdoptionDecision.Query().
+ Where(identityadoptiondecision.PendingAuthSessionIDEQ(pendingAuthSessionID)).
+ Only(ctx)
+}
+
+func (r *userRepository) UpdateUserLastLoginAt(ctx context.Context, userID int64, loginAt time.Time) error {
+ _, err := clientFromContext(ctx, r.client).User.UpdateOneID(userID).
+ SetLastLoginAt(loginAt).
+ Save(ctx)
+ return err
+}
+
+func (r *userRepository) UpdateUserLastActiveAt(ctx context.Context, userID int64, activeAt time.Time) error {
+ _, err := clientFromContext(ctx, r.client).User.UpdateOneID(userID).
+ SetLastActiveAt(activeAt).
+ Save(ctx)
+ return err
+}
+
+func (r *userRepository) GetUserAvatar(ctx context.Context, userID int64) (*service.UserAvatar, error) {
+ exec, err := r.userProfileIdentitySQL(ctx)
+ if err != nil {
+ return nil, err
+ }
+
+ rows, err := exec.QueryContext(ctx, `
+SELECT storage_provider, storage_key, url, content_type, byte_size, sha256
+FROM user_avatars
+WHERE user_id = $1`, userID)
+ if err != nil {
+ return nil, err
+ }
+ defer func() { _ = rows.Close() }()
+
+ if !rows.Next() {
+ return nil, rows.Err()
+ }
+
+ var avatar service.UserAvatar
+ if err := rows.Scan(
+ &avatar.StorageProvider,
+ &avatar.StorageKey,
+ &avatar.URL,
+ &avatar.ContentType,
+ &avatar.ByteSize,
+ &avatar.SHA256,
+ ); err != nil {
+ return nil, err
+ }
+ if err := rows.Err(); err != nil {
+ return nil, err
+ }
+ return &avatar, nil
+}
+
+func (r *userRepository) UpsertUserAvatar(ctx context.Context, userID int64, input service.UpsertUserAvatarInput) (*service.UserAvatar, error) {
+ exec, err := r.userProfileIdentitySQL(ctx)
+ if err != nil {
+ return nil, err
+ }
+
+ _, err = exec.ExecContext(ctx, `
+INSERT INTO user_avatars (user_id, storage_provider, storage_key, url, content_type, byte_size, sha256, updated_at)
+VALUES ($1, $2, $3, $4, $5, $6, $7, NOW())
+ON CONFLICT (user_id) DO UPDATE SET
+ storage_provider = EXCLUDED.storage_provider,
+ storage_key = EXCLUDED.storage_key,
+ url = EXCLUDED.url,
+ content_type = EXCLUDED.content_type,
+ byte_size = EXCLUDED.byte_size,
+ sha256 = EXCLUDED.sha256,
+ updated_at = NOW()`,
+ userID,
+ strings.TrimSpace(input.StorageProvider),
+ strings.TrimSpace(input.StorageKey),
+ strings.TrimSpace(input.URL),
+ strings.TrimSpace(input.ContentType),
+ input.ByteSize,
+ strings.TrimSpace(input.SHA256),
+ )
+ if err != nil {
+ return nil, err
+ }
+
+ return &service.UserAvatar{
+ StorageProvider: strings.TrimSpace(input.StorageProvider),
+ StorageKey: strings.TrimSpace(input.StorageKey),
+ URL: strings.TrimSpace(input.URL),
+ ContentType: strings.TrimSpace(input.ContentType),
+ ByteSize: input.ByteSize,
+ SHA256: strings.TrimSpace(input.SHA256),
+ }, nil
+}
+
+func (r *userRepository) DeleteUserAvatar(ctx context.Context, userID int64) error {
+ exec, err := r.userProfileIdentitySQL(ctx)
+ if err != nil {
+ return err
+ }
+ _, err = exec.ExecContext(ctx, `DELETE FROM user_avatars WHERE user_id = $1`, userID)
+ return err
+}
+
+func (r *userRepository) attachUserAvatar(ctx context.Context, user *service.User) error {
+ if user == nil {
+ return nil
+ }
+
+ avatar, err := r.GetUserAvatar(ctx, user.ID)
+ if err != nil {
+ return err
+ }
+ if avatar == nil {
+ return nil
+ }
+
+ user.AvatarURL = avatar.URL
+ user.AvatarSource = avatar.StorageProvider
+ user.AvatarMIME = avatar.ContentType
+ user.AvatarByteSize = avatar.ByteSize
+ user.AvatarSHA256 = avatar.SHA256
+ return nil
+}
+
+func copyMetadata(in map[string]any) map[string]any {
+ if len(in) == 0 {
+ return map[string]any{}
+ }
+ out := make(map[string]any, len(in))
+ for k, v := range in {
+ out[k] = v
+ }
+ return out
+}
+
+func txAwareSQLExecutor(ctx context.Context, fallback sqlExecutor, client *dbent.Client) sqlQueryExecutor {
+ if tx := dbent.TxFromContext(ctx); tx != nil {
+ if exec := sqlExecutorFromEntClient(tx.Client()); exec != nil {
+ return exec
+ }
+ }
+ if fallback != nil {
+ return fallback
+ }
+ return sqlExecutorFromEntClient(client)
+}
+
+func (r *userRepository) userProfileIdentitySQL(ctx context.Context) (sqlQueryExecutor, error) {
+ exec := txAwareSQLExecutor(ctx, r.sql, r.client)
+ if exec == nil {
+ return nil, fmt.Errorf("sql executor is not configured")
+ }
+ return exec, nil
+}
+
+func sqlExecutorFromEntClient(client *dbent.Client) sqlQueryExecutor {
+ if client == nil {
+ return nil
+ }
+
+ clientValue := reflect.ValueOf(client).Elem()
+ configValue := clientValue.FieldByName("config")
+ driverValue := configValue.FieldByName("driver")
+ if !driverValue.IsValid() {
+ return nil
+ }
+
+ driver := reflect.NewAt(driverValue.Type(), unsafe.Pointer(driverValue.UnsafeAddr())).Elem().Interface()
+ exec, ok := driver.(sqlQueryExecutor)
+ if !ok {
+ return nil
+ }
+ return exec
+}
diff --git a/backend/internal/repository/user_profile_identity_repo_contract_test.go b/backend/internal/repository/user_profile_identity_repo_contract_test.go
new file mode 100644
index 00000000..19022ec1
--- /dev/null
+++ b/backend/internal/repository/user_profile_identity_repo_contract_test.go
@@ -0,0 +1,428 @@
+//go:build integration
+
+package repository
+
+import (
+ "context"
+ "errors"
+ "fmt"
+ "testing"
+ "time"
+
+ dbent "github.com/Wei-Shaw/sub2api/ent"
+ "github.com/Wei-Shaw/sub2api/internal/service"
+ "github.com/stretchr/testify/suite"
+)
+
+type UserProfileIdentityRepoSuite struct {
+ suite.Suite
+ ctx context.Context
+ client *dbent.Client
+ repo *userRepository
+}
+
+func TestUserProfileIdentityRepoSuite(t *testing.T) {
+ suite.Run(t, new(UserProfileIdentityRepoSuite))
+}
+
+func (s *UserProfileIdentityRepoSuite) SetupTest() {
+ s.ctx = context.Background()
+ s.client = testEntClient(s.T())
+ s.repo = newUserRepositoryWithSQL(s.client, integrationDB)
+
+ _, err := integrationDB.ExecContext(s.ctx, `
+TRUNCATE TABLE
+ identity_adoption_decisions,
+ auth_identity_channels,
+ auth_identities,
+ pending_auth_sessions,
+ auth_identity_migration_reports,
+ user_provider_default_grants,
+ user_avatars
+RESTART IDENTITY`)
+ s.Require().NoError(err)
+}
+
+func (s *UserProfileIdentityRepoSuite) mustCreateUser(label string) *dbent.User {
+ s.T().Helper()
+
+ user, err := s.client.User.Create().
+ SetEmail(fmt.Sprintf("%s-%d@example.com", label, time.Now().UnixNano())).
+ SetPasswordHash("test-password-hash").
+ SetRole("user").
+ SetStatus("active").
+ Save(s.ctx)
+ s.Require().NoError(err)
+ return user
+}
+
+func (s *UserProfileIdentityRepoSuite) mustCreatePendingAuthSession(key AuthIdentityKey) *dbent.PendingAuthSession {
+ s.T().Helper()
+
+ session, err := s.client.PendingAuthSession.Create().
+ SetSessionToken(fmt.Sprintf("pending-%d", time.Now().UnixNano())).
+ SetIntent("bind_current_user").
+ SetProviderType(key.ProviderType).
+ SetProviderKey(key.ProviderKey).
+ SetProviderSubject(key.ProviderSubject).
+ SetExpiresAt(time.Now().UTC().Add(15 * time.Minute)).
+ SetUpstreamIdentityClaims(map[string]any{"provider_subject": key.ProviderSubject}).
+ SetLocalFlowState(map[string]any{"step": "pending"}).
+ Save(s.ctx)
+ s.Require().NoError(err)
+ return session
+}
+
+func (s *UserProfileIdentityRepoSuite) TestCreateAndLookupCanonicalAndChannelIdentity() {
+ user := s.mustCreateUser("canonical-channel")
+
+ verifiedAt := time.Now().UTC().Truncate(time.Second)
+ created, err := s.repo.CreateAuthIdentity(s.ctx, CreateAuthIdentityInput{
+ UserID: user.ID,
+ Canonical: AuthIdentityKey{
+ ProviderType: "wechat",
+ ProviderKey: "wechat-open",
+ ProviderSubject: "union-123",
+ },
+ Channel: &AuthIdentityChannelKey{
+ ProviderType: "wechat",
+ ProviderKey: "wechat-open",
+ Channel: "mp",
+ ChannelAppID: "wx-app",
+ ChannelSubject: "openid-123",
+ },
+ Issuer: stringPtr("https://issuer.example"),
+ VerifiedAt: &verifiedAt,
+ Metadata: map[string]any{"unionid": "union-123"},
+ ChannelMetadata: map[string]any{"openid": "openid-123"},
+ })
+ s.Require().NoError(err)
+ s.Require().NotNil(created.Identity)
+ s.Require().NotNil(created.Channel)
+
+ canonical, err := s.repo.GetUserByCanonicalIdentity(s.ctx, created.IdentityRef())
+ s.Require().NoError(err)
+ s.Require().Equal(user.ID, canonical.User.ID)
+ s.Require().Equal(created.Identity.ID, canonical.Identity.ID)
+ s.Require().Equal("union-123", canonical.Identity.ProviderSubject)
+
+ channel, err := s.repo.GetUserByChannelIdentity(s.ctx, *created.ChannelRef())
+ s.Require().NoError(err)
+ s.Require().Equal(user.ID, channel.User.ID)
+ s.Require().Equal(created.Identity.ID, channel.Identity.ID)
+ s.Require().Equal(created.Channel.ID, channel.Channel.ID)
+}
+
+func (s *UserProfileIdentityRepoSuite) TestBindAuthIdentityToUser_IsIdempotentAndRejectsOtherOwners() {
+ owner := s.mustCreateUser("owner")
+ other := s.mustCreateUser("other")
+
+ first, err := s.repo.BindAuthIdentityToUser(s.ctx, BindAuthIdentityInput{
+ UserID: owner.ID,
+ Canonical: AuthIdentityKey{
+ ProviderType: "linuxdo",
+ ProviderKey: "linuxdo-main",
+ ProviderSubject: "subject-1",
+ },
+ Channel: &AuthIdentityChannelKey{
+ ProviderType: "linuxdo",
+ ProviderKey: "linuxdo-main",
+ Channel: "oauth",
+ ChannelAppID: "linuxdo-web",
+ ChannelSubject: "subject-1",
+ },
+ Metadata: map[string]any{"username": "first"},
+ ChannelMetadata: map[string]any{"scope": "read"},
+ })
+ s.Require().NoError(err)
+
+ second, err := s.repo.BindAuthIdentityToUser(s.ctx, BindAuthIdentityInput{
+ UserID: owner.ID,
+ Canonical: AuthIdentityKey{
+ ProviderType: "linuxdo",
+ ProviderKey: "linuxdo-main",
+ ProviderSubject: "subject-1",
+ },
+ Channel: &AuthIdentityChannelKey{
+ ProviderType: "linuxdo",
+ ProviderKey: "linuxdo-main",
+ Channel: "oauth",
+ ChannelAppID: "linuxdo-web",
+ ChannelSubject: "subject-1",
+ },
+ Metadata: map[string]any{"username": "second"},
+ ChannelMetadata: map[string]any{"scope": "write"},
+ })
+ s.Require().NoError(err)
+ s.Require().Equal(first.Identity.ID, second.Identity.ID)
+ s.Require().Equal(first.Channel.ID, second.Channel.ID)
+ s.Require().Equal("second", second.Identity.Metadata["username"])
+ s.Require().Equal("write", second.Channel.Metadata["scope"])
+
+ _, err = s.repo.BindAuthIdentityToUser(s.ctx, BindAuthIdentityInput{
+ UserID: other.ID,
+ Canonical: AuthIdentityKey{
+ ProviderType: "linuxdo",
+ ProviderKey: "linuxdo-main",
+ ProviderSubject: "subject-1",
+ },
+ })
+ s.Require().ErrorIs(err, ErrAuthIdentityOwnershipConflict)
+
+ _, err = s.repo.BindAuthIdentityToUser(s.ctx, BindAuthIdentityInput{
+ UserID: other.ID,
+ Canonical: AuthIdentityKey{
+ ProviderType: "linuxdo",
+ ProviderKey: "linuxdo-main",
+ ProviderSubject: "subject-2",
+ },
+ Channel: &AuthIdentityChannelKey{
+ ProviderType: "linuxdo",
+ ProviderKey: "linuxdo-main",
+ Channel: "oauth",
+ ChannelAppID: "linuxdo-web",
+ ChannelSubject: "subject-1",
+ },
+ })
+ s.Require().ErrorIs(err, ErrAuthIdentityChannelOwnershipConflict)
+}
+
+func (s *UserProfileIdentityRepoSuite) TestWithUserProfileIdentityTx_RollsBackIdentityAndGrantOnError() {
+ user := s.mustCreateUser("tx-rollback")
+ expectedErr := errors.New("rollback")
+
+ err := s.repo.WithUserProfileIdentityTx(s.ctx, func(txCtx context.Context) error {
+ _, err := s.repo.CreateAuthIdentity(txCtx, CreateAuthIdentityInput{
+ UserID: user.ID,
+ Canonical: AuthIdentityKey{
+ ProviderType: "oidc",
+ ProviderKey: "https://issuer.example",
+ ProviderSubject: "subject-rollback",
+ },
+ })
+ s.Require().NoError(err)
+
+ inserted, err := s.repo.RecordProviderGrant(txCtx, ProviderGrantRecordInput{
+ UserID: user.ID,
+ ProviderType: "oidc",
+ GrantReason: ProviderGrantReasonFirstBind,
+ })
+ s.Require().NoError(err)
+ s.Require().True(inserted)
+ return expectedErr
+ })
+ s.Require().ErrorIs(err, expectedErr)
+
+ _, err = s.repo.GetUserByCanonicalIdentity(s.ctx, AuthIdentityKey{
+ ProviderType: "oidc",
+ ProviderKey: "https://issuer.example",
+ ProviderSubject: "subject-rollback",
+ })
+ s.Require().True(dbent.IsNotFound(err))
+
+ var count int
+ s.Require().NoError(integrationDB.QueryRowContext(s.ctx, `
+SELECT COUNT(*)
+FROM user_provider_default_grants
+WHERE user_id = $1 AND provider_type = $2 AND grant_reason = $3`,
+ user.ID,
+ "oidc",
+ string(ProviderGrantReasonFirstBind),
+ ).Scan(&count))
+ s.Require().Zero(count)
+}
+
+func (s *UserProfileIdentityRepoSuite) TestRecordProviderGrant_IsIdempotentPerReason() {
+ user := s.mustCreateUser("grant")
+
+ inserted, err := s.repo.RecordProviderGrant(s.ctx, ProviderGrantRecordInput{
+ UserID: user.ID,
+ ProviderType: "wechat",
+ GrantReason: ProviderGrantReasonFirstBind,
+ })
+ s.Require().NoError(err)
+ s.Require().True(inserted)
+
+ inserted, err = s.repo.RecordProviderGrant(s.ctx, ProviderGrantRecordInput{
+ UserID: user.ID,
+ ProviderType: "wechat",
+ GrantReason: ProviderGrantReasonFirstBind,
+ })
+ s.Require().NoError(err)
+ s.Require().False(inserted)
+
+ inserted, err = s.repo.RecordProviderGrant(s.ctx, ProviderGrantRecordInput{
+ UserID: user.ID,
+ ProviderType: "wechat",
+ GrantReason: ProviderGrantReasonSignup,
+ })
+ s.Require().NoError(err)
+ s.Require().True(inserted)
+
+ var count int
+ s.Require().NoError(integrationDB.QueryRowContext(s.ctx, `
+SELECT COUNT(*)
+FROM user_provider_default_grants
+WHERE user_id = $1 AND provider_type = $2`,
+ user.ID,
+ "wechat",
+ ).Scan(&count))
+ s.Require().Equal(2, count)
+}
+
+func (s *UserProfileIdentityRepoSuite) TestUpsertIdentityAdoptionDecision_PersistsAndLinksIdentity() {
+ user := s.mustCreateUser("adoption")
+ identity, err := s.repo.CreateAuthIdentity(s.ctx, CreateAuthIdentityInput{
+ UserID: user.ID,
+ Canonical: AuthIdentityKey{
+ ProviderType: "wechat",
+ ProviderKey: "wechat-open",
+ ProviderSubject: "union-adoption",
+ },
+ })
+ s.Require().NoError(err)
+
+ session := s.mustCreatePendingAuthSession(identity.IdentityRef())
+
+ first, err := s.repo.UpsertIdentityAdoptionDecision(s.ctx, IdentityAdoptionDecisionInput{
+ PendingAuthSessionID: session.ID,
+ AdoptDisplayName: true,
+ AdoptAvatar: false,
+ })
+ s.Require().NoError(err)
+ s.Require().True(first.AdoptDisplayName)
+ s.Require().False(first.AdoptAvatar)
+ s.Require().Nil(first.IdentityID)
+
+ second, err := s.repo.UpsertIdentityAdoptionDecision(s.ctx, IdentityAdoptionDecisionInput{
+ PendingAuthSessionID: session.ID,
+ IdentityID: &identity.Identity.ID,
+ AdoptDisplayName: true,
+ AdoptAvatar: true,
+ })
+ s.Require().NoError(err)
+ s.Require().Equal(first.ID, second.ID)
+ s.Require().NotNil(second.IdentityID)
+ s.Require().Equal(identity.Identity.ID, *second.IdentityID)
+ s.Require().True(second.AdoptAvatar)
+
+ loaded, err := s.repo.GetIdentityAdoptionDecisionByPendingAuthSessionID(s.ctx, session.ID)
+ s.Require().NoError(err)
+ s.Require().Equal(second.ID, loaded.ID)
+ s.Require().Equal(identity.Identity.ID, *loaded.IdentityID)
+}
+
+func (s *UserProfileIdentityRepoSuite) TestUserAvatarCRUDAndUserLookup() {
+ user := s.mustCreateUser("avatar")
+
+ inlineAvatar, err := s.repo.UpsertUserAvatar(s.ctx, user.ID, service.UpsertUserAvatarInput{
+ StorageProvider: "inline",
+ URL: "data:image/png;base64,QUJD",
+ ContentType: "image/png",
+ ByteSize: 3,
+ SHA256: "902fbdd2b1df0c4f70b4a5d23525e932",
+ })
+ s.Require().NoError(err)
+ s.Require().Equal("inline", inlineAvatar.StorageProvider)
+ s.Require().Equal("data:image/png;base64,QUJD", inlineAvatar.URL)
+
+ loadedAvatar, err := s.repo.GetUserAvatar(s.ctx, user.ID)
+ s.Require().NoError(err)
+ s.Require().NotNil(loadedAvatar)
+ s.Require().Equal("image/png", loadedAvatar.ContentType)
+ s.Require().Equal(3, loadedAvatar.ByteSize)
+
+ _, err = s.repo.UpsertUserAvatar(s.ctx, user.ID, service.UpsertUserAvatarInput{
+ StorageProvider: "remote_url",
+ URL: "https://cdn.example.com/avatar.png",
+ })
+ s.Require().NoError(err)
+
+ loadedAvatar, err = s.repo.GetUserAvatar(s.ctx, user.ID)
+ s.Require().NoError(err)
+ s.Require().NotNil(loadedAvatar)
+ s.Require().Equal("remote_url", loadedAvatar.StorageProvider)
+ s.Require().Equal("https://cdn.example.com/avatar.png", loadedAvatar.URL)
+ s.Require().Zero(loadedAvatar.ByteSize)
+
+ s.Require().NoError(s.repo.DeleteUserAvatar(s.ctx, user.ID))
+ loadedAvatar, err = s.repo.GetUserAvatar(s.ctx, user.ID)
+ s.Require().NoError(err)
+ s.Require().Nil(loadedAvatar)
+}
+
+func (s *UserProfileIdentityRepoSuite) TestAuthIdentityMigrationReportHelpers_ListAndSummarize() {
+ _, err := integrationDB.ExecContext(s.ctx, `
+INSERT INTO auth_identity_migration_reports (report_type, report_key, details, created_at)
+VALUES
+ ('wechat_openid_only_requires_remediation', 'u-1', '{"user_id":1}'::jsonb, '2026-04-20T10:00:00Z'),
+ ('wechat_openid_only_requires_remediation', 'u-2', '{"user_id":2}'::jsonb, '2026-04-20T11:00:00Z'),
+ ('oidc_synthetic_email_requires_manual_recovery', 'u-3', '{"user_id":3}'::jsonb, '2026-04-20T12:00:00Z')`)
+ s.Require().NoError(err)
+
+ summary, err := s.repo.SummarizeAuthIdentityMigrationReports(s.ctx)
+ s.Require().NoError(err)
+ s.Require().Equal(int64(3), summary.Total)
+ s.Require().Equal(int64(2), summary.ByType["wechat_openid_only_requires_remediation"])
+ s.Require().Equal(int64(1), summary.ByType["oidc_synthetic_email_requires_manual_recovery"])
+
+ reports, err := s.repo.ListAuthIdentityMigrationReports(s.ctx, AuthIdentityMigrationReportQuery{
+ ReportType: "wechat_openid_only_requires_remediation",
+ Limit: 10,
+ })
+ s.Require().NoError(err)
+ s.Require().Len(reports, 2)
+ s.Require().Equal("u-2", reports[0].ReportKey)
+ s.Require().Equal(float64(2), reports[0].Details["user_id"])
+
+ report, err := s.repo.GetAuthIdentityMigrationReport(s.ctx, "oidc_synthetic_email_requires_manual_recovery", "u-3")
+ s.Require().NoError(err)
+ s.Require().Equal("u-3", report.ReportKey)
+ s.Require().Equal(float64(3), report.Details["user_id"])
+}
+
+func (s *UserProfileIdentityRepoSuite) TestUpdateUserLastLoginAndActiveAt_UsesDedicatedColumns() {
+ user := s.mustCreateUser("activity")
+ loginAt := time.Date(2026, 4, 20, 8, 0, 0, 0, time.UTC)
+ activeAt := loginAt.Add(5 * time.Minute)
+
+ s.Require().NoError(s.repo.UpdateUserLastLoginAt(s.ctx, user.ID, loginAt))
+ s.Require().NoError(s.repo.UpdateUserLastActiveAt(s.ctx, user.ID, activeAt))
+
+ var storedLoginAt sqlNullTime
+ var storedActiveAt sqlNullTime
+ s.Require().NoError(integrationDB.QueryRowContext(s.ctx, `
+SELECT last_login_at, last_active_at
+FROM users
+WHERE id = $1`,
+ user.ID,
+ ).Scan(&storedLoginAt, &storedActiveAt))
+ s.Require().True(storedLoginAt.Valid)
+ s.Require().True(storedActiveAt.Valid)
+ s.Require().True(storedLoginAt.Time.Equal(loginAt))
+ s.Require().True(storedActiveAt.Time.Equal(activeAt))
+}
+
+type sqlNullTime struct {
+ Time time.Time
+ Valid bool
+}
+
+func (t *sqlNullTime) Scan(value any) error {
+ switch v := value.(type) {
+ case time.Time:
+ t.Time = v
+ t.Valid = true
+ return nil
+ case nil:
+ t.Time = time.Time{}
+ t.Valid = false
+ return nil
+ default:
+ return fmt.Errorf("unsupported scan type %T", value)
+ }
+}
+
+func stringPtr(v string) *string {
+ return &v
+}
diff --git a/backend/internal/repository/user_repo.go b/backend/internal/repository/user_repo.go
index 913e1c40..0c607ecc 100644
--- a/backend/internal/repository/user_repo.go
+++ b/backend/internal/repository/user_repo.go
@@ -64,6 +64,9 @@ func (r *userRepository) Create(ctx context.Context, userIn *service.User) error
SetBalance(userIn.Balance).
SetConcurrency(userIn.Concurrency).
SetStatus(userIn.Status).
+ SetSignupSource(userSignupSourceOrDefault(userIn.SignupSource)).
+ SetNillableLastLoginAt(userIn.LastLoginAt).
+ SetNillableLastActiveAt(userIn.LastActiveAt).
Save(ctx)
if err != nil {
return translatePersistenceError(err, nil, service.ErrEmailExists)
@@ -151,6 +154,15 @@ func (r *userRepository) Update(ctx context.Context, userIn *service.User) error
SetNillableBalanceNotifyThreshold(userIn.BalanceNotifyThreshold).
SetBalanceNotifyExtraEmails(marshalExtraEmails(userIn.BalanceNotifyExtraEmails)).
SetTotalRecharged(userIn.TotalRecharged)
+ if userIn.SignupSource != "" {
+ updateOp = updateOp.SetSignupSource(userIn.SignupSource)
+ }
+ if userIn.LastLoginAt != nil {
+ updateOp = updateOp.SetLastLoginAt(*userIn.LastLoginAt)
+ }
+ if userIn.LastActiveAt != nil {
+ updateOp = updateOp.SetLastActiveAt(*userIn.LastActiveAt)
+ }
if userIn.BalanceNotifyThreshold == nil {
updateOp = updateOp.ClearBalanceNotifyThreshold()
}
@@ -300,6 +312,7 @@ func userListOrder(params pagination.PaginationParams) []func(*entsql.Selector)
var field string
defaultField := true
+ nullsLastField := false
switch sortBy {
case "email":
field = dbuser.FieldEmail
@@ -322,6 +335,14 @@ func userListOrder(params pagination.PaginationParams) []func(*entsql.Selector)
case "created_at":
field = dbuser.FieldCreatedAt
defaultField = false
+ case "last_login_at":
+ field = dbuser.FieldLastLoginAt
+ defaultField = false
+ nullsLastField = true
+ case "last_active_at":
+ field = dbuser.FieldLastActiveAt
+ defaultField = false
+ nullsLastField = true
default:
field = dbuser.FieldID
}
@@ -330,11 +351,23 @@ func userListOrder(params pagination.PaginationParams) []func(*entsql.Selector)
if defaultField && field == dbuser.FieldID {
return []func(*entsql.Selector){dbent.Asc(dbuser.FieldID)}
}
+ if nullsLastField {
+ return []func(*entsql.Selector){
+ entsql.OrderByField(field, entsql.OrderNullsLast()).ToFunc(),
+ dbent.Asc(dbuser.FieldID),
+ }
+ }
return []func(*entsql.Selector){dbent.Asc(field), dbent.Asc(dbuser.FieldID)}
}
if defaultField && field == dbuser.FieldID {
return []func(*entsql.Selector){dbent.Desc(dbuser.FieldID)}
}
+ if nullsLastField {
+ return []func(*entsql.Selector){
+ entsql.OrderByField(field, entsql.OrderDesc(), entsql.OrderNullsLast()).ToFunc(),
+ dbent.Desc(dbuser.FieldID),
+ }
+ }
return []func(*entsql.Selector){dbent.Desc(field), dbent.Desc(dbuser.FieldID)}
}
@@ -558,10 +591,21 @@ func applyUserEntityToService(dst *service.User, src *dbent.User) {
return
}
dst.ID = src.ID
+ dst.SignupSource = src.SignupSource
+ dst.LastLoginAt = src.LastLoginAt
+ dst.LastActiveAt = src.LastActiveAt
dst.CreatedAt = src.CreatedAt
dst.UpdatedAt = src.UpdatedAt
}
+func userSignupSourceOrDefault(signupSource string) string {
+ signupSource = strings.TrimSpace(signupSource)
+ if signupSource == "" {
+ return "email"
+ }
+ return signupSource
+}
+
// marshalExtraEmails serializes notify email entries to JSON for storage.
func marshalExtraEmails(entries []service.NotifyEmailEntry) string {
return service.MarshalNotifyEmails(entries)
diff --git a/backend/internal/repository/user_repo_sort_integration_test.go b/backend/internal/repository/user_repo_sort_integration_test.go
index ab84b0e9..8abef45a 100644
--- a/backend/internal/repository/user_repo_sort_integration_test.go
+++ b/backend/internal/repository/user_repo_sort_integration_test.go
@@ -4,6 +4,7 @@ package repository
import (
"testing"
+ "time"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/Wei-Shaw/sub2api/internal/service"
@@ -36,4 +37,86 @@ func (s *UserRepoSuite) TestList_DefaultSortByNewestFirst() {
s.Require().Equal(first.ID, users[1].ID)
}
+func (s *UserRepoSuite) TestCreateAndRead_PreservesSignupSourceAndActivityTimestamps() {
+ lastLoginAt := time.Now().Add(-2 * time.Hour).UTC().Truncate(time.Microsecond)
+ lastActiveAt := time.Now().Add(-30 * time.Minute).UTC().Truncate(time.Microsecond)
+
+ created := s.mustCreateUser(&service.User{
+ Email: "identity-meta@example.com",
+ SignupSource: "github",
+ LastLoginAt: &lastLoginAt,
+ LastActiveAt: &lastActiveAt,
+ })
+
+ got, err := s.repo.GetByID(s.ctx, created.ID)
+ s.Require().NoError(err)
+ s.Require().Equal("github", got.SignupSource)
+ s.Require().NotNil(got.LastLoginAt)
+ s.Require().NotNil(got.LastActiveAt)
+ s.Require().True(got.LastLoginAt.Equal(lastLoginAt))
+ s.Require().True(got.LastActiveAt.Equal(lastActiveAt))
+}
+
+func (s *UserRepoSuite) TestUpdate_PersistsSignupSourceAndActivityTimestamps() {
+ created := s.mustCreateUser(&service.User{Email: "identity-update@example.com"})
+ lastLoginAt := time.Now().Add(-90 * time.Minute).UTC().Truncate(time.Microsecond)
+ lastActiveAt := time.Now().Add(-15 * time.Minute).UTC().Truncate(time.Microsecond)
+
+ created.SignupSource = "oidc"
+ created.LastLoginAt = &lastLoginAt
+ created.LastActiveAt = &lastActiveAt
+
+ s.Require().NoError(s.repo.Update(s.ctx, created))
+
+ got, err := s.repo.GetByID(s.ctx, created.ID)
+ s.Require().NoError(err)
+ s.Require().Equal("oidc", got.SignupSource)
+ s.Require().NotNil(got.LastLoginAt)
+ s.Require().NotNil(got.LastActiveAt)
+ s.Require().True(got.LastLoginAt.Equal(lastLoginAt))
+ s.Require().True(got.LastActiveAt.Equal(lastActiveAt))
+}
+
+func (s *UserRepoSuite) TestListWithFilters_SortByLastLoginAtDesc() {
+ older := time.Now().Add(-4 * time.Hour).UTC().Truncate(time.Microsecond)
+ newer := time.Now().Add(-1 * time.Hour).UTC().Truncate(time.Microsecond)
+
+ s.mustCreateUser(&service.User{Email: "nil-login@example.com"})
+ s.mustCreateUser(&service.User{Email: "older-login@example.com", LastLoginAt: &older})
+ s.mustCreateUser(&service.User{Email: "newer-login@example.com", LastLoginAt: &newer})
+
+ users, _, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{
+ Page: 1,
+ PageSize: 10,
+ SortBy: "last_login_at",
+ SortOrder: "desc",
+ }, service.UserListFilters{})
+ s.Require().NoError(err)
+ s.Require().Len(users, 3)
+ s.Require().Equal("newer-login@example.com", users[0].Email)
+ s.Require().Equal("older-login@example.com", users[1].Email)
+ s.Require().Equal("nil-login@example.com", users[2].Email)
+}
+
+func (s *UserRepoSuite) TestListWithFilters_SortByLastActiveAtAsc() {
+ earlier := time.Now().Add(-3 * time.Hour).UTC().Truncate(time.Microsecond)
+ later := time.Now().Add(-45 * time.Minute).UTC().Truncate(time.Microsecond)
+
+ s.mustCreateUser(&service.User{Email: "nil-active@example.com"})
+ s.mustCreateUser(&service.User{Email: "later-active@example.com", LastActiveAt: &later})
+ s.mustCreateUser(&service.User{Email: "earlier-active@example.com", LastActiveAt: &earlier})
+
+ users, _, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{
+ Page: 1,
+ PageSize: 10,
+ SortBy: "last_active_at",
+ SortOrder: "asc",
+ }, service.UserListFilters{})
+ s.Require().NoError(err)
+ s.Require().Len(users, 3)
+ s.Require().Equal("earlier-active@example.com", users[0].Email)
+ s.Require().Equal("later-active@example.com", users[1].Email)
+ s.Require().Equal("nil-active@example.com", users[2].Email)
+}
+
func TestUserRepoSortSuiteSmoke(_ *testing.T) {}
diff --git a/backend/internal/server/api_contract_test.go b/backend/internal/server/api_contract_test.go
index b686b986..e903898f 100644
--- a/backend/internal/server/api_contract_test.go
+++ b/backend/internal/server/api_contract_test.go
@@ -479,7 +479,7 @@ func TestAPIContracts(t *testing.T) {
service.SettingKeyOIDCConnectRedirectURL: "",
service.SettingKeyOIDCConnectFrontendRedirectURL: "/auth/oidc/callback",
service.SettingKeyOIDCConnectTokenAuthMethod: "client_secret_post",
- service.SettingKeyOIDCConnectUsePKCE: "false",
+ service.SettingKeyOIDCConnectUsePKCE: "true",
service.SettingKeyOIDCConnectValidateIDToken: "true",
service.SettingKeyOIDCConnectAllowedSigningAlgs: "RS256,ES256,PS256",
service.SettingKeyOIDCConnectClockSkewSeconds: "120",
@@ -549,7 +549,7 @@ func TestAPIContracts(t *testing.T) {
"oidc_connect_redirect_url": "",
"oidc_connect_frontend_redirect_url": "/auth/oidc/callback",
"oidc_connect_token_auth_method": "client_secret_post",
- "oidc_connect_use_pkce": false,
+ "oidc_connect_use_pkce": true,
"oidc_connect_validate_id_token": true,
"oidc_connect_allowed_signing_algs": "RS256,ES256,PS256",
"oidc_connect_clock_skew_seconds": 120,
diff --git a/backend/internal/server/routes/auth.go b/backend/internal/server/routes/auth.go
index c143b030..911a4064 100644
--- a/backend/internal/server/routes/auth.go
+++ b/backend/internal/server/routes/auth.go
@@ -64,12 +64,26 @@ func RegisterAuthRoutes(
}), h.Auth.ResetPassword)
auth.GET("/oauth/linuxdo/start", h.Auth.LinuxDoOAuthStart)
auth.GET("/oauth/linuxdo/callback", h.Auth.LinuxDoOAuthCallback)
+ auth.GET("/oauth/wechat/start", h.Auth.WeChatOAuthStart)
+ auth.GET("/oauth/wechat/callback", h.Auth.WeChatOAuthCallback)
+ auth.POST("/oauth/pending/exchange",
+ rateLimiter.LimitWithOptions("oauth-pending-exchange", 20, time.Minute, middleware.RateLimitOptions{
+ FailureMode: middleware.RateLimitFailClose,
+ }),
+ h.Auth.ExchangePendingOAuthCompletion,
+ )
auth.POST("/oauth/linuxdo/complete-registration",
rateLimiter.LimitWithOptions("oauth-linuxdo-complete", 10, time.Minute, middleware.RateLimitOptions{
FailureMode: middleware.RateLimitFailClose,
}),
h.Auth.CompleteLinuxDoOAuthRegistration,
)
+ auth.POST("/oauth/wechat/complete-registration",
+ rateLimiter.LimitWithOptions("oauth-wechat-complete", 10, time.Minute, middleware.RateLimitOptions{
+ FailureMode: middleware.RateLimitFailClose,
+ }),
+ h.Auth.CompleteWeChatOAuthRegistration,
+ )
auth.GET("/oauth/oidc/start", h.Auth.OIDCOAuthStart)
auth.GET("/oauth/oidc/callback", h.Auth.OIDCOAuthCallback)
auth.POST("/oauth/oidc/complete-registration",
diff --git a/backend/internal/service/admin_service_apikey_test.go b/backend/internal/service/admin_service_apikey_test.go
index 419ddbc3..b802a9c2 100644
--- a/backend/internal/service/admin_service_apikey_test.go
+++ b/backend/internal/service/admin_service_apikey_test.go
@@ -44,6 +44,15 @@ func (s *userRepoStubForGroupUpdate) GetFirstAdmin(context.Context) (*User, erro
}
func (s *userRepoStubForGroupUpdate) Update(context.Context, *User) error { panic("unexpected") }
func (s *userRepoStubForGroupUpdate) Delete(context.Context, int64) error { panic("unexpected") }
+func (s *userRepoStubForGroupUpdate) GetUserAvatar(context.Context, int64) (*UserAvatar, error) {
+ panic("unexpected")
+}
+func (s *userRepoStubForGroupUpdate) UpsertUserAvatar(context.Context, int64, UpsertUserAvatarInput) (*UserAvatar, error) {
+ panic("unexpected")
+}
+func (s *userRepoStubForGroupUpdate) DeleteUserAvatar(context.Context, int64) error {
+ panic("unexpected")
+}
func (s *userRepoStubForGroupUpdate) List(context.Context, pagination.PaginationParams) ([]User, *pagination.PaginationResult, error) {
panic("unexpected")
}
diff --git a/backend/internal/service/admin_service_delete_test.go b/backend/internal/service/admin_service_delete_test.go
index fbc856cf..323286b0 100644
--- a/backend/internal/service/admin_service_delete_test.go
+++ b/backend/internal/service/admin_service_delete_test.go
@@ -62,6 +62,18 @@ func (s *userRepoStub) Delete(ctx context.Context, id int64) error {
return s.deleteErr
}
+func (s *userRepoStub) GetUserAvatar(ctx context.Context, userID int64) (*UserAvatar, error) {
+ panic("unexpected GetUserAvatar call")
+}
+
+func (s *userRepoStub) UpsertUserAvatar(ctx context.Context, userID int64, input UpsertUserAvatarInput) (*UserAvatar, error) {
+ panic("unexpected UpsertUserAvatar call")
+}
+
+func (s *userRepoStub) DeleteUserAvatar(ctx context.Context, userID int64) error {
+ panic("unexpected DeleteUserAvatar call")
+}
+
func (s *userRepoStub) List(ctx context.Context, params pagination.PaginationParams) ([]User, *pagination.PaginationResult, error) {
panic("unexpected List call")
}
diff --git a/backend/internal/service/auth_pending_identity_service.go b/backend/internal/service/auth_pending_identity_service.go
new file mode 100644
index 00000000..b7e86e12
--- /dev/null
+++ b/backend/internal/service/auth_pending_identity_service.go
@@ -0,0 +1,326 @@
+package service
+
+import (
+ "context"
+ "crypto/rand"
+ "crypto/sha256"
+ "encoding/hex"
+ "fmt"
+ "strings"
+ "time"
+
+ dbent "github.com/Wei-Shaw/sub2api/ent"
+ "github.com/Wei-Shaw/sub2api/ent/identityadoptiondecision"
+ "github.com/Wei-Shaw/sub2api/ent/pendingauthsession"
+ infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
+)
+
+var (
+ ErrPendingAuthSessionNotFound = infraerrors.NotFound("PENDING_AUTH_SESSION_NOT_FOUND", "pending auth session not found")
+ ErrPendingAuthSessionExpired = infraerrors.Unauthorized("PENDING_AUTH_SESSION_EXPIRED", "pending auth session has expired")
+ ErrPendingAuthSessionConsumed = infraerrors.Unauthorized("PENDING_AUTH_SESSION_CONSUMED", "pending auth session has already been used")
+ ErrPendingAuthCodeInvalid = infraerrors.Unauthorized("PENDING_AUTH_CODE_INVALID", "pending auth completion code is invalid")
+ ErrPendingAuthCodeExpired = infraerrors.Unauthorized("PENDING_AUTH_CODE_EXPIRED", "pending auth completion code has expired")
+ ErrPendingAuthCodeConsumed = infraerrors.Unauthorized("PENDING_AUTH_CODE_CONSUMED", "pending auth completion code has already been used")
+ ErrPendingAuthBrowserMismatch = infraerrors.Unauthorized("PENDING_AUTH_BROWSER_MISMATCH", "pending auth completion code does not match this browser session")
+)
+
+const (
+ defaultPendingAuthTTL = 15 * time.Minute
+ defaultPendingAuthCompletionTTL = 5 * time.Minute
+)
+
+type PendingAuthIdentityKey struct {
+ ProviderType string
+ ProviderKey string
+ ProviderSubject string
+}
+
+type CreatePendingAuthSessionInput struct {
+ SessionToken string
+ Intent string
+ Identity PendingAuthIdentityKey
+ TargetUserID *int64
+ RedirectTo string
+ ResolvedEmail string
+ RegistrationPasswordHash string
+ BrowserSessionKey string
+ UpstreamIdentityClaims map[string]any
+ LocalFlowState map[string]any
+ ExpiresAt time.Time
+}
+
+type IssuePendingAuthCompletionCodeInput struct {
+ PendingAuthSessionID int64
+ BrowserSessionKey string
+ TTL time.Duration
+}
+
+type IssuePendingAuthCompletionCodeResult struct {
+ Code string
+ ExpiresAt time.Time
+}
+
+type PendingIdentityAdoptionDecisionInput struct {
+ PendingAuthSessionID int64
+ IdentityID *int64
+ AdoptDisplayName bool
+ AdoptAvatar bool
+}
+
+type AuthPendingIdentityService struct {
+ entClient *dbent.Client
+}
+
+func NewAuthPendingIdentityService(entClient *dbent.Client) *AuthPendingIdentityService {
+ return &AuthPendingIdentityService{entClient: entClient}
+}
+
+func (s *AuthPendingIdentityService) CreatePendingSession(ctx context.Context, input CreatePendingAuthSessionInput) (*dbent.PendingAuthSession, error) {
+ if s == nil || s.entClient == nil {
+ return nil, fmt.Errorf("pending auth ent client is not configured")
+ }
+
+ sessionToken := strings.TrimSpace(input.SessionToken)
+ if sessionToken == "" {
+ var err error
+ sessionToken, err = randomOpaqueToken(24)
+ if err != nil {
+ return nil, err
+ }
+ }
+
+ expiresAt := input.ExpiresAt.UTC()
+ if expiresAt.IsZero() {
+ expiresAt = time.Now().UTC().Add(defaultPendingAuthTTL)
+ }
+
+ create := s.entClient.PendingAuthSession.Create().
+ SetSessionToken(sessionToken).
+ SetIntent(strings.TrimSpace(input.Intent)).
+ SetProviderType(strings.TrimSpace(input.Identity.ProviderType)).
+ SetProviderKey(strings.TrimSpace(input.Identity.ProviderKey)).
+ SetProviderSubject(strings.TrimSpace(input.Identity.ProviderSubject)).
+ SetRedirectTo(strings.TrimSpace(input.RedirectTo)).
+ SetResolvedEmail(strings.TrimSpace(input.ResolvedEmail)).
+ SetRegistrationPasswordHash(strings.TrimSpace(input.RegistrationPasswordHash)).
+ SetBrowserSessionKey(strings.TrimSpace(input.BrowserSessionKey)).
+ SetUpstreamIdentityClaims(copyPendingMap(input.UpstreamIdentityClaims)).
+ SetLocalFlowState(copyPendingMap(input.LocalFlowState)).
+ SetExpiresAt(expiresAt)
+ if input.TargetUserID != nil {
+ create = create.SetTargetUserID(*input.TargetUserID)
+ }
+ return create.Save(ctx)
+}
+
+func (s *AuthPendingIdentityService) IssueCompletionCode(ctx context.Context, input IssuePendingAuthCompletionCodeInput) (*IssuePendingAuthCompletionCodeResult, error) {
+ if s == nil || s.entClient == nil {
+ return nil, fmt.Errorf("pending auth ent client is not configured")
+ }
+
+ session, err := s.entClient.PendingAuthSession.Get(ctx, input.PendingAuthSessionID)
+ if err != nil {
+ if dbent.IsNotFound(err) {
+ return nil, ErrPendingAuthSessionNotFound
+ }
+ return nil, err
+ }
+
+ code, err := randomOpaqueToken(24)
+ if err != nil {
+ return nil, err
+ }
+ ttl := input.TTL
+ if ttl <= 0 {
+ ttl = defaultPendingAuthCompletionTTL
+ }
+ expiresAt := time.Now().UTC().Add(ttl)
+
+ update := s.entClient.PendingAuthSession.UpdateOneID(session.ID).
+ SetCompletionCodeHash(hashPendingAuthCode(code)).
+ SetCompletionCodeExpiresAt(expiresAt)
+ if strings.TrimSpace(input.BrowserSessionKey) != "" {
+ update = update.SetBrowserSessionKey(strings.TrimSpace(input.BrowserSessionKey))
+ }
+ if _, err := update.Save(ctx); err != nil {
+ return nil, err
+ }
+
+ return &IssuePendingAuthCompletionCodeResult{
+ Code: code,
+ ExpiresAt: expiresAt,
+ }, nil
+}
+
+func (s *AuthPendingIdentityService) ConsumeCompletionCode(ctx context.Context, rawCode, browserSessionKey string) (*dbent.PendingAuthSession, error) {
+ if s == nil || s.entClient == nil {
+ return nil, fmt.Errorf("pending auth ent client is not configured")
+ }
+
+ codeHash := hashPendingAuthCode(strings.TrimSpace(rawCode))
+ session, err := s.entClient.PendingAuthSession.Query().
+ Where(pendingauthsession.CompletionCodeHashEQ(codeHash)).
+ Only(ctx)
+ if err != nil {
+ if dbent.IsNotFound(err) {
+ return nil, ErrPendingAuthCodeInvalid
+ }
+ return nil, err
+ }
+
+ return s.consumeSession(ctx, session, browserSessionKey, ErrPendingAuthCodeExpired, ErrPendingAuthCodeConsumed)
+}
+
+func (s *AuthPendingIdentityService) ConsumeBrowserSession(ctx context.Context, sessionToken, browserSessionKey string) (*dbent.PendingAuthSession, error) {
+ if s == nil || s.entClient == nil {
+ return nil, fmt.Errorf("pending auth ent client is not configured")
+ }
+
+ session, err := s.getBrowserSession(ctx, sessionToken)
+ if err != nil {
+ return nil, err
+ }
+
+ return s.consumeSession(ctx, session, browserSessionKey, ErrPendingAuthSessionExpired, ErrPendingAuthSessionConsumed)
+}
+
+func (s *AuthPendingIdentityService) GetBrowserSession(ctx context.Context, sessionToken, browserSessionKey string) (*dbent.PendingAuthSession, error) {
+ if s == nil || s.entClient == nil {
+ return nil, fmt.Errorf("pending auth ent client is not configured")
+ }
+
+ session, err := s.getBrowserSession(ctx, sessionToken)
+ if err != nil {
+ return nil, err
+ }
+ if err := validatePendingSessionState(session, browserSessionKey, ErrPendingAuthSessionExpired, ErrPendingAuthSessionConsumed); err != nil {
+ return nil, err
+ }
+ return session, nil
+}
+
+func (s *AuthPendingIdentityService) getBrowserSession(ctx context.Context, sessionToken string) (*dbent.PendingAuthSession, error) {
+ if s == nil || s.entClient == nil {
+ return nil, fmt.Errorf("pending auth ent client is not configured")
+ }
+
+ sessionToken = strings.TrimSpace(sessionToken)
+ if sessionToken == "" {
+ return nil, ErrPendingAuthSessionNotFound
+ }
+
+ session, err := s.entClient.PendingAuthSession.Query().
+ Where(pendingauthsession.SessionTokenEQ(sessionToken)).
+ Only(ctx)
+ if err != nil {
+ if dbent.IsNotFound(err) {
+ return nil, ErrPendingAuthSessionNotFound
+ }
+ return nil, err
+ }
+ return session, nil
+}
+
+func (s *AuthPendingIdentityService) consumeSession(
+ ctx context.Context,
+ session *dbent.PendingAuthSession,
+ browserSessionKey string,
+ expiredErr error,
+ consumedErr error,
+) (*dbent.PendingAuthSession, error) {
+ if err := validatePendingSessionState(session, browserSessionKey, expiredErr, consumedErr); err != nil {
+ return nil, err
+ }
+
+ now := time.Now().UTC()
+ updated, err := s.entClient.PendingAuthSession.UpdateOneID(session.ID).
+ SetConsumedAt(now).
+ SetCompletionCodeHash("").
+ ClearCompletionCodeExpiresAt().
+ Save(ctx)
+ if err != nil {
+ return nil, err
+ }
+ return updated, nil
+}
+
+func validatePendingSessionState(session *dbent.PendingAuthSession, browserSessionKey string, expiredErr error, consumedErr error) error {
+ if session == nil {
+ return ErrPendingAuthSessionNotFound
+ }
+
+ now := time.Now().UTC()
+ if session.ConsumedAt != nil {
+ return consumedErr
+ }
+ if !session.ExpiresAt.IsZero() && now.After(session.ExpiresAt) {
+ return expiredErr
+ }
+ if session.CompletionCodeExpiresAt != nil && now.After(*session.CompletionCodeExpiresAt) {
+ return expiredErr
+ }
+ if strings.TrimSpace(session.BrowserSessionKey) != "" && strings.TrimSpace(browserSessionKey) != strings.TrimSpace(session.BrowserSessionKey) {
+ return ErrPendingAuthBrowserMismatch
+ }
+ return nil
+}
+
+func (s *AuthPendingIdentityService) UpsertAdoptionDecision(ctx context.Context, input PendingIdentityAdoptionDecisionInput) (*dbent.IdentityAdoptionDecision, error) {
+ if s == nil || s.entClient == nil {
+ return nil, fmt.Errorf("pending auth ent client is not configured")
+ }
+
+ existing, err := s.entClient.IdentityAdoptionDecision.Query().
+ Where(identityadoptiondecision.PendingAuthSessionIDEQ(input.PendingAuthSessionID)).
+ Only(ctx)
+ if err != nil && !dbent.IsNotFound(err) {
+ return nil, err
+ }
+ if existing == nil {
+ create := s.entClient.IdentityAdoptionDecision.Create().
+ SetPendingAuthSessionID(input.PendingAuthSessionID).
+ SetAdoptDisplayName(input.AdoptDisplayName).
+ SetAdoptAvatar(input.AdoptAvatar).
+ SetDecidedAt(time.Now().UTC())
+ if input.IdentityID != nil {
+ create = create.SetIdentityID(*input.IdentityID)
+ }
+ return create.Save(ctx)
+ }
+
+ update := s.entClient.IdentityAdoptionDecision.UpdateOneID(existing.ID).
+ SetAdoptDisplayName(input.AdoptDisplayName).
+ SetAdoptAvatar(input.AdoptAvatar)
+ if input.IdentityID != nil {
+ update = update.SetIdentityID(*input.IdentityID)
+ }
+ return update.Save(ctx)
+}
+
+func copyPendingMap(in map[string]any) map[string]any {
+ if len(in) == 0 {
+ return map[string]any{}
+ }
+ out := make(map[string]any, len(in))
+ for k, v := range in {
+ out[k] = v
+ }
+ return out
+}
+
+func randomOpaqueToken(byteLen int) (string, error) {
+ if byteLen <= 0 {
+ byteLen = 16
+ }
+ buf := make([]byte, byteLen)
+ if _, err := rand.Read(buf); err != nil {
+ return "", err
+ }
+ return hex.EncodeToString(buf), nil
+}
+
+func hashPendingAuthCode(code string) string {
+ sum := sha256.Sum256([]byte(code))
+ return hex.EncodeToString(sum[:])
+}
diff --git a/backend/internal/service/auth_pending_identity_service_test.go b/backend/internal/service/auth_pending_identity_service_test.go
new file mode 100644
index 00000000..c69ebfd2
--- /dev/null
+++ b/backend/internal/service/auth_pending_identity_service_test.go
@@ -0,0 +1,224 @@
+//go:build unit
+
+package service
+
+import (
+ "context"
+ "database/sql"
+ "testing"
+ "time"
+
+ dbent "github.com/Wei-Shaw/sub2api/ent"
+ "github.com/Wei-Shaw/sub2api/ent/enttest"
+ "github.com/stretchr/testify/require"
+
+ "entgo.io/ent/dialect"
+ entsql "entgo.io/ent/dialect/sql"
+ _ "modernc.org/sqlite"
+)
+
+func newAuthPendingIdentityServiceTestClient(t *testing.T) (*AuthPendingIdentityService, *dbent.Client) {
+ t.Helper()
+
+ db, err := sql.Open("sqlite", "file:auth_pending_identity_service?mode=memory&cache=shared")
+ require.NoError(t, err)
+ t.Cleanup(func() { _ = db.Close() })
+
+ _, err = db.Exec("PRAGMA foreign_keys = ON")
+ require.NoError(t, err)
+
+ drv := entsql.OpenDB(dialect.SQLite, db)
+ client := enttest.NewClient(t, enttest.WithOptions(dbent.Driver(drv)))
+ t.Cleanup(func() { _ = client.Close() })
+
+ return NewAuthPendingIdentityService(client), client
+}
+
+func TestAuthPendingIdentityService_CreatePendingSessionStoresSeparatedState(t *testing.T) {
+ svc, client := newAuthPendingIdentityServiceTestClient(t)
+ ctx := context.Background()
+
+ targetUser, err := client.User.Create().
+ SetEmail("pending-target@example.com").
+ SetPasswordHash("hash").
+ SetRole(RoleUser).
+ SetStatus(StatusActive).
+ Save(ctx)
+ require.NoError(t, err)
+
+ session, err := svc.CreatePendingSession(ctx, CreatePendingAuthSessionInput{
+ Intent: "bind_current_user",
+ Identity: PendingAuthIdentityKey{
+ ProviderType: "wechat",
+ ProviderKey: "wechat-open",
+ ProviderSubject: "union-123",
+ },
+ TargetUserID: &targetUser.ID,
+ RedirectTo: "/profile",
+ ResolvedEmail: "user@example.com",
+ BrowserSessionKey: "browser-1",
+ UpstreamIdentityClaims: map[string]any{"nickname": "wx-user", "avatar_url": "https://cdn.example/avatar.png"},
+ LocalFlowState: map[string]any{"step": "email_required"},
+ })
+ require.NoError(t, err)
+ require.NotEmpty(t, session.SessionToken)
+ require.Equal(t, "bind_current_user", session.Intent)
+ require.Equal(t, "wechat", session.ProviderType)
+ require.NotNil(t, session.TargetUserID)
+ require.Equal(t, targetUser.ID, *session.TargetUserID)
+ require.Equal(t, "wx-user", session.UpstreamIdentityClaims["nickname"])
+ require.Equal(t, "email_required", session.LocalFlowState["step"])
+}
+
+func TestAuthPendingIdentityService_CompletionCodeIsBrowserBoundAndOneTime(t *testing.T) {
+ svc, _ := newAuthPendingIdentityServiceTestClient(t)
+ ctx := context.Background()
+
+ session, err := svc.CreatePendingSession(ctx, CreatePendingAuthSessionInput{
+ Intent: "login",
+ Identity: PendingAuthIdentityKey{
+ ProviderType: "linuxdo",
+ ProviderKey: "linuxdo-main",
+ ProviderSubject: "subject-1",
+ },
+ BrowserSessionKey: "browser-expected",
+ UpstreamIdentityClaims: map[string]any{"nickname": "linux-user"},
+ LocalFlowState: map[string]any{"step": "pending"},
+ })
+ require.NoError(t, err)
+
+ issued, err := svc.IssueCompletionCode(ctx, IssuePendingAuthCompletionCodeInput{
+ PendingAuthSessionID: session.ID,
+ BrowserSessionKey: "browser-expected",
+ })
+ require.NoError(t, err)
+ require.NotEmpty(t, issued.Code)
+
+ _, err = svc.ConsumeCompletionCode(ctx, issued.Code, "browser-other")
+ require.ErrorIs(t, err, ErrPendingAuthBrowserMismatch)
+
+ consumed, err := svc.ConsumeCompletionCode(ctx, issued.Code, "browser-expected")
+ require.NoError(t, err)
+ require.NotNil(t, consumed.ConsumedAt)
+ require.Empty(t, consumed.CompletionCodeHash)
+ require.Nil(t, consumed.CompletionCodeExpiresAt)
+
+ _, err = svc.ConsumeCompletionCode(ctx, issued.Code, "browser-expected")
+ require.ErrorIs(t, err, ErrPendingAuthCodeInvalid)
+}
+
+func TestAuthPendingIdentityService_CompletionCodeExpires(t *testing.T) {
+ svc, client := newAuthPendingIdentityServiceTestClient(t)
+ ctx := context.Background()
+
+ session, err := svc.CreatePendingSession(ctx, CreatePendingAuthSessionInput{
+ Intent: "login",
+ Identity: PendingAuthIdentityKey{
+ ProviderType: "oidc",
+ ProviderKey: "https://issuer.example",
+ ProviderSubject: "subject-1",
+ },
+ BrowserSessionKey: "browser-expired",
+ })
+ require.NoError(t, err)
+
+ issued, err := svc.IssueCompletionCode(ctx, IssuePendingAuthCompletionCodeInput{
+ PendingAuthSessionID: session.ID,
+ BrowserSessionKey: "browser-expired",
+ TTL: time.Second,
+ })
+ require.NoError(t, err)
+
+ _, err = client.PendingAuthSession.UpdateOneID(session.ID).
+ SetCompletionCodeExpiresAt(time.Now().UTC().Add(-time.Minute)).
+ Save(ctx)
+ require.NoError(t, err)
+
+ _, err = svc.ConsumeCompletionCode(ctx, issued.Code, "browser-expired")
+ require.ErrorIs(t, err, ErrPendingAuthCodeExpired)
+}
+
+func TestAuthPendingIdentityService_UpsertAdoptionDecision(t *testing.T) {
+ svc, client := newAuthPendingIdentityServiceTestClient(t)
+ ctx := context.Background()
+
+ user, err := client.User.Create().
+ SetEmail("adoption@example.com").
+ SetPasswordHash("hash").
+ SetRole(RoleUser).
+ SetStatus(StatusActive).
+ Save(ctx)
+ require.NoError(t, err)
+
+ identity, err := client.AuthIdentity.Create().
+ SetUserID(user.ID).
+ SetProviderType("wechat").
+ SetProviderKey("wechat-open").
+ SetProviderSubject("union-adoption").
+ SetMetadata(map[string]any{}).
+ Save(ctx)
+ require.NoError(t, err)
+
+ session, err := svc.CreatePendingSession(ctx, CreatePendingAuthSessionInput{
+ Intent: "bind_current_user",
+ Identity: PendingAuthIdentityKey{
+ ProviderType: "wechat",
+ ProviderKey: "wechat-open",
+ ProviderSubject: "union-adoption",
+ },
+ })
+ require.NoError(t, err)
+
+ first, err := svc.UpsertAdoptionDecision(ctx, PendingIdentityAdoptionDecisionInput{
+ PendingAuthSessionID: session.ID,
+ AdoptDisplayName: true,
+ AdoptAvatar: false,
+ })
+ require.NoError(t, err)
+ require.True(t, first.AdoptDisplayName)
+ require.False(t, first.AdoptAvatar)
+ require.Nil(t, first.IdentityID)
+
+ second, err := svc.UpsertAdoptionDecision(ctx, PendingIdentityAdoptionDecisionInput{
+ PendingAuthSessionID: session.ID,
+ IdentityID: &identity.ID,
+ AdoptDisplayName: true,
+ AdoptAvatar: true,
+ })
+ require.NoError(t, err)
+ require.Equal(t, first.ID, second.ID)
+ require.NotNil(t, second.IdentityID)
+ require.Equal(t, identity.ID, *second.IdentityID)
+ require.True(t, second.AdoptAvatar)
+}
+
+func TestAuthPendingIdentityService_ConsumeBrowserSession(t *testing.T) {
+ svc, _ := newAuthPendingIdentityServiceTestClient(t)
+ ctx := context.Background()
+
+ session, err := svc.CreatePendingSession(ctx, CreatePendingAuthSessionInput{
+ Intent: "login",
+ Identity: PendingAuthIdentityKey{
+ ProviderType: "linuxdo",
+ ProviderKey: "linuxdo",
+ ProviderSubject: "subject-session-token",
+ },
+ BrowserSessionKey: "browser-session",
+ LocalFlowState: map[string]any{
+ "completion_response": map[string]any{
+ "access_token": "token",
+ },
+ },
+ })
+ require.NoError(t, err)
+
+ _, err = svc.ConsumeBrowserSession(ctx, session.SessionToken, "browser-other")
+ require.ErrorIs(t, err, ErrPendingAuthBrowserMismatch)
+
+ consumed, err := svc.ConsumeBrowserSession(ctx, session.SessionToken, "browser-session")
+ require.NoError(t, err)
+ require.NotNil(t, consumed.ConsumedAt)
+
+ _, err = svc.ConsumeBrowserSession(ctx, session.SessionToken, "browser-session")
+ require.ErrorIs(t, err, ErrPendingAuthSessionConsumed)
+}
diff --git a/backend/internal/service/auth_service.go b/backend/internal/service/auth_service.go
index fd28cd42..962009ce 100644
--- a/backend/internal/service/auth_service.go
+++ b/backend/internal/service/auth_service.go
@@ -13,6 +13,7 @@ import (
"time"
dbent "github.com/Wei-Shaw/sub2api/ent"
+ "github.com/Wei-Shaw/sub2api/ent/authidentity"
"github.com/Wei-Shaw/sub2api/internal/config"
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
@@ -106,6 +107,13 @@ func NewAuthService(
}
}
+func (s *AuthService) EntClient() *dbent.Client {
+ if s == nil {
+ return nil
+ }
+ return s.entClient
+}
+
// Register 用户注册,返回token和用户
func (s *AuthService) Register(ctx context.Context, email, password string) (string, *User, error) {
return s.RegisterWithVerification(ctx, email, password, "", "", "")
@@ -205,6 +213,7 @@ func (s *AuthService) RegisterWithVerification(ctx context.Context, email, passw
logger.LegacyPrintf("service.auth", "[Auth] Database error creating user: %v", err)
return "", nil, ErrServiceUnavailable
}
+ s.postAuthUserBootstrap(ctx, user, "email", true)
s.assignDefaultSubscriptions(ctx, user.ID)
// 标记邀请码为已使用(如果使用了邀请码)
@@ -421,6 +430,7 @@ func (s *AuthService) Login(ctx context.Context, email, password string) (string
if !user.IsActive() {
return "", nil, ErrUserNotActive
}
+ s.touchUserLogin(ctx, user.ID)
// 生成JWT token
token, err := s.GenerateToken(user)
@@ -501,6 +511,7 @@ func (s *AuthService) LoginOrRegisterOAuth(ctx context.Context, email, username
}
} else {
user = newUser
+ s.postAuthUserBootstrap(ctx, user, inferLegacySignupSource(email), true)
s.assignDefaultSubscriptions(ctx, user.ID)
}
} else {
@@ -520,6 +531,7 @@ func (s *AuthService) LoginOrRegisterOAuth(ctx context.Context, email, username
logger.LegacyPrintf("service.auth", "[Auth] Failed to update username after oauth login: %v", err)
}
}
+ s.touchUserLogin(ctx, user.ID)
token, err := s.GenerateToken(user)
if err != nil {
@@ -630,6 +642,7 @@ func (s *AuthService) LoginOrRegisterOAuthWithTokenPair(ctx context.Context, ema
return nil, nil, ErrServiceUnavailable
}
user = newUser
+ s.postAuthUserBootstrap(ctx, user, inferLegacySignupSource(email), true)
s.assignDefaultSubscriptions(ctx, user.ID)
}
} else {
@@ -646,6 +659,7 @@ func (s *AuthService) LoginOrRegisterOAuthWithTokenPair(ctx context.Context, ema
}
} else {
user = newUser
+ s.postAuthUserBootstrap(ctx, user, inferLegacySignupSource(email), true)
s.assignDefaultSubscriptions(ctx, user.ID)
if invitationRedeemCode != nil {
if err := s.redeemRepo.Use(ctx, invitationRedeemCode.ID, user.ID); err != nil {
@@ -670,6 +684,7 @@ func (s *AuthService) LoginOrRegisterOAuthWithTokenPair(ctx context.Context, ema
logger.LegacyPrintf("service.auth", "[Auth] Failed to update username after oauth login: %v", err)
}
}
+ s.touchUserLogin(ctx, user.ID)
tokenPair, err := s.GenerateTokenPair(ctx, user, "")
if err != nil {
@@ -678,63 +693,6 @@ func (s *AuthService) LoginOrRegisterOAuthWithTokenPair(ctx context.Context, ema
return tokenPair, user, nil
}
-// pendingOAuthTokenTTL is the validity period for pending OAuth tokens.
-const pendingOAuthTokenTTL = 10 * time.Minute
-
-// pendingOAuthPurpose is the purpose claim value for pending OAuth registration tokens.
-const pendingOAuthPurpose = "pending_oauth_registration"
-
-type pendingOAuthClaims struct {
- Email string `json:"email"`
- Username string `json:"username"`
- Purpose string `json:"purpose"`
- jwt.RegisteredClaims
-}
-
-// CreatePendingOAuthToken generates a short-lived JWT that carries the OAuth identity
-// while waiting for the user to supply an invitation code.
-func (s *AuthService) CreatePendingOAuthToken(email, username string) (string, error) {
- now := time.Now()
- claims := &pendingOAuthClaims{
- Email: email,
- Username: username,
- Purpose: pendingOAuthPurpose,
- RegisteredClaims: jwt.RegisteredClaims{
- ExpiresAt: jwt.NewNumericDate(now.Add(pendingOAuthTokenTTL)),
- IssuedAt: jwt.NewNumericDate(now),
- NotBefore: jwt.NewNumericDate(now),
- },
- }
- token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
- return token.SignedString([]byte(s.cfg.JWT.Secret))
-}
-
-// VerifyPendingOAuthToken validates a pending OAuth token and returns the embedded identity.
-// Returns ErrInvalidToken when the token is invalid or expired.
-func (s *AuthService) VerifyPendingOAuthToken(tokenStr string) (email, username string, err error) {
- if len(tokenStr) > maxTokenLength {
- return "", "", ErrInvalidToken
- }
- parser := jwt.NewParser(jwt.WithValidMethods([]string{jwt.SigningMethodHS256.Name}))
- token, parseErr := parser.ParseWithClaims(tokenStr, &pendingOAuthClaims{}, func(t *jwt.Token) (any, error) {
- if _, ok := t.Method.(*jwt.SigningMethodHMAC); !ok {
- return nil, fmt.Errorf("unexpected signing method: %v", t.Header["alg"])
- }
- return []byte(s.cfg.JWT.Secret), nil
- })
- if parseErr != nil {
- return "", "", ErrInvalidToken
- }
- claims, ok := token.Claims.(*pendingOAuthClaims)
- if !ok || !token.Valid {
- return "", "", ErrInvalidToken
- }
- if claims.Purpose != pendingOAuthPurpose {
- return "", "", ErrInvalidToken
- }
- return claims.Email, claims.Username, nil
-}
-
func (s *AuthService) assignDefaultSubscriptions(ctx context.Context, userID int64) {
if s.settingService == nil || s.defaultSubAssigner == nil || userID <= 0 {
return
@@ -752,6 +710,95 @@ func (s *AuthService) assignDefaultSubscriptions(ctx context.Context, userID int
}
}
+func (s *AuthService) postAuthUserBootstrap(ctx context.Context, user *User, signupSource string, touchLogin bool) {
+ if user == nil || user.ID <= 0 {
+ return
+ }
+
+ if strings.TrimSpace(signupSource) == "" {
+ signupSource = "email"
+ }
+ s.updateUserSignupSource(ctx, user.ID, signupSource)
+
+ if signupSource == "email" {
+ s.ensureEmailAuthIdentity(ctx, user)
+ }
+ if touchLogin {
+ s.touchUserLogin(ctx, user.ID)
+ }
+}
+
+func (s *AuthService) updateUserSignupSource(ctx context.Context, userID int64, signupSource string) {
+ if s == nil || s.entClient == nil || userID <= 0 {
+ return
+ }
+ if strings.TrimSpace(signupSource) == "" {
+ return
+ }
+ if err := s.entClient.User.UpdateOneID(userID).
+ SetSignupSource(signupSource).
+ Exec(ctx); err != nil {
+ logger.LegacyPrintf("service.auth", "[Auth] Failed to update signup source: user_id=%d source=%s err=%v", userID, signupSource, err)
+ }
+}
+
+func (s *AuthService) touchUserLogin(ctx context.Context, userID int64) {
+ if s == nil || s.entClient == nil || userID <= 0 {
+ return
+ }
+ now := time.Now().UTC()
+ if err := s.entClient.User.UpdateOneID(userID).
+ SetLastLoginAt(now).
+ SetLastActiveAt(now).
+ Exec(ctx); err != nil {
+ logger.LegacyPrintf("service.auth", "[Auth] Failed to touch login timestamps: user_id=%d err=%v", userID, err)
+ }
+}
+
+func (s *AuthService) ensureEmailAuthIdentity(ctx context.Context, user *User) {
+ if s == nil || s.entClient == nil || user == nil || user.ID <= 0 {
+ return
+ }
+
+ email := strings.ToLower(strings.TrimSpace(user.Email))
+ if email == "" || isReservedEmail(email) {
+ return
+ }
+
+ if err := s.entClient.AuthIdentity.Create().
+ SetUserID(user.ID).
+ SetProviderType("email").
+ SetProviderKey("email").
+ SetProviderSubject(email).
+ SetVerifiedAt(time.Now().UTC()).
+ SetMetadata(map[string]any{
+ "source": "auth_service_dual_write",
+ }).
+ OnConflictColumns(
+ authidentity.FieldProviderType,
+ authidentity.FieldProviderKey,
+ authidentity.FieldProviderSubject,
+ ).
+ DoNothing().
+ Exec(ctx); err != nil {
+ logger.LegacyPrintf("service.auth", "[Auth] Failed to ensure email auth identity: user_id=%d email=%s err=%v", user.ID, email, err)
+ }
+}
+
+func inferLegacySignupSource(email string) string {
+ normalized := strings.ToLower(strings.TrimSpace(email))
+ switch {
+ case strings.HasSuffix(normalized, LinuxDoConnectSyntheticEmailDomain):
+ return "linuxdo"
+ case strings.HasSuffix(normalized, OIDCConnectSyntheticEmailDomain):
+ return "oidc"
+ case strings.HasSuffix(normalized, WeChatConnectSyntheticEmailDomain):
+ return "wechat"
+ default:
+ return "email"
+ }
+}
+
func (s *AuthService) validateRegistrationEmailPolicy(ctx context.Context, email string) error {
if s.settingService == nil {
return nil
@@ -834,7 +881,8 @@ func randomHexString(byteLength int) (string, error) {
func isReservedEmail(email string) bool {
normalized := strings.ToLower(strings.TrimSpace(email))
return strings.HasSuffix(normalized, LinuxDoConnectSyntheticEmailDomain) ||
- strings.HasSuffix(normalized, OIDCConnectSyntheticEmailDomain)
+ strings.HasSuffix(normalized, OIDCConnectSyntheticEmailDomain) ||
+ strings.HasSuffix(normalized, WeChatConnectSyntheticEmailDomain)
}
// GenerateToken 生成JWT access token
diff --git a/backend/internal/service/auth_service_identity_sync_test.go b/backend/internal/service/auth_service_identity_sync_test.go
new file mode 100644
index 00000000..5bd2b25d
--- /dev/null
+++ b/backend/internal/service/auth_service_identity_sync_test.go
@@ -0,0 +1,153 @@
+//go:build unit
+
+package service_test
+
+import (
+ "context"
+ "database/sql"
+ "testing"
+ "time"
+
+ dbent "github.com/Wei-Shaw/sub2api/ent"
+ "github.com/Wei-Shaw/sub2api/ent/authidentity"
+ "github.com/Wei-Shaw/sub2api/ent/enttest"
+ "github.com/Wei-Shaw/sub2api/internal/config"
+ "github.com/Wei-Shaw/sub2api/internal/repository"
+ "github.com/Wei-Shaw/sub2api/internal/service"
+ "github.com/stretchr/testify/require"
+
+ "entgo.io/ent/dialect"
+ entsql "entgo.io/ent/dialect/sql"
+ _ "modernc.org/sqlite"
+)
+
+type authIdentitySettingRepoStub struct {
+ values map[string]string
+}
+
+func (s *authIdentitySettingRepoStub) Get(context.Context, string) (*service.Setting, error) {
+ panic("unexpected Get call")
+}
+
+func (s *authIdentitySettingRepoStub) GetValue(_ context.Context, key string) (string, error) {
+ if v, ok := s.values[key]; ok {
+ return v, nil
+ }
+ return "", service.ErrSettingNotFound
+}
+
+func (s *authIdentitySettingRepoStub) Set(context.Context, string, string) error {
+ panic("unexpected Set call")
+}
+
+func (s *authIdentitySettingRepoStub) GetMultiple(context.Context, []string) (map[string]string, error) {
+ panic("unexpected GetMultiple call")
+}
+
+func (s *authIdentitySettingRepoStub) SetMultiple(context.Context, map[string]string) error {
+ panic("unexpected SetMultiple call")
+}
+
+func (s *authIdentitySettingRepoStub) GetAll(context.Context) (map[string]string, error) {
+ panic("unexpected GetAll call")
+}
+
+func (s *authIdentitySettingRepoStub) Delete(context.Context, string) error {
+ panic("unexpected Delete call")
+}
+
+func newAuthServiceWithEnt(t *testing.T) (*service.AuthService, service.UserRepository, *dbent.Client) {
+ t.Helper()
+
+ db, err := sql.Open("sqlite", "file:auth_service_identity_sync?mode=memory&cache=shared")
+ require.NoError(t, err)
+ t.Cleanup(func() { _ = db.Close() })
+
+ _, err = db.Exec("PRAGMA foreign_keys = ON")
+ require.NoError(t, err)
+
+ drv := entsql.OpenDB(dialect.SQLite, db)
+ client := enttest.NewClient(t, enttest.WithOptions(dbent.Driver(drv)))
+ t.Cleanup(func() { _ = client.Close() })
+
+ repo := repository.NewUserRepository(client, db)
+ cfg := &config.Config{
+ JWT: config.JWTConfig{
+ Secret: "test-auth-identity-secret",
+ ExpireHour: 1,
+ },
+ Default: config.DefaultConfig{
+ UserBalance: 3.5,
+ UserConcurrency: 2,
+ },
+ }
+ settingSvc := service.NewSettingService(&authIdentitySettingRepoStub{
+ values: map[string]string{
+ service.SettingKeyRegistrationEnabled: "true",
+ },
+ }, cfg)
+
+ svc := service.NewAuthService(client, repo, nil, nil, cfg, settingSvc, nil, nil, nil, nil, nil)
+ return svc, repo, client
+}
+
+func TestAuthServiceRegisterDualWritesEmailIdentity(t *testing.T) {
+ svc, _, client := newAuthServiceWithEnt(t)
+ ctx := context.Background()
+
+ token, user, err := svc.Register(ctx, "user@example.com", "password")
+ require.NoError(t, err)
+ require.NotEmpty(t, token)
+ require.NotNil(t, user)
+
+ storedUser, err := client.User.Get(ctx, user.ID)
+ require.NoError(t, err)
+ require.Equal(t, "email", storedUser.SignupSource)
+ require.NotNil(t, storedUser.LastLoginAt)
+ require.NotNil(t, storedUser.LastActiveAt)
+
+ identity, err := client.AuthIdentity.Query().
+ Where(
+ authidentity.ProviderTypeEQ("email"),
+ authidentity.ProviderKeyEQ("email"),
+ authidentity.ProviderSubjectEQ("user@example.com"),
+ ).
+ Only(ctx)
+ require.NoError(t, err)
+ require.Equal(t, user.ID, identity.UserID)
+ require.NotNil(t, identity.VerifiedAt)
+}
+
+func TestAuthServiceLoginTouchesLastLoginAt(t *testing.T) {
+ svc, repo, client := newAuthServiceWithEnt(t)
+ ctx := context.Background()
+
+ user := &service.User{
+ Email: "login@example.com",
+ Role: service.RoleUser,
+ Status: service.StatusActive,
+ Balance: 1,
+ Concurrency: 1,
+ }
+ require.NoError(t, user.SetPassword("password"))
+ require.NoError(t, repo.Create(ctx, user))
+
+ old := time.Now().Add(-2 * time.Hour).UTC().Round(time.Second)
+ _, err := client.User.UpdateOneID(user.ID).
+ SetLastLoginAt(old).
+ SetLastActiveAt(old).
+ Save(ctx)
+ require.NoError(t, err)
+
+ token, gotUser, err := svc.Login(ctx, user.Email, "password")
+ require.NoError(t, err)
+ require.NotEmpty(t, token)
+ require.NotNil(t, gotUser)
+
+ storedUser, err := client.User.Get(ctx, user.ID)
+ require.NoError(t, err)
+ require.NotNil(t, storedUser.LastLoginAt)
+ require.NotNil(t, storedUser.LastActiveAt)
+ require.True(t, storedUser.LastLoginAt.After(old))
+ require.True(t, storedUser.LastActiveAt.After(old))
+}
diff --git a/backend/internal/service/auth_service_pending_oauth_test.go b/backend/internal/service/auth_service_pending_oauth_test.go
deleted file mode 100644
index 0472e06c..00000000
--- a/backend/internal/service/auth_service_pending_oauth_test.go
+++ /dev/null
@@ -1,146 +0,0 @@
-//go:build unit
-
-package service
-
-import (
- "testing"
- "time"
-
- "github.com/Wei-Shaw/sub2api/internal/config"
- "github.com/golang-jwt/jwt/v5"
- "github.com/stretchr/testify/require"
-)
-
-func newAuthServiceForPendingOAuthTest() *AuthService {
- cfg := &config.Config{
- JWT: config.JWTConfig{
- Secret: "test-secret-pending-oauth",
- ExpireHour: 1,
- },
- }
- return NewAuthService(nil, nil, nil, nil, cfg, nil, nil, nil, nil, nil, nil)
-}
-
-// TestVerifyPendingOAuthToken_ValidToken 验证正常签发的 pending token 可以被成功解析。
-func TestVerifyPendingOAuthToken_ValidToken(t *testing.T) {
- svc := newAuthServiceForPendingOAuthTest()
-
- token, err := svc.CreatePendingOAuthToken("user@example.com", "alice")
- require.NoError(t, err)
- require.NotEmpty(t, token)
-
- email, username, err := svc.VerifyPendingOAuthToken(token)
- require.NoError(t, err)
- require.Equal(t, "user@example.com", email)
- require.Equal(t, "alice", username)
-}
-
-// TestVerifyPendingOAuthToken_RegularJWTRejected 用普通 access token 尝试验证,应返回 ErrInvalidToken。
-func TestVerifyPendingOAuthToken_RegularJWTRejected(t *testing.T) {
- svc := newAuthServiceForPendingOAuthTest()
-
- // 签发一个普通 access token(JWTClaims,无 Purpose 字段)
- accessToken, err := svc.GenerateToken(&User{
- ID: 1,
- Email: "user@example.com",
- Role: RoleUser,
- })
- require.NoError(t, err)
-
- _, _, err = svc.VerifyPendingOAuthToken(accessToken)
- require.ErrorIs(t, err, ErrInvalidToken)
-}
-
-// TestVerifyPendingOAuthToken_WrongPurpose 手动构造 purpose 字段不匹配的 JWT,应返回 ErrInvalidToken。
-func TestVerifyPendingOAuthToken_WrongPurpose(t *testing.T) {
- svc := newAuthServiceForPendingOAuthTest()
-
- now := time.Now()
- claims := &pendingOAuthClaims{
- Email: "user@example.com",
- Username: "alice",
- Purpose: "some_other_purpose",
- RegisteredClaims: jwt.RegisteredClaims{
- ExpiresAt: jwt.NewNumericDate(now.Add(10 * time.Minute)),
- IssuedAt: jwt.NewNumericDate(now),
- NotBefore: jwt.NewNumericDate(now),
- },
- }
- tok := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
- tokenStr, err := tok.SignedString([]byte(svc.cfg.JWT.Secret))
- require.NoError(t, err)
-
- _, _, err = svc.VerifyPendingOAuthToken(tokenStr)
- require.ErrorIs(t, err, ErrInvalidToken)
-}
-
-// TestVerifyPendingOAuthToken_MissingPurpose 手动构造无 purpose 字段的 JWT(模拟旧 token),应返回 ErrInvalidToken。
-func TestVerifyPendingOAuthToken_MissingPurpose(t *testing.T) {
- svc := newAuthServiceForPendingOAuthTest()
-
- now := time.Now()
- claims := &pendingOAuthClaims{
- Email: "user@example.com",
- Username: "alice",
- Purpose: "", // 旧 token 无此字段,反序列化后为零值
- RegisteredClaims: jwt.RegisteredClaims{
- ExpiresAt: jwt.NewNumericDate(now.Add(10 * time.Minute)),
- IssuedAt: jwt.NewNumericDate(now),
- NotBefore: jwt.NewNumericDate(now),
- },
- }
- tok := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
- tokenStr, err := tok.SignedString([]byte(svc.cfg.JWT.Secret))
- require.NoError(t, err)
-
- _, _, err = svc.VerifyPendingOAuthToken(tokenStr)
- require.ErrorIs(t, err, ErrInvalidToken)
-}
-
-// TestVerifyPendingOAuthToken_ExpiredToken 过期 token 应返回 ErrInvalidToken。
-func TestVerifyPendingOAuthToken_ExpiredToken(t *testing.T) {
- svc := newAuthServiceForPendingOAuthTest()
-
- past := time.Now().Add(-1 * time.Hour)
- claims := &pendingOAuthClaims{
- Email: "user@example.com",
- Username: "alice",
- Purpose: pendingOAuthPurpose,
- RegisteredClaims: jwt.RegisteredClaims{
- ExpiresAt: jwt.NewNumericDate(past),
- IssuedAt: jwt.NewNumericDate(past.Add(-10 * time.Minute)),
- NotBefore: jwt.NewNumericDate(past.Add(-10 * time.Minute)),
- },
- }
- tok := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
- tokenStr, err := tok.SignedString([]byte(svc.cfg.JWT.Secret))
- require.NoError(t, err)
-
- _, _, err = svc.VerifyPendingOAuthToken(tokenStr)
- require.ErrorIs(t, err, ErrInvalidToken)
-}
-
-// TestVerifyPendingOAuthToken_WrongSecret 不同密钥签发的 token 应返回 ErrInvalidToken。
-func TestVerifyPendingOAuthToken_WrongSecret(t *testing.T) {
- other := NewAuthService(nil, nil, nil, nil, &config.Config{
- JWT: config.JWTConfig{Secret: "other-secret"},
- }, nil, nil, nil, nil, nil, nil)
-
- token, err := other.CreatePendingOAuthToken("user@example.com", "alice")
- require.NoError(t, err)
-
- svc := newAuthServiceForPendingOAuthTest()
- _, _, err = svc.VerifyPendingOAuthToken(token)
- require.ErrorIs(t, err, ErrInvalidToken)
-}
-
-// TestVerifyPendingOAuthToken_TooLong 超长 token 应返回 ErrInvalidToken。
-func TestVerifyPendingOAuthToken_TooLong(t *testing.T) {
- svc := newAuthServiceForPendingOAuthTest()
- giant := make([]byte, maxTokenLength+1)
- for i := range giant {
- giant[i] = 'a'
- }
- _, _, err := svc.VerifyPendingOAuthToken(string(giant))
- require.ErrorIs(t, err, ErrInvalidToken)
-}
diff --git a/backend/internal/service/domain_constants.go b/backend/internal/service/domain_constants.go
index cb452efb..1dddf77e 100644
--- a/backend/internal/service/domain_constants.go
+++ b/backend/internal/service/domain_constants.go
@@ -74,6 +74,9 @@ const LinuxDoConnectSyntheticEmailDomain = "@linuxdo-connect.invalid"
// OIDCConnectSyntheticEmailDomain 是 OIDC 用户的合成邮箱后缀(RFC 保留域名)。
const OIDCConnectSyntheticEmailDomain = "@oidc-connect.invalid"
+// WeChatConnectSyntheticEmailDomain 是 WeChat Connect 用户的合成邮箱后缀(RFC 保留域名)。
+const WeChatConnectSyntheticEmailDomain = "@wechat-connect.invalid"
+
// Setting keys
const (
// 注册设置
@@ -153,6 +156,29 @@ const (
SettingKeyDefaultBalance = "default_balance" // 新用户默认余额
SettingKeyDefaultSubscriptions = "default_subscriptions" // 新用户默认订阅列表(JSON)
+ // 第三方认证来源默认授予配置
+ SettingKeyAuthSourceDefaultEmailBalance = "auth_source_default_email_balance"
+ SettingKeyAuthSourceDefaultEmailConcurrency = "auth_source_default_email_concurrency"
+ SettingKeyAuthSourceDefaultEmailSubscriptions = "auth_source_default_email_subscriptions"
+ SettingKeyAuthSourceDefaultEmailGrantOnSignup = "auth_source_default_email_grant_on_signup"
+ SettingKeyAuthSourceDefaultEmailGrantOnFirstBind = "auth_source_default_email_grant_on_first_bind"
+ SettingKeyAuthSourceDefaultLinuxDoBalance = "auth_source_default_linuxdo_balance"
+ SettingKeyAuthSourceDefaultLinuxDoConcurrency = "auth_source_default_linuxdo_concurrency"
+ SettingKeyAuthSourceDefaultLinuxDoSubscriptions = "auth_source_default_linuxdo_subscriptions"
+ SettingKeyAuthSourceDefaultLinuxDoGrantOnSignup = "auth_source_default_linuxdo_grant_on_signup"
+ SettingKeyAuthSourceDefaultLinuxDoGrantOnFirstBind = "auth_source_default_linuxdo_grant_on_first_bind"
+ SettingKeyAuthSourceDefaultOIDCBalance = "auth_source_default_oidc_balance"
+ SettingKeyAuthSourceDefaultOIDCConcurrency = "auth_source_default_oidc_concurrency"
+ SettingKeyAuthSourceDefaultOIDCSubscriptions = "auth_source_default_oidc_subscriptions"
+ SettingKeyAuthSourceDefaultOIDCGrantOnSignup = "auth_source_default_oidc_grant_on_signup"
+ SettingKeyAuthSourceDefaultOIDCGrantOnFirstBind = "auth_source_default_oidc_grant_on_first_bind"
+ SettingKeyAuthSourceDefaultWeChatBalance = "auth_source_default_wechat_balance"
+ SettingKeyAuthSourceDefaultWeChatConcurrency = "auth_source_default_wechat_concurrency"
+ SettingKeyAuthSourceDefaultWeChatSubscriptions = "auth_source_default_wechat_subscriptions"
+ SettingKeyAuthSourceDefaultWeChatGrantOnSignup = "auth_source_default_wechat_grant_on_signup"
+ SettingKeyAuthSourceDefaultWeChatGrantOnFirstBind = "auth_source_default_wechat_grant_on_first_bind"
+ SettingKeyForceEmailOnThirdPartySignup = "force_email_on_third_party_signup"
+
// 管理员 API Key
SettingKeyAdminAPIKey = "admin_api_key" // 全局管理员 API Key(用于外部系统集成)
diff --git a/backend/internal/service/openai_account_scheduler.go b/backend/internal/service/openai_account_scheduler.go
index 6c09e354..09e60220 100644
--- a/backend/internal/service/openai_account_scheduler.go
+++ b/backend/internal/service/openai_account_scheduler.go
@@ -13,14 +13,30 @@ import (
"sync"
"sync/atomic"
"time"
+
+ "golang.org/x/sync/singleflight"
)
const (
openAIAccountScheduleLayerPreviousResponse = "previous_response_id"
openAIAccountScheduleLayerSessionSticky = "session_hash"
openAIAccountScheduleLayerLoadBalance = "load_balance"
+ openAIAdvancedSchedulerSettingKey = "openai_advanced_scheduler_enabled"
)
+const (
+ openAIAdvancedSchedulerSettingCacheTTL = 5 * time.Second
+ openAIAdvancedSchedulerSettingDBTimeout = 2 * time.Second
+)
+
+type cachedOpenAIAdvancedSchedulerSetting struct {
+ enabled bool
+ expiresAt int64
+}
+
+var openAIAdvancedSchedulerSettingCache atomic.Value // *cachedOpenAIAdvancedSchedulerSetting
+var openAIAdvancedSchedulerSettingSF singleflight.Group
+
type OpenAIAccountScheduleRequest struct {
GroupID *int64
SessionHash string
@@ -805,10 +821,56 @@ func (s *defaultOpenAIAccountScheduler) SnapshotMetrics() OpenAIAccountScheduler
return snapshot
}
-func (s *OpenAIGatewayService) getOpenAIAccountScheduler() OpenAIAccountScheduler {
+func (s *OpenAIGatewayService) openAIAdvancedSchedulerSettingRepo() SettingRepository {
+ if s == nil || s.rateLimitService == nil || s.rateLimitService.settingService == nil {
+ return nil
+ }
+ return s.rateLimitService.settingService.settingRepo
+}
+
+func (s *OpenAIGatewayService) isOpenAIAdvancedSchedulerEnabled(ctx context.Context) bool {
+ if cached, ok := openAIAdvancedSchedulerSettingCache.Load().(*cachedOpenAIAdvancedSchedulerSetting); ok && cached != nil {
+ if time.Now().UnixNano() < cached.expiresAt {
+ return cached.enabled
+ }
+ }
+
+ result, _, _ := openAIAdvancedSchedulerSettingSF.Do(openAIAdvancedSchedulerSettingKey, func() (any, error) {
+ if cached, ok := openAIAdvancedSchedulerSettingCache.Load().(*cachedOpenAIAdvancedSchedulerSetting); ok && cached != nil {
+ if time.Now().UnixNano() < cached.expiresAt {
+ return cached.enabled, nil
+ }
+ }
+
+ enabled := false
+ if repo := s.openAIAdvancedSchedulerSettingRepo(); repo != nil {
+ dbCtx, cancel := context.WithTimeout(context.WithoutCancel(ctx), openAIAdvancedSchedulerSettingDBTimeout)
+ defer cancel()
+
+ value, err := repo.GetValue(dbCtx, openAIAdvancedSchedulerSettingKey)
+ if err == nil {
+ enabled = strings.EqualFold(strings.TrimSpace(value), "true")
+ }
+ }
+
+ openAIAdvancedSchedulerSettingCache.Store(&cachedOpenAIAdvancedSchedulerSetting{
+ enabled: enabled,
+ expiresAt: time.Now().Add(openAIAdvancedSchedulerSettingCacheTTL).UnixNano(),
+ })
+ return enabled, nil
+ })
+
+ enabled, _ := result.(bool)
+ return enabled
+}
+
+func (s *OpenAIGatewayService) getOpenAIAccountScheduler(ctx context.Context) OpenAIAccountScheduler {
if s == nil {
return nil
}
+ if !s.isOpenAIAdvancedSchedulerEnabled(ctx) {
+ return nil
+ }
s.openaiSchedulerOnce.Do(func() {
if s.openaiAccountStats == nil {
s.openaiAccountStats = newOpenAIAccountRuntimeStats()
@@ -820,6 +882,11 @@ func (s *OpenAIGatewayService) getOpenAIAccountScheduler() OpenAIAccountSchedule
return s.openaiScheduler
}
+func resetOpenAIAdvancedSchedulerSettingCacheForTest() {
+ openAIAdvancedSchedulerSettingCache = atomic.Value{}
+ openAIAdvancedSchedulerSettingSF = singleflight.Group{}
+}
+
func (s *OpenAIGatewayService) SelectAccountWithScheduler(
ctx context.Context,
groupID *int64,
@@ -830,7 +897,7 @@ func (s *OpenAIGatewayService) SelectAccountWithScheduler(
requiredTransport OpenAIUpstreamTransport,
) (*AccountSelectionResult, OpenAIAccountScheduleDecision, error) {
decision := OpenAIAccountScheduleDecision{}
- scheduler := s.getOpenAIAccountScheduler()
+ scheduler := s.getOpenAIAccountScheduler(ctx)
if scheduler == nil {
selection, err := s.SelectAccountWithLoadAwareness(ctx, groupID, sessionHash, requestedModel, excludedIDs)
decision.Layer = openAIAccountScheduleLayerLoadBalance
@@ -856,7 +923,7 @@ func (s *OpenAIGatewayService) SelectAccountWithScheduler(
}
func (s *OpenAIGatewayService) ReportOpenAIAccountScheduleResult(accountID int64, success bool, firstTokenMs *int) {
- scheduler := s.getOpenAIAccountScheduler()
+ scheduler := s.getOpenAIAccountScheduler(context.Background())
if scheduler == nil {
return
}
@@ -864,7 +931,7 @@ func (s *OpenAIGatewayService) ReportOpenAIAccountScheduleResult(accountID int64
}
func (s *OpenAIGatewayService) RecordOpenAIAccountSwitch() {
- scheduler := s.getOpenAIAccountScheduler()
+ scheduler := s.getOpenAIAccountScheduler(context.Background())
if scheduler == nil {
return
}
@@ -872,7 +939,7 @@ func (s *OpenAIGatewayService) RecordOpenAIAccountSwitch() {
}
func (s *OpenAIGatewayService) SnapshotOpenAIAccountSchedulerMetrics() OpenAIAccountSchedulerMetricsSnapshot {
- scheduler := s.getOpenAIAccountScheduler()
+ scheduler := s.getOpenAIAccountScheduler(context.Background())
if scheduler == nil {
return OpenAIAccountSchedulerMetricsSnapshot{}
}
diff --git a/backend/internal/service/openai_account_scheduler_test.go b/backend/internal/service/openai_account_scheduler_test.go
index 088815ed..a54f2614 100644
--- a/backend/internal/service/openai_account_scheduler_test.go
+++ b/backend/internal/service/openai_account_scheduler_test.go
@@ -2,6 +2,7 @@ package service
import (
"context"
+ "errors"
"fmt"
"math"
"sync"
@@ -18,6 +19,202 @@ type openAISnapshotCacheStub struct {
accountsByID map[int64]*Account
}
+type schedulerTestOpenAIAccountRepo struct {
+ AccountRepository
+ accounts []Account
+}
+
+func (r schedulerTestOpenAIAccountRepo) GetByID(ctx context.Context, id int64) (*Account, error) {
+ for i := range r.accounts {
+ if r.accounts[i].ID == id {
+ return &r.accounts[i], nil
+ }
+ }
+ return nil, errors.New("account not found")
+}
+
+func (r schedulerTestOpenAIAccountRepo) ListSchedulableByGroupIDAndPlatform(ctx context.Context, groupID int64, platform string) ([]Account, error) {
+ var result []Account
+ for _, acc := range r.accounts {
+ if acc.Platform == platform {
+ result = append(result, acc)
+ }
+ }
+ return result, nil
+}
+
+func (r schedulerTestOpenAIAccountRepo) ListSchedulableByPlatform(ctx context.Context, platform string) ([]Account, error) {
+ var result []Account
+ for _, acc := range r.accounts {
+ if acc.Platform == platform {
+ result = append(result, acc)
+ }
+ }
+ return result, nil
+}
+
+func (r schedulerTestOpenAIAccountRepo) ListSchedulableUngroupedByPlatform(ctx context.Context, platform string) ([]Account, error) {
+ return r.ListSchedulableByPlatform(ctx, platform)
+}
+
+type schedulerTestConcurrencyCache struct {
+ ConcurrencyCache
+ loadBatchErr error
+ loadMap map[int64]*AccountLoadInfo
+ acquireResults map[int64]bool
+ waitCounts map[int64]int
+ skipDefaultLoad bool
+}
+
+func (c schedulerTestConcurrencyCache) AcquireAccountSlot(ctx context.Context, accountID int64, maxConcurrency int, requestID string) (bool, error) {
+ if c.acquireResults != nil {
+ if result, ok := c.acquireResults[accountID]; ok {
+ return result, nil
+ }
+ }
+ return true, nil
+}
+
+func (c schedulerTestConcurrencyCache) ReleaseAccountSlot(ctx context.Context, accountID int64, requestID string) error {
+ return nil
+}
+
+func (c schedulerTestConcurrencyCache) GetAccountsLoadBatch(ctx context.Context, accounts []AccountWithConcurrency) (map[int64]*AccountLoadInfo, error) {
+ if c.loadBatchErr != nil {
+ return nil, c.loadBatchErr
+ }
+ out := make(map[int64]*AccountLoadInfo, len(accounts))
+ if c.skipDefaultLoad && c.loadMap != nil {
+ for _, acc := range accounts {
+ if load, ok := c.loadMap[acc.ID]; ok {
+ out[acc.ID] = load
+ }
+ }
+ return out, nil
+ }
+ for _, acc := range accounts {
+ if c.loadMap != nil {
+ if load, ok := c.loadMap[acc.ID]; ok {
+ out[acc.ID] = load
+ continue
+ }
+ }
+ out[acc.ID] = &AccountLoadInfo{AccountID: acc.ID, LoadRate: 0}
+ }
+ return out, nil
+}
+
+func (c schedulerTestConcurrencyCache) GetAccountWaitingCount(ctx context.Context, accountID int64) (int, error) {
+ if c.waitCounts != nil {
+ if count, ok := c.waitCounts[accountID]; ok {
+ return count, nil
+ }
+ }
+ return 0, nil
+}
+
+type schedulerTestGatewayCache struct {
+ sessionBindings map[string]int64
+ deletedSessions map[string]int
+}
+
+func (c *schedulerTestGatewayCache) GetSessionAccountID(ctx context.Context, groupID int64, sessionHash string) (int64, error) {
+ if id, ok := c.sessionBindings[sessionHash]; ok {
+ return id, nil
+ }
+ return 0, errors.New("not found")
+}
+
+func (c *schedulerTestGatewayCache) SetSessionAccountID(ctx context.Context, groupID int64, sessionHash string, accountID int64, ttl time.Duration) error {
+ if c.sessionBindings == nil {
+ c.sessionBindings = make(map[string]int64)
+ }
+ c.sessionBindings[sessionHash] = accountID
+ return nil
+}
+
+func (c *schedulerTestGatewayCache) RefreshSessionTTL(ctx context.Context, groupID int64, sessionHash string, ttl time.Duration) error {
+ return nil
+}
+
+func (c *schedulerTestGatewayCache) DeleteSessionAccountID(ctx context.Context, groupID int64, sessionHash string) error {
+ if c.sessionBindings == nil {
+ return nil
+ }
+ if c.deletedSessions == nil {
+ c.deletedSessions = make(map[string]int)
+ }
+ c.deletedSessions[sessionHash]++
+ delete(c.sessionBindings, sessionHash)
+ return nil
+}
+
+func newSchedulerTestOpenAIWSV2Config() *config.Config {
+ cfg := &config.Config{}
+ cfg.Gateway.OpenAIWS.Enabled = true
+ cfg.Gateway.OpenAIWS.OAuthEnabled = true
+ cfg.Gateway.OpenAIWS.APIKeyEnabled = true
+ cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true
+ cfg.Gateway.OpenAIWS.StickyResponseIDTTLSeconds = 3600
+ return cfg
+}
+
+type openAIAdvancedSchedulerSettingRepoStub struct {
+ values map[string]string
+}
+
+func (s *openAIAdvancedSchedulerSettingRepoStub) Get(ctx context.Context, key string) (*Setting, error) {
+ value, err := s.GetValue(ctx, key)
+ if err != nil {
+ return nil, err
+ }
+ return &Setting{Key: key, Value: value}, nil
+}
+
+func (s *openAIAdvancedSchedulerSettingRepoStub) GetValue(_ context.Context, key string) (string, error) {
+ if s == nil || s.values == nil {
+ return "", ErrSettingNotFound
+ }
+ value, ok := s.values[key]
+ if !ok {
+ return "", ErrSettingNotFound
+ }
+ return value, nil
+}
+
+func (s *openAIAdvancedSchedulerSettingRepoStub) Set(context.Context, string, string) error {
+ panic("unexpected call to Set")
+}
+
+func (s *openAIAdvancedSchedulerSettingRepoStub) GetMultiple(context.Context, []string) (map[string]string, error) {
+ panic("unexpected call to GetMultiple")
+}
+
+func (s *openAIAdvancedSchedulerSettingRepoStub) SetMultiple(context.Context, map[string]string) error {
+ panic("unexpected call to SetMultiple")
+}
+
+func (s *openAIAdvancedSchedulerSettingRepoStub) GetAll(context.Context) (map[string]string, error) {
+ panic("unexpected call to GetAll")
+}
+
+func (s *openAIAdvancedSchedulerSettingRepoStub) Delete(context.Context, string) error {
+ panic("unexpected call to Delete")
+}
+
+func newOpenAIAdvancedSchedulerRateLimitService(enabled string) *RateLimitService {
+ resetOpenAIAdvancedSchedulerSettingCacheForTest()
+ repo := &openAIAdvancedSchedulerSettingRepoStub{
+ values: map[string]string{},
+ }
+ if enabled != "" {
+ repo.values[openAIAdvancedSchedulerSettingKey] = enabled
+ }
+ return &RateLimitService{
+ settingService: NewSettingService(repo, &config.Config{}),
+ }
+}
+
func (s *openAISnapshotCacheStub) GetSnapshot(ctx context.Context, bucket SchedulerBucket) ([]*Account, bool, error) {
if len(s.snapshotAccounts) == 0 {
return nil, false, nil
@@ -45,6 +242,138 @@ func (s *openAISnapshotCacheStub) GetAccount(ctx context.Context, accountID int6
return &cloned, nil
}
+func TestOpenAIGatewayService_SelectAccountWithScheduler_DefaultDisabledUsesLegacyLoadAwareness(t *testing.T) {
+ resetOpenAIAdvancedSchedulerSettingCacheForTest()
+
+ ctx := context.Background()
+ groupID := int64(10106)
+ accounts := []Account{
+ {
+ ID: 36001,
+ Platform: PlatformOpenAI,
+ Type: AccountTypeAPIKey,
+ Status: StatusActive,
+ Schedulable: true,
+ Concurrency: 1,
+ Priority: 5,
+ },
+ {
+ ID: 36002,
+ Platform: PlatformOpenAI,
+ Type: AccountTypeAPIKey,
+ Status: StatusActive,
+ Schedulable: true,
+ Concurrency: 1,
+ Priority: 0,
+ },
+ }
+ cfg := &config.Config{}
+ cfg.Gateway.Scheduling.LoadBatchEnabled = false
+ cache := &schedulerTestGatewayCache{}
+ svc := &OpenAIGatewayService{
+ accountRepo: schedulerTestOpenAIAccountRepo{accounts: accounts},
+ cache: cache,
+ cfg: cfg,
+ concurrencyService: NewConcurrencyService(schedulerTestConcurrencyCache{}),
+ }
+
+ store := svc.getOpenAIWSStateStore()
+ require.NoError(t, store.BindResponseAccount(ctx, groupID, "resp_disabled_001", 36001, time.Hour))
+ require.False(t, svc.isOpenAIAdvancedSchedulerEnabled(ctx))
+
+ selection, decision, err := svc.SelectAccountWithScheduler(
+ ctx,
+ &groupID,
+ "resp_disabled_001",
+ "",
+ "gpt-5.1",
+ nil,
+ OpenAIUpstreamTransportAny,
+ )
+ require.NoError(t, err)
+ require.NotNil(t, selection)
+ require.NotNil(t, selection.Account)
+ require.Equal(t, int64(36002), selection.Account.ID)
+ require.Equal(t, openAIAccountScheduleLayerLoadBalance, decision.Layer)
+ require.False(t, decision.StickyPreviousHit)
+}
+
+func TestOpenAIGatewayService_SelectAccountWithScheduler_EnabledUsesAdvancedPreviousResponseRouting(t *testing.T) {
+ resetOpenAIAdvancedSchedulerSettingCacheForTest()
+
+ ctx := context.Background()
+ groupID := int64(10107)
+ accounts := []Account{
+ {
+ ID: 37001,
+ Platform: PlatformOpenAI,
+ Type: AccountTypeAPIKey,
+ Status: StatusActive,
+ Schedulable: true,
+ Concurrency: 1,
+ Priority: 5,
+ Extra: map[string]any{
+ "openai_apikey_responses_websockets_v2_enabled": true,
+ },
+ },
+ {
+ ID: 37002,
+ Platform: PlatformOpenAI,
+ Type: AccountTypeAPIKey,
+ Status: StatusActive,
+ Schedulable: true,
+ Concurrency: 1,
+ Priority: 0,
+ },
+ }
+ cfg := &config.Config{}
+ cfg.Gateway.Scheduling.LoadBatchEnabled = false
+ cfg.Gateway.OpenAIWS.Enabled = true
+ cfg.Gateway.OpenAIWS.OAuthEnabled = true
+ cfg.Gateway.OpenAIWS.APIKeyEnabled = true
+ cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true
+ cfg.Gateway.OpenAIWS.StickyResponseIDTTLSeconds = 3600
+ svc := &OpenAIGatewayService{
+ accountRepo: schedulerTestOpenAIAccountRepo{accounts: accounts},
+ cache: &schedulerTestGatewayCache{},
+ cfg: cfg,
+ rateLimitService: newOpenAIAdvancedSchedulerRateLimitService("true"),
+ concurrencyService: NewConcurrencyService(schedulerTestConcurrencyCache{}),
+ }
+
+ store := svc.getOpenAIWSStateStore()
+ require.NoError(t, store.BindResponseAccount(ctx, groupID, "resp_enabled_001", 37001, time.Hour))
+ require.True(t, svc.isOpenAIAdvancedSchedulerEnabled(ctx))
+
+ selection, decision, err := svc.SelectAccountWithScheduler(
+ ctx,
+ &groupID,
+ "resp_enabled_001",
+ "",
+ "gpt-5.1",
+ nil,
+ OpenAIUpstreamTransportAny,
+ )
+ require.NoError(t, err)
+ require.NotNil(t, selection)
+ require.NotNil(t, selection.Account)
+ require.Equal(t, int64(37001), selection.Account.ID)
+ require.Equal(t, openAIAccountScheduleLayerPreviousResponse, decision.Layer)
+ require.True(t, decision.StickyPreviousHit)
+}
+
+func TestOpenAIGatewayService_OpenAIAccountSchedulerMetrics_DisabledNoOp(t *testing.T) {
+ resetOpenAIAdvancedSchedulerSettingCacheForTest()
+
+ svc := &OpenAIGatewayService{}
+ ttft := 120
+ svc.ReportOpenAIAccountScheduleResult(10, true, &ttft)
+ svc.RecordOpenAIAccountSwitch()
+
+ snapshot := svc.SnapshotOpenAIAccountSchedulerMetrics()
+ require.Equal(t, OpenAIAccountSchedulerMetricsSnapshot{}, snapshot)
+}
+
func TestOpenAIGatewayService_SelectAccountWithScheduler_SessionStickyRateLimitedAccountFallsBackToFreshCandidate(t *testing.T) {
ctx := context.Background()
groupID := int64(10101)
@@ -53,10 +382,17 @@ func TestOpenAIGatewayService_SelectAccountWithScheduler_SessionStickyRateLimite
staleBackup := &Account{ID: 31002, Platform: PlatformOpenAI, Type: AccountTypeOAuth, Status: StatusActive, Schedulable: true, Concurrency: 1, Priority: 5}
freshSticky := &Account{ID: 31001, Platform: PlatformOpenAI, Type: AccountTypeOAuth, Status: StatusActive, Schedulable: true, Concurrency: 1, Priority: 0, RateLimitResetAt: &rateLimitedUntil}
freshBackup := &Account{ID: 31002, Platform: PlatformOpenAI, Type: AccountTypeOAuth, Status: StatusActive, Schedulable: true, Concurrency: 1, Priority: 5}
- cache := &stubGatewayCache{sessionBindings: map[string]int64{"openai:session_hash_rate_limited": 31001}}
+ cache := &schedulerTestGatewayCache{sessionBindings: map[string]int64{"openai:session_hash_rate_limited": 31001}}
snapshotCache := &openAISnapshotCacheStub{snapshotAccounts: []*Account{staleSticky, staleBackup}, accountsByID: map[int64]*Account{31001: freshSticky, 31002: freshBackup}}
snapshotService := &SchedulerSnapshotService{cache: snapshotCache}
- svc := &OpenAIGatewayService{accountRepo: stubOpenAIAccountRepo{accounts: []Account{*freshSticky, *freshBackup}}, cache: cache, cfg: &config.Config{}, schedulerSnapshot: snapshotService, concurrencyService: NewConcurrencyService(stubConcurrencyCache{})}
+ svc := &OpenAIGatewayService{
+ accountRepo: schedulerTestOpenAIAccountRepo{accounts: []Account{*freshSticky, *freshBackup}},
+ cache: cache,
+ cfg: &config.Config{},
+ rateLimitService: newOpenAIAdvancedSchedulerRateLimitService("true"),
+ schedulerSnapshot: snapshotService,
+ concurrencyService: NewConcurrencyService(schedulerTestConcurrencyCache{}),
+ }
selection, decision, err := svc.SelectAccountWithScheduler(ctx, &groupID, "", "session_hash_rate_limited", "gpt-5.1", nil, OpenAIUpstreamTransportAny)
require.NoError(t, err)
@@ -76,7 +412,12 @@ func TestOpenAIGatewayService_SelectAccountForModelWithExclusions_SkipsFreshlyRa
freshSecondary := &Account{ID: 32002, Platform: PlatformOpenAI, Type: AccountTypeOAuth, Status: StatusActive, Schedulable: true, Concurrency: 1, Priority: 5}
snapshotCache := &openAISnapshotCacheStub{snapshotAccounts: []*Account{stalePrimary, staleSecondary}, accountsByID: map[int64]*Account{32001: freshPrimary, 32002: freshSecondary}}
snapshotService := &SchedulerSnapshotService{cache: snapshotCache}
- svc := &OpenAIGatewayService{accountRepo: stubOpenAIAccountRepo{accounts: []Account{*freshPrimary, *freshSecondary}}, cfg: &config.Config{}, schedulerSnapshot: snapshotService}
+ svc := &OpenAIGatewayService{
+ accountRepo: schedulerTestOpenAIAccountRepo{accounts: []Account{*freshPrimary, *freshSecondary}},
+ cfg: &config.Config{},
+ rateLimitService: newOpenAIAdvancedSchedulerRateLimitService("true"),
+ schedulerSnapshot: snapshotService,
+ }
account, err := svc.SelectAccountForModelWithExclusions(ctx, &groupID, "", "gpt-5.1", nil)
require.NoError(t, err)
@@ -92,18 +433,19 @@ func TestOpenAIGatewayService_SelectAccountWithScheduler_SessionStickyDBRuntimeR
staleBackup := &Account{ID: 33002, Platform: PlatformOpenAI, Type: AccountTypeOAuth, Status: StatusActive, Schedulable: true, Concurrency: 1, Priority: 5}
dbSticky := Account{ID: 33001, Platform: PlatformOpenAI, Type: AccountTypeOAuth, Status: StatusActive, Schedulable: true, Concurrency: 1, Priority: 0, RateLimitResetAt: &rateLimitedUntil}
dbBackup := Account{ID: 33002, Platform: PlatformOpenAI, Type: AccountTypeOAuth, Status: StatusActive, Schedulable: true, Concurrency: 1, Priority: 5}
- cache := &stubGatewayCache{sessionBindings: map[string]int64{"openai:session_hash_db_runtime_recheck": 33001}}
+ cache := &schedulerTestGatewayCache{sessionBindings: map[string]int64{"openai:session_hash_db_runtime_recheck": 33001}}
snapshotCache := &openAISnapshotCacheStub{
snapshotAccounts: []*Account{staleSticky, staleBackup},
accountsByID: map[int64]*Account{33001: staleSticky, 33002: staleBackup},
}
snapshotService := &SchedulerSnapshotService{cache: snapshotCache}
svc := &OpenAIGatewayService{
- accountRepo: stubOpenAIAccountRepo{accounts: []Account{dbSticky, dbBackup}},
+ accountRepo: schedulerTestOpenAIAccountRepo{accounts: []Account{dbSticky, dbBackup}},
cache: cache,
cfg: &config.Config{},
+ rateLimitService: newOpenAIAdvancedSchedulerRateLimitService("true"),
schedulerSnapshot: snapshotService,
- concurrencyService: NewConcurrencyService(stubConcurrencyCache{}),
+ concurrencyService: NewConcurrencyService(schedulerTestConcurrencyCache{}),
}
selection, decision, err := svc.SelectAccountWithScheduler(ctx, &groupID, "", "session_hash_db_runtime_recheck", "gpt-5.1", nil, OpenAIUpstreamTransportAny)
@@ -128,8 +470,9 @@ func TestOpenAIGatewayService_SelectAccountForModelWithExclusions_DBRuntimeReche
}
snapshotService := &SchedulerSnapshotService{cache: snapshotCache}
svc := &OpenAIGatewayService{
- accountRepo: stubOpenAIAccountRepo{accounts: []Account{dbPrimary, dbSecondary}},
+ accountRepo: schedulerTestOpenAIAccountRepo{accounts: []Account{dbPrimary, dbSecondary}},
cfg: &config.Config{},
+ rateLimitService: newOpenAIAdvancedSchedulerRateLimitService("true"),
schedulerSnapshot: snapshotService,
}
@@ -153,7 +496,7 @@ func TestOpenAIGatewayService_SelectAccountWithScheduler_PreviousResponseSticky(
"openai_apikey_responses_websockets_v2_enabled": true,
},
}
- cache := &stubGatewayCache{}
+ cache := &schedulerTestGatewayCache{}
cfg := &config.Config{}
cfg.Gateway.OpenAIWS.Enabled = true
cfg.Gateway.OpenAIWS.OAuthEnabled = true
@@ -163,10 +506,11 @@ func TestOpenAIGatewayService_SelectAccountWithScheduler_PreviousResponseSticky(
cfg.Gateway.OpenAIWS.StickyResponseIDTTLSeconds = 3600
svc := &OpenAIGatewayService{
- accountRepo: stubOpenAIAccountRepo{accounts: []Account{account}},
+ accountRepo: schedulerTestOpenAIAccountRepo{accounts: []Account{account}},
cache: cache,
cfg: cfg,
- concurrencyService: NewConcurrencyService(stubConcurrencyCache{}),
+ rateLimitService: newOpenAIAdvancedSchedulerRateLimitService("true"),
+ concurrencyService: NewConcurrencyService(schedulerTestConcurrencyCache{}),
}
store := svc.getOpenAIWSStateStore()
@@ -204,17 +548,18 @@ func TestOpenAIGatewayService_SelectAccountWithScheduler_SessionSticky(t *testin
Schedulable: true,
Concurrency: 1,
}
- cache := &stubGatewayCache{
+ cache := &schedulerTestGatewayCache{
sessionBindings: map[string]int64{
"openai:session_hash_abc": account.ID,
},
}
svc := &OpenAIGatewayService{
- accountRepo: stubOpenAIAccountRepo{accounts: []Account{account}},
+ accountRepo: schedulerTestOpenAIAccountRepo{accounts: []Account{account}},
cache: cache,
cfg: &config.Config{},
- concurrencyService: NewConcurrencyService(stubConcurrencyCache{}),
+ rateLimitService: newOpenAIAdvancedSchedulerRateLimitService("true"),
+ concurrencyService: NewConcurrencyService(schedulerTestConcurrencyCache{}),
}
selection, decision, err := svc.SelectAccountWithScheduler(
@@ -260,7 +605,7 @@ func TestOpenAIGatewayService_SelectAccountWithScheduler_SessionStickyBusyKeepsS
Priority: 9,
},
}
- cache := &stubGatewayCache{
+ cache := &schedulerTestGatewayCache{
sessionBindings: map[string]int64{
"openai:session_hash_sticky_busy": 21001,
},
@@ -273,7 +618,7 @@ func TestOpenAIGatewayService_SelectAccountWithScheduler_SessionStickyBusyKeepsS
cfg.Gateway.OpenAIWS.OAuthEnabled = true
cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true
- concurrencyCache := stubConcurrencyCache{
+ concurrencyCache := schedulerTestConcurrencyCache{
acquireResults: map[int64]bool{
21001: false, // sticky 账号已满
21002: true, // 若回退负载均衡会命中该账号(本测试要求不能切换)
@@ -288,9 +633,10 @@ func TestOpenAIGatewayService_SelectAccountWithScheduler_SessionStickyBusyKeepsS
}
svc := &OpenAIGatewayService{
- accountRepo: stubOpenAIAccountRepo{accounts: accounts},
+ accountRepo: schedulerTestOpenAIAccountRepo{accounts: accounts},
cache: cache,
cfg: cfg,
+ rateLimitService: newOpenAIAdvancedSchedulerRateLimitService("true"),
concurrencyService: NewConcurrencyService(concurrencyCache),
}
@@ -328,17 +674,18 @@ func TestOpenAIGatewayService_SelectAccountWithScheduler_SessionSticky_ForceHTTP
"openai_ws_force_http": true,
},
}
- cache := &stubGatewayCache{
+ cache := &schedulerTestGatewayCache{
sessionBindings: map[string]int64{
"openai:session_hash_force_http": account.ID,
},
}
svc := &OpenAIGatewayService{
- accountRepo: stubOpenAIAccountRepo{accounts: []Account{account}},
+ accountRepo: schedulerTestOpenAIAccountRepo{accounts: []Account{account}},
cache: cache,
cfg: &config.Config{},
- concurrencyService: NewConcurrencyService(stubConcurrencyCache{}),
+ rateLimitService: newOpenAIAdvancedSchedulerRateLimitService("true"),
+ concurrencyService: NewConcurrencyService(schedulerTestConcurrencyCache{}),
}
selection, decision, err := svc.SelectAccountWithScheduler(
@@ -387,15 +734,15 @@ func TestOpenAIGatewayService_SelectAccountWithScheduler_RequiredWSV2_SkipsStick
},
},
}
- cache := &stubGatewayCache{
+ cache := &schedulerTestGatewayCache{
sessionBindings: map[string]int64{
"openai:session_hash_ws_only": 2201,
},
}
- cfg := newOpenAIWSV2TestConfig()
+ cfg := newSchedulerTestOpenAIWSV2Config()
// 构造“HTTP-only 账号负载更低”的场景,验证 required transport 会强制过滤。
- concurrencyCache := stubConcurrencyCache{
+ concurrencyCache := schedulerTestConcurrencyCache{
loadMap: map[int64]*AccountLoadInfo{
2201: {AccountID: 2201, LoadRate: 0, WaitingCount: 0},
2202: {AccountID: 2202, LoadRate: 90, WaitingCount: 5},
@@ -403,9 +750,10 @@ func TestOpenAIGatewayService_SelectAccountWithScheduler_RequiredWSV2_SkipsStick
}
svc := &OpenAIGatewayService{
- accountRepo: stubOpenAIAccountRepo{accounts: accounts},
+ accountRepo: schedulerTestOpenAIAccountRepo{accounts: accounts},
cache: cache,
cfg: cfg,
+ rateLimitService: newOpenAIAdvancedSchedulerRateLimitService("true"),
concurrencyService: NewConcurrencyService(concurrencyCache),
}
@@ -445,10 +793,11 @@ func TestOpenAIGatewayService_SelectAccountWithScheduler_RequiredWSV2_NoAvailabl
}
svc := &OpenAIGatewayService{
- accountRepo: stubOpenAIAccountRepo{accounts: accounts},
- cache: &stubGatewayCache{},
- cfg: newOpenAIWSV2TestConfig(),
- concurrencyService: NewConcurrencyService(stubConcurrencyCache{}),
+ accountRepo: schedulerTestOpenAIAccountRepo{accounts: accounts},
+ cache: &schedulerTestGatewayCache{},
+ cfg: newSchedulerTestOpenAIWSV2Config(),
+ rateLimitService: newOpenAIAdvancedSchedulerRateLimitService("true"),
+ concurrencyService: NewConcurrencyService(schedulerTestConcurrencyCache{}),
}
selection, decision, err := svc.SelectAccountWithScheduler(
@@ -507,7 +856,7 @@ func TestOpenAIGatewayService_SelectAccountWithScheduler_LoadBalanceTopKFallback
cfg.Gateway.OpenAIWS.SchedulerScoreWeights.ErrorRate = 0.2
cfg.Gateway.OpenAIWS.SchedulerScoreWeights.TTFT = 0.1
- concurrencyCache := stubConcurrencyCache{
+ concurrencyCache := schedulerTestConcurrencyCache{
loadMap: map[int64]*AccountLoadInfo{
3001: {AccountID: 3001, LoadRate: 95, WaitingCount: 8},
3002: {AccountID: 3002, LoadRate: 20, WaitingCount: 1},
@@ -520,9 +869,10 @@ func TestOpenAIGatewayService_SelectAccountWithScheduler_LoadBalanceTopKFallback
}
svc := &OpenAIGatewayService{
- accountRepo: stubOpenAIAccountRepo{accounts: accounts},
- cache: &stubGatewayCache{},
+ accountRepo: schedulerTestOpenAIAccountRepo{accounts: accounts},
+ cache: &schedulerTestGatewayCache{},
cfg: cfg,
+ rateLimitService: newOpenAIAdvancedSchedulerRateLimitService("true"),
concurrencyService: NewConcurrencyService(concurrencyCache),
}
@@ -559,16 +909,17 @@ func TestOpenAIGatewayService_OpenAIAccountSchedulerMetrics(t *testing.T) {
Schedulable: true,
Concurrency: 1,
}
- cache := &stubGatewayCache{
+ cache := &schedulerTestGatewayCache{
sessionBindings: map[string]int64{
"openai:session_hash_metrics": account.ID,
},
}
svc := &OpenAIGatewayService{
- accountRepo: stubOpenAIAccountRepo{accounts: []Account{account}},
+ accountRepo: schedulerTestOpenAIAccountRepo{accounts: []Account{account}},
cache: cache,
cfg: &config.Config{},
- concurrencyService: NewConcurrencyService(stubConcurrencyCache{}),
+ rateLimitService: newOpenAIAdvancedSchedulerRateLimitService("true"),
+ concurrencyService: NewConcurrencyService(schedulerTestConcurrencyCache{}),
}
selection, _, err := svc.SelectAccountWithScheduler(ctx, &groupID, "", "session_hash_metrics", "gpt-5.1", nil, OpenAIUpstreamTransportAny)
@@ -749,7 +1100,7 @@ func TestOpenAIGatewayService_SelectAccountWithScheduler_LoadBalanceDistributesA
cfg.Gateway.OpenAIWS.SchedulerScoreWeights.ErrorRate = 1
cfg.Gateway.OpenAIWS.SchedulerScoreWeights.TTFT = 1
- concurrencyCache := stubConcurrencyCache{
+ concurrencyCache := schedulerTestConcurrencyCache{
loadMap: map[int64]*AccountLoadInfo{
5101: {AccountID: 5101, LoadRate: 20, WaitingCount: 1},
5102: {AccountID: 5102, LoadRate: 20, WaitingCount: 1},
@@ -757,9 +1108,10 @@ func TestOpenAIGatewayService_SelectAccountWithScheduler_LoadBalanceDistributesA
},
}
svc := &OpenAIGatewayService{
- accountRepo: stubOpenAIAccountRepo{accounts: accounts},
- cache: &stubGatewayCache{sessionBindings: map[string]int64{}},
+ accountRepo: schedulerTestOpenAIAccountRepo{accounts: accounts},
+ cache: &schedulerTestGatewayCache{sessionBindings: map[string]int64{}},
cfg: cfg,
+ rateLimitService: newOpenAIAdvancedSchedulerRateLimitService("true"),
concurrencyService: NewConcurrencyService(concurrencyCache),
}
@@ -905,12 +1257,14 @@ func TestDefaultOpenAIAccountScheduler_ReportSwitchAndSnapshot(t *testing.T) {
}
func TestOpenAIGatewayService_SchedulerWrappersAndDefaults(t *testing.T) {
+ resetOpenAIAdvancedSchedulerSettingCacheForTest()
+
svc := &OpenAIGatewayService{}
ttft := 120
svc.ReportOpenAIAccountScheduleResult(10, true, &ttft)
svc.RecordOpenAIAccountSwitch()
snapshot := svc.SnapshotOpenAIAccountSchedulerMetrics()
- require.GreaterOrEqual(t, snapshot.AccountSwitchTotal, int64(1))
+ require.Equal(t, OpenAIAccountSchedulerMetricsSnapshot{}, snapshot)
require.Equal(t, 7, svc.openAIWSLBTopK())
require.Equal(t, openaiStickySessionTTL, svc.openAIWSSessionStickyTTL())
@@ -947,7 +1301,7 @@ func TestDefaultOpenAIAccountScheduler_IsAccountTransportCompatible_Branches(t *
require.True(t, scheduler.isAccountTransportCompatible(nil, OpenAIUpstreamTransportHTTPSSE))
require.False(t, scheduler.isAccountTransportCompatible(nil, OpenAIUpstreamTransportResponsesWebsocketV2))
- cfg := newOpenAIWSV2TestConfig()
+ cfg := newSchedulerTestOpenAIWSV2Config()
scheduler.service = &OpenAIGatewayService{cfg: cfg}
account := &Account{
ID: 8801,
diff --git a/backend/internal/service/openai_account_scheduler_ws_snapshot_test.go b/backend/internal/service/openai_account_scheduler_ws_snapshot_test.go
index c5de8203..ddafc6eb 100644
--- a/backend/internal/service/openai_account_scheduler_ws_snapshot_test.go
+++ b/backend/internal/service/openai_account_scheduler_ws_snapshot_test.go
@@ -38,11 +38,12 @@ func TestOpenAIGatewayService_SelectAccountWithScheduler_UsesWSPassthroughSnapsh
cfg.Gateway.OpenAIWS.IngressModeDefault = OpenAIWSIngressModeCtxPool
svc := &OpenAIGatewayService{
- accountRepo: stubOpenAIAccountRepo{accounts: []Account{*account}},
- cache: &stubGatewayCache{},
+ accountRepo: schedulerTestOpenAIAccountRepo{accounts: []Account{*account}},
+ cache: &schedulerTestGatewayCache{},
cfg: cfg,
+ rateLimitService: newOpenAIAdvancedSchedulerRateLimitService("true"),
schedulerSnapshot: &SchedulerSnapshotService{cache: snapshotCache},
- concurrencyService: NewConcurrencyService(stubConcurrencyCache{}),
+ concurrencyService: NewConcurrencyService(schedulerTestConcurrencyCache{}),
}
selection, decision, err := svc.SelectAccountWithScheduler(
diff --git a/backend/internal/service/payment_config_service.go b/backend/internal/service/payment_config_service.go
index 59764b29..34462a3a 100644
--- a/backend/internal/service/payment_config_service.go
+++ b/backend/internal/service/payment_config_service.go
@@ -196,12 +196,25 @@ func (s *PaymentConfigService) GetPaymentConfig(ctx context.Context) (*PaymentCo
SettingHelpImageURL, SettingHelpText,
SettingCancelRateLimitOn, SettingCancelRateLimitMax,
SettingCancelWindowSize, SettingCancelWindowUnit, SettingCancelWindowMode,
+ SettingPaymentVisibleMethodAlipayEnabled, SettingPaymentVisibleMethodAlipaySource,
+ SettingPaymentVisibleMethodWxpayEnabled, SettingPaymentVisibleMethodWxpaySource,
}
vals, err := s.settingRepo.GetMultiple(ctx, keys)
if err != nil {
return nil, fmt.Errorf("get payment config settings: %w", err)
}
cfg := s.parsePaymentConfig(vals)
+ if s.entClient != nil {
+ instances, err := s.entClient.PaymentProviderInstance.Query().
+ Where(paymentproviderinstance.EnabledEQ(true)).
+ All(ctx)
+ if err != nil {
+ return nil, fmt.Errorf("list enabled provider instances: %w", err)
+ }
+ cfg.EnabledTypes = applyVisibleMethodRoutingToEnabledTypes(cfg.EnabledTypes, vals, buildVisibleMethodSourceAvailability(instances))
+ } else {
+ cfg.EnabledTypes = applyVisibleMethodRoutingToEnabledTypes(cfg.EnabledTypes, vals, nil)
+ }
// Load Stripe publishable key from the first enabled Stripe provider instance
cfg.StripePublishableKey = s.getStripePublishableKey(ctx)
return cfg, nil
@@ -234,18 +247,23 @@ func (s *PaymentConfigService) parsePaymentConfig(vals map[string]string) *Payme
cfg.LoadBalanceStrategy = payment.DefaultLoadBalanceStrategy
}
if raw := vals[SettingEnabledPaymentTypes]; raw != "" {
+ types := make([]string, 0, len(strings.Split(raw, ",")))
for _, t := range strings.Split(raw, ",") {
t = strings.TrimSpace(t)
if t != "" {
- cfg.EnabledTypes = append(cfg.EnabledTypes, t)
+ types = append(types, t)
}
}
+ cfg.EnabledTypes = NormalizeVisibleMethods(types)
}
return cfg
}
// getStripePublishableKey finds the publishable key from the first enabled Stripe provider instance.
func (s *PaymentConfigService) getStripePublishableKey(ctx context.Context) string {
+ if s.entClient == nil {
+ return ""
+ }
instances, err := s.entClient.PaymentProviderInstance.Query().
Where(
paymentproviderinstance.EnabledEQ(true),
@@ -385,3 +403,79 @@ func pcParseInt(s string, defaultVal int) int {
}
return v
}
+
+func buildVisibleMethodSourceAvailability(instances []*dbent.PaymentProviderInstance) map[string]bool {
+ available := make(map[string]bool, 4)
+ for _, inst := range instances {
+ switch inst.ProviderKey {
+ case payment.TypeAlipay:
+ if inst.SupportedTypes == "" || payment.InstanceSupportsType(inst.SupportedTypes, payment.TypeAlipay) || payment.InstanceSupportsType(inst.SupportedTypes, payment.TypeAlipayDirect) {
+ available[VisibleMethodSourceOfficialAlipay] = true
+ }
+ case payment.TypeWxpay:
+ if inst.SupportedTypes == "" || payment.InstanceSupportsType(inst.SupportedTypes, payment.TypeWxpay) || payment.InstanceSupportsType(inst.SupportedTypes, payment.TypeWxpayDirect) {
+ available[VisibleMethodSourceOfficialWechat] = true
+ }
+ case payment.TypeEasyPay:
+ for _, supportedType := range splitTypes(inst.SupportedTypes) {
+ switch NormalizeVisibleMethod(supportedType) {
+ case payment.TypeAlipay:
+ available[VisibleMethodSourceEasyPayAlipay] = true
+ case payment.TypeWxpay:
+ available[VisibleMethodSourceEasyPayWechat] = true
+ }
+ }
+ }
+ }
+ return available
+}
+
+func applyVisibleMethodRoutingToEnabledTypes(base []string, vals map[string]string, available map[string]bool) []string {
+ shouldExpose := map[string]bool{
+ payment.TypeAlipay: visibleMethodShouldBeExposed(payment.TypeAlipay, vals, available),
+ payment.TypeWxpay: visibleMethodShouldBeExposed(payment.TypeWxpay, vals, available),
+ }
+
+ seen := make(map[string]struct{}, len(base)+2)
+ out := make([]string, 0, len(base)+2)
+ appendType := func(paymentType string) {
+ paymentType = NormalizeVisibleMethod(paymentType)
+ if paymentType == "" {
+ return
+ }
+ if _, ok := seen[paymentType]; ok {
+ return
+ }
+ seen[paymentType] = struct{}{}
+ out = append(out, paymentType)
+ }
+
+ for _, paymentType := range base {
+ visibleMethod := NormalizeVisibleMethod(paymentType)
+ switch visibleMethod {
+ case payment.TypeAlipay, payment.TypeWxpay:
+ if shouldExpose[visibleMethod] {
+ appendType(visibleMethod)
+ }
+ default:
+ appendType(visibleMethod)
+ }
+ }
+
+ for _, visibleMethod := range []string{payment.TypeAlipay, payment.TypeWxpay} {
+ if shouldExpose[visibleMethod] {
+ appendType(visibleMethod)
+ }
+ }
+ return out
+}
+
+func visibleMethodShouldBeExposed(method string, vals map[string]string, available map[string]bool) bool {
+ enabledKey := visibleMethodEnabledSettingKey(method)
+ sourceKey := visibleMethodSourceSettingKey(method)
+ if enabledKey == "" || sourceKey == "" || vals[enabledKey] != "true" {
+ return false
+ }
+ source := NormalizeVisibleMethodSource(method, vals[sourceKey])
+ return source != "" && available[source]
+}
diff --git a/backend/internal/service/payment_config_service_test.go b/backend/internal/service/payment_config_service_test.go
index 027bb796..10919058 100644
--- a/backend/internal/service/payment_config_service_test.go
+++ b/backend/internal/service/payment_config_service_test.go
@@ -1,9 +1,17 @@
package service
import (
+ "context"
+ "database/sql"
"testing"
+ dbent "github.com/Wei-Shaw/sub2api/ent"
+ "github.com/Wei-Shaw/sub2api/ent/enttest"
"github.com/Wei-Shaw/sub2api/internal/payment"
+
+ "entgo.io/ent/dialect"
+ entsql "entgo.io/ent/dialect/sql"
+ _ "modernc.org/sqlite"
)
func TestPcParseFloat(t *testing.T) {
@@ -163,6 +171,20 @@ func TestParsePaymentConfig(t *testing.T) {
}
})
+ t.Run("enabled types are normalized to visible methods and deduplicated", func(t *testing.T) {
+ t.Parallel()
+ vals := map[string]string{
+ SettingEnabledPaymentTypes: "alipay_direct, alipay, wxpay_direct, wxpay",
+ }
+ cfg := svc.parsePaymentConfig(vals)
+ if len(cfg.EnabledTypes) != 2 {
+ t.Fatalf("EnabledTypes len = %d, want 2", len(cfg.EnabledTypes))
+ }
+ if cfg.EnabledTypes[0] != "alipay" || cfg.EnabledTypes[1] != "wxpay" {
+ t.Fatalf("EnabledTypes = %v, want [alipay wxpay]", cfg.EnabledTypes)
+ }
+ })
+
t.Run("empty enabled types string", func(t *testing.T) {
t.Parallel()
vals := map[string]string{
@@ -204,3 +226,167 @@ func TestGetBasePaymentType(t *testing.T) {
})
}
}
+
+func TestApplyVisibleMethodRoutingToEnabledTypes(t *testing.T) {
+ t.Parallel()
+
+ base := []string{"alipay", "wxpay", "stripe"}
+ vals := map[string]string{
+ SettingPaymentVisibleMethodAlipayEnabled: "true",
+ SettingPaymentVisibleMethodAlipaySource: VisibleMethodSourceOfficialAlipay,
+ SettingPaymentVisibleMethodWxpayEnabled: "true",
+ SettingPaymentVisibleMethodWxpaySource: VisibleMethodSourceOfficialWechat,
+ }
+ available := map[string]bool{
+ VisibleMethodSourceOfficialAlipay: true,
+ VisibleMethodSourceOfficialWechat: false,
+ }
+
+ got := applyVisibleMethodRoutingToEnabledTypes(base, vals, available)
+ want := []string{"alipay", "stripe"}
+ if len(got) != len(want) {
+ t.Fatalf("applyVisibleMethodRoutingToEnabledTypes len = %d, want %d (%v)", len(got), len(want), got)
+ }
+ for i := range want {
+ if got[i] != want[i] {
+ t.Fatalf("applyVisibleMethodRoutingToEnabledTypes[%d] = %q, want %q (full=%v)", i, got[i], want[i], got)
+ }
+ }
+}
+
+func TestApplyVisibleMethodRoutingAddsConfiguredVisibleMethod(t *testing.T) {
+ t.Parallel()
+
+ base := []string{"stripe"}
+ vals := map[string]string{
+ SettingPaymentVisibleMethodAlipayEnabled: "true",
+ SettingPaymentVisibleMethodAlipaySource: VisibleMethodSourceEasyPayAlipay,
+ }
+ available := map[string]bool{
+ VisibleMethodSourceEasyPayAlipay: true,
+ }
+
+ got := applyVisibleMethodRoutingToEnabledTypes(base, vals, available)
+ want := []string{"stripe", "alipay"}
+ if len(got) != len(want) {
+ t.Fatalf("applyVisibleMethodRoutingToEnabledTypes len = %d, want %d (%v)", len(got), len(want), got)
+ }
+ for i := range want {
+ if got[i] != want[i] {
+ t.Fatalf("applyVisibleMethodRoutingToEnabledTypes[%d] = %q, want %q (full=%v)", i, got[i], want[i], got)
+ }
+ }
+}
+
+func TestBuildVisibleMethodSourceAvailability(t *testing.T) {
+ t.Parallel()
+
+ instances := []*dbent.PaymentProviderInstance{
+ {ProviderKey: payment.TypeAlipay, SupportedTypes: "alipay"},
+ {ProviderKey: payment.TypeEasyPay, SupportedTypes: "wxpay_direct, alipay"},
+ {ProviderKey: payment.TypeWxpay, SupportedTypes: "wxpay_direct"},
+ }
+
+ got := buildVisibleMethodSourceAvailability(instances)
+ if !got[VisibleMethodSourceOfficialAlipay] {
+ t.Fatalf("expected %q to be available", VisibleMethodSourceOfficialAlipay)
+ }
+ if !got[VisibleMethodSourceEasyPayAlipay] {
+ t.Fatalf("expected %q to be available", VisibleMethodSourceEasyPayAlipay)
+ }
+ if !got[VisibleMethodSourceOfficialWechat] {
+ t.Fatalf("expected %q to be available", VisibleMethodSourceOfficialWechat)
+ }
+ if !got[VisibleMethodSourceEasyPayWechat] {
+ t.Fatalf("expected %q to be available", VisibleMethodSourceEasyPayWechat)
+ }
+}
+
+func TestGetPaymentConfigAppliesVisibleMethodRouting(t *testing.T) {
+ ctx := context.Background()
+ client := newPaymentConfigServiceTestClient(t)
+
+ _, err := client.PaymentProviderInstance.Create().
+ SetProviderKey(payment.TypeEasyPay).
+ SetName("EasyPay Alipay").
+ SetConfig("{}").
+ SetSupportedTypes("alipay").
+ SetEnabled(true).
+ Save(ctx)
+ if err != nil {
+ t.Fatalf("create easypay instance: %v", err)
+ }
+
+ svc := &PaymentConfigService{
+ entClient: client,
+ settingRepo: &paymentConfigSettingRepoStub{
+ values: map[string]string{
+ SettingEnabledPaymentTypes: "alipay,wxpay,stripe",
+ SettingPaymentVisibleMethodAlipayEnabled: "true",
+ SettingPaymentVisibleMethodAlipaySource: "easypay",
+ SettingPaymentVisibleMethodWxpayEnabled: "true",
+ SettingPaymentVisibleMethodWxpaySource: "wxpay",
+ },
+ },
+ }
+
+ cfg, err := svc.GetPaymentConfig(ctx)
+ if err != nil {
+ t.Fatalf("GetPaymentConfig returned error: %v", err)
+ }
+
+ want := []string{payment.TypeAlipay, payment.TypeStripe}
+ if len(cfg.EnabledTypes) != len(want) {
+ t.Fatalf("EnabledTypes len = %d, want %d (%v)", len(cfg.EnabledTypes), len(want), cfg.EnabledTypes)
+ }
+ for i := range want {
+ if cfg.EnabledTypes[i] != want[i] {
+ t.Fatalf("EnabledTypes[%d] = %q, want %q (full=%v)", i, cfg.EnabledTypes[i], want[i], cfg.EnabledTypes)
+ }
+ }
+}
+
+func newPaymentConfigServiceTestClient(t *testing.T) *dbent.Client {
+ t.Helper()
+
+ db, err := sql.Open("sqlite", "file:payment_config_service?mode=memory&cache=shared")
+ if err != nil {
+ t.Fatalf("open sqlite: %v", err)
+ }
+ t.Cleanup(func() { _ = db.Close() })
+
+ if _, err := db.Exec("PRAGMA foreign_keys = ON"); err != nil {
+ t.Fatalf("enable foreign keys: %v", err)
+ }
+
+ drv := entsql.OpenDB(dialect.SQLite, db)
+ client := enttest.NewClient(t, enttest.WithOptions(dbent.Driver(drv)))
+ t.Cleanup(func() { _ = client.Close() })
+ return client
+}
+
+type paymentConfigSettingRepoStub struct {
+ values map[string]string
+}
+
+func (s *paymentConfigSettingRepoStub) Get(context.Context, string) (*Setting, error) {
+ return nil, nil
+}
+func (s *paymentConfigSettingRepoStub) GetValue(_ context.Context, key string) (string, error) {
+ return s.values[key], nil
+}
+func (s *paymentConfigSettingRepoStub) Set(context.Context, string, string) error { return nil }
+func (s *paymentConfigSettingRepoStub) GetMultiple(_ context.Context, keys []string) (map[string]string, error) {
+ out := make(map[string]string, len(keys))
+ for _, key := range keys {
+ out[key] = s.values[key]
+ }
+ return out, nil
+}
+func (s *paymentConfigSettingRepoStub) SetMultiple(context.Context, map[string]string) error {
+ return nil
+}
+func (s *paymentConfigSettingRepoStub) GetAll(context.Context) (map[string]string, error) {
+ return s.values, nil
+}
+func (s *paymentConfigSettingRepoStub) Delete(context.Context, string) error { return nil }
diff --git a/backend/internal/service/payment_resume_service.go b/backend/internal/service/payment_resume_service.go
new file mode 100644
index 00000000..894a8198
--- /dev/null
+++ b/backend/internal/service/payment_resume_service.go
@@ -0,0 +1,248 @@
+package service
+
+import (
+ "context"
+ "crypto/hmac"
+ "crypto/sha256"
+ "encoding/base64"
+ "encoding/json"
+ "fmt"
+ "net/url"
+ "strings"
+ "time"
+
+ "github.com/Wei-Shaw/sub2api/internal/payment"
+ infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
+)
+
+const (
+ PaymentSourceHostedRedirect = "hosted_redirect"
+ PaymentSourceWechatInAppResume = "wechat_in_app_resume"
+
+ paymentResumeFallbackSigningKey = "sub2api-payment-resume"
+
+ SettingPaymentVisibleMethodAlipaySource = "payment_visible_method_alipay_source"
+ SettingPaymentVisibleMethodWxpaySource = "payment_visible_method_wxpay_source"
+ SettingPaymentVisibleMethodAlipayEnabled = "payment_visible_method_alipay_enabled"
+ SettingPaymentVisibleMethodWxpayEnabled = "payment_visible_method_wxpay_enabled"
+
+ VisibleMethodSourceOfficialAlipay = "official_alipay"
+ VisibleMethodSourceEasyPayAlipay = "easypay_alipay"
+ VisibleMethodSourceOfficialWechat = "official_wxpay"
+ VisibleMethodSourceEasyPayWechat = "easypay_wxpay"
+)
+
+type ResumeTokenClaims struct {
+ OrderID int64 `json:"oid"`
+ UserID int64 `json:"uid,omitempty"`
+ ProviderInstanceID string `json:"pi,omitempty"`
+ ProviderKey string `json:"pk,omitempty"`
+ PaymentType string `json:"pt,omitempty"`
+ CanonicalReturnURL string `json:"ru,omitempty"`
+ IssuedAt int64 `json:"iat"`
+}
+
+type PaymentResumeService struct {
+ signingKey []byte
+}
+
+type visibleMethodLoadBalancer struct {
+ inner payment.LoadBalancer
+ configService *PaymentConfigService
+}
+
+func NewPaymentResumeService(signingKey []byte) *PaymentResumeService {
+ return &PaymentResumeService{signingKey: signingKey}
+}
+
+func NormalizeVisibleMethod(method string) string {
+ return payment.GetBasePaymentType(strings.TrimSpace(method))
+}
+
+func NormalizeVisibleMethods(methods []string) []string {
+ if len(methods) == 0 {
+ return nil
+ }
+ seen := make(map[string]struct{}, len(methods))
+ out := make([]string, 0, len(methods))
+ for _, method := range methods {
+ normalized := NormalizeVisibleMethod(method)
+ if normalized == "" {
+ continue
+ }
+ if _, ok := seen[normalized]; ok {
+ continue
+ }
+ seen[normalized] = struct{}{}
+ out = append(out, normalized)
+ }
+ return out
+}
+
+func NormalizePaymentSource(source string) string {
+ switch strings.TrimSpace(strings.ToLower(source)) {
+ case "", PaymentSourceHostedRedirect:
+ return PaymentSourceHostedRedirect
+ case "wechat_in_app", "wxpay_resume", PaymentSourceWechatInAppResume:
+ return PaymentSourceWechatInAppResume
+ default:
+ return strings.TrimSpace(strings.ToLower(source))
+ }
+}
+
+func NormalizeVisibleMethodSource(method, source string) string {
+ switch NormalizeVisibleMethod(method) {
+ case payment.TypeAlipay:
+ switch strings.TrimSpace(strings.ToLower(source)) {
+ case VisibleMethodSourceOfficialAlipay, payment.TypeAlipay, payment.TypeAlipayDirect, "official":
+ return VisibleMethodSourceOfficialAlipay
+ case VisibleMethodSourceEasyPayAlipay, payment.TypeEasyPay:
+ return VisibleMethodSourceEasyPayAlipay
+ }
+ case payment.TypeWxpay:
+ switch strings.TrimSpace(strings.ToLower(source)) {
+ case VisibleMethodSourceOfficialWechat, payment.TypeWxpay, payment.TypeWxpayDirect, "wechat", "official":
+ return VisibleMethodSourceOfficialWechat
+ case VisibleMethodSourceEasyPayWechat, payment.TypeEasyPay:
+ return VisibleMethodSourceEasyPayWechat
+ }
+ }
+ return ""
+}
+
+func VisibleMethodProviderKeyForSource(method, source string) (string, bool) {
+ switch NormalizeVisibleMethodSource(method, source) {
+ case VisibleMethodSourceOfficialAlipay:
+ return payment.TypeAlipay, NormalizeVisibleMethod(method) == payment.TypeAlipay
+ case VisibleMethodSourceEasyPayAlipay:
+ return payment.TypeEasyPay, NormalizeVisibleMethod(method) == payment.TypeAlipay
+ case VisibleMethodSourceOfficialWechat:
+ return payment.TypeWxpay, NormalizeVisibleMethod(method) == payment.TypeWxpay
+ case VisibleMethodSourceEasyPayWechat:
+ return payment.TypeEasyPay, NormalizeVisibleMethod(method) == payment.TypeWxpay
+ default:
+ return "", false
+ }
+}
+
+func newVisibleMethodLoadBalancer(inner payment.LoadBalancer, configService *PaymentConfigService) payment.LoadBalancer {
+ if inner == nil || configService == nil || configService.settingRepo == nil {
+ return inner
+ }
+ return &visibleMethodLoadBalancer{inner: inner, configService: configService}
+}
+
+func (lb *visibleMethodLoadBalancer) GetInstanceConfig(ctx context.Context, instanceID int64) (map[string]string, error) {
+ return lb.inner.GetInstanceConfig(ctx, instanceID)
+}
+
+func (lb *visibleMethodLoadBalancer) SelectInstance(ctx context.Context, providerKey string, paymentType payment.PaymentType, strategy payment.Strategy, orderAmount float64) (*payment.InstanceSelection, error) {
+ visibleMethod := NormalizeVisibleMethod(paymentType)
+ if providerKey != "" || (visibleMethod != payment.TypeAlipay && visibleMethod != payment.TypeWxpay) {
+ return lb.inner.SelectInstance(ctx, providerKey, paymentType, strategy, orderAmount)
+ }
+
+ enabledKey := visibleMethodEnabledSettingKey(visibleMethod)
+ sourceKey := visibleMethodSourceSettingKey(visibleMethod)
+ vals, err := lb.configService.settingRepo.GetMultiple(ctx, []string{enabledKey, sourceKey})
+ if err != nil {
+ return nil, fmt.Errorf("load visible method routing for %s: %w", visibleMethod, err)
+ }
+ if vals[enabledKey] != "true" {
+ return nil, fmt.Errorf("visible payment method %s is disabled", visibleMethod)
+ }
+
+ targetProviderKey, ok := VisibleMethodProviderKeyForSource(visibleMethod, vals[sourceKey])
+ if !ok {
+ return nil, fmt.Errorf("visible payment method %s has no valid source", visibleMethod)
+ }
+ return lb.inner.SelectInstance(ctx, targetProviderKey, paymentType, strategy, orderAmount)
+}
+
+func visibleMethodEnabledSettingKey(method string) string {
+ switch NormalizeVisibleMethod(method) {
+ case payment.TypeAlipay:
+ return SettingPaymentVisibleMethodAlipayEnabled
+ case payment.TypeWxpay:
+ return SettingPaymentVisibleMethodWxpayEnabled
+ default:
+ return ""
+ }
+}
+
+func visibleMethodSourceSettingKey(method string) string {
+ switch NormalizeVisibleMethod(method) {
+ case payment.TypeAlipay:
+ return SettingPaymentVisibleMethodAlipaySource
+ case payment.TypeWxpay:
+ return SettingPaymentVisibleMethodWxpaySource
+ default:
+ return ""
+ }
+}
+
+func CanonicalizeReturnURL(raw string) (string, error) {
+ raw = strings.TrimSpace(raw)
+ if raw == "" {
+ return "", nil
+ }
+ parsed, err := url.Parse(raw)
+ if err != nil || !parsed.IsAbs() || parsed.Host == "" {
+ return "", infraerrors.BadRequest("INVALID_RETURN_URL", "return_url must be an absolute http/https URL")
+ }
+ if parsed.Scheme != "http" && parsed.Scheme != "https" {
+ return "", infraerrors.BadRequest("INVALID_RETURN_URL", "return_url must use http or https")
+ }
+ parsed.Fragment = ""
+ if parsed.Path == "" {
+ parsed.Path = "/"
+ }
+ return parsed.String(), nil
+}
+
+func (s *PaymentResumeService) CreateToken(claims ResumeTokenClaims) (string, error) {
+ if claims.OrderID <= 0 {
+ return "", fmt.Errorf("resume token requires order id")
+ }
+ if claims.IssuedAt == 0 {
+ claims.IssuedAt = time.Now().Unix()
+ }
+ payload, err := json.Marshal(claims)
+ if err != nil {
+ return "", fmt.Errorf("marshal resume claims: %w", err)
+ }
+ encodedPayload := base64.RawURLEncoding.EncodeToString(payload)
+ return encodedPayload + "." + s.sign(encodedPayload), nil
+}
+
+func (s *PaymentResumeService) ParseToken(token string) (*ResumeTokenClaims, error) {
+ parts := strings.Split(token, ".")
+ if len(parts) != 2 || parts[0] == "" || parts[1] == "" {
+ return nil, infraerrors.BadRequest("INVALID_RESUME_TOKEN", "resume token is malformed")
+ }
+ if !hmac.Equal([]byte(parts[1]), []byte(s.sign(parts[0]))) {
+ return nil, infraerrors.BadRequest("INVALID_RESUME_TOKEN", "resume token signature mismatch")
+ }
+ payload, err := base64.RawURLEncoding.DecodeString(parts[0])
+ if err != nil {
+ return nil, infraerrors.BadRequest("INVALID_RESUME_TOKEN", "resume token payload is malformed")
+ }
+ var claims ResumeTokenClaims
+ if err := json.Unmarshal(payload, &claims); err != nil {
+ return nil, infraerrors.BadRequest("INVALID_RESUME_TOKEN", "resume token payload is invalid")
+ }
+ if claims.OrderID <= 0 {
+ return nil, infraerrors.BadRequest("INVALID_RESUME_TOKEN", "resume token missing order id")
+ }
+ return &claims, nil
+}
+
+func (s *PaymentResumeService) sign(payload string) string {
+ key := s.signingKey
+ if len(key) == 0 {
+ key = []byte(paymentResumeFallbackSigningKey)
+ }
+ mac := hmac.New(sha256.New, key)
+ _, _ = mac.Write([]byte(payload))
+ return base64.RawURLEncoding.EncodeToString(mac.Sum(nil))
+}
diff --git a/backend/internal/service/payment_resume_service_test.go b/backend/internal/service/payment_resume_service_test.go
new file mode 100644
index 00000000..e56b4a88
--- /dev/null
+++ b/backend/internal/service/payment_resume_service_test.go
@@ -0,0 +1,240 @@
+//go:build unit
+
+package service
+
+import (
+ "context"
+ "testing"
+
+ "github.com/Wei-Shaw/sub2api/internal/payment"
+)
+
+func TestNormalizeVisibleMethods(t *testing.T) {
+ t.Parallel()
+
+ got := NormalizeVisibleMethods([]string{
+ "alipay_direct",
+ "alipay",
+ " wxpay_direct ",
+ "wxpay",
+ "stripe",
+ })
+
+ want := []string{"alipay", "wxpay", "stripe"}
+ if len(got) != len(want) {
+ t.Fatalf("NormalizeVisibleMethods len = %d, want %d (%v)", len(got), len(want), got)
+ }
+ for i := range want {
+ if got[i] != want[i] {
+ t.Fatalf("NormalizeVisibleMethods[%d] = %q, want %q (full=%v)", i, got[i], want[i], got)
+ }
+ }
+}
+
+func TestNormalizePaymentSource(t *testing.T) {
+ t.Parallel()
+
+ tests := []struct {
+ name string
+ input string
+ expect string
+ }{
+ {name: "empty uses default", input: "", expect: PaymentSourceHostedRedirect},
+ {name: "wechat alias normalized", input: "wechat_in_app", expect: PaymentSourceWechatInAppResume},
+ {name: "canonical value preserved", input: PaymentSourceWechatInAppResume, expect: PaymentSourceWechatInAppResume},
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ t.Parallel()
+ if got := NormalizePaymentSource(tt.input); got != tt.expect {
+ t.Fatalf("NormalizePaymentSource(%q) = %q, want %q", tt.input, got, tt.expect)
+ }
+ })
+ }
+}
+
+func TestCanonicalizeReturnURL(t *testing.T) {
+ t.Parallel()
+
+ got, err := CanonicalizeReturnURL("https://example.com/pay/result?b=2#a")
+ if err != nil {
+ t.Fatalf("CanonicalizeReturnURL returned error: %v", err)
+ }
+ if got != "https://example.com/pay/result?b=2" {
+ t.Fatalf("CanonicalizeReturnURL = %q, want %q", got, "https://example.com/pay/result?b=2")
+ }
+}
+
+func TestCanonicalizeReturnURLRejectsRelativeURL(t *testing.T) {
+ t.Parallel()
+
+ if _, err := CanonicalizeReturnURL("/payment/result"); err == nil {
+ t.Fatal("CanonicalizeReturnURL should reject relative URLs")
+ }
+}
+
+func TestPaymentResumeTokenRoundTrip(t *testing.T) {
+ t.Parallel()
+
+ svc := NewPaymentResumeService([]byte("0123456789abcdef0123456789abcdef"))
+ token, err := svc.CreateToken(ResumeTokenClaims{
+ OrderID: 42,
+ UserID: 7,
+ ProviderInstanceID: "19",
+ ProviderKey: "easypay",
+ PaymentType: "wxpay",
+ CanonicalReturnURL: "https://example.com/payment/result",
+ IssuedAt: 1234567890,
+ })
+ if err != nil {
+ t.Fatalf("CreateToken returned error: %v", err)
+ }
+
+ claims, err := svc.ParseToken(token)
+ if err != nil {
+ t.Fatalf("ParseToken returned error: %v", err)
+ }
+ if claims.OrderID != 42 || claims.UserID != 7 {
+ t.Fatalf("claims mismatch: %+v", claims)
+ }
+ if claims.ProviderInstanceID != "19" || claims.ProviderKey != "easypay" || claims.PaymentType != "wxpay" {
+ t.Fatalf("claims provider snapshot mismatch: %+v", claims)
+ }
+ if claims.CanonicalReturnURL != "https://example.com/payment/result" {
+ t.Fatalf("claims return URL = %q", claims.CanonicalReturnURL)
+ }
+}
+
+func TestNormalizeVisibleMethodSource(t *testing.T) {
+ t.Parallel()
+
+ tests := []struct {
+ name string
+ method string
+ input string
+ want string
+ }{
+ {name: "alipay official alias", method: payment.TypeAlipay, input: "alipay", want: VisibleMethodSourceOfficialAlipay},
+ {name: "alipay easypay alias", method: payment.TypeAlipay, input: "easypay", want: VisibleMethodSourceEasyPayAlipay},
+ {name: "wxpay official alias", method: payment.TypeWxpay, input: "wxpay", want: VisibleMethodSourceOfficialWechat},
+ {name: "wxpay easypay alias", method: payment.TypeWxpay, input: "easypay", want: VisibleMethodSourceEasyPayWechat},
+ {name: "unsupported source", method: payment.TypeWxpay, input: "stripe", want: ""},
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ t.Parallel()
+ if got := NormalizeVisibleMethodSource(tt.method, tt.input); got != tt.want {
+ t.Fatalf("NormalizeVisibleMethodSource(%q, %q) = %q, want %q", tt.method, tt.input, got, tt.want)
+ }
+ })
+ }
+}
+
+func TestVisibleMethodProviderKeyForSource(t *testing.T) {
+ t.Parallel()
+
+ tests := []struct {
+ name string
+ method string
+ source string
+ want string
+ ok bool
+ }{
+ {name: "official alipay", method: payment.TypeAlipay, source: VisibleMethodSourceOfficialAlipay, want: payment.TypeAlipay, ok: true},
+ {name: "easypay alipay", method: payment.TypeAlipay, source: VisibleMethodSourceEasyPayAlipay, want: payment.TypeEasyPay, ok: true},
+ {name: "official wechat", method: payment.TypeWxpay, source: VisibleMethodSourceOfficialWechat, want: payment.TypeWxpay, ok: true},
+ {name: "easypay wechat", method: payment.TypeWxpay, source: VisibleMethodSourceEasyPayWechat, want: payment.TypeEasyPay, ok: true},
+ {name: "mismatched method and source", method: payment.TypeAlipay, source: VisibleMethodSourceOfficialWechat, want: "", ok: false},
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ t.Parallel()
+ got, ok := VisibleMethodProviderKeyForSource(tt.method, tt.source)
+ if got != tt.want || ok != tt.ok {
+ t.Fatalf("VisibleMethodProviderKeyForSource(%q, %q) = (%q, %v), want (%q, %v)", tt.method, tt.source, got, ok, tt.want, tt.ok)
+ }
+ })
+ }
+}
+
+func TestVisibleMethodLoadBalancerUsesConfiguredSource(t *testing.T) {
+ t.Parallel()
+
+ inner := &captureLoadBalancer{}
+ configService := &PaymentConfigService{
+ settingRepo: &paymentSettingRepoStub{
+ values: map[string]string{
+ SettingPaymentVisibleMethodAlipayEnabled: "true",
+ SettingPaymentVisibleMethodAlipaySource: VisibleMethodSourceOfficialAlipay,
+ },
+ },
+ }
+ lb := newVisibleMethodLoadBalancer(inner, configService)
+
+ _, err := lb.SelectInstance(context.Background(), "", payment.TypeAlipay, payment.StrategyRoundRobin, 12.5)
+ if err != nil {
+ t.Fatalf("SelectInstance returned error: %v", err)
+ }
+ if inner.lastProviderKey != payment.TypeAlipay {
+ t.Fatalf("lastProviderKey = %q, want %q", inner.lastProviderKey, payment.TypeAlipay)
+ }
+}
+
+func TestVisibleMethodLoadBalancerRejectsDisabledVisibleMethod(t *testing.T) {
+ t.Parallel()
+
+ inner := &captureLoadBalancer{}
+ configService := &PaymentConfigService{
+ settingRepo: &paymentSettingRepoStub{
+ values: map[string]string{
+ SettingPaymentVisibleMethodWxpayEnabled: "false",
+ SettingPaymentVisibleMethodWxpaySource: VisibleMethodSourceOfficialWechat,
+ },
+ },
+ }
+ lb := newVisibleMethodLoadBalancer(inner, configService)
+
+ if _, err := lb.SelectInstance(context.Background(), "", payment.TypeWxpay, payment.StrategyRoundRobin, 9.9); err == nil {
+ t.Fatal("SelectInstance should reject disabled visible method")
+ }
+}
+
+type paymentSettingRepoStub struct {
+ values map[string]string
+}
+
+func (s *paymentSettingRepoStub) Get(context.Context, string) (*Setting, error) { return nil, nil }
+func (s *paymentSettingRepoStub) GetValue(_ context.Context, key string) (string, error) {
+ return s.values[key], nil
+}
+func (s *paymentSettingRepoStub) Set(context.Context, string, string) error { return nil }
+func (s *paymentSettingRepoStub) GetMultiple(_ context.Context, keys []string) (map[string]string, error) {
+ out := make(map[string]string, len(keys))
+ for _, key := range keys {
+ out[key] = s.values[key]
+ }
+ return out, nil
+}
+func (s *paymentSettingRepoStub) SetMultiple(context.Context, map[string]string) error { return nil }
+func (s *paymentSettingRepoStub) GetAll(context.Context) (map[string]string, error) {
+ return s.values, nil
+}
+func (s *paymentSettingRepoStub) Delete(context.Context, string) error { return nil }
+
+type captureLoadBalancer struct {
+ lastProviderKey string
+ lastPaymentType string
+}
+
+func (c *captureLoadBalancer) GetInstanceConfig(context.Context, int64) (map[string]string, error) {
+ return map[string]string{}, nil
+}
+
+func (c *captureLoadBalancer) SelectInstance(_ context.Context, providerKey string, paymentType payment.PaymentType, _ payment.Strategy, _ float64) (*payment.InstanceSelection, error) {
+ c.lastProviderKey = providerKey
+ c.lastPaymentType = paymentType
+ return &payment.InstanceSelection{ProviderKey: providerKey, SupportedTypes: paymentType}, nil
+}
diff --git a/backend/internal/service/payment_service.go b/backend/internal/service/payment_service.go
index 6fc23f97..e897741a 100644
--- a/backend/internal/service/payment_service.go
+++ b/backend/internal/service/payment_service.go
@@ -65,15 +65,17 @@ func generateRandomString(n int) string {
}
type CreateOrderRequest struct {
- UserID int64
- Amount float64
- PaymentType string
- ClientIP string
- IsMobile bool
- SrcHost string
- SrcURL string
- OrderType string
- PlanID int64
+ UserID int64
+ Amount float64
+ PaymentType string
+ ClientIP string
+ IsMobile bool
+ SrcHost string
+ SrcURL string
+ ReturnURL string
+ PaymentSource string
+ OrderType string
+ PlanID int64
}
type CreateOrderResponse struct {
@@ -88,6 +90,7 @@ type CreateOrderResponse struct {
ClientSecret string `json:"client_secret,omitempty"`
ExpiresAt time.Time `json:"expires_at"`
PaymentMode string `json:"payment_mode,omitempty"`
+ ResumeToken string `json:"resume_token,omitempty"`
}
type OrderListParams struct {
@@ -165,10 +168,13 @@ type PaymentService struct {
configService *PaymentConfigService
userRepo UserRepository
groupRepo GroupRepository
+ resumeService *PaymentResumeService
}
func NewPaymentService(entClient *dbent.Client, registry *payment.Registry, loadBalancer payment.LoadBalancer, redeemService *RedeemService, subscriptionSvc *SubscriptionService, configService *PaymentConfigService, userRepo UserRepository, groupRepo GroupRepository) *PaymentService {
- return &PaymentService{entClient: entClient, registry: registry, loadBalancer: loadBalancer, redeemService: redeemService, subscriptionSvc: subscriptionSvc, configService: configService, userRepo: userRepo, groupRepo: groupRepo}
+ svc := &PaymentService{entClient: entClient, registry: registry, loadBalancer: newVisibleMethodLoadBalancer(loadBalancer, configService), redeemService: redeemService, subscriptionSvc: subscriptionSvc, configService: configService, userRepo: userRepo, groupRepo: groupRepo}
+ svc.resumeService = NewPaymentResumeService(psResumeSigningKey(configService))
+ return svc
}
// --- Provider Registry ---
@@ -262,6 +268,20 @@ func psNilIfEmpty(s string) *string {
return &s
}
+func (s *PaymentService) paymentResume() *PaymentResumeService {
+ if s.resumeService != nil {
+ return s.resumeService
+ }
+ return NewPaymentResumeService(psResumeSigningKey(s.configService))
+}
+
+func psResumeSigningKey(configService *PaymentConfigService) []byte {
+ if configService == nil {
+ return nil
+ }
+ return configService.encryptionKey
+}
+
func psSliceContains(sl []string, s string) bool {
for _, v := range sl {
if v == s {
diff --git a/backend/internal/service/setting_service.go b/backend/internal/service/setting_service.go
index 7f4a2eb1..de555478 100644
--- a/backend/internal/service/setting_service.go
+++ b/backend/internal/service/setting_service.go
@@ -9,6 +9,7 @@ import (
"fmt"
"log/slog"
"net/url"
+ "os"
"sort"
"strconv"
"strings"
@@ -114,6 +115,66 @@ type SettingService struct {
webSearchManagerBuilder WebSearchManagerBuilder
}
+type ProviderDefaultGrantSettings struct {
+ Balance float64
+ Concurrency int
+ Subscriptions []DefaultSubscriptionSetting
+ GrantOnSignup bool
+ GrantOnFirstBind bool
+}
+
+type AuthSourceDefaultSettings struct {
+ Email ProviderDefaultGrantSettings
+ LinuxDo ProviderDefaultGrantSettings
+ OIDC ProviderDefaultGrantSettings
+ WeChat ProviderDefaultGrantSettings
+ ForceEmailOnThirdPartySignup bool
+}
+
+type authSourceDefaultKeySet struct {
+ balance string
+ concurrency string
+ subscriptions string
+ grantOnSignup string
+ grantOnFirstBind string
+}
+
+var (
+ emailAuthSourceDefaultKeys = authSourceDefaultKeySet{
+ balance: SettingKeyAuthSourceDefaultEmailBalance,
+ concurrency: SettingKeyAuthSourceDefaultEmailConcurrency,
+ subscriptions: SettingKeyAuthSourceDefaultEmailSubscriptions,
+ grantOnSignup: SettingKeyAuthSourceDefaultEmailGrantOnSignup,
+ grantOnFirstBind: SettingKeyAuthSourceDefaultEmailGrantOnFirstBind,
+ }
+ linuxDoAuthSourceDefaultKeys = authSourceDefaultKeySet{
+ balance: SettingKeyAuthSourceDefaultLinuxDoBalance,
+ concurrency: SettingKeyAuthSourceDefaultLinuxDoConcurrency,
+ subscriptions: SettingKeyAuthSourceDefaultLinuxDoSubscriptions,
+ grantOnSignup: SettingKeyAuthSourceDefaultLinuxDoGrantOnSignup,
+ grantOnFirstBind: SettingKeyAuthSourceDefaultLinuxDoGrantOnFirstBind,
+ }
+ oidcAuthSourceDefaultKeys = authSourceDefaultKeySet{
+ balance: SettingKeyAuthSourceDefaultOIDCBalance,
+ concurrency: SettingKeyAuthSourceDefaultOIDCConcurrency,
+ subscriptions: SettingKeyAuthSourceDefaultOIDCSubscriptions,
+ grantOnSignup: SettingKeyAuthSourceDefaultOIDCGrantOnSignup,
+ grantOnFirstBind: SettingKeyAuthSourceDefaultOIDCGrantOnFirstBind,
+ }
+ weChatAuthSourceDefaultKeys = authSourceDefaultKeySet{
+ balance: SettingKeyAuthSourceDefaultWeChatBalance,
+ concurrency: SettingKeyAuthSourceDefaultWeChatConcurrency,
+ subscriptions: SettingKeyAuthSourceDefaultWeChatSubscriptions,
+ grantOnSignup: SettingKeyAuthSourceDefaultWeChatGrantOnSignup,
+ grantOnFirstBind: SettingKeyAuthSourceDefaultWeChatGrantOnFirstBind,
+ }
+)
+
+const (
+ defaultAuthSourceBalance = 0
+ defaultAuthSourceConcurrency = 5
+)
+
// NewSettingService 创建系统设置服务实例
func NewSettingService(settingRepo SettingRepository, cfg *config.Config) *SettingService {
return &SettingService{
@@ -212,6 +273,7 @@ func (s *SettingService) GetPublicSettings(ctx context.Context) (*PublicSettings
if oidcProviderName == "" {
oidcProviderName = "OIDC"
}
+ weChatEnabled := isWeChatOAuthConfigured()
// Password reset requires email verification to be enabled
emailVerifyEnabled := settings[SettingKeyEmailVerifyEnabled] == "true"
@@ -254,6 +316,7 @@ func (s *SettingService) GetPublicSettings(ctx context.Context) (*PublicSettings
CustomMenuItems: settings[SettingKeyCustomMenuItems],
CustomEndpoints: settings[SettingKeyCustomEndpoints],
LinuxDoOAuthEnabled: linuxDoEnabled,
+ WeChatOAuthEnabled: weChatEnabled,
BackendModeEnabled: settings[SettingKeyBackendModeEnabled] == "true",
PaymentEnabled: settings[SettingPaymentEnabled] == "true",
OIDCOAuthEnabled: oidcEnabled,
@@ -310,6 +373,7 @@ func (s *SettingService) GetPublicSettingsForInjection(ctx context.Context) (any
CustomMenuItems json.RawMessage `json:"custom_menu_items"`
CustomEndpoints json.RawMessage `json:"custom_endpoints"`
LinuxDoOAuthEnabled bool `json:"linuxdo_oauth_enabled"`
+ WeChatOAuthEnabled bool `json:"wechat_oauth_enabled"`
BackendModeEnabled bool `json:"backend_mode_enabled"`
PaymentEnabled bool `json:"payment_enabled"`
OIDCOAuthEnabled bool `json:"oidc_oauth_enabled"`
@@ -344,6 +408,7 @@ func (s *SettingService) GetPublicSettingsForInjection(ctx context.Context) (any
CustomMenuItems: filterUserVisibleMenuItems(settings.CustomMenuItems),
CustomEndpoints: safeRawJSONArray(settings.CustomEndpoints),
LinuxDoOAuthEnabled: settings.LinuxDoOAuthEnabled,
+ WeChatOAuthEnabled: settings.WeChatOAuthEnabled,
BackendModeEnabled: settings.BackendModeEnabled,
PaymentEnabled: settings.PaymentEnabled,
OIDCOAuthEnabled: settings.OIDCOAuthEnabled,
@@ -392,6 +457,14 @@ func filterUserVisibleMenuItems(raw string) json.RawMessage {
return result
}
+func isWeChatOAuthConfigured() bool {
+ openConfigured := strings.TrimSpace(os.Getenv("WECHAT_OAUTH_OPEN_APP_ID")) != "" &&
+ strings.TrimSpace(os.Getenv("WECHAT_OAUTH_OPEN_APP_SECRET")) != ""
+ mpConfigured := strings.TrimSpace(os.Getenv("WECHAT_OAUTH_MP_APP_ID")) != "" &&
+ strings.TrimSpace(os.Getenv("WECHAT_OAUTH_MP_APP_SECRET")) != ""
+ return openConfigured || mpConfigured
+}
+
// safeRawJSONArray returns raw as json.RawMessage if it's valid JSON, otherwise "[]".
func safeRawJSONArray(raw string) json.RawMessage {
raw = strings.TrimSpace(raw)
@@ -919,6 +992,74 @@ func (s *SettingService) GetDefaultSubscriptions(ctx context.Context) []DefaultS
return parseDefaultSubscriptions(value)
}
+func (s *SettingService) GetAuthSourceDefaultSettings(ctx context.Context) (*AuthSourceDefaultSettings, error) {
+ keys := []string{
+ SettingKeyAuthSourceDefaultEmailBalance,
+ SettingKeyAuthSourceDefaultEmailConcurrency,
+ SettingKeyAuthSourceDefaultEmailSubscriptions,
+ SettingKeyAuthSourceDefaultEmailGrantOnSignup,
+ SettingKeyAuthSourceDefaultEmailGrantOnFirstBind,
+ SettingKeyAuthSourceDefaultLinuxDoBalance,
+ SettingKeyAuthSourceDefaultLinuxDoConcurrency,
+ SettingKeyAuthSourceDefaultLinuxDoSubscriptions,
+ SettingKeyAuthSourceDefaultLinuxDoGrantOnSignup,
+ SettingKeyAuthSourceDefaultLinuxDoGrantOnFirstBind,
+ SettingKeyAuthSourceDefaultOIDCBalance,
+ SettingKeyAuthSourceDefaultOIDCConcurrency,
+ SettingKeyAuthSourceDefaultOIDCSubscriptions,
+ SettingKeyAuthSourceDefaultOIDCGrantOnSignup,
+ SettingKeyAuthSourceDefaultOIDCGrantOnFirstBind,
+ SettingKeyAuthSourceDefaultWeChatBalance,
+ SettingKeyAuthSourceDefaultWeChatConcurrency,
+ SettingKeyAuthSourceDefaultWeChatSubscriptions,
+ SettingKeyAuthSourceDefaultWeChatGrantOnSignup,
+ SettingKeyAuthSourceDefaultWeChatGrantOnFirstBind,
+ SettingKeyForceEmailOnThirdPartySignup,
+ }
+
+ settings, err := s.settingRepo.GetMultiple(ctx, keys)
+ if err != nil {
+ return nil, fmt.Errorf("get auth source default settings: %w", err)
+ }
+
+ return &AuthSourceDefaultSettings{
+ Email: parseProviderDefaultGrantSettings(settings, emailAuthSourceDefaultKeys),
+ LinuxDo: parseProviderDefaultGrantSettings(settings, linuxDoAuthSourceDefaultKeys),
+ OIDC: parseProviderDefaultGrantSettings(settings, oidcAuthSourceDefaultKeys),
+ WeChat: parseProviderDefaultGrantSettings(settings, weChatAuthSourceDefaultKeys),
+ ForceEmailOnThirdPartySignup: settings[SettingKeyForceEmailOnThirdPartySignup] == "true",
+ }, nil
+}
+
+func (s *SettingService) UpdateAuthSourceDefaultSettings(ctx context.Context, settings *AuthSourceDefaultSettings) error {
+ if settings == nil {
+ return nil
+ }
+
+ for _, subscriptions := range [][]DefaultSubscriptionSetting{
+ settings.Email.Subscriptions,
+ settings.LinuxDo.Subscriptions,
+ settings.OIDC.Subscriptions,
+ settings.WeChat.Subscriptions,
+ } {
+ if err := s.validateDefaultSubscriptionGroups(ctx, subscriptions); err != nil {
+ return err
+ }
+ }
+
+ updates := make(map[string]string, 21)
+ writeProviderDefaultGrantUpdates(updates, emailAuthSourceDefaultKeys, settings.Email)
+ writeProviderDefaultGrantUpdates(updates, linuxDoAuthSourceDefaultKeys, settings.LinuxDo)
+ writeProviderDefaultGrantUpdates(updates, oidcAuthSourceDefaultKeys, settings.OIDC)
+ writeProviderDefaultGrantUpdates(updates, weChatAuthSourceDefaultKeys, settings.WeChat)
+ updates[SettingKeyForceEmailOnThirdPartySignup] = strconv.FormatBool(settings.ForceEmailOnThirdPartySignup)
+
+ if err := s.settingRepo.SetMultiple(ctx, updates); err != nil {
+ return fmt.Errorf("update auth source default settings: %w", err)
+ }
+ return nil
+}
+
// InitializeDefaultSettings 初始化默认设置
func (s *SettingService) InitializeDefaultSettings(ctx context.Context) error {
// 检查是否已有设置
@@ -933,25 +1074,46 @@ func (s *SettingService) InitializeDefaultSettings(ctx context.Context) error {
// 初始化默认设置
defaults := map[string]string{
- SettingKeyRegistrationEnabled: "true",
- SettingKeyEmailVerifyEnabled: "false",
- SettingKeyRegistrationEmailSuffixWhitelist: "[]",
- SettingKeyPromoCodeEnabled: "true", // 默认启用优惠码功能
- SettingKeySiteName: "Sub2API",
- SettingKeySiteLogo: "",
- SettingKeyPurchaseSubscriptionEnabled: "false",
- SettingKeyPurchaseSubscriptionURL: "",
- SettingKeyTableDefaultPageSize: "20",
- SettingKeyTablePageSizeOptions: "[10,20,50,100]",
- SettingKeyCustomMenuItems: "[]",
- SettingKeyCustomEndpoints: "[]",
- SettingKeyOIDCConnectEnabled: "false",
- SettingKeyOIDCConnectProviderName: "OIDC",
- SettingKeyDefaultConcurrency: strconv.Itoa(s.cfg.Default.UserConcurrency),
- SettingKeyDefaultBalance: strconv.FormatFloat(s.cfg.Default.UserBalance, 'f', 8, 64),
- SettingKeyDefaultSubscriptions: "[]",
- SettingKeySMTPPort: "587",
- SettingKeySMTPUseTLS: "false",
+ SettingKeyRegistrationEnabled: "true",
+ SettingKeyEmailVerifyEnabled: "false",
+ SettingKeyRegistrationEmailSuffixWhitelist: "[]",
+ SettingKeyPromoCodeEnabled: "true", // 默认启用优惠码功能
+ SettingKeySiteName: "Sub2API",
+ SettingKeySiteLogo: "",
+ SettingKeyPurchaseSubscriptionEnabled: "false",
+ SettingKeyPurchaseSubscriptionURL: "",
+ SettingKeyTableDefaultPageSize: "20",
+ SettingKeyTablePageSizeOptions: "[10,20,50,100]",
+ SettingKeyCustomMenuItems: "[]",
+ SettingKeyCustomEndpoints: "[]",
+ SettingKeyOIDCConnectEnabled: "false",
+ SettingKeyOIDCConnectProviderName: "OIDC",
+ SettingKeyDefaultConcurrency: strconv.Itoa(s.cfg.Default.UserConcurrency),
+ SettingKeyDefaultBalance: strconv.FormatFloat(s.cfg.Default.UserBalance, 'f', 8, 64),
+ SettingKeyDefaultSubscriptions: "[]",
+ SettingKeyAuthSourceDefaultEmailBalance: "0",
+ SettingKeyAuthSourceDefaultEmailConcurrency: "5",
+ SettingKeyAuthSourceDefaultEmailSubscriptions: "[]",
+ SettingKeyAuthSourceDefaultEmailGrantOnSignup: "true",
+ SettingKeyAuthSourceDefaultEmailGrantOnFirstBind: "false",
+ SettingKeyAuthSourceDefaultLinuxDoBalance: "0",
+ SettingKeyAuthSourceDefaultLinuxDoConcurrency: "5",
+ SettingKeyAuthSourceDefaultLinuxDoSubscriptions: "[]",
+ SettingKeyAuthSourceDefaultLinuxDoGrantOnSignup: "true",
+ SettingKeyAuthSourceDefaultLinuxDoGrantOnFirstBind: "false",
+ SettingKeyAuthSourceDefaultOIDCBalance: "0",
+ SettingKeyAuthSourceDefaultOIDCConcurrency: "5",
+ SettingKeyAuthSourceDefaultOIDCSubscriptions: "[]",
+ SettingKeyAuthSourceDefaultOIDCGrantOnSignup: "true",
+ SettingKeyAuthSourceDefaultOIDCGrantOnFirstBind: "false",
+ SettingKeyAuthSourceDefaultWeChatBalance: "0",
+ SettingKeyAuthSourceDefaultWeChatConcurrency: "5",
+ SettingKeyAuthSourceDefaultWeChatSubscriptions: "[]",
+ SettingKeyAuthSourceDefaultWeChatGrantOnSignup: "true",
+ SettingKeyAuthSourceDefaultWeChatGrantOnFirstBind: "false",
+ SettingKeyForceEmailOnThirdPartySignup: "false",
+ SettingKeySMTPPort: "587",
+ SettingKeySMTPUseTLS: "false",
// Model fallback defaults
SettingKeyEnableModelFallback: "false",
SettingKeyFallbackModelAnthropic: "claude-3-5-sonnet-20241022",
@@ -1164,6 +1326,8 @@ func (s *SettingService) parseSettings(settings map[string]string) *SystemSettin
} else {
result.OIDCConnectValidateIDToken = oidcBase.ValidateIDToken
}
+ result.OIDCConnectUsePKCE = true
+ result.OIDCConnectValidateIDToken = true
if v, ok := settings[SettingKeyOIDCConnectAllowedSigningAlgs]; ok && strings.TrimSpace(v) != "" {
result.OIDCConnectAllowedSigningAlgs = strings.TrimSpace(v)
} else {
@@ -1317,6 +1481,51 @@ func parseDefaultSubscriptions(raw string) []DefaultSubscriptionSetting {
return normalized
}
+func parseProviderDefaultGrantSettings(settings map[string]string, keys authSourceDefaultKeySet) ProviderDefaultGrantSettings {
+ result := ProviderDefaultGrantSettings{
+ Balance: defaultAuthSourceBalance,
+ Concurrency: defaultAuthSourceConcurrency,
+ Subscriptions: []DefaultSubscriptionSetting{},
+ GrantOnSignup: true,
+ GrantOnFirstBind: false,
+ }
+
+ if v, err := strconv.ParseFloat(strings.TrimSpace(settings[keys.balance]), 64); err == nil {
+ result.Balance = v
+ }
+ if v, err := strconv.Atoi(strings.TrimSpace(settings[keys.concurrency])); err == nil {
+ result.Concurrency = v
+ }
+ if items := parseDefaultSubscriptions(settings[keys.subscriptions]); items != nil {
+ result.Subscriptions = items
+ }
+ if raw, ok := settings[keys.grantOnSignup]; ok {
+ result.GrantOnSignup = raw == "true"
+ }
+ if raw, ok := settings[keys.grantOnFirstBind]; ok {
+ result.GrantOnFirstBind = raw == "true"
+ }
+
+ return result
+}
+
+func writeProviderDefaultGrantUpdates(updates map[string]string, keys authSourceDefaultKeySet, settings ProviderDefaultGrantSettings) {
+ updates[keys.balance] = strconv.FormatFloat(settings.Balance, 'f', 8, 64)
+ updates[keys.concurrency] = strconv.Itoa(settings.Concurrency)
+
+ subscriptions := settings.Subscriptions
+ if subscriptions == nil {
+ subscriptions = []DefaultSubscriptionSetting{}
+ }
+ raw, err := json.Marshal(subscriptions)
+ if err != nil {
+ raw = []byte("[]")
+ }
+ updates[keys.subscriptions] = string(raw)
+ updates[keys.grantOnSignup] = strconv.FormatBool(settings.GrantOnSignup)
+ updates[keys.grantOnFirstBind] = strconv.FormatBool(settings.GrantOnFirstBind)
+}
+
func parseTablePreferences(defaultPageSizeRaw, optionsRaw string) (int, []int) {
defaultPageSize := 20
if v, err := strconv.Atoi(strings.TrimSpace(defaultPageSizeRaw)); err == nil {
@@ -1539,6 +1748,7 @@ func (s *SettingService) GetLinuxDoConnectOAuthConfig(ctx context.Context) (conf
if v, ok := settings[SettingKeyLinuxDoConnectRedirectURL]; ok && strings.TrimSpace(v) != "" {
effective.RedirectURL = strings.TrimSpace(v)
}
+ effective.UsePKCE = true
if !effective.Enabled {
return config.LinuxDoConnectConfig{}, infraerrors.NotFound("OAUTH_DISABLED", "oauth login is disabled")
@@ -1587,9 +1797,6 @@ func (s *SettingService) GetLinuxDoConnectOAuthConfig(ctx context.Context) (conf
return config.LinuxDoConnectConfig{}, infraerrors.InternalServer("OAUTH_CONFIG_INVALID", "oauth client secret not configured")
}
case "none":
- if !effective.UsePKCE {
- return config.LinuxDoConnectConfig{}, infraerrors.InternalServer("OAUTH_CONFIG_INVALID", "oauth pkce must be enabled when token_auth_method=none")
- }
default:
return config.LinuxDoConnectConfig{}, infraerrors.InternalServer("OAUTH_CONFIG_INVALID", "oauth token_auth_method invalid")
}
@@ -1737,6 +1944,8 @@ func (s *SettingService) GetOIDCConnectOAuthConfig(ctx context.Context) (config.
if raw, ok := settings[SettingKeyOIDCConnectValidateIDToken]; ok {
effective.ValidateIDToken = raw == "true"
}
+ effective.UsePKCE = true
+ effective.ValidateIDToken = true
if v, ok := settings[SettingKeyOIDCConnectAllowedSigningAlgs]; ok && strings.TrimSpace(v) != "" {
effective.AllowedSigningAlgs = strings.TrimSpace(v)
}
@@ -1864,9 +2073,6 @@ func (s *SettingService) GetOIDCConnectOAuthConfig(ctx context.Context) (config.
return config.OIDCConnectConfig{}, infraerrors.InternalServer("OAUTH_CONFIG_INVALID", "oauth client secret not configured")
}
case "none":
- if !effective.UsePKCE {
- return config.OIDCConnectConfig{}, infraerrors.InternalServer("OAUTH_CONFIG_INVALID", "oauth pkce must be enabled when token_auth_method=none")
- }
default:
return config.OIDCConnectConfig{}, infraerrors.InternalServer("OAUTH_CONFIG_INVALID", "oauth token_auth_method invalid")
}
diff --git a/backend/internal/service/setting_service_auth_source_defaults_test.go b/backend/internal/service/setting_service_auth_source_defaults_test.go
new file mode 100644
index 00000000..097bf604
--- /dev/null
+++ b/backend/internal/service/setting_service_auth_source_defaults_test.go
@@ -0,0 +1,136 @@
+//go:build unit
+
+package service
+
+import (
+ "context"
+ "encoding/json"
+ "testing"
+
+ "github.com/Wei-Shaw/sub2api/internal/config"
+ "github.com/stretchr/testify/require"
+)
+
+type authSourceDefaultsRepoStub struct {
+ values map[string]string
+ updates map[string]string
+}
+
+func (s *authSourceDefaultsRepoStub) Get(ctx context.Context, key string) (*Setting, error) {
+ panic("unexpected Get call")
+}
+
+func (s *authSourceDefaultsRepoStub) GetValue(ctx context.Context, key string) (string, error) {
+ panic("unexpected GetValue call")
+}
+
+func (s *authSourceDefaultsRepoStub) Set(ctx context.Context, key, value string) error {
+ panic("unexpected Set call")
+}
+
+func (s *authSourceDefaultsRepoStub) GetMultiple(ctx context.Context, keys []string) (map[string]string, error) {
+ out := make(map[string]string, len(keys))
+ for _, key := range keys {
+ if value, ok := s.values[key]; ok {
+ out[key] = value
+ }
+ }
+ return out, nil
+}
+
+func (s *authSourceDefaultsRepoStub) SetMultiple(ctx context.Context, settings map[string]string) error {
+ s.updates = make(map[string]string, len(settings))
+ for key, value := range settings {
+ s.updates[key] = value
+ if s.values == nil {
+ s.values = map[string]string{}
+ }
+ s.values[key] = value
+ }
+ return nil
+}
+
+func (s *authSourceDefaultsRepoStub) GetAll(ctx context.Context) (map[string]string, error) {
+ panic("unexpected GetAll call")
+}
+
+func (s *authSourceDefaultsRepoStub) Delete(ctx context.Context, key string) error {
+ panic("unexpected Delete call")
+}
+
+func TestSettingService_GetAuthSourceDefaultSettings_ParsesValuesAndDefaults(t *testing.T) {
+ repo := &authSourceDefaultsRepoStub{
+ values: map[string]string{
+ SettingKeyAuthSourceDefaultEmailBalance: "12.5",
+ SettingKeyAuthSourceDefaultEmailConcurrency: "7",
+ SettingKeyAuthSourceDefaultEmailSubscriptions: `[{"group_id":11,"validity_days":30}]`,
+ SettingKeyAuthSourceDefaultEmailGrantOnSignup: "false",
+ SettingKeyAuthSourceDefaultLinuxDoGrantOnFirstBind: "true",
+ SettingKeyForceEmailOnThirdPartySignup: "true",
+ },
+ }
+ svc := NewSettingService(repo, &config.Config{})
+
+ got, err := svc.GetAuthSourceDefaultSettings(context.Background())
+ require.NoError(t, err)
+ require.Equal(t, 12.5, got.Email.Balance)
+ require.Equal(t, 7, got.Email.Concurrency)
+ require.Equal(t, []DefaultSubscriptionSetting{{GroupID: 11, ValidityDays: 30}}, got.Email.Subscriptions)
+ require.False(t, got.Email.GrantOnSignup)
+ require.False(t, got.Email.GrantOnFirstBind)
+ require.Equal(t, 0.0, got.LinuxDo.Balance)
+ require.Equal(t, 5, got.LinuxDo.Concurrency)
+ require.Equal(t, []DefaultSubscriptionSetting{}, got.LinuxDo.Subscriptions)
+ require.True(t, got.LinuxDo.GrantOnSignup)
+ require.True(t, got.LinuxDo.GrantOnFirstBind)
+ require.Equal(t, 5, got.OIDC.Concurrency)
+ require.Equal(t, 5, got.WeChat.Concurrency)
+ require.True(t, got.ForceEmailOnThirdPartySignup)
+}
+
+func TestSettingService_UpdateAuthSourceDefaultSettings_PersistsAllKeys(t *testing.T) {
+ repo := &authSourceDefaultsRepoStub{}
+ svc := NewSettingService(repo, &config.Config{})
+
+ err := svc.UpdateAuthSourceDefaultSettings(context.Background(), &AuthSourceDefaultSettings{
+ Email: ProviderDefaultGrantSettings{
+ Balance: 1.25,
+ Concurrency: 3,
+ Subscriptions: []DefaultSubscriptionSetting{{GroupID: 21, ValidityDays: 14}},
+ GrantOnSignup: false,
+ GrantOnFirstBind: true,
+ },
+ LinuxDo: ProviderDefaultGrantSettings{
+ Balance: 2,
+ Concurrency: 4,
+ Subscriptions: []DefaultSubscriptionSetting{{GroupID: 22, ValidityDays: 30}},
+ GrantOnSignup: true,
+ GrantOnFirstBind: false,
+ },
+ OIDC: ProviderDefaultGrantSettings{
+ Balance: 3,
+ Concurrency: 5,
+ Subscriptions: []DefaultSubscriptionSetting{{GroupID: 23, ValidityDays: 60}},
+ GrantOnSignup: true,
+ GrantOnFirstBind: true,
+ },
+ WeChat: ProviderDefaultGrantSettings{
+ Balance: 4,
+ Concurrency: 6,
+ Subscriptions: []DefaultSubscriptionSetting{{GroupID: 24, ValidityDays: 90}},
+ GrantOnSignup: false,
+ GrantOnFirstBind: false,
+ },
+ ForceEmailOnThirdPartySignup: true,
+ })
+ require.NoError(t, err)
+ require.Equal(t, "1.25000000", repo.updates[SettingKeyAuthSourceDefaultEmailBalance])
+ require.Equal(t, "3", repo.updates[SettingKeyAuthSourceDefaultEmailConcurrency])
+ require.Equal(t, "false", repo.updates[SettingKeyAuthSourceDefaultEmailGrantOnSignup])
+ require.Equal(t, "true", repo.updates[SettingKeyAuthSourceDefaultEmailGrantOnFirstBind])
+ require.Equal(t, "true", repo.updates[SettingKeyForceEmailOnThirdPartySignup])
+
+ var got []DefaultSubscriptionSetting
+ require.NoError(t, json.Unmarshal([]byte(repo.updates[SettingKeyAuthSourceDefaultWeChatSubscriptions]), &got))
+ require.Equal(t, []DefaultSubscriptionSetting{{GroupID: 24, ValidityDays: 90}}, got)
+}
diff --git a/backend/internal/service/settings_view.go b/backend/internal/service/settings_view.go
index ab2eb274..e991ebef 100644
--- a/backend/internal/service/settings_view.go
+++ b/backend/internal/service/settings_view.go
@@ -152,6 +152,7 @@ type PublicSettings struct {
CustomEndpoints string // JSON array of custom endpoints
LinuxDoOAuthEnabled bool
+ WeChatOAuthEnabled bool
BackendModeEnabled bool
PaymentEnabled bool
OIDCOAuthEnabled bool
diff --git a/backend/internal/service/user.go b/backend/internal/service/user.go
index 59f8aa6b..d8b5325c 100644
--- a/backend/internal/service/user.go
+++ b/backend/internal/service/user.go
@@ -7,19 +7,27 @@ import (
)
type User struct {
- ID int64
- Email string
- Username string
- Notes string
- PasswordHash string
- Role string
- Balance float64
- Concurrency int
- Status string
- AllowedGroups []int64
- TokenVersion int64 // Incremented on password change to invalidate existing tokens
- CreatedAt time.Time
- UpdatedAt time.Time
+ ID int64
+ Email string
+ Username string
+ Notes string
+ AvatarURL string
+ AvatarSource string
+ AvatarMIME string
+ AvatarByteSize int
+ AvatarSHA256 string
+ PasswordHash string
+ Role string
+ Balance float64
+ Concurrency int
+ Status string
+ AllowedGroups []int64
+ TokenVersion int64 // Incremented on password change to invalidate existing tokens
+ SignupSource string
+ LastLoginAt *time.Time
+ LastActiveAt *time.Time
+ CreatedAt time.Time
+ UpdatedAt time.Time
// GroupRates 用户专属分组倍率配置
// map[groupID]rateMultiplier
diff --git a/backend/internal/service/user_service.go b/backend/internal/service/user_service.go
index 3490e804..2f6d9427 100644
--- a/backend/internal/service/user_service.go
+++ b/backend/internal/service/user_service.go
@@ -2,9 +2,13 @@ package service
import (
"context"
+ "crypto/sha256"
"crypto/subtle"
+ "encoding/base64"
+ "encoding/hex"
"fmt"
"log/slog"
+ "net/url"
"strings"
"time"
@@ -17,10 +21,14 @@ var (
ErrPasswordIncorrect = infraerrors.BadRequest("PASSWORD_INCORRECT", "current password is incorrect")
ErrInsufficientPerms = infraerrors.Forbidden("INSUFFICIENT_PERMISSIONS", "insufficient permissions")
ErrNotifyCodeUserRateLimit = infraerrors.TooManyRequests("NOTIFY_CODE_USER_RATE_LIMIT", "too many verification codes requested, please try again later")
+ ErrAvatarInvalid = infraerrors.BadRequest("AVATAR_INVALID", "avatar must be a valid image data URL or http(s) URL")
+ ErrAvatarTooLarge = infraerrors.BadRequest("AVATAR_TOO_LARGE", "avatar image must be 100KB or smaller")
+ ErrAvatarNotImage = infraerrors.BadRequest("AVATAR_NOT_IMAGE", "avatar content must be an image")
)
const (
- maxNotifyEmails = 3 // Maximum number of notification emails per user
+ maxNotifyEmails = 3 // Maximum number of notification emails per user
+ maxInlineAvatarBytes = 100 * 1024
// User-level rate limiting for notify email verification codes
notifyCodeUserRateLimit = 5
@@ -47,6 +55,9 @@ type UserRepository interface {
GetFirstAdmin(ctx context.Context) (*User, error)
Update(ctx context.Context, user *User) error
Delete(ctx context.Context, id int64) error
+ GetUserAvatar(ctx context.Context, userID int64) (*UserAvatar, error)
+ UpsertUserAvatar(ctx context.Context, userID int64, input UpsertUserAvatarInput) (*UserAvatar, error)
+ DeleteUserAvatar(ctx context.Context, userID int64) error
List(ctx context.Context, params pagination.PaginationParams) ([]User, *pagination.PaginationResult, error)
ListWithFilters(ctx context.Context, params pagination.PaginationParams, filters UserListFilters) ([]User, *pagination.PaginationResult, error)
@@ -71,11 +82,30 @@ type UserRepository interface {
type UpdateProfileRequest struct {
Email *string `json:"email"`
Username *string `json:"username"`
+ AvatarURL *string `json:"avatar_url"`
Concurrency *int `json:"concurrency"`
BalanceNotifyEnabled *bool `json:"balance_notify_enabled"`
BalanceNotifyThreshold *float64 `json:"balance_notify_threshold"`
}
+type UserAvatar struct {
+ StorageProvider string
+ StorageKey string
+ URL string
+ ContentType string
+ ByteSize int
+ SHA256 string
+}
+
+type UpsertUserAvatarInput struct {
+ StorageProvider string
+ StorageKey string
+ URL string
+ ContentType string
+ ByteSize int
+ SHA256 string
+}
+
// ChangePasswordRequest 修改密码请求
type ChangePasswordRequest struct {
CurrentPassword string `json:"current_password"`
@@ -115,6 +145,9 @@ func (s *UserService) GetProfile(ctx context.Context, userID int64) (*User, erro
if err != nil {
return nil, fmt.Errorf("get user: %w", err)
}
+ if err := s.hydrateUserAvatar(ctx, user); err != nil {
+ return nil, fmt.Errorf("get user avatar: %w", err)
+ }
return user, nil
}
@@ -143,6 +176,27 @@ func (s *UserService) UpdateProfile(ctx context.Context, userID int64, req Updat
user.Username = *req.Username
}
+ if req.AvatarURL != nil {
+ avatarValue := strings.TrimSpace(*req.AvatarURL)
+ switch {
+ case avatarValue == "":
+ if err := s.userRepo.DeleteUserAvatar(ctx, userID); err != nil {
+ return nil, fmt.Errorf("delete avatar: %w", err)
+ }
+ applyUserAvatar(user, nil)
+ default:
+ avatarInput, err := normalizeUserAvatarInput(avatarValue)
+ if err != nil {
+ return nil, err
+ }
+ avatar, err := s.userRepo.UpsertUserAvatar(ctx, userID, avatarInput)
+ if err != nil {
+ return nil, fmt.Errorf("upsert avatar: %w", err)
+ }
+ applyUserAvatar(user, avatar)
+ }
+ }
+
if req.Concurrency != nil {
user.Concurrency = *req.Concurrency
}
@@ -168,6 +222,87 @@ func (s *UserService) UpdateProfile(ctx context.Context, userID int64, req Updat
return user, nil
}
+func applyUserAvatar(user *User, avatar *UserAvatar) {
+ if user == nil {
+ return
+ }
+ if avatar == nil {
+ user.AvatarURL = ""
+ user.AvatarSource = ""
+ user.AvatarMIME = ""
+ user.AvatarByteSize = 0
+ user.AvatarSHA256 = ""
+ return
+ }
+
+ user.AvatarURL = avatar.URL
+ user.AvatarSource = avatar.StorageProvider
+ user.AvatarMIME = avatar.ContentType
+ user.AvatarByteSize = avatar.ByteSize
+ user.AvatarSHA256 = avatar.SHA256
+}
+
+func normalizeUserAvatarInput(raw string) (UpsertUserAvatarInput, error) {
+ raw = strings.TrimSpace(raw)
+ if raw == "" {
+ return UpsertUserAvatarInput{}, ErrAvatarInvalid
+ }
+ if strings.HasPrefix(raw, "data:") {
+ return normalizeInlineUserAvatarInput(raw)
+ }
+
+ parsed, err := url.Parse(raw)
+ if err != nil || parsed == nil {
+ return UpsertUserAvatarInput{}, ErrAvatarInvalid
+ }
+ if !strings.EqualFold(parsed.Scheme, "http") && !strings.EqualFold(parsed.Scheme, "https") {
+ return UpsertUserAvatarInput{}, ErrAvatarInvalid
+ }
+ if strings.TrimSpace(parsed.Host) == "" {
+ return UpsertUserAvatarInput{}, ErrAvatarInvalid
+ }
+
+ return UpsertUserAvatarInput{
+ StorageProvider: "remote_url",
+ URL: raw,
+ }, nil
+}
+
+func normalizeInlineUserAvatarInput(raw string) (UpsertUserAvatarInput, error) {
+ body := strings.TrimPrefix(raw, "data:")
+ meta, encoded, ok := strings.Cut(body, ",")
+ if !ok {
+ return UpsertUserAvatarInput{}, ErrAvatarInvalid
+ }
+ meta = strings.TrimSpace(meta)
+ encoded = strings.TrimSpace(encoded)
+ if !strings.HasSuffix(strings.ToLower(meta), ";base64") {
+ return UpsertUserAvatarInput{}, ErrAvatarInvalid
+ }
+
+ contentType := strings.TrimSpace(meta[:len(meta)-len(";base64")])
+ if contentType == "" || !strings.HasPrefix(strings.ToLower(contentType), "image/") {
+ return UpsertUserAvatarInput{}, ErrAvatarNotImage
+ }
+
+ decoded, err := base64.StdEncoding.DecodeString(encoded)
+ if err != nil {
+ return UpsertUserAvatarInput{}, ErrAvatarInvalid
+ }
+ if len(decoded) > maxInlineAvatarBytes {
+ return UpsertUserAvatarInput{}, ErrAvatarTooLarge
+ }
+
+ sum := sha256.Sum256(decoded)
+ return UpsertUserAvatarInput{
+ StorageProvider: "inline",
+ URL: raw,
+ ContentType: contentType,
+ ByteSize: len(decoded),
+ SHA256: hex.EncodeToString(sum[:]),
+ }, nil
+}
+
// ChangePassword 修改密码
// Security: Increments TokenVersion to invalidate all existing JWT tokens
func (s *UserService) ChangePassword(ctx context.Context, userID int64, req ChangePasswordRequest) error {
@@ -202,9 +337,25 @@ func (s *UserService) GetByID(ctx context.Context, id int64) (*User, error) {
if err != nil {
return nil, fmt.Errorf("get user: %w", err)
}
+ if err := s.hydrateUserAvatar(ctx, user); err != nil {
+ return nil, fmt.Errorf("get user avatar: %w", err)
+ }
return user, nil
}
+func (s *UserService) hydrateUserAvatar(ctx context.Context, user *User) error {
+ if s == nil || s.userRepo == nil || user == nil || user.ID == 0 {
+ return nil
+ }
+
+ avatar, err := s.userRepo.GetUserAvatar(ctx, user.ID)
+ if err != nil {
+ return err
+ }
+ applyUserAvatar(user, avatar)
+ return nil
+}
+
// List 获取用户列表(管理员功能)
func (s *UserService) List(ctx context.Context, params pagination.PaginationParams) ([]User, *pagination.PaginationResult, error) {
users, pagination, err := s.userRepo.List(ctx, params)
diff --git a/backend/internal/service/user_service_test.go b/backend/internal/service/user_service_test.go
index a998d5f4..7d63bb36 100644
--- a/backend/internal/service/user_service_test.go
+++ b/backend/internal/service/user_service_test.go
@@ -4,6 +4,9 @@ package service
import (
"context"
+ "crypto/sha256"
+ "encoding/base64"
+ "encoding/hex"
"errors"
"sync"
"sync/atomic"
@@ -19,14 +22,65 @@ import (
type mockUserRepo struct {
updateBalanceErr error
updateBalanceFn func(ctx context.Context, id int64, amount float64) error
+ getByIDUser *User
+ getByIDErr error
+ updateFn func(ctx context.Context, user *User) error
+ updateCalls int
+ upsertAvatarFn func(ctx context.Context, userID int64, input UpsertUserAvatarInput) (*UserAvatar, error)
+ upsertAvatarArgs []UpsertUserAvatarInput
+ deleteAvatarFn func(ctx context.Context, userID int64) error
+ deleteAvatarIDs []int64
+ getAvatarFn func(ctx context.Context, userID int64) (*UserAvatar, error)
}
-func (m *mockUserRepo) Create(context.Context, *User) error { return nil }
-func (m *mockUserRepo) GetByID(context.Context, int64) (*User, error) { return &User{}, nil }
+func (m *mockUserRepo) Create(context.Context, *User) error { return nil }
+func (m *mockUserRepo) GetByID(context.Context, int64) (*User, error) {
+ if m.getByIDErr != nil {
+ return nil, m.getByIDErr
+ }
+ if m.getByIDUser != nil {
+ cloned := *m.getByIDUser
+ return &cloned, nil
+ }
+ return &User{}, nil
+}
func (m *mockUserRepo) GetByEmail(context.Context, string) (*User, error) { return &User{}, nil }
func (m *mockUserRepo) GetFirstAdmin(context.Context) (*User, error) { return &User{}, nil }
-func (m *mockUserRepo) Update(context.Context, *User) error { return nil }
-func (m *mockUserRepo) Delete(context.Context, int64) error { return nil }
+func (m *mockUserRepo) Update(ctx context.Context, user *User) error {
+ m.updateCalls++
+ if m.updateFn != nil {
+ return m.updateFn(ctx, user)
+ }
+ return nil
+}
+func (m *mockUserRepo) Delete(context.Context, int64) error { return nil }
+func (m *mockUserRepo) GetUserAvatar(ctx context.Context, userID int64) (*UserAvatar, error) {
+ if m.getAvatarFn != nil {
+ return m.getAvatarFn(ctx, userID)
+ }
+ return nil, nil
+}
+func (m *mockUserRepo) UpsertUserAvatar(ctx context.Context, userID int64, input UpsertUserAvatarInput) (*UserAvatar, error) {
+ m.upsertAvatarArgs = append(m.upsertAvatarArgs, input)
+ if m.upsertAvatarFn != nil {
+ return m.upsertAvatarFn(ctx, userID, input)
+ }
+ return &UserAvatar{
+ StorageProvider: input.StorageProvider,
+ StorageKey: input.StorageKey,
+ URL: input.URL,
+ ContentType: input.ContentType,
+ ByteSize: input.ByteSize,
+ SHA256: input.SHA256,
+ }, nil
+}
+func (m *mockUserRepo) DeleteUserAvatar(ctx context.Context, userID int64) error {
+ m.deleteAvatarIDs = append(m.deleteAvatarIDs, userID)
+ if m.deleteAvatarFn != nil {
+ return m.deleteAvatarFn(ctx, userID)
+ }
+ return nil
+}
func (m *mockUserRepo) List(context.Context, pagination.PaginationParams) ([]User, *pagination.PaginationResult, error) {
return nil, nil, nil
}
@@ -200,3 +254,121 @@ func TestNewUserService_FieldsAssignment(t *testing.T) {
require.Equal(t, auth, svc.authCacheInvalidator)
require.Equal(t, cache, svc.billingCache)
}
+
+func TestUpdateProfile_StoresInlineAvatarWithinLimit(t *testing.T) {
+ raw := []byte("small-avatar")
+ dataURL := "data:image/png;base64," + base64.StdEncoding.EncodeToString(raw)
+ expectedSum := sha256.Sum256(raw)
+ repo := &mockUserRepo{
+ getByIDUser: &User{
+ ID: 7,
+ Email: "avatar@example.com",
+ Username: "avatar-user",
+ },
+ }
+ svc := NewUserService(repo, nil, nil, nil)
+
+ updated, err := svc.UpdateProfile(context.Background(), 7, UpdateProfileRequest{
+ AvatarURL: &dataURL,
+ })
+ require.NoError(t, err)
+ require.Len(t, repo.upsertAvatarArgs, 1)
+ require.Equal(t, "inline", repo.upsertAvatarArgs[0].StorageProvider)
+ require.Equal(t, "image/png", repo.upsertAvatarArgs[0].ContentType)
+ require.Equal(t, len(raw), repo.upsertAvatarArgs[0].ByteSize)
+ require.Equal(t, hex.EncodeToString(expectedSum[:]), repo.upsertAvatarArgs[0].SHA256)
+ require.Equal(t, dataURL, updated.AvatarURL)
+ require.Equal(t, "inline", updated.AvatarSource)
+ require.Equal(t, "image/png", updated.AvatarMIME)
+ require.Equal(t, len(raw), updated.AvatarByteSize)
+ require.Equal(t, hex.EncodeToString(expectedSum[:]), updated.AvatarSHA256)
+}
+
+func TestUpdateProfile_RejectsInlineAvatarOverLimit(t *testing.T) {
+ raw := make([]byte, maxInlineAvatarBytes+1)
+ dataURL := "data:image/png;base64," + base64.StdEncoding.EncodeToString(raw)
+ repo := &mockUserRepo{
+ getByIDUser: &User{
+ ID: 8,
+ Email: "large-avatar@example.com",
+ Username: "too-large",
+ },
+ }
+ svc := NewUserService(repo, nil, nil, nil)
+
+ _, err := svc.UpdateProfile(context.Background(), 8, UpdateProfileRequest{
+ AvatarURL: &dataURL,
+ })
+ require.ErrorIs(t, err, ErrAvatarTooLarge)
+ require.Empty(t, repo.upsertAvatarArgs)
+ require.Empty(t, repo.deleteAvatarIDs)
+ require.Zero(t, repo.updateCalls)
+}
+
+func TestUpdateProfile_StoresRemoteAvatarURL(t *testing.T) {
+ remoteURL := "https://cdn.example.com/avatar.png"
+ repo := &mockUserRepo{
+ getByIDUser: &User{
+ ID: 9,
+ Email: "remote-avatar@example.com",
+ Username: "remote-avatar",
+ },
+ }
+ svc := NewUserService(repo, nil, nil, nil)
+
+ updated, err := svc.UpdateProfile(context.Background(), 9, UpdateProfileRequest{
+ AvatarURL: &remoteURL,
+ })
+ require.NoError(t, err)
+ require.Len(t, repo.upsertAvatarArgs, 1)
+ require.Equal(t, "remote_url", repo.upsertAvatarArgs[0].StorageProvider)
+ require.Equal(t, remoteURL, repo.upsertAvatarArgs[0].URL)
+ require.Equal(t, remoteURL, updated.AvatarURL)
+ require.Equal(t, "remote_url", updated.AvatarSource)
+ require.Zero(t, updated.AvatarByteSize)
+}
+
+func TestUpdateProfile_DeletesAvatarOnEmptyString(t *testing.T) {
+ empty := ""
+ repo := &mockUserRepo{
+ getByIDUser: &User{
+ ID: 10,
+ Email: "delete-avatar@example.com",
+ Username: "delete-avatar",
+ AvatarURL: "https://cdn.example.com/old.png",
+ AvatarSource: "remote_url",
+ },
+ }
+ svc := NewUserService(repo, nil, nil, nil)
+
+ updated, err := svc.UpdateProfile(context.Background(), 10, UpdateProfileRequest{
+ AvatarURL: &empty,
+ })
+ require.NoError(t, err)
+ require.Equal(t, []int64{10}, repo.deleteAvatarIDs)
+ require.Empty(t, repo.upsertAvatarArgs)
+ require.Empty(t, updated.AvatarURL)
+ require.Empty(t, updated.AvatarSource)
+}
+
+func TestGetProfile_HydratesAvatarFromRepository(t *testing.T) {
+ repo := &mockUserRepo{
+ getByIDUser: &User{
+ ID: 12,
+ Email: "profile-avatar@example.com",
+ Username: "profile-avatar",
+ },
+ getAvatarFn: func(context.Context, int64) (*UserAvatar, error) {
+ return &UserAvatar{
+ StorageProvider: "remote_url",
+ URL: "https://cdn.example.com/profile.png",
+ }, nil
+ },
+ }
+ svc := NewUserService(repo, nil, nil, nil)
+
+ user, err := svc.GetProfile(context.Background(), 12)
+ require.NoError(t, err)
+ require.Equal(t, "https://cdn.example.com/profile.png", user.AvatarURL)
+ require.Equal(t, "remote_url", user.AvatarSource)
+}
diff --git a/backend/migrations/108_auth_identity_foundation_core.sql b/backend/migrations/108_auth_identity_foundation_core.sql
new file mode 100644
index 00000000..117e3ca3
--- /dev/null
+++ b/backend/migrations/108_auth_identity_foundation_core.sql
@@ -0,0 +1,141 @@
+ALTER TABLE users
+ADD COLUMN IF NOT EXISTS signup_source VARCHAR(20) NOT NULL DEFAULT 'email',
+ADD COLUMN IF NOT EXISTS last_login_at TIMESTAMPTZ NULL,
+ADD COLUMN IF NOT EXISTS last_active_at TIMESTAMPTZ NULL;
+
+UPDATE users
+SET signup_source = 'email'
+WHERE signup_source IS NULL OR signup_source = '';
+
+DO $$
+BEGIN
+ IF NOT EXISTS (
+ SELECT 1
+ FROM pg_constraint
+ WHERE conname = 'users_signup_source_check'
+ ) THEN
+ ALTER TABLE users
+ ADD CONSTRAINT users_signup_source_check
+ CHECK (signup_source IN ('email', 'linuxdo', 'wechat', 'oidc'));
+ END IF;
+END $$;
+
+CREATE TABLE IF NOT EXISTS auth_identities (
+ id BIGSERIAL PRIMARY KEY,
+ user_id BIGINT NOT NULL REFERENCES users(id) ON DELETE CASCADE,
+ provider_type VARCHAR(20) NOT NULL,
+ provider_key TEXT NOT NULL,
+ provider_subject TEXT NOT NULL,
+ verified_at TIMESTAMPTZ NULL,
+ issuer TEXT NULL,
+ metadata JSONB NOT NULL DEFAULT '{}'::jsonb,
+ created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
+ updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
+ CONSTRAINT auth_identities_provider_type_check
+ CHECK (provider_type IN ('email', 'linuxdo', 'wechat', 'oidc'))
+);
+
+CREATE UNIQUE INDEX IF NOT EXISTS auth_identities_provider_subject_key
+ ON auth_identities (provider_type, provider_key, provider_subject);
+
+CREATE INDEX IF NOT EXISTS auth_identities_user_id_idx
+ ON auth_identities (user_id);
+
+CREATE INDEX IF NOT EXISTS auth_identities_user_provider_idx
+ ON auth_identities (user_id, provider_type);
+
+CREATE TABLE IF NOT EXISTS auth_identity_channels (
+ id BIGSERIAL PRIMARY KEY,
+ identity_id BIGINT NOT NULL REFERENCES auth_identities(id) ON DELETE CASCADE,
+ provider_type VARCHAR(20) NOT NULL,
+ provider_key TEXT NOT NULL,
+ channel VARCHAR(20) NOT NULL,
+ channel_app_id TEXT NOT NULL,
+ channel_subject TEXT NOT NULL,
+ metadata JSONB NOT NULL DEFAULT '{}'::jsonb,
+ created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
+ updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
+ CONSTRAINT auth_identity_channels_provider_type_check
+ CHECK (provider_type IN ('email', 'linuxdo', 'wechat', 'oidc'))
+);
+
+CREATE UNIQUE INDEX IF NOT EXISTS auth_identity_channels_channel_key
+ ON auth_identity_channels (provider_type, provider_key, channel, channel_app_id, channel_subject);
+
+CREATE INDEX IF NOT EXISTS auth_identity_channels_identity_id_idx
+ ON auth_identity_channels (identity_id);
+
+CREATE TABLE IF NOT EXISTS pending_auth_sessions (
+ id BIGSERIAL PRIMARY KEY,
+ session_token VARCHAR(255) NOT NULL,
+ intent VARCHAR(40) NOT NULL,
+ provider_type VARCHAR(20) NOT NULL,
+ provider_key TEXT NOT NULL,
+ provider_subject TEXT NOT NULL,
+ target_user_id BIGINT NULL REFERENCES users(id) ON DELETE SET NULL,
+ redirect_to TEXT NOT NULL DEFAULT '',
+ resolved_email TEXT NOT NULL DEFAULT '',
+ registration_password_hash TEXT NOT NULL DEFAULT '',
+ upstream_identity_claims JSONB NOT NULL DEFAULT '{}'::jsonb,
+ local_flow_state JSONB NOT NULL DEFAULT '{}'::jsonb,
+ browser_session_key TEXT NOT NULL DEFAULT '',
+ completion_code_hash TEXT NOT NULL DEFAULT '',
+ completion_code_expires_at TIMESTAMPTZ NULL,
+ email_verified_at TIMESTAMPTZ NULL,
+ password_verified_at TIMESTAMPTZ NULL,
+ totp_verified_at TIMESTAMPTZ NULL,
+ expires_at TIMESTAMPTZ NOT NULL,
+ consumed_at TIMESTAMPTZ NULL,
+ created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
+ updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
+ CONSTRAINT pending_auth_sessions_intent_check
+ CHECK (intent IN ('login', 'bind_current_user', 'adopt_existing_user_by_email')),
+ CONSTRAINT pending_auth_sessions_provider_type_check
+ CHECK (provider_type IN ('email', 'linuxdo', 'wechat', 'oidc'))
+);
+
+CREATE UNIQUE INDEX IF NOT EXISTS pending_auth_sessions_session_token_key
+ ON pending_auth_sessions (session_token);
+
+CREATE INDEX IF NOT EXISTS pending_auth_sessions_target_user_id_idx
+ ON pending_auth_sessions (target_user_id);
+
+CREATE INDEX IF NOT EXISTS pending_auth_sessions_expires_at_idx
+ ON pending_auth_sessions (expires_at);
+
+CREATE INDEX IF NOT EXISTS pending_auth_sessions_provider_idx
+ ON pending_auth_sessions (provider_type, provider_key, provider_subject);
+
+CREATE INDEX IF NOT EXISTS pending_auth_sessions_completion_code_idx
+ ON pending_auth_sessions (completion_code_hash);
+
+CREATE TABLE IF NOT EXISTS identity_adoption_decisions (
+ id BIGSERIAL PRIMARY KEY,
+ pending_auth_session_id BIGINT NOT NULL REFERENCES pending_auth_sessions(id) ON DELETE CASCADE,
+ identity_id BIGINT NULL REFERENCES auth_identities(id) ON DELETE SET NULL,
+ adopt_display_name BOOLEAN NOT NULL DEFAULT FALSE,
+ adopt_avatar BOOLEAN NOT NULL DEFAULT FALSE,
+ decided_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
+ created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
+ updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
+);
+
+CREATE UNIQUE INDEX IF NOT EXISTS identity_adoption_decisions_pending_auth_session_id_key
+ ON identity_adoption_decisions (pending_auth_session_id);
+
+CREATE INDEX IF NOT EXISTS identity_adoption_decisions_identity_id_idx
+ ON identity_adoption_decisions (identity_id);
+
+CREATE TABLE IF NOT EXISTS auth_identity_migration_reports (
+ id BIGSERIAL PRIMARY KEY,
+ report_type VARCHAR(40) NOT NULL,
+ report_key TEXT NOT NULL,
+ details JSONB NOT NULL DEFAULT '{}'::jsonb,
+ created_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
+);
+
+CREATE INDEX IF NOT EXISTS auth_identity_migration_reports_type_idx
+ ON auth_identity_migration_reports (report_type);
+
+CREATE UNIQUE INDEX IF NOT EXISTS auth_identity_migration_reports_type_key
+ ON auth_identity_migration_reports (report_type, report_key);
diff --git a/backend/migrations/109_auth_identity_compat_backfill.sql b/backend/migrations/109_auth_identity_compat_backfill.sql
new file mode 100644
index 00000000..ddbbedbc
--- /dev/null
+++ b/backend/migrations/109_auth_identity_compat_backfill.sql
@@ -0,0 +1,125 @@
+INSERT INTO auth_identities (
+ user_id,
+ provider_type,
+ provider_key,
+ provider_subject,
+ verified_at,
+ metadata
+)
+SELECT
+ u.id,
+ 'email',
+ 'email',
+ LOWER(BTRIM(u.email)),
+ COALESCE(u.updated_at, u.created_at, NOW()),
+ jsonb_build_object(
+ 'backfill_source', 'users.email',
+ 'migration', '109_auth_identity_compat_backfill'
+ )
+FROM users AS u
+WHERE u.deleted_at IS NULL
+ AND BTRIM(COALESCE(u.email, '')) <> ''
+ AND RIGHT(LOWER(BTRIM(u.email)), LENGTH('@linuxdo-connect.invalid')) <> '@linuxdo-connect.invalid'
+ AND RIGHT(LOWER(BTRIM(u.email)), LENGTH('@oidc-connect.invalid')) <> '@oidc-connect.invalid'
+ AND RIGHT(LOWER(BTRIM(u.email)), LENGTH('@wechat-connect.invalid')) <> '@wechat-connect.invalid'
+ON CONFLICT (provider_type, provider_key, provider_subject) DO NOTHING;
+
+INSERT INTO auth_identities (
+ user_id,
+ provider_type,
+ provider_key,
+ provider_subject,
+ verified_at,
+ metadata
+)
+SELECT
+ u.id,
+ 'linuxdo',
+ 'linuxdo',
+ SUBSTRING(BTRIM(u.email) FROM '(?i)^linuxdo-(.+)@linuxdo-connect\.invalid$'),
+ COALESCE(u.updated_at, u.created_at, NOW()),
+ jsonb_build_object(
+ 'backfill_source', 'synthetic_email',
+ 'legacy_email', BTRIM(u.email),
+ 'migration', '109_auth_identity_compat_backfill'
+ )
+FROM users AS u
+WHERE u.deleted_at IS NULL
+ AND LOWER(BTRIM(u.email)) ~ '^linuxdo-.+@linuxdo-connect\.invalid$'
+ON CONFLICT (provider_type, provider_key, provider_subject) DO NOTHING;
+
+INSERT INTO auth_identities (
+ user_id,
+ provider_type,
+ provider_key,
+ provider_subject,
+ verified_at,
+ metadata
+)
+SELECT
+ u.id,
+ 'wechat',
+ 'wechat',
+ SUBSTRING(BTRIM(u.email) FROM '(?i)^wechat-(.+)@wechat-connect\.invalid$'),
+ COALESCE(u.updated_at, u.created_at, NOW()),
+ jsonb_build_object(
+ 'backfill_source', 'synthetic_email',
+ 'legacy_email', BTRIM(u.email),
+ 'migration', '109_auth_identity_compat_backfill'
+ )
+FROM users AS u
+WHERE u.deleted_at IS NULL
+ AND LOWER(BTRIM(u.email)) ~ '^wechat-.+@wechat-connect\.invalid$'
+ON CONFLICT (provider_type, provider_key, provider_subject) DO NOTHING;
+
+UPDATE users
+SET signup_source = 'linuxdo'
+WHERE deleted_at IS NULL
+ AND LOWER(BTRIM(COALESCE(email, ''))) ~ '^linuxdo-.+@linuxdo-connect\.invalid$';
+
+UPDATE users
+SET signup_source = 'wechat'
+WHERE deleted_at IS NULL
+ AND LOWER(BTRIM(COALESCE(email, ''))) ~ '^wechat-.+@wechat-connect\.invalid$';
+
+UPDATE users
+SET signup_source = 'oidc'
+WHERE deleted_at IS NULL
+ AND LOWER(BTRIM(COALESCE(email, ''))) ~ '^oidc-.+@oidc-connect\.invalid$';
+
+INSERT INTO auth_identity_migration_reports (report_type, report_key, details)
+SELECT
+ 'oidc_synthetic_email_requires_manual_recovery',
+ CAST(u.id AS TEXT),
+ jsonb_build_object(
+ 'user_id', u.id,
+ 'email', LOWER(BTRIM(u.email)),
+ 'reason', 'cannot recover issuer_plus_sub deterministically from synthetic email alone',
+ 'migration', '109_auth_identity_compat_backfill'
+ )
+FROM users AS u
+WHERE u.deleted_at IS NULL
+ AND LOWER(BTRIM(u.email)) ~ '^oidc-.+@oidc-connect\.invalid$'
+ON CONFLICT (report_type, report_key) DO NOTHING;
+
+INSERT INTO auth_identity_migration_reports (report_type, report_key, details)
+SELECT
+ 'wechat_openid_only_requires_remediation',
+ CAST(u.id AS TEXT),
+ jsonb_build_object(
+ 'user_id', u.id,
+ 'email', LOWER(BTRIM(u.email)),
+ 'reason', 'legacy wechat synthetic identity requires explicit unionid remediation if channel-only data exists',
+ 'migration', '109_auth_identity_compat_backfill'
+ )
+FROM users AS u
+WHERE u.deleted_at IS NULL
+ AND LOWER(BTRIM(u.email)) ~ '^wechat-.+@wechat-connect\.invalid$'
+ AND NOT EXISTS (
+ SELECT 1
+ FROM auth_identities ai
+ WHERE ai.user_id = u.id
+ AND ai.provider_type = 'wechat'
+ AND ai.provider_key = 'wechat'
+ )
+ON CONFLICT (report_type, report_key) DO NOTHING;
diff --git a/backend/migrations/110_pending_auth_and_provider_default_grants.sql b/backend/migrations/110_pending_auth_and_provider_default_grants.sql
new file mode 100644
index 00000000..fbaed62e
--- /dev/null
+++ b/backend/migrations/110_pending_auth_and_provider_default_grants.sql
@@ -0,0 +1,60 @@
+CREATE TABLE IF NOT EXISTS user_provider_default_grants (
+ id BIGSERIAL PRIMARY KEY,
+ user_id BIGINT NOT NULL REFERENCES users(id) ON DELETE CASCADE,
+ provider_type VARCHAR(20) NOT NULL,
+ grant_reason VARCHAR(20) NOT NULL DEFAULT 'first_bind',
+ granted_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
+ created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
+ CONSTRAINT user_provider_default_grants_provider_type_check
+ CHECK (provider_type IN ('email', 'linuxdo', 'wechat', 'oidc')),
+ CONSTRAINT user_provider_default_grants_reason_check
+ CHECK (grant_reason IN ('signup', 'first_bind'))
+);
+
+CREATE UNIQUE INDEX IF NOT EXISTS user_provider_default_grants_user_provider_reason_key
+ ON user_provider_default_grants (user_id, provider_type, grant_reason);
+
+CREATE INDEX IF NOT EXISTS user_provider_default_grants_user_id_idx
+ ON user_provider_default_grants (user_id);
+
+CREATE TABLE IF NOT EXISTS user_avatars (
+ id BIGSERIAL PRIMARY KEY,
+ user_id BIGINT NOT NULL REFERENCES users(id) ON DELETE CASCADE,
+ storage_provider VARCHAR(20) NOT NULL DEFAULT 'database',
+ storage_key TEXT NOT NULL DEFAULT '',
+ url TEXT NOT NULL DEFAULT '',
+ content_type VARCHAR(100) NOT NULL DEFAULT '',
+ byte_size INT NOT NULL DEFAULT 0,
+ sha256 VARCHAR(64) NOT NULL DEFAULT '',
+ created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
+ updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
+);
+
+CREATE UNIQUE INDEX IF NOT EXISTS user_avatars_user_id_key
+ ON user_avatars (user_id);
+
+INSERT INTO settings (key, value)
+VALUES
+ ('auth_source_default_email_balance', '0'),
+ ('auth_source_default_email_concurrency', '5'),
+ ('auth_source_default_email_subscriptions', '[]'),
+ ('auth_source_default_email_grant_on_signup', 'true'),
+ ('auth_source_default_email_grant_on_first_bind', 'false'),
+ ('auth_source_default_linuxdo_balance', '0'),
+ ('auth_source_default_linuxdo_concurrency', '5'),
+ ('auth_source_default_linuxdo_subscriptions', '[]'),
+ ('auth_source_default_linuxdo_grant_on_signup', 'true'),
+ ('auth_source_default_linuxdo_grant_on_first_bind', 'false'),
+ ('auth_source_default_oidc_balance', '0'),
+ ('auth_source_default_oidc_concurrency', '5'),
+ ('auth_source_default_oidc_subscriptions', '[]'),
+ ('auth_source_default_oidc_grant_on_signup', 'true'),
+ ('auth_source_default_oidc_grant_on_first_bind', 'false'),
+ ('auth_source_default_wechat_balance', '0'),
+ ('auth_source_default_wechat_concurrency', '5'),
+ ('auth_source_default_wechat_subscriptions', '[]'),
+ ('auth_source_default_wechat_grant_on_signup', 'true'),
+ ('auth_source_default_wechat_grant_on_first_bind', 'false'),
+ ('force_email_on_third_party_signup', 'false')
+ON CONFLICT (key) DO NOTHING;
+
diff --git a/backend/migrations/111_payment_routing_and_scheduler_flags.sql b/backend/migrations/111_payment_routing_and_scheduler_flags.sql
new file mode 100644
index 00000000..f222a8d4
--- /dev/null
+++ b/backend/migrations/111_payment_routing_and_scheduler_flags.sql
@@ -0,0 +1,8 @@
+INSERT INTO settings (key, value)
+VALUES
+ ('payment_visible_method_alipay_source', ''),
+ ('payment_visible_method_wxpay_source', ''),
+ ('payment_visible_method_alipay_enabled', 'false'),
+ ('payment_visible_method_wxpay_enabled', 'false'),
+ ('openai_advanced_scheduler_enabled', 'false')
+ON CONFLICT (key) DO NOTHING;
diff --git a/frontend/src/api/__tests__/auth-oauth-adoption.spec.ts b/frontend/src/api/__tests__/auth-oauth-adoption.spec.ts
new file mode 100644
index 00000000..574e1e36
--- /dev/null
+++ b/frontend/src/api/__tests__/auth-oauth-adoption.spec.ts
@@ -0,0 +1,60 @@
+import { beforeEach, describe, expect, it, vi } from 'vitest'
+
+const post = vi.fn()
+
+vi.mock('@/api/client', () => ({
+ apiClient: {
+ post
+ }
+}))
+
+describe('oauth adoption auth api', () => {
+ beforeEach(() => {
+ post.mockReset()
+ post.mockResolvedValue({ data: {} })
+ })
+
+ it('posts adoption decisions when exchanging pending oauth completion', async () => {
+ const { exchangePendingOAuthCompletion } = await import('@/api/auth')
+
+ await exchangePendingOAuthCompletion({
+ adoptDisplayName: false,
+ adoptAvatar: true
+ })
+
+ expect(post).toHaveBeenCalledWith('/auth/oauth/pending/exchange', {
+ adopt_display_name: false,
+ adopt_avatar: true
+ })
+ })
+
+ it('posts linuxdo invitation completion with adoption decisions', async () => {
+ const { completeLinuxDoOAuthRegistration } = await import('@/api/auth')
+
+ await completeLinuxDoOAuthRegistration('invite-code', {
+ adoptDisplayName: true,
+ adoptAvatar: false
+ })
+
+ expect(post).toHaveBeenCalledWith('/auth/oauth/linuxdo/complete-registration', {
+ invitation_code: 'invite-code',
+ adopt_display_name: true,
+ adopt_avatar: false
+ })
+ })
+
+ it('posts oidc invitation completion with adoption decisions', async () => {
+ const { completeOIDCOAuthRegistration } = await import('@/api/auth')
+
+ await completeOIDCOAuthRegistration('invite-code', {
+ adoptDisplayName: false,
+ adoptAvatar: true
+ })
+
+ expect(post).toHaveBeenCalledWith('/auth/oauth/oidc/complete-registration', {
+ invitation_code: 'invite-code',
+ adopt_display_name: false,
+ adopt_avatar: true
+ })
+ })
+})
diff --git a/frontend/src/api/__tests__/settings.authSourceDefaults.spec.ts b/frontend/src/api/__tests__/settings.authSourceDefaults.spec.ts
new file mode 100644
index 00000000..8756146e
--- /dev/null
+++ b/frontend/src/api/__tests__/settings.authSourceDefaults.spec.ts
@@ -0,0 +1,118 @@
+import { describe, expect, it } from 'vitest'
+
+import {
+ appendAuthSourceDefaultsToUpdateRequest,
+ buildAuthSourceDefaultsState,
+ type UpdateSettingsRequest,
+} from '@/api/admin/settings'
+
+describe('admin settings auth source defaults helpers', () => {
+ it('builds auth source defaults state from flat settings fields', () => {
+ const state = buildAuthSourceDefaultsState({
+ auth_source_default_email_balance: 9.5,
+ auth_source_default_email_concurrency: 3,
+ auth_source_default_email_subscriptions: [
+ { group_id: 1, validity_days: 30 },
+ ],
+ auth_source_default_email_grant_on_signup: false,
+ auth_source_default_email_grant_on_first_bind: true,
+ auth_source_default_linuxdo_balance: 6,
+ auth_source_default_linuxdo_concurrency: 8,
+ auth_source_default_linuxdo_subscriptions: [
+ { group_id: 2, validity_days: 60 },
+ ],
+ auth_source_default_linuxdo_grant_on_signup: true,
+ auth_source_default_linuxdo_grant_on_first_bind: false,
+ })
+
+ expect(state.email).toEqual({
+ balance: 9.5,
+ concurrency: 3,
+ subscriptions: [{ group_id: 1, validity_days: 30 }],
+ grant_on_signup: false,
+ grant_on_first_bind: true,
+ })
+ expect(state.linuxdo).toEqual({
+ balance: 6,
+ concurrency: 8,
+ subscriptions: [{ group_id: 2, validity_days: 60 }],
+ grant_on_signup: true,
+ grant_on_first_bind: false,
+ })
+ expect(state.oidc).toEqual({
+ balance: 0,
+ concurrency: 5,
+ subscriptions: [],
+ grant_on_signup: true,
+ grant_on_first_bind: false,
+ })
+ expect(state.wechat).toEqual({
+ balance: 0,
+ concurrency: 5,
+ subscriptions: [],
+ grant_on_signup: true,
+ grant_on_first_bind: false,
+ })
+ })
+
+ it('appends auth source defaults back onto update payload', () => {
+ const payload: UpdateSettingsRequest = {
+ site_name: 'Sub2API',
+ }
+
+ appendAuthSourceDefaultsToUpdateRequest(payload, {
+ email: {
+ balance: 1.25,
+ concurrency: 2,
+ subscriptions: [{ group_id: 3, validity_days: 7 }],
+ grant_on_signup: true,
+ grant_on_first_bind: false,
+ },
+ linuxdo: {
+ balance: 0,
+ concurrency: 6,
+ subscriptions: [],
+ grant_on_signup: false,
+ grant_on_first_bind: true,
+ },
+ oidc: {
+ balance: 4,
+ concurrency: 9,
+ subscriptions: [{ group_id: 9, validity_days: 90 }],
+ grant_on_signup: true,
+ grant_on_first_bind: true,
+ },
+ wechat: {
+ balance: 2,
+ concurrency: 5,
+ subscriptions: [],
+ grant_on_signup: false,
+ grant_on_first_bind: false,
+ },
+ })
+
+ expect(payload).toMatchObject({
+ site_name: 'Sub2API',
+ auth_source_default_email_balance: 1.25,
+ auth_source_default_email_concurrency: 2,
+ auth_source_default_email_subscriptions: [{ group_id: 3, validity_days: 7 }],
+ auth_source_default_email_grant_on_signup: true,
+ auth_source_default_email_grant_on_first_bind: false,
+ auth_source_default_linuxdo_balance: 0,
+ auth_source_default_linuxdo_concurrency: 6,
+ auth_source_default_linuxdo_subscriptions: [],
+ auth_source_default_linuxdo_grant_on_signup: false,
+ auth_source_default_linuxdo_grant_on_first_bind: true,
+ auth_source_default_oidc_balance: 4,
+ auth_source_default_oidc_concurrency: 9,
+ auth_source_default_oidc_subscriptions: [{ group_id: 9, validity_days: 90 }],
+ auth_source_default_oidc_grant_on_signup: true,
+ auth_source_default_oidc_grant_on_first_bind: true,
+ auth_source_default_wechat_balance: 2,
+ auth_source_default_wechat_concurrency: 5,
+ auth_source_default_wechat_subscriptions: [],
+ auth_source_default_wechat_grant_on_signup: false,
+ auth_source_default_wechat_grant_on_first_bind: false,
+ })
+ })
+})
diff --git a/frontend/src/api/admin/settings.ts b/frontend/src/api/admin/settings.ts
index 1e4a3053..8e182c1c 100644
--- a/frontend/src/api/admin/settings.ts
+++ b/frontend/src/api/admin/settings.ts
@@ -11,6 +11,81 @@ export interface DefaultSubscriptionSetting {
validity_days: number
}
+export type AuthSourceType = 'email' | 'linuxdo' | 'oidc' | 'wechat'
+
+export interface AuthSourceDefaultsValue {
+ balance: number
+ concurrency: number
+ subscriptions: DefaultSubscriptionSetting[]
+ grant_on_signup: boolean
+ grant_on_first_bind: boolean
+}
+
+export type AuthSourceDefaultsState = Record
+
+const AUTH_SOURCE_TYPES: AuthSourceType[] = ['email', 'linuxdo', 'oidc', 'wechat']
+const AUTH_SOURCE_DEFAULT_BALANCE = 0
+const AUTH_SOURCE_DEFAULT_CONCURRENCY = 5
+
+export function normalizeDefaultSubscriptionSettings(
+ subscriptions: DefaultSubscriptionSetting[] | null | undefined
+): DefaultSubscriptionSetting[] {
+ if (!Array.isArray(subscriptions)) return []
+
+ return subscriptions
+ .filter((item) => item.group_id > 0 && item.validity_days > 0)
+ .map((item) => ({
+ group_id: Math.floor(item.group_id),
+ validity_days: Math.min(36500, Math.max(1, Math.floor(item.validity_days)))
+ }))
+}
+
+export function buildAuthSourceDefaultsState(
+ settings: Partial
+): AuthSourceDefaultsState {
+ const raw = settings as Record
+
+ return AUTH_SOURCE_TYPES.reduce((acc, source) => {
+ const subscriptions = raw[`auth_source_default_${source}_subscriptions`]
+ acc[source] = {
+ balance: Number(raw[`auth_source_default_${source}_balance`] ?? AUTH_SOURCE_DEFAULT_BALANCE),
+ concurrency: Math.max(
+ 1,
+ Number(raw[`auth_source_default_${source}_concurrency`] ?? AUTH_SOURCE_DEFAULT_CONCURRENCY)
+ ),
+ subscriptions: normalizeDefaultSubscriptionSettings(
+ Array.isArray(subscriptions) ? (subscriptions as DefaultSubscriptionSetting[]) : []
+ ),
+ grant_on_signup: raw[`auth_source_default_${source}_grant_on_signup`] !== false,
+ grant_on_first_bind: raw[`auth_source_default_${source}_grant_on_first_bind`] === true,
+ }
+ return acc
+ }, {} as AuthSourceDefaultsState)
+}
+
+export function appendAuthSourceDefaultsToUpdateRequest(
+ payload: UpdateSettingsRequest,
+ authSourceDefaults: AuthSourceDefaultsState
+): UpdateSettingsRequest {
+ const target = payload as Record
+
+ for (const source of AUTH_SOURCE_TYPES) {
+ const current = authSourceDefaults[source]
+ target[`auth_source_default_${source}_balance`] = Number(current.balance) || 0
+ target[`auth_source_default_${source}_concurrency`] = Math.max(
+ 1,
+ Math.floor(Number(current.concurrency) || AUTH_SOURCE_DEFAULT_CONCURRENCY)
+ )
+ target[`auth_source_default_${source}_subscriptions`] = normalizeDefaultSubscriptionSettings(
+ current.subscriptions
+ )
+ target[`auth_source_default_${source}_grant_on_signup`] = current.grant_on_signup
+ target[`auth_source_default_${source}_grant_on_first_bind`] = current.grant_on_first_bind
+ }
+
+ return payload
+}
+
/**
* System settings interface
*/
@@ -29,6 +104,27 @@ export interface SystemSettings {
default_balance: number
default_concurrency: number
default_subscriptions: DefaultSubscriptionSetting[]
+ auth_source_default_email_balance?: number
+ auth_source_default_email_concurrency?: number
+ auth_source_default_email_subscriptions?: DefaultSubscriptionSetting[]
+ auth_source_default_email_grant_on_signup?: boolean
+ auth_source_default_email_grant_on_first_bind?: boolean
+ auth_source_default_linuxdo_balance?: number
+ auth_source_default_linuxdo_concurrency?: number
+ auth_source_default_linuxdo_subscriptions?: DefaultSubscriptionSetting[]
+ auth_source_default_linuxdo_grant_on_signup?: boolean
+ auth_source_default_linuxdo_grant_on_first_bind?: boolean
+ auth_source_default_oidc_balance?: number
+ auth_source_default_oidc_concurrency?: number
+ auth_source_default_oidc_subscriptions?: DefaultSubscriptionSetting[]
+ auth_source_default_oidc_grant_on_signup?: boolean
+ auth_source_default_oidc_grant_on_first_bind?: boolean
+ auth_source_default_wechat_balance?: number
+ auth_source_default_wechat_concurrency?: number
+ auth_source_default_wechat_subscriptions?: DefaultSubscriptionSetting[]
+ auth_source_default_wechat_grant_on_signup?: boolean
+ auth_source_default_wechat_grant_on_first_bind?: boolean
+ force_email_on_third_party_signup?: boolean
// OEM settings
site_name: string
site_logo: string
@@ -137,6 +233,11 @@ export interface SystemSettings {
payment_cancel_rate_limit_window: number
payment_cancel_rate_limit_unit: string
payment_cancel_rate_limit_window_mode: string
+ payment_visible_method_alipay_source?: string
+ payment_visible_method_wxpay_source?: string
+ payment_visible_method_alipay_enabled?: boolean
+ payment_visible_method_wxpay_enabled?: boolean
+ openai_advanced_scheduler_enabled?: boolean
// Balance & quota notification
balance_low_notify_enabled: boolean
@@ -158,6 +259,27 @@ export interface UpdateSettingsRequest {
default_balance?: number
default_concurrency?: number
default_subscriptions?: DefaultSubscriptionSetting[]
+ auth_source_default_email_balance?: number
+ auth_source_default_email_concurrency?: number
+ auth_source_default_email_subscriptions?: DefaultSubscriptionSetting[]
+ auth_source_default_email_grant_on_signup?: boolean
+ auth_source_default_email_grant_on_first_bind?: boolean
+ auth_source_default_linuxdo_balance?: number
+ auth_source_default_linuxdo_concurrency?: number
+ auth_source_default_linuxdo_subscriptions?: DefaultSubscriptionSetting[]
+ auth_source_default_linuxdo_grant_on_signup?: boolean
+ auth_source_default_linuxdo_grant_on_first_bind?: boolean
+ auth_source_default_oidc_balance?: number
+ auth_source_default_oidc_concurrency?: number
+ auth_source_default_oidc_subscriptions?: DefaultSubscriptionSetting[]
+ auth_source_default_oidc_grant_on_signup?: boolean
+ auth_source_default_oidc_grant_on_first_bind?: boolean
+ auth_source_default_wechat_balance?: number
+ auth_source_default_wechat_concurrency?: number
+ auth_source_default_wechat_subscriptions?: DefaultSubscriptionSetting[]
+ auth_source_default_wechat_grant_on_signup?: boolean
+ auth_source_default_wechat_grant_on_first_bind?: boolean
+ force_email_on_third_party_signup?: boolean
site_name?: string
site_logo?: string
site_subtitle?: string
@@ -245,6 +367,11 @@ export interface UpdateSettingsRequest {
payment_cancel_rate_limit_window?: number
payment_cancel_rate_limit_unit?: string
payment_cancel_rate_limit_window_mode?: string
+ payment_visible_method_alipay_source?: string
+ payment_visible_method_wxpay_source?: string
+ payment_visible_method_alipay_enabled?: boolean
+ payment_visible_method_wxpay_enabled?: boolean
+ openai_advanced_scheduler_enabled?: boolean
// Balance & quota notification
balance_low_notify_enabled?: boolean
balance_low_notify_threshold?: number
diff --git a/frontend/src/api/auth.ts b/frontend/src/api/auth.ts
index d7abcd6a..10b6ca58 100644
--- a/frontend/src/api/auth.ts
+++ b/frontend/src/api/auth.ts
@@ -198,6 +198,26 @@ export interface PendingOAuthExchangeResponse {
suggested_avatar_url?: string
}
+export interface OAuthAdoptionDecision {
+ adoptDisplayName?: boolean
+ adoptAvatar?: boolean
+}
+
+function serializeOAuthAdoptionDecision(
+ decision?: OAuthAdoptionDecision
+): Record {
+ const payload: Record = {}
+
+ if (typeof decision?.adoptDisplayName === 'boolean') {
+ payload.adopt_display_name = decision.adoptDisplayName
+ }
+ if (typeof decision?.adoptAvatar === 'boolean') {
+ payload.adopt_avatar = decision.adoptAvatar
+ }
+
+ return payload
+}
+
/**
* Refresh the access token using the refresh token
* @returns New token pair
@@ -353,7 +373,8 @@ export async function resetPassword(request: ResetPasswordRequest): Promise {
const { data } = await apiClient.post<{
access_token: string
@@ -361,7 +382,8 @@ export async function completeLinuxDoOAuthRegistration(
expires_in: number
token_type: string
}>('/auth/oauth/linuxdo/complete-registration', {
- invitation_code: invitationCode
+ invitation_code: invitationCode,
+ ...serializeOAuthAdoptionDecision(decision)
})
return data
}
@@ -372,7 +394,8 @@ export async function completeLinuxDoOAuthRegistration(
* @returns Token pair on success
*/
export async function completeOIDCOAuthRegistration(
- invitationCode: string
+ invitationCode: string,
+ decision?: OAuthAdoptionDecision
): Promise<{ access_token: string; refresh_token: string; expires_in: number; token_type: string }> {
const { data } = await apiClient.post<{
access_token: string
@@ -380,13 +403,19 @@ export async function completeOIDCOAuthRegistration(
expires_in: number
token_type: string
}>('/auth/oauth/oidc/complete-registration', {
- invitation_code: invitationCode
+ invitation_code: invitationCode,
+ ...serializeOAuthAdoptionDecision(decision)
})
return data
}
-export async function exchangePendingOAuthCompletion(): Promise {
- const { data } = await apiClient.post('/auth/oauth/pending/exchange', {})
+export async function exchangePendingOAuthCompletion(
+ decision?: OAuthAdoptionDecision
+): Promise {
+ const { data } = await apiClient.post(
+ '/auth/oauth/pending/exchange',
+ serializeOAuthAdoptionDecision(decision)
+ )
return data
}
diff --git a/frontend/src/components/auth/WechatOAuthSection.vue b/frontend/src/components/auth/WechatOAuthSection.vue
new file mode 100644
index 00000000..94e20222
--- /dev/null
+++ b/frontend/src/components/auth/WechatOAuthSection.vue
@@ -0,0 +1,53 @@
+
+
+
+
+ W
+
+ {{ t('auth.oidc.signIn', { providerName }) }}
+
+
+
+
+
+ {{ t('auth.oauthOrContinue') }}
+
+
+
+
+
+
+
diff --git a/frontend/src/components/auth/__tests__/WechatOAuthSection.spec.ts b/frontend/src/components/auth/__tests__/WechatOAuthSection.spec.ts
new file mode 100644
index 00000000..810832a0
--- /dev/null
+++ b/frontend/src/components/auth/__tests__/WechatOAuthSection.spec.ts
@@ -0,0 +1,74 @@
+import { mount } from '@vue/test-utils'
+import { afterEach, beforeEach, describe, expect, it, vi } from 'vitest'
+import WechatOAuthSection from '@/components/auth/WechatOAuthSection.vue'
+
+const routeState = vi.hoisted(() => ({
+ query: {} as Record,
+}))
+
+const locationState = vi.hoisted(() => ({
+ current: { href: 'http://localhost/login' } as { href: string },
+}))
+
+vi.mock('vue-router', () => ({
+ useRoute: () => routeState,
+}))
+
+vi.mock('vue-i18n', () => ({
+ useI18n: () => ({
+ t: (key: string, params?: Record) => {
+ if (key === 'auth.oidc.signIn') {
+ return `Continue with ${params?.providerName ?? ''}`.trim()
+ }
+ if (key === 'auth.oauthOrContinue') {
+ return 'or continue'
+ }
+ return key
+ },
+ }),
+}))
+
+describe('WechatOAuthSection', () => {
+ beforeEach(() => {
+ routeState.query = { redirect: '/billing?plan=pro' }
+ locationState.current = { href: 'http://localhost/login' }
+ Object.defineProperty(window, 'location', {
+ configurable: true,
+ value: locationState.current,
+ })
+ Object.defineProperty(window.navigator, 'userAgent', {
+ configurable: true,
+ value: 'Mozilla/5.0',
+ })
+ })
+
+ afterEach(() => {
+ vi.unstubAllGlobals()
+ })
+
+ it('starts the open WeChat OAuth flow with the current redirect target', async () => {
+ const wrapper = mount(WechatOAuthSection)
+
+ expect(wrapper.text()).toContain('WeChat')
+
+ await wrapper.get('button').trigger('click')
+
+ expect(locationState.current.href).toContain(
+ '/api/v1/auth/oauth/wechat/start?mode=open&redirect=%2Fbilling%3Fplan%3Dpro'
+ )
+ })
+
+ it('uses mp mode inside the WeChat browser', async () => {
+ Object.defineProperty(window.navigator, 'userAgent', {
+ configurable: true,
+ value: 'Mozilla/5.0 MicroMessenger',
+ })
+ const wrapper = mount(WechatOAuthSection)
+
+ await wrapper.get('button').trigger('click')
+
+ expect(locationState.current.href).toContain(
+ '/api/v1/auth/oauth/wechat/start?mode=mp&redirect=%2Fbilling%3Fplan%3Dpro'
+ )
+ })
+})
diff --git a/frontend/src/components/payment/PaymentStatusPanel.vue b/frontend/src/components/payment/PaymentStatusPanel.vue
index 974dee66..8f5a5666 100644
--- a/frontend/src/components/payment/PaymentStatusPanel.vue
+++ b/frontend/src/components/payment/PaymentStatusPanel.vue
@@ -141,7 +141,9 @@ const props = defineProps<{
orderType?: string
}>()
-const emit = defineEmits<{ done: []; success: [] }>()
+type PaymentOutcome = 'success' | 'cancelled' | 'expired'
+
+const emit = defineEmits<{ done: []; success: []; settled: [outcome: PaymentOutcome] }>()
const { t } = useI18n()
const paymentStore = usePaymentStore()
@@ -154,7 +156,7 @@ const cancelling = ref(false)
const paidOrder = ref(null)
// Terminal outcome: null = still active, 'success' | 'cancelled' | 'expired'
-const outcome = ref<'success' | 'cancelled' | 'expired' | null>(null)
+const outcome = ref(null)
let pollTimer: ReturnType | null = null
let countdownTimer: ReturnType | null = null
@@ -194,10 +196,19 @@ const countdownDisplay = computed(() => {
function reopenPopup() {
if (props.payUrl) {
- window.open(props.payUrl, 'paymentPopup', POPUP_WINDOW_FEATURES)
+ const win = window.open(props.payUrl, 'paymentPopup', POPUP_WINDOW_FEATURES)
+ if (!win || win.closed) {
+ window.location.href = props.payUrl
+ }
}
}
+function setOutcome(next: PaymentOutcome) {
+ if (outcome.value === next) return
+ outcome.value = next
+ emit('settled', next)
+}
+
async function renderQR() {
await nextTick()
if (!qrCanvas.value || !qrUrl.value) return
@@ -214,23 +225,23 @@ async function pollStatus() {
if (order.status === 'COMPLETED' || order.status === 'PAID') {
cleanup()
paidOrder.value = order
- outcome.value = 'success'
+ setOutcome('success')
emit('success')
} else if (order.status === 'CANCELLED') {
cleanup()
- outcome.value = 'cancelled'
+ setOutcome('cancelled')
} else if (order.status === 'EXPIRED' || order.status === 'FAILED') {
cleanup()
- outcome.value = 'expired'
+ setOutcome('expired')
}
}
function startCountdown(seconds: number) {
remainingSeconds.value = Math.max(0, seconds)
- if (remainingSeconds.value <= 0) { outcome.value = 'expired'; return }
+ if (remainingSeconds.value <= 0) { setOutcome('expired'); return }
countdownTimer = setInterval(() => {
remainingSeconds.value--
- if (remainingSeconds.value <= 0) { outcome.value = 'expired'; cleanup() }
+ if (remainingSeconds.value <= 0) { setOutcome('expired'); cleanup() }
}, 1000)
}
@@ -240,7 +251,7 @@ async function handleCancel() {
try {
await paymentAPI.cancelOrder(props.orderId)
cleanup()
- outcome.value = 'cancelled'
+ setOutcome('cancelled')
} catch (err: unknown) {
appStore.showError(extractApiErrorMessage(err, t('common.error')))
} finally {
diff --git a/frontend/src/components/payment/__tests__/paymentFlow.spec.ts b/frontend/src/components/payment/__tests__/paymentFlow.spec.ts
new file mode 100644
index 00000000..f5212f15
--- /dev/null
+++ b/frontend/src/components/payment/__tests__/paymentFlow.spec.ts
@@ -0,0 +1,163 @@
+import { describe, expect, it } from 'vitest'
+import type { CreateOrderResult, MethodLimit } from '@/types/payment'
+import {
+ decidePaymentLaunch,
+ getVisibleMethods,
+ readPaymentRecoverySnapshot,
+ type PaymentRecoverySnapshot,
+} from '@/components/payment/paymentFlow'
+
+function methodLimit(overrides: Partial = {}): MethodLimit {
+ return {
+ daily_limit: 0,
+ daily_used: 0,
+ daily_remaining: 0,
+ single_min: 0,
+ single_max: 0,
+ fee_rate: 0,
+ available: true,
+ ...overrides,
+ }
+}
+
+function createOrderResult(overrides: Partial = {}): CreateOrderResult {
+ return {
+ order_id: 101,
+ amount: 88,
+ pay_amount: 88,
+ fee_rate: 0,
+ expires_at: '2099-01-01T00:10:00.000Z',
+ ...overrides,
+ }
+}
+
+describe('getVisibleMethods', () => {
+ it('filters hidden provider methods and normalizes aliases', () => {
+ const visible = getVisibleMethods({
+ alipay_direct: methodLimit({ single_min: 5 }),
+ wxpay: methodLimit({ single_max: 100 }),
+ stripe: methodLimit({ fee_rate: 3 }),
+ })
+
+ expect(visible).toEqual({
+ alipay: methodLimit({ single_min: 5 }),
+ wxpay: methodLimit({ single_max: 100 }),
+ })
+ })
+
+ it('prefers canonical visible methods over aliases when both exist', () => {
+ const visible = getVisibleMethods({
+ alipay: methodLimit({ single_min: 2 }),
+ alipay_direct: methodLimit({ single_min: 9 }),
+ wxpay_direct: methodLimit({ fee_rate: 1.2 }),
+ })
+
+ expect(visible.alipay.single_min).toBe(2)
+ expect(visible.wxpay.fee_rate).toBe(1.2)
+ })
+})
+
+describe('decidePaymentLaunch', () => {
+ it('uses Stripe popup waiting flow for desktop Alipay client secret', () => {
+ const decision = decidePaymentLaunch(createOrderResult({
+ client_secret: 'cs_test',
+ resume_token: 'resume-1',
+ }), {
+ visibleMethod: 'alipay',
+ orderType: 'balance',
+ isMobile: false,
+ })
+
+ expect(decision.kind).toBe('stripe_popup')
+ expect(decision.paymentState.paymentType).toBe('alipay')
+ expect(decision.stripeMethod).toBe('alipay')
+ expect(decision.recovery.resumeToken).toBe('resume-1')
+ })
+
+ it('uses Stripe route flow for mobile WeChat client secret', () => {
+ const decision = decidePaymentLaunch(createOrderResult({
+ client_secret: 'cs_test',
+ }), {
+ visibleMethod: 'wxpay',
+ orderType: 'subscription',
+ isMobile: true,
+ })
+
+ expect(decision.kind).toBe('stripe_route')
+ expect(decision.stripeMethod).toBe('wechat_pay')
+ expect(decision.paymentState.orderType).toBe('subscription')
+ })
+
+ it('keeps hosted redirect metadata for recovery flows', () => {
+ const decision = decidePaymentLaunch(createOrderResult({
+ pay_url: 'https://pay.example.com/session/abc',
+ payment_mode: 'popup',
+ resume_token: 'resume-2',
+ }), {
+ visibleMethod: 'wxpay',
+ orderType: 'balance',
+ isMobile: false,
+ })
+
+ expect(decision.kind).toBe('redirect_waiting')
+ expect(decision.paymentState.payUrl).toBe('https://pay.example.com/session/abc')
+ expect(decision.recovery.paymentMode).toBe('popup')
+ expect(decision.recovery.resumeToken).toBe('resume-2')
+ })
+})
+
+describe('readPaymentRecoverySnapshot', () => {
+ it('restores an unexpired snapshot when the resume token matches', () => {
+ const snapshot: PaymentRecoverySnapshot = {
+ orderId: 33,
+ amount: 18,
+ qrCode: '',
+ expiresAt: '2099-01-01T00:10:00.000Z',
+ paymentType: 'alipay',
+ payUrl: 'https://pay.example.com/session/33',
+ clientSecret: '',
+ payAmount: 18,
+ orderType: 'balance',
+ paymentMode: 'popup',
+ resumeToken: 'resume-33',
+ createdAt: Date.UTC(2099, 0, 1, 0, 0, 0),
+ }
+
+ const restored = readPaymentRecoverySnapshot(JSON.stringify(snapshot), {
+ now: Date.UTC(2099, 0, 1, 0, 1, 0),
+ resumeToken: 'resume-33',
+ })
+
+ expect(restored?.orderId).toBe(33)
+ })
+
+ it('drops expired or mismatched recovery snapshots', () => {
+ const expiredSnapshot: PaymentRecoverySnapshot = {
+ orderId: 55,
+ amount: 18,
+ qrCode: '',
+ expiresAt: '2024-01-01T00:10:00.000Z',
+ paymentType: 'wxpay',
+ payUrl: 'https://pay.example.com/session/55',
+ clientSecret: '',
+ payAmount: 18,
+ orderType: 'balance',
+ paymentMode: 'popup',
+ resumeToken: 'resume-55',
+ createdAt: Date.UTC(2024, 0, 1, 0, 0, 0),
+ }
+
+ expect(readPaymentRecoverySnapshot(JSON.stringify(expiredSnapshot), {
+ now: Date.UTC(2024, 0, 1, 0, 20, 0),
+ resumeToken: 'resume-55',
+ })).toBeNull()
+
+ expect(readPaymentRecoverySnapshot(JSON.stringify({
+ ...expiredSnapshot,
+ expiresAt: '2099-01-01T00:10:00.000Z',
+ }), {
+ now: Date.UTC(2099, 0, 1, 0, 1, 0),
+ resumeToken: 'other-token',
+ })).toBeNull()
+ })
+})
diff --git a/frontend/src/components/payment/paymentFlow.ts b/frontend/src/components/payment/paymentFlow.ts
new file mode 100644
index 00000000..70225a0c
--- /dev/null
+++ b/frontend/src/components/payment/paymentFlow.ts
@@ -0,0 +1,197 @@
+import type { CreateOrderResult, MethodLimit, OrderType } from '@/types/payment'
+
+export const PAYMENT_RECOVERY_STORAGE_KEY = 'payment.recovery.current'
+
+const VISIBLE_METHOD_ALIASES = {
+ alipay: 'alipay',
+ alipay_direct: 'alipay',
+ wxpay: 'wxpay',
+ wxpay_direct: 'wxpay',
+} as const
+
+export type VisiblePaymentMethod = 'alipay' | 'wxpay'
+export type StripeVisibleMethod = 'alipay' | 'wechat_pay'
+export type PaymentLaunchKind =
+ | 'qr_waiting'
+ | 'redirect_waiting'
+ | 'stripe_popup'
+ | 'stripe_route'
+ | 'unhandled'
+
+export interface PaymentRecoverySnapshot {
+ orderId: number
+ amount: number
+ qrCode: string
+ expiresAt: string
+ paymentType: string
+ payUrl: string
+ clientSecret: string
+ payAmount: number
+ orderType: OrderType | ''
+ paymentMode: string
+ resumeToken: string
+ createdAt: number
+}
+
+export interface PaymentLaunchContext {
+ visibleMethod: string
+ orderType: OrderType
+ isMobile: boolean
+ now?: number
+ stripePopupUrl?: string
+ stripeRouteUrl?: string
+}
+
+export interface PaymentLaunchDecision {
+ kind: PaymentLaunchKind
+ paymentState: PaymentRecoverySnapshot
+ recovery: PaymentRecoverySnapshot
+ stripeMethod?: StripeVisibleMethod
+}
+
+type CreateOrderFlowResult = CreateOrderResult & {
+ resume_token?: string
+}
+
+type StorageWriter = Pick
+
+export function normalizeVisibleMethod(method: string): VisiblePaymentMethod | '' {
+ const normalized = VISIBLE_METHOD_ALIASES[method.trim() as keyof typeof VISIBLE_METHOD_ALIASES]
+ return normalized ?? ''
+}
+
+export function getVisibleMethods(methods: Record): Record {
+ const visible: Record = {}
+
+ Object.entries(methods).forEach(([type, limit]) => {
+ const normalized = normalizeVisibleMethod(type)
+ if (!normalized) return
+
+ const isCanonical = type === normalized
+ const existing = visible[normalized]
+ if (!existing || isCanonical) {
+ visible[normalized] = { ...limit }
+ }
+ })
+
+ return visible
+}
+
+export function decidePaymentLaunch(
+ result: CreateOrderFlowResult,
+ context: PaymentLaunchContext,
+): PaymentLaunchDecision {
+ const visibleMethod = normalizeVisibleMethod(context.visibleMethod) || context.visibleMethod
+ const baseState = createPaymentRecoverySnapshot({
+ orderId: result.order_id,
+ amount: result.amount,
+ qrCode: result.qr_code || '',
+ expiresAt: result.expires_at || '',
+ paymentType: visibleMethod,
+ payUrl: result.pay_url || '',
+ clientSecret: result.client_secret || '',
+ payAmount: result.pay_amount,
+ orderType: context.orderType,
+ paymentMode: (result.payment_mode || '').trim(),
+ resumeToken: result.resume_token || '',
+ }, context.now)
+
+ if (baseState.clientSecret) {
+ const stripeMethod: StripeVisibleMethod = visibleMethod === 'wxpay' ? 'wechat_pay' : 'alipay'
+ const kind: PaymentLaunchKind = stripeMethod === 'alipay' && !context.isMobile
+ ? 'stripe_popup'
+ : 'stripe_route'
+ const payUrl = kind === 'stripe_popup'
+ ? context.stripePopupUrl || context.stripeRouteUrl || ''
+ : context.stripeRouteUrl || context.stripePopupUrl || ''
+ const paymentState = { ...baseState, payUrl }
+ return { kind, paymentState, recovery: paymentState, stripeMethod }
+ }
+
+ if (baseState.qrCode) {
+ return { kind: 'qr_waiting', paymentState: baseState, recovery: baseState }
+ }
+
+ if (baseState.payUrl) {
+ return { kind: 'redirect_waiting', paymentState: baseState, recovery: baseState }
+ }
+
+ return { kind: 'unhandled', paymentState: baseState, recovery: baseState }
+}
+
+export function createPaymentRecoverySnapshot(
+ state: Omit,
+ now = Date.now(),
+): PaymentRecoverySnapshot {
+ return {
+ ...state,
+ createdAt: now,
+ }
+}
+
+export function writePaymentRecoverySnapshot(
+ storage: StorageWriter,
+ snapshot: PaymentRecoverySnapshot,
+ key = PAYMENT_RECOVERY_STORAGE_KEY,
+): void {
+ storage.setItem(key, JSON.stringify(snapshot))
+}
+
+export function clearPaymentRecoverySnapshot(
+ storage: Pick,
+ key = PAYMENT_RECOVERY_STORAGE_KEY,
+): void {
+ storage.removeItem(key)
+}
+
+export function readPaymentRecoverySnapshot(
+ raw: string | null | undefined,
+ options: { now?: number; resumeToken?: string } = {},
+): PaymentRecoverySnapshot | null {
+ if (!raw) return null
+
+ try {
+ const parsed = JSON.parse(raw) as Partial
+ if (
+ typeof parsed.orderId !== 'number'
+ || typeof parsed.amount !== 'number'
+ || typeof parsed.qrCode !== 'string'
+ || typeof parsed.expiresAt !== 'string'
+ || typeof parsed.paymentType !== 'string'
+ || typeof parsed.payUrl !== 'string'
+ || typeof parsed.clientSecret !== 'string'
+ || typeof parsed.payAmount !== 'number'
+ || typeof parsed.paymentMode !== 'string'
+ || typeof parsed.resumeToken !== 'string'
+ || typeof parsed.createdAt !== 'number'
+ ) {
+ return null
+ }
+
+ const now = options.now ?? Date.now()
+ const expiresAt = Date.parse(parsed.expiresAt)
+ if (Number.isFinite(expiresAt) && expiresAt <= now) {
+ return null
+ }
+ if (options.resumeToken && parsed.resumeToken && parsed.resumeToken !== options.resumeToken) {
+ return null
+ }
+
+ return {
+ orderId: parsed.orderId,
+ amount: parsed.amount,
+ qrCode: parsed.qrCode,
+ expiresAt: parsed.expiresAt,
+ paymentType: parsed.paymentType,
+ payUrl: parsed.payUrl,
+ clientSecret: parsed.clientSecret,
+ payAmount: parsed.payAmount,
+ orderType: parsed.orderType === 'subscription' ? 'subscription' : 'balance',
+ paymentMode: parsed.paymentMode,
+ resumeToken: parsed.resumeToken,
+ createdAt: parsed.createdAt,
+ }
+ } catch {
+ return null
+ }
+}
diff --git a/frontend/src/router/__tests__/wechat-route.spec.ts b/frontend/src/router/__tests__/wechat-route.spec.ts
new file mode 100644
index 00000000..84b20452
--- /dev/null
+++ b/frontend/src/router/__tests__/wechat-route.spec.ts
@@ -0,0 +1,55 @@
+import { describe, expect, it, vi } from 'vitest'
+
+const authStore = vi.hoisted(() => ({
+ checkAuth: vi.fn(),
+ isAuthenticated: false,
+ isAdmin: false,
+ isSimpleMode: false,
+}))
+
+const appStore = vi.hoisted(() => ({
+ siteName: 'Sub2API',
+ backendModeEnabled: false,
+ cachedPublicSettings: null as null | Record,
+}))
+
+vi.mock('@/stores/auth', () => ({
+ useAuthStore: () => authStore,
+}))
+
+vi.mock('@/stores/app', () => ({
+ useAppStore: () => appStore,
+}))
+
+vi.mock('@/stores/adminSettings', () => ({
+ useAdminSettingsStore: () => ({
+ customMenuItems: [],
+ }),
+}))
+
+vi.mock('@/composables/useNavigationLoading', () => ({
+ useNavigationLoadingState: () => ({
+ startNavigation: vi.fn(),
+ endNavigation: vi.fn(),
+ isLoading: { value: false },
+ }),
+}))
+
+vi.mock('@/composables/useRoutePrefetch', () => ({
+ useRoutePrefetch: () => ({
+ triggerPrefetch: vi.fn(),
+ cancelPendingPrefetch: vi.fn(),
+ resetPrefetchState: vi.fn(),
+ }),
+}))
+
+describe('router WeChat OAuth route', () => {
+ it('registers the WeChat callback route as a public route', async () => {
+ const { default: router } = await import('@/router')
+ const route = router.getRoutes().find((record) => record.name === 'WeChatOAuthCallback')
+
+ expect(route?.path).toBe('/auth/wechat/callback')
+ expect(route?.meta.requiresAuth).toBe(false)
+ expect(route?.meta.title).toBe('WeChat OAuth Callback')
+ })
+})
diff --git a/frontend/src/router/index.ts b/frontend/src/router/index.ts
index ad6e71c4..beaa1da2 100644
--- a/frontend/src/router/index.ts
+++ b/frontend/src/router/index.ts
@@ -83,6 +83,15 @@ const routes: RouteRecordRaw[] = [
title: 'LinuxDo OAuth Callback'
}
},
+ {
+ path: '/auth/wechat/callback',
+ name: 'WeChatOAuthCallback',
+ component: () => import('@/views/auth/WechatCallbackView.vue'),
+ meta: {
+ requiresAuth: false,
+ title: 'WeChat OAuth Callback'
+ }
+ },
{
path: '/auth/oidc/callback',
name: 'OIDCOAuthCallback',
diff --git a/frontend/src/stores/app.ts b/frontend/src/stores/app.ts
index 1995383d..1b1af87b 100644
--- a/frontend/src/stores/app.ts
+++ b/frontend/src/stores/app.ts
@@ -336,6 +336,7 @@ export const useAppStore = defineStore('app', () => {
custom_menu_items: [],
custom_endpoints: [],
linuxdo_oauth_enabled: false,
+ wechat_oauth_enabled: false,
oidc_oauth_enabled: false,
oidc_oauth_provider_name: 'OIDC',
backend_mode_enabled: false,
diff --git a/frontend/src/types/index.ts b/frontend/src/types/index.ts
index 89fd777f..529eff55 100644
--- a/frontend/src/types/index.ts
+++ b/frontend/src/types/index.ts
@@ -123,6 +123,7 @@ export interface PublicSettings {
custom_menu_items: CustomMenuItem[]
custom_endpoints: CustomEndpoint[]
linuxdo_oauth_enabled: boolean
+ wechat_oauth_enabled: boolean
oidc_oauth_enabled: boolean
oidc_oauth_provider_name: string
backend_mode_enabled: boolean
diff --git a/frontend/src/views/admin/SettingsView.vue b/frontend/src/views/admin/SettingsView.vue
index 8bfa0f2b..0d23baa5 100644
--- a/frontend/src/views/admin/SettingsView.vue
+++ b/frontend/src/views/admin/SettingsView.vue
@@ -1586,6 +1586,221 @@
+
+
+
+
+ {{ localText('认证来源默认值', 'Auth Source Defaults') }}
+
+
+ {{
+ localText(
+ '按注册来源配置新用户默认余额、并发、订阅与授权策略。',
+ 'Configure per-source default balance, concurrency, subscriptions, and grant rules.'
+ )
+ }}
+
+
+
+
+
+
+ {{ localText('第三方注册强制补充邮箱', 'Require email on third-party signup') }}
+
+
+ {{
+ localText(
+ '启用后,Linux DO、OIDC、微信注册缺少邮箱时必须先补充邮箱地址。',
+ 'When enabled, Linux DO, OIDC, and WeChat signups must provide an email before account creation.'
+ )
+ }}
+
+
+
+
+
+
+
+
+
{{ authSource.title }}
+
+ {{ authSource.description }}
+
+
+
+
+
+
+
+
+
+ {{ localText('注册即授权', 'Grant on signup') }}
+
+
+ {{
+ localText(
+ '来源首次注册成功后立即发放默认权益。',
+ 'Grant default entitlements immediately after signup.'
+ )
+ }}
+
+
+
+
+
+
+
+
+ {{ localText('首次绑定时授权', 'Grant on first bind') }}
+
+
+ {{
+ localText(
+ '来源首次绑定到现有账号时发放默认权益。',
+ 'Grant default entitlements when the source is first bound to an existing user.'
+ )
+ }}
+
+
+
+
+
+
+
+
+
+
+ {{ localText('默认订阅', 'Default subscriptions') }}
+
+
+ {{
+ localText(
+ '仅对当前认证来源生效,未配置时不追加来源专属订阅。',
+ 'Applies only to this auth source. Leave empty to skip source-specific subscriptions.'
+ )
+ }}
+
+
+
+ {{ t('admin.settings.defaults.addDefaultSubscription') }}
+
+
+
+
+ {{
+ localText(
+ '当前来源未配置专属默认订阅。',
+ 'No source-specific default subscriptions configured.'
+ )
+ }}
+
+
+
+
+
+
+ {{ t('admin.settings.defaults.subscriptionGroup') }}
+
+
+
+
+
+ {{ t('admin.settings.defaults.subscriptionGroup') }}
+
+
+
+
+
+
+
+
+
+ {{ t('admin.settings.defaults.subscriptionValidityDays') }}
+
+
+
+
+
+ {{ t('common.delete') }}
+
+
+
+
+
+
+
+
+
@@ -1643,19 +1858,38 @@
-
-
-
- {{ t('admin.settings.scheduling.allowUngroupedKey') }}
+
+
+
+
+ {{ t('admin.settings.scheduling.allowUngroupedKey') }}
+
+
+ {{ t('admin.settings.scheduling.allowUngroupedKeyHint') }}
+
+
+
+
+
-
- {{ t('admin.settings.scheduling.allowUngroupedKeyHint') }}
-
-
-
-
-
+
+
+
+
+ {{ localText('OpenAI 高级调度器', 'OpenAI advanced scheduler') }}
+
+
+ {{
+ localText(
+ '切换 OpenAI 侧新增的高级调度开关,供当前分支实验性调度逻辑使用。',
+ 'Toggles the new OpenAI advanced scheduler flag for the experimental routing logic on this branch.'
+ )
+ }}
+
+
+
+
@@ -2450,6 +2684,59 @@
+
+
+
+
+
+ {{
+ localText(
+ `${visibleMethod.title} 可见方式`,
+ `${visibleMethod.title} visible method`
+ )
+ }}
+
+
+ {{
+ localText(
+ '控制前台结算页是否展示该方式,以及展示时使用的来源键。',
+ 'Controls whether checkout shows this method and which source key it exposes.'
+ )
+ }}
+
+
+
+
+
+
+
+ {{ localText('来源键', 'Source key') }}
+
+
+
+ {{
+ localText(
+ '留空表示由后端使用默认来源;可填 easypay、alipay、wxpay 等来源标识。',
+ 'Leave blank to let the backend decide. Typical values are easypay, alipay, or wxpay.'
+ )
+ }}
+
+
+
+
@@ -2827,7 +3114,14 @@
import { ref, reactive, computed, onMounted } from 'vue'
import { useI18n } from 'vue-i18n'
import { adminAPI } from '@/api'
+import {
+ appendAuthSourceDefaultsToUpdateRequest,
+ buildAuthSourceDefaultsState,
+ normalizeDefaultSubscriptionSettings,
+} from '@/api/admin/settings'
import type {
+ AuthSourceDefaultsState,
+ AuthSourceType,
SystemSettings,
UpdateSettingsRequest,
DefaultSubscriptionSetting,
@@ -2864,6 +3158,10 @@ const { t, locale } = useI18n()
const appStore = useAppStore()
const adminSettingsStore = useAdminSettingsStore()
+function localText(zh: string, en: string): string {
+ return locale.value.startsWith('zh') ? zh : en
+}
+
type SettingsTab = 'general' | 'security' | 'users' | 'gateway' | 'payment' | 'email' | 'backup'
const activeTab = ref
('general')
const settingsTabs = [
@@ -2960,6 +3258,12 @@ type SettingsForm = SystemSettings & {
turnstile_secret_key: string
linuxdo_connect_client_secret: string
oidc_connect_client_secret: string
+ force_email_on_third_party_signup: boolean
+ payment_visible_method_alipay_source: string
+ payment_visible_method_wxpay_source: string
+ payment_visible_method_alipay_enabled: boolean
+ payment_visible_method_wxpay_enabled: boolean
+ openai_advanced_scheduler_enabled: boolean
}
const form = reactive({
@@ -2974,6 +3278,7 @@ const form = reactive({
default_balance: 0,
default_concurrency: 1,
default_subscriptions: [],
+ force_email_on_third_party_signup: false,
site_name: 'Sub2API',
site_logo: '',
site_subtitle: 'Subscription to API Conversion Platform',
@@ -2983,7 +3288,7 @@ const form = reactive({
home_content: '',
backend_mode_enabled: false,
hide_ccs_import_button: false,
- payment_enabled: false, payment_min_amount: 1, payment_max_amount: 10000, payment_daily_limit: 50000, payment_max_pending_orders: 3, payment_order_timeout_minutes: 30, payment_balance_disabled: false, payment_balance_recharge_multiplier: 1, payment_recharge_fee_rate: 0, payment_enabled_types: [], payment_help_image_url: '', payment_help_text: '', payment_product_name_prefix: '', payment_product_name_suffix: '', payment_load_balance_strategy: 'round-robin', payment_cancel_rate_limit_enabled: false, payment_cancel_rate_limit_max: 10, payment_cancel_rate_limit_window: 1, payment_cancel_rate_limit_unit: 'day', payment_cancel_rate_limit_window_mode: 'rolling',
+ payment_enabled: false, payment_min_amount: 1, payment_max_amount: 10000, payment_daily_limit: 50000, payment_max_pending_orders: 3, payment_order_timeout_minutes: 30, payment_balance_disabled: false, payment_balance_recharge_multiplier: 1, payment_recharge_fee_rate: 0, payment_enabled_types: [], payment_help_image_url: '', payment_help_text: '', payment_product_name_prefix: '', payment_product_name_suffix: '', payment_load_balance_strategy: 'round-robin', payment_cancel_rate_limit_enabled: false, payment_cancel_rate_limit_max: 10, payment_cancel_rate_limit_window: 1, payment_cancel_rate_limit_unit: 'day', payment_cancel_rate_limit_window_mode: 'rolling', payment_visible_method_alipay_source: '', payment_visible_method_wxpay_source: '', payment_visible_method_alipay_enabled: false, payment_visible_method_wxpay_enabled: false,
table_default_page_size: tablePageSizeDefault,
table_page_size_options: [10, 20, 50, 100],
custom_menu_items: [] as Array<{id: string; label: string; icon_svg: string; url: string; visibility: 'user' | 'admin'; sort_order: number}>,
@@ -3051,6 +3356,7 @@ const form = reactive({
max_claude_code_version: '',
// 分组隔离
allow_ungrouped_key_scheduling: false,
+ openai_advanced_scheduler_enabled: false,
// Gateway forwarding behavior
enable_fingerprint_unification: true,
enable_metadata_passthrough: false,
@@ -3063,6 +3369,74 @@ const form = reactive({
account_quota_notify_emails: [] as NotifyEmailEntry[]
})
+const authSourceDefaults = reactive(buildAuthSourceDefaultsState({}))
+
+const authSourceDefaultsMeta = computed(() => [
+ {
+ source: 'email' as AuthSourceType,
+ title: localText('邮箱注册', 'Email signup'),
+ description: localText('适用于邮箱密码注册的新用户默认配额。', 'Default quota grants for email-password signups.')
+ },
+ {
+ source: 'linuxdo' as AuthSourceType,
+ title: localText('Linux DO 登录', 'Linux DO signup'),
+ description: localText('适用于 Linux DO 第三方注册的新用户默认配额。', 'Default quota grants for Linux DO signups.')
+ },
+ {
+ source: 'oidc' as AuthSourceType,
+ title: localText('OIDC 登录', 'OIDC signup'),
+ description: localText('适用于 OIDC 第三方注册的新用户默认配额。', 'Default quota grants for OIDC signups.')
+ },
+ {
+ source: 'wechat' as AuthSourceType,
+ title: localText('微信登录', 'WeChat signup'),
+ description: localText('适用于微信第三方注册的新用户默认配额。', 'Default quota grants for WeChat signups.')
+ },
+])
+
+const paymentVisibleMethodCards = computed(() => [
+ {
+ key: 'alipay' as const,
+ title: t('payment.methods.alipay'),
+ enabledField: 'payment_visible_method_alipay_enabled' as const,
+ sourceField: 'payment_visible_method_alipay_source' as const,
+ },
+ {
+ key: 'wxpay' as const,
+ title: t('payment.methods.wxpay'),
+ enabledField: 'payment_visible_method_wxpay_enabled' as const,
+ sourceField: 'payment_visible_method_wxpay_source' as const,
+ },
+])
+
+function getPaymentVisibleMethodEnabled(method: 'alipay' | 'wxpay'): boolean {
+ return method === 'alipay'
+ ? form.payment_visible_method_alipay_enabled
+ : form.payment_visible_method_wxpay_enabled
+}
+
+function setPaymentVisibleMethodEnabled(method: 'alipay' | 'wxpay', enabled: boolean) {
+ if (method === 'alipay') {
+ form.payment_visible_method_alipay_enabled = enabled
+ return
+ }
+ form.payment_visible_method_wxpay_enabled = enabled
+}
+
+function getPaymentVisibleMethodSource(method: 'alipay' | 'wxpay'): string {
+ return method === 'alipay'
+ ? form.payment_visible_method_alipay_source
+ : form.payment_visible_method_wxpay_source
+}
+
+function setPaymentVisibleMethodSource(method: 'alipay' | 'wxpay', source: string) {
+ if (method === 'alipay') {
+ form.payment_visible_method_alipay_source = source
+ return
+ }
+ form.payment_visible_method_wxpay_source = source
+}
+
// Proxies for web search emulation ProxySelector
const webSearchProxies = ref([])
@@ -3428,15 +3802,9 @@ async function loadSettings() {
(form as Record)[key] = value
}
}
+ Object.assign(authSourceDefaults, buildAuthSourceDefaultsState(settings))
form.backend_mode_enabled = settings.backend_mode_enabled
- form.default_subscriptions = Array.isArray(settings.default_subscriptions)
- ? settings.default_subscriptions
- .filter((item) => item.group_id > 0 && item.validity_days > 0)
- .map((item) => ({
- group_id: item.group_id,
- validity_days: item.validity_days
- }))
- : []
+ form.default_subscriptions = normalizeDefaultSubscriptionSettings(settings.default_subscriptions)
registrationEmailSuffixWhitelistTags.value = normalizeRegistrationEmailSuffixDomains(
settings.registration_email_suffix_whitelist
)
@@ -3471,10 +3839,18 @@ async function loadSubscriptionGroups() {
}
}
+function findNextAvailableSubscriptionGroup(
+ existingGroupIDs: number[]
+): AdminGroup | undefined {
+ const existing = new Set(existingGroupIDs)
+ return subscriptionGroups.value.find((group) => !existing.has(group.id))
+}
+
function addDefaultSubscription() {
if (subscriptionGroups.value.length === 0) return
- const existing = new Set(form.default_subscriptions.map((item) => item.group_id))
- const candidate = subscriptionGroups.value.find((group) => !existing.has(group.id))
+ const candidate = findNextAvailableSubscriptionGroup(
+ form.default_subscriptions.map((item) => item.group_id)
+ )
if (!candidate) return
form.default_subscriptions.push({
group_id: candidate.id,
@@ -3486,6 +3862,36 @@ function removeDefaultSubscription(index: number) {
form.default_subscriptions.splice(index, 1)
}
+function addAuthSourceDefaultSubscription(source: AuthSourceType) {
+ if (subscriptionGroups.value.length === 0) return
+ const candidate = findNextAvailableSubscriptionGroup(
+ authSourceDefaults[source].subscriptions.map((item) => item.group_id)
+ )
+ if (!candidate) return
+ authSourceDefaults[source].subscriptions.push({
+ group_id: candidate.id,
+ validity_days: 30
+ })
+}
+
+function removeAuthSourceDefaultSubscription(source: AuthSourceType, index: number) {
+ authSourceDefaults[source].subscriptions.splice(index, 1)
+}
+
+function findDuplicateDefaultSubscription(
+ subscriptions: DefaultSubscriptionSetting[]
+): DefaultSubscriptionSetting | undefined {
+ const seenGroupIDs = new Set()
+
+ return subscriptions.find((item) => {
+ if (seenGroupIDs.has(item.group_id)) {
+ return true
+ }
+ seenGroupIDs.add(item.group_id)
+ return false
+ })
+}
+
async function saveSettings() {
saving.value = true
try {
@@ -3520,21 +3926,12 @@ async function saveSettings() {
form.table_default_page_size = normalizedTableDefaultPageSize
form.table_page_size_options = normalizedTablePageSizeOptions
- const normalizedDefaultSubscriptions = form.default_subscriptions
- .filter((item) => item.group_id > 0 && item.validity_days > 0)
- .map((item: DefaultSubscriptionSetting) => ({
- group_id: item.group_id,
- validity_days: Math.min(36500, Math.max(1, Math.floor(item.validity_days)))
- }))
-
- const seenGroupIDs = new Set()
- const duplicateDefaultSubscription = normalizedDefaultSubscriptions.find((item) => {
- if (seenGroupIDs.has(item.group_id)) {
- return true
- }
- seenGroupIDs.add(item.group_id)
- return false
- })
+ const normalizedDefaultSubscriptions = normalizeDefaultSubscriptionSettings(
+ form.default_subscriptions
+ )
+ const duplicateDefaultSubscription = findDuplicateDefaultSubscription(
+ normalizedDefaultSubscriptions
+ )
if (duplicateDefaultSubscription) {
appStore.showError(
t('admin.settings.defaults.defaultSubscriptionsDuplicate', {
@@ -3544,6 +3941,23 @@ async function saveSettings() {
return
}
+ for (const authSource of authSourceDefaultsMeta.value) {
+ authSourceDefaults[authSource.source].subscriptions = normalizeDefaultSubscriptionSettings(
+ authSourceDefaults[authSource.source].subscriptions
+ )
+ const duplicate = findDuplicateDefaultSubscription(
+ authSourceDefaults[authSource.source].subscriptions
+ )
+ if (duplicate) {
+ appStore.showError(
+ `${authSource.title}: ${t('admin.settings.defaults.defaultSubscriptionsDuplicate', {
+ groupId: duplicate.group_id
+ })}`
+ )
+ return
+ }
+ }
+
// Validate URL fields — novalidate disables browser-native checks, so we validate here
const isValidHttpUrl = (url: string): boolean => {
if (!url) return true
@@ -3571,6 +3985,7 @@ async function saveSettings() {
default_balance: form.default_balance,
default_concurrency: form.default_concurrency,
default_subscriptions: normalizedDefaultSubscriptions,
+ force_email_on_third_party_signup: form.force_email_on_third_party_signup,
site_name: form.site_name,
site_logo: form.site_logo,
site_subtitle: form.site_subtitle,
@@ -3655,6 +4070,11 @@ async function saveSettings() {
payment_cancel_rate_limit_window: Number(form.payment_cancel_rate_limit_window) || 1,
payment_cancel_rate_limit_unit: form.payment_cancel_rate_limit_unit,
payment_cancel_rate_limit_window_mode: form.payment_cancel_rate_limit_window_mode,
+ payment_visible_method_alipay_source: form.payment_visible_method_alipay_source,
+ payment_visible_method_wxpay_source: form.payment_visible_method_wxpay_source,
+ payment_visible_method_alipay_enabled: form.payment_visible_method_alipay_enabled,
+ payment_visible_method_wxpay_enabled: form.payment_visible_method_wxpay_enabled,
+ openai_advanced_scheduler_enabled: form.openai_advanced_scheduler_enabled,
// Balance & quota notification
balance_low_notify_enabled: form.balance_low_notify_enabled,
balance_low_notify_threshold: Number(form.balance_low_notify_threshold) || 0,
@@ -3663,12 +4083,15 @@ async function saveSettings() {
account_quota_notify_emails: (form.account_quota_notify_emails || []).filter((e) => e.email.trim() !== ''),
}
+ appendAuthSourceDefaultsToUpdateRequest(payload, authSourceDefaults)
+
const updated = await adminAPI.settings.updateSettings(payload)
for (const [key, value] of Object.entries(updated)) {
if (value !== null && value !== undefined) {
(form as Record)[key] = value
}
}
+ Object.assign(authSourceDefaults, buildAuthSourceDefaultsState(updated))
registrationEmailSuffixWhitelistTags.value = normalizeRegistrationEmailSuffixDomains(
updated.registration_email_suffix_whitelist
)
diff --git a/frontend/src/views/auth/LinuxDoCallbackView.vue b/frontend/src/views/auth/LinuxDoCallbackView.vue
index af48959b..0a125def 100644
--- a/frontend/src/views/auth/LinuxDoCallbackView.vue
+++ b/frontend/src/views/auth/LinuxDoCallbackView.vue
@@ -11,32 +11,94 @@
-
-
- {{ t('auth.linuxdo.invitationRequired') }}
-
-
-
-
-
-
- {{ invitationError }}
-
-
-
+
- {{ isSubmitting ? t('auth.linuxdo.completing') : t('auth.linuxdo.completeRegistration') }}
-
+
+
+
+ Use LinuxDo profile details
+
+
+ Choose whether to apply the nickname or avatar from LinuxDo to this account.
+
+
+
+
+
+
+
+ Use display name
+
+
+ {{ suggestedDisplayName }}
+
+
+
+
+
+
+
+
+
+ Use avatar
+
+
+ {{ suggestedAvatarUrl }}
+
+
+
+
+
+
+
+
+ {{ t('auth.linuxdo.invitationRequired') }}
+
+
+
+
+
+
+ {{ invitationError }}
+
+
+
+ {{ isSubmitting ? t('auth.linuxdo.completing') : t('auth.linuxdo.completeRegistration') }}
+
+
+
+
+
+ Review the LinuxDo profile details before continuing.
+
+
+ {{ isSubmitting ? t('common.processing') : 'Continue' }}
+
+
@@ -71,7 +133,12 @@ import { useI18n } from 'vue-i18n'
import { AuthLayout } from '@/components/layout'
import Icon from '@/components/icons/Icon.vue'
import { useAuthStore, useAppStore } from '@/stores'
-import { completeLinuxDoOAuthRegistration } from '@/api/auth'
+import {
+ completeLinuxDoOAuthRegistration,
+ exchangePendingOAuthCompletion,
+ type OAuthAdoptionDecision,
+ type PendingOAuthExchangeResponse
+} from '@/api/auth'
const route = useRoute()
const router = useRouter()
@@ -85,11 +152,16 @@ const errorMessage = ref('')
// Invitation code flow state
const needsInvitation = ref(false)
-const pendingOAuthToken = ref('')
const invitationCode = ref('')
const isSubmitting = ref(false)
const invitationError = ref('')
const redirectTo = ref('/dashboard')
+const adoptionRequired = ref(false)
+const suggestedDisplayName = ref('')
+const suggestedAvatarUrl = ref('')
+const adoptDisplayName = ref(true)
+const adoptAvatar = ref(true)
+const needsAdoptionConfirmation = ref(false)
function parseFragmentParams(): URLSearchParams {
const raw = typeof window !== 'undefined' ? window.location.hash : ''
@@ -106,6 +178,54 @@ function sanitizeRedirectPath(path: string | null | undefined): string {
return path
}
+function currentAdoptionDecision(): OAuthAdoptionDecision {
+ return {
+ adoptDisplayName: adoptDisplayName.value,
+ adoptAvatar: adoptAvatar.value
+ }
+}
+
+function applyAdoptionSuggestionState(completion: {
+ adoption_required?: boolean
+ suggested_display_name?: string
+ suggested_avatar_url?: string
+}) {
+ adoptionRequired.value = completion.adoption_required === true
+ suggestedDisplayName.value = completion.suggested_display_name || ''
+ suggestedAvatarUrl.value = completion.suggested_avatar_url || ''
+
+ if (!suggestedDisplayName.value) {
+ adoptDisplayName.value = false
+ }
+ if (!suggestedAvatarUrl.value) {
+ adoptAvatar.value = false
+ }
+}
+
+function hasSuggestedProfile(completion: {
+ suggested_display_name?: string
+ suggested_avatar_url?: string
+}): boolean {
+ return Boolean(completion.suggested_display_name || completion.suggested_avatar_url)
+}
+
+async function finalizeLogin(completion: PendingOAuthExchangeResponse, redirect: string) {
+ if (!completion.access_token) {
+ throw new Error(t('auth.linuxdo.callbackMissingToken'))
+ }
+
+ if (completion.refresh_token) {
+ localStorage.setItem('refresh_token', completion.refresh_token)
+ }
+ if (completion.expires_in) {
+ localStorage.setItem('token_expires_at', String(Date.now() + completion.expires_in * 1000))
+ }
+
+ await authStore.setToken(completion.access_token)
+ appStore.showSuccess(t('auth.loginSuccess'))
+ await router.replace(redirect)
+}
+
async function handleSubmitInvitation() {
invitationError.value = ''
if (!invitationCode.value.trim()) return
@@ -113,8 +233,8 @@ async function handleSubmitInvitation() {
isSubmitting.value = true
try {
const tokenData = await completeLinuxDoOAuthRegistration(
- pendingOAuthToken.value,
- invitationCode.value.trim()
+ invitationCode.value.trim(),
+ currentAdoptionDecision()
)
if (tokenData.refresh_token) {
localStorage.setItem('refresh_token', tokenData.refresh_token)
@@ -134,63 +254,65 @@ async function handleSubmitInvitation() {
}
}
+async function handleContinueLogin() {
+ isSubmitting.value = true
+ try {
+ const completion = await exchangePendingOAuthCompletion(currentAdoptionDecision())
+ await finalizeLogin(completion, redirectTo.value)
+ } catch (e: unknown) {
+ const err = e as { message?: string; response?: { data?: { detail?: string; message?: string } } }
+ errorMessage.value =
+ err.response?.data?.detail ||
+ err.response?.data?.message ||
+ err.message ||
+ t('auth.loginFailed')
+ appStore.showError(errorMessage.value)
+ needsAdoptionConfirmation.value = false
+ } finally {
+ isSubmitting.value = false
+ }
+}
+
onMounted(async () => {
const params = parseFragmentParams()
-
- const token = params.get('access_token') || ''
- const refreshToken = params.get('refresh_token') || ''
- const expiresInStr = params.get('expires_in') || ''
- const redirect = sanitizeRedirectPath(
- params.get('redirect') || (route.query.redirect as string | undefined) || '/dashboard'
- )
const error = params.get('error')
const errorDesc = params.get('error_description') || params.get('error_message') || ''
if (error) {
- if (error === 'invitation_required') {
- pendingOAuthToken.value = params.get('pending_oauth_token') || ''
- redirectTo.value = sanitizeRedirectPath(params.get('redirect'))
- if (!pendingOAuthToken.value) {
- errorMessage.value = t('auth.linuxdo.invalidPendingToken')
- appStore.showError(errorMessage.value)
- isProcessing.value = false
- return
- }
- needsInvitation.value = true
- isProcessing.value = false
- return
- }
errorMessage.value = errorDesc || error
appStore.showError(errorMessage.value)
isProcessing.value = false
return
}
- if (!token) {
- errorMessage.value = t('auth.linuxdo.callbackMissingToken')
- appStore.showError(errorMessage.value)
- isProcessing.value = false
- return
- }
-
try {
- // Store refresh token and expires_at (convert to timestamp) if provided
- if (refreshToken) {
- localStorage.setItem('refresh_token', refreshToken)
- }
- if (expiresInStr) {
- const expiresIn = parseInt(expiresInStr, 10)
- if (!isNaN(expiresIn)) {
- localStorage.setItem('token_expires_at', String(Date.now() + expiresIn * 1000))
- }
+ const completion = await exchangePendingOAuthCompletion()
+ const redirect = sanitizeRedirectPath(
+ completion.redirect || (route.query.redirect as string | undefined) || '/dashboard'
+ )
+ applyAdoptionSuggestionState(completion)
+ redirectTo.value = redirect
+
+ if (completion.error === 'invitation_required') {
+ needsInvitation.value = true
+ isProcessing.value = false
+ return
}
- await authStore.setToken(token)
- appStore.showSuccess(t('auth.loginSuccess'))
- await router.replace(redirect)
+ if (adoptionRequired.value && hasSuggestedProfile(completion)) {
+ needsAdoptionConfirmation.value = true
+ isProcessing.value = false
+ return
+ }
+
+ await finalizeLogin(completion, redirect)
} catch (e: unknown) {
- const err = e as { message?: string; response?: { data?: { detail?: string } } }
- errorMessage.value = err.response?.data?.detail || err.message || t('auth.loginFailed')
+ const err = e as { message?: string; response?: { data?: { detail?: string; message?: string } } }
+ errorMessage.value =
+ err.response?.data?.detail ||
+ err.response?.data?.message ||
+ err.message ||
+ t('auth.loginFailed')
appStore.showError(errorMessage.value)
isProcessing.value = false
}
@@ -209,4 +331,3 @@ onMounted(async () => {
transform: translateY(-8px);
}
-
diff --git a/frontend/src/views/auth/LoginView.vue b/frontend/src/views/auth/LoginView.vue
index 70b64e3f..fa4ac34c 100644
--- a/frontend/src/views/auth/LoginView.vue
+++ b/frontend/src/views/auth/LoginView.vue
@@ -11,12 +11,17 @@
-
+
+
(false)
const turnstileEnabled = ref(false)
const turnstileSiteKey = ref('')
const linuxdoOAuthEnabled = ref(false)
+const wechatOAuthEnabled = ref(false)
const backendModeEnabled = ref(false)
const oidcOAuthEnabled = ref(false)
const oidcOAuthProviderName = ref('OIDC')
@@ -267,6 +274,7 @@ onMounted(async () => {
turnstileEnabled.value = settings.turnstile_enabled
turnstileSiteKey.value = settings.turnstile_site_key || ''
linuxdoOAuthEnabled.value = settings.linuxdo_oauth_enabled
+ wechatOAuthEnabled.value = settings.wechat_oauth_enabled
backendModeEnabled.value = settings.backend_mode_enabled
oidcOAuthEnabled.value = settings.oidc_oauth_enabled
oidcOAuthProviderName.value = settings.oidc_oauth_provider_name || 'OIDC'
diff --git a/frontend/src/views/auth/OidcCallbackView.vue b/frontend/src/views/auth/OidcCallbackView.vue
index a6cb6c12..55f8af6e 100644
--- a/frontend/src/views/auth/OidcCallbackView.vue
+++ b/frontend/src/views/auth/OidcCallbackView.vue
@@ -15,36 +15,99 @@
-
-
- {{ t('auth.oidc.invitationRequired', { providerName }) }}
-
-
-
-
-
-
- {{ invitationError }}
-
-
-
+
- {{
- isSubmitting
- ? t('auth.oidc.completing')
- : t('auth.oidc.completeRegistration')
- }}
-
+
+
+
+ Use {{ providerName }} profile details
+
+
+ Choose whether to apply the nickname or avatar from {{ providerName }} to this
+ account.
+
+
+
+
+
+
+
+ Use display name
+
+
+ {{ suggestedDisplayName }}
+
+
+
+
+
+
+
+
+
+ Use avatar
+
+
+ {{ suggestedAvatarUrl }}
+
+
+
+
+
+
+
+
+ {{ t('auth.oidc.invitationRequired', { providerName }) }}
+
+
+
+
+
+
+ {{ invitationError }}
+
+
+
+ {{
+ isSubmitting
+ ? t('auth.oidc.completing')
+ : t('auth.oidc.completeRegistration')
+ }}
+
+
+
+
+
+ Review the {{ providerName }} profile details before continuing.
+
+
+ {{ isSubmitting ? t('common.processing') : 'Continue' }}
+
+
@@ -81,7 +144,10 @@ import Icon from '@/components/icons/Icon.vue'
import { useAuthStore, useAppStore } from '@/stores'
import {
completeOIDCOAuthRegistration,
- getPublicSettings
+ exchangePendingOAuthCompletion,
+ getPublicSettings,
+ type OAuthAdoptionDecision,
+ type PendingOAuthExchangeResponse
} from '@/api/auth'
const route = useRoute()
@@ -95,12 +161,17 @@ const isProcessing = ref(true)
const errorMessage = ref('')
const needsInvitation = ref(false)
-const pendingOAuthToken = ref('')
const invitationCode = ref('')
const isSubmitting = ref(false)
const invitationError = ref('')
const redirectTo = ref('/dashboard')
const providerName = ref('OIDC')
+const adoptionRequired = ref(false)
+const suggestedDisplayName = ref('')
+const suggestedAvatarUrl = ref('')
+const adoptDisplayName = ref(true)
+const adoptAvatar = ref(true)
+const needsAdoptionConfirmation = ref(false)
function parseFragmentParams(): URLSearchParams {
const raw = typeof window !== 'undefined' ? window.location.hash : ''
@@ -129,6 +200,54 @@ async function loadProviderName() {
}
}
+function currentAdoptionDecision(): OAuthAdoptionDecision {
+ return {
+ adoptDisplayName: adoptDisplayName.value,
+ adoptAvatar: adoptAvatar.value
+ }
+}
+
+function applyAdoptionSuggestionState(completion: {
+ adoption_required?: boolean
+ suggested_display_name?: string
+ suggested_avatar_url?: string
+}) {
+ adoptionRequired.value = completion.adoption_required === true
+ suggestedDisplayName.value = completion.suggested_display_name || ''
+ suggestedAvatarUrl.value = completion.suggested_avatar_url || ''
+
+ if (!suggestedDisplayName.value) {
+ adoptDisplayName.value = false
+ }
+ if (!suggestedAvatarUrl.value) {
+ adoptAvatar.value = false
+ }
+}
+
+function hasSuggestedProfile(completion: {
+ suggested_display_name?: string
+ suggested_avatar_url?: string
+}): boolean {
+ return Boolean(completion.suggested_display_name || completion.suggested_avatar_url)
+}
+
+async function finalizeLogin(completion: PendingOAuthExchangeResponse, redirect: string) {
+ if (!completion.access_token) {
+ throw new Error(t('auth.oidc.callbackMissingToken'))
+ }
+
+ if (completion.refresh_token) {
+ localStorage.setItem('refresh_token', completion.refresh_token)
+ }
+ if (completion.expires_in) {
+ localStorage.setItem('token_expires_at', String(Date.now() + completion.expires_in * 1000))
+ }
+
+ await authStore.setToken(completion.access_token)
+ appStore.showSuccess(t('auth.loginSuccess'))
+ await router.replace(redirect)
+}
+
async function handleSubmitInvitation() {
invitationError.value = ''
if (!invitationCode.value.trim()) return
@@ -136,8 +255,8 @@ async function handleSubmitInvitation() {
isSubmitting.value = true
try {
const tokenData = await completeOIDCOAuthRegistration(
- pendingOAuthToken.value,
- invitationCode.value.trim()
+ invitationCode.value.trim(),
+ currentAdoptionDecision()
)
if (tokenData.refresh_token) {
localStorage.setItem('refresh_token', tokenData.refresh_token)
@@ -157,63 +276,67 @@ async function handleSubmitInvitation() {
}
}
+async function handleContinueLogin() {
+ isSubmitting.value = true
+ try {
+ const completion = await exchangePendingOAuthCompletion(currentAdoptionDecision())
+ await finalizeLogin(completion, redirectTo.value)
+ } catch (e: unknown) {
+ const err = e as { message?: string; response?: { data?: { detail?: string; message?: string } } }
+ errorMessage.value =
+ err.response?.data?.detail ||
+ err.response?.data?.message ||
+ err.message ||
+ t('auth.loginFailed')
+ appStore.showError(errorMessage.value)
+ needsAdoptionConfirmation.value = false
+ } finally {
+ isSubmitting.value = false
+ }
+}
+
onMounted(async () => {
void loadProviderName()
const params = parseFragmentParams()
- const token = params.get('access_token') || ''
- const refreshToken = params.get('refresh_token') || ''
- const expiresInStr = params.get('expires_in') || ''
- const redirect = sanitizeRedirectPath(
- params.get('redirect') || (route.query.redirect as string | undefined) || '/dashboard'
- )
const error = params.get('error')
const errorDesc = params.get('error_description') || params.get('error_message') || ''
if (error) {
- if (error === 'invitation_required') {
- pendingOAuthToken.value = params.get('pending_oauth_token') || ''
- redirectTo.value = sanitizeRedirectPath(params.get('redirect'))
- if (!pendingOAuthToken.value) {
- errorMessage.value = t('auth.oidc.invalidPendingToken')
- appStore.showError(errorMessage.value)
- isProcessing.value = false
- return
- }
- needsInvitation.value = true
- isProcessing.value = false
- return
- }
errorMessage.value = errorDesc || error
appStore.showError(errorMessage.value)
isProcessing.value = false
return
}
- if (!token) {
- errorMessage.value = t('auth.oidc.callbackMissingToken')
- appStore.showError(errorMessage.value)
- isProcessing.value = false
- return
- }
-
try {
- if (refreshToken) {
- localStorage.setItem('refresh_token', refreshToken)
- }
- if (expiresInStr) {
- const expiresIn = parseInt(expiresInStr, 10)
- if (!isNaN(expiresIn)) {
- localStorage.setItem('token_expires_at', String(Date.now() + expiresIn * 1000))
- }
+ const completion = await exchangePendingOAuthCompletion()
+ const redirect = sanitizeRedirectPath(
+ completion.redirect || (route.query.redirect as string | undefined) || '/dashboard'
+ )
+ applyAdoptionSuggestionState(completion)
+ redirectTo.value = redirect
+
+ if (completion.error === 'invitation_required') {
+ needsInvitation.value = true
+ isProcessing.value = false
+ return
}
- await authStore.setToken(token)
- appStore.showSuccess(t('auth.loginSuccess'))
- await router.replace(redirect)
+ if (adoptionRequired.value && hasSuggestedProfile(completion)) {
+ needsAdoptionConfirmation.value = true
+ isProcessing.value = false
+ return
+ }
+
+ await finalizeLogin(completion, redirect)
} catch (e: unknown) {
- const err = e as { message?: string; response?: { data?: { detail?: string } } }
- errorMessage.value = err.response?.data?.detail || err.message || t('auth.loginFailed')
+ const err = e as { message?: string; response?: { data?: { detail?: string; message?: string } } }
+ errorMessage.value =
+ err.response?.data?.detail ||
+ err.response?.data?.message ||
+ err.message ||
+ t('auth.loginFailed')
appStore.showError(errorMessage.value)
isProcessing.value = false
}
diff --git a/frontend/src/views/auth/RegisterView.vue b/frontend/src/views/auth/RegisterView.vue
index bc8b8dce..378f9d8a 100644
--- a/frontend/src/views/auth/RegisterView.vue
+++ b/frontend/src/views/auth/RegisterView.vue
@@ -11,12 +11,17 @@
-
+
+
(false)
const turnstileSiteKey = ref('')
const siteName = ref('Sub2API')
const linuxdoOAuthEnabled = ref(false)
+const wechatOAuthEnabled = ref(false)
const oidcOAuthEnabled = ref(false)
const oidcOAuthProviderName = ref('OIDC')
const registrationEmailSuffixWhitelist = ref([])
@@ -397,6 +404,7 @@ onMounted(async () => {
turnstileSiteKey.value = settings.turnstile_site_key || ''
siteName.value = settings.site_name || 'Sub2API'
linuxdoOAuthEnabled.value = settings.linuxdo_oauth_enabled
+ wechatOAuthEnabled.value = settings.wechat_oauth_enabled
oidcOAuthEnabled.value = settings.oidc_oauth_enabled
oidcOAuthProviderName.value = settings.oidc_oauth_provider_name || 'OIDC'
registrationEmailSuffixWhitelist.value = normalizeRegistrationEmailSuffixWhitelist(
diff --git a/frontend/src/views/auth/WechatCallbackView.vue b/frontend/src/views/auth/WechatCallbackView.vue
new file mode 100644
index 00000000..407b395b
--- /dev/null
+++ b/frontend/src/views/auth/WechatCallbackView.vue
@@ -0,0 +1,361 @@
+
+
+
+
+
+ {{ t('auth.oidc.callbackTitle', { providerName }) }}
+
+
+ {{
+ isProcessing
+ ? t('auth.oidc.callbackProcessing', { providerName })
+ : t('auth.oidc.callbackHint')
+ }}
+
+
+
+
+
+
+
+
+
+ Use {{ providerName }} profile details
+
+
+ Choose whether to apply the nickname or avatar from {{ providerName }} to this account.
+
+
+
+
+
+
+
+ Use display name
+
+
+ {{ suggestedDisplayName }}
+
+
+
+
+
+
+
+
+
+ Use avatar
+
+
+ {{ suggestedAvatarUrl }}
+
+
+
+
+
+
+
+
+ {{ t('auth.oidc.invitationRequired', { providerName }) }}
+
+
+
+
+
+
+ {{ invitationError }}
+
+
+
+ {{
+ isSubmitting
+ ? t('auth.oidc.completing')
+ : t('auth.oidc.completeRegistration')
+ }}
+
+
+
+
+
+ Review the {{ providerName }} profile details before continuing.
+
+
+ {{ isSubmitting ? t('common.processing') : 'Continue' }}
+
+
+
+
+
+
+
+
+
+
+
+
+
+ {{ errorMessage }}
+
+
+ {{ t('auth.oidc.backToLogin') }}
+
+
+
+
+
+
+
+
+
+
+
+
diff --git a/frontend/src/views/auth/__tests__/LinuxDoCallbackView.spec.ts b/frontend/src/views/auth/__tests__/LinuxDoCallbackView.spec.ts
new file mode 100644
index 00000000..60a40474
--- /dev/null
+++ b/frontend/src/views/auth/__tests__/LinuxDoCallbackView.spec.ts
@@ -0,0 +1,180 @@
+import { beforeEach, describe, expect, it, vi } from 'vitest'
+import { flushPromises, mount } from '@vue/test-utils'
+
+import LinuxDoCallbackView from '../LinuxDoCallbackView.vue'
+
+const replace = vi.fn()
+const showSuccess = vi.fn()
+const showError = vi.fn()
+const setToken = vi.fn()
+const exchangePendingOAuthCompletion = vi.fn()
+const completeLinuxDoOAuthRegistration = vi.fn()
+
+vi.mock('vue-router', () => ({
+ useRoute: () => ({
+ query: {}
+ }),
+ useRouter: () => ({
+ replace
+ })
+}))
+
+vi.mock('vue-i18n', async () => {
+ const actual = await vi.importActual('vue-i18n')
+ return {
+ ...actual,
+ useI18n: () => ({
+ t: (key: string) => key
+ })
+ }
+})
+
+vi.mock('@/stores', () => ({
+ useAuthStore: () => ({
+ setToken
+ }),
+ useAppStore: () => ({
+ showSuccess,
+ showError
+ })
+}))
+
+vi.mock('@/api/auth', () => ({
+ exchangePendingOAuthCompletion: (...args: any[]) => exchangePendingOAuthCompletion(...args),
+ completeLinuxDoOAuthRegistration: (...args: any[]) => completeLinuxDoOAuthRegistration(...args)
+}))
+
+describe('LinuxDoCallbackView', () => {
+ beforeEach(() => {
+ replace.mockReset()
+ showSuccess.mockReset()
+ showError.mockReset()
+ setToken.mockReset()
+ exchangePendingOAuthCompletion.mockReset()
+ completeLinuxDoOAuthRegistration.mockReset()
+ })
+
+ it('does not send adoption decisions during the initial exchange', async () => {
+ exchangePendingOAuthCompletion.mockResolvedValue({
+ access_token: 'access-token',
+ refresh_token: 'refresh-token',
+ expires_in: 3600,
+ redirect: '/dashboard',
+ adoption_required: true
+ })
+ setToken.mockResolvedValue({})
+
+ mount(LinuxDoCallbackView, {
+ global: {
+ stubs: {
+ AuthLayout: { template: '
' },
+ Icon: true,
+ RouterLink: { template: ' ' },
+ transition: false
+ }
+ }
+ })
+
+ await flushPromises()
+
+ expect(exchangePendingOAuthCompletion).toHaveBeenCalledTimes(1)
+ expect(exchangePendingOAuthCompletion).toHaveBeenCalledWith()
+ })
+
+ it('waits for explicit adoption confirmation before finishing a non-invitation login', async () => {
+ exchangePendingOAuthCompletion
+ .mockResolvedValueOnce({
+ redirect: '/dashboard',
+ adoption_required: true,
+ suggested_display_name: 'LinuxDo Nick',
+ suggested_avatar_url: 'https://cdn.example/linuxdo.png'
+ })
+ .mockResolvedValueOnce({
+ access_token: 'access-token',
+ refresh_token: 'refresh-token',
+ expires_in: 3600,
+ redirect: '/dashboard'
+ })
+ setToken.mockResolvedValue({})
+
+ const wrapper = mount(LinuxDoCallbackView, {
+ global: {
+ stubs: {
+ AuthLayout: { template: '
' },
+ Icon: true,
+ RouterLink: { template: ' ' },
+ transition: false
+ }
+ }
+ })
+
+ await flushPromises()
+
+ expect(wrapper.text()).toContain('LinuxDo Nick')
+ expect(setToken).not.toHaveBeenCalled()
+ expect(replace).not.toHaveBeenCalled()
+
+ const checkboxes = wrapper.findAll('input[type="checkbox"]')
+ await checkboxes[1].setValue(false)
+
+ const buttons = wrapper.findAll('button')
+ expect(buttons).toHaveLength(1)
+ await buttons[0].trigger('click')
+ await flushPromises()
+
+ expect(exchangePendingOAuthCompletion).toHaveBeenCalledTimes(2)
+ expect(exchangePendingOAuthCompletion).toHaveBeenNthCalledWith(1)
+ expect(exchangePendingOAuthCompletion).toHaveBeenNthCalledWith(2, {
+ adoptDisplayName: true,
+ adoptAvatar: false
+ })
+ expect(setToken).toHaveBeenCalledWith('access-token')
+ expect(replace).toHaveBeenCalledWith('/dashboard')
+ })
+
+ it('renders adoption choices for invitation flow and submits the selected values', async () => {
+ exchangePendingOAuthCompletion.mockResolvedValue({
+ error: 'invitation_required',
+ redirect: '/dashboard',
+ adoption_required: true,
+ suggested_display_name: 'LinuxDo Nick',
+ suggested_avatar_url: 'https://cdn.example/linuxdo.png'
+ })
+ completeLinuxDoOAuthRegistration.mockResolvedValue({
+ access_token: 'access-token',
+ refresh_token: 'refresh-token',
+ expires_in: 3600,
+ token_type: 'Bearer'
+ })
+ setToken.mockResolvedValue({})
+
+ const wrapper = mount(LinuxDoCallbackView, {
+ global: {
+ stubs: {
+ AuthLayout: { template: '
' },
+ Icon: true,
+ RouterLink: { template: ' ' },
+ transition: false
+ }
+ }
+ })
+
+ await flushPromises()
+
+ expect(wrapper.text()).toContain('LinuxDo Nick')
+ expect(exchangePendingOAuthCompletion).toHaveBeenCalledTimes(1)
+ expect(exchangePendingOAuthCompletion).toHaveBeenCalledWith()
+
+ const checkboxes = wrapper.findAll('input[type="checkbox"]')
+ expect(checkboxes).toHaveLength(2)
+
+ await checkboxes[0].setValue(false)
+ await wrapper.find('input[type="text"]').setValue('invite-code')
+ await wrapper.find('button').trigger('click')
+
+ expect(completeLinuxDoOAuthRegistration).toHaveBeenCalledWith('invite-code', {
+ adoptDisplayName: false,
+ adoptAvatar: true
+ })
+ })
+})
diff --git a/frontend/src/views/auth/__tests__/OidcCallbackView.spec.ts b/frontend/src/views/auth/__tests__/OidcCallbackView.spec.ts
new file mode 100644
index 00000000..299c0746
--- /dev/null
+++ b/frontend/src/views/auth/__tests__/OidcCallbackView.spec.ts
@@ -0,0 +1,191 @@
+import { beforeEach, describe, expect, it, vi } from 'vitest'
+import { flushPromises, mount } from '@vue/test-utils'
+
+import OidcCallbackView from '../OidcCallbackView.vue'
+
+const replace = vi.fn()
+const showSuccess = vi.fn()
+const showError = vi.fn()
+const setToken = vi.fn()
+const exchangePendingOAuthCompletion = vi.fn()
+const completeOIDCOAuthRegistration = vi.fn()
+const getPublicSettings = vi.fn()
+
+vi.mock('vue-router', () => ({
+ useRoute: () => ({
+ query: {}
+ }),
+ useRouter: () => ({
+ replace
+ })
+}))
+
+vi.mock('vue-i18n', async () => {
+ const actual = await vi.importActual('vue-i18n')
+ return {
+ ...actual,
+ useI18n: () => ({
+ t: (key: string, params?: Record) => {
+ if (!params?.providerName) {
+ return key
+ }
+ return `${key}:${params.providerName}`
+ }
+ })
+ }
+})
+
+vi.mock('@/stores', () => ({
+ useAuthStore: () => ({
+ setToken
+ }),
+ useAppStore: () => ({
+ showSuccess,
+ showError
+ })
+}))
+
+vi.mock('@/api/auth', () => ({
+ exchangePendingOAuthCompletion: (...args: any[]) => exchangePendingOAuthCompletion(...args),
+ completeOIDCOAuthRegistration: (...args: any[]) => completeOIDCOAuthRegistration(...args),
+ getPublicSettings: (...args: any[]) => getPublicSettings(...args)
+}))
+
+describe('OidcCallbackView', () => {
+ beforeEach(() => {
+ replace.mockReset()
+ showSuccess.mockReset()
+ showError.mockReset()
+ setToken.mockReset()
+ exchangePendingOAuthCompletion.mockReset()
+ completeOIDCOAuthRegistration.mockReset()
+ getPublicSettings.mockReset()
+ getPublicSettings.mockResolvedValue({
+ oidc_oauth_provider_name: 'ExampleID'
+ })
+ })
+
+ it('does not send adoption decisions during the initial exchange', async () => {
+ exchangePendingOAuthCompletion.mockResolvedValue({
+ access_token: 'access-token',
+ refresh_token: 'refresh-token',
+ expires_in: 3600,
+ redirect: '/dashboard',
+ adoption_required: true
+ })
+ setToken.mockResolvedValue({})
+
+ mount(OidcCallbackView, {
+ global: {
+ stubs: {
+ AuthLayout: { template: '
' },
+ Icon: true,
+ RouterLink: { template: ' ' },
+ transition: false
+ }
+ }
+ })
+
+ await flushPromises()
+
+ expect(exchangePendingOAuthCompletion).toHaveBeenCalledTimes(1)
+ expect(exchangePendingOAuthCompletion).toHaveBeenCalledWith()
+ })
+
+ it('waits for explicit adoption confirmation before finishing a non-invitation login', async () => {
+ exchangePendingOAuthCompletion
+ .mockResolvedValueOnce({
+ redirect: '/dashboard',
+ adoption_required: true,
+ suggested_display_name: 'OIDC Nick',
+ suggested_avatar_url: 'https://cdn.example/oidc.png'
+ })
+ .mockResolvedValueOnce({
+ access_token: 'access-token',
+ refresh_token: 'refresh-token',
+ expires_in: 3600,
+ redirect: '/dashboard'
+ })
+ setToken.mockResolvedValue({})
+
+ const wrapper = mount(OidcCallbackView, {
+ global: {
+ stubs: {
+ AuthLayout: { template: '
' },
+ Icon: true,
+ RouterLink: { template: ' ' },
+ transition: false
+ }
+ }
+ })
+
+ await flushPromises()
+
+ expect(wrapper.text()).toContain('OIDC Nick')
+ expect(setToken).not.toHaveBeenCalled()
+ expect(replace).not.toHaveBeenCalled()
+
+ const checkboxes = wrapper.findAll('input[type="checkbox"]')
+ await checkboxes[0].setValue(false)
+
+ const buttons = wrapper.findAll('button')
+ expect(buttons).toHaveLength(1)
+ await buttons[0].trigger('click')
+ await flushPromises()
+
+ expect(exchangePendingOAuthCompletion).toHaveBeenCalledTimes(2)
+ expect(exchangePendingOAuthCompletion).toHaveBeenNthCalledWith(1)
+ expect(exchangePendingOAuthCompletion).toHaveBeenNthCalledWith(2, {
+ adoptDisplayName: false,
+ adoptAvatar: true
+ })
+ expect(setToken).toHaveBeenCalledWith('access-token')
+ expect(replace).toHaveBeenCalledWith('/dashboard')
+ })
+
+ it('renders adoption choices for invitation flow and submits the selected values', async () => {
+ exchangePendingOAuthCompletion.mockResolvedValue({
+ error: 'invitation_required',
+ redirect: '/dashboard',
+ adoption_required: true,
+ suggested_display_name: 'OIDC Nick',
+ suggested_avatar_url: 'https://cdn.example/oidc.png'
+ })
+ completeOIDCOAuthRegistration.mockResolvedValue({
+ access_token: 'access-token',
+ refresh_token: 'refresh-token',
+ expires_in: 3600,
+ token_type: 'Bearer'
+ })
+ setToken.mockResolvedValue({})
+
+ const wrapper = mount(OidcCallbackView, {
+ global: {
+ stubs: {
+ AuthLayout: { template: '
' },
+ Icon: true,
+ RouterLink: { template: ' ' },
+ transition: false
+ }
+ }
+ })
+
+ await flushPromises()
+
+ expect(wrapper.text()).toContain('OIDC Nick')
+ expect(exchangePendingOAuthCompletion).toHaveBeenCalledTimes(1)
+ expect(exchangePendingOAuthCompletion).toHaveBeenCalledWith()
+
+ const checkboxes = wrapper.findAll('input[type="checkbox"]')
+ expect(checkboxes).toHaveLength(2)
+
+ await checkboxes[1].setValue(false)
+ await wrapper.find('input[type="text"]').setValue('invite-code')
+ await wrapper.find('button').trigger('click')
+
+ expect(completeOIDCOAuthRegistration).toHaveBeenCalledWith('invite-code', {
+ adoptDisplayName: true,
+ adoptAvatar: false
+ })
+ })
+})
diff --git a/frontend/src/views/auth/__tests__/WechatCallbackView.spec.ts b/frontend/src/views/auth/__tests__/WechatCallbackView.spec.ts
new file mode 100644
index 00000000..a9e2ada2
--- /dev/null
+++ b/frontend/src/views/auth/__tests__/WechatCallbackView.spec.ts
@@ -0,0 +1,241 @@
+import { flushPromises, mount } from '@vue/test-utils'
+import { beforeEach, describe, expect, it, vi } from 'vitest'
+import WechatCallbackView from '@/views/auth/WechatCallbackView.vue'
+
+const {
+ postMock,
+ replaceMock,
+ setTokenMock,
+ showSuccessMock,
+ showErrorMock,
+ routeState,
+} = vi.hoisted(() => ({
+ postMock: vi.fn(),
+ replaceMock: vi.fn(),
+ setTokenMock: vi.fn(),
+ showSuccessMock: vi.fn(),
+ showErrorMock: vi.fn(),
+ routeState: {
+ query: {} as Record,
+ },
+}))
+
+vi.mock('vue-router', () => ({
+ useRoute: () => routeState,
+ useRouter: () => ({
+ replace: replaceMock,
+ }),
+}))
+
+vi.mock('vue-i18n', () => ({
+ createI18n: () => ({
+ global: {
+ t: (key: string) => key,
+ },
+ }),
+ useI18n: () => ({
+ t: (key: string, params?: Record) => {
+ if (key === 'auth.oidc.callbackTitle') {
+ return `Signing you in with ${params?.providerName ?? ''}`.trim()
+ }
+ if (key === 'auth.oidc.callbackProcessing') {
+ return `Completing login with ${params?.providerName ?? ''}`.trim()
+ }
+ if (key === 'auth.oidc.invitationRequired') {
+ return `${params?.providerName ?? ''} invitation required`.trim()
+ }
+ if (key === 'auth.oidc.completeRegistration') {
+ return 'Complete registration'
+ }
+ if (key === 'auth.oidc.completing') {
+ return 'Completing'
+ }
+ if (key === 'auth.oidc.backToLogin') {
+ return 'Back to login'
+ }
+ if (key === 'auth.invitationCodePlaceholder') {
+ return 'Invitation code'
+ }
+ if (key === 'auth.loginSuccess') {
+ return 'Login success'
+ }
+ if (key === 'auth.loginFailed') {
+ return 'Login failed'
+ }
+ if (key === 'auth.oidc.callbackHint') {
+ return 'Callback hint'
+ }
+ if (key === 'auth.oidc.callbackMissingToken') {
+ return 'Missing login token'
+ }
+ if (key === 'auth.oidc.completeRegistrationFailed') {
+ return 'Complete registration failed'
+ }
+ return key
+ },
+ }),
+}))
+
+vi.mock('@/stores', () => ({
+ useAuthStore: () => ({
+ setToken: setTokenMock,
+ }),
+ useAppStore: () => ({
+ showSuccess: showSuccessMock,
+ showError: showErrorMock,
+ }),
+}))
+
+vi.mock('@/api/client', () => ({
+ apiClient: {
+ post: postMock,
+ },
+}))
+
+describe('WechatCallbackView', () => {
+ beforeEach(() => {
+ postMock.mockReset()
+ replaceMock.mockReset()
+ setTokenMock.mockReset()
+ showSuccessMock.mockReset()
+ showErrorMock.mockReset()
+ routeState.query = {}
+ localStorage.clear()
+ })
+
+ it('does not send adoption decisions during the initial exchange', async () => {
+ postMock.mockResolvedValueOnce({
+ data: {
+ access_token: 'access-token',
+ refresh_token: 'refresh-token',
+ expires_in: 3600,
+ redirect: '/dashboard',
+ adoption_required: true,
+ },
+ })
+ setTokenMock.mockResolvedValue({})
+
+ mount(WechatCallbackView, {
+ global: {
+ stubs: {
+ AuthLayout: { template: '
' },
+ Icon: true,
+ RouterLink: { template: ' ' },
+ transition: false,
+ },
+ },
+ })
+
+ await flushPromises()
+
+ expect(postMock).toHaveBeenCalledWith('/auth/oauth/pending/exchange', {})
+ expect(postMock).toHaveBeenCalledTimes(1)
+ })
+
+ it('waits for explicit adoption confirmation before finishing a non-invitation login', async () => {
+ postMock
+ .mockResolvedValueOnce({
+ data: {
+ redirect: '/dashboard',
+ adoption_required: true,
+ suggested_display_name: 'WeChat Nick',
+ suggested_avatar_url: 'https://cdn.example/wechat.png',
+ },
+ })
+ .mockResolvedValueOnce({
+ data: {
+ access_token: 'wechat-access-token',
+ refresh_token: 'wechat-refresh-token',
+ expires_in: 3600,
+ token_type: 'Bearer',
+ redirect: '/dashboard',
+ },
+ })
+ setTokenMock.mockResolvedValue({})
+
+ const wrapper = mount(WechatCallbackView, {
+ global: {
+ stubs: {
+ AuthLayout: { template: '
' },
+ Icon: true,
+ RouterLink: { template: ' ' },
+ transition: false,
+ },
+ },
+ })
+
+ await flushPromises()
+
+ expect(wrapper.text()).toContain('WeChat Nick')
+ expect(setTokenMock).not.toHaveBeenCalled()
+ expect(replaceMock).not.toHaveBeenCalled()
+
+ const checkboxes = wrapper.findAll('input[type="checkbox"]')
+ expect(checkboxes).toHaveLength(2)
+ await checkboxes[1].setValue(false)
+
+ const buttons = wrapper.findAll('button')
+ expect(buttons).toHaveLength(1)
+ await buttons[0].trigger('click')
+ await flushPromises()
+
+ expect(postMock).toHaveBeenNthCalledWith(1, '/auth/oauth/pending/exchange', {})
+ expect(postMock).toHaveBeenNthCalledWith(2, '/auth/oauth/pending/exchange', {
+ adopt_display_name: true,
+ adopt_avatar: false,
+ })
+ expect(setTokenMock).toHaveBeenCalledWith('wechat-access-token')
+ expect(replaceMock).toHaveBeenCalledWith('/dashboard')
+ expect(localStorage.getItem('refresh_token')).toBe('wechat-refresh-token')
+ })
+
+ it('renders adoption choices for invitation flow and submits the selected values', async () => {
+ postMock
+ .mockResolvedValueOnce({
+ data: {
+ error: 'invitation_required',
+ redirect: '/subscriptions',
+ adoption_required: true,
+ suggested_display_name: 'WeChat Nick',
+ suggested_avatar_url: 'https://cdn.example/wechat.png',
+ },
+ })
+ .mockResolvedValueOnce({
+ data: {
+ access_token: 'wechat-invite-token',
+ refresh_token: 'wechat-invite-refresh',
+ expires_in: 600,
+ token_type: 'Bearer',
+ },
+ })
+
+ const wrapper = mount(WechatCallbackView, {
+ global: {
+ stubs: {
+ AuthLayout: { template: '
' },
+ Icon: true,
+ RouterLink: { template: ' ' },
+ transition: false,
+ },
+ },
+ })
+
+ await flushPromises()
+
+ expect(wrapper.text()).toContain('WeChat Nick')
+ const checkboxes = wrapper.findAll('input[type="checkbox"]')
+ expect(checkboxes).toHaveLength(2)
+ await checkboxes[0].setValue(false)
+ await wrapper.get('input[type="text"]').setValue(' INVITE-CODE ')
+ await wrapper.get('button').trigger('click')
+ await flushPromises()
+
+ expect(postMock).toHaveBeenNthCalledWith(2, '/auth/oauth/wechat/complete-registration', {
+ invitation_code: 'INVITE-CODE',
+ adopt_display_name: false,
+ adopt_avatar: true,
+ })
+ expect(setTokenMock).toHaveBeenCalledWith('wechat-invite-token')
+ expect(replaceMock).toHaveBeenCalledWith('/subscriptions')
+ })
+})
diff --git a/frontend/src/views/user/PaymentView.vue b/frontend/src/views/user/PaymentView.vue
index e91df5da..838f3000 100644
--- a/frontend/src/views/user/PaymentView.vue
+++ b/frontend/src/views/user/PaymentView.vue
@@ -23,20 +23,7 @@
:order-type="paymentState.orderType"
@done="onPaymentDone"
@success="onPaymentSuccess"
- />
-
-
-
@@ -265,7 +252,7 @@
diff --git a/frontend/src/components/user/profile/ProfileInfoCard.vue b/frontend/src/components/user/profile/ProfileInfoCard.vue
index b6f6022d..e82ae229 100644
--- a/frontend/src/components/user/profile/ProfileInfoCard.vue
+++ b/frontend/src/components/user/profile/ProfileInfoCard.vue
@@ -4,11 +4,16 @@
class="border-b border-gray-100 bg-gradient-to-r from-primary-500/10 to-primary-600/5 px-6 py-5 dark:border-dark-700 dark:from-primary-500/20 dark:to-primary-600/10"
>
-
- {{ user?.email?.charAt(0).toUpperCase() || 'U' }}
+
+
{{ avatarInitial }}
@@ -41,18 +46,163 @@
{{ user.username }}
+
+
+
+
+ {{ hint.text }}
+
+
+
+
diff --git a/frontend/src/components/user/profile/__tests__/ProfileIdentityBindingsSection.spec.ts b/frontend/src/components/user/profile/__tests__/ProfileIdentityBindingsSection.spec.ts
new file mode 100644
index 00000000..1c9531e3
--- /dev/null
+++ b/frontend/src/components/user/profile/__tests__/ProfileIdentityBindingsSection.spec.ts
@@ -0,0 +1,120 @@
+import { mount } from '@vue/test-utils'
+import { afterEach, beforeEach, describe, expect, it, vi } from 'vitest'
+import ProfileIdentityBindingsSection from '@/components/user/profile/ProfileIdentityBindingsSection.vue'
+import type { User } from '@/types'
+
+const routeState = vi.hoisted(() => ({
+ fullPath: '/profile',
+}))
+
+const locationState = vi.hoisted(() => ({
+ current: { href: 'http://localhost/profile' } as { href: string },
+}))
+
+vi.mock('vue-router', () => ({
+ useRoute: () => routeState,
+}))
+
+vi.mock('vue-i18n', async (importOriginal) => {
+ const actual = await importOriginal()
+ return {
+ ...actual,
+ useI18n: () => ({
+ t: (key: string, params?: Record) => {
+ if (key === 'profile.authBindings.title') return 'Connected sign-in methods'
+ if (key === 'profile.authBindings.description') return 'Manage bound providers'
+ if (key === 'profile.authBindings.status.bound') return 'Bound'
+ if (key === 'profile.authBindings.status.notBound') return 'Not bound'
+ if (key === 'profile.authBindings.providers.email') return 'Email'
+ if (key === 'profile.authBindings.providers.linuxdo') return 'LinuxDo'
+ if (key === 'profile.authBindings.providers.wechat') return 'WeChat'
+ if (key === 'profile.authBindings.providers.oidc') return params?.providerName || 'OIDC'
+ if (key === 'profile.authBindings.bindAction') return `Bind ${params?.providerName || ''}`.trim()
+ return key
+ },
+ }),
+ }
+})
+
+function createUser(overrides: Partial = {}): User {
+ return {
+ id: 7,
+ username: 'alice',
+ email: 'alice@example.com',
+ role: 'user',
+ balance: 10,
+ concurrency: 2,
+ status: 'active',
+ allowed_groups: null,
+ balance_notify_enabled: true,
+ balance_notify_threshold: null,
+ balance_notify_extra_emails: [],
+ created_at: '2026-04-20T00:00:00Z',
+ updated_at: '2026-04-20T00:00:00Z',
+ ...overrides,
+ }
+}
+
+describe('ProfileIdentityBindingsSection', () => {
+ beforeEach(() => {
+ routeState.fullPath = '/profile'
+ locationState.current = { href: 'http://localhost/profile' }
+ Object.defineProperty(window, 'location', {
+ configurable: true,
+ value: locationState.current,
+ })
+ Object.defineProperty(window.navigator, 'userAgent', {
+ configurable: true,
+ value: 'Mozilla/5.0',
+ })
+ })
+
+ afterEach(() => {
+ vi.unstubAllGlobals()
+ })
+
+ it('renders provider binding states and provider-specific bind actions', () => {
+ const wrapper = mount(ProfileIdentityBindingsSection, {
+ props: {
+ user: createUser({
+ auth_bindings: {
+ email: { bound: true },
+ linuxdo: { bound: true },
+ oidc: { bound: false },
+ wechat: false,
+ },
+ }),
+ linuxdoEnabled: true,
+ oidcEnabled: true,
+ oidcProviderName: 'ExampleID',
+ wechatEnabled: true,
+ },
+ })
+
+ expect(wrapper.get('[data-testid="profile-binding-email-status"]').text()).toBe('Bound')
+ expect(wrapper.get('[data-testid="profile-binding-linuxdo-status"]').text()).toBe('Bound')
+ expect(wrapper.get('[data-testid="profile-binding-oidc-status"]').text()).toBe('Not bound')
+ expect(wrapper.get('[data-testid="profile-binding-oidc-action"]').text()).toBe(
+ 'Bind ExampleID'
+ )
+ expect(wrapper.get('[data-testid="profile-binding-wechat-action"]').text()).toBe('Bind WeChat')
+ })
+
+ it('starts the WeChat bind flow for the current profile page', async () => {
+ const wrapper = mount(ProfileIdentityBindingsSection, {
+ props: {
+ user: createUser(),
+ linuxdoEnabled: false,
+ oidcEnabled: false,
+ wechatEnabled: true,
+ },
+ })
+
+ await wrapper.get('[data-testid="profile-binding-wechat-action"]').trigger('click')
+
+ expect(locationState.current.href).toContain('/api/v1/auth/oauth/wechat/start?')
+ expect(locationState.current.href).toContain('mode=open')
+ expect(locationState.current.href).toContain('intent=bind_current_user')
+ expect(locationState.current.href).toContain('redirect=%2Fprofile')
+ })
+})
diff --git a/frontend/src/i18n/locales/en.ts b/frontend/src/i18n/locales/en.ts
index 684c196f..7d058a74 100644
--- a/frontend/src/i18n/locales/en.ts
+++ b/frontend/src/i18n/locales/en.ts
@@ -940,6 +940,26 @@ export default {
maxEmailsReached: 'Maximum number of notification emails reached',
unverified: 'Unverified',
verified: 'Verified',
+ },
+ authBindings: {
+ title: 'Connected Sign-In Methods',
+ description: 'View current bindings and connect another provider to this account.',
+ bindAction: 'Bind {providerName}',
+ bindSuccess: 'Account linked successfully',
+ status: {
+ bound: 'Bound',
+ notBound: 'Not bound',
+ },
+ providers: {
+ email: 'Email',
+ linuxdo: 'LinuxDo',
+ oidc: '{providerName}',
+ wechat: 'WeChat',
+ },
+ source: {
+ avatar: 'Avatar is currently synced from {providerName}',
+ username: 'Nickname is currently synced from {providerName}',
+ },
}
},
diff --git a/frontend/src/i18n/locales/zh.ts b/frontend/src/i18n/locales/zh.ts
index 2a4c69a5..6dd74334 100644
--- a/frontend/src/i18n/locales/zh.ts
+++ b/frontend/src/i18n/locales/zh.ts
@@ -944,6 +944,26 @@ export default {
maxEmailsReached: '已达到通知邮箱数量上限',
unverified: '未验证',
verified: '已验证',
+ },
+ authBindings: {
+ title: '登录方式绑定',
+ description: '查看当前绑定状态,并将更多第三方登录方式关联到这个账号。',
+ bindAction: '绑定 {providerName}',
+ bindSuccess: '账号绑定成功',
+ status: {
+ bound: '已绑定',
+ notBound: '未绑定',
+ },
+ providers: {
+ email: '邮箱',
+ linuxdo: 'LinuxDo',
+ oidc: '{providerName}',
+ wechat: '微信',
+ },
+ source: {
+ avatar: '头像当前来自 {providerName}',
+ username: '昵称当前来自 {providerName}',
+ },
}
},
diff --git a/frontend/src/types/index.ts b/frontend/src/types/index.ts
index 9c9722a9..a19d6c26 100644
--- a/frontend/src/types/index.ts
+++ b/frontend/src/types/index.ts
@@ -34,10 +34,47 @@ export interface NotifyEmailEntry {
// ==================== User & Auth Types ====================
+export type UserAuthProvider = 'email' | 'linuxdo' | 'oidc' | 'wechat'
+
+export interface UserAuthBindingStatus {
+ bound?: boolean
+ provider?: UserAuthProvider | string
+ provider_key?: string | null
+ provider_subject?: string | null
+ issuer?: string | null
+ label?: string | null
+ provider_label?: string | null
+ metadata?: Record
+}
+
+export interface UserProfileSourceContext {
+ provider?: UserAuthProvider | string
+ source?: string | null
+ label?: string | null
+ provider_label?: string | null
+}
+
export interface User {
id: number
username: string
email: string
+ avatar_url?: string | null
+ avatar_source?: string | UserProfileSourceContext | null
+ username_source?: string | UserProfileSourceContext | null
+ display_name_source?: string | UserProfileSourceContext | null
+ nickname_source?: string | UserProfileSourceContext | null
+ profile_sources?: {
+ avatar?: string | UserProfileSourceContext | null
+ username?: string | UserProfileSourceContext | null
+ display_name?: string | UserProfileSourceContext | null
+ nickname?: string | UserProfileSourceContext | null
+ }
+ auth_bindings?: Partial>
+ identity_bindings?: Partial>
+ email_bound?: boolean
+ linuxdo_bound?: boolean
+ oidc_bound?: boolean
+ wechat_bound?: boolean
role: 'admin' | 'user' // User role for authorization
balance: number // User balance for API usage
concurrency: number // Allowed concurrent requests
diff --git a/frontend/src/views/auth/LinuxDoCallbackView.vue b/frontend/src/views/auth/LinuxDoCallbackView.vue
index 0a125def..6dc8f242 100644
--- a/frontend/src/views/auth/LinuxDoCallbackView.vue
+++ b/frontend/src/views/auth/LinuxDoCallbackView.vue
@@ -136,6 +136,9 @@ import { useAuthStore, useAppStore } from '@/stores'
import {
completeLinuxDoOAuthRegistration,
exchangePendingOAuthCompletion,
+ getOAuthCompletionKind,
+ isOAuthLoginCompletion,
+ persistOAuthTokenContext,
type OAuthAdoptionDecision,
type PendingOAuthExchangeResponse
} from '@/api/auth'
@@ -162,6 +165,7 @@ const suggestedAvatarUrl = ref('')
const adoptDisplayName = ref(true)
const adoptAvatar = ref(true)
const needsAdoptionConfirmation = ref(false)
+const bindSuccessMessage = t('profile.authBindings.bindSuccess')
function parseFragmentParams(): URLSearchParams {
const raw = typeof window !== 'undefined' ? window.location.hash : ''
@@ -209,18 +213,19 @@ function hasSuggestedProfile(completion: {
return Boolean(completion.suggested_display_name || completion.suggested_avatar_url)
}
-async function finalizeLogin(completion: PendingOAuthExchangeResponse, redirect: string) {
- if (!completion.access_token) {
+async function finalizeCompletion(completion: PendingOAuthExchangeResponse, redirect: string) {
+ if (getOAuthCompletionKind(completion) === 'bind') {
+ const bindRedirect = sanitizeRedirectPath(completion.redirect || '/profile')
+ appStore.showSuccess(bindSuccessMessage)
+ await router.replace(bindRedirect)
+ return
+ }
+
+ if (!isOAuthLoginCompletion(completion)) {
throw new Error(t('auth.linuxdo.callbackMissingToken'))
}
- if (completion.refresh_token) {
- localStorage.setItem('refresh_token', completion.refresh_token)
- }
- if (completion.expires_in) {
- localStorage.setItem('token_expires_at', String(Date.now() + completion.expires_in * 1000))
- }
-
+ persistOAuthTokenContext(completion)
await authStore.setToken(completion.access_token)
appStore.showSuccess(t('auth.loginSuccess'))
await router.replace(redirect)
@@ -236,12 +241,7 @@ async function handleSubmitInvitation() {
invitationCode.value.trim(),
currentAdoptionDecision()
)
- if (tokenData.refresh_token) {
- localStorage.setItem('refresh_token', tokenData.refresh_token)
- }
- if (tokenData.expires_in) {
- localStorage.setItem('token_expires_at', String(Date.now() + tokenData.expires_in * 1000))
- }
+ persistOAuthTokenContext(tokenData)
await authStore.setToken(tokenData.access_token)
appStore.showSuccess(t('auth.loginSuccess'))
await router.replace(redirectTo.value)
@@ -258,7 +258,7 @@ async function handleContinueLogin() {
isSubmitting.value = true
try {
const completion = await exchangePendingOAuthCompletion(currentAdoptionDecision())
- await finalizeLogin(completion, redirectTo.value)
+ await finalizeCompletion(completion, redirectTo.value)
} catch (e: unknown) {
const err = e as { message?: string; response?: { data?: { detail?: string; message?: string } } }
errorMessage.value =
@@ -305,7 +305,7 @@ onMounted(async () => {
return
}
- await finalizeLogin(completion, redirect)
+ await finalizeCompletion(completion, redirect)
} catch (e: unknown) {
const err = e as { message?: string; response?: { data?: { detail?: string; message?: string } } }
errorMessage.value =
diff --git a/frontend/src/views/auth/OidcCallbackView.vue b/frontend/src/views/auth/OidcCallbackView.vue
index 55f8af6e..83304226 100644
--- a/frontend/src/views/auth/OidcCallbackView.vue
+++ b/frontend/src/views/auth/OidcCallbackView.vue
@@ -145,7 +145,10 @@ import { useAuthStore, useAppStore } from '@/stores'
import {
completeOIDCOAuthRegistration,
exchangePendingOAuthCompletion,
+ getOAuthCompletionKind,
getPublicSettings,
+ isOAuthLoginCompletion,
+ persistOAuthTokenContext,
type OAuthAdoptionDecision,
type PendingOAuthExchangeResponse
} from '@/api/auth'
@@ -172,6 +175,7 @@ const suggestedAvatarUrl = ref('')
const adoptDisplayName = ref(true)
const adoptAvatar = ref(true)
const needsAdoptionConfirmation = ref(false)
+const bindSuccessMessage = t('profile.authBindings.bindSuccess')
function parseFragmentParams(): URLSearchParams {
const raw = typeof window !== 'undefined' ? window.location.hash : ''
@@ -231,18 +235,19 @@ function hasSuggestedProfile(completion: {
return Boolean(completion.suggested_display_name || completion.suggested_avatar_url)
}
-async function finalizeLogin(completion: PendingOAuthExchangeResponse, redirect: string) {
- if (!completion.access_token) {
+async function finalizeCompletion(completion: PendingOAuthExchangeResponse, redirect: string) {
+ if (getOAuthCompletionKind(completion) === 'bind') {
+ const bindRedirect = sanitizeRedirectPath(completion.redirect || '/profile')
+ appStore.showSuccess(bindSuccessMessage)
+ await router.replace(bindRedirect)
+ return
+ }
+
+ if (!isOAuthLoginCompletion(completion)) {
throw new Error(t('auth.oidc.callbackMissingToken'))
}
- if (completion.refresh_token) {
- localStorage.setItem('refresh_token', completion.refresh_token)
- }
- if (completion.expires_in) {
- localStorage.setItem('token_expires_at', String(Date.now() + completion.expires_in * 1000))
- }
-
+ persistOAuthTokenContext(completion)
await authStore.setToken(completion.access_token)
appStore.showSuccess(t('auth.loginSuccess'))
await router.replace(redirect)
@@ -258,12 +263,7 @@ async function handleSubmitInvitation() {
invitationCode.value.trim(),
currentAdoptionDecision()
)
- if (tokenData.refresh_token) {
- localStorage.setItem('refresh_token', tokenData.refresh_token)
- }
- if (tokenData.expires_in) {
- localStorage.setItem('token_expires_at', String(Date.now() + tokenData.expires_in * 1000))
- }
+ persistOAuthTokenContext(tokenData)
await authStore.setToken(tokenData.access_token)
appStore.showSuccess(t('auth.loginSuccess'))
await router.replace(redirectTo.value)
@@ -280,7 +280,7 @@ async function handleContinueLogin() {
isSubmitting.value = true
try {
const completion = await exchangePendingOAuthCompletion(currentAdoptionDecision())
- await finalizeLogin(completion, redirectTo.value)
+ await finalizeCompletion(completion, redirectTo.value)
} catch (e: unknown) {
const err = e as { message?: string; response?: { data?: { detail?: string; message?: string } } }
errorMessage.value =
@@ -329,7 +329,7 @@ onMounted(async () => {
return
}
- await finalizeLogin(completion, redirect)
+ await finalizeCompletion(completion, redirect)
} catch (e: unknown) {
const err = e as { message?: string; response?: { data?: { detail?: string; message?: string } } }
errorMessage.value =
diff --git a/frontend/src/views/auth/WechatCallbackView.vue b/frontend/src/views/auth/WechatCallbackView.vue
index 407b395b..ac0ede4c 100644
--- a/frontend/src/views/auth/WechatCallbackView.vue
+++ b/frontend/src/views/auth/WechatCallbackView.vue
@@ -140,27 +140,16 @@ import { useRoute, useRouter } from 'vue-router'
import { useI18n } from 'vue-i18n'
import { AuthLayout } from '@/components/layout'
import Icon from '@/components/icons/Icon.vue'
-import { apiClient } from '@/api/client'
import { useAuthStore, useAppStore } from '@/stores'
-
-interface OAuthTokenResponse {
- access_token: string
- refresh_token: string
- expires_in: number
- token_type: string
-}
-
-interface PendingOAuthExchangeResponse {
- access_token?: string
- refresh_token?: string
- expires_in?: number
- token_type?: string
- redirect?: string
- error?: string
- adoption_required?: boolean
- suggested_display_name?: string
- suggested_avatar_url?: string
-}
+import {
+ completeWeChatOAuthRegistration,
+ exchangePendingOAuthCompletion,
+ getOAuthCompletionKind,
+ isOAuthLoginCompletion,
+ persistOAuthTokenContext,
+ type OAuthAdoptionDecision,
+ type PendingOAuthExchangeResponse
+} from '@/api/auth'
const route = useRoute()
const router = useRouter()
@@ -182,6 +171,7 @@ const suggestedAvatarUrl = ref('')
const adoptDisplayName = ref(true)
const adoptAvatar = ref(true)
const needsAdoptionConfirmation = ref(false)
+const bindSuccessMessage = t('profile.authBindings.bindSuccess')
const providerName = 'WeChat'
@@ -200,10 +190,10 @@ function sanitizeRedirectPath(path: string | null | undefined): string {
return path
}
-function currentAdoptionDecision(): Record {
+function currentAdoptionDecision(): OAuthAdoptionDecision {
return {
- adopt_display_name: adoptDisplayName.value,
- adopt_avatar: adoptAvatar.value,
+ adoptDisplayName: adoptDisplayName.value,
+ adoptAvatar: adoptAvatar.value
}
}
@@ -224,49 +214,35 @@ function hasSuggestedProfile(completion: PendingOAuthExchangeResponse): boolean
return Boolean(completion.suggested_display_name || completion.suggested_avatar_url)
}
-async function exchangePendingOAuthCompletion(): Promise {
- const { data } = await apiClient.post('/auth/oauth/pending/exchange', {})
- return data
-}
+async function finalizeCompletion(completion: PendingOAuthExchangeResponse, redirect: string) {
+ if (getOAuthCompletionKind(completion) === 'bind') {
+ const bindRedirect = sanitizeRedirectPath(completion.redirect || '/profile')
+ appStore.showSuccess(bindSuccessMessage)
+ await router.replace(bindRedirect)
+ return
+ }
-async function finalizeLogin(completion: PendingOAuthExchangeResponse, redirect: string) {
- if (!completion.access_token) {
+ if (!isOAuthLoginCompletion(completion)) {
throw new Error(t('auth.oidc.callbackMissingToken'))
}
- if (completion.refresh_token) {
- localStorage.setItem('refresh_token', completion.refresh_token)
- }
- if (completion.expires_in) {
- localStorage.setItem('token_expires_at', String(Date.now() + completion.expires_in * 1000))
- }
-
+ persistOAuthTokenContext(completion)
await authStore.setToken(completion.access_token)
appStore.showSuccess(t('auth.loginSuccess'))
await router.replace(redirect)
}
-async function completeWeChatOAuthRegistration(invitation: string): Promise {
- const { data } = await apiClient.post('/auth/oauth/wechat/complete-registration', {
- invitation_code: invitation,
- ...currentAdoptionDecision(),
- })
- return data
-}
-
async function handleSubmitInvitation() {
invitationError.value = ''
if (!invitationCode.value.trim()) return
isSubmitting.value = true
try {
- const tokenData = await completeWeChatOAuthRegistration(invitationCode.value.trim())
- if (tokenData.refresh_token) {
- localStorage.setItem('refresh_token', tokenData.refresh_token)
- }
- if (tokenData.expires_in) {
- localStorage.setItem('token_expires_at', String(Date.now() + tokenData.expires_in * 1000))
- }
+ const tokenData = await completeWeChatOAuthRegistration(
+ invitationCode.value.trim(),
+ currentAdoptionDecision()
+ )
+ persistOAuthTokenContext(tokenData)
await authStore.setToken(tokenData.access_token)
appStore.showSuccess(t('auth.loginSuccess'))
await router.replace(redirectTo.value)
@@ -282,11 +258,8 @@ async function handleSubmitInvitation() {
async function handleContinueLogin() {
isSubmitting.value = true
try {
- const { data } = await apiClient.post(
- '/auth/oauth/pending/exchange',
- currentAdoptionDecision()
- )
- await finalizeLogin(data, redirectTo.value)
+ const completion = await exchangePendingOAuthCompletion(currentAdoptionDecision())
+ await finalizeCompletion(completion, redirectTo.value)
} catch (e: unknown) {
const err = e as { message?: string; response?: { data?: { detail?: string; message?: string } } }
errorMessage.value =
@@ -333,7 +306,7 @@ onMounted(async () => {
return
}
- await finalizeLogin(completion, redirect)
+ await finalizeCompletion(completion, redirect)
} catch (e: unknown) {
const err = e as { message?: string; response?: { data?: { detail?: string; message?: string } } }
errorMessage.value =
diff --git a/frontend/src/views/auth/__tests__/LinuxDoCallbackView.spec.ts b/frontend/src/views/auth/__tests__/LinuxDoCallbackView.spec.ts
index 60a40474..7ffdcd19 100644
--- a/frontend/src/views/auth/__tests__/LinuxDoCallbackView.spec.ts
+++ b/frontend/src/views/auth/__tests__/LinuxDoCallbackView.spec.ts
@@ -39,10 +39,14 @@ vi.mock('@/stores', () => ({
})
}))
-vi.mock('@/api/auth', () => ({
- exchangePendingOAuthCompletion: (...args: any[]) => exchangePendingOAuthCompletion(...args),
- completeLinuxDoOAuthRegistration: (...args: any[]) => completeLinuxDoOAuthRegistration(...args)
-}))
+vi.mock('@/api/auth', async () => {
+ const actual = await vi.importActual('@/api/auth')
+ return {
+ ...actual,
+ exchangePendingOAuthCompletion: (...args: any[]) => exchangePendingOAuthCompletion(...args),
+ completeLinuxDoOAuthRegistration: (...args: any[]) => completeLinuxDoOAuthRegistration(...args)
+ }
+})
describe('LinuxDoCallbackView', () => {
beforeEach(() => {
@@ -132,6 +136,64 @@ describe('LinuxDoCallbackView', () => {
expect(replace).toHaveBeenCalledWith('/dashboard')
})
+ it('treats a completion without token as bind success and returns to profile', async () => {
+ exchangePendingOAuthCompletion.mockResolvedValue({})
+
+ mount(LinuxDoCallbackView, {
+ global: {
+ stubs: {
+ AuthLayout: { template: '
' },
+ Icon: true,
+ RouterLink: { template: ' ' },
+ transition: false
+ }
+ }
+ })
+
+ await flushPromises()
+
+ expect(setToken).not.toHaveBeenCalled()
+ expect(showSuccess).toHaveBeenCalledWith('profile.authBindings.bindSuccess')
+ expect(replace).toHaveBeenCalledWith('/profile')
+ })
+
+ it('supports bind completion after adoption confirmation', async () => {
+ exchangePendingOAuthCompletion
+ .mockResolvedValueOnce({
+ redirect: '/dashboard',
+ adoption_required: true,
+ suggested_display_name: 'LinuxDo Nick',
+ suggested_avatar_url: 'https://cdn.example/linuxdo.png'
+ })
+ .mockResolvedValueOnce({
+ redirect: '/profile/security'
+ })
+
+ const wrapper = mount(LinuxDoCallbackView, {
+ global: {
+ stubs: {
+ AuthLayout: { template: '
' },
+ Icon: true,
+ RouterLink: { template: ' ' },
+ transition: false
+ }
+ }
+ })
+
+ await flushPromises()
+
+ await wrapper.findAll('button')[0].trigger('click')
+ await flushPromises()
+
+ expect(exchangePendingOAuthCompletion).toHaveBeenNthCalledWith(2, {
+ adoptDisplayName: true,
+ adoptAvatar: true
+ })
+ expect(setToken).not.toHaveBeenCalled()
+ expect(showSuccess).toHaveBeenCalledWith('profile.authBindings.bindSuccess')
+ expect(replace).toHaveBeenCalledWith('/profile/security')
+ })
+
it('renders adoption choices for invitation flow and submits the selected values', async () => {
exchangePendingOAuthCompletion.mockResolvedValue({
error: 'invitation_required',
diff --git a/frontend/src/views/auth/__tests__/OidcCallbackView.spec.ts b/frontend/src/views/auth/__tests__/OidcCallbackView.spec.ts
index 299c0746..f8de79f2 100644
--- a/frontend/src/views/auth/__tests__/OidcCallbackView.spec.ts
+++ b/frontend/src/views/auth/__tests__/OidcCallbackView.spec.ts
@@ -45,11 +45,15 @@ vi.mock('@/stores', () => ({
})
}))
-vi.mock('@/api/auth', () => ({
- exchangePendingOAuthCompletion: (...args: any[]) => exchangePendingOAuthCompletion(...args),
- completeOIDCOAuthRegistration: (...args: any[]) => completeOIDCOAuthRegistration(...args),
- getPublicSettings: (...args: any[]) => getPublicSettings(...args)
-}))
+vi.mock('@/api/auth', async () => {
+ const actual = await vi.importActual('@/api/auth')
+ return {
+ ...actual,
+ exchangePendingOAuthCompletion: (...args: any[]) => exchangePendingOAuthCompletion(...args),
+ completeOIDCOAuthRegistration: (...args: any[]) => completeOIDCOAuthRegistration(...args),
+ getPublicSettings: (...args: any[]) => getPublicSettings(...args)
+ }
+})
describe('OidcCallbackView', () => {
beforeEach(() => {
@@ -143,6 +147,43 @@ describe('OidcCallbackView', () => {
expect(replace).toHaveBeenCalledWith('/dashboard')
})
+ it('supports bind completion after adoption confirmation', async () => {
+ exchangePendingOAuthCompletion
+ .mockResolvedValueOnce({
+ redirect: '/dashboard',
+ adoption_required: true,
+ suggested_display_name: 'OIDC Nick',
+ suggested_avatar_url: 'https://cdn.example/oidc.png'
+ })
+ .mockResolvedValueOnce({
+ redirect: '/profile'
+ })
+
+ const wrapper = mount(OidcCallbackView, {
+ global: {
+ stubs: {
+ AuthLayout: { template: '
' },
+ Icon: true,
+ RouterLink: { template: ' ' },
+ transition: false
+ }
+ }
+ })
+
+ await flushPromises()
+
+ await wrapper.findAll('button')[0].trigger('click')
+ await flushPromises()
+
+ expect(exchangePendingOAuthCompletion).toHaveBeenNthCalledWith(2, {
+ adoptDisplayName: true,
+ adoptAvatar: true
+ })
+ expect(setToken).not.toHaveBeenCalled()
+ expect(showSuccess).toHaveBeenCalledWith('profile.authBindings.bindSuccess')
+ expect(replace).toHaveBeenCalledWith('/profile')
+ })
+
it('renders adoption choices for invitation flow and submits the selected values', async () => {
exchangePendingOAuthCompletion.mockResolvedValue({
error: 'invitation_required',
diff --git a/frontend/src/views/auth/__tests__/WechatCallbackView.spec.ts b/frontend/src/views/auth/__tests__/WechatCallbackView.spec.ts
index a9e2ada2..896bf15d 100644
--- a/frontend/src/views/auth/__tests__/WechatCallbackView.spec.ts
+++ b/frontend/src/views/auth/__tests__/WechatCallbackView.spec.ts
@@ -3,14 +3,16 @@ import { beforeEach, describe, expect, it, vi } from 'vitest'
import WechatCallbackView from '@/views/auth/WechatCallbackView.vue'
const {
- postMock,
+ exchangePendingOAuthCompletionMock,
+ completeWeChatOAuthRegistrationMock,
replaceMock,
setTokenMock,
showSuccessMock,
showErrorMock,
routeState,
} = vi.hoisted(() => ({
- postMock: vi.fn(),
+ exchangePendingOAuthCompletionMock: vi.fn(),
+ completeWeChatOAuthRegistrationMock: vi.fn(),
replaceMock: vi.fn(),
setTokenMock: vi.fn(),
showSuccessMock: vi.fn(),
@@ -86,15 +88,19 @@ vi.mock('@/stores', () => ({
}),
}))
-vi.mock('@/api/client', () => ({
- apiClient: {
- post: postMock,
- },
-}))
+vi.mock('@/api/auth', async () => {
+ const actual = await vi.importActual('@/api/auth')
+ return {
+ ...actual,
+ exchangePendingOAuthCompletion: (...args: any[]) => exchangePendingOAuthCompletionMock(...args),
+ completeWeChatOAuthRegistration: (...args: any[]) => completeWeChatOAuthRegistrationMock(...args),
+ }
+})
describe('WechatCallbackView', () => {
beforeEach(() => {
- postMock.mockReset()
+ exchangePendingOAuthCompletionMock.mockReset()
+ completeWeChatOAuthRegistrationMock.mockReset()
replaceMock.mockReset()
setTokenMock.mockReset()
showSuccessMock.mockReset()
@@ -104,14 +110,12 @@ describe('WechatCallbackView', () => {
})
it('does not send adoption decisions during the initial exchange', async () => {
- postMock.mockResolvedValueOnce({
- data: {
- access_token: 'access-token',
- refresh_token: 'refresh-token',
- expires_in: 3600,
- redirect: '/dashboard',
- adoption_required: true,
- },
+ exchangePendingOAuthCompletionMock.mockResolvedValue({
+ access_token: 'access-token',
+ refresh_token: 'refresh-token',
+ expires_in: 3600,
+ redirect: '/dashboard',
+ adoption_required: true,
})
setTokenMock.mockResolvedValue({})
@@ -128,28 +132,24 @@ describe('WechatCallbackView', () => {
await flushPromises()
- expect(postMock).toHaveBeenCalledWith('/auth/oauth/pending/exchange', {})
- expect(postMock).toHaveBeenCalledTimes(1)
+ expect(exchangePendingOAuthCompletionMock).toHaveBeenCalledWith()
+ expect(exchangePendingOAuthCompletionMock).toHaveBeenCalledTimes(1)
})
it('waits for explicit adoption confirmation before finishing a non-invitation login', async () => {
- postMock
+ exchangePendingOAuthCompletionMock
.mockResolvedValueOnce({
- data: {
- redirect: '/dashboard',
- adoption_required: true,
- suggested_display_name: 'WeChat Nick',
- suggested_avatar_url: 'https://cdn.example/wechat.png',
- },
+ redirect: '/dashboard',
+ adoption_required: true,
+ suggested_display_name: 'WeChat Nick',
+ suggested_avatar_url: 'https://cdn.example/wechat.png',
})
.mockResolvedValueOnce({
- data: {
- access_token: 'wechat-access-token',
- refresh_token: 'wechat-refresh-token',
- expires_in: 3600,
- token_type: 'Bearer',
- redirect: '/dashboard',
- },
+ access_token: 'wechat-access-token',
+ refresh_token: 'wechat-refresh-token',
+ expires_in: 3600,
+ token_type: 'Bearer',
+ redirect: '/dashboard',
})
setTokenMock.mockResolvedValue({})
@@ -179,35 +179,67 @@ describe('WechatCallbackView', () => {
await buttons[0].trigger('click')
await flushPromises()
- expect(postMock).toHaveBeenNthCalledWith(1, '/auth/oauth/pending/exchange', {})
- expect(postMock).toHaveBeenNthCalledWith(2, '/auth/oauth/pending/exchange', {
- adopt_display_name: true,
- adopt_avatar: false,
+ expect(exchangePendingOAuthCompletionMock).toHaveBeenNthCalledWith(1)
+ expect(exchangePendingOAuthCompletionMock).toHaveBeenNthCalledWith(2, {
+ adoptDisplayName: true,
+ adoptAvatar: false,
})
expect(setTokenMock).toHaveBeenCalledWith('wechat-access-token')
expect(replaceMock).toHaveBeenCalledWith('/dashboard')
expect(localStorage.getItem('refresh_token')).toBe('wechat-refresh-token')
})
+ it('supports bind completion after adoption confirmation', async () => {
+ exchangePendingOAuthCompletionMock
+ .mockResolvedValueOnce({
+ redirect: '/dashboard',
+ adoption_required: true,
+ suggested_display_name: 'WeChat Nick',
+ suggested_avatar_url: 'https://cdn.example/wechat.png',
+ })
+ .mockResolvedValueOnce({
+ redirect: '/profile/connections',
+ })
+
+ const wrapper = mount(WechatCallbackView, {
+ global: {
+ stubs: {
+ AuthLayout: { template: '
' },
+ Icon: true,
+ RouterLink: { template: ' ' },
+ transition: false,
+ },
+ },
+ })
+
+ await flushPromises()
+
+ await wrapper.findAll('button')[0].trigger('click')
+ await flushPromises()
+
+ expect(exchangePendingOAuthCompletionMock).toHaveBeenNthCalledWith(2, {
+ adoptDisplayName: true,
+ adoptAvatar: true,
+ })
+ expect(setTokenMock).not.toHaveBeenCalled()
+ expect(showSuccessMock).toHaveBeenCalledWith('profile.authBindings.bindSuccess')
+ expect(replaceMock).toHaveBeenCalledWith('/profile/connections')
+ })
+
it('renders adoption choices for invitation flow and submits the selected values', async () => {
- postMock
- .mockResolvedValueOnce({
- data: {
- error: 'invitation_required',
- redirect: '/subscriptions',
- adoption_required: true,
- suggested_display_name: 'WeChat Nick',
- suggested_avatar_url: 'https://cdn.example/wechat.png',
- },
- })
- .mockResolvedValueOnce({
- data: {
- access_token: 'wechat-invite-token',
- refresh_token: 'wechat-invite-refresh',
- expires_in: 600,
- token_type: 'Bearer',
- },
- })
+ exchangePendingOAuthCompletionMock.mockResolvedValue({
+ error: 'invitation_required',
+ redirect: '/subscriptions',
+ adoption_required: true,
+ suggested_display_name: 'WeChat Nick',
+ suggested_avatar_url: 'https://cdn.example/wechat.png',
+ })
+ completeWeChatOAuthRegistrationMock.mockResolvedValue({
+ access_token: 'wechat-invite-token',
+ refresh_token: 'wechat-invite-refresh',
+ expires_in: 600,
+ token_type: 'Bearer',
+ })
const wrapper = mount(WechatCallbackView, {
global: {
@@ -230,10 +262,9 @@ describe('WechatCallbackView', () => {
await wrapper.get('button').trigger('click')
await flushPromises()
- expect(postMock).toHaveBeenNthCalledWith(2, '/auth/oauth/wechat/complete-registration', {
- invitation_code: 'INVITE-CODE',
- adopt_display_name: false,
- adopt_avatar: true,
+ expect(completeWeChatOAuthRegistrationMock).toHaveBeenCalledWith('INVITE-CODE', {
+ adoptDisplayName: false,
+ adoptAvatar: true,
})
expect(setTokenMock).toHaveBeenCalledWith('wechat-invite-token')
expect(replaceMock).toHaveBeenCalledWith('/subscriptions')
diff --git a/frontend/src/views/user/ProfileView.vue b/frontend/src/views/user/ProfileView.vue
index e7418ebb..f7418be9 100644
--- a/frontend/src/views/user/ProfileView.vue
+++ b/frontend/src/views/user/ProfileView.vue
@@ -2,18 +2,53 @@
-
-
-
+
+
+
-
-
+
+
+
+
-
-
{{ t('common.contactSupport') }} {{ contactInfo }}
+
+
+
+
+
+ {{ t('common.contactSupport') }}
+
+
{{ contactInfo }}
+
+
+
+
@@ -29,26 +65,78 @@
\ No newline at end of file
+onMounted(async () => {
+ const profileRefresh = authStore.refreshUser().catch((error) => {
+ console.error('Failed to refresh profile:', error)
+ })
+
+ const settingsLoad = authAPI.getPublicSettings()
+ .then((settings) => {
+ contactInfo.value = settings.contact_info || ''
+ balanceLowNotifyEnabled.value = settings.balance_low_notify_enabled ?? false
+ systemDefaultThreshold.value = settings.balance_low_notify_threshold ?? 0
+ linuxdoOAuthEnabled.value = settings.linuxdo_oauth_enabled ?? false
+ wechatOAuthEnabled.value = settings.wechat_oauth_enabled ?? false
+ oidcOAuthEnabled.value = settings.oidc_oauth_enabled ?? false
+ oidcOAuthProviderName.value = settings.oidc_oauth_provider_name || 'OIDC'
+ })
+ .catch((error) => {
+ console.error('Failed to load settings:', error)
+ })
+
+ await Promise.all([profileRefresh, settingsLoad])
+})
+
+const formatCurrency = (value: number) => `$${value.toFixed(2)}`
+
From 4e0e69154649def4f4149054e09ff03fc8a3e50a Mon Sep 17 00:00:00 2001
From: IanShaw027
Date: Mon, 20 Apr 2026 18:39:53 +0800
Subject: [PATCH 025/189] feat: apply auth source signup defaults
---
.../service/admin_service_apikey_test.go | 3 +
.../service/admin_service_delete_test.go | 51 ++++--
backend/internal/service/auth_service.go | 117 +++++++++----
.../service/auth_service_register_test.go | 159 +++++++++++++++++-
4 files changed, 283 insertions(+), 47 deletions(-)
diff --git a/backend/internal/service/admin_service_apikey_test.go b/backend/internal/service/admin_service_apikey_test.go
index b802a9c2..487fb5f1 100644
--- a/backend/internal/service/admin_service_apikey_test.go
+++ b/backend/internal/service/admin_service_apikey_test.go
@@ -79,6 +79,9 @@ func (s *userRepoStubForGroupUpdate) UpdateTotpSecret(context.Context, int64, *s
}
func (s *userRepoStubForGroupUpdate) EnableTotp(context.Context, int64) error { panic("unexpected") }
func (s *userRepoStubForGroupUpdate) DisableTotp(context.Context, int64) error { panic("unexpected") }
+func (s *userRepoStubForGroupUpdate) ListUserAuthIdentities(context.Context, int64) ([]UserAuthIdentityRecord, error) {
+ panic("unexpected")
+}
func (s *userRepoStubForGroupUpdate) RemoveGroupFromUserAllowedGroups(context.Context, int64, int64) error {
panic("unexpected")
}
diff --git a/backend/internal/service/admin_service_delete_test.go b/backend/internal/service/admin_service_delete_test.go
index 323286b0..ac1d8ee7 100644
--- a/backend/internal/service/admin_service_delete_test.go
+++ b/backend/internal/service/admin_service_delete_test.go
@@ -13,15 +13,18 @@ import (
)
type userRepoStub struct {
- user *User
- getErr error
- createErr error
- deleteErr error
- exists bool
- existsErr error
- nextID int64
- created []*User
- deletedIDs []int64
+ user *User
+ getErr error
+ createErr error
+ deleteErr error
+ exists bool
+ existsErr error
+ nextID int64
+ created []*User
+ updated []*User
+ deletedIDs []int64
+ usersByEmail map[string]*User
+ getByEmailErr error
}
func (s *userRepoStub) Create(ctx context.Context, user *User) error {
@@ -32,6 +35,11 @@ func (s *userRepoStub) Create(ctx context.Context, user *User) error {
user.ID = s.nextID
}
s.created = append(s.created, user)
+ if s.usersByEmail == nil {
+ s.usersByEmail = make(map[string]*User)
+ }
+ s.usersByEmail[user.Email] = user
+ s.user = user
return nil
}
@@ -46,7 +54,18 @@ func (s *userRepoStub) GetByID(ctx context.Context, id int64) (*User, error) {
}
func (s *userRepoStub) GetByEmail(ctx context.Context, email string) (*User, error) {
- panic("unexpected GetByEmail call")
+ if s.getByEmailErr != nil {
+ return nil, s.getByEmailErr
+ }
+ if s.usersByEmail != nil {
+ if user, ok := s.usersByEmail[email]; ok {
+ return user, nil
+ }
+ }
+ if s.user != nil && s.user.Email == email {
+ return s.user, nil
+ }
+ return nil, ErrUserNotFound
}
func (s *userRepoStub) GetFirstAdmin(ctx context.Context) (*User, error) {
@@ -54,7 +73,13 @@ func (s *userRepoStub) GetFirstAdmin(ctx context.Context) (*User, error) {
}
func (s *userRepoStub) Update(ctx context.Context, user *User) error {
- panic("unexpected Update call")
+ s.updated = append(s.updated, user)
+ if s.usersByEmail == nil {
+ s.usersByEmail = make(map[string]*User)
+ }
+ s.usersByEmail[user.Email] = user
+ s.user = user
+ return nil
}
func (s *userRepoStub) Delete(ctx context.Context, id int64) error {
@@ -113,6 +138,10 @@ func (s *userRepoStub) AddGroupToAllowedGroups(ctx context.Context, userID int64
panic("unexpected AddGroupToAllowedGroups call")
}
+func (s *userRepoStub) ListUserAuthIdentities(ctx context.Context, userID int64) ([]UserAuthIdentityRecord, error) {
+ panic("unexpected ListUserAuthIdentities call")
+}
+
func (s *userRepoStub) UpdateTotpSecret(ctx context.Context, userID int64, encryptedSecret *string) error {
panic("unexpected UpdateTotpSecret call")
}
diff --git a/backend/internal/service/auth_service.go b/backend/internal/service/auth_service.go
index 962009ce..40753139 100644
--- a/backend/internal/service/auth_service.go
+++ b/backend/internal/service/auth_service.go
@@ -78,6 +78,12 @@ type DefaultSubscriptionAssigner interface {
AssignOrExtendSubscription(ctx context.Context, input *AssignSubscriptionInput) (*UserSubscription, bool, error)
}
+type signupGrantPlan struct {
+ Balance float64
+ Concurrency int
+ Subscriptions []DefaultSubscriptionSetting
+}
+
// NewAuthService 创建认证服务实例
func NewAuthService(
entClient *dbent.Client,
@@ -187,21 +193,15 @@ func (s *AuthService) RegisterWithVerification(ctx context.Context, email, passw
return "", nil, fmt.Errorf("hash password: %w", err)
}
- // 获取默认配置
- defaultBalance := s.cfg.Default.UserBalance
- defaultConcurrency := s.cfg.Default.UserConcurrency
- if s.settingService != nil {
- defaultBalance = s.settingService.GetDefaultBalance(ctx)
- defaultConcurrency = s.settingService.GetDefaultConcurrency(ctx)
- }
+ grantPlan := s.resolveSignupGrantPlan(ctx, "email")
// 创建用户
user := &User{
Email: email,
PasswordHash: hashedPassword,
Role: RoleUser,
- Balance: defaultBalance,
- Concurrency: defaultConcurrency,
+ Balance: grantPlan.Balance,
+ Concurrency: grantPlan.Concurrency,
Status: StatusActive,
}
@@ -214,7 +214,7 @@ func (s *AuthService) RegisterWithVerification(ctx context.Context, email, passw
return "", nil, ErrServiceUnavailable
}
s.postAuthUserBootstrap(ctx, user, "email", true)
- s.assignDefaultSubscriptions(ctx, user.ID)
+ s.assignSubscriptions(ctx, user.ID, grantPlan.Subscriptions, "auto assigned by signup defaults")
// 标记邀请码为已使用(如果使用了邀请码)
if invitationRedeemCode != nil {
@@ -479,21 +479,16 @@ func (s *AuthService) LoginOrRegisterOAuth(ctx context.Context, email, username
return "", nil, fmt.Errorf("hash password: %w", err)
}
- // 新用户默认值。
- defaultBalance := s.cfg.Default.UserBalance
- defaultConcurrency := s.cfg.Default.UserConcurrency
- if s.settingService != nil {
- defaultBalance = s.settingService.GetDefaultBalance(ctx)
- defaultConcurrency = s.settingService.GetDefaultConcurrency(ctx)
- }
+ signupSource := inferLegacySignupSource(email)
+ grantPlan := s.resolveSignupGrantPlan(ctx, signupSource)
newUser := &User{
Email: email,
Username: username,
PasswordHash: hashedPassword,
Role: RoleUser,
- Balance: defaultBalance,
- Concurrency: defaultConcurrency,
+ Balance: grantPlan.Balance,
+ Concurrency: grantPlan.Concurrency,
Status: StatusActive,
}
@@ -511,8 +506,8 @@ func (s *AuthService) LoginOrRegisterOAuth(ctx context.Context, email, username
}
} else {
user = newUser
- s.postAuthUserBootstrap(ctx, user, inferLegacySignupSource(email), true)
- s.assignDefaultSubscriptions(ctx, user.ID)
+ s.postAuthUserBootstrap(ctx, user, signupSource, true)
+ s.assignSubscriptions(ctx, user.ID, grantPlan.Subscriptions, "auto assigned by signup defaults")
}
} else {
logger.LegacyPrintf("service.auth", "[Auth] Database error during oauth login: %v", err)
@@ -596,20 +591,16 @@ func (s *AuthService) LoginOrRegisterOAuthWithTokenPair(ctx context.Context, ema
return nil, nil, fmt.Errorf("hash password: %w", err)
}
- defaultBalance := s.cfg.Default.UserBalance
- defaultConcurrency := s.cfg.Default.UserConcurrency
- if s.settingService != nil {
- defaultBalance = s.settingService.GetDefaultBalance(ctx)
- defaultConcurrency = s.settingService.GetDefaultConcurrency(ctx)
- }
+ signupSource := inferLegacySignupSource(email)
+ grantPlan := s.resolveSignupGrantPlan(ctx, signupSource)
newUser := &User{
Email: email,
Username: username,
PasswordHash: hashedPassword,
Role: RoleUser,
- Balance: defaultBalance,
- Concurrency: defaultConcurrency,
+ Balance: grantPlan.Balance,
+ Concurrency: grantPlan.Concurrency,
Status: StatusActive,
}
@@ -642,8 +633,8 @@ func (s *AuthService) LoginOrRegisterOAuthWithTokenPair(ctx context.Context, ema
return nil, nil, ErrServiceUnavailable
}
user = newUser
- s.postAuthUserBootstrap(ctx, user, inferLegacySignupSource(email), true)
- s.assignDefaultSubscriptions(ctx, user.ID)
+ s.postAuthUserBootstrap(ctx, user, signupSource, true)
+ s.assignSubscriptions(ctx, user.ID, grantPlan.Subscriptions, "auto assigned by signup defaults")
}
} else {
if err := s.userRepo.Create(ctx, newUser); err != nil {
@@ -659,8 +650,8 @@ func (s *AuthService) LoginOrRegisterOAuthWithTokenPair(ctx context.Context, ema
}
} else {
user = newUser
- s.postAuthUserBootstrap(ctx, user, inferLegacySignupSource(email), true)
- s.assignDefaultSubscriptions(ctx, user.ID)
+ s.postAuthUserBootstrap(ctx, user, signupSource, true)
+ s.assignSubscriptions(ctx, user.ID, grantPlan.Subscriptions, "auto assigned by signup defaults")
if invitationRedeemCode != nil {
if err := s.redeemRepo.Use(ctx, invitationRedeemCode.ID, user.ID); err != nil {
return nil, nil, ErrInvitationCodeInvalid
@@ -694,22 +685,78 @@ func (s *AuthService) LoginOrRegisterOAuthWithTokenPair(ctx context.Context, ema
}
func (s *AuthService) assignDefaultSubscriptions(ctx context.Context, userID int64) {
+ if s.settingService == nil {
+ return
+ }
+ s.assignSubscriptions(ctx, userID, s.settingService.GetDefaultSubscriptions(ctx), "auto assigned by default user subscriptions setting")
+}
+
+func (s *AuthService) assignSubscriptions(ctx context.Context, userID int64, items []DefaultSubscriptionSetting, notes string) {
if s.settingService == nil || s.defaultSubAssigner == nil || userID <= 0 {
return
}
- items := s.settingService.GetDefaultSubscriptions(ctx)
for _, item := range items {
if _, _, err := s.defaultSubAssigner.AssignOrExtendSubscription(ctx, &AssignSubscriptionInput{
UserID: userID,
GroupID: item.GroupID,
ValidityDays: item.ValidityDays,
- Notes: "auto assigned by default user subscriptions setting",
+ Notes: notes,
}); err != nil {
logger.LegacyPrintf("service.auth", "[Auth] Failed to assign default subscription: user_id=%d group_id=%d err=%v", userID, item.GroupID, err)
}
}
}
+func (s *AuthService) resolveSignupGrantPlan(ctx context.Context, signupSource string) signupGrantPlan {
+ plan := signupGrantPlan{}
+ if s != nil && s.cfg != nil {
+ plan.Balance = s.cfg.Default.UserBalance
+ plan.Concurrency = s.cfg.Default.UserConcurrency
+ }
+ if s == nil || s.settingService == nil {
+ return plan
+ }
+
+ plan.Balance = s.settingService.GetDefaultBalance(ctx)
+ plan.Concurrency = s.settingService.GetDefaultConcurrency(ctx)
+ plan.Subscriptions = s.settingService.GetDefaultSubscriptions(ctx)
+
+ defaults, err := s.settingService.GetAuthSourceDefaultSettings(ctx)
+ if err != nil {
+ logger.LegacyPrintf("service.auth", "[Auth] Failed to load auth source signup defaults for %s: %v", signupSource, err)
+ return plan
+ }
+
+ providerDefaults, ok := authSourceSignupSettings(defaults, signupSource)
+ if !ok || !providerDefaults.GrantOnSignup {
+ return plan
+ }
+
+ plan.Balance = providerDefaults.Balance
+ plan.Concurrency = providerDefaults.Concurrency
+ plan.Subscriptions = providerDefaults.Subscriptions
+ return plan
+}
+
+func authSourceSignupSettings(defaults *AuthSourceDefaultSettings, signupSource string) (ProviderDefaultGrantSettings, bool) {
+ if defaults == nil {
+ return ProviderDefaultGrantSettings{}, false
+ }
+
+ switch strings.ToLower(strings.TrimSpace(signupSource)) {
+ case "email":
+ return defaults.Email, true
+ case "linuxdo":
+ return defaults.LinuxDo, true
+ case "oidc":
+ return defaults.OIDC, true
+ case "wechat":
+ return defaults.WeChat, true
+ default:
+ return ProviderDefaultGrantSettings{}, false
+ }
+}
+
func (s *AuthService) postAuthUserBootstrap(ctx context.Context, user *User, signupSource string, touchLogin bool) {
if user == nil || user.ID <= 0 {
return
diff --git a/backend/internal/service/auth_service_register_test.go b/backend/internal/service/auth_service_register_test.go
index 103bafe7..901b3db3 100644
--- a/backend/internal/service/auth_service_register_test.go
+++ b/backend/internal/service/auth_service_register_test.go
@@ -37,7 +37,16 @@ func (s *settingRepoStub) Set(ctx context.Context, key, value string) error {
}
func (s *settingRepoStub) GetMultiple(ctx context.Context, keys []string) (map[string]string, error) {
- panic("unexpected GetMultiple call")
+ if s.err != nil {
+ return nil, s.err
+ }
+ result := make(map[string]string, len(keys))
+ for _, key := range keys {
+ if v, ok := s.values[key]; ok {
+ result[key] = v
+ }
+ }
+ return result, nil
}
func (s *settingRepoStub) SetMultiple(ctx context.Context, settings map[string]string) error {
@@ -62,6 +71,8 @@ type defaultSubscriptionAssignerStub struct {
err error
}
+type refreshTokenCacheStub struct{}
+
func (s *defaultSubscriptionAssignerStub) AssignOrExtendSubscription(_ context.Context, input *AssignSubscriptionInput) (*UserSubscription, bool, error) {
if input != nil {
s.calls = append(s.calls, *input)
@@ -72,6 +83,46 @@ func (s *defaultSubscriptionAssignerStub) AssignOrExtendSubscription(_ context.C
return &UserSubscription{UserID: input.UserID, GroupID: input.GroupID}, false, nil
}
+func (s *refreshTokenCacheStub) StoreRefreshToken(context.Context, string, *RefreshTokenData, time.Duration) error {
+ return nil
+}
+
+func (s *refreshTokenCacheStub) GetRefreshToken(context.Context, string) (*RefreshTokenData, error) {
+ return nil, ErrRefreshTokenNotFound
+}
+
+func (s *refreshTokenCacheStub) DeleteRefreshToken(context.Context, string) error {
+ return nil
+}
+
+func (s *refreshTokenCacheStub) DeleteUserRefreshTokens(context.Context, int64) error {
+ return nil
+}
+
+func (s *refreshTokenCacheStub) DeleteTokenFamily(context.Context, string) error {
+ return nil
+}
+
+func (s *refreshTokenCacheStub) AddToUserTokenSet(context.Context, int64, string, time.Duration) error {
+ return nil
+}
+
+func (s *refreshTokenCacheStub) AddToFamilyTokenSet(context.Context, string, string, time.Duration) error {
+ return nil
+}
+
+func (s *refreshTokenCacheStub) GetUserTokenHashes(context.Context, int64) ([]string, error) {
+ return nil, nil
+}
+
+func (s *refreshTokenCacheStub) GetFamilyTokenHashes(context.Context, string) ([]string, error) {
+ return nil, nil
+}
+
+func (s *refreshTokenCacheStub) IsTokenInFamily(context.Context, string, string) (bool, error) {
+ return false, nil
+}
+
func (s *emailCacheStub) GetVerificationCode(ctx context.Context, email string) (*VerificationCodeData, error) {
if s.err != nil {
return nil, s.err
@@ -484,3 +535,109 @@ func TestAuthService_Register_AssignsDefaultSubscriptions(t *testing.T) {
require.Equal(t, int64(12), assigner.calls[1].GroupID)
require.Equal(t, 7, assigner.calls[1].ValidityDays)
}
+
+func TestAuthService_Register_UsesEmailAuthSourceDefaultsWhenGrantEnabled(t *testing.T) {
+ repo := &userRepoStub{nextID: 52}
+ assigner := &defaultSubscriptionAssignerStub{}
+ service := newAuthService(repo, map[string]string{
+ SettingKeyRegistrationEnabled: "true",
+ SettingKeyDefaultSubscriptions: `[{"group_id":91,"validity_days":3}]`,
+ SettingKeyAuthSourceDefaultEmailBalance: "12.5",
+ SettingKeyAuthSourceDefaultEmailConcurrency: "7",
+ SettingKeyAuthSourceDefaultEmailSubscriptions: `[{"group_id":11,"validity_days":30}]`,
+ SettingKeyAuthSourceDefaultEmailGrantOnSignup: "true",
+ }, nil)
+ service.defaultSubAssigner = assigner
+
+ _, user, err := service.Register(context.Background(), "email-defaults@test.com", "password")
+ require.NoError(t, err)
+ require.NotNil(t, user)
+ require.Equal(t, 12.5, user.Balance)
+ require.Equal(t, 7, user.Concurrency)
+ require.Len(t, assigner.calls, 1)
+ require.Equal(t, int64(11), assigner.calls[0].GroupID)
+ require.Equal(t, 30, assigner.calls[0].ValidityDays)
+}
+
+func TestAuthService_Register_GrantOnSignupFalseFallsBackToGlobalDefaults(t *testing.T) {
+ repo := &userRepoStub{nextID: 53}
+ assigner := &defaultSubscriptionAssignerStub{}
+ service := newAuthService(repo, map[string]string{
+ SettingKeyRegistrationEnabled: "true",
+ SettingKeyDefaultSubscriptions: `[{"group_id":31,"validity_days":5}]`,
+ SettingKeyAuthSourceDefaultEmailBalance: "99",
+ SettingKeyAuthSourceDefaultEmailConcurrency: "88",
+ SettingKeyAuthSourceDefaultEmailSubscriptions: `[{"group_id":32,"validity_days":9}]`,
+ SettingKeyAuthSourceDefaultEmailGrantOnSignup: "false",
+ }, nil)
+ service.defaultSubAssigner = assigner
+
+ _, user, err := service.Register(context.Background(), "email-global@test.com", "password")
+ require.NoError(t, err)
+ require.NotNil(t, user)
+ require.Equal(t, 3.5, user.Balance)
+ require.Equal(t, 2, user.Concurrency)
+ require.Len(t, assigner.calls, 1)
+ require.Equal(t, int64(31), assigner.calls[0].GroupID)
+ require.Equal(t, 5, assigner.calls[0].ValidityDays)
+}
+
+func TestAuthService_LoginOrRegisterOAuthWithTokenPair_UsesLinuxDoAuthSourceDefaultsOnSignup(t *testing.T) {
+ repo := &userRepoStub{nextID: 61}
+ assigner := &defaultSubscriptionAssignerStub{}
+ service := newAuthService(repo, map[string]string{
+ SettingKeyRegistrationEnabled: "true",
+ SettingKeyDefaultSubscriptions: `[{"group_id":81,"validity_days":1}]`,
+ SettingKeyAuthSourceDefaultLinuxDoBalance: "21.75",
+ SettingKeyAuthSourceDefaultLinuxDoConcurrency: "9",
+ SettingKeyAuthSourceDefaultLinuxDoSubscriptions: `[{"group_id":22,"validity_days":14}]`,
+ SettingKeyAuthSourceDefaultLinuxDoGrantOnSignup: "true",
+ }, nil)
+ service.defaultSubAssigner = assigner
+ service.refreshTokenCache = &refreshTokenCacheStub{}
+
+ tokenPair, user, err := service.LoginOrRegisterOAuthWithTokenPair(context.Background(), "linuxdo-123@linuxdo-connect.invalid", "linuxdo_user", "")
+ require.NoError(t, err)
+ require.NotNil(t, tokenPair)
+ require.NotNil(t, user)
+ require.Equal(t, int64(61), user.ID)
+ require.Equal(t, 21.75, user.Balance)
+ require.Equal(t, 9, user.Concurrency)
+ require.Len(t, repo.created, 1)
+ require.Len(t, assigner.calls, 1)
+ require.Equal(t, int64(22), assigner.calls[0].GroupID)
+ require.Equal(t, 14, assigner.calls[0].ValidityDays)
+}
+
+func TestAuthService_LoginOrRegisterOAuthWithTokenPair_ExistingUserDoesNotGrantAgain(t *testing.T) {
+ existing := &User{
+ ID: 88,
+ Email: "linuxdo-123@linuxdo-connect.invalid",
+ Username: "existing-linuxdo",
+ Role: RoleUser,
+ Status: StatusActive,
+ Balance: 4,
+ Concurrency: 1,
+ TokenVersion: 2,
+ }
+ repo := &userRepoStub{user: existing}
+ assigner := &defaultSubscriptionAssignerStub{}
+ service := newAuthService(repo, map[string]string{
+ SettingKeyRegistrationEnabled: "true",
+ SettingKeyAuthSourceDefaultLinuxDoBalance: "21.75",
+ SettingKeyAuthSourceDefaultLinuxDoConcurrency: "9",
+ SettingKeyAuthSourceDefaultLinuxDoSubscriptions: `[{"group_id":22,"validity_days":14}]`,
+ SettingKeyAuthSourceDefaultLinuxDoGrantOnSignup: "true",
+ }, nil)
+ service.defaultSubAssigner = assigner
+ service.refreshTokenCache = &refreshTokenCacheStub{}
+
+ tokenPair, user, err := service.LoginOrRegisterOAuthWithTokenPair(context.Background(), existing.Email, "linuxdo_user", "")
+ require.NoError(t, err)
+ require.NotNil(t, tokenPair)
+ require.Equal(t, existing.ID, user.ID)
+ require.Equal(t, 4.0, user.Balance)
+ require.Equal(t, 1, user.Concurrency)
+ require.Empty(t, repo.created)
+ require.Empty(t, assigner.calls)
+}
From 0353c3870f938f5d57d4060232dfe9e3bc4a01ff Mon Sep 17 00:00:00 2001
From: IanShaw027
Date: Mon, 20 Apr 2026 18:40:34 +0800
Subject: [PATCH 026/189] test: update user service stubs for identity
summaries
---
.../service/billing_cache_service_singleflight_test.go | 4 ++++
backend/internal/service/user_service_test.go | 9 ++++++---
2 files changed, 10 insertions(+), 3 deletions(-)
diff --git a/backend/internal/service/billing_cache_service_singleflight_test.go b/backend/internal/service/billing_cache_service_singleflight_test.go
index 4a8b8f03..2ebc2f04 100644
--- a/backend/internal/service/billing_cache_service_singleflight_test.go
+++ b/backend/internal/service/billing_cache_service_singleflight_test.go
@@ -86,6 +86,10 @@ func (s *balanceLoadUserRepoStub) GetByID(ctx context.Context, id int64) (*User,
return &User{ID: id, Balance: s.balance}, nil
}
+func (s *balanceLoadUserRepoStub) ListUserAuthIdentities(context.Context, int64) ([]UserAuthIdentityRecord, error) {
+ return nil, nil
+}
+
func TestBillingCacheServiceGetUserBalance_Singleflight(t *testing.T) {
cache := &billingCacheMissStub{}
userRepo := &balanceLoadUserRepoStub{
diff --git a/backend/internal/service/user_service_test.go b/backend/internal/service/user_service_test.go
index 7d63bb36..89f0362e 100644
--- a/backend/internal/service/user_service_test.go
+++ b/backend/internal/service/user_service_test.go
@@ -100,9 +100,12 @@ func (m *mockUserRepo) RemoveGroupFromAllowedGroups(context.Context, int64) (int
return 0, nil
}
func (m *mockUserRepo) AddGroupToAllowedGroups(context.Context, int64, int64) error { return nil }
-func (m *mockUserRepo) UpdateTotpSecret(context.Context, int64, *string) error { return nil }
-func (m *mockUserRepo) EnableTotp(context.Context, int64) error { return nil }
-func (m *mockUserRepo) DisableTotp(context.Context, int64) error { return nil }
+func (m *mockUserRepo) ListUserAuthIdentities(context.Context, int64) ([]UserAuthIdentityRecord, error) {
+ return nil, nil
+}
+func (m *mockUserRepo) UpdateTotpSecret(context.Context, int64, *string) error { return nil }
+func (m *mockUserRepo) EnableTotp(context.Context, int64) error { return nil }
+func (m *mockUserRepo) DisableTotp(context.Context, int64) error { return nil }
func (m *mockUserRepo) RemoveGroupFromUserAllowedGroups(context.Context, int64, int64) error {
return nil
}
From d47580a144f083965b04e4612e4fd9ceb4779e35 Mon Sep 17 00:00:00 2001
From: IanShaw027
Date: Mon, 20 Apr 2026 18:42:28 +0800
Subject: [PATCH 027/189] test: pin email signup defaults in register tests
---
backend/internal/service/auth_service_register_test.go | 8 +++++---
1 file changed, 5 insertions(+), 3 deletions(-)
diff --git a/backend/internal/service/auth_service_register_test.go b/backend/internal/service/auth_service_register_test.go
index 901b3db3..e0dce982 100644
--- a/backend/internal/service/auth_service_register_test.go
+++ b/backend/internal/service/auth_service_register_test.go
@@ -373,7 +373,8 @@ func TestAuthService_Register_CreateEmailExistsRace(t *testing.T) {
func TestAuthService_Register_Success(t *testing.T) {
repo := &userRepoStub{nextID: 5}
service := newAuthService(repo, map[string]string{
- SettingKeyRegistrationEnabled: "true",
+ SettingKeyRegistrationEnabled: "true",
+ SettingKeyAuthSourceDefaultEmailGrantOnSignup: "false",
}, nil)
token, user, err := service.Register(context.Background(), "user@test.com", "password")
@@ -520,8 +521,9 @@ func TestAuthService_Register_AssignsDefaultSubscriptions(t *testing.T) {
repo := &userRepoStub{nextID: 42}
assigner := &defaultSubscriptionAssignerStub{}
service := newAuthService(repo, map[string]string{
- SettingKeyRegistrationEnabled: "true",
- SettingKeyDefaultSubscriptions: `[{"group_id":11,"validity_days":30},{"group_id":12,"validity_days":7}]`,
+ SettingKeyRegistrationEnabled: "true",
+ SettingKeyDefaultSubscriptions: `[{"group_id":11,"validity_days":30},{"group_id":12,"validity_days":7}]`,
+ SettingKeyAuthSourceDefaultEmailGrantOnSignup: "false",
}, nil)
service.defaultSubAssigner = assigner
From 6a75bd77e3cbf11a12a624a474a979f684be5fb6 Mon Sep 17 00:00:00 2001
From: IanShaw027
Date: Mon, 20 Apr 2026 19:30:09 +0800
Subject: [PATCH 028/189] feat: add pending oauth email onboarding flow
---
.../internal/handler/auth_linuxdo_oauth.go | 125 +++--
.../handler/auth_oauth_pending_flow.go | 426 +++++++++++++++++-
.../handler/auth_oauth_pending_flow_test.go | 371 ++++++++++++++-
backend/internal/handler/auth_oidc_oauth.go | 137 +++---
backend/internal/handler/auth_wechat_oauth.go | 33 ++
backend/internal/handler/dto/settings.go | 1 +
backend/internal/handler/setting_handler.go | 1 +
.../handler/setting_handler_public_test.go | 83 ++++
backend/internal/server/routes/auth.go | 48 ++
.../internal/service/auth_oauth_email_flow.go | 151 +++++++
backend/internal/service/setting_service.go | 2 +
.../service/setting_service_public_test.go | 13 +
backend/internal/service/settings_view.go | 1 +
13 files changed, 1273 insertions(+), 119 deletions(-)
create mode 100644 backend/internal/handler/setting_handler_public_test.go
create mode 100644 backend/internal/service/auth_oauth_email_flow.go
diff --git a/backend/internal/handler/auth_linuxdo_oauth.go b/backend/internal/handler/auth_linuxdo_oauth.go
index 835e5fd8..c4ecb8fa 100644
--- a/backend/internal/handler/auth_linuxdo_oauth.go
+++ b/backend/internal/handler/auth_linuxdo_oauth.go
@@ -243,6 +243,18 @@ func (h *AuthHandler) LinuxDoOAuthCallback(c *gin.Context) {
if subject != "" {
email = linuxDoSyntheticEmail(subject)
}
+ identityKey := service.PendingAuthIdentityKey{
+ ProviderType: "linuxdo",
+ ProviderKey: "linuxdo",
+ ProviderSubject: subject,
+ }
+ upstreamClaims := map[string]any{
+ "email": email,
+ "username": username,
+ "subject": subject,
+ "suggested_display_name": displayName,
+ "suggested_avatar_url": avatarURL,
+ }
if intent == oauthIntentBindCurrentUser {
targetUserID, err := h.readOAuthBindUserIDFromCookie(c, linuxDoOAuthBindUserCookieName)
if err != nil {
@@ -250,23 +262,13 @@ func (h *AuthHandler) LinuxDoOAuthCallback(c *gin.Context) {
return
}
if err := h.createOAuthPendingSession(c, oauthPendingSessionPayload{
- Intent: oauthIntentBindCurrentUser,
- Identity: service.PendingAuthIdentityKey{
- ProviderType: "linuxdo",
- ProviderKey: "linuxdo",
- ProviderSubject: subject,
- },
- TargetUserID: &targetUserID,
- ResolvedEmail: email,
- RedirectTo: redirectTo,
- BrowserSessionKey: browserSessionKey,
- UpstreamIdentityClaims: map[string]any{
- "email": email,
- "username": username,
- "subject": subject,
- "suggested_display_name": displayName,
- "suggested_avatar_url": avatarURL,
- },
+ Intent: oauthIntentBindCurrentUser,
+ Identity: identityKey,
+ TargetUserID: &targetUserID,
+ ResolvedEmail: email,
+ RedirectTo: redirectTo,
+ BrowserSessionKey: browserSessionKey,
+ UpstreamIdentityClaims: upstreamClaims,
CompletionResponse: map[string]any{
"redirect": redirectTo,
},
@@ -278,27 +280,60 @@ func (h *AuthHandler) LinuxDoOAuthCallback(c *gin.Context) {
return
}
+ existingIdentityUser, err := h.findOAuthIdentityUser(c.Request.Context(), identityKey)
+ if err != nil {
+ redirectOAuthError(c, frontendCallback, "session_error", infraerrors.Reason(err), infraerrors.Message(err))
+ return
+ }
+ if existingIdentityUser != nil {
+ tokenPair, user, err := h.authService.LoginOrRegisterOAuthWithTokenPair(c.Request.Context(), existingIdentityUser.Email, username, "")
+ if err != nil {
+ redirectOAuthError(c, frontendCallback, "login_failed", infraerrors.Reason(err), infraerrors.Message(err))
+ return
+ }
+ if err := h.createOAuthPendingSession(c, oauthPendingSessionPayload{
+ Intent: oauthIntentLogin,
+ Identity: identityKey,
+ TargetUserID: &user.ID,
+ ResolvedEmail: existingIdentityUser.Email,
+ RedirectTo: redirectTo,
+ BrowserSessionKey: browserSessionKey,
+ UpstreamIdentityClaims: upstreamClaims,
+ CompletionResponse: map[string]any{
+ "access_token": tokenPair.AccessToken,
+ "refresh_token": tokenPair.RefreshToken,
+ "expires_in": tokenPair.ExpiresIn,
+ "token_type": "Bearer",
+ "redirect": redirectTo,
+ },
+ }); err != nil {
+ redirectOAuthError(c, frontendCallback, "session_error", "failed to continue oauth login", "")
+ return
+ }
+ redirectToFrontendCallback(c, frontendCallback)
+ return
+ }
+
+ if h.isForceEmailOnThirdPartySignup(c.Request.Context()) {
+ if err := h.createOAuthEmailRequiredPendingSession(c, identityKey, redirectTo, browserSessionKey, upstreamClaims); err != nil {
+ redirectOAuthError(c, frontendCallback, "session_error", "failed to continue oauth login", "")
+ return
+ }
+ redirectToFrontendCallback(c, frontendCallback)
+ return
+ }
+
// 传入空邀请码;如果需要邀请码,服务层返回 ErrOAuthInvitationRequired
tokenPair, user, err := h.authService.LoginOrRegisterOAuthWithTokenPair(c.Request.Context(), email, username, "")
if err != nil {
if errors.Is(err, service.ErrOAuthInvitationRequired) {
if err := h.createOAuthPendingSession(c, oauthPendingSessionPayload{
- Intent: "login",
- Identity: service.PendingAuthIdentityKey{
- ProviderType: "linuxdo",
- ProviderKey: "linuxdo",
- ProviderSubject: subject,
- },
- ResolvedEmail: email,
- RedirectTo: redirectTo,
- BrowserSessionKey: browserSessionKey,
- UpstreamIdentityClaims: map[string]any{
- "email": email,
- "username": username,
- "subject": subject,
- "suggested_display_name": displayName,
- "suggested_avatar_url": avatarURL,
- },
+ Intent: "login",
+ Identity: identityKey,
+ ResolvedEmail: email,
+ RedirectTo: redirectTo,
+ BrowserSessionKey: browserSessionKey,
+ UpstreamIdentityClaims: upstreamClaims,
CompletionResponse: map[string]any{
"error": "invitation_required",
"redirect": redirectTo,
@@ -316,23 +351,13 @@ func (h *AuthHandler) LinuxDoOAuthCallback(c *gin.Context) {
}
if err := h.createOAuthPendingSession(c, oauthPendingSessionPayload{
- Intent: "login",
- Identity: service.PendingAuthIdentityKey{
- ProviderType: "linuxdo",
- ProviderKey: "linuxdo",
- ProviderSubject: subject,
- },
- TargetUserID: &user.ID,
- ResolvedEmail: email,
- RedirectTo: redirectTo,
- BrowserSessionKey: browserSessionKey,
- UpstreamIdentityClaims: map[string]any{
- "email": email,
- "username": username,
- "subject": subject,
- "suggested_display_name": displayName,
- "suggested_avatar_url": avatarURL,
- },
+ Intent: "login",
+ Identity: identityKey,
+ TargetUserID: &user.ID,
+ ResolvedEmail: email,
+ RedirectTo: redirectTo,
+ BrowserSessionKey: browserSessionKey,
+ UpstreamIdentityClaims: upstreamClaims,
CompletionResponse: map[string]any{
"access_token": tokenPair.AccessToken,
"refresh_token": tokenPair.RefreshToken,
diff --git a/backend/internal/handler/auth_oauth_pending_flow.go b/backend/internal/handler/auth_oauth_pending_flow.go
index 2d6c3714..99b9b406 100644
--- a/backend/internal/handler/auth_oauth_pending_flow.go
+++ b/backend/internal/handler/auth_oauth_pending_flow.go
@@ -46,6 +46,36 @@ type oauthAdoptionDecisionRequest struct {
AdoptAvatar *bool `json:"adopt_avatar,omitempty"`
}
+type bindPendingOAuthLoginRequest struct {
+ Email string `json:"email" binding:"required,email"`
+ Password string `json:"password" binding:"required"`
+ AdoptDisplayName *bool `json:"adopt_display_name,omitempty"`
+ AdoptAvatar *bool `json:"adopt_avatar,omitempty"`
+}
+
+type createPendingOAuthAccountRequest struct {
+ Email string `json:"email" binding:"required,email"`
+ VerifyCode string `json:"verify_code,omitempty"`
+ Password string `json:"password" binding:"required,min=6"`
+ InvitationCode string `json:"invitation_code,omitempty"`
+ AdoptDisplayName *bool `json:"adopt_display_name,omitempty"`
+ AdoptAvatar *bool `json:"adopt_avatar,omitempty"`
+}
+
+func (r bindPendingOAuthLoginRequest) adoptionDecision() oauthAdoptionDecisionRequest {
+ return oauthAdoptionDecisionRequest{
+ AdoptDisplayName: r.AdoptDisplayName,
+ AdoptAvatar: r.AdoptAvatar,
+ }
+}
+
+func (r createPendingOAuthAccountRequest) adoptionDecision() oauthAdoptionDecisionRequest {
+ return oauthAdoptionDecisionRequest{
+ AdoptDisplayName: r.AdoptDisplayName,
+ AdoptAvatar: r.AdoptAvatar,
+ }
+}
+
func (h *AuthHandler) pendingIdentityService() (*service.AuthPendingIdentityService, error) {
if h == nil || h.authService == nil || h.authService.EntClient() == nil {
return nil, infraerrors.ServiceUnavailable("PENDING_AUTH_NOT_READY", "pending auth service is not ready")
@@ -170,6 +200,36 @@ func readCompletionResponse(session map[string]any) (map[string]any, bool) {
return result, true
}
+func clonePendingMap(values map[string]any) map[string]any {
+ if len(values) == 0 {
+ return map[string]any{}
+ }
+ cloned := make(map[string]any, len(values))
+ for key, value := range values {
+ cloned[key] = value
+ }
+ return cloned
+}
+
+func mergePendingCompletionResponse(session *dbent.PendingAuthSession, overrides map[string]any) map[string]any {
+ payload, _ := readCompletionResponse(session.LocalFlowState)
+ merged := clonePendingMap(payload)
+ if strings.TrimSpace(session.RedirectTo) != "" {
+ if _, exists := merged["redirect"]; !exists {
+ merged["redirect"] = session.RedirectTo
+ }
+ }
+ for key, value := range overrides {
+ if value == nil {
+ delete(merged, key)
+ continue
+ }
+ merged[key] = value
+ }
+ applySuggestedProfileToCompletionResponse(merged, session.UpstreamIdentityClaims)
+ return merged
+}
+
func pendingSessionStringValue(values map[string]any, key string) string {
if len(values) == 0 {
return ""
@@ -264,6 +324,89 @@ func (h *AuthHandler) entClient() *dbent.Client {
return h.authService.EntClient()
}
+func (h *AuthHandler) isForceEmailOnThirdPartySignup(ctx context.Context) bool {
+ if h == nil || h.settingSvc == nil {
+ return false
+ }
+ defaults, err := h.settingSvc.GetAuthSourceDefaultSettings(ctx)
+ if err != nil || defaults == nil {
+ return false
+ }
+ return defaults.ForceEmailOnThirdPartySignup
+}
+
+func (h *AuthHandler) findOAuthIdentityUser(ctx context.Context, identity service.PendingAuthIdentityKey) (*dbent.User, error) {
+ client := h.entClient()
+ if client == nil {
+ return nil, infraerrors.ServiceUnavailable("PENDING_AUTH_NOT_READY", "pending auth service is not ready")
+ }
+
+ record, err := client.AuthIdentity.Query().
+ Where(
+ authidentity.ProviderTypeEQ(strings.TrimSpace(identity.ProviderType)),
+ authidentity.ProviderKeyEQ(strings.TrimSpace(identity.ProviderKey)),
+ authidentity.ProviderSubjectEQ(strings.TrimSpace(identity.ProviderSubject)),
+ ).
+ Only(ctx)
+ if err != nil {
+ if dbent.IsNotFound(err) {
+ return nil, nil
+ }
+ return nil, infraerrors.InternalServer("AUTH_IDENTITY_LOOKUP_FAILED", "failed to inspect auth identity ownership").WithCause(err)
+ }
+
+ userEntity, err := client.User.Get(ctx, record.UserID)
+ if err != nil {
+ if dbent.IsNotFound(err) {
+ return nil, nil
+ }
+ return nil, infraerrors.InternalServer("AUTH_IDENTITY_USER_LOOKUP_FAILED", "failed to load auth identity user").WithCause(err)
+ }
+ return userEntity, nil
+}
+
+func (h *AuthHandler) createOAuthEmailRequiredPendingSession(
+ c *gin.Context,
+ identity service.PendingAuthIdentityKey,
+ redirectTo string,
+ browserSessionKey string,
+ upstreamClaims map[string]any,
+) error {
+ return h.createOAuthPendingSession(c, oauthPendingSessionPayload{
+ Intent: oauthIntentLogin,
+ Identity: identity,
+ RedirectTo: redirectTo,
+ BrowserSessionKey: browserSessionKey,
+ UpstreamIdentityClaims: upstreamClaims,
+ CompletionResponse: map[string]any{
+ "redirect": redirectTo,
+ "step": "email_required",
+ "force_email_on_signup": true,
+ "email_binding_required": true,
+ "existing_account_bindable": true,
+ },
+ })
+}
+
+func (h *AuthHandler) BindLinuxDoOAuthLogin(c *gin.Context) { h.bindPendingOAuthLogin(c, "linuxdo") }
+func (h *AuthHandler) BindOIDCOAuthLogin(c *gin.Context) { h.bindPendingOAuthLogin(c, "oidc") }
+func (h *AuthHandler) BindWeChatOAuthLogin(c *gin.Context) { h.bindPendingOAuthLogin(c, "wechat") }
+func (h *AuthHandler) BindPendingOAuthLogin(c *gin.Context) { h.bindPendingOAuthLogin(c, "") }
+
+func (h *AuthHandler) CreateLinuxDoOAuthAccount(c *gin.Context) {
+ h.createPendingOAuthAccount(c, "linuxdo")
+}
+
+func (h *AuthHandler) CreateOIDCOAuthAccount(c *gin.Context) { h.createPendingOAuthAccount(c, "oidc") }
+
+func (h *AuthHandler) CreateWeChatOAuthAccount(c *gin.Context) {
+ h.createPendingOAuthAccount(c, "wechat")
+}
+
+func (h *AuthHandler) CreatePendingOAuthAccount(c *gin.Context) {
+ h.createPendingOAuthAccount(c, "")
+}
+
func (h *AuthHandler) upsertPendingOAuthAdoptionDecision(
c *gin.Context,
sessionID int64,
@@ -313,6 +456,60 @@ func (h *AuthHandler) upsertPendingOAuthAdoptionDecision(
return decision, nil
}
+func (h *AuthHandler) ensurePendingOAuthAdoptionDecision(
+ c *gin.Context,
+ sessionID int64,
+ req oauthAdoptionDecisionRequest,
+) (*dbent.IdentityAdoptionDecision, error) {
+ decision, err := h.upsertPendingOAuthAdoptionDecision(c, sessionID, req)
+ if err != nil {
+ return nil, err
+ }
+ if decision != nil {
+ return decision, nil
+ }
+
+ svc, err := h.pendingIdentityService()
+ if err != nil {
+ return nil, err
+ }
+ decision, err = svc.UpsertAdoptionDecision(c.Request.Context(), service.PendingIdentityAdoptionDecisionInput{
+ PendingAuthSessionID: sessionID,
+ })
+ if err != nil {
+ return nil, infraerrors.InternalServer("PENDING_AUTH_ADOPTION_SAVE_FAILED", "failed to save oauth profile adoption decision").WithCause(err)
+ }
+ return decision, nil
+}
+
+func updatePendingOAuthSessionProgress(
+ ctx context.Context,
+ client *dbent.Client,
+ session *dbent.PendingAuthSession,
+ intent string,
+ resolvedEmail string,
+ targetUserID *int64,
+ completionResponse map[string]any,
+) (*dbent.PendingAuthSession, error) {
+ if client == nil || session == nil {
+ return nil, infraerrors.BadRequest("PENDING_AUTH_SESSION_INVALID", "pending auth session is invalid")
+ }
+
+ localFlowState := clonePendingMap(session.LocalFlowState)
+ localFlowState[oauthCompletionResponseKey] = clonePendingMap(completionResponse)
+
+ update := client.PendingAuthSession.UpdateOneID(session.ID).
+ SetIntent(strings.TrimSpace(intent)).
+ SetResolvedEmail(strings.TrimSpace(resolvedEmail)).
+ SetLocalFlowState(localFlowState)
+ if targetUserID != nil && *targetUserID > 0 {
+ update = update.SetTargetUserID(*targetUserID)
+ } else {
+ update = update.ClearTargetUserID()
+ }
+ return update.Save(ctx)
+}
+
func resolvePendingOAuthTargetUserID(ctx context.Context, client *dbent.Client, session *dbent.PendingAuthSession) (int64, error) {
if session == nil {
return 0, infraerrors.BadRequest("PENDING_AUTH_SESSION_INVALID", "pending auth session is invalid")
@@ -401,17 +598,18 @@ func shouldBindPendingOAuthIdentity(session *dbent.PendingAuthSession, decision
return decision.AdoptDisplayName || decision.AdoptAvatar
}
-func applyPendingOAuthAdoption(
+func applyPendingOAuthBinding(
ctx context.Context,
client *dbent.Client,
session *dbent.PendingAuthSession,
decision *dbent.IdentityAdoptionDecision,
overrideUserID *int64,
+ forceBind bool,
) error {
- if client == nil || session == nil || decision == nil {
+ if client == nil || session == nil {
return nil
}
- if !shouldBindPendingOAuthIdentity(session, decision) {
+ if !forceBind && !shouldBindPendingOAuthIdentity(session, decision) {
return nil
}
@@ -427,11 +625,11 @@ func applyPendingOAuthAdoption(
}
adoptedDisplayName := ""
- if decision.AdoptDisplayName {
+ if decision != nil && decision.AdoptDisplayName {
adoptedDisplayName = normalizeAdoptedOAuthDisplayName(pendingSessionStringValue(session.UpstreamIdentityClaims, "suggested_display_name"))
}
adoptedAvatarURL := ""
- if decision.AdoptAvatar {
+ if decision != nil && decision.AdoptAvatar {
adoptedAvatarURL = pendingSessionStringValue(session.UpstreamIdentityClaims, "suggested_avatar_url")
}
@@ -441,7 +639,7 @@ func applyPendingOAuthAdoption(
}
defer func() { _ = tx.Rollback() }()
- if decision.AdoptDisplayName && adoptedDisplayName != "" {
+ if decision != nil && decision.AdoptDisplayName && adoptedDisplayName != "" {
if err := tx.Client().User.UpdateOneID(targetUserID).
SetUsername(adoptedDisplayName).
Exec(ctx); err != nil {
@@ -458,10 +656,10 @@ func applyPendingOAuthAdoption(
for key, value := range session.UpstreamIdentityClaims {
metadata[key] = value
}
- if decision.AdoptDisplayName && adoptedDisplayName != "" {
+ if decision != nil && decision.AdoptDisplayName && adoptedDisplayName != "" {
metadata["display_name"] = adoptedDisplayName
}
- if decision.AdoptAvatar && adoptedAvatarURL != "" {
+ if decision != nil && decision.AdoptAvatar && adoptedAvatarURL != "" {
metadata["avatar_url"] = adoptedAvatarURL
}
@@ -473,7 +671,7 @@ func applyPendingOAuthAdoption(
return err
}
- if decision.IdentityID == nil || *decision.IdentityID != identity.ID {
+ if decision != nil && (decision.IdentityID == nil || *decision.IdentityID != identity.ID) {
if _, err := tx.Client().IdentityAdoptionDecision.UpdateOneID(decision.ID).
SetIdentityID(identity.ID).
Save(ctx); err != nil {
@@ -484,6 +682,16 @@ func applyPendingOAuthAdoption(
return tx.Commit()
}
+func applyPendingOAuthAdoption(
+ ctx context.Context,
+ client *dbent.Client,
+ session *dbent.PendingAuthSession,
+ decision *dbent.IdentityAdoptionDecision,
+ overrideUserID *int64,
+) error {
+ return applyPendingOAuthBinding(ctx, client, session, decision, overrideUserID, false)
+}
+
func applySuggestedProfileToCompletionResponse(payload map[string]any, upstream map[string]any) {
if len(payload) == 0 || len(upstream) == 0 {
return
@@ -507,6 +715,206 @@ func applySuggestedProfileToCompletionResponse(payload map[string]any, upstream
}
}
+func readPendingOAuthBrowserSession(c *gin.Context, h *AuthHandler) (*service.AuthPendingIdentityService, *dbent.PendingAuthSession, func(), error) {
+ secureCookie := isRequestHTTPS(c)
+ clearCookies := func() {
+ clearOAuthPendingSessionCookie(c, secureCookie)
+ clearOAuthPendingBrowserCookie(c, secureCookie)
+ }
+
+ sessionToken, err := readOAuthPendingSessionCookie(c)
+ if err != nil || strings.TrimSpace(sessionToken) == "" {
+ clearCookies()
+ return nil, nil, clearCookies, service.ErrPendingAuthSessionNotFound
+ }
+ browserSessionKey, err := readOAuthPendingBrowserCookie(c)
+ if err != nil || strings.TrimSpace(browserSessionKey) == "" {
+ clearCookies()
+ return nil, nil, clearCookies, service.ErrPendingAuthBrowserMismatch
+ }
+
+ svc, err := h.pendingIdentityService()
+ if err != nil {
+ clearCookies()
+ return nil, nil, clearCookies, err
+ }
+
+ session, err := svc.GetBrowserSession(c.Request.Context(), sessionToken, browserSessionKey)
+ if err != nil {
+ clearCookies()
+ return nil, nil, clearCookies, err
+ }
+
+ return svc, session, clearCookies, nil
+}
+
+func buildPendingOAuthSessionStatusPayload(session *dbent.PendingAuthSession) gin.H {
+ payload := gin.H{
+ "auth_result": "pending_session",
+ "provider": strings.TrimSpace(session.ProviderType),
+ "intent": strings.TrimSpace(session.Intent),
+ }
+ for key, value := range mergePendingCompletionResponse(session, nil) {
+ payload[key] = value
+ }
+ if email := strings.TrimSpace(session.ResolvedEmail); email != "" {
+ payload["email"] = email
+ }
+ return payload
+}
+
+func writeOAuthTokenPairResponse(c *gin.Context, tokenPair *service.TokenPair) {
+ c.JSON(http.StatusOK, gin.H{
+ "access_token": tokenPair.AccessToken,
+ "refresh_token": tokenPair.RefreshToken,
+ "expires_in": tokenPair.ExpiresIn,
+ "token_type": "Bearer",
+ })
+}
+
+func (h *AuthHandler) bindPendingOAuthLogin(c *gin.Context, provider string) {
+ var req bindPendingOAuthLoginRequest
+ if err := c.ShouldBindJSON(&req); err != nil {
+ response.BadRequest(c, "Invalid request: "+err.Error())
+ return
+ }
+
+ pendingSvc, session, clearCookies, err := readPendingOAuthBrowserSession(c, h)
+ if err != nil {
+ response.ErrorFrom(c, err)
+ return
+ }
+ if strings.TrimSpace(provider) != "" && !strings.EqualFold(strings.TrimSpace(session.ProviderType), provider) {
+ response.BadRequest(c, "Pending oauth session provider mismatch")
+ return
+ }
+
+ user, err := h.authService.ValidatePasswordCredentials(c.Request.Context(), strings.TrimSpace(req.Email), req.Password)
+ if err != nil {
+ response.ErrorFrom(c, err)
+ return
+ }
+ if session.TargetUserID != nil && *session.TargetUserID > 0 && user.ID != *session.TargetUserID {
+ response.ErrorFrom(c, infraerrors.Conflict("PENDING_AUTH_TARGET_USER_MISMATCH", "pending oauth session must be completed by the targeted user"))
+ return
+ }
+
+ decision, err := h.ensurePendingOAuthAdoptionDecision(c, session.ID, req.adoptionDecision())
+ if err != nil {
+ response.ErrorFrom(c, err)
+ return
+ }
+ if err := applyPendingOAuthBinding(c.Request.Context(), h.entClient(), session, decision, &user.ID, true); err != nil {
+ response.ErrorFrom(c, infraerrors.InternalServer("PENDING_AUTH_BIND_APPLY_FAILED", "failed to bind pending oauth identity").WithCause(err))
+ return
+ }
+
+ h.authService.RecordSuccessfulLogin(c.Request.Context(), user.ID)
+ tokenPair, err := h.authService.GenerateTokenPair(c.Request.Context(), user, "")
+ if err != nil {
+ response.InternalError(c, "Failed to generate token pair")
+ return
+ }
+ if _, err := pendingSvc.ConsumeBrowserSession(c.Request.Context(), session.SessionToken, session.BrowserSessionKey); err != nil {
+ clearCookies()
+ response.ErrorFrom(c, err)
+ return
+ }
+
+ clearCookies()
+ writeOAuthTokenPairResponse(c, tokenPair)
+}
+
+func (h *AuthHandler) createPendingOAuthAccount(c *gin.Context, provider string) {
+ var req createPendingOAuthAccountRequest
+ if err := c.ShouldBindJSON(&req); err != nil {
+ response.BadRequest(c, "Invalid request: "+err.Error())
+ return
+ }
+
+ pendingSvc, session, clearCookies, err := readPendingOAuthBrowserSession(c, h)
+ if err != nil {
+ response.ErrorFrom(c, err)
+ return
+ }
+ if strings.TrimSpace(provider) != "" && !strings.EqualFold(strings.TrimSpace(session.ProviderType), provider) {
+ response.BadRequest(c, "Pending oauth session provider mismatch")
+ return
+ }
+
+ client := h.entClient()
+ if client == nil {
+ response.ErrorFrom(c, infraerrors.ServiceUnavailable("PENDING_AUTH_NOT_READY", "pending auth service is not ready"))
+ return
+ }
+
+ email := strings.TrimSpace(strings.ToLower(req.Email))
+ existingUser, err := client.User.Query().Where(dbuser.EmailEQ(email)).Only(c.Request.Context())
+ if err != nil && !dbent.IsNotFound(err) {
+ response.ErrorFrom(c, infraerrors.ServiceUnavailable("SERVICE_UNAVAILABLE", "service temporarily unavailable"))
+ return
+ }
+ if existingUser != nil {
+ completionResponse := mergePendingCompletionResponse(session, map[string]any{
+ "step": "bind_login_required",
+ "email": email,
+ })
+ session, err = updatePendingOAuthSessionProgress(
+ c.Request.Context(),
+ client,
+ session,
+ "adopt_existing_user_by_email",
+ email,
+ &existingUser.ID,
+ completionResponse,
+ )
+ if err != nil {
+ response.ErrorFrom(c, infraerrors.InternalServer("PENDING_AUTH_SESSION_UPDATE_FAILED", "failed to update pending oauth session").WithCause(err))
+ return
+ }
+
+ if _, err := h.ensurePendingOAuthAdoptionDecision(c, session.ID, req.adoptionDecision()); err != nil {
+ response.ErrorFrom(c, err)
+ return
+ }
+
+ c.JSON(http.StatusOK, buildPendingOAuthSessionStatusPayload(session))
+ return
+ }
+
+ tokenPair, user, err := h.authService.RegisterOAuthEmailAccount(
+ c.Request.Context(),
+ email,
+ req.Password,
+ strings.TrimSpace(req.VerifyCode),
+ strings.TrimSpace(req.InvitationCode),
+ strings.TrimSpace(session.ProviderType),
+ )
+ if err != nil {
+ response.ErrorFrom(c, err)
+ return
+ }
+
+ decision, err := h.ensurePendingOAuthAdoptionDecision(c, session.ID, req.adoptionDecision())
+ if err != nil {
+ response.ErrorFrom(c, err)
+ return
+ }
+ if err := applyPendingOAuthBinding(c.Request.Context(), client, session, decision, &user.ID, true); err != nil {
+ response.ErrorFrom(c, infraerrors.InternalServer("PENDING_AUTH_BIND_APPLY_FAILED", "failed to bind pending oauth identity").WithCause(err))
+ return
+ }
+
+ if _, err := pendingSvc.ConsumeBrowserSession(c.Request.Context(), session.SessionToken, session.BrowserSessionKey); err != nil {
+ clearCookies()
+ response.ErrorFrom(c, err)
+ return
+ }
+
+ clearCookies()
+ writeOAuthTokenPairResponse(c, tokenPair)
+}
+
// ExchangePendingOAuthCompletion redeems a pending OAuth browser session into a frontend-safe payload.
// POST /api/v1/auth/oauth/pending/exchange
func (h *AuthHandler) ExchangePendingOAuthCompletion(c *gin.Context) {
diff --git a/backend/internal/handler/auth_oauth_pending_flow_test.go b/backend/internal/handler/auth_oauth_pending_flow_test.go
index 3afb4fb7..80338b8a 100644
--- a/backend/internal/handler/auth_oauth_pending_flow_test.go
+++ b/backend/internal/handler/auth_oauth_pending_flow_test.go
@@ -509,9 +509,305 @@ func TestExchangePendingOAuthCompletionInvitationRequiredFalseFalsePersistsDecis
require.Nil(t, storedSession.ConsumedAt)
}
+func TestCreateOIDCOAuthAccountCreatesUserBindsIdentityAndConsumesSession(t *testing.T) {
+ handler, client := newOAuthPendingFlowTestHandlerWithEmailVerification(t, false, "fresh@example.com", "246810")
+ ctx := context.Background()
+
+ session, err := client.PendingAuthSession.Create().
+ SetSessionToken("create-account-session-token").
+ SetIntent("login").
+ SetProviderType("oidc").
+ SetProviderKey("https://issuer.example").
+ SetProviderSubject("oidc-create-123").
+ SetBrowserSessionKey("create-account-browser-session-key").
+ SetUpstreamIdentityClaims(map[string]any{
+ "username": "oidc_user",
+ "suggested_display_name": "Fresh OIDC User",
+ "suggested_avatar_url": "https://cdn.example/fresh.png",
+ }).
+ SetRedirectTo("/profile").
+ SetExpiresAt(time.Now().UTC().Add(10 * time.Minute)).
+ Save(ctx)
+ require.NoError(t, err)
+
+ body := bytes.NewBufferString(`{"email":"fresh@example.com","verify_code":"246810","password":"secret-123","adopt_display_name":false,"adopt_avatar":false}`)
+ recorder := httptest.NewRecorder()
+ ginCtx, _ := gin.CreateTestContext(recorder)
+ req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/oauth/oidc/create-account", body)
+ req.Header.Set("Content-Type", "application/json")
+ req.AddCookie(&http.Cookie{Name: oauthPendingSessionCookieName, Value: encodeCookieValue(session.SessionToken)})
+ req.AddCookie(&http.Cookie{Name: oauthPendingBrowserCookieName, Value: encodeCookieValue("create-account-browser-session-key")})
+ ginCtx.Request = req
+
+ handler.CreateOIDCOAuthAccount(ginCtx)
+
+ require.Equal(t, http.StatusOK, recorder.Code)
+
+ var payload map[string]any
+ require.NoError(t, json.Unmarshal(recorder.Body.Bytes(), &payload))
+ require.NotEmpty(t, payload["access_token"])
+ require.NotEmpty(t, payload["refresh_token"])
+ require.Equal(t, "Bearer", payload["token_type"])
+
+ createdUser, err := client.User.Query().Where(dbuser.EmailEQ("fresh@example.com")).Only(ctx)
+ require.NoError(t, err)
+ require.Equal(t, service.StatusActive, createdUser.Status)
+
+ identity, err := client.AuthIdentity.Query().
+ Where(
+ authidentity.ProviderTypeEQ("oidc"),
+ authidentity.ProviderKeyEQ("https://issuer.example"),
+ authidentity.ProviderSubjectEQ("oidc-create-123"),
+ ).
+ Only(ctx)
+ require.NoError(t, err)
+ require.Equal(t, createdUser.ID, identity.UserID)
+
+ storedSession, err := client.PendingAuthSession.Get(ctx, session.ID)
+ require.NoError(t, err)
+ require.NotNil(t, storedSession.ConsumedAt)
+}
+
+func TestCreateOIDCOAuthAccountExistingEmailReturnsAdoptExistingUserByEmailState(t *testing.T) {
+ handler, client := newOAuthPendingFlowTestHandlerWithEmailVerification(t, false, "owner@example.com", "135790")
+ ctx := context.Background()
+
+ existingUser, err := client.User.Create().
+ SetEmail("owner@example.com").
+ SetUsername("owner-user").
+ SetPasswordHash("hash").
+ SetRole(service.RoleUser).
+ SetStatus(service.StatusActive).
+ Save(ctx)
+ require.NoError(t, err)
+
+ session, err := client.PendingAuthSession.Create().
+ SetSessionToken("existing-email-session-token").
+ SetIntent("login").
+ SetProviderType("oidc").
+ SetProviderKey("https://issuer.example").
+ SetProviderSubject("oidc-existing-123").
+ SetBrowserSessionKey("existing-email-browser-session-key").
+ SetUpstreamIdentityClaims(map[string]any{
+ "username": "oidc_user",
+ "suggested_display_name": "Existing OIDC User",
+ "suggested_avatar_url": "https://cdn.example/existing.png",
+ }).
+ SetRedirectTo("/dashboard").
+ SetExpiresAt(time.Now().UTC().Add(10 * time.Minute)).
+ Save(ctx)
+ require.NoError(t, err)
+
+ body := bytes.NewBufferString(`{"email":"owner@example.com","verify_code":"135790","password":"secret-123"}`)
+ recorder := httptest.NewRecorder()
+ ginCtx, _ := gin.CreateTestContext(recorder)
+ req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/oauth/oidc/create-account", body)
+ req.Header.Set("Content-Type", "application/json")
+ req.AddCookie(&http.Cookie{Name: oauthPendingSessionCookieName, Value: encodeCookieValue(session.SessionToken)})
+ req.AddCookie(&http.Cookie{Name: oauthPendingBrowserCookieName, Value: encodeCookieValue("existing-email-browser-session-key")})
+ ginCtx.Request = req
+
+ handler.CreateOIDCOAuthAccount(ginCtx)
+
+ require.Equal(t, http.StatusOK, recorder.Code)
+
+ var payload map[string]any
+ require.NoError(t, json.Unmarshal(recorder.Body.Bytes(), &payload))
+ require.Equal(t, "pending_session", payload["auth_result"])
+ require.Equal(t, "adopt_existing_user_by_email", payload["intent"])
+ require.Equal(t, "oidc", payload["provider"])
+ require.Equal(t, "/dashboard", payload["redirect"])
+ require.Equal(t, true, payload["adoption_required"])
+ require.Equal(t, "Existing OIDC User", payload["suggested_display_name"])
+ require.Equal(t, "https://cdn.example/existing.png", payload["suggested_avatar_url"])
+
+ storedSession, err := client.PendingAuthSession.Get(ctx, session.ID)
+ require.NoError(t, err)
+ require.Equal(t, "adopt_existing_user_by_email", storedSession.Intent)
+ require.NotNil(t, storedSession.TargetUserID)
+ require.Equal(t, existingUser.ID, *storedSession.TargetUserID)
+ require.Equal(t, "owner@example.com", storedSession.ResolvedEmail)
+ require.Nil(t, storedSession.ConsumedAt)
+
+ identityCount, err := client.AuthIdentity.Query().
+ Where(
+ authidentity.ProviderTypeEQ("oidc"),
+ authidentity.ProviderKeyEQ("https://issuer.example"),
+ authidentity.ProviderSubjectEQ("oidc-existing-123"),
+ ).
+ Count(ctx)
+ require.NoError(t, err)
+ require.Zero(t, identityCount)
+}
+
+func TestBindOIDCOAuthLoginBindsExistingUserAndConsumesSession(t *testing.T) {
+ handler, client := newOAuthPendingFlowTestHandler(t, false)
+ ctx := context.Background()
+
+ passwordHash, err := handler.authService.HashPassword("secret-123")
+ require.NoError(t, err)
+
+ existingUser, err := client.User.Create().
+ SetEmail("owner@example.com").
+ SetUsername("owner-user").
+ SetPasswordHash(passwordHash).
+ SetRole(service.RoleUser).
+ SetStatus(service.StatusActive).
+ Save(ctx)
+ require.NoError(t, err)
+
+ session, err := client.PendingAuthSession.Create().
+ SetSessionToken("bind-login-session-token").
+ SetIntent("adopt_existing_user_by_email").
+ SetProviderType("oidc").
+ SetProviderKey("https://issuer.example").
+ SetProviderSubject("oidc-bind-123").
+ SetTargetUserID(existingUser.ID).
+ SetResolvedEmail(existingUser.Email).
+ SetBrowserSessionKey("bind-login-browser-session-key").
+ SetUpstreamIdentityClaims(map[string]any{
+ "username": "oidc_user",
+ "suggested_display_name": "Bound OIDC User",
+ "suggested_avatar_url": "https://cdn.example/bound.png",
+ }).
+ SetRedirectTo("/profile").
+ SetExpiresAt(time.Now().UTC().Add(10 * time.Minute)).
+ Save(ctx)
+ require.NoError(t, err)
+
+ body := bytes.NewBufferString(`{"email":"owner@example.com","password":"secret-123","adopt_display_name":false,"adopt_avatar":false}`)
+ recorder := httptest.NewRecorder()
+ ginCtx, _ := gin.CreateTestContext(recorder)
+ req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/oauth/oidc/bind-login", body)
+ req.Header.Set("Content-Type", "application/json")
+ req.AddCookie(&http.Cookie{Name: oauthPendingSessionCookieName, Value: encodeCookieValue(session.SessionToken)})
+ req.AddCookie(&http.Cookie{Name: oauthPendingBrowserCookieName, Value: encodeCookieValue("bind-login-browser-session-key")})
+ ginCtx.Request = req
+
+ handler.BindOIDCOAuthLogin(ginCtx)
+
+ require.Equal(t, http.StatusOK, recorder.Code)
+
+ var payload map[string]any
+ require.NoError(t, json.Unmarshal(recorder.Body.Bytes(), &payload))
+ require.NotEmpty(t, payload["access_token"])
+ require.NotEmpty(t, payload["refresh_token"])
+ require.Equal(t, "Bearer", payload["token_type"])
+
+ identity, err := client.AuthIdentity.Query().
+ Where(
+ authidentity.ProviderTypeEQ("oidc"),
+ authidentity.ProviderKeyEQ("https://issuer.example"),
+ authidentity.ProviderSubjectEQ("oidc-bind-123"),
+ ).
+ Only(ctx)
+ require.NoError(t, err)
+ require.Equal(t, existingUser.ID, identity.UserID)
+
+ storedSession, err := client.PendingAuthSession.Get(ctx, session.ID)
+ require.NoError(t, err)
+ require.NotNil(t, storedSession.ConsumedAt)
+}
+
+func TestBindOIDCOAuthLoginRejectsInvalidPasswordWithoutConsumingSession(t *testing.T) {
+ handler, client := newOAuthPendingFlowTestHandler(t, false)
+ ctx := context.Background()
+
+ passwordHash, err := handler.authService.HashPassword("secret-123")
+ require.NoError(t, err)
+
+ existingUser, err := client.User.Create().
+ SetEmail("owner@example.com").
+ SetUsername("owner-user").
+ SetPasswordHash(passwordHash).
+ SetRole(service.RoleUser).
+ SetStatus(service.StatusActive).
+ Save(ctx)
+ require.NoError(t, err)
+
+ session, err := client.PendingAuthSession.Create().
+ SetSessionToken("bind-login-invalid-password-session-token").
+ SetIntent("adopt_existing_user_by_email").
+ SetProviderType("oidc").
+ SetProviderKey("https://issuer.example").
+ SetProviderSubject("oidc-bind-invalid-123").
+ SetTargetUserID(existingUser.ID).
+ SetResolvedEmail(existingUser.Email).
+ SetBrowserSessionKey("bind-login-invalid-password-browser-session-key").
+ SetUpstreamIdentityClaims(map[string]any{
+ "username": "oidc_user",
+ "suggested_display_name": "Bound OIDC User",
+ "suggested_avatar_url": "https://cdn.example/bound.png",
+ }).
+ SetExpiresAt(time.Now().UTC().Add(10 * time.Minute)).
+ Save(ctx)
+ require.NoError(t, err)
+
+ body := bytes.NewBufferString(`{"email":"owner@example.com","password":"wrong-password"}`)
+ recorder := httptest.NewRecorder()
+ ginCtx, _ := gin.CreateTestContext(recorder)
+ req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/oauth/oidc/bind-login", body)
+ req.Header.Set("Content-Type", "application/json")
+ req.AddCookie(&http.Cookie{Name: oauthPendingSessionCookieName, Value: encodeCookieValue(session.SessionToken)})
+ req.AddCookie(&http.Cookie{Name: oauthPendingBrowserCookieName, Value: encodeCookieValue("bind-login-invalid-password-browser-session-key")})
+ ginCtx.Request = req
+
+ handler.BindOIDCOAuthLogin(ginCtx)
+
+ require.Equal(t, http.StatusUnauthorized, recorder.Code)
+ payload := decodeJSONBody(t, recorder)
+ require.Equal(t, "INVALID_CREDENTIALS", payload["reason"])
+
+ identityCount, err := client.AuthIdentity.Query().
+ Where(
+ authidentity.ProviderTypeEQ("oidc"),
+ authidentity.ProviderKeyEQ("https://issuer.example"),
+ authidentity.ProviderSubjectEQ("oidc-bind-invalid-123"),
+ ).
+ Count(ctx)
+ require.NoError(t, err)
+ require.Zero(t, identityCount)
+
+ storedSession, err := client.PendingAuthSession.Get(ctx, session.ID)
+ require.NoError(t, err)
+ require.Nil(t, storedSession.ConsumedAt)
+}
+
func newOAuthPendingFlowTestHandler(t *testing.T, invitationEnabled bool) (*AuthHandler, *dbent.Client) {
t.Helper()
+ return newOAuthPendingFlowTestHandlerWithOptions(t, invitationEnabled, false, nil)
+}
+
+func newOAuthPendingFlowTestHandlerWithEmailVerification(
+ t *testing.T,
+ invitationEnabled bool,
+ email string,
+ code string,
+) (*AuthHandler, *dbent.Client) {
+ t.Helper()
+
+ cache := &oauthPendingFlowEmailCacheStub{
+ verificationCodes: map[string]*service.VerificationCodeData{
+ email: {
+ Code: code,
+ Attempts: 0,
+ CreatedAt: time.Now().UTC(),
+ ExpiresAt: time.Now().UTC().Add(15 * time.Minute),
+ },
+ },
+ }
+ return newOAuthPendingFlowTestHandlerWithOptions(t, invitationEnabled, true, cache)
+}
+
+func newOAuthPendingFlowTestHandlerWithOptions(
+ t *testing.T,
+ invitationEnabled bool,
+ emailVerifyEnabled bool,
+ emailCache service.EmailCache,
+) (*AuthHandler, *dbent.Client) {
+ t.Helper()
+
db, err := sql.Open("sqlite", "file:auth_oauth_pending_flow_handler?mode=memory&cache=shared")
require.NoError(t, err)
t.Cleanup(func() { _ = db.Close() })
@@ -538,9 +834,18 @@ func newOAuthPendingFlowTestHandler(t *testing.T, invitationEnabled bool) (*Auth
values: map[string]string{
service.SettingKeyRegistrationEnabled: "true",
service.SettingKeyInvitationCodeEnabled: boolSettingValue(invitationEnabled),
+ service.SettingKeyEmailVerifyEnabled: boolSettingValue(emailVerifyEnabled),
},
}, cfg)
userRepo := &oauthPendingFlowUserRepo{client: client}
+ var emailService *service.EmailService
+ if emailCache != nil {
+ emailService = service.NewEmailService(&oauthPendingFlowSettingRepoStub{
+ values: map[string]string{
+ service.SettingKeyEmailVerifyEnabled: boolSettingValue(emailVerifyEnabled),
+ },
+ }, emailCache)
+ }
authSvc := service.NewAuthService(
client,
userRepo,
@@ -548,7 +853,7 @@ func newOAuthPendingFlowTestHandler(t *testing.T, invitationEnabled bool) (*Auth
&oauthPendingFlowRefreshTokenCacheStub{},
cfg,
settingSvc,
- nil,
+ emailService,
nil,
nil,
nil,
@@ -622,6 +927,70 @@ func (s *oauthPendingFlowSettingRepoStub) Delete(context.Context, string) error
type oauthPendingFlowRefreshTokenCacheStub struct{}
+type oauthPendingFlowEmailCacheStub struct {
+ verificationCodes map[string]*service.VerificationCodeData
+}
+
+func (s *oauthPendingFlowEmailCacheStub) GetVerificationCode(_ context.Context, email string) (*service.VerificationCodeData, error) {
+ if s == nil || s.verificationCodes == nil {
+ return nil, nil
+ }
+ return s.verificationCodes[email], nil
+}
+
+func (s *oauthPendingFlowEmailCacheStub) SetVerificationCode(_ context.Context, email string, data *service.VerificationCodeData, _ time.Duration) error {
+ if s.verificationCodes == nil {
+ s.verificationCodes = map[string]*service.VerificationCodeData{}
+ }
+ s.verificationCodes[email] = data
+ return nil
+}
+
+func (s *oauthPendingFlowEmailCacheStub) DeleteVerificationCode(_ context.Context, email string) error {
+ delete(s.verificationCodes, email)
+ return nil
+}
+
+func (s *oauthPendingFlowEmailCacheStub) GetNotifyVerifyCode(context.Context, string) (*service.VerificationCodeData, error) {
+ return nil, nil
+}
+
+func (s *oauthPendingFlowEmailCacheStub) SetNotifyVerifyCode(context.Context, string, *service.VerificationCodeData, time.Duration) error {
+ return nil
+}
+
+func (s *oauthPendingFlowEmailCacheStub) DeleteNotifyVerifyCode(context.Context, string) error {
+ return nil
+}
+
+func (s *oauthPendingFlowEmailCacheStub) GetPasswordResetToken(context.Context, string) (*service.PasswordResetTokenData, error) {
+ return nil, nil
+}
+
+func (s *oauthPendingFlowEmailCacheStub) SetPasswordResetToken(context.Context, string, *service.PasswordResetTokenData, time.Duration) error {
+ return nil
+}
+
+func (s *oauthPendingFlowEmailCacheStub) DeletePasswordResetToken(context.Context, string) error {
+ return nil
+}
+
+func (s *oauthPendingFlowEmailCacheStub) IsPasswordResetEmailInCooldown(context.Context, string) bool {
+ return false
+}
+
+func (s *oauthPendingFlowEmailCacheStub) SetPasswordResetEmailCooldown(context.Context, string, time.Duration) error {
+ return nil
+}
+
+func (s *oauthPendingFlowEmailCacheStub) IncrNotifyCodeUserRate(context.Context, int64, time.Duration) (int64, error) {
+ return 0, nil
+}
+
+func (s *oauthPendingFlowEmailCacheStub) GetNotifyCodeUserRate(context.Context, int64) (int64, error) {
+ return 0, nil
+}
+
func (s *oauthPendingFlowRefreshTokenCacheStub) StoreRefreshToken(context.Context, string, *service.RefreshTokenData, time.Duration) error {
return nil
}
diff --git a/backend/internal/handler/auth_oidc_oauth.go b/backend/internal/handler/auth_oidc_oauth.go
index 0f79759e..909d6379 100644
--- a/backend/internal/handler/auth_oidc_oauth.go
+++ b/backend/internal/handler/auth_oidc_oauth.go
@@ -342,6 +342,21 @@ func (h *AuthHandler) OIDCOAuthCallback(c *gin.Context) {
idClaims.Name,
oidcFallbackUsername(subject),
)
+ identityRef := service.PendingAuthIdentityKey{
+ ProviderType: "oidc",
+ ProviderKey: issuer,
+ ProviderSubject: subject,
+ }
+ upstreamClaims := map[string]any{
+ "email": email,
+ "username": username,
+ "subject": subject,
+ "issuer": issuer,
+ "email_verified": emailVerified != nil && *emailVerified,
+ "provider_fallback": strings.TrimSpace(cfg.ProviderName),
+ "suggested_display_name": firstNonEmpty(userInfoClaims.DisplayName, idClaims.Name, username),
+ "suggested_avatar_url": userInfoClaims.AvatarURL,
+ }
if intent == oauthIntentBindCurrentUser {
targetUserID, err := h.readOAuthBindUserIDFromCookie(c, oidcOAuthBindUserCookieName)
if err != nil {
@@ -349,26 +364,13 @@ func (h *AuthHandler) OIDCOAuthCallback(c *gin.Context) {
return
}
if err := h.createOAuthPendingSession(c, oauthPendingSessionPayload{
- Intent: oauthIntentBindCurrentUser,
- Identity: service.PendingAuthIdentityKey{
- ProviderType: "oidc",
- ProviderKey: issuer,
- ProviderSubject: subject,
- },
- TargetUserID: &targetUserID,
- ResolvedEmail: email,
- RedirectTo: redirectTo,
- BrowserSessionKey: browserSessionKey,
- UpstreamIdentityClaims: map[string]any{
- "email": email,
- "username": username,
- "subject": subject,
- "issuer": issuer,
- "email_verified": emailVerified != nil && *emailVerified,
- "provider_fallback": strings.TrimSpace(cfg.ProviderName),
- "suggested_display_name": firstNonEmpty(userInfoClaims.DisplayName, idClaims.Name, username),
- "suggested_avatar_url": userInfoClaims.AvatarURL,
- },
+ Intent: oauthIntentBindCurrentUser,
+ Identity: identityRef,
+ TargetUserID: &targetUserID,
+ ResolvedEmail: email,
+ RedirectTo: redirectTo,
+ BrowserSessionKey: browserSessionKey,
+ UpstreamIdentityClaims: upstreamClaims,
CompletionResponse: map[string]any{
"redirect": redirectTo,
},
@@ -380,30 +382,60 @@ func (h *AuthHandler) OIDCOAuthCallback(c *gin.Context) {
return
}
+ existingIdentityUser, err := h.findOAuthIdentityUser(c.Request.Context(), identityRef)
+ if err != nil {
+ redirectOAuthError(c, frontendCallback, "session_error", infraerrors.Reason(err), infraerrors.Message(err))
+ return
+ }
+ if existingIdentityUser != nil {
+ tokenPair, user, err := h.authService.LoginOrRegisterOAuthWithTokenPair(c.Request.Context(), existingIdentityUser.Email, username, "")
+ if err != nil {
+ redirectOAuthError(c, frontendCallback, "login_failed", infraerrors.Reason(err), infraerrors.Message(err))
+ return
+ }
+ if err := h.createOAuthPendingSession(c, oauthPendingSessionPayload{
+ Intent: oauthIntentLogin,
+ Identity: identityRef,
+ TargetUserID: &user.ID,
+ ResolvedEmail: existingIdentityUser.Email,
+ RedirectTo: redirectTo,
+ BrowserSessionKey: browserSessionKey,
+ UpstreamIdentityClaims: upstreamClaims,
+ CompletionResponse: map[string]any{
+ "access_token": tokenPair.AccessToken,
+ "refresh_token": tokenPair.RefreshToken,
+ "expires_in": tokenPair.ExpiresIn,
+ "token_type": "Bearer",
+ "redirect": redirectTo,
+ },
+ }); err != nil {
+ redirectOAuthError(c, frontendCallback, "session_error", "failed to continue oauth login", "")
+ return
+ }
+ redirectToFrontendCallback(c, frontendCallback)
+ return
+ }
+
+ if h.isForceEmailOnThirdPartySignup(c.Request.Context()) {
+ if err := h.createOAuthEmailRequiredPendingSession(c, identityRef, redirectTo, browserSessionKey, upstreamClaims); err != nil {
+ redirectOAuthError(c, frontendCallback, "session_error", "failed to continue oauth login", "")
+ return
+ }
+ redirectToFrontendCallback(c, frontendCallback)
+ return
+ }
+
// 传入空邀请码;如果需要邀请码,服务层返回 ErrOAuthInvitationRequired
tokenPair, user, err := h.authService.LoginOrRegisterOAuthWithTokenPair(c.Request.Context(), email, username, "")
if err != nil {
if errors.Is(err, service.ErrOAuthInvitationRequired) {
if err := h.createOAuthPendingSession(c, oauthPendingSessionPayload{
- Intent: "login",
- Identity: service.PendingAuthIdentityKey{
- ProviderType: "oidc",
- ProviderKey: issuer,
- ProviderSubject: subject,
- },
- ResolvedEmail: email,
- RedirectTo: redirectTo,
- BrowserSessionKey: browserSessionKey,
- UpstreamIdentityClaims: map[string]any{
- "email": email,
- "username": username,
- "subject": subject,
- "issuer": issuer,
- "email_verified": emailVerified != nil && *emailVerified,
- "provider_fallback": strings.TrimSpace(cfg.ProviderName),
- "suggested_display_name": firstNonEmpty(userInfoClaims.DisplayName, idClaims.Name, username),
- "suggested_avatar_url": userInfoClaims.AvatarURL,
- },
+ Intent: "login",
+ Identity: identityRef,
+ ResolvedEmail: email,
+ RedirectTo: redirectTo,
+ BrowserSessionKey: browserSessionKey,
+ UpstreamIdentityClaims: upstreamClaims,
CompletionResponse: map[string]any{
"error": "invitation_required",
"redirect": redirectTo,
@@ -420,26 +452,13 @@ func (h *AuthHandler) OIDCOAuthCallback(c *gin.Context) {
}
if err := h.createOAuthPendingSession(c, oauthPendingSessionPayload{
- Intent: "login",
- Identity: service.PendingAuthIdentityKey{
- ProviderType: "oidc",
- ProviderKey: issuer,
- ProviderSubject: subject,
- },
- TargetUserID: &user.ID,
- ResolvedEmail: email,
- RedirectTo: redirectTo,
- BrowserSessionKey: browserSessionKey,
- UpstreamIdentityClaims: map[string]any{
- "email": email,
- "username": username,
- "subject": subject,
- "issuer": issuer,
- "email_verified": emailVerified != nil && *emailVerified,
- "provider_fallback": strings.TrimSpace(cfg.ProviderName),
- "suggested_display_name": firstNonEmpty(userInfoClaims.DisplayName, idClaims.Name, username),
- "suggested_avatar_url": userInfoClaims.AvatarURL,
- },
+ Intent: "login",
+ Identity: identityRef,
+ TargetUserID: &user.ID,
+ ResolvedEmail: email,
+ RedirectTo: redirectTo,
+ BrowserSessionKey: browserSessionKey,
+ UpstreamIdentityClaims: upstreamClaims,
CompletionResponse: map[string]any{
"access_token": tokenPair.AccessToken,
"refresh_token": tokenPair.RefreshToken,
diff --git a/backend/internal/handler/auth_wechat_oauth.go b/backend/internal/handler/auth_wechat_oauth.go
index 45ac6cad..6d37c799 100644
--- a/backend/internal/handler/auth_wechat_oauth.go
+++ b/backend/internal/handler/auth_wechat_oauth.go
@@ -214,6 +214,11 @@ func (h *AuthHandler) WeChatOAuthCallback(c *gin.Context) {
"suggested_display_name": strings.TrimSpace(userInfo.Nickname),
"suggested_avatar_url": strings.TrimSpace(userInfo.HeadImgURL),
}
+ identityRef := service.PendingAuthIdentityKey{
+ ProviderType: "wechat",
+ ProviderKey: wechatOAuthProviderKey,
+ ProviderSubject: providerSubject,
+ }
normalizedIntent := normalizeWeChatOAuthIntent(intent)
if normalizedIntent == wechatOAuthIntentBind {
@@ -232,6 +237,34 @@ func (h *AuthHandler) WeChatOAuthCallback(c *gin.Context) {
return
}
+ existingIdentityUser, err := h.findOAuthIdentityUser(c.Request.Context(), identityRef)
+ if err != nil {
+ redirectOAuthError(c, frontendCallback, "session_error", infraerrors.Reason(err), infraerrors.Message(err))
+ return
+ }
+ if existingIdentityUser != nil {
+ tokenPair, user, err := h.authService.LoginOrRegisterOAuthWithTokenPair(c.Request.Context(), existingIdentityUser.Email, username, "")
+ if err != nil {
+ redirectOAuthError(c, frontendCallback, "login_failed", infraerrors.Reason(err), infraerrors.Message(err))
+ return
+ }
+ if err := h.createWeChatPendingSession(c, normalizedIntent, providerSubject, existingIdentityUser.Email, redirectTo, browserSessionKey, upstreamClaims, tokenPair, nil, &user.ID); err != nil {
+ redirectOAuthError(c, frontendCallback, "session_error", "failed to continue oauth login", "")
+ return
+ }
+ redirectToFrontendCallback(c, frontendCallback)
+ return
+ }
+
+ if h.isForceEmailOnThirdPartySignup(c.Request.Context()) {
+ if err := h.createOAuthEmailRequiredPendingSession(c, identityRef, redirectTo, browserSessionKey, upstreamClaims); err != nil {
+ redirectOAuthError(c, frontendCallback, "session_error", "failed to continue oauth login", "")
+ return
+ }
+ redirectToFrontendCallback(c, frontendCallback)
+ return
+ }
+
tokenPair, _, err := h.authService.LoginOrRegisterOAuthWithTokenPair(c.Request.Context(), email, username, "")
if err != nil {
if err := h.createWeChatPendingSession(c, normalizedIntent, providerSubject, email, redirectTo, browserSessionKey, upstreamClaims, tokenPair, err, nil); err != nil {
diff --git a/backend/internal/handler/dto/settings.go b/backend/internal/handler/dto/settings.go
index f44b3e3b..637e317b 100644
--- a/backend/internal/handler/dto/settings.go
+++ b/backend/internal/handler/dto/settings.go
@@ -167,6 +167,7 @@ type DefaultSubscriptionSetting struct {
type PublicSettings struct {
RegistrationEnabled bool `json:"registration_enabled"`
EmailVerifyEnabled bool `json:"email_verify_enabled"`
+ ForceEmailOnThirdPartySignup bool `json:"force_email_on_third_party_signup"`
RegistrationEmailSuffixWhitelist []string `json:"registration_email_suffix_whitelist"`
PromoCodeEnabled bool `json:"promo_code_enabled"`
PasswordResetEnabled bool `json:"password_reset_enabled"`
diff --git a/backend/internal/handler/setting_handler.go b/backend/internal/handler/setting_handler.go
index c7bc3e2a..9925f066 100644
--- a/backend/internal/handler/setting_handler.go
+++ b/backend/internal/handler/setting_handler.go
@@ -34,6 +34,7 @@ func (h *SettingHandler) GetPublicSettings(c *gin.Context) {
response.Success(c, dto.PublicSettings{
RegistrationEnabled: settings.RegistrationEnabled,
EmailVerifyEnabled: settings.EmailVerifyEnabled,
+ ForceEmailOnThirdPartySignup: settings.ForceEmailOnThirdPartySignup,
RegistrationEmailSuffixWhitelist: settings.RegistrationEmailSuffixWhitelist,
PromoCodeEnabled: settings.PromoCodeEnabled,
PasswordResetEnabled: settings.PasswordResetEnabled,
diff --git a/backend/internal/handler/setting_handler_public_test.go b/backend/internal/handler/setting_handler_public_test.go
new file mode 100644
index 00000000..114c7245
--- /dev/null
+++ b/backend/internal/handler/setting_handler_public_test.go
@@ -0,0 +1,83 @@
+//go:build unit
+
+package handler
+
+import (
+ "context"
+ "encoding/json"
+ "net/http"
+ "net/http/httptest"
+ "testing"
+
+ "github.com/Wei-Shaw/sub2api/internal/config"
+ "github.com/Wei-Shaw/sub2api/internal/service"
+ "github.com/gin-gonic/gin"
+ "github.com/stretchr/testify/require"
+)
+
+type settingHandlerPublicRepoStub struct {
+ values map[string]string
+}
+
+func (s *settingHandlerPublicRepoStub) Get(ctx context.Context, key string) (*service.Setting, error) {
+ panic("unexpected Get call")
+}
+
+func (s *settingHandlerPublicRepoStub) GetValue(ctx context.Context, key string) (string, error) {
+ panic("unexpected GetValue call")
+}
+
+func (s *settingHandlerPublicRepoStub) Set(ctx context.Context, key, value string) error {
+ panic("unexpected Set call")
+}
+
+func (s *settingHandlerPublicRepoStub) GetMultiple(ctx context.Context, keys []string) (map[string]string, error) {
+ out := make(map[string]string, len(keys))
+ for _, key := range keys {
+ if value, ok := s.values[key]; ok {
+ out[key] = value
+ }
+ }
+ return out, nil
+}
+
+func (s *settingHandlerPublicRepoStub) SetMultiple(ctx context.Context, settings map[string]string) error {
+ panic("unexpected SetMultiple call")
+}
+
+func (s *settingHandlerPublicRepoStub) GetAll(ctx context.Context) (map[string]string, error) {
+ panic("unexpected GetAll call")
+}
+
+func (s *settingHandlerPublicRepoStub) Delete(ctx context.Context, key string) error {
+ panic("unexpected Delete call")
+}
+
+func TestSettingHandler_GetPublicSettings_ExposesForceEmailOnThirdPartySignup(t *testing.T) {
+ gin.SetMode(gin.TestMode)
+
+ repo := &settingHandlerPublicRepoStub{
+ values: map[string]string{
+ service.SettingKeyForceEmailOnThirdPartySignup: "true",
+ },
+ }
+ h := NewSettingHandler(service.NewSettingService(repo, &config.Config{}), "test-version")
+
+ recorder := httptest.NewRecorder()
+ c, _ := gin.CreateTestContext(recorder)
+ c.Request = httptest.NewRequest(http.MethodGet, "/api/v1/settings/public", nil)
+
+ h.GetPublicSettings(c)
+
+ require.Equal(t, http.StatusOK, recorder.Code)
+
+ var resp struct {
+ Code int `json:"code"`
+ Data struct {
+ ForceEmailOnThirdPartySignup bool `json:"force_email_on_third_party_signup"`
+ } `json:"data"`
+ }
+ require.NoError(t, json.Unmarshal(recorder.Body.Bytes(), &resp))
+ require.Equal(t, 0, resp.Code)
+ require.True(t, resp.Data.ForceEmailOnThirdPartySignup)
+}
diff --git a/backend/internal/server/routes/auth.go b/backend/internal/server/routes/auth.go
index 7a34834d..1f28e9c3 100644
--- a/backend/internal/server/routes/auth.go
+++ b/backend/internal/server/routes/auth.go
@@ -72,18 +72,54 @@ func RegisterAuthRoutes(
}),
h.Auth.ExchangePendingOAuthCompletion,
)
+ auth.POST("/oauth/pending/create-account",
+ rateLimiter.LimitWithOptions("oauth-pending-create-account", 10, time.Minute, middleware.RateLimitOptions{
+ FailureMode: middleware.RateLimitFailClose,
+ }),
+ h.Auth.CreatePendingOAuthAccount,
+ )
+ auth.POST("/oauth/pending/bind-login",
+ rateLimiter.LimitWithOptions("oauth-pending-bind-login", 10, time.Minute, middleware.RateLimitOptions{
+ FailureMode: middleware.RateLimitFailClose,
+ }),
+ h.Auth.BindPendingOAuthLogin,
+ )
auth.POST("/oauth/linuxdo/complete-registration",
rateLimiter.LimitWithOptions("oauth-linuxdo-complete", 10, time.Minute, middleware.RateLimitOptions{
FailureMode: middleware.RateLimitFailClose,
}),
h.Auth.CompleteLinuxDoOAuthRegistration,
)
+ auth.POST("/oauth/linuxdo/bind-login",
+ rateLimiter.LimitWithOptions("oauth-linuxdo-bind-login", 20, time.Minute, middleware.RateLimitOptions{
+ FailureMode: middleware.RateLimitFailClose,
+ }),
+ h.Auth.BindLinuxDoOAuthLogin,
+ )
+ auth.POST("/oauth/linuxdo/create-account",
+ rateLimiter.LimitWithOptions("oauth-linuxdo-create-account", 10, time.Minute, middleware.RateLimitOptions{
+ FailureMode: middleware.RateLimitFailClose,
+ }),
+ h.Auth.CreateLinuxDoOAuthAccount,
+ )
auth.POST("/oauth/wechat/complete-registration",
rateLimiter.LimitWithOptions("oauth-wechat-complete", 10, time.Minute, middleware.RateLimitOptions{
FailureMode: middleware.RateLimitFailClose,
}),
h.Auth.CompleteWeChatOAuthRegistration,
)
+ auth.POST("/oauth/wechat/bind-login",
+ rateLimiter.LimitWithOptions("oauth-wechat-bind-login", 20, time.Minute, middleware.RateLimitOptions{
+ FailureMode: middleware.RateLimitFailClose,
+ }),
+ h.Auth.BindWeChatOAuthLogin,
+ )
+ auth.POST("/oauth/wechat/create-account",
+ rateLimiter.LimitWithOptions("oauth-wechat-create-account", 10, time.Minute, middleware.RateLimitOptions{
+ FailureMode: middleware.RateLimitFailClose,
+ }),
+ h.Auth.CreateWeChatOAuthAccount,
+ )
auth.GET("/oauth/oidc/start", h.Auth.OIDCOAuthStart)
auth.GET("/oauth/oidc/callback", h.Auth.OIDCOAuthCallback)
auth.POST("/oauth/oidc/complete-registration",
@@ -92,6 +128,18 @@ func RegisterAuthRoutes(
}),
h.Auth.CompleteOIDCOAuthRegistration,
)
+ auth.POST("/oauth/oidc/bind-login",
+ rateLimiter.LimitWithOptions("oauth-oidc-bind-login", 20, time.Minute, middleware.RateLimitOptions{
+ FailureMode: middleware.RateLimitFailClose,
+ }),
+ h.Auth.BindOIDCOAuthLogin,
+ )
+ auth.POST("/oauth/oidc/create-account",
+ rateLimiter.LimitWithOptions("oauth-oidc-create-account", 10, time.Minute, middleware.RateLimitOptions{
+ FailureMode: middleware.RateLimitFailClose,
+ }),
+ h.Auth.CreateOIDCOAuthAccount,
+ )
}
// 公开设置(无需认证)
diff --git a/backend/internal/service/auth_oauth_email_flow.go b/backend/internal/service/auth_oauth_email_flow.go
new file mode 100644
index 00000000..ca3403d4
--- /dev/null
+++ b/backend/internal/service/auth_oauth_email_flow.go
@@ -0,0 +1,151 @@
+package service
+
+import (
+ "context"
+ "errors"
+ "fmt"
+ "strings"
+)
+
+// VerifyOAuthEmailCode verifies the locally entered email verification code for
+// third-party signup and binding flows. This is intentionally independent from
+// the global registration email verification toggle.
+func (s *AuthService) VerifyOAuthEmailCode(ctx context.Context, email, verifyCode string) error {
+ email = strings.TrimSpace(strings.ToLower(email))
+ verifyCode = strings.TrimSpace(verifyCode)
+
+ if email == "" {
+ return ErrEmailVerifyRequired
+ }
+ if verifyCode == "" {
+ return ErrEmailVerifyRequired
+ }
+ if s == nil || s.emailService == nil {
+ return ErrServiceUnavailable
+ }
+ return s.emailService.VerifyCode(ctx, email, verifyCode)
+}
+
+// RegisterOAuthEmailAccount creates a local account from a third-party first
+// login after the user has verified a local email address.
+func (s *AuthService) RegisterOAuthEmailAccount(
+ ctx context.Context,
+ email string,
+ password string,
+ verifyCode string,
+ invitationCode string,
+ signupSource string,
+) (*TokenPair, *User, error) {
+ if s == nil {
+ return nil, nil, ErrServiceUnavailable
+ }
+ if s.settingService == nil || !s.settingService.IsRegistrationEnabled(ctx) {
+ return nil, nil, ErrRegDisabled
+ }
+
+ email = strings.TrimSpace(strings.ToLower(email))
+ if isReservedEmail(email) {
+ return nil, nil, ErrEmailReserved
+ }
+ if err := s.validateRegistrationEmailPolicy(ctx, email); err != nil {
+ return nil, nil, err
+ }
+ if err := s.VerifyOAuthEmailCode(ctx, email, verifyCode); err != nil {
+ return nil, nil, err
+ }
+
+ var invitationRedeemCode *RedeemCode
+ if s.settingService.IsInvitationCodeEnabled(ctx) {
+ if invitationCode == "" {
+ return nil, nil, ErrInvitationCodeRequired
+ }
+ redeemCode, err := s.redeemRepo.GetByCode(ctx, invitationCode)
+ if err != nil {
+ return nil, nil, ErrInvitationCodeInvalid
+ }
+ if redeemCode.Type != RedeemTypeInvitation || redeemCode.Status != StatusUnused {
+ return nil, nil, ErrInvitationCodeInvalid
+ }
+ invitationRedeemCode = redeemCode
+ }
+
+ existsEmail, err := s.userRepo.ExistsByEmail(ctx, email)
+ if err != nil {
+ return nil, nil, ErrServiceUnavailable
+ }
+ if existsEmail {
+ return nil, nil, ErrEmailExists
+ }
+
+ hashedPassword, err := s.HashPassword(password)
+ if err != nil {
+ return nil, nil, fmt.Errorf("hash password: %w", err)
+ }
+
+ signupSource = strings.TrimSpace(strings.ToLower(signupSource))
+ if signupSource == "" {
+ signupSource = "email"
+ }
+ grantPlan := s.resolveSignupGrantPlan(ctx, signupSource)
+
+ user := &User{
+ Email: email,
+ PasswordHash: hashedPassword,
+ Role: RoleUser,
+ Balance: grantPlan.Balance,
+ Concurrency: grantPlan.Concurrency,
+ Status: StatusActive,
+ }
+
+ if err := s.userRepo.Create(ctx, user); err != nil {
+ if errors.Is(err, ErrEmailExists) {
+ return nil, nil, ErrEmailExists
+ }
+ return nil, nil, ErrServiceUnavailable
+ }
+
+ s.postAuthUserBootstrap(ctx, user, signupSource, true)
+ s.assignSubscriptions(ctx, user.ID, grantPlan.Subscriptions, "auto assigned by signup defaults")
+
+ if invitationRedeemCode != nil {
+ if err := s.redeemRepo.Use(ctx, invitationRedeemCode.ID, user.ID); err != nil {
+ return nil, nil, ErrInvitationCodeInvalid
+ }
+ }
+
+ tokenPair, err := s.GenerateTokenPair(ctx, user, "")
+ if err != nil {
+ return nil, nil, fmt.Errorf("generate token pair: %w", err)
+ }
+ return tokenPair, user, nil
+}
+
+// ValidatePasswordCredentials checks the local password without completing the
+// login flow. This is used by pending third-party account adoption flows before
+// the external identity has been bound.
+func (s *AuthService) ValidatePasswordCredentials(ctx context.Context, email, password string) (*User, error) {
+ if s == nil {
+ return nil, ErrServiceUnavailable
+ }
+
+ user, err := s.userRepo.GetByEmail(ctx, strings.TrimSpace(strings.ToLower(email)))
+ if err != nil {
+ if errors.Is(err, ErrUserNotFound) {
+ return nil, ErrInvalidCredentials
+ }
+ return nil, ErrServiceUnavailable
+ }
+ if !user.IsActive() {
+ return nil, ErrUserNotActive
+ }
+ if !s.CheckPassword(password, user.PasswordHash) {
+ return nil, ErrInvalidCredentials
+ }
+ return user, nil
+}
+
+// RecordSuccessfulLogin updates last-login activity after a non-standard login
+// flow finishes with a real session.
+func (s *AuthService) RecordSuccessfulLogin(ctx context.Context, userID int64) {
+ s.touchUserLogin(ctx, userID)
+}
diff --git a/backend/internal/service/setting_service.go b/backend/internal/service/setting_service.go
index de555478..a2644fcd 100644
--- a/backend/internal/service/setting_service.go
+++ b/backend/internal/service/setting_service.go
@@ -217,6 +217,7 @@ func (s *SettingService) GetPublicSettings(ctx context.Context) (*PublicSettings
keys := []string{
SettingKeyRegistrationEnabled,
SettingKeyEmailVerifyEnabled,
+ SettingKeyForceEmailOnThirdPartySignup,
SettingKeyRegistrationEmailSuffixWhitelist,
SettingKeyPromoCodeEnabled,
SettingKeyPasswordResetEnabled,
@@ -294,6 +295,7 @@ func (s *SettingService) GetPublicSettings(ctx context.Context) (*PublicSettings
return &PublicSettings{
RegistrationEnabled: settings[SettingKeyRegistrationEnabled] == "true",
EmailVerifyEnabled: emailVerifyEnabled,
+ ForceEmailOnThirdPartySignup: settings[SettingKeyForceEmailOnThirdPartySignup] == "true",
RegistrationEmailSuffixWhitelist: registrationEmailSuffixWhitelist,
PromoCodeEnabled: settings[SettingKeyPromoCodeEnabled] != "false", // 默认启用
PasswordResetEnabled: passwordResetEnabled,
diff --git a/backend/internal/service/setting_service_public_test.go b/backend/internal/service/setting_service_public_test.go
index 5cf1e860..bb97c2aa 100644
--- a/backend/internal/service/setting_service_public_test.go
+++ b/backend/internal/service/setting_service_public_test.go
@@ -77,3 +77,16 @@ func TestSettingService_GetPublicSettings_ExposesTablePreferences(t *testing.T)
require.Equal(t, 50, settings.TableDefaultPageSize)
require.Equal(t, []int{20, 50, 100}, settings.TablePageSizeOptions)
}
+
+func TestSettingService_GetPublicSettings_ExposesForceEmailOnThirdPartySignup(t *testing.T) {
+ repo := &settingPublicRepoStub{
+ values: map[string]string{
+ SettingKeyForceEmailOnThirdPartySignup: "true",
+ },
+ }
+ svc := NewSettingService(repo, &config.Config{})
+
+ settings, err := svc.GetPublicSettings(context.Background())
+ require.NoError(t, err)
+ require.True(t, settings.ForceEmailOnThirdPartySignup)
+}
diff --git a/backend/internal/service/settings_view.go b/backend/internal/service/settings_view.go
index e991ebef..72db4e31 100644
--- a/backend/internal/service/settings_view.go
+++ b/backend/internal/service/settings_view.go
@@ -128,6 +128,7 @@ type DefaultSubscriptionSetting struct {
type PublicSettings struct {
RegistrationEnabled bool
EmailVerifyEnabled bool
+ ForceEmailOnThirdPartySignup bool
RegistrationEmailSuffixWhitelist []string
PromoCodeEnabled bool
PasswordResetEnabled bool
From 6ea3f42e2f825f165c98a63fa6b5f472f6853b33 Mon Sep 17 00:00:00 2001
From: IanShaw027
Date: Mon, 20 Apr 2026 19:30:19 +0800
Subject: [PATCH 029/189] feat: add oauth callback email binding ui
---
.../api/__tests__/auth-oauth-adoption.spec.ts | 91 +++++++
frontend/src/api/auth.ts | 102 +++++--
frontend/src/stores/app.ts | 1 +
frontend/src/types/index.ts | 1 +
.../src/views/auth/LinuxDoCallbackView.vue | 257 +++++++++++++++++-
frontend/src/views/auth/OidcCallbackView.vue | 128 ++++++++-
.../src/views/auth/WechatCallbackView.vue | 108 ++++++++
.../__tests__/LinuxDoCallbackView.spec.ts | 105 +++++++
.../auth/__tests__/OidcCallbackView.spec.ts | 71 +++++
.../auth/__tests__/WechatCallbackView.spec.ts | 88 ++++++
10 files changed, 916 insertions(+), 36 deletions(-)
diff --git a/frontend/src/api/__tests__/auth-oauth-adoption.spec.ts b/frontend/src/api/__tests__/auth-oauth-adoption.spec.ts
index 9c0b4d55..f95332fb 100644
--- a/frontend/src/api/__tests__/auth-oauth-adoption.spec.ts
+++ b/frontend/src/api/__tests__/auth-oauth-adoption.spec.ts
@@ -30,6 +30,20 @@ describe('oauth adoption auth api', () => {
})
})
+ it('posts bind-login decisions when finalizing pending oauth bind flow', async () => {
+ const { completePendingOAuthBindLogin } = await import('@/api/auth')
+
+ await completePendingOAuthBindLogin({
+ adoptDisplayName: true,
+ adoptAvatar: false
+ })
+
+ expect(post).toHaveBeenCalledWith('/auth/oauth/pending/exchange', {
+ adopt_display_name: true,
+ adopt_avatar: false
+ })
+ })
+
it('posts linuxdo invitation completion with adoption decisions', async () => {
const { completeLinuxDoOAuthRegistration } = await import('@/api/auth')
@@ -45,6 +59,21 @@ describe('oauth adoption auth api', () => {
})
})
+ it('posts linuxdo create-account completion with adoption decisions', async () => {
+ const { createPendingLinuxDoOAuthAccount } = await import('@/api/auth')
+
+ await createPendingLinuxDoOAuthAccount('invite-code', {
+ adoptDisplayName: false,
+ adoptAvatar: true
+ })
+
+ expect(post).toHaveBeenCalledWith('/auth/oauth/linuxdo/complete-registration', {
+ invitation_code: 'invite-code',
+ adopt_display_name: false,
+ adopt_avatar: true
+ })
+ })
+
it('posts oidc invitation completion with adoption decisions', async () => {
const { completeOIDCOAuthRegistration } = await import('@/api/auth')
@@ -60,6 +89,21 @@ describe('oauth adoption auth api', () => {
})
})
+ it('posts oidc create-account completion with adoption decisions', async () => {
+ const { createPendingOIDCOAuthAccount } = await import('@/api/auth')
+
+ await createPendingOIDCOAuthAccount('invite-code', {
+ adoptDisplayName: true,
+ adoptAvatar: false
+ })
+
+ expect(post).toHaveBeenCalledWith('/auth/oauth/oidc/complete-registration', {
+ invitation_code: 'invite-code',
+ adopt_display_name: true,
+ adopt_avatar: false
+ })
+ })
+
it('posts wechat invitation completion with adoption decisions', async () => {
const { completeWeChatOAuthRegistration } = await import('@/api/auth')
@@ -75,6 +119,21 @@ describe('oauth adoption auth api', () => {
})
})
+ it('posts wechat create-account completion with adoption decisions', async () => {
+ const { createPendingWeChatOAuthAccount } = await import('@/api/auth')
+
+ await createPendingWeChatOAuthAccount('invite-code', {
+ adoptDisplayName: false,
+ adoptAvatar: false
+ })
+
+ expect(post).toHaveBeenCalledWith('/auth/oauth/wechat/complete-registration', {
+ invitation_code: 'invite-code',
+ adopt_display_name: false,
+ adopt_avatar: false
+ })
+ })
+
it('classifies oauth completion results as login or bind', async () => {
const { getOAuthCompletionKind } = await import('@/api/auth')
@@ -82,6 +141,38 @@ describe('oauth adoption auth api', () => {
expect(getOAuthCompletionKind({ redirect: '/profile' })).toBe('bind')
})
+ it('provides bind-login utility helpers for invitation and suggested profile states', async () => {
+ const {
+ getPendingOAuthBindLoginKind,
+ hasPendingOAuthSuggestedProfile,
+ isPendingOAuthCreateAccountRequired
+ } = await import('@/api/auth')
+
+ expect(getPendingOAuthBindLoginKind({ access_token: 'access-token' })).toBe('login')
+ expect(getPendingOAuthBindLoginKind({ redirect: '/profile' })).toBe('bind')
+ expect(
+ isPendingOAuthCreateAccountRequired({
+ error: 'invitation_required'
+ })
+ ).toBe(true)
+ expect(
+ isPendingOAuthCreateAccountRequired({
+ error: 'other'
+ })
+ ).toBe(false)
+ expect(
+ hasPendingOAuthSuggestedProfile({
+ suggested_display_name: 'OAuth Nick'
+ })
+ ).toBe(true)
+ expect(
+ hasPendingOAuthSuggestedProfile({
+ suggested_avatar_url: 'https://cdn.example/avatar.png'
+ })
+ ).toBe(true)
+ expect(hasPendingOAuthSuggestedProfile({})).toBe(false)
+ })
+
it('prepares an oauth bind access token cookie before redirect binding', async () => {
localStorage.setItem('auth_token', 'access-token-value')
const setCookie = vi.fn()
diff --git a/frontend/src/api/auth.ts b/frontend/src/api/auth.ts
index c11bd90b..98b20154 100644
--- a/frontend/src/api/auth.ts
+++ b/frontend/src/api/auth.ts
@@ -193,7 +193,7 @@ export interface OAuthTokenResponse {
token_type?: string
}
-export interface PendingOAuthExchangeResponse extends Partial {
+export interface PendingOAuthBindLoginResponse extends Partial {
redirect?: string
error?: string
adoption_required?: boolean
@@ -201,6 +201,10 @@ export interface PendingOAuthExchangeResponse extends Partial
+): boolean {
+ return completion.error === 'invitation_required'
+}
+
+export function hasPendingOAuthSuggestedProfile(
+ completion: Pick<
+ PendingOAuthBindLoginResponse,
+ 'suggested_display_name' | 'suggested_avatar_url'
+ >
+): boolean {
+ return Boolean(completion.suggested_display_name || completion.suggested_avatar_url)
+}
+
export function persistOAuthTokenContext(tokens: Partial): void {
if (tokens.refresh_token) {
setRefreshToken(tokens.refresh_token)
@@ -431,11 +456,7 @@ export async function completeLinuxDoOAuthRegistration(
invitationCode: string,
decision?: OAuthAdoptionDecision
): Promise {
- const { data } = await apiClient.post('/auth/oauth/linuxdo/complete-registration', {
- invitation_code: invitationCode,
- ...serializeOAuthAdoptionDecision(decision)
- })
- return data
+ return createPendingLinuxDoOAuthAccount(invitationCode, decision)
}
/**
@@ -447,32 +468,66 @@ export async function completeOIDCOAuthRegistration(
invitationCode: string,
decision?: OAuthAdoptionDecision
): Promise {
- const { data } = await apiClient.post('/auth/oauth/oidc/complete-registration', {
- invitation_code: invitationCode,
- ...serializeOAuthAdoptionDecision(decision)
- })
- return data
+ return createPendingOIDCOAuthAccount(invitationCode, decision)
}
export async function completeWeChatOAuthRegistration(
invitationCode: string,
decision?: OAuthAdoptionDecision
): Promise {
- const { data } = await apiClient.post('/auth/oauth/wechat/complete-registration', {
- invitation_code: invitationCode,
- ...serializeOAuthAdoptionDecision(decision)
- })
+ return createPendingWeChatOAuthAccount(invitationCode, decision)
+}
+
+async function createPendingOAuthAccount(
+ provider: 'linuxdo' | 'oidc' | 'wechat',
+ invitationCode: string,
+ decision?: OAuthAdoptionDecision
+): Promise {
+ const { data } = await apiClient.post(
+ `/auth/oauth/${provider}/complete-registration`,
+ {
+ invitation_code: invitationCode,
+ ...serializeOAuthAdoptionDecision(decision)
+ }
+ )
+ return data
+}
+
+export async function createPendingLinuxDoOAuthAccount(
+ invitationCode: string,
+ decision?: OAuthAdoptionDecision
+): Promise {
+ return createPendingOAuthAccount('linuxdo', invitationCode, decision)
+}
+
+export async function createPendingOIDCOAuthAccount(
+ invitationCode: string,
+ decision?: OAuthAdoptionDecision
+): Promise {
+ return createPendingOAuthAccount('oidc', invitationCode, decision)
+}
+
+export async function createPendingWeChatOAuthAccount(
+ invitationCode: string,
+ decision?: OAuthAdoptionDecision
+): Promise {
+ return createPendingOAuthAccount('wechat', invitationCode, decision)
+}
+
+export async function completePendingOAuthBindLogin(
+ decision?: OAuthAdoptionDecision
+): Promise {
+ const { data } = await apiClient.post(
+ '/auth/oauth/pending/exchange',
+ serializeOAuthAdoptionDecision(decision)
+ )
return data
}
export async function exchangePendingOAuthCompletion(
decision?: OAuthAdoptionDecision
): Promise {
- const { data } = await apiClient.post(
- '/auth/oauth/pending/exchange',
- serializeOAuthAdoptionDecision(decision)
- )
- return data
+ return completePendingOAuthBindLogin(decision)
}
export const authAPI = {
@@ -498,6 +553,13 @@ export const authAPI = {
resetPassword,
refreshToken,
revokeAllSessions,
+ getPendingOAuthBindLoginKind,
+ isPendingOAuthCreateAccountRequired,
+ hasPendingOAuthSuggestedProfile,
+ completePendingOAuthBindLogin,
+ createPendingLinuxDoOAuthAccount,
+ createPendingOIDCOAuthAccount,
+ createPendingWeChatOAuthAccount,
exchangePendingOAuthCompletion,
completeLinuxDoOAuthRegistration,
completeOIDCOAuthRegistration,
diff --git a/frontend/src/stores/app.ts b/frontend/src/stores/app.ts
index 1b1af87b..a8e03a51 100644
--- a/frontend/src/stores/app.ts
+++ b/frontend/src/stores/app.ts
@@ -316,6 +316,7 @@ export const useAppStore = defineStore('app', () => {
return {
registration_enabled: false,
email_verify_enabled: false,
+ force_email_on_third_party_signup: false,
registration_email_suffix_whitelist: [],
promo_code_enabled: true,
password_reset_enabled: false,
diff --git a/frontend/src/types/index.ts b/frontend/src/types/index.ts
index a19d6c26..a4b2277b 100644
--- a/frontend/src/types/index.ts
+++ b/frontend/src/types/index.ts
@@ -142,6 +142,7 @@ export interface CustomEndpoint {
export interface PublicSettings {
registration_enabled: boolean
email_verify_enabled: boolean
+ force_email_on_third_party_signup: boolean
registration_email_suffix_whitelist: string[]
promo_code_enabled: boolean
password_reset_enabled: boolean
diff --git a/frontend/src/views/auth/LinuxDoCallbackView.vue b/frontend/src/views/auth/LinuxDoCallbackView.vue
index 6dc8f242..00b73868 100644
--- a/frontend/src/views/auth/LinuxDoCallbackView.vue
+++ b/frontend/src/views/auth/LinuxDoCallbackView.vue
@@ -11,7 +11,10 @@
-
+
@@ -127,11 +214,12 @@
diff --git a/frontend/src/views/auth/__tests__/WechatPaymentCallbackView.spec.ts b/frontend/src/views/auth/__tests__/WechatPaymentCallbackView.spec.ts
new file mode 100644
index 00000000..cfbd9f1c
--- /dev/null
+++ b/frontend/src/views/auth/__tests__/WechatPaymentCallbackView.spec.ts
@@ -0,0 +1,80 @@
+import { flushPromises, mount } from '@vue/test-utils'
+import { beforeEach, describe, expect, it, vi } from 'vitest'
+import WechatPaymentCallbackView from '@/views/auth/WechatPaymentCallbackView.vue'
+
+const { replaceMock, routeState, locationState } = vi.hoisted(() => ({
+ replaceMock: vi.fn(),
+ routeState: {
+ query: {} as Record
,
+ },
+ locationState: {
+ current: {
+ href: 'http://localhost/auth/wechat/payment/callback',
+ hash: '',
+ search: '',
+ pathname: '/auth/wechat/payment/callback',
+ origin: 'http://localhost',
+ } as Location & { origin: string },
+ },
+}))
+
+vi.mock('vue-router', () => ({
+ useRoute: () => routeState,
+ useRouter: () => ({
+ replace: replaceMock,
+ }),
+}))
+
+vi.mock('vue-i18n', () => ({
+ useI18n: () => ({
+ t: (key: string) => key,
+ locale: { value: 'zh-CN' },
+ }),
+}))
+
+describe('WechatPaymentCallbackView', () => {
+ beforeEach(() => {
+ replaceMock.mockReset()
+ routeState.query = {}
+ locationState.current = {
+ href: 'http://localhost/auth/wechat/payment/callback',
+ hash: '',
+ search: '',
+ pathname: '/auth/wechat/payment/callback',
+ origin: 'http://localhost',
+ } as Location & { origin: string }
+ Object.defineProperty(window, 'location', {
+ configurable: true,
+ value: locationState.current,
+ })
+ })
+
+ it('redirects back to purchase with openid and payment context from hash fragment', async () => {
+ locationState.current.hash = '#openid=openid-123&payment_type=wxpay&amount=12.5&order_type=balance&redirect=%2Fpurchase%3Ffrom%3Dwechat'
+
+ mount(WechatPaymentCallbackView)
+ await flushPromises()
+
+ expect(replaceMock).toHaveBeenCalledWith({
+ path: '/purchase',
+ query: {
+ from: 'wechat',
+ wechat_resume: '1',
+ openid: 'openid-123',
+ payment_type: 'wxpay',
+ amount: '12.5',
+ order_type: 'balance',
+ },
+ })
+ })
+
+ it('shows an error when the callback payload is missing openid', async () => {
+ locationState.current.hash = '#payment_type=wxpay'
+
+ const wrapper = mount(WechatPaymentCallbackView)
+ await flushPromises()
+
+ expect(replaceMock).not.toHaveBeenCalled()
+ expect(wrapper.text()).toContain('微信支付回调缺少 openid。')
+ })
+})
diff --git a/frontend/src/views/user/PaymentView.vue b/frontend/src/views/user/PaymentView.vue
index f973ad5b..bfb9dae2 100644
--- a/frontend/src/views/user/PaymentView.vue
+++ b/frontend/src/views/user/PaymentView.vue
@@ -309,6 +309,20 @@ const previewImage = ref('')
const paymentPhase = ref<'select' | 'paying'>('select')
+interface CreateOrderOptions {
+ openid?: string
+ paymentType?: string
+ isResume?: boolean
+}
+
+interface WeixinJSBridgeLike {
+ invoke(
+ action: string,
+ payload: Record,
+ callback: (result: Record) => void,
+ ): void
+}
+
function emptyPaymentState(): PaymentRecoverySnapshot {
return {
orderId: 0,
@@ -326,6 +340,48 @@ function emptyPaymentState(): PaymentRecoverySnapshot {
}
}
+function readRouteQueryValue(value: unknown): string {
+ if (Array.isArray(value)) {
+ return typeof value[0] === 'string' ? value[0] : ''
+ }
+ return typeof value === 'string' ? value : ''
+}
+
+function getWeixinJSBridge(): WeixinJSBridgeLike | undefined {
+ return (window as Window & { WeixinJSBridge?: WeixinJSBridgeLike }).WeixinJSBridge
+}
+
+function waitForWeixinJSBridge(timeoutMs = 4000): Promise {
+ const existing = getWeixinJSBridge()
+ if (existing) return Promise.resolve(existing)
+
+ return new Promise((resolve) => {
+ let settled = false
+ const finish = (bridge: WeixinJSBridgeLike | null) => {
+ if (settled) return
+ settled = true
+ document.removeEventListener('WeixinJSBridgeReady', handleReady)
+ document.removeEventListener('onWeixinJSBridgeReady', handleReady)
+ window.clearTimeout(timer)
+ resolve(bridge)
+ }
+ const handleReady = () => finish(getWeixinJSBridge() ?? null)
+ const timer = window.setTimeout(() => finish(getWeixinJSBridge() ?? null), timeoutMs)
+ document.addEventListener('WeixinJSBridgeReady', handleReady, false)
+ document.addEventListener('onWeixinJSBridgeReady', handleReady, false)
+ })
+}
+
+async function invokeWechatJsapiPayment(payload: Record): Promise> {
+ const bridge = await waitForWeixinJSBridge()
+ if (!bridge) {
+ throw new Error('WeixinJSBridge is unavailable')
+ }
+ return new Promise((resolve) => {
+ bridge.invoke('getBrandWCPayRequest', payload, (result) => resolve(result || {}))
+ })
+}
+
const paymentState = ref(emptyPaymentState())
function persistRecoverySnapshot(snapshot: PaymentRecoverySnapshot) {
@@ -560,25 +616,32 @@ async function confirmSubscribe() {
await createOrder(selectedPlan.value.price, 'subscription', selectedPlan.value.id)
}
-async function createOrder(orderAmount: number, orderType: OrderType, planId?: number) {
+async function createOrder(orderAmount: number, orderType: OrderType, planId?: number, options: CreateOrderOptions = {}) {
submitting.value = true
errorMessage.value = ''
try {
- const result = await paymentStore.createOrder(buildCreateOrderPayload({
+ const requestType = normalizeVisibleMethod(options.paymentType || selectedMethod.value) || options.paymentType || selectedMethod.value
+ const payload = buildCreateOrderPayload({
amount: orderAmount,
- paymentType: selectedMethod.value,
+ paymentType: requestType,
orderType,
planId,
origin: typeof window !== 'undefined' ? window.location.origin : '',
isWechatBrowser: typeof window !== 'undefined' && /MicroMessenger/i.test(window.navigator.userAgent),
- })) as CreateOrderResult & { resume_token?: string }
+ })
+ if (options.openid) {
+ payload.openid = options.openid
+ }
+ payload.is_mobile = isMobileDevice()
+
+ const result = await paymentStore.createOrder(payload) as CreateOrderResult & { resume_token?: string }
const openWindow = (url: string, features = POPUP_WINDOW_FEATURES) => {
const win = window.open(url, 'paymentPopup', features)
if (!win || win.closed) {
window.location.href = url
}
}
- const visibleMethod = normalizeVisibleMethod(selectedMethod.value) || selectedMethod.value
+ const visibleMethod = normalizeVisibleMethod(requestType) || requestType
const stripeMethod = visibleMethod === 'wxpay' ? 'wechat_pay' : 'alipay'
const stripeRouteUrl = result.client_secret
? router.resolve({
@@ -599,6 +662,11 @@ async function createOrder(orderAmount: number, orderType: OrderType, planId?: n
stripeRouteUrl,
})
+ if (decision.kind === 'wechat_oauth' && decision.oauth?.authorize_url) {
+ window.location.href = decision.oauth.authorize_url
+ return
+ }
+
if (decision.kind === 'unhandled') {
errorMessage.value = t('payment.result.failed')
appStore.showError(errorMessage.value)
@@ -617,6 +685,16 @@ async function createOrder(orderAmount: number, orderType: OrderType, planId?: n
window.location.href = decision.paymentState.payUrl
return
}
+ if (decision.kind === 'wechat_jsapi' && decision.jsapi) {
+ const jsapiResult = await invokeWechatJsapiPayment(decision.jsapi as Record)
+ const errMsg = String(jsapiResult.err_msg || '').toLowerCase()
+ if (errMsg.includes('cancel')) {
+ appStore.showInfo(t('payment.qr.cancelled'))
+ } else if (errMsg && !errMsg.includes('ok')) {
+ appStore.showError(t('payment.result.failed'))
+ }
+ return
+ }
if (decision.kind === 'redirect_waiting' && decision.paymentState.payUrl) {
if (isMobileDevice()) {
window.location.href = decision.paymentState.payUrl
@@ -640,6 +718,50 @@ async function createOrder(orderAmount: number, orderType: OrderType, planId?: n
}
}
+async function resumeWechatPaymentFromQuery() {
+ const openid = readRouteQueryValue(route.query.openid)
+ if (readRouteQueryValue(route.query.wechat_resume) !== '1' || !openid) {
+ return
+ }
+
+ const paymentType = normalizeVisibleMethod(readRouteQueryValue(route.query.payment_type)) || 'wxpay'
+ const orderType = readRouteQueryValue(route.query.order_type) === 'subscription' ? 'subscription' : 'balance'
+ const planId = Number.parseInt(readRouteQueryValue(route.query.plan_id), 10)
+ const rawAmount = Number.parseFloat(readRouteQueryValue(route.query.amount))
+ const orderAmount = Number.isFinite(rawAmount) && rawAmount > 0
+ ? rawAmount
+ : (orderType === 'subscription'
+ ? (checkout.value.plans.find(plan => plan.id === planId)?.price ?? 0)
+ : validAmount.value)
+
+ selectedMethod.value = paymentType
+ if (orderType === 'balance' && orderAmount > 0) {
+ amount.value = orderAmount
+ }
+ if (orderType === 'subscription' && Number.isFinite(planId) && planId > 0) {
+ selectedPlan.value = checkout.value.plans.find(plan => plan.id === planId) ?? null
+ }
+
+ const nextQuery = { ...route.query }
+ delete nextQuery.wechat_resume
+ delete nextQuery.openid
+ delete nextQuery.state
+ delete nextQuery.scope
+ delete nextQuery.payment_type
+ delete nextQuery.amount
+ delete nextQuery.order_type
+ delete nextQuery.plan_id
+ await router.replace({ path: route.path, query: nextQuery })
+
+ if (orderAmount > 0) {
+ await createOrder(orderAmount, orderType, Number.isFinite(planId) && planId > 0 ? planId : undefined, {
+ openid,
+ paymentType,
+ isResume: true,
+ })
+ }
+}
+
onMounted(async () => {
try {
const res = await paymentAPI.getCheckoutInfo()
@@ -672,6 +794,7 @@ onMounted(async () => {
removeRecoverySnapshot()
}
}
+ await resumeWechatPaymentFromQuery()
if (checkout.value.balance_disabled) {
activeTab.value = 'subscription'
}
From 0fa47f18ed556a49ec582669e85b6863cdcdb2fd Mon Sep 17 00:00:00 2001
From: IanShaw027
Date: Tue, 21 Apr 2026 00:02:51 +0800
Subject: [PATCH 059/189] feat: complete pending oauth account creation UI
---
.../auth/PendingOAuthCreateAccountForm.vue | 212 ++++++++++++++++++
.../PendingOAuthCreateAccountForm.spec.ts | 80 +++++++
.../src/views/auth/LinuxDoCallbackView.vue | 55 ++---
frontend/src/views/auth/OidcCallbackView.vue | 55 ++---
.../src/views/auth/WechatCallbackView.vue | 55 ++---
.../__tests__/LinuxDoCallbackView.spec.ts | 43 +++-
.../auth/__tests__/OidcCallbackView.spec.ts | 43 +++-
.../auth/__tests__/WechatCallbackView.spec.ts | 42 +++-
8 files changed, 469 insertions(+), 116 deletions(-)
create mode 100644 frontend/src/components/auth/PendingOAuthCreateAccountForm.vue
create mode 100644 frontend/src/components/auth/__tests__/PendingOAuthCreateAccountForm.spec.ts
diff --git a/frontend/src/components/auth/PendingOAuthCreateAccountForm.vue b/frontend/src/components/auth/PendingOAuthCreateAccountForm.vue
new file mode 100644
index 00000000..39588a86
--- /dev/null
+++ b/frontend/src/components/auth/PendingOAuthCreateAccountForm.vue
@@ -0,0 +1,212 @@
+
+
+
+
+
+
+
diff --git a/frontend/src/components/auth/__tests__/PendingOAuthCreateAccountForm.spec.ts b/frontend/src/components/auth/__tests__/PendingOAuthCreateAccountForm.spec.ts
new file mode 100644
index 00000000..63aeebc6
--- /dev/null
+++ b/frontend/src/components/auth/__tests__/PendingOAuthCreateAccountForm.spec.ts
@@ -0,0 +1,80 @@
+import { beforeEach, describe, expect, it, vi } from 'vitest'
+import { flushPromises, mount } from '@vue/test-utils'
+
+import PendingOAuthCreateAccountForm from '../PendingOAuthCreateAccountForm.vue'
+
+const sendVerifyCode = vi.fn()
+
+vi.mock('vue-i18n', async () => {
+ const actual = await vi.importActual('vue-i18n')
+ return {
+ ...actual,
+ useI18n: () => ({
+ t: (key: string) => key
+ })
+ }
+})
+
+vi.mock('@/api/auth', async () => {
+ const actual = await vi.importActual('@/api/auth')
+ return {
+ ...actual,
+ sendVerifyCode: (...args: any[]) => sendVerifyCode(...args)
+ }
+})
+
+describe('PendingOAuthCreateAccountForm', () => {
+ beforeEach(() => {
+ sendVerifyCode.mockReset()
+ })
+
+ it('emits trimmed email, password, and verify code on submit', async () => {
+ const wrapper = mount(PendingOAuthCreateAccountForm, {
+ props: {
+ providerName: 'LinuxDo',
+ testIdPrefix: 'linuxdo',
+ initialEmail: 'prefill@example.com',
+ isSubmitting: false
+ }
+ })
+
+ await wrapper.get('[data-testid="linuxdo-create-account-email"]').setValue(' user@example.com ')
+ await wrapper.get('[data-testid="linuxdo-create-account-password"]').setValue('secret-123')
+ await wrapper.get('[data-testid="linuxdo-create-account-verify-code"]').setValue(' 246810 ')
+ await wrapper.get('form').trigger('submit.prevent')
+
+ expect(wrapper.emitted('submit')).toEqual([
+ [
+ {
+ email: 'user@example.com',
+ password: 'secret-123',
+ verifyCode: '246810'
+ }
+ ]
+ ])
+ })
+
+ it('sends a verify code for the trimmed email value', async () => {
+ sendVerifyCode.mockResolvedValue({
+ message: 'sent',
+ countdown: 60
+ })
+
+ const wrapper = mount(PendingOAuthCreateAccountForm, {
+ props: {
+ providerName: 'LinuxDo',
+ testIdPrefix: 'linuxdo',
+ initialEmail: '',
+ isSubmitting: false
+ }
+ })
+
+ await wrapper.get('[data-testid="linuxdo-create-account-email"]').setValue(' user@example.com ')
+ await wrapper.get('[data-testid="linuxdo-create-account-send-code"]').trigger('click')
+ await flushPromises()
+
+ expect(sendVerifyCode).toHaveBeenCalledWith({
+ email: 'user@example.com'
+ })
+ })
+})
diff --git a/frontend/src/views/auth/LinuxDoCallbackView.vue b/frontend/src/views/auth/LinuxDoCallbackView.vue
index ce7f92ef..128410bc 100644
--- a/frontend/src/views/auth/LinuxDoCallbackView.vue
+++ b/frontend/src/views/auth/LinuxDoCallbackView.vue
@@ -113,37 +113,14 @@
Enter an email address to create your account and continue.
-
-
-
- {{ isSubmitting ? t('common.processing') : 'Create account' }}
-
-
- I already have an account
-
-
-
-
- {{ accountActionError }}
-
-
+
@@ -258,6 +235,9 @@ import { computed, onMounted, ref } from 'vue'
import { useRoute, useRouter } from 'vue-router'
import { useI18n } from 'vue-i18n'
import { AuthLayout } from '@/components/layout'
+import PendingOAuthCreateAccountForm, {
+ type PendingOAuthCreateAccountPayload
+} from '@/components/auth/PendingOAuthCreateAccountForm.vue'
import Icon from '@/components/icons/Icon.vue'
import { apiClient } from '@/api/client'
import { useAuthStore, useAppStore } from '@/stores'
@@ -432,9 +412,9 @@ function applyTotpChallenge(completion: LinuxDoPendingActionResponse): boolean {
return true
}
-function switchToBindLoginMode() {
+function switchToBindLoginMode(nextEmail?: string) {
pendingAccountAction.value = 'bind_login'
- bindLoginEmail.value = bindLoginEmail.value.trim() || pendingAccountEmail.value.trim()
+ bindLoginEmail.value = bindLoginEmail.value.trim() || nextEmail?.trim() || pendingAccountEmail.value.trim()
bindLoginPassword.value = ''
accountActionError.value = ''
canReturnToCreateAccount.value = true
@@ -533,15 +513,16 @@ async function handleContinueLogin() {
}
}
-async function handleCreateAccount() {
+async function handleCreateAccount(payload: PendingOAuthCreateAccountPayload) {
accountActionError.value = ''
- const email = pendingAccountEmail.value.trim()
- if (!email) return
+ if (!payload.email || !payload.password) return
isSubmitting.value = true
try {
const { data } = await apiClient.post('/auth/oauth/pending/create-account', {
- email,
+ email: payload.email,
+ password: payload.password,
+ verify_code: payload.verifyCode || undefined,
...serializeAdoptionDecision(currentAdoptionDecision())
})
await finalizePendingAccountResponse(data)
diff --git a/frontend/src/views/auth/OidcCallbackView.vue b/frontend/src/views/auth/OidcCallbackView.vue
index de3f2e40..344c3537 100644
--- a/frontend/src/views/auth/OidcCallbackView.vue
+++ b/frontend/src/views/auth/OidcCallbackView.vue
@@ -122,37 +122,14 @@
Enter an email address to create your account and continue.
-
-
-
- {{ isSubmitting ? t('common.processing') : 'Create account' }}
-
-
- I already have an account
-
-
-
-
- {{ accountActionError }}
-
-
+
@@ -267,6 +244,9 @@ import { computed, onMounted, ref } from 'vue'
import { useRoute, useRouter } from 'vue-router'
import { useI18n } from 'vue-i18n'
import { AuthLayout } from '@/components/layout'
+import PendingOAuthCreateAccountForm, {
+ type PendingOAuthCreateAccountPayload
+} from '@/components/auth/PendingOAuthCreateAccountForm.vue'
import Icon from '@/components/icons/Icon.vue'
import { apiClient } from '@/api/client'
import { useAuthStore, useAppStore } from '@/stores'
@@ -476,9 +456,9 @@ function applyTotpChallenge(completion: PendingOidcCompletion): boolean {
return true
}
-function switchToBindLoginMode() {
+function switchToBindLoginMode(nextEmail?: string) {
pendingAccountAction.value = 'bind_login'
- bindLoginEmail.value = bindLoginEmail.value.trim() || pendingAccountEmail.value.trim()
+ bindLoginEmail.value = bindLoginEmail.value.trim() || nextEmail?.trim() || pendingAccountEmail.value.trim()
bindLoginPassword.value = ''
accountActionError.value = ''
canReturnToCreateAccount.value = true
@@ -577,15 +557,16 @@ async function handleContinueLogin() {
}
}
-async function handleCreateAccount() {
+async function handleCreateAccount(payload: PendingOAuthCreateAccountPayload) {
accountActionError.value = ''
- const email = pendingAccountEmail.value.trim()
- if (!email) return
+ if (!payload.email || !payload.password) return
isSubmitting.value = true
try {
const { data } = await apiClient.post('/auth/oauth/pending/create-account', {
- email,
+ email: payload.email,
+ password: payload.password,
+ verify_code: payload.verifyCode || undefined,
...serializeAdoptionDecision(currentAdoptionDecision())
})
await finalizePendingAccountResponse(data)
diff --git a/frontend/src/views/auth/WechatCallbackView.vue b/frontend/src/views/auth/WechatCallbackView.vue
index e4dd6301..10b83b1c 100644
--- a/frontend/src/views/auth/WechatCallbackView.vue
+++ b/frontend/src/views/auth/WechatCallbackView.vue
@@ -160,37 +160,14 @@
Enter an email address to create your account and continue.
-
-
-
- {{ isSubmitting ? t('common.processing') : 'Create account' }}
-
-
- I already have an account
-
-
-
-
- {{ accountActionError }}
-
-
+
@@ -305,6 +282,9 @@ import { computed, onMounted, ref } from 'vue'
import { useRoute, useRouter } from 'vue-router'
import { useI18n } from 'vue-i18n'
import { AuthLayout } from '@/components/layout'
+import PendingOAuthCreateAccountForm, {
+ type PendingOAuthCreateAccountPayload
+} from '@/components/auth/PendingOAuthCreateAccountForm.vue'
import Icon from '@/components/icons/Icon.vue'
import { apiClient } from '@/api/client'
import { useAuthStore, useAppStore } from '@/stores'
@@ -575,9 +555,9 @@ function applyTotpChallenge(completion: PendingWeChatCompletion): boolean {
return true
}
-function switchToBindLoginMode() {
+function switchToBindLoginMode(nextEmail?: string) {
pendingAccountAction.value = 'bind_login'
- bindLoginEmail.value = bindLoginEmail.value.trim() || pendingAccountEmail.value.trim()
+ bindLoginEmail.value = bindLoginEmail.value.trim() || nextEmail?.trim() || pendingAccountEmail.value.trim()
bindLoginPassword.value = ''
accountActionError.value = ''
canReturnToCreateAccount.value = true
@@ -676,15 +656,16 @@ async function handleContinueLogin() {
}
}
-async function handleCreateAccount() {
+async function handleCreateAccount(payload: PendingOAuthCreateAccountPayload) {
accountActionError.value = ''
- const email = pendingAccountEmail.value.trim()
- if (!email) return
+ if (!payload.email || !payload.password) return
isSubmitting.value = true
try {
const { data } = await apiClient.post('/auth/oauth/pending/create-account', {
- email,
+ email: payload.email,
+ password: payload.password,
+ verify_code: payload.verifyCode || undefined,
...serializeAdoptionDecision(currentAdoptionDecision())
})
await finalizePendingAccountResponse(data)
diff --git a/frontend/src/views/auth/__tests__/LinuxDoCallbackView.spec.ts b/frontend/src/views/auth/__tests__/LinuxDoCallbackView.spec.ts
index 8b2482fa..b9930b70 100644
--- a/frontend/src/views/auth/__tests__/LinuxDoCallbackView.spec.ts
+++ b/frontend/src/views/auth/__tests__/LinuxDoCallbackView.spec.ts
@@ -11,6 +11,7 @@ const exchangePendingOAuthCompletion = vi.fn()
const completeLinuxDoOAuthRegistration = vi.fn()
const login2FA = vi.fn()
const apiClientPost = vi.fn()
+const sendVerifyCode = vi.fn()
vi.mock('vue-router', () => ({
useRoute: () => ({
@@ -53,7 +54,8 @@ vi.mock('@/api/auth', async () => {
...actual,
exchangePendingOAuthCompletion: (...args: any[]) => exchangePendingOAuthCompletion(...args),
completeLinuxDoOAuthRegistration: (...args: any[]) => completeLinuxDoOAuthRegistration(...args),
- login2FA: (...args: any[]) => login2FA(...args)
+ login2FA: (...args: any[]) => login2FA(...args),
+ sendVerifyCode: (...args: any[]) => sendVerifyCode(...args)
}
})
@@ -67,6 +69,7 @@ describe('LinuxDoCallbackView', () => {
completeLinuxDoOAuthRegistration.mockReset()
login2FA.mockReset()
apiClientPost.mockReset()
+ sendVerifyCode.mockReset()
})
it('does not send adoption decisions during the initial exchange', async () => {
@@ -251,7 +254,7 @@ describe('LinuxDoCallbackView', () => {
})
})
- it('collects email for pending oauth account creation and submits adoption decisions', async () => {
+ it('collects email, password, and verify code for pending oauth account creation and submits adoption decisions', async () => {
exchangePendingOAuthCompletion.mockResolvedValue({
error: 'email_required',
redirect: '/welcome',
@@ -286,11 +289,15 @@ describe('LinuxDoCallbackView', () => {
expect(checkboxes).toHaveLength(2)
await checkboxes[1].setValue(false)
await wrapper.get('[data-testid="linuxdo-create-account-email"]').setValue(' new@example.com ')
+ await wrapper.get('[data-testid="linuxdo-create-account-password"]').setValue('secret-123')
+ await wrapper.get('[data-testid="linuxdo-create-account-verify-code"]').setValue('246810')
await wrapper.get('[data-testid="linuxdo-create-account-submit"]').trigger('click')
await flushPromises()
expect(apiClientPost).toHaveBeenCalledWith('/auth/oauth/pending/create-account', {
email: 'new@example.com',
+ password: 'secret-123',
+ verify_code: '246810',
adopt_display_name: true,
adopt_avatar: false
})
@@ -298,6 +305,38 @@ describe('LinuxDoCallbackView', () => {
expect(replace).toHaveBeenCalledWith('/welcome')
})
+ it('sends a verify code for pending oauth account creation', async () => {
+ exchangePendingOAuthCompletion.mockResolvedValue({
+ error: 'email_required',
+ redirect: '/welcome'
+ })
+ sendVerifyCode.mockResolvedValue({
+ message: 'sent',
+ countdown: 60
+ })
+
+ const wrapper = mount(LinuxDoCallbackView, {
+ global: {
+ stubs: {
+ AuthLayout: { template: '
' },
+ Icon: true,
+ RouterLink: { template: ' ' },
+ transition: false
+ }
+ }
+ })
+
+ await flushPromises()
+
+ await wrapper.get('[data-testid="linuxdo-create-account-email"]').setValue(' new@example.com ')
+ await wrapper.get('[data-testid="linuxdo-create-account-send-code"]').trigger('click')
+ await flushPromises()
+
+ expect(sendVerifyCode).toHaveBeenCalledWith({
+ email: 'new@example.com'
+ })
+ })
+
it('shows bind-login form for existing account binding and submits credentials with adoption decisions', async () => {
exchangePendingOAuthCompletion.mockResolvedValue({
error: 'bind_login_required',
diff --git a/frontend/src/views/auth/__tests__/OidcCallbackView.spec.ts b/frontend/src/views/auth/__tests__/OidcCallbackView.spec.ts
index c7460d8f..cb28e283 100644
--- a/frontend/src/views/auth/__tests__/OidcCallbackView.spec.ts
+++ b/frontend/src/views/auth/__tests__/OidcCallbackView.spec.ts
@@ -12,6 +12,7 @@ const completeOIDCOAuthRegistration = vi.fn()
const getPublicSettings = vi.fn()
const login2FA = vi.fn()
const apiClientPost = vi.fn()
+const sendVerifyCode = vi.fn()
vi.mock('vue-router', () => ({
useRoute: () => ({
@@ -60,7 +61,8 @@ vi.mock('@/api/auth', async () => {
exchangePendingOAuthCompletion: (...args: any[]) => exchangePendingOAuthCompletion(...args),
completeOIDCOAuthRegistration: (...args: any[]) => completeOIDCOAuthRegistration(...args),
getPublicSettings: (...args: any[]) => getPublicSettings(...args),
- login2FA: (...args: any[]) => login2FA(...args)
+ login2FA: (...args: any[]) => login2FA(...args),
+ sendVerifyCode: (...args: any[]) => sendVerifyCode(...args)
}
})
@@ -75,6 +77,7 @@ describe('OidcCallbackView', () => {
getPublicSettings.mockReset()
login2FA.mockReset()
apiClientPost.mockReset()
+ sendVerifyCode.mockReset()
getPublicSettings.mockResolvedValue({
oidc_oauth_provider_name: 'ExampleID'
})
@@ -234,7 +237,7 @@ describe('OidcCallbackView', () => {
})
})
- it('collects email for pending oauth account creation and submits adoption decisions', async () => {
+ it('collects email, password, and verify code for pending oauth account creation and submits adoption decisions', async () => {
exchangePendingOAuthCompletion.mockResolvedValue({
error: 'email_required',
redirect: '/welcome',
@@ -269,11 +272,15 @@ describe('OidcCallbackView', () => {
expect(checkboxes).toHaveLength(2)
await checkboxes[1].setValue(false)
await wrapper.get('[data-testid="oidc-create-account-email"]').setValue(' new@example.com ')
+ await wrapper.get('[data-testid="oidc-create-account-password"]').setValue('secret-123')
+ await wrapper.get('[data-testid="oidc-create-account-verify-code"]').setValue('246810')
await wrapper.get('[data-testid="oidc-create-account-submit"]').trigger('click')
await flushPromises()
expect(apiClientPost).toHaveBeenCalledWith('/auth/oauth/pending/create-account', {
email: 'new@example.com',
+ password: 'secret-123',
+ verify_code: '246810',
adopt_display_name: true,
adopt_avatar: false
})
@@ -281,6 +288,38 @@ describe('OidcCallbackView', () => {
expect(replace).toHaveBeenCalledWith('/welcome')
})
+ it('sends a verify code for pending oauth account creation', async () => {
+ exchangePendingOAuthCompletion.mockResolvedValue({
+ error: 'email_required',
+ redirect: '/welcome'
+ })
+ sendVerifyCode.mockResolvedValue({
+ message: 'sent',
+ countdown: 60
+ })
+
+ const wrapper = mount(OidcCallbackView, {
+ global: {
+ stubs: {
+ AuthLayout: { template: '
' },
+ Icon: true,
+ RouterLink: { template: ' ' },
+ transition: false
+ }
+ }
+ })
+
+ await flushPromises()
+
+ await wrapper.get('[data-testid="oidc-create-account-email"]').setValue(' new@example.com ')
+ await wrapper.get('[data-testid="oidc-create-account-send-code"]').trigger('click')
+ await flushPromises()
+
+ expect(sendVerifyCode).toHaveBeenCalledWith({
+ email: 'new@example.com'
+ })
+ })
+
it('shows bind-login form for existing account binding and submits credentials with adoption decisions', async () => {
exchangePendingOAuthCompletion.mockResolvedValue({
error: 'adopt_existing_user_by_email',
diff --git a/frontend/src/views/auth/__tests__/WechatCallbackView.spec.ts b/frontend/src/views/auth/__tests__/WechatCallbackView.spec.ts
index c49d0243..aa673238 100644
--- a/frontend/src/views/auth/__tests__/WechatCallbackView.spec.ts
+++ b/frontend/src/views/auth/__tests__/WechatCallbackView.spec.ts
@@ -7,6 +7,7 @@ const {
completeWeChatOAuthRegistrationMock,
login2FAMock,
apiClientPostMock,
+ sendVerifyCodeMock,
prepareOAuthBindAccessTokenCookieMock,
getAuthTokenMock,
replaceMock,
@@ -20,6 +21,7 @@ const {
completeWeChatOAuthRegistrationMock: vi.fn(),
login2FAMock: vi.fn(),
apiClientPostMock: vi.fn(),
+ sendVerifyCodeMock: vi.fn(),
prepareOAuthBindAccessTokenCookieMock: vi.fn(),
getAuthTokenMock: vi.fn(),
replaceMock: vi.fn(),
@@ -118,6 +120,7 @@ vi.mock('@/api/auth', async () => {
exchangePendingOAuthCompletion: (...args: any[]) => exchangePendingOAuthCompletionMock(...args),
completeWeChatOAuthRegistration: (...args: any[]) => completeWeChatOAuthRegistrationMock(...args),
login2FA: (...args: any[]) => login2FAMock(...args),
+ sendVerifyCode: (...args: any[]) => sendVerifyCodeMock(...args),
prepareOAuthBindAccessTokenCookie: (...args: any[]) => prepareOAuthBindAccessTokenCookieMock(...args),
getAuthToken: (...args: any[]) => getAuthTokenMock(...args),
}
@@ -129,6 +132,7 @@ describe('WechatCallbackView', () => {
completeWeChatOAuthRegistrationMock.mockReset()
login2FAMock.mockReset()
apiClientPostMock.mockReset()
+ sendVerifyCodeMock.mockReset()
replaceMock.mockReset()
setTokenMock.mockReset()
showSuccessMock.mockReset()
@@ -374,7 +378,7 @@ describe('WechatCallbackView', () => {
expect(locationState.current.href).toContain('mode=open')
})
- it('collects email for pending oauth account creation and submits adoption decisions', async () => {
+ it('collects email, password, and verify code for pending oauth account creation and submits adoption decisions', async () => {
exchangePendingOAuthCompletionMock.mockResolvedValue({
error: 'email_required',
redirect: '/welcome',
@@ -409,11 +413,15 @@ describe('WechatCallbackView', () => {
expect(checkboxes).toHaveLength(2)
await checkboxes[1].setValue(false)
await wrapper.get('[data-testid="wechat-create-account-email"]').setValue(' new@example.com ')
+ await wrapper.get('[data-testid="wechat-create-account-password"]').setValue('secret-123')
+ await wrapper.get('[data-testid="wechat-create-account-verify-code"]').setValue('246810')
await wrapper.get('[data-testid="wechat-create-account-submit"]').trigger('click')
await flushPromises()
expect(apiClientPostMock).toHaveBeenCalledWith('/auth/oauth/pending/create-account', {
email: 'new@example.com',
+ password: 'secret-123',
+ verify_code: '246810',
adopt_display_name: true,
adopt_avatar: false,
})
@@ -421,6 +429,38 @@ describe('WechatCallbackView', () => {
expect(replaceMock).toHaveBeenCalledWith('/welcome')
})
+ it('sends a verify code for pending oauth account creation', async () => {
+ exchangePendingOAuthCompletionMock.mockResolvedValue({
+ error: 'email_required',
+ redirect: '/welcome',
+ })
+ sendVerifyCodeMock.mockResolvedValue({
+ message: 'sent',
+ countdown: 60,
+ })
+
+ const wrapper = mount(WechatCallbackView, {
+ global: {
+ stubs: {
+ AuthLayout: { template: '
' },
+ Icon: true,
+ RouterLink: { template: ' ' },
+ transition: false,
+ },
+ },
+ })
+
+ await flushPromises()
+
+ await wrapper.get('[data-testid="wechat-create-account-email"]').setValue(' new@example.com ')
+ await wrapper.get('[data-testid="wechat-create-account-send-code"]').trigger('click')
+ await flushPromises()
+
+ expect(sendVerifyCodeMock).toHaveBeenCalledWith({
+ email: 'new@example.com',
+ })
+ })
+
it('shows bind-login form for existing account binding and submits credentials with adoption decisions', async () => {
exchangePendingOAuthCompletionMock.mockResolvedValue({
step: 'bind_login_required',
From 4ebdfcd13a6966085fdbc5145d1ac93222d11b35 Mon Sep 17 00:00:00 2001
From: IanShaw027
Date: Tue, 21 Apr 2026 00:03:27 +0800
Subject: [PATCH 060/189] test(admin): constrain payment visible method sources
---
.../settings.paymentVisibleMethods.spec.ts | 63 +++
frontend/src/api/admin/settings.ts | 68 +++
frontend/src/views/admin/SettingsView.vue | 61 ++-
.../admin/__tests__/SettingsView.spec.ts | 452 ++++++++++++++++++
4 files changed, 631 insertions(+), 13 deletions(-)
create mode 100644 frontend/src/api/__tests__/settings.paymentVisibleMethods.spec.ts
create mode 100644 frontend/src/views/admin/__tests__/SettingsView.spec.ts
diff --git a/frontend/src/api/__tests__/settings.paymentVisibleMethods.spec.ts b/frontend/src/api/__tests__/settings.paymentVisibleMethods.spec.ts
new file mode 100644
index 00000000..3b1a373f
--- /dev/null
+++ b/frontend/src/api/__tests__/settings.paymentVisibleMethods.spec.ts
@@ -0,0 +1,63 @@
+import { describe, expect, it } from 'vitest'
+
+import {
+ getPaymentVisibleMethodSourceOptions,
+ normalizePaymentVisibleMethodSource,
+} from '@/api/admin/settings'
+
+describe('admin settings payment visible method helpers', () => {
+ it('normalizes aliases into canonical source keys per visible method', () => {
+ expect(normalizePaymentVisibleMethodSource('alipay', 'official')).toBe('official_alipay')
+ expect(normalizePaymentVisibleMethodSource('alipay', 'alipay_direct')).toBe('official_alipay')
+ expect(normalizePaymentVisibleMethodSource('alipay', 'easypay')).toBe('easypay_alipay')
+
+ expect(normalizePaymentVisibleMethodSource('wxpay', 'official')).toBe('official_wxpay')
+ expect(normalizePaymentVisibleMethodSource('wxpay', 'wechat')).toBe('official_wxpay')
+ expect(normalizePaymentVisibleMethodSource('wxpay', 'easypay')).toBe('easypay_wxpay')
+ })
+
+ it('rejects unknown or cross-method source values', () => {
+ expect(normalizePaymentVisibleMethodSource('alipay', 'official_wxpay')).toBe('')
+ expect(normalizePaymentVisibleMethodSource('wxpay', 'official_alipay')).toBe('')
+ expect(normalizePaymentVisibleMethodSource('alipay', 'unknown')).toBe('')
+ expect(normalizePaymentVisibleMethodSource('wxpay', null)).toBe('')
+ })
+
+ it('exposes method-scoped source options instead of arbitrary strings', () => {
+ expect(getPaymentVisibleMethodSourceOptions('alipay')).toEqual([
+ {
+ value: '',
+ labelZh: '自动路由',
+ labelEn: 'Automatic routing',
+ },
+ {
+ value: 'official_alipay',
+ labelZh: '支付宝官方',
+ labelEn: 'Official Alipay',
+ },
+ {
+ value: 'easypay_alipay',
+ labelZh: '易支付支付宝',
+ labelEn: 'EasyPay Alipay',
+ },
+ ])
+
+ expect(getPaymentVisibleMethodSourceOptions('wxpay')).toEqual([
+ {
+ value: '',
+ labelZh: '自动路由',
+ labelEn: 'Automatic routing',
+ },
+ {
+ value: 'official_wxpay',
+ labelZh: '微信官方',
+ labelEn: 'Official WeChat Pay',
+ },
+ {
+ value: 'easypay_wxpay',
+ labelZh: '易支付微信',
+ labelEn: 'EasyPay WeChat Pay',
+ },
+ ])
+ })
+})
diff --git a/frontend/src/api/admin/settings.ts b/frontend/src/api/admin/settings.ts
index 8e182c1c..505fcdca 100644
--- a/frontend/src/api/admin/settings.ts
+++ b/frontend/src/api/admin/settings.ts
@@ -22,10 +22,60 @@ export interface AuthSourceDefaultsValue {
}
export type AuthSourceDefaultsState = Record
+export type PaymentVisibleMethod = 'alipay' | 'wxpay'
+export type PaymentVisibleMethodSource =
+ | ''
+ | 'official_alipay'
+ | 'easypay_alipay'
+ | 'official_wxpay'
+ | 'easypay_wxpay'
+
+export interface PaymentVisibleMethodSourceOption {
+ value: PaymentVisibleMethodSource
+ labelZh: string
+ labelEn: string
+}
const AUTH_SOURCE_TYPES: AuthSourceType[] = ['email', 'linuxdo', 'oidc', 'wechat']
const AUTH_SOURCE_DEFAULT_BALANCE = 0
const AUTH_SOURCE_DEFAULT_CONCURRENCY = 5
+const PAYMENT_VISIBLE_METHOD_SOURCE_OPTIONS: Record<
+ PaymentVisibleMethod,
+ PaymentVisibleMethodSourceOption[]
+> = {
+ alipay: [
+ { value: '', labelZh: '自动路由', labelEn: 'Automatic routing' },
+ { value: 'official_alipay', labelZh: '支付宝官方', labelEn: 'Official Alipay' },
+ { value: 'easypay_alipay', labelZh: '易支付支付宝', labelEn: 'EasyPay Alipay' },
+ ],
+ wxpay: [
+ { value: '', labelZh: '自动路由', labelEn: 'Automatic routing' },
+ { value: 'official_wxpay', labelZh: '微信官方', labelEn: 'Official WeChat Pay' },
+ { value: 'easypay_wxpay', labelZh: '易支付微信', labelEn: 'EasyPay WeChat Pay' },
+ ],
+}
+const PAYMENT_VISIBLE_METHOD_SOURCE_ALIASES: Record<
+ PaymentVisibleMethod,
+ Record
+> = {
+ alipay: {
+ official_alipay: 'official_alipay',
+ alipay: 'official_alipay',
+ alipay_direct: 'official_alipay',
+ official: 'official_alipay',
+ easypay_alipay: 'easypay_alipay',
+ easypay: 'easypay_alipay',
+ },
+ wxpay: {
+ official_wxpay: 'official_wxpay',
+ wxpay: 'official_wxpay',
+ wxpay_direct: 'official_wxpay',
+ wechat: 'official_wxpay',
+ official: 'official_wxpay',
+ easypay_wxpay: 'easypay_wxpay',
+ easypay: 'easypay_wxpay',
+ },
+}
export function normalizeDefaultSubscriptionSettings(
subscriptions: DefaultSubscriptionSetting[] | null | undefined
@@ -86,6 +136,24 @@ export function appendAuthSourceDefaultsToUpdateRequest(
return payload
}
+export function getPaymentVisibleMethodSourceOptions(
+ method: PaymentVisibleMethod
+): PaymentVisibleMethodSourceOption[] {
+ return PAYMENT_VISIBLE_METHOD_SOURCE_OPTIONS[method]
+}
+
+export function normalizePaymentVisibleMethodSource(
+ method: PaymentVisibleMethod,
+ source: unknown
+): PaymentVisibleMethodSource {
+ if (typeof source !== 'string') return ''
+
+ const normalized = source.trim().toLowerCase()
+ if (!normalized) return ''
+
+ return PAYMENT_VISIBLE_METHOD_SOURCE_ALIASES[method][normalized] ?? ''
+}
+
/**
* System settings interface
*/
diff --git a/frontend/src/views/admin/SettingsView.vue b/frontend/src/views/admin/SettingsView.vue
index 0d23baa5..8a042e70 100644
--- a/frontend/src/views/admin/SettingsView.vue
+++ b/frontend/src/views/admin/SettingsView.vue
@@ -2717,20 +2717,19 @@
- {{ localText('来源键', 'Source key') }}
+ {{ localText('支付来源', 'Payment source') }}
-
{{
localText(
- '留空表示由后端使用默认来源;可填 easypay、alipay、wxpay 等来源标识。',
- 'Leave blank to let the backend decide. Typical values are easypay, alipay, or wxpay.'
+ '留空表示自动路由;仅允许当前系统支持的官方或易支付来源。',
+ 'Leave blank for automatic routing. Only supported official or EasyPay sources are allowed.'
)
}}
@@ -3117,11 +3116,14 @@ import { adminAPI } from '@/api'
import {
appendAuthSourceDefaultsToUpdateRequest,
buildAuthSourceDefaultsState,
+ getPaymentVisibleMethodSourceOptions,
+ normalizePaymentVisibleMethodSource,
normalizeDefaultSubscriptionSettings,
} from '@/api/admin/settings'
import type {
AuthSourceDefaultsState,
AuthSourceType,
+ PaymentVisibleMethod,
SystemSettings,
UpdateSettingsRequest,
DefaultSubscriptionSetting,
@@ -3429,12 +3431,23 @@ function getPaymentVisibleMethodSource(method: 'alipay' | 'wxpay'): string {
: form.payment_visible_method_wxpay_source
}
-function setPaymentVisibleMethodSource(method: 'alipay' | 'wxpay', source: string) {
+function getPaymentVisibleMethodSourceSelectOptions(method: PaymentVisibleMethod) {
+ return getPaymentVisibleMethodSourceOptions(method).map((option) => ({
+ value: option.value,
+ label: localText(option.labelZh, option.labelEn),
+ }))
+}
+
+function setPaymentVisibleMethodSource(
+ method: 'alipay' | 'wxpay',
+ source: string | number | boolean | null
+) {
+ const normalized = normalizePaymentVisibleMethodSource(method, source)
if (method === 'alipay') {
- form.payment_visible_method_alipay_source = source
+ form.payment_visible_method_alipay_source = normalized
return
}
- form.payment_visible_method_wxpay_source = source
+ form.payment_visible_method_wxpay_source = normalized
}
// Proxies for web search emulation ProxySelector
@@ -3805,6 +3818,14 @@ async function loadSettings() {
Object.assign(authSourceDefaults, buildAuthSourceDefaultsState(settings))
form.backend_mode_enabled = settings.backend_mode_enabled
form.default_subscriptions = normalizeDefaultSubscriptionSettings(settings.default_subscriptions)
+ form.payment_visible_method_alipay_source = normalizePaymentVisibleMethodSource(
+ 'alipay',
+ settings.payment_visible_method_alipay_source
+ )
+ form.payment_visible_method_wxpay_source = normalizePaymentVisibleMethodSource(
+ 'wxpay',
+ settings.payment_visible_method_wxpay_source
+ )
registrationEmailSuffixWhitelistTags.value = normalizeRegistrationEmailSuffixDomains(
settings.registration_email_suffix_whitelist
)
@@ -4070,8 +4091,14 @@ async function saveSettings() {
payment_cancel_rate_limit_window: Number(form.payment_cancel_rate_limit_window) || 1,
payment_cancel_rate_limit_unit: form.payment_cancel_rate_limit_unit,
payment_cancel_rate_limit_window_mode: form.payment_cancel_rate_limit_window_mode,
- payment_visible_method_alipay_source: form.payment_visible_method_alipay_source,
- payment_visible_method_wxpay_source: form.payment_visible_method_wxpay_source,
+ payment_visible_method_alipay_source: normalizePaymentVisibleMethodSource(
+ 'alipay',
+ form.payment_visible_method_alipay_source
+ ),
+ payment_visible_method_wxpay_source: normalizePaymentVisibleMethodSource(
+ 'wxpay',
+ form.payment_visible_method_wxpay_source
+ ),
payment_visible_method_alipay_enabled: form.payment_visible_method_alipay_enabled,
payment_visible_method_wxpay_enabled: form.payment_visible_method_wxpay_enabled,
openai_advanced_scheduler_enabled: form.openai_advanced_scheduler_enabled,
@@ -4092,6 +4119,14 @@ async function saveSettings() {
}
}
Object.assign(authSourceDefaults, buildAuthSourceDefaultsState(updated))
+ form.payment_visible_method_alipay_source = normalizePaymentVisibleMethodSource(
+ 'alipay',
+ updated.payment_visible_method_alipay_source
+ )
+ form.payment_visible_method_wxpay_source = normalizePaymentVisibleMethodSource(
+ 'wxpay',
+ updated.payment_visible_method_wxpay_source
+ )
registrationEmailSuffixWhitelistTags.value = normalizeRegistrationEmailSuffixDomains(
updated.registration_email_suffix_whitelist
)
diff --git a/frontend/src/views/admin/__tests__/SettingsView.spec.ts b/frontend/src/views/admin/__tests__/SettingsView.spec.ts
new file mode 100644
index 00000000..f20170e9
--- /dev/null
+++ b/frontend/src/views/admin/__tests__/SettingsView.spec.ts
@@ -0,0 +1,452 @@
+import { beforeEach, describe, expect, it, vi } from 'vitest'
+import { defineComponent, h, ref } from 'vue'
+import { flushPromises, mount } from '@vue/test-utils'
+
+import SettingsView from '../SettingsView.vue'
+
+const {
+ getSettings,
+ updateSettings,
+ getWebSearchEmulationConfig,
+ updateWebSearchEmulationConfig,
+ getAdminApiKey,
+ getOverloadCooldownSettings,
+ getStreamTimeoutSettings,
+ getRectifierSettings,
+ getBetaPolicySettings,
+ getGroups,
+ listProxies,
+ getProviders,
+ fetchPublicSettings,
+ adminSettingsFetch,
+ showError,
+ showSuccess,
+} = vi.hoisted(() => ({
+ getSettings: vi.fn(),
+ updateSettings: vi.fn(),
+ getWebSearchEmulationConfig: vi.fn(),
+ updateWebSearchEmulationConfig: vi.fn(),
+ getAdminApiKey: vi.fn(),
+ getOverloadCooldownSettings: vi.fn(),
+ getStreamTimeoutSettings: vi.fn(),
+ getRectifierSettings: vi.fn(),
+ getBetaPolicySettings: vi.fn(),
+ getGroups: vi.fn(),
+ listProxies: vi.fn(),
+ getProviders: vi.fn(),
+ fetchPublicSettings: vi.fn(),
+ adminSettingsFetch: vi.fn(),
+ showError: vi.fn(),
+ showSuccess: vi.fn(),
+}))
+
+vi.mock('@/api', () => ({
+ adminAPI: {
+ settings: {
+ getSettings,
+ updateSettings,
+ getWebSearchEmulationConfig,
+ updateWebSearchEmulationConfig,
+ getAdminApiKey,
+ getOverloadCooldownSettings,
+ getStreamTimeoutSettings,
+ getRectifierSettings,
+ getBetaPolicySettings,
+ },
+ groups: {
+ getAll: getGroups,
+ },
+ proxies: {
+ list: listProxies,
+ },
+ payment: {
+ getProviders,
+ },
+ },
+}))
+
+vi.mock('@/stores', () => ({
+ useAppStore: () => ({
+ showError,
+ showSuccess,
+ showWarning: vi.fn(),
+ showInfo: vi.fn(),
+ fetchPublicSettings,
+ }),
+}))
+
+vi.mock('@/stores/adminSettings', () => ({
+ useAdminSettingsStore: () => ({
+ fetch: adminSettingsFetch,
+ }),
+}))
+
+vi.mock('@/composables/useClipboard', () => ({
+ useClipboard: () => ({
+ copyToClipboard: vi.fn(),
+ }),
+}))
+
+vi.mock('@/utils/apiError', () => ({
+ extractApiErrorMessage: () => 'error',
+}))
+
+vi.mock('vue-i18n', async () => {
+ const actual = await vi.importActual
('vue-i18n')
+ return {
+ ...actual,
+ useI18n: () => ({
+ t: (key: string) => key,
+ locale: ref('zh-CN'),
+ }),
+ }
+})
+
+const AppLayoutStub = { template: '
' }
+const ToggleStub = defineComponent({
+ props: {
+ modelValue: {
+ type: Boolean,
+ default: false,
+ },
+ },
+ emits: ['update:modelValue'],
+ setup(props, { emit }) {
+ return () =>
+ h('input', {
+ class: 'toggle-stub',
+ type: 'checkbox',
+ checked: props.modelValue,
+ onChange: (event: Event) => {
+ emit('update:modelValue', (event.target as HTMLInputElement).checked)
+ },
+ })
+ },
+})
+
+const SelectStub = defineComponent({
+ props: {
+ modelValue: {
+ type: [String, Number, Boolean, null],
+ default: '',
+ },
+ options: {
+ type: Array,
+ default: () => [],
+ },
+ placeholder: {
+ type: String,
+ default: '',
+ },
+ },
+ emits: ['update:modelValue', 'change'],
+ setup(props, { emit }) {
+ const onChange = (event: Event) => {
+ const target = event.target as HTMLSelectElement
+ emit('update:modelValue', target.value)
+ const option = (props.options as Array>).find(
+ (item) => String(item.value ?? '') === target.value
+ ) ?? null
+ emit('change', target.value, option)
+ }
+
+ return () =>
+ h(
+ 'select',
+ {
+ class: 'select-stub',
+ value: props.modelValue ?? '',
+ 'data-placeholder': props.placeholder,
+ onChange,
+ },
+ (props.options as Array>).map((option) =>
+ h(
+ 'option',
+ {
+ key: `${String(option.value ?? '')}:${String(option.label ?? '')}`,
+ value: option.value as string,
+ },
+ String(option.label ?? '')
+ )
+ )
+ )
+ },
+})
+
+const baseSettingsResponse = {
+ registration_enabled: true,
+ email_verify_enabled: false,
+ registration_email_suffix_whitelist: [],
+ promo_code_enabled: true,
+ invitation_code_enabled: false,
+ password_reset_enabled: false,
+ totp_enabled: false,
+ totp_encryption_key_configured: false,
+ default_balance: 0,
+ default_concurrency: 1,
+ default_subscriptions: [],
+ site_name: 'Sub2API',
+ site_logo: '',
+ site_subtitle: '',
+ api_base_url: '',
+ contact_info: '',
+ doc_url: '',
+ home_content: '',
+ hide_ccs_import_button: false,
+ table_default_page_size: 20,
+ table_page_size_options: [10, 20, 50, 100],
+ backend_mode_enabled: false,
+ custom_menu_items: [],
+ custom_endpoints: [],
+ frontend_url: '',
+ smtp_host: '',
+ smtp_port: 587,
+ smtp_username: '',
+ smtp_password_configured: false,
+ smtp_from_email: '',
+ smtp_from_name: '',
+ smtp_use_tls: true,
+ turnstile_enabled: false,
+ turnstile_site_key: '',
+ turnstile_secret_key_configured: false,
+ linuxdo_connect_enabled: false,
+ linuxdo_connect_client_id: '',
+ linuxdo_connect_client_secret_configured: false,
+ linuxdo_connect_redirect_url: '',
+ oidc_connect_enabled: false,
+ oidc_connect_provider_name: 'OIDC',
+ oidc_connect_client_id: '',
+ oidc_connect_client_secret_configured: false,
+ oidc_connect_issuer_url: '',
+ oidc_connect_discovery_url: '',
+ oidc_connect_authorize_url: '',
+ oidc_connect_token_url: '',
+ oidc_connect_userinfo_url: '',
+ oidc_connect_jwks_url: '',
+ oidc_connect_scopes: 'openid email profile',
+ oidc_connect_redirect_url: '',
+ oidc_connect_frontend_redirect_url: '/auth/oidc/callback',
+ oidc_connect_token_auth_method: 'client_secret_post',
+ oidc_connect_use_pkce: true,
+ oidc_connect_validate_id_token: true,
+ oidc_connect_allowed_signing_algs: 'RS256,ES256,PS256',
+ oidc_connect_clock_skew_seconds: 120,
+ oidc_connect_require_email_verified: false,
+ oidc_connect_userinfo_email_path: '',
+ oidc_connect_userinfo_id_path: '',
+ oidc_connect_userinfo_username_path: '',
+ enable_model_fallback: false,
+ fallback_model_anthropic: '',
+ fallback_model_openai: '',
+ fallback_model_gemini: '',
+ fallback_model_antigravity: '',
+ enable_identity_patch: false,
+ identity_patch_prompt: '',
+ ops_monitoring_enabled: false,
+ ops_realtime_monitoring_enabled: false,
+ ops_query_mode_default: 'auto',
+ ops_metrics_interval_seconds: 60,
+ min_claude_code_version: '',
+ max_claude_code_version: '',
+ allow_ungrouped_key_scheduling: false,
+ enable_fingerprint_unification: true,
+ enable_metadata_passthrough: false,
+ enable_cch_signing: false,
+ payment_enabled: true,
+ payment_min_amount: 1,
+ payment_max_amount: 10000,
+ payment_daily_limit: 50000,
+ payment_order_timeout_minutes: 30,
+ payment_max_pending_orders: 3,
+ payment_enabled_types: [],
+ payment_balance_disabled: false,
+ payment_balance_recharge_multiplier: 1,
+ payment_recharge_fee_rate: 0,
+ payment_load_balance_strategy: 'round-robin',
+ payment_product_name_prefix: '',
+ payment_product_name_suffix: '',
+ payment_help_image_url: '',
+ payment_help_text: '',
+ payment_cancel_rate_limit_enabled: false,
+ payment_cancel_rate_limit_max: 10,
+ payment_cancel_rate_limit_window: 1,
+ payment_cancel_rate_limit_unit: 'day',
+ payment_cancel_rate_limit_window_mode: 'rolling',
+ payment_visible_method_alipay_source: 'alipay_direct',
+ payment_visible_method_wxpay_source: 'invalid-source',
+ payment_visible_method_alipay_enabled: true,
+ payment_visible_method_wxpay_enabled: true,
+ openai_advanced_scheduler_enabled: false,
+ balance_low_notify_enabled: false,
+ balance_low_notify_threshold: 0,
+ balance_low_notify_recharge_url: '',
+ account_quota_notify_enabled: false,
+ account_quota_notify_emails: [],
+}
+
+function mountView() {
+ return mount(SettingsView, {
+ global: {
+ stubs: {
+ AppLayout: AppLayoutStub,
+ Select: SelectStub,
+ Toggle: ToggleStub,
+ Icon: true,
+ ConfirmDialog: true,
+ PaymentProviderList: true,
+ PaymentProviderDialog: true,
+ GroupBadge: true,
+ GroupOptionItem: true,
+ ProxySelector: true,
+ ImageUpload: true,
+ BackupSettings: true,
+ },
+ },
+ })
+}
+
+async function openPaymentTab(wrapper: ReturnType) {
+ const paymentTabButton = wrapper
+ .findAll('button')
+ .find((node) => node.text().includes('admin.settings.tabs.payment'))
+
+ expect(paymentTabButton).toBeDefined()
+ await paymentTabButton?.trigger('click')
+ await flushPromises()
+}
+
+describe('admin SettingsView payment visible method controls', () => {
+ beforeEach(() => {
+ getSettings.mockReset()
+ updateSettings.mockReset()
+ getWebSearchEmulationConfig.mockReset()
+ updateWebSearchEmulationConfig.mockReset()
+ getAdminApiKey.mockReset()
+ getOverloadCooldownSettings.mockReset()
+ getStreamTimeoutSettings.mockReset()
+ getRectifierSettings.mockReset()
+ getBetaPolicySettings.mockReset()
+ getGroups.mockReset()
+ listProxies.mockReset()
+ getProviders.mockReset()
+ fetchPublicSettings.mockReset()
+ adminSettingsFetch.mockReset()
+ showError.mockReset()
+ showSuccess.mockReset()
+
+ getSettings.mockResolvedValue({ ...baseSettingsResponse })
+ updateSettings.mockImplementation(async (payload) => ({
+ ...baseSettingsResponse,
+ ...payload,
+ }))
+ getWebSearchEmulationConfig.mockResolvedValue({
+ enabled: false,
+ providers: [],
+ })
+ updateWebSearchEmulationConfig.mockResolvedValue({
+ enabled: false,
+ providers: [],
+ })
+ getAdminApiKey.mockResolvedValue({
+ exists: false,
+ masked_key: '',
+ })
+ getOverloadCooldownSettings.mockResolvedValue({
+ enabled: true,
+ cooldown_minutes: 10,
+ })
+ getStreamTimeoutSettings.mockResolvedValue({
+ enabled: true,
+ action: 'temp_unsched',
+ temp_unsched_minutes: 5,
+ threshold_count: 3,
+ threshold_window_minutes: 10,
+ })
+ getRectifierSettings.mockResolvedValue({
+ enabled: true,
+ thinking_signature_enabled: true,
+ thinking_budget_enabled: true,
+ apikey_signature_enabled: false,
+ apikey_signature_patterns: [],
+ })
+ getBetaPolicySettings.mockResolvedValue({
+ rules: [],
+ })
+ getGroups.mockResolvedValue([])
+ listProxies.mockResolvedValue({
+ items: [],
+ })
+ getProviders.mockResolvedValue({
+ data: [],
+ })
+ fetchPublicSettings.mockResolvedValue(undefined)
+ adminSettingsFetch.mockResolvedValue(undefined)
+ })
+
+ it('loads canonical source options and normalizes existing values', async () => {
+ const wrapper = mountView()
+
+ await flushPromises()
+ await openPaymentTab(wrapper)
+
+ const paymentSourceSelects = wrapper
+ .findAll('select.select-stub')
+ .filter((node) => ['alipay', 'wxpay'].includes(node.attributes('data-placeholder')))
+
+ expect(paymentSourceSelects).toHaveLength(2)
+
+ const alipaySelect = paymentSourceSelects.find(
+ (node) => node.attributes('data-placeholder') === 'alipay'
+ )
+ const wxpaySelect = paymentSourceSelects.find(
+ (node) => node.attributes('data-placeholder') === 'wxpay'
+ )
+
+ expect(alipaySelect?.element.value).toBe('official_alipay')
+ expect(alipaySelect?.findAll('option').map((option) => option.element.value)).toEqual([
+ '',
+ 'official_alipay',
+ 'easypay_alipay',
+ ])
+
+ expect(wxpaySelect?.element.value).toBe('')
+ expect(wxpaySelect?.findAll('option').map((option) => option.element.value)).toEqual([
+ '',
+ 'official_wxpay',
+ 'easypay_wxpay',
+ ])
+ })
+
+ it('saves canonical source keys selected from the dropdowns', async () => {
+ const wrapper = mountView()
+
+ await flushPromises()
+ await openPaymentTab(wrapper)
+
+ const paymentSourceSelects = wrapper
+ .findAll('select.select-stub')
+ .filter((node) => ['alipay', 'wxpay'].includes(node.attributes('data-placeholder')))
+
+ const alipaySelect = paymentSourceSelects.find(
+ (node) => node.attributes('data-placeholder') === 'alipay'
+ )
+ const wxpaySelect = paymentSourceSelects.find(
+ (node) => node.attributes('data-placeholder') === 'wxpay'
+ )
+
+ await alipaySelect?.setValue('easypay_alipay')
+ await wxpaySelect?.setValue('official_wxpay')
+ await wrapper.find('form').trigger('submit.prevent')
+ await flushPromises()
+
+ expect(updateSettings).toHaveBeenCalledTimes(1)
+ expect(updateSettings).toHaveBeenCalledWith(
+ expect.objectContaining({
+ payment_visible_method_alipay_source: 'easypay_alipay',
+ payment_visible_method_wxpay_source: 'official_wxpay',
+ payment_visible_method_alipay_enabled: true,
+ payment_visible_method_wxpay_enabled: true,
+ })
+ )
+ })
+})
From f83fd59dcadf9348c97620235a4f0cd66241f48d Mon Sep 17 00:00:00 2001
From: IanShaw027
Date: Tue, 21 Apr 2026 00:05:09 +0800
Subject: [PATCH 061/189] Refine payment UX for wallet flows
---
backend/internal/payment/provider/alipay.go | 33 +++--
.../internal/payment/provider/alipay_test.go | 113 ++++++++++++++++++
frontend/src/i18n/locales/en.ts | 12 ++
frontend/src/i18n/locales/zh.ts | 12 ++
frontend/src/views/user/PaymentResultView.vue | 9 +-
frontend/src/views/user/PaymentView.vue | 33 ++++-
.../user/__tests__/PaymentResultView.spec.ts | 25 ++++
.../views/user/__tests__/paymentUx.spec.ts | 49 ++++++++
frontend/src/views/user/paymentUx.ts | 105 ++++++++++++++++
9 files changed, 373 insertions(+), 18 deletions(-)
create mode 100644 frontend/src/views/user/__tests__/paymentUx.spec.ts
create mode 100644 frontend/src/views/user/paymentUx.ts
diff --git a/backend/internal/payment/provider/alipay.go b/backend/internal/payment/provider/alipay.go
index af8a90c6..4f87e5a7 100644
--- a/backend/internal/payment/provider/alipay.go
+++ b/backend/internal/payment/provider/alipay.go
@@ -26,6 +26,18 @@ const (
alipayRefundSuffix = "-refund"
)
+var (
+ alipayTradeWapPay = func(client *alipay.Client, param alipay.TradeWapPay) (*url.URL, error) {
+ return client.TradeWapPay(param)
+ }
+ alipayTradePagePay = func(client *alipay.Client, param alipay.TradePagePay) (*url.URL, error) {
+ return client.TradePagePay(param)
+ }
+ alipayTradePreCreate = func(ctx context.Context, client *alipay.Client, param alipay.TradePreCreate) (*alipay.TradePreCreateRsp, error) {
+ return client.TradePreCreate(ctx, param)
+ }
+)
+
// Alipay implements payment.Provider and payment.CancelableProvider using the smartwalle/alipay SDK.
type Alipay struct {
instanceID string
@@ -80,7 +92,7 @@ func (a *Alipay) SupportedTypes() []payment.PaymentType {
}
// CreatePayment creates an Alipay payment page URL.
-func (a *Alipay) CreatePayment(_ context.Context, req payment.CreatePaymentRequest) (*payment.CreatePaymentResponse, error) {
+func (a *Alipay) CreatePayment(ctx context.Context, req payment.CreatePaymentRequest) (*payment.CreatePaymentResponse, error) {
client, err := a.getClient()
if err != nil {
return nil, err
@@ -96,12 +108,12 @@ func (a *Alipay) CreatePayment(_ context.Context, req payment.CreatePaymentReque
}
if req.IsMobile {
- return a.createTrade(client, req, notifyURL, returnURL, true)
+ return a.createTrade(ctx, client, req, notifyURL, returnURL, true)
}
- return a.createTrade(client, req, notifyURL, returnURL, false)
+ return a.createTrade(ctx, client, req, notifyURL, returnURL, false)
}
-func (a *Alipay) createTrade(client *alipay.Client, req payment.CreatePaymentRequest, notifyURL, returnURL string, isMobile bool) (*payment.CreatePaymentResponse, error) {
+func (a *Alipay) createTrade(ctx context.Context, client *alipay.Client, req payment.CreatePaymentRequest, notifyURL, returnURL string, isMobile bool) (*payment.CreatePaymentResponse, error) {
if isMobile {
param := alipay.TradeWapPay{}
param.OutTradeNo = req.OrderID
@@ -111,7 +123,7 @@ func (a *Alipay) createTrade(client *alipay.Client, req payment.CreatePaymentReq
param.NotifyURL = notifyURL
param.ReturnURL = returnURL
- payURL, err := client.TradeWapPay(param)
+ payURL, err := alipayTradeWapPay(client, param)
if err != nil {
return nil, fmt.Errorf("alipay TradeWapPay: %w", err)
}
@@ -121,22 +133,19 @@ func (a *Alipay) createTrade(client *alipay.Client, req payment.CreatePaymentReq
}, nil
}
- param := alipay.TradePagePay{}
+ param := alipay.TradePreCreate{}
param.OutTradeNo = req.OrderID
param.TotalAmount = req.Amount
param.Subject = req.Subject
- param.ProductCode = alipayProductCodePagePay
param.NotifyURL = notifyURL
- param.ReturnURL = returnURL
- payURL, err := client.TradePagePay(param)
+ resp, err := alipayTradePreCreate(ctx, client, param)
if err != nil {
- return nil, fmt.Errorf("alipay TradePagePay: %w", err)
+ return nil, fmt.Errorf("alipay TradePreCreate: %w", err)
}
return &payment.CreatePaymentResponse{
TradeNo: req.OrderID,
- PayURL: payURL.String(),
- QRCode: payURL.String(),
+ QRCode: strings.TrimSpace(resp.QRCode),
}, nil
}
diff --git a/backend/internal/payment/provider/alipay_test.go b/backend/internal/payment/provider/alipay_test.go
index 7b0ce0d8..6cc4246c 100644
--- a/backend/internal/payment/provider/alipay_test.go
+++ b/backend/internal/payment/provider/alipay_test.go
@@ -3,9 +3,14 @@
package provider
import (
+ "context"
"errors"
+ "net/url"
"strings"
"testing"
+
+ "github.com/Wei-Shaw/sub2api/internal/payment"
+ "github.com/smartwalle/alipay/v3"
)
func TestIsTradeNotExist(t *testing.T) {
@@ -130,3 +135,111 @@ func TestNewAlipay(t *testing.T) {
})
}
}
+
+func TestCreateTradeUsesPreCreateForDesktop(t *testing.T) {
+ origPreCreate := alipayTradePreCreate
+ origPagePay := alipayTradePagePay
+ origWapPay := alipayTradeWapPay
+ t.Cleanup(func() {
+ alipayTradePreCreate = origPreCreate
+ alipayTradePagePay = origPagePay
+ alipayTradeWapPay = origWapPay
+ })
+
+ preCreateCalls := 0
+ pagePayCalls := 0
+ wapPayCalls := 0
+ alipayTradePreCreate = func(ctx context.Context, client *alipay.Client, param alipay.TradePreCreate) (*alipay.TradePreCreateRsp, error) {
+ preCreateCalls++
+ if param.OutTradeNo != "sub2_100" {
+ t.Fatalf("out_trade_no = %q, want %q", param.OutTradeNo, "sub2_100")
+ }
+ if param.NotifyURL != "https://merchant.example.com/api/v1/payment/webhook/alipay" {
+ t.Fatalf("notify_url = %q", param.NotifyURL)
+ }
+ return &alipay.TradePreCreateRsp{
+ OutTradeNo: "sub2_100",
+ QRCode: "https://qr.alipay.example.com/precreate-token",
+ }, nil
+ }
+ alipayTradePagePay = func(client *alipay.Client, param alipay.TradePagePay) (*url.URL, error) {
+ pagePayCalls++
+ return url.Parse("https://openapi.alipay.com/gateway.do?page-pay")
+ }
+ alipayTradeWapPay = func(client *alipay.Client, param alipay.TradeWapPay) (*url.URL, error) {
+ wapPayCalls++
+ return url.Parse("https://openapi.alipay.com/gateway.do?wap-pay")
+ }
+
+ provider := &Alipay{}
+ resp, err := provider.createTrade(context.Background(), &alipay.Client{}, payment.CreatePaymentRequest{
+ OrderID: "sub2_100",
+ Amount: "88.00",
+ Subject: "Balance recharge",
+ }, "https://merchant.example.com/api/v1/payment/webhook/alipay", "https://merchant.example.com/payment/result", false)
+ if err != nil {
+ t.Fatalf("unexpected error: %v", err)
+ }
+ if preCreateCalls != 1 {
+ t.Fatalf("precreate calls = %d, want 1", preCreateCalls)
+ }
+ if pagePayCalls != 0 {
+ t.Fatalf("page pay calls = %d, want 0", pagePayCalls)
+ }
+ if wapPayCalls != 0 {
+ t.Fatalf("wap pay calls = %d, want 0", wapPayCalls)
+ }
+ if resp.QRCode != "https://qr.alipay.example.com/precreate-token" {
+ t.Fatalf("qr_code = %q", resp.QRCode)
+ }
+ if resp.PayURL != "" {
+ t.Fatalf("pay_url = %q, want empty", resp.PayURL)
+ }
+}
+
+func TestCreateTradeUsesWapPayForMobile(t *testing.T) {
+ origPreCreate := alipayTradePreCreate
+ origWapPay := alipayTradeWapPay
+ t.Cleanup(func() {
+ alipayTradePreCreate = origPreCreate
+ alipayTradeWapPay = origWapPay
+ })
+
+ preCreateCalls := 0
+ alipayTradePreCreate = func(ctx context.Context, client *alipay.Client, param alipay.TradePreCreate) (*alipay.TradePreCreateRsp, error) {
+ preCreateCalls++
+ return &alipay.TradePreCreateRsp{}, nil
+ }
+
+ wapPayCalls := 0
+ alipayTradeWapPay = func(client *alipay.Client, param alipay.TradeWapPay) (*url.URL, error) {
+ wapPayCalls++
+ if param.ReturnURL != "https://merchant.example.com/payment/result" {
+ t.Fatalf("return_url = %q", param.ReturnURL)
+ }
+ return url.Parse("https://openapi.alipay.com/gateway.do?wap-pay")
+ }
+
+ provider := &Alipay{}
+ resp, err := provider.createTrade(context.Background(), &alipay.Client{}, payment.CreatePaymentRequest{
+ OrderID: "sub2_101",
+ Amount: "18.00",
+ Subject: "Balance recharge",
+ IsMobile: true,
+ }, "https://merchant.example.com/api/v1/payment/webhook/alipay", "https://merchant.example.com/payment/result", true)
+ if err != nil {
+ t.Fatalf("unexpected error: %v", err)
+ }
+ if preCreateCalls != 0 {
+ t.Fatalf("precreate calls = %d, want 0", preCreateCalls)
+ }
+ if wapPayCalls != 1 {
+ t.Fatalf("wap pay calls = %d, want 1", wapPayCalls)
+ }
+ if resp.PayURL == "" {
+ t.Fatal("expected pay_url for mobile wap pay")
+ }
+ if resp.QRCode != "" {
+ t.Fatalf("qr_code = %q, want empty", resp.QRCode)
+ }
+}
diff --git a/frontend/src/i18n/locales/en.ts b/frontend/src/i18n/locales/en.ts
index 7d058a74..e17ed616 100644
--- a/frontend/src/i18n/locales/en.ts
+++ b/frontend/src/i18n/locales/en.ts
@@ -5453,6 +5453,18 @@ export default {
errors: {
tooManyPending: 'Too many pending orders (max {max}). Please complete or cancel existing orders first.',
cancelRateLimited: 'Too many cancellations. Please try again later.',
+ wechatH5NotAuthorized: 'This merchant has not enabled WeChat H5 payment. Open this page in WeChat to continue.',
+ wechatPaymentMpNotConfigured: 'This site has not completed WeChat MP/JSAPI payment setup, so in-app WeChat payment is unavailable right now.',
+ wechatJsapiUnavailable: 'WeChat payment could not be invoked in the current environment. Reopen this page inside WeChat and try again.',
+ wechatJsapiFailed: 'WeChat payment did not complete. Try invoking it again or switch to QR payment.',
+ wechatUnavailable: 'WeChat payment is temporarily unavailable. Please try again later.',
+ wechatOpenInWeChatHint: 'Open the current page inside WeChat, or switch to desktop WeChat QR payment.',
+ wechatScanOnDesktopHint: 'On desktop, use WeChat Scan to pay; on mobile, reopen the current page inside WeChat.',
+ wechatSwitchBrowserHint: 'Switch to desktop WeChat QR payment, or reopen this page in an external browser and retry.',
+ alipayDesktopUnavailable: 'The desktop Alipay flow could not generate a QR code.',
+ alipayDesktopQrHint: 'Desktop Alipay should render a QR code. Refresh and retry, or make sure the payment page was not blocked.',
+ alipayMobileUnavailable: 'This page could not hand off to Alipay.',
+ alipayMobileOpenHint: 'Allow the current page to open the Alipay app, or retry from the system browser.',
PENDING_ORDERS: 'This provider has pending orders. Please wait for them to complete before making changes.',
},
stripePay: 'Pay Now',
diff --git a/frontend/src/i18n/locales/zh.ts b/frontend/src/i18n/locales/zh.ts
index 6dd74334..d54c0aba 100644
--- a/frontend/src/i18n/locales/zh.ts
+++ b/frontend/src/i18n/locales/zh.ts
@@ -5641,6 +5641,18 @@ export default {
errors: {
tooManyPending: '待支付订单过多(最多 {max} 个),请先完成或取消现有订单',
cancelRateLimited: '取消订单过于频繁,请稍后再试',
+ wechatH5NotAuthorized: '当前商户未开通微信 H5 支付,请在微信中打开当前页面继续支付。',
+ wechatPaymentMpNotConfigured: '当前站点未完成公众号/JSAPI 支付配置,暂时无法在微信内直接拉起支付。',
+ wechatJsapiUnavailable: '当前环境未能拉起微信支付,请确认正在微信内打开本页后重试。',
+ wechatJsapiFailed: '微信支付未完成,请重新拉起支付或改用扫码支付。',
+ wechatUnavailable: '当前微信支付暂不可用,请稍后重试。',
+ wechatOpenInWeChatHint: '请复制当前页面链接到微信内打开,或直接改用电脑端微信扫码支付。',
+ wechatScanOnDesktopHint: '电脑端请直接使用微信扫一扫完成支付;移动端请在微信内打开当前页面。',
+ wechatSwitchBrowserHint: '请改用电脑端微信扫码,或在外部浏览器重新打开本页后再试。',
+ alipayDesktopUnavailable: '当前支付宝桌面支付未成功生成二维码。',
+ alipayDesktopQrHint: '电脑端支付宝应展示扫码单,请刷新后重试,或确认浏览器未拦截当前支付页。',
+ alipayMobileUnavailable: '当前页面未成功跳转到支付宝。',
+ alipayMobileOpenHint: '请允许当前页面打开支付宝 App,或改用系统浏览器重新发起支付。',
PENDING_ORDERS: '该服务商有未完成的订单,请等待订单完成后再操作',
},
stripePay: '立即支付',
diff --git a/frontend/src/views/user/PaymentResultView.vue b/frontend/src/views/user/PaymentResultView.vue
index 9687d1c7..53bbb550 100644
--- a/frontend/src/views/user/PaymentResultView.vue
+++ b/frontend/src/views/user/PaymentResultView.vue
@@ -54,7 +54,7 @@
{{ t('payment.orders.paymentMethod') }}
- {{ t('payment.methods.' + order.payment_type, order.payment_type) }}
+ {{ t(paymentMethodI18nKey(order.payment_type), normalizedOrderPaymentType(order.payment_type)) }}
{{ t('payment.orders.status') }}
@@ -75,7 +75,7 @@
{{ t('payment.orders.paymentMethod') }}
- {{ t('payment.methods.' + returnInfo.type, returnInfo.type) }}
+ {{ t(paymentMethodI18nKey(returnInfo.type), normalizedOrderPaymentType(returnInfo.type)) }}
@@ -98,6 +98,7 @@ import { PAYMENT_RECOVERY_STORAGE_KEY, readPaymentRecoverySnapshot } from '@/com
import { usePaymentStore } from '@/stores/payment'
import { paymentAPI } from '@/api/payment'
import type { PaymentOrder } from '@/types/payment'
+import { normalizePaymentMethodForDisplay, paymentMethodI18nKey } from './paymentUx'
const { t } = useI18n()
const route = useRoute()
@@ -133,6 +134,10 @@ const isSuccess = computed(() => {
return !!order.value && SUCCESS_STATUSES.has(order.value.status)
})
+function normalizedOrderPaymentType(paymentType: string): string {
+ return normalizePaymentMethodForDisplay(paymentType) || paymentType
+}
+
onMounted(async () => {
const resumeToken = typeof route.query.resume_token === 'string'
? route.query.resume_token
diff --git a/frontend/src/views/user/PaymentView.vue b/frontend/src/views/user/PaymentView.vue
index bfb9dae2..019de16a 100644
--- a/frontend/src/views/user/PaymentView.vue
+++ b/frontend/src/views/user/PaymentView.vue
@@ -88,6 +88,7 @@
{{ errorMessage }}
+
{{ errorHintMessage }}
@@ -174,6 +175,7 @@
{{ t('common.cancel') }}
{{ errorMessage }}
+
{{ errorHintMessage }}
@@ -281,6 +283,7 @@ import SubscriptionPlanCard from '@/components/payment/SubscriptionPlanCard.vue'
import PaymentStatusPanel from '@/components/payment/PaymentStatusPanel.vue'
import Icon from '@/components/icons/Icon.vue'
import type { PaymentMethodOption } from '@/components/payment/PaymentMethodSelector.vue'
+import { describePaymentScenarioError } from './paymentUx'
const { t } = useI18n()
const route = useRoute()
@@ -301,6 +304,7 @@ function getDaysRemaining(expiresAt: string): number {
const loading = ref(true)
const submitting = ref(false)
const errorMessage = ref('')
+const errorHintMessage = ref('')
const activeTab = ref<'recharge' | 'subscription'>('recharge')
const amount = ref(null)
const selectedMethod = ref('')
@@ -619,6 +623,7 @@ async function confirmSubscribe() {
async function createOrder(orderAmount: number, orderType: OrderType, planId?: number, options: CreateOrderOptions = {}) {
submitting.value = true
errorMessage.value = ''
+ errorHintMessage.value = ''
try {
const requestType = normalizeVisibleMethod(options.paymentType || selectedMethod.value) || options.paymentType || selectedMethod.value
const payload = buildCreateOrderPayload({
@@ -668,8 +673,7 @@ async function createOrder(orderAmount: number, orderType: OrderType, planId?: n
}
if (decision.kind === 'unhandled') {
- errorMessage.value = t('payment.result.failed')
- appStore.showError(errorMessage.value)
+ applyScenarioError({ reason: 'UNHANDLED_PAYMENT_SCENARIO' }, visibleMethod)
return
}
@@ -691,7 +695,7 @@ async function createOrder(orderAmount: number, orderType: OrderType, planId?: n
if (errMsg.includes('cancel')) {
appStore.showInfo(t('payment.qr.cancelled'))
} else if (errMsg && !errMsg.includes('ok')) {
- appStore.showError(t('payment.result.failed'))
+ applyScenarioError({ reason: 'WECHAT_JSAPI_FAILED', message: errMsg }, visibleMethod)
}
return
}
@@ -707,10 +711,16 @@ async function createOrder(orderAmount: number, orderType: OrderType, planId?: n
if (apiErr.reason === 'TOO_MANY_PENDING') {
const metadata = apiErr.metadata as Record | undefined
errorMessage.value = t('payment.errors.tooManyPending', { max: metadata?.max || '' })
+ errorHintMessage.value = ''
} else if (apiErr.reason === 'CANCEL_RATE_LIMITED') {
errorMessage.value = t('payment.errors.cancelRateLimited')
+ errorHintMessage.value = ''
} else {
- errorMessage.value = extractApiErrorMessage(err, t('payment.result.failed'))
+ applyScenarioError(err, normalizeVisibleMethod(options.paymentType || selectedMethod.value) || selectedMethod.value)
+ if (!errorMessage.value) {
+ errorMessage.value = extractApiErrorMessage(err, t('payment.result.failed'))
+ errorHintMessage.value = ''
+ }
}
appStore.showError(errorMessage.value)
} finally {
@@ -718,6 +728,21 @@ async function createOrder(orderAmount: number, orderType: OrderType, planId?: n
}
}
+function applyScenarioError(err: unknown, paymentMethod: string) {
+ const descriptor = describePaymentScenarioError(err, {
+ paymentMethod,
+ isMobile: isMobileDevice(),
+ isWechatBrowser: typeof window !== 'undefined' && /MicroMessenger/i.test(window.navigator.userAgent),
+ })
+ if (!descriptor) {
+ errorMessage.value = ''
+ errorHintMessage.value = ''
+ return
+ }
+ errorMessage.value = t(descriptor.messageKey)
+ errorHintMessage.value = descriptor.hintKey ? t(descriptor.hintKey) : ''
+}
+
async function resumeWechatPaymentFromQuery() {
const openid = readRouteQueryValue(route.query.openid)
if (readRouteQueryValue(route.query.wechat_resume) !== '1' || !openid) {
diff --git a/frontend/src/views/user/__tests__/PaymentResultView.spec.ts b/frontend/src/views/user/__tests__/PaymentResultView.spec.ts
index d23a60d9..b1caa526 100644
--- a/frontend/src/views/user/__tests__/PaymentResultView.spec.ts
+++ b/frontend/src/views/user/__tests__/PaymentResultView.spec.ts
@@ -155,4 +155,29 @@ describe('PaymentResultView', () => {
expect(wrapper.text()).toContain('payment.result.success')
expect(verifyOrderPublic).not.toHaveBeenCalled()
})
+
+ it('normalizes aliased payment methods before rendering the label', async () => {
+ routeState.query = {
+ resume_token: 'resume-88',
+ }
+ resolveOrderPublicByResumeToken.mockResolvedValueOnce({
+ data: {
+ ...orderFactory('PAID'),
+ payment_type: 'alipay_direct',
+ },
+ })
+
+ const wrapper = mount(PaymentResultView, {
+ global: {
+ stubs: {
+ OrderStatusBadge: true,
+ },
+ },
+ })
+
+ await flushPromises()
+
+ expect(wrapper.text()).toContain('payment.methods.alipay')
+ expect(wrapper.text()).not.toContain('payment.methods.alipay_direct')
+ })
})
diff --git a/frontend/src/views/user/__tests__/paymentUx.spec.ts b/frontend/src/views/user/__tests__/paymentUx.spec.ts
new file mode 100644
index 00000000..6cf105f2
--- /dev/null
+++ b/frontend/src/views/user/__tests__/paymentUx.spec.ts
@@ -0,0 +1,49 @@
+import { describe, expect, it } from 'vitest'
+import {
+ describePaymentScenarioError,
+ normalizePaymentMethodForDisplay,
+} from '../paymentUx'
+
+describe('normalizePaymentMethodForDisplay', () => {
+ it('collapses visible payment aliases to canonical method ids', () => {
+ expect(normalizePaymentMethodForDisplay(' alipay_direct ')).toBe('alipay')
+ expect(normalizePaymentMethodForDisplay('wxpay_direct')).toBe('wxpay')
+ expect(normalizePaymentMethodForDisplay('wechat_pay')).toBe('wxpay')
+ })
+
+ it('leaves non-aliased methods untouched', () => {
+ expect(normalizePaymentMethodForDisplay('stripe')).toBe('stripe')
+ })
+})
+
+describe('describePaymentScenarioError', () => {
+ it('maps WeChat H5 authorization errors to explicit in-app guidance', () => {
+ expect(describePaymentScenarioError(
+ { reason: 'WECHAT_H5_NOT_AUTHORIZED' },
+ { paymentMethod: 'wxpay', isMobile: true, isWechatBrowser: false },
+ )).toEqual({
+ messageKey: 'payment.errors.wechatH5NotAuthorized',
+ hintKey: 'payment.errors.wechatOpenInWeChatHint',
+ })
+ })
+
+ it('maps missing WeixinJSBridge to a JSAPI-specific prompt', () => {
+ expect(describePaymentScenarioError(
+ new Error('WeixinJSBridge is unavailable'),
+ { paymentMethod: 'wxpay', isMobile: true, isWechatBrowser: true },
+ )).toEqual({
+ messageKey: 'payment.errors.wechatJsapiUnavailable',
+ hintKey: 'payment.errors.wechatOpenInWeChatHint',
+ })
+ })
+
+ it('maps generic desktop Alipay failures to QR guidance', () => {
+ expect(describePaymentScenarioError(
+ { reason: 'PAYMENT_GATEWAY_ERROR' },
+ { paymentMethod: 'alipay', isMobile: false, isWechatBrowser: false },
+ )).toEqual({
+ messageKey: 'payment.errors.alipayDesktopUnavailable',
+ hintKey: 'payment.errors.alipayDesktopQrHint',
+ })
+ })
+})
diff --git a/frontend/src/views/user/paymentUx.ts b/frontend/src/views/user/paymentUx.ts
new file mode 100644
index 00000000..443529a7
--- /dev/null
+++ b/frontend/src/views/user/paymentUx.ts
@@ -0,0 +1,105 @@
+import { normalizeVisibleMethod } from '@/components/payment/paymentFlow'
+import { extractApiErrorCode } from '@/utils/apiError'
+
+const DISPLAY_METHOD_ALIASES: Record = {
+ wechat: 'wxpay',
+ wechat_pay: 'wxpay',
+}
+
+export interface PaymentScenarioContext {
+ paymentMethod: string
+ isMobile: boolean
+ isWechatBrowser: boolean
+}
+
+export interface PaymentScenarioErrorDescriptor {
+ messageKey: string
+ hintKey?: string
+}
+
+export function normalizePaymentMethodForDisplay(paymentType: string): string {
+ const trimmed = paymentType.trim().toLowerCase()
+ const visibleMethod = normalizeVisibleMethod(trimmed)
+ if (visibleMethod) return visibleMethod
+ return DISPLAY_METHOD_ALIASES[trimmed] ?? trimmed
+}
+
+export function paymentMethodI18nKey(paymentType: string): string {
+ return `payment.methods.${normalizePaymentMethodForDisplay(paymentType)}`
+}
+
+function defaultWechatHint(context: PaymentScenarioContext): string {
+ if (!context.isMobile) return 'payment.errors.wechatScanOnDesktopHint'
+ return 'payment.errors.wechatOpenInWeChatHint'
+}
+
+function defaultAlipayHint(context: PaymentScenarioContext): string {
+ if (context.isMobile) return 'payment.errors.alipayMobileOpenHint'
+ return 'payment.errors.alipayDesktopQrHint'
+}
+
+export function describePaymentScenarioError(
+ error: unknown,
+ context: PaymentScenarioContext,
+): PaymentScenarioErrorDescriptor | null {
+ const method = normalizePaymentMethodForDisplay(context.paymentMethod)
+ const code = extractApiErrorCode(error)
+ const message = error instanceof Error
+ ? error.message
+ : (typeof error === 'object' && error && 'message' in error && typeof error.message === 'string'
+ ? error.message
+ : String(error || ''))
+ const normalizedMessage = message.toLowerCase()
+
+ if (method === 'wxpay') {
+ if (code === 'WECHAT_H5_NOT_AUTHORIZED') {
+ return {
+ messageKey: 'payment.errors.wechatH5NotAuthorized',
+ hintKey: defaultWechatHint(context),
+ }
+ }
+ if (code === 'WECHAT_PAYMENT_MP_NOT_CONFIGURED') {
+ return {
+ messageKey: 'payment.errors.wechatPaymentMpNotConfigured',
+ hintKey: context.isWechatBrowser
+ ? 'payment.errors.wechatSwitchBrowserHint'
+ : defaultWechatHint(context),
+ }
+ }
+ if (code === 'NO_AVAILABLE_INSTANCE') {
+ return {
+ messageKey: 'payment.errors.wechatUnavailable',
+ hintKey: defaultWechatHint(context),
+ }
+ }
+ if (code === 'WECHAT_JSAPI_FAILED' || normalizedMessage.includes('get_brand_wcpay_request:fail')) {
+ return {
+ messageKey: 'payment.errors.wechatJsapiFailed',
+ hintKey: defaultWechatHint(context),
+ }
+ }
+ if (normalizedMessage.includes('weixinjsbridge is unavailable')) {
+ return {
+ messageKey: 'payment.errors.wechatJsapiUnavailable',
+ hintKey: 'payment.errors.wechatOpenInWeChatHint',
+ }
+ }
+ if (code === 'PAYMENT_GATEWAY_ERROR' || code === 'UNHANDLED_PAYMENT_SCENARIO') {
+ return {
+ messageKey: 'payment.errors.wechatUnavailable',
+ hintKey: defaultWechatHint(context),
+ }
+ }
+ }
+
+ if (method === 'alipay' && (code === 'PAYMENT_GATEWAY_ERROR' || code === 'UNHANDLED_PAYMENT_SCENARIO')) {
+ return {
+ messageKey: context.isMobile
+ ? 'payment.errors.alipayMobileUnavailable'
+ : 'payment.errors.alipayDesktopUnavailable',
+ hintKey: defaultAlipayHint(context),
+ }
+ }
+
+ return null
+}
From 9e84e2fd2bed765d38ca4a438fbbf1739950ae32 Mon Sep 17 00:00:00 2001
From: IanShaw027
Date: Tue, 21 Apr 2026 00:05:17 +0800
Subject: [PATCH 062/189] fix: persist admin payment visibility and scheduler
settings
---
.../internal/handler/admin/setting_handler.go | 64 ++++++++++++++++
...tting_handler_auth_source_defaults_test.go | 74 +++++++++++++++++++
backend/internal/handler/dto/settings.go | 9 +++
backend/internal/server/api_contract_test.go | 63 +++++++++++++++-
.../service/admin_service_apikey_test.go | 6 ++
backend/internal/service/setting_service.go | 54 +++++++++++++-
.../service/setting_service_update_test.go | 31 ++++++++
backend/internal/service/settings_view.go | 9 +++
backend/internal/service/user_service_test.go | 6 ++
9 files changed, 311 insertions(+), 5 deletions(-)
diff --git a/backend/internal/handler/admin/setting_handler.go b/backend/internal/handler/admin/setting_handler.go
index fe5c7928..e5681208 100644
--- a/backend/internal/handler/admin/setting_handler.go
+++ b/backend/internal/handler/admin/setting_handler.go
@@ -180,6 +180,11 @@ func (h *SettingHandler) GetSettings(c *gin.Context) {
EnableMetadataPassthrough: settings.EnableMetadataPassthrough,
EnableCCHSigning: settings.EnableCCHSigning,
WebSearchEmulationEnabled: settings.WebSearchEmulationEnabled,
+ PaymentVisibleMethodAlipaySource: settings.PaymentVisibleMethodAlipaySource,
+ PaymentVisibleMethodWxpaySource: settings.PaymentVisibleMethodWxpaySource,
+ PaymentVisibleMethodAlipayEnabled: settings.PaymentVisibleMethodAlipayEnabled,
+ PaymentVisibleMethodWxpayEnabled: settings.PaymentVisibleMethodWxpayEnabled,
+ OpenAIAdvancedSchedulerEnabled: settings.OpenAIAdvancedSchedulerEnabled,
BalanceLowNotifyEnabled: settings.BalanceLowNotifyEnabled,
BalanceLowNotifyThreshold: settings.BalanceLowNotifyThreshold,
BalanceLowNotifyRechargeURL: settings.BalanceLowNotifyRechargeURL,
@@ -338,6 +343,15 @@ type UpdateSettingsRequest struct {
EnableMetadataPassthrough *bool `json:"enable_metadata_passthrough"`
EnableCCHSigning *bool `json:"enable_cch_signing"`
+ // Payment visible method routing
+ PaymentVisibleMethodAlipaySource *string `json:"payment_visible_method_alipay_source"`
+ PaymentVisibleMethodWxpaySource *string `json:"payment_visible_method_wxpay_source"`
+ PaymentVisibleMethodAlipayEnabled *bool `json:"payment_visible_method_alipay_enabled"`
+ PaymentVisibleMethodWxpayEnabled *bool `json:"payment_visible_method_wxpay_enabled"`
+
+ // OpenAI account scheduling
+ OpenAIAdvancedSchedulerEnabled *bool `json:"openai_advanced_scheduler_enabled"`
+
// Balance low notification
BalanceLowNotifyEnabled *bool `json:"balance_low_notify_enabled"`
BalanceLowNotifyThreshold *float64 `json:"balance_low_notify_threshold"`
@@ -935,6 +949,36 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
}
return previousSettings.EnableCCHSigning
}(),
+ PaymentVisibleMethodAlipaySource: func() string {
+ if req.PaymentVisibleMethodAlipaySource != nil {
+ return strings.TrimSpace(*req.PaymentVisibleMethodAlipaySource)
+ }
+ return previousSettings.PaymentVisibleMethodAlipaySource
+ }(),
+ PaymentVisibleMethodWxpaySource: func() string {
+ if req.PaymentVisibleMethodWxpaySource != nil {
+ return strings.TrimSpace(*req.PaymentVisibleMethodWxpaySource)
+ }
+ return previousSettings.PaymentVisibleMethodWxpaySource
+ }(),
+ PaymentVisibleMethodAlipayEnabled: func() bool {
+ if req.PaymentVisibleMethodAlipayEnabled != nil {
+ return *req.PaymentVisibleMethodAlipayEnabled
+ }
+ return previousSettings.PaymentVisibleMethodAlipayEnabled
+ }(),
+ PaymentVisibleMethodWxpayEnabled: func() bool {
+ if req.PaymentVisibleMethodWxpayEnabled != nil {
+ return *req.PaymentVisibleMethodWxpayEnabled
+ }
+ return previousSettings.PaymentVisibleMethodWxpayEnabled
+ }(),
+ OpenAIAdvancedSchedulerEnabled: func() bool {
+ if req.OpenAIAdvancedSchedulerEnabled != nil {
+ return *req.OpenAIAdvancedSchedulerEnabled
+ }
+ return previousSettings.OpenAIAdvancedSchedulerEnabled
+ }(),
BalanceLowNotifyEnabled: func() bool {
if req.BalanceLowNotifyEnabled != nil {
return *req.BalanceLowNotifyEnabled
@@ -1153,6 +1197,11 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
EnableFingerprintUnification: updatedSettings.EnableFingerprintUnification,
EnableMetadataPassthrough: updatedSettings.EnableMetadataPassthrough,
EnableCCHSigning: updatedSettings.EnableCCHSigning,
+ PaymentVisibleMethodAlipaySource: updatedSettings.PaymentVisibleMethodAlipaySource,
+ PaymentVisibleMethodWxpaySource: updatedSettings.PaymentVisibleMethodWxpaySource,
+ PaymentVisibleMethodAlipayEnabled: updatedSettings.PaymentVisibleMethodAlipayEnabled,
+ PaymentVisibleMethodWxpayEnabled: updatedSettings.PaymentVisibleMethodWxpayEnabled,
+ OpenAIAdvancedSchedulerEnabled: updatedSettings.OpenAIAdvancedSchedulerEnabled,
BalanceLowNotifyEnabled: updatedSettings.BalanceLowNotifyEnabled,
BalanceLowNotifyThreshold: updatedSettings.BalanceLowNotifyThreshold,
BalanceLowNotifyRechargeURL: updatedSettings.BalanceLowNotifyRechargeURL,
@@ -1455,6 +1504,21 @@ func diffSettings(before *service.SystemSettings, after *service.SystemSettings,
if before.EnableCCHSigning != after.EnableCCHSigning {
changed = append(changed, "enable_cch_signing")
}
+ if before.PaymentVisibleMethodAlipaySource != after.PaymentVisibleMethodAlipaySource {
+ changed = append(changed, "payment_visible_method_alipay_source")
+ }
+ if before.PaymentVisibleMethodWxpaySource != after.PaymentVisibleMethodWxpaySource {
+ changed = append(changed, "payment_visible_method_wxpay_source")
+ }
+ if before.PaymentVisibleMethodAlipayEnabled != after.PaymentVisibleMethodAlipayEnabled {
+ changed = append(changed, "payment_visible_method_alipay_enabled")
+ }
+ if before.PaymentVisibleMethodWxpayEnabled != after.PaymentVisibleMethodWxpayEnabled {
+ changed = append(changed, "payment_visible_method_wxpay_enabled")
+ }
+ if before.OpenAIAdvancedSchedulerEnabled != after.OpenAIAdvancedSchedulerEnabled {
+ changed = append(changed, "openai_advanced_scheduler_enabled")
+ }
// Balance & quota notification
if before.BalanceLowNotifyEnabled != after.BalanceLowNotifyEnabled {
changed = append(changed, "balance_low_notify_enabled")
diff --git a/backend/internal/handler/admin/setting_handler_auth_source_defaults_test.go b/backend/internal/handler/admin/setting_handler_auth_source_defaults_test.go
index b26fa447..bf51fc68 100644
--- a/backend/internal/handler/admin/setting_handler_auth_source_defaults_test.go
+++ b/backend/internal/handler/admin/setting_handler_auth_source_defaults_test.go
@@ -147,3 +147,77 @@ func TestSettingHandler_UpdateSettings_PreservesOmittedAuthSourceDefaults(t *tes
require.Equal(t, float64(8), data["auth_source_default_email_concurrency"])
require.Equal(t, true, data["force_email_on_third_party_signup"])
}
+
+func TestSettingHandler_UpdateSettings_PersistsPaymentVisibleMethodsAndAdvancedScheduler(t *testing.T) {
+ gin.SetMode(gin.TestMode)
+ repo := &settingHandlerRepoStub{
+ values: map[string]string{
+ service.SettingKeyPromoCodeEnabled: "true",
+ },
+ }
+ svc := service.NewSettingService(repo, &config.Config{Default: config.DefaultConfig{UserConcurrency: 5}})
+ handler := NewSettingHandler(svc, nil, nil, nil, nil, nil)
+
+ body := map[string]any{
+ "promo_code_enabled": true,
+ "payment_visible_method_alipay_source": "easypay",
+ "payment_visible_method_wxpay_source": "wxpay",
+ "payment_visible_method_alipay_enabled": true,
+ "payment_visible_method_wxpay_enabled": false,
+ "openai_advanced_scheduler_enabled": true,
+ }
+ rawBody, err := json.Marshal(body)
+ require.NoError(t, err)
+
+ rec := httptest.NewRecorder()
+ c, _ := gin.CreateTestContext(rec)
+ c.Request = httptest.NewRequest(http.MethodPut, "/api/v1/admin/settings", bytes.NewReader(rawBody))
+ c.Request.Header.Set("Content-Type", "application/json")
+
+ handler.UpdateSettings(c)
+
+ require.Equal(t, http.StatusOK, rec.Code)
+ require.Equal(t, service.VisibleMethodSourceEasyPayAlipay, repo.values[service.SettingPaymentVisibleMethodAlipaySource])
+ require.Equal(t, service.VisibleMethodSourceOfficialWechat, repo.values[service.SettingPaymentVisibleMethodWxpaySource])
+ require.Equal(t, "true", repo.values[service.SettingPaymentVisibleMethodAlipayEnabled])
+ require.Equal(t, "false", repo.values[service.SettingPaymentVisibleMethodWxpayEnabled])
+ require.Equal(t, "true", repo.values["openai_advanced_scheduler_enabled"])
+
+ var resp response.Response
+ require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &resp))
+ data, ok := resp.Data.(map[string]any)
+ require.True(t, ok)
+ require.Equal(t, service.VisibleMethodSourceEasyPayAlipay, data["payment_visible_method_alipay_source"])
+ require.Equal(t, service.VisibleMethodSourceOfficialWechat, data["payment_visible_method_wxpay_source"])
+ require.Equal(t, true, data["payment_visible_method_alipay_enabled"])
+ require.Equal(t, false, data["payment_visible_method_wxpay_enabled"])
+ require.Equal(t, true, data["openai_advanced_scheduler_enabled"])
+}
+
+func TestSettingHandler_UpdateSettings_RejectsInvalidPaymentVisibleMethodSource(t *testing.T) {
+ gin.SetMode(gin.TestMode)
+ repo := &settingHandlerRepoStub{
+ values: map[string]string{
+ service.SettingKeyPromoCodeEnabled: "true",
+ },
+ }
+ svc := service.NewSettingService(repo, &config.Config{Default: config.DefaultConfig{UserConcurrency: 5}})
+ handler := NewSettingHandler(svc, nil, nil, nil, nil, nil)
+
+ body := map[string]any{
+ "promo_code_enabled": true,
+ "payment_visible_method_alipay_source": "bogus",
+ }
+ rawBody, err := json.Marshal(body)
+ require.NoError(t, err)
+
+ rec := httptest.NewRecorder()
+ c, _ := gin.CreateTestContext(rec)
+ c.Request = httptest.NewRequest(http.MethodPut, "/api/v1/admin/settings", bytes.NewReader(rawBody))
+ c.Request.Header.Set("Content-Type", "application/json")
+
+ handler.UpdateSettings(c)
+
+ require.Equal(t, http.StatusBadRequest, rec.Code)
+ require.NotContains(t, repo.values, service.SettingPaymentVisibleMethodAlipaySource)
+}
diff --git a/backend/internal/handler/dto/settings.go b/backend/internal/handler/dto/settings.go
index 637e317b..cc3f8496 100644
--- a/backend/internal/handler/dto/settings.go
+++ b/backend/internal/handler/dto/settings.go
@@ -127,6 +127,15 @@ type SystemSettings struct {
// Web Search Emulation
WebSearchEmulationEnabled bool `json:"web_search_emulation_enabled"`
+ // Payment visible method routing
+ PaymentVisibleMethodAlipaySource string `json:"payment_visible_method_alipay_source"`
+ PaymentVisibleMethodWxpaySource string `json:"payment_visible_method_wxpay_source"`
+ PaymentVisibleMethodAlipayEnabled bool `json:"payment_visible_method_alipay_enabled"`
+ PaymentVisibleMethodWxpayEnabled bool `json:"payment_visible_method_wxpay_enabled"`
+
+ // OpenAI account scheduling
+ OpenAIAdvancedSchedulerEnabled bool `json:"openai_advanced_scheduler_enabled"`
+
// Payment configuration
PaymentEnabled bool `json:"payment_enabled"`
PaymentMinAmount float64 `json:"payment_min_amount"`
diff --git a/backend/internal/server/api_contract_test.go b/backend/internal/server/api_contract_test.go
index e903898f..533f2dac 100644
--- a/backend/internal/server/api_contract_test.go
+++ b/backend/internal/server/api_contract_test.go
@@ -500,10 +500,15 @@ func TestAPIContracts(t *testing.T) {
service.SettingKeyTableDefaultPageSize: "20",
service.SettingKeyTablePageSizeOptions: "[10,20,50,100]",
- service.SettingKeyOpsMonitoringEnabled: "false",
- service.SettingKeyOpsRealtimeMonitoringEnabled: "true",
- service.SettingKeyOpsQueryModeDefault: "auto",
- service.SettingKeyOpsMetricsIntervalSeconds: "60",
+ service.SettingKeyOpsMonitoringEnabled: "false",
+ service.SettingKeyOpsRealtimeMonitoringEnabled: "true",
+ service.SettingKeyOpsQueryModeDefault: "auto",
+ service.SettingKeyOpsMetricsIntervalSeconds: "60",
+ service.SettingPaymentVisibleMethodAlipaySource: service.VisibleMethodSourceEasyPayAlipay,
+ service.SettingPaymentVisibleMethodWxpaySource: service.VisibleMethodSourceOfficialWechat,
+ service.SettingPaymentVisibleMethodAlipayEnabled: "true",
+ service.SettingPaymentVisibleMethodWxpayEnabled: "false",
+ "openai_advanced_scheduler_enabled": "true",
})
},
method: http.MethodGet,
@@ -567,6 +572,27 @@ func TestAPIContracts(t *testing.T) {
"api_base_url": "https://api.example.com",
"contact_info": "support",
"doc_url": "https://docs.example.com",
+ "auth_source_default_email_balance": 0,
+ "auth_source_default_email_concurrency": 5,
+ "auth_source_default_email_subscriptions": [],
+ "auth_source_default_email_grant_on_signup": false,
+ "auth_source_default_email_grant_on_first_bind": false,
+ "auth_source_default_linuxdo_balance": 0,
+ "auth_source_default_linuxdo_concurrency": 5,
+ "auth_source_default_linuxdo_subscriptions": [],
+ "auth_source_default_linuxdo_grant_on_signup": false,
+ "auth_source_default_linuxdo_grant_on_first_bind": false,
+ "auth_source_default_oidc_balance": 0,
+ "auth_source_default_oidc_concurrency": 5,
+ "auth_source_default_oidc_subscriptions": [],
+ "auth_source_default_oidc_grant_on_signup": false,
+ "auth_source_default_oidc_grant_on_first_bind": false,
+ "auth_source_default_wechat_balance": 0,
+ "auth_source_default_wechat_concurrency": 5,
+ "auth_source_default_wechat_subscriptions": [],
+ "auth_source_default_wechat_grant_on_signup": false,
+ "auth_source_default_wechat_grant_on_first_bind": false,
+ "force_email_on_third_party_signup": false,
"default_concurrency": 5,
"default_balance": 1.25,
"default_subscriptions": [],
@@ -592,6 +618,11 @@ func TestAPIContracts(t *testing.T) {
"enable_fingerprint_unification": true,
"enable_metadata_passthrough": false,
"web_search_emulation_enabled": false,
+ "payment_visible_method_alipay_source": "easypay_alipay",
+ "payment_visible_method_wxpay_source": "official_wxpay",
+ "payment_visible_method_alipay_enabled": true,
+ "payment_visible_method_wxpay_enabled": false,
+ "openai_advanced_scheduler_enabled": true,
"custom_menu_items": [],
"custom_endpoints": [],
"payment_enabled": false,
@@ -858,6 +889,18 @@ func (r *stubUserRepo) Delete(ctx context.Context, id int64) error {
return errors.New("not implemented")
}
+func (r *stubUserRepo) GetUserAvatar(ctx context.Context, userID int64) (*service.UserAvatar, error) {
+ return nil, nil
+}
+
+func (r *stubUserRepo) UpsertUserAvatar(ctx context.Context, userID int64, input service.UpsertUserAvatarInput) (*service.UserAvatar, error) {
+ return nil, errors.New("not implemented")
+}
+
+func (r *stubUserRepo) DeleteUserAvatar(ctx context.Context, userID int64) error {
+ return errors.New("not implemented")
+}
+
func (r *stubUserRepo) List(ctx context.Context, params pagination.PaginationParams) ([]service.User, *pagination.PaginationResult, error) {
return nil, nil, errors.New("not implemented")
}
@@ -894,6 +937,18 @@ func (r *stubUserRepo) AddGroupToAllowedGroups(ctx context.Context, userID int64
return errors.New("not implemented")
}
+func (r *stubUserRepo) ListUserAuthIdentities(ctx context.Context, userID int64) ([]service.UserAuthIdentityRecord, error) {
+ return nil, errors.New("not implemented")
+}
+
+func (r *stubUserRepo) GetLatestUsedAtByUserIDs(ctx context.Context, userIDs []int64) (map[int64]*time.Time, error) {
+ return map[int64]*time.Time{}, nil
+}
+
+func (r *stubUserRepo) GetLatestUsedAtByUserID(ctx context.Context, userID int64) (*time.Time, error) {
+ return nil, nil
+}
+
func (r *stubUserRepo) UpdateTotpSecret(ctx context.Context, userID int64, encryptedSecret *string) error {
return errors.New("not implemented")
}
diff --git a/backend/internal/service/admin_service_apikey_test.go b/backend/internal/service/admin_service_apikey_test.go
index 487fb5f1..e2eae0b4 100644
--- a/backend/internal/service/admin_service_apikey_test.go
+++ b/backend/internal/service/admin_service_apikey_test.go
@@ -82,6 +82,12 @@ func (s *userRepoStubForGroupUpdate) DisableTotp(context.Context, int64) error {
func (s *userRepoStubForGroupUpdate) ListUserAuthIdentities(context.Context, int64) ([]UserAuthIdentityRecord, error) {
panic("unexpected")
}
+func (s *userRepoStubForGroupUpdate) GetLatestUsedAtByUserIDs(context.Context, []int64) (map[int64]*time.Time, error) {
+ panic("unexpected")
+}
+func (s *userRepoStubForGroupUpdate) GetLatestUsedAtByUserID(context.Context, int64) (*time.Time, error) {
+ panic("unexpected")
+}
func (s *userRepoStubForGroupUpdate) RemoveGroupFromUserAllowedGroups(context.Context, int64, int64) error {
panic("unexpected")
}
diff --git a/backend/internal/service/setting_service.go b/backend/internal/service/setting_service.go
index a2644fcd..02a64c1c 100644
--- a/backend/internal/service/setting_service.go
+++ b/backend/internal/service/setting_service.go
@@ -566,6 +566,16 @@ func (s *SettingService) UpdateSettings(ctx context.Context, settings *SystemSet
normalizedWhitelist = []string{}
}
settings.RegistrationEmailSuffixWhitelist = normalizedWhitelist
+ alipaySource, err := normalizeVisibleMethodSettingSource("alipay", settings.PaymentVisibleMethodAlipaySource, settings.PaymentVisibleMethodAlipayEnabled)
+ if err != nil {
+ return err
+ }
+ wxpaySource, err := normalizeVisibleMethodSettingSource("wxpay", settings.PaymentVisibleMethodWxpaySource, settings.PaymentVisibleMethodWxpayEnabled)
+ if err != nil {
+ return err
+ }
+ settings.PaymentVisibleMethodAlipaySource = alipaySource
+ settings.PaymentVisibleMethodWxpaySource = wxpaySource
updates := make(map[string]string)
@@ -701,6 +711,11 @@ func (s *SettingService) UpdateSettings(ctx context.Context, settings *SystemSet
updates[SettingKeyEnableFingerprintUnification] = strconv.FormatBool(settings.EnableFingerprintUnification)
updates[SettingKeyEnableMetadataPassthrough] = strconv.FormatBool(settings.EnableMetadataPassthrough)
updates[SettingKeyEnableCCHSigning] = strconv.FormatBool(settings.EnableCCHSigning)
+ updates[SettingPaymentVisibleMethodAlipaySource] = settings.PaymentVisibleMethodAlipaySource
+ updates[SettingPaymentVisibleMethodWxpaySource] = settings.PaymentVisibleMethodWxpaySource
+ updates[SettingPaymentVisibleMethodAlipayEnabled] = strconv.FormatBool(settings.PaymentVisibleMethodAlipayEnabled)
+ updates[SettingPaymentVisibleMethodWxpayEnabled] = strconv.FormatBool(settings.PaymentVisibleMethodWxpayEnabled)
+ updates[openAIAdvancedSchedulerSettingKey] = strconv.FormatBool(settings.OpenAIAdvancedSchedulerEnabled)
// Balance low notification
updates[SettingKeyBalanceLowNotifyEnabled] = strconv.FormatBool(settings.BalanceLowNotifyEnabled)
@@ -730,6 +745,11 @@ func (s *SettingService) UpdateSettings(ctx context.Context, settings *SystemSet
cchSigning: settings.EnableCCHSigning,
expiresAt: time.Now().Add(gatewayForwardingCacheTTL).UnixNano(),
})
+ openAIAdvancedSchedulerSettingSF.Forget(openAIAdvancedSchedulerSettingKey)
+ openAIAdvancedSchedulerSettingCache.Store(&cachedOpenAIAdvancedSchedulerSetting{
+ enabled: settings.OpenAIAdvancedSchedulerEnabled,
+ expiresAt: time.Now().Add(openAIAdvancedSchedulerSettingCacheTTL).UnixNano(),
+ })
if s.onUpdate != nil {
s.onUpdate() // Invalidate cache after settings update
}
@@ -1137,7 +1157,12 @@ func (s *SettingService) InitializeDefaultSettings(ctx context.Context) error {
SettingKeyMaxClaudeCodeVersion: "",
// 分组隔离(默认不允许未分组 Key 调度)
- SettingKeyAllowUngroupedKeyScheduling: "false",
+ SettingKeyAllowUngroupedKeyScheduling: "false",
+ SettingPaymentVisibleMethodAlipaySource: "",
+ SettingPaymentVisibleMethodWxpaySource: "",
+ SettingPaymentVisibleMethodAlipayEnabled: "false",
+ SettingPaymentVisibleMethodWxpayEnabled: "false",
+ openAIAdvancedSchedulerSettingKey: "false",
}
return s.settingRepo.SetMultiple(ctx, defaults)
@@ -1429,6 +1454,11 @@ func (s *SettingService) parseSettings(settings map[string]string) *SystemSettin
result.WebSearchEmulationEnabled = wsCfg.Enabled && len(wsCfg.Providers) > 0
}
}
+ result.PaymentVisibleMethodAlipaySource = NormalizeVisibleMethodSource("alipay", settings[SettingPaymentVisibleMethodAlipaySource])
+ result.PaymentVisibleMethodWxpaySource = NormalizeVisibleMethodSource("wxpay", settings[SettingPaymentVisibleMethodWxpaySource])
+ result.PaymentVisibleMethodAlipayEnabled = settings[SettingPaymentVisibleMethodAlipayEnabled] == "true"
+ result.PaymentVisibleMethodWxpayEnabled = settings[SettingPaymentVisibleMethodWxpayEnabled] == "true"
+ result.OpenAIAdvancedSchedulerEnabled = settings[openAIAdvancedSchedulerSettingKey] == "true"
// Balance low notification
result.BalanceLowNotifyEnabled = settings[SettingKeyBalanceLowNotifyEnabled] == "true"
@@ -1458,6 +1488,28 @@ func isFalseSettingValue(value string) bool {
}
}
+func normalizeVisibleMethodSettingSource(method, source string, enabled bool) (string, error) {
+ source = strings.TrimSpace(source)
+ if source == "" {
+ if enabled {
+ return "", infraerrors.BadRequest(
+ "INVALID_PAYMENT_VISIBLE_METHOD_SOURCE",
+ fmt.Sprintf("%s source is required when the visible method is enabled", method),
+ )
+ }
+ return "", nil
+ }
+
+ normalized := NormalizeVisibleMethodSource(method, source)
+ if normalized == "" {
+ return "", infraerrors.BadRequest(
+ "INVALID_PAYMENT_VISIBLE_METHOD_SOURCE",
+ fmt.Sprintf("%s source must be one of the supported payment providers", method),
+ )
+ }
+ return normalized, nil
+}
+
func parseDefaultSubscriptions(raw string) []DefaultSubscriptionSetting {
raw = strings.TrimSpace(raw)
if raw == "" {
diff --git a/backend/internal/service/setting_service_update_test.go b/backend/internal/service/setting_service_update_test.go
index e62218b4..9dc0ca59 100644
--- a/backend/internal/service/setting_service_update_test.go
+++ b/backend/internal/service/setting_service_update_test.go
@@ -223,3 +223,34 @@ func TestSettingService_UpdateSettings_TablePreferences(t *testing.T) {
require.Equal(t, "1000", repo.updates[SettingKeyTableDefaultPageSize])
require.Equal(t, "[20,100]", repo.updates[SettingKeyTablePageSizeOptions])
}
+
+func TestSettingService_UpdateSettings_PaymentVisibleMethodsAndAdvancedScheduler(t *testing.T) {
+ repo := &settingUpdateRepoStub{}
+ svc := NewSettingService(repo, &config.Config{})
+
+ err := svc.UpdateSettings(context.Background(), &SystemSettings{
+ PaymentVisibleMethodAlipaySource: "alipay",
+ PaymentVisibleMethodWxpaySource: "easypay",
+ PaymentVisibleMethodAlipayEnabled: true,
+ PaymentVisibleMethodWxpayEnabled: false,
+ OpenAIAdvancedSchedulerEnabled: true,
+ })
+ require.NoError(t, err)
+ require.Equal(t, VisibleMethodSourceOfficialAlipay, repo.updates[SettingPaymentVisibleMethodAlipaySource])
+ require.Equal(t, VisibleMethodSourceEasyPayWechat, repo.updates[SettingPaymentVisibleMethodWxpaySource])
+ require.Equal(t, "true", repo.updates[SettingPaymentVisibleMethodAlipayEnabled])
+ require.Equal(t, "false", repo.updates[SettingPaymentVisibleMethodWxpayEnabled])
+ require.Equal(t, "true", repo.updates[openAIAdvancedSchedulerSettingKey])
+}
+
+func TestSettingService_UpdateSettings_RejectsInvalidPaymentVisibleMethodSource(t *testing.T) {
+ repo := &settingUpdateRepoStub{}
+ svc := NewSettingService(repo, &config.Config{})
+
+ err := svc.UpdateSettings(context.Background(), &SystemSettings{
+ PaymentVisibleMethodAlipaySource: "not-a-provider",
+ })
+ require.Error(t, err)
+ require.Equal(t, "INVALID_PAYMENT_VISIBLE_METHOD_SOURCE", infraerrors.Reason(err))
+ require.Nil(t, repo.updates)
+}
diff --git a/backend/internal/service/settings_view.go b/backend/internal/service/settings_view.go
index 72db4e31..9bd461f9 100644
--- a/backend/internal/service/settings_view.go
+++ b/backend/internal/service/settings_view.go
@@ -110,6 +110,15 @@ type SystemSettings struct {
// Web Search Emulation
WebSearchEmulationEnabled bool // 是否启用 web search 模拟
+ // Payment visible method routing
+ PaymentVisibleMethodAlipaySource string
+ PaymentVisibleMethodWxpaySource string
+ PaymentVisibleMethodAlipayEnabled bool
+ PaymentVisibleMethodWxpayEnabled bool
+
+ // OpenAI account scheduling
+ OpenAIAdvancedSchedulerEnabled bool
+
// Balance low notification
BalanceLowNotifyEnabled bool
BalanceLowNotifyThreshold float64
diff --git a/backend/internal/service/user_service_test.go b/backend/internal/service/user_service_test.go
index 89f0362e..e13fb95d 100644
--- a/backend/internal/service/user_service_test.go
+++ b/backend/internal/service/user_service_test.go
@@ -103,6 +103,12 @@ func (m *mockUserRepo) AddGroupToAllowedGroups(context.Context, int64, int64) er
func (m *mockUserRepo) ListUserAuthIdentities(context.Context, int64) ([]UserAuthIdentityRecord, error) {
return nil, nil
}
+func (m *mockUserRepo) GetLatestUsedAtByUserIDs(context.Context, []int64) (map[int64]*time.Time, error) {
+ return map[int64]*time.Time{}, nil
+}
+func (m *mockUserRepo) GetLatestUsedAtByUserID(context.Context, int64) (*time.Time, error) {
+ return nil, nil
+}
func (m *mockUserRepo) UpdateTotpSecret(context.Context, int64, *string) error { return nil }
func (m *mockUserRepo) EnableTotp(context.Context, int64) error { return nil }
func (m *mockUserRepo) DisableTotp(context.Context, int64) error { return nil }
From 4f6966d7b3979809da9238498bf403ced840d20b Mon Sep 17 00:00:00 2001
From: IanShaw027
Date: Tue, 21 Apr 2026 00:05:42 +0800
Subject: [PATCH 063/189] frontend: route wechat oauth entry by public settings
---
frontend/src/api/auth.ts | 55 ++++++
.../components/auth/WechatOAuthSection.vue | 62 ++++++-
.../auth/__tests__/WechatOAuthSection.spec.ts | 170 ++++++++++++++++--
3 files changed, 261 insertions(+), 26 deletions(-)
diff --git a/frontend/src/api/auth.ts b/frontend/src/api/auth.ts
index 6a2feb87..6c877d76 100644
--- a/frontend/src/api/auth.ts
+++ b/frontend/src/api/auth.ts
@@ -349,6 +349,61 @@ export async function getPublicSettings(): Promise {
return data
}
+export type WeChatOAuthMode = 'open' | 'mp'
+export type WeChatOAuthUnavailableReason =
+ | 'not_configured'
+ | 'external_browser_required'
+ | 'wechat_browser_required'
+
+export interface ResolvedWeChatOAuthStart {
+ mode: WeChatOAuthMode | null
+ openEnabled: boolean
+ mpEnabled: boolean
+ isWeChatBrowser: boolean
+ unavailableReason: WeChatOAuthUnavailableReason | null
+}
+
+type WeChatOAuthPublicSettings = {
+ wechat_oauth_enabled?: boolean
+ wechat_oauth_open_enabled?: boolean
+ wechat_oauth_mp_enabled?: boolean
+}
+
+export function resolveWeChatOAuthStart(
+ settings: WeChatOAuthPublicSettings | null | undefined,
+ userAgent?: string
+): ResolvedWeChatOAuthStart {
+ const normalizedUserAgent = (userAgent
+ ?? (typeof navigator !== 'undefined' ? navigator.userAgent : '')
+ ?? '').trim()
+ const isWeChatBrowser = /MicroMessenger/i.test(normalizedUserAgent)
+ const legacyEnabled = settings?.wechat_oauth_enabled ?? false
+ const openEnabled = typeof settings?.wechat_oauth_open_enabled === 'boolean'
+ ? settings.wechat_oauth_open_enabled
+ : legacyEnabled
+ const mpEnabled = typeof settings?.wechat_oauth_mp_enabled === 'boolean'
+ ? settings.wechat_oauth_mp_enabled
+ : legacyEnabled
+
+ if (isWeChatBrowser) {
+ if (mpEnabled) {
+ return { mode: 'mp', openEnabled, mpEnabled, isWeChatBrowser, unavailableReason: null }
+ }
+ if (openEnabled) {
+ return { mode: null, openEnabled, mpEnabled, isWeChatBrowser, unavailableReason: 'external_browser_required' }
+ }
+ return { mode: null, openEnabled, mpEnabled, isWeChatBrowser, unavailableReason: 'not_configured' }
+ }
+
+ if (openEnabled) {
+ return { mode: 'open', openEnabled, mpEnabled, isWeChatBrowser, unavailableReason: null }
+ }
+ if (mpEnabled) {
+ return { mode: null, openEnabled, mpEnabled, isWeChatBrowser, unavailableReason: 'wechat_browser_required' }
+ }
+ return { mode: null, openEnabled, mpEnabled, isWeChatBrowser, unavailableReason: 'not_configured' }
+}
+
/**
* Send verification code to email
* @param request - Email and optional Turnstile token
diff --git a/frontend/src/components/auth/WechatOAuthSection.vue b/frontend/src/components/auth/WechatOAuthSection.vue
index 94e20222..01bbd180 100644
--- a/frontend/src/components/auth/WechatOAuthSection.vue
+++ b/frontend/src/components/auth/WechatOAuthSection.vue
@@ -1,6 +1,6 @@
-
+
@@ -9,6 +9,14 @@
{{ t('auth.oidc.signIn', { providerName }) }}
+
+ {{ disabledHint }}
+
+
@@ -20,33 +28,69 @@
diff --git a/frontend/src/views/admin/__tests__/AuthIdentityMigrationReportsView.spec.ts b/frontend/src/views/admin/__tests__/AuthIdentityMigrationReportsView.spec.ts
new file mode 100644
index 00000000..5e6b0ae0
--- /dev/null
+++ b/frontend/src/views/admin/__tests__/AuthIdentityMigrationReportsView.spec.ts
@@ -0,0 +1,243 @@
+import { beforeEach, describe, expect, it, vi } from 'vitest'
+import { flushPromises, mount } from '@vue/test-utils'
+import { defineComponent, h } from 'vue'
+
+import AuthIdentityMigrationReportsView from '../AuthIdentityMigrationReportsView.vue'
+
+const { getAuthIdentityMigrationReportSummary, listAuthIdentityMigrationReports, resolveAuthIdentityMigrationReport } = vi.hoisted(() => ({
+ getAuthIdentityMigrationReportSummary: vi.fn(),
+ listAuthIdentityMigrationReports: vi.fn(),
+ resolveAuthIdentityMigrationReport: vi.fn(),
+}))
+
+const { showError, showSuccess } = vi.hoisted(() => ({
+ showError: vi.fn(),
+ showSuccess: vi.fn(),
+}))
+
+vi.mock('@/api/admin', () => ({
+ adminAPI: {
+ users: {
+ getAuthIdentityMigrationReportSummary,
+ listAuthIdentityMigrationReports,
+ resolveAuthIdentityMigrationReport,
+ },
+ },
+}))
+
+vi.mock('@/stores/app', () => ({
+ useAppStore: () => ({
+ showError,
+ showSuccess,
+ }),
+}))
+
+vi.mock('vue-i18n', async () => {
+ const actual = await vi.importActual('vue-i18n')
+ return {
+ ...actual,
+ useI18n: () => ({
+ locale: { value: 'en' },
+ t: (key: string) => key,
+ }),
+ }
+})
+
+vi.mock('@/utils/format', () => ({
+ formatDateTime: (value: string | null | undefined) => value ?? '',
+}))
+
+const sampleReport = {
+ id: 1,
+ report_type: 'oidc_synthetic_email_requires_manual_recovery',
+ report_key: 'legacy@example.invalid',
+ details: {
+ user_id: 42,
+ legacy_email: 'legacy@example.invalid',
+ provider_key: 'https://issuer.example',
+ provider_subject: 'subject-123',
+ },
+ created_at: '2026-04-20T01:02:03Z',
+ resolved_at: null,
+ resolved_by_user_id: null,
+ resolution_note: '',
+}
+
+const summaryResponse = {
+ total: 2,
+ open_total: 1,
+ resolved_total: 1,
+ by_type: {
+ oidc_synthetic_email_requires_manual_recovery: 2,
+ },
+}
+
+const listResponse = {
+ items: [sampleReport],
+ total: 1,
+ page: 1,
+ page_size: 20,
+ pages: 1,
+}
+
+const AppLayoutStub = defineComponent({
+ setup(_, { slots }) {
+ return () => h('div', slots.default?.())
+ },
+})
+
+const TablePageLayoutStub = defineComponent({
+ setup(_, { slots }) {
+ return () => h('div', [
+ slots.actions?.(),
+ slots.filters?.(),
+ slots.table?.(),
+ slots.default?.(),
+ slots.pagination?.(),
+ ])
+ },
+})
+
+const DataTableStub = defineComponent({
+ props: {
+ columns: { type: Array, default: () => [] },
+ data: { type: Array, default: () => [] },
+ loading: { type: Boolean, default: false },
+ },
+ setup(props, { slots }) {
+ return () => h('div', { 'data-test': 'data-table' }, [
+ props.loading
+ ? h('div', 'loading')
+ : (props.data as Array>).map((row) =>
+ h(
+ 'div',
+ { key: String(row.id ?? row.report_key) },
+ (props.columns as Array<{ key: string }>).map((column) => {
+ const slot = slots[`cell-${column.key}`]
+ return h(
+ 'div',
+ { key: column.key, [`data-test-cell`]: `${String(row.id)}-${column.key}` },
+ slot
+ ? slot({ row, value: row[column.key] })
+ : String(row[column.key] ?? '')
+ )
+ })
+ )
+ ),
+ ])
+ },
+})
+
+const PaginationStub = defineComponent({
+ props: {
+ total: { type: Number, required: true },
+ page: { type: Number, required: true },
+ pageSize: { type: Number, required: true },
+ },
+ emits: ['update:page', 'update:pageSize'],
+ setup(props, { emit }) {
+ return () => h('div', { 'data-test': 'pagination' }, [
+ h('button', {
+ type: 'button',
+ 'data-test': 'next-page',
+ onClick: () => emit('update:page', props.page + 1),
+ }, 'next'),
+ h('button', {
+ type: 'button',
+ 'data-test': 'page-size-50',
+ onClick: () => emit('update:pageSize', 50),
+ }, '50'),
+ ])
+ },
+})
+
+describe('AuthIdentityMigrationReportsView', () => {
+ beforeEach(() => {
+ getAuthIdentityMigrationReportSummary.mockReset()
+ listAuthIdentityMigrationReports.mockReset()
+ resolveAuthIdentityMigrationReport.mockReset()
+ showError.mockReset()
+ showSuccess.mockReset()
+
+ getAuthIdentityMigrationReportSummary.mockResolvedValue(summaryResponse)
+ listAuthIdentityMigrationReports.mockResolvedValue(listResponse)
+ resolveAuthIdentityMigrationReport.mockResolvedValue({
+ ...sampleReport,
+ resolved_at: '2026-04-20T02:00:00Z',
+ resolved_by_user_id: 100,
+ resolution_note: 'resolved by admin',
+ })
+ })
+
+ const mountView = () =>
+ mount(AuthIdentityMigrationReportsView, {
+ global: {
+ stubs: {
+ AppLayout: AppLayoutStub,
+ TablePageLayout: TablePageLayoutStub,
+ DataTable: DataTableStub,
+ Pagination: PaginationStub,
+ Icon: true,
+ },
+ },
+ })
+
+ it('loads summary and first page of reports on mount', async () => {
+ const wrapper = mountView()
+
+ await flushPromises()
+
+ expect(getAuthIdentityMigrationReportSummary).toHaveBeenCalledTimes(1)
+ expect(listAuthIdentityMigrationReports).toHaveBeenCalledWith({
+ page: 1,
+ pageSize: 20,
+ reportType: '',
+ })
+ expect(wrapper.get('[data-test="summary-total"]').text()).toContain('2')
+ expect(wrapper.get('[data-test="summary-open"]').text()).toContain('1')
+ expect(wrapper.get('[data-test="summary-resolved"]').text()).toContain('1')
+ expect(wrapper.text()).toContain('legacy@example.invalid')
+ })
+
+ it('reloads list when the report type filter changes', async () => {
+ const wrapper = mountView()
+
+ await flushPromises()
+
+ listAuthIdentityMigrationReports.mockClear()
+
+ await wrapper.get('[data-test="report-type-filter"]').setValue(
+ 'oidc_synthetic_email_requires_manual_recovery'
+ )
+ await flushPromises()
+
+ expect(listAuthIdentityMigrationReports).toHaveBeenCalledWith({
+ page: 1,
+ pageSize: 20,
+ reportType: 'oidc_synthetic_email_requires_manual_recovery',
+ })
+ })
+
+ it('submits resolve note for the selected report and refreshes data', async () => {
+ const wrapper = mountView()
+
+ await flushPromises()
+
+ getAuthIdentityMigrationReportSummary.mockClear()
+ listAuthIdentityMigrationReports.mockClear()
+
+ await wrapper.get('[data-test="select-report-1"]').trigger('click')
+ await wrapper.get('[data-test="resolution-note"]').setValue('resolved by admin')
+ await wrapper.get('[data-test="resolve-submit"]').trigger('click')
+ await flushPromises()
+
+ expect(resolveAuthIdentityMigrationReport).toHaveBeenCalledWith(1, 'resolved by admin')
+ expect(showSuccess).toHaveBeenCalled()
+ expect(getAuthIdentityMigrationReportSummary).toHaveBeenCalledTimes(1)
+ expect(listAuthIdentityMigrationReports).toHaveBeenCalledWith({
+ page: 1,
+ pageSize: 20,
+ reportType: '',
+ })
+ })
+})
From 9204145746623d382f2fe5f9f2f44723042942dd Mon Sep 17 00:00:00 2001
From: IanShaw027
Date: Tue, 21 Apr 2026 00:11:03 +0800
Subject: [PATCH 066/189] Close profile identity and avatar loop
---
backend/internal/handler/auth_handler.go | 1 +
.../internal/handler/auth_linuxdo_oauth.go | 2 +-
.../handler/auth_oauth_pending_flow.go | 15 +-
.../handler/auth_oauth_pending_flow_test.go | 145 ++++++++++++-
backend/internal/handler/auth_oidc_oauth.go | 2 +-
backend/internal/handler/auth_wechat_oauth.go | 2 +-
backend/internal/handler/user_handler.go | 112 +++++++++-
backend/internal/handler/user_handler_test.go | 79 +++++++
backend/internal/service/user_service.go | 75 +++++--
frontend/src/api/user.ts | 1 +
.../user/profile/ProfileInfoCard.vue | 202 +++++++++++++++++-
.../profile/__tests__/ProfileInfoCard.spec.ts | 172 +++++++++++++++
frontend/src/i18n/locales/en.ts | 14 ++
frontend/src/i18n/locales/zh.ts | 14 ++
14 files changed, 801 insertions(+), 35 deletions(-)
create mode 100644 frontend/src/components/user/profile/__tests__/ProfileInfoCard.spec.ts
diff --git a/backend/internal/handler/auth_handler.go b/backend/internal/handler/auth_handler.go
index e4697609..b984a436 100644
--- a/backend/internal/handler/auth_handler.go
+++ b/backend/internal/handler/auth_handler.go
@@ -296,6 +296,7 @@ func (h *AuthHandler) Login2FA(c *gin.Context) {
c.Request.Context(),
h.entClient(),
h.authService,
+ h.userService,
pendingSession,
decision,
&user.ID,
diff --git a/backend/internal/handler/auth_linuxdo_oauth.go b/backend/internal/handler/auth_linuxdo_oauth.go
index a0760a3b..175b1e1f 100644
--- a/backend/internal/handler/auth_linuxdo_oauth.go
+++ b/backend/internal/handler/auth_linuxdo_oauth.go
@@ -495,7 +495,7 @@ func (h *AuthHandler) CompleteLinuxDoOAuthRegistration(c *gin.Context) {
response.ErrorFrom(c, err)
return
}
- if err := applyPendingOAuthAdoption(c.Request.Context(), h.entClient(), h.authService, session, decision, &user.ID); err != nil {
+ if err := applyPendingOAuthAdoption(c.Request.Context(), h.entClient(), h.authService, h.userService, session, decision, &user.ID); err != nil {
response.ErrorFrom(c, infraerrors.InternalServer("PENDING_AUTH_ADOPTION_APPLY_FAILED", "failed to apply oauth profile adoption").WithCause(err))
return
}
diff --git a/backend/internal/handler/auth_oauth_pending_flow.go b/backend/internal/handler/auth_oauth_pending_flow.go
index 1b3c1380..94186858 100644
--- a/backend/internal/handler/auth_oauth_pending_flow.go
+++ b/backend/internal/handler/auth_oauth_pending_flow.go
@@ -852,6 +852,7 @@ func applyPendingOAuthBinding(
ctx context.Context,
client *dbent.Client,
authService *service.AuthService,
+ userService *service.UserService,
session *dbent.PendingAuthSession,
decision *dbent.IdentityAdoptionDecision,
overrideUserID *int64,
@@ -938,6 +939,12 @@ func applyPendingOAuthBinding(
}
}
+ if decision != nil && decision.AdoptAvatar && adoptedAvatarURL != "" && userService != nil {
+ if _, err := userService.SetAvatar(txCtx, targetUserID, adoptedAvatarURL); err != nil {
+ return err
+ }
+ }
+
return tx.Commit()
}
@@ -945,6 +952,7 @@ func applyPendingOAuthAdoption(
ctx context.Context,
client *dbent.Client,
authService *service.AuthService,
+ userService *service.UserService,
session *dbent.PendingAuthSession,
decision *dbent.IdentityAdoptionDecision,
overrideUserID *int64,
@@ -953,6 +961,7 @@ func applyPendingOAuthAdoption(
ctx,
client,
authService,
+ userService,
session,
decision,
overrideUserID,
@@ -1092,7 +1101,7 @@ func (h *AuthHandler) bindPendingOAuthLogin(c *gin.Context, provider string) {
})
return
}
- if err := applyPendingOAuthBinding(c.Request.Context(), h.entClient(), h.authService, session, decision, &user.ID, true, true); err != nil {
+ if err := applyPendingOAuthBinding(c.Request.Context(), h.entClient(), h.authService, h.userService, session, decision, &user.ID, true, true); err != nil {
response.ErrorFrom(c, infraerrors.InternalServer("PENDING_AUTH_BIND_APPLY_FAILED", "failed to bind pending oauth identity").WithCause(err))
return
}
@@ -1188,7 +1197,7 @@ func (h *AuthHandler) createPendingOAuthAccount(c *gin.Context, provider string)
response.ErrorFrom(c, err)
return
}
- if err := applyPendingOAuthBinding(c.Request.Context(), client, h.authService, session, decision, &user.ID, true, false); err != nil {
+ if err := applyPendingOAuthBinding(c.Request.Context(), client, h.authService, h.userService, session, decision, &user.ID, true, false); err != nil {
response.ErrorFrom(c, infraerrors.InternalServer("PENDING_AUTH_BIND_APPLY_FAILED", "failed to bind pending oauth identity").WithCause(err))
return
}
@@ -1278,7 +1287,7 @@ func (h *AuthHandler) ExchangePendingOAuthCompletion(c *gin.Context) {
response.ErrorFrom(c, err)
return
}
- if err := applyPendingOAuthAdoption(c.Request.Context(), h.entClient(), h.authService, session, decision, session.TargetUserID); err != nil {
+ if err := applyPendingOAuthAdoption(c.Request.Context(), h.entClient(), h.authService, h.userService, session, decision, session.TargetUserID); err != nil {
response.ErrorFrom(c, infraerrors.InternalServer("PENDING_AUTH_ADOPTION_APPLY_FAILED", "failed to apply oauth profile adoption").WithCause(err))
return
}
diff --git a/backend/internal/handler/auth_oauth_pending_flow_test.go b/backend/internal/handler/auth_oauth_pending_flow_test.go
index 8c468fdc..2521186e 100644
--- a/backend/internal/handler/auth_oauth_pending_flow_test.go
+++ b/backend/internal/handler/auth_oauth_pending_flow_test.go
@@ -152,6 +152,11 @@ func TestExchangePendingOAuthCompletionPreviewThenFinalizeAppliesAdoptionDecisio
require.Equal(t, "Alice Example", identity.Metadata["display_name"])
require.Equal(t, "https://cdn.example/alice.png", identity.Metadata["avatar_url"])
+ avatar := loadUserAvatarRecord(t, client, userEntity.ID)
+ require.NotNil(t, avatar)
+ require.Equal(t, "remote_url", avatar.StorageProvider)
+ require.Equal(t, "https://cdn.example/alice.png", avatar.URL)
+
decision, err := client.IdentityAdoptionDecision.Query().
Where(identityadoptiondecision.PendingAuthSessionIDEQ(session.ID)).
Only(ctx)
@@ -1242,6 +1247,18 @@ CREATE TABLE IF NOT EXISTS user_provider_default_grants (
UNIQUE(user_id, provider_type, grant_reason)
)`)
require.NoError(t, err)
+ _, err = db.Exec(`
+CREATE TABLE IF NOT EXISTS user_avatars (
+ user_id INTEGER PRIMARY KEY,
+ storage_provider TEXT NOT NULL,
+ storage_key TEXT NOT NULL DEFAULT '',
+ url TEXT NOT NULL,
+ content_type TEXT NOT NULL DEFAULT '',
+ byte_size INTEGER NOT NULL DEFAULT 0,
+ sha256 TEXT NOT NULL DEFAULT '',
+ updated_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP
+)`)
+ require.NoError(t, err)
drv := entsql.OpenDB(dialect.SQLite, db)
client := enttest.NewClient(t, enttest.WithOptions(dbent.Driver(drv)))
@@ -1492,6 +1509,35 @@ func decodeJSONBody(t *testing.T, recorder *httptest.ResponseRecorder) map[strin
return payload
}
+type oauthPendingFlowAvatarRecord struct {
+ StorageProvider string
+ URL string
+}
+
+func loadUserAvatarRecord(t *testing.T, client *dbent.Client, userID int64) *oauthPendingFlowAvatarRecord {
+ t.Helper()
+
+ var rows entsql.Rows
+ err := client.Driver().Query(
+ context.Background(),
+ `SELECT storage_provider, url FROM user_avatars WHERE user_id = ?`,
+ []any{userID},
+ &rows,
+ )
+ require.NoError(t, err)
+ defer rows.Close()
+
+ if !rows.Next() {
+ require.NoError(t, rows.Err())
+ return nil
+ }
+
+ var record oauthPendingFlowAvatarRecord
+ require.NoError(t, rows.Scan(&record.StorageProvider, &record.URL))
+ require.NoError(t, rows.Err())
+ return &record
+}
+
func countProviderGrantRecords(
t *testing.T,
client *dbent.Client,
@@ -1604,16 +1650,95 @@ func (r *oauthPendingFlowUserRepo) Delete(ctx context.Context, id int64) error {
return r.client.User.DeleteOneID(id).Exec(ctx)
}
-func (r *oauthPendingFlowUserRepo) GetUserAvatar(context.Context, int64) (*service.UserAvatar, error) {
- return nil, nil
+func (r *oauthPendingFlowUserRepo) GetUserAvatar(ctx context.Context, userID int64) (*service.UserAvatar, error) {
+ driver := r.client.Driver()
+ if tx := dbent.TxFromContext(ctx); tx != nil {
+ driver = tx.Client().Driver()
+ }
+
+ var rows entsql.Rows
+ if err := driver.Query(
+ ctx,
+ `SELECT storage_provider, storage_key, url, content_type, byte_size, sha256 FROM user_avatars WHERE user_id = ?`,
+ []any{userID},
+ &rows,
+ ); err != nil {
+ return nil, err
+ }
+ defer rows.Close()
+
+ if !rows.Next() {
+ return nil, rows.Err()
+ }
+
+ var avatar service.UserAvatar
+ if err := rows.Scan(
+ &avatar.StorageProvider,
+ &avatar.StorageKey,
+ &avatar.URL,
+ &avatar.ContentType,
+ &avatar.ByteSize,
+ &avatar.SHA256,
+ ); err != nil {
+ return nil, err
+ }
+ if err := rows.Err(); err != nil {
+ return nil, err
+ }
+ return &avatar, nil
}
-func (r *oauthPendingFlowUserRepo) UpsertUserAvatar(context.Context, int64, service.UpsertUserAvatarInput) (*service.UserAvatar, error) {
- panic("unexpected UpsertUserAvatar call")
+func (r *oauthPendingFlowUserRepo) UpsertUserAvatar(ctx context.Context, userID int64, input service.UpsertUserAvatarInput) (*service.UserAvatar, error) {
+ driver := r.client.Driver()
+ if tx := dbent.TxFromContext(ctx); tx != nil {
+ driver = tx.Client().Driver()
+ }
+
+ var result entsql.Result
+ if err := driver.Exec(
+ ctx,
+ `INSERT INTO user_avatars (user_id, storage_provider, storage_key, url, content_type, byte_size, sha256, updated_at)
+VALUES (?, ?, ?, ?, ?, ?, ?, CURRENT_TIMESTAMP)
+ON CONFLICT(user_id) DO UPDATE SET
+ storage_provider = excluded.storage_provider,
+ storage_key = excluded.storage_key,
+ url = excluded.url,
+ content_type = excluded.content_type,
+ byte_size = excluded.byte_size,
+ sha256 = excluded.sha256,
+ updated_at = CURRENT_TIMESTAMP`,
+ []any{
+ userID,
+ input.StorageProvider,
+ input.StorageKey,
+ input.URL,
+ input.ContentType,
+ input.ByteSize,
+ input.SHA256,
+ },
+ &result,
+ ); err != nil {
+ return nil, err
+ }
+
+ return &service.UserAvatar{
+ StorageProvider: input.StorageProvider,
+ StorageKey: input.StorageKey,
+ URL: input.URL,
+ ContentType: input.ContentType,
+ ByteSize: input.ByteSize,
+ SHA256: input.SHA256,
+ }, nil
}
-func (r *oauthPendingFlowUserRepo) DeleteUserAvatar(context.Context, int64) error {
- return nil
+func (r *oauthPendingFlowUserRepo) DeleteUserAvatar(ctx context.Context, userID int64) error {
+ driver := r.client.Driver()
+ if tx := dbent.TxFromContext(ctx); tx != nil {
+ driver = tx.Client().Driver()
+ }
+
+ var result entsql.Result
+ return driver.Exec(ctx, `DELETE FROM user_avatars WHERE user_id = ?`, []any{userID}, &result)
}
func (r *oauthPendingFlowUserRepo) List(context.Context, pagination.PaginationParams) ([]service.User, *pagination.PaginationResult, error) {
@@ -1636,6 +1761,14 @@ func (r *oauthPendingFlowUserRepo) UpdateConcurrency(context.Context, int64, int
panic("unexpected UpdateConcurrency call")
}
+func (r *oauthPendingFlowUserRepo) GetLatestUsedAtByUserIDs(context.Context, []int64) (map[int64]*time.Time, error) {
+ return map[int64]*time.Time{}, nil
+}
+
+func (r *oauthPendingFlowUserRepo) GetLatestUsedAtByUserID(context.Context, int64) (*time.Time, error) {
+ return nil, nil
+}
+
func (r *oauthPendingFlowUserRepo) ExistsByEmail(ctx context.Context, email string) (bool, error) {
count, err := r.client.User.Query().Where(dbuser.EmailEQ(email)).Count(ctx)
return count > 0, err
diff --git a/backend/internal/handler/auth_oidc_oauth.go b/backend/internal/handler/auth_oidc_oauth.go
index 70424ec5..0f9f1895 100644
--- a/backend/internal/handler/auth_oidc_oauth.go
+++ b/backend/internal/handler/auth_oidc_oauth.go
@@ -537,7 +537,7 @@ func (h *AuthHandler) CompleteOIDCOAuthRegistration(c *gin.Context) {
response.ErrorFrom(c, err)
return
}
- if err := applyPendingOAuthAdoption(c.Request.Context(), h.entClient(), h.authService, session, decision, &user.ID); err != nil {
+ if err := applyPendingOAuthAdoption(c.Request.Context(), h.entClient(), h.authService, h.userService, session, decision, &user.ID); err != nil {
response.ErrorFrom(c, infraerrors.InternalServer("PENDING_AUTH_ADOPTION_APPLY_FAILED", "failed to apply oauth profile adoption").WithCause(err))
return
}
diff --git a/backend/internal/handler/auth_wechat_oauth.go b/backend/internal/handler/auth_wechat_oauth.go
index 4d25a763..b078b804 100644
--- a/backend/internal/handler/auth_wechat_oauth.go
+++ b/backend/internal/handler/auth_wechat_oauth.go
@@ -517,7 +517,7 @@ func (h *AuthHandler) CompleteWeChatOAuthRegistration(c *gin.Context) {
response.ErrorFrom(c, err)
return
}
- if err := applyPendingOAuthAdoption(c.Request.Context(), h.entClient(), h.authService, session, decision, &user.ID); err != nil {
+ if err := applyPendingOAuthAdoption(c.Request.Context(), h.entClient(), h.authService, h.userService, session, decision, &user.ID); err != nil {
response.ErrorFrom(c, infraerrors.InternalServer("PENDING_AUTH_ADOPTION_APPLY_FAILED", "failed to apply oauth profile adoption").WithCause(err))
return
}
diff --git a/backend/internal/handler/user_handler.go b/backend/internal/handler/user_handler.go
index 9dcff828..b1ade5c0 100644
--- a/backend/internal/handler/user_handler.go
+++ b/backend/internal/handler/user_handler.go
@@ -2,6 +2,7 @@ package handler
import (
"context"
+ "strings"
"github.com/Wei-Shaw/sub2api/internal/handler/dto"
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
@@ -43,8 +44,24 @@ type UpdateProfileRequest struct {
type userProfileResponse struct {
dto.User
- AvatarURL string `json:"avatar_url,omitempty"`
- Identities service.UserIdentitySummarySet `json:"identities"`
+ AvatarURL string `json:"avatar_url,omitempty"`
+ AvatarSource *userProfileSourceContext `json:"avatar_source,omitempty"`
+ UsernameSource *userProfileSourceContext `json:"username_source,omitempty"`
+ DisplayNameSource *userProfileSourceContext `json:"display_name_source,omitempty"`
+ NicknameSource *userProfileSourceContext `json:"nickname_source,omitempty"`
+ ProfileSources map[string]*userProfileSourceContext `json:"profile_sources,omitempty"`
+ Identities service.UserIdentitySummarySet `json:"identities"`
+ AuthBindings map[string]service.UserIdentitySummary `json:"auth_bindings"`
+ IdentityBindings map[string]service.UserIdentitySummary `json:"identity_bindings"`
+ EmailBound bool `json:"email_bound"`
+ LinuxDoBound bool `json:"linuxdo_bound"`
+ OIDCBound bool `json:"oidc_bound"`
+ WeChatBound bool `json:"wechat_bound"`
+}
+
+type userProfileSourceContext struct {
+ Provider string `json:"provider,omitempty"`
+ Source string `json:"source,omitempty"`
}
// GetProfile handles getting user profile
@@ -335,9 +352,94 @@ func userProfileResponseFromService(user *service.User, identities service.UserI
if base == nil {
return userProfileResponse{}
}
+ bindings := userProfileBindingMap(identities)
+ profileSources, avatarSource, usernameSource := inferUserProfileSources(user, identities)
return userProfileResponse{
- User: *base,
- AvatarURL: user.AvatarURL,
- Identities: identities,
+ User: *base,
+ AvatarURL: user.AvatarURL,
+ AvatarSource: avatarSource,
+ UsernameSource: usernameSource,
+ DisplayNameSource: usernameSource,
+ NicknameSource: usernameSource,
+ ProfileSources: profileSources,
+ Identities: identities,
+ AuthBindings: bindings,
+ IdentityBindings: bindings,
+ EmailBound: identities.Email.Bound,
+ LinuxDoBound: identities.LinuxDo.Bound,
+ OIDCBound: identities.OIDC.Bound,
+ WeChatBound: identities.WeChat.Bound,
+ }
+}
+
+func userProfileBindingMap(identities service.UserIdentitySummarySet) map[string]service.UserIdentitySummary {
+ return map[string]service.UserIdentitySummary{
+ "email": identities.Email,
+ "linuxdo": identities.LinuxDo,
+ "oidc": identities.OIDC,
+ "wechat": identities.WeChat,
+ }
+}
+
+func inferUserProfileSources(user *service.User, identities service.UserIdentitySummarySet) (
+ map[string]*userProfileSourceContext,
+ *userProfileSourceContext,
+ *userProfileSourceContext,
+) {
+ if user == nil {
+ return nil, nil, nil
+ }
+
+ thirdParty := thirdPartyIdentityProviders(identities)
+ var avatarSource *userProfileSourceContext
+ if strings.TrimSpace(user.AvatarURL) != "" && len(thirdParty) == 1 {
+ avatarSource = buildUserProfileSourceContext(thirdParty[0].Provider)
+ }
+
+ usernameValue := strings.TrimSpace(user.Username)
+ var usernameSource *userProfileSourceContext
+ for _, summary := range thirdParty {
+ if usernameValue != "" && usernameValue == strings.TrimSpace(summary.DisplayName) {
+ usernameSource = buildUserProfileSourceContext(summary.Provider)
+ break
+ }
+ }
+ if usernameSource == nil && usernameValue != "" && len(thirdParty) == 1 {
+ usernameSource = buildUserProfileSourceContext(thirdParty[0].Provider)
+ }
+
+ profileSources := map[string]*userProfileSourceContext{}
+ if avatarSource != nil {
+ profileSources["avatar"] = avatarSource
+ }
+ if usernameSource != nil {
+ profileSources["username"] = usernameSource
+ profileSources["display_name"] = usernameSource
+ profileSources["nickname"] = usernameSource
+ }
+ if len(profileSources) == 0 {
+ return nil, avatarSource, usernameSource
+ }
+ return profileSources, avatarSource, usernameSource
+}
+
+func thirdPartyIdentityProviders(identities service.UserIdentitySummarySet) []service.UserIdentitySummary {
+ out := make([]service.UserIdentitySummary, 0, 3)
+ for _, summary := range []service.UserIdentitySummary{identities.LinuxDo, identities.OIDC, identities.WeChat} {
+ if summary.Bound {
+ out = append(out, summary)
+ }
+ }
+ return out
+}
+
+func buildUserProfileSourceContext(provider string) *userProfileSourceContext {
+ provider = strings.TrimSpace(provider)
+ if provider == "" {
+ return nil
+ }
+ return &userProfileSourceContext{
+ Provider: provider,
+ Source: provider,
}
}
diff --git a/backend/internal/handler/user_handler_test.go b/backend/internal/handler/user_handler_test.go
index b71846c1..1216f9c4 100644
--- a/backend/internal/handler/user_handler_test.go
+++ b/backend/internal/handler/user_handler_test.go
@@ -92,6 +92,12 @@ func (s *userHandlerRepoStub) RemoveGroupFromAllowedGroups(context.Context, int6
func (s *userHandlerRepoStub) AddGroupToAllowedGroups(context.Context, int64, int64) error {
return nil
}
+func (s *userHandlerRepoStub) GetLatestUsedAtByUserIDs(context.Context, []int64) (map[int64]*time.Time, error) {
+ return map[int64]*time.Time{}, nil
+}
+func (s *userHandlerRepoStub) GetLatestUsedAtByUserID(context.Context, int64) (*time.Time, error) {
+ return nil, nil
+}
func (s *userHandlerRepoStub) RemoveGroupFromUserAllowedGroups(context.Context, int64, int64) error {
return nil
}
@@ -230,6 +236,79 @@ func TestUserHandlerGetProfileReturnsIdentitySummaries(t *testing.T) {
require.Contains(t, resp.Data.Identities.WeChat.BindStartPath, "/api/v1/auth/oauth/wechat/start")
}
+func TestUserHandlerGetProfileReturnsLegacyCompatibilityFields(t *testing.T) {
+ gin.SetMode(gin.TestMode)
+
+ verifiedAt := time.Date(2026, 4, 20, 8, 30, 0, 0, time.UTC)
+ repo := &userHandlerRepoStub{
+ user: &service.User{
+ ID: 21,
+ Email: "legacy-profile@example.com",
+ Username: "linuxdo-handle",
+ Role: service.RoleUser,
+ Status: service.StatusActive,
+ AvatarURL: "https://cdn.example.com/linuxdo.png",
+ AvatarSource: "remote_url",
+ },
+ identities: []service.UserAuthIdentityRecord{
+ {
+ ProviderType: "linuxdo",
+ ProviderKey: "linuxdo",
+ ProviderSubject: "linuxdo-subject-21",
+ VerifiedAt: &verifiedAt,
+ Metadata: map[string]any{
+ "username": "linuxdo-handle",
+ },
+ },
+ },
+ }
+ handler := NewUserHandler(service.NewUserService(repo, nil, nil, nil), nil, nil)
+
+ recorder := httptest.NewRecorder()
+ c, _ := gin.CreateTestContext(recorder)
+ c.Request = httptest.NewRequest(http.MethodGet, "/api/v1/user/profile", nil)
+ c.Set(string(middleware2.ContextKeyUser), middleware2.AuthSubject{UserID: 21})
+
+ handler.GetProfile(c)
+
+ require.Equal(t, http.StatusOK, recorder.Code)
+
+ var resp struct {
+ Code int `json:"code"`
+ Data map[string]any `json:"data"`
+ }
+ require.NoError(t, json.Unmarshal(recorder.Body.Bytes(), &resp))
+ require.Equal(t, 0, resp.Code)
+ require.Equal(t, true, resp.Data["email_bound"])
+ require.Equal(t, true, resp.Data["linuxdo_bound"])
+ require.Equal(t, false, resp.Data["oidc_bound"])
+ require.Equal(t, false, resp.Data["wechat_bound"])
+ require.Equal(t, "https://cdn.example.com/linuxdo.png", resp.Data["avatar_url"])
+
+ authBindings, ok := resp.Data["auth_bindings"].(map[string]any)
+ require.True(t, ok)
+ linuxdoBinding, ok := authBindings["linuxdo"].(map[string]any)
+ require.True(t, ok)
+ require.Equal(t, true, linuxdoBinding["bound"])
+ require.Equal(t, "linuxdo", linuxdoBinding["provider"])
+
+ identityBindings, ok := resp.Data["identity_bindings"].(map[string]any)
+ require.True(t, ok)
+ emailBinding, ok := identityBindings["email"].(map[string]any)
+ require.True(t, ok)
+ require.Equal(t, true, emailBinding["bound"])
+
+ avatarSource, ok := resp.Data["avatar_source"].(map[string]any)
+ require.True(t, ok)
+ require.Equal(t, "linuxdo", avatarSource["provider"])
+
+ profileSources, ok := resp.Data["profile_sources"].(map[string]any)
+ require.True(t, ok)
+ usernameSource, ok := profileSources["username"].(map[string]any)
+ require.True(t, ok)
+ require.Equal(t, "linuxdo", usernameSource["provider"])
+}
+
func TestUserHandlerStartIdentityBindingReturnsAuthorizeURL(t *testing.T) {
gin.SetMode(gin.TestMode)
diff --git a/backend/internal/service/user_service.go b/backend/internal/service/user_service.go
index c52f91bb..c106a3f5 100644
--- a/backend/internal/service/user_service.go
+++ b/backend/internal/service/user_service.go
@@ -65,6 +65,8 @@ type UserRepository interface {
List(ctx context.Context, params pagination.PaginationParams) ([]User, *pagination.PaginationResult, error)
ListWithFilters(ctx context.Context, params pagination.PaginationParams, filters UserListFilters) ([]User, *pagination.PaginationResult, error)
+ GetLatestUsedAtByUserIDs(ctx context.Context, userIDs []int64) (map[int64]*time.Time, error)
+ GetLatestUsedAtByUserID(ctx context.Context, userID int64) (*time.Time, error)
UpdateBalance(ctx context.Context, id int64, amount float64) error
DeductBalance(ctx context.Context, id int64, amount float64) error
@@ -159,6 +161,33 @@ type userAuthIdentityReader interface {
ListUserAuthIdentities(ctx context.Context, userID int64) ([]UserAuthIdentityRecord, error)
}
+type emailAuthIdentitySynchronizer interface {
+ EnsureEmailAuthIdentity(ctx context.Context, userID int64, email string) error
+ ReplaceEmailAuthIdentity(ctx context.Context, userID int64, oldEmail, newEmail string) error
+}
+
+func ensureEmailAuthIdentitySync(ctx context.Context, repo UserRepository, userID int64, email string) error {
+ syncer, ok := repo.(emailAuthIdentitySynchronizer)
+ if !ok {
+ return nil
+ }
+ return syncer.EnsureEmailAuthIdentity(ctx, userID, email)
+}
+
+func replaceEmailAuthIdentitySync(ctx context.Context, repo UserRepository, userID int64, oldEmail, newEmail string) error {
+ oldNormalized := strings.ToLower(strings.TrimSpace(oldEmail))
+ newNormalized := strings.ToLower(strings.TrimSpace(newEmail))
+ if oldNormalized == newNormalized {
+ return nil
+ }
+
+ syncer, ok := repo.(emailAuthIdentitySynchronizer)
+ if !ok {
+ return nil
+ }
+ return syncer.ReplaceEmailAuthIdentity(ctx, userID, oldEmail, newEmail)
+}
+
// ChangePasswordRequest 修改密码请求
type ChangePasswordRequest struct {
CurrentPassword string `json:"current_password"`
@@ -252,6 +281,7 @@ func (s *UserService) UpdateProfile(ctx context.Context, userID int64, req Updat
return nil, fmt.Errorf("get user: %w", err)
}
oldConcurrency := user.Concurrency
+ oldEmail := user.Email
// 更新字段
if req.Email != nil {
@@ -271,24 +301,11 @@ func (s *UserService) UpdateProfile(ctx context.Context, userID int64, req Updat
}
if req.AvatarURL != nil {
- avatarValue := strings.TrimSpace(*req.AvatarURL)
- switch {
- case avatarValue == "":
- if err := s.userRepo.DeleteUserAvatar(ctx, userID); err != nil {
- return nil, fmt.Errorf("delete avatar: %w", err)
- }
- applyUserAvatar(user, nil)
- default:
- avatarInput, err := normalizeUserAvatarInput(avatarValue)
- if err != nil {
- return nil, err
- }
- avatar, err := s.userRepo.UpsertUserAvatar(ctx, userID, avatarInput)
- if err != nil {
- return nil, fmt.Errorf("upsert avatar: %w", err)
- }
- applyUserAvatar(user, avatar)
+ avatar, err := s.SetAvatar(ctx, userID, *req.AvatarURL)
+ if err != nil {
+ return nil, err
}
+ applyUserAvatar(user, avatar)
}
if req.Concurrency != nil {
@@ -309,6 +326,9 @@ func (s *UserService) UpdateProfile(ctx context.Context, userID int64, req Updat
if err := s.userRepo.Update(ctx, user); err != nil {
return nil, fmt.Errorf("update user: %w", err)
}
+ if err := replaceEmailAuthIdentitySync(ctx, s.userRepo, user.ID, oldEmail, user.Email); err != nil {
+ return nil, fmt.Errorf("sync email auth identity: %w", err)
+ }
if s.authCacheInvalidator != nil && user.Concurrency != oldConcurrency {
s.authCacheInvalidator.InvalidateAuthCacheByUserID(ctx, userID)
}
@@ -316,6 +336,27 @@ func (s *UserService) UpdateProfile(ctx context.Context, userID int64, req Updat
return user, nil
}
+func (s *UserService) SetAvatar(ctx context.Context, userID int64, raw string) (*UserAvatar, error) {
+ avatarValue := strings.TrimSpace(raw)
+ if avatarValue == "" {
+ if err := s.userRepo.DeleteUserAvatar(ctx, userID); err != nil {
+ return nil, fmt.Errorf("delete avatar: %w", err)
+ }
+ return nil, nil
+ }
+
+ avatarInput, err := normalizeUserAvatarInput(avatarValue)
+ if err != nil {
+ return nil, err
+ }
+
+ avatar, err := s.userRepo.UpsertUserAvatar(ctx, userID, avatarInput)
+ if err != nil {
+ return nil, fmt.Errorf("upsert avatar: %w", err)
+ }
+ return avatar, nil
+}
+
func applyUserAvatar(user *User, avatar *UserAvatar) {
if user == nil {
return
diff --git a/frontend/src/api/user.ts b/frontend/src/api/user.ts
index 1f6e4cd9..c7b1e503 100644
--- a/frontend/src/api/user.ts
+++ b/frontend/src/api/user.ts
@@ -23,6 +23,7 @@ export async function getProfile(): Promise {
*/
export async function updateProfile(profile: {
username?: string
+ avatar_url?: string | null
balance_notify_enabled?: boolean
balance_notify_threshold?: number | null
balance_notify_extra_emails?: NotifyEmailEntry[]
diff --git a/frontend/src/components/user/profile/ProfileInfoCard.vue b/frontend/src/components/user/profile/ProfileInfoCard.vue
index e82ae229..d6273431 100644
--- a/frontend/src/components/user/profile/ProfileInfoCard.vue
+++ b/frontend/src/components/user/profile/ProfileInfoCard.vue
@@ -61,6 +61,71 @@
+
+
+
+
+ {{ t('profile.avatar.title') }}
+
+
+ {{ t('profile.avatar.description') }}
+
+
+
+ {{ t('common.delete') }}
+
+
+
+
+
+ {{ t('profile.avatar.inputLabel') }}
+
+
+
+
+
+ {{ t('profile.avatar.uploadAction') }}
+
+
+ {{ t('common.save') }}
+
+
+ {{ t('profile.avatar.uploadHint') }}
+
+
+
+
+
diff --git a/frontend/src/components/user/profile/__tests__/ProfileInfoCard.spec.ts b/frontend/src/components/user/profile/__tests__/ProfileInfoCard.spec.ts
new file mode 100644
index 00000000..238a898e
--- /dev/null
+++ b/frontend/src/components/user/profile/__tests__/ProfileInfoCard.spec.ts
@@ -0,0 +1,172 @@
+import { mount } from '@vue/test-utils'
+import { beforeEach, describe, expect, it, vi } from 'vitest'
+import ProfileInfoCard from '@/components/user/profile/ProfileInfoCard.vue'
+import type { User } from '@/types'
+
+const {
+ updateProfileMock,
+ showSuccessMock,
+ showErrorMock,
+ authStoreState
+} = vi.hoisted(() => ({
+ updateProfileMock: vi.fn(),
+ showSuccessMock: vi.fn(),
+ showErrorMock: vi.fn(),
+ authStoreState: {
+ user: null as User | null
+ }
+}))
+
+vi.mock('@/api', () => ({
+ userAPI: {
+ updateProfile: updateProfileMock
+ }
+}))
+
+vi.mock('@/stores/auth', () => ({
+ useAuthStore: () => authStoreState
+}))
+
+vi.mock('@/stores/app', () => ({
+ useAppStore: () => ({
+ showSuccess: showSuccessMock,
+ showError: showErrorMock
+ })
+}))
+
+vi.mock('@/utils/apiError', () => ({
+ extractApiErrorMessage: (error: unknown) => (error as Error).message || 'request failed'
+}))
+
+vi.mock('vue-i18n', async (importOriginal) => {
+ const actual = await importOriginal()
+ return {
+ ...actual,
+ useI18n: () => ({
+ t: (key: string, params?: Record) => {
+ if (key === 'profile.administrator') return 'Administrator'
+ if (key === 'profile.user') return 'User'
+ if (key === 'profile.avatar.title') return 'Profile avatar'
+ if (key === 'profile.avatar.description') return 'Set avatar by image URL or upload'
+ if (key === 'profile.avatar.inputLabel') return 'Avatar URL or data URL'
+ if (key === 'profile.avatar.inputPlaceholder') return 'https://cdn.example.com/avatar.png'
+ if (key === 'profile.avatar.uploadAction') return 'Upload image'
+ if (key === 'profile.avatar.uploadHint') return 'Images must be 100KB or smaller'
+ if (key === 'profile.avatar.saveSuccess') return 'Avatar updated'
+ if (key === 'profile.avatar.deleteSuccess') return 'Avatar removed'
+ if (key === 'profile.avatar.invalidType') return 'Please choose an image file'
+ if (key === 'profile.avatar.fileTooLarge') return 'Avatar image must be 100KB or smaller'
+ if (key === 'profile.avatar.invalidValue') return 'Enter a valid avatar URL or image data URL'
+ if (key === 'profile.avatar.emptyDeleteHint') return 'Avatar already removed'
+ if (key === 'profile.authBindings.providers.email') return 'Email'
+ if (key === 'profile.authBindings.providers.linuxdo') return 'LinuxDo'
+ if (key === 'profile.authBindings.providers.wechat') return 'WeChat'
+ if (key === 'profile.authBindings.providers.oidc') return params?.providerName || 'OIDC'
+ if (key === 'common.save') return 'Save'
+ if (key === 'common.delete') return 'Delete'
+ return key
+ }
+ })
+ }
+})
+
+function createUser(overrides: Partial = {}): User {
+ return {
+ id: 5,
+ username: 'alice',
+ email: 'alice@example.com',
+ avatar_url: null,
+ role: 'user',
+ balance: 10,
+ concurrency: 2,
+ status: 'active',
+ allowed_groups: null,
+ balance_notify_enabled: true,
+ balance_notify_threshold: null,
+ balance_notify_extra_emails: [],
+ created_at: '2026-04-20T00:00:00Z',
+ updated_at: '2026-04-20T00:00:00Z',
+ ...overrides
+ }
+}
+
+describe('ProfileInfoCard', () => {
+ beforeEach(() => {
+ updateProfileMock.mockReset()
+ showSuccessMock.mockReset()
+ showErrorMock.mockReset()
+ authStoreState.user = null
+ })
+
+ it('saves a remote avatar URL and updates the auth store', async () => {
+ const updatedUser = createUser({ avatar_url: 'https://cdn.example.com/new.png' })
+ updateProfileMock.mockResolvedValue(updatedUser)
+ authStoreState.user = createUser()
+
+ const wrapper = mount(ProfileInfoCard, {
+ props: {
+ user: authStoreState.user
+ },
+ global: {
+ stubs: {
+ Icon: true,
+ ProfileIdentityBindingsSection: true
+ }
+ }
+ })
+
+ await wrapper.get('[data-testid="profile-avatar-input"]').setValue('https://cdn.example.com/new.png')
+ await wrapper.get('[data-testid="profile-avatar-save"]').trigger('click')
+
+ expect(updateProfileMock).toHaveBeenCalledWith({ avatar_url: 'https://cdn.example.com/new.png' })
+ expect(authStoreState.user?.avatar_url).toBe('https://cdn.example.com/new.png')
+ expect(showSuccessMock).toHaveBeenCalledWith('Avatar updated')
+ })
+
+ it('rejects an oversized data URL before sending the request', async () => {
+ authStoreState.user = createUser()
+ const oversized = `data:image/png;base64,${Buffer.from(new Uint8Array(102401)).toString('base64')}`
+
+ const wrapper = mount(ProfileInfoCard, {
+ props: {
+ user: authStoreState.user
+ },
+ global: {
+ stubs: {
+ Icon: true,
+ ProfileIdentityBindingsSection: true
+ }
+ }
+ })
+
+ await wrapper.get('[data-testid="profile-avatar-input"]').setValue(oversized)
+ await wrapper.get('[data-testid="profile-avatar-save"]').trigger('click')
+
+ expect(updateProfileMock).not.toHaveBeenCalled()
+ expect(showErrorMock).toHaveBeenCalledWith('Avatar image must be 100KB or smaller')
+ })
+
+ it('deletes the current avatar', async () => {
+ const updatedUser = createUser({ avatar_url: null })
+ updateProfileMock.mockResolvedValue(updatedUser)
+ authStoreState.user = createUser({ avatar_url: 'https://cdn.example.com/old.png' })
+
+ const wrapper = mount(ProfileInfoCard, {
+ props: {
+ user: authStoreState.user
+ },
+ global: {
+ stubs: {
+ Icon: true,
+ ProfileIdentityBindingsSection: true
+ }
+ }
+ })
+
+ await wrapper.get('[data-testid="profile-avatar-delete"]').trigger('click')
+
+ expect(updateProfileMock).toHaveBeenCalledWith({ avatar_url: '' })
+ expect(authStoreState.user?.avatar_url).toBeNull()
+ expect(showSuccessMock).toHaveBeenCalledWith('Avatar removed')
+ })
+})
diff --git a/frontend/src/i18n/locales/en.ts b/frontend/src/i18n/locales/en.ts
index e17ed616..62958642 100644
--- a/frontend/src/i18n/locales/en.ts
+++ b/frontend/src/i18n/locales/en.ts
@@ -941,6 +941,20 @@ export default {
unverified: 'Unverified',
verified: 'Verified',
},
+ avatar: {
+ title: 'Profile Avatar',
+ description: 'Set your avatar with a remote image URL or upload a small image.',
+ inputLabel: 'Avatar URL or data URL',
+ inputPlaceholder: 'https://cdn.example.com/avatar.png',
+ uploadAction: 'Upload image',
+ uploadHint: 'Images must be 100KB or smaller',
+ saveSuccess: 'Avatar updated',
+ deleteSuccess: 'Avatar removed',
+ invalidType: 'Please choose an image file',
+ fileTooLarge: 'Avatar image must be 100KB or smaller',
+ invalidValue: 'Enter a valid avatar URL or image data URL',
+ emptyDeleteHint: 'Avatar is already empty',
+ },
authBindings: {
title: 'Connected Sign-In Methods',
description: 'View current bindings and connect another provider to this account.',
diff --git a/frontend/src/i18n/locales/zh.ts b/frontend/src/i18n/locales/zh.ts
index d54c0aba..7b7cdbb4 100644
--- a/frontend/src/i18n/locales/zh.ts
+++ b/frontend/src/i18n/locales/zh.ts
@@ -945,6 +945,20 @@ export default {
unverified: '未验证',
verified: '已验证',
},
+ avatar: {
+ title: '资料头像',
+ description: '支持填写远程图片 URL,或上传不超过 100KB 的头像图片。',
+ inputLabel: '头像 URL 或 data URL',
+ inputPlaceholder: 'https://cdn.example.com/avatar.png',
+ uploadAction: '上传图片',
+ uploadHint: '图片大小需不超过 100KB',
+ saveSuccess: '头像已更新',
+ deleteSuccess: '头像已删除',
+ invalidType: '请选择图片文件',
+ fileTooLarge: '头像图片必须不超过 100KB',
+ invalidValue: '请输入有效的头像 URL 或图片 data URL',
+ emptyDeleteHint: '当前没有可删除的头像',
+ },
authBindings: {
title: '登录方式绑定',
description: '查看当前绑定状态,并将更多第三方登录方式关联到这个账号。',
From 5d58c7c6fba7e7531ddfcf80b6a8b4915998b4b4 Mon Sep 17 00:00:00 2001
From: IanShaw027
Date: Tue, 21 Apr 2026 00:13:40 +0800
Subject: [PATCH 067/189] Add auth identity legacy backfill and email sync
---
...ntity_legacy_migration_integration_test.go | 206 ++++++++++++++++++
backend/internal/repository/user_repo.go | 102 +++++++++
...er_repo_email_identity_integration_test.go | 86 ++++++++
.../repository/user_repo_integration_test.go | 2 +
backend/internal/service/admin_service.go | 7 +
.../admin_service_email_identity_sync_test.go | 95 ++++++++
.../user_service_email_identity_sync_test.go | 68 ++++++
...auth_identity_legacy_external_backfill.sql | 187 ++++++++++++++++
8 files changed, 753 insertions(+)
create mode 100644 backend/internal/repository/auth_identity_legacy_migration_integration_test.go
create mode 100644 backend/internal/repository/user_repo_email_identity_integration_test.go
create mode 100644 backend/internal/service/admin_service_email_identity_sync_test.go
create mode 100644 backend/internal/service/user_service_email_identity_sync_test.go
create mode 100644 backend/migrations/115_auth_identity_legacy_external_backfill.sql
diff --git a/backend/internal/repository/auth_identity_legacy_migration_integration_test.go b/backend/internal/repository/auth_identity_legacy_migration_integration_test.go
new file mode 100644
index 00000000..6a6312d4
--- /dev/null
+++ b/backend/internal/repository/auth_identity_legacy_migration_integration_test.go
@@ -0,0 +1,206 @@
+//go:build integration
+
+package repository
+
+import (
+ "context"
+ "os"
+ "path/filepath"
+ "strconv"
+ "testing"
+
+ "github.com/stretchr/testify/require"
+)
+
+func TestAuthIdentityLegacyExternalBackfillMigration(t *testing.T) {
+ tx := testTx(t)
+ ctx := context.Background()
+
+ migrationPath := filepath.Join("..", "..", "migrations", "115_auth_identity_legacy_external_backfill.sql")
+ migrationSQL, err := os.ReadFile(migrationPath)
+ require.NoError(t, err)
+
+ _, err = tx.ExecContext(ctx, `
+CREATE TABLE IF NOT EXISTS user_external_identities (
+ id BIGSERIAL PRIMARY KEY,
+ user_id BIGINT NOT NULL,
+ provider TEXT NOT NULL,
+ provider_user_id TEXT NOT NULL,
+ provider_union_id TEXT NULL,
+ provider_username TEXT NOT NULL DEFAULT '',
+ display_name TEXT NOT NULL DEFAULT '',
+ profile_url TEXT NOT NULL DEFAULT '',
+ avatar_url TEXT NOT NULL DEFAULT '',
+ metadata TEXT NOT NULL DEFAULT '{}',
+ created_at TIMESTAMPTZ NOT NULL DEFAULT CURRENT_TIMESTAMP,
+ updated_at TIMESTAMPTZ NOT NULL DEFAULT CURRENT_TIMESTAMP
+);
+
+ TRUNCATE TABLE
+ auth_identity_channels,
+ auth_identities,
+ auth_identity_migration_reports,
+ user_external_identities,
+ users
+ RESTART IDENTITY;
+`)
+ require.NoError(t, err)
+
+ var linuxDoUserID int64
+ require.NoError(t, tx.QueryRowContext(ctx, `
+INSERT INTO users (email, password_hash, role, status, balance, concurrency)
+VALUES ('legacy-linuxdo@example.com', 'hash', 'user', 'active', 0, 1)
+RETURNING id`).Scan(&linuxDoUserID))
+
+ var wechatUnionUserID int64
+ require.NoError(t, tx.QueryRowContext(ctx, `
+INSERT INTO users (email, password_hash, role, status, balance, concurrency)
+VALUES ('legacy-wechat-union@example.com', 'hash', 'user', 'active', 0, 1)
+RETURNING id`).Scan(&wechatUnionUserID))
+
+ var wechatOpenIDOnlyUserID int64
+ require.NoError(t, tx.QueryRowContext(ctx, `
+INSERT INTO users (email, password_hash, role, status, balance, concurrency)
+VALUES ('legacy-wechat-openid@example.com', 'hash', 'user', 'active', 0, 1)
+RETURNING id`).Scan(&wechatOpenIDOnlyUserID))
+
+ var syntheticAuthIdentityID int64
+ require.NoError(t, tx.QueryRowContext(ctx, `
+INSERT INTO auth_identities (user_id, provider_type, provider_key, provider_subject, metadata)
+VALUES ($1, 'wechat', 'wechat-main', 'openid-synthetic', '{"backfill_source":"synthetic_email"}'::jsonb)
+RETURNING id`, wechatOpenIDOnlyUserID).Scan(&syntheticAuthIdentityID))
+
+ var linuxDoLegacyID int64
+ require.NoError(t, tx.QueryRowContext(ctx, `
+INSERT INTO user_external_identities (
+ user_id,
+ provider,
+ provider_user_id,
+ provider_union_id,
+ provider_username,
+ display_name,
+ metadata
+) VALUES ($1, 'linuxdo', 'linuxdo-user-1', NULL, 'linux-user', 'Linux User', '{"source":"legacy"}')
+RETURNING id
+`, linuxDoUserID).Scan(&linuxDoLegacyID))
+
+ var wechatUnionLegacyID int64
+ require.NoError(t, tx.QueryRowContext(ctx, `
+INSERT INTO user_external_identities (
+ user_id,
+ provider,
+ provider_user_id,
+ provider_union_id,
+ provider_username,
+ display_name,
+ metadata
+) VALUES ($1, 'wechat', 'openid-union-1', 'union-1', 'wechat-union-user', 'WeChat Union User', '{"channel":"oa","appid":"wx-app-1"}')
+RETURNING id
+`, wechatUnionUserID).Scan(&wechatUnionLegacyID))
+
+ var wechatOpenIDLegacyID int64
+ require.NoError(t, tx.QueryRowContext(ctx, `
+INSERT INTO user_external_identities (
+ user_id,
+ provider,
+ provider_user_id,
+ provider_union_id,
+ provider_username,
+ display_name,
+ metadata
+) VALUES ($1, 'wechat', 'openid-only-1', NULL, 'wechat-openid-user', 'WeChat OpenID User', '{"channel":"oa","appid":"wx-app-2"}')
+RETURNING id
+`, wechatOpenIDOnlyUserID).Scan(&wechatOpenIDLegacyID))
+
+ _, err = tx.ExecContext(ctx, string(migrationSQL))
+ require.NoError(t, err)
+
+ var linuxDoCount int
+ require.NoError(t, tx.QueryRowContext(ctx, `
+SELECT COUNT(*)
+FROM auth_identities
+WHERE user_id = $1
+ AND provider_type = 'linuxdo'
+ AND provider_key = 'linuxdo'
+ AND provider_subject = 'linuxdo-user-1'
+`, linuxDoUserID).Scan(&linuxDoCount))
+ require.Equal(t, 1, linuxDoCount)
+
+ var wechatSubject string
+ require.NoError(t, tx.QueryRowContext(ctx, `
+SELECT provider_subject
+FROM auth_identities
+WHERE user_id = $1
+ AND provider_type = 'wechat'
+ AND provider_key = 'wechat-main'
+ AND provider_subject = 'union-1'
+`, wechatUnionUserID).Scan(&wechatSubject))
+ require.Equal(t, "union-1", wechatSubject)
+
+ var wechatChannelCount int
+ require.NoError(t, tx.QueryRowContext(ctx, `
+SELECT COUNT(*)
+FROM auth_identity_channels channel
+JOIN auth_identities ai ON ai.id = channel.identity_id
+WHERE ai.user_id = $1
+ AND channel.provider_type = 'wechat'
+ AND channel.provider_key = 'wechat-main'
+ AND channel.channel = 'oa'
+ AND channel.channel_app_id = 'wx-app-1'
+ AND channel.channel_subject = 'openid-union-1'
+`, wechatUnionUserID).Scan(&wechatChannelCount))
+ require.Equal(t, 1, wechatChannelCount)
+
+ var legacyOpenIDOnlyReportCount int
+ require.NoError(t, tx.QueryRowContext(ctx, `
+SELECT COUNT(*)
+FROM auth_identity_migration_reports
+WHERE report_type = 'wechat_openid_only_requires_remediation'
+ AND report_key = $1
+`, "legacy_external_identity:"+strconv.FormatInt(wechatOpenIDLegacyID, 10)).Scan(&legacyOpenIDOnlyReportCount))
+ require.Equal(t, 1, legacyOpenIDOnlyReportCount)
+
+ var syntheticReviewCount int
+ require.NoError(t, tx.QueryRowContext(ctx, `
+SELECT COUNT(*)
+FROM auth_identity_migration_reports
+WHERE report_type = 'wechat_openid_only_requires_remediation'
+ AND report_key = $1
+`, "synthetic_auth_identity:"+strconv.FormatInt(syntheticAuthIdentityID, 10)).Scan(&syntheticReviewCount))
+ require.Equal(t, 1, syntheticReviewCount)
+
+ var unionLegacyReportCount int
+ require.NoError(t, tx.QueryRowContext(ctx, `
+SELECT COUNT(*)
+FROM auth_identity_migration_reports
+WHERE report_type = 'wechat_openid_only_requires_remediation'
+ AND report_key = $1
+`, "legacy_external_identity:"+strconv.FormatInt(wechatUnionLegacyID, 10)).Scan(&unionLegacyReportCount))
+ require.Zero(t, unionLegacyReportCount)
+ require.NotZero(t, linuxDoLegacyID)
+}
+
+func TestAuthIdentityLegacyExternalBackfillMigration_IsSafeWhenLegacyTableMissing(t *testing.T) {
+ tx := testTx(t)
+ ctx := context.Background()
+
+ migrationPath := filepath.Join("..", "..", "migrations", "115_auth_identity_legacy_external_backfill.sql")
+ migrationSQL, err := os.ReadFile(migrationPath)
+ require.NoError(t, err)
+
+ var beforeCount int
+ require.NoError(t, tx.QueryRowContext(ctx, `
+SELECT COUNT(*)
+FROM auth_identity_migration_reports
+`).Scan(&beforeCount))
+
+ _, err = tx.ExecContext(ctx, string(migrationSQL))
+ require.NoError(t, err)
+
+ var afterCount int
+ require.NoError(t, tx.QueryRowContext(ctx, `
+SELECT COUNT(*)
+FROM auth_identity_migration_reports
+`).Scan(&afterCount))
+ require.Equal(t, beforeCount, afterCount)
+}
diff --git a/backend/internal/repository/user_repo.go b/backend/internal/repository/user_repo.go
index b5efd19d..b2190b68 100644
--- a/backend/internal/repository/user_repo.go
+++ b/backend/internal/repository/user_repo.go
@@ -11,6 +11,7 @@ import (
dbent "github.com/Wei-Shaw/sub2api/ent"
"github.com/Wei-Shaw/sub2api/ent/apikey"
+ "github.com/Wei-Shaw/sub2api/ent/authidentity"
dbgroup "github.com/Wei-Shaw/sub2api/ent/group"
"github.com/Wei-Shaw/sub2api/ent/predicate"
dbuser "github.com/Wei-Shaw/sub2api/ent/user"
@@ -76,6 +77,9 @@ func (r *userRepository) Create(ctx context.Context, userIn *service.User) error
if err := r.syncUserAllowedGroupsWithClient(ctx, txClient, created.ID, userIn.AllowedGroups); err != nil {
return err
}
+ if err := ensureEmailAuthIdentityWithClient(ctx, txClient, created.ID, created.Email, "user_repo_create"); err != nil {
+ return err
+ }
if tx != nil {
if err := tx.Commit(); err != nil {
@@ -150,6 +154,11 @@ func (r *userRepository) Update(ctx context.Context, userIn *service.User) error
// 已处于外部事务中(ErrTxStarted),复用当前 client 并由调用方负责提交/回滚。
txClient = r.client
}
+ existing, err := clientFromContext(ctx, txClient).User.Get(ctx, userIn.ID)
+ if err != nil {
+ return translatePersistenceError(err, service.ErrUserNotFound, nil)
+ }
+ oldEmail := existing.Email
updateOp := txClient.User.UpdateOneID(userIn.ID).
SetEmail(userIn.Email).
@@ -185,6 +194,9 @@ func (r *userRepository) Update(ctx context.Context, userIn *service.User) error
if err := r.syncUserAllowedGroupsWithClient(ctx, txClient, updated.ID, userIn.AllowedGroups); err != nil {
return err
}
+ if err := replaceEmailAuthIdentityWithClient(ctx, txClient, updated.ID, oldEmail, updated.Email, "user_repo_update"); err != nil {
+ return err
+ }
if tx != nil {
if err := tx.Commit(); err != nil {
@@ -196,6 +208,96 @@ func (r *userRepository) Update(ctx context.Context, userIn *service.User) error
return nil
}
+func (r *userRepository) EnsureEmailAuthIdentity(ctx context.Context, userID int64, email string) error {
+ return ensureEmailAuthIdentityWithClient(ctx, r.client, userID, email, "service_dual_write")
+}
+
+func (r *userRepository) ReplaceEmailAuthIdentity(ctx context.Context, userID int64, oldEmail, newEmail string) error {
+ return replaceEmailAuthIdentityWithClient(ctx, r.client, userID, oldEmail, newEmail, "service_dual_write")
+}
+
+func ensureEmailAuthIdentityWithClient(ctx context.Context, client *dbent.Client, userID int64, email string, source string) error {
+ client = clientFromContext(ctx, client)
+ if client == nil || userID <= 0 {
+ return nil
+ }
+
+ subject := normalizeEmailAuthIdentitySubject(email)
+ if subject == "" {
+ return nil
+ }
+
+ if err := client.AuthIdentity.Create().
+ SetUserID(userID).
+ SetProviderType("email").
+ SetProviderKey("email").
+ SetProviderSubject(subject).
+ SetVerifiedAt(time.Now().UTC()).
+ SetMetadata(map[string]any{"source": source}).
+ OnConflictColumns(
+ authidentity.FieldProviderType,
+ authidentity.FieldProviderKey,
+ authidentity.FieldProviderSubject,
+ ).
+ DoNothing().
+ Exec(ctx); err != nil {
+ return err
+ }
+
+ identity, err := client.AuthIdentity.Query().
+ Where(
+ authidentity.ProviderTypeEQ("email"),
+ authidentity.ProviderKeyEQ("email"),
+ authidentity.ProviderSubjectEQ(subject),
+ ).
+ Only(ctx)
+ if err != nil {
+ if dbent.IsNotFound(err) {
+ return nil
+ }
+ return err
+ }
+ if identity.UserID != userID {
+ return ErrAuthIdentityOwnershipConflict
+ }
+ return nil
+}
+
+func replaceEmailAuthIdentityWithClient(ctx context.Context, client *dbent.Client, userID int64, oldEmail, newEmail string, source string) error {
+ newSubject := normalizeEmailAuthIdentitySubject(newEmail)
+ if err := ensureEmailAuthIdentityWithClient(ctx, client, userID, newEmail, source); err != nil {
+ return err
+ }
+
+ oldSubject := normalizeEmailAuthIdentitySubject(oldEmail)
+ if oldSubject == "" || oldSubject == newSubject {
+ return nil
+ }
+
+ _, err := clientFromContext(ctx, client).AuthIdentity.Delete().
+ Where(
+ authidentity.UserIDEQ(userID),
+ authidentity.ProviderTypeEQ("email"),
+ authidentity.ProviderKeyEQ("email"),
+ authidentity.ProviderSubjectEQ(oldSubject),
+ ).
+ Exec(ctx)
+ return err
+}
+
+func normalizeEmailAuthIdentitySubject(email string) string {
+ normalized := strings.ToLower(strings.TrimSpace(email))
+ if normalized == "" {
+ return ""
+ }
+ if strings.HasSuffix(normalized, service.LinuxDoConnectSyntheticEmailDomain) ||
+ strings.HasSuffix(normalized, service.OIDCConnectSyntheticEmailDomain) ||
+ strings.HasSuffix(normalized, service.WeChatConnectSyntheticEmailDomain) {
+ return ""
+ }
+ return normalized
+}
+
func (r *userRepository) Delete(ctx context.Context, id int64) error {
affected, err := r.client.User.Delete().Where(dbuser.IDEQ(id)).Exec(ctx)
if err != nil {
diff --git a/backend/internal/repository/user_repo_email_identity_integration_test.go b/backend/internal/repository/user_repo_email_identity_integration_test.go
new file mode 100644
index 00000000..fddd82c5
--- /dev/null
+++ b/backend/internal/repository/user_repo_email_identity_integration_test.go
@@ -0,0 +1,86 @@
+//go:build integration
+
+package repository
+
+import (
+ "context"
+
+ "github.com/Wei-Shaw/sub2api/ent/authidentity"
+ "github.com/Wei-Shaw/sub2api/internal/service"
+)
+
+func (s *UserRepoSuite) TestCreate_CreatesEmailAuthIdentityForNormalEmail() {
+ user := &service.User{
+ Email: "repo-create@example.com",
+ PasswordHash: "test-password-hash",
+ Role: service.RoleUser,
+ Status: service.StatusActive,
+ Concurrency: 2,
+ }
+
+ s.Require().NoError(s.repo.Create(s.ctx, user))
+
+ identity, err := s.client.AuthIdentity.Query().
+ Where(
+ authidentity.UserIDEQ(user.ID),
+ authidentity.ProviderTypeEQ("email"),
+ authidentity.ProviderKeyEQ("email"),
+ authidentity.ProviderSubjectEQ("repo-create@example.com"),
+ ).
+ Only(s.ctx)
+ s.Require().NoError(err)
+ s.Require().Equal(user.ID, identity.UserID)
+}
+
+func (s *UserRepoSuite) TestCreate_SkipsEmailAuthIdentityForSyntheticLinuxDoEmail() {
+ user := &service.User{
+ Email: "linuxdo-legacy-user@linuxdo-connect.invalid",
+ PasswordHash: "test-password-hash",
+ Role: service.RoleUser,
+ Status: service.StatusActive,
+ Concurrency: 2,
+ }
+
+ s.Require().NoError(s.repo.Create(s.ctx, user))
+
+ count, err := s.client.AuthIdentity.Query().
+ Where(
+ authidentity.UserIDEQ(user.ID),
+ authidentity.ProviderTypeEQ("email"),
+ authidentity.ProviderKeyEQ("email"),
+ ).
+ Count(s.ctx)
+ s.Require().NoError(err)
+ s.Require().Zero(count)
+}
+
+func (s *UserRepoSuite) TestUpdate_ReplacesEmailAuthIdentityWhenEmailChanges() {
+ user := s.mustCreateUser(&service.User{
+ Email: "before-update@example.com",
+ })
+
+ user.Email = "after-update@example.com"
+ s.Require().NoError(s.repo.Update(s.ctx, user))
+
+ newIdentity, err := s.client.AuthIdentity.Query().
+ Where(
+ authidentity.UserIDEQ(user.ID),
+ authidentity.ProviderTypeEQ("email"),
+ authidentity.ProviderKeyEQ("email"),
+ authidentity.ProviderSubjectEQ("after-update@example.com"),
+ ).
+ Only(s.ctx)
+ s.Require().NoError(err)
+ s.Require().Equal(user.ID, newIdentity.UserID)
+
+ oldCount, err := s.client.AuthIdentity.Query().
+ Where(
+ authidentity.UserIDEQ(user.ID),
+ authidentity.ProviderTypeEQ("email"),
+ authidentity.ProviderKeyEQ("email"),
+ authidentity.ProviderSubjectEQ("before-update@example.com"),
+ ).
+ Count(context.Background())
+ s.Require().NoError(err)
+ s.Require().Zero(oldCount)
+}
diff --git a/backend/internal/repository/user_repo_integration_test.go b/backend/internal/repository/user_repo_integration_test.go
index f5d0f9ff..07fb0598 100644
--- a/backend/internal/repository/user_repo_integration_test.go
+++ b/backend/internal/repository/user_repo_integration_test.go
@@ -26,6 +26,8 @@ func (s *UserRepoSuite) SetupTest() {
s.repo = newUserRepositoryWithSQL(s.client, integrationDB)
// 清理测试数据,确保每个测试从干净状态开始
+ _, _ = integrationDB.ExecContext(s.ctx, "DELETE FROM auth_identity_channels")
+ _, _ = integrationDB.ExecContext(s.ctx, "DELETE FROM auth_identities")
_, _ = integrationDB.ExecContext(s.ctx, "DELETE FROM user_subscriptions")
_, _ = integrationDB.ExecContext(s.ctx, "DELETE FROM user_allowed_groups")
_, _ = integrationDB.ExecContext(s.ctx, "DELETE FROM users")
diff --git a/backend/internal/service/admin_service.go b/backend/internal/service/admin_service.go
index 3490374e..79840e5b 100644
--- a/backend/internal/service/admin_service.go
+++ b/backend/internal/service/admin_service.go
@@ -630,6 +630,9 @@ func (s *adminServiceImpl) CreateUser(ctx context.Context, input *CreateUserInpu
if err := s.userRepo.Create(ctx, user); err != nil {
return nil, err
}
+ if err := ensureEmailAuthIdentitySync(ctx, s.userRepo, user.ID, user.Email); err != nil {
+ return nil, fmt.Errorf("sync email auth identity: %w", err)
+ }
s.assignDefaultSubscriptions(ctx, user.ID)
return user, nil
}
@@ -665,6 +668,7 @@ func (s *adminServiceImpl) UpdateUser(ctx context.Context, id int64, input *Upda
oldConcurrency := user.Concurrency
oldStatus := user.Status
oldRole := user.Role
+ oldEmail := user.Email
if input.Email != "" {
user.Email = input.Email
@@ -697,6 +701,9 @@ func (s *adminServiceImpl) UpdateUser(ctx context.Context, id int64, input *Upda
if err := s.userRepo.Update(ctx, user); err != nil {
return nil, err
}
+ if err := replaceEmailAuthIdentitySync(ctx, s.userRepo, user.ID, oldEmail, user.Email); err != nil {
+ return nil, fmt.Errorf("sync email auth identity: %w", err)
+ }
// 同步用户专属分组倍率
if input.GroupRates != nil && s.userGroupRateRepo != nil {
diff --git a/backend/internal/service/admin_service_email_identity_sync_test.go b/backend/internal/service/admin_service_email_identity_sync_test.go
new file mode 100644
index 00000000..3f3d867c
--- /dev/null
+++ b/backend/internal/service/admin_service_email_identity_sync_test.go
@@ -0,0 +1,95 @@
+//go:build unit
+
+package service
+
+import (
+ "context"
+ "testing"
+ "time"
+
+ "github.com/stretchr/testify/require"
+)
+
+type ensureEmailCall struct {
+ userID int64
+ email string
+}
+
+type replaceEmailCall struct {
+ userID int64
+ oldEmail string
+ newEmail string
+}
+
+type emailSyncUserRepoStub struct {
+ *userRepoStub
+ ensureCalls []ensureEmailCall
+ replaceCalls []replaceEmailCall
+}
+
+func (s *emailSyncUserRepoStub) EnsureEmailAuthIdentity(_ context.Context, userID int64, email string) error {
+ s.ensureCalls = append(s.ensureCalls, ensureEmailCall{userID: userID, email: email})
+ return nil
+}
+
+func (s *emailSyncUserRepoStub) ReplaceEmailAuthIdentity(_ context.Context, userID int64, oldEmail, newEmail string) error {
+ s.replaceCalls = append(s.replaceCalls, replaceEmailCall{
+ userID: userID,
+ oldEmail: oldEmail,
+ newEmail: newEmail,
+ })
+ return nil
+}
+
+func (s *emailSyncUserRepoStub) GetLatestUsedAtByUserIDs(context.Context, []int64) (map[int64]*time.Time, error) {
+ return map[int64]*time.Time{}, nil
+}
+
+func (s *emailSyncUserRepoStub) GetLatestUsedAtByUserID(context.Context, int64) (*time.Time, error) {
+ return nil, nil
+}
+
+func TestAdminService_CreateUser_EnsuresEmailAuthIdentity(t *testing.T) {
+ repo := &emailSyncUserRepoStub{userRepoStub: &userRepoStub{nextID: 55}}
+ svc := &adminServiceImpl{userRepo: repo}
+
+ user, err := svc.CreateUser(context.Background(), &CreateUserInput{
+ Email: "admin-created@example.com",
+ Password: "strong-pass",
+ })
+ require.NoError(t, err)
+ require.NotNil(t, user)
+ require.Equal(t, []ensureEmailCall{{
+ userID: 55,
+ email: "admin-created@example.com",
+ }}, repo.ensureCalls)
+ require.Empty(t, repo.replaceCalls)
+}
+
+func TestAdminService_UpdateUser_ReplacesEmailAuthIdentity(t *testing.T) {
+ repo := &emailSyncUserRepoStub{
+ userRepoStub: &userRepoStub{
+ user: &User{
+ ID: 91,
+ Email: "before@example.com",
+ Role: RoleUser,
+ Status: StatusActive,
+ Concurrency: 3,
+ },
+ },
+ }
+ svc := &adminServiceImpl{userRepo: repo}
+
+ updated, err := svc.UpdateUser(context.Background(), 91, &UpdateUserInput{
+ Email: "after@example.com",
+ })
+ require.NoError(t, err)
+ require.NotNil(t, updated)
+ require.Equal(t, "after@example.com", updated.Email)
+ require.Equal(t, []replaceEmailCall{{
+ userID: 91,
+ oldEmail: "before@example.com",
+ newEmail: "after@example.com",
+ }}, repo.replaceCalls)
+ require.Empty(t, repo.ensureCalls)
+}
diff --git a/backend/internal/service/user_service_email_identity_sync_test.go b/backend/internal/service/user_service_email_identity_sync_test.go
new file mode 100644
index 00000000..3950df8b
--- /dev/null
+++ b/backend/internal/service/user_service_email_identity_sync_test.go
@@ -0,0 +1,68 @@
+//go:build unit
+
+package service
+
+import (
+ "context"
+ "testing"
+ "time"
+
+ "github.com/stretchr/testify/require"
+)
+
+type emailSyncMockUserRepo struct {
+ *mockUserRepo
+ ensureCalls []ensureEmailCall
+ replaceCalls []replaceEmailCall
+}
+
+func (m *emailSyncMockUserRepo) EnsureEmailAuthIdentity(_ context.Context, userID int64, email string) error {
+ m.ensureCalls = append(m.ensureCalls, ensureEmailCall{userID: userID, email: email})
+ return nil
+}
+
+func (m *emailSyncMockUserRepo) ReplaceEmailAuthIdentity(_ context.Context, userID int64, oldEmail, newEmail string) error {
+ m.replaceCalls = append(m.replaceCalls, replaceEmailCall{
+ userID: userID,
+ oldEmail: oldEmail,
+ newEmail: newEmail,
+ })
+ return nil
+}
+
+func (m *emailSyncMockUserRepo) GetLatestUsedAtByUserIDs(context.Context, []int64) (map[int64]*time.Time, error) {
+ return map[int64]*time.Time{}, nil
+}
+
+func (m *emailSyncMockUserRepo) GetLatestUsedAtByUserID(context.Context, int64) (*time.Time, error) {
+ return nil, nil
+}
+
+func TestUpdateProfile_ReplacesEmailAuthIdentityWhenEmailChanges(t *testing.T) {
+ repo := &emailSyncMockUserRepo{
+ mockUserRepo: &mockUserRepo{
+ getByIDUser: &User{
+ ID: 19,
+ Email: "profile-before@example.com",
+ Username: "tester",
+ Concurrency: 2,
+ },
+ },
+ }
+ svc := NewUserService(repo, nil, nil, nil)
+
+ newEmail := "profile-after@example.com"
+ updated, err := svc.UpdateProfile(context.Background(), 19, UpdateProfileRequest{
+ Email: &newEmail,
+ })
+ require.NoError(t, err)
+ require.NotNil(t, updated)
+ require.Equal(t, newEmail, updated.Email)
+ require.Equal(t, 1, repo.updateCalls)
+ require.Equal(t, []replaceEmailCall{{
+ userID: 19,
+ oldEmail: "profile-before@example.com",
+ newEmail: "profile-after@example.com",
+ }}, repo.replaceCalls)
+ require.Empty(t, repo.ensureCalls)
+}
diff --git a/backend/migrations/115_auth_identity_legacy_external_backfill.sql b/backend/migrations/115_auth_identity_legacy_external_backfill.sql
new file mode 100644
index 00000000..f4a13c36
--- /dev/null
+++ b/backend/migrations/115_auth_identity_legacy_external_backfill.sql
@@ -0,0 +1,187 @@
+DO $$
+BEGIN
+ IF to_regclass('public.user_external_identities') IS NULL THEN
+ RETURN;
+ END IF;
+
+ EXECUTE $sql$
+INSERT INTO auth_identities (
+ user_id,
+ provider_type,
+ provider_key,
+ provider_subject,
+ verified_at,
+ metadata
+)
+SELECT
+ legacy.user_id,
+ 'linuxdo',
+ 'linuxdo',
+ legacy.provider_user_id,
+ COALESCE(legacy.updated_at, legacy.created_at, NOW()),
+ legacy.metadata_json || jsonb_build_object(
+ 'legacy_identity_id', legacy.id,
+ 'provider_user_id', legacy.provider_user_id,
+ 'provider_username', legacy.provider_username,
+ 'display_name', legacy.display_name,
+ 'migration', '115_auth_identity_legacy_external_backfill'
+ )
+FROM (
+ SELECT
+ uei.id,
+ uei.user_id,
+ BTRIM(uei.provider_user_id) AS provider_user_id,
+ BTRIM(uei.provider_username) AS provider_username,
+ BTRIM(uei.display_name) AS display_name,
+ COALESCE(NULLIF(BTRIM(COALESCE(uei.metadata, '')), '')::jsonb, '{}'::jsonb) AS metadata_json,
+ uei.created_at,
+ uei.updated_at
+ FROM user_external_identities AS uei
+ JOIN users AS u ON u.id = uei.user_id
+ WHERE u.deleted_at IS NULL
+ AND LOWER(BTRIM(COALESCE(uei.provider, ''))) = 'linuxdo'
+ AND BTRIM(COALESCE(uei.provider_user_id, '')) <> ''
+) AS legacy
+ON CONFLICT (provider_type, provider_key, provider_subject) DO NOTHING;
+$sql$;
+
+ EXECUTE $sql$
+INSERT INTO auth_identities (
+ user_id,
+ provider_type,
+ provider_key,
+ provider_subject,
+ verified_at,
+ metadata
+)
+SELECT
+ legacy.user_id,
+ 'wechat',
+ 'wechat-main',
+ legacy.provider_union_id,
+ COALESCE(legacy.updated_at, legacy.created_at, NOW()),
+ legacy.metadata_json || jsonb_build_object(
+ 'legacy_identity_id', legacy.id,
+ 'openid', legacy.provider_user_id,
+ 'unionid', legacy.provider_union_id,
+ 'provider_user_id', legacy.provider_user_id,
+ 'provider_union_id', legacy.provider_union_id,
+ 'provider_username', legacy.provider_username,
+ 'display_name', legacy.display_name,
+ 'migration', '115_auth_identity_legacy_external_backfill'
+ )
+FROM (
+ SELECT
+ uei.id,
+ uei.user_id,
+ BTRIM(uei.provider_user_id) AS provider_user_id,
+ BTRIM(uei.provider_union_id) AS provider_union_id,
+ BTRIM(uei.provider_username) AS provider_username,
+ BTRIM(uei.display_name) AS display_name,
+ COALESCE(NULLIF(BTRIM(COALESCE(uei.metadata, '')), '')::jsonb, '{}'::jsonb) AS metadata_json,
+ uei.created_at,
+ uei.updated_at
+ FROM user_external_identities AS uei
+ JOIN users AS u ON u.id = uei.user_id
+ WHERE u.deleted_at IS NULL
+ AND LOWER(BTRIM(COALESCE(uei.provider, ''))) = 'wechat'
+ AND BTRIM(COALESCE(uei.provider_union_id, '')) <> ''
+) AS legacy
+ON CONFLICT (provider_type, provider_key, provider_subject) DO NOTHING;
+$sql$;
+
+ EXECUTE $sql$
+INSERT INTO auth_identity_channels (
+ identity_id,
+ provider_type,
+ provider_key,
+ channel,
+ channel_app_id,
+ channel_subject,
+ metadata
+)
+SELECT
+ ai.id,
+ 'wechat',
+ 'wechat-main',
+ legacy.channel,
+ legacy.channel_app_id,
+ legacy.provider_user_id,
+ legacy.metadata_json || jsonb_build_object(
+ 'openid', legacy.provider_user_id,
+ 'unionid', legacy.provider_union_id,
+ 'migration', '115_auth_identity_legacy_external_backfill'
+ )
+FROM (
+ SELECT
+ uei.user_id,
+ BTRIM(uei.provider_user_id) AS provider_user_id,
+ BTRIM(uei.provider_union_id) AS provider_union_id,
+ BTRIM(COALESCE(meta.metadata_json ->> 'channel', '')) AS channel,
+ BTRIM(COALESCE(meta.metadata_json ->> 'channel_app_id', meta.metadata_json ->> 'appid', meta.metadata_json ->> 'app_id', '')) AS channel_app_id,
+ meta.metadata_json
+ FROM user_external_identities AS uei
+ JOIN users AS u ON u.id = uei.user_id
+ CROSS JOIN LATERAL (
+ SELECT COALESCE(NULLIF(BTRIM(COALESCE(uei.metadata, '')), '')::jsonb, '{}'::jsonb) AS metadata_json
+ ) AS meta
+ WHERE u.deleted_at IS NULL
+ AND LOWER(BTRIM(COALESCE(uei.provider, ''))) = 'wechat'
+ AND BTRIM(COALESCE(uei.provider_union_id, '')) <> ''
+) AS legacy
+JOIN auth_identities AS ai
+ ON ai.user_id = legacy.user_id
+ AND ai.provider_type = 'wechat'
+ AND ai.provider_key = 'wechat-main'
+ AND ai.provider_subject = legacy.provider_union_id
+WHERE legacy.channel <> ''
+ AND legacy.channel_app_id <> ''
+ AND legacy.provider_user_id <> ''
+ON CONFLICT DO NOTHING;
+$sql$;
+
+ EXECUTE $sql$
+INSERT INTO auth_identity_migration_reports (report_type, report_key, details)
+SELECT
+ 'wechat_openid_only_requires_remediation',
+ 'legacy_external_identity:' || legacy.id::text,
+ legacy.metadata_json || jsonb_build_object(
+ 'legacy_identity_id', legacy.id,
+ 'user_id', legacy.user_id,
+ 'openid', legacy.provider_user_id,
+ 'reason', 'legacy user_external_identities row only has openid and cannot be canonicalized offline',
+ 'migration', '115_auth_identity_legacy_external_backfill'
+ )
+FROM (
+ SELECT
+ uei.id,
+ uei.user_id,
+ BTRIM(uei.provider_user_id) AS provider_user_id,
+ COALESCE(NULLIF(BTRIM(COALESCE(uei.metadata, '')), '')::jsonb, '{}'::jsonb) AS metadata_json
+ FROM user_external_identities AS uei
+ JOIN users AS u ON u.id = uei.user_id
+ WHERE u.deleted_at IS NULL
+ AND LOWER(BTRIM(COALESCE(uei.provider, ''))) = 'wechat'
+ AND BTRIM(COALESCE(uei.provider_user_id, '')) <> ''
+ AND BTRIM(COALESCE(uei.provider_union_id, '')) = ''
+) AS legacy
+ON CONFLICT (report_type, report_key) DO NOTHING;
+$sql$;
+END $$;
+
+INSERT INTO auth_identity_migration_reports (report_type, report_key, details)
+SELECT
+ 'wechat_openid_only_requires_remediation',
+ 'synthetic_auth_identity:' || ai.id::text,
+ COALESCE(ai.metadata, '{}'::jsonb) || jsonb_build_object(
+ 'auth_identity_id', ai.id,
+ 'user_id', ai.user_id,
+ 'provider_subject', ai.provider_subject,
+ 'reason', 'synthetic wechat auth identity still lacks unionid metadata and needs remediation',
+ 'migration', '115_auth_identity_legacy_external_backfill'
+ )
+FROM auth_identities AS ai
+WHERE ai.provider_type = 'wechat'
+ AND COALESCE(ai.metadata ->> 'backfill_source', '') = 'synthetic_email'
+ AND BTRIM(COALESCE(ai.metadata ->> 'unionid', '')) = ''
+ON CONFLICT (report_type, report_key) DO NOTHING;
From 16be82b9599a47760e9e07b75675f7a0c6db11dd Mon Sep 17 00:00:00 2001
From: IanShaw027
Date: Tue, 21 Apr 2026 00:14:05 +0800
Subject: [PATCH 068/189] fix payment visible methods and resume recovery
---
backend/internal/payment/provider/wxpay.go | 6 +-
.../internal/payment/provider/wxpay_test.go | 31 ++++++++
.../internal/service/payment_config_limits.go | 46 ++++++++++++
.../service/payment_config_limits_test.go | 71 +++++++++++++++++++
frontend/src/views/user/PaymentResultView.vue | 42 ++++++-----
.../user/__tests__/PaymentResultView.spec.ts | 24 +++++++
6 files changed, 201 insertions(+), 19 deletions(-)
diff --git a/backend/internal/payment/provider/wxpay.go b/backend/internal/payment/provider/wxpay.go
index 84b324a8..47569ee3 100644
--- a/backend/internal/payment/provider/wxpay.go
+++ b/backend/internal/payment/provider/wxpay.go
@@ -310,12 +310,14 @@ func buildWxpayResultURL(returnURL string, req payment.CreatePaymentRequest) (st
return "", fmt.Errorf("return URL must be an absolute http(s) URL")
}
- values := url.Values{}
+ values := u.Query()
values.Set("out_trade_no", strings.TrimSpace(req.OrderID))
if paymentType := strings.TrimSpace(req.PaymentType); paymentType != "" {
values.Set("payment_type", paymentType)
}
- u.Path = wxpayResultPath
+ if strings.TrimSpace(u.Path) == "" {
+ u.Path = wxpayResultPath
+ }
u.RawPath = ""
u.RawQuery = values.Encode()
u.Fragment = ""
diff --git a/backend/internal/payment/provider/wxpay_test.go b/backend/internal/payment/provider/wxpay_test.go
index 5074c545..8489e261 100644
--- a/backend/internal/payment/provider/wxpay_test.go
+++ b/backend/internal/payment/provider/wxpay_test.go
@@ -4,6 +4,7 @@ package provider
import (
"context"
+ "net/url"
"strings"
"testing"
@@ -263,6 +264,36 @@ func TestNewWxpay(t *testing.T) {
}
}
+func TestBuildWxpayResultURLPreservesResumeToken(t *testing.T) {
+ t.Parallel()
+
+ resultURL, err := buildWxpayResultURL("https://app.example.com/payment/result?order_id=42&resume_token=resume-42&status=success", payment.CreatePaymentRequest{
+ OrderID: "sub2_42",
+ PaymentType: payment.TypeWxpay,
+ })
+ if err != nil {
+ t.Fatalf("buildWxpayResultURL returned error: %v", err)
+ }
+
+ parsed, err := url.Parse(resultURL)
+ if err != nil {
+ t.Fatalf("url.Parse returned error: %v", err)
+ }
+ query := parsed.Query()
+ if parsed.Path != wxpayResultPath {
+ t.Fatalf("path = %q, want %q", parsed.Path, wxpayResultPath)
+ }
+ if query.Get("resume_token") != "resume-42" {
+ t.Fatalf("resume_token = %q, want %q", query.Get("resume_token"), "resume-42")
+ }
+ if query.Get("order_id") != "42" {
+ t.Fatalf("order_id = %q, want %q", query.Get("order_id"), "42")
+ }
+ if query.Get("out_trade_no") != "sub2_42" {
+ t.Fatalf("out_trade_no = %q, want %q", query.Get("out_trade_no"), "sub2_42")
+ }
+}
+
func TestResolveWxpayJSAPIAppID(t *testing.T) {
t.Parallel()
diff --git a/backend/internal/service/payment_config_limits.go b/backend/internal/service/payment_config_limits.go
index 56905278..f30b119a 100644
--- a/backend/internal/service/payment_config_limits.go
+++ b/backend/internal/service/payment_config_limits.go
@@ -20,6 +20,18 @@ func (s *PaymentConfigService) GetAvailableMethodLimits(ctx context.Context) (*M
return nil, fmt.Errorf("query provider instances: %w", err)
}
typeInstances := pcGroupByPaymentType(instances)
+ if s.settingRepo != nil {
+ vals, err := s.settingRepo.GetMultiple(ctx, []string{
+ SettingPaymentVisibleMethodAlipayEnabled,
+ SettingPaymentVisibleMethodAlipaySource,
+ SettingPaymentVisibleMethodWxpayEnabled,
+ SettingPaymentVisibleMethodWxpaySource,
+ })
+ if err != nil {
+ return nil, fmt.Errorf("query visible method settings: %w", err)
+ }
+ typeInstances = pcApplyVisibleMethodRouting(typeInstances, vals, buildVisibleMethodSourceAvailability(instances))
+ }
resp := &MethodLimitsResponse{
Methods: make(map[string]MethodLimits, len(typeInstances)),
}
@@ -31,6 +43,40 @@ func (s *PaymentConfigService) GetAvailableMethodLimits(ctx context.Context) (*M
return resp, nil
}
+func pcApplyVisibleMethodRouting(typeInstances map[string][]*dbent.PaymentProviderInstance, vals map[string]string, available map[string]bool) map[string][]*dbent.PaymentProviderInstance {
+ if len(typeInstances) == 0 {
+ return typeInstances
+ }
+
+ filtered := make(map[string][]*dbent.PaymentProviderInstance, len(typeInstances))
+ for paymentType, instances := range typeInstances {
+ visibleMethod := NormalizeVisibleMethod(paymentType)
+ switch visibleMethod {
+ case payment.TypeAlipay, payment.TypeWxpay:
+ if !visibleMethodShouldBeExposed(visibleMethod, vals, available) {
+ continue
+ }
+ targetProviderKey, ok := VisibleMethodProviderKeyForSource(visibleMethod, vals[visibleMethodSourceSettingKey(visibleMethod)])
+ if !ok {
+ continue
+ }
+ matching := make([]*dbent.PaymentProviderInstance, 0, len(instances))
+ for _, inst := range instances {
+ if inst.ProviderKey == targetProviderKey {
+ matching = append(matching, inst)
+ }
+ }
+ if len(matching) == 0 {
+ continue
+ }
+ filtered[paymentType] = matching
+ default:
+ filtered[paymentType] = instances
+ }
+ }
+ return filtered
+}
+
// GetMethodLimits returns per-payment-type limits from enabled provider instances.
func (s *PaymentConfigService) GetMethodLimits(ctx context.Context, types []string) ([]MethodLimits, error) {
instances, err := s.entClient.PaymentProviderInstance.Query().
diff --git a/backend/internal/service/payment_config_limits_test.go b/backend/internal/service/payment_config_limits_test.go
index 73ad66ef..4a9d663d 100644
--- a/backend/internal/service/payment_config_limits_test.go
+++ b/backend/internal/service/payment_config_limits_test.go
@@ -1,6 +1,7 @@
package service
import (
+ "context"
"testing"
dbent "github.com/Wei-Shaw/sub2api/ent"
@@ -299,3 +300,73 @@ func TestPcInstanceTypeLimits(t *testing.T) {
}
})
}
+
+func TestGetAvailableMethodLimitsRespectsVisibleMethodRouting(t *testing.T) {
+ ctx := context.Background()
+ client := newPaymentConfigServiceTestClient(t)
+
+ _, err := client.PaymentProviderInstance.Create().
+ SetProviderKey(payment.TypeAlipay).
+ SetName("Official Alipay").
+ SetConfig("{}").
+ SetSupportedTypes("alipay").
+ SetLimits(`{"alipay":{"singleMin":10,"singleMax":100}}`).
+ SetEnabled(true).
+ Save(ctx)
+ if err != nil {
+ t.Fatalf("create official alipay instance: %v", err)
+ }
+ _, err = client.PaymentProviderInstance.Create().
+ SetProviderKey(payment.TypeEasyPay).
+ SetName("EasyPay Alipay").
+ SetConfig("{}").
+ SetSupportedTypes("alipay").
+ SetLimits(`{"alipay":{"singleMin":20,"singleMax":200}}`).
+ SetEnabled(true).
+ Save(ctx)
+ if err != nil {
+ t.Fatalf("create easypay alipay instance: %v", err)
+ }
+ _, err = client.PaymentProviderInstance.Create().
+ SetProviderKey(payment.TypeWxpay).
+ SetName("Official WeChat").
+ SetConfig("{}").
+ SetSupportedTypes("wxpay").
+ SetLimits(`{"wxpay":{"singleMin":30,"singleMax":300}}`).
+ SetEnabled(true).
+ Save(ctx)
+ if err != nil {
+ t.Fatalf("create official wxpay instance: %v", err)
+ }
+
+ svc := &PaymentConfigService{
+ entClient: client,
+ settingRepo: &paymentConfigSettingRepoStub{
+ values: map[string]string{
+ SettingPaymentVisibleMethodAlipayEnabled: "true",
+ SettingPaymentVisibleMethodAlipaySource: VisibleMethodSourceEasyPayAlipay,
+ SettingPaymentVisibleMethodWxpayEnabled: "false",
+ SettingPaymentVisibleMethodWxpaySource: VisibleMethodSourceOfficialWechat,
+ },
+ },
+ }
+
+ resp, err := svc.GetAvailableMethodLimits(ctx)
+ if err != nil {
+ t.Fatalf("GetAvailableMethodLimits returned error: %v", err)
+ }
+
+ alipayLimits, ok := resp.Methods[payment.TypeAlipay]
+ if !ok {
+ t.Fatalf("expected visible alipay limits, got %v", resp.Methods)
+ }
+ if alipayLimits.SingleMin != 20 || alipayLimits.SingleMax != 200 {
+ t.Fatalf("alipay limits = %+v, want easypay-only min=20 max=200", alipayLimits)
+ }
+ if _, ok := resp.Methods[payment.TypeWxpay]; ok {
+ t.Fatalf("wxpay should be hidden when visible method is disabled, got %v", resp.Methods[payment.TypeWxpay])
+ }
+ if resp.GlobalMin != 20 || resp.GlobalMax != 200 {
+ t.Fatalf("global range = (%v, %v), want (20, 200)", resp.GlobalMin, resp.GlobalMax)
+ }
+}
diff --git a/frontend/src/views/user/PaymentResultView.vue b/frontend/src/views/user/PaymentResultView.vue
index 53bbb550..e1bcbbe5 100644
--- a/frontend/src/views/user/PaymentResultView.vue
+++ b/frontend/src/views/user/PaymentResultView.vue
@@ -142,10 +142,11 @@ onMounted(async () => {
const resumeToken = typeof route.query.resume_token === 'string'
? route.query.resume_token
: ''
- let orderId = Number(route.query.order_id) || 0
+ const routeOrderId = Number(route.query.order_id) || 0
const outTradeNo = String(route.query.out_trade_no || '')
+ let orderId = 0
- if (!orderId && resumeToken && typeof window !== 'undefined') {
+ if (resumeToken && typeof window !== 'undefined') {
const restored = readPaymentRecoverySnapshot(
window.localStorage.getItem(PAYMENT_RECOVERY_STORAGE_KEY),
{ resumeToken },
@@ -155,17 +156,31 @@ onMounted(async () => {
}
}
- if (!order.value && !orderId && resumeToken) {
+ if (!order.value && resumeToken && orderId) {
try {
- const result = await paymentAPI.resolveOrderPublicByResumeToken(resumeToken)
- order.value = result.data
- orderId = result.data.id
+ order.value = await paymentStore.pollOrderStatus(orderId)
} catch (_err: unknown) {
- // Resume token recovery failed, continue to legacy fallback paths.
+ // Fall through to signed resume-token recovery below.
}
}
- if (!order.value && orderId) {
+ if (!order.value && resumeToken) {
+ try {
+ const result = await paymentAPI.resolveOrderPublicByResumeToken(resumeToken)
+ order.value = result.data
+ if (!orderId) {
+ orderId = result.data.id
+ }
+ } catch (_err: unknown) {
+ // Resume token recovery failed; do not trust legacy public out_trade_no fallback.
+ }
+ }
+
+ if (!resumeToken) {
+ orderId = routeOrderId
+ }
+
+ if (!order.value && !resumeToken && orderId) {
try {
order.value = await paymentStore.pollOrderStatus(orderId)
} catch (_err: unknown) {
@@ -173,7 +188,8 @@ onMounted(async () => {
}
}
- if (!order.value && outTradeNo) {
+ const hasLegacyFallbackContext = Boolean(route.query.trade_status || route.query.money || route.query.type)
+ if (!order.value && !resumeToken && !orderId && outTradeNo && hasLegacyFallbackContext) {
returnInfo.value = {
outTradeNo,
money: String(route.query.money || ''),
@@ -191,14 +207,6 @@ onMounted(async () => {
} catch (_e: unknown) { /* fall through */ }
}
}
-
- if (!order.value && orderId) {
- try {
- order.value = await paymentStore.pollOrderStatus(orderId)
- } catch (_err: unknown) {
- // Order lookup failed, will show returnInfo fallback.
- }
- }
loading.value = false
})
diff --git a/frontend/src/views/user/__tests__/PaymentResultView.spec.ts b/frontend/src/views/user/__tests__/PaymentResultView.spec.ts
index b1caa526..bfc044a7 100644
--- a/frontend/src/views/user/__tests__/PaymentResultView.spec.ts
+++ b/frontend/src/views/user/__tests__/PaymentResultView.spec.ts
@@ -76,6 +76,7 @@ describe('PaymentResultView', () => {
it('restores order id from a matching resume token and does not trust query success flags', async () => {
routeState.query = {
resume_token: 'resume-42',
+ order_id: '999',
status: 'success',
}
window.localStorage.setItem(PAYMENT_RECOVERY_STORAGE_KEY, JSON.stringify({
@@ -110,6 +111,29 @@ describe('PaymentResultView', () => {
expect(wrapper.text()).not.toContain('payment.result.success')
})
+ it('does not fall back to public out_trade_no verification when resume_token recovery fails', async () => {
+ routeState.query = {
+ resume_token: 'resume-fail',
+ out_trade_no: 'legacy-should-not-run',
+ trade_status: 'TRADE_SUCCESS',
+ }
+ resolveOrderPublicByResumeToken.mockRejectedValueOnce(new Error('resume failed'))
+
+ mount(PaymentResultView, {
+ global: {
+ stubs: {
+ OrderStatusBadge: true,
+ },
+ },
+ })
+
+ await flushPromises()
+
+ expect(resolveOrderPublicByResumeToken).toHaveBeenCalledWith('resume-fail')
+ expect(verifyOrderPublic).not.toHaveBeenCalled()
+ expect(verifyOrder).not.toHaveBeenCalled()
+ })
+
it('keeps legacy out_trade_no verification as a fallback when no order context is available', async () => {
routeState.query = {
out_trade_no: 'legacy-123',
From b79052aaf2c5dfe201cc00c70cd37809961a5f2d Mon Sep 17 00:00:00 2001
From: IanShaw027
Date: Tue, 21 Apr 2026 00:16:06 +0800
Subject: [PATCH 069/189] Decouple email sync tests from local stubs
---
.../admin_service_email_identity_sync_test.go | 130 +++++++++++++++---
.../user_service_email_identity_sync_test.go | 43 +-----
2 files changed, 114 insertions(+), 59 deletions(-)
diff --git a/backend/internal/service/admin_service_email_identity_sync_test.go b/backend/internal/service/admin_service_email_identity_sync_test.go
index 3f3d867c..d555d609 100644
--- a/backend/internal/service/admin_service_email_identity_sync_test.go
+++ b/backend/internal/service/admin_service_email_identity_sync_test.go
@@ -4,9 +4,11 @@ package service
import (
"context"
+ "fmt"
"testing"
"time"
+ "github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/stretchr/testify/require"
)
@@ -21,18 +23,112 @@ type replaceEmailCall struct {
newEmail string
}
-type emailSyncUserRepoStub struct {
- *userRepoStub
+type emailSyncRepoStub struct {
+ user *User
+ nextID int64
+ updateCalls int
+ created []*User
+ updated []*User
ensureCalls []ensureEmailCall
replaceCalls []replaceEmailCall
}
-func (s *emailSyncUserRepoStub) EnsureEmailAuthIdentity(_ context.Context, userID int64, email string) error {
+func (s *emailSyncRepoStub) Create(_ context.Context, user *User) error {
+ if s.nextID != 0 && user.ID == 0 {
+ user.ID = s.nextID
+ }
+ s.created = append(s.created, user)
+ s.user = user
+ return nil
+}
+
+func (s *emailSyncRepoStub) GetByID(_ context.Context, _ int64) (*User, error) {
+ if s.user == nil {
+ return nil, ErrUserNotFound
+ }
+ cloned := *s.user
+ return &cloned, nil
+}
+
+func (s *emailSyncRepoStub) GetByEmail(_ context.Context, _ string) (*User, error) {
+ return nil, ErrUserNotFound
+}
+
+func (s *emailSyncRepoStub) GetFirstAdmin(context.Context) (*User, error) {
+ return nil, fmt.Errorf("unexpected GetFirstAdmin call")
+}
+
+func (s *emailSyncRepoStub) Update(_ context.Context, user *User) error {
+ s.updateCalls++
+ s.updated = append(s.updated, user)
+ s.user = user
+ return nil
+}
+
+func (s *emailSyncRepoStub) Delete(context.Context, int64) error { return nil }
+
+func (s *emailSyncRepoStub) GetUserAvatar(context.Context, int64) (*UserAvatar, error) {
+ return nil, fmt.Errorf("unexpected GetUserAvatar call")
+}
+
+func (s *emailSyncRepoStub) UpsertUserAvatar(context.Context, int64, UpsertUserAvatarInput) (*UserAvatar, error) {
+ return nil, fmt.Errorf("unexpected UpsertUserAvatar call")
+}
+
+func (s *emailSyncRepoStub) DeleteUserAvatar(context.Context, int64) error {
+ return fmt.Errorf("unexpected DeleteUserAvatar call")
+}
+
+func (s *emailSyncRepoStub) List(context.Context, pagination.PaginationParams) ([]User, *pagination.PaginationResult, error) {
+ return nil, nil, fmt.Errorf("unexpected List call")
+}
+
+func (s *emailSyncRepoStub) ListWithFilters(context.Context, pagination.PaginationParams, UserListFilters) ([]User, *pagination.PaginationResult, error) {
+ return nil, nil, fmt.Errorf("unexpected ListWithFilters call")
+}
+
+func (s *emailSyncRepoStub) GetLatestUsedAtByUserIDs(context.Context, []int64) (map[int64]*time.Time, error) {
+ return map[int64]*time.Time{}, nil
+}
+
+func (s *emailSyncRepoStub) GetLatestUsedAtByUserID(context.Context, int64) (*time.Time, error) {
+ return nil, nil
+}
+
+func (s *emailSyncRepoStub) UpdateBalance(context.Context, int64, float64) error { return nil }
+
+func (s *emailSyncRepoStub) DeductBalance(context.Context, int64, float64) error { return nil }
+
+func (s *emailSyncRepoStub) UpdateConcurrency(context.Context, int64, int) error { return nil }
+
+func (s *emailSyncRepoStub) ExistsByEmail(context.Context, string) (bool, error) { return false, nil }
+
+func (s *emailSyncRepoStub) RemoveGroupFromAllowedGroups(context.Context, int64) (int64, error) {
+ return 0, nil
+}
+
+func (s *emailSyncRepoStub) AddGroupToAllowedGroups(context.Context, int64, int64) error { return nil }
+
+func (s *emailSyncRepoStub) RemoveGroupFromUserAllowedGroups(context.Context, int64, int64) error {
+ return nil
+}
+
+func (s *emailSyncRepoStub) ListUserAuthIdentities(context.Context, int64) ([]UserAuthIdentityRecord, error) {
+ return nil, nil
+}
+
+func (s *emailSyncRepoStub) UpdateTotpSecret(context.Context, int64, *string) error { return nil }
+
+func (s *emailSyncRepoStub) EnableTotp(context.Context, int64) error { return nil }
+
+func (s *emailSyncRepoStub) DisableTotp(context.Context, int64) error { return nil }
+
+func (s *emailSyncRepoStub) EnsureEmailAuthIdentity(_ context.Context, userID int64, email string) error {
s.ensureCalls = append(s.ensureCalls, ensureEmailCall{userID: userID, email: email})
return nil
}
-func (s *emailSyncUserRepoStub) ReplaceEmailAuthIdentity(_ context.Context, userID int64, oldEmail, newEmail string) error {
+func (s *emailSyncRepoStub) ReplaceEmailAuthIdentity(_ context.Context, userID int64, oldEmail, newEmail string) error {
s.replaceCalls = append(s.replaceCalls, replaceEmailCall{
userID: userID,
oldEmail: oldEmail,
@@ -41,16 +137,8 @@ func (s *emailSyncUserRepoStub) ReplaceEmailAuthIdentity(_ context.Context, user
return nil
}
-func (s *emailSyncUserRepoStub) GetLatestUsedAtByUserIDs(context.Context, []int64) (map[int64]*time.Time, error) {
- return map[int64]*time.Time{}, nil
-}
-
-func (s *emailSyncUserRepoStub) GetLatestUsedAtByUserID(context.Context, int64) (*time.Time, error) {
- return nil, nil
-}
-
func TestAdminService_CreateUser_EnsuresEmailAuthIdentity(t *testing.T) {
- repo := &emailSyncUserRepoStub{userRepoStub: &userRepoStub{nextID: 55}}
+ repo := &emailSyncRepoStub{nextID: 55}
svc := &adminServiceImpl{userRepo: repo}
user, err := svc.CreateUser(context.Background(), &CreateUserInput{
@@ -67,15 +155,13 @@ func TestAdminService_CreateUser_EnsuresEmailAuthIdentity(t *testing.T) {
}
func TestAdminService_UpdateUser_ReplacesEmailAuthIdentity(t *testing.T) {
- repo := &emailSyncUserRepoStub{
- userRepoStub: &userRepoStub{
- user: &User{
- ID: 91,
- Email: "before@example.com",
- Role: RoleUser,
- Status: StatusActive,
- Concurrency: 3,
- },
+ repo := &emailSyncRepoStub{
+ user: &User{
+ ID: 91,
+ Email: "before@example.com",
+ Role: RoleUser,
+ Status: StatusActive,
+ Concurrency: 3,
},
}
svc := &adminServiceImpl{userRepo: repo}
diff --git a/backend/internal/service/user_service_email_identity_sync_test.go b/backend/internal/service/user_service_email_identity_sync_test.go
index 3950df8b..8109b368 100644
--- a/backend/internal/service/user_service_email_identity_sync_test.go
+++ b/backend/internal/service/user_service_email_identity_sync_test.go
@@ -5,48 +5,17 @@ package service
import (
"context"
"testing"
- "time"
"github.com/stretchr/testify/require"
)
-type emailSyncMockUserRepo struct {
- *mockUserRepo
- ensureCalls []ensureEmailCall
- replaceCalls []replaceEmailCall
-}
-
-func (m *emailSyncMockUserRepo) EnsureEmailAuthIdentity(_ context.Context, userID int64, email string) error {
- m.ensureCalls = append(m.ensureCalls, ensureEmailCall{userID: userID, email: email})
- return nil
-}
-
-func (m *emailSyncMockUserRepo) ReplaceEmailAuthIdentity(_ context.Context, userID int64, oldEmail, newEmail string) error {
- m.replaceCalls = append(m.replaceCalls, replaceEmailCall{
- userID: userID,
- oldEmail: oldEmail,
- newEmail: newEmail,
- })
- return nil
-}
-
-func (m *emailSyncMockUserRepo) GetLatestUsedAtByUserIDs(context.Context, []int64) (map[int64]*time.Time, error) {
- return map[int64]*time.Time{}, nil
-}
-
-func (m *emailSyncMockUserRepo) GetLatestUsedAtByUserID(context.Context, int64) (*time.Time, error) {
- return nil, nil
-}
-
func TestUpdateProfile_ReplacesEmailAuthIdentityWhenEmailChanges(t *testing.T) {
- repo := &emailSyncMockUserRepo{
- mockUserRepo: &mockUserRepo{
- getByIDUser: &User{
- ID: 19,
- Email: "profile-before@example.com",
- Username: "tester",
- Concurrency: 2,
- },
+ repo := &emailSyncRepoStub{
+ user: &User{
+ ID: 19,
+ Email: "profile-before@example.com",
+ Username: "tester",
+ Concurrency: 2,
},
}
svc := NewUserService(repo, nil, nil, nil)
From beeab54ae3b1d44f291822f99d064d50bef59cd7 Mon Sep 17 00:00:00 2001
From: IanShaw027
Date: Tue, 21 Apr 2026 00:17:48 +0800
Subject: [PATCH 070/189] Implement latest-used user repo queries
---
backend/internal/repository/user_repo.go | 46 ++++++++++++++++++++++++
1 file changed, 46 insertions(+)
diff --git a/backend/internal/repository/user_repo.go b/backend/internal/repository/user_repo.go
index b2190b68..7378611d 100644
--- a/backend/internal/repository/user_repo.go
+++ b/backend/internal/repository/user_repo.go
@@ -19,6 +19,7 @@ import (
"github.com/Wei-Shaw/sub2api/ent/usersubscription"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/Wei-Shaw/sub2api/internal/service"
+ "github.com/lib/pq"
entsql "entgo.io/ent/dialect/sql"
)
@@ -298,6 +299,51 @@ func normalizeEmailAuthIdentitySubject(email string) string {
return normalized
}
+func (r *userRepository) GetLatestUsedAtByUserIDs(ctx context.Context, userIDs []int64) (map[int64]*time.Time, error) {
+ result := make(map[int64]*time.Time, len(userIDs))
+ if len(userIDs) == 0 {
+ return result, nil
+ }
+ if r.sql == nil {
+ return nil, fmt.Errorf("sql executor is not configured")
+ }
+
+ rows, err := r.sql.QueryContext(ctx, `
+ SELECT user_id, MAX(created_at) AS last_used_at
+ FROM usage_logs
+ WHERE user_id = ANY($1)
+ GROUP BY user_id
+ `, pq.Array(userIDs))
+ if err != nil {
+ return nil, err
+ }
+ defer func() { _ = rows.Close() }()
+
+ for rows.Next() {
+ var (
+ userID int64
+ lastUsedAt time.Time
+ )
+ if err := rows.Scan(&userID, &lastUsedAt); err != nil {
+ return nil, err
+ }
+ ts := lastUsedAt.UTC()
+ result[userID] = &ts
+ }
+ if err := rows.Err(); err != nil {
+ return nil, err
+ }
+ return result, nil
+}
+
+func (r *userRepository) GetLatestUsedAtByUserID(ctx context.Context, userID int64) (*time.Time, error) {
+ latestByUserID, err := r.GetLatestUsedAtByUserIDs(ctx, []int64{userID})
+ if err != nil {
+ return nil, err
+ }
+ return latestByUserID[userID], nil
+}
+
func (r *userRepository) Delete(ctx context.Context, id int64) error {
affected, err := r.client.User.Delete().Where(dbuser.IDEQ(id)).Exec(ctx)
if err != nil {
From bf3ef2d19aafa4210c15bdead68df7403098ee66 Mon Sep 17 00:00:00 2001
From: IanShaw027
Date: Tue, 21 Apr 2026 00:22:17 +0800
Subject: [PATCH 071/189] add admin user last used support
---
backend/internal/handler/dto/mappers.go | 1 +
backend/internal/handler/dto/types.go | 3 +-
.../handler/dto/user_mapper_activity_test.go | 4 +
backend/internal/repository/user_repo.go | 115 ++++++++-----
.../user_repo_sort_integration_test.go | 63 +++++++
backend/internal/service/admin_service.go | 20 +++
.../service/admin_service_list_users_test.go | 41 +++++
backend/internal/service/user.go | 1 +
frontend/src/types/index.ts | 1 +
frontend/src/views/admin/UsersView.vue | 9 +-
.../views/admin/__tests__/UsersView.spec.ts | 162 ++++++++++++++++++
11 files changed, 373 insertions(+), 47 deletions(-)
create mode 100644 frontend/src/views/admin/__tests__/UsersView.spec.ts
diff --git a/backend/internal/handler/dto/mappers.go b/backend/internal/handler/dto/mappers.go
index 2f18c64e..d88c110c 100644
--- a/backend/internal/handler/dto/mappers.go
+++ b/backend/internal/handler/dto/mappers.go
@@ -68,6 +68,7 @@ func UserFromServiceAdmin(u *service.User) *AdminUser {
return &AdminUser{
User: *base,
Notes: u.Notes,
+ LastUsedAt: u.LastUsedAt,
GroupRates: u.GroupRates,
}
}
diff --git a/backend/internal/handler/dto/types.go b/backend/internal/handler/dto/types.go
index 72fce1fe..15b8548a 100644
--- a/backend/internal/handler/dto/types.go
+++ b/backend/internal/handler/dto/types.go
@@ -36,7 +36,8 @@ type User struct {
type AdminUser struct {
User
- Notes string `json:"notes"`
+ Notes string `json:"notes"`
+ LastUsedAt *time.Time `json:"last_used_at"`
// GroupRates 用户专属分组倍率配置
// map[groupID]rateMultiplier
GroupRates map[int64]float64 `json:"group_rates,omitempty"`
diff --git a/backend/internal/handler/dto/user_mapper_activity_test.go b/backend/internal/handler/dto/user_mapper_activity_test.go
index 668f886c..1e362fba 100644
--- a/backend/internal/handler/dto/user_mapper_activity_test.go
+++ b/backend/internal/handler/dto/user_mapper_activity_test.go
@@ -13,6 +13,7 @@ func TestUserFromServiceAdmin_MapsActivityTimestamps(t *testing.T) {
lastLoginAt := time.Date(2026, time.April, 20, 10, 0, 0, 0, time.UTC)
lastActiveAt := lastLoginAt.Add(15 * time.Minute)
+ lastUsedAt := lastLoginAt.Add(45 * time.Minute)
out := UserFromServiceAdmin(&service.User{
ID: 42,
@@ -22,11 +23,14 @@ func TestUserFromServiceAdmin_MapsActivityTimestamps(t *testing.T) {
Status: service.StatusActive,
LastLoginAt: &lastLoginAt,
LastActiveAt: &lastActiveAt,
+ LastUsedAt: &lastUsedAt,
})
require.NotNil(t, out)
require.NotNil(t, out.LastLoginAt)
require.NotNil(t, out.LastActiveAt)
+ require.NotNil(t, out.LastUsedAt)
require.WithinDuration(t, lastLoginAt, *out.LastLoginAt, time.Second)
require.WithinDuration(t, lastActiveAt, *out.LastActiveAt, time.Second)
+ require.WithinDuration(t, lastUsedAt, *out.LastUsedAt, time.Second)
}
diff --git a/backend/internal/repository/user_repo.go b/backend/internal/repository/user_repo.go
index 7378611d..25d3f1d6 100644
--- a/backend/internal/repository/user_repo.go
+++ b/backend/internal/repository/user_repo.go
@@ -299,51 +299,6 @@ func normalizeEmailAuthIdentitySubject(email string) string {
return normalized
}
-func (r *userRepository) GetLatestUsedAtByUserIDs(ctx context.Context, userIDs []int64) (map[int64]*time.Time, error) {
- result := make(map[int64]*time.Time, len(userIDs))
- if len(userIDs) == 0 {
- return result, nil
- }
- if r.sql == nil {
- return nil, fmt.Errorf("sql executor is not configured")
- }
-
- rows, err := r.sql.QueryContext(ctx, `
- SELECT user_id, MAX(created_at) AS last_used_at
- FROM usage_logs
- WHERE user_id = ANY($1)
- GROUP BY user_id
- `, pq.Array(userIDs))
- if err != nil {
- return nil, err
- }
- defer func() { _ = rows.Close() }()
-
- for rows.Next() {
- var (
- userID int64
- lastUsedAt time.Time
- )
- if err := rows.Scan(&userID, &lastUsedAt); err != nil {
- return nil, err
- }
- ts := lastUsedAt.UTC()
- result[userID] = &ts
- }
- if err := rows.Err(); err != nil {
- return nil, err
- }
- return result, nil
-}
-
-func (r *userRepository) GetLatestUsedAtByUserID(ctx context.Context, userID int64) (*time.Time, error) {
- latestByUserID, err := r.GetLatestUsedAtByUserIDs(ctx, []int64{userID})
- if err != nil {
- return nil, err
- }
- return latestByUserID[userID], nil
-}
-
func (r *userRepository) Delete(ctx context.Context, id int64) error {
affected, err := r.client.User.Delete().Where(dbuser.IDEQ(id)).Exec(ctx)
if err != nil {
@@ -469,6 +424,10 @@ func userListOrder(params pagination.PaginationParams) []func(*entsql.Selector)
sortBy := strings.ToLower(strings.TrimSpace(params.SortBy))
sortOrder := params.NormalizedSortOrder(pagination.SortOrderDesc)
+ if sortBy == "last_used_at" {
+ return userLastUsedAtOrder(sortOrder)
+ }
+
var field string
defaultField := true
nullsLastField := false
@@ -530,6 +489,72 @@ func userListOrder(params pagination.PaginationParams) []func(*entsql.Selector)
return []func(*entsql.Selector){dbent.Desc(field), dbent.Desc(dbuser.FieldID)}
}
+func (r *userRepository) GetLatestUsedAtByUserIDs(ctx context.Context, userIDs []int64) (map[int64]*time.Time, error) {
+ result := make(map[int64]*time.Time, len(userIDs))
+ if len(userIDs) == 0 {
+ return result, nil
+ }
+ if r.sql == nil {
+ return nil, fmt.Errorf("sql executor is not configured")
+ }
+
+ const query = `
+ SELECT user_id, MAX(created_at) AS last_used_at
+ FROM usage_logs
+ WHERE user_id = ANY($1)
+ GROUP BY user_id
+ `
+
+ rows, err := r.sql.QueryContext(ctx, query, pq.Array(userIDs))
+ if err != nil {
+ return nil, err
+ }
+ defer func() { _ = rows.Close() }()
+
+ for rows.Next() {
+ var (
+ userID int64
+ lastUsedAt time.Time
+ )
+ if scanErr := rows.Scan(&userID, &lastUsedAt); scanErr != nil {
+ return nil, scanErr
+ }
+ ts := lastUsedAt.UTC()
+ result[userID] = &ts
+ }
+ if err := rows.Err(); err != nil {
+ return nil, err
+ }
+ return result, nil
+}
+
+func (r *userRepository) GetLatestUsedAtByUserID(ctx context.Context, userID int64) (*time.Time, error) {
+ latestByUserID, err := r.GetLatestUsedAtByUserIDs(ctx, []int64{userID})
+ if err != nil {
+ return nil, err
+ }
+ return latestByUserID[userID], nil
+}
+
+func userLastUsedAtOrder(sortOrder string) []func(*entsql.Selector) {
+ orderExpr := func(direction, nulls string, tieOrder func(string) string) func(*entsql.Selector) {
+ return func(s *entsql.Selector) {
+ subquery := fmt.Sprintf("(SELECT MAX(created_at) FROM usage_logs WHERE user_id = %s)", s.C(dbuser.FieldID))
+ s.OrderExpr(entsql.Expr(subquery + " " + direction + " NULLS " + nulls))
+ s.OrderBy(tieOrder(s.C(dbuser.FieldID)))
+ }
+ }
+
+ if sortOrder == pagination.SortOrderAsc {
+ return []func(*entsql.Selector){
+ orderExpr("ASC", "FIRST", entsql.Asc),
+ }
+ }
+ return []func(*entsql.Selector){
+ orderExpr("DESC", "LAST", entsql.Desc),
+ }
+}
+
// filterUsersByAttributes returns user IDs that match ALL the given attribute filters
func (r *userRepository) filterUsersByAttributes(ctx context.Context, attrs map[int64]string) ([]int64, error) {
if len(attrs) == 0 {
diff --git a/backend/internal/repository/user_repo_sort_integration_test.go b/backend/internal/repository/user_repo_sort_integration_test.go
index 8abef45a..e2445d5b 100644
--- a/backend/internal/repository/user_repo_sort_integration_test.go
+++ b/backend/internal/repository/user_repo_sort_integration_test.go
@@ -10,6 +10,24 @@ import (
"github.com/Wei-Shaw/sub2api/internal/service"
)
+func (s *UserRepoSuite) mustInsertUsageLog(userID int64, createdAt time.Time) {
+ s.T().Helper()
+
+ account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "usage-log-account"})
+ apiKey := mustCreateApiKey(s.T(), s.client, &service.APIKey{UserID: userID})
+
+ _, err := integrationDB.ExecContext(
+ s.ctx,
+ `INSERT INTO usage_logs (user_id, api_key_id, account_id, model, input_tokens, output_tokens, total_cost, actual_cost, created_at)
+ VALUES ($1, $2, $3, 'gpt-test', 1, 1, 0.01, 0.01, $4)`,
+ userID,
+ apiKey.ID,
+ account.ID,
+ createdAt.UTC(),
+ )
+ s.Require().NoError(err)
+}
+
func (s *UserRepoSuite) TestListWithFilters_SortByEmailAsc() {
s.mustCreateUser(&service.User{Email: "z-last@example.com", Username: "z-user"})
s.mustCreateUser(&service.User{Email: "a-first@example.com", Username: "a-user"})
@@ -119,4 +137,49 @@ func (s *UserRepoSuite) TestListWithFilters_SortByLastActiveAtAsc() {
s.Require().Equal("nil-active@example.com", users[2].Email)
}
+func (s *UserRepoSuite) TestGetLatestUsedAtByUserIDs_UsesUsageLogs() {
+ older := time.Now().Add(-4 * time.Hour).UTC().Truncate(time.Second)
+ newer := time.Now().Add(-90 * time.Minute).UTC().Truncate(time.Second)
+
+ userWithUsage := s.mustCreateUser(&service.User{Email: "usage-source@example.com"})
+ userWithoutUsage := s.mustCreateUser(&service.User{Email: "usage-missing@example.com"})
+ s.mustInsertUsageLog(userWithUsage.ID, older)
+ s.mustInsertUsageLog(userWithUsage.ID, newer)
+
+ got, err := s.repo.GetLatestUsedAtByUserIDs(s.ctx, []int64{userWithUsage.ID, userWithoutUsage.ID})
+ s.Require().NoError(err)
+ s.Require().Contains(got, userWithUsage.ID)
+ s.Require().NotContains(got, userWithoutUsage.ID)
+ s.Require().NotNil(got[userWithUsage.ID])
+ s.Require().True(got[userWithUsage.ID].Equal(newer))
+}
+
+func (s *UserRepoSuite) TestListWithFilters_SortByLastUsedAtDesc_UsesUsageLogsNotLastActiveAt() {
+ lastUsedOlder := time.Now().Add(-6 * time.Hour).UTC().Truncate(time.Second)
+ lastUsedNewer := time.Now().Add(-2 * time.Hour).UTC().Truncate(time.Second)
+ lastActiveVeryRecent := time.Now().Add(-10 * time.Minute).UTC().Truncate(time.Second)
+
+ nilUsage := s.mustCreateUser(&service.User{Email: "nil-last-used@example.com"})
+ wrongSource := s.mustCreateUser(&service.User{
+ Email: "active-not-usage@example.com",
+ LastActiveAt: &lastActiveVeryRecent,
+ })
+ rightSource := s.mustCreateUser(&service.User{Email: "usage-wins@example.com"})
+
+ s.mustInsertUsageLog(wrongSource.ID, lastUsedOlder)
+ s.mustInsertUsageLog(rightSource.ID, lastUsedNewer)
+
+ users, _, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{
+ Page: 1,
+ PageSize: 10,
+ SortBy: "last_used_at",
+ SortOrder: "desc",
+ }, service.UserListFilters{})
+ s.Require().NoError(err)
+ s.Require().Len(users, 3)
+ s.Require().Equal(rightSource.ID, users[0].ID)
+ s.Require().Equal(wrongSource.ID, users[1].ID)
+ s.Require().Equal(nilUsage.ID, users[2].ID)
+}
+
func TestUserRepoSortSuiteSmoke(_ *testing.T) {}
diff --git a/backend/internal/service/admin_service.go b/backend/internal/service/admin_service.go
index 79840e5b..10b85f76 100644
--- a/backend/internal/service/admin_service.go
+++ b/backend/internal/service/admin_service.go
@@ -557,6 +557,20 @@ func (s *adminServiceImpl) ListUsers(ctx context.Context, page, pageSize int, fi
if err != nil {
return nil, 0, err
}
+ if len(users) > 0 {
+ userIDs := make([]int64, 0, len(users))
+ for i := range users {
+ userIDs = append(userIDs, users[i].ID)
+ }
+ lastUsedByUserID, latestErr := s.userRepo.GetLatestUsedAtByUserIDs(ctx, userIDs)
+ if latestErr != nil {
+ logger.LegacyPrintf("service.admin", "failed to load user last_used_at in batch: err=%v", latestErr)
+ } else {
+ for i := range users {
+ users[i].LastUsedAt = lastUsedByUserID[users[i].ID]
+ }
+ }
+ }
// 批量加载用户专属分组倍率
if s.userGroupRateRepo != nil && len(users) > 0 {
if batchRepo, ok := s.userGroupRateRepo.(userGroupRateBatchReader); ok {
@@ -601,6 +615,12 @@ func (s *adminServiceImpl) GetUser(ctx context.Context, id int64) (*User, error)
if err != nil {
return nil, err
}
+ lastUsedAt, latestErr := s.userRepo.GetLatestUsedAtByUserID(ctx, id)
+ if latestErr != nil {
+ logger.LegacyPrintf("service.admin", "failed to load user last_used_at: user_id=%d err=%v", id, latestErr)
+ } else {
+ user.LastUsedAt = lastUsedAt
+ }
// 加载用户专属分组倍率
if s.userGroupRateRepo != nil {
rates, err := s.userGroupRateRepo.GetByUserID(ctx, id)
diff --git a/backend/internal/service/admin_service_list_users_test.go b/backend/internal/service/admin_service_list_users_test.go
index ceeb52c2..657616c4 100644
--- a/backend/internal/service/admin_service_list_users_test.go
+++ b/backend/internal/service/admin_service_list_users_test.go
@@ -6,6 +6,7 @@ import (
"context"
"errors"
"testing"
+ "time"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/stretchr/testify/require"
@@ -16,6 +17,8 @@ type userRepoStubForListUsers struct {
users []User
err error
listWithFiltersParams pagination.PaginationParams
+ lastUsedByUserID map[int64]*time.Time
+ lastUsedErr error
}
func (s *userRepoStubForListUsers) ListWithFilters(_ context.Context, params pagination.PaginationParams, _ UserListFilters) ([]User, *pagination.PaginationResult, error) {
@@ -32,6 +35,26 @@ func (s *userRepoStubForListUsers) ListWithFilters(_ context.Context, params pag
}, nil
}
+func (s *userRepoStubForListUsers) GetLatestUsedAtByUserIDs(_ context.Context, userIDs []int64) (map[int64]*time.Time, error) {
+ if s.lastUsedErr != nil {
+ return nil, s.lastUsedErr
+ }
+ result := make(map[int64]*time.Time, len(userIDs))
+ for _, userID := range userIDs {
+ if ts, ok := s.lastUsedByUserID[userID]; ok {
+ result[userID] = ts
+ }
+ }
+ return result, nil
+}
+
+func (s *userRepoStubForListUsers) GetLatestUsedAtByUserID(_ context.Context, userID int64) (*time.Time, error) {
+ if s.lastUsedErr != nil {
+ return nil, s.lastUsedErr
+ }
+ return s.lastUsedByUserID[userID], nil
+}
+
type userGroupRateRepoStubForListUsers struct {
batchCalls int
singleCall []int64
@@ -130,3 +153,21 @@ func TestAdminService_ListUsers_PassesSortParams(t *testing.T) {
SortOrder: "ASC",
}, userRepo.listWithFiltersParams)
}
+
+func TestAdminService_ListUsers_PopulatesLastUsedAt(t *testing.T) {
+ lastUsed := time.Now().UTC().Add(-30 * time.Minute).Truncate(time.Second)
+ userRepo := &userRepoStubForListUsers{
+ users: []User{{ID: 101, Email: "u@example.com"}},
+ lastUsedByUserID: map[int64]*time.Time{
+ 101: &lastUsed,
+ },
+ }
+ svc := &adminServiceImpl{userRepo: userRepo}
+
+ users, total, err := svc.ListUsers(context.Background(), 1, 20, UserListFilters{}, "", "")
+ require.NoError(t, err)
+ require.Equal(t, int64(1), total)
+ require.Len(t, users, 1)
+ require.NotNil(t, users[0].LastUsedAt)
+ require.WithinDuration(t, lastUsed, *users[0].LastUsedAt, time.Second)
+}
diff --git a/backend/internal/service/user.go b/backend/internal/service/user.go
index d8b5325c..fa04d95e 100644
--- a/backend/internal/service/user.go
+++ b/backend/internal/service/user.go
@@ -26,6 +26,7 @@ type User struct {
SignupSource string
LastLoginAt *time.Time
LastActiveAt *time.Time
+ LastUsedAt *time.Time
CreatedAt time.Time
UpdatedAt time.Time
diff --git a/frontend/src/types/index.ts b/frontend/src/types/index.ts
index a4b2277b..5a2e3184 100644
--- a/frontend/src/types/index.ts
+++ b/frontend/src/types/index.ts
@@ -93,6 +93,7 @@ export interface User {
export interface AdminUser extends User {
// 管理员备注(普通用户接口不返回)
notes: string
+ last_used_at?: string | null
// 用户专属分组倍率配置 (group_id -> rate_multiplier)
group_rates?: Record
// 当前并发数(仅管理员列表接口返回)
diff --git a/frontend/src/views/admin/UsersView.vue b/frontend/src/views/admin/UsersView.vue
index 55a8f2d8..07c9d437 100644
--- a/frontend/src/views/admin/UsersView.vue
+++ b/frontend/src/views/admin/UsersView.vue
@@ -461,6 +461,12 @@
+
+
+ {{ value ? formatDateTime(value) : '-' }}
+
+
+
{{ value ? formatDateTime(value) : '-' }}
@@ -713,6 +719,7 @@ const allColumns = computed(() => [
{ key: 'concurrency', label: t('admin.users.columns.concurrency'), sortable: true },
{ key: 'status', label: t('admin.users.columns.status'), sortable: true },
{ key: 'last_login_at', label: t('admin.users.columns.lastLogin'), sortable: true },
+ { key: 'last_used_at', label: t('admin.users.columns.lastUsed'), sortable: true },
{ key: 'last_active_at', label: t('admin.users.columns.lastActive'), sortable: true },
{ key: 'created_at', label: t('admin.users.columns.created'), sortable: true },
{ key: 'actions', label: t('admin.users.columns.actions'), sortable: false }
@@ -801,7 +808,7 @@ const searchQuery = ref('')
const USER_SORT_STORAGE_KEY = 'admin-users-table-sort'
const loadInitialSortState = (): { sort_by: string; sort_order: 'asc' | 'desc' } => {
const fallback = { sort_by: 'created_at', sort_order: 'desc' as 'asc' | 'desc' }
- const sortable = new Set(['email', 'id', 'username', 'role', 'balance', 'concurrency', 'status', 'last_login_at', 'last_active_at', 'created_at'])
+ const sortable = new Set(['email', 'id', 'username', 'role', 'balance', 'concurrency', 'status', 'last_login_at', 'last_used_at', 'last_active_at', 'created_at'])
try {
const raw = localStorage.getItem(USER_SORT_STORAGE_KEY)
if (!raw) return fallback
diff --git a/frontend/src/views/admin/__tests__/UsersView.spec.ts b/frontend/src/views/admin/__tests__/UsersView.spec.ts
new file mode 100644
index 00000000..1ea67b63
--- /dev/null
+++ b/frontend/src/views/admin/__tests__/UsersView.spec.ts
@@ -0,0 +1,162 @@
+import { beforeEach, describe, expect, it, vi } from 'vitest'
+import { flushPromises, mount } from '@vue/test-utils'
+
+import type { AdminUser } from '@/types'
+import UsersView from '../UsersView.vue'
+
+const {
+ listUsers,
+ getAllGroups,
+ getBatchUsersUsage,
+ listEnabledDefinitions,
+ getBatchUserAttributes
+} = vi.hoisted(() => ({
+ listUsers: vi.fn(),
+ getAllGroups: vi.fn(),
+ getBatchUsersUsage: vi.fn(),
+ listEnabledDefinitions: vi.fn(),
+ getBatchUserAttributes: vi.fn()
+}))
+
+vi.mock('@/api/admin', () => ({
+ adminAPI: {
+ users: {
+ list: listUsers,
+ toggleStatus: vi.fn(),
+ delete: vi.fn()
+ },
+ groups: {
+ getAll: getAllGroups
+ },
+ dashboard: {
+ getBatchUsersUsage
+ },
+ userAttributes: {
+ listEnabledDefinitions,
+ getBatchUserAttributes
+ }
+ }
+}))
+
+vi.mock('@/stores/app', () => ({
+ useAppStore: () => ({
+ showError: vi.fn(),
+ showSuccess: vi.fn()
+ })
+}))
+
+vi.mock('vue-i18n', async () => {
+ const actual = await vi.importActual('vue-i18n')
+ return {
+ ...actual,
+ useI18n: () => ({
+ t: (key: string) => key
+ })
+ }
+})
+
+const createAdminUser = (): AdminUser => ({
+ id: 42,
+ username: 'scoped-user',
+ email: 'scoped@example.com',
+ role: 'user',
+ balance: 0,
+ concurrency: 1,
+ status: 'active',
+ allowed_groups: [],
+ balance_notify_enabled: false,
+ balance_notify_threshold: null,
+ balance_notify_extra_emails: [],
+ created_at: '2026-04-17T00:00:00Z',
+ updated_at: '2026-04-17T00:00:00Z',
+ notes: '',
+ last_login_at: '2026-04-16T01:00:00Z',
+ last_active_at: '2026-04-16T02:00:00Z',
+ last_used_at: '2026-04-17T02:00:00Z',
+ current_concurrency: 0
+})
+
+const DataTableStub = {
+ props: ['columns', 'data'],
+ emits: ['sort'],
+ template: `
+
+
{{ columns.map(col => col.key).join(',') }}
+
sort
+
+
+
+
+ `
+}
+
+describe('admin UsersView', () => {
+ beforeEach(() => {
+ localStorage.clear()
+
+ listUsers.mockReset()
+ getAllGroups.mockReset()
+ getBatchUsersUsage.mockReset()
+ listEnabledDefinitions.mockReset()
+ getBatchUserAttributes.mockReset()
+
+ listUsers.mockResolvedValue({
+ items: [createAdminUser()],
+ total: 1,
+ page: 1,
+ page_size: 20,
+ pages: 1
+ })
+ getAllGroups.mockResolvedValue([])
+ getBatchUsersUsage.mockResolvedValue({ stats: {} })
+ listEnabledDefinitions.mockResolvedValue([])
+ getBatchUserAttributes.mockResolvedValue({ values: {} })
+ })
+
+ it('shows last_used_at column and requests last_used_at sort', async () => {
+ const wrapper = mount(UsersView, {
+ global: {
+ stubs: {
+ AppLayout: { template: '
' },
+ TablePageLayout: {
+ template: '
'
+ },
+ DataTable: DataTableStub,
+ Pagination: true,
+ ConfirmDialog: true,
+ EmptyState: true,
+ GroupBadge: true,
+ Select: true,
+ UserAttributesConfigModal: true,
+ UserConcurrencyCell: true,
+ UserCreateModal: true,
+ UserEditModal: true,
+ UserApiKeysModal: true,
+ UserAllowedGroupsModal: true,
+ UserBalanceModal: true,
+ UserBalanceHistoryModal: true,
+ GroupReplaceModal: true,
+ Icon: true,
+ Teleport: true
+ }
+ }
+ })
+
+ await flushPromises()
+
+ expect(wrapper.get('[data-test="columns"]').text()).toContain('last_used_at')
+
+ await wrapper.get('[data-test="sort-last-used"]').trigger('click')
+ await flushPromises()
+
+ expect(listUsers).toHaveBeenLastCalledWith(
+ 1,
+ 20,
+ expect.objectContaining({
+ sort_by: 'last_used_at',
+ sort_order: 'desc'
+ }),
+ expect.any(Object)
+ )
+ })
+})
From 1521d503990f2b3ab6d474958b64c1e4f5fb3baf Mon Sep 17 00:00:00 2001
From: IanShaw027
Date: Tue, 21 Apr 2026 00:31:52 +0800
Subject: [PATCH 072/189] fix: apply email first-bind defaults on legacy login
---
backend/internal/service/auth_service.go | 80 ++++++--
.../auth_service_identity_sync_test.go | 189 +++++++++++++++++-
2 files changed, 238 insertions(+), 31 deletions(-)
diff --git a/backend/internal/service/auth_service.go b/backend/internal/service/auth_service.go
index dda6df04..d0d5e4e3 100644
--- a/backend/internal/service/auth_service.go
+++ b/backend/internal/service/auth_service.go
@@ -807,37 +807,75 @@ func (s *AuthService) backfillEmailIdentityOnSuccessfulLogin(ctx context.Context
if s == nil || user == nil || user.ID <= 0 {
return
}
- s.ensureEmailAuthIdentity(ctx, user)
+ if s.ensureEmailAuthIdentity(ctx, user) {
+ if err := s.ApplyProviderDefaultSettingsOnFirstBind(ctx, user.ID, "email"); err != nil {
+ logger.LegacyPrintf("service.auth", "[Auth] Failed to apply email first bind defaults: user_id=%d err=%v", user.ID, err)
+ }
+ }
}
-func (s *AuthService) ensureEmailAuthIdentity(ctx context.Context, user *User) {
+func (s *AuthService) ensureEmailAuthIdentity(ctx context.Context, user *User) bool {
if s == nil || s.entClient == nil || user == nil || user.ID <= 0 {
- return
+ return false
}
email := strings.ToLower(strings.TrimSpace(user.Email))
if email == "" || isReservedEmail(email) {
- return
+ return false
}
- if err := s.entClient.AuthIdentity.Create().
- SetUserID(user.ID).
- SetProviderType("email").
- SetProviderKey("email").
- SetProviderSubject(email).
- SetVerifiedAt(time.Now().UTC()).
- SetMetadata(map[string]any{
- "source": "auth_service_dual_write",
- }).
- OnConflictColumns(
- authidentity.FieldProviderType,
- authidentity.FieldProviderKey,
- authidentity.FieldProviderSubject,
- ).
- DoNothing().
- Exec(ctx); err != nil {
- logger.LegacyPrintf("service.auth", "[Auth] Failed to ensure email auth identity: user_id=%d email=%s err=%v", user.ID, email, err)
+ client := s.entClient
+ if tx := dbent.TxFromContext(ctx); tx != nil {
+ client = tx.Client()
}
+
+ buildQuery := func() *dbent.AuthIdentityQuery {
+ return client.AuthIdentity.Query().Where(
+ authidentity.ProviderTypeEQ("email"),
+ authidentity.ProviderKeyEQ("email"),
+ authidentity.ProviderSubjectEQ(email),
+ )
+ }
+
+ existed, err := buildQuery().Exist(ctx)
+ if err != nil {
+ logger.LegacyPrintf("service.auth", "[Auth] Failed to inspect email auth identity: user_id=%d email=%s err=%v", user.ID, email, err)
+ return false
+ }
+
+ if !existed {
+ if err := client.AuthIdentity.Create().
+ SetUserID(user.ID).
+ SetProviderType("email").
+ SetProviderKey("email").
+ SetProviderSubject(email).
+ SetVerifiedAt(time.Now().UTC()).
+ SetMetadata(map[string]any{
+ "source": "auth_service_dual_write",
+ }).
+ OnConflictColumns(
+ authidentity.FieldProviderType,
+ authidentity.FieldProviderKey,
+ authidentity.FieldProviderSubject,
+ ).
+ DoNothing().
+ Exec(ctx); err != nil {
+ logger.LegacyPrintf("service.auth", "[Auth] Failed to ensure email auth identity: user_id=%d email=%s err=%v", user.ID, email, err)
+ return false
+ }
+ }
+
+ identity, err := buildQuery().Only(ctx)
+ if err != nil {
+ logger.LegacyPrintf("service.auth", "[Auth] Failed to reload email auth identity: user_id=%d email=%s err=%v", user.ID, email, err)
+ return false
+ }
+ if identity.UserID != user.ID {
+ logger.LegacyPrintf("service.auth", "[Auth] Email auth identity ownership mismatch: user_id=%d email=%s owner_id=%d", user.ID, email, identity.UserID)
+ return false
+ }
+
+ return !existed
}
func inferLegacySignupSource(email string) string {
diff --git a/backend/internal/service/auth_service_identity_sync_test.go b/backend/internal/service/auth_service_identity_sync_test.go
index fcb4813b..e2a94b13 100644
--- a/backend/internal/service/auth_service_identity_sync_test.go
+++ b/backend/internal/service/auth_service_identity_sync_test.go
@@ -21,6 +21,19 @@ import (
_ "modernc.org/sqlite"
)
+type authIdentityDefaultSubAssignerStub struct {
+ calls []*service.AssignSubscriptionInput
+}
+
+func (s *authIdentityDefaultSubAssignerStub) AssignOrExtendSubscription(
+ _ context.Context,
+ input *service.AssignSubscriptionInput,
+) (*service.UserSubscription, bool, error) {
+ cloned := *input
+ s.calls = append(s.calls, &cloned)
+ return &service.UserSubscription{UserID: input.UserID, GroupID: input.GroupID}, true, nil
+}
+
type authIdentitySettingRepoStub struct {
values map[string]string
}
@@ -40,8 +53,14 @@ func (s *authIdentitySettingRepoStub) Set(context.Context, string, string) error
panic("unexpected Set call")
}
-func (s *authIdentitySettingRepoStub) GetMultiple(context.Context, []string) (map[string]string, error) {
- panic("unexpected GetMultiple call")
+func (s *authIdentitySettingRepoStub) GetMultiple(_ context.Context, keys []string) (map[string]string, error) {
+ out := make(map[string]string, len(keys))
+ for _, key := range keys {
+ if v, ok := s.values[key]; ok {
+ out[key] = v
+ }
+ }
+ return out, nil
}
func (s *authIdentitySettingRepoStub) SetMultiple(context.Context, map[string]string) error {
@@ -56,7 +75,11 @@ func (s *authIdentitySettingRepoStub) Delete(context.Context, string) error {
panic("unexpected Delete call")
}
-func newAuthServiceWithEnt(t *testing.T) (*service.AuthService, service.UserRepository, *dbent.Client) {
+func newAuthServiceWithEnt(
+ t *testing.T,
+ settings map[string]string,
+ defaultSubAssigner service.DefaultSubscriptionAssigner,
+) (*service.AuthService, service.UserRepository, *dbent.Client) {
t.Helper()
db, err := sql.Open("sqlite", "file:auth_service_identity_sync?mode=memory&cache=shared")
@@ -65,6 +88,16 @@ func newAuthServiceWithEnt(t *testing.T) (*service.AuthService, service.UserRepo
_, err = db.Exec("PRAGMA foreign_keys = ON")
require.NoError(t, err)
+ _, err = db.Exec(`
+CREATE TABLE IF NOT EXISTS user_provider_default_grants (
+ id INTEGER PRIMARY KEY AUTOINCREMENT,
+ user_id INTEGER NOT NULL,
+ provider_type TEXT NOT NULL,
+ grant_reason TEXT NOT NULL DEFAULT 'first_bind',
+ created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP,
+ UNIQUE(user_id, provider_type, grant_reason)
+)`)
+ require.NoError(t, err)
drv := entsql.OpenDB(dialect.SQLite, db)
client := enttest.NewClient(t, enttest.WithOptions(dbent.Driver(drv)))
@@ -82,17 +115,17 @@ func newAuthServiceWithEnt(t *testing.T) (*service.AuthService, service.UserRepo
},
}
settingSvc := service.NewSettingService(&authIdentitySettingRepoStub{
- values: map[string]string{
- service.SettingKeyRegistrationEnabled: "true",
- },
+ values: settings,
}, cfg)
- svc := service.NewAuthService(client, repo, nil, nil, cfg, settingSvc, nil, nil, nil, nil, nil)
+ svc := service.NewAuthService(client, repo, nil, nil, cfg, settingSvc, nil, nil, nil, nil, defaultSubAssigner)
return svc, repo, client
}
func TestAuthServiceRegisterDualWritesEmailIdentity(t *testing.T) {
- svc, _, client := newAuthServiceWithEnt(t)
+ svc, _, client := newAuthServiceWithEnt(t, map[string]string{
+ service.SettingKeyRegistrationEnabled: "true",
+ }, nil)
ctx := context.Background()
token, user, err := svc.Register(ctx, "user@example.com", "password")
@@ -119,7 +152,9 @@ func TestAuthServiceRegisterDualWritesEmailIdentity(t *testing.T) {
}
func TestAuthServiceLoginTouchesLastLoginAt(t *testing.T) {
- svc, repo, client := newAuthServiceWithEnt(t)
+ svc, repo, client := newAuthServiceWithEnt(t, map[string]string{
+ service.SettingKeyRegistrationEnabled: "true",
+ }, nil)
ctx := context.Background()
user := &service.User{
@@ -163,7 +198,9 @@ func TestAuthServiceLoginTouchesLastLoginAt(t *testing.T) {
}
func TestAuthServiceRecordSuccessfulLoginBackfillsEmailIdentity(t *testing.T) {
- svc, repo, client := newAuthServiceWithEnt(t)
+ svc, repo, client := newAuthServiceWithEnt(t, map[string]string{
+ service.SettingKeyRegistrationEnabled: "true",
+ }, nil)
ctx := context.Background()
user := &service.User{
@@ -188,3 +225,135 @@ func TestAuthServiceRecordSuccessfulLoginBackfillsEmailIdentity(t *testing.T) {
require.NoError(t, err)
require.Equal(t, user.ID, identity.UserID)
}
+
+func TestAuthServiceLogin_AppliesEmailFirstBindDefaultsOnlyWhenEmailIdentityIsNew(t *testing.T) {
+ assigner := &authIdentityDefaultSubAssignerStub{}
+ svc, _, client := newAuthServiceWithEnt(t, map[string]string{
+ service.SettingKeyRegistrationEnabled: "true",
+ service.SettingKeyAuthSourceDefaultEmailBalance: "8.5",
+ service.SettingKeyAuthSourceDefaultEmailConcurrency: "4",
+ service.SettingKeyAuthSourceDefaultEmailSubscriptions: `[{"group_id":11,"validity_days":30}]`,
+ service.SettingKeyAuthSourceDefaultEmailGrantOnFirstBind: "true",
+ }, assigner)
+ ctx := context.Background()
+
+ passwordHash, err := svc.HashPassword("password")
+ require.NoError(t, err)
+ user, err := client.User.Create().
+ SetEmail("legacy@example.com").
+ SetUsername("legacy-user").
+ SetPasswordHash(passwordHash).
+ SetBalance(1.5).
+ SetConcurrency(2).
+ SetRole(service.RoleUser).
+ SetStatus(service.StatusActive).
+ Save(ctx)
+ require.NoError(t, err)
+
+ token, gotUser, err := svc.Login(ctx, user.Email, "password")
+ require.NoError(t, err)
+ require.NotEmpty(t, token)
+ require.NotNil(t, gotUser)
+
+ storedUser, err := client.User.Get(ctx, user.ID)
+ require.NoError(t, err)
+ require.Equal(t, 10.0, storedUser.Balance)
+ require.Equal(t, 6, storedUser.Concurrency)
+ require.Len(t, assigner.calls, 1)
+ require.Equal(t, int64(11), assigner.calls[0].GroupID)
+ require.Equal(t, 30, assigner.calls[0].ValidityDays)
+
+ identityCount, err := client.AuthIdentity.Query().
+ Where(
+ authidentity.ProviderTypeEQ("email"),
+ authidentity.ProviderKeyEQ("email"),
+ authidentity.ProviderSubjectEQ("legacy@example.com"),
+ ).
+ Count(ctx)
+ require.NoError(t, err)
+ require.Equal(t, 1, identityCount)
+ require.Equal(t, 1, countProviderGrantRecords(t, client, user.ID, "email", "first_bind"))
+
+ token, gotUser, err = svc.Login(ctx, user.Email, "password")
+ require.NoError(t, err)
+ require.NotEmpty(t, token)
+ require.NotNil(t, gotUser)
+
+ storedUser, err = client.User.Get(ctx, user.ID)
+ require.NoError(t, err)
+ require.Equal(t, 10.0, storedUser.Balance)
+ require.Equal(t, 6, storedUser.Concurrency)
+ require.Len(t, assigner.calls, 1)
+ require.Equal(t, 1, countProviderGrantRecords(t, client, user.ID, "email", "first_bind"))
+}
+
+func TestAuthServiceLogin_DoesNotApplyEmailFirstBindDefaultsWhenIdentityAlreadyExists(t *testing.T) {
+ assigner := &authIdentityDefaultSubAssignerStub{}
+ svc, _, client := newAuthServiceWithEnt(t, map[string]string{
+ service.SettingKeyRegistrationEnabled: "true",
+ service.SettingKeyAuthSourceDefaultEmailBalance: "8.5",
+ service.SettingKeyAuthSourceDefaultEmailConcurrency: "4",
+ service.SettingKeyAuthSourceDefaultEmailSubscriptions: `[{"group_id":11,"validity_days":30}]`,
+ service.SettingKeyAuthSourceDefaultEmailGrantOnFirstBind: "true",
+ }, assigner)
+ ctx := context.Background()
+
+ passwordHash, err := svc.HashPassword("password")
+ require.NoError(t, err)
+ user, err := client.User.Create().
+ SetEmail("bound@example.com").
+ SetUsername("bound-user").
+ SetPasswordHash(passwordHash).
+ SetBalance(2).
+ SetConcurrency(3).
+ SetRole(service.RoleUser).
+ SetStatus(service.StatusActive).
+ Save(ctx)
+ require.NoError(t, err)
+ _, err = client.AuthIdentity.Create().
+ SetUserID(user.ID).
+ SetProviderType("email").
+ SetProviderKey("email").
+ SetProviderSubject("bound@example.com").
+ SetVerifiedAt(time.Now().UTC()).
+ SetMetadata(map[string]any{"source": "preexisting"}).
+ Save(ctx)
+ require.NoError(t, err)
+
+ token, gotUser, err := svc.Login(ctx, user.Email, "password")
+ require.NoError(t, err)
+ require.NotEmpty(t, token)
+ require.NotNil(t, gotUser)
+
+ storedUser, err := client.User.Get(ctx, user.ID)
+ require.NoError(t, err)
+ require.Equal(t, 2.0, storedUser.Balance)
+ require.Equal(t, 3, storedUser.Concurrency)
+ require.Empty(t, assigner.calls)
+ require.Equal(t, 0, countProviderGrantRecords(t, client, user.ID, "email", "first_bind"))
+}
+
+func countProviderGrantRecords(
+ t *testing.T,
+ client *dbent.Client,
+ userID int64,
+ providerType string,
+ grantReason string,
+) int {
+ t.Helper()
+
+ var count int
+ rows, err := client.QueryContext(
+ context.Background(),
+ `SELECT COUNT(*) FROM user_provider_default_grants WHERE user_id = ? AND provider_type = ? AND grant_reason = ?`,
+ userID,
+ providerType,
+ grantReason,
+ )
+ require.NoError(t, err)
+ defer rows.Close()
+ require.True(t, rows.Next())
+ require.NoError(t, rows.Scan(&count))
+ require.NoError(t, rows.Err())
+ return count
+}
From 55e8dd550a4a261ab4bdf26f865da2eab11dafd8 Mon Sep 17 00:00:00 2001
From: IanShaw027
Date: Tue, 21 Apr 2026 00:33:23 +0800
Subject: [PATCH 073/189] Tighten WeChat payment resume flow
---
backend/internal/handler/auth_wechat_oauth.go | 36 +++++---
.../handler/auth_wechat_oauth_test.go | 61 +++++++++++++
backend/internal/handler/payment_handler.go | 69 ++++++++++++--
.../handler/payment_handler_resume_test.go | 61 +++++++++++++
.../internal/service/payment_resume_lookup.go | 4 +
.../service/payment_resume_service.go | 91 ++++++++++++++++---
.../service/payment_resume_service_test.go | 33 +++++++
frontend/src/types/payment.ts | 1 +
.../views/auth/WechatPaymentCallbackView.vue | 24 ++---
.../WechatPaymentCallbackView.spec.ts | 13 +--
frontend/src/views/user/PaymentResultView.vue | 3 +-
frontend/src/views/user/PaymentView.vue | 64 ++++++-------
.../user/__tests__/PaymentResultView.spec.ts | 19 ++++
.../__tests__/paymentWechatResume.spec.ts | 56 ++++++++++++
.../src/views/user/paymentWechatResume.ts | 77 ++++++++++++++++
15 files changed, 514 insertions(+), 98 deletions(-)
create mode 100644 backend/internal/handler/payment_handler_resume_test.go
create mode 100644 frontend/src/views/user/__tests__/paymentWechatResume.spec.ts
create mode 100644 frontend/src/views/user/paymentWechatResume.ts
diff --git a/backend/internal/handler/auth_wechat_oauth.go b/backend/internal/handler/auth_wechat_oauth.go
index b078b804..45de30a8 100644
--- a/backend/internal/handler/auth_wechat_oauth.go
+++ b/backend/internal/handler/auth_wechat_oauth.go
@@ -435,24 +435,34 @@ func (h *AuthHandler) WeChatPaymentOAuthCallback(c *gin.Context) {
scope = strings.TrimSpace(tokenResp.Scope)
}
+ resumeToken, err := h.wechatPaymentResumeService().CreateWeChatPaymentResumeToken(service.WeChatPaymentResumeClaims{
+ OpenID: openid,
+ PaymentType: paymentContext.PaymentType,
+ Amount: paymentContext.Amount,
+ OrderType: paymentContext.OrderType,
+ PlanID: paymentContext.PlanID,
+ RedirectTo: redirectTo,
+ Scope: scope,
+ })
+ if err != nil {
+ redirectOAuthError(c, frontendCallback, "invalid_context", "failed to encode payment resume context", "")
+ return
+ }
+
fragment := url.Values{}
- fragment.Set("openid", openid)
- fragment.Set("state", state)
- fragment.Set("scope", scope)
- fragment.Set("payment_type", paymentContext.PaymentType)
- if paymentContext.Amount != "" {
- fragment.Set("amount", paymentContext.Amount)
- }
- if paymentContext.OrderType != "" {
- fragment.Set("order_type", paymentContext.OrderType)
- }
- if paymentContext.PlanID > 0 {
- fragment.Set("plan_id", strconv.FormatInt(paymentContext.PlanID, 10))
- }
+ fragment.Set("wechat_resume_token", resumeToken)
fragment.Set("redirect", redirectTo)
redirectWithFragment(c, frontendCallback, fragment)
}
+func (h *AuthHandler) wechatPaymentResumeService() *service.PaymentResumeService {
+ key, err := payment.ProvideEncryptionKey(h.cfg)
+ if err != nil {
+ return service.NewPaymentResumeService(nil)
+ }
+ return service.NewPaymentResumeService([]byte(key))
+}
+
type completeWeChatOAuthRequest struct {
InvitationCode string `json:"invitation_code" binding:"required"`
AdoptDisplayName *bool `json:"adopt_display_name,omitempty"`
diff --git a/backend/internal/handler/auth_wechat_oauth_test.go b/backend/internal/handler/auth_wechat_oauth_test.go
index def9d5d6..c65f4cd1 100644
--- a/backend/internal/handler/auth_wechat_oauth_test.go
+++ b/backend/internal/handler/auth_wechat_oauth_test.go
@@ -21,6 +21,7 @@ import (
"github.com/Wei-Shaw/sub2api/ent/pendingauthsession"
dbuser "github.com/Wei-Shaw/sub2api/ent/user"
"github.com/Wei-Shaw/sub2api/internal/config"
+ "github.com/Wei-Shaw/sub2api/internal/payment"
"github.com/Wei-Shaw/sub2api/internal/repository"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin"
@@ -175,6 +176,66 @@ func TestWeChatOAuthCallbackRejectsMissingUnionID(t *testing.T) {
require.Zero(t, count)
}
+func TestWeChatPaymentOAuthCallbackRedirectsWithOpaqueResumeToken(t *testing.T) {
+ t.Setenv("WECHAT_OAUTH_MP_APP_ID", "wx-mp-app")
+ t.Setenv("WECHAT_OAUTH_MP_APP_SECRET", "wx-mp-secret")
+
+ originalAccessTokenURL := wechatOAuthAccessTokenURL
+ t.Cleanup(func() {
+ wechatOAuthAccessTokenURL = originalAccessTokenURL
+ })
+
+ upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ if strings.Contains(r.URL.Path, "/sns/oauth2/access_token") {
+ w.Header().Set("Content-Type", "application/json")
+ _, _ = w.Write([]byte(`{"access_token":"wechat-access","openid":"openid-123","scope":"snsapi_base"}`))
+ return
+ }
+ http.NotFound(w, r)
+ }))
+ defer upstream.Close()
+ wechatOAuthAccessTokenURL = upstream.URL + "/sns/oauth2/access_token"
+
+ handler, client := newWeChatOAuthTestHandler(t, false)
+ defer client.Close()
+ handler.cfg.Totp.EncryptionKey = "0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef"
+
+ recorder := httptest.NewRecorder()
+ c, _ := gin.CreateTestContext(recorder)
+ req := httptest.NewRequest(http.MethodGet, "/api/v1/auth/oauth/wechat/payment/callback?code=wechat-code&state=state-123", nil)
+ req.Host = "api.example.com"
+ req.AddCookie(encodedCookie(wechatPaymentOAuthStateName, "state-123"))
+ req.AddCookie(encodedCookie(wechatPaymentOAuthRedirect, "/purchase?from=wechat"))
+ req.AddCookie(encodedCookie(wechatPaymentOAuthContextName, `{"payment_type":"wxpay","amount":"12.5","order_type":"subscription","plan_id":7}`))
+ req.AddCookie(encodedCookie(wechatPaymentOAuthScope, "snsapi_base"))
+ c.Request = req
+
+ handler.WeChatPaymentOAuthCallback(c)
+
+ require.Equal(t, http.StatusFound, recorder.Code)
+ location := recorder.Header().Get("Location")
+ parsed, err := url.Parse(location)
+ require.NoError(t, err)
+ fragment, err := url.ParseQuery(parsed.Fragment)
+ require.NoError(t, err)
+ require.Equal(t, "/purchase?from=wechat", fragment.Get("redirect"))
+ require.NotEmpty(t, fragment.Get("wechat_resume_token"))
+ require.Empty(t, fragment.Get("openid"))
+ require.Empty(t, fragment.Get("payment_type"))
+ require.Empty(t, fragment.Get("amount"))
+ require.Empty(t, fragment.Get("order_type"))
+ require.Empty(t, fragment.Get("plan_id"))
+
+ claims, err := handler.wechatPaymentResumeService().ParseWeChatPaymentResumeToken(fragment.Get("wechat_resume_token"))
+ require.NoError(t, err)
+ require.Equal(t, "openid-123", claims.OpenID)
+ require.Equal(t, payment.TypeWxpay, claims.PaymentType)
+ require.Equal(t, "12.5", claims.Amount)
+ require.Equal(t, payment.OrderTypeSubscription, claims.OrderType)
+ require.EqualValues(t, 7, claims.PlanID)
+ require.Equal(t, "/purchase?from=wechat", claims.RedirectTo)
+}
+
func TestWeChatOAuthCallbackBindUsesUnionCanonicalIdentityAcrossChannels(t *testing.T) {
testCases := []struct {
name string
diff --git a/backend/internal/handler/payment_handler.go b/backend/internal/handler/payment_handler.go
index d54cbe92..5fd6b43e 100644
--- a/backend/internal/handler/payment_handler.go
+++ b/backend/internal/handler/payment_handler.go
@@ -1,9 +1,12 @@
package handler
import (
+ "fmt"
"strconv"
"strings"
+ "github.com/Wei-Shaw/sub2api/internal/payment"
+ infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware"
@@ -202,14 +205,15 @@ func (h *PaymentHandler) GetLimits(c *gin.Context) {
// CreateOrderRequest is the request body for creating a payment order.
type CreateOrderRequest struct {
- Amount float64 `json:"amount"`
- PaymentType string `json:"payment_type" binding:"required"`
- OpenID string `json:"openid"`
- ReturnURL string `json:"return_url"`
- PaymentSource string `json:"payment_source"`
- OrderType string `json:"order_type"`
- PlanID int64 `json:"plan_id"`
- IsMobile *bool `json:"is_mobile,omitempty"`
+ Amount float64 `json:"amount"`
+ PaymentType string `json:"payment_type" binding:"required"`
+ OpenID string `json:"openid"`
+ WechatResumeToken string `json:"wechat_resume_token"`
+ ReturnURL string `json:"return_url"`
+ PaymentSource string `json:"payment_source"`
+ OrderType string `json:"order_type"`
+ PlanID int64 `json:"plan_id"`
+ IsMobile *bool `json:"is_mobile,omitempty"`
}
// CreateOrder creates a new payment order.
@@ -225,6 +229,17 @@ func (h *PaymentHandler) CreateOrder(c *gin.Context) {
response.BadRequest(c, "Invalid request: "+err.Error())
return
}
+ if strings.TrimSpace(req.WechatResumeToken) != "" {
+ claims, err := h.paymentService.ParseWeChatPaymentResumeToken(req.WechatResumeToken)
+ if err != nil {
+ response.ErrorFrom(c, err)
+ return
+ }
+ if err := applyWeChatPaymentResumeClaims(&req, claims); err != nil {
+ response.ErrorFrom(c, err)
+ return
+ }
+ }
mobile := isMobile(c)
if req.IsMobile != nil {
@@ -253,6 +268,44 @@ func (h *PaymentHandler) CreateOrder(c *gin.Context) {
response.Success(c, result)
}
+func applyWeChatPaymentResumeClaims(req *CreateOrderRequest, claims *service.WeChatPaymentResumeClaims) error {
+ if req == nil || claims == nil {
+ return infraerrors.BadRequest("INVALID_WECHAT_PAYMENT_RESUME_TOKEN", "wechat payment resume context is missing")
+ }
+ openid := strings.TrimSpace(claims.OpenID)
+ if openid == "" {
+ return infraerrors.BadRequest("INVALID_WECHAT_PAYMENT_RESUME_TOKEN", "wechat payment resume token missing openid")
+ }
+
+ paymentType := service.NormalizeVisibleMethod(claims.PaymentType)
+ if paymentType == "" {
+ paymentType = payment.TypeWxpay
+ }
+ if req.PaymentType != "" {
+ requestPaymentType := service.NormalizeVisibleMethod(req.PaymentType)
+ if requestPaymentType != "" && requestPaymentType != paymentType {
+ return infraerrors.BadRequest("INVALID_WECHAT_PAYMENT_RESUME_TOKEN", "wechat payment resume token payment type mismatch")
+ }
+ }
+ req.PaymentType = paymentType
+ req.OpenID = openid
+
+ if strings.TrimSpace(claims.Amount) != "" {
+ amount, err := strconv.ParseFloat(strings.TrimSpace(claims.Amount), 64)
+ if err != nil || amount <= 0 {
+ return infraerrors.BadRequest("INVALID_WECHAT_PAYMENT_RESUME_TOKEN", fmt.Sprintf("invalid resume amount: %s", claims.Amount))
+ }
+ req.Amount = amount
+ }
+ if claims.OrderType != "" {
+ req.OrderType = claims.OrderType
+ }
+ if claims.PlanID > 0 {
+ req.PlanID = claims.PlanID
+ }
+ return nil
+}
+
// GetMyOrders returns the authenticated user's orders.
// GET /api/v1/payment/orders/my
func (h *PaymentHandler) GetMyOrders(c *gin.Context) {
diff --git a/backend/internal/handler/payment_handler_resume_test.go b/backend/internal/handler/payment_handler_resume_test.go
new file mode 100644
index 00000000..323f7292
--- /dev/null
+++ b/backend/internal/handler/payment_handler_resume_test.go
@@ -0,0 +1,61 @@
+//go:build unit
+
+package handler
+
+import (
+ "testing"
+
+ "github.com/Wei-Shaw/sub2api/internal/payment"
+ "github.com/Wei-Shaw/sub2api/internal/service"
+)
+
+func TestApplyWeChatPaymentResumeClaims(t *testing.T) {
+ t.Parallel()
+
+ req := CreateOrderRequest{
+ Amount: 0,
+ PaymentType: payment.TypeWxpay,
+ OrderType: payment.OrderTypeBalance,
+ }
+
+ err := applyWeChatPaymentResumeClaims(&req, &service.WeChatPaymentResumeClaims{
+ OpenID: "openid-123",
+ PaymentType: payment.TypeWxpay,
+ Amount: "12.50",
+ OrderType: payment.OrderTypeSubscription,
+ PlanID: 7,
+ })
+ if err != nil {
+ t.Fatalf("applyWeChatPaymentResumeClaims returned error: %v", err)
+ }
+ if req.OpenID != "openid-123" {
+ t.Fatalf("openid = %q, want %q", req.OpenID, "openid-123")
+ }
+ if req.Amount != 12.5 {
+ t.Fatalf("amount = %v, want 12.5", req.Amount)
+ }
+ if req.OrderType != payment.OrderTypeSubscription {
+ t.Fatalf("order_type = %q, want %q", req.OrderType, payment.OrderTypeSubscription)
+ }
+ if req.PlanID != 7 {
+ t.Fatalf("plan_id = %d, want 7", req.PlanID)
+ }
+}
+
+func TestApplyWeChatPaymentResumeClaimsRejectsPaymentTypeMismatch(t *testing.T) {
+ t.Parallel()
+
+ req := CreateOrderRequest{
+ PaymentType: payment.TypeAlipay,
+ }
+
+ err := applyWeChatPaymentResumeClaims(&req, &service.WeChatPaymentResumeClaims{
+ OpenID: "openid-123",
+ PaymentType: payment.TypeWxpay,
+ Amount: "12.50",
+ OrderType: payment.OrderTypeBalance,
+ })
+ if err == nil {
+ t.Fatal("applyWeChatPaymentResumeClaims should reject mismatched payment types")
+ }
+}
diff --git a/backend/internal/service/payment_resume_lookup.go b/backend/internal/service/payment_resume_lookup.go
index 493ca325..69033afd 100644
--- a/backend/internal/service/payment_resume_lookup.go
+++ b/backend/internal/service/payment_resume_lookup.go
@@ -33,3 +33,7 @@ func (s *PaymentService) GetPublicOrderByResumeToken(ctx context.Context, token
return order, nil
}
+
+func (s *PaymentService) ParseWeChatPaymentResumeToken(token string) (*WeChatPaymentResumeClaims, error) {
+ return s.paymentResume().ParseWeChatPaymentResumeToken(strings.TrimSpace(token))
+}
diff --git a/backend/internal/service/payment_resume_service.go b/backend/internal/service/payment_resume_service.go
index 4f63645e..64d1d125 100644
--- a/backend/internal/service/payment_resume_service.go
+++ b/backend/internal/service/payment_resume_service.go
@@ -31,6 +31,8 @@ const (
VisibleMethodSourceEasyPayAlipay = "easypay_alipay"
VisibleMethodSourceOfficialWechat = "official_wxpay"
VisibleMethodSourceEasyPayWechat = "easypay_wxpay"
+
+ wechatPaymentResumeTokenType = "wechat_payment_resume"
)
type ResumeTokenClaims struct {
@@ -43,6 +45,18 @@ type ResumeTokenClaims struct {
IssuedAt int64 `json:"iat"`
}
+type WeChatPaymentResumeClaims struct {
+ TokenType string `json:"tk,omitempty"`
+ OpenID string `json:"openid"`
+ PaymentType string `json:"pt,omitempty"`
+ Amount string `json:"amt,omitempty"`
+ OrderType string `json:"ot,omitempty"`
+ PlanID int64 `json:"pid,omitempty"`
+ RedirectTo string `json:"rd,omitempty"`
+ Scope string `json:"scp,omitempty"`
+ IssuedAt int64 `json:"iat"`
+}
+
type PaymentResumeService struct {
signingKey []byte
}
@@ -232,6 +246,66 @@ func (s *PaymentResumeService) CreateToken(claims ResumeTokenClaims) (string, er
if claims.IssuedAt == 0 {
claims.IssuedAt = time.Now().Unix()
}
+ return s.createSignedToken(claims)
+}
+
+func (s *PaymentResumeService) ParseToken(token string) (*ResumeTokenClaims, error) {
+ var claims ResumeTokenClaims
+ if err := s.parseSignedToken(token, &claims); err != nil {
+ return nil, infraerrors.BadRequest("INVALID_RESUME_TOKEN", "resume token payload is invalid")
+ }
+ if claims.OrderID <= 0 {
+ return nil, infraerrors.BadRequest("INVALID_RESUME_TOKEN", "resume token missing order id")
+ }
+ return &claims, nil
+}
+
+func (s *PaymentResumeService) CreateWeChatPaymentResumeToken(claims WeChatPaymentResumeClaims) (string, error) {
+ claims.OpenID = strings.TrimSpace(claims.OpenID)
+ if claims.OpenID == "" {
+ return "", fmt.Errorf("wechat payment resume token requires openid")
+ }
+ if claims.IssuedAt == 0 {
+ claims.IssuedAt = time.Now().Unix()
+ }
+ if normalized := NormalizeVisibleMethod(claims.PaymentType); normalized != "" {
+ claims.PaymentType = normalized
+ }
+ if claims.PaymentType == "" {
+ claims.PaymentType = payment.TypeWxpay
+ }
+ if claims.OrderType == "" {
+ claims.OrderType = payment.OrderTypeBalance
+ }
+ claims.TokenType = wechatPaymentResumeTokenType
+ return s.createSignedToken(claims)
+}
+
+func (s *PaymentResumeService) ParseWeChatPaymentResumeToken(token string) (*WeChatPaymentResumeClaims, error) {
+ var claims WeChatPaymentResumeClaims
+ if err := s.parseSignedToken(token, &claims); err != nil {
+ return nil, infraerrors.BadRequest("INVALID_WECHAT_PAYMENT_RESUME_TOKEN", "wechat payment resume token payload is invalid")
+ }
+ if claims.TokenType != wechatPaymentResumeTokenType {
+ return nil, infraerrors.BadRequest("INVALID_WECHAT_PAYMENT_RESUME_TOKEN", "wechat payment resume token type mismatch")
+ }
+ claims.OpenID = strings.TrimSpace(claims.OpenID)
+ if claims.OpenID == "" {
+ return nil, infraerrors.BadRequest("INVALID_WECHAT_PAYMENT_RESUME_TOKEN", "wechat payment resume token missing openid")
+ }
+ if normalized := NormalizeVisibleMethod(claims.PaymentType); normalized != "" {
+ claims.PaymentType = normalized
+ }
+ if claims.PaymentType == "" {
+ claims.PaymentType = payment.TypeWxpay
+ }
+ if claims.OrderType == "" {
+ claims.OrderType = payment.OrderTypeBalance
+ }
+ return &claims, nil
+}
+
+func (s *PaymentResumeService) createSignedToken(claims any) (string, error) {
payload, err := json.Marshal(claims)
if err != nil {
return "", fmt.Errorf("marshal resume claims: %w", err)
@@ -240,26 +314,19 @@ func (s *PaymentResumeService) CreateToken(claims ResumeTokenClaims) (string, er
return encodedPayload + "." + s.sign(encodedPayload), nil
}
-func (s *PaymentResumeService) ParseToken(token string) (*ResumeTokenClaims, error) {
+func (s *PaymentResumeService) parseSignedToken(token string, dest any) error {
parts := strings.Split(token, ".")
if len(parts) != 2 || parts[0] == "" || parts[1] == "" {
- return nil, infraerrors.BadRequest("INVALID_RESUME_TOKEN", "resume token is malformed")
+ return infraerrors.BadRequest("INVALID_RESUME_TOKEN", "resume token is malformed")
}
if !hmac.Equal([]byte(parts[1]), []byte(s.sign(parts[0]))) {
- return nil, infraerrors.BadRequest("INVALID_RESUME_TOKEN", "resume token signature mismatch")
+ return infraerrors.BadRequest("INVALID_RESUME_TOKEN", "resume token signature mismatch")
}
payload, err := base64.RawURLEncoding.DecodeString(parts[0])
if err != nil {
- return nil, infraerrors.BadRequest("INVALID_RESUME_TOKEN", "resume token payload is malformed")
+ return infraerrors.BadRequest("INVALID_RESUME_TOKEN", "resume token payload is malformed")
}
- var claims ResumeTokenClaims
- if err := json.Unmarshal(payload, &claims); err != nil {
- return nil, infraerrors.BadRequest("INVALID_RESUME_TOKEN", "resume token payload is invalid")
- }
- if claims.OrderID <= 0 {
- return nil, infraerrors.BadRequest("INVALID_RESUME_TOKEN", "resume token missing order id")
- }
- return &claims, nil
+ return json.Unmarshal(payload, dest)
}
func (s *PaymentResumeService) sign(payload string) string {
diff --git a/backend/internal/service/payment_resume_service_test.go b/backend/internal/service/payment_resume_service_test.go
index 9c35ac3d..24d50494 100644
--- a/backend/internal/service/payment_resume_service_test.go
+++ b/backend/internal/service/payment_resume_service_test.go
@@ -150,6 +150,39 @@ func TestPaymentResumeTokenRoundTrip(t *testing.T) {
}
}
+func TestWeChatPaymentResumeTokenRoundTrip(t *testing.T) {
+ t.Parallel()
+
+ svc := NewPaymentResumeService([]byte("0123456789abcdef0123456789abcdef"))
+ token, err := svc.CreateWeChatPaymentResumeToken(WeChatPaymentResumeClaims{
+ OpenID: "openid-123",
+ PaymentType: payment.TypeWxpay,
+ Amount: "12.50",
+ OrderType: payment.OrderTypeSubscription,
+ PlanID: 7,
+ RedirectTo: "/purchase?from=wechat",
+ Scope: "snsapi_base",
+ IssuedAt: 1234567890,
+ })
+ if err != nil {
+ t.Fatalf("CreateWeChatPaymentResumeToken returned error: %v", err)
+ }
+
+ claims, err := svc.ParseWeChatPaymentResumeToken(token)
+ if err != nil {
+ t.Fatalf("ParseWeChatPaymentResumeToken returned error: %v", err)
+ }
+ if claims.OpenID != "openid-123" || claims.PaymentType != payment.TypeWxpay {
+ t.Fatalf("claims mismatch: %+v", claims)
+ }
+ if claims.Amount != "12.50" || claims.OrderType != payment.OrderTypeSubscription || claims.PlanID != 7 {
+ t.Fatalf("claims payment context mismatch: %+v", claims)
+ }
+ if claims.RedirectTo != "/purchase?from=wechat" || claims.Scope != "snsapi_base" {
+ t.Fatalf("claims redirect/scope mismatch: %+v", claims)
+ }
+}
+
func TestNormalizeVisibleMethodSource(t *testing.T) {
t.Parallel()
diff --git a/frontend/src/types/payment.ts b/frontend/src/types/payment.ts
index 5cd49064..77fbe689 100644
--- a/frontend/src/types/payment.ts
+++ b/frontend/src/types/payment.ts
@@ -157,6 +157,7 @@ export interface CreateOrderRequest {
return_url?: string
payment_source?: string
openid?: string
+ wechat_resume_token?: string
is_mobile?: boolean
}
diff --git a/frontend/src/views/auth/WechatPaymentCallbackView.vue b/frontend/src/views/auth/WechatPaymentCallbackView.vue
index 422a0bb8..73095102 100644
--- a/frontend/src/views/auth/WechatPaymentCallbackView.vue
+++ b/frontend/src/views/auth/WechatPaymentCallbackView.vue
@@ -114,23 +114,17 @@ onMounted(async () => {
return
}
- const openid = readParam('openid')
- const state = readParam('state')
- const scope = readParam('scope')
- const paymentType = readParam('payment_type')
- const amount = readParam('amount')
- const orderType = readParam('order_type')
- const planId = readParam('plan_id')
+ const resumeToken = readParam('wechat_resume_token')
const redirectURL = new URL(
normalizeRedirectPath(readParam('redirect')),
window.location.origin,
)
- if (!openid) {
+ if (!resumeToken) {
errorMessage.value = textWithFallback(
- 'auth.wechatPayment.callbackMissingOpenId',
- '微信支付回调缺少 openid。',
- 'The WeChat payment callback is missing the openid.',
+ 'auth.wechatPayment.callbackMissingResumeToken',
+ '微信支付回调缺少恢复令牌。',
+ 'The WeChat payment callback is missing the resume token.',
)
return
}
@@ -138,14 +132,8 @@ onMounted(async () => {
const query: Record = {
...Object.fromEntries(redirectURL.searchParams.entries()),
wechat_resume: '1',
- openid,
+ wechat_resume_token: resumeToken,
}
- if (state) query.state = state
- if (scope) query.scope = scope
- if (paymentType) query.payment_type = paymentType
- if (amount) query.amount = amount
- if (orderType) query.order_type = orderType
- if (planId) query.plan_id = planId
await router.replace({
path: redirectURL.pathname,
diff --git a/frontend/src/views/auth/__tests__/WechatPaymentCallbackView.spec.ts b/frontend/src/views/auth/__tests__/WechatPaymentCallbackView.spec.ts
index cfbd9f1c..400e50d5 100644
--- a/frontend/src/views/auth/__tests__/WechatPaymentCallbackView.spec.ts
+++ b/frontend/src/views/auth/__tests__/WechatPaymentCallbackView.spec.ts
@@ -49,8 +49,8 @@ describe('WechatPaymentCallbackView', () => {
})
})
- it('redirects back to purchase with openid and payment context from hash fragment', async () => {
- locationState.current.hash = '#openid=openid-123&payment_type=wxpay&amount=12.5&order_type=balance&redirect=%2Fpurchase%3Ffrom%3Dwechat'
+ it('redirects back to purchase with an opaque resume token from hash fragment', async () => {
+ locationState.current.hash = '#wechat_resume_token=resume-token-123&redirect=%2Fpurchase%3Ffrom%3Dwechat'
mount(WechatPaymentCallbackView)
await flushPromises()
@@ -60,21 +60,18 @@ describe('WechatPaymentCallbackView', () => {
query: {
from: 'wechat',
wechat_resume: '1',
- openid: 'openid-123',
- payment_type: 'wxpay',
- amount: '12.5',
- order_type: 'balance',
+ wechat_resume_token: 'resume-token-123',
},
})
})
- it('shows an error when the callback payload is missing openid', async () => {
+ it('shows an error when the callback payload is missing the resume token', async () => {
locationState.current.hash = '#payment_type=wxpay'
const wrapper = mount(WechatPaymentCallbackView)
await flushPromises()
expect(replaceMock).not.toHaveBeenCalled()
- expect(wrapper.text()).toContain('微信支付回调缺少 openid。')
+ expect(wrapper.text()).toContain('微信支付回调缺少恢复令牌。')
})
})
diff --git a/frontend/src/views/user/PaymentResultView.vue b/frontend/src/views/user/PaymentResultView.vue
index e1bcbbe5..5c843d1f 100644
--- a/frontend/src/views/user/PaymentResultView.vue
+++ b/frontend/src/views/user/PaymentResultView.vue
@@ -188,7 +188,8 @@ onMounted(async () => {
}
}
- const hasLegacyFallbackContext = Boolean(route.query.trade_status || route.query.money || route.query.type)
+ const hasLegacyFallbackContext = typeof route.query.trade_status === 'string'
+ && route.query.trade_status.trim() !== ''
if (!order.value && !resumeToken && !orderId && outTradeNo && hasLegacyFallbackContext) {
returnInfo.value = {
outTradeNo,
diff --git a/frontend/src/views/user/PaymentView.vue b/frontend/src/views/user/PaymentView.vue
index 019de16a..a3a81ccb 100644
--- a/frontend/src/views/user/PaymentView.vue
+++ b/frontend/src/views/user/PaymentView.vue
@@ -284,6 +284,7 @@ import PaymentStatusPanel from '@/components/payment/PaymentStatusPanel.vue'
import Icon from '@/components/icons/Icon.vue'
import type { PaymentMethodOption } from '@/components/payment/PaymentMethodSelector.vue'
import { describePaymentScenarioError } from './paymentUx'
+import { parseWechatResumeRoute, stripWechatResumeQuery } from './paymentWechatResume'
const { t } = useI18n()
const route = useRoute()
@@ -315,6 +316,7 @@ const paymentPhase = ref<'select' | 'paying'>('select')
interface CreateOrderOptions {
openid?: string
+ wechatResumeToken?: string
paymentType?: string
isResume?: boolean
}
@@ -344,13 +346,6 @@ function emptyPaymentState(): PaymentRecoverySnapshot {
}
}
-function readRouteQueryValue(value: unknown): string {
- if (Array.isArray(value)) {
- return typeof value[0] === 'string' ? value[0] : ''
- }
- return typeof value === 'string' ? value : ''
-}
-
function getWeixinJSBridge(): WeixinJSBridgeLike | undefined {
return (window as Window & { WeixinJSBridge?: WeixinJSBridgeLike }).WeixinJSBridge
}
@@ -637,6 +632,9 @@ async function createOrder(orderAmount: number, orderType: OrderType, planId?: n
if (options.openid) {
payload.openid = options.openid
}
+ if (options.wechatResumeToken) {
+ payload.wechat_resume_token = options.wechatResumeToken
+ }
payload.is_mobile = isMobileDevice()
const result = await paymentStore.createOrder(payload) as CreateOrderResult & { resume_token?: string }
@@ -744,44 +742,34 @@ function applyScenarioError(err: unknown, paymentMethod: string) {
}
async function resumeWechatPaymentFromQuery() {
- const openid = readRouteQueryValue(route.query.openid)
- if (readRouteQueryValue(route.query.wechat_resume) !== '1' || !openid) {
+ const resume = parseWechatResumeRoute(route.query, checkout.value.plans, validAmount.value)
+ if (!resume) {
return
}
- const paymentType = normalizeVisibleMethod(readRouteQueryValue(route.query.payment_type)) || 'wxpay'
- const orderType = readRouteQueryValue(route.query.order_type) === 'subscription' ? 'subscription' : 'balance'
- const planId = Number.parseInt(readRouteQueryValue(route.query.plan_id), 10)
- const rawAmount = Number.parseFloat(readRouteQueryValue(route.query.amount))
- const orderAmount = Number.isFinite(rawAmount) && rawAmount > 0
- ? rawAmount
- : (orderType === 'subscription'
- ? (checkout.value.plans.find(plan => plan.id === planId)?.price ?? 0)
- : validAmount.value)
-
- selectedMethod.value = paymentType
- if (orderType === 'balance' && orderAmount > 0) {
- amount.value = orderAmount
+ selectedMethod.value = resume.paymentType
+ if (resume.orderType === 'balance' && resume.orderAmount > 0) {
+ amount.value = resume.orderAmount
}
- if (orderType === 'subscription' && Number.isFinite(planId) && planId > 0) {
- selectedPlan.value = checkout.value.plans.find(plan => plan.id === planId) ?? null
+ if (resume.orderType === 'subscription' && resume.planId) {
+ selectedPlan.value = checkout.value.plans.find(plan => plan.id === resume.planId) ?? null
}
- const nextQuery = { ...route.query }
- delete nextQuery.wechat_resume
- delete nextQuery.openid
- delete nextQuery.state
- delete nextQuery.scope
- delete nextQuery.payment_type
- delete nextQuery.amount
- delete nextQuery.order_type
- delete nextQuery.plan_id
- await router.replace({ path: route.path, query: nextQuery })
+ await router.replace({ path: route.path, query: stripWechatResumeQuery(route.query) })
- if (orderAmount > 0) {
- await createOrder(orderAmount, orderType, Number.isFinite(planId) && planId > 0 ? planId : undefined, {
- openid,
- paymentType,
+ if (resume.wechatResumeToken) {
+ await createOrder(0, resume.orderType, resume.planId, {
+ wechatResumeToken: resume.wechatResumeToken,
+ paymentType: resume.paymentType,
+ isResume: true,
+ })
+ return
+ }
+
+ if (resume.orderAmount > 0 && resume.openid) {
+ await createOrder(resume.orderAmount, resume.orderType, resume.planId, {
+ openid: resume.openid,
+ paymentType: resume.paymentType,
isResume: true,
})
}
diff --git a/frontend/src/views/user/__tests__/PaymentResultView.spec.ts b/frontend/src/views/user/__tests__/PaymentResultView.spec.ts
index bfc044a7..d8199e3b 100644
--- a/frontend/src/views/user/__tests__/PaymentResultView.spec.ts
+++ b/frontend/src/views/user/__tests__/PaymentResultView.spec.ts
@@ -157,6 +157,25 @@ describe('PaymentResultView', () => {
expect(wrapper.text()).toContain('payment.result.success')
})
+ it('does not use public out_trade_no verification for bare order numbers without legacy return markers', async () => {
+ routeState.query = {
+ out_trade_no: 'legacy-bare',
+ }
+
+ mount(PaymentResultView, {
+ global: {
+ stubs: {
+ OrderStatusBadge: true,
+ },
+ },
+ })
+
+ await flushPromises()
+
+ expect(verifyOrderPublic).not.toHaveBeenCalled()
+ expect(verifyOrder).not.toHaveBeenCalled()
+ })
+
it('resolves order by resume token when local recovery snapshot is missing', async () => {
routeState.query = {
resume_token: 'resume-77',
diff --git a/frontend/src/views/user/__tests__/paymentWechatResume.spec.ts b/frontend/src/views/user/__tests__/paymentWechatResume.spec.ts
new file mode 100644
index 00000000..c850ec1b
--- /dev/null
+++ b/frontend/src/views/user/__tests__/paymentWechatResume.spec.ts
@@ -0,0 +1,56 @@
+import { describe, expect, it } from 'vitest'
+import { parseWechatResumeRoute, stripWechatResumeQuery } from '../paymentWechatResume'
+
+describe('parseWechatResumeRoute', () => {
+ it('prefers the opaque resume token over legacy openid query params', () => {
+ expect(parseWechatResumeRoute({
+ wechat_resume: '1',
+ wechat_resume_token: 'resume-token-123',
+ openid: 'openid-123',
+ payment_type: 'wxpay',
+ amount: '12.5',
+ order_type: 'subscription',
+ plan_id: '7',
+ }, [], 88)).toEqual({
+ wechatResumeToken: 'resume-token-123',
+ paymentType: 'wxpay',
+ orderType: 'balance',
+ orderAmount: 0,
+ })
+ })
+
+ it('falls back to legacy openid-based resume when opaque token is absent', () => {
+ expect(parseWechatResumeRoute({
+ wechat_resume: '1',
+ openid: 'openid-123',
+ payment_type: 'wxpay',
+ amount: '12.5',
+ order_type: 'balance',
+ }, [], 88)).toEqual({
+ openid: 'openid-123',
+ paymentType: 'wxpay',
+ orderType: 'balance',
+ orderAmount: 12.5,
+ planId: undefined,
+ })
+ })
+})
+
+describe('stripWechatResumeQuery', () => {
+ it('removes both opaque-token and legacy resume params from the route query', () => {
+ expect(stripWechatResumeQuery({
+ foo: 'bar',
+ wechat_resume: '1',
+ wechat_resume_token: 'resume-token-123',
+ openid: 'openid-123',
+ payment_type: 'wxpay',
+ amount: '12.5',
+ order_type: 'subscription',
+ plan_id: '7',
+ state: 'state-123',
+ scope: 'snsapi_base',
+ })).toEqual({
+ foo: 'bar',
+ })
+ })
+})
diff --git a/frontend/src/views/user/paymentWechatResume.ts b/frontend/src/views/user/paymentWechatResume.ts
new file mode 100644
index 00000000..f53c8457
--- /dev/null
+++ b/frontend/src/views/user/paymentWechatResume.ts
@@ -0,0 +1,77 @@
+import type { LocationQuery, LocationQueryRaw } from 'vue-router'
+import type { SubscriptionPlan } from '@/types/payment'
+import { normalizeVisibleMethod } from '@/components/payment/paymentFlow'
+
+export interface ParsedWechatResumeRoute {
+ orderAmount: number
+ orderType: 'balance' | 'subscription'
+ paymentType: string
+ planId?: number
+ openid?: string
+ wechatResumeToken?: string
+}
+
+function readQueryString(query: LocationQuery, key: string): string {
+ const value = query[key]
+ if (Array.isArray(value)) {
+ return typeof value[0] === 'string' ? value[0] : ''
+ }
+ return typeof value === 'string' ? value : ''
+}
+
+export function parseWechatResumeRoute(
+ query: LocationQuery,
+ plans: SubscriptionPlan[],
+ fallbackBalanceAmount: number,
+): ParsedWechatResumeRoute | null {
+ if (readQueryString(query, 'wechat_resume') !== '1') {
+ return null
+ }
+
+ const wechatResumeToken = readQueryString(query, 'wechat_resume_token')
+ if (wechatResumeToken) {
+ return {
+ wechatResumeToken,
+ paymentType: 'wxpay',
+ orderType: 'balance',
+ orderAmount: 0,
+ }
+ }
+
+ const openid = readQueryString(query, 'openid')
+ if (!openid) {
+ return null
+ }
+
+ const paymentType = normalizeVisibleMethod(readQueryString(query, 'payment_type')) || 'wxpay'
+ const orderType = readQueryString(query, 'order_type') === 'subscription' ? 'subscription' : 'balance'
+ const planId = Number.parseInt(readQueryString(query, 'plan_id'), 10)
+ const rawAmount = Number.parseFloat(readQueryString(query, 'amount'))
+ const orderAmount = Number.isFinite(rawAmount) && rawAmount > 0
+ ? rawAmount
+ : (orderType === 'subscription'
+ ? (plans.find(plan => plan.id === planId)?.price ?? 0)
+ : fallbackBalanceAmount)
+
+ return {
+ openid,
+ paymentType,
+ orderType,
+ orderAmount,
+ planId: Number.isFinite(planId) && planId > 0 ? planId : undefined,
+ }
+}
+
+export function stripWechatResumeQuery(query: LocationQuery): LocationQueryRaw {
+ const nextQuery: LocationQueryRaw = { ...query }
+ delete nextQuery.wechat_resume
+ delete nextQuery.wechat_resume_token
+ delete nextQuery.openid
+ delete nextQuery.state
+ delete nextQuery.scope
+ delete nextQuery.payment_type
+ delete nextQuery.amount
+ delete nextQuery.order_type
+ delete nextQuery.plan_id
+ return nextQuery
+}
From 030da8c2f66bdf031c245867322660b44776ed8d Mon Sep 17 00:00:00 2001
From: IanShaw027
Date: Tue, 21 Apr 2026 00:41:29 +0800
Subject: [PATCH 074/189] fix: close admin settings review gaps
---
.../settings.paymentVisibleMethods.spec.ts | 8 ++---
frontend/src/api/admin/settings.ts | 4 +--
frontend/src/components/layout/AppSidebar.vue | 6 ++++
.../layout/__tests__/AppSidebar.spec.ts | 6 ++++
.../AuthIdentityMigrationReportsView.vue | 30 +++++++++++++++----
frontend/src/views/admin/SettingsView.vue | 30 +++++++++++++++++--
.../AuthIdentityMigrationReportsView.spec.ts | 17 +++++++++++
.../admin/__tests__/SettingsView.spec.ts | 23 ++++++++++++++
8 files changed, 110 insertions(+), 14 deletions(-)
diff --git a/frontend/src/api/__tests__/settings.paymentVisibleMethods.spec.ts b/frontend/src/api/__tests__/settings.paymentVisibleMethods.spec.ts
index 3b1a373f..ad355afe 100644
--- a/frontend/src/api/__tests__/settings.paymentVisibleMethods.spec.ts
+++ b/frontend/src/api/__tests__/settings.paymentVisibleMethods.spec.ts
@@ -27,8 +27,8 @@ describe('admin settings payment visible method helpers', () => {
expect(getPaymentVisibleMethodSourceOptions('alipay')).toEqual([
{
value: '',
- labelZh: '自动路由',
- labelEn: 'Automatic routing',
+ labelZh: '未配置',
+ labelEn: 'Not configured',
},
{
value: 'official_alipay',
@@ -45,8 +45,8 @@ describe('admin settings payment visible method helpers', () => {
expect(getPaymentVisibleMethodSourceOptions('wxpay')).toEqual([
{
value: '',
- labelZh: '自动路由',
- labelEn: 'Automatic routing',
+ labelZh: '未配置',
+ labelEn: 'Not configured',
},
{
value: 'official_wxpay',
diff --git a/frontend/src/api/admin/settings.ts b/frontend/src/api/admin/settings.ts
index 505fcdca..235bda7b 100644
--- a/frontend/src/api/admin/settings.ts
+++ b/frontend/src/api/admin/settings.ts
@@ -44,12 +44,12 @@ const PAYMENT_VISIBLE_METHOD_SOURCE_OPTIONS: Record<
PaymentVisibleMethodSourceOption[]
> = {
alipay: [
- { value: '', labelZh: '自动路由', labelEn: 'Automatic routing' },
+ { value: '', labelZh: '未配置', labelEn: 'Not configured' },
{ value: 'official_alipay', labelZh: '支付宝官方', labelEn: 'Official Alipay' },
{ value: 'easypay_alipay', labelZh: '易支付支付宝', labelEn: 'EasyPay Alipay' },
],
wxpay: [
- { value: '', labelZh: '自动路由', labelEn: 'Automatic routing' },
+ { value: '', labelZh: '未配置', labelEn: 'Not configured' },
{ value: 'official_wxpay', labelZh: '微信官方', labelEn: 'Official WeChat Pay' },
{ value: 'easypay_wxpay', labelZh: '易支付微信', labelEn: 'EasyPay WeChat Pay' },
],
diff --git a/frontend/src/components/layout/AppSidebar.vue b/frontend/src/components/layout/AppSidebar.vue
index 92dcc519..b7158d9b 100644
--- a/frontend/src/components/layout/AppSidebar.vue
+++ b/frontend/src/components/layout/AppSidebar.vue
@@ -663,6 +663,12 @@ const adminNavItems = computed((): NavItem[] => {
? [{ path: '/admin/ops', label: t('nav.ops'), icon: ChartIcon }]
: []),
{ path: '/admin/users', label: t('nav.users'), icon: UsersIcon, hideInSimpleMode: true },
+ {
+ path: '/admin/users/auth-identity-migration-reports',
+ label: 'Migration Reports',
+ icon: UsersIcon,
+ hideInSimpleMode: true
+ },
{ path: '/admin/groups', label: t('nav.groups'), icon: FolderIcon, hideInSimpleMode: true },
{ path: '/admin/channels', label: t('nav.channels', '渠道管理'), icon: ChannelIcon, hideInSimpleMode: true },
{ path: '/admin/subscriptions', label: t('nav.subscriptions'), icon: CreditCardIcon, hideInSimpleMode: true },
diff --git a/frontend/src/components/layout/__tests__/AppSidebar.spec.ts b/frontend/src/components/layout/__tests__/AppSidebar.spec.ts
index 118c7615..915a67f8 100644
--- a/frontend/src/components/layout/__tests__/AppSidebar.spec.ts
+++ b/frontend/src/components/layout/__tests__/AppSidebar.spec.ts
@@ -30,3 +30,9 @@ describe('AppSidebar header styles', () => {
expect(sidebarBrandBlockMatch?.[0]).not.toContain('overflow: hidden;')
})
})
+
+describe('AppSidebar admin navigation', () => {
+ it('includes a visible entry for auth identity migration reports', () => {
+ expect(componentSource).toContain("'/admin/users/auth-identity-migration-reports'")
+ })
+})
diff --git a/frontend/src/views/admin/AuthIdentityMigrationReportsView.vue b/frontend/src/views/admin/AuthIdentityMigrationReportsView.vue
index 5aeb6b28..35c232b6 100644
--- a/frontend/src/views/admin/AuthIdentityMigrationReportsView.vue
+++ b/frontend/src/views/admin/AuthIdentityMigrationReportsView.vue
@@ -318,6 +318,7 @@ const pagination = reactive({
pageSize: 20,
total: 0,
})
+const knownReportTypes = ref([])
const columns: Column[] = [
{ key: 'status', label: text('状态', 'Status') },
@@ -330,12 +331,16 @@ const columns: Column[] = [
]
const reportTypeOptions = computed(() =>
- Object.entries(summary.value.by_type)
- .sort(([left], [right]) => left.localeCompare(right))
- .map(([value, count]) => ({
- value,
- label: `${value} (${count})`,
- }))
+ knownReportTypes.value
+ .slice()
+ .sort((left, right) => left.localeCompare(right))
+ .map((value) => {
+ const count = summary.value.by_type[value]
+ return {
+ value,
+ label: count === undefined ? value : `${value} (${count})`,
+ }
+ })
)
const canResolve = computed(() =>
@@ -347,10 +352,22 @@ const canResolve = computed(() =>
)
)
+const mergeKnownReportTypes = (...values: Array) => {
+ const merged = new Set(knownReportTypes.value)
+ for (const value of values) {
+ const normalized = value?.trim()
+ if (normalized) {
+ merged.add(normalized)
+ }
+ }
+ knownReportTypes.value = Array.from(merged)
+}
+
const loadSummary = async () => {
summaryLoading.value = true
try {
summary.value = await adminAPI.users.getAuthIdentityMigrationReportSummary()
+ mergeKnownReportTypes(...Object.keys(summary.value.by_type))
} catch (error) {
console.error('Failed to load auth identity migration report summary:', error)
appStore.showError(text('加载 migration reports 汇总失败', 'Failed to load migration report summary'))
@@ -370,6 +387,7 @@ const loadReports = async () => {
reports.value = response.items
pagination.total = response.total
+ mergeKnownReportTypes(filters.reportType, ...response.items.map((report) => report.report_type))
if (selectedReport.value) {
const refreshed = response.items.find((report) => report.id === selectedReport.value?.id) ?? null
diff --git a/frontend/src/views/admin/SettingsView.vue b/frontend/src/views/admin/SettingsView.vue
index 8a042e70..9fb8da41 100644
--- a/frontend/src/views/admin/SettingsView.vue
+++ b/frontend/src/views/admin/SettingsView.vue
@@ -2728,8 +2728,8 @@
{{
localText(
- '留空表示自动路由;仅允许当前系统支持的官方或易支付来源。',
- 'Leave blank for automatic routing. Only supported official or EasyPay sources are allowed.'
+ '启用后必须明确选择一个来源;未配置状态不会对外展示该支付方式。',
+ 'Choose an explicit source before enabling the method. Not configured methods are not exposed.'
)
}}
@@ -3450,6 +3450,28 @@ function setPaymentVisibleMethodSource(
form.payment_visible_method_wxpay_source = normalized
}
+function validatePaymentVisibleMethodSelections(): boolean {
+ for (const visibleMethod of paymentVisibleMethodCards.value) {
+ if (!getPaymentVisibleMethodEnabled(visibleMethod.key)) {
+ continue
+ }
+
+ if (getPaymentVisibleMethodSource(visibleMethod.key)) {
+ continue
+ }
+
+ appStore.showError(
+ localText(
+ `${visibleMethod.title} 已启用,请先选择支付来源`,
+ `Select a payment source before enabling ${visibleMethod.title}`
+ )
+ )
+ return false
+ }
+
+ return true
+}
+
// Proxies for web search emulation ProxySelector
const webSearchProxies = ref([])
@@ -3979,6 +4001,10 @@ async function saveSettings() {
}
}
+ if (!validatePaymentVisibleMethodSelections()) {
+ return
+ }
+
// Validate URL fields — novalidate disables browser-native checks, so we validate here
const isValidHttpUrl = (url: string): boolean => {
if (!url) return true
diff --git a/frontend/src/views/admin/__tests__/AuthIdentityMigrationReportsView.spec.ts b/frontend/src/views/admin/__tests__/AuthIdentityMigrationReportsView.spec.ts
index 5e6b0ae0..406baaf1 100644
--- a/frontend/src/views/admin/__tests__/AuthIdentityMigrationReportsView.spec.ts
+++ b/frontend/src/views/admin/__tests__/AuthIdentityMigrationReportsView.spec.ts
@@ -240,4 +240,21 @@ describe('AuthIdentityMigrationReportsView', () => {
reportType: '',
})
})
+
+ it('keeps report type filter options available from list data when summary fails', async () => {
+ getAuthIdentityMigrationReportSummary.mockRejectedValueOnce(new Error('summary failed'))
+ listAuthIdentityMigrationReports.mockResolvedValueOnce(listResponse)
+
+ const wrapper = mountView()
+
+ await flushPromises()
+
+ const options = wrapper
+ .get('[data-test="report-type-filter"]')
+ .findAll('option')
+ .map((node) => node.element.value)
+
+ expect(showError).toHaveBeenCalled()
+ expect(options).toContain('oidc_synthetic_email_requires_manual_recovery')
+ })
})
diff --git a/frontend/src/views/admin/__tests__/SettingsView.spec.ts b/frontend/src/views/admin/__tests__/SettingsView.spec.ts
index f20170e9..b6f8ab17 100644
--- a/frontend/src/views/admin/__tests__/SettingsView.spec.ts
+++ b/frontend/src/views/admin/__tests__/SettingsView.spec.ts
@@ -449,4 +449,27 @@ describe('admin SettingsView payment visible method controls', () => {
})
)
})
+
+ it('blocks saving when a visible payment method is enabled without a source', async () => {
+ const wrapper = mountView()
+
+ await flushPromises()
+ await openPaymentTab(wrapper)
+
+ const paymentSourceSelects = wrapper
+ .findAll('select.select-stub')
+ .filter((node) => ['alipay', 'wxpay'].includes(node.attributes('data-placeholder')))
+
+ const alipaySelect = paymentSourceSelects.find(
+ (node) => node.attributes('data-placeholder') === 'alipay'
+ )
+
+ await alipaySelect?.setValue('')
+ await wrapper.find('form').trigger('submit.prevent')
+ await flushPromises()
+
+ expect(updateSettings).not.toHaveBeenCalled()
+ expect(showError).toHaveBeenCalled()
+ expect(String(showError.mock.calls.at(-1)?.[0] ?? '')).toContain('支付来源')
+ })
})
From e4fe9fae2a5e8992a822f58334660d9f1aaa29a1 Mon Sep 17 00:00:00 2001
From: IanShaw027
Date: Tue, 21 Apr 2026 00:42:55 +0800
Subject: [PATCH 075/189] Fix profile refresh identity compatibility
---
.../handler/auth_current_user_test.go | 78 ++++++++++++++++
backend/internal/handler/auth_handler.go | 13 ++-
.../handler/auth_oauth_pending_flow.go | 18 +++-
.../handler/auth_oauth_pending_flow_test.go | 72 +++++++++++++++
backend/internal/handler/user_handler.go | 88 ++-----------------
backend/internal/handler/user_handler_test.go | 13 +--
backend/internal/service/user_service.go | 5 ++
7 files changed, 195 insertions(+), 92 deletions(-)
create mode 100644 backend/internal/handler/auth_current_user_test.go
diff --git a/backend/internal/handler/auth_current_user_test.go b/backend/internal/handler/auth_current_user_test.go
new file mode 100644
index 00000000..dab95e29
--- /dev/null
+++ b/backend/internal/handler/auth_current_user_test.go
@@ -0,0 +1,78 @@
+//go:build unit
+
+package handler
+
+import (
+ "encoding/json"
+ "net/http"
+ "net/http/httptest"
+ "testing"
+ "time"
+
+ middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware"
+ "github.com/Wei-Shaw/sub2api/internal/service"
+ "github.com/gin-gonic/gin"
+ "github.com/stretchr/testify/require"
+)
+
+func TestAuthHandlerGetCurrentUserReturnsProfileCompatibilityFields(t *testing.T) {
+ gin.SetMode(gin.TestMode)
+
+ verifiedAt := time.Date(2026, 4, 20, 8, 30, 0, 0, time.UTC)
+ repo := &userHandlerRepoStub{
+ user: &service.User{
+ ID: 31,
+ Email: "me@example.com",
+ Username: "linuxdo-handle",
+ Role: service.RoleUser,
+ Status: service.StatusActive,
+ AvatarURL: "https://cdn.example.com/linuxdo.png",
+ AvatarSource: "remote_url",
+ },
+ identities: []service.UserAuthIdentityRecord{
+ {
+ ProviderType: "linuxdo",
+ ProviderKey: "linuxdo",
+ ProviderSubject: "linuxdo-subject-31",
+ VerifiedAt: &verifiedAt,
+ Metadata: map[string]any{
+ "username": "linuxdo-handle",
+ },
+ },
+ },
+ }
+
+ handler := &AuthHandler{
+ userService: service.NewUserService(repo, nil, nil, nil),
+ }
+
+ recorder := httptest.NewRecorder()
+ c, _ := gin.CreateTestContext(recorder)
+ c.Request = httptest.NewRequest(http.MethodGet, "/api/v1/auth/me", nil)
+ c.Set(string(middleware2.ContextKeyUser), middleware2.AuthSubject{UserID: 31})
+
+ handler.GetCurrentUser(c)
+
+ require.Equal(t, http.StatusOK, recorder.Code)
+
+ var resp struct {
+ Code int `json:"code"`
+ Data map[string]any `json:"data"`
+ }
+ require.NoError(t, json.Unmarshal(recorder.Body.Bytes(), &resp))
+ require.Equal(t, 0, resp.Code)
+ require.Equal(t, true, resp.Data["email_bound"])
+ require.Equal(t, true, resp.Data["linuxdo_bound"])
+ require.Equal(t, "https://cdn.example.com/linuxdo.png", resp.Data["avatar_url"])
+
+ authBindings, ok := resp.Data["auth_bindings"].(map[string]any)
+ require.True(t, ok)
+ linuxdoBinding, ok := authBindings["linuxdo"].(map[string]any)
+ require.True(t, ok)
+ require.Equal(t, true, linuxdoBinding["bound"])
+
+ _, hasAvatarSource := resp.Data["avatar_source"]
+ require.False(t, hasAvatarSource)
+ _, hasProfileSources := resp.Data["profile_sources"]
+ require.False(t, hasProfileSources)
+}
diff --git a/backend/internal/handler/auth_handler.go b/backend/internal/handler/auth_handler.go
index b984a436..76ca153d 100644
--- a/backend/internal/handler/auth_handler.go
+++ b/backend/internal/handler/auth_handler.go
@@ -348,8 +348,14 @@ func (h *AuthHandler) GetCurrentUser(c *gin.Context) {
return
}
+ identities, err := h.userService.GetProfileIdentitySummaries(c.Request.Context(), subject.UserID, user)
+ if err != nil {
+ response.ErrorFrom(c, err)
+ return
+ }
+
type UserResponse struct {
- *dto.User
+ userProfileResponse
RunMode string `json:"run_mode"`
}
@@ -358,7 +364,10 @@ func (h *AuthHandler) GetCurrentUser(c *gin.Context) {
runMode = h.cfg.RunMode
}
- response.Success(c, UserResponse{User: dto.UserFromService(user), RunMode: runMode})
+ response.Success(c, UserResponse{
+ userProfileResponse: userProfileResponseFromService(user, identities),
+ RunMode: runMode,
+ })
}
// ValidatePromoCodeRequest 验证优惠码请求
diff --git a/backend/internal/handler/auth_oauth_pending_flow.go b/backend/internal/handler/auth_oauth_pending_flow.go
index 94186858..461810f1 100644
--- a/backend/internal/handler/auth_oauth_pending_flow.go
+++ b/backend/internal/handler/auth_oauth_pending_flow.go
@@ -848,6 +848,12 @@ func shouldBindPendingOAuthIdentity(session *dbent.PendingAuthSession, decision
}
}
+func shouldSkipAvatarAdoption(err error) bool {
+ return errors.Is(err, service.ErrAvatarInvalid) ||
+ errors.Is(err, service.ErrAvatarTooLarge) ||
+ errors.Is(err, service.ErrAvatarNotImage)
+}
+
func applyPendingOAuthBinding(
ctx context.Context,
client *dbent.Client,
@@ -885,6 +891,14 @@ func applyPendingOAuthBinding(
if decision != nil && decision.AdoptAvatar {
adoptedAvatarURL = pendingSessionStringValue(session.UpstreamIdentityClaims, "suggested_avatar_url")
}
+ shouldAdoptAvatar := false
+ if decision != nil && decision.AdoptAvatar && adoptedAvatarURL != "" {
+ if err := service.ValidateUserAvatar(adoptedAvatarURL); err == nil {
+ shouldAdoptAvatar = true
+ } else if !shouldSkipAvatarAdoption(err) {
+ return err
+ }
+ }
tx, err := client.Tx(ctx)
if err != nil {
@@ -913,7 +927,7 @@ func applyPendingOAuthBinding(
if decision != nil && decision.AdoptDisplayName && adoptedDisplayName != "" {
metadata["display_name"] = adoptedDisplayName
}
- if decision != nil && decision.AdoptAvatar && adoptedAvatarURL != "" {
+ if shouldAdoptAvatar {
metadata["avatar_url"] = adoptedAvatarURL
}
@@ -939,7 +953,7 @@ func applyPendingOAuthBinding(
}
}
- if decision != nil && decision.AdoptAvatar && adoptedAvatarURL != "" && userService != nil {
+ if shouldAdoptAvatar && userService != nil {
if _, err := userService.SetAvatar(txCtx, targetUserID, adoptedAvatarURL); err != nil {
return err
}
diff --git a/backend/internal/handler/auth_oauth_pending_flow_test.go b/backend/internal/handler/auth_oauth_pending_flow_test.go
index 2521186e..d29e4b88 100644
--- a/backend/internal/handler/auth_oauth_pending_flow_test.go
+++ b/backend/internal/handler/auth_oauth_pending_flow_test.go
@@ -173,6 +173,78 @@ func TestExchangePendingOAuthCompletionPreviewThenFinalizeAppliesAdoptionDecisio
require.NotNil(t, consumed.ConsumedAt)
}
+func TestExchangePendingOAuthCompletionSkipsInvalidAvatarAdoptionWithoutBlockingCompletion(t *testing.T) {
+ handler, client := newOAuthPendingFlowTestHandler(t, false)
+ ctx := context.Background()
+
+ userEntity, err := client.User.Create().
+ SetEmail("invalid-avatar@example.com").
+ SetUsername("legacy-name").
+ SetPasswordHash("hash").
+ SetRole(service.RoleUser).
+ SetStatus(service.StatusActive).
+ Save(ctx)
+ require.NoError(t, err)
+
+ session, err := client.PendingAuthSession.Create().
+ SetSessionToken("pending-invalid-avatar-token").
+ SetIntent("login").
+ SetProviderType("linuxdo").
+ SetProviderKey("linuxdo").
+ SetProviderSubject("invalid-avatar-123").
+ SetTargetUserID(userEntity.ID).
+ SetResolvedEmail(userEntity.Email).
+ SetBrowserSessionKey("browser-invalid-avatar-key").
+ SetUpstreamIdentityClaims(map[string]any{
+ "username": "linuxdo_user",
+ "suggested_display_name": "Alice Example",
+ "suggested_avatar_url": "/avatars/alice.png",
+ }).
+ SetLocalFlowState(map[string]any{
+ oauthCompletionResponseKey: map[string]any{
+ "access_token": "access-token",
+ "redirect": "/dashboard",
+ },
+ }).
+ SetExpiresAt(time.Now().UTC().Add(10 * time.Minute)).
+ Save(ctx)
+ require.NoError(t, err)
+
+ body := bytes.NewBufferString(`{"adopt_display_name":true,"adopt_avatar":true}`)
+ recorder := httptest.NewRecorder()
+ ginCtx, _ := gin.CreateTestContext(recorder)
+ req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/oauth/pending/exchange", body)
+ req.Header.Set("Content-Type", "application/json")
+ req.AddCookie(&http.Cookie{Name: oauthPendingSessionCookieName, Value: encodeCookieValue(session.SessionToken)})
+ req.AddCookie(&http.Cookie{Name: oauthPendingBrowserCookieName, Value: encodeCookieValue("browser-invalid-avatar-key")})
+ ginCtx.Request = req
+
+ handler.ExchangePendingOAuthCompletion(ginCtx)
+
+ require.Equal(t, http.StatusOK, recorder.Code)
+
+ identity, err := client.AuthIdentity.Query().
+ Where(
+ authidentity.ProviderTypeEQ("linuxdo"),
+ authidentity.ProviderKeyEQ("linuxdo"),
+ authidentity.ProviderSubjectEQ("invalid-avatar-123"),
+ ).
+ Only(ctx)
+ require.NoError(t, err)
+ require.Equal(t, "Alice Example", identity.Metadata["display_name"])
+ _, hasAdoptedAvatar := identity.Metadata["avatar_url"]
+ require.False(t, hasAdoptedAvatar)
+
+ avatar := loadUserAvatarRecord(t, client, userEntity.ID)
+ require.Nil(t, avatar)
+
+ consumed, err := client.PendingAuthSession.Query().
+ Where(pendingauthsession.IDEQ(session.ID)).
+ Only(ctx)
+ require.NoError(t, err)
+ require.NotNil(t, consumed.ConsumedAt)
+}
+
func TestExchangePendingOAuthCompletionBindCurrentUserPreviewThenFinalizeBindsIdentityWithoutAdoption(t *testing.T) {
handler, client := newOAuthPendingFlowTestHandler(t, false)
ctx := context.Background()
diff --git a/backend/internal/handler/user_handler.go b/backend/internal/handler/user_handler.go
index b1ade5c0..843b0bd9 100644
--- a/backend/internal/handler/user_handler.go
+++ b/backend/internal/handler/user_handler.go
@@ -2,7 +2,6 @@ package handler
import (
"context"
- "strings"
"github.com/Wei-Shaw/sub2api/internal/handler/dto"
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
@@ -353,22 +352,16 @@ func userProfileResponseFromService(user *service.User, identities service.UserI
return userProfileResponse{}
}
bindings := userProfileBindingMap(identities)
- profileSources, avatarSource, usernameSource := inferUserProfileSources(user, identities)
return userProfileResponse{
- User: *base,
- AvatarURL: user.AvatarURL,
- AvatarSource: avatarSource,
- UsernameSource: usernameSource,
- DisplayNameSource: usernameSource,
- NicknameSource: usernameSource,
- ProfileSources: profileSources,
- Identities: identities,
- AuthBindings: bindings,
- IdentityBindings: bindings,
- EmailBound: identities.Email.Bound,
- LinuxDoBound: identities.LinuxDo.Bound,
- OIDCBound: identities.OIDC.Bound,
- WeChatBound: identities.WeChat.Bound,
+ User: *base,
+ AvatarURL: user.AvatarURL,
+ Identities: identities,
+ AuthBindings: bindings,
+ IdentityBindings: bindings,
+ EmailBound: identities.Email.Bound,
+ LinuxDoBound: identities.LinuxDo.Bound,
+ OIDCBound: identities.OIDC.Bound,
+ WeChatBound: identities.WeChat.Bound,
}
}
@@ -380,66 +373,3 @@ func userProfileBindingMap(identities service.UserIdentitySummarySet) map[string
"wechat": identities.WeChat,
}
}
-
-func inferUserProfileSources(user *service.User, identities service.UserIdentitySummarySet) (
- map[string]*userProfileSourceContext,
- *userProfileSourceContext,
- *userProfileSourceContext,
-) {
- if user == nil {
- return nil, nil, nil
- }
-
- thirdParty := thirdPartyIdentityProviders(identities)
- var avatarSource *userProfileSourceContext
- if strings.TrimSpace(user.AvatarURL) != "" && len(thirdParty) == 1 {
- avatarSource = buildUserProfileSourceContext(thirdParty[0].Provider)
- }
-
- usernameValue := strings.TrimSpace(user.Username)
- var usernameSource *userProfileSourceContext
- for _, summary := range thirdParty {
- if usernameValue != "" && usernameValue == strings.TrimSpace(summary.DisplayName) {
- usernameSource = buildUserProfileSourceContext(summary.Provider)
- break
- }
- }
- if usernameSource == nil && usernameValue != "" && len(thirdParty) == 1 {
- usernameSource = buildUserProfileSourceContext(thirdParty[0].Provider)
- }
-
- profileSources := map[string]*userProfileSourceContext{}
- if avatarSource != nil {
- profileSources["avatar"] = avatarSource
- }
- if usernameSource != nil {
- profileSources["username"] = usernameSource
- profileSources["display_name"] = usernameSource
- profileSources["nickname"] = usernameSource
- }
- if len(profileSources) == 0 {
- return nil, avatarSource, usernameSource
- }
- return profileSources, avatarSource, usernameSource
-}
-
-func thirdPartyIdentityProviders(identities service.UserIdentitySummarySet) []service.UserIdentitySummary {
- out := make([]service.UserIdentitySummary, 0, 3)
- for _, summary := range []service.UserIdentitySummary{identities.LinuxDo, identities.OIDC, identities.WeChat} {
- if summary.Bound {
- out = append(out, summary)
- }
- }
- return out
-}
-
-func buildUserProfileSourceContext(provider string) *userProfileSourceContext {
- provider = strings.TrimSpace(provider)
- if provider == "" {
- return nil
- }
- return &userProfileSourceContext{
- Provider: provider,
- Source: provider,
- }
-}
diff --git a/backend/internal/handler/user_handler_test.go b/backend/internal/handler/user_handler_test.go
index 1216f9c4..7c6460e8 100644
--- a/backend/internal/handler/user_handler_test.go
+++ b/backend/internal/handler/user_handler_test.go
@@ -298,15 +298,10 @@ func TestUserHandlerGetProfileReturnsLegacyCompatibilityFields(t *testing.T) {
require.True(t, ok)
require.Equal(t, true, emailBinding["bound"])
- avatarSource, ok := resp.Data["avatar_source"].(map[string]any)
- require.True(t, ok)
- require.Equal(t, "linuxdo", avatarSource["provider"])
-
- profileSources, ok := resp.Data["profile_sources"].(map[string]any)
- require.True(t, ok)
- usernameSource, ok := profileSources["username"].(map[string]any)
- require.True(t, ok)
- require.Equal(t, "linuxdo", usernameSource["provider"])
+ _, hasAvatarSource := resp.Data["avatar_source"]
+ require.False(t, hasAvatarSource)
+ _, hasProfileSources := resp.Data["profile_sources"]
+ require.False(t, hasProfileSources)
}
func TestUserHandlerStartIdentityBindingReturnsAuthorizeURL(t *testing.T) {
diff --git a/backend/internal/service/user_service.go b/backend/internal/service/user_service.go
index c106a3f5..cd1bc2bb 100644
--- a/backend/internal/service/user_service.go
+++ b/backend/internal/service/user_service.go
@@ -403,6 +403,11 @@ func normalizeUserAvatarInput(raw string) (UpsertUserAvatarInput, error) {
}, nil
}
+func ValidateUserAvatar(raw string) error {
+ _, err := normalizeUserAvatarInput(raw)
+ return err
+}
+
func normalizeInlineUserAvatarInput(raw string) (UpsertUserAvatarInput, error) {
body := strings.TrimPrefix(raw, "data:")
meta, encoded, ok := strings.Cut(body, ",")
From 12f4af742f5c7c40503430f468c5e0e42894c72f Mon Sep 17 00:00:00 2001
From: IanShaw027
Date: Tue, 21 Apr 2026 00:45:56 +0800
Subject: [PATCH 076/189] fix auth pending adoption and turnstile flow
---
.../auth/PendingOAuthCreateAccountForm.vue | 62 +++++++++++++-
.../PendingOAuthCreateAccountForm.spec.ts | 50 +++++++++++-
.../src/views/auth/LinuxDoCallbackView.vue | 28 ++++++-
frontend/src/views/auth/OidcCallbackView.vue | 28 ++++++-
.../__tests__/LinuxDoCallbackView.spec.ts | 81 ++++++++++++++++++-
.../auth/__tests__/OidcCallbackView.spec.ts | 78 +++++++++++++++++-
6 files changed, 313 insertions(+), 14 deletions(-)
diff --git a/frontend/src/components/auth/PendingOAuthCreateAccountForm.vue b/frontend/src/components/auth/PendingOAuthCreateAccountForm.vue
index 39588a86..36e78d36 100644
--- a/frontend/src/components/auth/PendingOAuthCreateAccountForm.vue
+++ b/frontend/src/components/auth/PendingOAuthCreateAccountForm.vue
@@ -16,6 +16,15 @@
placeholder="Password"
:disabled="isSubmitting"
/>
+
+
+
{{
@@ -80,9 +89,10 @@
diff --git a/frontend/src/components/user/profile/__tests__/ProfileIdentityBindingsSection.spec.ts b/frontend/src/components/user/profile/__tests__/ProfileIdentityBindingsSection.spec.ts
index 1c9531e3..7db2ecd7 100644
--- a/frontend/src/components/user/profile/__tests__/ProfileIdentityBindingsSection.spec.ts
+++ b/frontend/src/components/user/profile/__tests__/ProfileIdentityBindingsSection.spec.ts
@@ -1,6 +1,8 @@
import { mount } from '@vue/test-utils'
+import { createPinia, setActivePinia } from 'pinia'
import { afterEach, beforeEach, describe, expect, it, vi } from 'vitest'
import ProfileIdentityBindingsSection from '@/components/user/profile/ProfileIdentityBindingsSection.vue'
+import { useAppStore } from '@/stores'
import type { User } from '@/types'
const routeState = vi.hoisted(() => ({
@@ -11,6 +13,8 @@ const locationState = vi.hoisted(() => ({
current: { href: 'http://localhost/profile' } as { href: string },
}))
+let pinia: ReturnType
+
vi.mock('vue-router', () => ({
useRoute: () => routeState,
}))
@@ -57,6 +61,8 @@ function createUser(overrides: Partial = {}): User {
describe('ProfileIdentityBindingsSection', () => {
beforeEach(() => {
+ pinia = createPinia()
+ setActivePinia(pinia)
routeState.fullPath = '/profile'
locationState.current = { href: 'http://localhost/profile' }
Object.defineProperty(window, 'location', {
@@ -67,6 +73,9 @@ describe('ProfileIdentityBindingsSection', () => {
configurable: true,
value: 'Mozilla/5.0',
})
+ const appStore = useAppStore()
+ appStore.cachedPublicSettings = null
+ appStore.publicSettingsLoaded = false
})
afterEach(() => {
@@ -75,6 +84,9 @@ describe('ProfileIdentityBindingsSection', () => {
it('renders provider binding states and provider-specific bind actions', () => {
const wrapper = mount(ProfileIdentityBindingsSection, {
+ global: {
+ plugins: [pinia],
+ },
props: {
user: createUser({
auth_bindings: {
@@ -102,11 +114,16 @@ describe('ProfileIdentityBindingsSection', () => {
it('starts the WeChat bind flow for the current profile page', async () => {
const wrapper = mount(ProfileIdentityBindingsSection, {
+ global: {
+ plugins: [pinia],
+ },
props: {
user: createUser(),
linuxdoEnabled: false,
oidcEnabled: false,
wechatEnabled: true,
+ wechatOpenEnabled: true,
+ wechatMpEnabled: false,
},
})
@@ -117,4 +134,22 @@ describe('ProfileIdentityBindingsSection', () => {
expect(locationState.current.href).toContain('intent=bind_current_user')
expect(locationState.current.href).toContain('redirect=%2Fprofile')
})
+
+ it('hides the WeChat bind action outside the WeChat browser when only mp mode is configured', () => {
+ const wrapper = mount(ProfileIdentityBindingsSection, {
+ global: {
+ plugins: [pinia],
+ },
+ props: {
+ user: createUser(),
+ linuxdoEnabled: false,
+ oidcEnabled: false,
+ wechatEnabled: true,
+ wechatOpenEnabled: false,
+ wechatMpEnabled: true,
+ },
+ })
+
+ expect(wrapper.find('[data-testid="profile-binding-wechat-action"]').exists()).toBe(false)
+ })
})
diff --git a/frontend/src/stores/app.ts b/frontend/src/stores/app.ts
index a8e03a51..0ed3f37f 100644
--- a/frontend/src/stores/app.ts
+++ b/frontend/src/stores/app.ts
@@ -338,6 +338,8 @@ export const useAppStore = defineStore('app', () => {
custom_endpoints: [],
linuxdo_oauth_enabled: false,
wechat_oauth_enabled: false,
+ wechat_oauth_open_enabled: false,
+ wechat_oauth_mp_enabled: false,
oidc_oauth_enabled: false,
oidc_oauth_provider_name: 'OIDC',
backend_mode_enabled: false,
diff --git a/frontend/src/types/index.ts b/frontend/src/types/index.ts
index 5a2e3184..07341919 100644
--- a/frontend/src/types/index.ts
+++ b/frontend/src/types/index.ts
@@ -165,6 +165,8 @@ export interface PublicSettings {
custom_endpoints: CustomEndpoint[]
linuxdo_oauth_enabled: boolean
wechat_oauth_enabled: boolean
+ wechat_oauth_open_enabled?: boolean
+ wechat_oauth_mp_enabled?: boolean
oidc_oauth_enabled: boolean
oidc_oauth_provider_name: string
backend_mode_enabled: boolean
diff --git a/frontend/src/views/auth/WechatCallbackView.vue b/frontend/src/views/auth/WechatCallbackView.vue
index 10b83b1c..a5da35e5 100644
--- a/frontend/src/views/auth/WechatCallbackView.vue
+++ b/frontend/src/views/auth/WechatCallbackView.vue
@@ -297,6 +297,7 @@ import {
login2FA,
prepareOAuthBindAccessTokenCookie,
persistOAuthTokenContext,
+ resolveWeChatOAuthStart,
type OAuthAdoptionDecision,
type PendingOAuthExchangeResponse
} from '@/api/auth'
@@ -378,7 +379,47 @@ function normalizeWeChatOAuthMode(value: unknown): 'open' | 'mp' | null {
return value === 'open' || value === 'mp' ? value : null
}
-function resolveRequestedWeChatOAuthMode(): 'open' | 'mp' {
+async function ensurePublicSettingsLoaded(): Promise {
+ if (appStore.cachedPublicSettings || appStore.publicSettingsLoaded) {
+ return
+ }
+
+ try {
+ await appStore.fetchPublicSettings()
+ } catch {
+ // Fall back to legacy mode selection when public settings are unavailable.
+ }
+}
+
+function resolveConfiguredWeChatOAuthMode(): 'open' | 'mp' | null {
+ if (!appStore.cachedPublicSettings && !appStore.publicSettingsLoaded) {
+ return null
+ }
+
+ return resolveWeChatOAuthStart(appStore.cachedPublicSettings).mode
+}
+
+function resolveWeChatOAuthUnavailableMessage(): string {
+ const resolved = resolveWeChatOAuthStart(appStore.cachedPublicSettings)
+
+ switch (resolved.unavailableReason) {
+ case 'external_browser_required':
+ return 'This WeChat sign-in flow is only available in your system browser.'
+ case 'wechat_browser_required':
+ return 'This WeChat sign-in flow is only available inside the WeChat browser.'
+ case 'not_configured':
+ return 'WeChat sign-in is not configured yet.'
+ default:
+ return t('auth.loginFailed')
+ }
+}
+
+function resolveRequestedWeChatOAuthMode(): 'open' | 'mp' | null {
+ const configuredMode = resolveConfiguredWeChatOAuthMode()
+ if (configuredMode) {
+ return configuredMode
+ }
+
const queryMode = normalizeWeChatOAuthMode(route.query.mode)
return queryMode || resolveWeChatOAuthMode()
}
@@ -389,11 +430,15 @@ function resolveRedirectTarget(): string {
)
}
-function resolveWeChatStartURL(intent: 'bind_current_user' | 'adopt_existing_user_by_email'): string {
+function resolveWeChatStartURL(intent: 'bind_current_user' | 'adopt_existing_user_by_email'): string | null {
const apiBase = (import.meta.env.VITE_API_BASE_URL as string | undefined) || '/api/v1'
const normalized = apiBase.replace(/\/$/, '')
+ const mode = resolveRequestedWeChatOAuthMode()
+ if (!mode) {
+ return null
+ }
const params = new URLSearchParams({
- mode: resolveRequestedWeChatOAuthMode(),
+ mode,
redirect: resolveRedirectTarget(),
intent,
})
@@ -406,11 +451,15 @@ function resolveWeChatStartURL(intent: 'bind_current_user' | 'adopt_existing_use
return `${normalized}/auth/oauth/wechat/start?${params.toString()}`
}
-function buildExistingAccountResumePath(): string {
+function buildExistingAccountResumePath(): string | null {
+ const mode = resolveRequestedWeChatOAuthMode()
+ if (!mode) {
+ return null
+ }
const params = new URLSearchParams({
wechat_bind_existing: '1',
redirect: resolveRedirectTarget(),
- mode: resolveRequestedWeChatOAuthMode(),
+ mode,
})
const email = existingAccountEmail.value.trim()
@@ -444,14 +493,31 @@ function serializeAdoptionDecision(decision: OAuthAdoptionDecision): Record {
+ await ensurePublicSettingsLoaded()
+
if (typeof route.query.email === 'string') {
existingAccountEmail.value = route.query.email
}
if (route.query.wechat_bind_existing === '1') {
if (getAuthToken()) {
+ const startURL = resolveWeChatStartURL('bind_current_user')
+ if (!startURL) {
+ errorMessage.value = resolveWeChatOAuthUnavailableMessage()
+ appStore.showError(errorMessage.value)
+ isProcessing.value = false
+ return
+ }
prepareOAuthBindAccessTokenCookie()
- window.location.href = resolveWeChatStartURL('bind_current_user')
+ window.location.href = startURL
+ return
+ }
+
+ const resumePath = buildExistingAccountResumePath()
+ if (!resumePath) {
+ errorMessage.value = resolveWeChatOAuthUnavailableMessage()
+ appStore.showError(errorMessage.value)
+ isProcessing.value = false
return
}
const params = new URLSearchParams({
- redirect: buildExistingAccountResumePath(),
+ redirect: resumePath,
})
const email = existingAccountEmail.value.trim()
if (email) {
diff --git a/frontend/src/views/auth/__tests__/WechatCallbackView.spec.ts b/frontend/src/views/auth/__tests__/WechatCallbackView.spec.ts
index aa673238..7f26f3c8 100644
--- a/frontend/src/views/auth/__tests__/WechatCallbackView.spec.ts
+++ b/frontend/src/views/auth/__tests__/WechatCallbackView.spec.ts
@@ -14,8 +14,10 @@ const {
setTokenMock,
showSuccessMock,
showErrorMock,
+ fetchPublicSettingsMock,
routeState,
locationState,
+ appStoreState,
} = vi.hoisted(() => ({
exchangePendingOAuthCompletionMock: vi.fn(),
completeWeChatOAuthRegistrationMock: vi.fn(),
@@ -28,6 +30,7 @@ const {
setTokenMock: vi.fn(),
showSuccessMock: vi.fn(),
showErrorMock: vi.fn(),
+ fetchPublicSettingsMock: vi.fn(),
routeState: {
query: {} as Record,
},
@@ -39,6 +42,10 @@ const {
pathname: '/auth/wechat/callback'
} as { href: string; hash: string; search: string; pathname: string },
},
+ appStoreState: {
+ cachedPublicSettings: null as null | Record,
+ publicSettingsLoaded: false,
+ },
}))
vi.mock('vue-router', () => ({
@@ -102,8 +109,10 @@ vi.mock('@/stores', () => ({
setToken: setTokenMock,
}),
useAppStore: () => ({
+ ...appStoreState,
showSuccess: showSuccessMock,
showError: showErrorMock,
+ fetchPublicSettings: fetchPublicSettingsMock,
}),
}))
@@ -139,7 +148,10 @@ describe('WechatCallbackView', () => {
showErrorMock.mockReset()
prepareOAuthBindAccessTokenCookieMock.mockReset()
getAuthTokenMock.mockReset()
+ fetchPublicSettingsMock.mockReset()
routeState.query = {}
+ appStoreState.cachedPublicSettings = null
+ appStoreState.publicSettingsLoaded = false
localStorage.clear()
locationState.current = {
href: 'http://localhost/auth/wechat/callback',
@@ -157,6 +169,38 @@ describe('WechatCallbackView', () => {
})
})
+ it('overrides an incompatible query mode with the configured open capability during bind recovery', async () => {
+ routeState.query = {
+ wechat_bind_existing: '1',
+ mode: 'mp',
+ redirect: '/profile',
+ }
+ appStoreState.cachedPublicSettings = {
+ wechat_oauth_enabled: true,
+ wechat_oauth_open_enabled: true,
+ wechat_oauth_mp_enabled: false,
+ }
+ appStoreState.publicSettingsLoaded = true
+ getAuthTokenMock.mockReturnValue('current-auth-token')
+
+ mount(WechatCallbackView, {
+ global: {
+ stubs: {
+ AuthLayout: { template: '
' },
+ Icon: true,
+ RouterLink: { template: ' ' },
+ transition: false,
+ },
+ },
+ })
+
+ await flushPromises()
+
+ expect(prepareOAuthBindAccessTokenCookieMock).toHaveBeenCalledTimes(1)
+ expect(locationState.current.href).toContain('mode=open')
+ expect(locationState.current.href).not.toContain('mode=mp')
+ })
+
it('does not send adoption decisions during the initial exchange', async () => {
exchangePendingOAuthCompletionMock.mockResolvedValue({
access_token: 'access-token',
diff --git a/frontend/src/views/user/ProfileView.vue b/frontend/src/views/user/ProfileView.vue
index f7418be9..14d7efea 100644
--- a/frontend/src/views/user/ProfileView.vue
+++ b/frontend/src/views/user/ProfileView.vue
@@ -67,7 +67,6 @@
diff --git a/frontend/src/views/user/__tests__/PaymentResultView.spec.ts b/frontend/src/views/user/__tests__/PaymentResultView.spec.ts
index d8199e3b..c4e38523 100644
--- a/frontend/src/views/user/__tests__/PaymentResultView.spec.ts
+++ b/frontend/src/views/user/__tests__/PaymentResultView.spec.ts
@@ -1,4 +1,4 @@
-import { beforeEach, describe, expect, it, vi } from 'vitest'
+import { afterEach, beforeEach, describe, expect, it, vi } from 'vitest'
import { flushPromises, mount } from '@vue/test-utils'
const routeState = vi.hoisted(() => ({
@@ -73,7 +73,11 @@ describe('PaymentResultView', () => {
window.localStorage.clear()
})
- it('restores order id from a matching resume token and does not trust query success flags', async () => {
+ afterEach(() => {
+ vi.useRealTimers()
+ })
+
+ it('renders a pending state instead of a failure state when the restored order is still pending', async () => {
routeState.query = {
resume_token: 'resume-42',
order_id: '999',
@@ -107,8 +111,43 @@ describe('PaymentResultView', () => {
expect(pollOrderStatus).toHaveBeenCalledWith(42)
expect(verifyOrderPublic).not.toHaveBeenCalled()
- expect(wrapper.text()).toContain('payment.result.failed')
+ expect(wrapper.text()).toContain('payment.result.processing')
expect(wrapper.text()).not.toContain('payment.result.success')
+ expect(wrapper.text()).not.toContain('payment.result.failed')
+ })
+
+ it('refreshes a pending resume-token result until the order becomes paid', async () => {
+ vi.useFakeTimers()
+ routeState.query = {
+ resume_token: 'resume-77',
+ }
+ resolveOrderPublicByResumeToken
+ .mockResolvedValueOnce({
+ data: orderFactory('PENDING'),
+ })
+ .mockResolvedValueOnce({
+ data: orderFactory('PAID'),
+ })
+
+ const wrapper = mount(PaymentResultView, {
+ global: {
+ stubs: {
+ OrderStatusBadge: true,
+ },
+ },
+ })
+
+ await flushPromises()
+
+ expect(resolveOrderPublicByResumeToken).toHaveBeenCalledTimes(1)
+ expect(wrapper.text()).toContain('payment.result.processing')
+
+ await vi.advanceTimersByTimeAsync(2000)
+ await flushPromises()
+
+ expect(resolveOrderPublicByResumeToken).toHaveBeenCalledTimes(2)
+ expect(wrapper.text()).toContain('payment.result.success')
+ expect(wrapper.text()).not.toContain('payment.result.failed')
})
it('does not fall back to public out_trade_no verification when resume_token recovery fails', async () => {
From 7a9488ff37b07d70ae4ad62ac1dab17e37bc7470 Mon Sep 17 00:00:00 2001
From: IanShaw027
Date: Tue, 21 Apr 2026 00:59:20 +0800
Subject: [PATCH 079/189] Add legacy identity safety remediation migration
---
...ntity_legacy_migration_integration_test.go | 246 +++++++++++++
...dentity_legacy_external_safety_reports.sql | 336 ++++++++++++++++++
2 files changed, 582 insertions(+)
create mode 100644 backend/migrations/116_auth_identity_legacy_external_safety_reports.sql
diff --git a/backend/internal/repository/auth_identity_legacy_migration_integration_test.go b/backend/internal/repository/auth_identity_legacy_migration_integration_test.go
index 6a6312d4..7f2f363f 100644
--- a/backend/internal/repository/auth_identity_legacy_migration_integration_test.go
+++ b/backend/internal/repository/auth_identity_legacy_migration_integration_test.go
@@ -200,6 +200,252 @@ FROM auth_identity_migration_reports
var afterCount int
require.NoError(t, tx.QueryRowContext(ctx, `
SELECT COUNT(*)
+FROM auth_identity_migration_reports
+ `).Scan(&afterCount))
+ require.Equal(t, beforeCount, afterCount)
+}
+
+func TestAuthIdentityLegacyExternalSafetyMigration_ReportsConflictsAndDowngradesInvalidJSON(t *testing.T) {
+ tx := testTx(t)
+ ctx := context.Background()
+
+ migrationPath := filepath.Join("..", "..", "migrations", "116_auth_identity_legacy_external_safety_reports.sql")
+ migrationSQL, err := os.ReadFile(migrationPath)
+ require.NoError(t, err)
+
+ _, err = tx.ExecContext(ctx, `
+CREATE TABLE IF NOT EXISTS user_external_identities (
+ id BIGSERIAL PRIMARY KEY,
+ user_id BIGINT NOT NULL,
+ provider TEXT NOT NULL,
+ provider_user_id TEXT NOT NULL,
+ provider_union_id TEXT NULL,
+ provider_username TEXT NOT NULL DEFAULT '',
+ display_name TEXT NOT NULL DEFAULT '',
+ profile_url TEXT NOT NULL DEFAULT '',
+ avatar_url TEXT NOT NULL DEFAULT '',
+ metadata TEXT NOT NULL DEFAULT '{}',
+ created_at TIMESTAMPTZ NOT NULL DEFAULT CURRENT_TIMESTAMP,
+ updated_at TIMESTAMPTZ NOT NULL DEFAULT CURRENT_TIMESTAMP
+);
+
+ TRUNCATE TABLE
+ auth_identity_channels,
+ auth_identities,
+ auth_identity_migration_reports,
+ user_external_identities,
+ users
+ RESTART IDENTITY;
+`)
+ require.NoError(t, err)
+
+ userIDs := make([]int64, 0, 8)
+ for _, email := range []string{
+ "linuxdo-conflict-legacy@example.com",
+ "linuxdo-conflict-owner@example.com",
+ "wechat-conflict-legacy@example.com",
+ "wechat-conflict-owner@example.com",
+ "wechat-channel-legacy@example.com",
+ "wechat-channel-owner@example.com",
+ "linuxdo-invalid-json@example.com",
+ "wechat-openid-invalid-json@example.com",
+ } {
+ var userID int64
+ require.NoError(t, tx.QueryRowContext(ctx, `
+INSERT INTO users (email, password_hash, role, status, balance, concurrency)
+VALUES ($1, 'hash', 'user', 'active', 0, 1)
+RETURNING id`, email).Scan(&userID))
+ userIDs = append(userIDs, userID)
+ }
+
+ linuxdoConflictLegacyUserID := userIDs[0]
+ linuxdoConflictOwnerUserID := userIDs[1]
+ wechatConflictLegacyUserID := userIDs[2]
+ wechatConflictOwnerUserID := userIDs[3]
+ wechatChannelLegacyUserID := userIDs[4]
+ wechatChannelOwnerUserID := userIDs[5]
+ linuxdoInvalidJSONUserID := userIDs[6]
+ wechatInvalidOpenIDUserID := userIDs[7]
+
+ require.NoError(t, tx.QueryRowContext(ctx, `
+INSERT INTO auth_identities (user_id, provider_type, provider_key, provider_subject, metadata)
+VALUES ($1, 'linuxdo', 'linuxdo', 'linuxdo-conflict', '{}'::jsonb)
+RETURNING id`, linuxdoConflictOwnerUserID).Scan(new(int64)))
+
+ require.NoError(t, tx.QueryRowContext(ctx, `
+INSERT INTO auth_identities (user_id, provider_type, provider_key, provider_subject, metadata)
+VALUES ($1, 'wechat', 'wechat-main', 'union-conflict', '{}'::jsonb)
+RETURNING id`, wechatConflictOwnerUserID).Scan(new(int64)))
+
+ var wechatChannelOwnerIdentityID int64
+ require.NoError(t, tx.QueryRowContext(ctx, `
+INSERT INTO auth_identities (user_id, provider_type, provider_key, provider_subject, metadata)
+VALUES ($1, 'wechat', 'wechat-main', 'union-channel-owner', '{}'::jsonb)
+RETURNING id`, wechatChannelOwnerUserID).Scan(&wechatChannelOwnerIdentityID))
+
+ require.NoError(t, tx.QueryRowContext(ctx, `
+INSERT INTO auth_identity_channels (
+ identity_id,
+ provider_type,
+ provider_key,
+ channel,
+ channel_app_id,
+ channel_subject,
+ metadata
+)
+VALUES ($1, 'wechat', 'wechat-main', 'oa', 'wx-app-conflict', 'openid-channel-conflict', '{}'::jsonb)
+RETURNING id`, wechatChannelOwnerIdentityID).Scan(new(int64)))
+
+ var linuxdoConflictLegacyID int64
+ require.NoError(t, tx.QueryRowContext(ctx, `
+INSERT INTO user_external_identities (
+ user_id,
+ provider,
+ provider_user_id,
+ provider_union_id,
+ provider_username,
+ display_name,
+ metadata
+) VALUES ($1, 'linuxdo', 'linuxdo-conflict', NULL, 'legacy-linuxdo', 'Legacy LinuxDo Conflict', '{"source":"legacy"}')
+RETURNING id
+`, linuxdoConflictLegacyUserID).Scan(&linuxdoConflictLegacyID))
+
+ var wechatConflictLegacyID int64
+ require.NoError(t, tx.QueryRowContext(ctx, `
+INSERT INTO user_external_identities (
+ user_id,
+ provider,
+ provider_user_id,
+ provider_union_id,
+ provider_username,
+ display_name,
+ metadata
+) VALUES ($1, 'wechat', 'openid-union-conflict', 'union-conflict', 'legacy-wechat', 'Legacy WeChat Conflict', '{"channel":"oa","appid":"wx-app-conflict-canon"}')
+RETURNING id
+`, wechatConflictLegacyUserID).Scan(&wechatConflictLegacyID))
+
+ var wechatChannelConflictLegacyID int64
+ require.NoError(t, tx.QueryRowContext(ctx, `
+INSERT INTO user_external_identities (
+ user_id,
+ provider,
+ provider_user_id,
+ provider_union_id,
+ provider_username,
+ display_name,
+ metadata
+) VALUES ($1, 'wechat', 'openid-channel-conflict', 'union-channel-legacy', 'legacy-wechat-channel', 'Legacy WeChat Channel Conflict', '{"channel":"oa","appid":"wx-app-conflict"}')
+RETURNING id
+`, wechatChannelLegacyUserID).Scan(&wechatChannelConflictLegacyID))
+
+ var linuxdoInvalidJSONLegacyID int64
+ require.NoError(t, tx.QueryRowContext(ctx, `
+INSERT INTO user_external_identities (
+ user_id,
+ provider,
+ provider_user_id,
+ provider_union_id,
+ provider_username,
+ display_name,
+ metadata
+) VALUES ($1, 'linuxdo', 'linuxdo-invalid-json', NULL, 'legacy-linuxdo-invalid', 'Legacy LinuxDo Invalid JSON', '{invalid')
+RETURNING id
+`, linuxdoInvalidJSONUserID).Scan(&linuxdoInvalidJSONLegacyID))
+
+ var wechatInvalidOpenIDLegacyID int64
+ require.NoError(t, tx.QueryRowContext(ctx, `
+INSERT INTO user_external_identities (
+ user_id,
+ provider,
+ provider_user_id,
+ provider_union_id,
+ provider_username,
+ display_name,
+ metadata
+) VALUES ($1, 'wechat', 'openid-invalid-json-only', NULL, 'legacy-wechat-invalid', 'Legacy WeChat Invalid JSON', '{still-invalid')
+RETURNING id
+`, wechatInvalidOpenIDUserID).Scan(&wechatInvalidOpenIDLegacyID))
+
+ _, err = tx.ExecContext(ctx, string(migrationSQL))
+ require.NoError(t, err)
+
+ var linuxdoConflictReportCount int
+ require.NoError(t, tx.QueryRowContext(ctx, `
+SELECT COUNT(*)
+FROM auth_identity_migration_reports
+WHERE report_type = 'legacy_external_identity_conflict'
+ AND report_key = $1
+`, "legacy_external_identity:"+strconv.FormatInt(linuxdoConflictLegacyID, 10)).Scan(&linuxdoConflictReportCount))
+ require.Equal(t, 1, linuxdoConflictReportCount)
+
+ var wechatConflictReportCount int
+ require.NoError(t, tx.QueryRowContext(ctx, `
+SELECT COUNT(*)
+FROM auth_identity_migration_reports
+WHERE report_type = 'legacy_external_identity_conflict'
+ AND report_key = $1
+`, "legacy_external_identity:"+strconv.FormatInt(wechatConflictLegacyID, 10)).Scan(&wechatConflictReportCount))
+ require.Equal(t, 1, wechatConflictReportCount)
+
+ var channelConflictReportCount int
+ require.NoError(t, tx.QueryRowContext(ctx, `
+SELECT COUNT(*)
+FROM auth_identity_migration_reports
+WHERE report_type = 'legacy_external_channel_conflict'
+ AND report_key = $1
+`, "legacy_external_identity:"+strconv.FormatInt(wechatChannelConflictLegacyID, 10)).Scan(&channelConflictReportCount))
+ require.Equal(t, 1, channelConflictReportCount)
+
+ var invalidJSONReportCount int
+ require.NoError(t, tx.QueryRowContext(ctx, `
+SELECT COUNT(*)
+FROM auth_identity_migration_reports
+WHERE report_type = 'legacy_external_identity_invalid_metadata_json'
+ AND report_key IN ($1, $2)
+`, "legacy_external_identity:"+strconv.FormatInt(linuxdoInvalidJSONLegacyID, 10), "legacy_external_identity:"+strconv.FormatInt(wechatInvalidOpenIDLegacyID, 10)).Scan(&invalidJSONReportCount))
+ require.Equal(t, 2, invalidJSONReportCount)
+
+ var linuxdoInvalidIdentityCount int
+ require.NoError(t, tx.QueryRowContext(ctx, `
+SELECT COUNT(*)
+FROM auth_identities
+WHERE user_id = $1
+ AND provider_type = 'linuxdo'
+ AND provider_key = 'linuxdo'
+ AND provider_subject = 'linuxdo-invalid-json'
+`, linuxdoInvalidJSONUserID).Scan(&linuxdoInvalidIdentityCount))
+ require.Equal(t, 1, linuxdoInvalidIdentityCount)
+
+ var wechatOpenIDOnlyReportCount int
+ require.NoError(t, tx.QueryRowContext(ctx, `
+SELECT COUNT(*)
+FROM auth_identity_migration_reports
+WHERE report_type = 'wechat_openid_only_requires_remediation'
+ AND report_key = $1
+`, "legacy_external_identity:"+strconv.FormatInt(wechatInvalidOpenIDLegacyID, 10)).Scan(&wechatOpenIDOnlyReportCount))
+ require.Equal(t, 1, wechatOpenIDOnlyReportCount)
+}
+
+func TestAuthIdentityLegacyExternalSafetyMigration_IsSafeWhenLegacyTableMissing(t *testing.T) {
+ tx := testTx(t)
+ ctx := context.Background()
+
+ migrationPath := filepath.Join("..", "..", "migrations", "116_auth_identity_legacy_external_safety_reports.sql")
+ migrationSQL, err := os.ReadFile(migrationPath)
+ require.NoError(t, err)
+
+ var beforeCount int
+ require.NoError(t, tx.QueryRowContext(ctx, `
+SELECT COUNT(*)
+FROM auth_identity_migration_reports
+`).Scan(&beforeCount))
+
+ _, err = tx.ExecContext(ctx, string(migrationSQL))
+ require.NoError(t, err)
+
+ var afterCount int
+ require.NoError(t, tx.QueryRowContext(ctx, `
+SELECT COUNT(*)
FROM auth_identity_migration_reports
`).Scan(&afterCount))
require.Equal(t, beforeCount, afterCount)
diff --git a/backend/migrations/116_auth_identity_legacy_external_safety_reports.sql b/backend/migrations/116_auth_identity_legacy_external_safety_reports.sql
new file mode 100644
index 00000000..994f3f37
--- /dev/null
+++ b/backend/migrations/116_auth_identity_legacy_external_safety_reports.sql
@@ -0,0 +1,336 @@
+CREATE OR REPLACE FUNCTION public.__migration_116_safe_legacy_metadata_jsonb(input_text TEXT)
+RETURNS JSONB
+LANGUAGE plpgsql
+AS $$
+DECLARE
+ parsed JSONB;
+BEGIN
+ IF input_text IS NULL OR BTRIM(input_text) = '' THEN
+ RETURN '{}'::jsonb;
+ END IF;
+
+ BEGIN
+ parsed := input_text::jsonb;
+ EXCEPTION
+ WHEN OTHERS THEN
+ RETURN '{}'::jsonb;
+ END;
+
+ IF jsonb_typeof(parsed) = 'object' THEN
+ RETURN parsed;
+ END IF;
+
+ RETURN jsonb_build_object('_legacy_metadata_raw_json', parsed);
+END;
+$$;
+
+CREATE OR REPLACE FUNCTION public.__migration_116_is_valid_legacy_metadata_jsonb(input_text TEXT)
+RETURNS BOOLEAN
+LANGUAGE plpgsql
+AS $$
+DECLARE
+ parsed JSONB;
+BEGIN
+ IF input_text IS NULL OR BTRIM(input_text) = '' THEN
+ RETURN TRUE;
+ END IF;
+
+ parsed := input_text::jsonb;
+ RETURN TRUE;
+EXCEPTION
+ WHEN OTHERS THEN
+ RETURN FALSE;
+END;
+$$;
+
+DO $$
+BEGIN
+ IF to_regclass('public.user_external_identities') IS NULL THEN
+ RETURN;
+ END IF;
+
+ EXECUTE $sql$
+INSERT INTO auth_identity_migration_reports (report_type, report_key, details)
+SELECT
+ 'legacy_external_identity_invalid_metadata_json',
+ 'legacy_external_identity:' || uei.id::text,
+ jsonb_build_object(
+ 'legacy_identity_id', uei.id,
+ 'user_id', uei.user_id,
+ 'provider', LOWER(BTRIM(COALESCE(uei.provider, ''))),
+ 'provider_user_id', BTRIM(COALESCE(uei.provider_user_id, '')),
+ 'provider_union_id', BTRIM(COALESCE(uei.provider_union_id, '')),
+ 'reason', 'legacy metadata is not valid JSON; migration downgraded metadata to empty object',
+ 'raw_metadata', LEFT(BTRIM(COALESCE(uei.metadata, '')), 1000),
+ 'migration', '116_auth_identity_legacy_external_safety_reports'
+ )
+FROM user_external_identities AS uei
+JOIN users AS u ON u.id = uei.user_id
+WHERE u.deleted_at IS NULL
+ AND BTRIM(COALESCE(uei.metadata, '')) <> ''
+ AND NOT public.__migration_116_is_valid_legacy_metadata_jsonb(uei.metadata)
+ON CONFLICT (report_type, report_key) DO NOTHING;
+$sql$;
+
+ EXECUTE $sql$
+INSERT INTO auth_identity_migration_reports (report_type, report_key, details)
+SELECT
+ 'legacy_external_identity_conflict',
+ 'legacy_external_identity:' || legacy.id::text,
+ legacy.metadata_json || jsonb_build_object(
+ 'legacy_identity_id', legacy.id,
+ 'legacy_user_id', legacy.user_id,
+ 'existing_identity_id', ai.id,
+ 'existing_user_id', ai.user_id,
+ 'provider_type', legacy.provider_type,
+ 'provider_key', legacy.provider_key,
+ 'provider_subject', legacy.provider_subject,
+ 'reason', 'legacy canonical identity subject already belongs to another user',
+ 'migration', '116_auth_identity_legacy_external_safety_reports'
+ )
+FROM (
+ SELECT
+ uei.id,
+ uei.user_id,
+ LOWER(BTRIM(COALESCE(uei.provider, ''))) AS provider_type,
+ CASE
+ WHEN LOWER(BTRIM(COALESCE(uei.provider, ''))) = 'wechat' THEN 'wechat-main'
+ ELSE 'linuxdo'
+ END AS provider_key,
+ CASE
+ WHEN LOWER(BTRIM(COALESCE(uei.provider, ''))) = 'wechat' THEN BTRIM(COALESCE(uei.provider_union_id, ''))
+ ELSE BTRIM(COALESCE(uei.provider_user_id, ''))
+ END AS provider_subject,
+ BTRIM(COALESCE(uei.provider_user_id, '')) AS provider_user_id,
+ BTRIM(COALESCE(uei.provider_union_id, '')) AS provider_union_id,
+ BTRIM(COALESCE(uei.provider_username, '')) AS provider_username,
+ BTRIM(COALESCE(uei.display_name, '')) AS display_name,
+ public.__migration_116_safe_legacy_metadata_jsonb(uei.metadata) AS metadata_json
+ FROM user_external_identities AS uei
+ JOIN users AS u ON u.id = uei.user_id
+ WHERE u.deleted_at IS NULL
+ AND LOWER(BTRIM(COALESCE(uei.provider, ''))) IN ('linuxdo', 'wechat')
+ AND (
+ (LOWER(BTRIM(COALESCE(uei.provider, ''))) = 'linuxdo' AND BTRIM(COALESCE(uei.provider_user_id, '')) <> '')
+ OR
+ (LOWER(BTRIM(COALESCE(uei.provider, ''))) = 'wechat' AND BTRIM(COALESCE(uei.provider_union_id, '')) <> '')
+ )
+) AS legacy
+JOIN auth_identities AS ai
+ ON ai.provider_type = legacy.provider_type
+ AND ai.provider_key = legacy.provider_key
+ AND ai.provider_subject = legacy.provider_subject
+WHERE ai.user_id <> legacy.user_id
+ON CONFLICT (report_type, report_key) DO NOTHING;
+$sql$;
+
+ EXECUTE $sql$
+INSERT INTO auth_identities (
+ user_id,
+ provider_type,
+ provider_key,
+ provider_subject,
+ verified_at,
+ metadata
+)
+SELECT
+ legacy.user_id,
+ legacy.provider_type,
+ legacy.provider_key,
+ legacy.provider_subject,
+ legacy.verified_at,
+ legacy.metadata_json || jsonb_build_object(
+ 'legacy_identity_id', legacy.id,
+ 'provider_user_id', legacy.provider_user_id,
+ 'provider_union_id', NULLIF(legacy.provider_union_id, ''),
+ 'provider_username', legacy.provider_username,
+ 'display_name', legacy.display_name,
+ 'migration', '116_auth_identity_legacy_external_safety_reports'
+ )
+FROM (
+ SELECT
+ uei.id,
+ uei.user_id,
+ LOWER(BTRIM(COALESCE(uei.provider, ''))) AS provider_type,
+ CASE
+ WHEN LOWER(BTRIM(COALESCE(uei.provider, ''))) = 'wechat' THEN 'wechat-main'
+ ELSE 'linuxdo'
+ END AS provider_key,
+ CASE
+ WHEN LOWER(BTRIM(COALESCE(uei.provider, ''))) = 'wechat' THEN BTRIM(COALESCE(uei.provider_union_id, ''))
+ ELSE BTRIM(COALESCE(uei.provider_user_id, ''))
+ END AS provider_subject,
+ BTRIM(COALESCE(uei.provider_user_id, '')) AS provider_user_id,
+ BTRIM(COALESCE(uei.provider_union_id, '')) AS provider_union_id,
+ BTRIM(COALESCE(uei.provider_username, '')) AS provider_username,
+ BTRIM(COALESCE(uei.display_name, '')) AS display_name,
+ public.__migration_116_safe_legacy_metadata_jsonb(uei.metadata) AS metadata_json,
+ COALESCE(uei.updated_at, uei.created_at, NOW()) AS verified_at
+ FROM user_external_identities AS uei
+ JOIN users AS u ON u.id = uei.user_id
+ WHERE u.deleted_at IS NULL
+ AND LOWER(BTRIM(COALESCE(uei.provider, ''))) IN ('linuxdo', 'wechat')
+ AND (
+ (LOWER(BTRIM(COALESCE(uei.provider, ''))) = 'linuxdo' AND BTRIM(COALESCE(uei.provider_user_id, '')) <> '')
+ OR
+ (LOWER(BTRIM(COALESCE(uei.provider, ''))) = 'wechat' AND BTRIM(COALESCE(uei.provider_union_id, '')) <> '')
+ )
+) AS legacy
+LEFT JOIN auth_identities AS ai
+ ON ai.provider_type = legacy.provider_type
+ AND ai.provider_key = legacy.provider_key
+ AND ai.provider_subject = legacy.provider_subject
+WHERE ai.id IS NULL
+ON CONFLICT (provider_type, provider_key, provider_subject) DO NOTHING;
+$sql$;
+
+ EXECUTE $sql$
+INSERT INTO auth_identity_migration_reports (report_type, report_key, details)
+SELECT
+ 'legacy_external_channel_conflict',
+ 'legacy_external_identity:' || legacy.id::text,
+ legacy.metadata_json || jsonb_build_object(
+ 'legacy_identity_id', legacy.id,
+ 'legacy_user_id', legacy.user_id,
+ 'existing_channel_id', channel.id,
+ 'existing_identity_id', existing_ai.id,
+ 'existing_user_id', existing_ai.user_id,
+ 'provider_type', 'wechat',
+ 'provider_key', 'wechat-main',
+ 'provider_subject', legacy.provider_union_id,
+ 'channel', legacy.channel,
+ 'channel_app_id', legacy.channel_app_id,
+ 'channel_subject', legacy.provider_user_id,
+ 'reason', 'legacy channel subject already belongs to another user',
+ 'migration', '116_auth_identity_legacy_external_safety_reports'
+ )
+FROM (
+ SELECT
+ uei.id,
+ uei.user_id,
+ BTRIM(COALESCE(uei.provider_user_id, '')) AS provider_user_id,
+ BTRIM(COALESCE(uei.provider_union_id, '')) AS provider_union_id,
+ public.__migration_116_safe_legacy_metadata_jsonb(uei.metadata) AS metadata_json,
+ BTRIM(COALESCE(public.__migration_116_safe_legacy_metadata_jsonb(uei.metadata) ->> 'channel', '')) AS channel,
+ BTRIM(COALESCE(
+ public.__migration_116_safe_legacy_metadata_jsonb(uei.metadata) ->> 'channel_app_id',
+ public.__migration_116_safe_legacy_metadata_jsonb(uei.metadata) ->> 'appid',
+ public.__migration_116_safe_legacy_metadata_jsonb(uei.metadata) ->> 'app_id',
+ ''
+ )) AS channel_app_id
+ FROM user_external_identities AS uei
+ JOIN users AS u ON u.id = uei.user_id
+ WHERE u.deleted_at IS NULL
+ AND LOWER(BTRIM(COALESCE(uei.provider, ''))) = 'wechat'
+ AND BTRIM(COALESCE(uei.provider_union_id, '')) <> ''
+ AND BTRIM(COALESCE(uei.provider_user_id, '')) <> ''
+) AS legacy
+JOIN auth_identities AS legacy_ai
+ ON legacy_ai.user_id = legacy.user_id
+ AND legacy_ai.provider_type = 'wechat'
+ AND legacy_ai.provider_key = 'wechat-main'
+ AND legacy_ai.provider_subject = legacy.provider_union_id
+JOIN auth_identity_channels AS channel
+ ON channel.provider_type = 'wechat'
+ AND channel.provider_key = 'wechat-main'
+ AND channel.channel = legacy.channel
+ AND channel.channel_app_id = legacy.channel_app_id
+ AND channel.channel_subject = legacy.provider_user_id
+JOIN auth_identities AS existing_ai
+ ON existing_ai.id = channel.identity_id
+WHERE legacy.channel <> ''
+ AND legacy.channel_app_id <> ''
+ AND existing_ai.user_id <> legacy.user_id
+ON CONFLICT (report_type, report_key) DO NOTHING;
+$sql$;
+
+ EXECUTE $sql$
+INSERT INTO auth_identity_channels (
+ identity_id,
+ provider_type,
+ provider_key,
+ channel,
+ channel_app_id,
+ channel_subject,
+ metadata
+)
+SELECT
+ legacy_ai.id,
+ 'wechat',
+ 'wechat-main',
+ legacy.channel,
+ legacy.channel_app_id,
+ legacy.provider_user_id,
+ legacy.metadata_json || jsonb_build_object(
+ 'openid', legacy.provider_user_id,
+ 'unionid', legacy.provider_union_id,
+ 'migration', '116_auth_identity_legacy_external_safety_reports'
+ )
+FROM (
+ SELECT
+ uei.user_id,
+ BTRIM(COALESCE(uei.provider_user_id, '')) AS provider_user_id,
+ BTRIM(COALESCE(uei.provider_union_id, '')) AS provider_union_id,
+ public.__migration_116_safe_legacy_metadata_jsonb(uei.metadata) AS metadata_json,
+ BTRIM(COALESCE(public.__migration_116_safe_legacy_metadata_jsonb(uei.metadata) ->> 'channel', '')) AS channel,
+ BTRIM(COALESCE(
+ public.__migration_116_safe_legacy_metadata_jsonb(uei.metadata) ->> 'channel_app_id',
+ public.__migration_116_safe_legacy_metadata_jsonb(uei.metadata) ->> 'appid',
+ public.__migration_116_safe_legacy_metadata_jsonb(uei.metadata) ->> 'app_id',
+ ''
+ )) AS channel_app_id
+ FROM user_external_identities AS uei
+ JOIN users AS u ON u.id = uei.user_id
+ WHERE u.deleted_at IS NULL
+ AND LOWER(BTRIM(COALESCE(uei.provider, ''))) = 'wechat'
+ AND BTRIM(COALESCE(uei.provider_union_id, '')) <> ''
+ AND BTRIM(COALESCE(uei.provider_user_id, '')) <> ''
+) AS legacy
+JOIN auth_identities AS legacy_ai
+ ON legacy_ai.user_id = legacy.user_id
+ AND legacy_ai.provider_type = 'wechat'
+ AND legacy_ai.provider_key = 'wechat-main'
+ AND legacy_ai.provider_subject = legacy.provider_union_id
+LEFT JOIN auth_identity_channels AS channel
+ ON channel.provider_type = 'wechat'
+ AND channel.provider_key = 'wechat-main'
+ AND channel.channel = legacy.channel
+ AND channel.channel_app_id = legacy.channel_app_id
+ AND channel.channel_subject = legacy.provider_user_id
+WHERE legacy.channel <> ''
+ AND legacy.channel_app_id <> ''
+ AND channel.id IS NULL
+ON CONFLICT DO NOTHING;
+$sql$;
+
+ EXECUTE $sql$
+INSERT INTO auth_identity_migration_reports (report_type, report_key, details)
+SELECT
+ 'wechat_openid_only_requires_remediation',
+ 'legacy_external_identity:' || legacy.id::text,
+ legacy.metadata_json || jsonb_build_object(
+ 'legacy_identity_id', legacy.id,
+ 'user_id', legacy.user_id,
+ 'openid', legacy.provider_user_id,
+ 'reason', 'legacy user_external_identities row only has openid and cannot be canonicalized offline',
+ 'migration', '116_auth_identity_legacy_external_safety_reports'
+ )
+FROM (
+ SELECT
+ uei.id,
+ uei.user_id,
+ BTRIM(COALESCE(uei.provider_user_id, '')) AS provider_user_id,
+ public.__migration_116_safe_legacy_metadata_jsonb(uei.metadata) AS metadata_json
+ FROM user_external_identities AS uei
+ JOIN users AS u ON u.id = uei.user_id
+ WHERE u.deleted_at IS NULL
+ AND LOWER(BTRIM(COALESCE(uei.provider, ''))) = 'wechat'
+ AND BTRIM(COALESCE(uei.provider_user_id, '')) <> ''
+ AND BTRIM(COALESCE(uei.provider_union_id, '')) = ''
+) AS legacy
+ON CONFLICT (report_type, report_key) DO NOTHING;
+$sql$;
+END $$;
+
+DROP FUNCTION IF EXISTS public.__migration_116_is_valid_legacy_metadata_jsonb(TEXT);
+DROP FUNCTION IF EXISTS public.__migration_116_safe_legacy_metadata_jsonb(TEXT);
From ea27ac6fd7d21ec5b6ce6b861b3d7889293624d3 Mon Sep 17 00:00:00 2001
From: IanShaw027
Date: Tue, 21 Apr 2026 01:00:59 +0800
Subject: [PATCH 080/189] fix: unify email identity sync and retry first-bind
defaults
---
backend/internal/repository/user_repo.go | 8 --
backend/internal/service/admin_service.go | 7 --
.../admin_service_email_identity_sync_test.go | 28 +++----
backend/internal/service/auth_service.go | 81 ++++++++++++++++---
.../auth_service_identity_sync_test.go | 68 ++++++++++++++++
backend/internal/service/user_service.go | 31 -------
.../user_service_email_identity_sync_test.go | 9 +--
7 files changed, 154 insertions(+), 78 deletions(-)
diff --git a/backend/internal/repository/user_repo.go b/backend/internal/repository/user_repo.go
index 25d3f1d6..195776a3 100644
--- a/backend/internal/repository/user_repo.go
+++ b/backend/internal/repository/user_repo.go
@@ -209,14 +209,6 @@ func (r *userRepository) Update(ctx context.Context, userIn *service.User) error
return nil
}
-func (r *userRepository) EnsureEmailAuthIdentity(ctx context.Context, userID int64, email string) error {
- return ensureEmailAuthIdentityWithClient(ctx, r.client, userID, email, "service_dual_write")
-}
-
-func (r *userRepository) ReplaceEmailAuthIdentity(ctx context.Context, userID int64, oldEmail, newEmail string) error {
- return replaceEmailAuthIdentityWithClient(ctx, r.client, userID, oldEmail, newEmail, "service_dual_write")
-}
-
func ensureEmailAuthIdentityWithClient(ctx context.Context, client *dbent.Client, userID int64, email string, source string) error {
client = clientFromContext(ctx, client)
if client == nil || userID <= 0 {
diff --git a/backend/internal/service/admin_service.go b/backend/internal/service/admin_service.go
index 10b85f76..ce1c1a77 100644
--- a/backend/internal/service/admin_service.go
+++ b/backend/internal/service/admin_service.go
@@ -650,9 +650,6 @@ func (s *adminServiceImpl) CreateUser(ctx context.Context, input *CreateUserInpu
if err := s.userRepo.Create(ctx, user); err != nil {
return nil, err
}
- if err := ensureEmailAuthIdentitySync(ctx, s.userRepo, user.ID, user.Email); err != nil {
- return nil, fmt.Errorf("sync email auth identity: %w", err)
- }
s.assignDefaultSubscriptions(ctx, user.ID)
return user, nil
}
@@ -688,7 +685,6 @@ func (s *adminServiceImpl) UpdateUser(ctx context.Context, id int64, input *Upda
oldConcurrency := user.Concurrency
oldStatus := user.Status
oldRole := user.Role
- oldEmail := user.Email
if input.Email != "" {
user.Email = input.Email
@@ -721,9 +717,6 @@ func (s *adminServiceImpl) UpdateUser(ctx context.Context, id int64, input *Upda
if err := s.userRepo.Update(ctx, user); err != nil {
return nil, err
}
- if err := replaceEmailAuthIdentitySync(ctx, s.userRepo, user.ID, oldEmail, user.Email); err != nil {
- return nil, fmt.Errorf("sync email auth identity: %w", err)
- }
// 同步用户专属分组倍率
if input.GroupRates != nil && s.userGroupRateRepo != nil {
diff --git a/backend/internal/service/admin_service_email_identity_sync_test.go b/backend/internal/service/admin_service_email_identity_sync_test.go
index d555d609..d6a7af9a 100644
--- a/backend/internal/service/admin_service_email_identity_sync_test.go
+++ b/backend/internal/service/admin_service_email_identity_sync_test.go
@@ -31,6 +31,8 @@ type emailSyncRepoStub struct {
updated []*User
ensureCalls []ensureEmailCall
replaceCalls []replaceEmailCall
+ ensureErr error
+ replaceErr error
}
func (s *emailSyncRepoStub) Create(_ context.Context, user *User) error {
@@ -125,7 +127,7 @@ func (s *emailSyncRepoStub) DisableTotp(context.Context, int64) error { return n
func (s *emailSyncRepoStub) EnsureEmailAuthIdentity(_ context.Context, userID int64, email string) error {
s.ensureCalls = append(s.ensureCalls, ensureEmailCall{userID: userID, email: email})
- return nil
+ return s.ensureErr
}
func (s *emailSyncRepoStub) ReplaceEmailAuthIdentity(_ context.Context, userID int64, oldEmail, newEmail string) error {
@@ -134,11 +136,14 @@ func (s *emailSyncRepoStub) ReplaceEmailAuthIdentity(_ context.Context, userID i
oldEmail: oldEmail,
newEmail: newEmail,
})
- return nil
+ return s.replaceErr
}
-func TestAdminService_CreateUser_EnsuresEmailAuthIdentity(t *testing.T) {
- repo := &emailSyncRepoStub{nextID: 55}
+func TestAdminService_CreateUser_DoesNotReturnPartialSuccessFromEmailIdentityResync(t *testing.T) {
+ repo := &emailSyncRepoStub{
+ nextID: 55,
+ ensureErr: fmt.Errorf("unexpected email resync"),
+ }
svc := &adminServiceImpl{userRepo: repo}
user, err := svc.CreateUser(context.Background(), &CreateUserInput{
@@ -147,14 +152,12 @@ func TestAdminService_CreateUser_EnsuresEmailAuthIdentity(t *testing.T) {
})
require.NoError(t, err)
require.NotNil(t, user)
- require.Equal(t, []ensureEmailCall{{
- userID: 55,
- email: "admin-created@example.com",
- }}, repo.ensureCalls)
+ require.Equal(t, int64(55), user.ID)
+ require.Empty(t, repo.ensureCalls)
require.Empty(t, repo.replaceCalls)
}
-func TestAdminService_UpdateUser_ReplacesEmailAuthIdentity(t *testing.T) {
+func TestAdminService_UpdateUser_DoesNotReturnPartialSuccessFromEmailIdentityResync(t *testing.T) {
repo := &emailSyncRepoStub{
user: &User{
ID: 91,
@@ -163,6 +166,7 @@ func TestAdminService_UpdateUser_ReplacesEmailAuthIdentity(t *testing.T) {
Status: StatusActive,
Concurrency: 3,
},
+ replaceErr: fmt.Errorf("unexpected email resync"),
}
svc := &adminServiceImpl{userRepo: repo}
@@ -172,10 +176,6 @@ func TestAdminService_UpdateUser_ReplacesEmailAuthIdentity(t *testing.T) {
require.NoError(t, err)
require.NotNil(t, updated)
require.Equal(t, "after@example.com", updated.Email)
- require.Equal(t, []replaceEmailCall{{
- userID: 91,
- oldEmail: "before@example.com",
- newEmail: "after@example.com",
- }}, repo.replaceCalls)
+ require.Empty(t, repo.replaceCalls)
require.Empty(t, repo.ensureCalls)
}
diff --git a/backend/internal/service/auth_service.go b/backend/internal/service/auth_service.go
index d0d5e4e3..00fefd82 100644
--- a/backend/internal/service/auth_service.go
+++ b/backend/internal/service/auth_service.go
@@ -768,9 +768,6 @@ func (s *AuthService) postAuthUserBootstrap(ctx context.Context, user *User, sig
}
s.updateUserSignupSource(ctx, user.ID, signupSource)
- if signupSource == "email" {
- s.ensureEmailAuthIdentity(ctx, user)
- }
if touchLogin {
s.touchUserLogin(ctx, user.ID)
}
@@ -807,21 +804,81 @@ func (s *AuthService) backfillEmailIdentityOnSuccessfulLogin(ctx context.Context
if s == nil || user == nil || user.ID <= 0 {
return
}
- if s.ensureEmailAuthIdentity(ctx, user) {
+ identity, created := s.ensureEmailAuthIdentity(ctx, user)
+ if s.shouldApplyEmailFirstBindDefaults(ctx, user.ID, identity, created) {
if err := s.ApplyProviderDefaultSettingsOnFirstBind(ctx, user.ID, "email"); err != nil {
logger.LegacyPrintf("service.auth", "[Auth] Failed to apply email first bind defaults: user_id=%d err=%v", user.ID, err)
}
}
}
-func (s *AuthService) ensureEmailAuthIdentity(ctx context.Context, user *User) bool {
- if s == nil || s.entClient == nil || user == nil || user.ID <= 0 {
+func (s *AuthService) shouldApplyEmailFirstBindDefaults(
+ ctx context.Context,
+ userID int64,
+ identity *dbent.AuthIdentity,
+ created bool,
+) bool {
+ if created {
+ return true
+ }
+ if s == nil || s.entClient == nil || userID <= 0 || identity == nil || identity.UserID != userID {
return false
}
+ if emailAuthIdentitySource(identity.Metadata) != "auth_service_dual_write" {
+ return false
+ }
+
+ hasGrant, err := s.hasProviderGrantRecord(ctx, userID, "email", "first_bind")
+ if err != nil {
+ logger.LegacyPrintf("service.auth", "[Auth] Failed to inspect email first bind grant state: user_id=%d err=%v", userID, err)
+ return false
+ }
+ return !hasGrant
+}
+
+func emailAuthIdentitySource(metadata map[string]any) string {
+ if len(metadata) == 0 {
+ return ""
+ }
+ raw, ok := metadata["source"]
+ if !ok {
+ return ""
+ }
+ return strings.TrimSpace(fmt.Sprint(raw))
+}
+
+func (s *AuthService) hasProviderGrantRecord(
+ ctx context.Context,
+ userID int64,
+ providerType string,
+ grantReason string,
+) (bool, error) {
+ if s == nil || s.entClient == nil || userID <= 0 {
+ return false, nil
+ }
+
+ rows, err := s.entClient.QueryContext(
+ ctx,
+ `SELECT 1 FROM user_provider_default_grants WHERE user_id = ? AND provider_type = ? AND grant_reason = ? LIMIT 1`,
+ userID,
+ strings.TrimSpace(providerType),
+ strings.TrimSpace(grantReason),
+ )
+ if err != nil {
+ return false, err
+ }
+ defer rows.Close()
+ return rows.Next(), rows.Err()
+}
+
+func (s *AuthService) ensureEmailAuthIdentity(ctx context.Context, user *User) (*dbent.AuthIdentity, bool) {
+ if s == nil || s.entClient == nil || user == nil || user.ID <= 0 {
+ return nil, false
+ }
email := strings.ToLower(strings.TrimSpace(user.Email))
if email == "" || isReservedEmail(email) {
- return false
+ return nil, false
}
client := s.entClient
@@ -840,7 +897,7 @@ func (s *AuthService) ensureEmailAuthIdentity(ctx context.Context, user *User) b
existed, err := buildQuery().Exist(ctx)
if err != nil {
logger.LegacyPrintf("service.auth", "[Auth] Failed to inspect email auth identity: user_id=%d email=%s err=%v", user.ID, email, err)
- return false
+ return nil, false
}
if !existed {
@@ -861,21 +918,21 @@ func (s *AuthService) ensureEmailAuthIdentity(ctx context.Context, user *User) b
DoNothing().
Exec(ctx); err != nil {
logger.LegacyPrintf("service.auth", "[Auth] Failed to ensure email auth identity: user_id=%d email=%s err=%v", user.ID, email, err)
- return false
+ return nil, false
}
}
identity, err := buildQuery().Only(ctx)
if err != nil {
logger.LegacyPrintf("service.auth", "[Auth] Failed to reload email auth identity: user_id=%d email=%s err=%v", user.ID, email, err)
- return false
+ return nil, false
}
if identity.UserID != user.ID {
logger.LegacyPrintf("service.auth", "[Auth] Email auth identity ownership mismatch: user_id=%d email=%s owner_id=%d", user.ID, email, identity.UserID)
- return false
+ return nil, false
}
- return !existed
+ return identity, !existed
}
func inferLegacySignupSource(email string) string {
diff --git a/backend/internal/service/auth_service_identity_sync_test.go b/backend/internal/service/auth_service_identity_sync_test.go
index e2a94b13..95c9c933 100644
--- a/backend/internal/service/auth_service_identity_sync_test.go
+++ b/backend/internal/service/auth_service_identity_sync_test.go
@@ -5,6 +5,7 @@ package service_test
import (
"context"
"database/sql"
+ "errors"
"testing"
"time"
@@ -34,6 +35,24 @@ func (s *authIdentityDefaultSubAssignerStub) AssignOrExtendSubscription(
return &service.UserSubscription{UserID: input.UserID, GroupID: input.GroupID}, true, nil
}
+type flakyAuthIdentityDefaultSubAssignerStub struct {
+ failuresRemaining int
+ calls []*service.AssignSubscriptionInput
+}
+
+func (s *flakyAuthIdentityDefaultSubAssignerStub) AssignOrExtendSubscription(
+ _ context.Context,
+ input *service.AssignSubscriptionInput,
+) (*service.UserSubscription, bool, error) {
+ cloned := *input
+ s.calls = append(s.calls, &cloned)
+ if s.failuresRemaining > 0 {
+ s.failuresRemaining--
+ return nil, false, errors.New("temporary assign failure")
+ }
+ return &service.UserSubscription{UserID: input.UserID, GroupID: input.GroupID}, true, nil
+}
+
type authIdentitySettingRepoStub struct {
values map[string]string
}
@@ -333,6 +352,55 @@ func TestAuthServiceLogin_DoesNotApplyEmailFirstBindDefaultsWhenIdentityAlreadyE
require.Equal(t, 0, countProviderGrantRecords(t, client, user.ID, "email", "first_bind"))
}
+func TestAuthServiceLogin_RetriesEmailFirstBindDefaultsAfterPreviousFailure(t *testing.T) {
+ assigner := &flakyAuthIdentityDefaultSubAssignerStub{failuresRemaining: 1}
+ svc, _, client := newAuthServiceWithEnt(t, map[string]string{
+ service.SettingKeyRegistrationEnabled: "true",
+ service.SettingKeyAuthSourceDefaultEmailBalance: "8.5",
+ service.SettingKeyAuthSourceDefaultEmailConcurrency: "4",
+ service.SettingKeyAuthSourceDefaultEmailSubscriptions: `[{"group_id":11,"validity_days":30}]`,
+ service.SettingKeyAuthSourceDefaultEmailGrantOnFirstBind: "true",
+ }, assigner)
+ ctx := context.Background()
+
+ passwordHash, err := svc.HashPassword("password")
+ require.NoError(t, err)
+ user, err := client.User.Create().
+ SetEmail("retry-first-bind@example.com").
+ SetUsername("retry-user").
+ SetPasswordHash(passwordHash).
+ SetBalance(1.5).
+ SetConcurrency(2).
+ SetRole(service.RoleUser).
+ SetStatus(service.StatusActive).
+ Save(ctx)
+ require.NoError(t, err)
+
+ token, gotUser, err := svc.Login(ctx, user.Email, "password")
+ require.NoError(t, err)
+ require.NotEmpty(t, token)
+ require.NotNil(t, gotUser)
+
+ storedUser, err := client.User.Get(ctx, user.ID)
+ require.NoError(t, err)
+ require.Equal(t, 1.5, storedUser.Balance)
+ require.Equal(t, 2, storedUser.Concurrency)
+ require.Len(t, assigner.calls, 1)
+ require.Equal(t, 0, countProviderGrantRecords(t, client, user.ID, "email", "first_bind"))
+
+ token, gotUser, err = svc.Login(ctx, user.Email, "password")
+ require.NoError(t, err)
+ require.NotEmpty(t, token)
+ require.NotNil(t, gotUser)
+
+ storedUser, err = client.User.Get(ctx, user.ID)
+ require.NoError(t, err)
+ require.Equal(t, 10.0, storedUser.Balance)
+ require.Equal(t, 6, storedUser.Concurrency)
+ require.Len(t, assigner.calls, 2)
+ require.Equal(t, 1, countProviderGrantRecords(t, client, user.ID, "email", "first_bind"))
+}
+
func countProviderGrantRecords(
t *testing.T,
client *dbent.Client,
diff --git a/backend/internal/service/user_service.go b/backend/internal/service/user_service.go
index cd1bc2bb..7c2ca2d0 100644
--- a/backend/internal/service/user_service.go
+++ b/backend/internal/service/user_service.go
@@ -161,33 +161,6 @@ type userAuthIdentityReader interface {
ListUserAuthIdentities(ctx context.Context, userID int64) ([]UserAuthIdentityRecord, error)
}
-type emailAuthIdentitySynchronizer interface {
- EnsureEmailAuthIdentity(ctx context.Context, userID int64, email string) error
- ReplaceEmailAuthIdentity(ctx context.Context, userID int64, oldEmail, newEmail string) error
-}
-
-func ensureEmailAuthIdentitySync(ctx context.Context, repo UserRepository, userID int64, email string) error {
- syncer, ok := repo.(emailAuthIdentitySynchronizer)
- if !ok {
- return nil
- }
- return syncer.EnsureEmailAuthIdentity(ctx, userID, email)
-}
-
-func replaceEmailAuthIdentitySync(ctx context.Context, repo UserRepository, userID int64, oldEmail, newEmail string) error {
- oldNormalized := strings.ToLower(strings.TrimSpace(oldEmail))
- newNormalized := strings.ToLower(strings.TrimSpace(newEmail))
- if oldNormalized == newNormalized {
- return nil
- }
-
- syncer, ok := repo.(emailAuthIdentitySynchronizer)
- if !ok {
- return nil
- }
- return syncer.ReplaceEmailAuthIdentity(ctx, userID, oldEmail, newEmail)
-}
-
// ChangePasswordRequest 修改密码请求
type ChangePasswordRequest struct {
CurrentPassword string `json:"current_password"`
@@ -281,7 +254,6 @@ func (s *UserService) UpdateProfile(ctx context.Context, userID int64, req Updat
return nil, fmt.Errorf("get user: %w", err)
}
oldConcurrency := user.Concurrency
- oldEmail := user.Email
// 更新字段
if req.Email != nil {
@@ -326,9 +298,6 @@ func (s *UserService) UpdateProfile(ctx context.Context, userID int64, req Updat
if err := s.userRepo.Update(ctx, user); err != nil {
return nil, fmt.Errorf("update user: %w", err)
}
- if err := replaceEmailAuthIdentitySync(ctx, s.userRepo, user.ID, oldEmail, user.Email); err != nil {
- return nil, fmt.Errorf("sync email auth identity: %w", err)
- }
if s.authCacheInvalidator != nil && user.Concurrency != oldConcurrency {
s.authCacheInvalidator.InvalidateAuthCacheByUserID(ctx, userID)
}
diff --git a/backend/internal/service/user_service_email_identity_sync_test.go b/backend/internal/service/user_service_email_identity_sync_test.go
index 8109b368..702b3b1a 100644
--- a/backend/internal/service/user_service_email_identity_sync_test.go
+++ b/backend/internal/service/user_service_email_identity_sync_test.go
@@ -9,7 +9,7 @@ import (
"github.com/stretchr/testify/require"
)
-func TestUpdateProfile_ReplacesEmailAuthIdentityWhenEmailChanges(t *testing.T) {
+func TestUpdateProfile_DoesNotReturnPartialSuccessFromEmailIdentityResync(t *testing.T) {
repo := &emailSyncRepoStub{
user: &User{
ID: 19,
@@ -17,6 +17,7 @@ func TestUpdateProfile_ReplacesEmailAuthIdentityWhenEmailChanges(t *testing.T) {
Username: "tester",
Concurrency: 2,
},
+ replaceErr: context.DeadlineExceeded,
}
svc := NewUserService(repo, nil, nil, nil)
@@ -28,10 +29,6 @@ func TestUpdateProfile_ReplacesEmailAuthIdentityWhenEmailChanges(t *testing.T) {
require.NotNil(t, updated)
require.Equal(t, newEmail, updated.Email)
require.Equal(t, 1, repo.updateCalls)
- require.Equal(t, []replaceEmailCall{{
- userID: 19,
- oldEmail: "profile-before@example.com",
- newEmail: "profile-after@example.com",
- }}, repo.replaceCalls)
+ require.Empty(t, repo.replaceCalls)
require.Empty(t, repo.ensureCalls)
}
From a70f7aca07c6d3c7b42c0a9487de2edf5cee0a43 Mon Sep 17 00:00:00 2001
From: IanShaw027
Date: Tue, 21 Apr 2026 01:05:59 +0800
Subject: [PATCH 081/189] fix pending auth session restore
---
frontend/src/stores/__tests__/auth.spec.ts | 29 ++++++++++++++++++++++
frontend/src/stores/auth.ts | 24 +++++++++++-------
2 files changed, 44 insertions(+), 9 deletions(-)
diff --git a/frontend/src/stores/__tests__/auth.spec.ts b/frontend/src/stores/__tests__/auth.spec.ts
index c96a606a..56195b5b 100644
--- a/frontend/src/stores/__tests__/auth.spec.ts
+++ b/frontend/src/stores/__tests__/auth.spec.ts
@@ -261,6 +261,35 @@ describe('useAuthStore', () => {
expect(localStorage.getItem('pending_auth_session')).toBeNull()
})
+ it('restores a persisted pending oauth session without requiring a token value', () => {
+ const firstStore = useAuthStore()
+
+ firstStore.setPendingAuthSession({
+ token: '',
+ token_field: 'pending_oauth_token',
+ provider: 'oidc',
+ redirect: '/welcome',
+ adoption_required: true,
+ suggested_display_name: 'OIDC Nick'
+ })
+
+ setActivePinia(createPinia())
+ const restoredStore = useAuthStore()
+ restoredStore.checkAuth()
+
+ expect(restoredStore.isAuthenticated).toBe(false)
+ expect(restoredStore.hasPendingAuthSession).toBe(true)
+ expect(restoredStore.pendingAuthSession).toEqual({
+ token: '',
+ token_field: 'pending_oauth_token',
+ provider: 'oidc',
+ redirect: '/welcome',
+ adoption_required: true,
+ suggested_display_name: 'OIDC Nick',
+ suggested_avatar_url: undefined
+ })
+ })
+
it('preserves pending auth session when registration fails', async () => {
const store = useAuthStore()
store.setPendingAuthSession({
diff --git a/frontend/src/stores/auth.ts b/frontend/src/stores/auth.ts
index 08ba5484..4b712692 100644
--- a/frontend/src/stores/auth.ts
+++ b/frontend/src/stores/auth.ts
@@ -28,6 +28,10 @@ interface PendingAuthSessionSummary {
suggested_avatar_url?: string
}
+function normalizePendingAuthTokenField(value: unknown): PendingAuthTokenField {
+ return value === 'pending_oauth_token' ? 'pending_oauth_token' : 'pending_auth_token'
+}
+
function getPersistedPendingAuthSession(): PendingAuthSessionSummary | null {
const raw = localStorage.getItem(PENDING_AUTH_SESSION_KEY)
if (!raw) {
@@ -35,18 +39,20 @@ function getPersistedPendingAuthSession(): PendingAuthSessionSummary | null {
}
try {
- const parsed = JSON.parse(raw) as PendingAuthSessionSummary
- if (!parsed?.token || !parsed?.provider) {
+ const parsed = JSON.parse(raw) as Partial | null
+ const provider = typeof parsed?.provider === 'string' ? parsed.provider.trim() : ''
+ if (!provider) {
+ localStorage.removeItem(PENDING_AUTH_SESSION_KEY)
return null
}
return {
- token: parsed.token,
- token_field: parsed.token_field || 'pending_auth_token',
- provider: parsed.provider,
- redirect: parsed.redirect,
- adoption_required: parsed.adoption_required,
- suggested_display_name: parsed.suggested_display_name,
- suggested_avatar_url: parsed.suggested_avatar_url
+ token: typeof parsed?.token === 'string' ? parsed.token : '',
+ token_field: normalizePendingAuthTokenField(parsed?.token_field),
+ provider,
+ redirect: typeof parsed?.redirect === 'string' ? parsed.redirect : undefined,
+ adoption_required: typeof parsed?.adoption_required === 'boolean' ? parsed.adoption_required : undefined,
+ suggested_display_name: typeof parsed?.suggested_display_name === 'string' ? parsed.suggested_display_name : undefined,
+ suggested_avatar_url: typeof parsed?.suggested_avatar_url === 'string' ? parsed.suggested_avatar_url : undefined
}
} catch {
localStorage.removeItem(PENDING_AUTH_SESSION_KEY)
From 0a461d824803fadf118313bb5645d1a9fcba0c9d Mon Sep 17 00:00:00 2001
From: IanShaw027
Date: Tue, 21 Apr 2026 01:30:37 +0800
Subject: [PATCH 082/189] fix: harden auth identity legacy migrations
---
...ntity_legacy_migration_integration_test.go | 193 ++++++++++++++++++
.../repository/user_profile_identity_repo.go | 29 +++
...ser_profile_identity_repo_contract_test.go | 42 ++++
...auth_identity_legacy_external_backfill.sql | 36 +++-
...dentity_legacy_external_safety_reports.sql | 33 +++
5 files changed, 329 insertions(+), 4 deletions(-)
diff --git a/backend/internal/repository/auth_identity_legacy_migration_integration_test.go b/backend/internal/repository/auth_identity_legacy_migration_integration_test.go
index 7f2f363f..29654417 100644
--- a/backend/internal/repository/auth_identity_legacy_migration_integration_test.go
+++ b/backend/internal/repository/auth_identity_legacy_migration_integration_test.go
@@ -205,6 +205,199 @@ FROM auth_identity_migration_reports
require.Equal(t, beforeCount, afterCount)
}
+func TestAuthIdentityLegacyExternalMigrations_ChainHandlesMalformedAndNonObjectMetadata(t *testing.T) {
+ tx := testTx(t)
+ ctx := context.Background()
+
+ migration115Path := filepath.Join("..", "..", "migrations", "115_auth_identity_legacy_external_backfill.sql")
+ migration115SQL, err := os.ReadFile(migration115Path)
+ require.NoError(t, err)
+
+ migration116Path := filepath.Join("..", "..", "migrations", "116_auth_identity_legacy_external_safety_reports.sql")
+ migration116SQL, err := os.ReadFile(migration116Path)
+ require.NoError(t, err)
+
+ _, err = tx.ExecContext(ctx, `
+CREATE TABLE IF NOT EXISTS user_external_identities (
+ id BIGSERIAL PRIMARY KEY,
+ user_id BIGINT NOT NULL,
+ provider TEXT NOT NULL,
+ provider_user_id TEXT NOT NULL,
+ provider_union_id TEXT NULL,
+ provider_username TEXT NOT NULL DEFAULT '',
+ display_name TEXT NOT NULL DEFAULT '',
+ profile_url TEXT NOT NULL DEFAULT '',
+ avatar_url TEXT NOT NULL DEFAULT '',
+ metadata TEXT NOT NULL DEFAULT '{}',
+ created_at TIMESTAMPTZ NOT NULL DEFAULT CURRENT_TIMESTAMP,
+ updated_at TIMESTAMPTZ NOT NULL DEFAULT CURRENT_TIMESTAMP
+);
+
+TRUNCATE TABLE
+ auth_identity_channels,
+ auth_identities,
+ auth_identity_migration_reports,
+ user_external_identities,
+ users
+RESTART IDENTITY;
+`)
+ require.NoError(t, err)
+
+ var linuxDoMalformedUserID int64
+ require.NoError(t, tx.QueryRowContext(ctx, `
+INSERT INTO users (email, password_hash, role, status, balance, concurrency)
+VALUES ('legacy-linuxdo-malformed@example.com', 'hash', 'user', 'active', 0, 1)
+RETURNING id`).Scan(&linuxDoMalformedUserID))
+
+ var linuxDoArrayUserID int64
+ require.NoError(t, tx.QueryRowContext(ctx, `
+INSERT INTO users (email, password_hash, role, status, balance, concurrency)
+VALUES ('legacy-linuxdo-array@example.com', 'hash', 'user', 'active', 0, 1)
+RETURNING id`).Scan(&linuxDoArrayUserID))
+
+ var wechatUnionArrayUserID int64
+ require.NoError(t, tx.QueryRowContext(ctx, `
+INSERT INTO users (email, password_hash, role, status, balance, concurrency)
+VALUES ('legacy-wechat-array@example.com', 'hash', 'user', 'active', 0, 1)
+RETURNING id`).Scan(&wechatUnionArrayUserID))
+
+ var wechatOpenIDArrayUserID int64
+ require.NoError(t, tx.QueryRowContext(ctx, `
+INSERT INTO users (email, password_hash, role, status, balance, concurrency)
+VALUES ('legacy-wechat-openid-array@example.com', 'hash', 'user', 'active', 0, 1)
+RETURNING id`).Scan(&wechatOpenIDArrayUserID))
+
+ var linuxDoMalformedLegacyID int64
+ require.NoError(t, tx.QueryRowContext(ctx, `
+INSERT INTO user_external_identities (
+ user_id,
+ provider,
+ provider_user_id,
+ provider_union_id,
+ provider_username,
+ display_name,
+ metadata
+) VALUES ($1, 'linuxdo', 'linuxdo-malformed', NULL, 'legacy-linuxdo-malformed', 'Legacy LinuxDo Malformed', '{invalid')
+RETURNING id
+`, linuxDoMalformedUserID).Scan(&linuxDoMalformedLegacyID))
+
+ var linuxDoArrayLegacyID int64
+ require.NoError(t, tx.QueryRowContext(ctx, `
+INSERT INTO user_external_identities (
+ user_id,
+ provider,
+ provider_user_id,
+ provider_union_id,
+ provider_username,
+ display_name,
+ metadata
+) VALUES ($1, 'linuxdo', 'linuxdo-array', NULL, 'legacy-linuxdo-array', 'Legacy LinuxDo Array', '["legacy-linuxdo-array"]')
+RETURNING id
+`, linuxDoArrayUserID).Scan(&linuxDoArrayLegacyID))
+
+ var wechatUnionArrayLegacyID int64
+ require.NoError(t, tx.QueryRowContext(ctx, `
+INSERT INTO user_external_identities (
+ user_id,
+ provider,
+ provider_user_id,
+ provider_union_id,
+ provider_username,
+ display_name,
+ metadata
+) VALUES ($1, 'wechat', 'openid-array', 'union-array', 'legacy-wechat-array', 'Legacy WeChat Array', '["legacy-wechat-array"]')
+RETURNING id
+`, wechatUnionArrayUserID).Scan(&wechatUnionArrayLegacyID))
+
+ var wechatOpenIDArrayLegacyID int64
+ require.NoError(t, tx.QueryRowContext(ctx, `
+INSERT INTO user_external_identities (
+ user_id,
+ provider,
+ provider_user_id,
+ provider_union_id,
+ provider_username,
+ display_name,
+ metadata
+) VALUES ($1, 'wechat', 'openid-array-only', NULL, 'legacy-wechat-array-only', 'Legacy WeChat Array Only', '["legacy-wechat-openid-array"]')
+RETURNING id
+`, wechatOpenIDArrayUserID).Scan(&wechatOpenIDArrayLegacyID))
+
+ _, err = tx.ExecContext(ctx, string(migration115SQL))
+ require.NoError(t, err)
+
+ _, err = tx.ExecContext(ctx, string(migration116SQL))
+ require.NoError(t, err)
+
+ var linuxDoMalformedMetadataType string
+ require.NoError(t, tx.QueryRowContext(ctx, `
+SELECT jsonb_typeof(metadata)
+FROM auth_identities
+WHERE user_id = $1
+ AND provider_type = 'linuxdo'
+ AND provider_key = 'linuxdo'
+ AND provider_subject = 'linuxdo-malformed'
+`, linuxDoMalformedUserID).Scan(&linuxDoMalformedMetadataType))
+ require.Equal(t, "object", linuxDoMalformedMetadataType)
+
+ var linuxDoArrayMetadataType string
+ require.NoError(t, tx.QueryRowContext(ctx, `
+SELECT jsonb_typeof(metadata)
+FROM auth_identities
+WHERE user_id = $1
+ AND provider_type = 'linuxdo'
+ AND provider_key = 'linuxdo'
+ AND provider_subject = 'linuxdo-array'
+`, linuxDoArrayUserID).Scan(&linuxDoArrayMetadataType))
+ require.Equal(t, "object", linuxDoArrayMetadataType)
+
+ var wechatUnionArrayMetadataType string
+ require.NoError(t, tx.QueryRowContext(ctx, `
+SELECT jsonb_typeof(metadata)
+FROM auth_identities
+WHERE user_id = $1
+ AND provider_type = 'wechat'
+ AND provider_key = 'wechat-main'
+ AND provider_subject = 'union-array'
+`, wechatUnionArrayUserID).Scan(&wechatUnionArrayMetadataType))
+ require.Equal(t, "object", wechatUnionArrayMetadataType)
+
+ var invalidJSONReportDetailsType string
+ require.NoError(t, tx.QueryRowContext(ctx, `
+SELECT jsonb_typeof(details)
+FROM auth_identity_migration_reports
+WHERE report_type = 'legacy_external_identity_invalid_metadata_json'
+ AND report_key = $1
+`, "legacy_external_identity:"+strconv.FormatInt(linuxDoMalformedLegacyID, 10)).Scan(&invalidJSONReportDetailsType))
+ require.Equal(t, "object", invalidJSONReportDetailsType)
+
+ var openIDOnlyReportDetailsType string
+ require.NoError(t, tx.QueryRowContext(ctx, `
+SELECT jsonb_typeof(details)
+FROM auth_identity_migration_reports
+WHERE report_type = 'wechat_openid_only_requires_remediation'
+ AND report_key = $1
+`, "legacy_external_identity:"+strconv.FormatInt(wechatOpenIDArrayLegacyID, 10)).Scan(&openIDOnlyReportDetailsType))
+ require.Equal(t, "object", openIDOnlyReportDetailsType)
+
+ var preservedArrayMetadataCount int
+ require.NoError(t, tx.QueryRowContext(ctx, `
+SELECT COUNT(*)
+FROM auth_identities
+WHERE id IN (
+ SELECT id
+ FROM auth_identities
+ WHERE (user_id = $1 AND provider_subject = 'linuxdo-array')
+ OR (user_id = $2 AND provider_subject = 'union-array')
+)
+ AND metadata ? '_legacy_metadata_raw_json'
+`, linuxDoArrayUserID, wechatUnionArrayUserID).Scan(&preservedArrayMetadataCount))
+ require.Equal(t, 2, preservedArrayMetadataCount)
+
+ require.NotZero(t, linuxDoArrayLegacyID)
+ require.NotZero(t, wechatUnionArrayLegacyID)
+}
+
func TestAuthIdentityLegacyExternalSafetyMigration_ReportsConflictsAndDowngradesInvalidJSON(t *testing.T) {
tx := testTx(t)
ctx := context.Background()
diff --git a/backend/internal/repository/user_profile_identity_repo.go b/backend/internal/repository/user_profile_identity_repo.go
index c1b4b6bf..dbba364d 100644
--- a/backend/internal/repository/user_profile_identity_repo.go
+++ b/backend/internal/repository/user_profile_identity_repo.go
@@ -26,6 +26,10 @@ var (
"AUTH_IDENTITY_CHANNEL_OWNERSHIP_CONFLICT",
"auth identity channel already belongs to another user",
)
+ ErrAuthIdentityChannelProviderMismatch = infraerrors.BadRequest(
+ "AUTH_IDENTITY_CHANNEL_PROVIDER_MISMATCH",
+ "auth identity channel provider must match canonical identity",
+ )
)
type ProviderGrantReason string
@@ -133,6 +137,10 @@ func (r *userRepository) WithUserProfileIdentityTx(ctx context.Context, fn func(
}
func (r *userRepository) CreateAuthIdentity(ctx context.Context, input CreateAuthIdentityInput) (*CreateAuthIdentityResult, error) {
+ if err := validateAuthIdentityChannelProviderMatch(input.Canonical, input.Channel); err != nil {
+ return nil, err
+ }
+
client := clientFromContext(ctx, r.client)
create := client.AuthIdentity.Create().
@@ -240,6 +248,10 @@ func (r *userRepository) ListUserAuthIdentities(ctx context.Context, userID int6
}
func (r *userRepository) BindAuthIdentityToUser(ctx context.Context, input BindAuthIdentityInput) (*CreateAuthIdentityResult, error) {
+ if err := validateAuthIdentityChannelProviderMatch(input.Canonical, input.Channel); err != nil {
+ return nil, err
+ }
+
var result *CreateAuthIdentityResult
err := r.WithUserProfileIdentityTx(ctx, func(txCtx context.Context) error {
client := clientFromContext(txCtx, r.client)
@@ -531,6 +543,23 @@ func copyMetadata(in map[string]any) map[string]any {
return out
}
+func validateAuthIdentityChannelProviderMatch(canonical AuthIdentityKey, channel *AuthIdentityChannelKey) error {
+ if channel == nil {
+ return nil
+ }
+
+ canonicalProviderType := strings.TrimSpace(canonical.ProviderType)
+ canonicalProviderKey := strings.TrimSpace(canonical.ProviderKey)
+ channelProviderType := strings.TrimSpace(channel.ProviderType)
+ channelProviderKey := strings.TrimSpace(channel.ProviderKey)
+
+ if canonicalProviderType != channelProviderType || canonicalProviderKey != channelProviderKey {
+ return ErrAuthIdentityChannelProviderMismatch
+ }
+
+ return nil
+}
+
func txAwareSQLExecutor(ctx context.Context, fallback sqlExecutor, client *dbent.Client) sqlQueryExecutor {
if tx := dbent.TxFromContext(ctx); tx != nil {
if exec := sqlExecutorFromEntClient(tx.Client()); exec != nil {
diff --git a/backend/internal/repository/user_profile_identity_repo_contract_test.go b/backend/internal/repository/user_profile_identity_repo_contract_test.go
index 19022ec1..a02af62b 100644
--- a/backend/internal/repository/user_profile_identity_repo_contract_test.go
+++ b/backend/internal/repository/user_profile_identity_repo_contract_test.go
@@ -187,6 +187,48 @@ func (s *UserProfileIdentityRepoSuite) TestBindAuthIdentityToUser_IsIdempotentAn
s.Require().ErrorIs(err, ErrAuthIdentityChannelOwnershipConflict)
}
+func (s *UserProfileIdentityRepoSuite) TestCreateAuthIdentity_RejectsChannelProviderMismatch() {
+ user := s.mustCreateUser("provider-mismatch-create")
+
+ _, err := s.repo.CreateAuthIdentity(s.ctx, CreateAuthIdentityInput{
+ UserID: user.ID,
+ Canonical: AuthIdentityKey{
+ ProviderType: "wechat",
+ ProviderKey: "wechat-main",
+ ProviderSubject: "union-create-mismatch",
+ },
+ Channel: &AuthIdentityChannelKey{
+ ProviderType: "linuxdo",
+ ProviderKey: "linuxdo-main",
+ Channel: "oauth",
+ ChannelAppID: "app-mismatch",
+ ChannelSubject: "openid-create-mismatch",
+ },
+ })
+ s.Require().ErrorIs(err, ErrAuthIdentityChannelProviderMismatch)
+}
+
+func (s *UserProfileIdentityRepoSuite) TestBindAuthIdentityToUser_RejectsChannelProviderMismatch() {
+ user := s.mustCreateUser("provider-mismatch-bind")
+
+ _, err := s.repo.BindAuthIdentityToUser(s.ctx, BindAuthIdentityInput{
+ UserID: user.ID,
+ Canonical: AuthIdentityKey{
+ ProviderType: "wechat",
+ ProviderKey: "wechat-main",
+ ProviderSubject: "union-bind-mismatch",
+ },
+ Channel: &AuthIdentityChannelKey{
+ ProviderType: "wechat",
+ ProviderKey: "wechat-legacy",
+ Channel: "oa",
+ ChannelAppID: "wx-app-bind-mismatch",
+ ChannelSubject: "openid-bind-mismatch",
+ },
+ })
+ s.Require().ErrorIs(err, ErrAuthIdentityChannelProviderMismatch)
+}
+
func (s *UserProfileIdentityRepoSuite) TestWithUserProfileIdentityTx_RollsBackIdentityAndGrantOnError() {
user := s.mustCreateUser("tx-rollback")
expectedErr := errors.New("rollback")
diff --git a/backend/migrations/115_auth_identity_legacy_external_backfill.sql b/backend/migrations/115_auth_identity_legacy_external_backfill.sql
index f4a13c36..7a20f8eb 100644
--- a/backend/migrations/115_auth_identity_legacy_external_backfill.sql
+++ b/backend/migrations/115_auth_identity_legacy_external_backfill.sql
@@ -1,3 +1,29 @@
+CREATE OR REPLACE FUNCTION public.__migration_115_safe_legacy_metadata_jsonb(input_text TEXT)
+RETURNS JSONB
+LANGUAGE plpgsql
+AS $$
+DECLARE
+ parsed JSONB;
+BEGIN
+ IF input_text IS NULL OR BTRIM(input_text) = '' THEN
+ RETURN '{}'::jsonb;
+ END IF;
+
+ BEGIN
+ parsed := input_text::jsonb;
+ EXCEPTION
+ WHEN OTHERS THEN
+ RETURN '{}'::jsonb;
+ END;
+
+ IF jsonb_typeof(parsed) = 'object' THEN
+ RETURN parsed;
+ END IF;
+
+ RETURN jsonb_build_object('_legacy_metadata_raw_json', parsed);
+END;
+$$;
+
DO $$
BEGIN
IF to_regclass('public.user_external_identities') IS NULL THEN
@@ -33,7 +59,7 @@ FROM (
BTRIM(uei.provider_user_id) AS provider_user_id,
BTRIM(uei.provider_username) AS provider_username,
BTRIM(uei.display_name) AS display_name,
- COALESCE(NULLIF(BTRIM(COALESCE(uei.metadata, '')), '')::jsonb, '{}'::jsonb) AS metadata_json,
+ public.__migration_115_safe_legacy_metadata_jsonb(uei.metadata) AS metadata_json,
uei.created_at,
uei.updated_at
FROM user_external_identities AS uei
@@ -78,7 +104,7 @@ FROM (
BTRIM(uei.provider_union_id) AS provider_union_id,
BTRIM(uei.provider_username) AS provider_username,
BTRIM(uei.display_name) AS display_name,
- COALESCE(NULLIF(BTRIM(COALESCE(uei.metadata, '')), '')::jsonb, '{}'::jsonb) AS metadata_json,
+ public.__migration_115_safe_legacy_metadata_jsonb(uei.metadata) AS metadata_json,
uei.created_at,
uei.updated_at
FROM user_external_identities AS uei
@@ -123,7 +149,7 @@ FROM (
FROM user_external_identities AS uei
JOIN users AS u ON u.id = uei.user_id
CROSS JOIN LATERAL (
- SELECT COALESCE(NULLIF(BTRIM(COALESCE(uei.metadata, '')), '')::jsonb, '{}'::jsonb) AS metadata_json
+ SELECT public.__migration_115_safe_legacy_metadata_jsonb(uei.metadata) AS metadata_json
) AS meta
WHERE u.deleted_at IS NULL
AND LOWER(BTRIM(COALESCE(uei.provider, ''))) = 'wechat'
@@ -157,7 +183,7 @@ FROM (
uei.id,
uei.user_id,
BTRIM(uei.provider_user_id) AS provider_user_id,
- COALESCE(NULLIF(BTRIM(COALESCE(uei.metadata, '')), '')::jsonb, '{}'::jsonb) AS metadata_json
+ public.__migration_115_safe_legacy_metadata_jsonb(uei.metadata) AS metadata_json
FROM user_external_identities AS uei
JOIN users AS u ON u.id = uei.user_id
WHERE u.deleted_at IS NULL
@@ -185,3 +211,5 @@ WHERE ai.provider_type = 'wechat'
AND COALESCE(ai.metadata ->> 'backfill_source', '') = 'synthetic_email'
AND BTRIM(COALESCE(ai.metadata ->> 'unionid', '')) = ''
ON CONFLICT (report_type, report_key) DO NOTHING;
+
+DROP FUNCTION IF EXISTS public.__migration_115_safe_legacy_metadata_jsonb(TEXT);
diff --git a/backend/migrations/116_auth_identity_legacy_external_safety_reports.sql b/backend/migrations/116_auth_identity_legacy_external_safety_reports.sql
index 994f3f37..3983bb1a 100644
--- a/backend/migrations/116_auth_identity_legacy_external_safety_reports.sql
+++ b/backend/migrations/116_auth_identity_legacy_external_safety_reports.sql
@@ -332,5 +332,38 @@ ON CONFLICT (report_type, report_key) DO NOTHING;
$sql$;
END $$;
+DO $$
+BEGIN
+ IF NOT EXISTS (
+ SELECT 1
+ FROM pg_constraint
+ WHERE conname = 'auth_identities_metadata_is_object_check'
+ ) THEN
+ ALTER TABLE auth_identities
+ ADD CONSTRAINT auth_identities_metadata_is_object_check
+ CHECK (jsonb_typeof(metadata) = 'object');
+ END IF;
+
+ IF NOT EXISTS (
+ SELECT 1
+ FROM pg_constraint
+ WHERE conname = 'auth_identity_channels_metadata_is_object_check'
+ ) THEN
+ ALTER TABLE auth_identity_channels
+ ADD CONSTRAINT auth_identity_channels_metadata_is_object_check
+ CHECK (jsonb_typeof(metadata) = 'object');
+ END IF;
+
+ IF NOT EXISTS (
+ SELECT 1
+ FROM pg_constraint
+ WHERE conname = 'auth_identity_migration_reports_details_is_object_check'
+ ) THEN
+ ALTER TABLE auth_identity_migration_reports
+ ADD CONSTRAINT auth_identity_migration_reports_details_is_object_check
+ CHECK (jsonb_typeof(details) = 'object');
+ END IF;
+END $$;
+
DROP FUNCTION IF EXISTS public.__migration_116_is_valid_legacy_metadata_jsonb(TEXT);
DROP FUNCTION IF EXISTS public.__migration_116_safe_legacy_metadata_jsonb(TEXT);
From 1d8432b8a4a3f2f6fde27ee0d5c01f00783a1c04 Mon Sep 17 00:00:00 2001
From: IanShaw027
Date: Tue, 21 Apr 2026 01:40:56 +0800
Subject: [PATCH 083/189] fix: harden payment resume and wxpay webhook routing
---
.../handler/payment_webhook_handler.go | 37 +++++-
.../handler/payment_webhook_handler_test.go | 70 +++++++++++
backend/internal/payment/load_balancer.go | 38 ++++++
backend/internal/service/payment_order.go | 66 +++++++++--
.../service/payment_order_jsapi_test.go | 112 ++++++++++++++++++
.../service/payment_resume_service.go | 32 ++++-
.../service/payment_resume_service_test.go | 64 ++++++++++
.../service/payment_webhook_provider.go | 76 +++++++++++-
.../service/payment_webhook_provider_test.go | 62 +++++++++-
9 files changed, 526 insertions(+), 31 deletions(-)
create mode 100644 backend/internal/service/payment_order_jsapi_test.go
diff --git a/backend/internal/handler/payment_webhook_handler.go b/backend/internal/handler/payment_webhook_handler.go
index 9fdefa93..c06a5b7e 100644
--- a/backend/internal/handler/payment_webhook_handler.go
+++ b/backend/internal/handler/payment_webhook_handler.go
@@ -1,6 +1,8 @@
package handler
import (
+ "context"
+ "fmt"
"io"
"log/slog"
"net/http"
@@ -77,9 +79,13 @@ func (h *PaymentWebhookHandler) handleNotify(c *gin.Context, providerKey string)
// This is needed when multiple instances of the same provider exist (e.g. multiple EasyPay accounts).
outTradeNo := extractOutTradeNo(rawBody, providerKey)
- provider, err := h.paymentService.GetWebhookProvider(c.Request.Context(), providerKey, outTradeNo)
+ providers, err := h.paymentService.GetWebhookProviders(c.Request.Context(), providerKey, outTradeNo)
if err != nil {
slog.Warn("[Payment Webhook] provider not found", "provider", providerKey, "outTradeNo", outTradeNo, "error", err)
+ if providerKey == payment.TypeWxpay {
+ c.String(http.StatusBadRequest, "verify failed")
+ return
+ }
writeSuccessResponse(c, providerKey)
return
}
@@ -89,7 +95,7 @@ func (h *PaymentWebhookHandler) handleNotify(c *gin.Context, providerKey string)
headers[strings.ToLower(k)] = c.GetHeader(k)
}
- notification, err := provider.VerifyNotification(c.Request.Context(), rawBody, headers)
+ resolvedProviderKey, notification, err := verifyNotificationWithProviders(c.Request.Context(), providers, rawBody, headers)
if err != nil {
truncatedBody := rawBody
if len(truncatedBody) > webhookLogTruncateLen {
@@ -103,17 +109,17 @@ func (h *PaymentWebhookHandler) handleNotify(c *gin.Context, providerKey string)
// nil notification means irrelevant event (e.g. Stripe non-payment event); return success.
if notification == nil {
- writeSuccessResponse(c, providerKey)
+ writeSuccessResponse(c, resolvedProviderKey)
return
}
- if err := h.paymentService.HandlePaymentNotification(c.Request.Context(), notification, providerKey); err != nil {
- slog.Error("[Payment Webhook] handle notification failed", "provider", providerKey, "error", err)
+ if err := h.paymentService.HandlePaymentNotification(c.Request.Context(), notification, resolvedProviderKey); err != nil {
+ slog.Error("[Payment Webhook] handle notification failed", "provider", resolvedProviderKey, "error", err)
c.String(http.StatusInternalServerError, "handle failed")
return
}
- writeSuccessResponse(c, providerKey)
+ writeSuccessResponse(c, resolvedProviderKey)
}
// extractOutTradeNo parses the webhook body to find the out_trade_no.
@@ -131,6 +137,25 @@ func extractOutTradeNo(rawBody, providerKey string) string {
return ""
}
+func verifyNotificationWithProviders(ctx context.Context, providers []payment.Provider, rawBody string, headers map[string]string) (string, *payment.PaymentNotification, error) {
+ var lastErr error
+ for _, provider := range providers {
+ if provider == nil {
+ continue
+ }
+ notification, err := provider.VerifyNotification(ctx, rawBody, headers)
+ if err != nil {
+ lastErr = err
+ continue
+ }
+ return provider.ProviderKey(), notification, nil
+ }
+ if lastErr != nil {
+ return "", nil, lastErr
+ }
+ return "", nil, fmt.Errorf("no webhook provider could verify notification")
+}
+
// wxpaySuccessResponse is the JSON response expected by WeChat Pay webhook.
type wxpaySuccessResponse struct {
Code string `json:"code"`
diff --git a/backend/internal/handler/payment_webhook_handler_test.go b/backend/internal/handler/payment_webhook_handler_test.go
index 6f448131..88221b5c 100644
--- a/backend/internal/handler/payment_webhook_handler_test.go
+++ b/backend/internal/handler/payment_webhook_handler_test.go
@@ -3,11 +3,14 @@
package handler
import (
+ "context"
"encoding/json"
+ "errors"
"net/http"
"net/http/httptest"
"testing"
+ "github.com/Wei-Shaw/sub2api/internal/payment"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
@@ -131,3 +134,70 @@ func TestExtractOutTradeNo(t *testing.T) {
})
}
}
+
+func TestVerifyNotificationWithProvidersReturnsMatchedProvider(t *testing.T) {
+ firstErr := errors.New("wrong provider")
+ providers := []payment.Provider{
+ webhookHandlerProviderStub{
+ key: payment.TypeWxpay,
+ verifyErr: firstErr,
+ },
+ webhookHandlerProviderStub{
+ key: payment.TypeWxpay,
+ notification: &payment.PaymentNotification{
+ OrderID: "sub2_42",
+ TradeNo: "trade-42",
+ Status: payment.NotificationStatusSuccess,
+ },
+ },
+ }
+
+ providerKey, notification, err := verifyNotificationWithProviders(context.Background(), providers, "{}", map[string]string{"wechatpay-signature": "sig"})
+ require.NoError(t, err)
+ require.Equal(t, payment.TypeWxpay, providerKey)
+ require.NotNil(t, notification)
+ require.Equal(t, "sub2_42", notification.OrderID)
+}
+
+func TestVerifyNotificationWithProvidersFailsWhenAllProvidersReject(t *testing.T) {
+ providers := []payment.Provider{
+ webhookHandlerProviderStub{
+ key: payment.TypeWxpay,
+ verifyErr: errors.New("verify failed a"),
+ },
+ webhookHandlerProviderStub{
+ key: payment.TypeWxpay,
+ verifyErr: errors.New("verify failed b"),
+ },
+ }
+
+ _, _, err := verifyNotificationWithProviders(context.Background(), providers, "{}", nil)
+ require.Error(t, err)
+}
+
+type webhookHandlerProviderStub struct {
+ key string
+ notification *payment.PaymentNotification
+ verifyErr error
+}
+
+func (p webhookHandlerProviderStub) Name() string { return p.key }
+func (p webhookHandlerProviderStub) ProviderKey() string { return p.key }
+func (p webhookHandlerProviderStub) SupportedTypes() []payment.PaymentType {
+ return []payment.PaymentType{payment.PaymentType(p.key)}
+}
+func (p webhookHandlerProviderStub) CreatePayment(context.Context, payment.CreatePaymentRequest) (*payment.CreatePaymentResponse, error) {
+ panic("unexpected call")
+}
+func (p webhookHandlerProviderStub) QueryOrder(context.Context, string) (*payment.QueryOrderResponse, error) {
+ panic("unexpected call")
+}
+func (p webhookHandlerProviderStub) VerifyNotification(context.Context, string, map[string]string) (*payment.PaymentNotification, error) {
+ if p.verifyErr != nil {
+ return nil, p.verifyErr
+ }
+ return p.notification, nil
+}
+func (p webhookHandlerProviderStub) Refund(context.Context, payment.RefundRequest) (*payment.RefundResponse, error) {
+ panic("unexpected call")
+}
diff --git a/backend/internal/payment/load_balancer.go b/backend/internal/payment/load_balancer.go
index ec4ed1d3..ddf792bb 100644
--- a/backend/internal/payment/load_balancer.go
+++ b/backend/internal/payment/load_balancer.go
@@ -45,11 +45,31 @@ type DefaultLoadBalancer struct {
counter atomic.Uint64
}
+type contextKey string
+
+const wxpayJSAPIAppIDContextKey contextKey = "payment.wxpay.jsapi_app_id"
+
// NewDefaultLoadBalancer creates a new load balancer.
func NewDefaultLoadBalancer(db *dbent.Client, encryptionKey []byte) *DefaultLoadBalancer {
return &DefaultLoadBalancer{db: db, encryptionKey: encryptionKey}
}
+func WithWxpayJSAPIAppID(ctx context.Context, appID string) context.Context {
+ appID = strings.TrimSpace(appID)
+ if appID == "" {
+ return ctx
+ }
+ return context.WithValue(ctx, wxpayJSAPIAppIDContextKey, appID)
+}
+
+func wxpayJSAPIAppIDFromContext(ctx context.Context) string {
+ if ctx == nil {
+ return ""
+ }
+ appID, _ := ctx.Value(wxpayJSAPIAppIDContextKey).(string)
+ return strings.TrimSpace(appID)
+}
+
// instanceCandidate pairs an instance with its pre-fetched daily usage.
type instanceCandidate struct {
inst *dbent.PaymentProviderInstance
@@ -116,6 +136,7 @@ func (lb *DefaultLoadBalancer) queryEnabledInstances(
}
var matched []*dbent.PaymentProviderInstance
+ expectedWxpayJSAPIAppID := wxpayJSAPIAppIDFromContext(ctx)
for _, inst := range instances {
// Stripe: match by provider_key because supported_types lists sub-types (card,link,alipay,wxpay),
// not "stripe" itself. The checkout page aggregates all sub-types under "stripe".
@@ -124,6 +145,16 @@ func (lb *DefaultLoadBalancer) queryEnabledInstances(
matched = append(matched, inst)
}
} else if InstanceSupportsType(inst.SupportedTypes, paymentType) {
+ if expectedWxpayJSAPIAppID != "" && normalizeVisibleMethodSupportType(paymentType) == TypeWxpay && inst.ProviderKey == TypeWxpay {
+ config, cfgErr := lb.decryptConfig(inst.Config)
+ if cfgErr != nil {
+ slog.Warn("skip wxpay instance with unreadable config during jsapi filtering", "instance_id", inst.ID, "error", cfgErr)
+ continue
+ }
+ if resolveWxpayJSAPIAppID(config) != expectedWxpayJSAPIAppID {
+ continue
+ }
+ }
matched = append(matched, inst)
}
}
@@ -358,6 +389,13 @@ func legacyVisibleMethodAlias(paymentType PaymentType) PaymentType {
}
}
+func resolveWxpayJSAPIAppID(config map[string]string) string {
+ if appID := strings.TrimSpace(config["mpAppId"]); appID != "" {
+ return appID
+ }
+ return strings.TrimSpace(config["appId"])
+}
+
// GetInstanceConfig decrypts and returns the configuration for a provider instance by ID.
func (lb *DefaultLoadBalancer) GetInstanceConfig(ctx context.Context, instanceID int64) (map[string]string, error) {
inst, err := lb.db.PaymentProviderInstance.Get(ctx, instanceID)
diff --git a/backend/internal/service/payment_order.go b/backend/internal/service/payment_order.go
index 221d6b94..7d973b92 100644
--- a/backend/internal/service/payment_order.go
+++ b/backend/internal/service/payment_order.go
@@ -216,7 +216,11 @@ func (s *PaymentService) checkDailyLimit(ctx context.Context, tx *dbent.Tx, user
}
func (s *PaymentService) selectCreateOrderInstance(ctx context.Context, req CreateOrderRequest, cfg *PaymentConfig, payAmount float64) (*payment.InstanceSelection, error) {
- sel, err := s.loadBalancer.SelectInstance(ctx, "", req.PaymentType, payment.Strategy(cfg.LoadBalanceStrategy), payAmount)
+ selectCtx, err := s.prepareCreateOrderSelectionContext(ctx, req)
+ if err != nil {
+ return nil, err
+ }
+ sel, err := s.loadBalancer.SelectInstance(selectCtx, "", req.PaymentType, payment.Strategy(cfg.LoadBalanceStrategy), payAmount)
if err != nil {
return nil, infraerrors.ServiceUnavailable("PAYMENT_GATEWAY_ERROR", fmt.Sprintf("payment method (%s) is not configured", req.PaymentType))
}
@@ -226,6 +230,44 @@ func (s *PaymentService) selectCreateOrderInstance(ctx context.Context, req Crea
return sel, nil
}
+func (s *PaymentService) prepareCreateOrderSelectionContext(ctx context.Context, req CreateOrderRequest) (context.Context, error) {
+ if !requestNeedsWeChatJSAPICompatibility(req) {
+ return ctx, nil
+ }
+ if !s.usesOfficialWxpayVisibleMethod(ctx) {
+ return ctx, nil
+ }
+ expectedAppID, _, err := s.getWeChatPaymentOAuthCredential(ctx)
+ if err != nil {
+ return nil, err
+ }
+ return payment.WithWxpayJSAPIAppID(ctx, expectedAppID), nil
+}
+
+func requestNeedsWeChatJSAPICompatibility(req CreateOrderRequest) bool {
+ if payment.GetBasePaymentType(req.PaymentType) != payment.TypeWxpay {
+ return false
+ }
+ return req.IsWeChatBrowser || strings.TrimSpace(req.OpenID) != ""
+}
+
+func (s *PaymentService) usesOfficialWxpayVisibleMethod(ctx context.Context) bool {
+ if s == nil || s.configService == nil || s.configService.settingRepo == nil {
+ return false
+ }
+ vals, err := s.configService.settingRepo.GetMultiple(ctx, []string{
+ SettingPaymentVisibleMethodWxpayEnabled,
+ SettingPaymentVisibleMethodWxpaySource,
+ })
+ if err != nil {
+ return false
+ }
+ if vals[SettingPaymentVisibleMethodWxpayEnabled] != "true" {
+ return false
+ }
+ return NormalizeVisibleMethodSource(payment.TypeWxpay, vals[SettingPaymentVisibleMethodWxpaySource]) == VisibleMethodSourceOfficialWechat
+}
+
func (s *PaymentService) invokeProvider(ctx context.Context, order *dbent.PaymentOrder, req CreateOrderRequest, cfg *PaymentConfig, limitAmount float64, payAmountStr string, payAmount float64, plan *dbent.SubscriptionPlan, sel *payment.InstanceSelection) (*CreateOrderResponse, error) {
prov, err := provider.CreateProvider(sel.ProviderKey, sel.InstanceID, sel.Config)
if err != nil {
@@ -239,16 +281,18 @@ func (s *PaymentService) invokeProvider(ctx context.Context, order *dbent.Paymen
}
resumeToken := ""
if resume := s.paymentResume(); resume != nil {
- resumeToken, err = resume.CreateToken(ResumeTokenClaims{
- OrderID: order.ID,
- UserID: order.UserID,
- ProviderInstanceID: sel.InstanceID,
- ProviderKey: sel.ProviderKey,
- PaymentType: req.PaymentType,
- CanonicalReturnURL: canonicalReturnURL,
- })
- if err != nil {
- return nil, fmt.Errorf("create payment resume token: %w", err)
+ if resume.isSigningConfigured() {
+ resumeToken, err = resume.CreateToken(ResumeTokenClaims{
+ OrderID: order.ID,
+ UserID: order.UserID,
+ ProviderInstanceID: sel.InstanceID,
+ ProviderKey: sel.ProviderKey,
+ PaymentType: req.PaymentType,
+ CanonicalReturnURL: canonicalReturnURL,
+ })
+ if err != nil {
+ return nil, fmt.Errorf("create payment resume token: %w", err)
+ }
}
}
providerReturnURL, err := buildPaymentReturnURL(canonicalReturnURL, order.ID, resumeToken)
diff --git a/backend/internal/service/payment_order_jsapi_test.go b/backend/internal/service/payment_order_jsapi_test.go
new file mode 100644
index 00000000..08492432
--- /dev/null
+++ b/backend/internal/service/payment_order_jsapi_test.go
@@ -0,0 +1,112 @@
+package service
+
+import (
+ "context"
+ "encoding/json"
+ "fmt"
+ "testing"
+
+ "github.com/Wei-Shaw/sub2api/internal/payment"
+)
+
+const jsapiTestEncryptionKey = "0123456789abcdef0123456789abcdef"
+
+func TestSelectCreateOrderInstancePrefersJSAPICompatibleWxpayInstance(t *testing.T) {
+ ctx := context.Background()
+ client := newPaymentConfigServiceTestClient(t)
+
+ compatibleConfig := mustEncryptJSAPITestConfig(t, map[string]string{
+ "appId": "wx-merchant-app",
+ "mpAppId": "wx-mp-app",
+ "mchId": "mch-compatible",
+ "privateKey": "private-key",
+ "apiV3Key": jsapiTestEncryptionKey,
+ "publicKey": "public-key",
+ "publicKeyId": "key-compatible",
+ "certSerial": "serial-compatible",
+ })
+ incompatibleConfig := mustEncryptJSAPITestConfig(t, map[string]string{
+ "appId": "wx-merchant-other",
+ "mpAppId": "wx-mp-other",
+ "mchId": "mch-incompatible",
+ "privateKey": "private-key",
+ "apiV3Key": jsapiTestEncryptionKey,
+ "publicKey": "public-key",
+ "publicKeyId": "key-incompatible",
+ "certSerial": "serial-incompatible",
+ })
+
+ compatible, err := client.PaymentProviderInstance.Create().
+ SetProviderKey(payment.TypeWxpay).
+ SetName("wxpay-compatible").
+ SetConfig(compatibleConfig).
+ SetSupportedTypes("wxpay").
+ SetEnabled(true).
+ SetSortOrder(1).
+ Save(ctx)
+ if err != nil {
+ t.Fatalf("create compatible wxpay instance: %v", err)
+ }
+ _, err = client.PaymentProviderInstance.Create().
+ SetProviderKey(payment.TypeWxpay).
+ SetName("wxpay-incompatible").
+ SetConfig(incompatibleConfig).
+ SetSupportedTypes("wxpay").
+ SetEnabled(true).
+ SetSortOrder(2).
+ Save(ctx)
+ if err != nil {
+ t.Fatalf("create incompatible wxpay instance: %v", err)
+ }
+
+ configService := &PaymentConfigService{
+ entClient: client,
+ settingRepo: &paymentConfigSettingRepoStub{values: map[string]string{
+ SettingPaymentVisibleMethodWxpayEnabled: "true",
+ SettingPaymentVisibleMethodWxpaySource: VisibleMethodSourceOfficialWechat,
+ }},
+ encryptionKey: []byte(jsapiTestEncryptionKey),
+ }
+ loadBalancer := newVisibleMethodLoadBalancer(
+ payment.NewDefaultLoadBalancer(client, []byte(jsapiTestEncryptionKey)),
+ configService,
+ )
+ svc := &PaymentService{
+ entClient: client,
+ loadBalancer: loadBalancer,
+ configService: configService,
+ }
+
+ t.Setenv("WECHAT_OAUTH_MP_APP_ID", "wx-mp-app")
+ t.Setenv("WECHAT_OAUTH_MP_APP_SECRET", "wechat-secret")
+
+ sel, err := svc.selectCreateOrderInstance(ctx, CreateOrderRequest{
+ PaymentType: payment.TypeWxpay,
+ OpenID: "openid-123",
+ IsWeChatBrowser: true,
+ }, &PaymentConfig{LoadBalanceStrategy: string(payment.StrategyRoundRobin)}, 12.5)
+ if err != nil {
+ t.Fatalf("selectCreateOrderInstance returned error: %v", err)
+ }
+ if sel == nil {
+ t.Fatal("expected selected instance, got nil")
+ }
+ expectedInstanceID := fmt.Sprintf("%d", compatible.ID)
+ if sel.InstanceID != expectedInstanceID {
+ t.Fatalf("selected instance id = %q, want %q", sel.InstanceID, expectedInstanceID)
+ }
+}
+
+func mustEncryptJSAPITestConfig(t *testing.T, config map[string]string) string {
+ t.Helper()
+
+ data, err := json.Marshal(config)
+ if err != nil {
+ t.Fatalf("marshal config: %v", err)
+ }
+ encrypted, err := payment.Encrypt(string(data), []byte(jsapiTestEncryptionKey))
+ if err != nil {
+ t.Fatalf("encrypt config: %v", err)
+ }
+ return encrypted
+}
diff --git a/backend/internal/service/payment_resume_service.go b/backend/internal/service/payment_resume_service.go
index 64d1d125..05997b38 100644
--- a/backend/internal/service/payment_resume_service.go
+++ b/backend/internal/service/payment_resume_service.go
@@ -33,6 +33,9 @@ const (
VisibleMethodSourceEasyPayWechat = "easypay_wxpay"
wechatPaymentResumeTokenType = "wechat_payment_resume"
+
+ paymentResumeNotConfiguredCode = "PAYMENT_RESUME_NOT_CONFIGURED"
+ paymentResumeNotConfiguredMessage = "payment resume tokens require a configured signing key"
)
type ResumeTokenClaims struct {
@@ -70,6 +73,17 @@ func NewPaymentResumeService(signingKey []byte) *PaymentResumeService {
return &PaymentResumeService{signingKey: signingKey}
}
+func (s *PaymentResumeService) isSigningConfigured() bool {
+ return s != nil && len(s.signingKey) > 0
+}
+
+func (s *PaymentResumeService) ensureSigningKey() error {
+ if s.isSigningConfigured() {
+ return nil
+ }
+ return infraerrors.ServiceUnavailable(paymentResumeNotConfiguredCode, paymentResumeNotConfiguredMessage)
+}
+
func NormalizeVisibleMethod(method string) string {
return payment.GetBasePaymentType(strings.TrimSpace(method))
}
@@ -240,6 +254,9 @@ func buildPaymentReturnURL(base string, orderID int64, resumeToken string) (stri
}
func (s *PaymentResumeService) CreateToken(claims ResumeTokenClaims) (string, error) {
+ if err := s.ensureSigningKey(); err != nil {
+ return "", err
+ }
if claims.OrderID <= 0 {
return "", fmt.Errorf("resume token requires order id")
}
@@ -250,6 +267,9 @@ func (s *PaymentResumeService) CreateToken(claims ResumeTokenClaims) (string, er
}
func (s *PaymentResumeService) ParseToken(token string) (*ResumeTokenClaims, error) {
+ if err := s.ensureSigningKey(); err != nil {
+ return nil, err
+ }
var claims ResumeTokenClaims
if err := s.parseSignedToken(token, &claims); err != nil {
return nil, infraerrors.BadRequest("INVALID_RESUME_TOKEN", "resume token payload is invalid")
@@ -261,6 +281,9 @@ func (s *PaymentResumeService) ParseToken(token string) (*ResumeTokenClaims, err
}
func (s *PaymentResumeService) CreateWeChatPaymentResumeToken(claims WeChatPaymentResumeClaims) (string, error) {
+ if err := s.ensureSigningKey(); err != nil {
+ return "", err
+ }
claims.OpenID = strings.TrimSpace(claims.OpenID)
if claims.OpenID == "" {
return "", fmt.Errorf("wechat payment resume token requires openid")
@@ -282,6 +305,9 @@ func (s *PaymentResumeService) CreateWeChatPaymentResumeToken(claims WeChatPayme
}
func (s *PaymentResumeService) ParseWeChatPaymentResumeToken(token string) (*WeChatPaymentResumeClaims, error) {
+ if err := s.ensureSigningKey(); err != nil {
+ return nil, err
+ }
var claims WeChatPaymentResumeClaims
if err := s.parseSignedToken(token, &claims); err != nil {
return nil, infraerrors.BadRequest("INVALID_WECHAT_PAYMENT_RESUME_TOKEN", "wechat payment resume token payload is invalid")
@@ -330,11 +356,7 @@ func (s *PaymentResumeService) parseSignedToken(token string, dest any) error {
}
func (s *PaymentResumeService) sign(payload string) string {
- key := s.signingKey
- if len(key) == 0 {
- key = []byte(paymentResumeFallbackSigningKey)
- }
- mac := hmac.New(sha256.New, key)
+ mac := hmac.New(sha256.New, s.signingKey)
_, _ = mac.Write([]byte(payload))
return base64.RawURLEncoding.EncodeToString(mac.Sum(nil))
}
diff --git a/backend/internal/service/payment_resume_service_test.go b/backend/internal/service/payment_resume_service_test.go
index 24d50494..4ac89f0f 100644
--- a/backend/internal/service/payment_resume_service_test.go
+++ b/backend/internal/service/payment_resume_service_test.go
@@ -4,6 +4,10 @@ package service
import (
"context"
+ "crypto/hmac"
+ "crypto/sha256"
+ "encoding/base64"
+ "encoding/json"
"net/url"
"strconv"
"testing"
@@ -150,6 +154,27 @@ func TestPaymentResumeTokenRoundTrip(t *testing.T) {
}
}
+func TestCreateTokenRejectsMissingSigningKey(t *testing.T) {
+ t.Parallel()
+
+ svc := NewPaymentResumeService(nil)
+ _, err := svc.CreateToken(ResumeTokenClaims{OrderID: 42})
+ if err == nil {
+ t.Fatal("CreateToken should reject missing signing key")
+ }
+}
+
+func TestParseTokenRejectsFallbackSignedTokenWhenSigningKeyMissing(t *testing.T) {
+ t.Parallel()
+
+ token := mustCreateFallbackSignedToken(t, ResumeTokenClaims{OrderID: 42, UserID: 7})
+ svc := NewPaymentResumeService(nil)
+ _, err := svc.ParseToken(token)
+ if err == nil {
+ t.Fatal("ParseToken should reject tokens when signing key is missing")
+ }
+}
+
func TestWeChatPaymentResumeTokenRoundTrip(t *testing.T) {
t.Parallel()
@@ -183,6 +208,31 @@ func TestWeChatPaymentResumeTokenRoundTrip(t *testing.T) {
}
}
+func TestCreateWeChatPaymentResumeTokenRejectsMissingSigningKey(t *testing.T) {
+ t.Parallel()
+
+ svc := NewPaymentResumeService(nil)
+ _, err := svc.CreateWeChatPaymentResumeToken(WeChatPaymentResumeClaims{OpenID: "openid-123"})
+ if err == nil {
+ t.Fatal("CreateWeChatPaymentResumeToken should reject missing signing key")
+ }
+}
+
+func TestParseWeChatPaymentResumeTokenRejectsFallbackSignedTokenWhenSigningKeyMissing(t *testing.T) {
+ t.Parallel()
+
+ token := mustCreateFallbackSignedToken(t, WeChatPaymentResumeClaims{
+ TokenType: wechatPaymentResumeTokenType,
+ OpenID: "openid-123",
+ PaymentType: payment.TypeWxpay,
+ })
+ svc := NewPaymentResumeService(nil)
+ _, err := svc.ParseWeChatPaymentResumeToken(token)
+ if err == nil {
+ t.Fatal("ParseWeChatPaymentResumeToken should reject tokens when signing key is missing")
+ }
+}
+
func TestNormalizeVisibleMethodSource(t *testing.T) {
t.Parallel()
@@ -315,3 +365,17 @@ func (c *captureLoadBalancer) SelectInstance(_ context.Context, providerKey stri
c.lastPaymentType = paymentType
return &payment.InstanceSelection{ProviderKey: providerKey, SupportedTypes: paymentType}, nil
}
+
+func mustCreateFallbackSignedToken(t *testing.T, claims any) string {
+ t.Helper()
+
+ payload, err := json.Marshal(claims)
+ if err != nil {
+ t.Fatalf("marshal claims: %v", err)
+ }
+ encodedPayload := base64.RawURLEncoding.EncodeToString(payload)
+ mac := hmac.New(sha256.New, []byte(paymentResumeFallbackSigningKey))
+ _, _ = mac.Write([]byte(encodedPayload))
+ signature := base64.RawURLEncoding.EncodeToString(mac.Sum(nil))
+ return encodedPayload + "." + signature
+}
diff --git a/backend/internal/service/payment_webhook_provider.go b/backend/internal/service/payment_webhook_provider.go
index 289d63ed..82dc9ea3 100644
--- a/backend/internal/service/payment_webhook_provider.go
+++ b/backend/internal/service/payment_webhook_provider.go
@@ -16,33 +16,70 @@ import (
// It resolves the original provider instance from the order whenever possible and
// only falls back to a registry provider for legacy/single-instance scenarios.
func (s *PaymentService) GetWebhookProvider(ctx context.Context, providerKey, outTradeNo string) (payment.Provider, error) {
+ providers, err := s.GetWebhookProviders(ctx, providerKey, outTradeNo)
+ if err != nil {
+ return nil, err
+ }
+ if len(providers) == 0 {
+ return nil, payment.ErrProviderNotFound
+ }
+ return providers[0], nil
+}
+
+// GetWebhookProviders returns provider candidates that can verify the webhook.
+// Official WeChat Pay may require multiple candidates because the callback body
+// cannot be bound to a merchant before decryption.
+func (s *PaymentService) GetWebhookProviders(ctx context.Context, providerKey, outTradeNo string) ([]payment.Provider, error) {
if outTradeNo != "" {
order, err := s.entClient.PaymentOrder.Query().Where(paymentorder.OutTradeNo(outTradeNo)).Only(ctx)
if err == nil {
if psHasPinnedProviderInstance(order) {
- return s.getPinnedOrderProvider(ctx, order)
+ prov, err := s.getPinnedOrderProvider(ctx, order)
+ if err != nil {
+ return nil, err
+ }
+ return []payment.Provider{prov}, nil
}
inst, err := s.getOrderProviderInstance(ctx, order)
if err != nil {
return nil, fmt.Errorf("load order provider instance: %w", err)
}
if inst != nil {
- return s.createProviderFromInstance(ctx, inst)
+ prov, err := s.createProviderFromInstance(ctx, inst)
+ if err != nil {
+ return nil, err
+ }
+ return []payment.Provider{prov}, nil
+ }
+ if strings.TrimSpace(providerKey) == payment.TypeWxpay {
+ return s.getEnabledWebhookProvidersByKey(ctx, providerKey)
}
if !s.webhookRegistryFallbackAllowed(ctx, providerKey) {
return nil, fmt.Errorf("webhook provider fallback is ambiguous for %s", providerKey)
}
s.EnsureProviders(ctx)
- return s.registry.GetProviderByKey(providerKey)
+ prov, err := s.registry.GetProviderByKey(providerKey)
+ if err != nil {
+ return nil, err
+ }
+ return []payment.Provider{prov}, nil
}
}
+ if strings.TrimSpace(providerKey) == payment.TypeWxpay {
+ return s.getEnabledWebhookProvidersByKey(ctx, providerKey)
+ }
+
if !s.webhookRegistryFallbackAllowed(ctx, providerKey) {
return nil, fmt.Errorf("webhook provider fallback is ambiguous for %s", providerKey)
}
s.EnsureProviders(ctx)
- return s.registry.GetProviderByKey(providerKey)
+ prov, err := s.registry.GetProviderByKey(providerKey)
+ if err != nil {
+ return nil, err
+ }
+ return []payment.Provider{prov}, nil
}
func (s *PaymentService) getPinnedOrderProvider(ctx context.Context, o *dbent.PaymentOrder) (payment.Provider, error) {
@@ -78,3 +115,34 @@ func (s *PaymentService) webhookRegistryFallbackAllowed(ctx context.Context, pro
func psHasPinnedProviderInstance(order *dbent.PaymentOrder) bool {
return order != nil && order.ProviderInstanceID != nil && strings.TrimSpace(*order.ProviderInstanceID) != ""
}
+
+func (s *PaymentService) getEnabledWebhookProvidersByKey(ctx context.Context, providerKey string) ([]payment.Provider, error) {
+ providerKey = strings.TrimSpace(providerKey)
+ instances, err := s.entClient.PaymentProviderInstance.Query().
+ Where(
+ paymentproviderinstance.ProviderKeyEQ(providerKey),
+ paymentproviderinstance.EnabledEQ(true),
+ ).
+ Order(dbent.Asc(paymentproviderinstance.FieldSortOrder)).
+ All(ctx)
+ if err != nil {
+ return nil, fmt.Errorf("query webhook provider instances: %w", err)
+ }
+ if len(instances) == 0 {
+ return nil, payment.ErrProviderNotFound
+ }
+
+ providers := make([]payment.Provider, 0, len(instances))
+ for _, inst := range instances {
+ prov, provErr := s.createProviderFromInstance(ctx, inst)
+ if provErr != nil {
+ slog.Warn("skip webhook provider instance", "provider", providerKey, "instanceID", inst.ID, "error", provErr)
+ continue
+ }
+ providers = append(providers, prov)
+ }
+ if len(providers) == 0 {
+ return nil, payment.ErrProviderNotFound
+ }
+ return providers, nil
+}
diff --git a/backend/internal/service/payment_webhook_provider_test.go b/backend/internal/service/payment_webhook_provider_test.go
index 4f0b6848..15b447c2 100644
--- a/backend/internal/service/payment_webhook_provider_test.go
+++ b/backend/internal/service/payment_webhook_provider_test.go
@@ -208,10 +208,28 @@ func TestGetOrderProviderInstanceLeavesProviderKeyMatchUnresolvedWhenTypeNotSupp
func TestGetWebhookProviderRejectsAmbiguousRegistryFallback(t *testing.T) {
ctx := context.Background()
client := newPaymentConfigServiceTestClient(t)
+ wxpayConfigA := encryptWebhookProviderConfig(t, map[string]string{
+ "appId": "wx-app-a",
+ "mchId": "mch-a",
+ "privateKey": "private-key-a",
+ "apiV3Key": webhookProviderTestEncryptionKey,
+ "publicKey": "public-key-a",
+ "publicKeyId": "public-key-id-a",
+ "certSerial": "cert-serial-a",
+ })
+ wxpayConfigB := encryptWebhookProviderConfig(t, map[string]string{
+ "appId": "wx-app-b",
+ "mchId": "mch-b",
+ "privateKey": "private-key-b",
+ "apiV3Key": webhookProviderTestEncryptionKey,
+ "publicKey": "public-key-b",
+ "publicKeyId": "public-key-id-b",
+ "certSerial": "cert-serial-b",
+ })
_, err := client.PaymentProviderInstance.Create().
SetProviderKey(payment.TypeWxpay).
SetName("wxpay-a").
- SetConfig("{}").
+ SetConfig(wxpayConfigA).
SetSupportedTypes("wxpay").
SetEnabled(true).
Save(ctx)
@@ -219,19 +237,51 @@ func TestGetWebhookProviderRejectsAmbiguousRegistryFallback(t *testing.T) {
_, err = client.PaymentProviderInstance.Create().
SetProviderKey(payment.TypeWxpay).
SetName("wxpay-b").
- SetConfig("{}").
+ SetConfig(wxpayConfigB).
SetSupportedTypes("wxpay").
SetEnabled(true).
Save(ctx)
require.NoError(t, err)
+ svc := &PaymentService{
+ entClient: client,
+ loadBalancer: newWebhookProviderTestLoadBalancer(client),
+ registry: payment.NewRegistry(),
+ providersLoaded: true,
+ }
+
+ providers, err := svc.GetWebhookProviders(ctx, payment.TypeWxpay, "")
+ require.NoError(t, err)
+ require.Len(t, providers, 2)
+}
+
+func TestGetWebhookProvidersRejectAmbiguousFallbackForNonWxpay(t *testing.T) {
+ ctx := context.Background()
+ client := newPaymentConfigServiceTestClient(t)
+ _, err := client.PaymentProviderInstance.Create().
+ SetProviderKey(payment.TypeAlipay).
+ SetName("alipay-a").
+ SetConfig("{}").
+ SetSupportedTypes("alipay").
+ SetEnabled(true).
+ Save(ctx)
+ require.NoError(t, err)
+ _, err = client.PaymentProviderInstance.Create().
+ SetProviderKey(payment.TypeAlipay).
+ SetName("alipay-b").
+ SetConfig("{}").
+ SetSupportedTypes("alipay").
+ SetEnabled(true).
+ Save(ctx)
+ require.NoError(t, err)
+
svc := &PaymentService{
entClient: client,
registry: payment.NewRegistry(),
providersLoaded: true,
}
- _, err = svc.GetWebhookProvider(ctx, payment.TypeWxpay, "")
+ _, err = svc.GetWebhookProviders(ctx, payment.TypeAlipay, "")
require.Error(t, err)
require.Contains(t, err.Error(), "ambiguous")
}
@@ -260,8 +310,10 @@ func TestGetWebhookProviderAllowsSingleInstanceRegistryFallback(t *testing.T) {
providersLoaded: true,
}
- prov, err := svc.GetWebhookProvider(ctx, payment.TypeStripe, "")
+ providers, err := svc.GetWebhookProviders(ctx, payment.TypeStripe, "")
require.NoError(t, err)
+ require.Len(t, providers, 1)
+ prov := providers[0]
require.Equal(t, payment.TypeStripe, prov.ProviderKey())
}
@@ -308,7 +360,7 @@ func TestGetWebhookProviderRejectsRegistryFallbackForPinnedOrder(t *testing.T) {
providersLoaded: true,
}
- _, err = svc.GetWebhookProvider(ctx, payment.TypeWxpay, "sub2_test_pinned_order")
+ _, err = svc.GetWebhookProviders(ctx, payment.TypeWxpay, "sub2_test_pinned_order")
require.Error(t, err)
require.Contains(t, err.Error(), "provider instance")
}
From 7c6491c2d39254029ca6f0950a27249e400d0098 Mon Sep 17 00:00:00 2001
From: IanShaw027
Date: Tue, 21 Apr 2026 01:45:25 +0800
Subject: [PATCH 084/189] fix auth pending session hardening
---
backend/internal/handler/auth_handler.go | 49 ++++-
.../internal/handler/auth_linuxdo_oauth.go | 9 +
.../handler/auth_linuxdo_oauth_test.go | 52 ++++++
.../handler/auth_oauth_pending_flow.go | 57 ++++++
.../handler/auth_oauth_pending_flow_test.go | 172 ++++++++++++++++++
backend/internal/handler/auth_oidc_oauth.go | 9 +
.../internal/handler/auth_oidc_oauth_test.go | 52 ++++++
backend/internal/handler/auth_wechat_oauth.go | 9 +
.../handler/auth_wechat_oauth_test.go | 53 ++++++
.../internal/service/auth_oauth_email_flow.go | 2 +-
backend/internal/service/auth_service.go | 12 +-
.../auth_service_identity_sync_test.go | 46 +++--
12 files changed, 490 insertions(+), 32 deletions(-)
diff --git a/backend/internal/handler/auth_handler.go b/backend/internal/handler/auth_handler.go
index 76ca153d..9801b3b3 100644
--- a/backend/internal/handler/auth_handler.go
+++ b/backend/internal/handler/auth_handler.go
@@ -1,6 +1,7 @@
package handler
import (
+ "context"
"log/slog"
"strings"
@@ -105,6 +106,34 @@ func (h *AuthHandler) respondWithTokenPair(c *gin.Context, user *service.User) {
})
}
+func (h *AuthHandler) ensureBackendModeAllowsUser(ctx context.Context, user *service.User) error {
+ if user == nil {
+ return infraerrors.Unauthorized("INVALID_USER", "user not found")
+ }
+ if h == nil || !h.isBackendModeEnabled(ctx) || user.IsAdmin() {
+ return nil
+ }
+ return infraerrors.Forbidden("BACKEND_MODE_ADMIN_ONLY", "Backend mode is active. Only admin login is allowed.")
+}
+
+func (h *AuthHandler) ensureBackendModeAllowsNewUserLogin(ctx context.Context) error {
+ if h == nil || !h.isBackendModeEnabled(ctx) {
+ return nil
+ }
+ return infraerrors.Forbidden("BACKEND_MODE_ADMIN_ONLY", "Backend mode is active. Only admin login is allowed.")
+}
+
+func (h *AuthHandler) isBackendModeEnabled(ctx context.Context) bool {
+ if h == nil || h.settingSvc == nil {
+ return false
+ }
+ settings, err := h.settingSvc.GetPublicSettings(ctx)
+ if err == nil && settings != nil {
+ return settings.BackendModeEnabled
+ }
+ return h.settingSvc.IsBackendModeEnabled(ctx)
+}
+
// Register handles user registration
// POST /api/v1/auth/register
func (h *AuthHandler) Register(c *gin.Context) {
@@ -178,6 +207,11 @@ func (h *AuthHandler) Login(c *gin.Context) {
}
_ = token // token 由 authService.Login 返回但此处由 respondWithTokenPair 重新生成
+ if err := h.ensureBackendModeAllowsUser(c.Request.Context(), user); err != nil {
+ response.ErrorFrom(c, err)
+ return
+ }
+
// Check if TOTP 2FA is enabled for this user
if h.totpService != nil && h.settingSvc.IsTotpEnabled(c.Request.Context()) && user.TotpEnabled {
// Create a temporary login session for 2FA
@@ -195,11 +229,7 @@ func (h *AuthHandler) Login(c *gin.Context) {
return
}
- // Backend mode: only admin can login
- if h.settingSvc.IsBackendModeEnabled(c.Request.Context()) && !user.IsAdmin() {
- response.Forbidden(c, "Backend mode is active. Only admin login is allowed.")
- return
- }
+ h.authService.RecordSuccessfulLogin(c.Request.Context(), user.ID)
h.respondWithTokenPair(c, user)
}
@@ -264,9 +294,8 @@ func (h *AuthHandler) Login2FA(c *gin.Context) {
return
}
- // Backend mode: only admin can login (check BEFORE deleting session)
- if h.settingSvc.IsBackendModeEnabled(c.Request.Context()) && !user.IsAdmin() {
- response.Forbidden(c, "Backend mode is active. Only admin login is allowed.")
+ if err := h.ensureBackendModeAllowsUser(c.Request.Context(), user); err != nil {
+ response.ErrorFrom(c, err)
return
}
@@ -330,6 +359,10 @@ func (h *AuthHandler) Login2FA(c *gin.Context) {
// Delete the login session (only after all checks pass)
_ = h.totpService.DeleteLoginSession(c.Request.Context(), req.TempToken)
+ if session.PendingOAuthBind == nil {
+ h.authService.RecordSuccessfulLogin(c.Request.Context(), user.ID)
+ }
+
h.respondWithTokenPair(c, user)
}
diff --git a/backend/internal/handler/auth_linuxdo_oauth.go b/backend/internal/handler/auth_linuxdo_oauth.go
index 175b1e1f..c3a9041b 100644
--- a/backend/internal/handler/auth_linuxdo_oauth.go
+++ b/backend/internal/handler/auth_linuxdo_oauth.go
@@ -474,6 +474,14 @@ func (h *AuthHandler) CompleteLinuxDoOAuthRegistration(c *gin.Context) {
response.ErrorFrom(c, err)
return
}
+ if err := ensurePendingOAuthCompleteRegistrationSession(session); err != nil {
+ response.ErrorFrom(c, err)
+ return
+ }
+ if err := h.ensureBackendModeAllowsNewUserLogin(c.Request.Context()); err != nil {
+ response.ErrorFrom(c, err)
+ return
+ }
email := strings.TrimSpace(session.ResolvedEmail)
username := pendingSessionStringValue(session.UpstreamIdentityClaims, "username")
@@ -499,6 +507,7 @@ func (h *AuthHandler) CompleteLinuxDoOAuthRegistration(c *gin.Context) {
response.ErrorFrom(c, infraerrors.InternalServer("PENDING_AUTH_ADOPTION_APPLY_FAILED", "failed to apply oauth profile adoption").WithCause(err))
return
}
+ h.authService.RecordSuccessfulLogin(c.Request.Context(), user.ID)
if _, err := pendingSvc.ConsumeBrowserSession(c.Request.Context(), sessionToken, browserSessionKey); err != nil {
clearOAuthPendingSessionCookie(c, secureCookie)
clearOAuthPendingBrowserCookie(c, secureCookie)
diff --git a/backend/internal/handler/auth_linuxdo_oauth_test.go b/backend/internal/handler/auth_linuxdo_oauth_test.go
index fb57e570..7938f3e7 100644
--- a/backend/internal/handler/auth_linuxdo_oauth_test.go
+++ b/backend/internal/handler/auth_linuxdo_oauth_test.go
@@ -591,6 +591,58 @@ func TestCompleteLinuxDoOAuthRegistrationAppliesPendingAdoptionDecision(t *testi
require.NotNil(t, consumed.ConsumedAt)
}
+func TestCompleteLinuxDoOAuthRegistrationRejectsAdoptExistingUserSession(t *testing.T) {
+ handler, client := newOAuthPendingFlowTestHandler(t, false)
+ ctx := context.Background()
+
+ existingUser, err := client.User.Create().
+ SetEmail("owner@example.com").
+ SetUsername("owner-user").
+ SetPasswordHash("hash").
+ SetRole(service.RoleUser).
+ SetStatus(service.StatusActive).
+ Save(ctx)
+ require.NoError(t, err)
+
+ session, err := client.PendingAuthSession.Create().
+ SetSessionToken("linuxdo-complete-invalid-session").
+ SetIntent("adopt_existing_user_by_email").
+ SetProviderType("linuxdo").
+ SetProviderKey("linuxdo").
+ SetProviderSubject("linuxdo-invalid-subject-1").
+ SetTargetUserID(existingUser.ID).
+ SetResolvedEmail(existingUser.Email).
+ SetBrowserSessionKey("linuxdo-invalid-browser").
+ SetUpstreamIdentityClaims(map[string]any{
+ "username": "linuxdo_user",
+ }).
+ SetLocalFlowState(map[string]any{
+ oauthCompletionResponseKey: map[string]any{
+ "step": "bind_login_required",
+ },
+ }).
+ SetExpiresAt(time.Now().UTC().Add(10 * time.Minute)).
+ Save(ctx)
+ require.NoError(t, err)
+
+ body := bytes.NewBufferString(`{"invitation_code":"invite-1"}`)
+ recorder := httptest.NewRecorder()
+ c, _ := gin.CreateTestContext(recorder)
+ req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/oauth/linuxdo/complete-registration", body)
+ req.Header.Set("Content-Type", "application/json")
+ req.AddCookie(&http.Cookie{Name: oauthPendingSessionCookieName, Value: encodeCookieValue(session.SessionToken)})
+ req.AddCookie(&http.Cookie{Name: oauthPendingBrowserCookieName, Value: encodeCookieValue("linuxdo-invalid-browser")})
+ c.Request = req
+
+ handler.CompleteLinuxDoOAuthRegistration(c)
+
+ require.Equal(t, http.StatusBadRequest, recorder.Code)
+
+ storedSession, err := client.PendingAuthSession.Get(ctx, session.ID)
+ require.NoError(t, err)
+ require.Nil(t, storedSession.ConsumedAt)
+}
+
func newLinuxDoOAuthTestHandler(t *testing.T, invitationEnabled bool, oauthCfg config.LinuxDoConnectConfig) *AuthHandler {
t.Helper()
handler, _ := newLinuxDoOAuthHandlerAndClient(t, invitationEnabled, oauthCfg)
diff --git a/backend/internal/handler/auth_oauth_pending_flow.go b/backend/internal/handler/auth_oauth_pending_flow.go
index 461810f1..6041e5dd 100644
--- a/backend/internal/handler/auth_oauth_pending_flow.go
+++ b/backend/internal/handler/auth_oauth_pending_flow.go
@@ -253,6 +253,35 @@ func pendingSessionWantsInvitation(payload map[string]any) bool {
return strings.EqualFold(strings.TrimSpace(pendingSessionStringValue(payload, "error")), "invitation_required")
}
+func pendingOAuthCompletionIncludesTokenPayload(payload map[string]any) bool {
+ if len(payload) == 0 {
+ return false
+ }
+ for _, key := range []string{"access_token", "refresh_token"} {
+ if value := pendingSessionStringValue(payload, key); value != "" {
+ return true
+ }
+ }
+ return false
+}
+
+func ensurePendingOAuthCompleteRegistrationSession(session *dbent.PendingAuthSession) error {
+ if session == nil {
+ return infraerrors.BadRequest("PENDING_AUTH_SESSION_INVALID", "pending auth registration context is invalid")
+ }
+ if strings.TrimSpace(session.Intent) != oauthIntentLogin {
+ return infraerrors.BadRequest("PENDING_AUTH_SESSION_INVALID", "pending auth registration context is invalid")
+ }
+ if session.TargetUserID != nil && *session.TargetUserID > 0 {
+ return infraerrors.BadRequest("PENDING_AUTH_SESSION_INVALID", "pending auth registration context is invalid")
+ }
+ payload, _ := readCompletionResponse(session.LocalFlowState)
+ if strings.EqualFold(strings.TrimSpace(pendingSessionStringValue(payload, "step")), "bind_login_required") {
+ return infraerrors.BadRequest("PENDING_AUTH_SESSION_INVALID", "pending auth registration context is invalid")
+ }
+ return nil
+}
+
func (r oauthAdoptionDecisionRequest) hasDecision() bool {
return r.AdoptDisplayName != nil || r.AdoptAvatar != nil
}
@@ -1090,6 +1119,10 @@ func (h *AuthHandler) bindPendingOAuthLogin(c *gin.Context, provider string) {
response.ErrorFrom(c, infraerrors.Conflict("PENDING_AUTH_TARGET_USER_MISMATCH", "pending oauth session must be completed by the targeted user"))
return
}
+ if err := h.ensureBackendModeAllowsUser(c.Request.Context(), user); err != nil {
+ response.ErrorFrom(c, err)
+ return
+ }
decision, err := h.ensurePendingOAuthAdoptionDecision(c, session.ID, req.adoptionDecision())
if err != nil {
@@ -1192,6 +1225,10 @@ func (h *AuthHandler) createPendingOAuthAccount(c *gin.Context, provider string)
c.JSON(http.StatusOK, buildPendingOAuthSessionStatusPayload(session))
return
}
+ if err := h.ensureBackendModeAllowsNewUserLogin(c.Request.Context()); err != nil {
+ response.ErrorFrom(c, err)
+ return
+ }
tokenPair, user, err := h.authService.RegisterOAuthEmailAccount(
c.Request.Context(),
@@ -1215,6 +1252,7 @@ func (h *AuthHandler) createPendingOAuthAccount(c *gin.Context, provider string)
response.ErrorFrom(c, infraerrors.InternalServer("PENDING_AUTH_BIND_APPLY_FAILED", "failed to bind pending oauth identity").WithCause(err))
return
}
+ h.authService.RecordSuccessfulLogin(c.Request.Context(), user.ID)
if _, err := pendingSvc.ConsumeBrowserSession(c.Request.Context(), session.SessionToken, session.BrowserSessionKey); err != nil {
clearCookies()
@@ -1279,6 +1317,25 @@ func (h *AuthHandler) ExchangePendingOAuthCompletion(c *gin.Context) {
}
}
applySuggestedProfileToCompletionResponse(payload, session.UpstreamIdentityClaims)
+ if pendingOAuthCompletionIncludesTokenPayload(payload) {
+ if session.TargetUserID == nil || *session.TargetUserID <= 0 {
+ clearCookies()
+ response.ErrorFrom(c, infraerrors.InternalServer("PENDING_AUTH_COMPLETION_INVALID", "pending auth completion payload is invalid"))
+ return
+ }
+ user, err := h.userService.GetByID(c.Request.Context(), *session.TargetUserID)
+ if err != nil {
+ clearCookies()
+ response.ErrorFrom(c, err)
+ return
+ }
+ if err := h.ensureBackendModeAllowsUser(c.Request.Context(), user); err != nil {
+ clearCookies()
+ response.ErrorFrom(c, err)
+ return
+ }
+ h.authService.RecordSuccessfulLogin(c.Request.Context(), user.ID)
+ }
if pendingSessionWantsInvitation(payload) {
if adoptionDecision.hasDecision() {
diff --git a/backend/internal/handler/auth_oauth_pending_flow_test.go b/backend/internal/handler/auth_oauth_pending_flow_test.go
index d29e4b88..c2b83c73 100644
--- a/backend/internal/handler/auth_oauth_pending_flow_test.go
+++ b/backend/internal/handler/auth_oauth_pending_flow_test.go
@@ -523,6 +523,60 @@ func TestExchangePendingOAuthCompletionLoginFalseFalseBindsIdentityWithoutAdopti
require.NotNil(t, storedSession.ConsumedAt)
}
+func TestExchangePendingOAuthCompletionBlocksBackendModeBeforeReturningTokenPayload(t *testing.T) {
+ handler, client := newOAuthPendingFlowTestHandlerWithDependencies(t, oauthPendingFlowTestHandlerOptions{
+ settingValues: map[string]string{
+ service.SettingKeyBackendModeEnabled: "true",
+ },
+ })
+ ctx := context.Background()
+
+ userEntity, err := client.User.Create().
+ SetEmail("blocked@example.com").
+ SetUsername("blocked-user").
+ SetPasswordHash("hash").
+ SetRole(service.RoleUser).
+ SetStatus(service.StatusActive).
+ Save(ctx)
+ require.NoError(t, err)
+
+ session, err := client.PendingAuthSession.Create().
+ SetSessionToken("blocked-backend-mode-session-token").
+ SetIntent("login").
+ SetProviderType("linuxdo").
+ SetProviderKey("linuxdo").
+ SetProviderSubject("blocked-subject-123").
+ SetTargetUserID(userEntity.ID).
+ SetResolvedEmail(userEntity.Email).
+ SetBrowserSessionKey("blocked-backend-mode-browser-session-key").
+ SetLocalFlowState(map[string]any{
+ oauthCompletionResponseKey: map[string]any{
+ "access_token": "access-token",
+ "refresh_token": "refresh-token",
+ "expires_in": float64(3600),
+ "token_type": "Bearer",
+ },
+ }).
+ SetExpiresAt(time.Now().UTC().Add(10 * time.Minute)).
+ Save(ctx)
+ require.NoError(t, err)
+
+ recorder := httptest.NewRecorder()
+ ginCtx, _ := gin.CreateTestContext(recorder)
+ req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/oauth/pending/exchange", nil)
+ req.AddCookie(&http.Cookie{Name: oauthPendingSessionCookieName, Value: encodeCookieValue(session.SessionToken)})
+ req.AddCookie(&http.Cookie{Name: oauthPendingBrowserCookieName, Value: encodeCookieValue("blocked-backend-mode-browser-session-key")})
+ ginCtx.Request = req
+
+ handler.ExchangePendingOAuthCompletion(ginCtx)
+
+ require.Equal(t, http.StatusForbidden, recorder.Code)
+
+ storedSession, err := client.PendingAuthSession.Get(ctx, session.ID)
+ require.NoError(t, err)
+ require.Nil(t, storedSession.ConsumedAt)
+}
+
func TestExchangePendingOAuthCompletionInvitationRequiredFalseFalsePersistsDecisionWithoutBinding(t *testing.T) {
handler, client := newOAuthPendingFlowTestHandler(t, true)
ctx := context.Background()
@@ -773,6 +827,60 @@ func TestCreateOIDCOAuthAccountExistingEmailNormalizesLegacySpacingAndCase(t *te
require.Equal(t, "owner@example.com", storedSession.ResolvedEmail)
}
+func TestCreateOIDCOAuthAccountBlocksBackendModeBeforeCreatingUser(t *testing.T) {
+ handler, client := newOAuthPendingFlowTestHandlerWithDependencies(t, oauthPendingFlowTestHandlerOptions{
+ emailVerifyEnabled: true,
+ emailCache: &oauthPendingFlowEmailCacheStub{
+ verificationCodes: map[string]*service.VerificationCodeData{
+ "fresh@example.com": {
+ Code: "246810",
+ CreatedAt: time.Now().UTC(),
+ ExpiresAt: time.Now().UTC().Add(15 * time.Minute),
+ },
+ },
+ },
+ settingValues: map[string]string{
+ service.SettingKeyBackendModeEnabled: "true",
+ },
+ })
+ ctx := context.Background()
+
+ session, err := client.PendingAuthSession.Create().
+ SetSessionToken("create-account-backend-mode-session-token").
+ SetIntent("login").
+ SetProviderType("oidc").
+ SetProviderKey("https://issuer.example").
+ SetProviderSubject("oidc-create-backend-mode-123").
+ SetBrowserSessionKey("create-account-backend-mode-browser-session-key").
+ SetUpstreamIdentityClaims(map[string]any{
+ "username": "oidc_user",
+ }).
+ SetExpiresAt(time.Now().UTC().Add(10 * time.Minute)).
+ Save(ctx)
+ require.NoError(t, err)
+
+ body := bytes.NewBufferString(`{"email":"fresh@example.com","verify_code":"246810","password":"secret-123"}`)
+ recorder := httptest.NewRecorder()
+ ginCtx, _ := gin.CreateTestContext(recorder)
+ req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/oauth/oidc/create-account", body)
+ req.Header.Set("Content-Type", "application/json")
+ req.AddCookie(&http.Cookie{Name: oauthPendingSessionCookieName, Value: encodeCookieValue(session.SessionToken)})
+ req.AddCookie(&http.Cookie{Name: oauthPendingBrowserCookieName, Value: encodeCookieValue("create-account-backend-mode-browser-session-key")})
+ ginCtx.Request = req
+
+ handler.CreateOIDCOAuthAccount(ginCtx)
+
+ require.Equal(t, http.StatusForbidden, recorder.Code)
+
+ userCount, err := client.User.Query().Where(dbuser.EmailEQ("fresh@example.com")).Count(ctx)
+ require.NoError(t, err)
+ require.Zero(t, userCount)
+
+ storedSession, err := client.PendingAuthSession.Get(ctx, session.ID)
+ require.NoError(t, err)
+ require.Nil(t, storedSession.ConsumedAt)
+}
+
func TestBindOIDCOAuthLoginBindsExistingUserAndConsumesSession(t *testing.T) {
handler, client := newOAuthPendingFlowTestHandler(t, false)
ctx := context.Background()
@@ -842,6 +950,70 @@ func TestBindOIDCOAuthLoginBindsExistingUserAndConsumesSession(t *testing.T) {
require.NotNil(t, storedSession.ConsumedAt)
}
+func TestBindOIDCOAuthLoginBlocksBackendModeBeforeTokenIssue(t *testing.T) {
+ handler, client := newOAuthPendingFlowTestHandlerWithDependencies(t, oauthPendingFlowTestHandlerOptions{
+ settingValues: map[string]string{
+ service.SettingKeyBackendModeEnabled: "true",
+ },
+ })
+ ctx := context.Background()
+
+ passwordHash, err := handler.authService.HashPassword("secret-123")
+ require.NoError(t, err)
+
+ existingUser, err := client.User.Create().
+ SetEmail("owner@example.com").
+ SetUsername("owner-user").
+ SetPasswordHash(passwordHash).
+ SetRole(service.RoleUser).
+ SetStatus(service.StatusActive).
+ Save(ctx)
+ require.NoError(t, err)
+
+ session, err := client.PendingAuthSession.Create().
+ SetSessionToken("bind-login-backend-mode-session-token").
+ SetIntent("adopt_existing_user_by_email").
+ SetProviderType("oidc").
+ SetProviderKey("https://issuer.example").
+ SetProviderSubject("oidc-bind-backend-mode-123").
+ SetTargetUserID(existingUser.ID).
+ SetResolvedEmail(existingUser.Email).
+ SetBrowserSessionKey("bind-login-backend-mode-browser-session-key").
+ SetUpstreamIdentityClaims(map[string]any{
+ "username": "oidc_user",
+ }).
+ SetExpiresAt(time.Now().UTC().Add(10 * time.Minute)).
+ Save(ctx)
+ require.NoError(t, err)
+
+ body := bytes.NewBufferString(`{"email":"owner@example.com","password":"secret-123"}`)
+ recorder := httptest.NewRecorder()
+ ginCtx, _ := gin.CreateTestContext(recorder)
+ req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/oauth/oidc/bind-login", body)
+ req.Header.Set("Content-Type", "application/json")
+ req.AddCookie(&http.Cookie{Name: oauthPendingSessionCookieName, Value: encodeCookieValue(session.SessionToken)})
+ req.AddCookie(&http.Cookie{Name: oauthPendingBrowserCookieName, Value: encodeCookieValue("bind-login-backend-mode-browser-session-key")})
+ ginCtx.Request = req
+
+ handler.BindOIDCOAuthLogin(ginCtx)
+
+ require.Equal(t, http.StatusForbidden, recorder.Code)
+
+ identityCount, err := client.AuthIdentity.Query().
+ Where(
+ authidentity.ProviderTypeEQ("oidc"),
+ authidentity.ProviderKeyEQ("https://issuer.example"),
+ authidentity.ProviderSubjectEQ("oidc-bind-backend-mode-123"),
+ ).
+ Count(ctx)
+ require.NoError(t, err)
+ require.Zero(t, identityCount)
+
+ storedSession, err := client.PendingAuthSession.Get(ctx, session.ID)
+ require.NoError(t, err)
+ require.Nil(t, storedSession.ConsumedAt)
+}
+
func TestBindOIDCOAuthLoginRejectsInvalidPasswordWithoutConsumingSession(t *testing.T) {
handler, client := newOAuthPendingFlowTestHandler(t, false)
ctx := context.Background()
diff --git a/backend/internal/handler/auth_oidc_oauth.go b/backend/internal/handler/auth_oidc_oauth.go
index 0f9f1895..5901a953 100644
--- a/backend/internal/handler/auth_oidc_oauth.go
+++ b/backend/internal/handler/auth_oidc_oauth.go
@@ -516,6 +516,14 @@ func (h *AuthHandler) CompleteOIDCOAuthRegistration(c *gin.Context) {
response.ErrorFrom(c, err)
return
}
+ if err := ensurePendingOAuthCompleteRegistrationSession(session); err != nil {
+ response.ErrorFrom(c, err)
+ return
+ }
+ if err := h.ensureBackendModeAllowsNewUserLogin(c.Request.Context()); err != nil {
+ response.ErrorFrom(c, err)
+ return
+ }
email := strings.TrimSpace(session.ResolvedEmail)
username := pendingSessionStringValue(session.UpstreamIdentityClaims, "username")
@@ -541,6 +549,7 @@ func (h *AuthHandler) CompleteOIDCOAuthRegistration(c *gin.Context) {
response.ErrorFrom(c, infraerrors.InternalServer("PENDING_AUTH_ADOPTION_APPLY_FAILED", "failed to apply oauth profile adoption").WithCause(err))
return
}
+ h.authService.RecordSuccessfulLogin(c.Request.Context(), user.ID)
if _, err := pendingSvc.ConsumeBrowserSession(c.Request.Context(), sessionToken, browserSessionKey); err != nil {
clearOAuthPendingSessionCookie(c, secureCookie)
clearOAuthPendingBrowserCookie(c, secureCookie)
diff --git a/backend/internal/handler/auth_oidc_oauth_test.go b/backend/internal/handler/auth_oidc_oauth_test.go
index 07f5ef68..ba736db2 100644
--- a/backend/internal/handler/auth_oidc_oauth_test.go
+++ b/backend/internal/handler/auth_oidc_oauth_test.go
@@ -431,6 +431,58 @@ func TestCompleteOIDCOAuthRegistrationAppliesPendingAdoptionDecision(t *testing.
require.NotNil(t, consumed.ConsumedAt)
}
+func TestCompleteOIDCOAuthRegistrationRejectsAdoptExistingUserSession(t *testing.T) {
+ handler, client := newOAuthPendingFlowTestHandler(t, false)
+ ctx := context.Background()
+
+ existingUser, err := client.User.Create().
+ SetEmail("owner@example.com").
+ SetUsername("owner-user").
+ SetPasswordHash("hash").
+ SetRole(service.RoleUser).
+ SetStatus(service.StatusActive).
+ Save(ctx)
+ require.NoError(t, err)
+
+ session, err := client.PendingAuthSession.Create().
+ SetSessionToken("oidc-complete-invalid-session").
+ SetIntent("adopt_existing_user_by_email").
+ SetProviderType("oidc").
+ SetProviderKey("https://issuer.example.com").
+ SetProviderSubject("oidc-invalid-subject-1").
+ SetTargetUserID(existingUser.ID).
+ SetResolvedEmail(existingUser.Email).
+ SetBrowserSessionKey("oidc-invalid-browser").
+ SetUpstreamIdentityClaims(map[string]any{
+ "username": "oidc_user",
+ }).
+ SetLocalFlowState(map[string]any{
+ oauthCompletionResponseKey: map[string]any{
+ "step": "bind_login_required",
+ },
+ }).
+ SetExpiresAt(time.Now().UTC().Add(10 * time.Minute)).
+ Save(ctx)
+ require.NoError(t, err)
+
+ body := bytes.NewBufferString(`{"invitation_code":"invite-1"}`)
+ recorder := httptest.NewRecorder()
+ c, _ := gin.CreateTestContext(recorder)
+ req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/oauth/oidc/complete-registration", body)
+ req.Header.Set("Content-Type", "application/json")
+ req.AddCookie(&http.Cookie{Name: oauthPendingSessionCookieName, Value: encodeCookieValue(session.SessionToken)})
+ req.AddCookie(&http.Cookie{Name: oauthPendingBrowserCookieName, Value: encodeCookieValue("oidc-invalid-browser")})
+ c.Request = req
+
+ handler.CompleteOIDCOAuthRegistration(c)
+
+ require.Equal(t, http.StatusBadRequest, recorder.Code)
+
+ storedSession, err := client.PendingAuthSession.Get(ctx, session.ID)
+ require.NoError(t, err)
+ require.Nil(t, storedSession.ConsumedAt)
+}
+
type oidcProviderFixture struct {
Subject string
PreferredUsername string
diff --git a/backend/internal/handler/auth_wechat_oauth.go b/backend/internal/handler/auth_wechat_oauth.go
index 45de30a8..5e697fb5 100644
--- a/backend/internal/handler/auth_wechat_oauth.go
+++ b/backend/internal/handler/auth_wechat_oauth.go
@@ -506,6 +506,14 @@ func (h *AuthHandler) CompleteWeChatOAuthRegistration(c *gin.Context) {
response.ErrorFrom(c, err)
return
}
+ if err := ensurePendingOAuthCompleteRegistrationSession(session); err != nil {
+ response.ErrorFrom(c, err)
+ return
+ }
+ if err := h.ensureBackendModeAllowsNewUserLogin(c.Request.Context()); err != nil {
+ response.ErrorFrom(c, err)
+ return
+ }
email := strings.TrimSpace(session.ResolvedEmail)
username := pendingSessionStringValue(session.UpstreamIdentityClaims, "username")
@@ -531,6 +539,7 @@ func (h *AuthHandler) CompleteWeChatOAuthRegistration(c *gin.Context) {
response.ErrorFrom(c, infraerrors.InternalServer("PENDING_AUTH_ADOPTION_APPLY_FAILED", "failed to apply oauth profile adoption").WithCause(err))
return
}
+ h.authService.RecordSuccessfulLogin(c.Request.Context(), user.ID)
if _, err := pendingSvc.ConsumeBrowserSession(c.Request.Context(), sessionToken, browserSessionKey); err != nil {
clearOAuthPendingSessionCookie(c, secureCookie)
clearOAuthPendingBrowserCookie(c, secureCookie)
diff --git a/backend/internal/handler/auth_wechat_oauth_test.go b/backend/internal/handler/auth_wechat_oauth_test.go
index c65f4cd1..cd34f52f 100644
--- a/backend/internal/handler/auth_wechat_oauth_test.go
+++ b/backend/internal/handler/auth_wechat_oauth_test.go
@@ -846,6 +846,59 @@ func TestWeChatOAuthCallbackRepairsLegacyOpenIDOnlyIdentity(t *testing.T) {
require.Equal(t, repairedIdentity.ID, channel.IdentityID)
}
+func TestCompleteWeChatOAuthRegistrationRejectsAdoptExistingUserSession(t *testing.T) {
+ handler, client := newWeChatOAuthTestHandler(t, false)
+ defer client.Close()
+
+ ctx := context.Background()
+ existingUser, err := client.User.Create().
+ SetEmail("owner@example.com").
+ SetUsername("owner-user").
+ SetPasswordHash("hash").
+ SetRole(service.RoleUser).
+ SetStatus(service.StatusActive).
+ Save(ctx)
+ require.NoError(t, err)
+
+ session, err := client.PendingAuthSession.Create().
+ SetSessionToken("wechat-complete-invalid-session").
+ SetIntent("adopt_existing_user_by_email").
+ SetProviderType("wechat").
+ SetProviderKey("wechat-main").
+ SetProviderSubject("union-invalid-1").
+ SetTargetUserID(existingUser.ID).
+ SetResolvedEmail(existingUser.Email).
+ SetBrowserSessionKey("wechat-invalid-browser").
+ SetUpstreamIdentityClaims(map[string]any{
+ "username": "wechat_user",
+ }).
+ SetLocalFlowState(map[string]any{
+ oauthCompletionResponseKey: map[string]any{
+ "step": "bind_login_required",
+ },
+ }).
+ SetExpiresAt(time.Now().UTC().Add(10 * time.Minute)).
+ Save(ctx)
+ require.NoError(t, err)
+
+ body := bytes.NewBufferString(`{"invitation_code":"invite-1"}`)
+ recorder := httptest.NewRecorder()
+ completeCtx, _ := gin.CreateTestContext(recorder)
+ req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/oauth/wechat/complete-registration", body)
+ req.Header.Set("Content-Type", "application/json")
+ req.AddCookie(&http.Cookie{Name: oauthPendingSessionCookieName, Value: encodeCookieValue(session.SessionToken)})
+ req.AddCookie(&http.Cookie{Name: oauthPendingBrowserCookieName, Value: encodeCookieValue("wechat-invalid-browser")})
+ completeCtx.Request = req
+
+ handler.CompleteWeChatOAuthRegistration(completeCtx)
+
+ require.Equal(t, http.StatusBadRequest, recorder.Code)
+
+ storedSession, err := client.PendingAuthSession.Get(ctx, session.ID)
+ require.NoError(t, err)
+ require.Nil(t, storedSession.ConsumedAt)
+}
+
func TestWeChatOAuthCallbackRepairsLegacyProviderKeyCanonicalIdentity(t *testing.T) {
t.Setenv("WECHAT_OAUTH_OPEN_APP_ID", "wx-open-app")
t.Setenv("WECHAT_OAUTH_OPEN_APP_SECRET", "wx-open-secret")
diff --git a/backend/internal/service/auth_oauth_email_flow.go b/backend/internal/service/auth_oauth_email_flow.go
index 4ab4e245..2e0107ae 100644
--- a/backend/internal/service/auth_oauth_email_flow.go
+++ b/backend/internal/service/auth_oauth_email_flow.go
@@ -104,7 +104,7 @@ func (s *AuthService) RegisterOAuthEmailAccount(
return nil, nil, ErrServiceUnavailable
}
- s.postAuthUserBootstrap(ctx, user, signupSource, true)
+ s.postAuthUserBootstrap(ctx, user, signupSource, false)
s.assignSubscriptions(ctx, user.ID, grantPlan.Subscriptions, "auto assigned by signup defaults")
if invitationRedeemCode != nil {
diff --git a/backend/internal/service/auth_service.go b/backend/internal/service/auth_service.go
index 00fefd82..d63a8753 100644
--- a/backend/internal/service/auth_service.go
+++ b/backend/internal/service/auth_service.go
@@ -430,8 +430,6 @@ func (s *AuthService) Login(ctx context.Context, email, password string) (string
if !user.IsActive() {
return "", nil, ErrUserNotActive
}
- s.backfillEmailIdentityOnSuccessfulLogin(ctx, user)
- s.touchUserLogin(ctx, user.ID)
// 生成JWT token
token, err := s.GenerateToken(user)
@@ -507,7 +505,7 @@ func (s *AuthService) LoginOrRegisterOAuth(ctx context.Context, email, username
}
} else {
user = newUser
- s.postAuthUserBootstrap(ctx, user, signupSource, true)
+ s.postAuthUserBootstrap(ctx, user, signupSource, false)
s.assignSubscriptions(ctx, user.ID, grantPlan.Subscriptions, "auto assigned by signup defaults")
}
} else {
@@ -527,8 +525,6 @@ func (s *AuthService) LoginOrRegisterOAuth(ctx context.Context, email, username
logger.LegacyPrintf("service.auth", "[Auth] Failed to update username after oauth login: %v", err)
}
}
- s.touchUserLogin(ctx, user.ID)
-
token, err := s.GenerateToken(user)
if err != nil {
return "", nil, fmt.Errorf("generate token: %w", err)
@@ -634,7 +630,7 @@ func (s *AuthService) LoginOrRegisterOAuthWithTokenPair(ctx context.Context, ema
return nil, nil, ErrServiceUnavailable
}
user = newUser
- s.postAuthUserBootstrap(ctx, user, signupSource, true)
+ s.postAuthUserBootstrap(ctx, user, signupSource, false)
s.assignSubscriptions(ctx, user.ID, grantPlan.Subscriptions, "auto assigned by signup defaults")
}
} else {
@@ -651,7 +647,7 @@ func (s *AuthService) LoginOrRegisterOAuthWithTokenPair(ctx context.Context, ema
}
} else {
user = newUser
- s.postAuthUserBootstrap(ctx, user, signupSource, true)
+ s.postAuthUserBootstrap(ctx, user, signupSource, false)
s.assignSubscriptions(ctx, user.ID, grantPlan.Subscriptions, "auto assigned by signup defaults")
if invitationRedeemCode != nil {
if err := s.redeemRepo.Use(ctx, invitationRedeemCode.ID, user.ID); err != nil {
@@ -676,8 +672,6 @@ func (s *AuthService) LoginOrRegisterOAuthWithTokenPair(ctx context.Context, ema
logger.LegacyPrintf("service.auth", "[Auth] Failed to update username after oauth login: %v", err)
}
}
- s.touchUserLogin(ctx, user.ID)
-
tokenPair, err := s.GenerateTokenPair(ctx, user, "")
if err != nil {
return nil, nil, fmt.Errorf("generate token pair: %w", err)
diff --git a/backend/internal/service/auth_service_identity_sync_test.go b/backend/internal/service/auth_service_identity_sync_test.go
index 95c9c933..85c13604 100644
--- a/backend/internal/service/auth_service_identity_sync_test.go
+++ b/backend/internal/service/auth_service_identity_sync_test.go
@@ -170,24 +170,26 @@ func TestAuthServiceRegisterDualWritesEmailIdentity(t *testing.T) {
require.NotNil(t, identity.VerifiedAt)
}
-func TestAuthServiceLoginTouchesLastLoginAt(t *testing.T) {
- svc, repo, client := newAuthServiceWithEnt(t, map[string]string{
+func TestAuthServiceLoginDefersLastLoginTouchUntilRecordSuccessfulLogin(t *testing.T) {
+ svc, _, client := newAuthServiceWithEnt(t, map[string]string{
service.SettingKeyRegistrationEnabled: "true",
}, nil)
ctx := context.Background()
- user := &service.User{
- Email: "login@example.com",
- Role: service.RoleUser,
- Status: service.StatusActive,
- Balance: 1,
- Concurrency: 1,
- }
- require.NoError(t, user.SetPassword("password"))
- require.NoError(t, repo.Create(ctx, user))
+ passwordHash, err := svc.HashPassword("password")
+ require.NoError(t, err)
+ user, err := client.User.Create().
+ SetEmail("login@example.com").
+ SetPasswordHash(passwordHash).
+ SetRole(service.RoleUser).
+ SetStatus(service.StatusActive).
+ SetBalance(1).
+ SetConcurrency(1).
+ Save(ctx)
+ require.NoError(t, err)
old := time.Now().Add(-2 * time.Hour).UTC().Round(time.Second)
- _, err := client.User.UpdateOneID(user.ID).
+ _, err = client.User.UpdateOneID(user.ID).
SetLastLoginAt(old).
SetLastActiveAt(old).
Save(ctx)
@@ -202,8 +204,20 @@ func TestAuthServiceLoginTouchesLastLoginAt(t *testing.T) {
require.NoError(t, err)
require.NotNil(t, storedUser.LastLoginAt)
require.NotNil(t, storedUser.LastActiveAt)
- require.True(t, storedUser.LastLoginAt.After(old))
- require.True(t, storedUser.LastActiveAt.After(old))
+ require.True(t, storedUser.LastLoginAt.Equal(old))
+ require.True(t, storedUser.LastActiveAt.Equal(old))
+
+ identityCount, err := client.AuthIdentity.Query().
+ Where(
+ authidentity.ProviderTypeEQ("email"),
+ authidentity.ProviderKeyEQ("email"),
+ authidentity.ProviderSubjectEQ("login@example.com"),
+ ).
+ Count(ctx)
+ require.NoError(t, err)
+ require.Zero(t, identityCount)
+
+ svc.RecordSuccessfulLogin(ctx, user.ID)
identity, err := client.AuthIdentity.Query().
Where(
@@ -273,6 +287,7 @@ func TestAuthServiceLogin_AppliesEmailFirstBindDefaultsOnlyWhenEmailIdentityIsNe
require.NoError(t, err)
require.NotEmpty(t, token)
require.NotNil(t, gotUser)
+ svc.RecordSuccessfulLogin(ctx, user.ID)
storedUser, err := client.User.Get(ctx, user.ID)
require.NoError(t, err)
@@ -343,6 +358,7 @@ func TestAuthServiceLogin_DoesNotApplyEmailFirstBindDefaultsWhenIdentityAlreadyE
require.NoError(t, err)
require.NotEmpty(t, token)
require.NotNil(t, gotUser)
+ svc.RecordSuccessfulLogin(ctx, user.ID)
storedUser, err := client.User.Get(ctx, user.ID)
require.NoError(t, err)
@@ -380,6 +396,7 @@ func TestAuthServiceLogin_RetriesEmailFirstBindDefaultsAfterPreviousFailure(t *t
require.NoError(t, err)
require.NotEmpty(t, token)
require.NotNil(t, gotUser)
+ svc.RecordSuccessfulLogin(ctx, user.ID)
storedUser, err := client.User.Get(ctx, user.ID)
require.NoError(t, err)
@@ -392,6 +409,7 @@ func TestAuthServiceLogin_RetriesEmailFirstBindDefaultsAfterPreviousFailure(t *t
require.NoError(t, err)
require.NotEmpty(t, token)
require.NotNil(t, gotUser)
+ svc.RecordSuccessfulLogin(ctx, user.ID)
storedUser, err = client.User.Get(ctx, user.ID)
require.NoError(t, err)
From cd0338fbae08b2641f17ad1d4b29ebabffca64fc Mon Sep 17 00:00:00 2001
From: IanShaw027
Date: Tue, 21 Apr 2026 01:48:23 +0800
Subject: [PATCH 085/189] fix frontend wechat oauth capability recovery
---
frontend/src/api/auth.ts | 33 +++++++++
frontend/src/api/user.ts | 4 +-
.../ProfileIdentityBindingsSection.vue | 16 ++---
.../user/profile/ProfileInfoCard.vue | 6 ++
.../ProfileIdentityBindingsSection.spec.ts | 72 +++++++++++++++++++
.../src/views/auth/WechatCallbackView.vue | 55 ++++++++------
.../auth/__tests__/WechatCallbackView.spec.ts | 55 ++++++++++++++
frontend/src/views/user/ProfileView.vue | 10 +++
8 files changed, 219 insertions(+), 32 deletions(-)
diff --git a/frontend/src/api/auth.ts b/frontend/src/api/auth.ts
index 383bdcce..9cf0c7ad 100644
--- a/frontend/src/api/auth.ts
+++ b/frontend/src/api/auth.ts
@@ -352,6 +352,7 @@ export async function getPublicSettings(): Promise {
export type WeChatOAuthMode = 'open' | 'mp'
export type WeChatOAuthUnavailableReason =
| 'not_configured'
+ | 'capability_unknown'
| 'external_browser_required'
| 'wechat_browser_required'
@@ -369,6 +370,16 @@ export type WeChatOAuthPublicSettings = {
wechat_oauth_mp_enabled?: boolean
}
+export function hasExplicitWeChatOAuthCapabilities(
+ settings: WeChatOAuthPublicSettings | null | undefined,
+): settings is WeChatOAuthPublicSettings & {
+ wechat_oauth_open_enabled: boolean
+ wechat_oauth_mp_enabled: boolean
+} {
+ return typeof settings?.wechat_oauth_open_enabled === 'boolean'
+ && typeof settings?.wechat_oauth_mp_enabled === 'boolean'
+}
+
export function resolveWeChatOAuthStart(
settings: WeChatOAuthPublicSettings | null | undefined,
userAgent?: string
@@ -404,6 +415,28 @@ export function resolveWeChatOAuthStart(
return { mode: null, openEnabled, mpEnabled, isWeChatBrowser, unavailableReason: 'not_configured' }
}
+export function resolveWeChatOAuthStartStrict(
+ settings: WeChatOAuthPublicSettings | null | undefined,
+ userAgent?: string,
+): ResolvedWeChatOAuthStart {
+ const normalizedUserAgent = (userAgent
+ ?? (typeof navigator !== 'undefined' ? navigator.userAgent : '')
+ ?? '').trim()
+ const isWeChatBrowser = /MicroMessenger/i.test(normalizedUserAgent)
+
+ if (!hasExplicitWeChatOAuthCapabilities(settings)) {
+ return {
+ mode: null,
+ openEnabled: false,
+ mpEnabled: false,
+ isWeChatBrowser,
+ unavailableReason: 'capability_unknown',
+ }
+ }
+
+ return resolveWeChatOAuthStart(settings, normalizedUserAgent)
+}
+
/**
* Send verification code to email
* @param request - Email and optional Turnstile token
diff --git a/frontend/src/api/user.ts b/frontend/src/api/user.ts
index 41ebea05..7b498303 100644
--- a/frontend/src/api/user.ts
+++ b/frontend/src/api/user.ts
@@ -5,8 +5,8 @@
import { apiClient } from './client'
import {
+ resolveWeChatOAuthStartStrict,
prepareOAuthBindAccessTokenCookie,
- resolveWeChatOAuthStart,
type WeChatOAuthPublicSettings,
} from './auth'
import type { User, ChangePasswordRequest, NotifyEmailEntry, UserAuthProvider } from '@/types'
@@ -107,7 +107,7 @@ function resolveWeChatOAuthBindingMode(
settings?: WeChatOAuthPublicSettings | null
): 'open' | 'mp' | null {
if (settings) {
- return resolveWeChatOAuthStart(settings).mode
+ return resolveWeChatOAuthStartStrict(settings).mode
}
return resolveWeChatOAuthMode()
}
diff --git a/frontend/src/components/user/profile/ProfileIdentityBindingsSection.vue b/frontend/src/components/user/profile/ProfileIdentityBindingsSection.vue
index d3d9814d..76c94ba1 100644
--- a/frontend/src/components/user/profile/ProfileIdentityBindingsSection.vue
+++ b/frontend/src/components/user/profile/ProfileIdentityBindingsSection.vue
@@ -52,7 +52,11 @@
import { computed } from 'vue'
import { useI18n } from 'vue-i18n'
import { useRoute } from 'vue-router'
-import { resolveWeChatOAuthStart, type WeChatOAuthPublicSettings } from '@/api/auth'
+import {
+ hasExplicitWeChatOAuthCapabilities,
+ resolveWeChatOAuthStartStrict,
+ type WeChatOAuthPublicSettings,
+} from '@/api/auth'
import { startOAuthBinding } from '@/api/user'
import { useAppStore } from '@/stores'
import type { User, UserAuthBindingStatus, UserAuthProvider } from '@/types'
@@ -82,15 +86,11 @@ const route = useRoute()
const appStore = useAppStore()
const wechatOAuthSettings = computed(() => {
- if (appStore.cachedPublicSettings) {
+ if (hasExplicitWeChatOAuthCapabilities(appStore.cachedPublicSettings)) {
return appStore.cachedPublicSettings
}
- if (
- typeof props.wechatEnabled === 'boolean' ||
- typeof props.wechatOpenEnabled === 'boolean' ||
- typeof props.wechatMpEnabled === 'boolean'
- ) {
+ if (typeof props.wechatOpenEnabled === 'boolean' && typeof props.wechatMpEnabled === 'boolean') {
return {
wechat_oauth_enabled: props.wechatEnabled,
wechat_oauth_open_enabled: props.wechatOpenEnabled,
@@ -101,7 +101,7 @@ const wechatOAuthSettings = computed(() => {
return null
})
-const resolvedWeChatBinding = computed(() => resolveWeChatOAuthStart(wechatOAuthSettings.value))
+const resolvedWeChatBinding = computed(() => resolveWeChatOAuthStartStrict(wechatOAuthSettings.value))
function normalizeBindingStatus(binding: boolean | UserAuthBindingStatus | undefined): boolean | null {
if (typeof binding === 'boolean') {
diff --git a/frontend/src/components/user/profile/ProfileInfoCard.vue b/frontend/src/components/user/profile/ProfileInfoCard.vue
index d6273431..460e5c7f 100644
--- a/frontend/src/components/user/profile/ProfileInfoCard.vue
+++ b/frontend/src/components/user/profile/ProfileInfoCard.vue
@@ -133,6 +133,8 @@
:oidc-enabled="oidcEnabled"
:oidc-provider-name="oidcProviderName"
:wechat-enabled="wechatEnabled"
+ :wechat-open-enabled="wechatOpenEnabled"
+ :wechat-mp-enabled="wechatMpEnabled"
/>
@@ -156,12 +158,16 @@ const props = withDefaults(
oidcEnabled?: boolean
oidcProviderName?: string
wechatEnabled?: boolean
+ wechatOpenEnabled?: boolean
+ wechatMpEnabled?: boolean
}>(),
{
linuxdoEnabled: false,
oidcEnabled: false,
oidcProviderName: 'OIDC',
wechatEnabled: false,
+ wechatOpenEnabled: undefined,
+ wechatMpEnabled: undefined,
}
)
diff --git a/frontend/src/components/user/profile/__tests__/ProfileIdentityBindingsSection.spec.ts b/frontend/src/components/user/profile/__tests__/ProfileIdentityBindingsSection.spec.ts
index 7db2ecd7..4e194a39 100644
--- a/frontend/src/components/user/profile/__tests__/ProfileIdentityBindingsSection.spec.ts
+++ b/frontend/src/components/user/profile/__tests__/ProfileIdentityBindingsSection.spec.ts
@@ -100,6 +100,8 @@ describe('ProfileIdentityBindingsSection', () => {
oidcEnabled: true,
oidcProviderName: 'ExampleID',
wechatEnabled: true,
+ wechatOpenEnabled: true,
+ wechatMpEnabled: false,
},
})
@@ -152,4 +154,74 @@ describe('ProfileIdentityBindingsSection', () => {
expect(wrapper.find('[data-testid="profile-binding-wechat-action"]').exists()).toBe(false)
})
+
+ it('hides the WeChat bind action when only the legacy aggregate setting is present', () => {
+ const wrapper = mount(ProfileIdentityBindingsSection, {
+ global: {
+ plugins: [pinia],
+ },
+ props: {
+ user: createUser(),
+ linuxdoEnabled: false,
+ oidcEnabled: false,
+ wechatEnabled: true,
+ },
+ })
+
+ expect(wrapper.find('[data-testid="profile-binding-wechat-action"]').exists()).toBe(false)
+ })
+
+ it('uses explicit cached WeChat capabilities and ignores legacy prop fallbacks', () => {
+ const appStore = useAppStore()
+ appStore.cachedPublicSettings = {
+ registration_enabled: false,
+ email_verify_enabled: false,
+ force_email_on_third_party_signup: false,
+ registration_email_suffix_whitelist: [],
+ promo_code_enabled: true,
+ password_reset_enabled: false,
+ invitation_code_enabled: false,
+ turnstile_enabled: false,
+ turnstile_site_key: '',
+ site_name: 'Sub2API',
+ site_logo: '',
+ site_subtitle: '',
+ api_base_url: '',
+ contact_info: '',
+ doc_url: '',
+ home_content: '',
+ hide_ccs_import_button: false,
+ payment_enabled: false,
+ table_default_page_size: 20,
+ table_page_size_options: [10, 20, 50, 100],
+ custom_menu_items: [],
+ custom_endpoints: [],
+ linuxdo_oauth_enabled: false,
+ wechat_oauth_enabled: true,
+ wechat_oauth_open_enabled: true,
+ wechat_oauth_mp_enabled: false,
+ oidc_oauth_enabled: false,
+ oidc_oauth_provider_name: 'OIDC',
+ backend_mode_enabled: false,
+ version: 'test',
+ balance_low_notify_enabled: false,
+ account_quota_notify_enabled: false,
+ balance_low_notify_threshold: 0,
+ }
+ appStore.publicSettingsLoaded = true
+
+ const wrapper = mount(ProfileIdentityBindingsSection, {
+ global: {
+ plugins: [pinia],
+ },
+ props: {
+ user: createUser(),
+ linuxdoEnabled: false,
+ oidcEnabled: false,
+ wechatEnabled: true,
+ },
+ })
+
+ expect(wrapper.find('[data-testid="profile-binding-wechat-action"]').exists()).toBe(true)
+ })
})
diff --git a/frontend/src/views/auth/WechatCallbackView.vue b/frontend/src/views/auth/WechatCallbackView.vue
index a5da35e5..35cd0032 100644
--- a/frontend/src/views/auth/WechatCallbackView.vue
+++ b/frontend/src/views/auth/WechatCallbackView.vue
@@ -292,12 +292,13 @@ import {
completeWeChatOAuthRegistration,
exchangePendingOAuthCompletion,
getAuthToken,
+ hasExplicitWeChatOAuthCapabilities,
getOAuthCompletionKind,
isOAuthLoginCompletion,
login2FA,
prepareOAuthBindAccessTokenCookie,
persistOAuthTokenContext,
- resolveWeChatOAuthStart,
+ resolveWeChatOAuthStartStrict,
type OAuthAdoptionDecision,
type PendingOAuthExchangeResponse
} from '@/api/auth'
@@ -368,41 +369,32 @@ function sanitizeRedirectPath(path: string | null | undefined): string {
return path
}
-function resolveWeChatOAuthMode(): 'open' | 'mp' {
- if (typeof navigator === 'undefined') {
- return 'open'
- }
- return /MicroMessenger/i.test(navigator.userAgent) ? 'mp' : 'open'
-}
-
-function normalizeWeChatOAuthMode(value: unknown): 'open' | 'mp' | null {
- return value === 'open' || value === 'mp' ? value : null
-}
-
async function ensurePublicSettingsLoaded(): Promise {
- if (appStore.cachedPublicSettings || appStore.publicSettingsLoaded) {
+ if (hasExplicitWeChatOAuthCapabilities(appStore.cachedPublicSettings)) {
return
}
- try {
- await appStore.fetchPublicSettings()
- } catch {
- // Fall back to legacy mode selection when public settings are unavailable.
+ if (appStore.publicSettingsLoaded) {
+ return
}
+
+ await appStore.fetchPublicSettings()
}
function resolveConfiguredWeChatOAuthMode(): 'open' | 'mp' | null {
- if (!appStore.cachedPublicSettings && !appStore.publicSettingsLoaded) {
+ if (!hasExplicitWeChatOAuthCapabilities(appStore.cachedPublicSettings)) {
return null
}
- return resolveWeChatOAuthStart(appStore.cachedPublicSettings).mode
+ return resolveWeChatOAuthStartStrict(appStore.cachedPublicSettings).mode
}
function resolveWeChatOAuthUnavailableMessage(): string {
- const resolved = resolveWeChatOAuthStart(appStore.cachedPublicSettings)
+ const resolved = resolveWeChatOAuthStartStrict(appStore.cachedPublicSettings)
switch (resolved.unavailableReason) {
+ case 'capability_unknown':
+ return 'WeChat sign-in availability could not be confirmed. Refresh and retry.'
case 'external_browser_required':
return 'This WeChat sign-in flow is only available in your system browser.'
case 'wechat_browser_required':
@@ -414,6 +406,17 @@ function resolveWeChatOAuthUnavailableMessage(): string {
}
}
+function resolveRuntimeWeChatOAuthMode(): 'open' | 'mp' {
+ if (typeof navigator === 'undefined') {
+ return 'open'
+ }
+ return /MicroMessenger/i.test(navigator.userAgent) ? 'mp' : 'open'
+}
+
+function normalizeWeChatOAuthMode(value: unknown): 'open' | 'mp' | null {
+ return value === 'open' || value === 'mp' ? value : null
+}
+
function resolveRequestedWeChatOAuthMode(): 'open' | 'mp' | null {
const configuredMode = resolveConfiguredWeChatOAuthMode()
if (configuredMode) {
@@ -421,7 +424,11 @@ function resolveRequestedWeChatOAuthMode(): 'open' | 'mp' | null {
}
const queryMode = normalizeWeChatOAuthMode(route.query.mode)
- return queryMode || resolveWeChatOAuthMode()
+ if (queryMode) {
+ return queryMode
+ }
+
+ return resolveRuntimeWeChatOAuthMode()
}
function resolveRedirectTarget(): string {
@@ -786,7 +793,11 @@ async function handleSubmitTotpChallenge() {
}
onMounted(async () => {
- await ensurePublicSettingsLoaded()
+ try {
+ await ensurePublicSettingsLoaded()
+ } catch {
+ // Binding recovery requires confirmed capability flags. Use the strict guard below.
+ }
if (typeof route.query.email === 'string') {
existingAccountEmail.value = route.query.email
diff --git a/frontend/src/views/auth/__tests__/WechatCallbackView.spec.ts b/frontend/src/views/auth/__tests__/WechatCallbackView.spec.ts
index 7f26f3c8..fed88890 100644
--- a/frontend/src/views/auth/__tests__/WechatCallbackView.spec.ts
+++ b/frontend/src/views/auth/__tests__/WechatCallbackView.spec.ts
@@ -201,6 +201,61 @@ describe('WechatCallbackView', () => {
expect(locationState.current.href).not.toContain('mode=mp')
})
+ it('falls back to the query mode when capability settings cannot be confirmed', async () => {
+ routeState.query = {
+ wechat_bind_existing: '1',
+ mode: 'mp',
+ redirect: '/profile',
+ }
+ fetchPublicSettingsMock.mockResolvedValue(null)
+ getAuthTokenMock.mockReturnValue('current-auth-token')
+
+ mount(WechatCallbackView, {
+ global: {
+ stubs: {
+ AuthLayout: { template: '
' },
+ Icon: true,
+ RouterLink: { template: ' ' },
+ transition: false,
+ },
+ },
+ })
+
+ await flushPromises()
+
+ expect(prepareOAuthBindAccessTokenCookieMock).toHaveBeenCalledTimes(1)
+ expect(locationState.current.href).toContain('mode=mp')
+ })
+
+ it('ignores legacy aggregate wechat settings and reuses the query mode during bind recovery', async () => {
+ routeState.query = {
+ wechat_bind_existing: '1',
+ mode: 'open',
+ redirect: '/profile',
+ }
+ appStoreState.cachedPublicSettings = {
+ wechat_oauth_enabled: true,
+ }
+ appStoreState.publicSettingsLoaded = true
+ getAuthTokenMock.mockReturnValue('current-auth-token')
+
+ mount(WechatCallbackView, {
+ global: {
+ stubs: {
+ AuthLayout: { template: '
' },
+ Icon: true,
+ RouterLink: { template: ' ' },
+ transition: false,
+ },
+ },
+ })
+
+ await flushPromises()
+
+ expect(prepareOAuthBindAccessTokenCookieMock).toHaveBeenCalledTimes(1)
+ expect(locationState.current.href).toContain('mode=open')
+ })
+
it('does not send adoption decisions during the initial exchange', async () => {
exchangePendingOAuthCompletionMock.mockResolvedValue({
access_token: 'access-token',
diff --git a/frontend/src/views/user/ProfileView.vue b/frontend/src/views/user/ProfileView.vue
index 14d7efea..9e72776f 100644
--- a/frontend/src/views/user/ProfileView.vue
+++ b/frontend/src/views/user/ProfileView.vue
@@ -28,6 +28,8 @@
:oidc-enabled="oidcOAuthEnabled"
:oidc-provider-name="oidcOAuthProviderName"
:wechat-enabled="wechatOAuthEnabled"
+ :wechat-open-enabled="wechatOAuthOpenEnabled"
+ :wechat-mp-enabled="wechatOAuthMPEnabled"
/>
(undefined)
+const wechatOAuthMPEnabled = ref
(undefined)
const oidcOAuthEnabled = ref(false)
const oidcOAuthProviderName = ref('OIDC')
@@ -132,6 +136,12 @@ onMounted(async () => {
systemDefaultThreshold.value = settings.balance_low_notify_threshold ?? 0
linuxdoOAuthEnabled.value = settings.linuxdo_oauth_enabled ?? false
wechatOAuthEnabled.value = settings.wechat_oauth_enabled ?? false
+ wechatOAuthOpenEnabled.value = typeof settings.wechat_oauth_open_enabled === 'boolean'
+ ? settings.wechat_oauth_open_enabled
+ : undefined
+ wechatOAuthMPEnabled.value = typeof settings.wechat_oauth_mp_enabled === 'boolean'
+ ? settings.wechat_oauth_mp_enabled
+ : undefined
oidcOAuthEnabled.value = settings.oidc_oauth_enabled ?? false
oidcOAuthProviderName.value = settings.oidc_oauth_provider_name || 'OIDC'
})
From e12599c1b967c6349384c2ed954aa6590f976e16 Mon Sep 17 00:00:00 2001
From: IanShaw027
Date: Tue, 21 Apr 2026 02:08:04 +0800
Subject: [PATCH 086/189] fix settings auth source default persistence
---
.../internal/handler/admin/setting_handler.go | 58 ++++-
...tting_handler_auth_source_defaults_test.go | 123 ++++++++++
.../internal/service/auth_oauth_first_bind.go | 6 +-
backend/internal/service/auth_service.go | 12 +-
.../auth_service_identity_sync_test.go | 41 ++++
.../service/auth_service_register_test.go | 23 ++
backend/internal/service/setting_service.go | 210 +++++++++++++-----
7 files changed, 398 insertions(+), 75 deletions(-)
diff --git a/backend/internal/handler/admin/setting_handler.go b/backend/internal/handler/admin/setting_handler.go
index e5681208..f0e91f3a 100644
--- a/backend/internal/handler/admin/setting_handler.go
+++ b/backend/internal/handler/admin/setting_handler.go
@@ -1011,10 +1011,6 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
}(),
}
- if err := h.settingService.UpdateSettings(c.Request.Context(), settings); err != nil {
- response.ErrorFrom(c, err)
- return
- }
authSourceDefaults := &service.AuthSourceDefaultSettings{
Email: service.ProviderDefaultGrantSettings{
Balance: float64ValueOrDefault(req.AuthSourceDefaultEmailBalance, previousAuthSourceDefaults.Email.Balance),
@@ -1046,7 +1042,7 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
},
ForceEmailOnThirdPartySignup: boolValueOrDefault(req.ForceEmailOnThirdPartySignup, previousAuthSourceDefaults.ForceEmailOnThirdPartySignup),
}
- if err := h.settingService.UpdateAuthSourceDefaultSettings(c.Request.Context(), authSourceDefaults); err != nil {
+ if err := h.settingService.UpdateSettingsWithAuthSourceDefaults(c.Request.Context(), settings, authSourceDefaults); err != nil {
response.ErrorFrom(c, err)
return
}
@@ -1086,7 +1082,7 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
}
}
- h.auditSettingsUpdate(c, previousSettings, settings, req)
+ h.auditSettingsUpdate(c, previousSettings, settings, previousAuthSourceDefaults, authSourceDefaults, req)
// 重新获取设置返回
updatedSettings, err := h.settingService.GetAllSettings(c.Request.Context())
@@ -1245,12 +1241,12 @@ func hasPaymentFields(req UpdateSettingsRequest) bool {
req.PaymentCancelRateLimitUnit != nil || req.PaymentCancelRateLimitMode != nil
}
-func (h *SettingHandler) auditSettingsUpdate(c *gin.Context, before *service.SystemSettings, after *service.SystemSettings, req UpdateSettingsRequest) {
+func (h *SettingHandler) auditSettingsUpdate(c *gin.Context, before *service.SystemSettings, after *service.SystemSettings, beforeAuthSourceDefaults *service.AuthSourceDefaultSettings, afterAuthSourceDefaults *service.AuthSourceDefaultSettings, req UpdateSettingsRequest) {
if before == nil || after == nil {
return
}
- changed := diffSettings(before, after, req)
+ changed := diffSettings(before, after, beforeAuthSourceDefaults, afterAuthSourceDefaults, req)
if len(changed) == 0 {
return
}
@@ -1265,7 +1261,7 @@ func (h *SettingHandler) auditSettingsUpdate(c *gin.Context, before *service.Sys
)
}
-func diffSettings(before *service.SystemSettings, after *service.SystemSettings, req UpdateSettingsRequest) []string {
+func diffSettings(before *service.SystemSettings, after *service.SystemSettings, beforeAuthSourceDefaults *service.AuthSourceDefaultSettings, afterAuthSourceDefaults *service.AuthSourceDefaultSettings, req UpdateSettingsRequest) []string {
changed := make([]string, 0, 20)
if before.RegistrationEnabled != after.RegistrationEnabled {
changed = append(changed, "registration_enabled")
@@ -1535,6 +1531,50 @@ func diffSettings(before *service.SystemSettings, after *service.SystemSettings,
if !equalNotifyEmailEntries(before.AccountQuotaNotifyEmails, after.AccountQuotaNotifyEmails) {
changed = append(changed, "account_quota_notify_emails")
}
+ changed = appendAuthSourceDefaultChanges(changed, beforeAuthSourceDefaults, afterAuthSourceDefaults)
+ return changed
+}
+
+func appendAuthSourceDefaultChanges(changed []string, before *service.AuthSourceDefaultSettings, after *service.AuthSourceDefaultSettings) []string {
+ if before == nil {
+ before = &service.AuthSourceDefaultSettings{}
+ }
+ if after == nil {
+ after = &service.AuthSourceDefaultSettings{}
+ }
+
+ type providerDefaultGrantField struct {
+ name string
+ before service.ProviderDefaultGrantSettings
+ after service.ProviderDefaultGrantSettings
+ }
+
+ fields := []providerDefaultGrantField{
+ {name: "email", before: before.Email, after: after.Email},
+ {name: "linuxdo", before: before.LinuxDo, after: after.LinuxDo},
+ {name: "oidc", before: before.OIDC, after: after.OIDC},
+ {name: "wechat", before: before.WeChat, after: after.WeChat},
+ }
+ for _, field := range fields {
+ if field.before.Balance != field.after.Balance {
+ changed = append(changed, "auth_source_default_"+field.name+"_balance")
+ }
+ if field.before.Concurrency != field.after.Concurrency {
+ changed = append(changed, "auth_source_default_"+field.name+"_concurrency")
+ }
+ if !equalDefaultSubscriptions(field.before.Subscriptions, field.after.Subscriptions) {
+ changed = append(changed, "auth_source_default_"+field.name+"_subscriptions")
+ }
+ if field.before.GrantOnSignup != field.after.GrantOnSignup {
+ changed = append(changed, "auth_source_default_"+field.name+"_grant_on_signup")
+ }
+ if field.before.GrantOnFirstBind != field.after.GrantOnFirstBind {
+ changed = append(changed, "auth_source_default_"+field.name+"_grant_on_first_bind")
+ }
+ }
+ if before.ForceEmailOnThirdPartySignup != after.ForceEmailOnThirdPartySignup {
+ changed = append(changed, "force_email_on_third_party_signup")
+ }
return changed
}
diff --git a/backend/internal/handler/admin/setting_handler_auth_source_defaults_test.go b/backend/internal/handler/admin/setting_handler_auth_source_defaults_test.go
index bf51fc68..cef531e0 100644
--- a/backend/internal/handler/admin/setting_handler_auth_source_defaults_test.go
+++ b/backend/internal/handler/admin/setting_handler_auth_source_defaults_test.go
@@ -4,6 +4,7 @@ import (
"bytes"
"context"
"encoding/json"
+ "errors"
"net/http"
"net/http/httptest"
"testing"
@@ -66,6 +67,58 @@ func (s *settingHandlerRepoStub) Delete(ctx context.Context, key string) error {
panic("unexpected Delete call")
}
+type failingAuthSourceSettingsRepoStub struct {
+ values map[string]string
+ err error
+}
+
+func (s *failingAuthSourceSettingsRepoStub) Get(ctx context.Context, key string) (*service.Setting, error) {
+ panic("unexpected Get call")
+}
+
+func (s *failingAuthSourceSettingsRepoStub) GetValue(ctx context.Context, key string) (string, error) {
+ panic("unexpected GetValue call")
+}
+
+func (s *failingAuthSourceSettingsRepoStub) Set(ctx context.Context, key, value string) error {
+ panic("unexpected Set call")
+}
+
+func (s *failingAuthSourceSettingsRepoStub) GetMultiple(ctx context.Context, keys []string) (map[string]string, error) {
+ out := make(map[string]string, len(keys))
+ for _, key := range keys {
+ if value, ok := s.values[key]; ok {
+ out[key] = value
+ }
+ }
+ return out, nil
+}
+
+func (s *failingAuthSourceSettingsRepoStub) SetMultiple(ctx context.Context, settings map[string]string) error {
+ if _, ok := settings[service.SettingKeyAuthSourceDefaultEmailBalance]; ok {
+ return s.err
+ }
+ for key, value := range settings {
+ if s.values == nil {
+ s.values = map[string]string{}
+ }
+ s.values[key] = value
+ }
+ return nil
+}
+
+func (s *failingAuthSourceSettingsRepoStub) GetAll(ctx context.Context) (map[string]string, error) {
+ out := make(map[string]string, len(s.values))
+ for key, value := range s.values {
+ out[key] = value
+ }
+ return out, nil
+}
+
+func (s *failingAuthSourceSettingsRepoStub) Delete(ctx context.Context, key string) error {
+ panic("unexpected Delete call")
+}
+
func TestSettingHandler_GetSettings_InjectsAuthSourceDefaults(t *testing.T) {
gin.SetMode(gin.TestMode)
repo := &settingHandlerRepoStub{
@@ -221,3 +274,73 @@ func TestSettingHandler_UpdateSettings_RejectsInvalidPaymentVisibleMethodSource(
require.Equal(t, http.StatusBadRequest, rec.Code)
require.NotContains(t, repo.values, service.SettingPaymentVisibleMethodAlipaySource)
}
+
+func TestSettingHandler_UpdateSettings_DoesNotPersistPartialSystemSettingsWhenAuthSourceDefaultsFail(t *testing.T) {
+ gin.SetMode(gin.TestMode)
+ repo := &failingAuthSourceSettingsRepoStub{
+ values: map[string]string{
+ service.SettingKeyRegistrationEnabled: "false",
+ service.SettingKeyPromoCodeEnabled: "true",
+ service.SettingKeyAuthSourceDefaultEmailBalance: "9.5",
+ service.SettingKeyAuthSourceDefaultEmailConcurrency: "8",
+ service.SettingKeyAuthSourceDefaultEmailSubscriptions: `[{"group_id":31,"validity_days":15}]`,
+ },
+ err: errors.New("write auth source defaults failed"),
+ }
+ svc := service.NewSettingService(repo, &config.Config{Default: config.DefaultConfig{UserConcurrency: 5}})
+ handler := NewSettingHandler(svc, nil, nil, nil, nil, nil)
+
+ body := map[string]any{
+ "registration_enabled": true,
+ "promo_code_enabled": true,
+ "auth_source_default_email_balance": 12.75,
+ }
+ rawBody, err := json.Marshal(body)
+ require.NoError(t, err)
+
+ rec := httptest.NewRecorder()
+ c, _ := gin.CreateTestContext(rec)
+ c.Request = httptest.NewRequest(http.MethodPut, "/api/v1/admin/settings", bytes.NewReader(rawBody))
+ c.Request.Header.Set("Content-Type", "application/json")
+
+ handler.UpdateSettings(c)
+
+ require.Equal(t, http.StatusInternalServerError, rec.Code)
+ require.Equal(t, "false", repo.values[service.SettingKeyRegistrationEnabled])
+ require.Equal(t, "9.5", repo.values[service.SettingKeyAuthSourceDefaultEmailBalance])
+}
+
+func TestDiffSettings_IncludesAuthSourceDefaultsAndForceEmail(t *testing.T) {
+ changed := diffSettings(
+ &service.SystemSettings{},
+ &service.SystemSettings{},
+ &service.AuthSourceDefaultSettings{
+ Email: service.ProviderDefaultGrantSettings{
+ Balance: 0,
+ Concurrency: 5,
+ Subscriptions: nil,
+ GrantOnSignup: true,
+ GrantOnFirstBind: false,
+ },
+ ForceEmailOnThirdPartySignup: false,
+ },
+ &service.AuthSourceDefaultSettings{
+ Email: service.ProviderDefaultGrantSettings{
+ Balance: 12.5,
+ Concurrency: 7,
+ Subscriptions: []service.DefaultSubscriptionSetting{{GroupID: 21, ValidityDays: 30}},
+ GrantOnSignup: false,
+ GrantOnFirstBind: true,
+ },
+ ForceEmailOnThirdPartySignup: true,
+ },
+ UpdateSettingsRequest{},
+ )
+
+ require.Contains(t, changed, "auth_source_default_email_balance")
+ require.Contains(t, changed, "auth_source_default_email_concurrency")
+ require.Contains(t, changed, "auth_source_default_email_subscriptions")
+ require.Contains(t, changed, "auth_source_default_email_grant_on_signup")
+ require.Contains(t, changed, "auth_source_default_email_grant_on_first_bind")
+ require.Contains(t, changed, "force_email_on_third_party_signup")
+}
diff --git a/backend/internal/service/auth_oauth_first_bind.go b/backend/internal/service/auth_oauth_first_bind.go
index 422a2a88..b0da5069 100644
--- a/backend/internal/service/auth_oauth_first_bind.go
+++ b/backend/internal/service/auth_oauth_first_bind.go
@@ -44,13 +44,11 @@ func (s *AuthService) applyProviderDefaultSettingsOnFirstBind(
userID int64,
providerType string,
) error {
- defaults, err := s.settingService.GetAuthSourceDefaultSettings(ctx)
+ providerDefaults, enabled, err := s.settingService.ResolveAuthSourceGrantSettings(ctx, providerType, true)
if err != nil {
return fmt.Errorf("load auth source defaults: %w", err)
}
-
- providerDefaults, ok := authSourceSignupSettings(defaults, providerType)
- if !ok || !providerDefaults.GrantOnFirstBind {
+ if !enabled {
return nil
}
diff --git a/backend/internal/service/auth_service.go b/backend/internal/service/auth_service.go
index d63a8753..5b6e5fef 100644
--- a/backend/internal/service/auth_service.go
+++ b/backend/internal/service/auth_service.go
@@ -716,20 +716,18 @@ func (s *AuthService) resolveSignupGrantPlan(ctx context.Context, signupSource s
plan.Concurrency = s.settingService.GetDefaultConcurrency(ctx)
plan.Subscriptions = s.settingService.GetDefaultSubscriptions(ctx)
- defaults, err := s.settingService.GetAuthSourceDefaultSettings(ctx)
+ resolved, enabled, err := s.settingService.ResolveAuthSourceGrantSettings(ctx, signupSource, false)
if err != nil {
logger.LegacyPrintf("service.auth", "[Auth] Failed to load auth source signup defaults for %s: %v", signupSource, err)
return plan
}
-
- providerDefaults, ok := authSourceSignupSettings(defaults, signupSource)
- if !ok || !providerDefaults.GrantOnSignup {
+ if !enabled {
return plan
}
- plan.Balance = providerDefaults.Balance
- plan.Concurrency = providerDefaults.Concurrency
- plan.Subscriptions = providerDefaults.Subscriptions
+ plan.Balance = resolved.Balance
+ plan.Concurrency = resolved.Concurrency
+ plan.Subscriptions = resolved.Subscriptions
return plan
}
diff --git a/backend/internal/service/auth_service_identity_sync_test.go b/backend/internal/service/auth_service_identity_sync_test.go
index 85c13604..4d2a840f 100644
--- a/backend/internal/service/auth_service_identity_sync_test.go
+++ b/backend/internal/service/auth_service_identity_sync_test.go
@@ -321,6 +321,47 @@ func TestAuthServiceLogin_AppliesEmailFirstBindDefaultsOnlyWhenEmailIdentityIsNe
require.Equal(t, 1, countProviderGrantRecords(t, client, user.ID, "email", "first_bind"))
}
+func TestAuthServiceLogin_MergesEmailFirstBindSourceOverridesWithGlobalDefaults(t *testing.T) {
+ assigner := &authIdentityDefaultSubAssignerStub{}
+ svc, _, client := newAuthServiceWithEnt(t, map[string]string{
+ service.SettingKeyRegistrationEnabled: "true",
+ service.SettingKeyDefaultSubscriptions: `[{"group_id":21,"validity_days":14}]`,
+ service.SettingKeyAuthSourceDefaultEmailBalance: "8.5",
+ service.SettingKeyAuthSourceDefaultEmailConcurrency: "5",
+ service.SettingKeyAuthSourceDefaultEmailSubscriptions: `[]`,
+ service.SettingKeyAuthSourceDefaultEmailGrantOnFirstBind: "true",
+ }, assigner)
+ ctx := context.Background()
+
+ passwordHash, err := svc.HashPassword("password")
+ require.NoError(t, err)
+ user, err := client.User.Create().
+ SetEmail("merged-first-bind@example.com").
+ SetUsername("merged-user").
+ SetPasswordHash(passwordHash).
+ SetBalance(1.5).
+ SetConcurrency(2).
+ SetRole(service.RoleUser).
+ SetStatus(service.StatusActive).
+ Save(ctx)
+ require.NoError(t, err)
+
+ token, gotUser, err := svc.Login(ctx, user.Email, "password")
+ require.NoError(t, err)
+ require.NotEmpty(t, token)
+ require.NotNil(t, gotUser)
+ svc.RecordSuccessfulLogin(ctx, user.ID)
+
+ storedUser, err := client.User.Get(ctx, user.ID)
+ require.NoError(t, err)
+ require.Equal(t, 10.0, storedUser.Balance)
+ require.Equal(t, 4, storedUser.Concurrency)
+ require.Len(t, assigner.calls, 1)
+ require.Equal(t, int64(21), assigner.calls[0].GroupID)
+ require.Equal(t, 14, assigner.calls[0].ValidityDays)
+ require.Equal(t, 1, countProviderGrantRecords(t, client, user.ID, "email", "first_bind"))
+}
+
func TestAuthServiceLogin_DoesNotApplyEmailFirstBindDefaultsWhenIdentityAlreadyExists(t *testing.T) {
assigner := &authIdentityDefaultSubAssignerStub{}
svc, _, client := newAuthServiceWithEnt(t, map[string]string{
diff --git a/backend/internal/service/auth_service_register_test.go b/backend/internal/service/auth_service_register_test.go
index e0dce982..dbd18a20 100644
--- a/backend/internal/service/auth_service_register_test.go
+++ b/backend/internal/service/auth_service_register_test.go
@@ -584,6 +584,29 @@ func TestAuthService_Register_GrantOnSignupFalseFallsBackToGlobalDefaults(t *tes
require.Equal(t, 5, assigner.calls[0].ValidityDays)
}
+func TestAuthService_Register_GrantOnSignupMergesSourceOverridesWithGlobalDefaults(t *testing.T) {
+ repo := &userRepoStub{nextID: 54}
+ assigner := &defaultSubscriptionAssignerStub{}
+ service := newAuthService(repo, map[string]string{
+ SettingKeyRegistrationEnabled: "true",
+ SettingKeyDefaultSubscriptions: `[{"group_id":31,"validity_days":5}]`,
+ SettingKeyAuthSourceDefaultEmailBalance: "9.5",
+ SettingKeyAuthSourceDefaultEmailConcurrency: "5",
+ SettingKeyAuthSourceDefaultEmailSubscriptions: `[]`,
+ SettingKeyAuthSourceDefaultEmailGrantOnSignup: "true",
+ }, nil)
+ service.defaultSubAssigner = assigner
+
+ _, user, err := service.Register(context.Background(), "email-merged@test.com", "password")
+ require.NoError(t, err)
+ require.NotNil(t, user)
+ require.Equal(t, 9.5, user.Balance)
+ require.Equal(t, 2, user.Concurrency)
+ require.Len(t, assigner.calls, 1)
+ require.Equal(t, int64(31), assigner.calls[0].GroupID)
+ require.Equal(t, 5, assigner.calls[0].ValidityDays)
+}
+
func TestAuthService_LoginOrRegisterOAuthWithTokenPair_UsesLinuxDoAuthSourceDefaultsOnSignup(t *testing.T) {
repo := &userRepoStub{nextID: 61}
assigner := &defaultSubscriptionAssignerStub{}
diff --git a/backend/internal/service/setting_service.go b/backend/internal/service/setting_service.go
index 58246111..8c879d52 100644
--- a/backend/internal/service/setting_service.go
+++ b/backend/internal/service/setting_service.go
@@ -569,12 +569,47 @@ func parseCustomMenuItemURLs(raw string) []string {
// UpdateSettings 更新系统设置
func (s *SettingService) UpdateSettings(ctx context.Context, settings *SystemSettings) error {
- if err := s.validateDefaultSubscriptionGroups(ctx, settings.DefaultSubscriptions); err != nil {
+ updates, err := s.buildSystemSettingsUpdates(ctx, settings)
+ if err != nil {
return err
}
+
+ err = s.settingRepo.SetMultiple(ctx, updates)
+ if err == nil {
+ s.refreshCachedSettings(settings)
+ }
+ return err
+}
+
+// UpdateSettingsWithAuthSourceDefaults persists system settings and auth-source defaults in a single write.
+func (s *SettingService) UpdateSettingsWithAuthSourceDefaults(ctx context.Context, settings *SystemSettings, authDefaults *AuthSourceDefaultSettings) error {
+ updates, err := s.buildSystemSettingsUpdates(ctx, settings)
+ if err != nil {
+ return err
+ }
+
+ authSourceUpdates, err := s.buildAuthSourceDefaultUpdates(ctx, authDefaults)
+ if err != nil {
+ return err
+ }
+ for key, value := range authSourceUpdates {
+ updates[key] = value
+ }
+
+ err = s.settingRepo.SetMultiple(ctx, updates)
+ if err == nil {
+ s.refreshCachedSettings(settings)
+ }
+ return err
+}
+
+func (s *SettingService) buildSystemSettingsUpdates(ctx context.Context, settings *SystemSettings) (map[string]string, error) {
+ if err := s.validateDefaultSubscriptionGroups(ctx, settings.DefaultSubscriptions); err != nil {
+ return nil, err
+ }
normalizedWhitelist, err := NormalizeRegistrationEmailSuffixWhitelist(settings.RegistrationEmailSuffixWhitelist)
if err != nil {
- return infraerrors.BadRequest("INVALID_REGISTRATION_EMAIL_SUFFIX_WHITELIST", err.Error())
+ return nil, infraerrors.BadRequest("INVALID_REGISTRATION_EMAIL_SUFFIX_WHITELIST", err.Error())
}
if normalizedWhitelist == nil {
normalizedWhitelist = []string{}
@@ -582,11 +617,11 @@ func (s *SettingService) UpdateSettings(ctx context.Context, settings *SystemSet
settings.RegistrationEmailSuffixWhitelist = normalizedWhitelist
alipaySource, err := normalizeVisibleMethodSettingSource("alipay", settings.PaymentVisibleMethodAlipaySource, settings.PaymentVisibleMethodAlipayEnabled)
if err != nil {
- return err
+ return nil, err
}
wxpaySource, err := normalizeVisibleMethodSettingSource("wxpay", settings.PaymentVisibleMethodWxpaySource, settings.PaymentVisibleMethodWxpayEnabled)
if err != nil {
- return err
+ return nil, err
}
settings.PaymentVisibleMethodAlipaySource = alipaySource
settings.PaymentVisibleMethodWxpaySource = wxpaySource
@@ -598,7 +633,7 @@ func (s *SettingService) UpdateSettings(ctx context.Context, settings *SystemSet
updates[SettingKeyEmailVerifyEnabled] = strconv.FormatBool(settings.EmailVerifyEnabled)
registrationEmailSuffixWhitelistJSON, err := json.Marshal(settings.RegistrationEmailSuffixWhitelist)
if err != nil {
- return fmt.Errorf("marshal registration email suffix whitelist: %w", err)
+ return nil, fmt.Errorf("marshal registration email suffix whitelist: %w", err)
}
updates[SettingKeyRegistrationEmailSuffixWhitelist] = string(registrationEmailSuffixWhitelistJSON)
updates[SettingKeyPromoCodeEnabled] = strconv.FormatBool(settings.PromoCodeEnabled)
@@ -677,7 +712,7 @@ func (s *SettingService) UpdateSettings(ctx context.Context, settings *SystemSet
updates[SettingKeyTableDefaultPageSize] = strconv.Itoa(tableDefaultPageSize)
tablePageSizeOptionsJSON, err := json.Marshal(tablePageSizeOptions)
if err != nil {
- return fmt.Errorf("marshal table page size options: %w", err)
+ return nil, fmt.Errorf("marshal table page size options: %w", err)
}
updates[SettingKeyTablePageSizeOptions] = string(tablePageSizeOptionsJSON)
updates[SettingKeyCustomMenuItems] = settings.CustomMenuItems
@@ -688,7 +723,7 @@ func (s *SettingService) UpdateSettings(ctx context.Context, settings *SystemSet
updates[SettingKeyDefaultBalance] = strconv.FormatFloat(settings.DefaultBalance, 'f', 8, 64)
defaultSubsJSON, err := json.Marshal(settings.DefaultSubscriptions)
if err != nil {
- return fmt.Errorf("marshal default subscriptions: %w", err)
+ return nil, fmt.Errorf("marshal default subscriptions: %w", err)
}
updates[SettingKeyDefaultSubscriptions] = string(defaultSubsJSON)
@@ -738,37 +773,66 @@ func (s *SettingService) UpdateSettings(ctx context.Context, settings *SystemSet
updates[SettingKeyAccountQuotaNotifyEnabled] = strconv.FormatBool(settings.AccountQuotaNotifyEnabled)
updates[SettingKeyAccountQuotaNotifyEmails] = MarshalNotifyEmails(settings.AccountQuotaNotifyEmails)
- err = s.settingRepo.SetMultiple(ctx, updates)
- if err == nil {
- // 先使 inflight singleflight 失效,再刷新缓存,缩小旧值覆盖新值的竞态窗口
- versionBoundsSF.Forget("version_bounds")
- versionBoundsCache.Store(&cachedVersionBounds{
- min: settings.MinClaudeCodeVersion,
- max: settings.MaxClaudeCodeVersion,
- expiresAt: time.Now().Add(versionBoundsCacheTTL).UnixNano(),
- })
- backendModeSF.Forget("backend_mode")
- backendModeCache.Store(&cachedBackendMode{
- value: settings.BackendModeEnabled,
- expiresAt: time.Now().Add(backendModeCacheTTL).UnixNano(),
- })
- gatewayForwardingSF.Forget("gateway_forwarding")
- gatewayForwardingCache.Store(&cachedGatewayForwardingSettings{
- fingerprintUnification: settings.EnableFingerprintUnification,
- metadataPassthrough: settings.EnableMetadataPassthrough,
- cchSigning: settings.EnableCCHSigning,
- expiresAt: time.Now().Add(gatewayForwardingCacheTTL).UnixNano(),
- })
- openAIAdvancedSchedulerSettingSF.Forget(openAIAdvancedSchedulerSettingKey)
- openAIAdvancedSchedulerSettingCache.Store(&cachedOpenAIAdvancedSchedulerSetting{
- enabled: settings.OpenAIAdvancedSchedulerEnabled,
- expiresAt: time.Now().Add(openAIAdvancedSchedulerSettingCacheTTL).UnixNano(),
- })
- if s.onUpdate != nil {
- s.onUpdate() // Invalidate cache after settings update
+ return updates, nil
+}
+
+func (s *SettingService) buildAuthSourceDefaultUpdates(ctx context.Context, settings *AuthSourceDefaultSettings) (map[string]string, error) {
+ if settings == nil {
+ return nil, nil
+ }
+
+ for _, subscriptions := range [][]DefaultSubscriptionSetting{
+ settings.Email.Subscriptions,
+ settings.LinuxDo.Subscriptions,
+ settings.OIDC.Subscriptions,
+ settings.WeChat.Subscriptions,
+ } {
+ if err := s.validateDefaultSubscriptionGroups(ctx, subscriptions); err != nil {
+ return nil, err
}
}
- return err
+
+ updates := make(map[string]string, 21)
+ writeProviderDefaultGrantUpdates(updates, emailAuthSourceDefaultKeys, settings.Email)
+ writeProviderDefaultGrantUpdates(updates, linuxDoAuthSourceDefaultKeys, settings.LinuxDo)
+ writeProviderDefaultGrantUpdates(updates, oidcAuthSourceDefaultKeys, settings.OIDC)
+ writeProviderDefaultGrantUpdates(updates, weChatAuthSourceDefaultKeys, settings.WeChat)
+ updates[SettingKeyForceEmailOnThirdPartySignup] = strconv.FormatBool(settings.ForceEmailOnThirdPartySignup)
+ return updates, nil
+}
+
+func (s *SettingService) refreshCachedSettings(settings *SystemSettings) {
+ if settings == nil {
+ return
+ }
+
+ // 先使 inflight singleflight 失效,再刷新缓存,缩小旧值覆盖新值的竞态窗口
+ versionBoundsSF.Forget("version_bounds")
+ versionBoundsCache.Store(&cachedVersionBounds{
+ min: settings.MinClaudeCodeVersion,
+ max: settings.MaxClaudeCodeVersion,
+ expiresAt: time.Now().Add(versionBoundsCacheTTL).UnixNano(),
+ })
+ backendModeSF.Forget("backend_mode")
+ backendModeCache.Store(&cachedBackendMode{
+ value: settings.BackendModeEnabled,
+ expiresAt: time.Now().Add(backendModeCacheTTL).UnixNano(),
+ })
+ gatewayForwardingSF.Forget("gateway_forwarding")
+ gatewayForwardingCache.Store(&cachedGatewayForwardingSettings{
+ fingerprintUnification: settings.EnableFingerprintUnification,
+ metadataPassthrough: settings.EnableMetadataPassthrough,
+ cchSigning: settings.EnableCCHSigning,
+ expiresAt: time.Now().Add(gatewayForwardingCacheTTL).UnixNano(),
+ })
+ openAIAdvancedSchedulerSettingSF.Forget(openAIAdvancedSchedulerSettingKey)
+ openAIAdvancedSchedulerSettingCache.Store(&cachedOpenAIAdvancedSchedulerSetting{
+ enabled: settings.OpenAIAdvancedSchedulerEnabled,
+ expiresAt: time.Now().Add(openAIAdvancedSchedulerSettingCacheTTL).UnixNano(),
+ })
+ if s.onUpdate != nil {
+ s.onUpdate() // Invalidate cache after settings update
+ }
}
func (s *SettingService) validateDefaultSubscriptionGroups(ctx context.Context, items []DefaultSubscriptionSetting) error {
@@ -1067,29 +1131,43 @@ func (s *SettingService) GetAuthSourceDefaultSettings(ctx context.Context) (*Aut
}, nil
}
+func (s *SettingService) ResolveAuthSourceGrantSettings(ctx context.Context, signupSource string, firstBind bool) (ProviderDefaultGrantSettings, bool, error) {
+ result := ProviderDefaultGrantSettings{
+ Balance: s.GetDefaultBalance(ctx),
+ Concurrency: s.GetDefaultConcurrency(ctx),
+ Subscriptions: s.GetDefaultSubscriptions(ctx),
+ }
+
+ defaults, err := s.GetAuthSourceDefaultSettings(ctx)
+ if err != nil {
+ return result, false, err
+ }
+
+ providerDefaults, ok := authSourceSignupSettings(defaults, signupSource)
+ if !ok {
+ return result, false, nil
+ }
+
+ enabled := providerDefaults.GrantOnSignup
+ if firstBind {
+ enabled = providerDefaults.GrantOnFirstBind
+ }
+ if !enabled {
+ return result, false, nil
+ }
+
+ return mergeProviderDefaultGrantSettings(result, providerDefaults), true, nil
+}
+
func (s *SettingService) UpdateAuthSourceDefaultSettings(ctx context.Context, settings *AuthSourceDefaultSettings) error {
- if settings == nil {
+ updates, err := s.buildAuthSourceDefaultUpdates(ctx, settings)
+ if err != nil {
+ return err
+ }
+ if len(updates) == 0 {
return nil
}
- for _, subscriptions := range [][]DefaultSubscriptionSetting{
- settings.Email.Subscriptions,
- settings.LinuxDo.Subscriptions,
- settings.OIDC.Subscriptions,
- settings.WeChat.Subscriptions,
- } {
- if err := s.validateDefaultSubscriptionGroups(ctx, subscriptions); err != nil {
- return err
- }
- }
-
- updates := make(map[string]string, 21)
- writeProviderDefaultGrantUpdates(updates, emailAuthSourceDefaultKeys, settings.Email)
- writeProviderDefaultGrantUpdates(updates, linuxDoAuthSourceDefaultKeys, settings.LinuxDo)
- writeProviderDefaultGrantUpdates(updates, oidcAuthSourceDefaultKeys, settings.OIDC)
- writeProviderDefaultGrantUpdates(updates, weChatAuthSourceDefaultKeys, settings.WeChat)
- updates[SettingKeyForceEmailOnThirdPartySignup] = strconv.FormatBool(settings.ForceEmailOnThirdPartySignup)
-
if err := s.settingRepo.SetMultiple(ctx, updates); err != nil {
return fmt.Errorf("update auth source default settings: %w", err)
}
@@ -1594,6 +1672,28 @@ func writeProviderDefaultGrantUpdates(updates map[string]string, keys authSource
updates[keys.grantOnFirstBind] = strconv.FormatBool(settings.GrantOnFirstBind)
}
+func mergeProviderDefaultGrantSettings(globalDefaults ProviderDefaultGrantSettings, providerDefaults ProviderDefaultGrantSettings) ProviderDefaultGrantSettings {
+ result := ProviderDefaultGrantSettings{
+ Balance: globalDefaults.Balance,
+ Concurrency: globalDefaults.Concurrency,
+ Subscriptions: append([]DefaultSubscriptionSetting(nil), globalDefaults.Subscriptions...),
+ GrantOnSignup: providerDefaults.GrantOnSignup,
+ GrantOnFirstBind: providerDefaults.GrantOnFirstBind,
+ }
+
+ if providerDefaults.Balance != defaultAuthSourceBalance {
+ result.Balance = providerDefaults.Balance
+ }
+ if providerDefaults.Concurrency > 0 && providerDefaults.Concurrency != defaultAuthSourceConcurrency {
+ result.Concurrency = providerDefaults.Concurrency
+ }
+ if len(providerDefaults.Subscriptions) > 0 {
+ result.Subscriptions = append([]DefaultSubscriptionSetting(nil), providerDefaults.Subscriptions...)
+ }
+
+ return result
+}
+
func parseTablePreferences(defaultPageSizeRaw, optionsRaw string) (int, []int) {
defaultPageSize := 20
if v, err := strconv.Atoi(strings.TrimSpace(defaultPageSizeRaw)); err == nil {
From a27a7add3da1388d62b6f365e86b9b4ff99283c8 Mon Sep 17 00:00:00 2001
From: IanShaw027
Date: Tue, 21 Apr 2026 02:08:34 +0800
Subject: [PATCH 087/189] fix payment resume result consistency
---
.../internal/service/payment_resume_lookup.go | 9 ++
.../service/payment_resume_lookup_test.go | 84 ++++++++++++++++
.../components/payment/PaymentStatusPanel.vue | 6 +-
.../__tests__/PaymentStatusPanel.spec.ts | 99 +++++++++++++++++++
frontend/src/views/user/PaymentResultView.vue | 28 +++---
.../user/__tests__/PaymentResultView.spec.ts | 52 ++++++++++
6 files changed, 264 insertions(+), 14 deletions(-)
create mode 100644 frontend/src/components/payment/__tests__/PaymentStatusPanel.spec.ts
diff --git a/backend/internal/service/payment_resume_lookup.go b/backend/internal/service/payment_resume_lookup.go
index 69033afd..048e489a 100644
--- a/backend/internal/service/payment_resume_lookup.go
+++ b/backend/internal/service/payment_resume_lookup.go
@@ -30,6 +30,15 @@ func (s *PaymentService) GetPublicOrderByResumeToken(ctx context.Context, token
if claims.PaymentType != "" && strings.TrimSpace(order.PaymentType) != claims.PaymentType {
return nil, fmt.Errorf("resume token payment type mismatch")
}
+ if order.Status == OrderStatusPending || order.Status == OrderStatusExpired {
+ result := s.checkPaid(ctx, order)
+ if result == checkPaidResultAlreadyPaid {
+ order, err = s.entClient.PaymentOrder.Get(ctx, order.ID)
+ if err != nil {
+ return nil, fmt.Errorf("reload order by resume token: %w", err)
+ }
+ }
+ }
return order, nil
}
diff --git a/backend/internal/service/payment_resume_lookup_test.go b/backend/internal/service/payment_resume_lookup_test.go
index d411398e..39ea24e7 100644
--- a/backend/internal/service/payment_resume_lookup_test.go
+++ b/backend/internal/service/payment_resume_lookup_test.go
@@ -11,6 +11,35 @@ import (
"github.com/stretchr/testify/require"
)
+type paymentResumeLookupProvider struct {
+ queryCount int
+}
+
+func (p *paymentResumeLookupProvider) Name() string { return "resume-lookup-provider" }
+
+func (p *paymentResumeLookupProvider) ProviderKey() string { return payment.TypeAlipay }
+
+func (p *paymentResumeLookupProvider) SupportedTypes() []payment.PaymentType {
+ return []payment.PaymentType{payment.TypeAlipay}
+}
+
+func (p *paymentResumeLookupProvider) CreatePayment(context.Context, payment.CreatePaymentRequest) (*payment.CreatePaymentResponse, error) {
+ panic("unexpected call")
+}
+
+func (p *paymentResumeLookupProvider) QueryOrder(context.Context, string) (*payment.QueryOrderResponse, error) {
+ p.queryCount++
+ return &payment.QueryOrderResponse{Status: payment.ProviderStatusPending}, nil
+}
+
+func (p *paymentResumeLookupProvider) VerifyNotification(context.Context, string, map[string]string) (*payment.PaymentNotification, error) {
+ panic("unexpected call")
+}
+
+func (p *paymentResumeLookupProvider) Refund(context.Context, payment.RefundRequest) (*payment.RefundResponse, error) {
+ panic("unexpected call")
+}
+
func TestGetPublicOrderByResumeTokenReturnsMatchingOrder(t *testing.T) {
ctx := context.Background()
client := newPaymentConfigServiceTestClient(t)
@@ -116,3 +145,58 @@ func TestGetPublicOrderByResumeTokenRejectsSnapshotMismatch(t *testing.T) {
require.Error(t, err)
require.Contains(t, err.Error(), "resume token")
}
+
+func TestGetPublicOrderByResumeTokenChecksUpstreamForPendingOrder(t *testing.T) {
+ ctx := context.Background()
+ client := newPaymentConfigServiceTestClient(t)
+ user, err := client.User.Create().
+ SetEmail("resume-refresh@example.com").
+ SetPasswordHash("hash").
+ SetUsername("resume-refresh-user").
+ Save(ctx)
+ require.NoError(t, err)
+
+ order, err := client.PaymentOrder.Create().
+ SetUserID(user.ID).
+ SetUserEmail(user.Email).
+ SetUserName(user.Username).
+ SetAmount(88).
+ SetPayAmount(88).
+ SetFeeRate(0).
+ SetRechargeCode("RESUME-PENDING").
+ SetOutTradeNo("sub2_resume_lookup_pending").
+ SetPaymentType(payment.TypeAlipay).
+ SetPaymentTradeNo("trade-pending").
+ SetOrderType(payment.OrderTypeBalance).
+ SetStatus(OrderStatusPending).
+ SetExpiresAt(time.Now().Add(time.Hour)).
+ SetClientIP("127.0.0.1").
+ SetSrcHost("api.example.com").
+ Save(ctx)
+ require.NoError(t, err)
+
+ resumeSvc := NewPaymentResumeService([]byte("0123456789abcdef0123456789abcdef"))
+ token, err := resumeSvc.CreateToken(ResumeTokenClaims{
+ OrderID: order.ID,
+ UserID: user.ID,
+ PaymentType: payment.TypeAlipay,
+ CanonicalReturnURL: "https://app.example.com/payment/result",
+ })
+ require.NoError(t, err)
+
+ registry := payment.NewRegistry()
+ provider := &paymentResumeLookupProvider{}
+ registry.Register(provider)
+
+ svc := &PaymentService{
+ entClient: client,
+ registry: registry,
+ resumeService: resumeSvc,
+ providersLoaded: true,
+ }
+
+ got, err := svc.GetPublicOrderByResumeToken(ctx, token)
+ require.NoError(t, err)
+ require.Equal(t, order.ID, got.ID)
+ require.Equal(t, 1, provider.queryCount)
+}
diff --git a/frontend/src/components/payment/PaymentStatusPanel.vue b/frontend/src/components/payment/PaymentStatusPanel.vue
index 8f5a5666..f48036b4 100644
--- a/frontend/src/components/payment/PaymentStatusPanel.vue
+++ b/frontend/src/components/payment/PaymentStatusPanel.vue
@@ -194,6 +194,10 @@ const countdownDisplay = computed(() => {
return m.toString().padStart(2, '0') + ':' + s.toString().padStart(2, '0')
})
+function isSuccessStatus(status: string | null | undefined): boolean {
+ return status === 'COMPLETED' || status === 'PAID' || status === 'RECHARGING'
+}
+
function reopenPopup() {
if (props.payUrl) {
const win = window.open(props.payUrl, 'paymentPopup', POPUP_WINDOW_FEATURES)
@@ -222,7 +226,7 @@ async function pollStatus() {
if (!props.orderId || outcome.value) return
const order = await paymentStore.pollOrderStatus(props.orderId)
if (!order) return
- if (order.status === 'COMPLETED' || order.status === 'PAID') {
+ if (isSuccessStatus(order.status)) {
cleanup()
paidOrder.value = order
setOutcome('success')
diff --git a/frontend/src/components/payment/__tests__/PaymentStatusPanel.spec.ts b/frontend/src/components/payment/__tests__/PaymentStatusPanel.spec.ts
new file mode 100644
index 00000000..d7017e1f
--- /dev/null
+++ b/frontend/src/components/payment/__tests__/PaymentStatusPanel.spec.ts
@@ -0,0 +1,99 @@
+import { afterEach, beforeEach, describe, expect, it, vi } from 'vitest'
+import { flushPromises, mount } from '@vue/test-utils'
+
+const pollOrderStatus = vi.hoisted(() => vi.fn())
+const cancelOrder = vi.hoisted(() => vi.fn())
+const showError = vi.hoisted(() => vi.fn())
+const toCanvas = vi.hoisted(() => vi.fn())
+
+vi.mock('vue-i18n', async () => {
+ const actual = await vi.importActual('vue-i18n')
+ return {
+ ...actual,
+ useI18n: () => ({
+ t: (key: string) => key,
+ }),
+ }
+})
+
+vi.mock('@/stores/payment', () => ({
+ usePaymentStore: () => ({
+ pollOrderStatus,
+ }),
+}))
+
+vi.mock('@/stores', () => ({
+ useAppStore: () => ({
+ showError,
+ }),
+}))
+
+vi.mock('@/api/payment', () => ({
+ paymentAPI: {
+ cancelOrder,
+ },
+}))
+
+vi.mock('qrcode', () => ({
+ default: {
+ toCanvas,
+ },
+}))
+
+import PaymentStatusPanel from '../PaymentStatusPanel.vue'
+
+const orderFactory = (status: string) => ({
+ id: 42,
+ user_id: 9,
+ amount: 88,
+ pay_amount: 88,
+ fee_rate: 0,
+ payment_type: 'alipay',
+ out_trade_no: 'sub2_20260420abcd1234',
+ status,
+ order_type: 'balance',
+ created_at: '2026-04-20T12:00:00Z',
+ expires_at: '2099-01-01T12:30:00Z',
+ refund_amount: 0,
+})
+
+describe('PaymentStatusPanel', () => {
+ beforeEach(() => {
+ vi.useFakeTimers()
+ pollOrderStatus.mockReset()
+ cancelOrder.mockReset()
+ showError.mockReset()
+ toCanvas.mockReset().mockResolvedValue(undefined)
+ })
+
+ afterEach(() => {
+ vi.useRealTimers()
+ })
+
+ it('treats RECHARGING as a successful terminal state', async () => {
+ pollOrderStatus.mockResolvedValue(orderFactory('RECHARGING'))
+
+ const wrapper = mount(PaymentStatusPanel, {
+ props: {
+ orderId: 42,
+ qrCode: 'https://pay.example.com/qr/42',
+ expiresAt: '2099-01-01T12:30:00Z',
+ paymentType: 'alipay',
+ orderType: 'balance',
+ },
+ global: {
+ stubs: {
+ Icon: true,
+ },
+ },
+ })
+
+ await flushPromises()
+ await vi.advanceTimersByTimeAsync(3000)
+ await flushPromises()
+
+ expect(pollOrderStatus).toHaveBeenCalledWith(42)
+ expect(wrapper.text()).toContain('payment.result.success')
+ expect(wrapper.emitted('success')).toHaveLength(1)
+ })
+})
diff --git a/frontend/src/views/user/PaymentResultView.vue b/frontend/src/views/user/PaymentResultView.vue
index d07bcc29..5177a7fd 100644
--- a/frontend/src/views/user/PaymentResultView.vue
+++ b/frontend/src/views/user/PaymentResultView.vue
@@ -177,6 +177,15 @@ function isPendingStatus(status: string | null | undefined): boolean {
return PENDING_STATUSES.has(normalizeOrderStatus(status))
}
+async function resolveOrderFromResumeToken(resumeToken: string): Promise {
+ try {
+ const result = await paymentAPI.resolveOrderPublicByResumeToken(resumeToken)
+ return result.data
+ } catch (_err: unknown) {
+ return null
+ }
+}
+
function clearStatusRefreshTimer(): void {
if (statusRefreshTimer !== null) {
clearTimeout(statusRefreshTimer)
@@ -230,15 +239,13 @@ onMounted(async () => {
}
}
- if (!order.value && resumeToken) {
- try {
- const result = await paymentAPI.resolveOrderPublicByResumeToken(resumeToken)
- order.value = result.data
+ if (resumeToken) {
+ const resolvedOrder = await resolveOrderFromResumeToken(resumeToken)
+ if (resolvedOrder) {
+ order.value = resolvedOrder
if (!orderId) {
- orderId = result.data.id
+ orderId = resolvedOrder.id
}
- } catch (_err: unknown) {
- // Resume token recovery failed; do not trust legacy public out_trade_no fallback.
}
}
@@ -278,12 +285,7 @@ onMounted(async () => {
const refreshOrder = async (): Promise => {
if (resumeToken) {
- try {
- const result = await paymentAPI.resolveOrderPublicByResumeToken(resumeToken)
- return result.data
- } catch (_err: unknown) {
- return null
- }
+ return await resolveOrderFromResumeToken(resumeToken)
}
if (orderId) {
diff --git a/frontend/src/views/user/__tests__/PaymentResultView.spec.ts b/frontend/src/views/user/__tests__/PaymentResultView.spec.ts
index c4e38523..86c55e25 100644
--- a/frontend/src/views/user/__tests__/PaymentResultView.spec.ts
+++ b/frontend/src/views/user/__tests__/PaymentResultView.spec.ts
@@ -116,6 +116,58 @@ describe('PaymentResultView', () => {
expect(wrapper.text()).not.toContain('payment.result.failed')
})
+ it('prefers the public resume-token result over a stale restored DB snapshot', async () => {
+ routeState.query = {
+ resume_token: 'resume-authoritative',
+ order_id: '42',
+ status: 'success',
+ }
+ window.localStorage.setItem(PAYMENT_RECOVERY_STORAGE_KEY, JSON.stringify({
+ orderId: 42,
+ amount: 88,
+ qrCode: '',
+ expiresAt: '2099-01-01T00:10:00.000Z',
+ paymentType: 'alipay',
+ payUrl: 'https://pay.example.com/session/42',
+ clientSecret: '',
+ payAmount: 88,
+ orderType: 'balance',
+ paymentMode: 'popup',
+ resumeToken: 'resume-authoritative',
+ createdAt: Date.UTC(2099, 0, 1, 0, 0, 0),
+ }))
+ pollOrderStatus.mockResolvedValue({
+ ...orderFactory('PENDING'),
+ amount: 88,
+ pay_amount: 88,
+ fee_rate: 0,
+ })
+ resolveOrderPublicByResumeToken.mockResolvedValue({
+ data: {
+ ...orderFactory('PAID'),
+ amount: 100,
+ pay_amount: 103,
+ fee_rate: 3,
+ },
+ })
+
+ const wrapper = mount(PaymentResultView, {
+ global: {
+ stubs: {
+ OrderStatusBadge: true,
+ },
+ },
+ })
+
+ await flushPromises()
+
+ expect(pollOrderStatus).toHaveBeenCalledWith(42)
+ expect(resolveOrderPublicByResumeToken).toHaveBeenCalledWith('resume-authoritative')
+ expect(wrapper.text()).toContain('payment.result.success')
+ expect(wrapper.text()).toContain('103.00')
+ expect(wrapper.text()).toContain('100.00')
+ })
+
it('refreshes a pending resume-token result until the order becomes paid', async () => {
vi.useFakeTimers()
routeState.query = {
From ebe75244150437cba1d8f70990b89ae4e2e3009b Mon Sep 17 00:00:00 2001
From: IanShaw027
Date: Tue, 21 Apr 2026 02:08:56 +0800
Subject: [PATCH 088/189] fix profile activity and migration remediation
---
.../handler/admin/admin_service_stub_test.go | 14 ++
.../admin/user_handler_activity_test.go | 120 +++++++++++
.../handler/auth_current_user_test.go | 15 +-
backend/internal/handler/user_handler.go | 88 +++++++-
backend/internal/handler/user_handler_test.go | 15 +-
backend/internal/service/user_service.go | 48 ++++-
backend/internal/service/user_service_test.go | 97 ++++++++-
.../__tests__/users.migrationReports.spec.ts | 28 +++
frontend/src/api/admin/users.ts | 36 ++++
.../profile/__tests__/ProfileInfoCard.spec.ts | 27 +++
.../AuthIdentityMigrationReportsView.vue | 198 ++++++++++++++++++
.../AuthIdentityMigrationReportsView.spec.ts | 45 +++-
12 files changed, 703 insertions(+), 28 deletions(-)
create mode 100644 backend/internal/handler/admin/user_handler_activity_test.go
diff --git a/backend/internal/handler/admin/admin_service_stub_test.go b/backend/internal/handler/admin/admin_service_stub_test.go
index 8ecadfdf..dfdf327b 100644
--- a/backend/internal/handler/admin/admin_service_stub_test.go
+++ b/backend/internal/handler/admin/admin_service_stub_test.go
@@ -45,6 +45,14 @@ type stubAdminService struct {
sortOrder string
calls int
}
+ lastListUsers struct {
+ page int
+ pageSize int
+ filters service.UserListFilters
+ sortBy string
+ sortOrder string
+ calls int
+ }
lastListProxies struct {
protocol string
status string
@@ -139,6 +147,12 @@ func newStubAdminService() *stubAdminService {
}
func (s *stubAdminService) ListUsers(ctx context.Context, page, pageSize int, filters service.UserListFilters, sortBy, sortOrder string) ([]service.User, int64, error) {
+ s.lastListUsers.page = page
+ s.lastListUsers.pageSize = pageSize
+ s.lastListUsers.filters = filters
+ s.lastListUsers.sortBy = sortBy
+ s.lastListUsers.sortOrder = sortOrder
+ s.lastListUsers.calls++
return s.users, int64(len(s.users)), nil
}
diff --git a/backend/internal/handler/admin/user_handler_activity_test.go b/backend/internal/handler/admin/user_handler_activity_test.go
new file mode 100644
index 00000000..ac29086d
--- /dev/null
+++ b/backend/internal/handler/admin/user_handler_activity_test.go
@@ -0,0 +1,120 @@
+//go:build unit
+
+package admin
+
+import (
+ "encoding/json"
+ "net/http"
+ "net/http/httptest"
+ "testing"
+ "time"
+
+ "github.com/Wei-Shaw/sub2api/internal/service"
+ "github.com/gin-gonic/gin"
+ "github.com/stretchr/testify/require"
+)
+
+func TestUserHandlerListIncludesActivityFieldsAndSortParams(t *testing.T) {
+ gin.SetMode(gin.TestMode)
+
+ lastLoginAt := time.Date(2026, 4, 20, 8, 0, 0, 0, time.UTC)
+ lastActiveAt := lastLoginAt.Add(30 * time.Minute)
+ lastUsedAt := lastLoginAt.Add(90 * time.Minute)
+
+ adminSvc := newStubAdminService()
+ adminSvc.users = []service.User{
+ {
+ ID: 7,
+ Email: "activity@example.com",
+ Username: "activity-user",
+ Role: service.RoleUser,
+ Status: service.StatusActive,
+ LastLoginAt: &lastLoginAt,
+ LastActiveAt: &lastActiveAt,
+ LastUsedAt: &lastUsedAt,
+ CreatedAt: lastLoginAt.Add(-24 * time.Hour),
+ UpdatedAt: lastLoginAt,
+ },
+ }
+ handler := NewUserHandler(adminSvc, nil)
+
+ recorder := httptest.NewRecorder()
+ c, _ := gin.CreateTestContext(recorder)
+ c.Request = httptest.NewRequest(
+ http.MethodGet,
+ "/api/v1/admin/users?sort_by=last_used_at&sort_order=asc&search=activity",
+ nil,
+ )
+
+ handler.List(c)
+
+ require.Equal(t, http.StatusOK, recorder.Code)
+ require.Equal(t, "last_used_at", adminSvc.lastListUsers.sortBy)
+ require.Equal(t, "asc", adminSvc.lastListUsers.sortOrder)
+ require.Equal(t, "activity", adminSvc.lastListUsers.filters.Search)
+
+ var resp struct {
+ Code int `json:"code"`
+ Data struct {
+ Items []struct {
+ LastLoginAt *time.Time `json:"last_login_at"`
+ LastActiveAt *time.Time `json:"last_active_at"`
+ LastUsedAt *time.Time `json:"last_used_at"`
+ } `json:"items"`
+ } `json:"data"`
+ }
+ require.NoError(t, json.Unmarshal(recorder.Body.Bytes(), &resp))
+ require.Equal(t, 0, resp.Code)
+ require.Len(t, resp.Data.Items, 1)
+ require.WithinDuration(t, lastLoginAt, *resp.Data.Items[0].LastLoginAt, time.Second)
+ require.WithinDuration(t, lastActiveAt, *resp.Data.Items[0].LastActiveAt, time.Second)
+ require.WithinDuration(t, lastUsedAt, *resp.Data.Items[0].LastUsedAt, time.Second)
+}
+
+func TestUserHandlerGetByIDIncludesActivityFields(t *testing.T) {
+ gin.SetMode(gin.TestMode)
+
+ lastLoginAt := time.Date(2026, 4, 20, 8, 0, 0, 0, time.UTC)
+ lastActiveAt := lastLoginAt.Add(30 * time.Minute)
+ lastUsedAt := lastLoginAt.Add(90 * time.Minute)
+
+ adminSvc := newStubAdminService()
+ adminSvc.users = []service.User{
+ {
+ ID: 8,
+ Email: "detail@example.com",
+ Username: "detail-user",
+ Role: service.RoleUser,
+ Status: service.StatusActive,
+ LastLoginAt: &lastLoginAt,
+ LastActiveAt: &lastActiveAt,
+ LastUsedAt: &lastUsedAt,
+ CreatedAt: lastLoginAt.Add(-24 * time.Hour),
+ UpdatedAt: lastLoginAt,
+ },
+ }
+ handler := NewUserHandler(adminSvc, nil)
+
+ recorder := httptest.NewRecorder()
+ c, _ := gin.CreateTestContext(recorder)
+ c.Params = gin.Params{{Key: "id", Value: "8"}}
+ c.Request = httptest.NewRequest(http.MethodGet, "/api/v1/admin/users/8", nil)
+
+ handler.GetByID(c)
+
+ require.Equal(t, http.StatusOK, recorder.Code)
+
+ var resp struct {
+ Code int `json:"code"`
+ Data struct {
+ LastLoginAt *time.Time `json:"last_login_at"`
+ LastActiveAt *time.Time `json:"last_active_at"`
+ LastUsedAt *time.Time `json:"last_used_at"`
+ } `json:"data"`
+ }
+ require.NoError(t, json.Unmarshal(recorder.Body.Bytes(), &resp))
+ require.Equal(t, 0, resp.Code)
+ require.WithinDuration(t, lastLoginAt, *resp.Data.LastLoginAt, time.Second)
+ require.WithinDuration(t, lastActiveAt, *resp.Data.LastActiveAt, time.Second)
+ require.WithinDuration(t, lastUsedAt, *resp.Data.LastUsedAt, time.Second)
+}
diff --git a/backend/internal/handler/auth_current_user_test.go b/backend/internal/handler/auth_current_user_test.go
index dab95e29..31d92a36 100644
--- a/backend/internal/handler/auth_current_user_test.go
+++ b/backend/internal/handler/auth_current_user_test.go
@@ -71,8 +71,15 @@ func TestAuthHandlerGetCurrentUserReturnsProfileCompatibilityFields(t *testing.T
require.True(t, ok)
require.Equal(t, true, linuxdoBinding["bound"])
- _, hasAvatarSource := resp.Data["avatar_source"]
- require.False(t, hasAvatarSource)
- _, hasProfileSources := resp.Data["profile_sources"]
- require.False(t, hasProfileSources)
+ avatarSource, ok := resp.Data["avatar_source"].(map[string]any)
+ require.True(t, ok)
+ require.Equal(t, "linuxdo", avatarSource["provider"])
+ require.Equal(t, "linuxdo", avatarSource["source"])
+
+ profileSources, ok := resp.Data["profile_sources"].(map[string]any)
+ require.True(t, ok)
+ usernameSource, ok := profileSources["username"].(map[string]any)
+ require.True(t, ok)
+ require.Equal(t, "linuxdo", usernameSource["provider"])
+ require.Equal(t, "linuxdo", usernameSource["source"])
}
diff --git a/backend/internal/handler/user_handler.go b/backend/internal/handler/user_handler.go
index 843b0bd9..b1ade5c0 100644
--- a/backend/internal/handler/user_handler.go
+++ b/backend/internal/handler/user_handler.go
@@ -2,6 +2,7 @@ package handler
import (
"context"
+ "strings"
"github.com/Wei-Shaw/sub2api/internal/handler/dto"
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
@@ -352,16 +353,22 @@ func userProfileResponseFromService(user *service.User, identities service.UserI
return userProfileResponse{}
}
bindings := userProfileBindingMap(identities)
+ profileSources, avatarSource, usernameSource := inferUserProfileSources(user, identities)
return userProfileResponse{
- User: *base,
- AvatarURL: user.AvatarURL,
- Identities: identities,
- AuthBindings: bindings,
- IdentityBindings: bindings,
- EmailBound: identities.Email.Bound,
- LinuxDoBound: identities.LinuxDo.Bound,
- OIDCBound: identities.OIDC.Bound,
- WeChatBound: identities.WeChat.Bound,
+ User: *base,
+ AvatarURL: user.AvatarURL,
+ AvatarSource: avatarSource,
+ UsernameSource: usernameSource,
+ DisplayNameSource: usernameSource,
+ NicknameSource: usernameSource,
+ ProfileSources: profileSources,
+ Identities: identities,
+ AuthBindings: bindings,
+ IdentityBindings: bindings,
+ EmailBound: identities.Email.Bound,
+ LinuxDoBound: identities.LinuxDo.Bound,
+ OIDCBound: identities.OIDC.Bound,
+ WeChatBound: identities.WeChat.Bound,
}
}
@@ -373,3 +380,66 @@ func userProfileBindingMap(identities service.UserIdentitySummarySet) map[string
"wechat": identities.WeChat,
}
}
+
+func inferUserProfileSources(user *service.User, identities service.UserIdentitySummarySet) (
+ map[string]*userProfileSourceContext,
+ *userProfileSourceContext,
+ *userProfileSourceContext,
+) {
+ if user == nil {
+ return nil, nil, nil
+ }
+
+ thirdParty := thirdPartyIdentityProviders(identities)
+ var avatarSource *userProfileSourceContext
+ if strings.TrimSpace(user.AvatarURL) != "" && len(thirdParty) == 1 {
+ avatarSource = buildUserProfileSourceContext(thirdParty[0].Provider)
+ }
+
+ usernameValue := strings.TrimSpace(user.Username)
+ var usernameSource *userProfileSourceContext
+ for _, summary := range thirdParty {
+ if usernameValue != "" && usernameValue == strings.TrimSpace(summary.DisplayName) {
+ usernameSource = buildUserProfileSourceContext(summary.Provider)
+ break
+ }
+ }
+ if usernameSource == nil && usernameValue != "" && len(thirdParty) == 1 {
+ usernameSource = buildUserProfileSourceContext(thirdParty[0].Provider)
+ }
+
+ profileSources := map[string]*userProfileSourceContext{}
+ if avatarSource != nil {
+ profileSources["avatar"] = avatarSource
+ }
+ if usernameSource != nil {
+ profileSources["username"] = usernameSource
+ profileSources["display_name"] = usernameSource
+ profileSources["nickname"] = usernameSource
+ }
+ if len(profileSources) == 0 {
+ return nil, avatarSource, usernameSource
+ }
+ return profileSources, avatarSource, usernameSource
+}
+
+func thirdPartyIdentityProviders(identities service.UserIdentitySummarySet) []service.UserIdentitySummary {
+ out := make([]service.UserIdentitySummary, 0, 3)
+ for _, summary := range []service.UserIdentitySummary{identities.LinuxDo, identities.OIDC, identities.WeChat} {
+ if summary.Bound {
+ out = append(out, summary)
+ }
+ }
+ return out
+}
+
+func buildUserProfileSourceContext(provider string) *userProfileSourceContext {
+ provider = strings.TrimSpace(provider)
+ if provider == "" {
+ return nil
+ }
+ return &userProfileSourceContext{
+ Provider: provider,
+ Source: provider,
+ }
+}
diff --git a/backend/internal/handler/user_handler_test.go b/backend/internal/handler/user_handler_test.go
index 7c6460e8..aacfc332 100644
--- a/backend/internal/handler/user_handler_test.go
+++ b/backend/internal/handler/user_handler_test.go
@@ -285,6 +285,11 @@ func TestUserHandlerGetProfileReturnsLegacyCompatibilityFields(t *testing.T) {
require.Equal(t, false, resp.Data["wechat_bound"])
require.Equal(t, "https://cdn.example.com/linuxdo.png", resp.Data["avatar_url"])
+ avatarSource, ok := resp.Data["avatar_source"].(map[string]any)
+ require.True(t, ok)
+ require.Equal(t, "linuxdo", avatarSource["provider"])
+ require.Equal(t, "linuxdo", avatarSource["source"])
+
authBindings, ok := resp.Data["auth_bindings"].(map[string]any)
require.True(t, ok)
linuxdoBinding, ok := authBindings["linuxdo"].(map[string]any)
@@ -298,10 +303,12 @@ func TestUserHandlerGetProfileReturnsLegacyCompatibilityFields(t *testing.T) {
require.True(t, ok)
require.Equal(t, true, emailBinding["bound"])
- _, hasAvatarSource := resp.Data["avatar_source"]
- require.False(t, hasAvatarSource)
- _, hasProfileSources := resp.Data["profile_sources"]
- require.False(t, hasProfileSources)
+ profileSources, ok := resp.Data["profile_sources"].(map[string]any)
+ require.True(t, ok)
+ usernameSource, ok := profileSources["username"].(map[string]any)
+ require.True(t, ok)
+ require.Equal(t, "linuxdo", usernameSource["provider"])
+ require.Equal(t, "linuxdo", usernameSource["source"])
}
func TestUserHandlerStartIdentityBindingReturnsAuthorizeURL(t *testing.T) {
diff --git a/backend/internal/service/user_service.go b/backend/internal/service/user_service.go
index 7c2ca2d0..6994ec6c 100644
--- a/backend/internal/service/user_service.go
+++ b/backend/internal/service/user_service.go
@@ -161,6 +161,10 @@ type userAuthIdentityReader interface {
ListUserAuthIdentities(ctx context.Context, userID int64) ([]UserAuthIdentityRecord, error)
}
+type userProfileIdentityTxRunner interface {
+ WithUserProfileIdentityTx(ctx context.Context, fn func(txCtx context.Context) error) error
+}
+
// ChangePasswordRequest 修改密码请求
type ChangePasswordRequest struct {
CurrentPassword string `json:"current_password"`
@@ -249,9 +253,38 @@ func (s *UserService) PrepareIdentityBindingStart(_ context.Context, req StartUs
// UpdateProfile 更新用户资料
func (s *UserService) UpdateProfile(ctx context.Context, userID int64, req UpdateProfileRequest) (*User, error) {
+ if txRunner, ok := s.userRepo.(userProfileIdentityTxRunner); ok {
+ var (
+ updated *User
+ oldConcurrency int
+ )
+ if err := txRunner.WithUserProfileIdentityTx(ctx, func(txCtx context.Context) error {
+ var err error
+ updated, oldConcurrency, err = s.updateProfile(txCtx, userID, req)
+ return err
+ }); err != nil {
+ return nil, err
+ }
+ if s.authCacheInvalidator != nil && updated != nil && updated.Concurrency != oldConcurrency {
+ s.authCacheInvalidator.InvalidateAuthCacheByUserID(ctx, userID)
+ }
+ return updated, nil
+ }
+
+ updated, oldConcurrency, err := s.updateProfile(ctx, userID, req)
+ if err != nil {
+ return nil, err
+ }
+ if s.authCacheInvalidator != nil && updated.Concurrency != oldConcurrency {
+ s.authCacheInvalidator.InvalidateAuthCacheByUserID(ctx, userID)
+ }
+ return updated, nil
+}
+
+func (s *UserService) updateProfile(ctx context.Context, userID int64, req UpdateProfileRequest) (*User, int, error) {
user, err := s.userRepo.GetByID(ctx, userID)
if err != nil {
- return nil, fmt.Errorf("get user: %w", err)
+ return nil, 0, fmt.Errorf("get user: %w", err)
}
oldConcurrency := user.Concurrency
@@ -260,10 +293,10 @@ func (s *UserService) UpdateProfile(ctx context.Context, userID int64, req Updat
// 检查新邮箱是否已被使用
exists, err := s.userRepo.ExistsByEmail(ctx, *req.Email)
if err != nil {
- return nil, fmt.Errorf("check email exists: %w", err)
+ return nil, oldConcurrency, fmt.Errorf("check email exists: %w", err)
}
if exists && *req.Email != user.Email {
- return nil, ErrEmailExists
+ return nil, oldConcurrency, ErrEmailExists
}
user.Email = *req.Email
}
@@ -275,7 +308,7 @@ func (s *UserService) UpdateProfile(ctx context.Context, userID int64, req Updat
if req.AvatarURL != nil {
avatar, err := s.SetAvatar(ctx, userID, *req.AvatarURL)
if err != nil {
- return nil, err
+ return nil, oldConcurrency, err
}
applyUserAvatar(user, avatar)
}
@@ -296,13 +329,10 @@ func (s *UserService) UpdateProfile(ctx context.Context, userID int64, req Updat
}
if err := s.userRepo.Update(ctx, user); err != nil {
- return nil, fmt.Errorf("update user: %w", err)
- }
- if s.authCacheInvalidator != nil && user.Concurrency != oldConcurrency {
- s.authCacheInvalidator.InvalidateAuthCacheByUserID(ctx, userID)
+ return nil, oldConcurrency, fmt.Errorf("update user: %w", err)
}
- return user, nil
+ return user, oldConcurrency, nil
}
func (s *UserService) SetAvatar(ctx context.Context, userID int64, raw string) (*UserAvatar, error) {
diff --git a/backend/internal/service/user_service_test.go b/backend/internal/service/user_service_test.go
index e13fb95d..c0cc36cb 100644
--- a/backend/internal/service/user_service_test.go
+++ b/backend/internal/service/user_service_test.go
@@ -31,13 +31,26 @@ type mockUserRepo struct {
deleteAvatarFn func(ctx context.Context, userID int64) error
deleteAvatarIDs []int64
getAvatarFn func(ctx context.Context, userID int64) (*UserAvatar, error)
+ txCalls int
+}
+
+type mockUserRepoTxKey struct{}
+
+type mockUserRepoTxState struct {
+ getByIDUser *User
+ upsertAvatarArgs []UpsertUserAvatarInput
+ deleteAvatarIDs []int64
}
func (m *mockUserRepo) Create(context.Context, *User) error { return nil }
-func (m *mockUserRepo) GetByID(context.Context, int64) (*User, error) {
+func (m *mockUserRepo) GetByID(ctx context.Context, _ int64) (*User, error) {
if m.getByIDErr != nil {
return nil, m.getByIDErr
}
+ if txState, _ := ctx.Value(mockUserRepoTxKey{}).(*mockUserRepoTxState); txState != nil && txState.getByIDUser != nil {
+ cloned := *txState.getByIDUser
+ return &cloned, nil
+ }
if m.getByIDUser != nil {
cloned := *m.getByIDUser
return &cloned, nil
@@ -61,6 +74,27 @@ func (m *mockUserRepo) GetUserAvatar(ctx context.Context, userID int64) (*UserAv
return nil, nil
}
func (m *mockUserRepo) UpsertUserAvatar(ctx context.Context, userID int64, input UpsertUserAvatarInput) (*UserAvatar, error) {
+ if txState, _ := ctx.Value(mockUserRepoTxKey{}).(*mockUserRepoTxState); txState != nil {
+ txState.upsertAvatarArgs = append(txState.upsertAvatarArgs, input)
+ if txState.getByIDUser != nil {
+ txState.getByIDUser.AvatarURL = input.URL
+ txState.getByIDUser.AvatarSource = input.StorageProvider
+ txState.getByIDUser.AvatarMIME = input.ContentType
+ txState.getByIDUser.AvatarByteSize = input.ByteSize
+ txState.getByIDUser.AvatarSHA256 = input.SHA256
+ }
+ if m.upsertAvatarFn != nil {
+ return m.upsertAvatarFn(ctx, userID, input)
+ }
+ return &UserAvatar{
+ StorageProvider: input.StorageProvider,
+ StorageKey: input.StorageKey,
+ URL: input.URL,
+ ContentType: input.ContentType,
+ ByteSize: input.ByteSize,
+ SHA256: input.SHA256,
+ }, nil
+ }
m.upsertAvatarArgs = append(m.upsertAvatarArgs, input)
if m.upsertAvatarFn != nil {
return m.upsertAvatarFn(ctx, userID, input)
@@ -75,6 +109,20 @@ func (m *mockUserRepo) UpsertUserAvatar(ctx context.Context, userID int64, input
}, nil
}
func (m *mockUserRepo) DeleteUserAvatar(ctx context.Context, userID int64) error {
+ if txState, _ := ctx.Value(mockUserRepoTxKey{}).(*mockUserRepoTxState); txState != nil {
+ txState.deleteAvatarIDs = append(txState.deleteAvatarIDs, userID)
+ if txState.getByIDUser != nil {
+ txState.getByIDUser.AvatarURL = ""
+ txState.getByIDUser.AvatarSource = ""
+ txState.getByIDUser.AvatarMIME = ""
+ txState.getByIDUser.AvatarByteSize = 0
+ txState.getByIDUser.AvatarSHA256 = ""
+ }
+ if m.deleteAvatarFn != nil {
+ return m.deleteAvatarFn(ctx, userID)
+ }
+ return nil
+ }
m.deleteAvatarIDs = append(m.deleteAvatarIDs, userID)
if m.deleteAvatarFn != nil {
return m.deleteAvatarFn(ctx, userID)
@@ -116,6 +164,26 @@ func (m *mockUserRepo) RemoveGroupFromUserAllowedGroups(context.Context, int64,
return nil
}
+func (m *mockUserRepo) WithUserProfileIdentityTx(ctx context.Context, fn func(txCtx context.Context) error) error {
+ m.txCalls++
+ txState := &mockUserRepoTxState{
+ upsertAvatarArgs: append([]UpsertUserAvatarInput(nil), m.upsertAvatarArgs...),
+ deleteAvatarIDs: append([]int64(nil), m.deleteAvatarIDs...),
+ }
+ if m.getByIDUser != nil {
+ userCopy := *m.getByIDUser
+ txState.getByIDUser = &userCopy
+ }
+ err := fn(context.WithValue(ctx, mockUserRepoTxKey{}, txState))
+ if err != nil {
+ return err
+ }
+ m.getByIDUser = txState.getByIDUser
+ m.upsertAvatarArgs = txState.upsertAvatarArgs
+ m.deleteAvatarIDs = txState.deleteAvatarIDs
+ return nil
+}
+
// --- mock: APIKeyAuthCacheInvalidator ---
type mockAuthCacheInvalidator struct {
@@ -360,6 +428,33 @@ func TestUpdateProfile_DeletesAvatarOnEmptyString(t *testing.T) {
require.Empty(t, updated.AvatarSource)
}
+func TestUpdateProfile_RollsBackAvatarMutationWhenUserUpdateFails(t *testing.T) {
+ repo := &mockUserRepo{
+ getByIDUser: &User{
+ ID: 11,
+ Email: "rollback@example.com",
+ AvatarURL: "https://cdn.example.com/original.png",
+ AvatarSource: "remote_url",
+ },
+ updateFn: func(context.Context, *User) error {
+ return errors.New("write user failed")
+ },
+ }
+ svc := NewUserService(repo, nil, nil, nil)
+
+ remoteURL := "https://cdn.example.com/new.png"
+ _, err := svc.UpdateProfile(context.Background(), 11, UpdateProfileRequest{
+ AvatarURL: &remoteURL,
+ })
+
+ require.EqualError(t, err, "update user: write user failed")
+ require.Equal(t, 1, repo.txCalls)
+ require.Empty(t, repo.upsertAvatarArgs)
+ require.Empty(t, repo.deleteAvatarIDs)
+ require.Equal(t, "https://cdn.example.com/original.png", repo.getByIDUser.AvatarURL)
+ require.Equal(t, "remote_url", repo.getByIDUser.AvatarSource)
+}
+
func TestGetProfile_HydratesAvatarFromRepository(t *testing.T) {
repo := &mockUserRepo{
getByIDUser: &User{
diff --git a/frontend/src/api/__tests__/users.migrationReports.spec.ts b/frontend/src/api/__tests__/users.migrationReports.spec.ts
index c9375cd4..ca889600 100644
--- a/frontend/src/api/__tests__/users.migrationReports.spec.ts
+++ b/frontend/src/api/__tests__/users.migrationReports.spec.ts
@@ -13,6 +13,7 @@ vi.mock('@/api/client', () => ({
}))
import {
+ bindUserAuthIdentity,
getAuthIdentityMigrationReportSummary,
listAuthIdentityMigrationReports,
resolveAuthIdentityMigrationReport,
@@ -81,4 +82,31 @@ describe('admin users auth identity migration reports API', () => {
})
expect(result).toBe(response)
})
+
+ it('binds a canonical auth identity to a user for remediation', async () => {
+ const response = {
+ identity_id: 11,
+ provider_type: 'oidc',
+ provider_key: 'https://issuer.example',
+ provider_subject: 'subject-123',
+ }
+ post.mockResolvedValue({ data: response })
+
+ const result = await bindUserAuthIdentity(42, {
+ provider_type: 'oidc',
+ provider_key: 'https://issuer.example',
+ provider_subject: 'subject-123',
+ issuer: 'https://issuer.example',
+ metadata: { source: 'migration-report' },
+ })
+
+ expect(post).toHaveBeenCalledWith('/admin/users/42/auth-identities', {
+ provider_type: 'oidc',
+ provider_key: 'https://issuer.example',
+ provider_subject: 'subject-123',
+ issuer: 'https://issuer.example',
+ metadata: { source: 'migration-report' },
+ })
+ expect(result).toBe(response)
+ })
})
diff --git a/frontend/src/api/admin/users.ts b/frontend/src/api/admin/users.ts
index 2d154cf8..be30eddc 100644
--- a/frontend/src/api/admin/users.ts
+++ b/frontend/src/api/admin/users.ts
@@ -24,6 +24,30 @@ export interface AuthIdentityMigrationReportSummary {
by_type: Record
}
+export interface AdminBindAuthIdentityChannelRequest {
+ channel: string
+ channel_app_id?: string
+ channel_subject: string
+ metadata?: Record
+}
+
+export interface AdminBindAuthIdentityRequest {
+ provider_type: string
+ provider_key: string
+ provider_subject: string
+ issuer?: string
+ metadata?: Record
+ channel?: AdminBindAuthIdentityChannelRequest
+}
+
+export interface AdminBoundAuthIdentity {
+ identity_id: number
+ provider_type: string
+ provider_key: string
+ provider_subject: string
+ channel_id?: number | null
+}
+
export interface ListAuthIdentityMigrationReportsParams {
page?: number
pageSize?: number
@@ -308,6 +332,17 @@ export async function resolveAuthIdentityMigrationReport(
return data
}
+export async function bindUserAuthIdentity(
+ userId: number,
+ input: AdminBindAuthIdentityRequest
+): Promise {
+ const { data } = await apiClient.post(
+ `/admin/users/${userId}/auth-identities`,
+ input
+ )
+ return data
+}
+
export const usersAPI = {
list,
getById,
@@ -321,6 +356,7 @@ export const usersAPI = {
getUserUsageStats,
getUserBalanceHistory,
replaceGroup,
+ bindUserAuthIdentity,
getAuthIdentityMigrationReportSummary,
listAuthIdentityMigrationReports,
resolveAuthIdentityMigrationReport
diff --git a/frontend/src/components/user/profile/__tests__/ProfileInfoCard.spec.ts b/frontend/src/components/user/profile/__tests__/ProfileInfoCard.spec.ts
index 238a898e..6e9e4dd5 100644
--- a/frontend/src/components/user/profile/__tests__/ProfileInfoCard.spec.ts
+++ b/frontend/src/components/user/profile/__tests__/ProfileInfoCard.spec.ts
@@ -62,6 +62,8 @@ vi.mock('vue-i18n', async (importOriginal) => {
if (key === 'profile.authBindings.providers.linuxdo') return 'LinuxDo'
if (key === 'profile.authBindings.providers.wechat') return 'WeChat'
if (key === 'profile.authBindings.providers.oidc') return params?.providerName || 'OIDC'
+ if (key === 'profile.authBindings.source.avatar') return `Avatar synced from ${params?.providerName || 'provider'}`
+ if (key === 'profile.authBindings.source.username') return `Username synced from ${params?.providerName || 'provider'}`
if (key === 'common.save') return 'Save'
if (key === 'common.delete') return 'Delete'
return key
@@ -169,4 +171,29 @@ describe('ProfileInfoCard', () => {
expect(authStoreState.user?.avatar_url).toBeNull()
expect(showSuccessMock).toHaveBeenCalledWith('Avatar removed')
})
+
+ it('renders third-party source hints from profile_sources', () => {
+ authStoreState.user = createUser({
+ avatar_url: 'https://cdn.example.com/linuxdo.png',
+ profile_sources: {
+ avatar: { provider: 'linuxdo', source: 'linuxdo' },
+ username: { provider: 'linuxdo', source: 'linuxdo' }
+ }
+ })
+
+ const wrapper = mount(ProfileInfoCard, {
+ props: {
+ user: authStoreState.user
+ },
+ global: {
+ stubs: {
+ Icon: true,
+ ProfileIdentityBindingsSection: true
+ }
+ }
+ })
+
+ expect(wrapper.text()).toContain('Avatar synced from LinuxDo')
+ expect(wrapper.text()).toContain('Username synced from LinuxDo')
+ })
})
diff --git a/frontend/src/views/admin/AuthIdentityMigrationReportsView.vue b/frontend/src/views/admin/AuthIdentityMigrationReportsView.vue
index 35c232b6..12fbf1b9 100644
--- a/frontend/src/views/admin/AuthIdentityMigrationReportsView.vue
+++ b/frontend/src/views/admin/AuthIdentityMigrationReportsView.vue
@@ -238,6 +238,83 @@
{{ resolving ? copy.resolving : copy.resolveAction }}
+
+
+
+ {{ copy.remediationTitle }}
+
+
+ {{ copy.remediationSubtitle }}
+
+
+
+
@@ -249,6 +326,7 @@ import { computed, onMounted, reactive, ref } from 'vue'
import { useI18n } from 'vue-i18n'
import { adminAPI } from '@/api/admin'
import type {
+ AdminBindAuthIdentityRequest,
AuthIdentityMigrationReport,
AuthIdentityMigrationReportSummary,
} from '@/api/admin/users'
@@ -294,6 +372,15 @@ const copy = computed(() => ({
resolvePlaceholder: text('填写本次处理动作、用户沟通结果或后续追踪信息。', 'Describe the action taken, user communication, or follow-up context.'),
resolveAction: text('提交 Resolve', 'Submit resolve'),
resolving: text('提交中...', 'Submitting...'),
+ remediationTitle: text('修复绑定', 'Remediation binding'),
+ remediationSubtitle: text('可直接把迁移报告中的身份信息绑定到指定用户;已识别字段会自动预填。', 'Bind the migrated identity directly to a user. Recognized fields are prefilled automatically.'),
+ remediationUserID: text('目标用户 ID', 'Target user ID'),
+ remediationProviderType: text('Provider Type', 'Provider type'),
+ remediationProviderKey: text('Provider Key', 'Provider key'),
+ remediationProviderSubject: text('Provider Subject', 'Provider subject'),
+ remediationIssuer: text('Issuer', 'Issuer'),
+ remediationAction: text('提交绑定修复', 'Submit remediation binding'),
+ remediationSubmitting: text('提交中...', 'Submitting...'),
}))
const summary = ref({
@@ -308,6 +395,7 @@ const resolutionNote = ref('')
const loading = ref(false)
const summaryLoading = ref(false)
const resolving = ref(false)
+const binding = ref(false)
const filters = reactive({
reportType: '',
@@ -319,6 +407,13 @@ const pagination = reactive({
total: 0,
})
const knownReportTypes = ref([])
+const remediation = reactive({
+ userID: '',
+ providerType: '',
+ providerKey: '',
+ providerSubject: '',
+ issuer: '',
+})
const columns: Column[] = [
{ key: 'status', label: text('状态', 'Status') },
@@ -352,6 +447,18 @@ const canResolve = computed(() =>
)
)
+const canBindRemediation = computed(() =>
+ Boolean(
+ selectedReport.value &&
+ !selectedReport.value.resolved_at &&
+ remediation.userID.trim() &&
+ remediation.providerType.trim() &&
+ remediation.providerKey.trim() &&
+ remediation.providerSubject.trim() &&
+ !binding.value
+ )
+)
+
const mergeKnownReportTypes = (...values: Array) => {
const merged = new Set(knownReportTypes.value)
for (const value of values) {
@@ -392,6 +499,7 @@ const loadReports = async () => {
if (selectedReport.value) {
const refreshed = response.items.find((report) => report.id === selectedReport.value?.id) ?? null
selectedReport.value = refreshed
+ applyRemediationDefaults(refreshed)
resolutionNote.value = refreshed?.resolved_at
? refreshed.resolution_note ?? ''
: resolutionNote.value
@@ -427,6 +535,7 @@ const handlePageSizeChange = async (pageSize: number) => {
const selectReport = (report: AuthIdentityMigrationReport) => {
selectedReport.value = report
resolutionNote.value = report.resolution_note ?? ''
+ applyRemediationDefaults(report)
}
const formatDetailsJson = (details: Record) => JSON.stringify(details ?? {}, null, 2)
@@ -458,6 +567,63 @@ const getDetailHighlights = (details: Record) => {
.map(([key, value]) => ({ key, value: String(value) }))
}
+const stringDetailValue = (details: Record, key: string) => {
+ const value = details[key]
+ return typeof value === 'string' ? value.trim() : ''
+}
+
+const numericDetailValue = (details: Record, key: string) => {
+ const value = details[key]
+ if (typeof value === 'number' && Number.isFinite(value)) {
+ return String(Math.trunc(value))
+ }
+ if (typeof value === 'string' && value.trim()) {
+ return value.trim()
+ }
+ return ''
+}
+
+const inferProviderTypeFromReport = (report: AuthIdentityMigrationReport) => {
+ const explicit = stringDetailValue(report.details, 'provider_type')
+ if (explicit) {
+ return explicit
+ }
+ if (report.report_type.includes('oidc')) {
+ return 'oidc'
+ }
+ if (report.report_type.includes('wechat')) {
+ return 'wechat'
+ }
+ if (report.report_type.includes('linuxdo')) {
+ return 'linuxdo'
+ }
+ return ''
+}
+
+const inferProviderKeyFromReport = (report: AuthIdentityMigrationReport, providerType: string) => {
+ const explicit = stringDetailValue(report.details, 'provider_key')
+ if (explicit) {
+ return explicit
+ }
+ if (providerType === 'wechat') {
+ return 'wechat-main'
+ }
+ return ''
+}
+
+const inferProviderSubjectFromReport = (report: AuthIdentityMigrationReport) =>
+ stringDetailValue(report.details, 'provider_subject')
+ || stringDetailValue(report.details, 'subject')
+ || stringDetailValue(report.details, 'unionid')
+
+const applyRemediationDefaults = (report: AuthIdentityMigrationReport | null) => {
+ remediation.userID = report ? numericDetailValue(report.details, 'user_id') : ''
+ remediation.providerType = report ? inferProviderTypeFromReport(report) : ''
+ remediation.providerKey = report ? inferProviderKeyFromReport(report, remediation.providerType) : ''
+ remediation.providerSubject = report ? inferProviderSubjectFromReport(report) : ''
+ remediation.issuer = report ? stringDetailValue(report.details, 'issuer') : ''
+}
+
const submitResolve = async () => {
if (!selectedReport.value) {
appStore.showError(text('请先选择一条报告', 'Select a report first'))
@@ -485,6 +651,38 @@ const submitResolve = async () => {
}
}
+const submitRemediationBinding = async () => {
+ if (!selectedReport.value) {
+ appStore.showError(text('请先选择一条报告', 'Select a report first'))
+ return
+ }
+
+ const userID = Number.parseInt(remediation.userID.trim(), 10)
+ if (!Number.isFinite(userID) || userID <= 0) {
+ appStore.showError(text('请输入有效的目标用户 ID', 'Enter a valid target user ID'))
+ return
+ }
+
+ const payload: AdminBindAuthIdentityRequest = {
+ provider_type: remediation.providerType.trim(),
+ provider_key: remediation.providerKey.trim(),
+ provider_subject: remediation.providerSubject.trim(),
+ issuer: remediation.issuer.trim() || undefined,
+ metadata: {},
+ }
+
+ binding.value = true
+ try {
+ await adminAPI.users.bindUserAuthIdentity(userID, payload)
+ appStore.showSuccess(text('修复绑定已提交', 'Remediation binding submitted'))
+ } catch (error) {
+ console.error('Failed to submit auth identity remediation binding:', error)
+ appStore.showError(text('提交修复绑定失败', 'Failed to submit remediation binding'))
+ } finally {
+ binding.value = false
+ }
+}
+
onMounted(async () => {
await refreshAll()
})
diff --git a/frontend/src/views/admin/__tests__/AuthIdentityMigrationReportsView.spec.ts b/frontend/src/views/admin/__tests__/AuthIdentityMigrationReportsView.spec.ts
index 406baaf1..20f57fa1 100644
--- a/frontend/src/views/admin/__tests__/AuthIdentityMigrationReportsView.spec.ts
+++ b/frontend/src/views/admin/__tests__/AuthIdentityMigrationReportsView.spec.ts
@@ -4,7 +4,13 @@ import { defineComponent, h } from 'vue'
import AuthIdentityMigrationReportsView from '../AuthIdentityMigrationReportsView.vue'
-const { getAuthIdentityMigrationReportSummary, listAuthIdentityMigrationReports, resolveAuthIdentityMigrationReport } = vi.hoisted(() => ({
+const {
+ bindUserAuthIdentity,
+ getAuthIdentityMigrationReportSummary,
+ listAuthIdentityMigrationReports,
+ resolveAuthIdentityMigrationReport,
+} = vi.hoisted(() => ({
+ bindUserAuthIdentity: vi.fn(),
getAuthIdentityMigrationReportSummary: vi.fn(),
listAuthIdentityMigrationReports: vi.fn(),
resolveAuthIdentityMigrationReport: vi.fn(),
@@ -18,6 +24,7 @@ const { showError, showSuccess } = vi.hoisted(() => ({
vi.mock('@/api/admin', () => ({
adminAPI: {
users: {
+ bindUserAuthIdentity,
getAuthIdentityMigrationReportSummary,
listAuthIdentityMigrationReports,
resolveAuthIdentityMigrationReport,
@@ -156,6 +163,7 @@ describe('AuthIdentityMigrationReportsView', () => {
getAuthIdentityMigrationReportSummary.mockReset()
listAuthIdentityMigrationReports.mockReset()
resolveAuthIdentityMigrationReport.mockReset()
+ bindUserAuthIdentity.mockReset()
showError.mockReset()
showSuccess.mockReset()
@@ -167,6 +175,12 @@ describe('AuthIdentityMigrationReportsView', () => {
resolved_by_user_id: 100,
resolution_note: 'resolved by admin',
})
+ bindUserAuthIdentity.mockResolvedValue({
+ identity_id: 77,
+ provider_type: 'oidc',
+ provider_key: 'https://issuer.example',
+ provider_subject: 'subject-123',
+ })
})
const mountView = () =>
@@ -241,6 +255,35 @@ describe('AuthIdentityMigrationReportsView', () => {
})
})
+ it('pre-fills and submits remediation binding for the selected report', async () => {
+ const wrapper = mountView()
+
+ await flushPromises()
+ await wrapper.get('[data-test="select-report-1"]').trigger('click')
+ await flushPromises()
+
+ expect((wrapper.get('[data-test="remediation-user-id"]').element as HTMLInputElement).value).toBe('42')
+ expect((wrapper.get('[data-test="remediation-provider-type"]').element as HTMLInputElement).value).toBe('oidc')
+ expect((wrapper.get('[data-test="remediation-provider-key"]').element as HTMLInputElement).value).toBe(
+ 'https://issuer.example'
+ )
+ expect((wrapper.get('[data-test="remediation-provider-subject"]').element as HTMLInputElement).value).toBe(
+ 'subject-123'
+ )
+
+ await wrapper.get('[data-test="remediation-submit"]').trigger('click')
+ await flushPromises()
+
+ expect(bindUserAuthIdentity).toHaveBeenCalledWith(42, {
+ provider_type: 'oidc',
+ provider_key: 'https://issuer.example',
+ provider_subject: 'subject-123',
+ issuer: undefined,
+ metadata: {},
+ })
+ expect(showSuccess).toHaveBeenCalled()
+ })
+
it('keeps report type filter options available from list data when summary fails', async () => {
getAuthIdentityMigrationReportSummary.mockRejectedValueOnce(new Error('summary failed'))
listAuthIdentityMigrationReports.mockResolvedValueOnce(listResponse)
From 07bde2b665b0220e823f49bcad7bc90d945d7918 Mon Sep 17 00:00:00 2001
From: IanShaw027
Date: Tue, 21 Apr 2026 02:17:49 +0800
Subject: [PATCH 089/189] docs: align payment guide with visible routing
---
docs/PAYMENT.md | 23 ++++++++++++++++-------
docs/PAYMENT_CN.md | 23 ++++++++++++++++-------
2 files changed, 32 insertions(+), 14 deletions(-)
diff --git a/docs/PAYMENT.md b/docs/PAYMENT.md
index 755b313a..981863a9 100644
--- a/docs/PAYMENT.md
+++ b/docs/PAYMENT.md
@@ -22,11 +22,11 @@ Sub2API has a built-in payment system that enables user self-service top-up with
| Provider | Payment Methods | Description |
|----------|----------------|-------------|
| **EasyPay** | Alipay, WeChat Pay | Third-party aggregation via EasyPay protocol |
-| **Alipay (Direct)** | PC Page Pay, H5 Mobile Pay | Direct integration with Alipay Open Platform, auto-switches by device |
-| **WeChat Pay (Direct)** | Native QR Code, H5 Pay | Direct integration with WeChat Pay APIv3, mobile-first H5 |
+| **Alipay (Direct)** | Desktop QR code, mobile Alipay redirect | Direct integration with Alipay Open Platform, returning desktop QR codes and mobile WAP/app launch links |
+| **WeChat Pay (Direct)** | Native QR, H5, MP/JSAPI Pay | Direct integration with WeChat Pay APIv3 with environment-aware routing |
| **Stripe** | Card, Alipay, WeChat Pay, Link, etc. | International payments, multi-currency support |
-> Alipay/WeChat Pay direct and EasyPay can coexist. Direct channels connect to payment APIs directly with lower fees; EasyPay aggregates through third-party platforms with easier setup.
+> Alipay/WeChat Pay direct and EasyPay can both exist as backend provider instances, but the frontend always exposes only two visible buttons: `Alipay` and `WeChat Pay`. Admins choose exactly one source for each visible method: direct or EasyPay. Direct channels connect to payment APIs directly with lower fees; EasyPay aggregates through third-party platforms with easier setup.
> **EasyPay Provider Recommendations**: Both options below are third-party aggregators compatible with the EasyPay protocol. Pick based on the funding channel and settlement currency you need:
>
@@ -65,6 +65,15 @@ Configure the following in Admin Dashboard **Settings → Payment Settings**:
| **Max Pending Orders** | Maximum concurrent pending orders per user | 3 |
| **Load Balance Strategy** | Strategy for selecting provider instances | Least Amount |
+### Frontend Visible Method Routing
+
+The current payment UX keeps the frontend method list unified and does not expose provider brands directly:
+
+- **Alipay**: when enabled, this button must be routed to either `Alipay (Direct)` or `EasyPay Alipay`
+- **WeChat Pay**: when enabled, this button must be routed to either `WeChat Pay (Direct)` or `EasyPay WeChat`
+- Each visible method can route to only one source at a time
+- If a visible method is enabled without a selected source, the frontend will not expose that method
+
### Load Balance Strategies
| Strategy | Description |
@@ -113,7 +122,7 @@ Compatible with any payment service that implements the EasyPay protocol.
### Alipay (Direct)
-Direct integration with Alipay Open Platform. Supports PC page pay and H5 mobile pay.
+Direct integration with Alipay Open Platform. Desktop flows return a QR code for in-page display, while mobile flows return an Alipay WAP/app redirect URL.
| Parameter | Description | Required |
|-----------|-------------|----------|
@@ -123,7 +132,7 @@ Direct integration with Alipay Open Platform. Supports PC page pay and H5 mobile
### WeChat Pay (Direct)
-Direct integration with WeChat Pay APIv3. Supports Native QR code and H5 payment.
+Direct integration with WeChat Pay APIv3. Supports Native QR code payment, H5 payment, and MP/JSAPI payment inside the WeChat environment.
| Parameter | Description | Required |
|-----------|-------------|----------|
@@ -220,8 +229,8 @@ User selects amount and payment method
▼
User completes payment
├─ EasyPay → QR code / H5 redirect
- ├─ Alipay → PC page pay / H5 mobile pay
- ├─ WeChat Pay → Native QR / H5 pay
+ ├─ Alipay → Desktop QR / mobile Alipay redirect
+ ├─ WeChat Pay → Desktop Native QR / non-WeChat H5 / in-WeChat JSAPI
└─ Stripe → Payment Element (card/Alipay/WeChat/etc.)
│
▼
diff --git a/docs/PAYMENT_CN.md b/docs/PAYMENT_CN.md
index aca3c866..d6d83ed1 100644
--- a/docs/PAYMENT_CN.md
+++ b/docs/PAYMENT_CN.md
@@ -22,11 +22,11 @@ Sub2API 内置支付系统,支持用户自助充值,无需部署独立的支
| 服务商 | 支付方式 | 说明 |
|--------|---------|------|
| **EasyPay(易支付)** | 支付宝、微信支付 | 兼容易支付协议的第三方聚合支付 |
-| **支付宝官方** | 支付宝 PC 页面支付、H5 手机网站支付 | 直接对接支付宝开放平台,自动根据终端切换 |
-| **微信官方** | Native 扫码支付、H5 支付 | 直接对接微信支付 APIv3,移动端优先 H5 |
+| **支付宝官方** | 桌面二维码扫码、移动端支付宝跳转 | 直接对接支付宝开放平台,桌面端返回二维码,移动端返回 WAP/唤起链接 |
+| **微信官方** | Native 扫码、H5、公众号/JSAPI 支付 | 直接对接微信支付 APIv3,按终端环境自动分流 |
| **Stripe** | 银行卡、支付宝、微信支付、Link 等 | 国际支付,支持多币种 |
-> 支付宝官方 / 微信官方与易支付可以共存。官方渠道直接对接 API,资金直达商户账户,手续费更低;易支付通过第三方平台聚合,接入门槛更低。
+> 支付宝官方 / 微信官方与易支付可以同时作为后台服务商实例存在,但前台始终只展示 `支付宝`、`微信支付` 两个可见按钮。管理员需要分别为这两个按钮选择唯一支付来源:官方或易支付。官方渠道直接对接 API,资金直达商户账户,手续费更低;易支付通过第三方平台聚合,接入门槛更低。
> **易支付服务商推荐**:以下两家均为兼容易支付协议的第三方聚合支付,按资金通道与结算方式选择:
>
@@ -65,6 +65,15 @@ Sub2API 内置支付系统,支持用户自助充值,无需部署独立的支
| **最大待支付订单数** | 同一用户最大并行待支付订单数 | 3 |
| **负载均衡策略** | 多服务商实例时的选择策略 | 最少金额 |
+### 前台可见支付方式路由
+
+当前版本对用户统一展示支付方式,不区分官方渠道还是易支付:
+
+- **支付宝**:后台启用后,需要额外指定该按钮路由到 `支付宝官方` 或 `易支付支付宝`
+- **微信支付**:后台启用后,需要额外指定该按钮路由到 `微信官方` 或 `易支付微信`
+- 同一个可见支付方式在同一时刻只能路由到一个来源
+- 支付来源未选择时,即使对应按钮被开启,前台也不会暴露该支付方式
+
### 负载均衡策略
| 策略 | 说明 |
@@ -113,7 +122,7 @@ Sub2API 内置支付系统,支持用户自助充值,无需部署独立的支
### 支付宝官方
-直接对接支付宝开放平台,支持 PC 页面支付和 H5 手机网站支付。
+直接对接支付宝开放平台。桌面端返回二维码供页面内展示和扫码,移动端返回支付宝手机网站支付跳转链接。
| 参数 | 说明 | 必填 |
|------|------|------|
@@ -123,7 +132,7 @@ Sub2API 内置支付系统,支持用户自助充值,无需部署独立的支
### 微信官方
-直接对接微信支付 APIv3,支持 Native 扫码支付和 H5 支付。
+直接对接微信支付 APIv3,支持 Native 扫码支付、H5 支付,以及在微信环境内的公众号/JSAPI 支付。
| 参数 | 说明 | 必填 |
|------|------|------|
@@ -220,8 +229,8 @@ Sub2API 内置支付系统,支持用户自助充值,无需部署独立的支
▼
用户完成支付
├─ EasyPay → 扫码 / H5 跳转
- ├─ 支付宝官方 → PC 页面支付 / H5 手机网站支付
- ├─ 微信官方 → Native 扫码 / H5 支付
+ ├─ 支付宝官方 → 桌面二维码 / 移动端支付宝跳转
+ ├─ 微信官方 → 桌面 Native 扫码 / 非微信 H5 / 微信内 JSAPI
└─ Stripe → Payment Element(银行卡/支付宝/微信等)
│
▼
From f11b7d510570f9d1f96754f4623248796508c9ec Mon Sep 17 00:00:00 2001
From: IanShaw027
Date: Tue, 21 Apr 2026 02:25:22 +0800
Subject: [PATCH 090/189] docs: align payment defaults and wechat capability
notes
---
docs/PAYMENT.md | 4 ++--
docs/PAYMENT_CN.md | 4 ++--
.../2026-04-20-auth-identity-payment-foundation-design.md | 4 ++--
3 files changed, 6 insertions(+), 6 deletions(-)
diff --git a/docs/PAYMENT.md b/docs/PAYMENT.md
index 981863a9..2cc8f566 100644
--- a/docs/PAYMENT.md
+++ b/docs/PAYMENT.md
@@ -61,9 +61,9 @@ Configure the following in Admin Dashboard **Settings → Payment Settings**:
| **Minimum Amount** | Minimum single top-up amount | 1 |
| **Maximum Amount** | Maximum single top-up amount (empty = unlimited) | - |
| **Daily Limit** | Per-user daily cumulative limit (empty = unlimited) | - |
-| **Order Timeout** | Order timeout in minutes (minimum 1) | 5 |
+| **Order Timeout** | Order timeout in minutes (minimum 1) | 30 |
| **Max Pending Orders** | Maximum concurrent pending orders per user | 3 |
-| **Load Balance Strategy** | Strategy for selecting provider instances | Least Amount |
+| **Load Balance Strategy** | Strategy for selecting provider instances | Round Robin |
### Frontend Visible Method Routing
diff --git a/docs/PAYMENT_CN.md b/docs/PAYMENT_CN.md
index d6d83ed1..95bfb990 100644
--- a/docs/PAYMENT_CN.md
+++ b/docs/PAYMENT_CN.md
@@ -61,9 +61,9 @@ Sub2API 内置支付系统,支持用户自助充值,无需部署独立的支
| **最低金额** | 单笔最低充值金额 | 1 |
| **最高金额** | 单笔最高充值金额(留空表示不限制) | - |
| **每日限额** | 每用户每日累计充值上限(留空表示不限制) | - |
-| **订单超时时间** | 订单超时分钟数,至少 1 分钟 | 5 |
+| **订单超时时间** | 订单超时分钟数,至少 1 分钟 | 30 |
| **最大待支付订单数** | 同一用户最大并行待支付订单数 | 3 |
-| **负载均衡策略** | 多服务商实例时的选择策略 | 最少金额 |
+| **负载均衡策略** | 多服务商实例时的选择策略 | 轮询 |
### 前台可见支付方式路由
diff --git a/docs/superpowers/specs/2026-04-20-auth-identity-payment-foundation-design.md b/docs/superpowers/specs/2026-04-20-auth-identity-payment-foundation-design.md
index 790861b7..23823cf0 100644
--- a/docs/superpowers/specs/2026-04-20-auth-identity-payment-foundation-design.md
+++ b/docs/superpowers/specs/2026-04-20-auth-identity-payment-foundation-design.md
@@ -459,7 +459,7 @@ Frontend submits display method only. Backend resolves display method to active
- force email on third-party signup
- LinuxDo client settings
- OIDC issuer/client settings and provider display name
-- WeChat `open` and `mp` settings with config-valid and health indicators
+- WeChat `open` / `mp` capability indicators derived from environment-backed configuration, surfaced to the frontend/admin read models as effective availability rather than full in-panel credential editing
### Source default settings
@@ -476,7 +476,7 @@ Per source (`email`, `linuxdo`, `oidc`, `wechat`):
- active source for `alipay`
- active source for `wechat`
- source-specific credentials and enablement
-- WeChat capability matrix:
+- effective WeChat payment capabilities may differ by enabled provider instances and selected visible-method source:
- QR available
- H5 available
- MP/JSAPI available
From 09351e945991a582984dfe688b21a47665bd6f5f Mon Sep 17 00:00:00 2001
From: IanShaw027
Date: Tue, 21 Apr 2026 08:23:26 +0800
Subject: [PATCH 091/189] fix auth completion and payment resume hardening
---
.../handler/auth_oauth_pending_flow.go | 20 +++++-
.../handler/auth_oauth_pending_flow_test.go | 62 +++++++++++++++++++
.../service/payment_order_lifecycle.go | 14 +----
.../service/payment_resume_lookup_test.go | 45 ++++++++++++++
.../service/payment_resume_service.go | 27 ++++++++
.../service/payment_resume_service_test.go | 41 ++++++++++++
frontend/src/views/user/PaymentResultView.vue | 26 --------
.../user/__tests__/PaymentResultView.spec.ts | 11 ++--
8 files changed, 199 insertions(+), 47 deletions(-)
diff --git a/backend/internal/handler/auth_oauth_pending_flow.go b/backend/internal/handler/auth_oauth_pending_flow.go
index 6041e5dd..b7ff17c3 100644
--- a/backend/internal/handler/auth_oauth_pending_flow.go
+++ b/backend/internal/handler/auth_oauth_pending_flow.go
@@ -1350,10 +1350,24 @@ func (h *AuthHandler) ExchangePendingOAuthCompletion(c *gin.Context) {
return
}
if !adoptionDecision.hasDecision() {
- response.Success(c, payload)
- return
+ adoptionRequired, _ := payload["adoption_required"].(bool)
+ if adoptionRequired {
+ response.Success(c, payload)
+ return
+ }
}
- decision, err := h.upsertPendingOAuthAdoptionDecision(c, session.ID, adoptionDecision)
+
+ decisionReq := adoptionDecision
+ if !decisionReq.hasDecision() {
+ adoptDisplayName := false
+ adoptAvatar := false
+ decisionReq = oauthAdoptionDecisionRequest{
+ AdoptDisplayName: &adoptDisplayName,
+ AdoptAvatar: &adoptAvatar,
+ }
+ }
+
+ decision, err := h.ensurePendingOAuthAdoptionDecision(c, session.ID, decisionReq)
if err != nil {
response.ErrorFrom(c, err)
return
diff --git a/backend/internal/handler/auth_oauth_pending_flow_test.go b/backend/internal/handler/auth_oauth_pending_flow_test.go
index c2b83c73..913acddc 100644
--- a/backend/internal/handler/auth_oauth_pending_flow_test.go
+++ b/backend/internal/handler/auth_oauth_pending_flow_test.go
@@ -523,6 +523,68 @@ func TestExchangePendingOAuthCompletionLoginFalseFalseBindsIdentityWithoutAdopti
require.NotNil(t, storedSession.ConsumedAt)
}
+func TestExchangePendingOAuthCompletionLoginWithoutDecisionStillBindsIdentity(t *testing.T) {
+ handler, client := newOAuthPendingFlowTestHandler(t, false)
+ ctx := context.Background()
+
+ userEntity, err := client.User.Create().
+ SetEmail("login-nodecision@example.com").
+ SetUsername("legacy-name").
+ SetPasswordHash("hash").
+ SetRole(service.RoleUser).
+ SetStatus(service.StatusActive).
+ Save(ctx)
+ require.NoError(t, err)
+
+ session, err := client.PendingAuthSession.Create().
+ SetSessionToken("login-nodecision-session-token").
+ SetIntent("login").
+ SetProviderType("linuxdo").
+ SetProviderKey("linuxdo").
+ SetProviderSubject("login-nodecision-123").
+ SetTargetUserID(userEntity.ID).
+ SetResolvedEmail(userEntity.Email).
+ SetBrowserSessionKey("login-nodecision-browser-session-key").
+ SetUpstreamIdentityClaims(map[string]any{
+ "username": "login-nodecision-user",
+ }).
+ SetLocalFlowState(map[string]any{
+ oauthCompletionResponseKey: map[string]any{
+ "access_token": "access-token",
+ },
+ }).
+ SetExpiresAt(time.Now().UTC().Add(10 * time.Minute)).
+ Save(ctx)
+ require.NoError(t, err)
+
+ recorder := httptest.NewRecorder()
+ ginCtx, _ := gin.CreateTestContext(recorder)
+ req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/oauth/pending/exchange", nil)
+ req.AddCookie(&http.Cookie{Name: oauthPendingSessionCookieName, Value: encodeCookieValue(session.SessionToken)})
+ req.AddCookie(&http.Cookie{Name: oauthPendingBrowserCookieName, Value: encodeCookieValue("login-nodecision-browser-session-key")})
+ ginCtx.Request = req
+
+ handler.ExchangePendingOAuthCompletion(ginCtx)
+
+ require.Equal(t, http.StatusOK, recorder.Code)
+
+ identity, err := client.AuthIdentity.Query().
+ Where(
+ authidentity.ProviderTypeEQ("linuxdo"),
+ authidentity.ProviderKeyEQ("linuxdo"),
+ authidentity.ProviderSubjectEQ("login-nodecision-123"),
+ ).
+ Only(ctx)
+ require.NoError(t, err)
+ require.Equal(t, userEntity.ID, identity.UserID)
+
+ storedSession, err := client.PendingAuthSession.Query().
+ Where(pendingauthsession.IDEQ(session.ID)).
+ Only(ctx)
+ require.NoError(t, err)
+ require.NotNil(t, storedSession.ConsumedAt)
+}
+
func TestExchangePendingOAuthCompletionBlocksBackendModeBeforeReturningTokenPayload(t *testing.T) {
handler, client := newOAuthPendingFlowTestHandlerWithDependencies(t, oauthPendingFlowTestHandlerOptions{
settingValues: map[string]string{
diff --git a/backend/internal/service/payment_order_lifecycle.go b/backend/internal/service/payment_order_lifecycle.go
index 24d8b7a2..50a56ad0 100644
--- a/backend/internal/service/payment_order_lifecycle.go
+++ b/backend/internal/service/payment_order_lifecycle.go
@@ -190,8 +190,9 @@ func (s *PaymentService) VerifyOrderByOutTradeNo(ctx context.Context, outTradeNo
return o, nil
}
-// VerifyOrderPublic verifies payment status without user authentication.
-// Used by the payment result page when the user's session has expired.
+// VerifyOrderPublic returns the currently persisted public order state without
+// triggering any upstream reconciliation. Signed resume-token recovery is the
+// only public recovery path allowed to query upstream state.
func (s *PaymentService) VerifyOrderPublic(ctx context.Context, outTradeNo string) (*dbent.PaymentOrder, error) {
o, err := s.entClient.PaymentOrder.Query().
Where(paymentorder.OutTradeNo(outTradeNo)).
@@ -199,15 +200,6 @@ func (s *PaymentService) VerifyOrderPublic(ctx context.Context, outTradeNo strin
if err != nil {
return nil, infraerrors.NotFound("NOT_FOUND", "order not found")
}
- if o.Status == OrderStatusPending || o.Status == OrderStatusExpired {
- result := s.checkPaid(ctx, o)
- if result == checkPaidResultAlreadyPaid {
- o, err = s.entClient.PaymentOrder.Get(ctx, o.ID)
- if err != nil {
- return nil, fmt.Errorf("reload order: %w", err)
- }
- }
- }
return o, nil
}
diff --git a/backend/internal/service/payment_resume_lookup_test.go b/backend/internal/service/payment_resume_lookup_test.go
index 39ea24e7..3a50b5bc 100644
--- a/backend/internal/service/payment_resume_lookup_test.go
+++ b/backend/internal/service/payment_resume_lookup_test.go
@@ -200,3 +200,48 @@ func TestGetPublicOrderByResumeTokenChecksUpstreamForPendingOrder(t *testing.T)
require.Equal(t, order.ID, got.ID)
require.Equal(t, 1, provider.queryCount)
}
+
+func TestVerifyOrderPublicDoesNotCheckUpstreamForPendingOrder(t *testing.T) {
+ ctx := context.Background()
+ client := newPaymentConfigServiceTestClient(t)
+ user, err := client.User.Create().
+ SetEmail("public-verify@example.com").
+ SetPasswordHash("hash").
+ SetUsername("public-verify-user").
+ Save(ctx)
+ require.NoError(t, err)
+
+ order, err := client.PaymentOrder.Create().
+ SetUserID(user.ID).
+ SetUserEmail(user.Email).
+ SetUserName(user.Username).
+ SetAmount(88).
+ SetPayAmount(88).
+ SetFeeRate(0).
+ SetRechargeCode("PUBLIC-VERIFY").
+ SetOutTradeNo("sub2_public_verify_pending").
+ SetPaymentType(payment.TypeAlipay).
+ SetPaymentTradeNo("trade-public-verify").
+ SetOrderType(payment.OrderTypeBalance).
+ SetStatus(OrderStatusPending).
+ SetExpiresAt(time.Now().Add(time.Hour)).
+ SetClientIP("127.0.0.1").
+ SetSrcHost("api.example.com").
+ Save(ctx)
+ require.NoError(t, err)
+
+ registry := payment.NewRegistry()
+ provider := &paymentResumeLookupProvider{}
+ registry.Register(provider)
+
+ svc := &PaymentService{
+ entClient: client,
+ registry: registry,
+ providersLoaded: true,
+ }
+
+ got, err := svc.VerifyOrderPublic(ctx, order.OutTradeNo)
+ require.NoError(t, err)
+ require.Equal(t, order.ID, got.ID)
+ require.Equal(t, 0, provider.queryCount)
+}
diff --git a/backend/internal/service/payment_resume_service.go b/backend/internal/service/payment_resume_service.go
index 05997b38..486aaac0 100644
--- a/backend/internal/service/payment_resume_service.go
+++ b/backend/internal/service/payment_resume_service.go
@@ -36,6 +36,9 @@ const (
paymentResumeNotConfiguredCode = "PAYMENT_RESUME_NOT_CONFIGURED"
paymentResumeNotConfiguredMessage = "payment resume tokens require a configured signing key"
+
+ paymentResumeTokenTTL = 24 * time.Hour
+ wechatPaymentResumeTokenTTL = 15 * time.Minute
)
type ResumeTokenClaims struct {
@@ -46,6 +49,7 @@ type ResumeTokenClaims struct {
PaymentType string `json:"pt,omitempty"`
CanonicalReturnURL string `json:"ru,omitempty"`
IssuedAt int64 `json:"iat"`
+ ExpiresAt int64 `json:"exp,omitempty"`
}
type WeChatPaymentResumeClaims struct {
@@ -58,6 +62,7 @@ type WeChatPaymentResumeClaims struct {
RedirectTo string `json:"rd,omitempty"`
Scope string `json:"scp,omitempty"`
IssuedAt int64 `json:"iat"`
+ ExpiresAt int64 `json:"exp,omitempty"`
}
type PaymentResumeService struct {
@@ -263,6 +268,9 @@ func (s *PaymentResumeService) CreateToken(claims ResumeTokenClaims) (string, er
if claims.IssuedAt == 0 {
claims.IssuedAt = time.Now().Unix()
}
+ if claims.ExpiresAt == 0 {
+ claims.ExpiresAt = time.Now().Add(paymentResumeTokenTTL).Unix()
+ }
return s.createSignedToken(claims)
}
@@ -277,6 +285,9 @@ func (s *PaymentResumeService) ParseToken(token string) (*ResumeTokenClaims, err
if claims.OrderID <= 0 {
return nil, infraerrors.BadRequest("INVALID_RESUME_TOKEN", "resume token missing order id")
}
+ if err := validatePaymentResumeExpiry(claims.ExpiresAt, "INVALID_RESUME_TOKEN", "resume token has expired"); err != nil {
+ return nil, err
+ }
return &claims, nil
}
@@ -291,6 +302,9 @@ func (s *PaymentResumeService) CreateWeChatPaymentResumeToken(claims WeChatPayme
if claims.IssuedAt == 0 {
claims.IssuedAt = time.Now().Unix()
}
+ if claims.ExpiresAt == 0 {
+ claims.ExpiresAt = time.Now().Add(wechatPaymentResumeTokenTTL).Unix()
+ }
if normalized := NormalizeVisibleMethod(claims.PaymentType); normalized != "" {
claims.PaymentType = normalized
}
@@ -319,6 +333,9 @@ func (s *PaymentResumeService) ParseWeChatPaymentResumeToken(token string) (*WeC
if claims.OpenID == "" {
return nil, infraerrors.BadRequest("INVALID_WECHAT_PAYMENT_RESUME_TOKEN", "wechat payment resume token missing openid")
}
+ if err := validatePaymentResumeExpiry(claims.ExpiresAt, "INVALID_WECHAT_PAYMENT_RESUME_TOKEN", "wechat payment resume token has expired"); err != nil {
+ return nil, err
+ }
if normalized := NormalizeVisibleMethod(claims.PaymentType); normalized != "" {
claims.PaymentType = normalized
}
@@ -355,6 +372,16 @@ func (s *PaymentResumeService) parseSignedToken(token string, dest any) error {
return json.Unmarshal(payload, dest)
}
+func validatePaymentResumeExpiry(expiresAt int64, code, message string) error {
+ if expiresAt <= 0 {
+ return nil
+ }
+ if time.Now().Unix() > expiresAt {
+ return infraerrors.BadRequest(code, message)
+ }
+ return nil
+}
+
func (s *PaymentResumeService) sign(payload string) string {
mac := hmac.New(sha256.New, s.signingKey)
_, _ = mac.Write([]byte(payload))
diff --git a/backend/internal/service/payment_resume_service_test.go b/backend/internal/service/payment_resume_service_test.go
index 4ac89f0f..12d67be2 100644
--- a/backend/internal/service/payment_resume_service_test.go
+++ b/backend/internal/service/payment_resume_service_test.go
@@ -11,6 +11,7 @@ import (
"net/url"
"strconv"
"testing"
+ "time"
"github.com/Wei-Shaw/sub2api/internal/payment"
)
@@ -175,6 +176,26 @@ func TestParseTokenRejectsFallbackSignedTokenWhenSigningKeyMissing(t *testing.T)
}
}
+func TestParseTokenRejectsExpiredToken(t *testing.T) {
+ t.Parallel()
+
+ svc := NewPaymentResumeService([]byte("0123456789abcdef0123456789abcdef"))
+ token, err := svc.CreateToken(ResumeTokenClaims{
+ OrderID: 42,
+ UserID: 7,
+ IssuedAt: time.Now().Add(-25 * time.Hour).Unix(),
+ ExpiresAt: time.Now().Add(-1 * time.Hour).Unix(),
+ })
+ if err != nil {
+ t.Fatalf("CreateToken returned error: %v", err)
+ }
+
+ _, err = svc.ParseToken(token)
+ if err == nil {
+ t.Fatal("ParseToken should reject expired tokens")
+ }
+}
+
func TestWeChatPaymentResumeTokenRoundTrip(t *testing.T) {
t.Parallel()
@@ -233,6 +254,26 @@ func TestParseWeChatPaymentResumeTokenRejectsFallbackSignedTokenWhenSigningKeyMi
}
}
+func TestParseWeChatPaymentResumeTokenRejectsExpiredToken(t *testing.T) {
+ t.Parallel()
+
+ svc := NewPaymentResumeService([]byte("0123456789abcdef0123456789abcdef"))
+ token, err := svc.CreateWeChatPaymentResumeToken(WeChatPaymentResumeClaims{
+ OpenID: "openid-123",
+ PaymentType: payment.TypeWxpay,
+ IssuedAt: time.Now().Add(-30 * time.Minute).Unix(),
+ ExpiresAt: time.Now().Add(-1 * time.Minute).Unix(),
+ })
+ if err != nil {
+ t.Fatalf("CreateWeChatPaymentResumeToken returned error: %v", err)
+ }
+
+ _, err = svc.ParseWeChatPaymentResumeToken(token)
+ if err == nil {
+ t.Fatal("ParseWeChatPaymentResumeToken should reject expired tokens")
+ }
+}
+
func TestNormalizeVisibleMethodSource(t *testing.T) {
t.Parallel()
diff --git a/frontend/src/views/user/PaymentResultView.vue b/frontend/src/views/user/PaymentResultView.vue
index 5177a7fd..57e81f40 100644
--- a/frontend/src/views/user/PaymentResultView.vue
+++ b/frontend/src/views/user/PaymentResultView.vue
@@ -219,7 +219,6 @@ onMounted(async () => {
const routeOrderId = Number(route.query.order_id) || 0
const outTradeNo = String(route.query.out_trade_no || '')
let orderId = 0
- let canUseLegacyPublicVerify = false
if (resumeToken && typeof window !== 'undefined') {
const restored = readPaymentRecoverySnapshot(
@@ -264,23 +263,12 @@ onMounted(async () => {
const hasLegacyFallbackContext = typeof route.query.trade_status === 'string'
&& route.query.trade_status.trim() !== ''
if (!order.value && !resumeToken && !orderId && outTradeNo && hasLegacyFallbackContext) {
- canUseLegacyPublicVerify = true
returnInfo.value = {
outTradeNo,
money: String(route.query.money || ''),
type: String(route.query.type || ''),
tradeStatus: String(route.query.trade_status || ''),
}
-
- try {
- const result = await paymentAPI.verifyOrderPublic(outTradeNo)
- order.value = result.data
- } catch (_err: unknown) {
- try {
- const result = await paymentAPI.verifyOrder(outTradeNo)
- order.value = result.data
- } catch (_e: unknown) { /* fall through */ }
- }
}
const refreshOrder = async (): Promise => {
@@ -292,20 +280,6 @@ onMounted(async () => {
return await paymentStore.pollOrderStatus(orderId)
}
- if (canUseLegacyPublicVerify && outTradeNo) {
- try {
- const result = await paymentAPI.verifyOrderPublic(outTradeNo)
- return result.data
- } catch (_err: unknown) {
- try {
- const result = await paymentAPI.verifyOrder(outTradeNo)
- return result.data
- } catch (_e: unknown) {
- return null
- }
- }
- }
-
return null
}
diff --git a/frontend/src/views/user/__tests__/PaymentResultView.spec.ts b/frontend/src/views/user/__tests__/PaymentResultView.spec.ts
index 86c55e25..56e64793 100644
--- a/frontend/src/views/user/__tests__/PaymentResultView.spec.ts
+++ b/frontend/src/views/user/__tests__/PaymentResultView.spec.ts
@@ -225,16 +225,13 @@ describe('PaymentResultView', () => {
expect(verifyOrder).not.toHaveBeenCalled()
})
- it('keeps legacy out_trade_no verification as a fallback when no order context is available', async () => {
+ it('does not use anonymous out_trade_no verification when no signed resume context is available', async () => {
routeState.query = {
out_trade_no: 'legacy-123',
trade_status: 'TRADE_SUCCESS',
}
- verifyOrderPublic.mockResolvedValue({
- data: orderFactory('PAID'),
- })
- const wrapper = mount(PaymentResultView, {
+ mount(PaymentResultView, {
global: {
stubs: {
OrderStatusBadge: true,
@@ -244,8 +241,8 @@ describe('PaymentResultView', () => {
await flushPromises()
- expect(verifyOrderPublic).toHaveBeenCalledWith('legacy-123')
- expect(wrapper.text()).toContain('payment.result.success')
+ expect(verifyOrderPublic).not.toHaveBeenCalled()
+ expect(verifyOrder).not.toHaveBeenCalled()
})
it('does not use public out_trade_no verification for bare order numbers without legacy return markers', async () => {
From 2626e8f22c689c018a73d760eba34f33e5813334 Mon Sep 17 00:00:00 2001
From: IanShaw027
Date: Tue, 21 Apr 2026 08:28:48 +0800
Subject: [PATCH 092/189] fix legacy email identity backfill grants
---
backend/internal/service/auth_service.go | 15 +++++--
.../auth_service_identity_sync_test.go | 44 +++++++++----------
2 files changed, 31 insertions(+), 28 deletions(-)
diff --git a/backend/internal/service/auth_service.go b/backend/internal/service/auth_service.go
index 5b6e5fef..c3d9cc32 100644
--- a/backend/internal/service/auth_service.go
+++ b/backend/internal/service/auth_service.go
@@ -796,7 +796,7 @@ func (s *AuthService) backfillEmailIdentityOnSuccessfulLogin(ctx context.Context
if s == nil || user == nil || user.ID <= 0 {
return
}
- identity, created := s.ensureEmailAuthIdentity(ctx, user)
+ identity, created := s.ensureEmailAuthIdentity(ctx, user, "auth_service_login_backfill")
if s.shouldApplyEmailFirstBindDefaults(ctx, user.ID, identity, created) {
if err := s.ApplyProviderDefaultSettingsOnFirstBind(ctx, user.ID, "email"); err != nil {
logger.LegacyPrintf("service.auth", "[Auth] Failed to apply email first bind defaults: user_id=%d err=%v", user.ID, err)
@@ -810,13 +810,17 @@ func (s *AuthService) shouldApplyEmailFirstBindDefaults(
identity *dbent.AuthIdentity,
created bool,
) bool {
+ source := emailAuthIdentitySource(identity.Metadata)
+ if source == "auth_service_login_backfill" {
+ return false
+ }
if created {
return true
}
if s == nil || s.entClient == nil || userID <= 0 || identity == nil || identity.UserID != userID {
return false
}
- if emailAuthIdentitySource(identity.Metadata) != "auth_service_dual_write" {
+ if source != "auth_service_dual_write" {
return false
}
@@ -863,7 +867,7 @@ func (s *AuthService) hasProviderGrantRecord(
return rows.Next(), rows.Err()
}
-func (s *AuthService) ensureEmailAuthIdentity(ctx context.Context, user *User) (*dbent.AuthIdentity, bool) {
+func (s *AuthService) ensureEmailAuthIdentity(ctx context.Context, user *User, source string) (*dbent.AuthIdentity, bool) {
if s == nil || s.entClient == nil || user == nil || user.ID <= 0 {
return nil, false
}
@@ -872,6 +876,9 @@ func (s *AuthService) ensureEmailAuthIdentity(ctx context.Context, user *User) (
if email == "" || isReservedEmail(email) {
return nil, false
}
+ if strings.TrimSpace(source) == "" {
+ source = "auth_service_dual_write"
+ }
client := s.entClient
if tx := dbent.TxFromContext(ctx); tx != nil {
@@ -900,7 +907,7 @@ func (s *AuthService) ensureEmailAuthIdentity(ctx context.Context, user *User) (
SetProviderSubject(email).
SetVerifiedAt(time.Now().UTC()).
SetMetadata(map[string]any{
- "source": "auth_service_dual_write",
+ "source": strings.TrimSpace(source),
}).
OnConflictColumns(
authidentity.FieldProviderType,
diff --git a/backend/internal/service/auth_service_identity_sync_test.go b/backend/internal/service/auth_service_identity_sync_test.go
index 4d2a840f..2233e427 100644
--- a/backend/internal/service/auth_service_identity_sync_test.go
+++ b/backend/internal/service/auth_service_identity_sync_test.go
@@ -259,7 +259,7 @@ func TestAuthServiceRecordSuccessfulLoginBackfillsEmailIdentity(t *testing.T) {
require.Equal(t, user.ID, identity.UserID)
}
-func TestAuthServiceLogin_AppliesEmailFirstBindDefaultsOnlyWhenEmailIdentityIsNew(t *testing.T) {
+func TestAuthServiceLogin_DoesNotApplyEmailFirstBindDefaultsWhenBackfillingLegacyEmailIdentity(t *testing.T) {
assigner := &authIdentityDefaultSubAssignerStub{}
svc, _, client := newAuthServiceWithEnt(t, map[string]string{
service.SettingKeyRegistrationEnabled: "true",
@@ -291,11 +291,9 @@ func TestAuthServiceLogin_AppliesEmailFirstBindDefaultsOnlyWhenEmailIdentityIsNe
storedUser, err := client.User.Get(ctx, user.ID)
require.NoError(t, err)
- require.Equal(t, 10.0, storedUser.Balance)
- require.Equal(t, 6, storedUser.Concurrency)
- require.Len(t, assigner.calls, 1)
- require.Equal(t, int64(11), assigner.calls[0].GroupID)
- require.Equal(t, 30, assigner.calls[0].ValidityDays)
+ require.Equal(t, 1.5, storedUser.Balance)
+ require.Equal(t, 2, storedUser.Concurrency)
+ require.Empty(t, assigner.calls)
identityCount, err := client.AuthIdentity.Query().
Where(
@@ -306,7 +304,7 @@ func TestAuthServiceLogin_AppliesEmailFirstBindDefaultsOnlyWhenEmailIdentityIsNe
Count(ctx)
require.NoError(t, err)
require.Equal(t, 1, identityCount)
- require.Equal(t, 1, countProviderGrantRecords(t, client, user.ID, "email", "first_bind"))
+ require.Equal(t, 0, countProviderGrantRecords(t, client, user.ID, "email", "first_bind"))
token, gotUser, err = svc.Login(ctx, user.Email, "password")
require.NoError(t, err)
@@ -315,13 +313,13 @@ func TestAuthServiceLogin_AppliesEmailFirstBindDefaultsOnlyWhenEmailIdentityIsNe
storedUser, err = client.User.Get(ctx, user.ID)
require.NoError(t, err)
- require.Equal(t, 10.0, storedUser.Balance)
- require.Equal(t, 6, storedUser.Concurrency)
- require.Len(t, assigner.calls, 1)
- require.Equal(t, 1, countProviderGrantRecords(t, client, user.ID, "email", "first_bind"))
+ require.Equal(t, 1.5, storedUser.Balance)
+ require.Equal(t, 2, storedUser.Concurrency)
+ require.Empty(t, assigner.calls)
+ require.Equal(t, 0, countProviderGrantRecords(t, client, user.ID, "email", "first_bind"))
}
-func TestAuthServiceLogin_MergesEmailFirstBindSourceOverridesWithGlobalDefaults(t *testing.T) {
+func TestAuthServiceLogin_DoesNotApplyMergedEmailFirstBindDefaultsWhenBackfillingLegacyEmailIdentity(t *testing.T) {
assigner := &authIdentityDefaultSubAssignerStub{}
svc, _, client := newAuthServiceWithEnt(t, map[string]string{
service.SettingKeyRegistrationEnabled: "true",
@@ -354,12 +352,10 @@ func TestAuthServiceLogin_MergesEmailFirstBindSourceOverridesWithGlobalDefaults(
storedUser, err := client.User.Get(ctx, user.ID)
require.NoError(t, err)
- require.Equal(t, 10.0, storedUser.Balance)
- require.Equal(t, 4, storedUser.Concurrency)
- require.Len(t, assigner.calls, 1)
- require.Equal(t, int64(21), assigner.calls[0].GroupID)
- require.Equal(t, 14, assigner.calls[0].ValidityDays)
- require.Equal(t, 1, countProviderGrantRecords(t, client, user.ID, "email", "first_bind"))
+ require.Equal(t, 1.5, storedUser.Balance)
+ require.Equal(t, 2, storedUser.Concurrency)
+ require.Empty(t, assigner.calls)
+ require.Equal(t, 0, countProviderGrantRecords(t, client, user.ID, "email", "first_bind"))
}
func TestAuthServiceLogin_DoesNotApplyEmailFirstBindDefaultsWhenIdentityAlreadyExists(t *testing.T) {
@@ -409,7 +405,7 @@ func TestAuthServiceLogin_DoesNotApplyEmailFirstBindDefaultsWhenIdentityAlreadyE
require.Equal(t, 0, countProviderGrantRecords(t, client, user.ID, "email", "first_bind"))
}
-func TestAuthServiceLogin_RetriesEmailFirstBindDefaultsAfterPreviousFailure(t *testing.T) {
+func TestAuthServiceLogin_DoesNotRetryEmailFirstBindDefaultsForBackfilledEmailIdentity(t *testing.T) {
assigner := &flakyAuthIdentityDefaultSubAssignerStub{failuresRemaining: 1}
svc, _, client := newAuthServiceWithEnt(t, map[string]string{
service.SettingKeyRegistrationEnabled: "true",
@@ -443,7 +439,7 @@ func TestAuthServiceLogin_RetriesEmailFirstBindDefaultsAfterPreviousFailure(t *t
require.NoError(t, err)
require.Equal(t, 1.5, storedUser.Balance)
require.Equal(t, 2, storedUser.Concurrency)
- require.Len(t, assigner.calls, 1)
+ require.Empty(t, assigner.calls)
require.Equal(t, 0, countProviderGrantRecords(t, client, user.ID, "email", "first_bind"))
token, gotUser, err = svc.Login(ctx, user.Email, "password")
@@ -454,10 +450,10 @@ func TestAuthServiceLogin_RetriesEmailFirstBindDefaultsAfterPreviousFailure(t *t
storedUser, err = client.User.Get(ctx, user.ID)
require.NoError(t, err)
- require.Equal(t, 10.0, storedUser.Balance)
- require.Equal(t, 6, storedUser.Concurrency)
- require.Len(t, assigner.calls, 2)
- require.Equal(t, 1, countProviderGrantRecords(t, client, user.ID, "email", "first_bind"))
+ require.Equal(t, 1.5, storedUser.Balance)
+ require.Equal(t, 2, storedUser.Concurrency)
+ require.Empty(t, assigner.calls)
+ require.Equal(t, 0, countProviderGrantRecords(t, client, user.ID, "email", "first_bind"))
}
func countProviderGrantRecords(
From 07f23aaa7d413b6077194413bdecacf9874a3720 Mon Sep 17 00:00:00 2001
From: IanShaw027
Date: Tue, 21 Apr 2026 08:35:53 +0800
Subject: [PATCH 093/189] fix wxpay config contract and h5 scene info
---
backend/internal/payment/provider/wxpay.go | 15 +++-
.../internal/payment/provider/wxpay_test.go | 83 +++++++++++++++++++
.../payment/__tests__/providerConfig.spec.ts | 20 +++++
.../src/components/payment/providerConfig.ts | 7 +-
frontend/src/i18n/locales/en.ts | 3 +
frontend/src/i18n/locales/zh.ts | 3 +
6 files changed, 127 insertions(+), 4 deletions(-)
create mode 100644 frontend/src/components/payment/__tests__/providerConfig.spec.ts
diff --git a/backend/internal/payment/provider/wxpay.go b/backend/internal/payment/provider/wxpay.go
index 47569ee3..7d51dff0 100644
--- a/backend/internal/payment/provider/wxpay.go
+++ b/backend/internal/payment/provider/wxpay.go
@@ -250,13 +250,12 @@ func (w *Wxpay) prepayNative(ctx context.Context, c *core.Client, req payment.Cr
func (w *Wxpay) prepayH5(ctx context.Context, c *core.Client, req payment.CreatePaymentRequest, notifyURL string, totalFen int64) (*payment.CreatePaymentResponse, error) {
svc := h5.H5ApiService{Client: c}
cur := wxpayCurrency
- tp := wxpayH5Type
resp, _, err := wxpayH5Prepay(ctx, svc, h5.PrepayRequest{
Appid: core.String(w.config["appId"]), Mchid: core.String(w.config["mchId"]),
Description: core.String(req.Subject), OutTradeNo: core.String(req.OrderID),
NotifyUrl: core.String(notifyURL),
Amount: &h5.Amount{Total: core.Int64(totalFen), Currency: &cur},
- SceneInfo: &h5.SceneInfo{PayerClientIp: core.String(req.ClientIP), H5Info: &h5.H5Info{Type: &tp}},
+ SceneInfo: &h5.SceneInfo{PayerClientIp: core.String(req.ClientIP), H5Info: buildWxpayH5Info(w.config)},
})
if err != nil {
return nil, fmt.Errorf("wxpay h5 prepay: %w", err)
@@ -272,6 +271,18 @@ func (w *Wxpay) prepayH5(ctx context.Context, c *core.Client, req payment.Create
return &payment.CreatePaymentResponse{TradeNo: req.OrderID, PayURL: h5URL}, nil
}
+func buildWxpayH5Info(config map[string]string) *h5.H5Info {
+ tp := wxpayH5Type
+ info := &h5.H5Info{Type: &tp}
+ if appName := strings.TrimSpace(config["h5AppName"]); appName != "" {
+ info.AppName = core.String(appName)
+ }
+ if appURL := strings.TrimSpace(config["h5AppUrl"]); appURL != "" {
+ info.AppUrl = core.String(appURL)
+ }
+ return info
+}
+
func resolveWxpayCreateMode(req payment.CreatePaymentRequest) (string, error) {
if strings.TrimSpace(req.OpenID) != "" {
return wxpayModeJSAPI, nil
diff --git a/backend/internal/payment/provider/wxpay_test.go b/backend/internal/payment/provider/wxpay_test.go
index 8489e261..6d0006be 100644
--- a/backend/internal/payment/provider/wxpay_test.go
+++ b/backend/internal/payment/provider/wxpay_test.go
@@ -487,3 +487,86 @@ func TestCreatePaymentWithOpenIDReturnsJSAPIResult(t *testing.T) {
t.Fatalf("jsapi paySign = %q, want %q", resp.JSAPI.PaySign, "signed-payload")
}
}
+
+func TestCreatePaymentMobileH5IncludesConfiguredSceneInfo(t *testing.T) {
+ origJSAPIPrepay := wxpayJSAPIPrepayWithRequestPayment
+ origNativePrepay := wxpayNativePrepay
+ origH5Prepay := wxpayH5Prepay
+ t.Cleanup(func() {
+ wxpayJSAPIPrepayWithRequestPayment = origJSAPIPrepay
+ wxpayNativePrepay = origNativePrepay
+ wxpayH5Prepay = origH5Prepay
+ })
+
+ jsapiCalls := 0
+ nativeCalls := 0
+ h5Calls := 0
+ wxpayJSAPIPrepayWithRequestPayment = func(ctx context.Context, svc jsapi.JsapiApiService, req jsapi.PrepayRequest) (*jsapi.PrepayWithRequestPaymentResponse, *core.APIResult, error) {
+ jsapiCalls++
+ return &jsapi.PrepayWithRequestPaymentResponse{}, nil, nil
+ }
+ wxpayNativePrepay = func(ctx context.Context, svc native.NativeApiService, req native.PrepayRequest) (*native.PrepayResponse, *core.APIResult, error) {
+ nativeCalls++
+ return &native.PrepayResponse{}, nil, nil
+ }
+ wxpayH5Prepay = func(ctx context.Context, svc h5.H5ApiService, req h5.PrepayRequest) (*h5.PrepayResponse, *core.APIResult, error) {
+ h5Calls++
+ if req.SceneInfo == nil {
+ t.Fatal("expected scene_info, got nil")
+ }
+ if got := wxSV(req.SceneInfo.PayerClientIp); got != "203.0.113.10" {
+ t.Fatalf("scene_info payer_client_ip = %q, want %q", got, "203.0.113.10")
+ }
+ if req.SceneInfo.H5Info == nil {
+ t.Fatal("expected scene_info.h5_info, got nil")
+ }
+ if got := wxSV(req.SceneInfo.H5Info.Type); got != wxpayH5Type {
+ t.Fatalf("scene_info.h5_info.type = %q, want %q", got, wxpayH5Type)
+ }
+ if got := wxSV(req.SceneInfo.H5Info.AppName); got != "Sub2API" {
+ t.Fatalf("scene_info.h5_info.app_name = %q, want %q", got, "Sub2API")
+ }
+ if got := wxSV(req.SceneInfo.H5Info.AppUrl); got != "https://app.example.com" {
+ t.Fatalf("scene_info.h5_info.app_url = %q, want %q", got, "https://app.example.com")
+ }
+ return &h5.PrepayResponse{
+ H5Url: core.String("https://wx.tenpay.example/h5pay?prepay_id=1"),
+ }, nil, nil
+ }
+
+ provider := &Wxpay{
+ config: map[string]string{
+ "appId": "wx123",
+ "mchId": "mch123",
+ "h5AppName": "Sub2API",
+ "h5AppUrl": "https://app.example.com",
+ },
+ coreClient: &core.Client{},
+ }
+
+ resp, err := provider.CreatePayment(context.Background(), payment.CreatePaymentRequest{
+ OrderID: "sub2_99",
+ Amount: "66.88",
+ PaymentType: payment.TypeWxpay,
+ Subject: "Balance Recharge",
+ NotifyURL: "https://merchant.example/payment/notify",
+ ReturnURL: "https://merchant.example/payment/result?resume_token=resume-99",
+ ClientIP: "203.0.113.10",
+ IsMobile: true,
+ })
+ if err != nil {
+ t.Fatalf("unexpected error: %v", err)
+ }
+ if jsapiCalls != 0 {
+ t.Fatalf("jsapi prepay calls = %d, want 0", jsapiCalls)
+ }
+ if nativeCalls != 0 {
+ t.Fatalf("native prepay calls = %d, want 0", nativeCalls)
+ }
+ if h5Calls != 1 {
+ t.Fatalf("h5 prepay calls = %d, want 1", h5Calls)
+ }
+ if !strings.Contains(resp.PayURL, "redirect_url=") {
+ t.Fatalf("pay_url = %q, want redirect_url query appended", resp.PayURL)
+ }
+}
diff --git a/frontend/src/components/payment/__tests__/providerConfig.spec.ts b/frontend/src/components/payment/__tests__/providerConfig.spec.ts
new file mode 100644
index 00000000..6a4c9c26
--- /dev/null
+++ b/frontend/src/components/payment/__tests__/providerConfig.spec.ts
@@ -0,0 +1,20 @@
+import { describe, expect, it } from 'vitest'
+import { PROVIDER_CONFIG_FIELDS } from '@/components/payment/providerConfig'
+
+function findField(key: string) {
+ const fields = PROVIDER_CONFIG_FIELDS.wxpay || []
+ return fields.find(field => field.key === key)
+}
+
+describe('PROVIDER_CONFIG_FIELDS.wxpay', () => {
+ it('keeps admin form validation aligned with backend-required credentials', () => {
+ expect(findField('publicKeyId')?.optional).toBeFalsy()
+ expect(findField('certSerial')?.optional).toBeFalsy()
+ })
+
+ it('exposes optional mp and H5 metadata fields for WeChat-specific flows', () => {
+ expect(findField('mpAppId')?.optional).toBe(true)
+ expect(findField('h5AppName')?.optional).toBe(true)
+ expect(findField('h5AppUrl')?.optional).toBe(true)
+ })
+})
diff --git a/frontend/src/components/payment/providerConfig.ts b/frontend/src/components/payment/providerConfig.ts
index a83787fd..5d52eddf 100644
--- a/frontend/src/components/payment/providerConfig.ts
+++ b/frontend/src/components/payment/providerConfig.ts
@@ -83,12 +83,15 @@ export const PROVIDER_CONFIG_FIELDS: Record = {
],
wxpay: [
{ key: 'appId', label: 'App ID', sensitive: false },
+ { key: 'mpAppId', label: '', sensitive: false, optional: true },
{ key: 'mchId', label: '', sensitive: false },
{ key: 'privateKey', label: '', sensitive: true },
{ key: 'apiV3Key', label: '', sensitive: true },
{ key: 'publicKey', label: '', sensitive: true },
- { key: 'publicKeyId', label: '', sensitive: false, optional: true },
- { key: 'certSerial', label: '', sensitive: false, optional: true },
+ { key: 'publicKeyId', label: '', sensitive: false },
+ { key: 'certSerial', label: '', sensitive: false },
+ { key: 'h5AppName', label: '', sensitive: false, optional: true },
+ { key: 'h5AppUrl', label: '', sensitive: false, optional: true },
],
stripe: [
{ key: 'secretKey', label: '', sensitive: true },
diff --git a/frontend/src/i18n/locales/en.ts b/frontend/src/i18n/locales/en.ts
index 10160d73..78db9e8a 100644
--- a/frontend/src/i18n/locales/en.ts
+++ b/frontend/src/i18n/locales/en.ts
@@ -4655,10 +4655,13 @@ export default {
callbackBaseUrl: 'Callback Base URL',
field_privateKey: 'Private Key',
field_publicKey: 'Public Key',
+ field_mpAppId: 'MP App ID',
field_mchId: 'Merchant ID',
field_apiV3Key: 'API v3 Key',
field_publicKeyId: 'Public Key ID',
field_certSerial: 'Certificate Serial',
+ field_h5AppName: 'H5 App Name',
+ field_h5AppUrl: 'H5 App URL',
field_secretKey: 'Secret Key',
field_publishableKey: 'Publishable Key',
field_webhookSecret: 'Webhook Secret',
diff --git a/frontend/src/i18n/locales/zh.ts b/frontend/src/i18n/locales/zh.ts
index 8676df4b..9bb265e1 100644
--- a/frontend/src/i18n/locales/zh.ts
+++ b/frontend/src/i18n/locales/zh.ts
@@ -4819,10 +4819,13 @@ export default {
callbackBaseUrl: '回调基础地址',
field_privateKey: '私钥',
field_publicKey: '公钥',
+ field_mpAppId: '公众号 App ID',
field_mchId: '商户号',
field_apiV3Key: 'API v3 密钥',
field_publicKeyId: '公钥 ID',
field_certSerial: '证书序列号',
+ field_h5AppName: 'H5 应用名称',
+ field_h5AppUrl: 'H5 应用地址',
field_secretKey: '密钥',
field_publishableKey: '公开密钥',
field_webhookSecret: 'Webhook 密钥',
From 6da08262d733b83cc15dcc21e32083398d9cf21d Mon Sep 17 00:00:00 2001
From: IanShaw027
Date: Tue, 21 Apr 2026 08:53:59 +0800
Subject: [PATCH 094/189] feat avatar compress uploads to 20kb
---
backend/go.mod | 18 +--
backend/go.sum | 46 +++-----
backend/internal/service/user_service.go | 55 +++++++++
backend/internal/service/user_service_test.go | 54 +++++++++
.../user/profile/ProfileInfoCard.vue | 72 +++++++++++-
.../profile/__tests__/ProfileInfoCard.spec.ts | 104 +++++++++++++++++-
frontend/src/i18n/locales/en.ts | 8 +-
frontend/src/i18n/locales/zh.ts | 8 +-
8 files changed, 321 insertions(+), 44 deletions(-)
diff --git a/backend/go.mod b/backend/go.mod
index 66b6cc25..627851bf 100644
--- a/backend/go.mod
+++ b/backend/go.mod
@@ -39,10 +39,11 @@ require (
github.com/wechatpay-apiv3/wechatpay-go v0.2.21
github.com/zeromicro/go-zero v1.9.4
go.uber.org/zap v1.24.0
- golang.org/x/crypto v0.48.0
- golang.org/x/net v0.49.0
- golang.org/x/sync v0.19.0
- golang.org/x/term v0.40.0
+ golang.org/x/crypto v0.49.0
+ golang.org/x/image v0.39.0
+ golang.org/x/net v0.52.0
+ golang.org/x/sync v0.20.0
+ golang.org/x/term v0.41.0
gopkg.in/natefinch/lumberjack.v2 v2.2.1
gopkg.in/yaml.v3 v3.0.1
modernc.org/sqlite v1.44.3
@@ -103,7 +104,6 @@ require (
github.com/goccy/go-json v0.10.2 // indirect
github.com/google/go-cmp v0.7.0 // indirect
github.com/google/go-querystring v1.1.0 // indirect
- github.com/google/subcommands v1.2.0 // indirect
github.com/grpc-ecosystem/grpc-gateway/v2 v2.27.3 // indirect
github.com/hashicorp/hcl v1.0.0 // indirect
github.com/hashicorp/hcl/v2 v2.18.1 // indirect
@@ -172,10 +172,10 @@ require (
go.uber.org/multierr v1.9.0 // indirect
golang.org/x/arch v0.3.0 // indirect
golang.org/x/exp v0.0.0-20251023183803-a4bb9ffd2546 // indirect
- golang.org/x/mod v0.32.0 // indirect
- golang.org/x/sys v0.41.0 // indirect
- golang.org/x/text v0.34.0 // indirect
- golang.org/x/tools v0.41.0 // indirect
+ golang.org/x/mod v0.34.0 // indirect
+ golang.org/x/sys v0.42.0 // indirect
+ golang.org/x/text v0.36.0 // indirect
+ golang.org/x/tools v0.43.0 // indirect
google.golang.org/grpc v1.75.1 // indirect
google.golang.org/protobuf v1.36.10 // indirect
gopkg.in/ini.v1 v1.67.0 // indirect
diff --git a/backend/go.sum b/backend/go.sum
index 9312af63..f1c864f5 100644
--- a/backend/go.sum
+++ b/backend/go.sum
@@ -162,8 +162,6 @@ github.com/google/go-querystring v1.1.0/go.mod h1:Kcdr2DB4koayq7X8pmAG4sNG59So17
github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg=
github.com/google/pprof v0.0.0-20250317173921-a4b03ec1a45e h1:ijClszYn+mADRFY17kjQEVQ1XRhq2/JR1M3sGqeJoxs=
github.com/google/pprof v0.0.0-20250317173921-a4b03ec1a45e/go.mod h1:boTsfXsheKC2y+lKOCMpSfarhxDeIzfZG1jqGcPl3cA=
-github.com/google/subcommands v1.2.0 h1:vWQspBTo2nEqTUFita5/KeEWlUL8kQObDFbub/EN9oE=
-github.com/google/subcommands v1.2.0/go.mod h1:ZjhPrFU+Olkh9WazFPsl27BQ4UPiG37m3yTrtFlrHVk=
github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
github.com/google/wire v0.7.0 h1:JxUKI6+CVBgCO2WToKy/nQk0sS+amI9z9EjVmdaocj4=
@@ -183,8 +181,6 @@ github.com/icholy/digest v1.1.0 h1:HfGg9Irj7i+IX1o1QAmPfIBNu/Q5A5Tu3n/MED9k9H4=
github.com/icholy/digest v1.1.0/go.mod h1:QNrsSGQ5v7v9cReDI0+eyjsXGUoRSUZQHeQ5C4XLa0Y=
github.com/imroc/req/v3 v3.57.0 h1:LMTUjNRUybUkTPn8oJDq8Kg3JRBOBTcnDhKu7mzupKI=
github.com/imroc/req/v3 v3.57.0/go.mod h1:JL62ey1nvSLq81HORNcosvlf7SxZStONNqOprg0Pz00=
-github.com/inconshreveable/mousetrap v1.1.0 h1:wN+x4NVGpMsO7ErUn/mUI3vEoE6Jt13X2s0bqwp9tc8=
-github.com/inconshreveable/mousetrap v1.1.0/go.mod h1:vpF70FUmC8bwa3OWnCshd2FqLfsEA9PFc4w1p2J65bw=
github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsIM=
github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg=
github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 h1:iCEnooe7UlwOQYpKFhBabPMi4aNAfoODPEFNiAnClxo=
@@ -220,8 +216,6 @@ github.com/mattn/go-colorable v0.1.13/go.mod h1:7S9/ev0klgBDR4GtXTXX8a3vIGJpMovk
github.com/mattn/go-isatty v0.0.16/go.mod h1:kYGgaQfpe5nmfYZH+SKPsOc2e4SrIfOl2e/yFXSvRLM=
github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY=
github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
-github.com/mattn/go-runewidth v0.0.15 h1:UNAjwbU9l54TA3KzvqLGxwWjHmMgBUVhBiTjelZgg3U=
-github.com/mattn/go-runewidth v0.0.15/go.mod h1:Jdepj2loyihRzMpdS35Xk/zdY8IAYHsh153qUoGf23w=
github.com/mattn/go-sqlite3 v1.14.17 h1:mCRHCLDUBXgpKAqIKsaAaAsrAlbkeomtRFKXh2L6YIM=
github.com/mattn/go-sqlite3 v1.14.17/go.mod h1:2eHXhiwb8IkHr+BDWZGa96P6+rkvnG63S2DGjv9HUNg=
github.com/mdelapenya/tlscert v0.2.0 h1:7H81W6Z/4weDvZBNOfQte5GpIMo0lGYEeWbkGp5LJHI=
@@ -255,8 +249,6 @@ github.com/morikuni/aec v1.0.0 h1:nP9CBfwrvYnBRgY6qfDQkygYDmYwOilePFkwzv4dU8A=
github.com/morikuni/aec v1.0.0/go.mod h1:BbKIizmSmc5MMPqRYbxO4ZU0S0+P200+tUnFx7PXmsc=
github.com/ncruces/go-strftime v1.0.0 h1:HMFp8mLCTPp341M/ZnA4qaf7ZlsbTc+miZjCLOFAw7w=
github.com/ncruces/go-strftime v1.0.0/go.mod h1:Fwc5htZGVVkseilnfgOVb9mKy6w1naJmn9CehxcKcls=
-github.com/olekukonko/tablewriter v0.0.5 h1:P2Ga83D34wi1o9J6Wh1mRuqd4mF/x/lgBS7N7AbDhec=
-github.com/olekukonko/tablewriter v0.0.5/go.mod h1:hPp6KlRPjbx+hW8ykQs1w3UBbZlj6HuIJcUGPhkA7kY=
github.com/opencontainers/go-digest v1.0.0 h1:apOUWs51W5PlhuyGyz9FCeeBIOUDA/6nW8Oi/yOhh5U=
github.com/opencontainers/go-digest v1.0.0/go.mod h1:0JzlMkj0TRzQZfJkVvzbP0HBR3IKzErnv2BNG4W4MAM=
github.com/opencontainers/image-spec v1.1.1 h1:y0fUlFfIZhPF1W537XOLg0/fcx6zcHCJwooC2xJA040=
@@ -286,8 +278,6 @@ github.com/refraction-networking/utls v1.8.2 h1:j4Q1gJj0xngdeH+Ox/qND11aEfhpgoEv
github.com/refraction-networking/utls v1.8.2/go.mod h1:jkSOEkLqn+S/jtpEHPOsVv/4V4EVnelwbMQl4vCWXAM=
github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec h1:W09IVJc94icq4NjY3clb7Lk8O1qJ8BdBEF8z0ibU0rE=
github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec/go.mod h1:qqbHyh8v60DhA7CoWK5oRCqLrMHRGoxYCSS9EjAz6Eo=
-github.com/rivo/uniseg v0.2.0 h1:S1pD9weZBuJdFmowNwbpi7BJ8TNftyUImj/0WQi72jY=
-github.com/rivo/uniseg v0.2.0/go.mod h1:J6wj4VEh+S6ZtnVlnTBMWIodfgj8LQOQFoIToxlJtxc=
github.com/robfig/cron/v3 v3.0.1 h1:WdRxkvbJztn8LMz/QEvLN5sBU+xKpSqwwUO1Pjr4qDs=
github.com/robfig/cron/v3 v3.0.1/go.mod h1:eQICP3HwyT7UooqI/z+Ov+PtYAWygg1TEWWzGIFLtro=
github.com/rogpeppe/go-internal v1.13.1 h1:KvO1DLK/DRN07sQ1LQKScxyZJuNnedQ5/wKSR38lUII=
@@ -320,8 +310,6 @@ github.com/spf13/afero v1.11.0 h1:WJQKhtpdm3v2IzqG8VMqrr6Rf3UYpEF239Jy9wNepM8=
github.com/spf13/afero v1.11.0/go.mod h1:GH9Y3pIexgf1MTIWtNGyogA5MwRIDXGUr+hbWNoBjkY=
github.com/spf13/cast v1.6.0 h1:GEiTHELF+vaR5dhz3VqZfFSzZjYbgeKDpBxQVS4GYJ0=
github.com/spf13/cast v1.6.0/go.mod h1:ancEpBxwJDODSW/UG4rDrAqiKolqNNh2DX3mk86cAdo=
-github.com/spf13/cobra v1.7.0 h1:hyqWnYt1ZQShIddO5kBpj3vu05/++x6tJ6dg8EC572I=
-github.com/spf13/cobra v1.7.0/go.mod h1:uLxZILRyS/50WlhOIKD7W6V5bgeIt+4sICxh6uRMrb0=
github.com/spf13/pflag v1.0.5 h1:iy+VFUOCP1a+8yFto/drg2CJ5u0yRoB7fZw3DKv/JXA=
github.com/spf13/pflag v1.0.5/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg=
github.com/spf13/viper v1.18.2 h1:LUXCnvUvSM6FXAsj6nnfc8Q2tp1dIgUfY9Kc8GsSOiQ=
@@ -413,16 +401,18 @@ go.uber.org/zap v1.24.0/go.mod h1:2kMP+WWQ8aoFoedH3T2sq6iJ2yDWpHbP0f6MQbS9Gkg=
golang.org/x/arch v0.0.0-20210923205945-b76863e36670/go.mod h1:5om86z9Hs0C8fWVUuoMHwpExlXzs5Tkyp9hOrfG7pp8=
golang.org/x/arch v0.3.0 h1:02VY4/ZcO/gBOH6PUaoiptASxtXU10jazRCP865E97k=
golang.org/x/arch v0.3.0/go.mod h1:5om86z9Hs0C8fWVUuoMHwpExlXzs5Tkyp9hOrfG7pp8=
-golang.org/x/crypto v0.48.0 h1:/VRzVqiRSggnhY7gNRxPauEQ5Drw9haKdM0jqfcCFts=
-golang.org/x/crypto v0.48.0/go.mod h1:r0kV5h3qnFPlQnBSrULhlsRfryS2pmewsg+XfMgkVos=
+golang.org/x/crypto v0.49.0 h1:+Ng2ULVvLHnJ/ZFEq4KdcDd/cfjrrjjNSXNzxg0Y4U4=
+golang.org/x/crypto v0.49.0/go.mod h1:ErX4dUh2UM+CFYiXZRTcMpEcN8b/1gxEuv3nODoYtCA=
golang.org/x/exp v0.0.0-20251023183803-a4bb9ffd2546 h1:mgKeJMpvi0yx/sU5GsxQ7p6s2wtOnGAHZWCHUM4KGzY=
golang.org/x/exp v0.0.0-20251023183803-a4bb9ffd2546/go.mod h1:j/pmGrbnkbPtQfxEe5D0VQhZC6qKbfKifgD0oM7sR70=
-golang.org/x/mod v0.32.0 h1:9F4d3PHLljb6x//jOyokMv3eX+YDeepZSEo3mFJy93c=
-golang.org/x/mod v0.32.0/go.mod h1:SgipZ/3h2Ci89DlEtEXWUk/HteuRin+HHhN+WbNhguU=
-golang.org/x/net v0.49.0 h1:eeHFmOGUTtaaPSGNmjBKpbng9MulQsJURQUAfUwY++o=
-golang.org/x/net v0.49.0/go.mod h1:/ysNB2EvaqvesRkuLAyjI1ycPZlQHM3q01F02UY/MV8=
-golang.org/x/sync v0.19.0 h1:vV+1eWNmZ5geRlYjzm2adRgW2/mcpevXNg50YZtPCE4=
-golang.org/x/sync v0.19.0/go.mod h1:9KTHXmSnoGruLpwFjVSX0lNNA75CykiMECbovNTZqGI=
+golang.org/x/image v0.39.0 h1:skVYidAEVKgn8lZ602XO75asgXBgLj9G/FE3RbuPFww=
+golang.org/x/image v0.39.0/go.mod h1:sIbmppfU+xFLPIG0FoVUTvyBMmgng1/XAMhQ2ft0hpA=
+golang.org/x/mod v0.34.0 h1:xIHgNUUnW6sYkcM5Jleh05DvLOtwc6RitGHbDk4akRI=
+golang.org/x/mod v0.34.0/go.mod h1:ykgH52iCZe79kzLLMhyCUzhMci+nQj+0XkbXpNYtVjY=
+golang.org/x/net v0.52.0 h1:He/TN1l0e4mmR3QqHMT2Xab3Aj3L9qjbhRm78/6jrW0=
+golang.org/x/net v0.52.0/go.mod h1:R1MAz7uMZxVMualyPXb+VaqGSa3LIaUqk0eEt3w36Sw=
+golang.org/x/sync v0.20.0 h1:e0PTpb7pjO8GAtTs2dQ6jYa5BWYlMuX047Dco/pItO4=
+golang.org/x/sync v0.20.0/go.mod h1:9xrNwdLfx4jkKbNva9FpL6vEN7evnE43NNNJQ2LF3+0=
golang.org/x/sys v0.0.0-20190916202348-b4ddaad3f8a3/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20201204225414-ed752295db88/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20210616094352-59db8d763f22/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
@@ -432,16 +422,16 @@ golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBc
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.8.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.11.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
-golang.org/x/sys v0.41.0 h1:Ivj+2Cp/ylzLiEU89QhWblYnOE9zerudt9Ftecq2C6k=
-golang.org/x/sys v0.41.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks=
-golang.org/x/term v0.40.0 h1:36e4zGLqU4yhjlmxEaagx2KuYbJq3EwY8K943ZsHcvg=
-golang.org/x/term v0.40.0/go.mod h1:w2P8uVp06p2iyKKuvXIm7N/y0UCRt3UfJTfZ7oOpglM=
-golang.org/x/text v0.34.0 h1:oL/Qq0Kdaqxa1KbNeMKwQq0reLCCaFtqu2eNuSeNHbk=
-golang.org/x/text v0.34.0/go.mod h1:homfLqTYRFyVYemLBFl5GgL/DWEiH5wcsQ5gSh1yziA=
+golang.org/x/sys v0.42.0 h1:omrd2nAlyT5ESRdCLYdm3+fMfNFE/+Rf4bDIQImRJeo=
+golang.org/x/sys v0.42.0/go.mod h1:4GL1E5IUh+htKOUEOaiffhrAeqysfVGipDYzABqnCmw=
+golang.org/x/term v0.41.0 h1:QCgPso/Q3RTJx2Th4bDLqML4W6iJiaXFq2/ftQF13YU=
+golang.org/x/term v0.41.0/go.mod h1:3pfBgksrReYfZ5lvYM0kSO0LIkAl4Yl2bXOkKP7Ec2A=
+golang.org/x/text v0.36.0 h1:JfKh3XmcRPqZPKevfXVpI1wXPTqbkE5f7JA92a55Yxg=
+golang.org/x/text v0.36.0/go.mod h1:NIdBknypM8iqVmPiuco0Dh6P5Jcdk8lJL0CUebqK164=
golang.org/x/time v0.12.0 h1:ScB/8o8olJvc+CQPWrK3fPZNfh7qgwCrY0zJmoEQLSE=
golang.org/x/time v0.12.0/go.mod h1:CDIdPxbZBQxdj6cxyCIdrNogrJKMJ7pr37NYpMcMDSg=
-golang.org/x/tools v0.41.0 h1:a9b8iMweWG+S0OBnlU36rzLp20z1Rp10w+IY2czHTQc=
-golang.org/x/tools v0.41.0/go.mod h1:XSY6eDqxVNiYgezAVqqCeihT4j1U2CCsqvH3WhQpnlg=
+golang.org/x/tools v0.43.0 h1:12BdW9CeB3Z+J/I/wj34VMl8X+fEXBxVR90JeMX5E7s=
+golang.org/x/tools v0.43.0/go.mod h1:uHkMso649BX2cZK6+RpuIPXS3ho2hZo4FVwfoy1vIk0=
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
google.golang.org/genproto v0.0.0-20231106174013-bbf56f31fb17 h1:wpZ8pe2x1Q3f2KyT5f8oP/fa9rHAKgFPr/HZdNuS+PQ=
google.golang.org/genproto/googleapis/api v0.0.0-20250929231259-57b25ae835d4 h1:8XJ4pajGwOlasW+L13MnEGA8W4115jJySQtVfS2/IBU=
diff --git a/backend/internal/service/user_service.go b/backend/internal/service/user_service.go
index 6994ec6c..9c7d4747 100644
--- a/backend/internal/service/user_service.go
+++ b/backend/internal/service/user_service.go
@@ -1,6 +1,7 @@
package service
import (
+ "bytes"
"context"
"crypto/sha256"
"crypto/subtle"
@@ -9,11 +10,19 @@ import (
"fmt"
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
+ "image"
+ "image/color"
+ stddraw "image/draw"
+ _ "image/gif"
+ "image/jpeg"
+ _ "image/png"
"log/slog"
"net/url"
"sort"
"strings"
"time"
+
+ xdraw "golang.org/x/image/draw"
)
var (
@@ -31,6 +40,7 @@ var (
const (
maxNotifyEmails = 3 // Maximum number of notification emails per user
maxInlineAvatarBytes = 100 * 1024
+ targetAvatarBytes = 20 * 1024
// User-level rate limiting for notify email verification codes
notifyCodeUserRateLimit = 5
@@ -39,6 +49,11 @@ const (
defaultUserIdentityRedirect = "/settings/profile"
)
+var (
+ avatarScaleSteps = []float64{1, 0.92, 0.84, 0.76, 0.68, 0.6, 0.52, 0.44, 0.36}
+ avatarQualitySteps = []int{88, 80, 72, 64, 56, 48, 40, 32}
+)
+
// UserListFilters contains all filter options for listing users
type UserListFilters struct {
Status string // User status filter
@@ -432,6 +447,14 @@ func normalizeInlineUserAvatarInput(raw string) (UpsertUserAvatarInput, error) {
return UpsertUserAvatarInput{}, ErrAvatarTooLarge
}
+ if len(decoded) > targetAvatarBytes {
+ decoded, contentType, err = compressInlineAvatar(decoded)
+ if err != nil {
+ return UpsertUserAvatarInput{}, err
+ }
+ raw = "data:" + contentType + ";base64," + base64.StdEncoding.EncodeToString(decoded)
+ }
+
sum := sha256.Sum256(decoded)
return UpsertUserAvatarInput{
StorageProvider: "inline",
@@ -442,6 +465,38 @@ func normalizeInlineUserAvatarInput(raw string) (UpsertUserAvatarInput, error) {
}, nil
}
+func compressInlineAvatar(decoded []byte) ([]byte, string, error) {
+ src, _, err := image.Decode(bytes.NewReader(decoded))
+ if err != nil {
+ return nil, "", ErrAvatarInvalid
+ }
+
+ srcBounds := src.Bounds()
+ if srcBounds.Empty() {
+ return nil, "", ErrAvatarInvalid
+ }
+
+ for _, scale := range avatarScaleSteps {
+ width := max(1, int(float64(srcBounds.Dx())*scale))
+ height := max(1, int(float64(srcBounds.Dy())*scale))
+ dst := image.NewRGBA(image.Rect(0, 0, width, height))
+ stddraw.Draw(dst, dst.Bounds(), &image.Uniform{C: color.White}, image.Point{}, stddraw.Src)
+ xdraw.CatmullRom.Scale(dst, dst.Bounds(), src, srcBounds, stddraw.Over, nil)
+
+ for _, quality := range avatarQualitySteps {
+ var buf bytes.Buffer
+ if err := jpeg.Encode(&buf, dst, &jpeg.Options{Quality: quality}); err != nil {
+ return nil, "", ErrAvatarInvalid
+ }
+ if buf.Len() <= targetAvatarBytes {
+ return buf.Bytes(), "image/jpeg", nil
+ }
+ }
+ }
+
+ return nil, "", ErrAvatarTooLarge
+}
+
func (s *UserService) buildEmailIdentitySummary(user *User) UserIdentitySummary {
summary := UserIdentitySummary{
Provider: "email",
diff --git a/backend/internal/service/user_service_test.go b/backend/internal/service/user_service_test.go
index c0cc36cb..d771cb75 100644
--- a/backend/internal/service/user_service_test.go
+++ b/backend/internal/service/user_service_test.go
@@ -3,11 +3,14 @@
package service
import (
+ "bytes"
"context"
"crypto/sha256"
"encoding/base64"
"encoding/hex"
"errors"
+ "image"
+ "image/png"
"sync"
"sync/atomic"
"testing"
@@ -361,6 +364,57 @@ func TestUpdateProfile_StoresInlineAvatarWithinLimit(t *testing.T) {
require.Equal(t, hex.EncodeToString(expectedSum[:]), updated.AvatarSHA256)
}
+func TestUpdateProfile_CompressesInlineAvatarToTwentyKilobytes(t *testing.T) {
+ var encoded bytes.Buffer
+ for _, size := range []int{192, 224, 256, 288} {
+ encoded.Reset()
+ var img image.RGBA
+ img.Rect = image.Rect(0, 0, size, size)
+ img.Stride = size * 4
+ img.Pix = make([]byte, size*size*4)
+ for y := 0; y < size; y++ {
+ for x := 0; x < size; x++ {
+ offset := y*img.Stride + x*4
+ img.Pix[offset] = uint8((x*x + y*17) % 255)
+ img.Pix[offset+1] = uint8((y*y + x*29) % 255)
+ img.Pix[offset+2] = uint8(((x * y) + x*13 + y*7) % 255)
+ img.Pix[offset+3] = 0xff
+ }
+ }
+ require.NoError(t, png.Encode(&encoded, &img))
+ if encoded.Len() > 20*1024 && encoded.Len() <= maxInlineAvatarBytes {
+ break
+ }
+ }
+ require.Greater(t, encoded.Len(), 20*1024)
+ require.LessOrEqual(t, encoded.Len(), maxInlineAvatarBytes)
+
+ dataURL := "data:image/png;base64," + base64.StdEncoding.EncodeToString(encoded.Bytes())
+ repo := &mockUserRepo{
+ getByIDUser: &User{
+ ID: 17,
+ Email: "avatar-compress@example.com",
+ Username: "avatar-compress",
+ },
+ }
+ svc := NewUserService(repo, nil, nil, nil)
+
+ updated, err := svc.UpdateProfile(context.Background(), 17, UpdateProfileRequest{
+ AvatarURL: &dataURL,
+ })
+ require.NoError(t, err)
+ require.Len(t, repo.upsertAvatarArgs, 1)
+ require.Equal(t, "inline", repo.upsertAvatarArgs[0].StorageProvider)
+ require.LessOrEqual(t, repo.upsertAvatarArgs[0].ByteSize, 20*1024)
+ require.Equal(t, "image/jpeg", repo.upsertAvatarArgs[0].ContentType)
+ require.Contains(t, repo.upsertAvatarArgs[0].URL, "data:image/jpeg;base64,")
+ require.Equal(t, "inline", updated.AvatarSource)
+ require.Equal(t, "image/jpeg", updated.AvatarMIME)
+ require.LessOrEqual(t, updated.AvatarByteSize, 20*1024)
+ require.Contains(t, updated.AvatarURL, "data:image/jpeg;base64,")
+ require.NotEmpty(t, updated.AvatarSHA256)
+}
+
func TestUpdateProfile_RejectsInlineAvatarOverLimit(t *testing.T) {
raw := make([]byte, maxInlineAvatarBytes+1)
dataURL := "data:image/png;base64," + base64.StdEncoding.EncodeToString(raw)
diff --git a/frontend/src/components/user/profile/ProfileInfoCard.vue b/frontend/src/components/user/profile/ProfileInfoCard.vue
index 460e5c7f..dd7ae0ea 100644
--- a/frontend/src/components/user/profile/ProfileInfoCard.vue
+++ b/frontend/src/components/user/profile/ProfileInfoCard.vue
@@ -176,6 +176,9 @@ const authStore = useAuthStore()
const appStore = useAppStore()
const maxAvatarBytes = 100 * 1024
+const targetAvatarUploadBytes = 20 * 1024
+const avatarScaleSteps = [1, 0.92, 0.84, 0.76, 0.68, 0.6, 0.52, 0.44, 0.36]
+const avatarQualitySteps = [0.92, 0.84, 0.76, 0.68, 0.6, 0.52, 0.44, 0.36]
const avatarDraft = ref(props.user?.avatar_url?.trim() || '')
const avatarSaving = ref(false)
@@ -341,6 +344,72 @@ function readFileAsDataURL(file: File): Promise {
})
}
+function loadImage(dataURL: string): Promise {
+ return new Promise((resolve, reject) => {
+ const image = new Image()
+ image.onload = () => resolve(image)
+ image.onerror = () => reject(new Error(t('profile.avatar.readFailed')))
+ image.src = dataURL
+ })
+}
+
+function canvasToBlob(canvas: HTMLCanvasElement, type: string, quality: number): Promise {
+ return new Promise((resolve, reject) => {
+ canvas.toBlob((blob) => {
+ if (!blob) {
+ reject(new Error(t('profile.avatar.compressFailed')))
+ return
+ }
+ resolve(blob)
+ }, type, quality)
+ })
+}
+
+async function compressAvatarFile(file: File): Promise {
+ const sourceDataURL = await readFileAsDataURL(file)
+ const image = await loadImage(sourceDataURL)
+ const canvas = document.createElement('canvas')
+ const ctx = canvas.getContext('2d')
+ if (!ctx) {
+ throw new Error(t('profile.avatar.compressFailed'))
+ }
+
+ for (const scale of avatarScaleSteps) {
+ const width = Math.max(1, Math.round(image.naturalWidth * scale))
+ const height = Math.max(1, Math.round(image.naturalHeight * scale))
+ canvas.width = width
+ canvas.height = height
+ ctx.clearRect(0, 0, width, height)
+ ctx.drawImage(image, 0, 0, width, height)
+
+ for (const quality of avatarQualitySteps) {
+ const blob = await canvasToBlob(canvas, 'image/webp', quality)
+ if (blob.size <= targetAvatarUploadBytes) {
+ const fileName = file.name.replace(/\.[^.]+$/, '') || 'avatar'
+ return new File([blob], `${fileName}.webp`, { type: 'image/webp' })
+ }
+ }
+ }
+
+ throw new Error(t('profile.avatar.compressTooLarge'))
+}
+
+async function prepareAvatarUpload(file: File): Promise {
+ if (!file.type.startsWith('image/')) {
+ throw new Error(t('profile.avatar.invalidType'))
+ }
+ if (file.type === 'image/gif') {
+ if (file.size > targetAvatarUploadBytes) {
+ throw new Error(t('profile.avatar.gifTooLarge'))
+ }
+ return file
+ }
+ if (file.size <= targetAvatarUploadBytes) {
+ return file
+ }
+ return compressAvatarFile(file)
+}
+
async function handleAvatarFileChange(event: Event) {
const input = event.target as HTMLInputElement | null
const file = input?.files?.[0]
@@ -360,7 +429,8 @@ async function handleAvatarFileChange(event: Event) {
}
try {
- const dataURL = await readFileAsDataURL(file)
+ const preparedFile = await prepareAvatarUpload(file)
+ const dataURL = await readFileAsDataURL(preparedFile)
const normalized = validateAvatarInput(dataURL)
if (!normalized) {
return
diff --git a/frontend/src/components/user/profile/__tests__/ProfileInfoCard.spec.ts b/frontend/src/components/user/profile/__tests__/ProfileInfoCard.spec.ts
index 6e9e4dd5..4c2b25ca 100644
--- a/frontend/src/components/user/profile/__tests__/ProfileInfoCard.spec.ts
+++ b/frontend/src/components/user/profile/__tests__/ProfileInfoCard.spec.ts
@@ -1,5 +1,5 @@
import { mount } from '@vue/test-utils'
-import { beforeEach, describe, expect, it, vi } from 'vitest'
+import { afterEach, beforeEach, describe, expect, it, vi } from 'vitest'
import ProfileInfoCard from '@/components/user/profile/ProfileInfoCard.vue'
import type { User } from '@/types'
@@ -51,11 +51,15 @@ vi.mock('vue-i18n', async (importOriginal) => {
if (key === 'profile.avatar.inputLabel') return 'Avatar URL or data URL'
if (key === 'profile.avatar.inputPlaceholder') return 'https://cdn.example.com/avatar.png'
if (key === 'profile.avatar.uploadAction') return 'Upload image'
- if (key === 'profile.avatar.uploadHint') return 'Images must be 100KB or smaller'
+ if (key === 'profile.avatar.uploadHint') return 'Images must be 100KB or smaller and will be compressed to 20KB'
if (key === 'profile.avatar.saveSuccess') return 'Avatar updated'
if (key === 'profile.avatar.deleteSuccess') return 'Avatar removed'
if (key === 'profile.avatar.invalidType') return 'Please choose an image file'
if (key === 'profile.avatar.fileTooLarge') return 'Avatar image must be 100KB or smaller'
+ if (key === 'profile.avatar.gifTooLarge') return 'GIF avatars must already be 20KB or smaller'
+ if (key === 'profile.avatar.compressTooLarge') return 'Unable to compress this image below 20KB'
+ if (key === 'profile.avatar.compressFailed') return 'Failed to compress the selected image'
+ if (key === 'profile.avatar.readFailed') return 'Failed to read the selected image'
if (key === 'profile.avatar.invalidValue') return 'Enter a valid avatar URL or image data URL'
if (key === 'profile.avatar.emptyDeleteHint') return 'Avatar already removed'
if (key === 'profile.authBindings.providers.email') return 'Email'
@@ -92,6 +96,63 @@ function createUser(overrides: Partial = {}): User {
}
}
+async function flushAsyncWork(): Promise {
+ await Promise.resolve()
+ await Promise.resolve()
+}
+
+const originalFileReader = globalThis.FileReader
+const originalImage = globalThis.Image
+const originalCreateElement = document.createElement.bind(document)
+
+function installAvatarCompressionMocks() {
+ class MockFileReader {
+ result: string | ArrayBuffer | null = null
+ onload: ((this: FileReader, ev: ProgressEvent) => any) | null = null
+ onerror: ((this: FileReader, ev: ProgressEvent) => any) | null = null
+ error: DOMException | null = null
+
+ readAsDataURL(blob: Blob) {
+ if (blob.type === 'image/webp') {
+ this.result = 'data:image/webp;base64,' + Buffer.from('compressed-avatar').toString('base64')
+ } else {
+ this.result = 'data:image/png;base64,' + Buffer.from('original-avatar').toString('base64')
+ }
+ this.onload?.call(this as unknown as FileReader, new ProgressEvent('load'))
+ }
+ }
+
+ class MockImage {
+ naturalWidth = 1200
+ naturalHeight = 1200
+ onload: (() => void) | null = null
+ onerror: (() => void) | null = null
+
+ set src(_value: string) {
+ this.onload?.()
+ }
+ }
+
+ globalThis.FileReader = MockFileReader as unknown as typeof FileReader
+ globalThis.Image = MockImage as unknown as typeof Image
+ vi.spyOn(document, 'createElement').mockImplementation(((tagName: string, options?: ElementCreationOptions) => {
+ if (tagName === 'canvas') {
+ return {
+ width: 0,
+ height: 0,
+ getContext: () => ({
+ clearRect: vi.fn(),
+ drawImage: vi.fn(),
+ }),
+ toBlob: (callback: BlobCallback) => {
+ callback(new Blob([new Uint8Array(8 * 1024)], { type: 'image/webp' }))
+ },
+ } as unknown as HTMLCanvasElement
+ }
+ return originalCreateElement(tagName, options)
+ }) as typeof document.createElement)
+}
+
describe('ProfileInfoCard', () => {
beforeEach(() => {
updateProfileMock.mockReset()
@@ -100,6 +161,12 @@ describe('ProfileInfoCard', () => {
authStoreState.user = null
})
+ afterEach(() => {
+ globalThis.FileReader = originalFileReader
+ globalThis.Image = originalImage
+ vi.restoreAllMocks()
+ })
+
it('saves a remote avatar URL and updates the auth store', async () => {
const updatedUser = createUser({ avatar_url: 'https://cdn.example.com/new.png' })
updateProfileMock.mockResolvedValue(updatedUser)
@@ -148,6 +215,39 @@ describe('ProfileInfoCard', () => {
expect(showErrorMock).toHaveBeenCalledWith('Avatar image must be 100KB or smaller')
})
+ it('compresses uploaded images under 100KB before saving', async () => {
+ installAvatarCompressionMocks()
+ const updatedUser = createUser({ avatar_url: 'data:image/webp;base64,Y29tcHJlc3NlZC1hdmF0YXI=' })
+ updateProfileMock.mockResolvedValue(updatedUser)
+ authStoreState.user = createUser()
+
+ const wrapper = mount(ProfileInfoCard, {
+ props: {
+ user: authStoreState.user
+ },
+ global: {
+ stubs: {
+ Icon: true,
+ ProfileIdentityBindingsSection: true
+ }
+ }
+ })
+
+ const fileInput = wrapper.get('[data-testid="profile-avatar-file-input"]')
+ Object.defineProperty(fileInput.element, 'files', {
+ value: [new File([new Uint8Array(80 * 1024)], 'avatar.png', { type: 'image/png' })],
+ configurable: true
+ })
+
+ await fileInput.trigger('change')
+ await flushAsyncWork()
+ await wrapper.get('[data-testid="profile-avatar-save"]').trigger('click')
+
+ expect(updateProfileMock).toHaveBeenCalledWith({
+ avatar_url: 'data:image/webp;base64,Y29tcHJlc3NlZC1hdmF0YXI='
+ })
+ })
+
it('deletes the current avatar', async () => {
const updatedUser = createUser({ avatar_url: null })
updateProfileMock.mockResolvedValue(updatedUser)
diff --git a/frontend/src/i18n/locales/en.ts b/frontend/src/i18n/locales/en.ts
index 78db9e8a..ec9c1ea3 100644
--- a/frontend/src/i18n/locales/en.ts
+++ b/frontend/src/i18n/locales/en.ts
@@ -943,15 +943,19 @@ export default {
},
avatar: {
title: 'Profile Avatar',
- description: 'Set your avatar with a remote image URL or upload a small image.',
+ description: 'Set your avatar with a remote image URL or upload an image under 100KB. Uploaded images are compressed to 20KB.',
inputLabel: 'Avatar URL or data URL',
inputPlaceholder: 'https://cdn.example.com/avatar.png',
uploadAction: 'Upload image',
- uploadHint: 'Images must be 100KB or smaller',
+ uploadHint: 'Uploaded images must be 100KB or smaller. Static images are compressed to 20KB.',
saveSuccess: 'Avatar updated',
deleteSuccess: 'Avatar removed',
invalidType: 'Please choose an image file',
fileTooLarge: 'Avatar image must be 100KB or smaller',
+ gifTooLarge: 'GIF avatars must already be 20KB or smaller',
+ compressTooLarge: 'Unable to compress this image below 20KB. Try a smaller image.',
+ compressFailed: 'Failed to compress the selected image.',
+ readFailed: 'Failed to read the selected image.',
invalidValue: 'Enter a valid avatar URL or image data URL',
emptyDeleteHint: 'Avatar is already empty',
},
diff --git a/frontend/src/i18n/locales/zh.ts b/frontend/src/i18n/locales/zh.ts
index 9bb265e1..9941d323 100644
--- a/frontend/src/i18n/locales/zh.ts
+++ b/frontend/src/i18n/locales/zh.ts
@@ -947,15 +947,19 @@ export default {
},
avatar: {
title: '资料头像',
- description: '支持填写远程图片 URL,或上传不超过 100KB 的头像图片。',
+ description: '支持填写远程图片 URL,或上传不超过 100KB 的头像图片;上传图片会自动压缩到 20KB 以内。',
inputLabel: '头像 URL 或 data URL',
inputPlaceholder: 'https://cdn.example.com/avatar.png',
uploadAction: '上传图片',
- uploadHint: '图片大小需不超过 100KB',
+ uploadHint: '上传图片需不超过 100KB,静态图片会自动压缩到 20KB 以内',
saveSuccess: '头像已更新',
deleteSuccess: '头像已删除',
invalidType: '请选择图片文件',
fileTooLarge: '头像图片必须不超过 100KB',
+ gifTooLarge: 'GIF 头像必须在 20KB 以内',
+ compressTooLarge: '无法将图片压缩到 20KB 以内,请换一张更小的图片',
+ compressFailed: '压缩所选图片失败',
+ readFailed: '读取所选图片失败',
invalidValue: '请输入有效的头像 URL 或图片 data URL',
emptyDeleteHint: '当前没有可删除的头像',
},
From dcd5c43da469809cc1d722d984e5eba9a7507357 Mon Sep 17 00:00:00 2001
From: IanShaw027
Date: Tue, 21 Apr 2026 10:00:06 +0800
Subject: [PATCH 095/189] feat: complete email binding and pending oauth
verification flows
---
backend/cmd/server/wire_gen.go | 2 +-
.../handler/auth_oauth_pending_flow.go | 154 +++++++--
.../handler/auth_oauth_pending_flow_test.go | 240 ++++++++++++-
backend/internal/handler/user_handler.go | 86 ++++-
backend/internal/handler/user_handler_test.go | 119 ++++++-
backend/internal/server/routes/auth.go | 6 +
.../server/routes/auth_rate_limit_test.go | 1 +
backend/internal/server/routes/user.go | 2 +
.../internal/service/auth_email_binding.go | 128 +++++++
.../internal/service/auth_oauth_email_flow.go | 164 +++++++--
.../service/auth_oauth_email_flow_test.go | 251 ++++++++++++++
.../service/auth_service_email_bind_test.go | 316 ++++++++++++++++++
frontend/src/api/auth.ts | 11 +
frontend/src/api/user.ts | 15 +
.../auth/PendingOAuthCreateAccountForm.vue | 23 +-
.../PendingOAuthCreateAccountForm.spec.ts | 45 ++-
.../ProfileIdentityBindingsSection.vue | 225 +++++++++++--
.../ProfileIdentityBindingsSection.spec.ts | 79 ++++-
frontend/src/i18n/locales/en.ts | 6 +
frontend/src/i18n/locales/zh.ts | 6 +
frontend/src/types/index.ts | 2 +
frontend/src/views/auth/EmailVerifyView.vue | 16 +-
.../src/views/auth/LinuxDoCallbackView.vue | 27 ++
frontend/src/views/auth/OidcCallbackView.vue | 27 ++
.../src/views/auth/WechatCallbackView.vue | 27 ++
.../auth/__tests__/EmailVerifyView.spec.ts | 81 +++++
.../__tests__/LinuxDoCallbackView.spec.ts | 52 ++-
.../auth/__tests__/OidcCallbackView.spec.ts | 53 ++-
.../auth/__tests__/WechatCallbackView.spec.ts | 60 +++-
29 files changed, 2117 insertions(+), 107 deletions(-)
create mode 100644 backend/internal/service/auth_email_binding.go
create mode 100644 backend/internal/service/auth_oauth_email_flow_test.go
create mode 100644 backend/internal/service/auth_service_email_bind_test.go
diff --git a/backend/cmd/server/wire_gen.go b/backend/cmd/server/wire_gen.go
index 1d39fa1e..3b474c4a 100644
--- a/backend/cmd/server/wire_gen.go
+++ b/backend/cmd/server/wire_gen.go
@@ -79,7 +79,7 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
totpCache := repository.NewTotpCache(redisClient)
totpService := service.NewTotpService(userRepository, secretEncryptor, totpCache, settingService, emailService, emailQueueService)
authHandler := handler.NewAuthHandler(configConfig, authService, userService, settingService, promoService, redeemService, totpService)
- userHandler := handler.NewUserHandler(userService, emailService, emailCache)
+ userHandler := handler.NewUserHandler(userService, authService, emailService, emailCache)
apiKeyHandler := handler.NewAPIKeyHandler(apiKeyService)
usageLogRepository := repository.NewUsageLogRepository(client, db)
usageService := service.NewUsageService(usageLogRepository, userRepository, client, apiKeyAuthCacheInvalidator)
diff --git a/backend/internal/handler/auth_oauth_pending_flow.go b/backend/internal/handler/auth_oauth_pending_flow.go
index b7ff17c3..1d3b113f 100644
--- a/backend/internal/handler/auth_oauth_pending_flow.go
+++ b/backend/internal/handler/auth_oauth_pending_flow.go
@@ -16,6 +16,7 @@ import (
"github.com/Wei-Shaw/sub2api/ent/predicate"
dbuser "github.com/Wei-Shaw/sub2api/ent/user"
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
+ "github.com/Wei-Shaw/sub2api/internal/pkg/ip"
"github.com/Wei-Shaw/sub2api/internal/pkg/oauth"
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
"github.com/Wei-Shaw/sub2api/internal/service"
@@ -27,7 +28,7 @@ import (
const (
oauthPendingBrowserCookiePath = "/api/v1/auth/oauth"
oauthPendingBrowserCookieName = "oauth_pending_browser_session"
- oauthPendingSessionCookiePath = "/api/v1/auth/oauth/pending"
+ oauthPendingSessionCookiePath = "/api/v1/auth/oauth"
oauthPendingSessionCookieName = "oauth_pending_session"
oauthPendingCookieMaxAgeSec = 10 * 60
@@ -66,6 +67,13 @@ type createPendingOAuthAccountRequest struct {
AdoptAvatar *bool `json:"adopt_avatar,omitempty"`
}
+type sendPendingOAuthVerifyCodeRequest struct {
+ Email string `json:"email" binding:"required,email"`
+ TurnstileToken string `json:"turnstile_token,omitempty"`
+ PendingAuthToken string `json:"pending_auth_token,omitempty"`
+ PendingOAuthToken string `json:"pending_oauth_token,omitempty"`
+}
+
func (r bindPendingOAuthLoginRequest) adoptionDecision() oauthAdoptionDecisionRequest {
return oauthAdoptionDecisionRequest{
AdoptDisplayName: r.AdoptDisplayName,
@@ -448,6 +456,43 @@ func (h *AuthHandler) CreatePendingOAuthAccount(c *gin.Context) {
h.createPendingOAuthAccount(c, "")
}
+// SendPendingOAuthVerifyCode sends a verification code for a browser-bound
+// pending OAuth account-creation flow.
+// POST /api/v1/auth/oauth/pending/send-verify-code
+func (h *AuthHandler) SendPendingOAuthVerifyCode(c *gin.Context) {
+ var req sendPendingOAuthVerifyCodeRequest
+ if err := c.ShouldBindJSON(&req); err != nil {
+ response.BadRequest(c, "Invalid request: "+err.Error())
+ return
+ }
+
+ if err := h.authService.VerifyTurnstile(c.Request.Context(), req.TurnstileToken, ip.GetClientIP(c)); err != nil {
+ response.ErrorFrom(c, err)
+ return
+ }
+
+ _, session, _, err := readPendingOAuthBrowserSession(c, h)
+ if err != nil {
+ response.ErrorFrom(c, err)
+ return
+ }
+ if err := ensurePendingOAuthCompleteRegistrationSession(session); err != nil {
+ response.ErrorFrom(c, err)
+ return
+ }
+
+ result, err := h.authService.SendPendingOAuthVerifyCode(c.Request.Context(), req.Email)
+ if err != nil {
+ response.ErrorFrom(c, err)
+ return
+ }
+
+ response.Success(c, SendVerifyCodeResponse{
+ Message: "Verification code sent successfully",
+ Countdown: result.Countdown,
+ })
+}
+
func (h *AuthHandler) upsertPendingOAuthAdoptionDecision(
c *gin.Context,
sessionID int64,
@@ -1084,6 +1129,41 @@ func buildPendingOAuthSessionStatusPayload(session *dbent.PendingAuthSession) gi
return payload
}
+func (h *AuthHandler) transitionPendingOAuthAccountToBindLogin(
+ c *gin.Context,
+ client *dbent.Client,
+ session *dbent.PendingAuthSession,
+ email string,
+ decision oauthAdoptionDecisionRequest,
+) (*dbent.PendingAuthSession, error) {
+ existingUser, err := findUserByNormalizedEmail(c.Request.Context(), client, email)
+ if err != nil {
+ return nil, err
+ }
+
+ completionResponse := mergePendingCompletionResponse(session, map[string]any{
+ "step": "bind_login_required",
+ "email": email,
+ })
+ session, err = updatePendingOAuthSessionProgress(
+ c.Request.Context(),
+ client,
+ session,
+ "adopt_existing_user_by_email",
+ email,
+ &existingUser.ID,
+ completionResponse,
+ )
+ if err != nil {
+ return nil, infraerrors.InternalServer("PENDING_AUTH_SESSION_UPDATE_FAILED", "failed to update pending oauth session").WithCause(err)
+ }
+
+ if _, err := h.ensurePendingOAuthAdoptionDecision(c, session.ID, decision); err != nil {
+ return nil, err
+ }
+ return session, nil
+}
+
func writeOAuthTokenPairResponse(c *gin.Context, tokenPair *service.TokenPair) {
c.JSON(http.StatusOK, gin.H{
"access_token": tokenPair.AccessToken,
@@ -1199,29 +1279,11 @@ func (h *AuthHandler) createPendingOAuthAccount(c *gin.Context, provider string)
return
}
if existingUser != nil {
- completionResponse := mergePendingCompletionResponse(session, map[string]any{
- "step": "bind_login_required",
- "email": email,
- })
- session, err = updatePendingOAuthSessionProgress(
- c.Request.Context(),
- client,
- session,
- "adopt_existing_user_by_email",
- email,
- &existingUser.ID,
- completionResponse,
- )
+ session, err = h.transitionPendingOAuthAccountToBindLogin(c, client, session, email, req.adoptionDecision())
if err != nil {
- response.ErrorFrom(c, infraerrors.InternalServer("PENDING_AUTH_SESSION_UPDATE_FAILED", "failed to update pending oauth session").WithCause(err))
- return
- }
-
- if _, err := h.ensurePendingOAuthAdoptionDecision(c, session.ID, req.adoptionDecision()); err != nil {
response.ErrorFrom(c, err)
return
}
-
c.JSON(http.StatusOK, buildPendingOAuthSessionStatusPayload(session))
return
}
@@ -1239,27 +1301,77 @@ func (h *AuthHandler) createPendingOAuthAccount(c *gin.Context, provider string)
strings.TrimSpace(session.ProviderType),
)
if err != nil {
+ if errors.Is(err, service.ErrEmailExists) {
+ session, err = h.transitionPendingOAuthAccountToBindLogin(c, client, session, email, req.adoptionDecision())
+ if err != nil {
+ response.ErrorFrom(c, err)
+ return
+ }
+ c.JSON(http.StatusOK, buildPendingOAuthSessionStatusPayload(session))
+ return
+ }
response.ErrorFrom(c, err)
return
}
+ rollbackCreatedUser := func(originalErr error) bool {
+ if user == nil || user.ID <= 0 {
+ return false
+ }
+ if rollbackErr := h.authService.RollbackOAuthEmailAccountCreation(
+ c.Request.Context(),
+ user.ID,
+ strings.TrimSpace(req.InvitationCode),
+ ); rollbackErr != nil {
+ response.ErrorFrom(c, infraerrors.InternalServer(
+ "PENDING_AUTH_ACCOUNT_ROLLBACK_FAILED",
+ "failed to rollback pending oauth account creation",
+ ).WithCause(fmt.Errorf("original error: %w; rollback error: %v", originalErr, rollbackErr)))
+ return true
+ }
+ user = nil
+ return false
+ }
+
decision, err := h.ensurePendingOAuthAdoptionDecision(c, session.ID, req.adoptionDecision())
if err != nil {
+ if rollbackCreatedUser(err) {
+ return
+ }
response.ErrorFrom(c, err)
return
}
if err := applyPendingOAuthBinding(c.Request.Context(), client, h.authService, h.userService, session, decision, &user.ID, true, false); err != nil {
+ if rollbackCreatedUser(err) {
+ return
+ }
response.ErrorFrom(c, infraerrors.InternalServer("PENDING_AUTH_BIND_APPLY_FAILED", "failed to bind pending oauth identity").WithCause(err))
return
}
- h.authService.RecordSuccessfulLogin(c.Request.Context(), user.ID)
+
+ if err := h.authService.FinalizeOAuthEmailAccount(
+ c.Request.Context(),
+ user,
+ strings.TrimSpace(req.InvitationCode),
+ strings.TrimSpace(session.ProviderType),
+ ); err != nil {
+ if rollbackCreatedUser(err) {
+ return
+ }
+ response.ErrorFrom(c, err)
+ return
+ }
if _, err := pendingSvc.ConsumeBrowserSession(c.Request.Context(), session.SessionToken, session.BrowserSessionKey); err != nil {
+ if rollbackCreatedUser(err) {
+ return
+ }
clearCookies()
response.ErrorFrom(c, err)
return
}
+ h.authService.RecordSuccessfulLogin(c.Request.Context(), user.ID)
clearCookies()
writeOAuthTokenPairResponse(c, tokenPair)
}
diff --git a/backend/internal/handler/auth_oauth_pending_flow_test.go b/backend/internal/handler/auth_oauth_pending_flow_test.go
index 913acddc..1013a082 100644
--- a/backend/internal/handler/auth_oauth_pending_flow_test.go
+++ b/backend/internal/handler/auth_oauth_pending_flow_test.go
@@ -5,6 +5,7 @@ import (
"context"
"database/sql"
"encoding/json"
+ "errors"
"net/http"
"net/http/httptest"
"testing"
@@ -15,6 +16,7 @@ import (
"github.com/Wei-Shaw/sub2api/ent/enttest"
"github.com/Wei-Shaw/sub2api/ent/identityadoptiondecision"
"github.com/Wei-Shaw/sub2api/ent/pendingauthsession"
+ "github.com/Wei-Shaw/sub2api/ent/redeemcode"
dbuser "github.com/Wei-Shaw/sub2api/ent/user"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
@@ -61,6 +63,18 @@ func TestApplySuggestedProfileToCompletionResponseKeepsExistingPayloadValues(t *
require.Equal(t, true, payload["adoption_required"])
}
+func TestSetOAuthPendingSessionCookieUsesProviderCompletionPathPrefix(t *testing.T) {
+ recorder := httptest.NewRecorder()
+ ginCtx, _ := gin.CreateTestContext(recorder)
+ ginCtx.Request = httptest.NewRequest(http.MethodGet, "/api/v1/auth/oauth/oidc/callback", nil)
+
+ setOAuthPendingSessionCookie(ginCtx, "pending-session-token", false)
+
+ cookie := findCookie(recorder.Result().Cookies(), oauthPendingSessionCookieName)
+ require.NotNil(t, cookie)
+ require.Equal(t, "/api/v1/auth/oauth", cookie.Path)
+}
+
func TestExchangePendingOAuthCompletionPreviewThenFinalizeAppliesAdoptionDecision(t *testing.T) {
handler, client := newOAuthPendingFlowTestHandler(t, false)
ctx := context.Background()
@@ -943,6 +957,81 @@ func TestCreateOIDCOAuthAccountBlocksBackendModeBeforeCreatingUser(t *testing.T)
require.Nil(t, storedSession.ConsumedAt)
}
+func TestCreateOIDCOAuthAccountRollsBackCreatedUserWhenBindingFails(t *testing.T) {
+ handler, client := newOAuthPendingFlowTestHandlerWithEmailVerification(t, true, "fresh@example.com", "246810")
+ ctx := context.Background()
+
+ conflictOwner, err := client.User.Create().
+ SetEmail("owner@example.com").
+ SetUsername("owner-user").
+ SetPasswordHash("hash").
+ SetRole(service.RoleUser).
+ SetStatus(service.StatusActive).
+ Save(ctx)
+ require.NoError(t, err)
+
+ _, err = client.AuthIdentity.Create().
+ SetUserID(conflictOwner.ID).
+ SetProviderType("oidc").
+ SetProviderKey("https://issuer.example").
+ SetProviderSubject("oidc-conflict-123").
+ SetMetadata(map[string]any{
+ "username": "owner-user",
+ }).
+ Save(ctx)
+ require.NoError(t, err)
+
+ invitation, err := client.RedeemCode.Create().
+ SetCode("INVITE123").
+ SetType(service.RedeemTypeInvitation).
+ SetStatus(service.StatusUnused).
+ SetValue(0).
+ Save(ctx)
+ require.NoError(t, err)
+
+ session, err := client.PendingAuthSession.Create().
+ SetSessionToken("create-account-conflict-session-token").
+ SetIntent("login").
+ SetProviderType("oidc").
+ SetProviderKey("https://issuer.example").
+ SetProviderSubject("oidc-conflict-123").
+ SetBrowserSessionKey("create-account-conflict-browser-session-key").
+ SetUpstreamIdentityClaims(map[string]any{
+ "username": "oidc_user",
+ }).
+ SetRedirectTo("/profile").
+ SetExpiresAt(time.Now().UTC().Add(10 * time.Minute)).
+ Save(ctx)
+ require.NoError(t, err)
+
+ body := bytes.NewBufferString(`{"email":"fresh@example.com","verify_code":"246810","password":"secret-123","invitation_code":"INVITE123"}`)
+ recorder := httptest.NewRecorder()
+ ginCtx, _ := gin.CreateTestContext(recorder)
+ req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/oauth/oidc/create-account", body)
+ req.Header.Set("Content-Type", "application/json")
+ req.AddCookie(&http.Cookie{Name: oauthPendingSessionCookieName, Value: encodeCookieValue(session.SessionToken)})
+ req.AddCookie(&http.Cookie{Name: oauthPendingBrowserCookieName, Value: encodeCookieValue("create-account-conflict-browser-session-key")})
+ ginCtx.Request = req
+
+ handler.CreateOIDCOAuthAccount(ginCtx)
+
+ require.Equal(t, http.StatusInternalServerError, recorder.Code)
+
+ userCount, err := client.User.Query().Where(dbuser.EmailEQ("fresh@example.com")).Count(ctx)
+ require.NoError(t, err)
+ require.Zero(t, userCount)
+
+ storedInvitation, err := client.RedeemCode.Get(ctx, invitation.ID)
+ require.NoError(t, err)
+ require.Equal(t, service.StatusUnused, storedInvitation.Status)
+ require.Nil(t, storedInvitation.UsedBy)
+ require.Nil(t, storedInvitation.UsedAt)
+
+ storedSession, err := client.PendingAuthSession.Get(ctx, session.ID)
+ require.NoError(t, err)
+ require.Nil(t, storedSession.ConsumedAt)
+}
+
func TestBindOIDCOAuthLoginBindsExistingUserAndConsumesSession(t *testing.T) {
handler, client := newOAuthPendingFlowTestHandler(t, false)
ctx := context.Background()
@@ -1529,6 +1618,8 @@ type oauthPendingFlowTestHandlerOptions struct {
defaultSubAssigner service.DefaultSubscriptionAssigner
totpCache service.TotpCache
totpEncryptor service.SecretEncryptor
+ redeemRepoFactory func(client *dbent.Client) service.RedeemCodeRepository
+ userRepoOptions oauthPendingFlowUserRepoOptions
}
func newOAuthPendingFlowTestHandlerWithDependencies(
@@ -1590,7 +1681,17 @@ CREATE TABLE IF NOT EXISTS user_avatars (
settingValues[key] = value
}
settingSvc := service.NewSettingService(&oauthPendingFlowSettingRepoStub{values: settingValues}, cfg)
- userRepo := &oauthPendingFlowUserRepo{client: client}
+ userRepo := &oauthPendingFlowUserRepo{
+ client: client,
+ options: options.userRepoOptions,
+ }
+ redeemRepo := service.RedeemCodeRepository(nil)
+ if options.redeemRepoFactory != nil {
+ redeemRepo = options.redeemRepoFactory(client)
+ }
+ if redeemRepo == nil {
+ redeemRepo = &oauthPendingFlowRedeemCodeRepo{client: client}
+ }
var emailService *service.EmailService
if options.emailCache != nil {
emailService = service.NewEmailService(&oauthPendingFlowSettingRepoStub{
@@ -1602,7 +1703,7 @@ CREATE TABLE IF NOT EXISTS user_avatars (
authSvc := service.NewAuthService(
client,
userRepo,
- nil,
+ redeemRepo,
&oauthPendingFlowRefreshTokenCacheStub{},
cfg,
settingSvc,
@@ -1797,6 +1898,127 @@ func (s *oauthPendingFlowRefreshTokenCacheStub) IsTokenInFamily(context.Context,
return false, nil
}
+type oauthPendingFlowRedeemCodeRepo struct {
+ client *dbent.Client
+}
+
+func (r *oauthPendingFlowRedeemCodeRepo) Create(context.Context, *service.RedeemCode) error {
+ panic("unexpected Create call")
+}
+
+func (r *oauthPendingFlowRedeemCodeRepo) CreateBatch(context.Context, []service.RedeemCode) error {
+ panic("unexpected CreateBatch call")
+}
+
+func (r *oauthPendingFlowRedeemCodeRepo) GetByID(context.Context, int64) (*service.RedeemCode, error) {
+ panic("unexpected GetByID call")
+}
+
+func (r *oauthPendingFlowRedeemCodeRepo) GetByCode(ctx context.Context, code string) (*service.RedeemCode, error) {
+ entity, err := r.client.RedeemCode.Query().Where(redeemcode.CodeEQ(code)).Only(ctx)
+ if err != nil {
+ if dbent.IsNotFound(err) {
+ return nil, service.ErrRedeemCodeNotFound
+ }
+ return nil, err
+ }
+ notes := ""
+ if entity.Notes != nil {
+ notes = *entity.Notes
+ }
+ return &service.RedeemCode{
+ ID: entity.ID,
+ Code: entity.Code,
+ Type: entity.Type,
+ Value: entity.Value,
+ Status: entity.Status,
+ UsedBy: entity.UsedBy,
+ UsedAt: entity.UsedAt,
+ Notes: notes,
+ CreatedAt: entity.CreatedAt,
+ GroupID: entity.GroupID,
+ ValidityDays: entity.ValidityDays,
+ }, nil
+}
+
+func (r *oauthPendingFlowRedeemCodeRepo) Update(ctx context.Context, code *service.RedeemCode) error {
+ if code == nil {
+ return nil
+ }
+ update := r.client.RedeemCode.UpdateOneID(code.ID).
+ SetCode(code.Code).
+ SetType(code.Type).
+ SetValue(code.Value).
+ SetStatus(code.Status).
+ SetNotes(code.Notes).
+ SetValidityDays(code.ValidityDays)
+ if code.UsedBy != nil {
+ update = update.SetUsedBy(*code.UsedBy)
+ } else {
+ update = update.ClearUsedBy()
+ }
+ if code.UsedAt != nil {
+ update = update.SetUsedAt(*code.UsedAt)
+ } else {
+ update = update.ClearUsedAt()
+ }
+ if code.GroupID != nil {
+ update = update.SetGroupID(*code.GroupID)
+ } else {
+ update = update.ClearGroupID()
+ }
+ _, err := update.Save(ctx)
+ return err
+}
+
+func (r *oauthPendingFlowRedeemCodeRepo) Delete(context.Context, int64) error {
+ panic("unexpected Delete call")
+}
+
+func (r *oauthPendingFlowRedeemCodeRepo) Use(ctx context.Context, id, userID int64) error {
+ affected, err := r.client.RedeemCode.Update().
+ Where(redeemcode.IDEQ(id), redeemcode.StatusEQ(service.StatusUnused)).
+ SetStatus(service.StatusUsed).
+ SetUsedBy(userID).
+ SetUsedAt(time.Now().UTC()).
+ Save(ctx)
+ if err != nil {
+ return err
+ }
+ if affected == 0 {
+ return service.ErrRedeemCodeUsed
+ }
+ return nil
+}
+
+func (r *oauthPendingFlowRedeemCodeRepo) List(context.Context, pagination.PaginationParams) ([]service.RedeemCode, *pagination.PaginationResult, error) {
+ panic("unexpected List call")
+}
+
+func (r *oauthPendingFlowRedeemCodeRepo) ListWithFilters(context.Context, pagination.PaginationParams, string, string, string) ([]service.RedeemCode, *pagination.PaginationResult, error) {
+ panic("unexpected ListWithFilters call")
+}
+
+func (r *oauthPendingFlowRedeemCodeRepo) ListByUser(context.Context, int64, int) ([]service.RedeemCode, error) {
+ panic("unexpected ListByUser call")
+}
+
+func (r *oauthPendingFlowRedeemCodeRepo) ListByUserPaginated(context.Context, int64, pagination.PaginationParams, string) ([]service.RedeemCode, *pagination.PaginationResult, error) {
+ panic("unexpected ListByUserPaginated call")
+}
+
+func (r *oauthPendingFlowRedeemCodeRepo) SumPositiveBalanceByUser(context.Context, int64) (float64, error) {
+ panic("unexpected SumPositiveBalanceByUser call")
+}
+
+type oauthPendingFlowFailingUseRedeemRepo struct {
+ *oauthPendingFlowRedeemCodeRepo
+}
+
+func (r *oauthPendingFlowFailingUseRedeemRepo) Use(context.Context, int64, int64) error {
+ return errors.New("forced invitation use failure")
+}
+
func decodeJSONResponseData(t *testing.T, recorder *httptest.ResponseRecorder) map[string]any {
t.Helper()
@@ -1872,6 +2094,11 @@ func countProviderGrantRecords(
type oauthPendingFlowUserRepo struct {
client *dbent.Client
+ options oauthPendingFlowUserRepoOptions
+}
+
+type oauthPendingFlowUserRepoOptions struct {
+ rejectDeleteWhileAuthIdentityExists bool
}
func (r *oauthPendingFlowUserRepo) Create(ctx context.Context, user *service.User) error {
@@ -1953,6 +2180,15 @@ func (r *oauthPendingFlowUserRepo) Update(ctx context.Context, user *service.Use
}
func (r *oauthPendingFlowUserRepo) Delete(ctx context.Context, id int64) error {
+ if r.options.rejectDeleteWhileAuthIdentityExists {
+ count, err := r.client.AuthIdentity.Query().Where(authidentity.UserIDEQ(id)).Count(ctx)
+ if err != nil {
+ return err
+ }
+ if count > 0 {
+ return errors.New("cannot delete user while auth identities still exist")
+ }
+ }
return r.client.User.DeleteOneID(id).Exec(ctx)
}
diff --git a/backend/internal/handler/user_handler.go b/backend/internal/handler/user_handler.go
index b1ade5c0..a6a7be9a 100644
--- a/backend/internal/handler/user_handler.go
+++ b/backend/internal/handler/user_handler.go
@@ -15,14 +15,21 @@ import (
// UserHandler handles user-related requests
type UserHandler struct {
userService *service.UserService
+ authService *service.AuthService
emailService *service.EmailService
emailCache service.EmailCache
}
// NewUserHandler creates a new UserHandler
-func NewUserHandler(userService *service.UserService, emailService *service.EmailService, emailCache service.EmailCache) *UserHandler {
+func NewUserHandler(
+ userService *service.UserService,
+ authService *service.AuthService,
+ emailService *service.EmailService,
+ emailCache service.EmailCache,
+) *UserHandler {
return &UserHandler{
userService: userService,
+ authService: authService,
emailService: emailService,
emailCache: emailCache,
}
@@ -157,6 +164,16 @@ type StartIdentityBindingRequest struct {
RedirectTo string `json:"redirect_to"`
}
+type BindEmailIdentityRequest struct {
+ Email string `json:"email" binding:"required,email"`
+ VerifyCode string `json:"verify_code" binding:"required"`
+ Password string `json:"password" binding:"required,min=6"`
+}
+
+type SendEmailBindingCodeRequest struct {
+ Email string `json:"email" binding:"required,email"`
+}
+
// StartIdentityBinding returns the backend authorize URL for starting a third-party identity bind flow.
// POST /api/v1/user/auth-identities/bind/start
func (h *UserHandler) StartIdentityBinding(c *gin.Context) {
@@ -183,6 +200,73 @@ func (h *UserHandler) StartIdentityBinding(c *gin.Context) {
response.Success(c, result)
}
+// BindEmailIdentity verifies and binds a local email identity for the current user.
+// POST /api/v1/user/account-bindings/email
+func (h *UserHandler) BindEmailIdentity(c *gin.Context) {
+ subject, ok := middleware2.GetAuthSubjectFromContext(c)
+ if !ok {
+ response.Unauthorized(c, "User not authenticated")
+ return
+ }
+ if h.authService == nil {
+ response.InternalError(c, "Auth service not configured")
+ return
+ }
+
+ var req BindEmailIdentityRequest
+ if err := c.ShouldBindJSON(&req); err != nil {
+ response.BadRequest(c, "Invalid request: "+err.Error())
+ return
+ }
+
+ updatedUser, err := h.authService.BindEmailIdentity(
+ c.Request.Context(),
+ subject.UserID,
+ req.Email,
+ req.VerifyCode,
+ req.Password,
+ )
+ if err != nil {
+ response.ErrorFrom(c, err)
+ return
+ }
+
+ profileResp, err := h.buildUserProfileResponse(c.Request.Context(), subject.UserID, updatedUser)
+ if err != nil {
+ response.ErrorFrom(c, err)
+ return
+ }
+
+ response.Success(c, profileResp)
+}
+
+// SendEmailBindingCode sends a verification code for the current user's email binding flow.
+// POST /api/v1/user/account-bindings/email/send-code
+func (h *UserHandler) SendEmailBindingCode(c *gin.Context) {
+ subject, ok := middleware2.GetAuthSubjectFromContext(c)
+ if !ok {
+ response.Unauthorized(c, "User not authenticated")
+ return
+ }
+ if h.authService == nil {
+ response.InternalError(c, "Auth service not configured")
+ return
+ }
+
+ var req SendEmailBindingCodeRequest
+ if err := c.ShouldBindJSON(&req); err != nil {
+ response.BadRequest(c, "Invalid request: "+err.Error())
+ return
+ }
+
+ if err := h.authService.SendEmailIdentityBindCode(c.Request.Context(), subject.UserID, req.Email); err != nil {
+ response.ErrorFrom(c, err)
+ return
+ }
+
+ response.Success(c, gin.H{"message": "Verification code sent successfully"})
+}
+
// SendNotifyEmailCodeRequest represents the request to send notify email verification code
type SendNotifyEmailCodeRequest struct {
Email string `json:"email" binding:"required,email"`
diff --git a/backend/internal/handler/user_handler_test.go b/backend/internal/handler/user_handler_test.go
index aacfc332..72b28293 100644
--- a/backend/internal/handler/user_handler_test.go
+++ b/backend/internal/handler/user_handler_test.go
@@ -11,6 +11,7 @@ import (
"testing"
"time"
+ "github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware"
"github.com/Wei-Shaw/sub2api/internal/service"
@@ -122,7 +123,7 @@ func TestUserHandlerUpdateProfileReturnsAvatarURL(t *testing.T) {
Status: service.StatusActive,
},
}
- handler := NewUserHandler(service.NewUserService(repo, nil, nil, nil), nil, nil)
+ handler := NewUserHandler(service.NewUserService(repo, nil, nil, nil), nil, nil, nil)
body := []byte(`{"avatar_url":"https://cdn.example.com/avatar.png"}`)
recorder := httptest.NewRecorder()
@@ -180,7 +181,7 @@ func TestUserHandlerGetProfileReturnsIdentitySummaries(t *testing.T) {
},
},
}
- handler := NewUserHandler(service.NewUserService(repo, nil, nil, nil), nil, nil)
+ handler := NewUserHandler(service.NewUserService(repo, nil, nil, nil), nil, nil, nil)
recorder := httptest.NewRecorder()
c, _ := gin.CreateTestContext(recorder)
@@ -262,7 +263,7 @@ func TestUserHandlerGetProfileReturnsLegacyCompatibilityFields(t *testing.T) {
},
},
}
- handler := NewUserHandler(service.NewUserService(repo, nil, nil, nil), nil, nil)
+ handler := NewUserHandler(service.NewUserService(repo, nil, nil, nil), nil, nil, nil)
recorder := httptest.NewRecorder()
c, _ := gin.CreateTestContext(recorder)
@@ -311,6 +312,116 @@ func TestUserHandlerGetProfileReturnsLegacyCompatibilityFields(t *testing.T) {
require.Equal(t, "linuxdo", usernameSource["source"])
}
+type userHandlerEmailCacheStub struct {
+ data *service.VerificationCodeData
+}
+
+func (s *userHandlerEmailCacheStub) GetVerificationCode(context.Context, string) (*service.VerificationCodeData, error) {
+ return s.data, nil
+}
+
+func (s *userHandlerEmailCacheStub) SetVerificationCode(context.Context, string, *service.VerificationCodeData, time.Duration) error {
+ return nil
+}
+
+func (s *userHandlerEmailCacheStub) DeleteVerificationCode(context.Context, string) error {
+ return nil
+}
+
+func (s *userHandlerEmailCacheStub) GetNotifyVerifyCode(context.Context, string) (*service.VerificationCodeData, error) {
+ return nil, nil
+}
+
+func (s *userHandlerEmailCacheStub) SetNotifyVerifyCode(context.Context, string, *service.VerificationCodeData, time.Duration) error {
+ return nil
+}
+
+func (s *userHandlerEmailCacheStub) DeleteNotifyVerifyCode(context.Context, string) error {
+ return nil
+}
+
+func (s *userHandlerEmailCacheStub) GetPasswordResetToken(context.Context, string) (*service.PasswordResetTokenData, error) {
+ return nil, nil
+}
+
+func (s *userHandlerEmailCacheStub) SetPasswordResetToken(context.Context, string, *service.PasswordResetTokenData, time.Duration) error {
+ return nil
+}
+
+func (s *userHandlerEmailCacheStub) DeletePasswordResetToken(context.Context, string) error {
+ return nil
+}
+
+func (s *userHandlerEmailCacheStub) IsPasswordResetEmailInCooldown(context.Context, string) bool {
+ return false
+}
+
+func (s *userHandlerEmailCacheStub) SetPasswordResetEmailCooldown(context.Context, string, time.Duration) error {
+ return nil
+}
+
+func (s *userHandlerEmailCacheStub) GetNotifyCodeUserRate(context.Context, int64) (int64, error) {
+ return 0, nil
+}
+
+func (s *userHandlerEmailCacheStub) IncrNotifyCodeUserRate(context.Context, int64, time.Duration) (int64, error) {
+ return 0, nil
+}
+
+func TestUserHandlerBindEmailIdentityReturnsProfileResponse(t *testing.T) {
+ gin.SetMode(gin.TestMode)
+
+ repo := &userHandlerRepoStub{
+ user: &service.User{
+ ID: 11,
+ Email: "legacy-user" + service.LinuxDoConnectSyntheticEmailDomain,
+ Username: "legacy-user",
+ Role: service.RoleUser,
+ Status: service.StatusActive,
+ },
+ }
+ emailCache := &userHandlerEmailCacheStub{
+ data: &service.VerificationCodeData{
+ Code: "123456",
+ CreatedAt: time.Now().UTC(),
+ ExpiresAt: time.Now().UTC().Add(10 * time.Minute),
+ },
+ }
+ cfg := &config.Config{
+ JWT: config.JWTConfig{
+ Secret: "test-secret",
+ ExpireHour: 1,
+ },
+ }
+ emailService := service.NewEmailService(nil, emailCache)
+ authService := service.NewAuthService(nil, repo, nil, nil, cfg, nil, emailService, nil, nil, nil, nil)
+ handler := NewUserHandler(service.NewUserService(repo, nil, nil, nil), authService, nil, nil)
+
+ body := []byte(`{"email":"new@example.com","verify_code":"123456","password":"new-password"}`)
+ recorder := httptest.NewRecorder()
+ c, _ := gin.CreateTestContext(recorder)
+ c.Request = httptest.NewRequest(http.MethodPost, "/api/v1/user/account-bindings/email", bytes.NewReader(body))
+ c.Request.Header.Set("Content-Type", "application/json")
+ c.Params = gin.Params{{Key: "provider", Value: "email"}}
+ c.Set(string(middleware2.ContextKeyUser), middleware2.AuthSubject{UserID: 11})
+
+ handler.BindEmailIdentity(c)
+
+ require.Equal(t, http.StatusOK, recorder.Code)
+
+ var resp struct {
+ Code int `json:"code"`
+ Data struct {
+ Email string `json:"email"`
+ EmailBound bool `json:"email_bound"`
+ } `json:"data"`
+ }
+ require.NoError(t, json.Unmarshal(recorder.Body.Bytes(), &resp))
+ require.Equal(t, 0, resp.Code)
+ require.Equal(t, "new@example.com", resp.Data.Email)
+ require.True(t, resp.Data.EmailBound)
+}
+
func TestUserHandlerStartIdentityBindingReturnsAuthorizeURL(t *testing.T) {
gin.SetMode(gin.TestMode)
@@ -323,7 +434,7 @@ func TestUserHandlerStartIdentityBindingReturnsAuthorizeURL(t *testing.T) {
Status: service.StatusActive,
},
}
- handler := NewUserHandler(service.NewUserService(repo, nil, nil, nil), nil, nil)
+ handler := NewUserHandler(service.NewUserService(repo, nil, nil, nil), nil, nil, nil)
body := []byte(`{"provider":"wechat","redirect_to":"/settings/profile"}`)
recorder := httptest.NewRecorder()
diff --git a/backend/internal/server/routes/auth.go b/backend/internal/server/routes/auth.go
index 20d3d9b4..f1032eb5 100644
--- a/backend/internal/server/routes/auth.go
+++ b/backend/internal/server/routes/auth.go
@@ -74,6 +74,12 @@ func RegisterAuthRoutes(
}),
h.Auth.ExchangePendingOAuthCompletion,
)
+ auth.POST("/oauth/pending/send-verify-code",
+ rateLimiter.LimitWithOptions("oauth-pending-send-verify-code", 5, time.Minute, middleware.RateLimitOptions{
+ FailureMode: middleware.RateLimitFailClose,
+ }),
+ h.Auth.SendPendingOAuthVerifyCode,
+ )
auth.POST("/oauth/pending/create-account",
rateLimiter.LimitWithOptions("oauth-pending-create-account", 10, time.Minute, middleware.RateLimitOptions{
FailureMode: middleware.RateLimitFailClose,
diff --git a/backend/internal/server/routes/auth_rate_limit_test.go b/backend/internal/server/routes/auth_rate_limit_test.go
index 4f411cec..07a66efb 100644
--- a/backend/internal/server/routes/auth_rate_limit_test.go
+++ b/backend/internal/server/routes/auth_rate_limit_test.go
@@ -52,6 +52,7 @@ func TestAuthRoutesRateLimitFailCloseWhenRedisUnavailable(t *testing.T) {
"/api/v1/auth/login",
"/api/v1/auth/login/2fa",
"/api/v1/auth/send-verify-code",
+ "/api/v1/auth/oauth/pending/send-verify-code",
}
for _, path := range paths {
diff --git a/backend/internal/server/routes/user.go b/backend/internal/server/routes/user.go
index ccbe23ce..46baa80a 100644
--- a/backend/internal/server/routes/user.go
+++ b/backend/internal/server/routes/user.go
@@ -25,6 +25,8 @@ func RegisterUserRoutes(
user.GET("/profile", h.User.GetProfile)
user.PUT("/password", h.User.ChangePassword)
user.PUT("", h.User.UpdateProfile)
+ user.POST("/account-bindings/email/send-code", h.User.SendEmailBindingCode)
+ user.POST("/account-bindings/email", h.User.BindEmailIdentity)
user.POST("/auth-identities/bind/start", h.User.StartIdentityBinding)
// 通知邮箱管理
diff --git a/backend/internal/service/auth_email_binding.go b/backend/internal/service/auth_email_binding.go
new file mode 100644
index 00000000..b999660b
--- /dev/null
+++ b/backend/internal/service/auth_email_binding.go
@@ -0,0 +1,128 @@
+package service
+
+import (
+ "context"
+ "errors"
+ "fmt"
+ "net/mail"
+ "strings"
+
+ infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
+)
+
+// BindEmailIdentity verifies and binds a local email/password identity to the current user.
+func (s *AuthService) BindEmailIdentity(
+ ctx context.Context,
+ userID int64,
+ email string,
+ verifyCode string,
+ password string,
+) (*User, error) {
+ if s == nil {
+ return nil, ErrServiceUnavailable
+ }
+
+ normalizedEmail, err := normalizeEmailForIdentityBinding(email)
+ if err != nil {
+ return nil, err
+ }
+ if isReservedEmail(normalizedEmail) {
+ return nil, ErrEmailReserved
+ }
+ if strings.TrimSpace(password) == "" {
+ return nil, ErrPasswordRequired
+ }
+ if err := s.VerifyOAuthEmailCode(ctx, normalizedEmail, verifyCode); err != nil {
+ return nil, err
+ }
+
+ currentUser, err := s.userRepo.GetByID(ctx, userID)
+ if err != nil {
+ return nil, err
+ }
+
+ existingUser, err := s.userRepo.GetByEmail(ctx, normalizedEmail)
+ switch {
+ case err == nil && existingUser != nil && existingUser.ID != userID:
+ return nil, ErrEmailExists
+ case err != nil && !errors.Is(err, ErrUserNotFound):
+ return nil, ErrServiceUnavailable
+ }
+
+ hashedPassword, err := s.HashPassword(password)
+ if err != nil {
+ return nil, fmt.Errorf("hash password: %w", err)
+ }
+
+ firstRealEmailBind := !hasBindableEmailIdentitySubject(currentUser.Email)
+ currentUser.Email = normalizedEmail
+ currentUser.PasswordHash = hashedPassword
+ if err := s.userRepo.Update(ctx, currentUser); err != nil {
+ if errors.Is(err, ErrEmailExists) {
+ return nil, ErrEmailExists
+ }
+ return nil, ErrServiceUnavailable
+ }
+
+ if firstRealEmailBind {
+ if err := s.ApplyProviderDefaultSettingsOnFirstBind(ctx, userID, "email"); err != nil {
+ return nil, fmt.Errorf("apply email first bind defaults: %w", err)
+ }
+ }
+
+ return currentUser, nil
+}
+
+// SendEmailIdentityBindCode sends a verification code for authenticated email binding flows.
+func (s *AuthService) SendEmailIdentityBindCode(ctx context.Context, userID int64, email string) error {
+ if s == nil {
+ return ErrServiceUnavailable
+ }
+
+ normalizedEmail, err := normalizeEmailForIdentityBinding(email)
+ if err != nil {
+ return err
+ }
+ if isReservedEmail(normalizedEmail) {
+ return ErrEmailReserved
+ }
+ if s.emailService == nil {
+ return ErrServiceUnavailable
+ }
+ if _, err := s.userRepo.GetByID(ctx, userID); err != nil {
+ if errors.Is(err, ErrUserNotFound) {
+ return ErrUserNotFound
+ }
+ return ErrServiceUnavailable
+ }
+
+ existingUser, err := s.userRepo.GetByEmail(ctx, normalizedEmail)
+ switch {
+ case err == nil && existingUser != nil && existingUser.ID != userID:
+ return ErrEmailExists
+ case err != nil && !errors.Is(err, ErrUserNotFound):
+ return ErrServiceUnavailable
+ }
+
+ siteName := "Sub2API"
+ if s.settingService != nil {
+ siteName = s.settingService.GetSiteName(ctx)
+ }
+ return s.emailService.SendVerifyCode(ctx, normalizedEmail, siteName)
+}
+
+func normalizeEmailForIdentityBinding(email string) (string, error) {
+ normalized := strings.ToLower(strings.TrimSpace(email))
+ if normalized == "" || len(normalized) > 255 {
+ return "", infraerrors.BadRequest("INVALID_EMAIL", "invalid email")
+ }
+ if _, err := mail.ParseAddress(normalized); err != nil {
+ return "", infraerrors.BadRequest("INVALID_EMAIL", "invalid email")
+ }
+ return normalized, nil
+}
+
+func hasBindableEmailIdentitySubject(email string) bool {
+ normalized := strings.ToLower(strings.TrimSpace(email))
+ return normalized != "" && !isReservedEmail(normalized)
+}
diff --git a/backend/internal/service/auth_oauth_email_flow.go b/backend/internal/service/auth_oauth_email_flow.go
index 2e0107ae..ce25222c 100644
--- a/backend/internal/service/auth_oauth_email_flow.go
+++ b/backend/internal/service/auth_oauth_email_flow.go
@@ -4,9 +4,71 @@ import (
"context"
"errors"
"fmt"
+ "net/mail"
"strings"
+ "time"
)
+func normalizeOAuthSignupSource(signupSource string) string {
+ signupSource = strings.TrimSpace(strings.ToLower(signupSource))
+ if signupSource == "" {
+ return "email"
+ }
+ return signupSource
+}
+
+// SendPendingOAuthVerifyCode sends a local verification code for pending OAuth
+// account-creation flows without relying on the public registration gate.
+func (s *AuthService) SendPendingOAuthVerifyCode(ctx context.Context, email string) (*SendVerifyCodeResult, error) {
+ email = strings.TrimSpace(strings.ToLower(email))
+ if email == "" {
+ return nil, ErrEmailVerifyRequired
+ }
+ if _, err := mail.ParseAddress(email); err != nil {
+ return nil, ErrEmailVerifyRequired
+ }
+ if isReservedEmail(email) {
+ return nil, ErrEmailReserved
+ }
+ if s == nil || s.emailService == nil {
+ return nil, ErrServiceUnavailable
+ }
+
+ siteName := "Sub2API"
+ if s.settingService != nil {
+ siteName = s.settingService.GetSiteName(ctx)
+ }
+ if err := s.emailService.SendVerifyCode(ctx, email, siteName); err != nil {
+ return nil, err
+ }
+ return &SendVerifyCodeResult{
+ Countdown: int(verifyCodeCooldown / time.Second),
+ }, nil
+}
+
+func (s *AuthService) validateOAuthRegistrationInvitation(ctx context.Context, invitationCode string) (*RedeemCode, error) {
+ if s == nil || s.settingService == nil || !s.settingService.IsInvitationCodeEnabled(ctx) {
+ return nil, nil
+ }
+ if s.redeemRepo == nil {
+ return nil, ErrServiceUnavailable
+ }
+
+ invitationCode = strings.TrimSpace(invitationCode)
+ if invitationCode == "" {
+ return nil, ErrInvitationCodeRequired
+ }
+
+ redeemCode, err := s.redeemRepo.GetByCode(ctx, invitationCode)
+ if err != nil {
+ return nil, ErrInvitationCodeInvalid
+ }
+ if redeemCode.Type != RedeemTypeInvitation || redeemCode.Status != StatusUnused {
+ return nil, ErrInvitationCodeInvalid
+ }
+ return redeemCode, nil
+}
+
// VerifyOAuthEmailCode verifies the locally entered email verification code for
// third-party signup and binding flows. This is intentionally independent from
// the global registration email verification toggle.
@@ -54,19 +116,8 @@ func (s *AuthService) RegisterOAuthEmailAccount(
return nil, nil, err
}
- var invitationRedeemCode *RedeemCode
- if s.settingService.IsInvitationCodeEnabled(ctx) {
- if invitationCode == "" {
- return nil, nil, ErrInvitationCodeRequired
- }
- redeemCode, err := s.redeemRepo.GetByCode(ctx, invitationCode)
- if err != nil {
- return nil, nil, ErrInvitationCodeInvalid
- }
- if redeemCode.Type != RedeemTypeInvitation || redeemCode.Status != StatusUnused {
- return nil, nil, ErrInvitationCodeInvalid
- }
- invitationRedeemCode = redeemCode
+ if _, err := s.validateOAuthRegistrationInvitation(ctx, invitationCode); err != nil {
+ return nil, nil, err
}
existsEmail, err := s.userRepo.ExistsByEmail(ctx, email)
@@ -104,22 +155,91 @@ func (s *AuthService) RegisterOAuthEmailAccount(
return nil, nil, ErrServiceUnavailable
}
- s.postAuthUserBootstrap(ctx, user, signupSource, false)
- s.assignSubscriptions(ctx, user.ID, grantPlan.Subscriptions, "auto assigned by signup defaults")
-
- if invitationRedeemCode != nil {
- if err := s.redeemRepo.Use(ctx, invitationRedeemCode.ID, user.ID); err != nil {
- return nil, nil, ErrInvitationCodeInvalid
- }
- }
-
tokenPair, err := s.GenerateTokenPair(ctx, user, "")
if err != nil {
+ _ = s.RollbackOAuthEmailAccountCreation(ctx, user.ID, "")
return nil, nil, fmt.Errorf("generate token pair: %w", err)
}
return tokenPair, user, nil
}
+// FinalizeOAuthEmailAccount applies invitation usage and normal signup bootstrap
+// only after the pending OAuth flow has fully reached its last reversible step.
+func (s *AuthService) FinalizeOAuthEmailAccount(
+ ctx context.Context,
+ user *User,
+ invitationCode string,
+ signupSource string,
+) error {
+ if s == nil || user == nil || user.ID <= 0 {
+ return ErrServiceUnavailable
+ }
+
+ signupSource = normalizeOAuthSignupSource(signupSource)
+ invitationRedeemCode, err := s.validateOAuthRegistrationInvitation(ctx, invitationCode)
+ if err != nil {
+ return err
+ }
+ if invitationRedeemCode != nil {
+ if err := s.redeemRepo.Use(ctx, invitationRedeemCode.ID, user.ID); err != nil {
+ return ErrInvitationCodeInvalid
+ }
+ }
+
+ s.postAuthUserBootstrap(ctx, user, signupSource, false)
+ grantPlan := s.resolveSignupGrantPlan(ctx, signupSource)
+ s.assignSubscriptions(ctx, user.ID, grantPlan.Subscriptions, "auto assigned by signup defaults")
+ return nil
+}
+
+// RollbackOAuthEmailAccountCreation removes a partially-created local account
+// and restores any invitation code already consumed by that account.
+func (s *AuthService) RollbackOAuthEmailAccountCreation(ctx context.Context, userID int64, invitationCode string) error {
+ if s == nil || s.userRepo == nil || userID <= 0 {
+ return ErrServiceUnavailable
+ }
+ if err := s.restoreOAuthRegistrationInvitation(ctx, invitationCode, userID); err != nil {
+ return err
+ }
+ if err := s.userRepo.Delete(ctx, userID); err != nil {
+ return fmt.Errorf("delete created oauth user: %w", err)
+ }
+ return nil
+}
+
+func (s *AuthService) restoreOAuthRegistrationInvitation(ctx context.Context, invitationCode string, userID int64) error {
+ if s == nil || s.settingService == nil || !s.settingService.IsInvitationCodeEnabled(ctx) {
+ return nil
+ }
+ if s.redeemRepo == nil {
+ return ErrServiceUnavailable
+ }
+
+ invitationCode = strings.TrimSpace(invitationCode)
+ if invitationCode == "" || userID <= 0 {
+ return nil
+ }
+
+ redeemCode, err := s.redeemRepo.GetByCode(ctx, invitationCode)
+ if err != nil {
+ if errors.Is(err, ErrRedeemCodeNotFound) {
+ return nil
+ }
+ return fmt.Errorf("load invitation code: %w", err)
+ }
+ if redeemCode.Type != RedeemTypeInvitation || redeemCode.Status != StatusUsed || redeemCode.UsedBy == nil || *redeemCode.UsedBy != userID {
+ return nil
+ }
+
+ redeemCode.Status = StatusUnused
+ redeemCode.UsedBy = nil
+ redeemCode.UsedAt = nil
+ if err := s.redeemRepo.Update(ctx, redeemCode); err != nil {
+ return fmt.Errorf("restore invitation code: %w", err)
+ }
+ return nil
+}
+
// ValidatePasswordCredentials checks the local password without completing the
// login flow. This is used by pending third-party account adoption flows before
// the external identity has been bound.
diff --git a/backend/internal/service/auth_oauth_email_flow_test.go b/backend/internal/service/auth_oauth_email_flow_test.go
new file mode 100644
index 00000000..a77dda72
--- /dev/null
+++ b/backend/internal/service/auth_oauth_email_flow_test.go
@@ -0,0 +1,251 @@
+//go:build unit
+
+package service
+
+import (
+ "context"
+ "errors"
+ "testing"
+ "time"
+
+ "github.com/Wei-Shaw/sub2api/internal/config"
+ "github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
+ "github.com/stretchr/testify/require"
+)
+
+type redeemCodeRepoStub struct {
+ codesByCode map[string]*RedeemCode
+ useCalls []struct {
+ id int64
+ userID int64
+ }
+ updateCalls []*RedeemCode
+}
+
+func (s *redeemCodeRepoStub) Create(context.Context, *RedeemCode) error {
+ panic("unexpected Create call")
+}
+
+func (s *redeemCodeRepoStub) CreateBatch(context.Context, []RedeemCode) error {
+ panic("unexpected CreateBatch call")
+}
+
+func (s *redeemCodeRepoStub) GetByID(context.Context, int64) (*RedeemCode, error) {
+ panic("unexpected GetByID call")
+}
+
+func (s *redeemCodeRepoStub) GetByCode(_ context.Context, code string) (*RedeemCode, error) {
+ if s.codesByCode == nil {
+ return nil, ErrRedeemCodeNotFound
+ }
+ redeemCode, ok := s.codesByCode[code]
+ if !ok {
+ return nil, ErrRedeemCodeNotFound
+ }
+ cloned := *redeemCode
+ return &cloned, nil
+}
+
+func (s *redeemCodeRepoStub) Update(_ context.Context, code *RedeemCode) error {
+ if code == nil {
+ return nil
+ }
+ cloned := *code
+ s.updateCalls = append(s.updateCalls, &cloned)
+ if s.codesByCode == nil {
+ s.codesByCode = make(map[string]*RedeemCode)
+ }
+ s.codesByCode[cloned.Code] = &cloned
+ return nil
+}
+
+func (s *redeemCodeRepoStub) Delete(context.Context, int64) error {
+ panic("unexpected Delete call")
+}
+
+func (s *redeemCodeRepoStub) Use(_ context.Context, id, userID int64) error {
+ for code, redeemCode := range s.codesByCode {
+ if redeemCode.ID != id {
+ continue
+ }
+ now := time.Now().UTC()
+ redeemCode.Status = StatusUsed
+ redeemCode.UsedBy = &userID
+ redeemCode.UsedAt = &now
+ s.codesByCode[code] = redeemCode
+ s.useCalls = append(s.useCalls, struct {
+ id int64
+ userID int64
+ }{id: id, userID: userID})
+ return nil
+ }
+ return ErrRedeemCodeNotFound
+}
+
+func (s *redeemCodeRepoStub) List(context.Context, pagination.PaginationParams) ([]RedeemCode, *pagination.PaginationResult, error) {
+ panic("unexpected List call")
+}
+
+func (s *redeemCodeRepoStub) ListWithFilters(context.Context, pagination.PaginationParams, string, string, string) ([]RedeemCode, *pagination.PaginationResult, error) {
+ panic("unexpected ListWithFilters call")
+}
+
+func (s *redeemCodeRepoStub) ListByUser(context.Context, int64, int) ([]RedeemCode, error) {
+ panic("unexpected ListByUser call")
+}
+
+func (s *redeemCodeRepoStub) ListByUserPaginated(context.Context, int64, pagination.PaginationParams, string) ([]RedeemCode, *pagination.PaginationResult, error) {
+ panic("unexpected ListByUserPaginated call")
+}
+
+func (s *redeemCodeRepoStub) SumPositiveBalanceByUser(context.Context, int64) (float64, error) {
+ panic("unexpected SumPositiveBalanceByUser call")
+}
+
+func newOAuthEmailFlowAuthService(
+ userRepo UserRepository,
+ redeemRepo RedeemCodeRepository,
+ refreshTokenCache RefreshTokenCache,
+ settings map[string]string,
+ emailCache EmailCache,
+) *AuthService {
+ cfg := &config.Config{
+ JWT: config.JWTConfig{
+ Secret: "test-secret",
+ ExpireHour: 1,
+ AccessTokenExpireMinutes: 60,
+ RefreshTokenExpireDays: 7,
+ },
+ Default: config.DefaultConfig{
+ UserBalance: 3.5,
+ UserConcurrency: 2,
+ },
+ }
+
+ settingService := NewSettingService(&settingRepoStub{values: settings}, cfg)
+ emailService := NewEmailService(&settingRepoStub{values: settings}, emailCache)
+
+ return NewAuthService(
+ nil,
+ userRepo,
+ redeemRepo,
+ refreshTokenCache,
+ cfg,
+ settingService,
+ emailService,
+ nil,
+ nil,
+ nil,
+ nil,
+ )
+}
+
+func TestRegisterOAuthEmailAccountRollsBackCreatedUserWhenTokenPairGenerationFails(t *testing.T) {
+ userRepo := &userRepoStub{nextID: 42}
+ redeemRepo := &redeemCodeRepoStub{
+ codesByCode: map[string]*RedeemCode{
+ "INVITE123": {
+ ID: 7,
+ Code: "INVITE123",
+ Type: RedeemTypeInvitation,
+ Status: StatusUnused,
+ },
+ },
+ }
+ emailCache := &emailCacheStub{
+ data: &VerificationCodeData{
+ Code: "246810",
+ Attempts: 0,
+ CreatedAt: time.Now().UTC(),
+ ExpiresAt: time.Now().UTC().Add(15 * time.Minute),
+ },
+ }
+ authService := newOAuthEmailFlowAuthService(
+ userRepo,
+ redeemRepo,
+ nil,
+ map[string]string{
+ SettingKeyRegistrationEnabled: "true",
+ SettingKeyInvitationCodeEnabled: "true",
+ SettingKeyEmailVerifyEnabled: "true",
+ },
+ emailCache,
+ )
+
+ tokenPair, user, err := authService.RegisterOAuthEmailAccount(
+ context.Background(),
+ "fresh@example.com",
+ "secret-123",
+ "246810",
+ "INVITE123",
+ "oidc",
+ )
+
+ require.Nil(t, tokenPair)
+ require.Nil(t, user)
+ require.Error(t, err)
+ require.Contains(t, err.Error(), "generate token pair")
+ require.Equal(t, []int64{42}, userRepo.deletedIDs)
+ require.Len(t, userRepo.created, 1)
+ require.Empty(t, redeemRepo.useCalls)
+ require.Empty(t, redeemRepo.updateCalls)
+}
+
+func TestRollbackOAuthEmailAccountCreationRestoresInvitationUsage(t *testing.T) {
+ userRepo := &userRepoStub{}
+ redeemRepo := &redeemCodeRepoStub{
+ codesByCode: map[string]*RedeemCode{
+ "INVITE123": {
+ ID: 7,
+ Code: "INVITE123",
+ Type: RedeemTypeInvitation,
+ Status: StatusUsed,
+ UsedBy: func() *int64 {
+ v := int64(42)
+ return &v
+ }(),
+ UsedAt: func() *time.Time {
+ v := time.Now().UTC()
+ return &v
+ }(),
+ },
+ },
+ }
+ authService := newOAuthEmailFlowAuthService(
+ userRepo,
+ redeemRepo,
+ &refreshTokenCacheStub{},
+ map[string]string{
+ SettingKeyRegistrationEnabled: "true",
+ SettingKeyInvitationCodeEnabled: "true",
+ },
+ &emailCacheStub{},
+ )
+
+ err := authService.RollbackOAuthEmailAccountCreation(context.Background(), 42, "INVITE123")
+
+ require.NoError(t, err)
+ require.Equal(t, []int64{42}, userRepo.deletedIDs)
+ require.Len(t, redeemRepo.updateCalls, 1)
+ require.Equal(t, StatusUnused, redeemRepo.updateCalls[0].Status)
+ require.Nil(t, redeemRepo.updateCalls[0].UsedBy)
+ require.Nil(t, redeemRepo.updateCalls[0].UsedAt)
+}
+
+func TestRollbackOAuthEmailAccountCreationPropagatesDeleteError(t *testing.T) {
+ userRepo := &userRepoStub{deleteErr: errors.New("delete failed")}
+ authService := newOAuthEmailFlowAuthService(
+ userRepo,
+ &redeemCodeRepoStub{},
+ &refreshTokenCacheStub{},
+ map[string]string{
+ SettingKeyRegistrationEnabled: "true",
+ },
+ &emailCacheStub{},
+ )
+
+ err := authService.RollbackOAuthEmailAccountCreation(context.Background(), 42, "")
+
+ require.Error(t, err)
+ require.Contains(t, err.Error(), "delete created oauth user")
+}
diff --git a/backend/internal/service/auth_service_email_bind_test.go b/backend/internal/service/auth_service_email_bind_test.go
new file mode 100644
index 00000000..899a736d
--- /dev/null
+++ b/backend/internal/service/auth_service_email_bind_test.go
@@ -0,0 +1,316 @@
+//go:build unit
+
+package service_test
+
+import (
+ "context"
+ "database/sql"
+ "testing"
+ "time"
+
+ dbent "github.com/Wei-Shaw/sub2api/ent"
+ "github.com/Wei-Shaw/sub2api/ent/authidentity"
+ "github.com/Wei-Shaw/sub2api/ent/enttest"
+ "github.com/Wei-Shaw/sub2api/internal/config"
+ "github.com/Wei-Shaw/sub2api/internal/repository"
+ "github.com/Wei-Shaw/sub2api/internal/service"
+ "github.com/stretchr/testify/require"
+
+ "entgo.io/ent/dialect"
+ entsql "entgo.io/ent/dialect/sql"
+ _ "modernc.org/sqlite"
+)
+
+type emailBindDefaultSubAssignerStub struct {
+ calls []*service.AssignSubscriptionInput
+}
+
+func (s *emailBindDefaultSubAssignerStub) AssignOrExtendSubscription(
+ _ context.Context,
+ input *service.AssignSubscriptionInput,
+) (*service.UserSubscription, bool, error) {
+ cloned := *input
+ s.calls = append(s.calls, &cloned)
+ return &service.UserSubscription{UserID: input.UserID, GroupID: input.GroupID}, false, nil
+}
+
+func newAuthServiceForEmailBind(
+ t *testing.T,
+ settings map[string]string,
+ emailCache service.EmailCache,
+ defaultSubAssigner service.DefaultSubscriptionAssigner,
+) (*service.AuthService, service.UserRepository, *dbent.Client) {
+ t.Helper()
+
+ db, err := sql.Open("sqlite", "file:auth_service_email_bind?mode=memory&cache=shared")
+ require.NoError(t, err)
+ t.Cleanup(func() { _ = db.Close() })
+
+ _, err = db.Exec("PRAGMA foreign_keys = ON")
+ require.NoError(t, err)
+ _, err = db.Exec(`
+CREATE TABLE IF NOT EXISTS user_provider_default_grants (
+ id INTEGER PRIMARY KEY AUTOINCREMENT,
+ user_id INTEGER NOT NULL,
+ provider_type TEXT NOT NULL,
+ grant_reason TEXT NOT NULL DEFAULT 'first_bind',
+ created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP,
+ UNIQUE(user_id, provider_type, grant_reason)
+)`)
+ require.NoError(t, err)
+
+ drv := entsql.OpenDB(dialect.SQLite, db)
+ client := enttest.NewClient(t, enttest.WithOptions(dbent.Driver(drv)))
+ t.Cleanup(func() { _ = client.Close() })
+
+ repo := repository.NewUserRepository(client, db)
+ cfg := &config.Config{
+ JWT: config.JWTConfig{
+ Secret: "test-bind-email-secret",
+ ExpireHour: 1,
+ },
+ Default: config.DefaultConfig{
+ UserBalance: 3.5,
+ UserConcurrency: 2,
+ },
+ }
+
+ settingRepo := &emailBindSettingRepoStub{values: settings}
+ settingSvc := service.NewSettingService(settingRepo, cfg)
+
+ var emailSvc *service.EmailService
+ if emailCache != nil {
+ emailSvc = service.NewEmailService(settingRepo, emailCache)
+ }
+
+ svc := service.NewAuthService(client, repo, nil, nil, cfg, settingSvc, emailSvc, nil, nil, nil, defaultSubAssigner)
+ return svc, repo, client
+}
+
+func TestAuthServiceBindEmailIdentity_UpdatesEmailAndAppliesFirstBindDefaults(t *testing.T) {
+ assigner := &emailBindDefaultSubAssignerStub{}
+ cache := &emailBindCacheStub{
+ data: &service.VerificationCodeData{
+ Code: "123456",
+ CreatedAt: time.Now().UTC(),
+ ExpiresAt: time.Now().UTC().Add(10 * time.Minute),
+ },
+ }
+ svc, _, client := newAuthServiceForEmailBind(t, map[string]string{
+ service.SettingKeyAuthSourceDefaultEmailBalance: "8.5",
+ service.SettingKeyAuthSourceDefaultEmailConcurrency: "4",
+ service.SettingKeyAuthSourceDefaultEmailSubscriptions: `[{"group_id":11,"validity_days":30}]`,
+ service.SettingKeyAuthSourceDefaultEmailGrantOnFirstBind: "true",
+ }, cache, assigner)
+
+ ctx := context.Background()
+ user, err := client.User.Create().
+ SetEmail("legacy-user" + service.LinuxDoConnectSyntheticEmailDomain).
+ SetUsername("legacy-user").
+ SetPasswordHash("old-hash").
+ SetBalance(2.5).
+ SetConcurrency(1).
+ SetRole(service.RoleUser).
+ SetStatus(service.StatusActive).
+ Save(ctx)
+ require.NoError(t, err)
+
+ updatedUser, err := svc.BindEmailIdentity(ctx, user.ID, " NewEmail@Example.com ", "123456", "new-password")
+ require.NoError(t, err)
+ require.NotNil(t, updatedUser)
+ require.Equal(t, "newemail@example.com", updatedUser.Email)
+
+ storedUser, err := client.User.Get(ctx, user.ID)
+ require.NoError(t, err)
+ require.Equal(t, "newemail@example.com", storedUser.Email)
+ require.Equal(t, 11.0, storedUser.Balance)
+ require.Equal(t, 5, storedUser.Concurrency)
+ require.True(t, svc.CheckPassword("new-password", storedUser.PasswordHash))
+
+ identityCount, err := client.AuthIdentity.Query().
+ Where(
+ authidentity.UserIDEQ(user.ID),
+ authidentity.ProviderTypeEQ("email"),
+ authidentity.ProviderKeyEQ("email"),
+ authidentity.ProviderSubjectEQ("newemail@example.com"),
+ ).
+ Count(ctx)
+ require.NoError(t, err)
+ require.Equal(t, 1, identityCount)
+
+ require.Len(t, assigner.calls, 1)
+ require.Equal(t, user.ID, assigner.calls[0].UserID)
+ require.Equal(t, int64(11), assigner.calls[0].GroupID)
+ require.Equal(t, 30, assigner.calls[0].ValidityDays)
+ require.Equal(t, 1, countProviderGrantRecords(t, client, user.ID, "email", "first_bind"))
+}
+
+func TestAuthServiceBindEmailIdentity_RejectsExistingEmailOnAnotherUser(t *testing.T) {
+ cache := &emailBindCacheStub{
+ data: &service.VerificationCodeData{
+ Code: "123456",
+ CreatedAt: time.Now().UTC(),
+ ExpiresAt: time.Now().UTC().Add(10 * time.Minute),
+ },
+ }
+ svc, _, client := newAuthServiceForEmailBind(t, nil, cache, nil)
+
+ ctx := context.Background()
+ sourceUser, err := client.User.Create().
+ SetEmail("source-user" + service.OIDCConnectSyntheticEmailDomain).
+ SetUsername("source-user").
+ SetPasswordHash("old-hash").
+ SetBalance(1).
+ SetConcurrency(1).
+ SetRole(service.RoleUser).
+ SetStatus(service.StatusActive).
+ Save(ctx)
+ require.NoError(t, err)
+ _, err = client.User.Create().
+ SetEmail("taken@example.com").
+ SetUsername("taken-user").
+ SetPasswordHash("hash").
+ SetBalance(1).
+ SetConcurrency(1).
+ SetRole(service.RoleUser).
+ SetStatus(service.StatusActive).
+ Save(ctx)
+ require.NoError(t, err)
+
+ updatedUser, err := svc.BindEmailIdentity(ctx, sourceUser.ID, "taken@example.com", "123456", "new-password")
+ require.ErrorIs(t, err, service.ErrEmailExists)
+ require.Nil(t, updatedUser)
+
+ storedUser, err := client.User.Get(ctx, sourceUser.ID)
+ require.NoError(t, err)
+ require.Equal(t, "source-user"+service.OIDCConnectSyntheticEmailDomain, storedUser.Email)
+ require.Equal(t, 0, countProviderGrantRecords(t, client, sourceUser.ID, "email", "first_bind"))
+}
+
+func TestAuthServiceBindEmailIdentity_RejectsReservedEmail(t *testing.T) {
+ cache := &emailBindCacheStub{
+ data: &service.VerificationCodeData{
+ Code: "123456",
+ CreatedAt: time.Now().UTC(),
+ ExpiresAt: time.Now().UTC().Add(10 * time.Minute),
+ },
+ }
+ svc, _, client := newAuthServiceForEmailBind(t, nil, cache, nil)
+
+ ctx := context.Background()
+ user, err := client.User.Create().
+ SetEmail("source-user@example.com").
+ SetUsername("source-user").
+ SetPasswordHash("old-hash").
+ SetBalance(1).
+ SetConcurrency(1).
+ SetRole(service.RoleUser).
+ SetStatus(service.StatusActive).
+ Save(ctx)
+ require.NoError(t, err)
+
+ updatedUser, err := svc.BindEmailIdentity(ctx, user.ID, "reserved"+service.LinuxDoConnectSyntheticEmailDomain, "123456", "new-password")
+ require.ErrorIs(t, err, service.ErrEmailReserved)
+ require.Nil(t, updatedUser)
+}
+
+type emailBindSettingRepoStub struct {
+ values map[string]string
+}
+
+func (s *emailBindSettingRepoStub) Get(context.Context, string) (*service.Setting, error) {
+ panic("unexpected Get call")
+}
+
+func (s *emailBindSettingRepoStub) GetValue(_ context.Context, key string) (string, error) {
+ if v, ok := s.values[key]; ok {
+ return v, nil
+ }
+ return "", service.ErrSettingNotFound
+}
+
+func (s *emailBindSettingRepoStub) Set(context.Context, string, string) error {
+ panic("unexpected Set call")
+}
+
+func (s *emailBindSettingRepoStub) GetMultiple(_ context.Context, keys []string) (map[string]string, error) {
+ out := make(map[string]string, len(keys))
+ for _, key := range keys {
+ if v, ok := s.values[key]; ok {
+ out[key] = v
+ }
+ }
+ return out, nil
+}
+
+func (s *emailBindSettingRepoStub) SetMultiple(context.Context, map[string]string) error {
+ panic("unexpected SetMultiple call")
+}
+
+func (s *emailBindSettingRepoStub) GetAll(context.Context) (map[string]string, error) {
+ panic("unexpected GetAll call")
+}
+
+func (s *emailBindSettingRepoStub) Delete(context.Context, string) error {
+ panic("unexpected Delete call")
+}
+
+type emailBindCacheStub struct {
+ data *service.VerificationCodeData
+ err error
+}
+
+func (s *emailBindCacheStub) GetVerificationCode(context.Context, string) (*service.VerificationCodeData, error) {
+ if s.err != nil {
+ return nil, s.err
+ }
+ return s.data, nil
+}
+
+func (s *emailBindCacheStub) SetVerificationCode(context.Context, string, *service.VerificationCodeData, time.Duration) error {
+ return nil
+}
+
+func (s *emailBindCacheStub) DeleteVerificationCode(context.Context, string) error {
+ return nil
+}
+
+func (s *emailBindCacheStub) GetNotifyVerifyCode(context.Context, string) (*service.VerificationCodeData, error) {
+ return nil, nil
+}
+
+func (s *emailBindCacheStub) SetNotifyVerifyCode(context.Context, string, *service.VerificationCodeData, time.Duration) error {
+ return nil
+}
+
+func (s *emailBindCacheStub) DeleteNotifyVerifyCode(context.Context, string) error {
+ return nil
+}
+
+func (s *emailBindCacheStub) GetPasswordResetToken(context.Context, string) (*service.PasswordResetTokenData, error) {
+ return nil, nil
+}
+
+func (s *emailBindCacheStub) SetPasswordResetToken(context.Context, string, *service.PasswordResetTokenData, time.Duration) error {
+ return nil
+}
+
+func (s *emailBindCacheStub) DeletePasswordResetToken(context.Context, string) error {
+ return nil
+}
+
+func (s *emailBindCacheStub) IsPasswordResetEmailInCooldown(context.Context, string) bool {
+ return false
+}
+
+func (s *emailBindCacheStub) SetPasswordResetEmailCooldown(context.Context, string, time.Duration) error {
+ return nil
+}
+
+func (s *emailBindCacheStub) GetNotifyCodeUserRate(context.Context, int64) (int64, error) {
+ return 0, nil
+}
+
+func (s *emailBindCacheStub) IncrNotifyCodeUserRate(context.Context, int64, time.Duration) (int64, error) {
+ return 0, nil
+}
diff --git a/frontend/src/api/auth.ts b/frontend/src/api/auth.ts
index 9cf0c7ad..0f768018 100644
--- a/frontend/src/api/auth.ts
+++ b/frontend/src/api/auth.ts
@@ -449,6 +449,16 @@ export async function sendVerifyCode(
return data
}
+export async function sendPendingOAuthVerifyCode(
+ request: SendVerifyCodeRequest
+): Promise {
+ const { data } = await apiClient.post(
+ '/auth/oauth/pending/send-verify-code',
+ request
+ )
+ return data
+}
+
/**
* Validate promo code response
*/
@@ -638,6 +648,7 @@ export const authAPI = {
clearAuthToken,
getPublicSettings,
sendVerifyCode,
+ sendPendingOAuthVerifyCode,
validatePromoCode,
validateInvitationCode,
forgotPassword,
diff --git a/frontend/src/api/user.ts b/frontend/src/api/user.ts
index 7b498303..502bf151 100644
--- a/frontend/src/api/user.ts
+++ b/frontend/src/api/user.ts
@@ -89,6 +89,19 @@ export async function toggleNotifyEmail(email: string, disabled: boolean): Promi
return data
}
+export async function sendEmailBindingCode(email: string): Promise {
+ await apiClient.post('/user/account-bindings/email/send-code', { email })
+}
+
+export async function bindEmailIdentity(payload: {
+ email: string
+ verify_code: string
+ password: string
+}): Promise {
+ const { data } = await apiClient.post('/user/account-bindings/email', payload)
+ return data
+}
+
export type BindableOAuthProvider = Exclude
interface BuildOAuthBindingStartURLOptions {
@@ -158,6 +171,8 @@ export const userAPI = {
verifyNotifyEmail,
removeNotifyEmail,
toggleNotifyEmail,
+ sendEmailBindingCode,
+ bindEmailIdentity,
buildOAuthBindingStartURL,
startOAuthBinding
}
diff --git a/frontend/src/components/auth/PendingOAuthCreateAccountForm.vue b/frontend/src/components/auth/PendingOAuthCreateAccountForm.vue
index 36e78d36..8e05617f 100644
--- a/frontend/src/components/auth/PendingOAuthCreateAccountForm.vue
+++ b/frontend/src/components/auth/PendingOAuthCreateAccountForm.vue
@@ -58,11 +58,20 @@
{{ t('auth.verificationCodeHint') }}
+
{{ isSubmitting ? t('common.processing') : 'Create account' }}
@@ -92,12 +101,13 @@
import { onMounted, onUnmounted, ref, watch } from 'vue'
import { useI18n } from 'vue-i18n'
import TurnstileWidget from '@/components/TurnstileWidget.vue'
-import { getPublicSettings, sendVerifyCode } from '@/api/auth'
+import { getPublicSettings, sendPendingOAuthVerifyCode } from '@/api/auth'
export type PendingOAuthCreateAccountPayload = {
email: string
password: string
verifyCode: string
+ invitationCode?: string
}
const props = defineProps<{
@@ -117,10 +127,12 @@ const { t } = useI18n()
const email = ref('')
const password = ref('')
const verifyCode = ref('')
+const invitationCode = ref('')
const isSendingCode = ref(false)
const sendCodeError = ref('')
const sendCodeSuccess = ref(false)
const countdown = ref(0)
+const invitationCodeEnabled = ref(false)
const turnstileEnabled = ref(false)
const turnstileSiteKey = ref('')
const turnstileToken = ref('')
@@ -203,7 +215,7 @@ async function handleSendCode() {
sendCodeSuccess.value = false
try {
- const response = await sendVerifyCode({
+ const response = await sendPendingOAuthVerifyCode({
email: trimmedEmail,
turnstile_token: turnstileEnabled.value ? turnstileToken.value : undefined
})
@@ -228,7 +240,8 @@ function handleSubmit() {
emit('submit', {
email: trimmedEmail,
password: password.value,
- verifyCode: verifyCode.value.trim()
+ verifyCode: verifyCode.value.trim(),
+ invitationCode: invitationCode.value.trim() || undefined
})
}
@@ -239,9 +252,11 @@ function emitSwitchToBind() {
onMounted(async () => {
try {
const settings = await getPublicSettings()
+ invitationCodeEnabled.value = settings.invitation_code_enabled === true
turnstileEnabled.value = settings.turnstile_enabled === true
turnstileSiteKey.value = settings.turnstile_site_key || ''
} catch {
+ invitationCodeEnabled.value = false
turnstileEnabled.value = false
turnstileSiteKey.value = ''
}
diff --git a/frontend/src/components/auth/__tests__/PendingOAuthCreateAccountForm.spec.ts b/frontend/src/components/auth/__tests__/PendingOAuthCreateAccountForm.spec.ts
index 3160c0b7..f2375302 100644
--- a/frontend/src/components/auth/__tests__/PendingOAuthCreateAccountForm.spec.ts
+++ b/frontend/src/components/auth/__tests__/PendingOAuthCreateAccountForm.spec.ts
@@ -4,6 +4,7 @@ import { flushPromises, mount } from '@vue/test-utils'
import PendingOAuthCreateAccountForm from '../PendingOAuthCreateAccountForm.vue'
const sendVerifyCode = vi.fn()
+const sendPendingOAuthVerifyCode = vi.fn()
const getPublicSettings = vi.fn()
vi.mock('vue-i18n', async () => {
@@ -21,6 +22,7 @@ vi.mock('@/api/auth', async () => {
return {
...actual,
sendVerifyCode: (...args: any[]) => sendVerifyCode(...args),
+ sendPendingOAuthVerifyCode: (...args: any[]) => sendPendingOAuthVerifyCode(...args),
getPublicSettings: (...args: any[]) => getPublicSettings(...args)
}
})
@@ -28,6 +30,7 @@ vi.mock('@/api/auth', async () => {
describe('PendingOAuthCreateAccountForm', () => {
beforeEach(() => {
sendVerifyCode.mockReset()
+ sendPendingOAuthVerifyCode.mockReset()
getPublicSettings.mockReset()
getPublicSettings.mockResolvedValue({
turnstile_enabled: false,
@@ -61,8 +64,42 @@ describe('PendingOAuthCreateAccountForm', () => {
])
})
+ it('shows and emits invitation code when invitation-only signup is enabled', async () => {
+ getPublicSettings.mockResolvedValue({
+ invitation_code_enabled: true,
+ turnstile_enabled: false,
+ turnstile_site_key: ''
+ })
+
+ const wrapper = mount(PendingOAuthCreateAccountForm, {
+ props: {
+ providerName: 'LinuxDo',
+ testIdPrefix: 'linuxdo',
+ initialEmail: 'prefill@example.com',
+ isSubmitting: false
+ }
+ })
+
+ await flushPromises()
+ await wrapper.get('[data-testid="linuxdo-create-account-password"]').setValue('secret-123')
+ await wrapper.get('[data-testid="linuxdo-create-account-verify-code"]').setValue('246810')
+ await wrapper.get('[data-testid="linuxdo-create-account-invitation-code"]').setValue(' INVITE123 ')
+ await wrapper.get('form').trigger('submit.prevent')
+
+ expect(wrapper.emitted('submit')).toEqual([
+ [
+ {
+ email: 'prefill@example.com',
+ password: 'secret-123',
+ verifyCode: '246810',
+ invitationCode: 'INVITE123'
+ }
+ ]
+ ])
+ })
+
it('sends a verify code for the trimmed email value', async () => {
- sendVerifyCode.mockResolvedValue({
+ sendPendingOAuthVerifyCode.mockResolvedValue({
message: 'sent',
countdown: 60
})
@@ -80,7 +117,7 @@ describe('PendingOAuthCreateAccountForm', () => {
await wrapper.get('[data-testid="linuxdo-create-account-send-code"]').trigger('click')
await flushPromises()
- expect(sendVerifyCode).toHaveBeenCalledWith({
+ expect(sendPendingOAuthVerifyCode).toHaveBeenCalledWith({
email: 'user@example.com'
})
})
@@ -90,7 +127,7 @@ describe('PendingOAuthCreateAccountForm', () => {
turnstile_enabled: true,
turnstile_site_key: 'site-key'
})
- sendVerifyCode.mockResolvedValue({
+ sendPendingOAuthVerifyCode.mockResolvedValue({
message: 'sent',
countdown: 60
})
@@ -120,7 +157,7 @@ describe('PendingOAuthCreateAccountForm', () => {
await wrapper.get('[data-testid="linuxdo-create-account-send-code"]').trigger('click')
await flushPromises()
- expect(sendVerifyCode).toHaveBeenCalledWith({
+ expect(sendPendingOAuthVerifyCode).toHaveBeenCalledWith({
email: 'user@example.com',
turnstile_token: 'turnstile-token'
})
diff --git a/frontend/src/components/user/profile/ProfileIdentityBindingsSection.vue b/frontend/src/components/user/profile/ProfileIdentityBindingsSection.vue
index 76c94ba1..ccc1cbd0 100644
--- a/frontend/src/components/user/profile/ProfileIdentityBindingsSection.vue
+++ b/frontend/src/components/user/profile/ProfileIdentityBindingsSection.vue
@@ -13,35 +13,96 @@
-
-
- {{ item.label }}
+
-
-
- {{
- item.bound
- ? t('profile.authBindings.status.bound')
- : t('profile.authBindings.status.notBound')
- }}
-
-
-
- {{ t('profile.authBindings.bindAction', { providerName: item.label }) }}
-
+
+
+ {{ t('profile.authBindings.bindAction', { providerName: item.label }) }}
+
+
@@ -49,7 +110,7 @@
diff --git a/frontend/src/components/user/profile/__tests__/ProfileIdentityBindingsSection.spec.ts b/frontend/src/components/user/profile/__tests__/ProfileIdentityBindingsSection.spec.ts
index 4e194a39..ec4aed5d 100644
--- a/frontend/src/components/user/profile/__tests__/ProfileIdentityBindingsSection.spec.ts
+++ b/frontend/src/components/user/profile/__tests__/ProfileIdentityBindingsSection.spec.ts
@@ -2,7 +2,7 @@ import { mount } from '@vue/test-utils'
import { createPinia, setActivePinia } from 'pinia'
import { afterEach, beforeEach, describe, expect, it, vi } from 'vitest'
import ProfileIdentityBindingsSection from '@/components/user/profile/ProfileIdentityBindingsSection.vue'
-import { useAppStore } from '@/stores'
+import { useAppStore, useAuthStore } from '@/stores'
import type { User } from '@/types'
const routeState = vi.hoisted(() => ({
@@ -15,10 +15,24 @@ const locationState = vi.hoisted(() => ({
let pinia: ReturnType
+const userApiMocks = vi.hoisted(() => ({
+ sendEmailBindingCode: vi.fn(),
+ bindEmailIdentity: vi.fn(),
+}))
+
vi.mock('vue-router', () => ({
useRoute: () => routeState,
}))
+vi.mock('@/api/user', async (importOriginal) => {
+ const actual = await importOriginal()
+ return {
+ ...actual,
+ sendEmailBindingCode: (...args: any[]) => userApiMocks.sendEmailBindingCode(...args),
+ bindEmailIdentity: (...args: any[]) => userApiMocks.bindEmailIdentity(...args),
+ }
+})
+
vi.mock('vue-i18n', async (importOriginal) => {
const actual = await importOriginal()
return {
@@ -34,6 +48,13 @@ vi.mock('vue-i18n', async (importOriginal) => {
if (key === 'profile.authBindings.providers.wechat') return 'WeChat'
if (key === 'profile.authBindings.providers.oidc') return params?.providerName || 'OIDC'
if (key === 'profile.authBindings.bindAction') return `Bind ${params?.providerName || ''}`.trim()
+ if (key === 'profile.authBindings.emailPlaceholder') return 'Email address'
+ if (key === 'profile.authBindings.codePlaceholder') return 'Verification code'
+ if (key === 'profile.authBindings.passwordPlaceholder') return 'Set password'
+ if (key === 'profile.authBindings.sendCodeAction') return 'Send code'
+ if (key === 'profile.authBindings.confirmEmailBindAction') return 'Bind email'
+ if (key === 'profile.authBindings.codeSentTo') return `Code sent to ${params?.email || ''}`.trim()
+ if (key === 'profile.authBindings.bindSuccess') return 'Bind success'
return key
},
}),
@@ -76,6 +97,8 @@ describe('ProfileIdentityBindingsSection', () => {
const appStore = useAppStore()
appStore.cachedPublicSettings = null
appStore.publicSettingsLoaded = false
+ userApiMocks.sendEmailBindingCode.mockReset()
+ userApiMocks.bindEmailIdentity.mockReset()
})
afterEach(() => {
@@ -224,4 +247,58 @@ describe('ProfileIdentityBindingsSection', () => {
expect(wrapper.find('[data-testid="profile-binding-wechat-action"]').exists()).toBe(true)
})
+
+ it('sends email verification code and binds email from the profile card', async () => {
+ userApiMocks.sendEmailBindingCode.mockResolvedValue(undefined)
+ userApiMocks.bindEmailIdentity.mockResolvedValue(
+ createUser({
+ email: 'bound@example.com',
+ email_bound: true,
+ auth_bindings: {
+ email: { bound: true },
+ },
+ })
+ )
+
+ const appStore = useAppStore()
+ const authStore = useAuthStore()
+ authStore.user = createUser({
+ email: 'legacy-user@linuxdo-connect.invalid',
+ email_bound: false,
+ auth_bindings: {
+ email: { bound: false },
+ },
+ })
+ const showSuccessSpy = vi.spyOn(appStore, 'showSuccess')
+
+ const wrapper = mount(ProfileIdentityBindingsSection, {
+ global: {
+ plugins: [pinia],
+ },
+ props: {
+ user: authStore.user,
+ linuxdoEnabled: false,
+ oidcEnabled: false,
+ wechatEnabled: false,
+ },
+ })
+
+ await wrapper.get('[data-testid="profile-binding-email-input"]').setValue('bound@example.com')
+ await wrapper.get('[data-testid="profile-binding-email-send-code"]').trigger('click')
+
+ expect(userApiMocks.sendEmailBindingCode).toHaveBeenCalledWith('bound@example.com')
+ expect(showSuccessSpy).toHaveBeenCalledWith('Code sent to bound@example.com')
+
+ await wrapper.get('[data-testid="profile-binding-email-code-input"]').setValue('123456')
+ await wrapper.get('[data-testid="profile-binding-email-password-input"]').setValue('new-password')
+ await wrapper.get('[data-testid="profile-binding-email-submit"]').trigger('click')
+
+ expect(userApiMocks.bindEmailIdentity).toHaveBeenCalledWith({
+ email: 'bound@example.com',
+ verify_code: '123456',
+ password: 'new-password',
+ })
+ expect(wrapper.get('[data-testid="profile-binding-email-status"]').text()).toBe('Bound')
+ expect(authStore.user?.email).toBe('bound@example.com')
+ })
})
diff --git a/frontend/src/i18n/locales/en.ts b/frontend/src/i18n/locales/en.ts
index ec9c1ea3..2b41a3c3 100644
--- a/frontend/src/i18n/locales/en.ts
+++ b/frontend/src/i18n/locales/en.ts
@@ -964,6 +964,12 @@ export default {
description: 'View current bindings and connect another provider to this account.',
bindAction: 'Bind {providerName}',
bindSuccess: 'Account linked successfully',
+ emailPlaceholder: 'Enter email address',
+ codePlaceholder: 'Enter verification code',
+ passwordPlaceholder: 'Set a login password',
+ sendCodeAction: 'Send code',
+ confirmEmailBindAction: 'Bind email',
+ codeSentTo: 'Code sent to {email}',
status: {
bound: 'Bound',
notBound: 'Not bound',
diff --git a/frontend/src/i18n/locales/zh.ts b/frontend/src/i18n/locales/zh.ts
index 9941d323..b60a69d6 100644
--- a/frontend/src/i18n/locales/zh.ts
+++ b/frontend/src/i18n/locales/zh.ts
@@ -968,6 +968,12 @@ export default {
description: '查看当前绑定状态,并将更多第三方登录方式关联到这个账号。',
bindAction: '绑定 {providerName}',
bindSuccess: '账号绑定成功',
+ emailPlaceholder: '输入邮箱地址',
+ codePlaceholder: '输入验证码',
+ passwordPlaceholder: '设置登录密码',
+ sendCodeAction: '发送验证码',
+ confirmEmailBindAction: '绑定邮箱',
+ codeSentTo: '验证码已发送到 {email}',
status: {
bound: '已绑定',
notBound: '未绑定',
diff --git a/frontend/src/types/index.ts b/frontend/src/types/index.ts
index 07341919..bfc11cb2 100644
--- a/frontend/src/types/index.ts
+++ b/frontend/src/types/index.ts
@@ -118,6 +118,8 @@ export interface RegisterRequest {
export interface SendVerifyCodeRequest {
email: string
turnstile_token?: string
+ pending_auth_token?: string
+ pending_oauth_token?: string
}
export interface SendVerifyCodeResponse {
diff --git a/frontend/src/views/auth/EmailVerifyView.vue b/frontend/src/views/auth/EmailVerifyView.vue
index 84dd4667..d7bf6b7a 100644
--- a/frontend/src/views/auth/EmailVerifyView.vue
+++ b/frontend/src/views/auth/EmailVerifyView.vue
@@ -176,7 +176,12 @@ import { AuthLayout } from '@/components/layout'
import Icon from '@/components/icons/Icon.vue'
import TurnstileWidget from '@/components/TurnstileWidget.vue'
import { useAuthStore, useAppStore } from '@/stores'
-import { persistOAuthTokenContext, getPublicSettings, sendVerifyCode } from '@/api/auth'
+import {
+ persistOAuthTokenContext,
+ getPublicSettings,
+ sendPendingOAuthVerifyCode,
+ sendVerifyCode,
+} from '@/api/auth'
import { apiClient } from '@/api/client'
import { buildAuthErrorMessage } from '@/utils/authError'
import {
@@ -355,18 +360,21 @@ async function sendCode(): Promise {
errorMessage.value = ''
try {
- if (!isRegistrationEmailSuffixAllowed(email.value, registrationEmailSuffixWhitelist.value)) {
+ if (!pendingAuthToken.value && !isRegistrationEmailSuffixAllowed(email.value, registrationEmailSuffixWhitelist.value)) {
errorMessage.value = buildEmailSuffixNotAllowedMessage()
appStore.showError(errorMessage.value)
return
}
- const response = await sendVerifyCode({
+ const requestPayload = {
email: email.value,
[pendingAuthTokenField.value]: pendingAuthToken.value || undefined,
// 优先使用重发时新获取的 token(因为初始 token 可能已被使用)
turnstile_token: resendTurnstileToken.value || initialTurnstileToken.value || undefined
- } as Parameters[0])
+ } as Parameters[0]
+ const response = pendingAuthToken.value
+ ? await sendPendingOAuthVerifyCode(requestPayload)
+ : await sendVerifyCode(requestPayload)
codeSent.value = true
startCountdown(response.countdown)
diff --git a/frontend/src/views/auth/LinuxDoCallbackView.vue b/frontend/src/views/auth/LinuxDoCallbackView.vue
index 6c923b0a..735c6582 100644
--- a/frontend/src/views/auth/LinuxDoCallbackView.vue
+++ b/frontend/src/views/auth/LinuxDoCallbackView.vue
@@ -444,6 +444,28 @@ function getRequestErrorMessage(error: unknown, fallback: string): string {
return err.response?.data?.detail || err.response?.data?.message || err.message || fallback
}
+function isCreateAccountRecoveryError(error: unknown): boolean {
+ const data = (error as {
+ response?: {
+ data?: {
+ reason?: string
+ error?: string
+ code?: string
+ step?: string
+ intent?: string
+ }
+ }
+ }).response?.data
+ const states = [data?.reason, data?.error, data?.code, data?.step, data?.intent]
+ .map(value => value?.trim().toLowerCase())
+ .filter((value): value is string => Boolean(value))
+
+ return states.includes('email_exists') ||
+ states.includes('bind_login_required') ||
+ states.includes('bind_login') ||
+ states.includes('adopt_existing_user_by_email')
+}
+
async function finalizeCompletion(completion: PendingOAuthExchangeResponse, redirect: string) {
if (getOAuthCompletionKind(completion) === 'bind') {
const bindRedirect = sanitizeRedirectPath(completion.redirect || '/profile')
@@ -540,10 +562,15 @@ async function handleCreateAccount(payload: PendingOAuthCreateAccountPayload) {
email: payload.email,
password: payload.password,
verify_code: payload.verifyCode || undefined,
+ invitation_code: payload.invitationCode || undefined,
...serializeAdoptionDecision(currentAdoptionDecision())
})
await finalizePendingAccountResponse(data)
} catch (e: unknown) {
+ if (isCreateAccountRecoveryError(e)) {
+ switchToBindLoginMode(payload.email)
+ return
+ }
accountActionError.value = getRequestErrorMessage(e, t('auth.loginFailed'))
} finally {
isSubmitting.value = false
diff --git a/frontend/src/views/auth/OidcCallbackView.vue b/frontend/src/views/auth/OidcCallbackView.vue
index 840e4964..019cab54 100644
--- a/frontend/src/views/auth/OidcCallbackView.vue
+++ b/frontend/src/views/auth/OidcCallbackView.vue
@@ -488,6 +488,28 @@ function getRequestErrorMessage(error: unknown, fallback: string): string {
return err.response?.data?.detail || err.response?.data?.message || err.message || fallback
}
+function isCreateAccountRecoveryError(error: unknown): boolean {
+ const data = (error as {
+ response?: {
+ data?: {
+ reason?: string
+ error?: string
+ code?: string
+ step?: string
+ intent?: string
+ }
+ }
+ }).response?.data
+ const states = [data?.reason, data?.error, data?.code, data?.step, data?.intent]
+ .map(value => value?.trim().toLowerCase())
+ .filter((value): value is string => Boolean(value))
+
+ return states.includes('email_exists') ||
+ states.includes('bind_login_required') ||
+ states.includes('bind_login') ||
+ states.includes('adopt_existing_user_by_email')
+}
+
async function finalizeCompletion(completion: PendingOAuthExchangeResponse, redirect: string) {
if (getOAuthCompletionKind(completion) === 'bind') {
const bindRedirect = sanitizeRedirectPath(completion.redirect || '/profile')
@@ -584,10 +606,15 @@ async function handleCreateAccount(payload: PendingOAuthCreateAccountPayload) {
email: payload.email,
password: payload.password,
verify_code: payload.verifyCode || undefined,
+ invitation_code: payload.invitationCode || undefined,
...serializeAdoptionDecision(currentAdoptionDecision())
})
await finalizePendingAccountResponse(data)
} catch (e: unknown) {
+ if (isCreateAccountRecoveryError(e)) {
+ switchToBindLoginMode(payload.email)
+ return
+ }
accountActionError.value = getRequestErrorMessage(e, t('auth.loginFailed'))
} finally {
isSubmitting.value = false
diff --git a/frontend/src/views/auth/WechatCallbackView.vue b/frontend/src/views/auth/WechatCallbackView.vue
index 35cd0032..36e3140c 100644
--- a/frontend/src/views/auth/WechatCallbackView.vue
+++ b/frontend/src/views/auth/WechatCallbackView.vue
@@ -647,6 +647,28 @@ function getRequestErrorMessage(error: unknown, fallback: string): string {
return err.response?.data?.detail || err.response?.data?.message || err.message || fallback
}
+function isCreateAccountRecoveryError(error: unknown): boolean {
+ const data = (error as {
+ response?: {
+ data?: {
+ reason?: string
+ error?: string
+ code?: string
+ step?: string
+ intent?: string
+ }
+ }
+ }).response?.data
+ const states = [data?.reason, data?.error, data?.code, data?.step, data?.intent]
+ .map(value => value?.trim().toLowerCase())
+ .filter((value): value is string => Boolean(value))
+
+ return states.includes('email_exists') ||
+ states.includes('bind_login_required') ||
+ states.includes('bind_login') ||
+ states.includes('adopt_existing_user_by_email')
+}
+
async function finalizeCompletion(completion: PendingOAuthExchangeResponse, redirect: string) {
if (getOAuthCompletionKind(completion) === 'bind') {
const bindRedirect = sanitizeRedirectPath(completion.redirect || '/profile')
@@ -739,10 +761,15 @@ async function handleCreateAccount(payload: PendingOAuthCreateAccountPayload) {
email: payload.email,
password: payload.password,
verify_code: payload.verifyCode || undefined,
+ invitation_code: payload.invitationCode || undefined,
...serializeAdoptionDecision(currentAdoptionDecision())
})
await finalizePendingAccountResponse(data)
} catch (e: unknown) {
+ if (isCreateAccountRecoveryError(e)) {
+ switchToBindLoginMode(payload.email)
+ return
+ }
accountActionError.value = getRequestErrorMessage(e, t('auth.loginFailed'))
} finally {
isSubmitting.value = false
diff --git a/frontend/src/views/auth/__tests__/EmailVerifyView.spec.ts b/frontend/src/views/auth/__tests__/EmailVerifyView.spec.ts
index f6dff076..c231d6e7 100644
--- a/frontend/src/views/auth/__tests__/EmailVerifyView.spec.ts
+++ b/frontend/src/views/auth/__tests__/EmailVerifyView.spec.ts
@@ -11,6 +11,7 @@ const {
clearPendingAuthSessionMock,
getPublicSettingsMock,
sendVerifyCodeMock,
+ sendPendingOAuthVerifyCodeMock,
persistOAuthTokenContextMock,
apiClientPostMock,
authStoreState,
@@ -23,6 +24,7 @@ const {
clearPendingAuthSessionMock: vi.fn(),
getPublicSettingsMock: vi.fn(),
sendVerifyCodeMock: vi.fn(),
+ sendPendingOAuthVerifyCodeMock: vi.fn(),
persistOAuthTokenContextMock: vi.fn(),
apiClientPostMock: vi.fn(),
authStoreState: {
@@ -80,6 +82,7 @@ vi.mock('@/api/auth', async () => {
...actual,
getPublicSettings: (...args: any[]) => getPublicSettingsMock(...args),
sendVerifyCode: (...args: any[]) => sendVerifyCodeMock(...args),
+ sendPendingOAuthVerifyCode: (...args: any[]) => sendPendingOAuthVerifyCodeMock(...args),
persistOAuthTokenContext: (...args: any[]) => persistOAuthTokenContextMock(...args),
}
})
@@ -100,6 +103,7 @@ describe('EmailVerifyView', () => {
clearPendingAuthSessionMock.mockReset()
getPublicSettingsMock.mockReset()
sendVerifyCodeMock.mockReset()
+ sendPendingOAuthVerifyCodeMock.mockReset()
persistOAuthTokenContextMock.mockReset()
apiClientPostMock.mockReset()
authStoreState.pendingAuthSession = null
@@ -112,9 +116,86 @@ describe('EmailVerifyView', () => {
registration_email_suffix_whitelist: [],
})
sendVerifyCodeMock.mockResolvedValue({ countdown: 60 })
+ sendPendingOAuthVerifyCodeMock.mockResolvedValue({ countdown: 60 })
setTokenMock.mockResolvedValue({})
})
+ it('uses the pending oauth verify-code endpoint when register data carries a pending auth session', async () => {
+ authStoreState.pendingAuthSession = {
+ token: 'pending-token-1',
+ token_field: 'pending_auth_token',
+ provider: 'wechat',
+ redirect: '/profile',
+ }
+ sessionStorage.setItem(
+ 'register_data',
+ JSON.stringify({
+ email: 'fresh@example.com',
+ password: 'secret-123',
+ })
+ )
+
+ mount(EmailVerifyView, {
+ global: {
+ stubs: {
+ AuthLayout: { template: '
' },
+ Icon: true,
+ TurnstileWidget: true,
+ transition: false,
+ },
+ },
+ })
+
+ await flushPromises()
+
+ expect(sendPendingOAuthVerifyCodeMock).toHaveBeenCalledWith({
+ email: 'fresh@example.com',
+ pending_auth_token: 'pending-token-1',
+ })
+ expect(sendVerifyCodeMock).not.toHaveBeenCalled()
+ })
+
+ it('skips the registration email suffix whitelist for pending oauth verification', async () => {
+ authStoreState.pendingAuthSession = {
+ token: 'pending-token-2',
+ token_field: 'pending_auth_token',
+ provider: 'oidc',
+ redirect: '/profile',
+ }
+ getPublicSettingsMock.mockResolvedValue({
+ turnstile_enabled: false,
+ turnstile_site_key: '',
+ site_name: 'Sub2API',
+ registration_email_suffix_whitelist: ['allowed.com'],
+ })
+ sessionStorage.setItem(
+ 'register_data',
+ JSON.stringify({
+ email: 'fresh@example.com',
+ password: 'secret-123',
+ })
+ )
+
+ mount(EmailVerifyView, {
+ global: {
+ stubs: {
+ AuthLayout: { template: '
' },
+ Icon: true,
+ TurnstileWidget: true,
+ transition: false,
+ },
+ },
+ })
+
+ await flushPromises()
+
+ expect(sendPendingOAuthVerifyCodeMock).toHaveBeenCalledWith({
+ email: 'fresh@example.com',
+ pending_auth_token: 'pending-token-2',
+ })
+ expect(showErrorMock).not.toHaveBeenCalled()
+ })
+
it('submits pending auth account creation when session storage has no pending metadata but auth store does', async () => {
authStoreState.pendingAuthSession = {
token: 'pending-token-1',
diff --git a/frontend/src/views/auth/__tests__/LinuxDoCallbackView.spec.ts b/frontend/src/views/auth/__tests__/LinuxDoCallbackView.spec.ts
index a04915b7..f612681a 100644
--- a/frontend/src/views/auth/__tests__/LinuxDoCallbackView.spec.ts
+++ b/frontend/src/views/auth/__tests__/LinuxDoCallbackView.spec.ts
@@ -15,6 +15,7 @@ const getPublicSettings = vi.fn()
const login2FA = vi.fn()
const apiClientPost = vi.fn()
const sendVerifyCode = vi.fn()
+const sendPendingOAuthVerifyCode = vi.fn()
vi.mock('vue-router', () => ({
useRoute: () => ({
@@ -61,7 +62,8 @@ vi.mock('@/api/auth', async () => {
completeLinuxDoOAuthRegistration: (...args: any[]) => completeLinuxDoOAuthRegistration(...args),
getPublicSettings: (...args: any[]) => getPublicSettings(...args),
login2FA: (...args: any[]) => login2FA(...args),
- sendVerifyCode: (...args: any[]) => sendVerifyCode(...args)
+ sendVerifyCode: (...args: any[]) => sendVerifyCode(...args),
+ sendPendingOAuthVerifyCode: (...args: any[]) => sendPendingOAuthVerifyCode(...args)
}
})
@@ -79,6 +81,7 @@ describe('LinuxDoCallbackView', () => {
login2FA.mockReset()
apiClientPost.mockReset()
sendVerifyCode.mockReset()
+ sendPendingOAuthVerifyCode.mockReset()
getPublicSettings.mockResolvedValue({
turnstile_enabled: false,
turnstile_site_key: ''
@@ -334,6 +337,11 @@ describe('LinuxDoCallbackView', () => {
})
it('collects email, password, and verify code for pending oauth account creation and submits adoption decisions', async () => {
+ getPublicSettings.mockResolvedValue({
+ invitation_code_enabled: true,
+ turnstile_enabled: false,
+ turnstile_site_key: ''
+ })
exchangePendingOAuthCompletion.mockResolvedValue({
error: 'email_required',
redirect: '/welcome',
@@ -370,6 +378,7 @@ describe('LinuxDoCallbackView', () => {
await wrapper.get('[data-testid="linuxdo-create-account-email"]').setValue(' new@example.com ')
await wrapper.get('[data-testid="linuxdo-create-account-password"]').setValue('secret-123')
await wrapper.get('[data-testid="linuxdo-create-account-verify-code"]').setValue('246810')
+ await wrapper.get('[data-testid="linuxdo-create-account-invitation-code"]').setValue(' INVITE123 ')
await wrapper.get('[data-testid="linuxdo-create-account-submit"]').trigger('click')
await flushPromises()
@@ -377,6 +386,7 @@ describe('LinuxDoCallbackView', () => {
email: 'new@example.com',
password: 'secret-123',
verify_code: '246810',
+ invitation_code: 'INVITE123',
adopt_display_name: true,
adopt_avatar: false
})
@@ -384,12 +394,48 @@ describe('LinuxDoCallbackView', () => {
expect(replace).toHaveBeenCalledWith('/welcome')
})
+ it('switches to bind-login when create-account returns EMAIL_EXISTS', async () => {
+ exchangePendingOAuthCompletion.mockResolvedValue({
+ error: 'email_required',
+ redirect: '/welcome'
+ })
+ apiClientPost.mockRejectedValue({
+ response: {
+ data: {
+ reason: 'EMAIL_EXISTS',
+ message: 'email already exists'
+ }
+ }
+ })
+
+ const wrapper = mount(LinuxDoCallbackView, {
+ global: {
+ stubs: {
+ AuthLayout: { template: '
' },
+ Icon: true,
+ RouterLink: { template: ' ' },
+ transition: false
+ }
+ }
+ })
+
+ await flushPromises()
+ await wrapper.get('[data-testid="linuxdo-create-account-email"]').setValue('existing@example.com')
+ await wrapper.get('[data-testid="linuxdo-create-account-password"]').setValue('secret-123')
+ await wrapper.get('[data-testid="linuxdo-create-account-submit"]').trigger('click')
+ await flushPromises()
+
+ expect((wrapper.get('[data-testid="linuxdo-bind-login-email"]').element as HTMLInputElement).value).toBe(
+ 'existing@example.com'
+ )
+ })
+
it('sends a verify code for pending oauth account creation', async () => {
exchangePendingOAuthCompletion.mockResolvedValue({
error: 'email_required',
redirect: '/welcome'
})
- sendVerifyCode.mockResolvedValue({
+ sendPendingOAuthVerifyCode.mockResolvedValue({
message: 'sent',
countdown: 60
})
@@ -411,7 +457,7 @@ describe('LinuxDoCallbackView', () => {
await wrapper.get('[data-testid="linuxdo-create-account-send-code"]').trigger('click')
await flushPromises()
- expect(sendVerifyCode).toHaveBeenCalledWith({
+ expect(sendPendingOAuthVerifyCode).toHaveBeenCalledWith({
email: 'new@example.com'
})
})
diff --git a/frontend/src/views/auth/__tests__/OidcCallbackView.spec.ts b/frontend/src/views/auth/__tests__/OidcCallbackView.spec.ts
index 259fb282..0edcb931 100644
--- a/frontend/src/views/auth/__tests__/OidcCallbackView.spec.ts
+++ b/frontend/src/views/auth/__tests__/OidcCallbackView.spec.ts
@@ -15,6 +15,7 @@ const getPublicSettings = vi.fn()
const login2FA = vi.fn()
const apiClientPost = vi.fn()
const sendVerifyCode = vi.fn()
+const sendPendingOAuthVerifyCode = vi.fn()
vi.mock('vue-router', () => ({
useRoute: () => ({
@@ -66,7 +67,8 @@ vi.mock('@/api/auth', async () => {
completeOIDCOAuthRegistration: (...args: any[]) => completeOIDCOAuthRegistration(...args),
getPublicSettings: (...args: any[]) => getPublicSettings(...args),
login2FA: (...args: any[]) => login2FA(...args),
- sendVerifyCode: (...args: any[]) => sendVerifyCode(...args)
+ sendVerifyCode: (...args: any[]) => sendVerifyCode(...args),
+ sendPendingOAuthVerifyCode: (...args: any[]) => sendPendingOAuthVerifyCode(...args)
}
})
@@ -84,6 +86,7 @@ describe('OidcCallbackView', () => {
login2FA.mockReset()
apiClientPost.mockReset()
sendVerifyCode.mockReset()
+ sendPendingOAuthVerifyCode.mockReset()
getPublicSettings.mockResolvedValue({
oidc_oauth_provider_name: 'ExampleID',
turnstile_enabled: false,
@@ -312,6 +315,12 @@ describe('OidcCallbackView', () => {
})
it('collects email, password, and verify code for pending oauth account creation and submits adoption decisions', async () => {
+ getPublicSettings.mockResolvedValue({
+ oidc_oauth_provider_name: 'ExampleID',
+ invitation_code_enabled: true,
+ turnstile_enabled: false,
+ turnstile_site_key: ''
+ })
exchangePendingOAuthCompletion.mockResolvedValue({
error: 'email_required',
redirect: '/welcome',
@@ -348,6 +357,7 @@ describe('OidcCallbackView', () => {
await wrapper.get('[data-testid="oidc-create-account-email"]').setValue(' new@example.com ')
await wrapper.get('[data-testid="oidc-create-account-password"]').setValue('secret-123')
await wrapper.get('[data-testid="oidc-create-account-verify-code"]').setValue('246810')
+ await wrapper.get('[data-testid="oidc-create-account-invitation-code"]').setValue(' INVITE123 ')
await wrapper.get('[data-testid="oidc-create-account-submit"]').trigger('click')
await flushPromises()
@@ -355,6 +365,7 @@ describe('OidcCallbackView', () => {
email: 'new@example.com',
password: 'secret-123',
verify_code: '246810',
+ invitation_code: 'INVITE123',
adopt_display_name: true,
adopt_avatar: false
})
@@ -362,12 +373,48 @@ describe('OidcCallbackView', () => {
expect(replace).toHaveBeenCalledWith('/welcome')
})
+ it('switches to bind-login when create-account returns EMAIL_EXISTS', async () => {
+ exchangePendingOAuthCompletion.mockResolvedValue({
+ error: 'email_required',
+ redirect: '/welcome'
+ })
+ apiClientPost.mockRejectedValue({
+ response: {
+ data: {
+ reason: 'EMAIL_EXISTS',
+ message: 'email already exists'
+ }
+ }
+ })
+
+ const wrapper = mount(OidcCallbackView, {
+ global: {
+ stubs: {
+ AuthLayout: { template: '
' },
+ Icon: true,
+ RouterLink: { template: ' ' },
+ transition: false
+ }
+ }
+ })
+
+ await flushPromises()
+ await wrapper.get('[data-testid="oidc-create-account-email"]').setValue('existing@example.com')
+ await wrapper.get('[data-testid="oidc-create-account-password"]').setValue('secret-123')
+ await wrapper.get('[data-testid="oidc-create-account-submit"]').trigger('click')
+ await flushPromises()
+
+ expect((wrapper.get('[data-testid="oidc-bind-login-email"]').element as HTMLInputElement).value).toBe(
+ 'existing@example.com'
+ )
+ })
+
it('sends a verify code for pending oauth account creation', async () => {
exchangePendingOAuthCompletion.mockResolvedValue({
error: 'email_required',
redirect: '/welcome'
})
- sendVerifyCode.mockResolvedValue({
+ sendPendingOAuthVerifyCode.mockResolvedValue({
message: 'sent',
countdown: 60
})
@@ -389,7 +436,7 @@ describe('OidcCallbackView', () => {
await wrapper.get('[data-testid="oidc-create-account-send-code"]').trigger('click')
await flushPromises()
- expect(sendVerifyCode).toHaveBeenCalledWith({
+ expect(sendPendingOAuthVerifyCode).toHaveBeenCalledWith({
email: 'new@example.com'
})
})
diff --git a/frontend/src/views/auth/__tests__/WechatCallbackView.spec.ts b/frontend/src/views/auth/__tests__/WechatCallbackView.spec.ts
index fed88890..e02060f6 100644
--- a/frontend/src/views/auth/__tests__/WechatCallbackView.spec.ts
+++ b/frontend/src/views/auth/__tests__/WechatCallbackView.spec.ts
@@ -8,6 +8,8 @@ const {
login2FAMock,
apiClientPostMock,
sendVerifyCodeMock,
+ sendPendingOAuthVerifyCodeMock,
+ getPublicSettingsMock,
prepareOAuthBindAccessTokenCookieMock,
getAuthTokenMock,
replaceMock,
@@ -24,6 +26,8 @@ const {
login2FAMock: vi.fn(),
apiClientPostMock: vi.fn(),
sendVerifyCodeMock: vi.fn(),
+ sendPendingOAuthVerifyCodeMock: vi.fn(),
+ getPublicSettingsMock: vi.fn(),
prepareOAuthBindAccessTokenCookieMock: vi.fn(),
getAuthTokenMock: vi.fn(),
replaceMock: vi.fn(),
@@ -130,6 +134,8 @@ vi.mock('@/api/auth', async () => {
completeWeChatOAuthRegistration: (...args: any[]) => completeWeChatOAuthRegistrationMock(...args),
login2FA: (...args: any[]) => login2FAMock(...args),
sendVerifyCode: (...args: any[]) => sendVerifyCodeMock(...args),
+ sendPendingOAuthVerifyCode: (...args: any[]) => sendPendingOAuthVerifyCodeMock(...args),
+ getPublicSettings: (...args: any[]) => getPublicSettingsMock(...args),
prepareOAuthBindAccessTokenCookie: (...args: any[]) => prepareOAuthBindAccessTokenCookieMock(...args),
getAuthToken: (...args: any[]) => getAuthTokenMock(...args),
}
@@ -142,6 +148,8 @@ describe('WechatCallbackView', () => {
login2FAMock.mockReset()
apiClientPostMock.mockReset()
sendVerifyCodeMock.mockReset()
+ sendPendingOAuthVerifyCodeMock.mockReset()
+ getPublicSettingsMock.mockReset()
replaceMock.mockReset()
setTokenMock.mockReset()
showSuccessMock.mockReset()
@@ -167,6 +175,11 @@ describe('WechatCallbackView', () => {
configurable: true,
value: 'Mozilla/5.0',
})
+ getPublicSettingsMock.mockResolvedValue({
+ invitation_code_enabled: false,
+ turnstile_enabled: false,
+ turnstile_site_key: '',
+ })
})
it('overrides an incompatible query mode with the configured open capability during bind recovery', async () => {
@@ -478,6 +491,11 @@ describe('WechatCallbackView', () => {
})
it('collects email, password, and verify code for pending oauth account creation and submits adoption decisions', async () => {
+ getPublicSettingsMock.mockResolvedValue({
+ invitation_code_enabled: true,
+ turnstile_enabled: false,
+ turnstile_site_key: '',
+ })
exchangePendingOAuthCompletionMock.mockResolvedValue({
error: 'email_required',
redirect: '/welcome',
@@ -514,6 +532,7 @@ describe('WechatCallbackView', () => {
await wrapper.get('[data-testid="wechat-create-account-email"]').setValue(' new@example.com ')
await wrapper.get('[data-testid="wechat-create-account-password"]').setValue('secret-123')
await wrapper.get('[data-testid="wechat-create-account-verify-code"]').setValue('246810')
+ await wrapper.get('[data-testid="wechat-create-account-invitation-code"]').setValue(' INVITE123 ')
await wrapper.get('[data-testid="wechat-create-account-submit"]').trigger('click')
await flushPromises()
@@ -521,6 +540,7 @@ describe('WechatCallbackView', () => {
email: 'new@example.com',
password: 'secret-123',
verify_code: '246810',
+ invitation_code: 'INVITE123',
adopt_display_name: true,
adopt_avatar: false,
})
@@ -528,12 +548,48 @@ describe('WechatCallbackView', () => {
expect(replaceMock).toHaveBeenCalledWith('/welcome')
})
+ it('switches to bind-login when create-account returns EMAIL_EXISTS', async () => {
+ exchangePendingOAuthCompletionMock.mockResolvedValue({
+ error: 'email_required',
+ redirect: '/welcome',
+ })
+ apiClientPostMock.mockRejectedValue({
+ response: {
+ data: {
+ reason: 'EMAIL_EXISTS',
+ message: 'email already exists',
+ },
+ },
+ })
+
+ const wrapper = mount(WechatCallbackView, {
+ global: {
+ stubs: {
+ AuthLayout: { template: '
' },
+ Icon: true,
+ RouterLink: { template: ' ' },
+ transition: false,
+ },
+ },
+ })
+
+ await flushPromises()
+ await wrapper.get('[data-testid="wechat-create-account-email"]').setValue('existing@example.com')
+ await wrapper.get('[data-testid="wechat-create-account-password"]').setValue('secret-123')
+ await wrapper.get('[data-testid="wechat-create-account-submit"]').trigger('click')
+ await flushPromises()
+
+ expect((wrapper.get('[data-testid="wechat-bind-login-email"]').element as HTMLInputElement).value).toBe(
+ 'existing@example.com'
+ )
+ })
+
it('sends a verify code for pending oauth account creation', async () => {
exchangePendingOAuthCompletionMock.mockResolvedValue({
error: 'email_required',
redirect: '/welcome',
})
- sendVerifyCodeMock.mockResolvedValue({
+ sendPendingOAuthVerifyCodeMock.mockResolvedValue({
message: 'sent',
countdown: 60,
})
@@ -555,7 +611,7 @@ describe('WechatCallbackView', () => {
await wrapper.get('[data-testid="wechat-create-account-send-code"]').trigger('click')
await flushPromises()
- expect(sendVerifyCodeMock).toHaveBeenCalledWith({
+ expect(sendPendingOAuthVerifyCodeMock).toHaveBeenCalledWith({
email: 'new@example.com',
})
})
From 7e89bca5e660b69de795a8a634fa17f5c24bea14 Mon Sep 17 00:00:00 2001
From: IanShaw027
Date: Tue, 21 Apr 2026 10:41:29 +0800
Subject: [PATCH 096/189] fix: tighten pending oauth email routing and binding
state
---
.../handler/auth_oauth_pending_flow.go | 163 ++++++++++++++++--
.../handler/auth_oauth_pending_flow_test.go | 148 ++++++++++++++--
.../internal/service/auth_oauth_email_flow.go | 122 ++++++++++++-
backend/internal/service/user_service.go | 29 +++-
frontend/src/api/auth.ts | 10 +-
.../ProfileIdentityBindingsSection.vue | 7 +-
.../ProfileIdentityBindingsSection.spec.ts | 23 +++
frontend/src/views/auth/EmailVerifyView.vue | 78 ++++++++-
.../auth/__tests__/EmailVerifyView.spec.ts | 159 +++++++++++++++++
9 files changed, 685 insertions(+), 54 deletions(-)
diff --git a/backend/internal/handler/auth_oauth_pending_flow.go b/backend/internal/handler/auth_oauth_pending_flow.go
index 1d3b113f..8a3006f3 100644
--- a/backend/internal/handler/auth_oauth_pending_flow.go
+++ b/backend/internal/handler/auth_oauth_pending_flow.go
@@ -8,6 +8,7 @@ import (
"net/http"
"net/url"
"strings"
+ "time"
dbent "github.com/Wei-Shaw/sub2api/ent"
"github.com/Wei-Shaw/sub2api/ent/authidentity"
@@ -35,6 +36,8 @@ const (
oauthCompletionResponseKey = "completion_response"
)
+var pendingOAuthCreateAccountPreCommitHook func(context.Context, *dbent.PendingAuthSession) error
+
type oauthPendingSessionPayload struct {
Intent string
Identity service.PendingAuthIdentityKey
@@ -481,6 +484,26 @@ func (h *AuthHandler) SendPendingOAuthVerifyCode(c *gin.Context) {
return
}
+ client := h.entClient()
+ if client == nil {
+ response.ErrorFrom(c, infraerrors.ServiceUnavailable("PENDING_AUTH_NOT_READY", "pending auth service is not ready"))
+ return
+ }
+
+ email := strings.TrimSpace(strings.ToLower(req.Email))
+ if existingUser, err := findUserByNormalizedEmail(c.Request.Context(), client, email); err == nil && existingUser != nil {
+ session, err = h.transitionPendingOAuthAccountToBindLogin(c, client, session, email, oauthAdoptionDecisionRequest{})
+ if err != nil {
+ response.ErrorFrom(c, err)
+ return
+ }
+ c.JSON(http.StatusOK, buildPendingOAuthSessionStatusPayload(session))
+ return
+ } else if err != nil && !errors.Is(err, service.ErrUserNotFound) {
+ response.ErrorFrom(c, err)
+ return
+ }
+
result, err := h.authService.SendPendingOAuthVerifyCode(c.Request.Context(), req.Email)
if err != nil {
response.ErrorFrom(c, err)
@@ -946,11 +969,46 @@ func applyPendingOAuthBinding(
return nil
}
+ if tx := dbent.TxFromContext(ctx); tx != nil {
+ return applyPendingOAuthBindingTx(ctx, tx, authService, userService, session, decision, overrideUserID, forceBind, applyFirstBindDefaults)
+ }
+
+ tx, err := client.Tx(ctx)
+ if err != nil {
+ return err
+ }
+ defer func() { _ = tx.Rollback() }()
+
+ txCtx := dbent.NewTxContext(ctx, tx)
+ if err := applyPendingOAuthBindingTx(txCtx, tx, authService, userService, session, decision, overrideUserID, forceBind, applyFirstBindDefaults); err != nil {
+ return err
+ }
+ return tx.Commit()
+}
+
+func applyPendingOAuthBindingTx(
+ ctx context.Context,
+ tx *dbent.Tx,
+ authService *service.AuthService,
+ userService *service.UserService,
+ session *dbent.PendingAuthSession,
+ decision *dbent.IdentityAdoptionDecision,
+ overrideUserID *int64,
+ forceBind bool,
+ applyFirstBindDefaults bool,
+) error {
+ if tx == nil || session == nil {
+ return nil
+ }
+ if !forceBind && !shouldBindPendingOAuthIdentity(session, decision) {
+ return nil
+ }
+
targetUserID := int64(0)
if overrideUserID != nil && *overrideUserID > 0 {
targetUserID = *overrideUserID
} else {
- resolvedUserID, err := resolvePendingOAuthTargetUserID(ctx, client, session)
+ resolvedUserID, err := resolvePendingOAuthTargetUserID(ctx, tx.Client(), session)
if err != nil {
return err
}
@@ -974,22 +1032,15 @@ func applyPendingOAuthBinding(
}
}
- tx, err := client.Tx(ctx)
- if err != nil {
- return err
- }
- defer func() { _ = tx.Rollback() }()
- txCtx := dbent.NewTxContext(ctx, tx)
-
if decision != nil && decision.AdoptDisplayName && adoptedDisplayName != "" {
if err := tx.Client().User.UpdateOneID(targetUserID).
SetUsername(adoptedDisplayName).
- Exec(txCtx); err != nil {
+ Exec(ctx); err != nil {
return err
}
}
- identity, err := ensurePendingOAuthIdentityForUser(txCtx, tx, session, targetUserID)
+ identity, err := ensurePendingOAuthIdentityForUser(ctx, tx, session, targetUserID)
if err != nil {
return err
}
@@ -1009,31 +1060,71 @@ func applyPendingOAuthBinding(
if issuer := oauthIdentityIssuer(session); issuer != nil {
updateIdentity = updateIdentity.SetIssuer(strings.TrimSpace(*issuer))
}
- if _, err := updateIdentity.Save(txCtx); err != nil {
+ if _, err := updateIdentity.Save(ctx); err != nil {
return err
}
if decision != nil && (decision.IdentityID == nil || *decision.IdentityID != identity.ID) {
if _, err := tx.Client().IdentityAdoptionDecision.UpdateOneID(decision.ID).
SetIdentityID(identity.ID).
- Save(txCtx); err != nil {
+ Save(ctx); err != nil {
return err
}
}
if applyFirstBindDefaults && authService != nil {
- if err := authService.ApplyProviderDefaultSettingsOnFirstBind(txCtx, targetUserID, session.ProviderType); err != nil {
+ if err := authService.ApplyProviderDefaultSettingsOnFirstBind(ctx, targetUserID, session.ProviderType); err != nil {
return err
}
}
if shouldAdoptAvatar && userService != nil {
- if _, err := userService.SetAvatar(txCtx, targetUserID, adoptedAvatarURL); err != nil {
+ if _, err := userService.SetAvatar(ctx, targetUserID, adoptedAvatarURL); err != nil {
return err
}
}
- return tx.Commit()
+ return nil
+}
+
+func consumePendingOAuthBrowserSessionTx(
+ ctx context.Context,
+ tx *dbent.Tx,
+ session *dbent.PendingAuthSession,
+) error {
+ if tx == nil || session == nil {
+ return service.ErrPendingAuthSessionNotFound
+ }
+
+ storedSession, err := tx.Client().PendingAuthSession.Get(ctx, session.ID)
+ if err != nil {
+ if dbent.IsNotFound(err) {
+ return service.ErrPendingAuthSessionNotFound
+ }
+ return err
+ }
+
+ now := time.Now().UTC()
+ if storedSession.ConsumedAt != nil {
+ return service.ErrPendingAuthSessionConsumed
+ }
+ if !storedSession.ExpiresAt.IsZero() && now.After(storedSession.ExpiresAt) {
+ return service.ErrPendingAuthSessionExpired
+ }
+ if strings.TrimSpace(storedSession.BrowserSessionKey) != "" &&
+ strings.TrimSpace(storedSession.BrowserSessionKey) != strings.TrimSpace(session.BrowserSessionKey) {
+ return service.ErrPendingAuthBrowserMismatch
+ }
+
+ if _, err := tx.Client().PendingAuthSession.UpdateOneID(storedSession.ID).
+ SetConsumedAt(now).
+ SetCompletionCodeHash("").
+ ClearCompletionCodeExpiresAt().
+ Save(ctx); err != nil {
+ return err
+ }
+
+ return nil
}
func applyPendingOAuthAdoption(
@@ -1256,7 +1347,7 @@ func (h *AuthHandler) createPendingOAuthAccount(c *gin.Context, provider string)
return
}
- pendingSvc, session, clearCookies, err := readPendingOAuthBrowserSession(c, h)
+ _, session, clearCookies, err := readPendingOAuthBrowserSession(c, h)
if err != nil {
response.ErrorFrom(c, err)
return
@@ -1341,7 +1432,20 @@ func (h *AuthHandler) createPendingOAuthAccount(c *gin.Context, provider string)
response.ErrorFrom(c, err)
return
}
- if err := applyPendingOAuthBinding(c.Request.Context(), client, h.authService, h.userService, session, decision, &user.ID, true, false); err != nil {
+
+ tx, err := client.Tx(c.Request.Context())
+ if err != nil {
+ if rollbackCreatedUser(err) {
+ return
+ }
+ response.ErrorFrom(c, infraerrors.InternalServer("PENDING_AUTH_BIND_APPLY_FAILED", "failed to bind pending oauth identity").WithCause(err))
+ return
+ }
+ defer func() { _ = tx.Rollback() }()
+ txCtx := dbent.NewTxContext(c.Request.Context(), tx)
+
+ if err := applyPendingOAuthBinding(txCtx, client, h.authService, h.userService, session, decision, &user.ID, true, false); err != nil {
+ _ = tx.Rollback()
if rollbackCreatedUser(err) {
return
}
@@ -1350,11 +1454,12 @@ func (h *AuthHandler) createPendingOAuthAccount(c *gin.Context, provider string)
}
if err := h.authService.FinalizeOAuthEmailAccount(
- c.Request.Context(),
+ txCtx,
user,
strings.TrimSpace(req.InvitationCode),
strings.TrimSpace(session.ProviderType),
); err != nil {
+ _ = tx.Rollback()
if rollbackCreatedUser(err) {
return
}
@@ -1362,7 +1467,8 @@ func (h *AuthHandler) createPendingOAuthAccount(c *gin.Context, provider string)
return
}
- if _, err := pendingSvc.ConsumeBrowserSession(c.Request.Context(), session.SessionToken, session.BrowserSessionKey); err != nil {
+ if err := consumePendingOAuthBrowserSessionTx(txCtx, tx, session); err != nil {
+ _ = tx.Rollback()
if rollbackCreatedUser(err) {
return
}
@@ -1371,6 +1477,25 @@ func (h *AuthHandler) createPendingOAuthAccount(c *gin.Context, provider string)
return
}
+ if pendingOAuthCreateAccountPreCommitHook != nil {
+ if err := pendingOAuthCreateAccountPreCommitHook(txCtx, session); err != nil {
+ _ = tx.Rollback()
+ if rollbackCreatedUser(err) {
+ return
+ }
+ response.ErrorFrom(c, infraerrors.InternalServer("PENDING_AUTH_BIND_APPLY_FAILED", "failed to bind pending oauth identity").WithCause(err))
+ return
+ }
+ }
+
+ if err := tx.Commit(); err != nil {
+ if rollbackCreatedUser(err) {
+ return
+ }
+ response.ErrorFrom(c, infraerrors.InternalServer("PENDING_AUTH_BIND_APPLY_FAILED", "failed to bind pending oauth identity").WithCause(err))
+ return
+ }
+
h.authService.RecordSuccessfulLogin(c.Request.Context(), user.ID)
clearCookies()
writeOAuthTokenPairResponse(c, tokenPair)
diff --git a/backend/internal/handler/auth_oauth_pending_flow_test.go b/backend/internal/handler/auth_oauth_pending_flow_test.go
index 1013a082..008c9da2 100644
--- a/backend/internal/handler/auth_oauth_pending_flow_test.go
+++ b/backend/internal/handler/auth_oauth_pending_flow_test.go
@@ -903,6 +903,63 @@ func TestCreateOIDCOAuthAccountExistingEmailNormalizesLegacySpacingAndCase(t *te
require.Equal(t, "owner@example.com", storedSession.ResolvedEmail)
}
+func TestSendPendingOAuthVerifyCodeExistingEmailReturnsBindLoginState(t *testing.T) {
+ handler, client := newOAuthPendingFlowTestHandlerWithEmailVerification(t, false, "owner@example.com", "135790")
+ ctx := context.Background()
+
+ existingUser, err := client.User.Create().
+ SetEmail("owner@example.com").
+ SetUsername("owner-user").
+ SetPasswordHash("hash").
+ SetRole(service.RoleUser).
+ SetStatus(service.StatusActive).
+ Save(ctx)
+ require.NoError(t, err)
+
+ session, err := client.PendingAuthSession.Create().
+ SetSessionToken("existing-email-send-code-session-token").
+ SetIntent("login").
+ SetProviderType("oidc").
+ SetProviderKey("https://issuer.example").
+ SetProviderSubject("oidc-existing-send-code-123").
+ SetBrowserSessionKey("existing-email-send-code-browser-session-key").
+ SetLocalFlowState(map[string]any{
+ oauthCompletionResponseKey: map[string]any{
+ "step": "email_required",
+ },
+ }).
+ SetRedirectTo("/dashboard").
+ SetExpiresAt(time.Now().UTC().Add(10 * time.Minute)).
+ Save(ctx)
+ require.NoError(t, err)
+
+ body := bytes.NewBufferString(`{"email":"owner@example.com"}`)
+ recorder := httptest.NewRecorder()
+ ginCtx, _ := gin.CreateTestContext(recorder)
+ req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/oauth/pending/send-verify-code", body)
+ req.Header.Set("Content-Type", "application/json")
+ req.AddCookie(&http.Cookie{Name: oauthPendingSessionCookieName, Value: encodeCookieValue(session.SessionToken)})
+ req.AddCookie(&http.Cookie{Name: oauthPendingBrowserCookieName, Value: encodeCookieValue("existing-email-send-code-browser-session-key")})
+ ginCtx.Request = req
+
+ handler.SendPendingOAuthVerifyCode(ginCtx)
+
+ require.Equal(t, http.StatusOK, recorder.Code)
+
+ var payload map[string]any
+ require.NoError(t, json.Unmarshal(recorder.Body.Bytes(), &payload))
+ require.Equal(t, "pending_session", payload["auth_result"])
+ require.Equal(t, "bind_login_required", payload["step"])
+ require.Equal(t, "owner@example.com", payload["email"])
+
+ storedSession, err := client.PendingAuthSession.Get(ctx, session.ID)
+ require.NoError(t, err)
+ require.Equal(t, "adopt_existing_user_by_email", storedSession.Intent)
+ require.NotNil(t, storedSession.TargetUserID)
+ require.Equal(t, existingUser.ID, *storedSession.TargetUserID)
+ require.Equal(t, "owner@example.com", storedSession.ResolvedEmail)
+}
+
func TestCreateOIDCOAuthAccountBlocksBackendModeBeforeCreatingUser(t *testing.T) {
handler, client := newOAuthPendingFlowTestHandlerWithDependencies(t, oauthPendingFlowTestHandlerOptions{
emailVerifyEnabled: true,
@@ -1032,6 +1089,78 @@ func TestCreateOIDCOAuthAccountRollsBackCreatedUserWhenBindingFails(t *testing.T
require.Nil(t, storedSession.ConsumedAt)
}
+func TestCreateOIDCOAuthAccountRollsBackPostBindFailureBeforeIdentityCanCommit(t *testing.T) {
+ handler, client := newOAuthPendingFlowTestHandlerWithDependencies(t, oauthPendingFlowTestHandlerOptions{
+ emailVerifyEnabled: true,
+ emailCache: &oauthPendingFlowEmailCacheStub{
+ verificationCodes: map[string]*service.VerificationCodeData{
+ "fresh@example.com": {
+ Code: "246810",
+ CreatedAt: time.Now().UTC(),
+ ExpiresAt: time.Now().UTC().Add(15 * time.Minute),
+ },
+ },
+ },
+ userRepoOptions: oauthPendingFlowUserRepoOptions{
+ rejectDeleteWhileAuthIdentityExists: true,
+ },
+ })
+ ctx := context.Background()
+
+ session, err := client.PendingAuthSession.Create().
+ SetSessionToken("create-account-finalize-failure-session-token").
+ SetIntent("login").
+ SetProviderType("oidc").
+ SetProviderKey("https://issuer.example").
+ SetProviderSubject("oidc-finalize-failure-123").
+ SetBrowserSessionKey("create-account-finalize-failure-browser-session-key").
+ SetUpstreamIdentityClaims(map[string]any{
+ "username": "oidc_user",
+ }).
+ SetRedirectTo("/profile").
+ SetExpiresAt(time.Now().UTC().Add(10 * time.Minute)).
+ Save(ctx)
+ require.NoError(t, err)
+
+ pendingOAuthCreateAccountPreCommitHook = func(context.Context, *dbent.PendingAuthSession) error {
+ return errors.New("forced post-bind failure")
+ }
+ t.Cleanup(func() {
+ pendingOAuthCreateAccountPreCommitHook = nil
+ })
+
+ body := bytes.NewBufferString(`{"email":"fresh@example.com","verify_code":"246810","password":"secret-123"}`)
+ recorder := httptest.NewRecorder()
+ ginCtx, _ := gin.CreateTestContext(recorder)
+ req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/oauth/oidc/create-account", body)
+ req.Header.Set("Content-Type", "application/json")
+ req.AddCookie(&http.Cookie{Name: oauthPendingSessionCookieName, Value: encodeCookieValue(session.SessionToken)})
+ req.AddCookie(&http.Cookie{Name: oauthPendingBrowserCookieName, Value: encodeCookieValue("create-account-finalize-failure-browser-session-key")})
+ ginCtx.Request = req
+
+ handler.CreateOIDCOAuthAccount(ginCtx)
+
+ require.Equal(t, http.StatusInternalServerError, recorder.Code)
+
+ userCount, err := client.User.Query().Where(dbuser.EmailEQ("fresh@example.com")).Count(ctx)
+ require.NoError(t, err)
+ require.Zero(t, userCount)
+
+ identityCount, err := client.AuthIdentity.Query().
+ Where(
+ authidentity.ProviderTypeEQ("oidc"),
+ authidentity.ProviderKeyEQ("https://issuer.example"),
+ authidentity.ProviderSubjectEQ("oidc-finalize-failure-123"),
+ ).
+ Count(ctx)
+ require.NoError(t, err)
+ require.Zero(t, identityCount)
+
+ storedSession, err := client.PendingAuthSession.Get(ctx, session.ID)
+ require.NoError(t, err)
+ require.Nil(t, storedSession.ConsumedAt)
+}
+
func TestBindOIDCOAuthLoginBindsExistingUserAndConsumesSession(t *testing.T) {
handler, client := newOAuthPendingFlowTestHandler(t, false)
ctx := context.Background()
@@ -1618,7 +1747,6 @@ type oauthPendingFlowTestHandlerOptions struct {
defaultSubAssigner service.DefaultSubscriptionAssigner
totpCache service.TotpCache
totpEncryptor service.SecretEncryptor
- redeemRepoFactory func(client *dbent.Client) service.RedeemCodeRepository
userRepoOptions oauthPendingFlowUserRepoOptions
}
@@ -1685,13 +1813,7 @@ CREATE TABLE IF NOT EXISTS user_avatars (
client: client,
options: options.userRepoOptions,
}
- redeemRepo := service.RedeemCodeRepository(nil)
- if options.redeemRepoFactory != nil {
- redeemRepo = options.redeemRepoFactory(client)
- }
- if redeemRepo == nil {
- redeemRepo = &oauthPendingFlowRedeemCodeRepo{client: client}
- }
+ redeemRepo := &oauthPendingFlowRedeemCodeRepo{client: client}
var emailService *service.EmailService
if options.emailCache != nil {
emailService = service.NewEmailService(&oauthPendingFlowSettingRepoStub{
@@ -2011,14 +2133,6 @@ func (r *oauthPendingFlowRedeemCodeRepo) SumPositiveBalanceByUser(context.Contex
panic("unexpected SumPositiveBalanceByUser call")
}
-type oauthPendingFlowFailingUseRedeemRepo struct {
- *oauthPendingFlowRedeemCodeRepo
-}
-
-func (r *oauthPendingFlowFailingUseRedeemRepo) Use(context.Context, int64, int64) error {
- return errors.New("forced invitation use failure")
-}
-
func decodeJSONResponseData(t *testing.T, recorder *httptest.ResponseRecorder) map[string]any {
t.Helper()
@@ -2093,7 +2207,7 @@ func countProviderGrantRecords(
}
type oauthPendingFlowUserRepo struct {
- client *dbent.Client
+ client *dbent.Client
options oauthPendingFlowUserRepoOptions
}
diff --git a/backend/internal/service/auth_oauth_email_flow.go b/backend/internal/service/auth_oauth_email_flow.go
index ce25222c..ea558ae2 100644
--- a/backend/internal/service/auth_oauth_email_flow.go
+++ b/backend/internal/service/auth_oauth_email_flow.go
@@ -7,6 +7,9 @@ import (
"net/mail"
"strings"
"time"
+
+ dbent "github.com/Wei-Shaw/sub2api/ent"
+ "github.com/Wei-Shaw/sub2api/ent/redeemcode"
)
func normalizeOAuthSignupSource(signupSource string) string {
@@ -50,7 +53,7 @@ func (s *AuthService) validateOAuthRegistrationInvitation(ctx context.Context, i
if s == nil || s.settingService == nil || !s.settingService.IsInvitationCodeEnabled(ctx) {
return nil, nil
}
- if s.redeemRepo == nil {
+ if s.redeemRepo == nil && s.oauthEmailFlowClient(ctx) == nil {
return nil, ErrServiceUnavailable
}
@@ -59,7 +62,7 @@ func (s *AuthService) validateOAuthRegistrationInvitation(ctx context.Context, i
return nil, ErrInvitationCodeRequired
}
- redeemCode, err := s.redeemRepo.GetByCode(ctx, invitationCode)
+ redeemCode, err := s.loadOAuthRegistrationInvitation(ctx, invitationCode)
if err != nil {
return nil, ErrInvitationCodeInvalid
}
@@ -181,12 +184,12 @@ func (s *AuthService) FinalizeOAuthEmailAccount(
return err
}
if invitationRedeemCode != nil {
- if err := s.redeemRepo.Use(ctx, invitationRedeemCode.ID, user.ID); err != nil {
+ if err := s.useOAuthRegistrationInvitation(ctx, invitationRedeemCode.ID, user.ID); err != nil {
return ErrInvitationCodeInvalid
}
}
- s.postAuthUserBootstrap(ctx, user, signupSource, false)
+ s.updateOAuthSignupSource(ctx, user.ID, signupSource)
grantPlan := s.resolveSignupGrantPlan(ctx, signupSource)
s.assignSubscriptions(ctx, user.ID, grantPlan.Subscriptions, "auto assigned by signup defaults")
return nil
@@ -211,7 +214,7 @@ func (s *AuthService) restoreOAuthRegistrationInvitation(ctx context.Context, in
if s == nil || s.settingService == nil || !s.settingService.IsInvitationCodeEnabled(ctx) {
return nil
}
- if s.redeemRepo == nil {
+ if s.redeemRepo == nil && s.oauthEmailFlowClient(ctx) == nil {
return ErrServiceUnavailable
}
@@ -220,7 +223,7 @@ func (s *AuthService) restoreOAuthRegistrationInvitation(ctx context.Context, in
return nil
}
- redeemCode, err := s.redeemRepo.GetByCode(ctx, invitationCode)
+ redeemCode, err := s.loadOAuthRegistrationInvitation(ctx, invitationCode)
if err != nil {
if errors.Is(err, ErrRedeemCodeNotFound) {
return nil
@@ -234,12 +237,115 @@ func (s *AuthService) restoreOAuthRegistrationInvitation(ctx context.Context, in
redeemCode.Status = StatusUnused
redeemCode.UsedBy = nil
redeemCode.UsedAt = nil
- if err := s.redeemRepo.Update(ctx, redeemCode); err != nil {
+ if err := s.updateOAuthRegistrationInvitation(ctx, redeemCode); err != nil {
return fmt.Errorf("restore invitation code: %w", err)
}
return nil
}
+func (s *AuthService) oauthEmailFlowClient(ctx context.Context) *dbent.Client {
+ if s == nil || s.entClient == nil {
+ return nil
+ }
+ if tx := dbent.TxFromContext(ctx); tx != nil {
+ return tx.Client()
+ }
+ return s.entClient
+}
+
+func (s *AuthService) loadOAuthRegistrationInvitation(ctx context.Context, invitationCode string) (*RedeemCode, error) {
+ if client := s.oauthEmailFlowClient(ctx); client != nil {
+ entity, err := client.RedeemCode.Query().Where(redeemcode.CodeEQ(invitationCode)).Only(ctx)
+ if err != nil {
+ if dbent.IsNotFound(err) {
+ return nil, ErrRedeemCodeNotFound
+ }
+ return nil, err
+ }
+ return &RedeemCode{
+ ID: entity.ID,
+ Code: entity.Code,
+ Type: entity.Type,
+ Value: entity.Value,
+ Status: entity.Status,
+ UsedBy: entity.UsedBy,
+ UsedAt: entity.UsedAt,
+ Notes: oauthEmailFlowStringValue(entity.Notes),
+ CreatedAt: entity.CreatedAt,
+ GroupID: entity.GroupID,
+ ValidityDays: entity.ValidityDays,
+ }, nil
+ }
+ return s.redeemRepo.GetByCode(ctx, invitationCode)
+}
+
+func (s *AuthService) useOAuthRegistrationInvitation(ctx context.Context, invitationID, userID int64) error {
+ if client := s.oauthEmailFlowClient(ctx); client != nil {
+ affected, err := client.RedeemCode.Update().
+ Where(redeemcode.IDEQ(invitationID), redeemcode.StatusEQ(StatusUnused)).
+ SetStatus(StatusUsed).
+ SetUsedBy(userID).
+ SetUsedAt(time.Now().UTC()).
+ Save(ctx)
+ if err != nil {
+ return err
+ }
+ if affected == 0 {
+ return ErrRedeemCodeUsed
+ }
+ return nil
+ }
+ return s.redeemRepo.Use(ctx, invitationID, userID)
+}
+
+func (s *AuthService) updateOAuthRegistrationInvitation(ctx context.Context, code *RedeemCode) error {
+ if code == nil {
+ return nil
+ }
+ if client := s.oauthEmailFlowClient(ctx); client != nil {
+ update := client.RedeemCode.UpdateOneID(code.ID).
+ SetCode(code.Code).
+ SetType(code.Type).
+ SetValue(code.Value).
+ SetStatus(code.Status).
+ SetNotes(code.Notes).
+ SetValidityDays(code.ValidityDays)
+ if code.UsedBy != nil {
+ update = update.SetUsedBy(*code.UsedBy)
+ } else {
+ update = update.ClearUsedBy()
+ }
+ if code.UsedAt != nil {
+ update = update.SetUsedAt(*code.UsedAt)
+ } else {
+ update = update.ClearUsedAt()
+ }
+ if code.GroupID != nil {
+ update = update.SetGroupID(*code.GroupID)
+ } else {
+ update = update.ClearGroupID()
+ }
+ _, err := update.Save(ctx)
+ return err
+ }
+ return s.redeemRepo.Update(ctx, code)
+}
+
+func (s *AuthService) updateOAuthSignupSource(ctx context.Context, userID int64, signupSource string) {
+ client := s.oauthEmailFlowClient(ctx)
+ if client == nil || userID <= 0 || strings.TrimSpace(signupSource) == "" {
+ return
+ }
+ _ = client.User.UpdateOneID(userID).SetSignupSource(signupSource).Exec(ctx)
+}
+
+func oauthEmailFlowStringValue(value *string) string {
+ if value == nil {
+ return ""
+ }
+ return *value
+}
+
// ValidatePasswordCredentials checks the local password without completing the
// login flow. This is used by pending third-party account adoption flows before
// the external identity has been bound.
@@ -269,7 +375,7 @@ func (s *AuthService) ValidatePasswordCredentials(ctx context.Context, email, pa
func (s *AuthService) RecordSuccessfulLogin(ctx context.Context, userID int64) {
if s != nil && s.userRepo != nil && userID > 0 {
user, err := s.userRepo.GetByID(ctx, userID)
- if err == nil {
+ if err == nil && user != nil && !isReservedEmail(user.Email) {
s.backfillEmailIdentityOnSuccessfulLogin(ctx, user)
}
}
diff --git a/backend/internal/service/user_service.go b/backend/internal/service/user_service.go
index 9c7d4747..e6053984 100644
--- a/backend/internal/service/user_service.go
+++ b/backend/internal/service/user_service.go
@@ -240,7 +240,7 @@ func (s *UserService) GetProfileIdentitySummaries(ctx context.Context, userID in
}
return UserIdentitySummarySet{
- Email: s.buildEmailIdentitySummary(user),
+ Email: s.buildEmailIdentitySummary(user, records),
LinuxDo: s.buildProviderIdentitySummary("linuxdo", records),
OIDC: s.buildProviderIdentitySummary("oidc", records),
WeChat: s.buildProviderIdentitySummary("wechat", records),
@@ -497,7 +497,7 @@ func compressInlineAvatar(decoded []byte) ([]byte, string, error) {
return nil, "", ErrAvatarTooLarge
}
-func (s *UserService) buildEmailIdentitySummary(user *User) UserIdentitySummary {
+func (s *UserService) buildEmailIdentitySummary(user *User, records []UserAuthIdentityRecord) UserIdentitySummary {
summary := UserIdentitySummary{
Provider: "email",
CanBind: false,
@@ -508,11 +508,34 @@ func (s *UserService) buildEmailIdentitySummary(user *User) UserIdentitySummary
return summary
}
+ filtered := filterUserAuthIdentities(records, "email")
+ if len(filtered) > 0 {
+ primary := selectPrimaryUserAuthIdentity(filtered)
+ email := strings.TrimSpace(firstStringIdentityValue(primary.Metadata, "email"))
+ if email == "" {
+ email = strings.TrimSpace(primary.ProviderSubject)
+ }
+ if email == "" || isReservedEmail(email) {
+ email = strings.TrimSpace(user.Email)
+ }
+ if email == "" || isReservedEmail(email) {
+ email = strings.TrimSpace(primary.ProviderKey)
+ }
+
+ summary.Bound = true
+ summary.BoundCount = len(filtered)
+ summary.DisplayName = email
+ summary.SubjectHint = maskEmailIdentity(email)
+ summary.ProviderKey = strings.TrimSpace(primary.ProviderKey)
+ summary.VerifiedAt = primary.VerifiedAt
+ return summary
+ }
+
+ // Compatibility fallback for legacy normal-email users that predate auth_identities backfill.
email := strings.TrimSpace(user.Email)
if email == "" || isReservedEmail(email) {
return summary
}
-
summary.Bound = true
summary.BoundCount = 1
summary.DisplayName = email
diff --git a/frontend/src/api/auth.ts b/frontend/src/api/auth.ts
index 0f768018..89964c3c 100644
--- a/frontend/src/api/auth.ts
+++ b/frontend/src/api/auth.ts
@@ -208,6 +208,12 @@ export type PendingOAuthExchangeResponse = PendingOAuthBindLoginResponse
export interface PendingOAuthCreateAccountResponse extends OAuthTokenResponse {}
+export interface PendingOAuthSendVerifyCodeResponse extends SendVerifyCodeResponse {
+ auth_result?: string
+ provider?: string
+ redirect?: string
+}
+
export type OAuthCompletionKind = 'login' | 'bind'
export interface OAuthAdoptionDecision {
@@ -451,8 +457,8 @@ export async function sendVerifyCode(
export async function sendPendingOAuthVerifyCode(
request: SendVerifyCodeRequest
-): Promise {
- const { data } = await apiClient.post(
+): Promise {
+ const { data } = await apiClient.post(
'/auth/oauth/pending/send-verify-code',
request
)
diff --git a/frontend/src/components/user/profile/ProfileIdentityBindingsSection.vue b/frontend/src/components/user/profile/ProfileIdentityBindingsSection.vue
index ccc1cbd0..653b4e33 100644
--- a/frontend/src/components/user/profile/ProfileIdentityBindingsSection.vue
+++ b/frontend/src/components/user/profile/ProfileIdentityBindingsSection.vue
@@ -209,7 +209,12 @@ function getBindingStatus(provider: UserAuthProvider): boolean {
function getBindingStatusForUser(user: User | null | undefined, provider: UserAuthProvider): boolean {
if (provider === 'email') {
- return typeof user?.email_bound === 'boolean' ? user.email_bound : Boolean(user?.email)
+ if (typeof user?.email_bound === 'boolean') {
+ return user.email_bound
+ }
+ const nested = user?.auth_bindings?.email ?? user?.identity_bindings?.email
+ const normalized = normalizeBindingStatus(nested)
+ return normalized ?? false
}
const directFlag = user?.[`${provider}_bound` as keyof User]
diff --git a/frontend/src/components/user/profile/__tests__/ProfileIdentityBindingsSection.spec.ts b/frontend/src/components/user/profile/__tests__/ProfileIdentityBindingsSection.spec.ts
index ec4aed5d..c07acf18 100644
--- a/frontend/src/components/user/profile/__tests__/ProfileIdentityBindingsSection.spec.ts
+++ b/frontend/src/components/user/profile/__tests__/ProfileIdentityBindingsSection.spec.ts
@@ -301,4 +301,27 @@ describe('ProfileIdentityBindingsSection', () => {
expect(wrapper.get('[data-testid="profile-binding-email-status"]').text()).toBe('Bound')
expect(authStore.user?.email).toBe('bound@example.com')
})
+
+ it('keeps the email binding form visible when the user still lacks an email identity', () => {
+ const wrapper = mount(ProfileIdentityBindingsSection, {
+ global: {
+ plugins: [pinia],
+ },
+ props: {
+ user: createUser({
+ email: 'legacy@example.com',
+ email_bound: false,
+ auth_bindings: {
+ email: { bound: false },
+ },
+ }),
+ linuxdoEnabled: false,
+ oidcEnabled: false,
+ wechatEnabled: false,
+ },
+ })
+
+ expect(wrapper.get('[data-testid="profile-binding-email-status"]').text()).toBe('Not bound')
+ expect(wrapper.get('[data-testid="profile-binding-email-input"]').exists()).toBe(true)
+ })
})
diff --git a/frontend/src/views/auth/EmailVerifyView.vue b/frontend/src/views/auth/EmailVerifyView.vue
index d7bf6b7a..01829765 100644
--- a/frontend/src/views/auth/EmailVerifyView.vue
+++ b/frontend/src/views/auth/EmailVerifyView.vue
@@ -179,6 +179,8 @@ import { useAuthStore, useAppStore } from '@/stores'
import {
persistOAuthTokenContext,
getPublicSettings,
+ isOAuthLoginCompletion,
+ type PendingOAuthSendVerifyCodeResponse,
sendPendingOAuthVerifyCode,
sendVerifyCode,
} from '@/api/auth'
@@ -216,10 +218,13 @@ type PendingAuthSessionSummary = {
redirect?: string
}
type PendingOAuthCreateAccountResponse = {
+ auth_result?: string
access_token: string
refresh_token?: string
expires_in?: number
token_type?: string
+ provider?: string
+ redirect?: string
}
const email = ref('')
@@ -353,6 +358,46 @@ function onTurnstileError(): void {
errors.value.turnstile = t('auth.turnstileFailed')
}
+function isPendingOAuthFlow(): boolean {
+ return Boolean(pendingProvider.value.trim())
+}
+
+function shouldBypassRegistrationEmailPolicy(): boolean {
+ return isPendingOAuthFlow() || Boolean(pendingAuthToken.value.trim())
+}
+
+function resolvePendingOAuthCallbackRoute(provider: string): string {
+ switch (provider.trim().toLowerCase()) {
+ case 'linuxdo':
+ return '/auth/linuxdo/callback'
+ case 'oidc':
+ return '/auth/oidc/callback'
+ case 'wechat':
+ return '/auth/wechat/callback'
+ default:
+ return '/auth/callback'
+ }
+}
+
+function isPendingOAuthSessionResponse(data: PendingOAuthCreateAccountResponse): boolean {
+ return data.auth_result === 'pending_session'
+}
+
+function getPendingOAuthSendCodeSessionResponse(
+ data: PendingOAuthSendVerifyCodeResponse,
+): PendingOAuthSendVerifyCodeResponse | null {
+ return data.auth_result === 'pending_session' ? data : null
+}
+
+function persistPendingOAuthSession(provider: string, redirect?: string): void {
+ authStore.setPendingAuthSession({
+ token: pendingAuthToken.value,
+ token_field: pendingAuthTokenField.value,
+ provider: provider.trim() || pendingProvider.value.trim(),
+ redirect: redirect || pendingRedirect.value || undefined,
+ })
+}
+
// ==================== Send Code ====================
async function sendCode(): Promise {
@@ -360,7 +405,7 @@ async function sendCode(): Promise {
errorMessage.value = ''
try {
- if (!pendingAuthToken.value && !isRegistrationEmailSuffixAllowed(email.value, registrationEmailSuffixWhitelist.value)) {
+ if (!shouldBypassRegistrationEmailPolicy() && !isRegistrationEmailSuffixAllowed(email.value, registrationEmailSuffixWhitelist.value)) {
errorMessage.value = buildEmailSuffixNotAllowedMessage()
appStore.showError(errorMessage.value)
return
@@ -372,10 +417,25 @@ async function sendCode(): Promise {
// 优先使用重发时新获取的 token(因为初始 token 可能已被使用)
turnstile_token: resendTurnstileToken.value || initialTurnstileToken.value || undefined
} as Parameters[0]
- const response = pendingAuthToken.value
+ const response = isPendingOAuthFlow()
? await sendPendingOAuthVerifyCode(requestPayload)
: await sendVerifyCode(requestPayload)
+ const pendingSendCodeSession = isPendingOAuthFlow()
+ ? getPendingOAuthSendCodeSessionResponse(response as PendingOAuthSendVerifyCodeResponse)
+ : null
+ if (pendingSendCodeSession) {
+ sessionStorage.removeItem('register_data')
+ persistPendingOAuthSession(
+ pendingSendCodeSession.provider || pendingProvider.value,
+ pendingSendCodeSession.redirect,
+ )
+ await router.push(
+ resolvePendingOAuthCallbackRoute(pendingSendCodeSession.provider || pendingProvider.value),
+ )
+ return
+ }
+
codeSent.value = true
startCountdown(response.countdown)
@@ -438,13 +498,13 @@ async function handleVerify(): Promise {
isLoading.value = true
try {
- if (!isRegistrationEmailSuffixAllowed(email.value, registrationEmailSuffixWhitelist.value)) {
+ if (!shouldBypassRegistrationEmailPolicy() && !isRegistrationEmailSuffixAllowed(email.value, registrationEmailSuffixWhitelist.value)) {
errorMessage.value = buildEmailSuffixNotAllowedMessage()
appStore.showError(errorMessage.value)
return
}
- if (pendingProvider.value) {
+ if (isPendingOAuthFlow()) {
const { data } = await apiClient.post(
'/auth/oauth/pending/create-account',
{
@@ -456,6 +516,16 @@ async function handleVerify(): Promise {
adopt_avatar: pendingAdoptionDecision.value?.adoptAvatar
}
)
+ if (isPendingOAuthSessionResponse(data)) {
+ sessionStorage.removeItem('register_data')
+ persistPendingOAuthSession(data.provider || pendingProvider.value, data.redirect)
+ await router.push(resolvePendingOAuthCallbackRoute(data.provider || pendingProvider.value))
+ return
+ }
+ if (!isOAuthLoginCompletion(data)) {
+ throw new Error(t('auth.verifyFailed'))
+ }
+
persistOAuthTokenContext(data)
await authStore.setToken(data.access_token)
authStore.clearPendingAuthSession?.()
diff --git a/frontend/src/views/auth/__tests__/EmailVerifyView.spec.ts b/frontend/src/views/auth/__tests__/EmailVerifyView.spec.ts
index c231d6e7..9f67a994 100644
--- a/frontend/src/views/auth/__tests__/EmailVerifyView.spec.ts
+++ b/frontend/src/views/auth/__tests__/EmailVerifyView.spec.ts
@@ -8,6 +8,7 @@ const {
showErrorMock,
registerMock,
setTokenMock,
+ setPendingAuthSessionMock,
clearPendingAuthSessionMock,
getPublicSettingsMock,
sendVerifyCodeMock,
@@ -21,6 +22,7 @@ const {
showErrorMock: vi.fn(),
registerMock: vi.fn(),
setTokenMock: vi.fn(),
+ setPendingAuthSessionMock: vi.fn(),
clearPendingAuthSessionMock: vi.fn(),
getPublicSettingsMock: vi.fn(),
sendVerifyCodeMock: vi.fn(),
@@ -68,6 +70,7 @@ vi.mock('@/stores', () => ({
pendingAuthSession: authStoreState.pendingAuthSession,
register: (...args: any[]) => registerMock(...args),
setToken: (...args: any[]) => setTokenMock(...args),
+ setPendingAuthSession: (...args: any[]) => setPendingAuthSessionMock(...args),
clearPendingAuthSession: (...args: any[]) => clearPendingAuthSessionMock(...args),
}),
useAppStore: () => ({
@@ -100,6 +103,7 @@ describe('EmailVerifyView', () => {
showErrorMock.mockReset()
registerMock.mockReset()
setTokenMock.mockReset()
+ setPendingAuthSessionMock.mockReset()
clearPendingAuthSessionMock.mockReset()
getPublicSettingsMock.mockReset()
sendVerifyCodeMock.mockReset()
@@ -196,6 +200,97 @@ describe('EmailVerifyView', () => {
expect(showErrorMock).not.toHaveBeenCalled()
})
+ it('uses the pending oauth verify-code endpoint when auth store only carries the pending provider', async () => {
+ authStoreState.pendingAuthSession = {
+ token: '',
+ token_field: 'pending_oauth_token',
+ provider: 'oidc',
+ redirect: '/profile',
+ }
+ getPublicSettingsMock.mockResolvedValue({
+ turnstile_enabled: false,
+ turnstile_site_key: '',
+ site_name: 'Sub2API',
+ registration_email_suffix_whitelist: ['allowed.com'],
+ })
+ sessionStorage.setItem(
+ 'register_data',
+ JSON.stringify({
+ email: 'fresh@example.com',
+ password: 'secret-123',
+ })
+ )
+
+ mount(EmailVerifyView, {
+ global: {
+ stubs: {
+ AuthLayout: { template: '
' },
+ Icon: true,
+ TurnstileWidget: true,
+ transition: false,
+ },
+ },
+ })
+
+ await flushPromises()
+
+ expect(sendPendingOAuthVerifyCodeMock).toHaveBeenCalledWith({
+ email: 'fresh@example.com',
+ pending_oauth_token: undefined,
+ })
+ expect(sendVerifyCodeMock).not.toHaveBeenCalled()
+ expect(showErrorMock).not.toHaveBeenCalled()
+ })
+
+ it('returns to the oauth callback flow when pending send-code detects an existing account email', async () => {
+ authStoreState.pendingAuthSession = {
+ token: '',
+ token_field: 'pending_oauth_token',
+ provider: 'oidc',
+ redirect: '/profile/security',
+ }
+ getPublicSettingsMock.mockResolvedValue({
+ turnstile_enabled: false,
+ turnstile_site_key: '',
+ site_name: 'Sub2API',
+ registration_email_suffix_whitelist: ['allowed.com'],
+ })
+ sendPendingOAuthVerifyCodeMock.mockResolvedValue({
+ auth_result: 'pending_session',
+ provider: 'oidc',
+ redirect: '/profile/security',
+ })
+ sessionStorage.setItem(
+ 'register_data',
+ JSON.stringify({
+ email: 'fresh@example.com',
+ password: 'secret-123',
+ })
+ )
+
+ mount(EmailVerifyView, {
+ global: {
+ stubs: {
+ AuthLayout: { template: '
' },
+ Icon: true,
+ TurnstileWidget: true,
+ transition: false,
+ },
+ },
+ })
+
+ await flushPromises()
+
+ expect(setPendingAuthSessionMock).toHaveBeenCalledWith({
+ token: '',
+ token_field: 'pending_oauth_token',
+ provider: 'oidc',
+ redirect: '/profile/security',
+ })
+ expect(pushMock).toHaveBeenCalledWith('/auth/oidc/callback')
+ expect(showErrorMock).not.toHaveBeenCalled()
+ })
+
it('submits pending auth account creation when session storage has no pending metadata but auth store does', async () => {
authStoreState.pendingAuthSession = {
token: 'pending-token-1',
@@ -252,6 +347,70 @@ describe('EmailVerifyView', () => {
expect(registerMock).not.toHaveBeenCalled()
})
+ it('returns to the oauth callback flow when pending account creation becomes bind-login', async () => {
+ authStoreState.pendingAuthSession = {
+ token: '',
+ token_field: 'pending_oauth_token',
+ provider: 'oidc',
+ redirect: '/profile/security',
+ }
+ getPublicSettingsMock.mockResolvedValue({
+ turnstile_enabled: false,
+ turnstile_site_key: '',
+ site_name: 'Sub2API',
+ registration_email_suffix_whitelist: ['allowed.com'],
+ })
+ sessionStorage.setItem(
+ 'register_data',
+ JSON.stringify({
+ email: 'fresh@example.com',
+ password: 'secret-123',
+ })
+ )
+ apiClientPostMock.mockResolvedValue({
+ data: {
+ auth_result: 'pending_session',
+ provider: 'oidc',
+ step: 'bind_login_required',
+ redirect: '/profile/security',
+ email: 'fresh@example.com',
+ },
+ })
+
+ const wrapper = mount(EmailVerifyView, {
+ global: {
+ stubs: {
+ AuthLayout: { template: '
' },
+ Icon: true,
+ TurnstileWidget: true,
+ transition: false,
+ },
+ },
+ })
+
+ await flushPromises()
+ await wrapper.get('#code').setValue('123456')
+ await wrapper.get('form').trigger('submit.prevent')
+ await flushPromises()
+
+ expect(apiClientPostMock).toHaveBeenCalledWith('/auth/oauth/pending/create-account', {
+ email: 'fresh@example.com',
+ password: 'secret-123',
+ verify_code: '123456',
+ })
+ expect(setPendingAuthSessionMock).toHaveBeenCalledWith({
+ token: '',
+ token_field: 'pending_oauth_token',
+ provider: 'oidc',
+ redirect: '/profile/security',
+ })
+ expect(pushMock).toHaveBeenCalledWith('/auth/oidc/callback')
+ expect(setTokenMock).not.toHaveBeenCalled()
+ expect(persistOAuthTokenContextMock).not.toHaveBeenCalled()
+ expect(clearPendingAuthSessionMock).not.toHaveBeenCalled()
+ expect(showSuccessMock).not.toHaveBeenCalled()
+ })
+
it('keeps the normal email registration flow unchanged', async () => {
sessionStorage.setItem(
'register_data',
From f3986501663c864594f1444440f5e36199595983 Mon Sep 17 00:00:00 2001
From: IanShaw027
Date: Tue, 21 Apr 2026 11:00:08 +0800
Subject: [PATCH 097/189] fix: harden oidc compat email and email bind tx
---
backend/internal/handler/auth_oidc_oauth.go | 69 ++++++-
.../internal/handler/auth_oidc_oauth_test.go | 121 +++++++++++++
.../internal/service/auth_email_binding.go | 169 ++++++++++++++++++
.../service/auth_service_email_bind_test.go | 71 ++++++++
4 files changed, 424 insertions(+), 6 deletions(-)
diff --git a/backend/internal/handler/auth_oidc_oauth.go b/backend/internal/handler/auth_oidc_oauth.go
index 5901a953..6d19e9d6 100644
--- a/backend/internal/handler/auth_oidc_oauth.go
+++ b/backend/internal/handler/auth_oidc_oauth.go
@@ -19,6 +19,7 @@ import (
"strings"
"time"
+ dbent "github.com/Wei-Shaw/sub2api/ent"
"github.com/Wei-Shaw/sub2api/internal/config"
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
"github.com/Wei-Shaw/sub2api/internal/pkg/oauth"
@@ -323,18 +324,13 @@ func (h *AuthHandler) OIDCOAuthCallback(c *gin.Context) {
if emailVerified == nil {
emailVerified = idClaims.EmailVerified
}
- if cfg.RequireEmailVerified {
- if emailVerified == nil || !*emailVerified {
- redirectOAuthError(c, frontendCallback, "email_not_verified", "email is not verified", "")
- return
- }
- }
if userInfoClaims.Subject != "" && idClaims.Subject != "" && strings.TrimSpace(userInfoClaims.Subject) != strings.TrimSpace(idClaims.Subject) {
redirectOAuthError(c, frontendCallback, "subject_mismatch", "userinfo subject does not match id_token", "")
return
}
identityKey := oidcIdentityKey(issuer, subject)
+ compatEmail := strings.TrimSpace(firstNonEmpty(userInfoClaims.Email, idClaims.Email))
email := oidcSyntheticEmailFromIdentityKey(identityKey)
username := firstNonEmpty(
userInfoClaims.Username,
@@ -357,6 +353,9 @@ func (h *AuthHandler) OIDCOAuthCallback(c *gin.Context) {
"suggested_display_name": firstNonEmpty(userInfoClaims.DisplayName, idClaims.Name, username),
"suggested_avatar_url": userInfoClaims.AvatarURL,
}
+ if compatEmail != "" && !strings.EqualFold(strings.TrimSpace(compatEmail), strings.TrimSpace(email)) {
+ upstreamClaims["compat_email"] = compatEmail
+ }
if intent == oauthIntentBindCurrentUser {
targetUserID, err := h.readOAuthBindUserIDFromCookie(c, oidcOAuthBindUserCookieName)
if err != nil {
@@ -416,6 +415,40 @@ func (h *AuthHandler) OIDCOAuthCallback(c *gin.Context) {
return
}
+ compatEmailUser, err := h.findOIDCCompatEmailUser(c.Request.Context(), compatEmail)
+ if err != nil {
+ redirectOAuthError(c, frontendCallback, "session_error", infraerrors.Reason(err), infraerrors.Message(err))
+ return
+ }
+ if compatEmailUser != nil {
+ if err := h.createOAuthPendingSession(c, oauthPendingSessionPayload{
+ Intent: "adopt_existing_user_by_email",
+ Identity: identityRef,
+ TargetUserID: &compatEmailUser.ID,
+ ResolvedEmail: compatEmailUser.Email,
+ RedirectTo: redirectTo,
+ BrowserSessionKey: browserSessionKey,
+ UpstreamIdentityClaims: upstreamClaims,
+ CompletionResponse: map[string]any{
+ "redirect": redirectTo,
+ "step": "bind_login_required",
+ "email": compatEmailUser.Email,
+ },
+ }); err != nil {
+ redirectOAuthError(c, frontendCallback, "session_error", "failed to continue oauth login", "")
+ return
+ }
+ redirectToFrontendCallback(c, frontendCallback)
+ return
+ }
+
+ if cfg.RequireEmailVerified {
+ if emailVerified == nil || !*emailVerified {
+ redirectOAuthError(c, frontendCallback, "email_not_verified", "email is not verified", "")
+ return
+ }
+ }
+
if h.isForceEmailOnThirdPartySignup(c.Request.Context()) {
if err := h.createOAuthEmailRequiredPendingSession(c, identityRef, redirectTo, browserSessionKey, upstreamClaims); err != nil {
redirectOAuthError(c, frontendCallback, "session_error", "failed to continue oauth login", "")
@@ -473,6 +506,30 @@ func (h *AuthHandler) OIDCOAuthCallback(c *gin.Context) {
redirectToFrontendCallback(c, frontendCallback)
}
+func (h *AuthHandler) findOIDCCompatEmailUser(ctx context.Context, email string) (*dbent.User, error) {
+ client := h.entClient()
+ if client == nil {
+ return nil, infraerrors.ServiceUnavailable("PENDING_AUTH_NOT_READY", "pending auth service is not ready")
+ }
+
+ email = strings.TrimSpace(strings.ToLower(email))
+ if email == "" ||
+ strings.HasSuffix(email, service.LinuxDoConnectSyntheticEmailDomain) ||
+ strings.HasSuffix(email, service.OIDCConnectSyntheticEmailDomain) ||
+ strings.HasSuffix(email, service.WeChatConnectSyntheticEmailDomain) {
+ return nil, nil
+ }
+
+ userEntity, err := findUserByNormalizedEmail(ctx, client, email)
+ if err != nil {
+ if errors.Is(err, service.ErrUserNotFound) {
+ return nil, nil
+ }
+ return nil, infraerrors.InternalServer("COMPAT_EMAIL_LOOKUP_FAILED", "failed to look up compat email user").WithCause(err)
+ }
+ return userEntity, nil
+}
+
type completeOIDCOAuthRequest struct {
InvitationCode string `json:"invitation_code" binding:"required"`
AdoptDisplayName *bool `json:"adopt_display_name,omitempty"`
diff --git a/backend/internal/handler/auth_oidc_oauth_test.go b/backend/internal/handler/auth_oidc_oauth_test.go
index ba736db2..5cd8e0ea 100644
--- a/backend/internal/handler/auth_oidc_oauth_test.go
+++ b/backend/internal/handler/auth_oidc_oauth_test.go
@@ -245,6 +245,127 @@ func TestOIDCOAuthCallbackCreatesLoginPendingSessionForExistingUser(t *testing.T
require.Nil(t, completion["error"])
}
+func TestOIDCOAuthCallbackCreatesBindPendingSessionForCompatEmailUser(t *testing.T) {
+ cfg, cleanup := newOIDCTestProvider(t, oidcProviderFixture{
+ Subject: "oidc-subject-compat",
+ PreferredUsername: "oidc_compat",
+ DisplayName: "OIDC Compat Display",
+ AvatarURL: "https://cdn.example/oidc-compat.png",
+ Email: "legacy@example.com",
+ EmailVerified: true,
+ })
+ defer cleanup()
+
+ handler, client := newOIDCOAuthHandlerAndClient(t, false, cfg)
+ defer client.Close()
+
+ ctx := context.Background()
+ existingUser, err := client.User.Create().
+ SetEmail("legacy@example.com").
+ SetUsername("legacy-user").
+ SetPasswordHash("hash").
+ SetRole(service.RoleUser).
+ SetStatus(service.StatusActive).
+ Save(ctx)
+ require.NoError(t, err)
+
+ recorder := httptest.NewRecorder()
+ c, _ := gin.CreateTestContext(recorder)
+ req := httptest.NewRequest(http.MethodGet, "/api/v1/auth/oauth/oidc/callback?code=oidc-code&state=state-compat", nil)
+ req.AddCookie(encodedCookie(oidcOAuthStateCookieName, "state-compat"))
+ req.AddCookie(encodedCookie(oidcOAuthRedirectCookie, "/dashboard"))
+ req.AddCookie(encodedCookie(oidcOAuthVerifierCookie, "verifier-compat"))
+ req.AddCookie(encodedCookie(oidcOAuthNonceCookie, "nonce-oidc-subject-compat"))
+ req.AddCookie(encodedCookie(oidcOAuthIntentCookieName, oauthIntentLogin))
+ req.AddCookie(encodedCookie(oauthPendingBrowserCookieName, "browser-compat"))
+ c.Request = req
+
+ handler.OIDCOAuthCallback(c)
+
+ require.Equal(t, http.StatusFound, recorder.Code)
+ require.Equal(t, "/auth/oidc/callback", recorder.Header().Get("Location"))
+
+ sessionCookie := findCookie(recorder.Result().Cookies(), oauthPendingSessionCookieName)
+ require.NotNil(t, sessionCookie)
+
+ session, err := client.PendingAuthSession.Query().
+ Where(pendingauthsession.SessionTokenEQ(decodeCookieValueForTest(t, sessionCookie.Value))).
+ Only(ctx)
+ require.NoError(t, err)
+ require.Equal(t, "adopt_existing_user_by_email", session.Intent)
+ require.NotNil(t, session.TargetUserID)
+ require.Equal(t, existingUser.ID, *session.TargetUserID)
+ require.Equal(t, existingUser.Email, session.ResolvedEmail)
+ require.Equal(t, "legacy@example.com", session.UpstreamIdentityClaims["compat_email"])
+
+ completion := session.LocalFlowState[oauthCompletionResponseKey].(map[string]any)
+ require.Equal(t, "/dashboard", completion["redirect"])
+ require.Equal(t, "bind_login_required", completion["step"])
+ require.Equal(t, existingUser.Email, completion["email"])
+ _, hasAccessToken := completion["access_token"]
+ require.False(t, hasAccessToken)
+}
+
+func TestOIDCOAuthCallbackAllowsCompatEmailBindWhenUpstreamEmailIsUnverified(t *testing.T) {
+ cfg, cleanup := newOIDCTestProvider(t, oidcProviderFixture{
+ Subject: "oidc-subject-unverified-compat",
+ PreferredUsername: "oidc_unverified",
+ DisplayName: "OIDC Unverified Compat Display",
+ AvatarURL: "https://cdn.example/oidc-unverified.png",
+ Email: "owner@example.com",
+ EmailVerified: false,
+ })
+ defer cleanup()
+ cfg.RequireEmailVerified = true
+
+ handler, client := newOIDCOAuthHandlerAndClient(t, false, cfg)
+ defer client.Close()
+
+ ctx := context.Background()
+ existingUser, err := client.User.Create().
+ SetEmail("owner@example.com").
+ SetUsername("owner-user").
+ SetPasswordHash("hash").
+ SetRole(service.RoleUser).
+ SetStatus(service.StatusActive).
+ Save(ctx)
+ require.NoError(t, err)
+
+ recorder := httptest.NewRecorder()
+ c, _ := gin.CreateTestContext(recorder)
+ req := httptest.NewRequest(http.MethodGet, "/api/v1/auth/oauth/oidc/callback?code=oidc-code&state=state-unverified-compat", nil)
+ req.AddCookie(encodedCookie(oidcOAuthStateCookieName, "state-unverified-compat"))
+ req.AddCookie(encodedCookie(oidcOAuthRedirectCookie, "/settings/connections"))
+ req.AddCookie(encodedCookie(oidcOAuthVerifierCookie, "verifier-unverified-compat"))
+ req.AddCookie(encodedCookie(oidcOAuthNonceCookie, "nonce-oidc-subject-unverified-compat"))
+ req.AddCookie(encodedCookie(oidcOAuthIntentCookieName, oauthIntentLogin))
+ req.AddCookie(encodedCookie(oauthPendingBrowserCookieName, "browser-unverified-compat"))
+ c.Request = req
+
+ handler.OIDCOAuthCallback(c)
+
+ require.Equal(t, http.StatusFound, recorder.Code)
+ require.Equal(t, "/auth/oidc/callback", recorder.Header().Get("Location"))
+
+ sessionCookie := findCookie(recorder.Result().Cookies(), oauthPendingSessionCookieName)
+ require.NotNil(t, sessionCookie)
+
+ session, err := client.PendingAuthSession.Query().
+ Where(pendingauthsession.SessionTokenEQ(decodeCookieValueForTest(t, sessionCookie.Value))).
+ Only(ctx)
+ require.NoError(t, err)
+ require.Equal(t, "adopt_existing_user_by_email", session.Intent)
+ require.NotNil(t, session.TargetUserID)
+ require.Equal(t, existingUser.ID, *session.TargetUserID)
+ require.Equal(t, existingUser.Email, session.ResolvedEmail)
+ require.Equal(t, "owner@example.com", session.UpstreamIdentityClaims["compat_email"])
+
+ completion := session.LocalFlowState[oauthCompletionResponseKey].(map[string]any)
+ require.Equal(t, "/settings/connections", completion["redirect"])
+ require.Equal(t, "bind_login_required", completion["step"])
+ require.Equal(t, existingUser.Email, completion["email"])
+}
+
func TestOIDCOAuthCallbackCreatesInvitationPendingSessionWhenSignupRequiresInvite(t *testing.T) {
cfg, cleanup := newOIDCTestProvider(t, oidcProviderFixture{
Subject: "oidc-subject-invite",
diff --git a/backend/internal/service/auth_email_binding.go b/backend/internal/service/auth_email_binding.go
index b999660b..58f8e647 100644
--- a/backend/internal/service/auth_email_binding.go
+++ b/backend/internal/service/auth_email_binding.go
@@ -6,7 +6,10 @@ import (
"fmt"
"net/mail"
"strings"
+ "time"
+ dbent "github.com/Wei-Shaw/sub2api/ent"
+ "github.com/Wei-Shaw/sub2api/ent/authidentity"
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
)
@@ -55,6 +58,13 @@ func (s *AuthService) BindEmailIdentity(
}
firstRealEmailBind := !hasBindableEmailIdentitySubject(currentUser.Email)
+ if firstRealEmailBind && s.entClient != nil {
+ if err := s.bindEmailIdentityWithDefaultsTx(ctx, currentUser, normalizedEmail, hashedPassword); err != nil {
+ return nil, err
+ }
+ return currentUser, nil
+ }
+
currentUser.Email = normalizedEmail
currentUser.PasswordHash = hashedPassword
if err := s.userRepo.Update(ctx, currentUser); err != nil {
@@ -126,3 +136,162 @@ func hasBindableEmailIdentitySubject(email string) bool {
normalized := strings.ToLower(strings.TrimSpace(email))
return normalized != "" && !isReservedEmail(normalized)
}
+
+func (s *AuthService) bindEmailIdentityWithDefaultsTx(
+ ctx context.Context,
+ currentUser *User,
+ email string,
+ hashedPassword string,
+) error {
+ if tx := dbent.TxFromContext(ctx); tx != nil {
+ return s.bindEmailIdentityWithDefaults(ctx, tx.Client(), currentUser, email, hashedPassword)
+ }
+
+ tx, err := s.entClient.Tx(ctx)
+ if err != nil {
+ return ErrServiceUnavailable
+ }
+ defer func() { _ = tx.Rollback() }()
+
+ txCtx := dbent.NewTxContext(ctx, tx)
+ if err := s.bindEmailIdentityWithDefaults(txCtx, tx.Client(), currentUser, email, hashedPassword); err != nil {
+ return err
+ }
+ if err := tx.Commit(); err != nil {
+ return ErrServiceUnavailable
+ }
+ return nil
+}
+
+func (s *AuthService) bindEmailIdentityWithDefaults(
+ ctx context.Context,
+ client *dbent.Client,
+ currentUser *User,
+ email string,
+ hashedPassword string,
+) error {
+ if client == nil || currentUser == nil || currentUser.ID <= 0 {
+ return ErrServiceUnavailable
+ }
+
+ oldEmail := currentUser.Email
+ if _, err := client.User.UpdateOneID(currentUser.ID).
+ SetEmail(email).
+ SetPasswordHash(hashedPassword).
+ Save(ctx); err != nil {
+ if dbent.IsConstraintError(err) {
+ return ErrEmailExists
+ }
+ return ErrServiceUnavailable
+ }
+
+ if err := replaceBoundEmailAuthIdentityWithClient(ctx, client, currentUser.ID, oldEmail, email, "auth_service_email_bind"); err != nil {
+ if errors.Is(err, ErrEmailExists) {
+ return ErrEmailExists
+ }
+ return ErrServiceUnavailable
+ }
+
+ if err := s.ApplyProviderDefaultSettingsOnFirstBind(ctx, currentUser.ID, "email"); err != nil {
+ return fmt.Errorf("apply email first bind defaults: %w", err)
+ }
+
+ updatedUser, err := client.User.Get(ctx, currentUser.ID)
+ if err != nil {
+ return ErrServiceUnavailable
+ }
+ currentUser.Email = updatedUser.Email
+ currentUser.PasswordHash = updatedUser.PasswordHash
+ currentUser.Balance = updatedUser.Balance
+ currentUser.Concurrency = updatedUser.Concurrency
+ currentUser.UpdatedAt = updatedUser.UpdatedAt
+ return nil
+}
+
+func replaceBoundEmailAuthIdentityWithClient(
+ ctx context.Context,
+ client *dbent.Client,
+ userID int64,
+ oldEmail string,
+ newEmail string,
+ source string,
+) error {
+ newSubject := normalizeBoundEmailAuthIdentitySubject(newEmail)
+ if err := ensureBoundEmailAuthIdentityWithClient(ctx, client, userID, newSubject, source); err != nil {
+ return err
+ }
+
+ oldSubject := normalizeBoundEmailAuthIdentitySubject(oldEmail)
+ if oldSubject == "" || oldSubject == newSubject {
+ return nil
+ }
+
+ _, err := client.AuthIdentity.Delete().
+ Where(
+ authidentity.UserIDEQ(userID),
+ authidentity.ProviderTypeEQ("email"),
+ authidentity.ProviderKeyEQ("email"),
+ authidentity.ProviderSubjectEQ(oldSubject),
+ ).
+ Exec(ctx)
+ return err
+}
+
+func ensureBoundEmailAuthIdentityWithClient(
+ ctx context.Context,
+ client *dbent.Client,
+ userID int64,
+ subject string,
+ source string,
+) error {
+ if client == nil || userID <= 0 || subject == "" {
+ return nil
+ }
+
+ if strings.TrimSpace(source) == "" {
+ source = "auth_service_email_bind"
+ }
+
+ if err := client.AuthIdentity.Create().
+ SetUserID(userID).
+ SetProviderType("email").
+ SetProviderKey("email").
+ SetProviderSubject(subject).
+ SetVerifiedAt(time.Now().UTC()).
+ SetMetadata(map[string]any{"source": strings.TrimSpace(source)}).
+ OnConflictColumns(
+ authidentity.FieldProviderType,
+ authidentity.FieldProviderKey,
+ authidentity.FieldProviderSubject,
+ ).
+ DoNothing().
+ Exec(ctx); err != nil {
+ return err
+ }
+
+ identity, err := client.AuthIdentity.Query().
+ Where(
+ authidentity.ProviderTypeEQ("email"),
+ authidentity.ProviderKeyEQ("email"),
+ authidentity.ProviderSubjectEQ(subject),
+ ).
+ Only(ctx)
+ if err != nil {
+ if dbent.IsNotFound(err) {
+ return nil
+ }
+ return err
+ }
+ if identity.UserID != userID {
+ return ErrEmailExists
+ }
+ return nil
+}
+
+func normalizeBoundEmailAuthIdentitySubject(email string) string {
+ normalized := strings.ToLower(strings.TrimSpace(email))
+ if normalized == "" || isReservedEmail(normalized) {
+ return ""
+ }
+ return normalized
+}
diff --git a/backend/internal/service/auth_service_email_bind_test.go b/backend/internal/service/auth_service_email_bind_test.go
index 899a736d..fd5f499b 100644
--- a/backend/internal/service/auth_service_email_bind_test.go
+++ b/backend/internal/service/auth_service_email_bind_test.go
@@ -5,6 +5,7 @@ package service_test
import (
"context"
"database/sql"
+ "errors"
"testing"
"time"
@@ -34,6 +35,20 @@ func (s *emailBindDefaultSubAssignerStub) AssignOrExtendSubscription(
return &service.UserSubscription{UserID: input.UserID, GroupID: input.GroupID}, false, nil
}
+type flakyEmailBindDefaultSubAssignerStub struct {
+ err error
+ calls []*service.AssignSubscriptionInput
+}
+
+func (s *flakyEmailBindDefaultSubAssignerStub) AssignOrExtendSubscription(
+ _ context.Context,
+ input *service.AssignSubscriptionInput,
+) (*service.UserSubscription, bool, error) {
+ cloned := *input
+ s.calls = append(s.calls, &cloned)
+ return nil, false, s.err
+}
+
func newAuthServiceForEmailBind(
t *testing.T,
settings map[string]string,
@@ -187,6 +202,62 @@ func TestAuthServiceBindEmailIdentity_RejectsExistingEmailOnAnotherUser(t *testi
require.Equal(t, 0, countProviderGrantRecords(t, client, sourceUser.ID, "email", "first_bind"))
}
+func TestAuthServiceBindEmailIdentity_RollsBackWhenFirstBindDefaultsFail(t *testing.T) {
+ assigner := &flakyEmailBindDefaultSubAssignerStub{err: errors.New("temporary assign failure")}
+ cache := &emailBindCacheStub{
+ data: &service.VerificationCodeData{
+ Code: "123456",
+ CreatedAt: time.Now().UTC(),
+ ExpiresAt: time.Now().UTC().Add(10 * time.Minute),
+ },
+ }
+ svc, _, client := newAuthServiceForEmailBind(t, map[string]string{
+ service.SettingKeyAuthSourceDefaultEmailBalance: "8.5",
+ service.SettingKeyAuthSourceDefaultEmailConcurrency: "4",
+ service.SettingKeyAuthSourceDefaultEmailSubscriptions: `[{"group_id":11,"validity_days":30}]`,
+ service.SettingKeyAuthSourceDefaultEmailGrantOnFirstBind: "true",
+ }, cache, assigner)
+
+ ctx := context.Background()
+ originalEmail := "legacy-rollback" + service.LinuxDoConnectSyntheticEmailDomain
+ user, err := client.User.Create().
+ SetEmail(originalEmail).
+ SetUsername("legacy-rollback").
+ SetPasswordHash("old-hash").
+ SetBalance(2.5).
+ SetConcurrency(1).
+ SetRole(service.RoleUser).
+ SetStatus(service.StatusActive).
+ Save(ctx)
+ require.NoError(t, err)
+
+ updatedUser, err := svc.BindEmailIdentity(ctx, user.ID, "rollback@example.com", "123456", "new-password")
+ require.ErrorContains(t, err, "apply email first bind defaults")
+ require.ErrorContains(t, err, "temporary assign failure")
+ require.Nil(t, updatedUser)
+
+ storedUser, err := client.User.Get(ctx, user.ID)
+ require.NoError(t, err)
+ require.Equal(t, originalEmail, storedUser.Email)
+ require.Equal(t, "old-hash", storedUser.PasswordHash)
+ require.Equal(t, 2.5, storedUser.Balance)
+ require.Equal(t, 1, storedUser.Concurrency)
+
+ identityCount, err := client.AuthIdentity.Query().
+ Where(
+ authidentity.UserIDEQ(user.ID),
+ authidentity.ProviderTypeEQ("email"),
+ authidentity.ProviderKeyEQ("email"),
+ authidentity.ProviderSubjectEQ("rollback@example.com"),
+ ).
+ Count(ctx)
+ require.NoError(t, err)
+ require.Equal(t, 0, identityCount)
+
+ require.Len(t, assigner.calls, 1)
+ require.Equal(t, 0, countProviderGrantRecords(t, client, user.ID, "email", "first_bind"))
+}
+
func TestAuthServiceBindEmailIdentity_RejectsReservedEmail(t *testing.T) {
cache := &emailBindCacheStub{
data: &service.VerificationCodeData{
From 33b208ab6faba4144300b1a9728d673063c95ae5 Mon Sep 17 00:00:00 2001
From: IanShaw027
Date: Tue, 21 Apr 2026 11:00:18 +0800
Subject: [PATCH 098/189] fix: restore legacy oauth callback fragment
compatibility
---
.../src/views/auth/LinuxDoCallbackView.vue | 91 +++++++++++++++----
frontend/src/views/auth/OidcCallbackView.vue | 91 +++++++++++++++----
.../__tests__/LinuxDoCallbackView.spec.ts | 68 ++++++++++++++
.../auth/__tests__/OidcCallbackView.spec.ts | 68 ++++++++++++++
4 files changed, 282 insertions(+), 36 deletions(-)
diff --git a/frontend/src/views/auth/LinuxDoCallbackView.vue b/frontend/src/views/auth/LinuxDoCallbackView.vue
index 735c6582..a775075d 100644
--- a/frontend/src/views/auth/LinuxDoCallbackView.vue
+++ b/frontend/src/views/auth/LinuxDoCallbackView.vue
@@ -249,6 +249,7 @@ import {
login2FA,
persistOAuthTokenContext,
type OAuthAdoptionDecision,
+ type OAuthTokenResponse,
type PendingOAuthExchangeResponse
} from '@/api/auth'
@@ -278,6 +279,7 @@ const pendingAccountAction = ref<'none' | 'create_account' | 'bind_login'>('none
const pendingAccountEmail = ref('')
const bindLoginEmail = ref('')
const bindLoginPassword = ref('')
+const legacyPendingOAuthToken = ref('')
const accountActionError = ref('')
const canReturnToCreateAccount = ref(false)
const bindSuccessMessage = t('profile.authBindings.bindSuccess')
@@ -315,6 +317,30 @@ function parseFragmentParams(): URLSearchParams {
return new URLSearchParams(hash)
}
+function readLegacyFragmentLogin(params: URLSearchParams): OAuthTokenResponse | null {
+ const accessToken = params.get('access_token')?.trim() || ''
+ if (!accessToken) {
+ return null
+ }
+
+ const completion: OAuthTokenResponse = {
+ access_token: accessToken
+ }
+ const refreshToken = params.get('refresh_token')?.trim() || ''
+ if (refreshToken) {
+ completion.refresh_token = refreshToken
+ }
+ const expiresIn = Number.parseInt(params.get('expires_in')?.trim() || '', 10)
+ if (Number.isFinite(expiresIn) && expiresIn > 0) {
+ completion.expires_in = expiresIn
+ }
+ const tokenType = params.get('token_type')?.trim() || ''
+ if (tokenType) {
+ completion.token_type = tokenType
+ }
+ return completion
+}
+
function sanitizeRedirectPath(path: string | null | undefined): string {
if (!path) return '/dashboard'
if (!path.startsWith('/')) return '/dashboard'
@@ -521,10 +547,18 @@ async function handleSubmitInvitation() {
isSubmitting.value = true
try {
- const tokenData = await completeLinuxDoOAuthRegistration(
- invitationCode.value.trim(),
- currentAdoptionDecision()
- )
+ const tokenData = legacyPendingOAuthToken.value
+ ? (
+ await apiClient.post('/auth/oauth/linuxdo/complete-registration', {
+ pending_oauth_token: legacyPendingOAuthToken.value,
+ invitation_code: invitationCode.value.trim(),
+ ...serializeAdoptionDecision(currentAdoptionDecision())
+ })
+ ).data
+ : await completeLinuxDoOAuthRegistration(
+ invitationCode.value.trim(),
+ currentAdoptionDecision()
+ )
persistOAuthTokenContext(tokenData)
await authStore.setToken(tokenData.access_token)
appStore.showSuccess(t('auth.loginSuccess'))
@@ -621,51 +655,72 @@ async function handleSubmitTotpChallenge() {
onMounted(async () => {
const params = parseFragmentParams()
+ const legacyLogin = readLegacyFragmentLogin(params)
+ const legacyPendingToken = params.get('pending_oauth_token')?.trim() || ''
const error = params.get('error')
const errorDesc = params.get('error_description') || params.get('error_message') || ''
-
- if (error) {
- errorMessage.value = errorDesc || error
- appStore.showError(errorMessage.value)
- isProcessing.value = false
- return
- }
+ const redirect = sanitizeRedirectPath(
+ params.get('redirect') || (route.query.redirect as string | undefined) || '/dashboard'
+ )
try {
+ if (legacyLogin) {
+ persistOAuthTokenContext(legacyLogin)
+ await authStore.setToken(legacyLogin.access_token)
+ appStore.showSuccess(t('auth.loginSuccess'))
+ await router.replace(redirect)
+ return
+ }
+
+ if (error === 'invitation_required' && legacyPendingToken) {
+ legacyPendingOAuthToken.value = legacyPendingToken
+ redirectTo.value = redirect
+ needsInvitation.value = true
+ isProcessing.value = false
+ return
+ }
+
+ if (error) {
+ errorMessage.value = errorDesc || error
+ appStore.showError(errorMessage.value)
+ isProcessing.value = false
+ return
+ }
+
const completion = await exchangePendingOAuthCompletion()
- const redirect = sanitizeRedirectPath(
+ const completionRedirect = sanitizeRedirectPath(
completion.redirect || (route.query.redirect as string | undefined) || '/dashboard'
)
applyAdoptionSuggestionState(completion)
- redirectTo.value = redirect
+ redirectTo.value = completionRedirect
if (completion.error === 'invitation_required') {
needsInvitation.value = true
isProcessing.value = false
- persistPendingAuthSession(redirect)
+ persistPendingAuthSession(completionRedirect)
return
}
if (applyTotpChallenge(completion as LinuxDoPendingActionResponse)) {
- persistPendingAuthSession(redirect)
+ persistPendingAuthSession(completionRedirect)
return
}
applyPendingAccountAction(completion as LinuxDoPendingActionResponse)
if (pendingAccountAction.value !== 'none') {
isProcessing.value = false
- persistPendingAuthSession(redirect)
+ persistPendingAuthSession(completionRedirect)
return
}
if (adoptionRequired.value && hasSuggestedProfile(completion)) {
needsAdoptionConfirmation.value = true
isProcessing.value = false
- persistPendingAuthSession(redirect)
+ persistPendingAuthSession(completionRedirect)
return
}
- await finalizeCompletion(completion, redirect)
+ await finalizeCompletion(completion, completionRedirect)
} catch (e: unknown) {
clearPendingAuthSession()
errorMessage.value = getRequestErrorMessage(e, t('auth.loginFailed'))
diff --git a/frontend/src/views/auth/OidcCallbackView.vue b/frontend/src/views/auth/OidcCallbackView.vue
index 019cab54..e15e752f 100644
--- a/frontend/src/views/auth/OidcCallbackView.vue
+++ b/frontend/src/views/auth/OidcCallbackView.vue
@@ -259,6 +259,7 @@ import {
login2FA,
persistOAuthTokenContext,
type OAuthAdoptionDecision,
+ type OAuthTokenResponse,
type PendingOAuthExchangeResponse
} from '@/api/auth'
@@ -287,6 +288,7 @@ const pendingAccountAction = ref<'none' | 'create_account' | 'bind_login'>('none
const pendingAccountEmail = ref('')
const bindLoginEmail = ref('')
const bindLoginPassword = ref('')
+const legacyPendingOAuthToken = ref('')
const accountActionError = ref('')
const canReturnToCreateAccount = ref(false)
const bindSuccessMessage = t('profile.authBindings.bindSuccess')
@@ -331,6 +333,30 @@ function parseFragmentParams(): URLSearchParams {
return new URLSearchParams(hash)
}
+function readLegacyFragmentLogin(params: URLSearchParams): OAuthTokenResponse | null {
+ const accessToken = params.get('access_token')?.trim() || ''
+ if (!accessToken) {
+ return null
+ }
+
+ const completion: OAuthTokenResponse = {
+ access_token: accessToken
+ }
+ const refreshToken = params.get('refresh_token')?.trim() || ''
+ if (refreshToken) {
+ completion.refresh_token = refreshToken
+ }
+ const expiresIn = Number.parseInt(params.get('expires_in')?.trim() || '', 10)
+ if (Number.isFinite(expiresIn) && expiresIn > 0) {
+ completion.expires_in = expiresIn
+ }
+ const tokenType = params.get('token_type')?.trim() || ''
+ if (tokenType) {
+ completion.token_type = tokenType
+ }
+ return completion
+}
+
function sanitizeRedirectPath(path: string | null | undefined): string {
if (!path) return '/dashboard'
if (!path.startsWith('/')) return '/dashboard'
@@ -565,10 +591,18 @@ async function handleSubmitInvitation() {
isSubmitting.value = true
try {
- const tokenData = await completeOIDCOAuthRegistration(
- invitationCode.value.trim(),
- currentAdoptionDecision()
- )
+ const tokenData = legacyPendingOAuthToken.value
+ ? (
+ await apiClient.post('/auth/oauth/oidc/complete-registration', {
+ pending_oauth_token: legacyPendingOAuthToken.value,
+ invitation_code: invitationCode.value.trim(),
+ ...serializeAdoptionDecision(currentAdoptionDecision())
+ })
+ ).data
+ : await completeOIDCOAuthRegistration(
+ invitationCode.value.trim(),
+ currentAdoptionDecision()
+ )
persistOAuthTokenContext(tokenData)
await authStore.setToken(tokenData.access_token)
appStore.showSuccess(t('auth.loginSuccess'))
@@ -667,51 +701,72 @@ onMounted(async () => {
void loadProviderName()
const params = parseFragmentParams()
+ const legacyLogin = readLegacyFragmentLogin(params)
+ const legacyPendingToken = params.get('pending_oauth_token')?.trim() || ''
const error = params.get('error')
const errorDesc = params.get('error_description') || params.get('error_message') || ''
-
- if (error) {
- errorMessage.value = errorDesc || error
- appStore.showError(errorMessage.value)
- isProcessing.value = false
- return
- }
+ const redirect = sanitizeRedirectPath(
+ params.get('redirect') || (route.query.redirect as string | undefined) || '/dashboard'
+ )
try {
+ if (legacyLogin) {
+ persistOAuthTokenContext(legacyLogin)
+ await authStore.setToken(legacyLogin.access_token)
+ appStore.showSuccess(t('auth.loginSuccess'))
+ await router.replace(redirect)
+ return
+ }
+
+ if (error === 'invitation_required' && legacyPendingToken) {
+ legacyPendingOAuthToken.value = legacyPendingToken
+ redirectTo.value = redirect
+ needsInvitation.value = true
+ isProcessing.value = false
+ return
+ }
+
+ if (error) {
+ errorMessage.value = errorDesc || error
+ appStore.showError(errorMessage.value)
+ isProcessing.value = false
+ return
+ }
+
const completion = await exchangePendingOAuthCompletion() as PendingOidcCompletion
- const redirect = sanitizeRedirectPath(
+ const completionRedirect = sanitizeRedirectPath(
completion.redirect || (route.query.redirect as string | undefined) || '/dashboard'
)
applyAdoptionSuggestionState(completion)
- redirectTo.value = redirect
+ redirectTo.value = completionRedirect
if (completion.error === 'invitation_required') {
needsInvitation.value = true
isProcessing.value = false
- persistPendingAuthSession(redirect)
+ persistPendingAuthSession(completionRedirect)
return
}
if (applyTotpChallenge(completion)) {
- persistPendingAuthSession(redirect)
+ persistPendingAuthSession(completionRedirect)
return
}
applyPendingAccountAction(completion)
if (pendingAccountAction.value !== 'none') {
isProcessing.value = false
- persistPendingAuthSession(redirect)
+ persistPendingAuthSession(completionRedirect)
return
}
if (adoptionRequired.value && hasSuggestedProfile(completion)) {
needsAdoptionConfirmation.value = true
isProcessing.value = false
- persistPendingAuthSession(redirect)
+ persistPendingAuthSession(completionRedirect)
return
}
- await finalizeCompletion(completion, redirect)
+ await finalizeCompletion(completion, completionRedirect)
} catch (e: unknown) {
clearPendingAuthSession()
errorMessage.value = getRequestErrorMessage(e, t('auth.loginFailed'))
diff --git a/frontend/src/views/auth/__tests__/LinuxDoCallbackView.spec.ts b/frontend/src/views/auth/__tests__/LinuxDoCallbackView.spec.ts
index f612681a..0daf5d9a 100644
--- a/frontend/src/views/auth/__tests__/LinuxDoCallbackView.spec.ts
+++ b/frontend/src/views/auth/__tests__/LinuxDoCallbackView.spec.ts
@@ -86,6 +86,74 @@ describe('LinuxDoCallbackView', () => {
turnstile_enabled: false,
turnstile_site_key: ''
})
+ window.location.hash = ''
+ localStorage.clear()
+ })
+
+ it('accepts the legacy fragment token success callback without pending-session exchange', async () => {
+ window.location.hash =
+ '#access_token=legacy-access-token&refresh_token=legacy-refresh-token&expires_in=3600&token_type=Bearer&redirect=%2Flegacy-dashboard'
+ setToken.mockResolvedValue({})
+
+ mount(LinuxDoCallbackView, {
+ global: {
+ stubs: {
+ AuthLayout: { template: '
' },
+ Icon: true,
+ RouterLink: { template: ' ' },
+ transition: false
+ }
+ }
+ })
+
+ await flushPromises()
+
+ expect(exchangePendingOAuthCompletion).not.toHaveBeenCalled()
+ expect(setToken).toHaveBeenCalledWith('legacy-access-token')
+ expect(localStorage.getItem('refresh_token')).toBe('legacy-refresh-token')
+ expect(localStorage.getItem('token_expires_at')).not.toBeNull()
+ expect(showSuccess).toHaveBeenCalledWith('auth.loginSuccess')
+ expect(replace).toHaveBeenCalledWith('/legacy-dashboard')
+ })
+
+ it('accepts the legacy pending oauth invitation fragment without pending-session exchange', async () => {
+ window.location.hash = '#error=invitation_required&pending_oauth_token=legacy-pending-token&redirect=%2Flegacy-invite'
+ apiClientPost.mockResolvedValue({
+ data: {
+ access_token: 'legacy-access-token',
+ refresh_token: 'legacy-refresh-token',
+ expires_in: 3600,
+ token_type: 'Bearer'
+ }
+ })
+ setToken.mockResolvedValue({})
+
+ const wrapper = mount(LinuxDoCallbackView, {
+ global: {
+ stubs: {
+ AuthLayout: { template: '
' },
+ Icon: true,
+ RouterLink: { template: ' ' },
+ transition: false
+ }
+ }
+ })
+
+ await flushPromises()
+
+ expect(exchangePendingOAuthCompletion).not.toHaveBeenCalled()
+ await wrapper.find('input[type="text"]').setValue('invite-code')
+ await wrapper.find('button').trigger('click')
+ await flushPromises()
+
+ expect(apiClientPost).toHaveBeenCalledWith('/auth/oauth/linuxdo/complete-registration', {
+ adopt_display_name: true,
+ adopt_avatar: true,
+ pending_oauth_token: 'legacy-pending-token',
+ invitation_code: 'invite-code'
+ })
+ expect(setToken).toHaveBeenCalledWith('legacy-access-token')
+ expect(replace).toHaveBeenCalledWith('/legacy-invite')
})
it('does not send adoption decisions during the initial exchange', async () => {
diff --git a/frontend/src/views/auth/__tests__/OidcCallbackView.spec.ts b/frontend/src/views/auth/__tests__/OidcCallbackView.spec.ts
index 0edcb931..18128f17 100644
--- a/frontend/src/views/auth/__tests__/OidcCallbackView.spec.ts
+++ b/frontend/src/views/auth/__tests__/OidcCallbackView.spec.ts
@@ -92,6 +92,74 @@ describe('OidcCallbackView', () => {
turnstile_enabled: false,
turnstile_site_key: ''
})
+ window.location.hash = ''
+ localStorage.clear()
+ })
+
+ it('accepts the legacy fragment token success callback without pending-session exchange', async () => {
+ window.location.hash =
+ '#access_token=legacy-access-token&refresh_token=legacy-refresh-token&expires_in=3600&token_type=Bearer&redirect=%2Flegacy-dashboard'
+ setToken.mockResolvedValue({})
+
+ mount(OidcCallbackView, {
+ global: {
+ stubs: {
+ AuthLayout: { template: '
' },
+ Icon: true,
+ RouterLink: { template: ' ' },
+ transition: false
+ }
+ }
+ })
+
+ await flushPromises()
+
+ expect(exchangePendingOAuthCompletion).not.toHaveBeenCalled()
+ expect(setToken).toHaveBeenCalledWith('legacy-access-token')
+ expect(localStorage.getItem('refresh_token')).toBe('legacy-refresh-token')
+ expect(localStorage.getItem('token_expires_at')).not.toBeNull()
+ expect(showSuccess).toHaveBeenCalledWith('auth.loginSuccess')
+ expect(replace).toHaveBeenCalledWith('/legacy-dashboard')
+ })
+
+ it('accepts the legacy pending oauth invitation fragment without pending-session exchange', async () => {
+ window.location.hash = '#error=invitation_required&pending_oauth_token=legacy-pending-token&redirect=%2Flegacy-invite'
+ apiClientPost.mockResolvedValue({
+ data: {
+ access_token: 'legacy-access-token',
+ refresh_token: 'legacy-refresh-token',
+ expires_in: 3600,
+ token_type: 'Bearer'
+ }
+ })
+ setToken.mockResolvedValue({})
+
+ const wrapper = mount(OidcCallbackView, {
+ global: {
+ stubs: {
+ AuthLayout: { template: '
' },
+ Icon: true,
+ RouterLink: { template: ' ' },
+ transition: false
+ }
+ }
+ })
+
+ await flushPromises()
+
+ expect(exchangePendingOAuthCompletion).not.toHaveBeenCalled()
+ await wrapper.find('input[type="text"]').setValue('invite-code')
+ await wrapper.find('button').trigger('click')
+ await flushPromises()
+
+ expect(apiClientPost).toHaveBeenCalledWith('/auth/oauth/oidc/complete-registration', {
+ adopt_display_name: true,
+ adopt_avatar: true,
+ pending_oauth_token: 'legacy-pending-token',
+ invitation_code: 'invite-code'
+ })
+ expect(setToken).toHaveBeenCalledWith('legacy-access-token')
+ expect(replace).toHaveBeenCalledWith('/legacy-invite')
})
it('does not send adoption decisions during the initial exchange', async () => {
From 9742796ee727e678fabf1e4028640a764cbac895 Mon Sep 17 00:00:00 2001
From: IanShaw027
Date: Tue, 21 Apr 2026 11:41:02 +0800
Subject: [PATCH 099/189] fix: retire public payment verify and backfill trade
no
---
backend/internal/handler/payment_handler.go | 34 +--
.../handler/payment_handler_resume_test.go | 53 ++++
backend/internal/server/routes/payment.go | 5 +-
.../service/payment_order_lifecycle.go | 14 +-
.../service/payment_order_lifecycle_test.go | 251 ++++++++++++++++++
frontend/src/api/__tests__/payment.spec.ts | 36 +++
frontend/src/api/payment.ts | 5 -
.../user/__tests__/PaymentResultView.spec.ts | 8 -
8 files changed, 369 insertions(+), 37 deletions(-)
create mode 100644 backend/internal/service/payment_order_lifecycle_test.go
create mode 100644 frontend/src/api/__tests__/payment.spec.ts
diff --git a/backend/internal/handler/payment_handler.go b/backend/internal/handler/payment_handler.go
index 5fd6b43e..273aea73 100644
--- a/backend/internal/handler/payment_handler.go
+++ b/backend/internal/handler/payment_handler.go
@@ -2,6 +2,7 @@ package handler
import (
"fmt"
+ "net/http"
"strconv"
"strings"
@@ -459,29 +460,20 @@ type PublicOrderResult struct {
Status string `json:"status"`
}
-// VerifyOrderPublic verifies payment status without requiring authentication.
-// Returns limited order info (no user details) to prevent information leakage.
+var errPaymentPublicOrderVerifyRemoved = infraerrors.New(
+ http.StatusGone,
+ "PAYMENT_PUBLIC_ORDER_VERIFY_REMOVED",
+ "public payment order verification by out_trade_no has been removed; use resume_token recovery instead",
+).WithMetadata(map[string]string{
+ "replacement_endpoint": "/api/v1/payment/public/orders/resolve",
+ "replacement_field": "resume_token",
+})
+
+// VerifyOrderPublic is kept as a compatibility shim for the removed anonymous
+// out_trade_no lookup endpoint and always returns HTTP 410 Gone.
// POST /api/v1/payment/public/orders/verify
func (h *PaymentHandler) VerifyOrderPublic(c *gin.Context) {
- var req VerifyOrderRequest
- if err := c.ShouldBindJSON(&req); err != nil {
- response.BadRequest(c, "Invalid request: "+err.Error())
- return
- }
- order, err := h.paymentService.VerifyOrderPublic(c.Request.Context(), req.OutTradeNo)
- if err != nil {
- response.ErrorFrom(c, err)
- return
- }
- response.Success(c, PublicOrderResult{
- ID: order.ID,
- OutTradeNo: order.OutTradeNo,
- Amount: order.Amount,
- PayAmount: order.PayAmount,
- PaymentType: order.PaymentType,
- OrderType: order.OrderType,
- Status: order.Status,
- })
+ response.ErrorFrom(c, errPaymentPublicOrderVerifyRemoved)
}
// ResolveOrderPublicByResumeToken resolves a payment order from a signed resume token.
diff --git a/backend/internal/handler/payment_handler_resume_test.go b/backend/internal/handler/payment_handler_resume_test.go
index 323f7292..28da15d9 100644
--- a/backend/internal/handler/payment_handler_resume_test.go
+++ b/backend/internal/handler/payment_handler_resume_test.go
@@ -3,10 +3,24 @@
package handler
import (
+ "bytes"
+ "database/sql"
+ "encoding/json"
+ "net/http"
+ "net/http/httptest"
"testing"
+ dbent "github.com/Wei-Shaw/sub2api/ent"
+ "github.com/Wei-Shaw/sub2api/ent/enttest"
"github.com/Wei-Shaw/sub2api/internal/payment"
+ "github.com/Wei-Shaw/sub2api/internal/pkg/response"
"github.com/Wei-Shaw/sub2api/internal/service"
+ "github.com/gin-gonic/gin"
+ "github.com/stretchr/testify/require"
+
+ "entgo.io/ent/dialect"
+ entsql "entgo.io/ent/dialect/sql"
+ _ "modernc.org/sqlite"
)
func TestApplyWeChatPaymentResumeClaims(t *testing.T) {
@@ -59,3 +73,42 @@ func TestApplyWeChatPaymentResumeClaimsRejectsPaymentTypeMismatch(t *testing.T)
t.Fatal("applyWeChatPaymentResumeClaims should reject mismatched payment types")
}
}
+
+func TestVerifyOrderPublicReturnsGone(t *testing.T) {
+ t.Parallel()
+
+ gin.SetMode(gin.TestMode)
+
+ db, err := sql.Open("sqlite", "file:payment_handler_public_verify?mode=memory&cache=shared")
+ require.NoError(t, err)
+ t.Cleanup(func() { _ = db.Close() })
+
+ _, err = db.Exec("PRAGMA foreign_keys = ON")
+ require.NoError(t, err)
+
+ drv := entsql.OpenDB(dialect.SQLite, db)
+ client := enttest.NewClient(t, enttest.WithOptions(dbent.Driver(drv)))
+ t.Cleanup(func() { _ = client.Close() })
+
+ paymentSvc := service.NewPaymentService(client, payment.NewRegistry(), nil, nil, nil, nil, nil, nil)
+ h := NewPaymentHandler(paymentSvc, nil, nil)
+
+ recorder := httptest.NewRecorder()
+ ctx, _ := gin.CreateTestContext(recorder)
+ ctx.Request = httptest.NewRequest(
+ http.MethodPost,
+ "/api/v1/payment/public/orders/verify",
+ bytes.NewBufferString(`{"out_trade_no":"legacy-order-no"}`),
+ )
+ ctx.Request.Header.Set("Content-Type", "application/json")
+
+ h.VerifyOrderPublic(ctx)
+
+ require.Equal(t, http.StatusGone, recorder.Code)
+
+ var resp response.Response
+ require.NoError(t, json.Unmarshal(recorder.Body.Bytes(), &resp))
+ require.Equal(t, http.StatusGone, resp.Code)
+ require.Equal(t, "PAYMENT_PUBLIC_ORDER_VERIFY_REMOVED", resp.Reason)
+ require.Contains(t, resp.Message, "removed")
+}
diff --git a/backend/internal/server/routes/payment.go b/backend/internal/server/routes/payment.go
index dff14a70..ec340d94 100644
--- a/backend/internal/server/routes/payment.go
+++ b/backend/internal/server/routes/payment.go
@@ -44,8 +44,9 @@ func RegisterPaymentRoutes(
}
// --- Public payment endpoints (no auth) ---
- // Payment result page needs to verify order status without login
- // (user session may have expired during provider redirect).
+ // Signed resume-token recovery is the supported public lookup path.
+ // The legacy anonymous out_trade_no verify endpoint is kept only as a
+ // compatibility shim that returns HTTP 410 Gone.
public := v1.Group("/payment/public")
{
public.POST("/orders/verify", paymentHandler.VerifyOrderPublic)
diff --git a/backend/internal/service/payment_order_lifecycle.go b/backend/internal/service/payment_order_lifecycle.go
index 50a56ad0..1564c36d 100644
--- a/backend/internal/service/payment_order_lifecycle.go
+++ b/backend/internal/service/payment_order_lifecycle.go
@@ -151,7 +151,19 @@ func (s *PaymentService) checkPaid(ctx context.Context, o *dbent.PaymentOrder) s
return ""
}
if resp.Status == payment.ProviderStatusPaid {
- if err := s.HandlePaymentNotification(ctx, &payment.PaymentNotification{TradeNo: o.PaymentTradeNo, OrderID: o.OutTradeNo, Amount: resp.Amount, Status: payment.ProviderStatusSuccess}, prov.ProviderKey()); err != nil {
+ notificationTradeNo := o.PaymentTradeNo
+ if upstreamTradeNo := resp.TradeNo; upstreamTradeNo != "" && upstreamTradeNo != notificationTradeNo {
+ if _, updateErr := s.entClient.PaymentOrder.Update().
+ Where(paymentorder.IDEQ(o.ID)).
+ SetPaymentTradeNo(upstreamTradeNo).
+ Save(ctx); updateErr != nil {
+ slog.Error("persist upstream trade no during checkPaid failed", "orderID", o.ID, "tradeNo", upstreamTradeNo, "error", updateErr)
+ } else {
+ o.PaymentTradeNo = upstreamTradeNo
+ }
+ notificationTradeNo = upstreamTradeNo
+ }
+ if err := s.HandlePaymentNotification(ctx, &payment.PaymentNotification{TradeNo: notificationTradeNo, OrderID: o.OutTradeNo, Amount: resp.Amount, Status: payment.ProviderStatusSuccess}, prov.ProviderKey()); err != nil {
slog.Error("fulfillment failed during checkPaid", "orderID", o.ID, "error", err)
// Still return already_paid — order was paid, fulfillment can be retried
}
diff --git a/backend/internal/service/payment_order_lifecycle_test.go b/backend/internal/service/payment_order_lifecycle_test.go
new file mode 100644
index 00000000..3d4773a4
--- /dev/null
+++ b/backend/internal/service/payment_order_lifecycle_test.go
@@ -0,0 +1,251 @@
+//go:build unit
+
+package service
+
+import (
+ "context"
+ "database/sql"
+ "testing"
+ "time"
+
+ dbent "github.com/Wei-Shaw/sub2api/ent"
+ "github.com/Wei-Shaw/sub2api/ent/enttest"
+ "github.com/Wei-Shaw/sub2api/internal/payment"
+ "github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
+ "github.com/stretchr/testify/require"
+
+ "entgo.io/ent/dialect"
+ entsql "entgo.io/ent/dialect/sql"
+ _ "modernc.org/sqlite"
+)
+
+type paymentOrderLifecycleQueryProvider struct {
+ lastQueryTradeNo string
+ resp *payment.QueryOrderResponse
+}
+
+type paymentOrderLifecycleRedeemRepo struct {
+ codesByCode map[string]*RedeemCode
+ useCalls []struct {
+ id int64
+ userID int64
+ }
+}
+
+func (p *paymentOrderLifecycleQueryProvider) Name() string {
+ return "payment-order-lifecycle-query-provider"
+}
+
+func (p *paymentOrderLifecycleQueryProvider) ProviderKey() string { return payment.TypeAlipay }
+
+func (p *paymentOrderLifecycleQueryProvider) SupportedTypes() []payment.PaymentType {
+ return []payment.PaymentType{payment.TypeAlipay}
+}
+
+func (p *paymentOrderLifecycleQueryProvider) CreatePayment(context.Context, payment.CreatePaymentRequest) (*payment.CreatePaymentResponse, error) {
+ panic("unexpected call")
+}
+
+func (p *paymentOrderLifecycleQueryProvider) QueryOrder(_ context.Context, tradeNo string) (*payment.QueryOrderResponse, error) {
+ p.lastQueryTradeNo = tradeNo
+ return p.resp, nil
+}
+
+func (p *paymentOrderLifecycleQueryProvider) VerifyNotification(context.Context, string, map[string]string) (*payment.PaymentNotification, error) {
+ panic("unexpected call")
+}
+
+func (p *paymentOrderLifecycleQueryProvider) Refund(context.Context, payment.RefundRequest) (*payment.RefundResponse, error) {
+ panic("unexpected call")
+}
+
+func (r *paymentOrderLifecycleRedeemRepo) Create(context.Context, *RedeemCode) error {
+ panic("unexpected call")
+}
+
+func (r *paymentOrderLifecycleRedeemRepo) CreateBatch(context.Context, []RedeemCode) error {
+ panic("unexpected call")
+}
+
+func (r *paymentOrderLifecycleRedeemRepo) GetByID(_ context.Context, id int64) (*RedeemCode, error) {
+ for _, code := range r.codesByCode {
+ if code.ID != id {
+ continue
+ }
+ cloned := *code
+ return &cloned, nil
+ }
+ return nil, ErrRedeemCodeNotFound
+}
+
+func (r *paymentOrderLifecycleRedeemRepo) GetByCode(_ context.Context, code string) (*RedeemCode, error) {
+ redeemCode, ok := r.codesByCode[code]
+ if !ok {
+ return nil, ErrRedeemCodeNotFound
+ }
+ cloned := *redeemCode
+ return &cloned, nil
+}
+
+func (r *paymentOrderLifecycleRedeemRepo) Update(context.Context, *RedeemCode) error {
+ panic("unexpected call")
+}
+
+func (r *paymentOrderLifecycleRedeemRepo) Delete(context.Context, int64) error {
+ panic("unexpected call")
+}
+
+func (r *paymentOrderLifecycleRedeemRepo) Use(_ context.Context, id, userID int64) error {
+ for code, redeemCode := range r.codesByCode {
+ if redeemCode.ID != id {
+ continue
+ }
+ now := time.Now().UTC()
+ redeemCode.Status = StatusUsed
+ redeemCode.UsedBy = &userID
+ redeemCode.UsedAt = &now
+ r.codesByCode[code] = redeemCode
+ r.useCalls = append(r.useCalls, struct {
+ id int64
+ userID int64
+ }{id: id, userID: userID})
+ return nil
+ }
+ return ErrRedeemCodeNotFound
+}
+
+func (r *paymentOrderLifecycleRedeemRepo) List(context.Context, pagination.PaginationParams) ([]RedeemCode, *pagination.PaginationResult, error) {
+ panic("unexpected call")
+}
+
+func (r *paymentOrderLifecycleRedeemRepo) ListWithFilters(context.Context, pagination.PaginationParams, string, string, string) ([]RedeemCode, *pagination.PaginationResult, error) {
+ panic("unexpected call")
+}
+
+func (r *paymentOrderLifecycleRedeemRepo) ListByUser(context.Context, int64, int) ([]RedeemCode, error) {
+ panic("unexpected call")
+}
+
+func (r *paymentOrderLifecycleRedeemRepo) ListByUserPaginated(context.Context, int64, pagination.PaginationParams, string) ([]RedeemCode, *pagination.PaginationResult, error) {
+ panic("unexpected call")
+}
+
+func (r *paymentOrderLifecycleRedeemRepo) SumPositiveBalanceByUser(context.Context, int64) (float64, error) {
+ panic("unexpected call")
+}
+
+func TestVerifyOrderByOutTradeNoBackfillsTradeNoFromPaidQuery(t *testing.T) {
+ ctx := context.Background()
+ client := newPaymentOrderLifecycleTestClient(t)
+
+ user, err := client.User.Create().
+ SetEmail("checkpaid@example.com").
+ SetPasswordHash("hash").
+ SetUsername("checkpaid-user").
+ Save(ctx)
+ require.NoError(t, err)
+
+ order, err := client.PaymentOrder.Create().
+ SetUserID(user.ID).
+ SetUserEmail(user.Email).
+ SetUserName(user.Username).
+ SetAmount(88).
+ SetPayAmount(88).
+ SetFeeRate(0).
+ SetRechargeCode("CHECKPAID-UPSTREAM-TRADE-NO").
+ SetOutTradeNo("sub2_checkpaid_trade_no_missing").
+ SetPaymentType(payment.TypeAlipay).
+ SetPaymentTradeNo("").
+ SetOrderType(payment.OrderTypeBalance).
+ SetStatus(OrderStatusPending).
+ SetExpiresAt(time.Now().Add(time.Hour)).
+ SetClientIP("127.0.0.1").
+ SetSrcHost("api.example.com").
+ Save(ctx)
+ require.NoError(t, err)
+
+ userRepo := &mockUserRepo{
+ getByIDUser: &User{
+ ID: user.ID,
+ Email: user.Email,
+ Username: user.Username,
+ Balance: 0,
+ },
+ }
+ userRepo.updateBalanceFn = func(ctx context.Context, id int64, amount float64) error {
+ require.Equal(t, user.ID, id)
+ if userRepo.getByIDUser != nil {
+ userRepo.getByIDUser.Balance += amount
+ }
+ return nil
+ }
+ redeemRepo := &paymentOrderLifecycleRedeemRepo{
+ codesByCode: map[string]*RedeemCode{
+ order.RechargeCode: {
+ ID: 1,
+ Code: order.RechargeCode,
+ Type: RedeemTypeBalance,
+ Value: order.Amount,
+ Status: StatusUnused,
+ },
+ },
+ }
+ redeemService := NewRedeemService(
+ redeemRepo,
+ userRepo,
+ nil,
+ nil,
+ nil,
+ client,
+ nil,
+ )
+ registry := payment.NewRegistry()
+ provider := &paymentOrderLifecycleQueryProvider{
+ resp: &payment.QueryOrderResponse{
+ TradeNo: "upstream-trade-123",
+ Status: payment.ProviderStatusPaid,
+ Amount: 88,
+ },
+ }
+ registry.Register(provider)
+
+ svc := &PaymentService{
+ entClient: client,
+ registry: registry,
+ redeemService: redeemService,
+ userRepo: userRepo,
+ providersLoaded: true,
+ }
+
+ got, err := svc.VerifyOrderByOutTradeNo(ctx, order.OutTradeNo, user.ID)
+ require.NoError(t, err)
+ require.Equal(t, order.OutTradeNo, provider.lastQueryTradeNo)
+ require.Equal(t, OrderStatusCompleted, got.Status)
+ require.Equal(t, "upstream-trade-123", got.PaymentTradeNo)
+
+ reloaded, err := client.PaymentOrder.Get(ctx, order.ID)
+ require.NoError(t, err)
+ require.Equal(t, OrderStatusCompleted, reloaded.Status)
+ require.Equal(t, "upstream-trade-123", reloaded.PaymentTradeNo)
+
+ require.Equal(t, 88.0, userRepo.getByIDUser.Balance)
+ require.Len(t, redeemRepo.useCalls, 1)
+ require.Equal(t, int64(1), redeemRepo.useCalls[0].id)
+ require.Equal(t, user.ID, redeemRepo.useCalls[0].userID)
+}
+
+func newPaymentOrderLifecycleTestClient(t *testing.T) *dbent.Client {
+ t.Helper()
+
+ db, err := sql.Open("sqlite", "file:payment_order_lifecycle?mode=memory&cache=shared&_fk=1")
+ require.NoError(t, err)
+ t.Cleanup(func() { _ = db.Close() })
+
+ _, err = db.Exec("PRAGMA foreign_keys = ON")
+ require.NoError(t, err)
+
+ drv := entsql.OpenDB(dialect.SQLite, db)
+ client := enttest.NewClient(t, enttest.WithOptions(dbent.Driver(drv)))
+ t.Cleanup(func() { _ = client.Close() })
+ return client
+}
diff --git a/frontend/src/api/__tests__/payment.spec.ts b/frontend/src/api/__tests__/payment.spec.ts
new file mode 100644
index 00000000..3006484e
--- /dev/null
+++ b/frontend/src/api/__tests__/payment.spec.ts
@@ -0,0 +1,36 @@
+import { beforeEach, describe, expect, it, vi } from 'vitest'
+
+const { get, post } = vi.hoisted(() => ({
+ get: vi.fn(),
+ post: vi.fn(),
+}))
+
+vi.mock('@/api/client', () => ({
+ apiClient: {
+ get,
+ post,
+ },
+}))
+
+import { paymentAPI } from '@/api/payment'
+
+describe('payment api', () => {
+ beforeEach(() => {
+ get.mockReset()
+ post.mockReset()
+ get.mockResolvedValue({ data: {} })
+ post.mockResolvedValue({ data: {} })
+ })
+
+ it('does not expose anonymous public out_trade_no verification', () => {
+ expect(Object.prototype.hasOwnProperty.call(paymentAPI, 'verifyOrderPublic')).toBe(false)
+ })
+
+ it('keeps signed public resume-token resolve endpoint', async () => {
+ await paymentAPI.resolveOrderPublicByResumeToken('resume-token-123')
+
+ expect(post).toHaveBeenCalledWith('/payment/public/orders/resolve', {
+ resume_token: 'resume-token-123',
+ })
+ })
+})
diff --git a/frontend/src/api/payment.ts b/frontend/src/api/payment.ts
index 91b16866..e866e184 100644
--- a/frontend/src/api/payment.ts
+++ b/frontend/src/api/payment.ts
@@ -67,11 +67,6 @@ export const paymentAPI = {
return apiClient.post('/payment/orders/verify', { out_trade_no: outTradeNo })
},
- /** Verify order payment status without auth (public endpoint for result page) */
- verifyOrderPublic(outTradeNo: string) {
- return apiClient.post('/payment/public/orders/verify', { out_trade_no: outTradeNo })
- },
-
/** Resolve an order from a signed resume token without auth */
resolveOrderPublicByResumeToken(resumeToken: string) {
return apiClient.post('/payment/public/orders/resolve', { resume_token: resumeToken })
diff --git a/frontend/src/views/user/__tests__/PaymentResultView.spec.ts b/frontend/src/views/user/__tests__/PaymentResultView.spec.ts
index 56e64793..34ced07a 100644
--- a/frontend/src/views/user/__tests__/PaymentResultView.spec.ts
+++ b/frontend/src/views/user/__tests__/PaymentResultView.spec.ts
@@ -7,7 +7,6 @@ const routeState = vi.hoisted(() => ({
const routerPush = vi.hoisted(() => vi.fn())
const pollOrderStatus = vi.hoisted(() => vi.fn())
-const verifyOrderPublic = vi.hoisted(() => vi.fn())
const verifyOrder = vi.hoisted(() => vi.fn())
const resolveOrderPublicByResumeToken = vi.hoisted(() => vi.fn())
@@ -38,7 +37,6 @@ vi.mock('@/stores/payment', () => ({
vi.mock('@/api/payment', () => ({
paymentAPI: {
- verifyOrderPublic,
verifyOrder,
resolveOrderPublicByResumeToken,
},
@@ -67,7 +65,6 @@ describe('PaymentResultView', () => {
routeState.query = {}
routerPush.mockReset()
pollOrderStatus.mockReset()
- verifyOrderPublic.mockReset()
verifyOrder.mockReset()
resolveOrderPublicByResumeToken.mockReset()
window.localStorage.clear()
@@ -110,7 +107,6 @@ describe('PaymentResultView', () => {
await flushPromises()
expect(pollOrderStatus).toHaveBeenCalledWith(42)
- expect(verifyOrderPublic).not.toHaveBeenCalled()
expect(wrapper.text()).toContain('payment.result.processing')
expect(wrapper.text()).not.toContain('payment.result.success')
expect(wrapper.text()).not.toContain('payment.result.failed')
@@ -221,7 +217,6 @@ describe('PaymentResultView', () => {
await flushPromises()
expect(resolveOrderPublicByResumeToken).toHaveBeenCalledWith('resume-fail')
- expect(verifyOrderPublic).not.toHaveBeenCalled()
expect(verifyOrder).not.toHaveBeenCalled()
})
@@ -241,7 +236,6 @@ describe('PaymentResultView', () => {
await flushPromises()
- expect(verifyOrderPublic).not.toHaveBeenCalled()
expect(verifyOrder).not.toHaveBeenCalled()
})
@@ -260,7 +254,6 @@ describe('PaymentResultView', () => {
await flushPromises()
- expect(verifyOrderPublic).not.toHaveBeenCalled()
expect(verifyOrder).not.toHaveBeenCalled()
})
@@ -284,7 +277,6 @@ describe('PaymentResultView', () => {
expect(resolveOrderPublicByResumeToken).toHaveBeenCalledWith('resume-77')
expect(wrapper.text()).toContain('payment.result.success')
- expect(verifyOrderPublic).not.toHaveBeenCalled()
})
it('normalizes aliased payment methods before rendering the label', async () => {
From 440536a93da91405472cd0072728b70ad4d9a733 Mon Sep 17 00:00:00 2001
From: IanShaw027
Date: Tue, 21 Apr 2026 11:41:14 +0800
Subject: [PATCH 100/189] docs: align wechat payment required fields
---
docs/PAYMENT.md | 4 ++--
docs/PAYMENT_CN.md | 4 ++--
2 files changed, 4 insertions(+), 4 deletions(-)
diff --git a/docs/PAYMENT.md b/docs/PAYMENT.md
index 2cc8f566..9322f7bf 100644
--- a/docs/PAYMENT.md
+++ b/docs/PAYMENT.md
@@ -141,8 +141,8 @@ Direct integration with WeChat Pay APIv3. Supports Native QR code payment, H5 pa
| **Merchant API Private Key** | Merchant API private key (PEM format) | Yes |
| **APIv3 Key** | 32-byte APIv3 key | Yes |
| **WeChat Pay Public Key** | WeChat Pay public key (PEM format) | Yes |
-| **WeChat Pay Public Key ID** | WeChat Pay public key ID | No |
-| **Certificate Serial Number** | Merchant certificate serial number | No |
+| **WeChat Pay Public Key ID** | WeChat Pay public key ID | Yes |
+| **Certificate Serial Number** | Merchant certificate serial number | Yes |
### Stripe
diff --git a/docs/PAYMENT_CN.md b/docs/PAYMENT_CN.md
index 95bfb990..0fbc198a 100644
--- a/docs/PAYMENT_CN.md
+++ b/docs/PAYMENT_CN.md
@@ -141,8 +141,8 @@ Sub2API 内置支付系统,支持用户自助充值,无需部署独立的支
| **商户 API 私钥** | 商户 API 私钥(PEM 格式) | 是 |
| **APIv3 密钥** | 32 位 APIv3 密钥 | 是 |
| **微信支付公钥** | 微信支付公钥(PEM 格式) | 是 |
-| **微信支付公钥 ID** | 微信支付公钥 ID | 否 |
-| **商户证书序列号** | 商户证书序列号 | 否 |
+| **微信支付公钥 ID** | 微信支付公钥 ID | 是 |
+| **商户证书序列号** | 商户证书序列号 | 是 |
### Stripe
From 960b2bb8e61ed33a7679a00eec5faff2490b5637 Mon Sep 17 00:00:00 2001
From: shaw
Date: Tue, 21 Apr 2026 11:14:40 +0800
Subject: [PATCH 101/189] feat(legal): add CLA with automated GitHub Actions
enforcement
Introduce Individual Contributor License Agreement (ICLA) to enable
dual licensing (LGPL-V3 open source + future closed-source releases).
- CLA.md: Apache ICLA-style license grant with moral rights waiver,
patent license, electronic signature clause, and assignability
- .github/workflows/cla.yml: CLA Assistant Lite bot that auto-checks
PRs, posts signing prompts, and stores signatures on a separate
`cla-signatures` branch to keep main branch history clean
---
.github/workflows/cla.yml | 59 +++++++++++++++++++++++++++++++
CLA.md | 73 +++++++++++++++++++++++++++++++++++++++
2 files changed, 132 insertions(+)
create mode 100644 .github/workflows/cla.yml
create mode 100644 CLA.md
diff --git a/.github/workflows/cla.yml b/.github/workflows/cla.yml
new file mode 100644
index 00000000..67c8d6e9
--- /dev/null
+++ b/.github/workflows/cla.yml
@@ -0,0 +1,59 @@
+name: "CLA Assistant"
+
+on:
+ issue_comment:
+ types: [created]
+ pull_request_target:
+ types: [opened, reopened, closed, synchronize]
+
+permissions:
+ actions: write
+ contents: write
+ pull-requests: write
+ statuses: write
+
+jobs:
+ cla-check:
+ if: |
+ github.event_name == 'issue_comment' ||
+ (github.event_name == 'pull_request_target' && github.event.action != 'closed')
+ runs-on: ubuntu-latest
+ steps:
+ - name: "CLA Assistant"
+ if: |
+ (github.event.comment.body == 'recheck' ||
+ github.event.comment.body == 'I have read the CLA Document and I hereby sign the CLA') ||
+ github.event_name == 'pull_request_target'
+ uses: contributor-assistant/github-action@v2.6.1
+ env:
+ GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
+ with:
+ path-to-signatures: "cla.json"
+ path-to-document: "https://github.com/Wei-Shaw/sub2api/blob/main/CLA.md"
+ branch: "cla-signatures"
+ allowlist: "dependabot[bot],renovate[bot],bot*"
+ lock-pullrequest-aftermerge: false
+ custom-notsigned-prcomment: |
+ Thank you for your contribution! Before we can merge this PR, we need $you to sign our [Contributor License Agreement (CLA)](https://github.com/Wei-Shaw/sub2api/blob/main/CLA.md).
+
+ **To sign**, please reply with the following comment:
+
+ > I have read the CLA Document and I hereby sign the CLA
+
+ You only need to sign once — it will be valid for all your future contributions to this project.
+ custom-pr-sign-comment: "I have read the CLA Document and I hereby sign the CLA"
+ custom-allsigned-prcomment: "All contributors have signed the CLA. ✅"
+
+ cla-lock:
+ if: github.event_name == 'pull_request_target' && github.event.action == 'closed' && github.event.pull_request.merged == true
+ runs-on: ubuntu-latest
+ steps:
+ - name: "Lock merged PR"
+ uses: contributor-assistant/github-action@v2.6.1
+ env:
+ GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
+ with:
+ path-to-signatures: "cla.json"
+ path-to-document: "https://github.com/Wei-Shaw/sub2api/blob/main/CLA.md"
+ branch: "cla-signatures"
+ lock-pullrequest-aftermerge: true
diff --git a/CLA.md b/CLA.md
new file mode 100644
index 00000000..ed0d74b8
--- /dev/null
+++ b/CLA.md
@@ -0,0 +1,73 @@
+# Sub2API Individual Contributor License Agreement (v1.0)
+
+Thank you for your interest in contributing to Sub2API ("the Project"). This Contributor License Agreement ("Agreement") documents the rights granted by contributors to the Project.
+
+By signing this Agreement, you accept and agree to the following terms and conditions for your present and future contributions submitted to the Project.
+
+## 1. Definitions
+
+- **"You" (or "Your")** means the copyright owner or legal entity authorized by the copyright owner that is making this Agreement.
+- **"Contribution"** means any original work of authorship, including any modifications or additions to an existing work, that is intentionally submitted by You to the Project for inclusion in, or documentation of, any of the products owned or managed by the Project. For the purposes of this definition, "submitted" means any form of electronic, verbal, or written communication sent to the Project or its representatives, including but not limited to communication on electronic mailing lists, source code control systems, and issue tracking systems that are managed by, or on behalf of, the Project for the purpose of discussing and improving the Project, but excluding communication that is conspicuously marked or otherwise designated in writing by You as "Not a Contribution."
+- **"Project Owner"** means Wesley Liddick, or any individual or legal entity to whom Wesley Liddick has explicitly assigned or transferred ownership of the Project in writing, and their respective successors and assigns.
+
+## 2. Grant of Copyright License
+
+Subject to the terms and conditions of this Agreement, You hereby grant to the Project Owner a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable copyright license to reproduce, prepare derivative works of, publicly display, publicly perform, sublicense, and distribute Your Contributions and such derivative works. This license includes, without limitation, the right to sublicense, assign, and transfer these rights to any third party, including without limitation any successor, assignee, or acquiring entity of the Project or the Project Owner, and to use Your Contributions under any license, including proprietary or commercial licenses.
+
+## 3. Moral Rights
+
+To the fullest extent permitted by applicable law, You irrevocably waive and agree not to assert any moral rights (including rights of attribution and integrity) that You may have in Your Contributions, and agree that the Project Owner and its licensees may use, modify, and distribute Your Contributions without attribution or other obligations arising from moral rights.
+
+## 4. Grant of Patent License
+
+Subject to the terms and conditions of this Agreement, You hereby grant to the Project Owner a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable patent license to make, have made, use, offer to sell, sell, import, and otherwise transfer Your Contributions, where such license applies only to those patent claims licensable by You that are necessarily infringed by Your Contribution(s) alone or by combination of Your Contribution(s) with the Project to which such Contribution(s) was submitted.
+
+## 5. Representations and Warranties
+
+You represent and warrant that:
+
+(a) You are legally entitled to grant the above licenses.
+
+(b) If Your employer(s) has rights to intellectual property that You create that includes Your Contributions, You have received permission to make Contributions on behalf of that employer, or that Your employer has waived such rights for Your Contributions to the Project.
+
+(c) Each of Your Contributions is Your original creation, or You have sufficient rights to submit it under the terms of this Agreement. You agree to provide, upon request, reasonable documentation or explanation of any third-party materials included in Your Contributions.
+
+## 6. No Warranty
+
+Your Contributions are provided on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied, including, without limitation, any warranties or conditions of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A PARTICULAR PURPOSE. You are not expected to provide support for Your Contributions, except to the extent You desire to provide support.
+
+## 7. No Obligation
+
+You understand that the decision to include Your Contribution in any product or project is entirely at the discretion of the Project Owner, and this Agreement does not obligate the Project Owner to use Your Contribution.
+
+## 8. Retention of Rights
+
+You retain ownership of the copyright in Your Contributions. This Agreement does not transfer any copyright or other intellectual property rights from You to the Project Owner. This Agreement only grants the licenses described above.
+
+## 9. Term and Termination
+
+This Agreement shall remain in effect indefinitely. You may terminate this Agreement prospectively by providing written notice to the Project Owner, but such termination shall not affect the licenses granted for Contributions submitted prior to the effective date of termination. The licenses granted herein for Contributions submitted prior to termination are perpetual and irrevocable.
+
+## 10. Electronic Signature
+
+You agree that Your electronic signature (including but not limited to typing a specific phrase in a pull request, issue, or other electronic communication) is legally binding and has the same force and effect as a handwritten signature. You consent to the use of electronic means to enter into this Agreement and acknowledge that this Agreement is enforceable as if executed in a traditional written format.
+
+## 11. General Provisions
+
+**Entire Agreement.** This Agreement constitutes the entire agreement between You and the Project Owner with respect to Your Contributions and supersedes all prior or contemporaneous understandings regarding such subject matter.
+
+**Severability.** If any provision of this Agreement is held to be unenforceable or invalid, that provision will be enforced to the maximum extent possible and the remaining provisions will remain in full force and effect.
+
+**No Waiver.** The failure of the Project Owner to enforce any provision of this Agreement shall not constitute a waiver of that provision or any other provision.
+
+**Amendment.** This Agreement may only be modified by a written instrument signed by both parties. Modifications to this Agreement apply only to Contributions submitted after the modified Agreement is published and accepted by You. Prior Contributions remain governed by the version of the Agreement in effect at the time of submission.
+
+**Notification.** Notices under this Agreement shall be sent to the Project Owner via a GitHub issue on the Project repository. Notices are effective upon receipt.
+
+---
+
+**By signing this CLA, you acknowledge that you have read and understood this Agreement and agree to be bound by its terms.**
+
+To sign, reply in the pull request with:
+
+> I have read the CLA Document and I hereby sign the CLA
From 561405ab0033ef795cdc1fc752fabcfcbc333b65 Mon Sep 17 00:00:00 2001
From: IanShaw027
Date: Tue, 21 Apr 2026 12:41:27 +0800
Subject: [PATCH 102/189] feat: add payment order provider snapshots
---
backend/ent/migrate/schema.go | 15 +--
backend/ent/mutation.go | 75 ++++++++++-
backend/ent/paymentorder.go | 16 +++
backend/ent/paymentorder/paymentorder.go | 3 +
backend/ent/paymentorder/where.go | 10 ++
backend/ent/paymentorder_create.go | 70 +++++++++++
backend/ent/paymentorder_update.go | 36 ++++++
backend/ent/runtime/runtime.go | 16 +--
backend/ent/schema/payment_order.go | 3 +
.../internal/handler/admin/payment_handler.go | 25 +++-
backend/internal/handler/payment_handler.go | 27 +++-
backend/internal/service/payment_order.go | 49 +++++++-
.../payment_order_provider_snapshot_test.go | 116 ++++++++++++++++++
...17_add_payment_order_provider_snapshot.sql | 2 +
14 files changed, 440 insertions(+), 23 deletions(-)
create mode 100644 backend/internal/service/payment_order_provider_snapshot_test.go
create mode 100644 backend/migrations/117_add_payment_order_provider_snapshot.sql
diff --git a/backend/ent/migrate/schema.go b/backend/ent/migrate/schema.go
index 230ea060..81f6a664 100644
--- a/backend/ent/migrate/schema.go
+++ b/backend/ent/migrate/schema.go
@@ -655,6 +655,7 @@ var (
{Name: "subscription_days", Type: field.TypeInt, Nullable: true},
{Name: "provider_instance_id", Type: field.TypeString, Nullable: true, Size: 64},
{Name: "provider_key", Type: field.TypeString, Nullable: true, Size: 30},
+ {Name: "provider_snapshot", Type: field.TypeJSON, Nullable: true, SchemaType: map[string]string{"postgres": "jsonb"}},
{Name: "status", Type: field.TypeString, Size: 30, Default: "PENDING"},
{Name: "refund_amount", Type: field.TypeFloat64, Default: 0, SchemaType: map[string]string{"postgres": "decimal(20,2)"}},
{Name: "refund_reason", Type: field.TypeString, Nullable: true, SchemaType: map[string]string{"postgres": "text"}},
@@ -683,7 +684,7 @@ var (
ForeignKeys: []*schema.ForeignKey{
{
Symbol: "payment_orders_users_payment_orders",
- Columns: []*schema.Column{PaymentOrdersColumns[38]},
+ Columns: []*schema.Column{PaymentOrdersColumns[39]},
RefColumns: []*schema.Column{UsersColumns[0]},
OnDelete: schema.NoAction,
},
@@ -697,32 +698,32 @@ var (
{
Name: "paymentorder_user_id",
Unique: false,
- Columns: []*schema.Column{PaymentOrdersColumns[38]},
+ Columns: []*schema.Column{PaymentOrdersColumns[39]},
},
{
Name: "paymentorder_status",
Unique: false,
- Columns: []*schema.Column{PaymentOrdersColumns[20]},
+ Columns: []*schema.Column{PaymentOrdersColumns[21]},
},
{
Name: "paymentorder_expires_at",
Unique: false,
- Columns: []*schema.Column{PaymentOrdersColumns[28]},
+ Columns: []*schema.Column{PaymentOrdersColumns[29]},
},
{
Name: "paymentorder_created_at",
Unique: false,
- Columns: []*schema.Column{PaymentOrdersColumns[36]},
+ Columns: []*schema.Column{PaymentOrdersColumns[37]},
},
{
Name: "paymentorder_paid_at",
Unique: false,
- Columns: []*schema.Column{PaymentOrdersColumns[29]},
+ Columns: []*schema.Column{PaymentOrdersColumns[30]},
},
{
Name: "paymentorder_payment_type_paid_at",
Unique: false,
- Columns: []*schema.Column{PaymentOrdersColumns[9], PaymentOrdersColumns[29]},
+ Columns: []*schema.Column{PaymentOrdersColumns[9], PaymentOrdersColumns[30]},
},
{
Name: "paymentorder_order_type",
diff --git a/backend/ent/mutation.go b/backend/ent/mutation.go
index 5227015c..ec4a4070 100644
--- a/backend/ent/mutation.go
+++ b/backend/ent/mutation.go
@@ -15386,6 +15386,7 @@ type PaymentOrderMutation struct {
addsubscription_days *int
provider_instance_id *string
provider_key *string
+ provider_snapshot *map[string]interface{}
status *string
refund_amount *float64
addrefund_amount *float64
@@ -16471,6 +16472,55 @@ func (m *PaymentOrderMutation) ResetProviderKey() {
delete(m.clearedFields, paymentorder.FieldProviderKey)
}
+// SetProviderSnapshot sets the "provider_snapshot" field.
+func (m *PaymentOrderMutation) SetProviderSnapshot(value map[string]interface{}) {
+ m.provider_snapshot = &value
+}
+
+// ProviderSnapshot returns the value of the "provider_snapshot" field in the mutation.
+func (m *PaymentOrderMutation) ProviderSnapshot() (r map[string]interface{}, exists bool) {
+ v := m.provider_snapshot
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// OldProviderSnapshot returns the old "provider_snapshot" field's value of the PaymentOrder entity.
+// If the PaymentOrder object wasn't provided to the builder, the object is fetched from the database.
+// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
+func (m *PaymentOrderMutation) OldProviderSnapshot(ctx context.Context) (v map[string]interface{}, err error) {
+ if !m.op.Is(OpUpdateOne) {
+ return v, errors.New("OldProviderSnapshot is only allowed on UpdateOne operations")
+ }
+ if m.id == nil || m.oldValue == nil {
+ return v, errors.New("OldProviderSnapshot requires an ID field in the mutation")
+ }
+ oldValue, err := m.oldValue(ctx)
+ if err != nil {
+ return v, fmt.Errorf("querying old value for OldProviderSnapshot: %w", err)
+ }
+ return oldValue.ProviderSnapshot, nil
+}
+
+// ClearProviderSnapshot clears the value of the "provider_snapshot" field.
+func (m *PaymentOrderMutation) ClearProviderSnapshot() {
+ m.provider_snapshot = nil
+ m.clearedFields[paymentorder.FieldProviderSnapshot] = struct{}{}
+}
+
+// ProviderSnapshotCleared returns if the "provider_snapshot" field was cleared in this mutation.
+func (m *PaymentOrderMutation) ProviderSnapshotCleared() bool {
+ _, ok := m.clearedFields[paymentorder.FieldProviderSnapshot]
+ return ok
+}
+
+// ResetProviderSnapshot resets all changes to the "provider_snapshot" field.
+func (m *PaymentOrderMutation) ResetProviderSnapshot() {
+ m.provider_snapshot = nil
+ delete(m.clearedFields, paymentorder.FieldProviderSnapshot)
+}
+
// SetStatus sets the "status" field.
func (m *PaymentOrderMutation) SetStatus(s string) {
m.status = &s
@@ -17330,7 +17380,7 @@ func (m *PaymentOrderMutation) Type() string {
// order to get all numeric fields that were incremented/decremented, call
// AddedFields().
func (m *PaymentOrderMutation) Fields() []string {
- fields := make([]string, 0, 38)
+ fields := make([]string, 0, 39)
if m.user != nil {
fields = append(fields, paymentorder.FieldUserID)
}
@@ -17391,6 +17441,9 @@ func (m *PaymentOrderMutation) Fields() []string {
if m.provider_key != nil {
fields = append(fields, paymentorder.FieldProviderKey)
}
+ if m.provider_snapshot != nil {
+ fields = append(fields, paymentorder.FieldProviderSnapshot)
+ }
if m.status != nil {
fields = append(fields, paymentorder.FieldStatus)
}
@@ -17493,6 +17546,8 @@ func (m *PaymentOrderMutation) Field(name string) (ent.Value, bool) {
return m.ProviderInstanceID()
case paymentorder.FieldProviderKey:
return m.ProviderKey()
+ case paymentorder.FieldProviderSnapshot:
+ return m.ProviderSnapshot()
case paymentorder.FieldStatus:
return m.Status()
case paymentorder.FieldRefundAmount:
@@ -17578,6 +17633,8 @@ func (m *PaymentOrderMutation) OldField(ctx context.Context, name string) (ent.V
return m.OldProviderInstanceID(ctx)
case paymentorder.FieldProviderKey:
return m.OldProviderKey(ctx)
+ case paymentorder.FieldProviderSnapshot:
+ return m.OldProviderSnapshot(ctx)
case paymentorder.FieldStatus:
return m.OldStatus(ctx)
case paymentorder.FieldRefundAmount:
@@ -17763,6 +17820,13 @@ func (m *PaymentOrderMutation) SetField(name string, value ent.Value) error {
}
m.SetProviderKey(v)
return nil
+ case paymentorder.FieldProviderSnapshot:
+ v, ok := value.(map[string]interface{})
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetProviderSnapshot(v)
+ return nil
case paymentorder.FieldStatus:
v, ok := value.(string)
if !ok {
@@ -18033,6 +18097,9 @@ func (m *PaymentOrderMutation) ClearedFields() []string {
if m.FieldCleared(paymentorder.FieldProviderKey) {
fields = append(fields, paymentorder.FieldProviderKey)
}
+ if m.FieldCleared(paymentorder.FieldProviderSnapshot) {
+ fields = append(fields, paymentorder.FieldProviderSnapshot)
+ }
if m.FieldCleared(paymentorder.FieldRefundReason) {
fields = append(fields, paymentorder.FieldRefundReason)
}
@@ -18104,6 +18171,9 @@ func (m *PaymentOrderMutation) ClearField(name string) error {
case paymentorder.FieldProviderKey:
m.ClearProviderKey()
return nil
+ case paymentorder.FieldProviderSnapshot:
+ m.ClearProviderSnapshot()
+ return nil
case paymentorder.FieldRefundReason:
m.ClearRefundReason()
return nil
@@ -18202,6 +18272,9 @@ func (m *PaymentOrderMutation) ResetField(name string) error {
case paymentorder.FieldProviderKey:
m.ResetProviderKey()
return nil
+ case paymentorder.FieldProviderSnapshot:
+ m.ResetProviderSnapshot()
+ return nil
case paymentorder.FieldStatus:
m.ResetStatus()
return nil
diff --git a/backend/ent/paymentorder.go b/backend/ent/paymentorder.go
index a58823ee..b131b8c8 100644
--- a/backend/ent/paymentorder.go
+++ b/backend/ent/paymentorder.go
@@ -3,6 +3,7 @@
package ent
import (
+ "encoding/json"
"fmt"
"strings"
"time"
@@ -58,6 +59,8 @@ type PaymentOrder struct {
ProviderInstanceID *string `json:"provider_instance_id,omitempty"`
// ProviderKey holds the value of the "provider_key" field.
ProviderKey *string `json:"provider_key,omitempty"`
+ // ProviderSnapshot holds the value of the "provider_snapshot" field.
+ ProviderSnapshot map[string]interface{} `json:"provider_snapshot,omitempty"`
// Status holds the value of the "status" field.
Status string `json:"status,omitempty"`
// RefundAmount holds the value of the "refund_amount" field.
@@ -125,6 +128,8 @@ func (*PaymentOrder) scanValues(columns []string) ([]any, error) {
values := make([]any, len(columns))
for i := range columns {
switch columns[i] {
+ case paymentorder.FieldProviderSnapshot:
+ values[i] = new([]byte)
case paymentorder.FieldForceRefund:
values[i] = new(sql.NullBool)
case paymentorder.FieldAmount, paymentorder.FieldPayAmount, paymentorder.FieldFeeRate, paymentorder.FieldRefundAmount:
@@ -285,6 +290,14 @@ func (_m *PaymentOrder) assignValues(columns []string, values []any) error {
_m.ProviderKey = new(string)
*_m.ProviderKey = value.String
}
+ case paymentorder.FieldProviderSnapshot:
+ if value, ok := values[i].(*[]byte); !ok {
+ return fmt.Errorf("unexpected type %T for field provider_snapshot", values[i])
+ } else if value != nil && len(*value) > 0 {
+ if err := json.Unmarshal(*value, &_m.ProviderSnapshot); err != nil {
+ return fmt.Errorf("unmarshal field provider_snapshot: %w", err)
+ }
+ }
case paymentorder.FieldStatus:
if value, ok := values[i].(*sql.NullString); !ok {
return fmt.Errorf("unexpected type %T for field status", values[i])
@@ -522,6 +535,9 @@ func (_m *PaymentOrder) String() string {
builder.WriteString(*v)
}
builder.WriteString(", ")
+ builder.WriteString("provider_snapshot=")
+ builder.WriteString(fmt.Sprintf("%v", _m.ProviderSnapshot))
+ builder.WriteString(", ")
builder.WriteString("status=")
builder.WriteString(_m.Status)
builder.WriteString(", ")
diff --git a/backend/ent/paymentorder/paymentorder.go b/backend/ent/paymentorder/paymentorder.go
index af9b1422..62883794 100644
--- a/backend/ent/paymentorder/paymentorder.go
+++ b/backend/ent/paymentorder/paymentorder.go
@@ -54,6 +54,8 @@ const (
FieldProviderInstanceID = "provider_instance_id"
// FieldProviderKey holds the string denoting the provider_key field in the database.
FieldProviderKey = "provider_key"
+ // FieldProviderSnapshot holds the string denoting the provider_snapshot field in the database.
+ FieldProviderSnapshot = "provider_snapshot"
// FieldStatus holds the string denoting the status field in the database.
FieldStatus = "status"
// FieldRefundAmount holds the string denoting the refund_amount field in the database.
@@ -126,6 +128,7 @@ var Columns = []string{
FieldSubscriptionDays,
FieldProviderInstanceID,
FieldProviderKey,
+ FieldProviderSnapshot,
FieldStatus,
FieldRefundAmount,
FieldRefundReason,
diff --git a/backend/ent/paymentorder/where.go b/backend/ent/paymentorder/where.go
index 0f6b74a0..e96bf51e 100644
--- a/backend/ent/paymentorder/where.go
+++ b/backend/ent/paymentorder/where.go
@@ -1440,6 +1440,16 @@ func ProviderKeyContainsFold(v string) predicate.PaymentOrder {
return predicate.PaymentOrder(sql.FieldContainsFold(FieldProviderKey, v))
}
+// ProviderSnapshotIsNil applies the IsNil predicate on the "provider_snapshot" field.
+func ProviderSnapshotIsNil() predicate.PaymentOrder {
+ return predicate.PaymentOrder(sql.FieldIsNull(FieldProviderSnapshot))
+}
+
+// ProviderSnapshotNotNil applies the NotNil predicate on the "provider_snapshot" field.
+func ProviderSnapshotNotNil() predicate.PaymentOrder {
+ return predicate.PaymentOrder(sql.FieldNotNull(FieldProviderSnapshot))
+}
+
// StatusEQ applies the EQ predicate on the "status" field.
func StatusEQ(v string) predicate.PaymentOrder {
return predicate.PaymentOrder(sql.FieldEQ(FieldStatus, v))
diff --git a/backend/ent/paymentorder_create.go b/backend/ent/paymentorder_create.go
index 497ba52c..3ee24f8e 100644
--- a/backend/ent/paymentorder_create.go
+++ b/backend/ent/paymentorder_create.go
@@ -239,6 +239,12 @@ func (_c *PaymentOrderCreate) SetNillableProviderKey(v *string) *PaymentOrderCre
return _c
}
+// SetProviderSnapshot sets the "provider_snapshot" field.
+func (_c *PaymentOrderCreate) SetProviderSnapshot(v map[string]interface{}) *PaymentOrderCreate {
+ _c.mutation.SetProviderSnapshot(v)
+ return _c
+}
+
// SetStatus sets the "status" field.
func (_c *PaymentOrderCreate) SetStatus(v string) *PaymentOrderCreate {
_c.mutation.SetStatus(v)
@@ -771,6 +777,10 @@ func (_c *PaymentOrderCreate) createSpec() (*PaymentOrder, *sqlgraph.CreateSpec)
_spec.SetField(paymentorder.FieldProviderKey, field.TypeString, value)
_node.ProviderKey = &value
}
+ if value, ok := _c.mutation.ProviderSnapshot(); ok {
+ _spec.SetField(paymentorder.FieldProviderSnapshot, field.TypeJSON, value)
+ _node.ProviderSnapshot = value
+ }
if value, ok := _c.mutation.Status(); ok {
_spec.SetField(paymentorder.FieldStatus, field.TypeString, value)
_node.Status = value
@@ -1242,6 +1252,24 @@ func (u *PaymentOrderUpsert) ClearProviderKey() *PaymentOrderUpsert {
return u
}
+// SetProviderSnapshot sets the "provider_snapshot" field.
+func (u *PaymentOrderUpsert) SetProviderSnapshot(v map[string]interface{}) *PaymentOrderUpsert {
+ u.Set(paymentorder.FieldProviderSnapshot, v)
+ return u
+}
+
+// UpdateProviderSnapshot sets the "provider_snapshot" field to the value that was provided on create.
+func (u *PaymentOrderUpsert) UpdateProviderSnapshot() *PaymentOrderUpsert {
+ u.SetExcluded(paymentorder.FieldProviderSnapshot)
+ return u
+}
+
+// ClearProviderSnapshot clears the value of the "provider_snapshot" field.
+func (u *PaymentOrderUpsert) ClearProviderSnapshot() *PaymentOrderUpsert {
+ u.SetNull(paymentorder.FieldProviderSnapshot)
+ return u
+}
+
// SetStatus sets the "status" field.
func (u *PaymentOrderUpsert) SetStatus(v string) *PaymentOrderUpsert {
u.Set(paymentorder.FieldStatus, v)
@@ -1942,6 +1970,27 @@ func (u *PaymentOrderUpsertOne) ClearProviderKey() *PaymentOrderUpsertOne {
})
}
+// SetProviderSnapshot sets the "provider_snapshot" field.
+func (u *PaymentOrderUpsertOne) SetProviderSnapshot(v map[string]interface{}) *PaymentOrderUpsertOne {
+ return u.Update(func(s *PaymentOrderUpsert) {
+ s.SetProviderSnapshot(v)
+ })
+}
+
+// UpdateProviderSnapshot sets the "provider_snapshot" field to the value that was provided on create.
+func (u *PaymentOrderUpsertOne) UpdateProviderSnapshot() *PaymentOrderUpsertOne {
+ return u.Update(func(s *PaymentOrderUpsert) {
+ s.UpdateProviderSnapshot()
+ })
+}
+
+// ClearProviderSnapshot clears the value of the "provider_snapshot" field.
+func (u *PaymentOrderUpsertOne) ClearProviderSnapshot() *PaymentOrderUpsertOne {
+ return u.Update(func(s *PaymentOrderUpsert) {
+ s.ClearProviderSnapshot()
+ })
+}
+
// SetStatus sets the "status" field.
func (u *PaymentOrderUpsertOne) SetStatus(v string) *PaymentOrderUpsertOne {
return u.Update(func(s *PaymentOrderUpsert) {
@@ -2853,6 +2902,27 @@ func (u *PaymentOrderUpsertBulk) ClearProviderKey() *PaymentOrderUpsertBulk {
})
}
+// SetProviderSnapshot sets the "provider_snapshot" field.
+func (u *PaymentOrderUpsertBulk) SetProviderSnapshot(v map[string]interface{}) *PaymentOrderUpsertBulk {
+ return u.Update(func(s *PaymentOrderUpsert) {
+ s.SetProviderSnapshot(v)
+ })
+}
+
+// UpdateProviderSnapshot sets the "provider_snapshot" field to the value that was provided on create.
+func (u *PaymentOrderUpsertBulk) UpdateProviderSnapshot() *PaymentOrderUpsertBulk {
+ return u.Update(func(s *PaymentOrderUpsert) {
+ s.UpdateProviderSnapshot()
+ })
+}
+
+// ClearProviderSnapshot clears the value of the "provider_snapshot" field.
+func (u *PaymentOrderUpsertBulk) ClearProviderSnapshot() *PaymentOrderUpsertBulk {
+ return u.Update(func(s *PaymentOrderUpsert) {
+ s.ClearProviderSnapshot()
+ })
+}
+
// SetStatus sets the "status" field.
func (u *PaymentOrderUpsertBulk) SetStatus(v string) *PaymentOrderUpsertBulk {
return u.Update(func(s *PaymentOrderUpsert) {
diff --git a/backend/ent/paymentorder_update.go b/backend/ent/paymentorder_update.go
index 9a901415..378e0dad 100644
--- a/backend/ent/paymentorder_update.go
+++ b/backend/ent/paymentorder_update.go
@@ -405,6 +405,18 @@ func (_u *PaymentOrderUpdate) ClearProviderKey() *PaymentOrderUpdate {
return _u
}
+// SetProviderSnapshot sets the "provider_snapshot" field.
+func (_u *PaymentOrderUpdate) SetProviderSnapshot(v map[string]interface{}) *PaymentOrderUpdate {
+ _u.mutation.SetProviderSnapshot(v)
+ return _u
+}
+
+// ClearProviderSnapshot clears the value of the "provider_snapshot" field.
+func (_u *PaymentOrderUpdate) ClearProviderSnapshot() *PaymentOrderUpdate {
+ _u.mutation.ClearProviderSnapshot()
+ return _u
+}
+
// SetStatus sets the "status" field.
func (_u *PaymentOrderUpdate) SetStatus(v string) *PaymentOrderUpdate {
_u.mutation.SetStatus(v)
@@ -941,6 +953,12 @@ func (_u *PaymentOrderUpdate) sqlSave(ctx context.Context) (_node int, err error
if _u.mutation.ProviderKeyCleared() {
_spec.ClearField(paymentorder.FieldProviderKey, field.TypeString)
}
+ if value, ok := _u.mutation.ProviderSnapshot(); ok {
+ _spec.SetField(paymentorder.FieldProviderSnapshot, field.TypeJSON, value)
+ }
+ if _u.mutation.ProviderSnapshotCleared() {
+ _spec.ClearField(paymentorder.FieldProviderSnapshot, field.TypeJSON)
+ }
if value, ok := _u.mutation.Status(); ok {
_spec.SetField(paymentorder.FieldStatus, field.TypeString, value)
}
@@ -1450,6 +1468,18 @@ func (_u *PaymentOrderUpdateOne) ClearProviderKey() *PaymentOrderUpdateOne {
return _u
}
+// SetProviderSnapshot sets the "provider_snapshot" field.
+func (_u *PaymentOrderUpdateOne) SetProviderSnapshot(v map[string]interface{}) *PaymentOrderUpdateOne {
+ _u.mutation.SetProviderSnapshot(v)
+ return _u
+}
+
+// ClearProviderSnapshot clears the value of the "provider_snapshot" field.
+func (_u *PaymentOrderUpdateOne) ClearProviderSnapshot() *PaymentOrderUpdateOne {
+ _u.mutation.ClearProviderSnapshot()
+ return _u
+}
+
// SetStatus sets the "status" field.
func (_u *PaymentOrderUpdateOne) SetStatus(v string) *PaymentOrderUpdateOne {
_u.mutation.SetStatus(v)
@@ -2016,6 +2046,12 @@ func (_u *PaymentOrderUpdateOne) sqlSave(ctx context.Context) (_node *PaymentOrd
if _u.mutation.ProviderKeyCleared() {
_spec.ClearField(paymentorder.FieldProviderKey, field.TypeString)
}
+ if value, ok := _u.mutation.ProviderSnapshot(); ok {
+ _spec.SetField(paymentorder.FieldProviderSnapshot, field.TypeJSON, value)
+ }
+ if _u.mutation.ProviderSnapshotCleared() {
+ _spec.ClearField(paymentorder.FieldProviderSnapshot, field.TypeJSON)
+ }
if value, ok := _u.mutation.Status(); ok {
_spec.SetField(paymentorder.FieldStatus, field.TypeString, value)
}
diff --git a/backend/ent/runtime/runtime.go b/backend/ent/runtime/runtime.go
index b7118ac9..bdb7f7a9 100644
--- a/backend/ent/runtime/runtime.go
+++ b/backend/ent/runtime/runtime.go
@@ -728,37 +728,37 @@ func init() {
// paymentorder.ProviderKeyValidator is a validator for the "provider_key" field. It is called by the builders before save.
paymentorder.ProviderKeyValidator = paymentorderDescProviderKey.Validators[0].(func(string) error)
// paymentorderDescStatus is the schema descriptor for status field.
- paymentorderDescStatus := paymentorderFields[20].Descriptor()
+ paymentorderDescStatus := paymentorderFields[21].Descriptor()
// paymentorder.DefaultStatus holds the default value on creation for the status field.
paymentorder.DefaultStatus = paymentorderDescStatus.Default.(string)
// paymentorder.StatusValidator is a validator for the "status" field. It is called by the builders before save.
paymentorder.StatusValidator = paymentorderDescStatus.Validators[0].(func(string) error)
// paymentorderDescRefundAmount is the schema descriptor for refund_amount field.
- paymentorderDescRefundAmount := paymentorderFields[21].Descriptor()
+ paymentorderDescRefundAmount := paymentorderFields[22].Descriptor()
// paymentorder.DefaultRefundAmount holds the default value on creation for the refund_amount field.
paymentorder.DefaultRefundAmount = paymentorderDescRefundAmount.Default.(float64)
// paymentorderDescForceRefund is the schema descriptor for force_refund field.
- paymentorderDescForceRefund := paymentorderFields[24].Descriptor()
+ paymentorderDescForceRefund := paymentorderFields[25].Descriptor()
// paymentorder.DefaultForceRefund holds the default value on creation for the force_refund field.
paymentorder.DefaultForceRefund = paymentorderDescForceRefund.Default.(bool)
// paymentorderDescRefundRequestedBy is the schema descriptor for refund_requested_by field.
- paymentorderDescRefundRequestedBy := paymentorderFields[27].Descriptor()
+ paymentorderDescRefundRequestedBy := paymentorderFields[28].Descriptor()
// paymentorder.RefundRequestedByValidator is a validator for the "refund_requested_by" field. It is called by the builders before save.
paymentorder.RefundRequestedByValidator = paymentorderDescRefundRequestedBy.Validators[0].(func(string) error)
// paymentorderDescClientIP is the schema descriptor for client_ip field.
- paymentorderDescClientIP := paymentorderFields[33].Descriptor()
+ paymentorderDescClientIP := paymentorderFields[34].Descriptor()
// paymentorder.ClientIPValidator is a validator for the "client_ip" field. It is called by the builders before save.
paymentorder.ClientIPValidator = paymentorderDescClientIP.Validators[0].(func(string) error)
// paymentorderDescSrcHost is the schema descriptor for src_host field.
- paymentorderDescSrcHost := paymentorderFields[34].Descriptor()
+ paymentorderDescSrcHost := paymentorderFields[35].Descriptor()
// paymentorder.SrcHostValidator is a validator for the "src_host" field. It is called by the builders before save.
paymentorder.SrcHostValidator = paymentorderDescSrcHost.Validators[0].(func(string) error)
// paymentorderDescCreatedAt is the schema descriptor for created_at field.
- paymentorderDescCreatedAt := paymentorderFields[36].Descriptor()
+ paymentorderDescCreatedAt := paymentorderFields[37].Descriptor()
// paymentorder.DefaultCreatedAt holds the default value on creation for the created_at field.
paymentorder.DefaultCreatedAt = paymentorderDescCreatedAt.Default.(func() time.Time)
// paymentorderDescUpdatedAt is the schema descriptor for updated_at field.
- paymentorderDescUpdatedAt := paymentorderFields[37].Descriptor()
+ paymentorderDescUpdatedAt := paymentorderFields[38].Descriptor()
// paymentorder.DefaultUpdatedAt holds the default value on creation for the updated_at field.
paymentorder.DefaultUpdatedAt = paymentorderDescUpdatedAt.Default.(func() time.Time)
// paymentorder.UpdateDefaultUpdatedAt holds the default value on update for the updated_at field.
diff --git a/backend/ent/schema/payment_order.go b/backend/ent/schema/payment_order.go
index 64378de1..5815d032 100644
--- a/backend/ent/schema/payment_order.go
+++ b/backend/ent/schema/payment_order.go
@@ -95,6 +95,9 @@ func (PaymentOrder) Fields() []ent.Field {
Optional().
Nillable().
MaxLen(30),
+ field.JSON("provider_snapshot", map[string]any{}).
+ Optional().
+ SchemaType(map[string]string{dialect.Postgres: "jsonb"}),
// 状态
field.String("status").
diff --git a/backend/internal/handler/admin/payment_handler.go b/backend/internal/handler/admin/payment_handler.go
index b0ed6aed..84359cd9 100644
--- a/backend/internal/handler/admin/payment_handler.go
+++ b/backend/internal/handler/admin/payment_handler.go
@@ -3,6 +3,7 @@ package admin
import (
"strconv"
+ dbent "github.com/Wei-Shaw/sub2api/ent"
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
"github.com/Wei-Shaw/sub2api/internal/service"
@@ -66,7 +67,7 @@ func (h *PaymentHandler) ListOrders(c *gin.Context) {
response.ErrorFrom(c, err)
return
}
- response.Paginated(c, orders, int64(total), page, pageSize)
+ response.Paginated(c, sanitizeAdminPaymentOrdersForResponse(orders), int64(total), page, pageSize)
}
// GetOrderDetail returns detailed information about a single order.
@@ -82,7 +83,7 @@ func (h *PaymentHandler) GetOrderDetail(c *gin.Context) {
return
}
auditLogs, _ := h.paymentService.GetOrderAuditLogs(c.Request.Context(), orderID)
- response.Success(c, gin.H{"order": order, "auditLogs": auditLogs})
+ response.Success(c, gin.H{"order": sanitizeAdminPaymentOrderForResponse(order), "auditLogs": auditLogs})
}
// CancelOrder cancels a pending order (admin).
@@ -114,6 +115,26 @@ func (h *PaymentHandler) RetryFulfillment(c *gin.Context) {
response.Success(c, gin.H{"message": "fulfillment retried"})
}
+func sanitizeAdminPaymentOrdersForResponse(orders []*dbent.PaymentOrder) []*dbent.PaymentOrder {
+ if len(orders) == 0 {
+ return orders
+ }
+ out := make([]*dbent.PaymentOrder, 0, len(orders))
+ for _, order := range orders {
+ out = append(out, sanitizeAdminPaymentOrderForResponse(order))
+ }
+ return out
+}
+
+func sanitizeAdminPaymentOrderForResponse(order *dbent.PaymentOrder) *dbent.PaymentOrder {
+ if order == nil {
+ return nil
+ }
+ cloned := *order
+ cloned.ProviderSnapshot = nil
+ return &cloned
+}
+
// AdminProcessRefundRequest is the request body for admin refund processing.
type AdminProcessRefundRequest struct {
Amount float64 `json:"amount"`
diff --git a/backend/internal/handler/payment_handler.go b/backend/internal/handler/payment_handler.go
index 273aea73..0fba4726 100644
--- a/backend/internal/handler/payment_handler.go
+++ b/backend/internal/handler/payment_handler.go
@@ -6,6 +6,7 @@ import (
"strconv"
"strings"
+ dbent "github.com/Wei-Shaw/sub2api/ent"
"github.com/Wei-Shaw/sub2api/internal/payment"
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
@@ -327,7 +328,7 @@ func (h *PaymentHandler) GetMyOrders(c *gin.Context) {
response.ErrorFrom(c, err)
return
}
- response.Paginated(c, orders, int64(total), page, pageSize)
+ response.Paginated(c, sanitizePaymentOrdersForResponse(orders), int64(total), page, pageSize)
}
// GetOrder returns a single order for the authenticated user.
@@ -349,7 +350,7 @@ func (h *PaymentHandler) GetOrder(c *gin.Context) {
response.ErrorFrom(c, err)
return
}
- response.Success(c, order)
+ response.Success(c, sanitizePaymentOrderForResponse(order))
}
// CancelOrder cancels a pending order for the authenticated user.
@@ -445,7 +446,7 @@ func (h *PaymentHandler) VerifyOrder(c *gin.Context) {
response.ErrorFrom(c, err)
return
}
- response.Success(c, order)
+ response.Success(c, sanitizePaymentOrderForResponse(order))
}
// PublicOrderResult is the limited order info returned by the public verify endpoint.
@@ -523,6 +524,26 @@ func isMobile(c *gin.Context) bool {
return false
}
+func sanitizePaymentOrdersForResponse(orders []*dbent.PaymentOrder) []*dbent.PaymentOrder {
+ if len(orders) == 0 {
+ return orders
+ }
+ out := make([]*dbent.PaymentOrder, 0, len(orders))
+ for _, order := range orders {
+ out = append(out, sanitizePaymentOrderForResponse(order))
+ }
+ return out
+}
+
+func sanitizePaymentOrderForResponse(order *dbent.PaymentOrder) *dbent.PaymentOrder {
+ if order == nil {
+ return nil
+ }
+ cloned := *order
+ cloned.ProviderSnapshot = nil
+ return &cloned
+}
+
func isWeChatBrowser(c *gin.Context) bool {
return strings.Contains(strings.ToLower(c.GetHeader("User-Agent")), "micromessenger")
}
diff --git a/backend/internal/service/payment_order.go b/backend/internal/service/payment_order.go
index 7d973b92..1f01bc11 100644
--- a/backend/internal/service/payment_order.go
+++ b/backend/internal/service/payment_order.go
@@ -73,7 +73,7 @@ func (s *PaymentService) CreateOrder(ctx context.Context, req CreateOrderRequest
if oauthResp != nil {
return oauthResp, nil
}
- order, err := s.createOrderInTx(ctx, req, user, plan, cfg, orderAmount, limitAmount, feeRate, payAmount)
+ order, err := s.createOrderInTx(ctx, req, user, plan, cfg, orderAmount, limitAmount, feeRate, payAmount, sel)
if err != nil {
return nil, err
}
@@ -122,7 +122,7 @@ func (s *PaymentService) validateSubOrder(ctx context.Context, req CreateOrderRe
return plan, nil
}
-func (s *PaymentService) createOrderInTx(ctx context.Context, req CreateOrderRequest, user *User, plan *dbent.SubscriptionPlan, cfg *PaymentConfig, orderAmount, limitAmount, feeRate, payAmount float64) (*dbent.PaymentOrder, error) {
+func (s *PaymentService) createOrderInTx(ctx context.Context, req CreateOrderRequest, user *User, plan *dbent.SubscriptionPlan, cfg *PaymentConfig, orderAmount, limitAmount, feeRate, payAmount float64, sel *payment.InstanceSelection) (*dbent.PaymentOrder, error) {
tx, err := s.entClient.Tx(ctx)
if err != nil {
return nil, fmt.Errorf("begin transaction: %w", err)
@@ -139,6 +139,13 @@ func (s *PaymentService) createOrderInTx(ctx context.Context, req CreateOrderReq
tm = defaultOrderTimeoutMin
}
exp := time.Now().Add(time.Duration(tm) * time.Minute)
+ providerSnapshot := buildPaymentOrderProviderSnapshot(sel)
+ selectedInstanceID := ""
+ selectedProviderKey := ""
+ if sel != nil {
+ selectedInstanceID = strings.TrimSpace(sel.InstanceID)
+ selectedProviderKey = strings.TrimSpace(sel.ProviderKey)
+ }
b := tx.PaymentOrder.Create().
SetUserID(req.UserID).
SetUserEmail(user.Email).
@@ -159,6 +166,15 @@ func (s *PaymentService) createOrderInTx(ctx context.Context, req CreateOrderReq
if req.SrcURL != "" {
b.SetSrcURL(req.SrcURL)
}
+ if selectedInstanceID != "" {
+ b.SetProviderInstanceID(selectedInstanceID)
+ }
+ if selectedProviderKey != "" {
+ b.SetProviderKey(selectedProviderKey)
+ }
+ if providerSnapshot != nil {
+ b.SetProviderSnapshot(providerSnapshot)
+ }
if plan != nil {
b.SetPlanID(plan.ID).SetSubscriptionGroupID(plan.GroupID).SetSubscriptionDays(psComputeValidityDays(plan.ValidityDays, plan.ValidityUnit))
}
@@ -192,6 +208,35 @@ func (s *PaymentService) checkPendingLimit(ctx context.Context, tx *dbent.Tx, us
return nil
}
+func buildPaymentOrderProviderSnapshot(sel *payment.InstanceSelection) map[string]any {
+ if sel == nil {
+ return nil
+ }
+
+ snapshot := map[string]any{}
+ snapshot["schema_version"] = 1
+
+ instanceID := strings.TrimSpace(sel.InstanceID)
+ if instanceID != "" {
+ snapshot["provider_instance_id"] = instanceID
+ }
+
+ providerKey := strings.TrimSpace(sel.ProviderKey)
+ if providerKey != "" {
+ snapshot["provider_key"] = providerKey
+ }
+
+ paymentMode := strings.TrimSpace(sel.PaymentMode)
+ if paymentMode != "" {
+ snapshot["payment_mode"] = paymentMode
+ }
+
+ if len(snapshot) == 1 {
+ return nil
+ }
+ return snapshot
+}
+
func (s *PaymentService) checkDailyLimit(ctx context.Context, tx *dbent.Tx, userID int64, amount, limit float64) error {
if limit <= 0 {
return nil
diff --git a/backend/internal/service/payment_order_provider_snapshot_test.go b/backend/internal/service/payment_order_provider_snapshot_test.go
new file mode 100644
index 00000000..c75566bc
--- /dev/null
+++ b/backend/internal/service/payment_order_provider_snapshot_test.go
@@ -0,0 +1,116 @@
+//go:build unit
+
+package service
+
+import (
+ "context"
+ "strconv"
+ "testing"
+
+ "github.com/Wei-Shaw/sub2api/internal/payment"
+ "github.com/stretchr/testify/require"
+)
+
+func TestBuildPaymentOrderProviderSnapshot_ExcludesSensitiveConfig(t *testing.T) {
+ t.Parallel()
+
+ sel := &payment.InstanceSelection{
+ InstanceID: "12",
+ ProviderKey: payment.TypeWxpay,
+ SupportedTypes: "wxpay,wxpay_direct",
+ PaymentMode: "popup",
+ Config: map[string]string{
+ "privateKey": "secret",
+ "apiV3Key": "secret-v3",
+ "appId": "wx-app-id",
+ },
+ }
+
+ snapshot := buildPaymentOrderProviderSnapshot(sel)
+ require.Equal(t, map[string]any{
+ "schema_version": 1,
+ "provider_instance_id": "12",
+ "provider_key": payment.TypeWxpay,
+ "payment_mode": "popup",
+ }, snapshot)
+ require.NotContains(t, snapshot, "config")
+ require.NotContains(t, snapshot, "privateKey")
+ require.NotContains(t, snapshot, "apiV3Key")
+ require.NotContains(t, snapshot, "supported_types")
+ require.NotContains(t, snapshot, "instance_name")
+}
+
+func TestCreateOrderInTx_WritesProviderSnapshot(t *testing.T) {
+ ctx := context.Background()
+ client := newPaymentConfigServiceTestClient(t)
+
+ user, err := client.User.Create().
+ SetEmail("snapshot@example.com").
+ SetPasswordHash("hash").
+ SetUsername("snapshot-user").
+ Save(ctx)
+ require.NoError(t, err)
+
+ instance, err := client.PaymentProviderInstance.Create().
+ SetProviderKey(payment.TypeAlipay).
+ SetName("Primary Alipay").
+ SetConfig(`{"secretKey":"do-not-copy"}`).
+ SetSupportedTypes("alipay,alipay_direct").
+ SetPaymentMode("redirect").
+ SetEnabled(true).
+ Save(ctx)
+ require.NoError(t, err)
+
+ svc := &PaymentService{entClient: client}
+ order, err := svc.createOrderInTx(
+ ctx,
+ CreateOrderRequest{
+ UserID: user.ID,
+ PaymentType: payment.TypeAlipay,
+ OrderType: payment.OrderTypeBalance,
+ ClientIP: "127.0.0.1",
+ SrcHost: "app.example.com",
+ },
+ &User{
+ ID: user.ID,
+ Email: user.Email,
+ Username: user.Username,
+ },
+ nil,
+ &PaymentConfig{
+ MaxPendingOrders: 3,
+ OrderTimeoutMin: 30,
+ },
+ 88,
+ 88,
+ 0,
+ 88,
+ &payment.InstanceSelection{
+ InstanceID: strconv.FormatInt(instance.ID, 10),
+ ProviderKey: payment.TypeAlipay,
+ SupportedTypes: "alipay,alipay_direct",
+ PaymentMode: "redirect",
+ Config: map[string]string{
+ "secretKey": "do-not-copy",
+ },
+ },
+ )
+ require.NoError(t, err)
+ require.Equal(t, strconv.FormatInt(instance.ID, 10), valueOrEmpty(order.ProviderInstanceID))
+ require.Equal(t, payment.TypeAlipay, valueOrEmpty(order.ProviderKey))
+ require.Equal(t, float64(1), order.ProviderSnapshot["schema_version"])
+ require.Equal(t, strconv.FormatInt(instance.ID, 10), order.ProviderSnapshot["provider_instance_id"])
+ require.Equal(t, payment.TypeAlipay, order.ProviderSnapshot["provider_key"])
+ require.Equal(t, "redirect", order.ProviderSnapshot["payment_mode"])
+ require.NotContains(t, order.ProviderSnapshot, "config")
+ require.NotContains(t, order.ProviderSnapshot, "secretKey")
+ require.NotContains(t, order.ProviderSnapshot, "supported_types")
+ require.NotContains(t, order.ProviderSnapshot, "instance_name")
+}
+
+func valueOrEmpty(v *string) string {
+ if v == nil {
+ return ""
+ }
+ return *v
+}
diff --git a/backend/migrations/117_add_payment_order_provider_snapshot.sql b/backend/migrations/117_add_payment_order_provider_snapshot.sql
new file mode 100644
index 00000000..56a5fe2d
--- /dev/null
+++ b/backend/migrations/117_add_payment_order_provider_snapshot.sql
@@ -0,0 +1,2 @@
+ALTER TABLE payment_orders
+ADD COLUMN IF NOT EXISTS provider_snapshot JSONB;
From 35aeeaa6e17632ab08bd7f7f85d5edc396e4ce0f Mon Sep 17 00:00:00 2001
From: IanShaw027
Date: Tue, 21 Apr 2026 12:50:55 +0800
Subject: [PATCH 103/189] fix: pin payment read paths to provider snapshots
---
.../internal/service/payment_fulfillment.go | 2 +-
.../service/payment_fulfillment_test.go | 24 +++
.../payment_order_provider_snapshot.go | 115 ++++++++++++++
backend/internal/service/payment_refund.go | 4 +
.../service/payment_webhook_provider.go | 2 +-
.../service/payment_webhook_provider_test.go | 150 ++++++++++++++++++
6 files changed, 295 insertions(+), 2 deletions(-)
create mode 100644 backend/internal/service/payment_order_provider_snapshot.go
diff --git a/backend/internal/service/payment_fulfillment.go b/backend/internal/service/payment_fulfillment.go
index 83bac21d..9cb03cca 100644
--- a/backend/internal/service/payment_fulfillment.go
+++ b/backend/internal/service/payment_fulfillment.go
@@ -45,7 +45,7 @@ func (s *PaymentService) confirmPayment(ctx context.Context, oid int64, tradeNo
if inst, instErr := s.getOrderProviderInstance(ctx, o); instErr == nil && inst != nil {
instanceProviderKey = inst.ProviderKey
}
- expectedProviderKey := expectedNotificationProviderKey(s.registry, o.PaymentType, psStringValue(o.ProviderKey), instanceProviderKey)
+ expectedProviderKey := expectedNotificationProviderKeyForOrder(s.registry, o, instanceProviderKey)
if expectedProviderKey != "" && strings.TrimSpace(pk) != "" && !strings.EqualFold(expectedProviderKey, strings.TrimSpace(pk)) {
s.writeAuditLog(ctx, o.ID, "PAYMENT_PROVIDER_MISMATCH", pk, map[string]any{
"expectedProvider": expectedProviderKey,
diff --git a/backend/internal/service/payment_fulfillment_test.go b/backend/internal/service/payment_fulfillment_test.go
index 712129b0..3ce82973 100644
--- a/backend/internal/service/payment_fulfillment_test.go
+++ b/backend/internal/service/payment_fulfillment_test.go
@@ -7,6 +7,7 @@ import (
"errors"
"testing"
+ dbent "github.com/Wei-Shaw/sub2api/ent"
"github.com/Wei-Shaw/sub2api/internal/payment"
"github.com/stretchr/testify/assert"
)
@@ -240,3 +241,26 @@ func TestExpectedNotificationProviderKeyPrefersOrderSnapshotProviderKey(t *testi
expectedNotificationProviderKey(registry, payment.TypeAlipay, payment.TypeEasyPay, ""),
)
}
+
+func TestExpectedNotificationProviderKeyForOrderUsesSnapshotProviderKey(t *testing.T) {
+ t.Parallel()
+
+ registry := payment.NewRegistry()
+ registry.Register(paymentFulfillmentTestProvider{
+ key: payment.TypeAlipay,
+ supportedTypes: []payment.PaymentType{payment.TypeAlipay},
+ })
+
+ order := &dbent.PaymentOrder{
+ PaymentType: payment.TypeAlipay,
+ ProviderSnapshot: map[string]any{
+ "schema_version": 1,
+ "provider_key": payment.TypeEasyPay,
+ },
+ }
+
+ assert.Equal(t,
+ payment.TypeEasyPay,
+ expectedNotificationProviderKeyForOrder(registry, order, ""),
+ )
+}
diff --git a/backend/internal/service/payment_order_provider_snapshot.go b/backend/internal/service/payment_order_provider_snapshot.go
new file mode 100644
index 00000000..9a0aa106
--- /dev/null
+++ b/backend/internal/service/payment_order_provider_snapshot.go
@@ -0,0 +1,115 @@
+package service
+
+import (
+ "context"
+ "fmt"
+ "strconv"
+ "strings"
+
+ dbent "github.com/Wei-Shaw/sub2api/ent"
+ "github.com/Wei-Shaw/sub2api/internal/payment"
+)
+
+type paymentOrderProviderSnapshot struct {
+ SchemaVersion int
+ ProviderInstanceID string
+ ProviderKey string
+ PaymentMode string
+}
+
+func psOrderProviderSnapshot(order *dbent.PaymentOrder) *paymentOrderProviderSnapshot {
+ if order == nil || len(order.ProviderSnapshot) == 0 {
+ return nil
+ }
+
+ snapshot := &paymentOrderProviderSnapshot{
+ SchemaVersion: psSnapshotIntValue(order.ProviderSnapshot["schema_version"]),
+ ProviderInstanceID: psSnapshotStringValue(order.ProviderSnapshot["provider_instance_id"]),
+ ProviderKey: psSnapshotStringValue(order.ProviderSnapshot["provider_key"]),
+ PaymentMode: psSnapshotStringValue(order.ProviderSnapshot["payment_mode"]),
+ }
+ if snapshot.SchemaVersion == 0 && snapshot.ProviderInstanceID == "" && snapshot.ProviderKey == "" && snapshot.PaymentMode == "" {
+ return nil
+ }
+ return snapshot
+}
+
+func psSnapshotStringValue(value any) string {
+ switch typed := value.(type) {
+ case string:
+ return strings.TrimSpace(typed)
+ default:
+ return ""
+ }
+}
+
+func psSnapshotIntValue(value any) int {
+ switch typed := value.(type) {
+ case int:
+ return typed
+ case int32:
+ return int(typed)
+ case int64:
+ return int(typed)
+ case float32:
+ return int(typed)
+ case float64:
+ return int(typed)
+ case string:
+ n, err := strconv.Atoi(strings.TrimSpace(typed))
+ if err == nil {
+ return n
+ }
+ }
+ return 0
+}
+
+func (s *PaymentService) resolveSnapshotOrderProviderInstance(ctx context.Context, order *dbent.PaymentOrder, snapshot *paymentOrderProviderSnapshot) (*dbent.PaymentProviderInstance, error) {
+ if s == nil || s.entClient == nil || order == nil || snapshot == nil {
+ return nil, nil
+ }
+
+ snapshotInstanceID := strings.TrimSpace(snapshot.ProviderInstanceID)
+ columnInstanceID := strings.TrimSpace(psStringValue(order.ProviderInstanceID))
+ if snapshotInstanceID == "" {
+ snapshotInstanceID = columnInstanceID
+ }
+ if snapshotInstanceID == "" {
+ return nil, fmt.Errorf("order %d provider snapshot is missing provider_instance_id", order.ID)
+ }
+ if columnInstanceID != "" && snapshot.ProviderInstanceID != "" && !strings.EqualFold(columnInstanceID, snapshot.ProviderInstanceID) {
+ return nil, fmt.Errorf("order %d provider snapshot instance mismatch: snapshot=%s order=%s", order.ID, snapshot.ProviderInstanceID, columnInstanceID)
+ }
+
+ instID, err := strconv.ParseInt(snapshotInstanceID, 10, 64)
+ if err != nil {
+ return nil, fmt.Errorf("order %d provider snapshot instance id is invalid: %s", order.ID, snapshotInstanceID)
+ }
+
+ inst, err := s.entClient.PaymentProviderInstance.Get(ctx, instID)
+ if err != nil {
+ if dbent.IsNotFound(err) {
+ return nil, fmt.Errorf("order %d provider snapshot instance %s is missing", order.ID, snapshotInstanceID)
+ }
+ return nil, err
+ }
+
+ if snapshot.ProviderKey != "" && !strings.EqualFold(strings.TrimSpace(inst.ProviderKey), snapshot.ProviderKey) {
+ return nil, fmt.Errorf("order %d provider snapshot key mismatch: snapshot=%s instance=%s", order.ID, snapshot.ProviderKey, inst.ProviderKey)
+ }
+
+ return inst, nil
+}
+
+func expectedNotificationProviderKeyForOrder(registry *payment.Registry, order *dbent.PaymentOrder, instanceProviderKey string) string {
+ if order == nil {
+ return strings.TrimSpace(instanceProviderKey)
+ }
+
+ orderProviderKey := psStringValue(order.ProviderKey)
+ if snapshot := psOrderProviderSnapshot(order); snapshot != nil && snapshot.ProviderKey != "" {
+ orderProviderKey = snapshot.ProviderKey
+ }
+
+ return expectedNotificationProviderKey(registry, order.PaymentType, orderProviderKey, instanceProviderKey)
+}
diff --git a/backend/internal/service/payment_refund.go b/backend/internal/service/payment_refund.go
index 57469fa3..fbaeff99 100644
--- a/backend/internal/service/payment_refund.go
+++ b/backend/internal/service/payment_refund.go
@@ -27,6 +27,10 @@ func (s *PaymentService) getOrderProviderInstance(ctx context.Context, o *dbent.
return nil, nil
}
+ if snapshot := psOrderProviderSnapshot(o); snapshot != nil {
+ return s.resolveSnapshotOrderProviderInstance(ctx, o, snapshot)
+ }
+
instIDStr := strings.TrimSpace(psStringValue(o.ProviderInstanceID))
if instIDStr == "" {
return s.resolveUniqueLegacyOrderProviderInstance(ctx, o)
diff --git a/backend/internal/service/payment_webhook_provider.go b/backend/internal/service/payment_webhook_provider.go
index 82dc9ea3..f2da40d9 100644
--- a/backend/internal/service/payment_webhook_provider.go
+++ b/backend/internal/service/payment_webhook_provider.go
@@ -113,7 +113,7 @@ func (s *PaymentService) webhookRegistryFallbackAllowed(ctx context.Context, pro
}
func psHasPinnedProviderInstance(order *dbent.PaymentOrder) bool {
- return order != nil && order.ProviderInstanceID != nil && strings.TrimSpace(*order.ProviderInstanceID) != ""
+ return order != nil && (psOrderProviderSnapshot(order) != nil || (order.ProviderInstanceID != nil && strings.TrimSpace(*order.ProviderInstanceID) != ""))
}
func (s *PaymentService) getEnabledWebhookProvidersByKey(ctx context.Context, providerKey string) ([]payment.Provider, error) {
diff --git a/backend/internal/service/payment_webhook_provider_test.go b/backend/internal/service/payment_webhook_provider_test.go
index 15b447c2..f12cf691 100644
--- a/backend/internal/service/payment_webhook_provider_test.go
+++ b/backend/internal/service/payment_webhook_provider_test.go
@@ -5,6 +5,7 @@ package service
import (
"context"
"encoding/json"
+ "strconv"
"testing"
"time"
@@ -205,6 +206,72 @@ func TestGetOrderProviderInstanceLeavesProviderKeyMatchUnresolvedWhenTypeNotSupp
require.Nil(t, got)
}
+func TestGetOrderProviderInstanceUsesProviderSnapshotWhenPinnedColumnMissing(t *testing.T) {
+ ctx := context.Background()
+ client := newPaymentConfigServiceTestClient(t)
+ inst, err := client.PaymentProviderInstance.Create().
+ SetProviderKey(payment.TypeStripe).
+ SetName("stripe-snapshot").
+ SetConfig(encryptWebhookProviderConfig(t, map[string]string{"secretKey": "sk_snapshot"})).
+ SetSupportedTypes("stripe").
+ SetEnabled(true).
+ Save(ctx)
+ require.NoError(t, err)
+
+ order := &dbent.PaymentOrder{
+ ID: 42,
+ PaymentType: payment.TypeStripe,
+ ProviderSnapshot: map[string]any{
+ "schema_version": 1,
+ "provider_instance_id": strconv.FormatInt(inst.ID, 10),
+ "provider_key": payment.TypeStripe,
+ },
+ }
+
+ svc := &PaymentService{
+ entClient: client,
+ loadBalancer: newWebhookProviderTestLoadBalancer(client),
+ }
+
+ got, err := svc.getOrderProviderInstance(ctx, order)
+ require.NoError(t, err)
+ require.NotNil(t, got)
+ require.Equal(t, inst.ID, got.ID)
+}
+
+func TestGetOrderProviderInstanceRejectsMissingSnapshotInstanceWithoutLegacyFallback(t *testing.T) {
+ ctx := context.Background()
+ client := newPaymentConfigServiceTestClient(t)
+ _, err := client.PaymentProviderInstance.Create().
+ SetProviderKey(payment.TypeStripe).
+ SetName("stripe-legacy-fallback").
+ SetConfig(encryptWebhookProviderConfig(t, map[string]string{"secretKey": "sk_legacy"})).
+ SetSupportedTypes("stripe").
+ SetEnabled(true).
+ Save(ctx)
+ require.NoError(t, err)
+
+ order := &dbent.PaymentOrder{
+ ID: 43,
+ PaymentType: payment.TypeStripe,
+ ProviderSnapshot: map[string]any{
+ "schema_version": 1,
+ "provider_instance_id": "999999",
+ "provider_key": payment.TypeStripe,
+ },
+ }
+
+ svc := &PaymentService{
+ entClient: client,
+ loadBalancer: newWebhookProviderTestLoadBalancer(client),
+ }
+
+ got, err := svc.getOrderProviderInstance(ctx, order)
+ require.Nil(t, got)
+ require.Error(t, err)
+ require.Contains(t, err.Error(), "provider snapshot instance 999999 is missing")
+}
+
func TestGetWebhookProviderRejectsAmbiguousRegistryFallback(t *testing.T) {
ctx := context.Background()
client := newPaymentConfigServiceTestClient(t)
@@ -364,3 +431,86 @@ func TestGetWebhookProviderRejectsRegistryFallbackForPinnedOrder(t *testing.T) {
require.Error(t, err)
require.Contains(t, err.Error(), "provider instance")
}
+
+func TestGetWebhookProviderUsesProviderSnapshotBeforeWxpayFallback(t *testing.T) {
+ ctx := context.Background()
+ client := newPaymentConfigServiceTestClient(t)
+ user, err := client.User.Create().
+ SetEmail("snapshot-webhook@example.com").
+ SetPasswordHash("hash").
+ SetUsername("snapshot-webhook").
+ Save(ctx)
+ require.NoError(t, err)
+
+ wxpayConfigA := encryptWebhookProviderConfig(t, map[string]string{
+ "appId": "wx-app-snapshot-a",
+ "mchId": "mch-snapshot-a",
+ "privateKey": "private-key-snapshot-a",
+ "apiV3Key": webhookProviderTestEncryptionKey,
+ "publicKey": "public-key-snapshot-a",
+ "publicKeyId": "public-key-id-snapshot-a",
+ "certSerial": "cert-serial-snapshot-a",
+ })
+ wxpayConfigB := encryptWebhookProviderConfig(t, map[string]string{
+ "appId": "wx-app-snapshot-b",
+ "mchId": "mch-snapshot-b",
+ "privateKey": "private-key-snapshot-b",
+ "apiV3Key": webhookProviderTestEncryptionKey,
+ "publicKey": "public-key-snapshot-b",
+ "publicKeyId": "public-key-id-snapshot-b",
+ "certSerial": "cert-serial-snapshot-b",
+ })
+ instA, err := client.PaymentProviderInstance.Create().
+ SetProviderKey(payment.TypeWxpay).
+ SetName("wxpay-snapshot-a").
+ SetConfig(wxpayConfigA).
+ SetSupportedTypes("wxpay").
+ SetEnabled(true).
+ Save(ctx)
+ require.NoError(t, err)
+ _, err = client.PaymentProviderInstance.Create().
+ SetProviderKey(payment.TypeWxpay).
+ SetName("wxpay-snapshot-b").
+ SetConfig(wxpayConfigB).
+ SetSupportedTypes("wxpay").
+ SetEnabled(true).
+ Save(ctx)
+ require.NoError(t, err)
+
+ _, err = client.PaymentOrder.Create().
+ SetUserID(user.ID).
+ SetUserEmail(user.Email).
+ SetUserName(user.Username).
+ SetAmount(66).
+ SetPayAmount(66).
+ SetFeeRate(0).
+ SetRechargeCode("SNAPSHOT-WEBHOOK").
+ SetOutTradeNo("sub2_test_snapshot_webhook_order").
+ SetPaymentType(payment.TypeWxpay).
+ SetPaymentTradeNo("").
+ SetOrderType(payment.OrderTypeBalance).
+ SetStatus(OrderStatusPending).
+ SetExpiresAt(time.Now().Add(time.Hour)).
+ SetClientIP("127.0.0.1").
+ SetSrcHost("api.example.com").
+ SetProviderSnapshot(map[string]any{
+ "schema_version": 1,
+ "provider_instance_id": strconv.FormatInt(instA.ID, 10),
+ "provider_key": payment.TypeWxpay,
+ "payment_mode": "native",
+ }).
+ Save(ctx)
+ require.NoError(t, err)
+
+ svc := &PaymentService{
+ entClient: client,
+ loadBalancer: newWebhookProviderTestLoadBalancer(client),
+ registry: payment.NewRegistry(),
+ providersLoaded: true,
+ }
+
+ providers, err := svc.GetWebhookProviders(ctx, payment.TypeWxpay, "sub2_test_snapshot_webhook_order")
+ require.NoError(t, err)
+ require.Len(t, providers, 1)
+ require.Equal(t, payment.TypeWxpay, providers[0].ProviderKey())
+}
From 119f784d19bf6528aa45bb14a2ffcb2fe91dd6b2 Mon Sep 17 00:00:00 2001
From: IanShaw027
Date: Tue, 21 Apr 2026 12:57:35 +0800
Subject: [PATCH 104/189] fix: validate wxpay payments against order snapshots
---
backend/internal/payment/provider/wxpay.go | 43 +++++++++++++-
.../internal/payment/provider/wxpay_test.go | 28 +++++++++
backend/internal/payment/types.go | 20 ++++---
.../internal/service/payment_fulfillment.go | 57 ++++++++++++++++++-
.../service/payment_fulfillment_test.go | 43 ++++++++++++++
backend/internal/service/payment_order.go | 26 ++++++++-
.../service/payment_order_lifecycle.go | 2 +-
.../payment_order_provider_snapshot.go | 14 ++++-
.../payment_order_provider_snapshot_test.go | 28 ++++++++-
9 files changed, 239 insertions(+), 22 deletions(-)
diff --git a/backend/internal/payment/provider/wxpay.go b/backend/internal/payment/provider/wxpay.go
index 7d51dff0..30016338 100644
--- a/backend/internal/payment/provider/wxpay.go
+++ b/backend/internal/payment/provider/wxpay.go
@@ -32,6 +32,13 @@ const (
wxpayResultPath = "/payment/result"
)
+const (
+ wxpayMetadataAppID = "appid"
+ wxpayMetadataMerchantID = "mchid"
+ wxpayMetadataCurrency = "currency"
+ wxpayMetadataTradeState = "trade_state"
+)
+
// WeChat Pay create-payment modes.
const (
wxpayModeNative = "native"
@@ -355,6 +362,32 @@ func mapWxState(s string) string {
}
}
+func buildWxpayTransactionMetadata(tx *payments.Transaction) map[string]string {
+ if tx == nil {
+ return nil
+ }
+
+ metadata := map[string]string{}
+ if appID := wxSV(tx.Appid); appID != "" {
+ metadata[wxpayMetadataAppID] = appID
+ }
+ if merchantID := wxSV(tx.Mchid); merchantID != "" {
+ metadata[wxpayMetadataMerchantID] = merchantID
+ }
+ if tradeState := wxSV(tx.TradeState); tradeState != "" {
+ metadata[wxpayMetadataTradeState] = tradeState
+ }
+ if tx.Amount != nil {
+ if currency := wxSV(tx.Amount.Currency); currency != "" {
+ metadata[wxpayMetadataCurrency] = currency
+ }
+ }
+ if len(metadata) == 0 {
+ return nil
+ }
+ return metadata
+}
+
func (w *Wxpay) QueryOrder(ctx context.Context, tradeNo string) (*payment.QueryOrderResponse, error) {
c, err := w.ensureClient()
if err != nil {
@@ -379,7 +412,13 @@ func (w *Wxpay) QueryOrder(ctx context.Context, tradeNo string) (*payment.QueryO
if tx.SuccessTime != nil {
pa = *tx.SuccessTime
}
- return &payment.QueryOrderResponse{TradeNo: id, Status: mapWxState(wxSV(tx.TradeState)), Amount: amt, PaidAt: pa}, nil
+ return &payment.QueryOrderResponse{
+ TradeNo: id,
+ Status: mapWxState(wxSV(tx.TradeState)),
+ Amount: amt,
+ PaidAt: pa,
+ Metadata: buildWxpayTransactionMetadata(tx),
+ }, nil
}
func (w *Wxpay) VerifyNotification(ctx context.Context, rawBody string, headers map[string]string) (*payment.PaymentNotification, error) {
@@ -411,7 +450,7 @@ func (w *Wxpay) VerifyNotification(ctx context.Context, rawBody string, headers
}
return &payment.PaymentNotification{
TradeNo: wxSV(tx.TransactionId), OrderID: wxSV(tx.OutTradeNo),
- Amount: amt, Status: st, RawData: rawBody,
+ Amount: amt, Status: st, RawData: rawBody, Metadata: buildWxpayTransactionMetadata(&tx),
}, nil
}
diff --git a/backend/internal/payment/provider/wxpay_test.go b/backend/internal/payment/provider/wxpay_test.go
index 6d0006be..b3f4f648 100644
--- a/backend/internal/payment/provider/wxpay_test.go
+++ b/backend/internal/payment/provider/wxpay_test.go
@@ -10,6 +10,7 @@ import (
"github.com/Wei-Shaw/sub2api/internal/payment"
"github.com/wechatpay-apiv3/wechatpay-go/core"
+ "github.com/wechatpay-apiv3/wechatpay-go/services/payments"
"github.com/wechatpay-apiv3/wechatpay-go/services/payments/h5"
"github.com/wechatpay-apiv3/wechatpay-go/services/payments/jsapi"
"github.com/wechatpay-apiv3/wechatpay-go/services/payments/native"
@@ -102,6 +103,33 @@ func TestWxSV(t *testing.T) {
}
}
+func TestBuildWxpayTransactionMetadata(t *testing.T) {
+ t.Parallel()
+
+ tx := &payments.Transaction{
+ Appid: strPtr("wx-app-id"),
+ Mchid: strPtr("mch-id"),
+ TradeState: strPtr(wxpayTradeStateSuccess),
+ Amount: &payments.Amount{
+ Currency: strPtr(wxpayCurrency),
+ },
+ }
+
+ metadata := buildWxpayTransactionMetadata(tx)
+ if metadata[wxpayMetadataAppID] != "wx-app-id" {
+ t.Fatalf("appid = %q", metadata[wxpayMetadataAppID])
+ }
+ if metadata[wxpayMetadataMerchantID] != "mch-id" {
+ t.Fatalf("mchid = %q", metadata[wxpayMetadataMerchantID])
+ }
+ if metadata[wxpayMetadataCurrency] != wxpayCurrency {
+ t.Fatalf("currency = %q", metadata[wxpayMetadataCurrency])
+ }
+ if metadata[wxpayMetadataTradeState] != wxpayTradeStateSuccess {
+ t.Fatalf("trade_state = %q", metadata[wxpayMetadataTradeState])
+ }
+}
+
func strPtr(s string) *string {
return &s
}
diff --git a/backend/internal/payment/types.go b/backend/internal/payment/types.go
index bb125247..29abf82b 100644
--- a/backend/internal/payment/types.go
+++ b/backend/internal/payment/types.go
@@ -149,19 +149,21 @@ type CreatePaymentResponse struct {
// QueryOrderResponse describes the payment status from the upstream provider.
type QueryOrderResponse struct {
- TradeNo string
- Status string // "pending", "paid", "failed", "refunded"
- Amount float64 // Amount in CNY
- PaidAt string // RFC3339 timestamp or empty
+ TradeNo string
+ Status string // "pending", "paid", "failed", "refunded"
+ Amount float64 // Amount in CNY
+ PaidAt string // RFC3339 timestamp or empty
+ Metadata map[string]string
}
// PaymentNotification is the parsed result of a webhook/notify callback.
type PaymentNotification struct {
- TradeNo string
- OrderID string
- Amount float64
- Status string // "success" or "failed"
- RawData string // Raw notification body for audit
+ TradeNo string
+ OrderID string
+ Amount float64
+ Status string // "success" or "failed"
+ RawData string // Raw notification body for audit
+ Metadata map[string]string
}
// RefundRequest contains the parameters for requesting a refund.
diff --git a/backend/internal/service/payment_fulfillment.go b/backend/internal/service/payment_fulfillment.go
index 9cb03cca..7bde03c8 100644
--- a/backend/internal/service/payment_fulfillment.go
+++ b/backend/internal/service/payment_fulfillment.go
@@ -28,14 +28,14 @@ func (s *PaymentService) HandlePaymentNotification(ctx context.Context, n *payme
// Fallback: try legacy format (sub2_N where N is DB ID)
trimmed := strings.TrimPrefix(n.OrderID, orderIDPrefix)
if oid, parseErr := strconv.ParseInt(trimmed, 10, 64); parseErr == nil {
- return s.confirmPayment(ctx, oid, n.TradeNo, n.Amount, pk)
+ return s.confirmPayment(ctx, oid, n.TradeNo, n.Amount, pk, n.Metadata)
}
return fmt.Errorf("order not found for out_trade_no: %s", n.OrderID)
}
- return s.confirmPayment(ctx, order.ID, n.TradeNo, n.Amount, pk)
+ return s.confirmPayment(ctx, order.ID, n.TradeNo, n.Amount, pk, n.Metadata)
}
-func (s *PaymentService) confirmPayment(ctx context.Context, oid int64, tradeNo string, paid float64, pk string) error {
+func (s *PaymentService) confirmPayment(ctx context.Context, oid int64, tradeNo string, paid float64, pk string, metadata map[string]string) error {
o, err := s.entClient.PaymentOrder.Get(ctx, oid)
if err != nil {
slog.Error("order not found", "orderID", oid)
@@ -54,6 +54,13 @@ func (s *PaymentService) confirmPayment(ctx context.Context, oid int64, tradeNo
})
return fmt.Errorf("provider mismatch: expected %s, got %s", expectedProviderKey, pk)
}
+ if err := validateProviderNotificationMetadata(o, pk, metadata); err != nil {
+ s.writeAuditLog(ctx, o.ID, "PAYMENT_PROVIDER_METADATA_MISMATCH", pk, map[string]any{
+ "detail": err.Error(),
+ "tradeNo": tradeNo,
+ })
+ return err
+ }
// Skip amount check when paid=0 (e.g. QueryOrder doesn't return amount).
// Also skip if paid is NaN/Inf (malformed provider data).
if paid > 0 && !math.IsNaN(paid) && !math.IsInf(paid, 0) {
@@ -69,6 +76,50 @@ func (s *PaymentService) confirmPayment(ctx context.Context, oid int64, tradeNo
return s.toPaid(ctx, o, tradeNo, paid, pk)
}
+func validateProviderNotificationMetadata(order *dbent.PaymentOrder, providerKey string, metadata map[string]string) error {
+ if order == nil || len(metadata) == 0 || !strings.EqualFold(strings.TrimSpace(providerKey), payment.TypeWxpay) {
+ return nil
+ }
+
+ snapshot := psOrderProviderSnapshot(order)
+ if snapshot == nil {
+ return nil
+ }
+
+ if expected := strings.TrimSpace(snapshot.MerchantAppID); expected != "" {
+ actual := strings.TrimSpace(metadata["appid"])
+ if actual == "" {
+ return fmt.Errorf("wxpay notification missing appid")
+ }
+ if !strings.EqualFold(expected, actual) {
+ return fmt.Errorf("wxpay appid mismatch: expected %s, got %s", expected, actual)
+ }
+ }
+ if expected := strings.TrimSpace(snapshot.MerchantID); expected != "" {
+ actual := strings.TrimSpace(metadata["mchid"])
+ if actual == "" {
+ return fmt.Errorf("wxpay notification missing mchid")
+ }
+ if !strings.EqualFold(expected, actual) {
+ return fmt.Errorf("wxpay mchid mismatch: expected %s, got %s", expected, actual)
+ }
+ }
+ if expected := strings.TrimSpace(snapshot.Currency); expected != "" {
+ actual := strings.ToUpper(strings.TrimSpace(metadata["currency"]))
+ if actual == "" {
+ return fmt.Errorf("wxpay notification missing currency")
+ }
+ if !strings.EqualFold(expected, actual) {
+ return fmt.Errorf("wxpay currency mismatch: expected %s, got %s", expected, actual)
+ }
+ }
+ if actual := strings.TrimSpace(metadata["trade_state"]); actual != "" && !strings.EqualFold(actual, "SUCCESS") {
+ return fmt.Errorf("wxpay trade_state mismatch: expected SUCCESS, got %s", actual)
+ }
+
+ return nil
+}
+
func expectedNotificationProviderKey(registry *payment.Registry, orderPaymentType string, orderProviderKey string, instanceProviderKey string) string {
if key := strings.TrimSpace(instanceProviderKey); key != "" {
return key
diff --git a/backend/internal/service/payment_fulfillment_test.go b/backend/internal/service/payment_fulfillment_test.go
index 3ce82973..8883d3b8 100644
--- a/backend/internal/service/payment_fulfillment_test.go
+++ b/backend/internal/service/payment_fulfillment_test.go
@@ -264,3 +264,46 @@ func TestExpectedNotificationProviderKeyForOrderUsesSnapshotProviderKey(t *testi
expectedNotificationProviderKeyForOrder(registry, order, ""),
)
}
+
+func TestValidateProviderNotificationMetadataRejectsWxpaySnapshotMismatch(t *testing.T) {
+ t.Parallel()
+
+ order := &dbent.PaymentOrder{
+ PaymentType: payment.TypeWxpay,
+ ProviderSnapshot: map[string]any{
+ "schema_version": 1,
+ "merchant_app_id": "wx-app-expected",
+ "merchant_id": "mch-expected",
+ "currency": "CNY",
+ },
+ }
+
+ err := validateProviderNotificationMetadata(order, payment.TypeWxpay, map[string]string{
+ "appid": "wx-app-other",
+ "mchid": "mch-expected",
+ "currency": "CNY",
+ "trade_state": "SUCCESS",
+ })
+ assert.ErrorContains(t, err, "wxpay appid mismatch")
+}
+
+func TestValidateProviderNotificationMetadataAllowsLegacyOrdersWithoutSnapshotFields(t *testing.T) {
+ t.Parallel()
+
+ order := &dbent.PaymentOrder{
+ PaymentType: payment.TypeWxpay,
+ ProviderSnapshot: map[string]any{
+ "schema_version": 1,
+ "provider_instance_id": "9",
+ "provider_key": payment.TypeWxpay,
+ },
+ }
+
+ err := validateProviderNotificationMetadata(order, payment.TypeWxpay, map[string]string{
+ "appid": "wx-app-runtime",
+ "mchid": "mch-runtime",
+ "currency": "CNY",
+ "trade_state": "SUCCESS",
+ })
+ assert.NoError(t, err)
+}
diff --git a/backend/internal/service/payment_order.go b/backend/internal/service/payment_order.go
index 1f01bc11..6ee490a8 100644
--- a/backend/internal/service/payment_order.go
+++ b/backend/internal/service/payment_order.go
@@ -139,7 +139,7 @@ func (s *PaymentService) createOrderInTx(ctx context.Context, req CreateOrderReq
tm = defaultOrderTimeoutMin
}
exp := time.Now().Add(time.Duration(tm) * time.Minute)
- providerSnapshot := buildPaymentOrderProviderSnapshot(sel)
+ providerSnapshot := buildPaymentOrderProviderSnapshot(sel, req)
selectedInstanceID := ""
selectedProviderKey := ""
if sel != nil {
@@ -208,13 +208,13 @@ func (s *PaymentService) checkPendingLimit(ctx context.Context, tx *dbent.Tx, us
return nil
}
-func buildPaymentOrderProviderSnapshot(sel *payment.InstanceSelection) map[string]any {
+func buildPaymentOrderProviderSnapshot(sel *payment.InstanceSelection, req CreateOrderRequest) map[string]any {
if sel == nil {
return nil
}
snapshot := map[string]any{}
- snapshot["schema_version"] = 1
+ snapshot["schema_version"] = 2
instanceID := strings.TrimSpace(sel.InstanceID)
if instanceID != "" {
@@ -231,12 +231,32 @@ func buildPaymentOrderProviderSnapshot(sel *payment.InstanceSelection) map[strin
snapshot["payment_mode"] = paymentMode
}
+ if providerKey == payment.TypeWxpay {
+ if merchantAppID := paymentOrderSnapshotWxpayAppID(sel, req); merchantAppID != "" {
+ snapshot["merchant_app_id"] = merchantAppID
+ }
+ if merchantID := strings.TrimSpace(sel.Config["mchId"]); merchantID != "" {
+ snapshot["merchant_id"] = merchantID
+ }
+ snapshot["currency"] = "CNY"
+ }
+
if len(snapshot) == 1 {
return nil
}
return snapshot
}
+func paymentOrderSnapshotWxpayAppID(sel *payment.InstanceSelection, req CreateOrderRequest) string {
+ if sel == nil || strings.TrimSpace(sel.ProviderKey) != payment.TypeWxpay {
+ return ""
+ }
+ if strings.TrimSpace(req.OpenID) != "" {
+ return strings.TrimSpace(provider.ResolveWxpayJSAPIAppID(sel.Config))
+ }
+ return strings.TrimSpace(sel.Config["appId"])
+}
+
func (s *PaymentService) checkDailyLimit(ctx context.Context, tx *dbent.Tx, userID int64, amount, limit float64) error {
if limit <= 0 {
return nil
diff --git a/backend/internal/service/payment_order_lifecycle.go b/backend/internal/service/payment_order_lifecycle.go
index 1564c36d..c11baac1 100644
--- a/backend/internal/service/payment_order_lifecycle.go
+++ b/backend/internal/service/payment_order_lifecycle.go
@@ -163,7 +163,7 @@ func (s *PaymentService) checkPaid(ctx context.Context, o *dbent.PaymentOrder) s
}
notificationTradeNo = upstreamTradeNo
}
- if err := s.HandlePaymentNotification(ctx, &payment.PaymentNotification{TradeNo: notificationTradeNo, OrderID: o.OutTradeNo, Amount: resp.Amount, Status: payment.ProviderStatusSuccess}, prov.ProviderKey()); err != nil {
+ if err := s.HandlePaymentNotification(ctx, &payment.PaymentNotification{TradeNo: notificationTradeNo, OrderID: o.OutTradeNo, Amount: resp.Amount, Status: payment.ProviderStatusSuccess, Metadata: resp.Metadata}, prov.ProviderKey()); err != nil {
slog.Error("fulfillment failed during checkPaid", "orderID", o.ID, "error", err)
// Still return already_paid — order was paid, fulfillment can be retried
}
diff --git a/backend/internal/service/payment_order_provider_snapshot.go b/backend/internal/service/payment_order_provider_snapshot.go
index 9a0aa106..31a790c7 100644
--- a/backend/internal/service/payment_order_provider_snapshot.go
+++ b/backend/internal/service/payment_order_provider_snapshot.go
@@ -15,6 +15,9 @@ type paymentOrderProviderSnapshot struct {
ProviderInstanceID string
ProviderKey string
PaymentMode string
+ MerchantAppID string
+ MerchantID string
+ Currency string
}
func psOrderProviderSnapshot(order *dbent.PaymentOrder) *paymentOrderProviderSnapshot {
@@ -27,8 +30,17 @@ func psOrderProviderSnapshot(order *dbent.PaymentOrder) *paymentOrderProviderSna
ProviderInstanceID: psSnapshotStringValue(order.ProviderSnapshot["provider_instance_id"]),
ProviderKey: psSnapshotStringValue(order.ProviderSnapshot["provider_key"]),
PaymentMode: psSnapshotStringValue(order.ProviderSnapshot["payment_mode"]),
+ MerchantAppID: psSnapshotStringValue(order.ProviderSnapshot["merchant_app_id"]),
+ MerchantID: psSnapshotStringValue(order.ProviderSnapshot["merchant_id"]),
+ Currency: psSnapshotStringValue(order.ProviderSnapshot["currency"]),
}
- if snapshot.SchemaVersion == 0 && snapshot.ProviderInstanceID == "" && snapshot.ProviderKey == "" && snapshot.PaymentMode == "" {
+ if snapshot.SchemaVersion == 0 &&
+ snapshot.ProviderInstanceID == "" &&
+ snapshot.ProviderKey == "" &&
+ snapshot.PaymentMode == "" &&
+ snapshot.MerchantAppID == "" &&
+ snapshot.MerchantID == "" &&
+ snapshot.Currency == "" {
return nil
}
return snapshot
diff --git a/backend/internal/service/payment_order_provider_snapshot_test.go b/backend/internal/service/payment_order_provider_snapshot_test.go
index c75566bc..bc6666a8 100644
--- a/backend/internal/service/payment_order_provider_snapshot_test.go
+++ b/backend/internal/service/payment_order_provider_snapshot_test.go
@@ -26,18 +26,21 @@ func TestBuildPaymentOrderProviderSnapshot_ExcludesSensitiveConfig(t *testing.T)
},
}
- snapshot := buildPaymentOrderProviderSnapshot(sel)
+ snapshot := buildPaymentOrderProviderSnapshot(sel, CreateOrderRequest{})
require.Equal(t, map[string]any{
- "schema_version": 1,
+ "schema_version": 2,
"provider_instance_id": "12",
"provider_key": payment.TypeWxpay,
"payment_mode": "popup",
+ "merchant_app_id": "wx-app-id",
+ "currency": "CNY",
}, snapshot)
require.NotContains(t, snapshot, "config")
require.NotContains(t, snapshot, "privateKey")
require.NotContains(t, snapshot, "apiV3Key")
require.NotContains(t, snapshot, "supported_types")
require.NotContains(t, snapshot, "instance_name")
+ require.NotContains(t, snapshot, "merchant_id")
}
func TestCreateOrderInTx_WritesProviderSnapshot(t *testing.T) {
@@ -98,7 +101,7 @@ func TestCreateOrderInTx_WritesProviderSnapshot(t *testing.T) {
require.NoError(t, err)
require.Equal(t, strconv.FormatInt(instance.ID, 10), valueOrEmpty(order.ProviderInstanceID))
require.Equal(t, payment.TypeAlipay, valueOrEmpty(order.ProviderKey))
- require.Equal(t, float64(1), order.ProviderSnapshot["schema_version"])
+ require.Equal(t, float64(2), order.ProviderSnapshot["schema_version"])
require.Equal(t, strconv.FormatInt(instance.ID, 10), order.ProviderSnapshot["provider_instance_id"])
require.Equal(t, payment.TypeAlipay, order.ProviderSnapshot["provider_key"])
require.Equal(t, "redirect", order.ProviderSnapshot["payment_mode"])
@@ -108,6 +111,25 @@ func TestCreateOrderInTx_WritesProviderSnapshot(t *testing.T) {
require.NotContains(t, order.ProviderSnapshot, "instance_name")
}
+func TestBuildPaymentOrderProviderSnapshot_UsesWxpayJSAPIAppIDForOpenIDOrders(t *testing.T) {
+ t.Parallel()
+
+ snapshot := buildPaymentOrderProviderSnapshot(&payment.InstanceSelection{
+ InstanceID: "88",
+ ProviderKey: payment.TypeWxpay,
+ Config: map[string]string{
+ "appId": "wx-open-app",
+ "mpAppId": "wx-mp-app",
+ "mchId": "mch-88",
+ },
+ PaymentMode: "jsapi",
+ }, CreateOrderRequest{OpenID: "openid-123"})
+
+ require.Equal(t, "wx-mp-app", snapshot["merchant_app_id"])
+ require.Equal(t, "mch-88", snapshot["merchant_id"])
+ require.Equal(t, "CNY", snapshot["currency"])
+}
+
func valueOrEmpty(v *string) string {
if v == nil {
return ""
From 276ce052a30eab0148b95cc7f22e1feb955f75b5 Mon Sep 17 00:00:00 2001
From: IanShaw027
Date: Tue, 21 Apr 2026 13:01:21 +0800
Subject: [PATCH 105/189] fix: align payment recovery query refs and resume
authority
---
.../service/payment_order_lifecycle.go | 61 +++++++++++--
.../service/payment_order_lifecycle_test.go | 89 +++++++++++++++++++
.../internal/service/payment_resume_lookup.go | 15 +++-
.../service/payment_resume_lookup_test.go | 57 ++++++++++++
4 files changed, 212 insertions(+), 10 deletions(-)
diff --git a/backend/internal/service/payment_order_lifecycle.go b/backend/internal/service/payment_order_lifecycle.go
index c11baac1..a192f599 100644
--- a/backend/internal/service/payment_order_lifecycle.go
+++ b/backend/internal/service/payment_order_lifecycle.go
@@ -5,6 +5,7 @@ import (
"fmt"
"log/slog"
"strconv"
+ "strings"
"time"
dbent "github.com/Wei-Shaw/sub2api/ent"
@@ -139,20 +140,18 @@ func (s *PaymentService) checkPaid(ctx context.Context, o *dbent.PaymentOrder) s
if err != nil {
return ""
}
- // Use OutTradeNo as fallback when PaymentTradeNo is empty
- // (e.g. EasyPay popup mode where trade_no arrives only via notify callback)
- tradeNo := o.PaymentTradeNo
- if tradeNo == "" {
- tradeNo = o.OutTradeNo
+ queryRef := paymentOrderQueryReference(o, prov)
+ if queryRef == "" {
+ return ""
}
- resp, err := prov.QueryOrder(ctx, tradeNo)
+ resp, err := prov.QueryOrder(ctx, queryRef)
if err != nil {
slog.Warn("query upstream failed", "orderID", o.ID, "error", err)
return ""
}
if resp.Status == payment.ProviderStatusPaid {
notificationTradeNo := o.PaymentTradeNo
- if upstreamTradeNo := resp.TradeNo; upstreamTradeNo != "" && upstreamTradeNo != notificationTradeNo {
+ if upstreamTradeNo := strings.TrimSpace(resp.TradeNo); paymentOrderShouldPersistUpstreamTradeNo(queryRef, upstreamTradeNo, notificationTradeNo) {
if _, updateErr := s.entClient.PaymentOrder.Update().
Where(paymentorder.IDEQ(o.ID)).
SetPaymentTradeNo(upstreamTradeNo).
@@ -170,11 +169,57 @@ func (s *PaymentService) checkPaid(ctx context.Context, o *dbent.PaymentOrder) s
return checkPaidResultAlreadyPaid
}
if cp, ok := prov.(payment.CancelableProvider); ok {
- _ = cp.CancelPayment(ctx, tradeNo)
+ _ = cp.CancelPayment(ctx, queryRef)
}
return ""
}
+func paymentOrderQueryReference(order *dbent.PaymentOrder, prov payment.Provider) string {
+ if order == nil {
+ return ""
+ }
+
+ providerKey := ""
+ if prov != nil {
+ providerKey = strings.TrimSpace(prov.ProviderKey())
+ }
+ if providerKey == "" {
+ if snapshot := psOrderProviderSnapshot(order); snapshot != nil {
+ providerKey = strings.TrimSpace(snapshot.ProviderKey)
+ }
+ }
+ if providerKey == "" {
+ providerKey = strings.TrimSpace(psStringValue(order.ProviderKey))
+ }
+ if providerKey == "" {
+ providerKey = strings.TrimSpace(order.PaymentType)
+ }
+
+ switch payment.GetBasePaymentType(providerKey) {
+ case payment.TypeAlipay, payment.TypeEasyPay, payment.TypeWxpay:
+ return strings.TrimSpace(order.OutTradeNo)
+ default:
+ if tradeNo := strings.TrimSpace(order.PaymentTradeNo); tradeNo != "" {
+ return tradeNo
+ }
+ return strings.TrimSpace(order.OutTradeNo)
+ }
+}
+
+func paymentOrderShouldPersistUpstreamTradeNo(queryRef, upstreamTradeNo, currentTradeNo string) bool {
+ upstreamTradeNo = strings.TrimSpace(upstreamTradeNo)
+ if upstreamTradeNo == "" {
+ return false
+ }
+ if strings.EqualFold(upstreamTradeNo, strings.TrimSpace(currentTradeNo)) {
+ return false
+ }
+ if strings.EqualFold(upstreamTradeNo, strings.TrimSpace(queryRef)) {
+ return false
+ }
+ return true
+}
+
// VerifyOrderByOutTradeNo actively queries the upstream provider to check
// if a payment was made, and processes it if so. This handles the case where
// the provider's notify callback was missed (e.g. EasyPay popup mode).
diff --git a/backend/internal/service/payment_order_lifecycle_test.go b/backend/internal/service/payment_order_lifecycle_test.go
index 3d4773a4..3c6c65a5 100644
--- a/backend/internal/service/payment_order_lifecycle_test.go
+++ b/backend/internal/service/payment_order_lifecycle_test.go
@@ -234,6 +234,95 @@ func TestVerifyOrderByOutTradeNoBackfillsTradeNoFromPaidQuery(t *testing.T) {
require.Equal(t, user.ID, redeemRepo.useCalls[0].userID)
}
+func TestVerifyOrderByOutTradeNoUsesOutTradeNoWhenPaymentTradeNoAlreadyExistsForAlipay(t *testing.T) {
+ ctx := context.Background()
+ client := newPaymentOrderLifecycleTestClient(t)
+
+ user, err := client.User.Create().
+ SetEmail("checkpaid-existing-trade@example.com").
+ SetPasswordHash("hash").
+ SetUsername("checkpaid-existing-trade-user").
+ Save(ctx)
+ require.NoError(t, err)
+
+ order, err := client.PaymentOrder.Create().
+ SetUserID(user.ID).
+ SetUserEmail(user.Email).
+ SetUserName(user.Username).
+ SetAmount(88).
+ SetPayAmount(88).
+ SetFeeRate(0).
+ SetRechargeCode("CHECKPAID-EXISTING-TRADE-NO").
+ SetOutTradeNo("sub2_checkpaid_use_out_trade_no").
+ SetPaymentType(payment.TypeAlipay).
+ SetPaymentTradeNo("upstream-trade-existing").
+ SetOrderType(payment.OrderTypeBalance).
+ SetStatus(OrderStatusPending).
+ SetExpiresAt(time.Now().Add(time.Hour)).
+ SetClientIP("127.0.0.1").
+ SetSrcHost("api.example.com").
+ Save(ctx)
+ require.NoError(t, err)
+
+ userRepo := &mockUserRepo{
+ getByIDUser: &User{
+ ID: user.ID,
+ Email: user.Email,
+ Username: user.Username,
+ Balance: 0,
+ },
+ }
+ userRepo.updateBalanceFn = func(ctx context.Context, id int64, amount float64) error {
+ require.Equal(t, user.ID, id)
+ if userRepo.getByIDUser != nil {
+ userRepo.getByIDUser.Balance += amount
+ }
+ return nil
+ }
+ redeemRepo := &paymentOrderLifecycleRedeemRepo{
+ codesByCode: map[string]*RedeemCode{
+ order.RechargeCode: {
+ ID: 1,
+ Code: order.RechargeCode,
+ Type: RedeemTypeBalance,
+ Value: order.Amount,
+ Status: StatusUnused,
+ },
+ },
+ }
+ redeemService := NewRedeemService(
+ redeemRepo,
+ userRepo,
+ nil,
+ nil,
+ nil,
+ client,
+ nil,
+ )
+ registry := payment.NewRegistry()
+ provider := &paymentOrderLifecycleQueryProvider{
+ resp: &payment.QueryOrderResponse{
+ TradeNo: "upstream-trade-existing",
+ Status: payment.ProviderStatusPaid,
+ Amount: 88,
+ },
+ }
+ registry.Register(provider)
+
+ svc := &PaymentService{
+ entClient: client,
+ registry: registry,
+ redeemService: redeemService,
+ userRepo: userRepo,
+ providersLoaded: true,
+ }
+
+ got, err := svc.VerifyOrderByOutTradeNo(ctx, order.OutTradeNo, user.ID)
+ require.NoError(t, err)
+ require.Equal(t, order.OutTradeNo, provider.lastQueryTradeNo)
+ require.Equal(t, "upstream-trade-existing", got.PaymentTradeNo)
+}
+
func newPaymentOrderLifecycleTestClient(t *testing.T) *dbent.Client {
t.Helper()
diff --git a/backend/internal/service/payment_resume_lookup.go b/backend/internal/service/payment_resume_lookup.go
index 048e489a..05626aa6 100644
--- a/backend/internal/service/payment_resume_lookup.go
+++ b/backend/internal/service/payment_resume_lookup.go
@@ -21,10 +21,21 @@ func (s *PaymentService) GetPublicOrderByResumeToken(ctx context.Context, token
if claims.UserID > 0 && order.UserID != claims.UserID {
return nil, fmt.Errorf("resume token user mismatch")
}
- if claims.ProviderInstanceID != "" && strings.TrimSpace(psStringValue(order.ProviderInstanceID)) != claims.ProviderInstanceID {
+ snapshot := psOrderProviderSnapshot(order)
+ orderProviderInstanceID := strings.TrimSpace(psStringValue(order.ProviderInstanceID))
+ orderProviderKey := strings.TrimSpace(psStringValue(order.ProviderKey))
+ if snapshot != nil {
+ if snapshot.ProviderInstanceID != "" {
+ orderProviderInstanceID = snapshot.ProviderInstanceID
+ }
+ if snapshot.ProviderKey != "" {
+ orderProviderKey = snapshot.ProviderKey
+ }
+ }
+ if claims.ProviderInstanceID != "" && orderProviderInstanceID != claims.ProviderInstanceID {
return nil, fmt.Errorf("resume token provider instance mismatch")
}
- if claims.ProviderKey != "" && strings.TrimSpace(psStringValue(order.ProviderKey)) != claims.ProviderKey {
+ if claims.ProviderKey != "" && orderProviderKey != claims.ProviderKey {
return nil, fmt.Errorf("resume token provider key mismatch")
}
if claims.PaymentType != "" && strings.TrimSpace(order.PaymentType) != claims.PaymentType {
diff --git a/backend/internal/service/payment_resume_lookup_test.go b/backend/internal/service/payment_resume_lookup_test.go
index 3a50b5bc..946e7aa2 100644
--- a/backend/internal/service/payment_resume_lookup_test.go
+++ b/backend/internal/service/payment_resume_lookup_test.go
@@ -146,6 +146,63 @@ func TestGetPublicOrderByResumeTokenRejectsSnapshotMismatch(t *testing.T) {
require.Contains(t, err.Error(), "resume token")
}
+func TestGetPublicOrderByResumeTokenUsesSnapshotAuthorityWhenColumnsDiffer(t *testing.T) {
+ ctx := context.Background()
+ client := newPaymentConfigServiceTestClient(t)
+ user, err := client.User.Create().
+ SetEmail("resume-snapshot-authority@example.com").
+ SetPasswordHash("hash").
+ SetUsername("resume-snapshot-authority-user").
+ Save(ctx)
+ require.NoError(t, err)
+
+ order, err := client.PaymentOrder.Create().
+ SetUserID(user.ID).
+ SetUserEmail(user.Email).
+ SetUserName(user.Username).
+ SetAmount(88).
+ SetPayAmount(88).
+ SetFeeRate(0).
+ SetRechargeCode("RESUME-SNAPSHOT-AUTHORITY").
+ SetOutTradeNo("sub2_resume_snapshot_authority").
+ SetPaymentType(payment.TypeAlipay).
+ SetPaymentTradeNo("trade-snapshot-authority").
+ SetOrderType(payment.OrderTypeBalance).
+ SetStatus(OrderStatusPending).
+ SetExpiresAt(time.Now().Add(time.Hour)).
+ SetClientIP("127.0.0.1").
+ SetSrcHost("api.example.com").
+ SetProviderInstanceID("legacy-column-instance").
+ SetProviderKey(payment.TypeAlipay).
+ SetProviderSnapshot(map[string]any{
+ "schema_version": 2,
+ "provider_instance_id": "snapshot-instance",
+ "provider_key": payment.TypeEasyPay,
+ }).
+ Save(ctx)
+ require.NoError(t, err)
+
+ resumeSvc := NewPaymentResumeService([]byte("0123456789abcdef0123456789abcdef"))
+ token, err := resumeSvc.CreateToken(ResumeTokenClaims{
+ OrderID: order.ID,
+ UserID: user.ID,
+ ProviderInstanceID: "snapshot-instance",
+ ProviderKey: payment.TypeEasyPay,
+ PaymentType: payment.TypeAlipay,
+ CanonicalReturnURL: "https://app.example.com/payment/result",
+ })
+ require.NoError(t, err)
+
+ svc := &PaymentService{
+ entClient: client,
+ resumeService: resumeSvc,
+ }
+
+ got, err := svc.GetPublicOrderByResumeToken(ctx, token)
+ require.NoError(t, err)
+ require.Equal(t, order.ID, got.ID)
+}
+
func TestGetPublicOrderByResumeTokenChecksUpstreamForPendingOrder(t *testing.T) {
ctx := context.Background()
client := newPaymentConfigServiceTestClient(t)
From 64e401e22475336561966038968ffd0a9fba3a51 Mon Sep 17 00:00:00 2001
From: IanShaw027
Date: Tue, 21 Apr 2026 13:03:53 +0800
Subject: [PATCH 106/189] fix: tighten payment legacy fallback paths
---
.../internal/service/payment_fulfillment.go | 25 ++++++++++--
.../service/payment_fulfillment_test.go | 14 +++++++
.../service/payment_order_lifecycle.go | 38 +++++++++++++++++++
.../service/payment_order_lifecycle_test.go | 37 ++++++++++++++++++
4 files changed, 111 insertions(+), 3 deletions(-)
diff --git a/backend/internal/service/payment_fulfillment.go b/backend/internal/service/payment_fulfillment.go
index 7bde03c8..423ed80f 100644
--- a/backend/internal/service/payment_fulfillment.go
+++ b/backend/internal/service/payment_fulfillment.go
@@ -25,9 +25,9 @@ func (s *PaymentService) HandlePaymentNotification(ctx context.Context, n *payme
// Look up order by out_trade_no (the external order ID we sent to the provider)
order, err := s.entClient.PaymentOrder.Query().Where(paymentorder.OutTradeNo(n.OrderID)).Only(ctx)
if err != nil {
- // Fallback: try legacy format (sub2_N where N is DB ID)
- trimmed := strings.TrimPrefix(n.OrderID, orderIDPrefix)
- if oid, parseErr := strconv.ParseInt(trimmed, 10, 64); parseErr == nil {
+ // Fallback only for true legacy "sub2_N" DB-ID payloads when the
+ // current out_trade_no lookup genuinely did not find an order.
+ if oid, ok := parseLegacyPaymentOrderID(n.OrderID, err); ok {
return s.confirmPayment(ctx, oid, n.TradeNo, n.Amount, pk, n.Metadata)
}
return fmt.Errorf("order not found for out_trade_no: %s", n.OrderID)
@@ -35,6 +35,25 @@ func (s *PaymentService) HandlePaymentNotification(ctx context.Context, n *payme
return s.confirmPayment(ctx, order.ID, n.TradeNo, n.Amount, pk, n.Metadata)
}
+func parseLegacyPaymentOrderID(orderID string, lookupErr error) (int64, bool) {
+ if !dbent.IsNotFound(lookupErr) {
+ return 0, false
+ }
+ orderID = strings.TrimSpace(orderID)
+ if !strings.HasPrefix(orderID, orderIDPrefix) {
+ return 0, false
+ }
+ trimmed := strings.TrimPrefix(orderID, orderIDPrefix)
+ if trimmed == "" || trimmed == orderID {
+ return 0, false
+ }
+ oid, err := strconv.ParseInt(trimmed, 10, 64)
+ if err != nil || oid <= 0 {
+ return 0, false
+ }
+ return oid, true
+}
+
func (s *PaymentService) confirmPayment(ctx context.Context, oid int64, tradeNo string, paid float64, pk string, metadata map[string]string) error {
o, err := s.entClient.PaymentOrder.Get(ctx, oid)
if err != nil {
diff --git a/backend/internal/service/payment_fulfillment_test.go b/backend/internal/service/payment_fulfillment_test.go
index 8883d3b8..d70f8946 100644
--- a/backend/internal/service/payment_fulfillment_test.go
+++ b/backend/internal/service/payment_fulfillment_test.go
@@ -307,3 +307,17 @@ func TestValidateProviderNotificationMetadataAllowsLegacyOrdersWithoutSnapshotFi
})
assert.NoError(t, err)
}
+
+func TestParseLegacyPaymentOrderID(t *testing.T) {
+ t.Parallel()
+
+ oid, ok := parseLegacyPaymentOrderID("sub2_42", &dbent.NotFoundError{})
+ assert.True(t, ok)
+ assert.EqualValues(t, 42, oid)
+
+ _, ok = parseLegacyPaymentOrderID("42", &dbent.NotFoundError{})
+ assert.False(t, ok)
+
+ _, ok = parseLegacyPaymentOrderID("sub2_42", errors.New("db down"))
+ assert.False(t, ok)
+}
diff --git a/backend/internal/service/payment_order_lifecycle.go b/backend/internal/service/payment_order_lifecycle.go
index a192f599..ccab7c11 100644
--- a/backend/internal/service/payment_order_lifecycle.go
+++ b/backend/internal/service/payment_order_lifecycle.go
@@ -292,10 +292,48 @@ func (s *PaymentService) getOrderProvider(ctx context.Context, o *dbent.PaymentO
if inst != nil {
return s.createProviderFromInstance(ctx, inst)
}
+ if !paymentOrderAllowsRegistryFallback(o) {
+ return nil, fmt.Errorf("order %d provider instance is unresolved", o.ID)
+ }
+ providerKey := paymentOrderFallbackProviderKey(s.registry, o)
+ if providerKey == "" {
+ return nil, fmt.Errorf("order %d provider fallback key is missing", o.ID)
+ }
+ if !s.webhookRegistryFallbackAllowed(ctx, providerKey) {
+ return nil, fmt.Errorf("order %d provider fallback is ambiguous for %s", o.ID, providerKey)
+ }
s.EnsureProviders(ctx)
return s.registry.GetProvider(o.PaymentType)
}
+func paymentOrderAllowsRegistryFallback(order *dbent.PaymentOrder) bool {
+ if order == nil {
+ return false
+ }
+ if psOrderProviderSnapshot(order) != nil {
+ return false
+ }
+ if strings.TrimSpace(psStringValue(order.ProviderInstanceID)) != "" {
+ return false
+ }
+ if strings.TrimSpace(psStringValue(order.ProviderKey)) != "" {
+ return false
+ }
+ return true
+}
+
+func paymentOrderFallbackProviderKey(registry *payment.Registry, order *dbent.PaymentOrder) string {
+ if order == nil {
+ return ""
+ }
+ if registry != nil {
+ if key := strings.TrimSpace(registry.GetProviderKey(payment.PaymentType(order.PaymentType))); key != "" {
+ return key
+ }
+ }
+ return strings.TrimSpace(payment.GetBasePaymentType(strings.TrimSpace(order.PaymentType)))
+}
+
func (s *PaymentService) createProviderFromInstance(ctx context.Context, inst *dbent.PaymentProviderInstance) (payment.Provider, error) {
if inst == nil {
return nil, fmt.Errorf("payment provider instance is missing")
diff --git a/backend/internal/service/payment_order_lifecycle_test.go b/backend/internal/service/payment_order_lifecycle_test.go
index 3c6c65a5..39993a2f 100644
--- a/backend/internal/service/payment_order_lifecycle_test.go
+++ b/backend/internal/service/payment_order_lifecycle_test.go
@@ -323,6 +323,43 @@ func TestVerifyOrderByOutTradeNoUsesOutTradeNoWhenPaymentTradeNoAlreadyExistsFor
require.Equal(t, "upstream-trade-existing", got.PaymentTradeNo)
}
+func TestPaymentOrderAllowsRegistryFallbackOnlyForLegacyOrdersWithoutPinnedProviderState(t *testing.T) {
+ t.Parallel()
+
+ require.True(t, paymentOrderAllowsRegistryFallback(&dbent.PaymentOrder{
+ PaymentType: payment.TypeAlipay,
+ }))
+
+ instanceID := "12"
+ require.False(t, paymentOrderAllowsRegistryFallback(&dbent.PaymentOrder{
+ PaymentType: payment.TypeAlipay,
+ ProviderInstanceID: &instanceID,
+ }))
+
+ require.False(t, paymentOrderAllowsRegistryFallback(&dbent.PaymentOrder{
+ PaymentType: payment.TypeAlipay,
+ ProviderSnapshot: map[string]any{
+ "schema_version": 2,
+ "provider_instance_id": "12",
+ },
+ }))
+}
+
+func TestPaymentOrderQueryReferenceUsesOutTradeNoForOfficialProviders(t *testing.T) {
+ t.Parallel()
+
+ order := &dbent.PaymentOrder{
+ PaymentType: payment.TypeWxpay,
+ OutTradeNo: "sub2_out_trade_no",
+ PaymentTradeNo: "wx-transaction-id",
+ }
+
+ require.Equal(t, "sub2_out_trade_no", paymentOrderQueryReference(order, &paymentOrderLifecycleQueryProvider{}))
+ require.Equal(t, "sub2_out_trade_no", paymentOrderQueryReference(order, paymentFulfillmentTestProvider{
+ key: payment.TypeWxpay,
+ }))
+}
+
func newPaymentOrderLifecycleTestClient(t *testing.T) *dbent.Client {
t.Helper()
From ebd053c87ea752fa1845944c303ed09ea6b181be Mon Sep 17 00:00:00 2001
From: IanShaw027
Date: Tue, 21 Apr 2026 13:07:40 +0800
Subject: [PATCH 107/189] docs: clarify openai scheduler flag semantics
---
frontend/src/views/admin/SettingsView.vue | 6 +++---
.../src/views/admin/__tests__/SettingsView.spec.ts | 10 ++++++++++
2 files changed, 13 insertions(+), 3 deletions(-)
diff --git a/frontend/src/views/admin/SettingsView.vue b/frontend/src/views/admin/SettingsView.vue
index 9fb8da41..e56afe5f 100644
--- a/frontend/src/views/admin/SettingsView.vue
+++ b/frontend/src/views/admin/SettingsView.vue
@@ -1877,13 +1877,13 @@
- {{ localText('OpenAI 高级调度器', 'OpenAI advanced scheduler') }}
+ {{ localText('OpenAI 实验调度策略', 'OpenAI experimental scheduler policy') }}
{{
localText(
- '切换 OpenAI 侧新增的高级调度开关,供当前分支实验性调度逻辑使用。',
- 'Toggles the new OpenAI advanced scheduler flag for the experimental routing logic on this branch.'
+ '默认关闭。开启后仅影响本网关在 OpenAI 账号间的实验性调度选择逻辑,不代表上游 OpenAI 官方能力。',
+ 'Disabled by default. When enabled, this only changes the gateway\'s experimental account-selection policy for OpenAI traffic; it does not indicate an upstream OpenAI capability.'
)
}}
diff --git a/frontend/src/views/admin/__tests__/SettingsView.spec.ts b/frontend/src/views/admin/__tests__/SettingsView.spec.ts
index b6f8ab17..3541d994 100644
--- a/frontend/src/views/admin/__tests__/SettingsView.spec.ts
+++ b/frontend/src/views/admin/__tests__/SettingsView.spec.ts
@@ -472,4 +472,14 @@ describe('admin SettingsView payment visible method controls', () => {
expect(showError).toHaveBeenCalled()
expect(String(showError.mock.calls.at(-1)?.[0] ?? '')).toContain('支付来源')
})
+
+ it('renders advanced scheduler copy as local experimental gateway policy', async () => {
+ const wrapper = mountView()
+
+ await flushPromises()
+
+ expect(wrapper.text()).toContain('OpenAI 实验调度策略')
+ expect(wrapper.text()).toContain('默认关闭。开启后仅影响本网关在 OpenAI 账号间的实验性调度选择逻辑')
+ expect(wrapper.text()).not.toContain('OpenAI 高级调度器')
+ })
})
From 267844ebe63261e4ebfe6bb485fb06fc32df4719 Mon Sep 17 00:00:00 2001
From: IanShaw027
Date: Tue, 21 Apr 2026 13:10:59 +0800
Subject: [PATCH 108/189] fix: fail closed for legacy refund provider
resolution
---
backend/internal/service/payment_refund.go | 44 ++++++-
.../internal/service/payment_refund_test.go | 117 ++++++++++++++++++
2 files changed, 158 insertions(+), 3 deletions(-)
create mode 100644 backend/internal/service/payment_refund_test.go
diff --git a/backend/internal/service/payment_refund.go b/backend/internal/service/payment_refund.go
index fbaeff99..6883056c 100644
--- a/backend/internal/service/payment_refund.go
+++ b/backend/internal/service/payment_refund.go
@@ -43,6 +43,37 @@ func (s *PaymentService) getOrderProviderInstance(ctx context.Context, o *dbent.
return s.entClient.PaymentProviderInstance.Get(ctx, instID)
}
+// getRefundOrderProviderInstance resolves the provider instance for refund paths.
+// Refunds must be pinned to an explicit historical binding, so legacy
+// "best-effort" provider guessing is intentionally not allowed here.
+func (s *PaymentService) getRefundOrderProviderInstance(ctx context.Context, o *dbent.PaymentOrder) (*dbent.PaymentProviderInstance, error) {
+ if s == nil || s.entClient == nil || o == nil {
+ return nil, nil
+ }
+
+ if snapshot := psOrderProviderSnapshot(o); snapshot != nil {
+ return s.resolveSnapshotOrderProviderInstance(ctx, o, snapshot)
+ }
+
+ instIDStr := strings.TrimSpace(psStringValue(o.ProviderInstanceID))
+ if instIDStr == "" {
+ return nil, nil
+ }
+
+ instID, err := strconv.ParseInt(instIDStr, 10, 64)
+ if err != nil {
+ return nil, fmt.Errorf("order %d refund provider instance id is invalid: %s", o.ID, instIDStr)
+ }
+ inst, err := s.entClient.PaymentProviderInstance.Get(ctx, instID)
+ if err != nil {
+ if dbent.IsNotFound(err) {
+ return nil, fmt.Errorf("order %d refund provider instance %s is missing", o.ID, instIDStr)
+ }
+ return nil, err
+ }
+ return inst, nil
+}
+
func (s *PaymentService) resolveUniqueLegacyOrderProviderInstance(ctx context.Context, o *dbent.PaymentOrder) (*dbent.PaymentProviderInstance, error) {
paymentType := payment.GetBasePaymentType(strings.TrimSpace(o.PaymentType))
providerKey := strings.TrimSpace(psStringValue(o.ProviderKey))
@@ -157,7 +188,7 @@ func (s *PaymentService) validateRefundRequest(ctx context.Context, oid, uid int
return nil, infraerrors.BadRequest("INVALID_STATUS", "only completed orders can request refund")
}
// Check provider instance allows user refund
- inst, err := s.getOrderProviderInstance(ctx, o)
+ inst, err := s.getRefundOrderProviderInstance(ctx, o)
if err != nil || inst == nil {
return nil, infraerrors.Forbidden("USER_REFUND_DISABLED", "refund is not available for this order")
}
@@ -177,7 +208,7 @@ func (s *PaymentService) PrepareRefund(ctx context.Context, oid int64, amt float
return nil, nil, infraerrors.BadRequest("INVALID_STATUS", "order status does not allow refund")
}
// Check provider instance allows admin refund
- inst, instErr := s.getOrderProviderInstance(ctx, o)
+ inst, instErr := s.getRefundOrderProviderInstance(ctx, o)
if instErr != nil {
slog.Warn("refund: provider instance lookup failed", "orderID", oid, "error", instErr)
return nil, nil, infraerrors.InternalServer("PROVIDER_LOOKUP_FAILED", "failed to look up payment provider for this order")
@@ -314,7 +345,14 @@ func (s *PaymentService) gwRefund(ctx context.Context, p *RefundPlan) error {
// getRefundProvider creates a provider using the order's original instance config.
// Delegates to getOrderProvider which handles instance lookup and fallback.
func (s *PaymentService) getRefundProvider(ctx context.Context, o *dbent.PaymentOrder) (payment.Provider, error) {
- return s.getOrderProvider(ctx, o)
+ inst, err := s.getRefundOrderProviderInstance(ctx, o)
+ if err != nil {
+ return nil, err
+ }
+ if inst == nil {
+ return nil, fmt.Errorf("refund provider instance is unavailable for order %d", o.ID)
+ }
+ return s.createProviderFromInstance(ctx, inst)
}
func (s *PaymentService) handleGwFail(ctx context.Context, p *RefundPlan, gErr error) (*RefundResult, error) {
diff --git a/backend/internal/service/payment_refund_test.go b/backend/internal/service/payment_refund_test.go
new file mode 100644
index 00000000..95104618
--- /dev/null
+++ b/backend/internal/service/payment_refund_test.go
@@ -0,0 +1,117 @@
+//go:build unit
+
+package service
+
+import (
+ "context"
+ "testing"
+ "time"
+
+ "github.com/Wei-Shaw/sub2api/internal/payment"
+ infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
+ "github.com/stretchr/testify/require"
+)
+
+func TestValidateRefundRequestRejectsLegacyGuessedProviderInstance(t *testing.T) {
+ ctx := context.Background()
+ client := newPaymentConfigServiceTestClient(t)
+
+ user, err := client.User.Create().
+ SetEmail("refund-legacy@example.com").
+ SetPasswordHash("hash").
+ SetUsername("refund-legacy-user").
+ Save(ctx)
+ require.NoError(t, err)
+
+ _, err = client.PaymentProviderInstance.Create().
+ SetProviderKey(payment.TypeAlipay).
+ SetName("alipay-refund-instance").
+ SetConfig("{}").
+ SetSupportedTypes("alipay").
+ SetEnabled(true).
+ SetAllowUserRefund(true).
+ SetRefundEnabled(true).
+ Save(ctx)
+ require.NoError(t, err)
+
+ order, err := client.PaymentOrder.Create().
+ SetUserID(user.ID).
+ SetUserEmail(user.Email).
+ SetUserName(user.Username).
+ SetAmount(88).
+ SetPayAmount(88).
+ SetFeeRate(0).
+ SetRechargeCode("REFUND-LEGACY-ORDER").
+ SetOutTradeNo("sub2_refund_legacy_order").
+ SetPaymentType(payment.TypeAlipay).
+ SetPaymentTradeNo("trade-legacy-refund").
+ SetOrderType(payment.OrderTypeBalance).
+ SetStatus(OrderStatusCompleted).
+ SetExpiresAt(time.Now().Add(time.Hour)).
+ SetPaidAt(time.Now()).
+ SetClientIP("127.0.0.1").
+ SetSrcHost("api.example.com").
+ Save(ctx)
+ require.NoError(t, err)
+
+ svc := &PaymentService{
+ entClient: client,
+ }
+
+ _, err = svc.validateRefundRequest(ctx, order.ID, user.ID)
+ require.Error(t, err)
+ require.Equal(t, "USER_REFUND_DISABLED", infraerrors.Reason(err))
+}
+
+func TestPrepareRefundRejectsLegacyGuessedProviderInstance(t *testing.T) {
+ ctx := context.Background()
+ client := newPaymentConfigServiceTestClient(t)
+
+ user, err := client.User.Create().
+ SetEmail("refund-legacy-admin@example.com").
+ SetPasswordHash("hash").
+ SetUsername("refund-legacy-admin-user").
+ Save(ctx)
+ require.NoError(t, err)
+
+ _, err = client.PaymentProviderInstance.Create().
+ SetProviderKey(payment.TypeAlipay).
+ SetName("alipay-refund-admin-instance").
+ SetConfig("{}").
+ SetSupportedTypes("alipay").
+ SetEnabled(true).
+ SetAllowUserRefund(true).
+ SetRefundEnabled(true).
+ Save(ctx)
+ require.NoError(t, err)
+
+ order, err := client.PaymentOrder.Create().
+ SetUserID(user.ID).
+ SetUserEmail(user.Email).
+ SetUserName(user.Username).
+ SetAmount(188).
+ SetPayAmount(188).
+ SetFeeRate(0).
+ SetRechargeCode("REFUND-LEGACY-ADMIN-ORDER").
+ SetOutTradeNo("sub2_refund_legacy_admin_order").
+ SetPaymentType(payment.TypeAlipay).
+ SetPaymentTradeNo("trade-legacy-admin-refund").
+ SetOrderType(payment.OrderTypeBalance).
+ SetStatus(OrderStatusCompleted).
+ SetExpiresAt(time.Now().Add(time.Hour)).
+ SetPaidAt(time.Now()).
+ SetClientIP("127.0.0.1").
+ SetSrcHost("api.example.com").
+ Save(ctx)
+ require.NoError(t, err)
+
+ svc := &PaymentService{
+ entClient: client,
+ }
+
+ plan, result, err := svc.PrepareRefund(ctx, order.ID, 0, "", false, false)
+ require.Nil(t, plan)
+ require.Nil(t, result)
+ require.Error(t, err)
+ require.Equal(t, "REFUND_DISABLED", infraerrors.Reason(err))
+}
From 0934f737d5c46a0451844c161b4c6e69bd2050d9 Mon Sep 17 00:00:00 2001
From: IanShaw027
Date: Tue, 21 Apr 2026 13:35:54 +0800
Subject: [PATCH 109/189] fix: snapshot merchant identity for alipay and
easypay
---
backend/internal/payment/provider/alipay.go | 44 ++++++++---
.../internal/payment/provider/alipay_test.go | 15 ++++
backend/internal/payment/provider/easypay.go | 28 ++++++-
.../payment/provider/easypay_sign_test.go | 15 ++++
.../internal/payment/provider/wxpay_test.go | 2 +-
backend/internal/payment/types.go | 6 ++
.../internal/service/payment_fulfillment.go | 42 +---------
.../service/payment_fulfillment_test.go | 34 ++++++++
backend/internal/service/payment_order.go | 10 +++
.../payment_order_provider_snapshot.go | 78 +++++++++++++++++++
.../payment_order_provider_snapshot_test.go | 34 ++++++++
backend/internal/service/payment_refund.go | 6 ++
.../internal/service/payment_refund_test.go | 69 ++++++++++++++++
13 files changed, 328 insertions(+), 55 deletions(-)
diff --git a/backend/internal/payment/provider/alipay.go b/backend/internal/payment/provider/alipay.go
index 4f87e5a7..0604883a 100644
--- a/backend/internal/payment/provider/alipay.go
+++ b/backend/internal/payment/provider/alipay.go
@@ -91,6 +91,17 @@ func (a *Alipay) SupportedTypes() []payment.PaymentType {
return []payment.PaymentType{payment.TypeAlipay}
}
+func (a *Alipay) MerchantIdentityMetadata() map[string]string {
+ if a == nil {
+ return nil
+ }
+ appID := strings.TrimSpace(a.config["appId"])
+ if appID == "" {
+ return nil
+ }
+ return map[string]string{"app_id": appID}
+}
+
// CreatePayment creates an Alipay payment page URL.
func (a *Alipay) CreatePayment(ctx context.Context, req payment.CreatePaymentRequest) (*payment.CreatePaymentResponse, error) {
client, err := a.getClient()
@@ -181,10 +192,11 @@ func (a *Alipay) QueryOrder(ctx context.Context, tradeNo string) (*payment.Query
}
return &payment.QueryOrderResponse{
- TradeNo: result.TradeNo,
- Status: status,
- Amount: amount,
- PaidAt: result.SendPayDate,
+ TradeNo: result.TradeNo,
+ Status: status,
+ Amount: amount,
+ PaidAt: result.SendPayDate,
+ Metadata: a.MerchantIdentityMetadata(),
}, nil
}
@@ -215,12 +227,21 @@ func (a *Alipay) VerifyNotification(ctx context.Context, rawBody string, _ map[s
return nil, fmt.Errorf("alipay parse notification amount %q: %w", notification.TotalAmount, err)
}
+ metadata := a.MerchantIdentityMetadata()
+ if appID := strings.TrimSpace(notification.AppId); appID != "" {
+ if metadata == nil {
+ metadata = map[string]string{}
+ }
+ metadata["app_id"] = appID
+ }
+
return &payment.PaymentNotification{
- TradeNo: notification.TradeNo,
- OrderID: notification.OutTradeNo,
- Amount: amount,
- Status: status,
- RawData: rawBody,
+ TradeNo: notification.TradeNo,
+ OrderID: notification.OutTradeNo,
+ Amount: amount,
+ Status: status,
+ RawData: rawBody,
+ Metadata: metadata,
}, nil
}
@@ -283,6 +304,7 @@ func isTradeNotExist(err error) bool {
// Ensure interface compliance.
var (
- _ payment.Provider = (*Alipay)(nil)
- _ payment.CancelableProvider = (*Alipay)(nil)
+ _ payment.Provider = (*Alipay)(nil)
+ _ payment.CancelableProvider = (*Alipay)(nil)
+ _ payment.MerchantIdentityProvider = (*Alipay)(nil)
)
diff --git a/backend/internal/payment/provider/alipay_test.go b/backend/internal/payment/provider/alipay_test.go
index 6cc4246c..b25c05bd 100644
--- a/backend/internal/payment/provider/alipay_test.go
+++ b/backend/internal/payment/provider/alipay_test.go
@@ -243,3 +243,18 @@ func TestCreateTradeUsesWapPayForMobile(t *testing.T) {
t.Fatalf("qr_code = %q, want empty", resp.QRCode)
}
}
+
+func TestAlipayMerchantIdentityMetadata(t *testing.T) {
+ t.Parallel()
+
+ provider := &Alipay{
+ config: map[string]string{
+ "appId": "2021001234567890",
+ },
+ }
+
+ metadata := provider.MerchantIdentityMetadata()
+ if metadata["app_id"] != "2021001234567890" {
+ t.Fatalf("app_id = %q, want %q", metadata["app_id"], "2021001234567890")
+ }
+}
diff --git a/backend/internal/payment/provider/easypay.go b/backend/internal/payment/provider/easypay.go
index e33a567d..37bd38b2 100644
--- a/backend/internal/payment/provider/easypay.go
+++ b/backend/internal/payment/provider/easypay.go
@@ -59,6 +59,17 @@ func (e *EasyPay) SupportedTypes() []payment.PaymentType {
return []payment.PaymentType{payment.TypeAlipay, payment.TypeWxpay}
}
+func (e *EasyPay) MerchantIdentityMetadata() map[string]string {
+ if e == nil {
+ return nil
+ }
+ pid := strings.TrimSpace(e.config["pid"])
+ if pid == "" {
+ return nil
+ }
+ return map[string]string{"pid": pid}
+}
+
func (e *EasyPay) CreatePayment(ctx context.Context, req payment.CreatePaymentRequest) (*payment.CreatePaymentResponse, error) {
// Payment mode determined by instance config, not payment type.
// "popup" → hosted page (submit.php); "qrcode"/default → API call (mapi.php).
@@ -178,7 +189,12 @@ func (e *EasyPay) QueryOrder(ctx context.Context, tradeNo string) (*payment.Quer
status = payment.ProviderStatusPaid
}
amount, _ := strconv.ParseFloat(resp.Money, 64)
- return &payment.QueryOrderResponse{TradeNo: tradeNo, Status: status, Amount: amount}, nil
+ return &payment.QueryOrderResponse{
+ TradeNo: tradeNo,
+ Status: status,
+ Amount: amount,
+ Metadata: e.MerchantIdentityMetadata(),
+ }, nil
}
func (e *EasyPay) VerifyNotification(_ context.Context, rawBody string, _ map[string]string) (*payment.PaymentNotification, error) {
@@ -203,9 +219,17 @@ func (e *EasyPay) VerifyNotification(_ context.Context, rawBody string, _ map[st
status = payment.ProviderStatusSuccess
}
amount, _ := strconv.ParseFloat(params["money"], 64)
+
+ metadata := e.MerchantIdentityMetadata()
+ if pid := strings.TrimSpace(params["pid"]); pid != "" {
+ if metadata == nil {
+ metadata = map[string]string{}
+ }
+ metadata["pid"] = pid
+ }
return &payment.PaymentNotification{
TradeNo: params["trade_no"], OrderID: params["out_trade_no"],
- Amount: amount, Status: status, RawData: rawBody,
+ Amount: amount, Status: status, RawData: rawBody, Metadata: metadata,
}, nil
}
diff --git a/backend/internal/payment/provider/easypay_sign_test.go b/backend/internal/payment/provider/easypay_sign_test.go
index 146a6fa1..8328d294 100644
--- a/backend/internal/payment/provider/easypay_sign_test.go
+++ b/backend/internal/payment/provider/easypay_sign_test.go
@@ -178,3 +178,18 @@ func TestEasyPayVerifySignWrongSignValue(t *testing.T) {
t.Fatal("easyPayVerifySign should return false for an incorrect sign value")
}
}
+
+func TestEasyPayMerchantIdentityMetadata(t *testing.T) {
+ t.Parallel()
+
+ provider := &EasyPay{
+ config: map[string]string{
+ "pid": "1001",
+ },
+ }
+
+ metadata := provider.MerchantIdentityMetadata()
+ if metadata["pid"] != "1001" {
+ t.Fatalf("pid = %q, want %q", metadata["pid"], "1001")
+ }
+}
diff --git a/backend/internal/payment/provider/wxpay_test.go b/backend/internal/payment/provider/wxpay_test.go
index b3f4f648..0d79b1b0 100644
--- a/backend/internal/payment/provider/wxpay_test.go
+++ b/backend/internal/payment/provider/wxpay_test.go
@@ -110,7 +110,7 @@ func TestBuildWxpayTransactionMetadata(t *testing.T) {
Appid: strPtr("wx-app-id"),
Mchid: strPtr("mch-id"),
TradeState: strPtr(wxpayTradeStateSuccess),
- Amount: &payments.Amount{
+ Amount: &payments.TransactionAmount{
Currency: strPtr(wxpayCurrency),
},
}
diff --git a/backend/internal/payment/types.go b/backend/internal/payment/types.go
index 29abf82b..e7ac6727 100644
--- a/backend/internal/payment/types.go
+++ b/backend/internal/payment/types.go
@@ -214,3 +214,9 @@ type CancelableProvider interface {
// CancelPayment cancels/expires a pending payment on the upstream platform.
CancelPayment(ctx context.Context, tradeNo string) error
}
+
+// MerchantIdentityProvider exposes the current non-sensitive merchant identity
+// derived from provider configuration for snapshot consistency checks.
+type MerchantIdentityProvider interface {
+ MerchantIdentityMetadata() map[string]string
+}
diff --git a/backend/internal/service/payment_fulfillment.go b/backend/internal/service/payment_fulfillment.go
index 423ed80f..904960ee 100644
--- a/backend/internal/service/payment_fulfillment.go
+++ b/backend/internal/service/payment_fulfillment.go
@@ -96,47 +96,7 @@ func (s *PaymentService) confirmPayment(ctx context.Context, oid int64, tradeNo
}
func validateProviderNotificationMetadata(order *dbent.PaymentOrder, providerKey string, metadata map[string]string) error {
- if order == nil || len(metadata) == 0 || !strings.EqualFold(strings.TrimSpace(providerKey), payment.TypeWxpay) {
- return nil
- }
-
- snapshot := psOrderProviderSnapshot(order)
- if snapshot == nil {
- return nil
- }
-
- if expected := strings.TrimSpace(snapshot.MerchantAppID); expected != "" {
- actual := strings.TrimSpace(metadata["appid"])
- if actual == "" {
- return fmt.Errorf("wxpay notification missing appid")
- }
- if !strings.EqualFold(expected, actual) {
- return fmt.Errorf("wxpay appid mismatch: expected %s, got %s", expected, actual)
- }
- }
- if expected := strings.TrimSpace(snapshot.MerchantID); expected != "" {
- actual := strings.TrimSpace(metadata["mchid"])
- if actual == "" {
- return fmt.Errorf("wxpay notification missing mchid")
- }
- if !strings.EqualFold(expected, actual) {
- return fmt.Errorf("wxpay mchid mismatch: expected %s, got %s", expected, actual)
- }
- }
- if expected := strings.TrimSpace(snapshot.Currency); expected != "" {
- actual := strings.ToUpper(strings.TrimSpace(metadata["currency"]))
- if actual == "" {
- return fmt.Errorf("wxpay notification missing currency")
- }
- if !strings.EqualFold(expected, actual) {
- return fmt.Errorf("wxpay currency mismatch: expected %s, got %s", expected, actual)
- }
- }
- if actual := strings.TrimSpace(metadata["trade_state"]); actual != "" && !strings.EqualFold(actual, "SUCCESS") {
- return fmt.Errorf("wxpay trade_state mismatch: expected SUCCESS, got %s", actual)
- }
-
- return nil
+ return validateProviderSnapshotMetadata(order, providerKey, metadata)
}
func expectedNotificationProviderKey(registry *payment.Registry, orderPaymentType string, orderProviderKey string, instanceProviderKey string) string {
diff --git a/backend/internal/service/payment_fulfillment_test.go b/backend/internal/service/payment_fulfillment_test.go
index d70f8946..6aed19f8 100644
--- a/backend/internal/service/payment_fulfillment_test.go
+++ b/backend/internal/service/payment_fulfillment_test.go
@@ -321,3 +321,37 @@ func TestParseLegacyPaymentOrderID(t *testing.T) {
_, ok = parseLegacyPaymentOrderID("sub2_42", errors.New("db down"))
assert.False(t, ok)
}
+
+func TestValidateProviderNotificationMetadataRejectsAlipaySnapshotMismatch(t *testing.T) {
+ t.Parallel()
+
+ order := &dbent.PaymentOrder{
+ PaymentType: payment.TypeAlipay,
+ ProviderSnapshot: map[string]any{
+ "schema_version": 2,
+ "merchant_app_id": "alipay-app-expected",
+ },
+ }
+
+ err := validateProviderNotificationMetadata(order, payment.TypeAlipay, map[string]string{
+ "app_id": "alipay-app-other",
+ })
+ assert.ErrorContains(t, err, "alipay app_id mismatch")
+}
+
+func TestValidateProviderNotificationMetadataRejectsEasyPaySnapshotMismatch(t *testing.T) {
+ t.Parallel()
+
+ order := &dbent.PaymentOrder{
+ PaymentType: payment.TypeAlipay,
+ ProviderSnapshot: map[string]any{
+ "schema_version": 2,
+ "merchant_id": "pid-expected",
+ },
+ }
+
+ err := validateProviderNotificationMetadata(order, payment.TypeEasyPay, map[string]string{
+ "pid": "pid-other",
+ })
+ assert.ErrorContains(t, err, "easypay pid mismatch")
+}
diff --git a/backend/internal/service/payment_order.go b/backend/internal/service/payment_order.go
index 6ee490a8..254af5fe 100644
--- a/backend/internal/service/payment_order.go
+++ b/backend/internal/service/payment_order.go
@@ -240,6 +240,16 @@ func buildPaymentOrderProviderSnapshot(sel *payment.InstanceSelection, req Creat
}
snapshot["currency"] = "CNY"
}
+ if providerKey == payment.TypeAlipay {
+ if merchantAppID := strings.TrimSpace(sel.Config["appId"]); merchantAppID != "" {
+ snapshot["merchant_app_id"] = merchantAppID
+ }
+ }
+ if providerKey == payment.TypeEasyPay {
+ if merchantID := strings.TrimSpace(sel.Config["pid"]); merchantID != "" {
+ snapshot["merchant_id"] = merchantID
+ }
+ }
if len(snapshot) == 1 {
return nil
diff --git a/backend/internal/service/payment_order_provider_snapshot.go b/backend/internal/service/payment_order_provider_snapshot.go
index 31a790c7..bb60f9e2 100644
--- a/backend/internal/service/payment_order_provider_snapshot.go
+++ b/backend/internal/service/payment_order_provider_snapshot.go
@@ -125,3 +125,81 @@ func expectedNotificationProviderKeyForOrder(registry *payment.Registry, order *
return expectedNotificationProviderKey(registry, order.PaymentType, orderProviderKey, instanceProviderKey)
}
+
+func validateProviderSnapshotMetadata(order *dbent.PaymentOrder, providerKey string, metadata map[string]string) error {
+ if order == nil || len(metadata) == 0 {
+ return nil
+ }
+
+ snapshot := psOrderProviderSnapshot(order)
+ if snapshot == nil {
+ return nil
+ }
+
+ switch strings.TrimSpace(providerKey) {
+ case payment.TypeWxpay:
+ if expected := strings.TrimSpace(snapshot.MerchantAppID); expected != "" {
+ actual := strings.TrimSpace(metadata["appid"])
+ if actual == "" {
+ return fmt.Errorf("wxpay notification missing appid")
+ }
+ if !strings.EqualFold(expected, actual) {
+ return fmt.Errorf("wxpay appid mismatch: expected %s, got %s", expected, actual)
+ }
+ }
+ if expected := strings.TrimSpace(snapshot.MerchantID); expected != "" {
+ actual := strings.TrimSpace(metadata["mchid"])
+ if actual == "" {
+ return fmt.Errorf("wxpay notification missing mchid")
+ }
+ if !strings.EqualFold(expected, actual) {
+ return fmt.Errorf("wxpay mchid mismatch: expected %s, got %s", expected, actual)
+ }
+ }
+ if expected := strings.TrimSpace(snapshot.Currency); expected != "" {
+ actual := strings.ToUpper(strings.TrimSpace(metadata["currency"]))
+ if actual == "" {
+ return fmt.Errorf("wxpay notification missing currency")
+ }
+ if !strings.EqualFold(expected, actual) {
+ return fmt.Errorf("wxpay currency mismatch: expected %s, got %s", expected, actual)
+ }
+ }
+ if actual := strings.TrimSpace(metadata["trade_state"]); actual != "" && !strings.EqualFold(actual, "SUCCESS") {
+ return fmt.Errorf("wxpay trade_state mismatch: expected SUCCESS, got %s", actual)
+ }
+ case payment.TypeAlipay:
+ if expected := strings.TrimSpace(snapshot.MerchantAppID); expected != "" {
+ actual := strings.TrimSpace(metadata["app_id"])
+ if actual == "" {
+ return fmt.Errorf("alipay app_id missing")
+ }
+ if !strings.EqualFold(expected, actual) {
+ return fmt.Errorf("alipay app_id mismatch: expected %s, got %s", expected, actual)
+ }
+ }
+ case payment.TypeEasyPay:
+ if expected := strings.TrimSpace(snapshot.MerchantID); expected != "" {
+ actual := strings.TrimSpace(metadata["pid"])
+ if actual == "" {
+ return fmt.Errorf("easypay pid missing")
+ }
+ if !strings.EqualFold(expected, actual) {
+ return fmt.Errorf("easypay pid mismatch: expected %s, got %s", expected, actual)
+ }
+ }
+ }
+
+ return nil
+}
+
+func providerMerchantIdentityMetadata(prov payment.Provider) map[string]string {
+ if prov == nil {
+ return nil
+ }
+ reporter, ok := prov.(payment.MerchantIdentityProvider)
+ if !ok {
+ return nil
+ }
+ return reporter.MerchantIdentityMetadata()
+}
diff --git a/backend/internal/service/payment_order_provider_snapshot_test.go b/backend/internal/service/payment_order_provider_snapshot_test.go
index bc6666a8..efa013b5 100644
--- a/backend/internal/service/payment_order_provider_snapshot_test.go
+++ b/backend/internal/service/payment_order_provider_snapshot_test.go
@@ -130,6 +130,40 @@ func TestBuildPaymentOrderProviderSnapshot_UsesWxpayJSAPIAppIDForOpenIDOrders(t
require.Equal(t, "CNY", snapshot["currency"])
}
+func TestBuildPaymentOrderProviderSnapshot_IncludesAlipayMerchantIdentity(t *testing.T) {
+ t.Parallel()
+
+ snapshot := buildPaymentOrderProviderSnapshot(&payment.InstanceSelection{
+ InstanceID: "21",
+ ProviderKey: payment.TypeAlipay,
+ Config: map[string]string{
+ "appId": "alipay-app-21",
+ "privateKey": "secret",
+ },
+ PaymentMode: "redirect",
+ }, CreateOrderRequest{})
+
+ require.Equal(t, "alipay-app-21", snapshot["merchant_app_id"])
+ require.NotContains(t, snapshot, "privateKey")
+}
+
+func TestBuildPaymentOrderProviderSnapshot_IncludesEasyPayMerchantIdentity(t *testing.T) {
+ t.Parallel()
+
+ snapshot := buildPaymentOrderProviderSnapshot(&payment.InstanceSelection{
+ InstanceID: "66",
+ ProviderKey: payment.TypeEasyPay,
+ Config: map[string]string{
+ "pid": "easypay-merchant-66",
+ "pkey": "secret",
+ },
+ PaymentMode: "popup",
+ }, CreateOrderRequest{PaymentType: payment.TypeAlipay})
+
+ require.Equal(t, "easypay-merchant-66", snapshot["merchant_id"])
+ require.NotContains(t, snapshot, "pkey")
+}
+
func valueOrEmpty(v *string) string {
if v == nil {
return ""
diff --git a/backend/internal/service/payment_refund.go b/backend/internal/service/payment_refund.go
index 6883056c..7521878c 100644
--- a/backend/internal/service/payment_refund.go
+++ b/backend/internal/service/payment_refund.go
@@ -333,6 +333,12 @@ func (s *PaymentService) gwRefund(ctx context.Context, p *RefundPlan) error {
if err != nil {
return fmt.Errorf("get refund provider: %w", err)
}
+ if err := validateProviderSnapshotMetadata(p.Order, prov.ProviderKey(), providerMerchantIdentityMetadata(prov)); err != nil {
+ s.writeAuditLog(ctx, p.Order.ID, "REFUND_PROVIDER_METADATA_MISMATCH", "admin", map[string]any{
+ "detail": err.Error(),
+ })
+ return err
+ }
_, err = prov.Refund(ctx, payment.RefundRequest{
TradeNo: p.Order.PaymentTradeNo,
OrderID: p.Order.OutTradeNo,
diff --git a/backend/internal/service/payment_refund_test.go b/backend/internal/service/payment_refund_test.go
index 95104618..ca5b62cb 100644
--- a/backend/internal/service/payment_refund_test.go
+++ b/backend/internal/service/payment_refund_test.go
@@ -4,6 +4,7 @@ package service
import (
"context"
+ "strconv"
"testing"
"time"
@@ -115,3 +116,71 @@ func TestPrepareRefundRejectsLegacyGuessedProviderInstance(t *testing.T) {
require.Error(t, err)
require.Equal(t, "REFUND_DISABLED", infraerrors.Reason(err))
}
+
+func TestGwRefundRejectsAlipayMerchantIdentitySnapshotMismatch(t *testing.T) {
+ ctx := context.Background()
+ client := newPaymentConfigServiceTestClient(t)
+
+ user, err := client.User.Create().
+ SetEmail("refund-snapshot-mismatch@example.com").
+ SetPasswordHash("hash").
+ SetUsername("refund-snapshot-mismatch-user").
+ Save(ctx)
+ require.NoError(t, err)
+
+ inst, err := client.PaymentProviderInstance.Create().
+ SetProviderKey(payment.TypeAlipay).
+ SetName("alipay-refund-mismatch-instance").
+ SetConfig(encryptWebhookProviderConfig(t, map[string]string{
+ "appId": "runtime-alipay-app",
+ "privateKey": "runtime-private-key",
+ })).
+ SetSupportedTypes("alipay").
+ SetEnabled(true).
+ SetRefundEnabled(true).
+ Save(ctx)
+ require.NoError(t, err)
+
+ instID := strconv.FormatInt(inst.ID, 10)
+ order, err := client.PaymentOrder.Create().
+ SetUserID(user.ID).
+ SetUserEmail(user.Email).
+ SetUserName(user.Username).
+ SetAmount(88).
+ SetPayAmount(88).
+ SetFeeRate(0).
+ SetRechargeCode("REFUND-SNAPSHOT-MISMATCH-ORDER").
+ SetOutTradeNo("sub2_refund_snapshot_mismatch_order").
+ SetPaymentType(payment.TypeAlipay).
+ SetPaymentTradeNo("trade-refund-snapshot-mismatch").
+ SetOrderType(payment.OrderTypeBalance).
+ SetStatus(OrderStatusCompleted).
+ SetExpiresAt(time.Now().Add(time.Hour)).
+ SetPaidAt(time.Now()).
+ SetClientIP("127.0.0.1").
+ SetSrcHost("api.example.com").
+ SetProviderInstanceID(instID).
+ SetProviderKey(payment.TypeAlipay).
+ SetProviderSnapshot(map[string]any{
+ "schema_version": 2,
+ "provider_instance_id": instID,
+ "provider_key": payment.TypeAlipay,
+ "merchant_app_id": "expected-alipay-app",
+ }).
+ Save(ctx)
+ require.NoError(t, err)
+
+ svc := &PaymentService{
+ entClient: client,
+ loadBalancer: newWebhookProviderTestLoadBalancer(client),
+ }
+
+ err = svc.gwRefund(ctx, &RefundPlan{
+ OrderID: order.ID,
+ Order: order,
+ RefundAmount: order.Amount,
+ GatewayAmount: order.Amount,
+ Reason: "snapshot mismatch",
+ })
+ require.ErrorContains(t, err, "alipay app_id mismatch")
+}
From 12f1e19d688f1cbb892a152572eb793f5e7d097d Mon Sep 17 00:00:00 2001
From: IanShaw027
Date: Tue, 21 Apr 2026 13:36:19 +0800
Subject: [PATCH 110/189] fix: restore wechat oauth legacy callback
compatibility
---
.../src/views/auth/WechatCallbackView.vue | 117 +++++++++++++++---
.../auth/__tests__/WechatCallbackView.spec.ts | 111 +++++++++++++++++
2 files changed, 208 insertions(+), 20 deletions(-)
diff --git a/frontend/src/views/auth/WechatCallbackView.vue b/frontend/src/views/auth/WechatCallbackView.vue
index 36e3140c..a3efcf9a 100644
--- a/frontend/src/views/auth/WechatCallbackView.vue
+++ b/frontend/src/views/auth/WechatCallbackView.vue
@@ -300,6 +300,7 @@ import {
persistOAuthTokenContext,
resolveWeChatOAuthStartStrict,
type OAuthAdoptionDecision,
+ type OAuthTokenResponse,
type PendingOAuthExchangeResponse
} from '@/api/auth'
@@ -328,6 +329,7 @@ const pendingAccountAction = ref<'none' | 'create_account' | 'bind_login'>('none
const pendingAccountEmail = ref('')
const bindLoginEmail = ref('')
const bindLoginPassword = ref('')
+const legacyPendingOAuthToken = ref('')
const accountActionError = ref('')
const canReturnToCreateAccount = ref(false)
const needsTotpChallenge = ref(false)
@@ -354,12 +356,49 @@ type PendingWeChatCompletion = PendingOAuthExchangeResponse & {
user_email_masked?: string
}
+function persistPendingAuthSession(redirect?: string) {
+ authStore.setPendingAuthSession({
+ token: '',
+ token_field: 'pending_oauth_token',
+ provider: 'wechat',
+ redirect: sanitizeRedirectPath(redirect || redirectTo.value)
+ })
+}
+
+function clearPendingAuthSession() {
+ authStore.clearPendingAuthSession()
+}
+
function parseFragmentParams(): URLSearchParams {
const raw = typeof window !== 'undefined' ? window.location.hash : ''
const hash = raw.startsWith('#') ? raw.slice(1) : raw
return new URLSearchParams(hash)
}
+function readLegacyFragmentLogin(params: URLSearchParams): OAuthTokenResponse | null {
+ const accessToken = params.get('access_token')?.trim() || ''
+ if (!accessToken) {
+ return null
+ }
+
+ const completion: OAuthTokenResponse = {
+ access_token: accessToken
+ }
+ const refreshToken = params.get('refresh_token')?.trim() || ''
+ if (refreshToken) {
+ completion.refresh_token = refreshToken
+ }
+ const expiresIn = Number.parseInt(params.get('expires_in')?.trim() || '', 10)
+ if (Number.isFinite(expiresIn) && expiresIn > 0) {
+ completion.expires_in = expiresIn
+ }
+ const tokenType = params.get('token_type')?.trim() || ''
+ if (tokenType) {
+ completion.token_type = tokenType
+ }
+ return completion
+}
+
function sanitizeRedirectPath(path: string | null | undefined): string {
if (!path) return '/dashboard'
if (!path.startsWith('/')) return '/dashboard'
@@ -672,6 +711,7 @@ function isCreateAccountRecoveryError(error: unknown): boolean {
async function finalizeCompletion(completion: PendingOAuthExchangeResponse, redirect: string) {
if (getOAuthCompletionKind(completion) === 'bind') {
const bindRedirect = sanitizeRedirectPath(completion.redirect || '/profile')
+ clearPendingAuthSession()
appStore.showSuccess(bindSuccessMessage)
await router.replace(bindRedirect)
return
@@ -689,16 +729,19 @@ async function finalizeCompletion(completion: PendingOAuthExchangeResponse, redi
async function finalizePendingAccountResponse(completion: PendingWeChatCompletion) {
applyAdoptionSuggestionState(completion)
+ const redirect = sanitizeRedirectPath(completion.redirect || redirectTo.value)
if (completion.error === 'invitation_required') {
pendingAccountAction.value = 'none'
needsInvitation.value = true
needsAdoptionConfirmation.value = false
isProcessing.value = false
+ persistPendingAuthSession(redirect)
return
}
if (applyTotpChallenge(completion)) {
+ persistPendingAuthSession(redirect)
return
}
@@ -707,10 +750,10 @@ async function finalizePendingAccountResponse(completion: PendingWeChatCompletio
needsInvitation.value = false
needsAdoptionConfirmation.value = false
isProcessing.value = false
+ persistPendingAuthSession(redirect)
return
}
- const redirect = sanitizeRedirectPath(completion.redirect || redirectTo.value)
await finalizeCompletion(completion, redirect)
}
@@ -720,10 +763,18 @@ async function handleSubmitInvitation() {
isSubmitting.value = true
try {
- const tokenData = await completeWeChatOAuthRegistration(
- invitationCode.value.trim(),
- currentAdoptionDecision()
- )
+ const tokenData = legacyPendingOAuthToken.value
+ ? (
+ await apiClient.post('/auth/oauth/wechat/complete-registration', {
+ pending_oauth_token: legacyPendingOAuthToken.value,
+ invitation_code: invitationCode.value.trim(),
+ ...serializeAdoptionDecision(currentAdoptionDecision())
+ })
+ ).data
+ : await completeWeChatOAuthRegistration(
+ invitationCode.value.trim(),
+ currentAdoptionDecision()
+ )
persistOAuthTokenContext(tokenData)
await authStore.setToken(tokenData.access_token)
appStore.showSuccess(t('auth.loginSuccess'))
@@ -864,48 +915,74 @@ onMounted(async () => {
}
const params = parseFragmentParams()
+ const legacyLogin = readLegacyFragmentLogin(params)
+ const legacyPendingToken = params.get('pending_oauth_token')?.trim() || ''
const error = params.get('error')
const errorDesc = params.get('error_description') || params.get('error_message') || ''
-
- if (error) {
- errorMessage.value = errorDesc || error
- appStore.showError(errorMessage.value)
- isProcessing.value = false
- return
- }
+ const redirect = sanitizeRedirectPath(
+ params.get('redirect') || (route.query.redirect as string | undefined) || '/dashboard'
+ )
try {
- const completion = await exchangePendingOAuthCompletion() as PendingWeChatCompletion
- const redirect = sanitizeRedirectPath(
- completion.redirect || (route.query.redirect as string | undefined) || '/dashboard'
- )
- applyAdoptionSuggestionState(completion)
- redirectTo.value = redirect
+ if (legacyLogin) {
+ persistOAuthTokenContext(legacyLogin)
+ await authStore.setToken(legacyLogin.access_token)
+ appStore.showSuccess(t('auth.loginSuccess'))
+ await router.replace(redirect)
+ return
+ }
- if (completion.error === 'invitation_required') {
+ if (error === 'invitation_required' && legacyPendingToken) {
+ legacyPendingOAuthToken.value = legacyPendingToken
+ redirectTo.value = redirect
needsInvitation.value = true
isProcessing.value = false
return
}
+ if (error) {
+ errorMessage.value = errorDesc || error
+ appStore.showError(errorMessage.value)
+ isProcessing.value = false
+ return
+ }
+
+ const completion = await exchangePendingOAuthCompletion() as PendingWeChatCompletion
+ const completionRedirect = sanitizeRedirectPath(
+ completion.redirect || (route.query.redirect as string | undefined) || '/dashboard'
+ )
+ applyAdoptionSuggestionState(completion)
+ redirectTo.value = completionRedirect
+
+ if (completion.error === 'invitation_required') {
+ needsInvitation.value = true
+ isProcessing.value = false
+ persistPendingAuthSession(completionRedirect)
+ return
+ }
+
if (applyTotpChallenge(completion)) {
+ persistPendingAuthSession(completionRedirect)
return
}
applyPendingAccountAction(completion)
if (pendingAccountAction.value !== 'none') {
isProcessing.value = false
+ persistPendingAuthSession(completionRedirect)
return
}
if (adoptionRequired.value && hasSuggestedProfile(completion)) {
needsAdoptionConfirmation.value = true
isProcessing.value = false
+ persistPendingAuthSession(completionRedirect)
return
}
- await finalizeCompletion(completion, redirect)
+ await finalizeCompletion(completion, completionRedirect)
} catch (e: unknown) {
+ clearPendingAuthSession()
errorMessage.value = getRequestErrorMessage(e, t('auth.loginFailed'))
appStore.showError(errorMessage.value)
isProcessing.value = false
diff --git a/frontend/src/views/auth/__tests__/WechatCallbackView.spec.ts b/frontend/src/views/auth/__tests__/WechatCallbackView.spec.ts
index e02060f6..98a0268d 100644
--- a/frontend/src/views/auth/__tests__/WechatCallbackView.spec.ts
+++ b/frontend/src/views/auth/__tests__/WechatCallbackView.spec.ts
@@ -14,6 +14,8 @@ const {
getAuthTokenMock,
replaceMock,
setTokenMock,
+ setPendingAuthSessionMock,
+ clearPendingAuthSessionMock,
showSuccessMock,
showErrorMock,
fetchPublicSettingsMock,
@@ -32,6 +34,8 @@ const {
getAuthTokenMock: vi.fn(),
replaceMock: vi.fn(),
setTokenMock: vi.fn(),
+ setPendingAuthSessionMock: vi.fn(),
+ clearPendingAuthSessionMock: vi.fn(),
showSuccessMock: vi.fn(),
showErrorMock: vi.fn(),
fetchPublicSettingsMock: vi.fn(),
@@ -111,6 +115,8 @@ vi.mock('vue-i18n', () => ({
vi.mock('@/stores', () => ({
useAuthStore: () => ({
setToken: setTokenMock,
+ setPendingAuthSession: setPendingAuthSessionMock,
+ clearPendingAuthSession: clearPendingAuthSessionMock,
}),
useAppStore: () => ({
...appStoreState,
@@ -152,6 +158,8 @@ describe('WechatCallbackView', () => {
getPublicSettingsMock.mockReset()
replaceMock.mockReset()
setTokenMock.mockReset()
+ setPendingAuthSessionMock.mockReset()
+ clearPendingAuthSessionMock.mockReset()
showSuccessMock.mockReset()
showErrorMock.mockReset()
prepareOAuthBindAccessTokenCookieMock.mockReset()
@@ -269,6 +277,81 @@ describe('WechatCallbackView', () => {
expect(locationState.current.href).toContain('mode=open')
})
+ it('accepts the legacy fragment token success callback without pending-session exchange', async () => {
+ locationState.current.hash =
+ '#access_token=legacy-access-token&refresh_token=legacy-refresh-token&expires_in=3600&token_type=Bearer&redirect=%2Flegacy-dashboard'
+ Object.defineProperty(window, 'location', {
+ configurable: true,
+ value: locationState.current,
+ })
+ setTokenMock.mockResolvedValue({})
+
+ mount(WechatCallbackView, {
+ global: {
+ stubs: {
+ AuthLayout: { template: '
' },
+ Icon: true,
+ RouterLink: { template: ' ' },
+ transition: false,
+ },
+ },
+ })
+
+ await flushPromises()
+
+ expect(exchangePendingOAuthCompletionMock).not.toHaveBeenCalled()
+ expect(setTokenMock).toHaveBeenCalledWith('legacy-access-token')
+ expect(localStorage.getItem('refresh_token')).toBe('legacy-refresh-token')
+ expect(localStorage.getItem('token_expires_at')).not.toBeNull()
+ expect(showSuccessMock).toHaveBeenCalledWith('Login success')
+ expect(replaceMock).toHaveBeenCalledWith('/legacy-dashboard')
+ })
+
+ it('accepts the legacy pending oauth invitation fragment without pending-session exchange', async () => {
+ locationState.current.hash =
+ '#error=invitation_required&pending_oauth_token=legacy-pending-token&redirect=%2Flegacy-invite'
+ Object.defineProperty(window, 'location', {
+ configurable: true,
+ value: locationState.current,
+ })
+ apiClientPostMock.mockResolvedValue({
+ data: {
+ access_token: 'legacy-access-token',
+ refresh_token: 'legacy-refresh-token',
+ expires_in: 3600,
+ token_type: 'Bearer',
+ },
+ })
+ setTokenMock.mockResolvedValue({})
+
+ const wrapper = mount(WechatCallbackView, {
+ global: {
+ stubs: {
+ AuthLayout: { template: '
' },
+ Icon: true,
+ RouterLink: { template: ' ' },
+ transition: false,
+ },
+ },
+ })
+
+ await flushPromises()
+
+ expect(exchangePendingOAuthCompletionMock).not.toHaveBeenCalled()
+ await wrapper.find('input[type="text"]').setValue('invite-code')
+ await wrapper.find('button').trigger('click')
+ await flushPromises()
+
+ expect(apiClientPostMock).toHaveBeenCalledWith('/auth/oauth/wechat/complete-registration', {
+ pending_oauth_token: 'legacy-pending-token',
+ invitation_code: 'invite-code',
+ adopt_display_name: true,
+ adopt_avatar: true,
+ })
+ expect(setTokenMock).toHaveBeenCalledWith('legacy-access-token')
+ expect(replaceMock).toHaveBeenCalledWith('/legacy-invite')
+ })
+
it('does not send adoption decisions during the initial exchange', async () => {
exchangePendingOAuthCompletionMock.mockResolvedValue({
access_token: 'access-token',
@@ -382,6 +465,7 @@ describe('WechatCallbackView', () => {
adoptAvatar: true,
})
expect(setTokenMock).not.toHaveBeenCalled()
+ expect(clearPendingAuthSessionMock).toHaveBeenCalledTimes(1)
expect(showSuccessMock).toHaveBeenCalledWith('profile.authBindings.bindSuccess')
expect(replaceMock).toHaveBeenCalledWith('/profile/connections')
})
@@ -548,6 +632,33 @@ describe('WechatCallbackView', () => {
expect(replaceMock).toHaveBeenCalledWith('/welcome')
})
+ it('persists a pending auth session when the oauth flow still needs account creation', async () => {
+ exchangePendingOAuthCompletionMock.mockResolvedValue({
+ error: 'email_required',
+ redirect: '/welcome',
+ })
+
+ mount(WechatCallbackView, {
+ global: {
+ stubs: {
+ AuthLayout: { template: '
' },
+ Icon: true,
+ RouterLink: { template: ' ' },
+ transition: false,
+ },
+ },
+ })
+
+ await flushPromises()
+
+ expect(setPendingAuthSessionMock).toHaveBeenCalledWith({
+ token: '',
+ token_field: 'pending_oauth_token',
+ provider: 'wechat',
+ redirect: '/welcome',
+ })
+ })
+
it('switches to bind-login when create-account returns EMAIL_EXISTS', async () => {
exchangePendingOAuthCompletionMock.mockResolvedValue({
error: 'email_required',
From 65efef1eee696b295b11e550d424e6d76a944f76 Mon Sep 17 00:00:00 2001
From: IanShaw027
Date: Tue, 21 Apr 2026 13:47:15 +0800
Subject: [PATCH 111/189] feat: support replacing bound primary email
---
backend/internal/handler/user_handler.go | 2 +-
backend/internal/handler/user_handler_test.go | 53 +++++++
.../internal/service/auth_email_binding.go | 31 ++--
.../service/auth_service_email_bind_test.go | 142 ++++++++++++++++++
.../ProfileIdentityBindingsSection.vue | 28 +++-
.../ProfileIdentityBindingsSection.spec.ts | 68 +++++++++
frontend/src/i18n/locales/en.ts | 3 +
frontend/src/i18n/locales/zh.ts | 3 +
8 files changed, 313 insertions(+), 17 deletions(-)
diff --git a/backend/internal/handler/user_handler.go b/backend/internal/handler/user_handler.go
index a6a7be9a..497a23c4 100644
--- a/backend/internal/handler/user_handler.go
+++ b/backend/internal/handler/user_handler.go
@@ -167,7 +167,7 @@ type StartIdentityBindingRequest struct {
type BindEmailIdentityRequest struct {
Email string `json:"email" binding:"required,email"`
VerifyCode string `json:"verify_code" binding:"required"`
- Password string `json:"password" binding:"required,min=6"`
+ Password string `json:"password" binding:"required"`
}
type SendEmailBindingCodeRequest struct {
diff --git a/backend/internal/handler/user_handler_test.go b/backend/internal/handler/user_handler_test.go
index 72b28293..24f715d4 100644
--- a/backend/internal/handler/user_handler_test.go
+++ b/backend/internal/handler/user_handler_test.go
@@ -422,6 +422,59 @@ func TestUserHandlerBindEmailIdentityReturnsProfileResponse(t *testing.T) {
require.True(t, resp.Data.EmailBound)
}
+func TestUserHandlerBindEmailIdentityRejectsWrongCurrentPasswordForBoundEmail(t *testing.T) {
+ gin.SetMode(gin.TestMode)
+
+ user := &service.User{
+ ID: 11,
+ Email: "current@example.com",
+ Username: "bound-user",
+ Role: service.RoleUser,
+ Status: service.StatusActive,
+ }
+ require.NoError(t, user.SetPassword("current-password"))
+
+ repo := &userHandlerRepoStub{user: user}
+ emailCache := &userHandlerEmailCacheStub{
+ data: &service.VerificationCodeData{
+ Code: "123456",
+ CreatedAt: time.Now().UTC(),
+ ExpiresAt: time.Now().UTC().Add(10 * time.Minute),
+ },
+ }
+ cfg := &config.Config{
+ JWT: config.JWTConfig{
+ Secret: "test-secret",
+ ExpireHour: 1,
+ },
+ }
+ emailService := service.NewEmailService(nil, emailCache)
+ authService := service.NewAuthService(nil, repo, nil, nil, cfg, nil, emailService, nil, nil, nil, nil)
+ handler := NewUserHandler(service.NewUserService(repo, nil, nil, nil), authService, nil, nil)
+
+ body := []byte(`{"email":"new@example.com","verify_code":"123456","password":"wrong-password"}`)
+ recorder := httptest.NewRecorder()
+ c, _ := gin.CreateTestContext(recorder)
+ c.Request = httptest.NewRequest(http.MethodPost, "/api/v1/user/account-bindings/email", bytes.NewReader(body))
+ c.Request.Header.Set("Content-Type", "application/json")
+ c.Set(string(middleware2.ContextKeyUser), middleware2.AuthSubject{UserID: 11})
+
+ handler.BindEmailIdentity(c)
+
+ require.Equal(t, http.StatusBadRequest, recorder.Code)
+
+ var resp struct {
+ Code int `json:"code"`
+ Message string `json:"message"`
+ Reason string `json:"reason"`
+ }
+ require.NoError(t, json.Unmarshal(recorder.Body.Bytes(), &resp))
+ require.Equal(t, http.StatusBadRequest, resp.Code)
+ require.Equal(t, "PASSWORD_INCORRECT", resp.Reason)
+ require.Equal(t, "current password is incorrect", resp.Message)
+ require.Equal(t, "current@example.com", repo.user.Email)
+}
+
func TestUserHandlerStartIdentityBindingReturnsAuthorizeURL(t *testing.T) {
gin.SetMode(gin.TestMode)
diff --git a/backend/internal/service/auth_email_binding.go b/backend/internal/service/auth_email_binding.go
index 58f8e647..b060ab76 100644
--- a/backend/internal/service/auth_email_binding.go
+++ b/backend/internal/service/auth_email_binding.go
@@ -13,7 +13,8 @@ import (
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
)
-// BindEmailIdentity verifies and binds a local email/password identity to the current user.
+// BindEmailIdentity verifies and binds a local email/password identity to the
+// current user, or replaces the existing bound primary email.
func (s *AuthService) BindEmailIdentity(
ctx context.Context,
userID int64,
@@ -43,6 +44,13 @@ func (s *AuthService) BindEmailIdentity(
if err != nil {
return nil, err
}
+ firstRealEmailBind := !hasBindableEmailIdentitySubject(currentUser.Email)
+ if firstRealEmailBind && len(password) < 6 {
+ return nil, infraerrors.BadRequest("PASSWORD_TOO_SHORT", "password must be at least 6 characters")
+ }
+ if !firstRealEmailBind && !s.CheckPassword(password, currentUser.PasswordHash) {
+ return nil, ErrPasswordIncorrect
+ }
existingUser, err := s.userRepo.GetByEmail(ctx, normalizedEmail)
switch {
@@ -57,9 +65,8 @@ func (s *AuthService) BindEmailIdentity(
return nil, fmt.Errorf("hash password: %w", err)
}
- firstRealEmailBind := !hasBindableEmailIdentitySubject(currentUser.Email)
- if firstRealEmailBind && s.entClient != nil {
- if err := s.bindEmailIdentityWithDefaultsTx(ctx, currentUser, normalizedEmail, hashedPassword); err != nil {
+ if s.entClient != nil {
+ if err := s.updateBoundEmailIdentityTx(ctx, currentUser, normalizedEmail, hashedPassword, firstRealEmailBind); err != nil {
return nil, err
}
return currentUser, nil
@@ -137,14 +144,15 @@ func hasBindableEmailIdentitySubject(email string) bool {
return normalized != "" && !isReservedEmail(normalized)
}
-func (s *AuthService) bindEmailIdentityWithDefaultsTx(
+func (s *AuthService) updateBoundEmailIdentityTx(
ctx context.Context,
currentUser *User,
email string,
hashedPassword string,
+ applyFirstBindDefaults bool,
) error {
if tx := dbent.TxFromContext(ctx); tx != nil {
- return s.bindEmailIdentityWithDefaults(ctx, tx.Client(), currentUser, email, hashedPassword)
+ return s.updateBoundEmailIdentityWithClient(ctx, tx.Client(), currentUser, email, hashedPassword, applyFirstBindDefaults)
}
tx, err := s.entClient.Tx(ctx)
@@ -154,7 +162,7 @@ func (s *AuthService) bindEmailIdentityWithDefaultsTx(
defer func() { _ = tx.Rollback() }()
txCtx := dbent.NewTxContext(ctx, tx)
- if err := s.bindEmailIdentityWithDefaults(txCtx, tx.Client(), currentUser, email, hashedPassword); err != nil {
+ if err := s.updateBoundEmailIdentityWithClient(txCtx, tx.Client(), currentUser, email, hashedPassword, applyFirstBindDefaults); err != nil {
return err
}
if err := tx.Commit(); err != nil {
@@ -163,12 +171,13 @@ func (s *AuthService) bindEmailIdentityWithDefaultsTx(
return nil
}
-func (s *AuthService) bindEmailIdentityWithDefaults(
+func (s *AuthService) updateBoundEmailIdentityWithClient(
ctx context.Context,
client *dbent.Client,
currentUser *User,
email string,
hashedPassword string,
+ applyFirstBindDefaults bool,
) error {
if client == nil || currentUser == nil || currentUser.ID <= 0 {
return ErrServiceUnavailable
@@ -192,8 +201,10 @@ func (s *AuthService) bindEmailIdentityWithDefaults(
return ErrServiceUnavailable
}
- if err := s.ApplyProviderDefaultSettingsOnFirstBind(ctx, currentUser.ID, "email"); err != nil {
- return fmt.Errorf("apply email first bind defaults: %w", err)
+ if applyFirstBindDefaults {
+ if err := s.ApplyProviderDefaultSettingsOnFirstBind(ctx, currentUser.ID, "email"); err != nil {
+ return fmt.Errorf("apply email first bind defaults: %w", err)
+ }
}
updatedUser, err := client.User.Get(ctx, currentUser.ID)
diff --git a/backend/internal/service/auth_service_email_bind_test.go b/backend/internal/service/auth_service_email_bind_test.go
index fd5f499b..d32a4a40 100644
--- a/backend/internal/service/auth_service_email_bind_test.go
+++ b/backend/internal/service/auth_service_email_bind_test.go
@@ -285,6 +285,148 @@ func TestAuthServiceBindEmailIdentity_RejectsReservedEmail(t *testing.T) {
require.Nil(t, updatedUser)
}
+func TestAuthServiceBindEmailIdentity_ReplacesBoundEmailAndSkipsFirstBindDefaults(t *testing.T) {
+ assigner := &emailBindDefaultSubAssignerStub{}
+ cache := &emailBindCacheStub{
+ data: &service.VerificationCodeData{
+ Code: "123456",
+ CreatedAt: time.Now().UTC(),
+ ExpiresAt: time.Now().UTC().Add(10 * time.Minute),
+ },
+ }
+ svc, _, client := newAuthServiceForEmailBind(t, map[string]string{
+ service.SettingKeyAuthSourceDefaultEmailBalance: "8.5",
+ service.SettingKeyAuthSourceDefaultEmailConcurrency: "4",
+ service.SettingKeyAuthSourceDefaultEmailSubscriptions: `[{"group_id":11,"validity_days":30}]`,
+ service.SettingKeyAuthSourceDefaultEmailGrantOnFirstBind: "true",
+ }, cache, assigner)
+
+ ctx := context.Background()
+ hashedPassword, err := svc.HashPassword("current-password")
+ require.NoError(t, err)
+
+ user, err := client.User.Create().
+ SetEmail("current@example.com").
+ SetUsername("bound-user").
+ SetPasswordHash(hashedPassword).
+ SetBalance(7.5).
+ SetConcurrency(3).
+ SetRole(service.RoleUser).
+ SetStatus(service.StatusActive).
+ Save(ctx)
+ require.NoError(t, err)
+ require.NoError(t, client.AuthIdentity.Create().
+ SetUserID(user.ID).
+ SetProviderType("email").
+ SetProviderKey("email").
+ SetProviderSubject("current@example.com").
+ SetVerifiedAt(time.Now().UTC()).
+ SetMetadata(map[string]any{"source": "test"}).
+ Exec(ctx))
+
+ updatedUser, err := svc.BindEmailIdentity(ctx, user.ID, "new@example.com", "123456", "current-password")
+ require.NoError(t, err)
+ require.NotNil(t, updatedUser)
+ require.Equal(t, "new@example.com", updatedUser.Email)
+
+ storedUser, err := client.User.Get(ctx, user.ID)
+ require.NoError(t, err)
+ require.Equal(t, "new@example.com", storedUser.Email)
+ require.Equal(t, 7.5, storedUser.Balance)
+ require.Equal(t, 3, storedUser.Concurrency)
+ require.True(t, svc.CheckPassword("current-password", storedUser.PasswordHash))
+
+ newIdentityCount, err := client.AuthIdentity.Query().
+ Where(
+ authidentity.UserIDEQ(user.ID),
+ authidentity.ProviderTypeEQ("email"),
+ authidentity.ProviderKeyEQ("email"),
+ authidentity.ProviderSubjectEQ("new@example.com"),
+ ).
+ Count(ctx)
+ require.NoError(t, err)
+ require.Equal(t, 1, newIdentityCount)
+
+ oldIdentityCount, err := client.AuthIdentity.Query().
+ Where(
+ authidentity.UserIDEQ(user.ID),
+ authidentity.ProviderTypeEQ("email"),
+ authidentity.ProviderKeyEQ("email"),
+ authidentity.ProviderSubjectEQ("current@example.com"),
+ ).
+ Count(ctx)
+ require.NoError(t, err)
+ require.Equal(t, 0, oldIdentityCount)
+
+ require.Empty(t, assigner.calls)
+ require.Equal(t, 0, countProviderGrantRecords(t, client, user.ID, "email", "first_bind"))
+}
+
+func TestAuthServiceBindEmailIdentity_RejectsWrongCurrentPasswordForBoundEmail(t *testing.T) {
+ cache := &emailBindCacheStub{
+ data: &service.VerificationCodeData{
+ Code: "123456",
+ CreatedAt: time.Now().UTC(),
+ ExpiresAt: time.Now().UTC().Add(10 * time.Minute),
+ },
+ }
+ svc, _, client := newAuthServiceForEmailBind(t, nil, cache, nil)
+
+ ctx := context.Background()
+ hashedPassword, err := svc.HashPassword("current-password")
+ require.NoError(t, err)
+
+ user, err := client.User.Create().
+ SetEmail("current@example.com").
+ SetUsername("bound-user").
+ SetPasswordHash(hashedPassword).
+ SetBalance(1).
+ SetConcurrency(1).
+ SetRole(service.RoleUser).
+ SetStatus(service.StatusActive).
+ Save(ctx)
+ require.NoError(t, err)
+ require.NoError(t, client.AuthIdentity.Create().
+ SetUserID(user.ID).
+ SetProviderType("email").
+ SetProviderKey("email").
+ SetProviderSubject("current@example.com").
+ SetVerifiedAt(time.Now().UTC()).
+ SetMetadata(map[string]any{"source": "test"}).
+ Exec(ctx))
+
+ updatedUser, err := svc.BindEmailIdentity(ctx, user.ID, "new@example.com", "123456", "wrong-password")
+ require.ErrorIs(t, err, service.ErrPasswordIncorrect)
+ require.Nil(t, updatedUser)
+
+ storedUser, err := client.User.Get(ctx, user.ID)
+ require.NoError(t, err)
+ require.Equal(t, "current@example.com", storedUser.Email)
+ require.True(t, svc.CheckPassword("current-password", storedUser.PasswordHash))
+
+ oldIdentityCount, err := client.AuthIdentity.Query().
+ Where(
+ authidentity.UserIDEQ(user.ID),
+ authidentity.ProviderTypeEQ("email"),
+ authidentity.ProviderKeyEQ("email"),
+ authidentity.ProviderSubjectEQ("current@example.com"),
+ ).
+ Count(ctx)
+ require.NoError(t, err)
+ require.Equal(t, 1, oldIdentityCount)
+
+ newIdentityCount, err := client.AuthIdentity.Query().
+ Where(
+ authidentity.UserIDEQ(user.ID),
+ authidentity.ProviderTypeEQ("email"),
+ authidentity.ProviderKeyEQ("email"),
+ authidentity.ProviderSubjectEQ("new@example.com"),
+ ).
+ Count(ctx)
+ require.NoError(t, err)
+ require.Equal(t, 0, newIdentityCount)
+}
+
type emailBindSettingRepoStub struct {
values map[string]string
}
diff --git a/frontend/src/components/user/profile/ProfileIdentityBindingsSection.vue b/frontend/src/components/user/profile/ProfileIdentityBindingsSection.vue
index 653b4e33..ee582a60 100644
--- a/frontend/src/components/user/profile/ProfileIdentityBindingsSection.vue
+++ b/frontend/src/components/user/profile/ProfileIdentityBindingsSection.vue
@@ -34,7 +34,7 @@
@@ -160,7 +160,7 @@ watch(
() => props.user,
(user) => {
localUser.value = null
- if (!user || getBindingStatusForUser(user, 'email')) {
+ if (!user) {
return
}
if (typeof user.email === 'string' && !user.email.endsWith('.invalid')) {
@@ -171,6 +171,17 @@ watch(
)
const currentUser = computed(() => localUser.value ?? props.user)
+const emailBound = computed(() => getBindingStatus('email'))
+const emailPasswordPlaceholder = computed(() =>
+ emailBound.value
+ ? t('profile.authBindings.replaceEmailPasswordPlaceholder')
+ : t('profile.authBindings.passwordPlaceholder')
+)
+const emailSubmitActionLabel = computed(() =>
+ emailBound.value
+ ? t('profile.authBindings.confirmEmailReplaceAction')
+ : t('profile.authBindings.confirmEmailBindAction')
+)
const wechatOAuthSettings = computed
(() => {
if (hasExplicitWeChatOAuthCapabilities(appStore.cachedPublicSettings)) {
@@ -286,7 +297,7 @@ function validateEmailBindingForm(requireCode: boolean): boolean {
appStore.showError(t('auth.passwordRequired'))
return false
}
- if (requireCode && emailBindingForm.password.length < 6) {
+ if (requireCode && !emailBound.value && emailBindingForm.password.length < 6) {
appStore.showError(t('auth.passwordMinLength'))
return false
}
@@ -321,10 +332,15 @@ async function bindEmail(): Promise {
verify_code: emailBindingForm.verifyCode,
password: emailBindingForm.password,
})
+ const replacingBoundEmail = emailBound.value
applyUpdatedUser(user)
emailBindingForm.verifyCode = ''
emailBindingForm.password = ''
- appStore.showSuccess(t('profile.authBindings.bindSuccess'))
+ appStore.showSuccess(
+ replacingBoundEmail
+ ? t('profile.authBindings.replaceSuccess')
+ : t('profile.authBindings.bindSuccess')
+ )
} catch (error) {
appStore.showError((error as { message?: string }).message || t('common.tryAgain'))
} finally {
diff --git a/frontend/src/components/user/profile/__tests__/ProfileIdentityBindingsSection.spec.ts b/frontend/src/components/user/profile/__tests__/ProfileIdentityBindingsSection.spec.ts
index c07acf18..8821cdc5 100644
--- a/frontend/src/components/user/profile/__tests__/ProfileIdentityBindingsSection.spec.ts
+++ b/frontend/src/components/user/profile/__tests__/ProfileIdentityBindingsSection.spec.ts
@@ -51,10 +51,14 @@ vi.mock('vue-i18n', async (importOriginal) => {
if (key === 'profile.authBindings.emailPlaceholder') return 'Email address'
if (key === 'profile.authBindings.codePlaceholder') return 'Verification code'
if (key === 'profile.authBindings.passwordPlaceholder') return 'Set password'
+ if (key === 'profile.authBindings.replaceEmailPasswordPlaceholder')
+ return 'Current password'
if (key === 'profile.authBindings.sendCodeAction') return 'Send code'
if (key === 'profile.authBindings.confirmEmailBindAction') return 'Bind email'
+ if (key === 'profile.authBindings.confirmEmailReplaceAction') return 'Replace primary email'
if (key === 'profile.authBindings.codeSentTo') return `Code sent to ${params?.email || ''}`.trim()
if (key === 'profile.authBindings.bindSuccess') return 'Bind success'
+ if (key === 'profile.authBindings.replaceSuccess') return 'Primary email updated'
return key
},
}),
@@ -324,4 +328,68 @@ describe('ProfileIdentityBindingsSection', () => {
expect(wrapper.get('[data-testid="profile-binding-email-status"]').text()).toBe('Not bound')
expect(wrapper.get('[data-testid="profile-binding-email-input"]').exists()).toBe(true)
})
+
+ it('keeps the email form available for replacing a bound primary email', async () => {
+ userApiMocks.sendEmailBindingCode.mockResolvedValue(undefined)
+ userApiMocks.bindEmailIdentity.mockResolvedValue(
+ createUser({
+ email: 'new@example.com',
+ email_bound: true,
+ auth_bindings: {
+ email: { bound: true },
+ },
+ })
+ )
+
+ const appStore = useAppStore()
+ const authStore = useAuthStore()
+ authStore.user = createUser({
+ email: 'current@example.com',
+ email_bound: true,
+ auth_bindings: {
+ email: { bound: true },
+ },
+ })
+ const showSuccessSpy = vi.spyOn(appStore, 'showSuccess')
+
+ const wrapper = mount(ProfileIdentityBindingsSection, {
+ global: {
+ plugins: [pinia],
+ },
+ props: {
+ user: authStore.user,
+ linuxdoEnabled: false,
+ oidcEnabled: false,
+ wechatEnabled: false,
+ },
+ })
+
+ expect(wrapper.get('[data-testid="profile-binding-email-status"]').text()).toBe('Bound')
+ expect(wrapper.get('[data-testid="profile-binding-email-input"]').exists()).toBe(true)
+ expect(wrapper.get('[data-testid="profile-binding-email-submit"]').text()).toBe(
+ 'Replace primary email'
+ )
+ expect(
+ (wrapper.get('[data-testid="profile-binding-email-password-input"]').element as HTMLInputElement)
+ .placeholder
+ ).toBe('Current password')
+
+ await wrapper.get('[data-testid="profile-binding-email-input"]').setValue('new@example.com')
+ await wrapper.get('[data-testid="profile-binding-email-send-code"]').trigger('click')
+ expect(userApiMocks.sendEmailBindingCode).toHaveBeenCalledWith('new@example.com')
+
+ await wrapper.get('[data-testid="profile-binding-email-code-input"]').setValue('123456')
+ await wrapper.get('[data-testid="profile-binding-email-password-input"]').setValue(
+ 'current-password'
+ )
+ await wrapper.get('[data-testid="profile-binding-email-submit"]').trigger('click')
+
+ expect(userApiMocks.bindEmailIdentity).toHaveBeenCalledWith({
+ email: 'new@example.com',
+ verify_code: '123456',
+ password: 'current-password',
+ })
+ expect(authStore.user?.email).toBe('new@example.com')
+ expect(showSuccessSpy).toHaveBeenCalledWith('Primary email updated')
+ })
})
diff --git a/frontend/src/i18n/locales/en.ts b/frontend/src/i18n/locales/en.ts
index 2b41a3c3..345770a8 100644
--- a/frontend/src/i18n/locales/en.ts
+++ b/frontend/src/i18n/locales/en.ts
@@ -967,9 +967,12 @@ export default {
emailPlaceholder: 'Enter email address',
codePlaceholder: 'Enter verification code',
passwordPlaceholder: 'Set a login password',
+ replaceEmailPasswordPlaceholder: 'Enter current password',
sendCodeAction: 'Send code',
confirmEmailBindAction: 'Bind email',
+ confirmEmailReplaceAction: 'Replace primary email',
codeSentTo: 'Code sent to {email}',
+ replaceSuccess: 'Primary email updated',
status: {
bound: 'Bound',
notBound: 'Not bound',
diff --git a/frontend/src/i18n/locales/zh.ts b/frontend/src/i18n/locales/zh.ts
index b60a69d6..6493ffe8 100644
--- a/frontend/src/i18n/locales/zh.ts
+++ b/frontend/src/i18n/locales/zh.ts
@@ -971,9 +971,12 @@ export default {
emailPlaceholder: '输入邮箱地址',
codePlaceholder: '输入验证码',
passwordPlaceholder: '设置登录密码',
+ replaceEmailPasswordPlaceholder: '输入当前密码',
sendCodeAction: '发送验证码',
confirmEmailBindAction: '绑定邮箱',
+ confirmEmailReplaceAction: '更换主邮箱',
codeSentTo: '验证码已发送到 {email}',
+ replaceSuccess: '主邮箱已更新',
status: {
bound: '已绑定',
notBound: '未绑定',
From ace082066a14824462cb45386f43626ed7db9547 Mon Sep 17 00:00:00 2001
From: IanShaw027
Date: Tue, 21 Apr 2026 13:50:55 +0800
Subject: [PATCH 112/189] fix: honor ws transport when scheduler is disabled
---
.../service/openai_account_scheduler.go | 59 ++++++++++--
.../service/openai_account_scheduler_test.go | 92 +++++++++++++++++++
2 files changed, 143 insertions(+), 8 deletions(-)
diff --git a/backend/internal/service/openai_account_scheduler.go b/backend/internal/service/openai_account_scheduler.go
index 09e60220..38b92b47 100644
--- a/backend/internal/service/openai_account_scheduler.go
+++ b/backend/internal/service/openai_account_scheduler.go
@@ -767,14 +767,10 @@ func (s *defaultOpenAIAccountScheduler) selectByLoadBalance(
}
func (s *defaultOpenAIAccountScheduler) isAccountTransportCompatible(account *Account, requiredTransport OpenAIUpstreamTransport) bool {
- // HTTP 入站可回退到 HTTP 线路,不需要在账号选择阶段做传输协议强过滤。
- if requiredTransport == OpenAIUpstreamTransportAny || requiredTransport == OpenAIUpstreamTransportHTTPSSE {
- return true
- }
- if s == nil || s.service == nil || account == nil {
+ if s == nil || s.service == nil {
return false
}
- return s.service.getOpenAIWSProtocolResolver().Resolve(account).Transport == requiredTransport
+ return s.service.isOpenAIAccountTransportCompatible(account, requiredTransport)
}
func (s *defaultOpenAIAccountScheduler) ReportResult(accountID int64, success bool, firstTokenMs *int) {
@@ -899,9 +895,35 @@ func (s *OpenAIGatewayService) SelectAccountWithScheduler(
decision := OpenAIAccountScheduleDecision{}
scheduler := s.getOpenAIAccountScheduler(ctx)
if scheduler == nil {
- selection, err := s.SelectAccountWithLoadAwareness(ctx, groupID, sessionHash, requestedModel, excludedIDs)
decision.Layer = openAIAccountScheduleLayerLoadBalance
- return selection, decision, err
+ if requiredTransport == OpenAIUpstreamTransportAny || requiredTransport == OpenAIUpstreamTransportHTTPSSE {
+ selection, err := s.SelectAccountWithLoadAwareness(ctx, groupID, sessionHash, requestedModel, excludedIDs)
+ return selection, decision, err
+ }
+
+ effectiveExcludedIDs := cloneExcludedAccountIDs(excludedIDs)
+ for {
+ selection, err := s.SelectAccountWithLoadAwareness(ctx, groupID, sessionHash, requestedModel, effectiveExcludedIDs)
+ if err != nil {
+ return nil, decision, err
+ }
+ if selection == nil || selection.Account == nil {
+ return selection, decision, nil
+ }
+ if s.isOpenAIAccountTransportCompatible(selection.Account, requiredTransport) {
+ return selection, decision, nil
+ }
+ if selection.ReleaseFunc != nil {
+ selection.ReleaseFunc()
+ }
+ if effectiveExcludedIDs == nil {
+ effectiveExcludedIDs = make(map[int64]struct{})
+ }
+ if _, exists := effectiveExcludedIDs[selection.Account.ID]; exists {
+ return nil, decision, ErrNoAvailableAccounts
+ }
+ effectiveExcludedIDs[selection.Account.ID] = struct{}{}
+ }
}
var stickyAccountID int64
@@ -922,6 +944,27 @@ func (s *OpenAIGatewayService) SelectAccountWithScheduler(
})
}
+func cloneExcludedAccountIDs(excludedIDs map[int64]struct{}) map[int64]struct{} {
+ if len(excludedIDs) == 0 {
+ return nil
+ }
+ cloned := make(map[int64]struct{}, len(excludedIDs))
+ for id := range excludedIDs {
+ cloned[id] = struct{}{}
+ }
+ return cloned
+}
+
+func (s *OpenAIGatewayService) isOpenAIAccountTransportCompatible(account *Account, requiredTransport OpenAIUpstreamTransport) bool {
+ if requiredTransport == OpenAIUpstreamTransportAny || requiredTransport == OpenAIUpstreamTransportHTTPSSE {
+ return true
+ }
+ if s == nil || account == nil {
+ return false
+ }
+ return s.getOpenAIWSProtocolResolver().Resolve(account).Transport == requiredTransport
+}
+
func (s *OpenAIGatewayService) ReportOpenAIAccountScheduleResult(accountID int64, success bool, firstTokenMs *int) {
scheduler := s.getOpenAIAccountScheduler(context.Background())
if scheduler == nil {
diff --git a/backend/internal/service/openai_account_scheduler_test.go b/backend/internal/service/openai_account_scheduler_test.go
index a54f2614..b02370cb 100644
--- a/backend/internal/service/openai_account_scheduler_test.go
+++ b/backend/internal/service/openai_account_scheduler_test.go
@@ -298,6 +298,98 @@ func TestOpenAIGatewayService_SelectAccountWithScheduler_DefaultDisabledUsesLega
require.False(t, decision.StickyPreviousHit)
}
+func TestOpenAIGatewayService_SelectAccountWithScheduler_DefaultDisabled_RequiredWSV2_SkipsHTTPOnlyAccount(t *testing.T) {
+ resetOpenAIAdvancedSchedulerSettingCacheForTest()
+
+ ctx := context.Background()
+ groupID := int64(10108)
+ accounts := []Account{
+ {
+ ID: 36011,
+ Platform: PlatformOpenAI,
+ Type: AccountTypeAPIKey,
+ Status: StatusActive,
+ Schedulable: true,
+ Concurrency: 1,
+ Priority: 0,
+ },
+ {
+ ID: 36012,
+ Platform: PlatformOpenAI,
+ Type: AccountTypeAPIKey,
+ Status: StatusActive,
+ Schedulable: true,
+ Concurrency: 1,
+ Priority: 5,
+ Extra: map[string]any{
+ "openai_apikey_responses_websockets_v2_enabled": true,
+ },
+ },
+ }
+ cfg := newSchedulerTestOpenAIWSV2Config()
+ cfg.Gateway.Scheduling.LoadBatchEnabled = false
+ svc := &OpenAIGatewayService{
+ accountRepo: schedulerTestOpenAIAccountRepo{accounts: accounts},
+ cache: &schedulerTestGatewayCache{},
+ cfg: cfg,
+ concurrencyService: NewConcurrencyService(schedulerTestConcurrencyCache{}),
+ }
+
+ selection, decision, err := svc.SelectAccountWithScheduler(
+ ctx,
+ &groupID,
+ "",
+ "",
+ "gpt-5.1",
+ nil,
+ OpenAIUpstreamTransportResponsesWebsocketV2,
+ )
+ require.NoError(t, err)
+ require.NotNil(t, selection)
+ require.NotNil(t, selection.Account)
+ require.Equal(t, int64(36012), selection.Account.ID)
+ require.Equal(t, openAIAccountScheduleLayerLoadBalance, decision.Layer)
+}
+
+func TestOpenAIGatewayService_SelectAccountWithScheduler_DefaultDisabled_RequiredWSV2_NoAvailableAccount(t *testing.T) {
+ resetOpenAIAdvancedSchedulerSettingCacheForTest()
+
+ ctx := context.Background()
+ groupID := int64(10109)
+ accounts := []Account{
+ {
+ ID: 36021,
+ Platform: PlatformOpenAI,
+ Type: AccountTypeAPIKey,
+ Status: StatusActive,
+ Schedulable: true,
+ Concurrency: 1,
+ Priority: 0,
+ },
+ }
+ cfg := newSchedulerTestOpenAIWSV2Config()
+ cfg.Gateway.Scheduling.LoadBatchEnabled = false
+ svc := &OpenAIGatewayService{
+ accountRepo: schedulerTestOpenAIAccountRepo{accounts: accounts},
+ cache: &schedulerTestGatewayCache{},
+ cfg: cfg,
+ concurrencyService: NewConcurrencyService(schedulerTestConcurrencyCache{}),
+ }
+
+ selection, decision, err := svc.SelectAccountWithScheduler(
+ ctx,
+ &groupID,
+ "",
+ "",
+ "gpt-5.1",
+ nil,
+ OpenAIUpstreamTransportResponsesWebsocketV2,
+ )
+ require.ErrorContains(t, err, "no available OpenAI accounts")
+ require.Nil(t, selection)
+ require.Equal(t, openAIAccountScheduleLayerLoadBalance, decision.Layer)
+}
+
func TestOpenAIGatewayService_SelectAccountWithScheduler_EnabledUsesAdvancedPreviousResponseRouting(t *testing.T) {
resetOpenAIAdvancedSchedulerSettingCacheForTest()
From 0fcddce69edd1fca78dbd5d1b1099d8a6fe4c3b8 Mon Sep 17 00:00:00 2001
From: IanShaw027
Date: Tue, 21 Apr 2026 13:53:12 +0800
Subject: [PATCH 113/189] fix: reject http responses continuation ids
---
.../handler/openai_gateway_handler.go | 9 ++-
.../handler/openai_gateway_handler_test.go | 58 +++++++++++++++++++
2 files changed, 65 insertions(+), 2 deletions(-)
diff --git a/backend/internal/handler/openai_gateway_handler.go b/backend/internal/handler/openai_gateway_handler.go
index 5319b55d..43999a01 100644
--- a/backend/internal/handler/openai_gateway_handler.go
+++ b/backend/internal/handler/openai_gateway_handler.go
@@ -187,6 +187,11 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "previous_response_id must be a response.id (resp_*), not a message id")
return
}
+ reqLog.Warn("openai.request_validation_failed",
+ zap.String("reason", "previous_response_id_requires_wsv2"),
+ )
+ h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "previous_response_id is only supported on Responses WebSocket v2")
+ return
}
setOpsRequestContext(c, reqModel, reqStream, body)
@@ -856,7 +861,7 @@ func (h *OpenAIGatewayHandler) validateFunctionCallOutputRequest(c *gin.Context,
reqLog.Warn("openai.request_validation_failed",
zap.String("reason", "function_call_output_missing_call_id"),
)
- h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "function_call_output requires call_id or previous_response_id; if relying on history, ensure store=true and reuse previous_response_id")
+ h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "function_call_output requires call_id on HTTP requests; continuation via previous_response_id is only supported on Responses WebSocket v2")
return false
}
if validation.HasItemReferenceForAllCallIDs {
@@ -866,7 +871,7 @@ func (h *OpenAIGatewayHandler) validateFunctionCallOutputRequest(c *gin.Context,
reqLog.Warn("openai.request_validation_failed",
zap.String("reason", "function_call_output_missing_item_reference"),
)
- h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "function_call_output requires item_reference ids matching each call_id, or previous_response_id/tool_call context; if relying on history, ensure store=true and reuse previous_response_id")
+ h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "function_call_output requires item_reference ids matching each call_id on HTTP requests; continuation via previous_response_id is only supported on Responses WebSocket v2")
return false
}
diff --git a/backend/internal/handler/openai_gateway_handler_test.go b/backend/internal/handler/openai_gateway_handler_test.go
index d299fb81..8ecee59a 100644
--- a/backend/internal/handler/openai_gateway_handler_test.go
+++ b/backend/internal/handler/openai_gateway_handler_test.go
@@ -494,6 +494,64 @@ func TestOpenAIResponses_RejectsMessageIDAsPreviousResponseID(t *testing.T) {
require.Contains(t, w.Body.String(), "previous_response_id must be a response.id")
}
+func TestOpenAIResponses_RejectsHTTPContinuationPreviousResponseID(t *testing.T) {
+ gin.SetMode(gin.TestMode)
+
+ w := httptest.NewRecorder()
+ c, _ := gin.CreateTestContext(w)
+ c.Request = httptest.NewRequest(http.MethodPost, "/openai/v1/responses", strings.NewReader(
+ `{"model":"gpt-5.1","stream":false,"previous_response_id":"resp_123456","input":[{"type":"input_text","text":"hello"}]}`,
+ ))
+ c.Request.Header.Set("Content-Type", "application/json")
+
+ groupID := int64(2)
+ c.Set(string(middleware.ContextKeyAPIKey), &service.APIKey{
+ ID: 101,
+ GroupID: &groupID,
+ User: &service.User{ID: 1},
+ })
+ c.Set(string(middleware.ContextKeyUser), middleware.AuthSubject{
+ UserID: 1,
+ Concurrency: 1,
+ })
+
+ h := newOpenAIHandlerForPreviousResponseIDValidation(t, nil)
+ h.Responses(c)
+
+ require.Equal(t, http.StatusBadRequest, w.Code)
+ require.Contains(t, w.Body.String(), "Responses WebSocket v2")
+ require.Contains(t, w.Body.String(), "previous_response_id")
+}
+
+func TestOpenAIResponses_FunctionCallOutputHTTPGuidanceDoesNotSuggestPreviousResponseReuse(t *testing.T) {
+ gin.SetMode(gin.TestMode)
+
+ w := httptest.NewRecorder()
+ c, _ := gin.CreateTestContext(w)
+ c.Request = httptest.NewRequest(http.MethodPost, "/openai/v1/responses", strings.NewReader(
+ `{"model":"gpt-5.1","stream":false,"input":[{"type":"function_call_output","output":"{}"}]}`,
+ ))
+ c.Request.Header.Set("Content-Type", "application/json")
+
+ groupID := int64(2)
+ c.Set(string(middleware.ContextKeyAPIKey), &service.APIKey{
+ ID: 101,
+ GroupID: &groupID,
+ User: &service.User{ID: 1},
+ })
+ c.Set(string(middleware.ContextKeyUser), middleware.AuthSubject{
+ UserID: 1,
+ Concurrency: 1,
+ })
+
+ h := newOpenAIHandlerForPreviousResponseIDValidation(t, nil)
+ h.Responses(c)
+
+ require.Equal(t, http.StatusBadRequest, w.Code)
+ require.Contains(t, w.Body.String(), "Responses WebSocket v2")
+ require.NotContains(t, w.Body.String(), "reuse previous_response_id")
+}
+
func TestOpenAIResponsesWebSocket_SetsClientTransportWSWhenUpgradeValid(t *testing.T) {
gin.SetMode(gin.TestMode)
From 62ff2d803f172defdc6648599edd7d3379539b35 Mon Sep 17 00:00:00 2001
From: IanShaw027
Date: Tue, 21 Apr 2026 13:56:02 +0800
Subject: [PATCH 114/189] fix: normalize chat completions service tier
---
.../openai_gateway_chat_completions.go | 42 +++++++++++++++++-
.../openai_gateway_chat_completions_test.go | 44 +++++++++++++++++++
2 files changed, 85 insertions(+), 1 deletion(-)
create mode 100644 backend/internal/service/openai_gateway_chat_completions_test.go
diff --git a/backend/internal/service/openai_gateway_chat_completions.go b/backend/internal/service/openai_gateway_chat_completions.go
index ac7d28a7..663066a3 100644
--- a/backend/internal/service/openai_gateway_chat_completions.go
+++ b/backend/internal/service/openai_gateway_chat_completions.go
@@ -107,11 +107,15 @@ func (s *OpenAIGatewayService) ForwardAsChatCompletions(
responsesBody = stripped
}
}
+ responsesBody, normalizedServiceTier, err := normalizeResponsesBodyServiceTier(responsesBody)
+ if err != nil {
+ return nil, fmt.Errorf("normalize service_tier in responses-shape body: %w", err)
+ }
// Minimal stub populated from the raw body so downstream billing
// propagation (ServiceTier, ReasoningEffort) keeps working.
responsesReq = &apicompat.ResponsesRequest{
Model: upstreamModel,
- ServiceTier: gjson.GetBytes(responsesBody, "service_tier").String(),
+ ServiceTier: normalizedServiceTier,
}
if effort := gjson.GetBytes(responsesBody, "reasoning.effort").String(); effort != "" {
responsesReq.Reasoning = &apicompat.ResponsesReasoning{Effort: effort}
@@ -124,6 +128,7 @@ func (s *OpenAIGatewayService) ForwardAsChatCompletions(
return nil, fmt.Errorf("convert chat completions to responses: %w", err)
}
responsesReq.Model = upstreamModel
+ normalizeResponsesRequestServiceTier(responsesReq)
responsesBody, err = json.Marshal(responsesReq)
if err != nil {
return nil, fmt.Errorf("marshal responses request: %w", err)
@@ -274,6 +279,41 @@ func (s *OpenAIGatewayService) ForwardAsChatCompletions(
return result, handleErr
}
+func normalizeResponsesRequestServiceTier(req *apicompat.ResponsesRequest) {
+ if req == nil {
+ return
+ }
+ req.ServiceTier = normalizedOpenAIServiceTierValue(req.ServiceTier)
+}
+
+func normalizeResponsesBodyServiceTier(body []byte) ([]byte, string, error) {
+ if len(body) == 0 {
+ return body, "", nil
+ }
+ rawServiceTier := gjson.GetBytes(body, "service_tier").String()
+ if rawServiceTier == "" {
+ return body, "", nil
+ }
+ normalizedServiceTier := normalizedOpenAIServiceTierValue(rawServiceTier)
+ if normalizedServiceTier == "" {
+ trimmed, err := sjson.DeleteBytes(body, "service_tier")
+ return trimmed, "", err
+ }
+ if normalizedServiceTier == rawServiceTier {
+ return body, normalizedServiceTier, nil
+ }
+ trimmed, err := sjson.SetBytes(body, "service_tier", normalizedServiceTier)
+ return trimmed, normalizedServiceTier, err
+}
+
+func normalizedOpenAIServiceTierValue(raw string) string {
+ normalized := normalizeOpenAIServiceTier(raw)
+ if normalized == nil {
+ return ""
+ }
+ return *normalized
+}
+
// handleChatCompletionsErrorResponse reads an upstream error and returns it in
// OpenAI Chat Completions error format.
func (s *OpenAIGatewayService) handleChatCompletionsErrorResponse(
diff --git a/backend/internal/service/openai_gateway_chat_completions_test.go b/backend/internal/service/openai_gateway_chat_completions_test.go
new file mode 100644
index 00000000..a00fb71c
--- /dev/null
+++ b/backend/internal/service/openai_gateway_chat_completions_test.go
@@ -0,0 +1,44 @@
+package service
+
+import (
+ "testing"
+
+ "github.com/Wei-Shaw/sub2api/internal/pkg/apicompat"
+ "github.com/stretchr/testify/require"
+ "github.com/tidwall/gjson"
+)
+
+func TestNormalizeResponsesRequestServiceTier(t *testing.T) {
+ t.Parallel()
+
+ req := &apicompat.ResponsesRequest{ServiceTier: " fast "}
+ normalizeResponsesRequestServiceTier(req)
+ require.Equal(t, "priority", req.ServiceTier)
+
+ req.ServiceTier = "flex"
+ normalizeResponsesRequestServiceTier(req)
+ require.Equal(t, "flex", req.ServiceTier)
+
+ req.ServiceTier = "default"
+ normalizeResponsesRequestServiceTier(req)
+ require.Empty(t, req.ServiceTier)
+}
+
+func TestNormalizeResponsesBodyServiceTier(t *testing.T) {
+ t.Parallel()
+
+ body, tier, err := normalizeResponsesBodyServiceTier([]byte(`{"model":"gpt-5.1","service_tier":"fast"}`))
+ require.NoError(t, err)
+ require.Equal(t, "priority", tier)
+ require.Equal(t, "priority", gjson.GetBytes(body, "service_tier").String())
+
+ body, tier, err = normalizeResponsesBodyServiceTier([]byte(`{"model":"gpt-5.1","service_tier":"flex"}`))
+ require.NoError(t, err)
+ require.Equal(t, "flex", tier)
+ require.Equal(t, "flex", gjson.GetBytes(body, "service_tier").String())
+
+ body, tier, err = normalizeResponsesBodyServiceTier([]byte(`{"model":"gpt-5.1","service_tier":"default"}`))
+ require.NoError(t, err)
+ require.Empty(t, tier)
+ require.False(t, gjson.GetBytes(body, "service_tier").Exists())
+}
From 147ed42ad355ec4a100c0b29ba65ed3e1aeef233 Mon Sep 17 00:00:00 2001
From: IanShaw027
Date: Tue, 21 Apr 2026 14:10:30 +0800
Subject: [PATCH 115/189] fix: restrict payment return urls to internal result
page
---
backend/internal/service/payment_order.go | 2 +-
.../service/payment_resume_service.go | 46 +++++++++++++++++--
.../service/payment_resume_service_test.go | 24 ++++++++--
3 files changed, 63 insertions(+), 9 deletions(-)
diff --git a/backend/internal/service/payment_order.go b/backend/internal/service/payment_order.go
index 254af5fe..354f3cd1 100644
--- a/backend/internal/service/payment_order.go
+++ b/backend/internal/service/payment_order.go
@@ -350,7 +350,7 @@ func (s *PaymentService) invokeProvider(ctx context.Context, order *dbent.Paymen
}
subject := s.buildPaymentSubject(plan, limitAmount, cfg)
outTradeNo := order.OutTradeNo
- canonicalReturnURL, err := CanonicalizeReturnURL(req.ReturnURL)
+ canonicalReturnURL, err := CanonicalizeReturnURL(req.ReturnURL, req.SrcHost)
if err != nil {
return nil, err
}
diff --git a/backend/internal/service/payment_resume_service.go b/backend/internal/service/payment_resume_service.go
index 486aaac0..1806f5da 100644
--- a/backend/internal/service/payment_resume_service.go
+++ b/backend/internal/service/payment_resume_service.go
@@ -7,6 +7,7 @@ import (
"encoding/base64"
"encoding/json"
"fmt"
+ "net"
"net/url"
"strconv"
"strings"
@@ -16,6 +17,8 @@ import (
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
)
+const paymentResultReturnPath = "/payment/result"
+
const (
PaymentSourceHostedRedirect = "hosted_redirect"
PaymentSourceWechatInAppResume = "wechat_in_app_resume"
@@ -215,7 +218,7 @@ func visibleMethodSourceSettingKey(method string) string {
}
}
-func CanonicalizeReturnURL(raw string) (string, error) {
+func CanonicalizeReturnURL(raw string, srcHost string) (string, error) {
raw = strings.TrimSpace(raw)
if raw == "" {
return "", nil
@@ -231,19 +234,29 @@ func CanonicalizeReturnURL(raw string) (string, error) {
if parsed.Path == "" {
parsed.Path = "/"
}
+ if parsed.Path != paymentResultReturnPath {
+ return "", infraerrors.BadRequest("INVALID_RETURN_URL", "return_url must target the canonical internal payment result page")
+ }
+ if !sameOriginHost(parsed.Host, srcHost) {
+ return "", infraerrors.BadRequest("INVALID_RETURN_URL", "return_url must use the same host as the current site")
+ }
return parsed.String(), nil
}
func buildPaymentReturnURL(base string, orderID int64, resumeToken string) (string, error) {
- canonical, err := CanonicalizeReturnURL(base)
- if err != nil || canonical == "" {
- return canonical, err
+ canonical := strings.TrimSpace(base)
+ if canonical == "" {
+ return "", nil
}
parsed, err := url.Parse(canonical)
if err != nil {
return "", infraerrors.BadRequest("INVALID_RETURN_URL", "return_url must be a valid URL")
}
+ if !parsed.IsAbs() || parsed.Host == "" {
+ return "", infraerrors.BadRequest("INVALID_RETURN_URL", "return_url must be a valid absolute URL")
+ }
+ parsed.Fragment = ""
query := parsed.Query()
if orderID > 0 {
@@ -258,6 +271,31 @@ func buildPaymentReturnURL(base string, orderID int64, resumeToken string) (stri
return parsed.String(), nil
}
+func sameOriginHost(returnURLHost string, requestHost string) bool {
+ returnHost := strings.TrimSpace(returnURLHost)
+ reqHost := strings.TrimSpace(requestHost)
+ if returnHost == "" || reqHost == "" {
+ return false
+ }
+ if strings.EqualFold(returnHost, reqHost) {
+ return true
+ }
+
+ returnName, returnPort := splitHostPortDefault(returnHost)
+ reqName, reqPort := splitHostPortDefault(reqHost)
+ if returnName == "" || reqName == "" {
+ return false
+ }
+ return strings.EqualFold(returnName, reqName) && returnPort == reqPort
+}
+
+func splitHostPortDefault(raw string) (string, string) {
+ if host, port, err := net.SplitHostPort(raw); err == nil {
+ return host, port
+ }
+ return raw, ""
+}
+
func (s *PaymentResumeService) CreateToken(claims ResumeTokenClaims) (string, error) {
if err := s.ensureSigningKey(); err != nil {
return "", err
diff --git a/backend/internal/service/payment_resume_service_test.go b/backend/internal/service/payment_resume_service_test.go
index 12d67be2..7fa8dca1 100644
--- a/backend/internal/service/payment_resume_service_test.go
+++ b/backend/internal/service/payment_resume_service_test.go
@@ -64,23 +64,39 @@ func TestNormalizePaymentSource(t *testing.T) {
func TestCanonicalizeReturnURL(t *testing.T) {
t.Parallel()
- got, err := CanonicalizeReturnURL("https://example.com/pay/result?b=2#a")
+ got, err := CanonicalizeReturnURL("https://example.com/payment/result?b=2#a", "example.com")
if err != nil {
t.Fatalf("CanonicalizeReturnURL returned error: %v", err)
}
- if got != "https://example.com/pay/result?b=2" {
- t.Fatalf("CanonicalizeReturnURL = %q, want %q", got, "https://example.com/pay/result?b=2")
+ if got != "https://example.com/payment/result?b=2" {
+ t.Fatalf("CanonicalizeReturnURL = %q, want %q", got, "https://example.com/payment/result?b=2")
}
}
func TestCanonicalizeReturnURLRejectsRelativeURL(t *testing.T) {
t.Parallel()
- if _, err := CanonicalizeReturnURL("/payment/result"); err == nil {
+ if _, err := CanonicalizeReturnURL("/payment/result", "example.com"); err == nil {
t.Fatal("CanonicalizeReturnURL should reject relative URLs")
}
}
+func TestCanonicalizeReturnURLRejectsExternalHost(t *testing.T) {
+ t.Parallel()
+
+ if _, err := CanonicalizeReturnURL("https://evil.example/payment/result", "app.example.com"); err == nil {
+ t.Fatal("CanonicalizeReturnURL should reject external hosts")
+ }
+}
+
+func TestCanonicalizeReturnURLRejectsNonCanonicalPath(t *testing.T) {
+ t.Parallel()
+
+ if _, err := CanonicalizeReturnURL("https://app.example.com/orders/42", "app.example.com"); err == nil {
+ t.Fatal("CanonicalizeReturnURL should reject non-canonical result paths")
+ }
+}
+
func TestBuildPaymentReturnURL(t *testing.T) {
t.Parallel()
From 422f3449a23f937f147f6db7740a550f7ae6c5d0 Mon Sep 17 00:00:00 2001
From: IanShaw027
Date: Tue, 21 Apr 2026 14:54:42 +0800
Subject: [PATCH 116/189] chore: remove local docs from repo
---
.gitignore | 5 +-
docs/ADMIN_PAYMENT_INTEGRATION_API.md | 243 ------
docs/PAYMENT.md | 287 -------
docs/PAYMENT_CN.md | 287 -------
...-04-20-auth-identity-payment-foundation.md | 539 -------------
...auth-identity-payment-foundation-design.md | 763 ------------------
6 files changed, 1 insertion(+), 2123 deletions(-)
delete mode 100644 docs/ADMIN_PAYMENT_INTEGRATION_API.md
delete mode 100644 docs/PAYMENT.md
delete mode 100644 docs/PAYMENT_CN.md
delete mode 100644 docs/superpowers/plans/2026-04-20-auth-identity-payment-foundation.md
delete mode 100644 docs/superpowers/specs/2026-04-20-auth-identity-payment-foundation-design.md
diff --git a/.gitignore b/.gitignore
index 1a92ea3e..cf2bda9f 100644
--- a/.gitignore
+++ b/.gitignore
@@ -126,12 +126,9 @@ backend/cmd/server/server
deploy/docker-compose.override.yml
.gocache/
vite.config.js
-docs/*
-!docs/PAYMENT.md
-!docs/PAYMENT_CN.md
+docs/
.serena/
.codex/
frontend/coverage/
aicodex
output/
-
diff --git a/docs/ADMIN_PAYMENT_INTEGRATION_API.md b/docs/ADMIN_PAYMENT_INTEGRATION_API.md
deleted file mode 100644
index f674f86c..00000000
--- a/docs/ADMIN_PAYMENT_INTEGRATION_API.md
+++ /dev/null
@@ -1,243 +0,0 @@
-# ADMIN_PAYMENT_INTEGRATION_API
-
-> 单文件中英双语文档 / Single-file bilingual documentation (Chinese + English)
-
----
-
-## 中文
-
-### 目标
-本文档用于对接外部支付系统(如 `sub2apipay`)与 Sub2API 的 Admin API,覆盖:
-- 支付成功后充值
-- 用户查询
-- 人工余额修正
-- 前端购买页参数透传
-
-### 基础地址
-- 生产:`https://`
-- Beta:`http://:8084`
-
-### 认证
-推荐使用:
-- `x-api-key: admin-<64hex>`
-- `Content-Type: application/json`
-- 幂等接口额外传:`Idempotency-Key`
-
-说明:管理员 JWT 也可访问 admin 路由,但服务间调用建议使用 Admin API Key。
-
-### 1) 一步完成创建并兑换
-`POST /api/v1/admin/redeem-codes/create-and-redeem`
-
-用途:原子完成“创建兑换码 + 兑换到指定用户”。
-
-请求头:
-- `x-api-key`
-- `Idempotency-Key`
-
-请求体示例:
-```json
-{
- "code": "s2p_cm1234567890",
- "type": "balance",
- "value": 100.0,
- "user_id": 123,
- "notes": "sub2apipay order: cm1234567890"
-}
-```
-
-幂等语义:
-- 同 `code` 且 `used_by` 一致:`200`
-- 同 `code` 但 `used_by` 不一致:`409`
-- 缺少 `Idempotency-Key`:`400`(`IDEMPOTENCY_KEY_REQUIRED`)
-
-curl 示例:
-```bash
-curl -X POST "${BASE}/api/v1/admin/redeem-codes/create-and-redeem" \
- -H "x-api-key: ${KEY}" \
- -H "Idempotency-Key: pay-cm1234567890-success" \
- -H "Content-Type: application/json" \
- -d '{
- "code":"s2p_cm1234567890",
- "type":"balance",
- "value":100.00,
- "user_id":123,
- "notes":"sub2apipay order: cm1234567890"
- }'
-```
-
-### 2) 查询用户(可选前置校验)
-`GET /api/v1/admin/users/:id`
-
-```bash
-curl -s "${BASE}/api/v1/admin/users/123" \
- -H "x-api-key: ${KEY}"
-```
-
-### 3) 余额调整(已有接口)
-`POST /api/v1/admin/users/:id/balance`
-
-用途:人工补偿 / 扣减,支持 `set` / `add` / `subtract`。
-
-请求体示例(扣减):
-```json
-{
- "balance": 100.0,
- "operation": "subtract",
- "notes": "manual correction"
-}
-```
-
-```bash
-curl -X POST "${BASE}/api/v1/admin/users/123/balance" \
- -H "x-api-key: ${KEY}" \
- -H "Idempotency-Key: balance-subtract-cm1234567890" \
- -H "Content-Type: application/json" \
- -d '{
- "balance":100.00,
- "operation":"subtract",
- "notes":"manual correction"
- }'
-```
-
-### 4) 购买页 / 自定义页面 URL Query 透传(iframe / 新窗口一致)
-当 Sub2API 打开 `purchase_subscription_url` 或用户侧自定义页面 iframe URL 时,会统一追加:
-- `user_id`
-- `token`
-- `theme`(`light` / `dark`)
-- `lang`(例如 `zh` / `en`,用于向嵌入页传递当前界面语言)
-- `ui_mode`(固定 `embedded`)
-
-示例:
-```text
-https://pay.example.com/pay?user_id=123&token=&theme=light&lang=zh&ui_mode=embedded
-```
-
-### 5) 失败处理建议
-- 支付成功与充值成功分状态落库
-- 回调验签成功后立即标记“支付成功”
-- 支付成功但充值失败的订单允许后续重试
-- 重试保持相同 `code`,并使用新的 `Idempotency-Key`
-
-### 6) `doc_url` 配置建议
-- 查看链接:`https://github.com/Wei-Shaw/sub2api/blob/main/ADMIN_PAYMENT_INTEGRATION_API.md`
-- 下载链接:`https://raw.githubusercontent.com/Wei-Shaw/sub2api/main/ADMIN_PAYMENT_INTEGRATION_API.md`
-
----
-
-## English
-
-### Purpose
-This document describes the minimal Sub2API Admin API surface for external payment integrations (for example, `sub2apipay`), including:
-- Recharge after payment success
-- User lookup
-- Manual balance correction
-- Purchase page query parameter forwarding
-
-### Base URL
-- Production: `https://`
-- Beta: `http://:8084`
-
-### Authentication
-Recommended headers:
-- `x-api-key: admin-<64hex>`
-- `Content-Type: application/json`
-- `Idempotency-Key` for idempotent endpoints
-
-Note: Admin JWT can also access admin routes, but Admin API Key is recommended for server-to-server integration.
-
-### 1) Create and Redeem in one step
-`POST /api/v1/admin/redeem-codes/create-and-redeem`
-
-Use case: atomically create a redeem code and redeem it to a target user.
-
-Headers:
-- `x-api-key`
-- `Idempotency-Key`
-
-Request body:
-```json
-{
- "code": "s2p_cm1234567890",
- "type": "balance",
- "value": 100.0,
- "user_id": 123,
- "notes": "sub2apipay order: cm1234567890"
-}
-```
-
-Idempotency behavior:
-- Same `code` and same `used_by`: `200`
-- Same `code` but different `used_by`: `409`
-- Missing `Idempotency-Key`: `400` (`IDEMPOTENCY_KEY_REQUIRED`)
-
-curl example:
-```bash
-curl -X POST "${BASE}/api/v1/admin/redeem-codes/create-and-redeem" \
- -H "x-api-key: ${KEY}" \
- -H "Idempotency-Key: pay-cm1234567890-success" \
- -H "Content-Type: application/json" \
- -d '{
- "code":"s2p_cm1234567890",
- "type":"balance",
- "value":100.00,
- "user_id":123,
- "notes":"sub2apipay order: cm1234567890"
- }'
-```
-
-### 2) Query User (optional pre-check)
-`GET /api/v1/admin/users/:id`
-
-```bash
-curl -s "${BASE}/api/v1/admin/users/123" \
- -H "x-api-key: ${KEY}"
-```
-
-### 3) Balance Adjustment (existing API)
-`POST /api/v1/admin/users/:id/balance`
-
-Use case: manual correction with `set` / `add` / `subtract`.
-
-Request body example (`subtract`):
-```json
-{
- "balance": 100.0,
- "operation": "subtract",
- "notes": "manual correction"
-}
-```
-
-```bash
-curl -X POST "${BASE}/api/v1/admin/users/123/balance" \
- -H "x-api-key: ${KEY}" \
- -H "Idempotency-Key: balance-subtract-cm1234567890" \
- -H "Content-Type: application/json" \
- -d '{
- "balance":100.00,
- "operation":"subtract",
- "notes":"manual correction"
- }'
-```
-
-### 4) Purchase / Custom Page URL query forwarding (iframe and new tab)
-When Sub2API opens `purchase_subscription_url` or a user-facing custom page iframe URL, it appends:
-- `user_id`
-- `token`
-- `theme` (`light` / `dark`)
-- `lang` (for example `zh` / `en`, used to pass the current UI language to the embedded page)
-- `ui_mode` (fixed: `embedded`)
-
-Example:
-```text
-https://pay.example.com/pay?user_id=123&token=&theme=light&lang=zh&ui_mode=embedded
-```
-
-### 5) Failure handling recommendations
-- Persist payment success and recharge success as separate states
-- Mark payment as successful immediately after verified callback
-- Allow retry for orders with payment success but recharge failure
-- Keep the same `code` for retry, and use a new `Idempotency-Key`
-
-### 6) Recommended `doc_url`
-- View URL: `https://github.com/Wei-Shaw/sub2api/blob/main/ADMIN_PAYMENT_INTEGRATION_API.md`
-- Download URL: `https://raw.githubusercontent.com/Wei-Shaw/sub2api/main/ADMIN_PAYMENT_INTEGRATION_API.md`
diff --git a/docs/PAYMENT.md b/docs/PAYMENT.md
deleted file mode 100644
index 9322f7bf..00000000
--- a/docs/PAYMENT.md
+++ /dev/null
@@ -1,287 +0,0 @@
-# Payment System Configuration Guide
-
-Sub2API has a built-in payment system that enables user self-service top-up without deploying a separate payment service.
-
----
-
-## Table of Contents
-
-- [Supported Payment Methods](#supported-payment-methods)
-- [Quick Start](#quick-start)
-- [System Settings](#system-settings)
-- [Provider Configuration](#provider-configuration)
-- [Provider Instance Management](#provider-instance-management)
-- [Webhook Configuration](#webhook-configuration)
-- [Payment Flow](#payment-flow)
-- [Migrating from Sub2ApiPay](#migrating-from-sub2apipay)
-
----
-
-## Supported Payment Methods
-
-| Provider | Payment Methods | Description |
-|----------|----------------|-------------|
-| **EasyPay** | Alipay, WeChat Pay | Third-party aggregation via EasyPay protocol |
-| **Alipay (Direct)** | Desktop QR code, mobile Alipay redirect | Direct integration with Alipay Open Platform, returning desktop QR codes and mobile WAP/app launch links |
-| **WeChat Pay (Direct)** | Native QR, H5, MP/JSAPI Pay | Direct integration with WeChat Pay APIv3 with environment-aware routing |
-| **Stripe** | Card, Alipay, WeChat Pay, Link, etc. | International payments, multi-currency support |
-
-> Alipay/WeChat Pay direct and EasyPay can both exist as backend provider instances, but the frontend always exposes only two visible buttons: `Alipay` and `WeChat Pay`. Admins choose exactly one source for each visible method: direct or EasyPay. Direct channels connect to payment APIs directly with lower fees; EasyPay aggregates through third-party platforms with easier setup.
-
-> **EasyPay Provider Recommendations**: Both options below are third-party aggregators compatible with the EasyPay protocol. Pick based on the funding channel and settlement currency you need:
->
-> - **Domestic channel / CNY settlement** — [ZPay](https://z-pay.cn/?uid=23808) (`https://z-pay.cn/?uid=23808`): direct integration with official Alipay / WeChat Pay APIs, fee **1.6%**; funds go straight to the merchant account with **T+1 automatic settlement**. Supports **individual users** (no business license required) with up to 10,000 CNY daily transactions; business-licensed accounts have no limit. Link contains the referral code of [Sub2ApiPay](https://github.com/touwaeriol/sub2apipay) original author [@touwaeriol](https://github.com/touwaeriol) — feel free to remove it.
-> - **International channel / USDT or USD settlement** — [Kyren Topup](https://kyren.top/?code=SUB2API) (`https://kyren.top/?code=SUB2API`): a ready-to-launch global payment stack for AI startups with WeChat Pay and Alipay support, local-currency checkout, and USD settlement. Fees: WeChat 2%, Alipay 2.5%; withdrawal 0.1% (min $40, max $150), settled in **USDT or USD**. No qualification review required — sign up and use immediately, making it the lowest barrier to entry. Withdrawal threshold is relatively high, recommended for users **who do not use domestic Chinese payment channels, cannot tolerate Stripe's 6%+ fees, have high transaction volume, and have USD or USDT channels to receive withdrawn funds**. Kyren Topup charges a $200 account opening fee; signing up via this link (which contains Sub2Api author [@Wei-Shaw](https://github.com/Wei-Shaw)'s referral code) **waives the opening fee**. Feel free to remove it if you prefer.
->
-> Please evaluate the security, reliability, and compliance of any third-party payment provider on your own — this project does not endorse or guarantee any of them.
-
----
-
-## Quick Start
-
-1. Go to Admin Dashboard → **Settings** → **Payment Settings** tab
-2. Enable **Payment**
-3. Configure basic parameters (amount range, timeout, etc.)
-4. Add at least one provider instance in **Provider Management**
-5. Users can now top up from the frontend
-
----
-
-## System Settings
-
-Configure the following in Admin Dashboard **Settings → Payment Settings**:
-
-### Basic Settings
-
-| Setting | Description | Default |
-|---------|-------------|---------|
-| **Enable Payment** | Enable or disable the payment system | Off |
-| **Product Name Prefix** | Prefix shown on payment page | - |
-| **Product Name Suffix** | Suffix (e.g., "Credits") | - |
-| **Minimum Amount** | Minimum single top-up amount | 1 |
-| **Maximum Amount** | Maximum single top-up amount (empty = unlimited) | - |
-| **Daily Limit** | Per-user daily cumulative limit (empty = unlimited) | - |
-| **Order Timeout** | Order timeout in minutes (minimum 1) | 30 |
-| **Max Pending Orders** | Maximum concurrent pending orders per user | 3 |
-| **Load Balance Strategy** | Strategy for selecting provider instances | Round Robin |
-
-### Frontend Visible Method Routing
-
-The current payment UX keeps the frontend method list unified and does not expose provider brands directly:
-
-- **Alipay**: when enabled, this button must be routed to either `Alipay (Direct)` or `EasyPay Alipay`
-- **WeChat Pay**: when enabled, this button must be routed to either `WeChat Pay (Direct)` or `EasyPay WeChat`
-- Each visible method can route to only one source at a time
-- If a visible method is enabled without a selected source, the frontend will not expose that method
-
-### Load Balance Strategies
-
-| Strategy | Description |
-|----------|-------------|
-| **Round Robin** | Distribute orders to instances in rotation |
-| **Least Amount** | Prefer instances with the lowest daily cumulative amount |
-
-### Cancel Rate Limiting
-
-Prevents users from repeatedly creating and canceling orders:
-
-| Setting | Description |
-|---------|-------------|
-| **Enable Limit** | Toggle |
-| **Window Mode** | Sliding / Fixed window |
-| **Time Window** | Window duration |
-| **Window Unit** | Minutes / Hours |
-| **Max Cancels** | Maximum cancellations allowed within the window |
-
-### Help Information
-
-| Setting | Description |
-|---------|-------------|
-| **Help Image** | Customer service QR code or help image (supports upload) |
-| **Help Text** | Instructions displayed on the payment page |
-
----
-
-## Provider Configuration
-
-Each provider type requires different credentials. Select the type when adding a new provider instance in **Provider Management → Add Provider**.
-
-> **Callback URLs are auto-generated**: When adding a provider, the Notify URL and Return URL are automatically constructed from your site domain. You only need to confirm the domain is correct.
-
-### EasyPay
-
-Compatible with any payment service that implements the EasyPay protocol.
-
-| Parameter | Description | Required |
-|-----------|-------------|----------|
-| **Merchant ID (PID)** | EasyPay merchant ID | Yes |
-| **Merchant Key (PKey)** | EasyPay merchant secret key | Yes |
-| **API Base URL** | EasyPay API base address | Yes |
-| **Alipay Channel ID** | Specify Alipay channel (optional) | No |
-| **WeChat Channel ID** | Specify WeChat channel (optional) | No |
-
-### Alipay (Direct)
-
-Direct integration with Alipay Open Platform. Desktop flows return a QR code for in-page display, while mobile flows return an Alipay WAP/app redirect URL.
-
-| Parameter | Description | Required |
-|-----------|-------------|----------|
-| **AppID** | Alipay application AppID | Yes |
-| **Private Key** | RSA2 application private key | Yes |
-| **Alipay Public Key** | Alipay public key | Yes |
-
-### WeChat Pay (Direct)
-
-Direct integration with WeChat Pay APIv3. Supports Native QR code payment, H5 payment, and MP/JSAPI payment inside the WeChat environment.
-
-| Parameter | Description | Required |
-|-----------|-------------|----------|
-| **AppID** | WeChat Pay AppID | Yes |
-| **Merchant ID (MchID)** | WeChat Pay merchant ID | Yes |
-| **Merchant API Private Key** | Merchant API private key (PEM format) | Yes |
-| **APIv3 Key** | 32-byte APIv3 key | Yes |
-| **WeChat Pay Public Key** | WeChat Pay public key (PEM format) | Yes |
-| **WeChat Pay Public Key ID** | WeChat Pay public key ID | Yes |
-| **Certificate Serial Number** | Merchant certificate serial number | Yes |
-
-### Stripe
-
-International payment platform supporting multiple payment methods and currencies.
-
-| Parameter | Description | Required |
-|-----------|-------------|----------|
-| **Secret Key** | Stripe secret key (`sk_live_...` or `sk_test_...`) | Yes |
-| **Publishable Key** | Stripe publishable key (`pk_live_...` or `pk_test_...`) | Yes |
-| **Webhook Secret** | Stripe Webhook signing secret (`whsec_...`) | Yes |
-
----
-
-## Provider Instance Management
-
-You can create **multiple instances** of the same provider type for load balancing and risk control:
-
-- **Multi-instance load balancing** — Distribute orders via round-robin or least-amount strategy
-- **Independent limits** — Each instance can have its own min/max amount and daily limit
-- **Independent toggle** — Enable/disable individual instances without affecting others
-- **Refund control** — Enable or disable refunds per instance
-- **Payment methods** — Each instance can support a subset of payment methods
-- **Ordering** — Drag to reorder instances
-
-### Instance Limit Configuration
-
-Each instance supports these limits:
-
-| Limit | Description |
-|-------|-------------|
-| **Minimum Amount** | Minimum order amount accepted by this instance |
-| **Maximum Amount** | Maximum order amount accepted by this instance |
-| **Daily Limit** | Daily cumulative transaction limit for this instance |
-
-> During load balancing, instances that exceed their limits are automatically skipped.
-
----
-
-## Webhook Configuration
-
-Payment callbacks are essential for the payment system to work correctly.
-
-### Callback URL Format
-
-When adding a provider, the system auto-generates callback URLs from your site domain:
-
-| Provider | Callback Path |
-|----------|-------------|
-| **EasyPay** | `https://your-domain.com/api/v1/payment/webhook/easypay` |
-| **Alipay (Direct)** | `https://your-domain.com/api/v1/payment/webhook/alipay` |
-| **WeChat Pay (Direct)** | `https://your-domain.com/api/v1/payment/webhook/wxpay` |
-| **Stripe** | `https://your-domain.com/api/v1/payment/webhook/stripe` |
-
-> Replace `your-domain.com` with your actual domain. For EasyPay / Alipay / WeChat Pay, the callback URL is auto-filled when adding the provider — no manual configuration needed.
-
-### Stripe Webhook Setup
-
-1. Log in to [Stripe Dashboard](https://dashboard.stripe.com/)
-2. Go to **Developers → Webhooks**
-3. Add an endpoint with the callback URL
-4. Subscribe to events: `payment_intent.succeeded`, `payment_intent.payment_failed`
-5. Copy the generated Webhook Secret (`whsec_...`) to your provider configuration
-
-### Important Notes
-
-- Callback URLs must use **HTTPS** (required by Stripe, strongly recommended for others)
-- Ensure your firewall allows callback requests from payment platforms
-- The system automatically verifies callback signatures to prevent forgery
-- Balance top-up is processed automatically upon successful payment — no manual intervention needed
-
----
-
-## Payment Flow
-
-```
-User selects amount and payment method
- │
- ▼
- Create Order (PENDING)
- ├─ Validate amount range, pending order count, daily limit
- ├─ Load balance to select provider instance
- └─ Call provider to get payment info
- │
- ▼
- User completes payment
- ├─ EasyPay → QR code / H5 redirect
- ├─ Alipay → Desktop QR / mobile Alipay redirect
- ├─ WeChat Pay → Desktop Native QR / non-WeChat H5 / in-WeChat JSAPI
- └─ Stripe → Payment Element (card/Alipay/WeChat/etc.)
- │
- ▼
- Webhook callback verified → Order PAID
- │
- ▼
- Auto top-up to user balance → Order COMPLETED
-```
-
-### Order Status Reference
-
-| Status | Description |
-|--------|-------------|
-| `PENDING` | Waiting for user to complete payment |
-| `PAID` | Payment confirmed, awaiting balance credit |
-| `COMPLETED` | Balance credited successfully |
-| `EXPIRED` | Timed out without payment |
-| `CANCELLED` | Cancelled by user |
-| `FAILED` | Balance credit failed, admin can retry |
-| `REFUND_REQUESTED` | Refund requested |
-| `REFUNDING` | Refund in progress |
-| `REFUNDED` | Refund completed |
-
-### Timeout and Fallback
-
-- Before marking an order as expired, the background job queries the upstream payment status first
-- If the user has actually paid but the callback was delayed, the system will reconcile automatically
-- The background job runs every 60 seconds to check for timed-out orders
-
----
-
-## Migrating from Sub2ApiPay
-
-If you previously used [Sub2ApiPay](https://github.com/touwaeriol/sub2apipay) as an external payment system, you can migrate to the built-in payment system:
-
-### Key Differences
-
-| Aspect | Sub2ApiPay | Built-in Payment |
-|--------|-----------|-----------------|
-| Deployment | Separate service (Next.js + PostgreSQL) | Built into Sub2API, no extra deployment |
-| Payment Methods | EasyPay, Alipay, WeChat, Stripe | Same |
-| Configuration | Environment variables + separate admin UI | Unified in Sub2API admin dashboard |
-| Top-up Integration | Via Admin API callback | Internal processing, more reliable |
-| Subscription Plans | Supported | Not yet (planned) |
-| Order Management | Separate admin interface | Integrated in Sub2API admin dashboard |
-
-### Migration Steps
-
-1. Enable payment in Sub2API admin dashboard and configure providers (use the same payment credentials)
-2. Update webhook callback URLs to Sub2API's callback endpoints
-3. Verify that new orders are processed correctly via built-in payment
-4. Decommission the Sub2ApiPay service
-
-> **Note**: Historical order data from Sub2ApiPay will not be automatically migrated. Keep Sub2ApiPay running for a while to access historical records.
diff --git a/docs/PAYMENT_CN.md b/docs/PAYMENT_CN.md
deleted file mode 100644
index 0fbc198a..00000000
--- a/docs/PAYMENT_CN.md
+++ /dev/null
@@ -1,287 +0,0 @@
-# 支付系统配置指南
-
-Sub2API 内置支付系统,支持用户自助充值,无需部署独立的支付服务。
-
----
-
-## 目录
-
-- [支持的支付方式](#支持的支付方式)
-- [快速开始](#快速开始)
-- [系统设置](#系统设置)
-- [服务商配置](#服务商配置)
-- [服务商实例管理](#服务商实例管理)
-- [Webhook 配置](#webhook-配置)
-- [支付流程](#支付流程)
-- [从 Sub2ApiPay 迁移](#从-sub2apipay-迁移)
-
----
-
-## 支持的支付方式
-
-| 服务商 | 支付方式 | 说明 |
-|--------|---------|------|
-| **EasyPay(易支付)** | 支付宝、微信支付 | 兼容易支付协议的第三方聚合支付 |
-| **支付宝官方** | 桌面二维码扫码、移动端支付宝跳转 | 直接对接支付宝开放平台,桌面端返回二维码,移动端返回 WAP/唤起链接 |
-| **微信官方** | Native 扫码、H5、公众号/JSAPI 支付 | 直接对接微信支付 APIv3,按终端环境自动分流 |
-| **Stripe** | 银行卡、支付宝、微信支付、Link 等 | 国际支付,支持多币种 |
-
-> 支付宝官方 / 微信官方与易支付可以同时作为后台服务商实例存在,但前台始终只展示 `支付宝`、`微信支付` 两个可见按钮。管理员需要分别为这两个按钮选择唯一支付来源:官方或易支付。官方渠道直接对接 API,资金直达商户账户,手续费更低;易支付通过第三方平台聚合,接入门槛更低。
-
-> **易支付服务商推荐**:以下两家均为兼容易支付协议的第三方聚合支付,按资金通道与结算方式选择:
->
-> - **国内渠道 / 人民币结算** — [ZPay](https://z-pay.cn/?uid=23808)(`https://z-pay.cn/?uid=23808`):支付宝 / 微信官方 API 直连,手续费 **1.6%**;资金直达商家账户,**T+1 自动到账**。支持**个人用户**(无营业执照)每日 1 万元以内交易;拥有营业执照则无限额。链接含 [Sub2ApiPay](https://github.com/touwaeriol/sub2apipay) 原作者 [@touwaeriol](https://github.com/touwaeriol) 的邀请码,介意可去掉。
-> - **国际渠道 / USDT 或美元结算** — [启润支付](https://kyren.top/?code=SUB2API)(`https://kyren.top/?code=SUB2API`):为 AI 项目提供低门槛国际收款通道,支持国际版微信支付与支付宝,本地货币支付、美元结算。手续费:微信 2%、支付宝 2.5%;提现 0.1%(最低 40 美元、最高 150 美元),以 **USDT 或美元**到账。无资质审核、注册即用,使用门槛最低;提现门槛略高,适合**不使用国内支付渠道、无法接受 Stripe 高达 6%+ 手续费、流水较大,且拥有美元或 USDT 渠道可接收提现资金**的用户。启润支付开户费 200 美元,通过本链接注册(含 Sub2Api 作者 [@Wei-Shaw](https://github.com/Wei-Shaw) 邀请码)可**免开户费**,介意可去掉。
->
-> 支付渠道的安全性、稳定性及合规性请自行鉴别,本项目不对任何第三方支付服务商做担保或背书。
-
----
-
-## 快速开始
-
-1. 进入管理后台 → **设置** → **支付设置** 标签页
-2. 开启 **启用支付**
-3. 配置基本参数(金额范围、超时时间等)
-4. 在 **服务商管理** 中添加至少一个服务商实例
-5. 用户即可在前端页面进行充值
-
----
-
-## 系统设置
-
-在管理后台 **设置 → 支付设置** 中配置以下参数:
-
-### 基本设置
-
-| 设置项 | 说明 | 默认值 |
-|--------|------|--------|
-| **启用支付** | 启用或禁用支付系统 | 关闭 |
-| **商品名前缀** | 支付页面显示的商品名前缀 | - |
-| **商品名后缀** | 商品名后缀(如"元") | - |
-| **最低金额** | 单笔最低充值金额 | 1 |
-| **最高金额** | 单笔最高充值金额(留空表示不限制) | - |
-| **每日限额** | 每用户每日累计充值上限(留空表示不限制) | - |
-| **订单超时时间** | 订单超时分钟数,至少 1 分钟 | 30 |
-| **最大待支付订单数** | 同一用户最大并行待支付订单数 | 3 |
-| **负载均衡策略** | 多服务商实例时的选择策略 | 轮询 |
-
-### 前台可见支付方式路由
-
-当前版本对用户统一展示支付方式,不区分官方渠道还是易支付:
-
-- **支付宝**:后台启用后,需要额外指定该按钮路由到 `支付宝官方` 或 `易支付支付宝`
-- **微信支付**:后台启用后,需要额外指定该按钮路由到 `微信官方` 或 `易支付微信`
-- 同一个可见支付方式在同一时刻只能路由到一个来源
-- 支付来源未选择时,即使对应按钮被开启,前台也不会暴露该支付方式
-
-### 负载均衡策略
-
-| 策略 | 说明 |
-|------|------|
-| **轮询(round-robin)** | 按顺序轮流分配到各服务商实例 |
-| **最少金额(least-amount)** | 优先分配到当日累计金额最少的实例 |
-
-### 取消频率限制
-
-防止用户频繁创建并取消订单:
-
-| 设置项 | 说明 |
-|--------|------|
-| **启用限制** | 开关 |
-| **窗口模式** | 滚动窗口 / 固定窗口 |
-| **时间窗口** | 窗口长度 |
-| **窗口单位** | 分钟 / 小时 |
-| **最大次数** | 窗口内允许的最大取消次数 |
-
-### 帮助信息
-
-| 设置项 | 说明 |
-|--------|------|
-| **帮助图片** | 充值页面显示的客服二维码等图片(支持上传) |
-| **帮助文本** | 充值页面显示的说明文字 |
-
----
-
-## 服务商配置
-
-每种服务商需要不同的凭证和参数。在 **服务商管理 → 添加服务商** 中选择类型后填写。
-
-> **回调地址自动生成**:添加服务商时,异步回调地址(Notify URL)和同步跳转地址(Return URL)由系统根据你的站点域名自动拼接,无需手动填写。管理员只需确认域名正确即可。
-
-### EasyPay(易支付)
-
-兼容任何 EasyPay 协议的支付服务商。
-
-| 参数 | 说明 | 必填 |
-|------|------|------|
-| **商户 ID(PID)** | EasyPay 商户 ID | 是 |
-| **商户密钥(PKey)** | EasyPay 商户密钥 | 是 |
-| **API 地址** | EasyPay API 基础地址 | 是 |
-| **支付宝通道 ID** | 指定支付宝通道(可选) | 否 |
-| **微信通道 ID** | 指定微信通道(可选) | 否 |
-
-### 支付宝官方
-
-直接对接支付宝开放平台。桌面端返回二维码供页面内展示和扫码,移动端返回支付宝手机网站支付跳转链接。
-
-| 参数 | 说明 | 必填 |
-|------|------|------|
-| **AppID** | 支付宝应用 AppID | 是 |
-| **应用私钥** | RSA2 应用私钥 | 是 |
-| **支付宝公钥** | 支付宝公钥 | 是 |
-
-### 微信官方
-
-直接对接微信支付 APIv3,支持 Native 扫码支付、H5 支付,以及在微信环境内的公众号/JSAPI 支付。
-
-| 参数 | 说明 | 必填 |
-|------|------|------|
-| **AppID** | 微信支付 AppID | 是 |
-| **商户号(MchID)** | 微信支付商户号 | 是 |
-| **商户 API 私钥** | 商户 API 私钥(PEM 格式) | 是 |
-| **APIv3 密钥** | 32 位 APIv3 密钥 | 是 |
-| **微信支付公钥** | 微信支付公钥(PEM 格式) | 是 |
-| **微信支付公钥 ID** | 微信支付公钥 ID | 是 |
-| **商户证书序列号** | 商户证书序列号 | 是 |
-
-### Stripe
-
-国际支付平台,支持多种支付方式和币种。
-
-| 参数 | 说明 | 必填 |
-|------|------|------|
-| **Secret Key** | Stripe 密钥(`sk_live_...` 或 `sk_test_...`) | 是 |
-| **Publishable Key** | Stripe 可公开密钥(`pk_live_...` 或 `pk_test_...`) | 是 |
-| **Webhook Secret** | Stripe Webhook 签名密钥(`whsec_...`) | 是 |
-
----
-
-## 服务商实例管理
-
-同一种服务商可以创建**多个实例**,实现负载均衡和风控:
-
-- **多实例负载均衡** — 按轮询或最少金额策略分流订单
-- **独立限额** — 每个实例可独立配置单笔最小/最大金额和每日限额
-- **独立启停** — 可单独启用/禁用某个实例,不影响其他实例
-- **退款控制** — 每个实例可单独开启或关闭退款功能
-- **支付方式** — 每个实例可选择支持的支付方式子集
-- **排序** — 拖拽调整实例顺序
-
-### 实例限额配置
-
-每个实例支持以下限额:
-
-| 限额项 | 说明 |
-|--------|------|
-| **单笔最小金额** | 该实例接受的最小订单金额 |
-| **单笔最大金额** | 该实例接受的最大订单金额 |
-| **每日限额** | 该实例每日累计交易上限 |
-
-> 负载均衡时,系统会自动跳过超出限额的实例。
-
----
-
-## Webhook 配置
-
-支付回调是支付系统的核心环节,必须正确配置:
-
-### 回调地址格式
-
-添加服务商时,系统会自动根据站点域名拼接回调地址,格式如下:
-
-| 服务商 | 回调路径 |
-|--------|---------|
-| **EasyPay** | `https://your-domain.com/api/v1/payment/webhook/easypay` |
-| **支付宝官方** | `https://your-domain.com/api/v1/payment/webhook/alipay` |
-| **微信官方** | `https://your-domain.com/api/v1/payment/webhook/wxpay` |
-| **Stripe** | `https://your-domain.com/api/v1/payment/webhook/stripe` |
-
-> 将 `your-domain.com` 替换为你的实际域名。EasyPay / 支付宝 / 微信的回调地址在添加服务商时自动填入,无需手动配置。
-
-### Stripe Webhook 设置
-
-1. 登录 [Stripe Dashboard](https://dashboard.stripe.com/)
-2. 进入 **Developers → Webhooks**
-3. 添加端点,填写回调地址
-4. 订阅事件:`payment_intent.succeeded`、`payment_intent.payment_failed`
-5. 将生成的 Webhook Secret(`whsec_...`)填入服务商配置
-
-### 注意事项
-
-- 回调地址必须是 **HTTPS**(Stripe 强制要求,其他服务商强烈推荐)
-- 确保服务器防火墙允许支付平台的回调请求
-- 系统会自动进行签名验证,防止伪造回调
-- 支付成功后自动完成余额充值,无需人工干预
-
----
-
-## 支付流程
-
-```
-用户选择充值金额和支付方式
- │
- ▼
- 创建订单 (PENDING)
- ├─ 校验金额范围、待支付订单数、每日限额
- ├─ 负载均衡选择服务商实例
- └─ 调用服务商获取支付信息
- │
- ▼
- 用户完成支付
- ├─ EasyPay → 扫码 / H5 跳转
- ├─ 支付宝官方 → 桌面二维码 / 移动端支付宝跳转
- ├─ 微信官方 → 桌面 Native 扫码 / 非微信 H5 / 微信内 JSAPI
- └─ Stripe → Payment Element(银行卡/支付宝/微信等)
- │
- ▼
- 支付回调验签 → 订单 PAID
- │
- ▼
- 自动充值到用户余额 → 订单 COMPLETED
-```
-
-### 订单状态说明
-
-| 状态 | 说明 |
-|------|------|
-| `PENDING` | 待支付,等待用户完成支付 |
-| `PAID` | 已支付,等待充值到账 |
-| `COMPLETED` | 已完成,余额已到账 |
-| `EXPIRED` | 已过期,超时未支付 |
-| `CANCELLED` | 已取消,用户主动取消 |
-| `FAILED` | 充值失败,可管理员重试 |
-| `REFUND_REQUESTED` | 已申请退款 |
-| `REFUNDING` | 退款处理中 |
-| `REFUNDED` | 已退款 |
-
-### 超时与兜底
-
-- 订单超时后,后台任务会先查询上游支付状态再标记过期
-- 如果用户实际已支付但回调延迟,系统会通过查询补单
-- 后台任务每 60 秒执行一次超时检查
-
----
-
-## 从 Sub2ApiPay 迁移
-
-如果你之前使用 [Sub2ApiPay](https://github.com/touwaeriol/sub2apipay) 作为外部支付系统,现在可以迁移到内置支付:
-
-### 主要差异
-
-| 对比项 | Sub2ApiPay | 内置支付 |
-|--------|-----------|---------|
-| 部署方式 | 独立服务(Next.js + PostgreSQL) | 内置于 Sub2API,无需额外部署 |
-| 支付方式 | EasyPay、支付宝、微信、Stripe | 相同 |
-| 配置方式 | 环境变量 + 独立管理后台 | Sub2API 管理后台内统一配置 |
-| 充值对接 | 通过 Admin API 回调 | 内部直接处理,更可靠 |
-| 订阅套餐 | 支持 | 暂不支持(计划中) |
-| 订单管理 | 独立管理界面 | 集成在 Sub2API 管理后台 |
-
-### 迁移步骤
-
-1. 在 Sub2API 管理后台启用支付并配置服务商(使用相同的支付凭证)
-2. 更新 Webhook 回调地址为 Sub2API 的回调地址
-3. 确认新订单通过内置支付正常处理
-4. 停用 Sub2ApiPay 服务
-
-> **注意**:Sub2ApiPay 中的历史订单数据不会自动迁移。建议保留 Sub2ApiPay 一段时间以便查询历史记录。
diff --git a/docs/superpowers/plans/2026-04-20-auth-identity-payment-foundation.md b/docs/superpowers/plans/2026-04-20-auth-identity-payment-foundation.md
deleted file mode 100644
index 2d44e058..00000000
--- a/docs/superpowers/plans/2026-04-20-auth-identity-payment-foundation.md
+++ /dev/null
@@ -1,539 +0,0 @@
-# Auth Identity Payment Foundation Implementation Plan
-
-> **For agentic workers:** REQUIRED SUB-SKILL: Use `superpowers:subagent-driven-development` (recommended) or `superpowers:executing-plans` to implement this plan task-by-task. Steps use checkbox (`- [ ]`) syntax for tracking.
-
-**Goal:** Rebuild the auth identity, profile binding, payment routing, and OpenAI advanced scheduler foundation on top of a clean `origin/main` branch while preserving historical compatibility for existing email users, existing LinuxDo users, historical LinuxDo/WeChat/OIDC synthetic-email users, and historical WeChat `openid`-only records.
-**Architecture:** A unified identity foundation centered on durable provider subjects (`email`, `linuxdo`, `oidc`, `wechat`) and transactional pending-auth sessions; backend-owned payment source routing behind stable frontend methods (`alipay`, `wxpay`); compatibility-first migration/backfill before feature enablement.
-**Tech Stack:** Go, Gin, Ent, PostgreSQL, Redis, Vue 3, Pinia, TypeScript, Vitest, pnpm.
-
----
-
-## Non-Negotiable Product Rules
-
-- [ ] Preserve login continuity for existing email users, existing LinuxDo users, and historically migrated third-party users.
-- [ ] During migration, backfill historical LinuxDo/WeChat/OIDC synthetic-email users into explicit third-party identities before first post-upgrade login whenever deterministic recovery is possible.
-- [ ] During migration, surface historical WeChat `openid`-only records through explicit migration reports and remediation rules; do not silently reinterpret them as valid canonical identities.
-- [ ] Keep existing email login and add third-party login/bind for `linuxdo`, `oidc`, and `wechat`.
-- [ ] On first third-party login:
- - identity exists: direct login.
- - identity does not exist: start pending-auth flow.
- - local email binding is required only when system config says so.
- - upstream provider email verification never counts as local email verification.
-- [ ] When user-entered and locally verified email already exists:
- - offer bind-existing-account after local re-authentication.
- - offer change-email-and-create-new-account.
- - when email binding is mandatory, do not allow bypass without changing to another email.
-- [ ] On first third-party login or first third-party bind, provider nickname/avatar must be presented as independent replace options for the current nickname and avatar. They are not auto-applied.
-- [ ] Source-specific initial grants must support per-source defaults for balance, concurrency, and subscriptions.
-- [ ] Default grant timing: on successful new-account creation.
-- [ ] Optional grant timing: on first successful bind for the configured source.
-- [ ] Migration/backfill must never trigger first-bind or first-signup grants retroactively.
-- [ ] Avatar profile supports:
- - direct URL storage.
- - image data URL upload compressed to `<=100KB` before storing in DB.
- - explicit delete.
-- [ ] Admin user management must expose and sort by `last_login_at` and `last_active_at`.
-- [ ] WeChat login rules:
- - WeChat environment uses MP login.
- - non-WeChat browser uses Open/QR login.
- - canonical identity uses `unionid`.
- - when `unionid` is unavailable, fail the login/bind flow under the approved option-1 policy.
-- [ ] OIDC rules:
- - browser authorization-code flow always uses PKCE `S256`.
- - discovery issuer and ID token `iss` must match exactly.
- - `userinfo.sub` must match ID token `sub` when UserInfo is used.
- - upstream `email_verified` does not satisfy local email verification.
-- [ ] Payment UI rules:
- - user-facing methods stay `支付宝` and `微信支付`.
- - backend decides whether each method routes to official provider instance or EasyPay.
- - at runtime, each visible method may only have one active source.
-- [ ] Alipay rules:
- - PC: in-page QR.
- - mobile browser: jump to Alipay payment.
-- [ ] WeChat Pay rules:
- - PC: in-page QR.
- - WeChat H5: MP/JSAPI first, fallback to H5 pay.
- - non-WeChat H5: H5 pay, or prompt to open in WeChat when unavailable.
-- [ ] Payment success pages are informational only; actual fulfillment depends on webhook or server-side reconciliation.
-- [ ] WeChat in-app payment requiring `openid` must use a dedicated server-backed payment OAuth resume flow rather than frontend-only recovery state.
-- [ ] OpenAI advanced scheduler is available but default-disabled.
-
-## Hard Technical Constraints From Audit
-
-- [ ] Browser-based third-party auth must use Authorization Code + PKCE `S256`.
-- [ ] PKCE must not be admin-configurable off for browser authorization-code providers.
-- [ ] OIDC identity primary key must be `(issuer, subject)`, not email.
-- [ ] Email equality must never auto-link accounts.
-- [ ] Bind-existing-account must require explicit local re-authentication and TOTP verification when enabled.
-- [ ] Bind-current-user must originate from an already-authenticated local user and preserve explicit bind intent across callback completion.
-- [ ] OAuth redirect URI must be fixed server config, exact-match, and never derived from user input.
-- [ ] User-supplied redirect may only choose a normalized same-origin internal route after completion.
-- [ ] WeChat canonical identity must be `unionid`; `openid` remains channel/app-scoped support data only.
-- [ ] Every canonical identity uniqueness rule must include provider namespace (`provider_key`) consistently.
-- [ ] Callback completion must use backend session completion or a one-time opaque exchange code that is short-lived, one-time, browser-session-bound, `POST`-redeemed, and unusable as a bearer token.
-- [ ] Every payment order must snapshot the selected provider instance plus the order-time verification inputs required for callback verification, reconciliation, refund, and audit.
-- [ ] Frontend must not receive first-party bearer tokens through callback URL fragments in the rebuilt flow.
-- [ ] Public payment result polling must not expose order data by raw `out_trade_no` alone; use authenticated lookup or signed opaque result token.
-- [ ] WeChat Pay webhook handling must verify signature, decrypt payload, and compare `appid`, `mchid`, `out_trade_no`, `amount`, `currency`, and provider trade state against the order snapshot before fulfillment.
-
-## Baseline Notes
-
-- [ ] Current clean branch head when this plan was written: `721d7ab3`.
-- [ ] Baseline backend verification on clean `origin/main`: `cd backend && go test ./...` passes.
-- [ ] Baseline frontend verification on clean `origin/main`: `cd frontend && pnpm test:run` currently fails in unrelated existing suites. New work must add targeted tests and avoid claiming full frontend green until those baseline failures are addressed separately.
-- [ ] Existing migration directory currently ends at `107_*`; this rebuild reserves `108` through `111`.
-
-## Target File Map
-
-### New backend migrations
-
-- [ ] `backend/migrations/108_auth_identity_foundation_core.sql`
-- [ ] `backend/migrations/109_auth_identity_compat_backfill.sql`
-- [ ] `backend/migrations/110_pending_auth_and_provider_default_grants.sql`
-- [ ] `backend/migrations/111_payment_routing_and_scheduler_flags.sql`
-
-### New or rebuilt Ent schema
-
-- [ ] `backend/ent/schema/auth_identity.go`
-- [ ] `backend/ent/schema/auth_identity_channel.go`
-- [ ] `backend/ent/schema/pending_auth_session.go`
-- [ ] `backend/ent/schema/identity_adoption_decision.go`
-
-### New or rebuilt backend repositories/services/handlers
-
-- [ ] `backend/internal/repository/user_profile_identity_repo.go`
-- [ ] `backend/internal/repository/user_profile_identity_repo_contract_test.go`
-- [ ] `backend/internal/repository/auth_identity_migration_report.go`
-- [ ] `backend/internal/service/auth_identity_flow.go`
-- [ ] `backend/internal/service/auth_identity_flow_test.go`
-- [ ] `backend/internal/service/auth_pending_identity_service.go`
-- [ ] `backend/internal/service/auth_pending_identity_service_test.go`
-- [ ] `backend/internal/service/payment_config_service.go`
-- [ ] `backend/internal/service/payment_order.go`
-- [ ] `backend/internal/service/payment_order_lifecycle.go`
-- [ ] `backend/internal/service/payment_fulfillment.go`
-- [ ] `backend/internal/service/payment_resume_service.go`
-- [ ] `backend/internal/service/payment_resume_service_test.go`
-- [ ] `backend/internal/service/openai_account_scheduler.go`
-- [ ] `backend/internal/handler/auth_pending_identity_flow.go`
-- [ ] `backend/internal/handler/auth_linuxdo_oauth.go`
-- [ ] `backend/internal/handler/auth_oidc_oauth.go`
-- [ ] `backend/internal/handler/auth_wechat_oauth.go`
-- [ ] `backend/internal/handler/auth_handler.go`
-- [ ] `backend/internal/handler/user_handler.go`
-- [ ] `backend/internal/handler/payment_handler.go`
-- [ ] `backend/internal/handler/payment_webhook_handler.go`
-- [ ] `backend/internal/handler/admin/user_handler.go`
-- [ ] `backend/internal/handler/admin/setting_handler.go`
-
-### New or rebuilt frontend API/store/views/components
-
-- [ ] `frontend/src/api/auth.ts`
-- [ ] `frontend/src/api/user.ts`
-- [ ] `frontend/src/api/payment.ts`
-- [ ] `frontend/src/api/admin/settings.ts`
-- [ ] `frontend/src/api/admin/users.ts`
-- [ ] `frontend/src/stores/auth.ts`
-- [ ] `frontend/src/stores/payment.ts`
-- [ ] `frontend/src/components/auth/ThirdPartyAuthCallbackFlow.vue`
-- [ ] `frontend/src/components/auth/LinuxDoOAuthSection.vue`
-- [ ] `frontend/src/components/auth/OidcOAuthSection.vue`
-- [ ] `frontend/src/components/auth/WechatOAuthSection.vue`
-- [ ] `frontend/src/components/user/profile/ProfileAccountBindingsCard.vue`
-- [ ] `frontend/src/components/user/profile/ProfileInfoCard.vue`
-- [ ] `frontend/src/views/auth/LinuxDoCallbackView.vue`
-- [ ] `frontend/src/views/auth/OidcCallbackView.vue`
-- [ ] `frontend/src/views/auth/WechatCallbackView.vue`
-- [ ] `frontend/src/views/user/ProfileView.vue`
-- [ ] `frontend/src/views/user/PaymentView.vue`
-- [ ] `frontend/src/views/user/PaymentQRCodeView.vue`
-- [ ] `frontend/src/views/user/PaymentResultView.vue`
-
-## Phase 1: Migration And Compatibility Foundation
-
-### Task 1. Create core identity schema migration
-
-- [ ] Implement `backend/migrations/108_auth_identity_foundation_core.sql` with:
- - `auth_identities`
- - `auth_identity_channels`
- - `pending_auth_sessions`
- - `identity_adoption_decisions`
- - `users.last_login_at`
- - `users.last_active_at`
- - grant-tracking columns/tables required to prevent double-award
-- [ ] Add uniqueness/index rules:
- - one canonical identity per `(provider, provider_key, provider_subject)`
- - one channel record per `(provider, provider_channel, provider_app_id, provider_channel_subject)`
- - one adoption decision per pending session
-- [ ] Model `pending_auth_sessions` so immutable upstream claims and mutable local flow state are stored separately; do not reintroduce a mixed `metadata` catch-all.
-- [ ] Preserve null-safe compatibility defaults so historical rows remain readable before backfill finishes.
-- [ ] Add explicit rollback blocks only where safe; never repeat the destructive pattern observed in old `112_update_pending_auth_sessions.sql`.
-
-### Task 2. Materialize historical identities before runtime
-
-- [ ] Implement `backend/migrations/109_auth_identity_compat_backfill.sql` to backfill:
- - existing email users into `auth_identities(provider=email, provider_subject=normalized_email)`
- - historical LinuxDo users into `auth_identities(provider=linuxdo, provider_subject=linuxdo_subject)`
- - historical synthetic-email LinuxDo users into explicit LinuxDo identity rows by parsing legacy email mode and legacy provider metadata
- - historical synthetic-email WeChat users into explicit WeChat identities where `unionid` or equivalent deterministic provider identity is recoverable
- - historical synthetic-email OIDC users into explicit OIDC identities where deterministic provider identity is recoverable
- - profile/channel rows from historical `user_external_identities`-style data when present in upgraded databases
-- [ ] Write migration report output in `backend/internal/repository/auth_identity_migration_report.go` so production can inspect unmatched rows, `openid`-only WeChat rows, and non-deterministic synthetic-email rows instead of silently skipping them.
-- [ ] Set `signup_source` and provider provenance when recoverable from historical data. Do not flatten everything to `email`.
-
-### Task 3. Provider default grant and scheduler config migration
-
-- [ ] Implement `backend/migrations/110_pending_auth_and_provider_default_grants.sql` for:
- - provider-specific initial balance/concurrency/subscription defaults
- - grant timing flags: `on_signup`, optional `on_first_bind`
- - email-required-on-third-party-signup flags
- - profile avatar storage columns/settings
-- [ ] Implement `backend/migrations/111_payment_routing_and_scheduler_flags.sql` for:
- - stable payment method to provider-instance routing
- - visible-method normalization from historical `supported_types`, `payment_mode`, and legacy aliases such as `wxpay_direct`
- - admin exclusivity flags for `alipay` and `wxpay`
- - advanced scheduler enable flag defaulting to disabled
-
-### Task 4. Generate Ent and compile migration-safe model layer
-
-- [ ] Add the schema definitions in:
- - `backend/ent/schema/auth_identity.go`
- - `backend/ent/schema/auth_identity_channel.go`
- - `backend/ent/schema/pending_auth_session.go`
- - `backend/ent/schema/identity_adoption_decision.go`
-- [ ] Run:
- ```bash
- cd backend
- go generate ./ent
- ```
-- [ ] Compile after generation:
- ```bash
- cd backend
- go test ./... -run '^$'
- ```
-- [ ] Commit checkpoint:
- ```bash
- git add backend/migrations backend/ent/schema backend/ent
- git commit -m "feat: add auth identity foundation schema"
- ```
-
-## Phase 2: Backend Identity Flow Rebuild
-
-### Task 5. Build a single repository contract for identity lookups and grants
-
-- [ ] Implement `backend/internal/repository/user_profile_identity_repo.go` with transactional helpers for:
- - get user by canonical identity
- - get user by channel identity
- - create canonical + channel identity together
- - bind identity to existing user after verified re-auth
- - record one-time provider grant award
- - record adoption preference decisions
- - update `last_login_at` and `last_active_at`
-- [ ] Add repository contract coverage in `backend/internal/repository/user_profile_identity_repo_contract_test.go`.
-- [ ] Enforce dual-write for email registration/login so `users.email` and `auth_identities(provider=email, ...)` stay consistent from this phase onward.
-- [ ] Add repository coverage proving `last_login_at` and `last_active_at` use the required field names and are not silently replaced by derived `last_used_at` logic.
-
-### Task 6. Rebuild transactional pending-auth service
-
-- [ ] Implement `backend/internal/service/auth_pending_identity_service.go` and tests to own these flows:
- - create pending session from third-party callback
- - verify local email code
- - create new account from pending session with correct `signup_source`
- - bind pending identity to existing account after password/TOTP re-auth
- - apply configured provider defaults on the correct trigger only once
- - store provider nickname/avatar candidates and user opt-in replacement decisions independently
-- [ ] Implement callback completion so pending auth can finish through backend session completion or a one-time exchange code:
- - short TTL
- - one-time use
- - browser-session binding
- - `POST` redemption only
- - safe mixed-version bridge to legacy pending-token aliases during rollout
-- [ ] Keep pending session payload normalized:
- - provider identity fields live in typed columns/JSON structure
- - mutable local progression lives separately from immutable upstream claims
- - avoid the old branch’s mixed `metadata` and `upstream_identity_payload` ambiguity
-- [ ] Do not call plain email registration helpers from this flow. The old feature branch bug where pending third-party signup fell back to `RegisterWithVerification` must not reappear.
-
-### Task 7. Rebuild provider callback adapters
-
-- [ ] Refactor these handlers to thin adapters over the shared pending-auth service:
- - `backend/internal/handler/auth_linuxdo_oauth.go`
- - `backend/internal/handler/auth_oidc_oauth.go`
- - `backend/internal/handler/auth_wechat_oauth.go`
-- [ ] For OIDC:
- - require PKCE `S256`, `state`, and `nonce`
- - validate discovery issuer, `iss`, `aud`, optional `azp`, `exp`, and `nonce`
- - verify `userinfo.sub == id_token.sub` when UserInfo is used
- - persist canonical identity as `(issuer, sub)`
-- [ ] For WeChat:
- - MP flow in WeChat UA
- - Open/QR flow outside WeChat UA
- - website login uses authorization-code flow and persists channel/app binding
- - persist channel identity by `(channel, appid, openid)`
- - persist canonical identity by `unionid`
- - hard-fail when `unionid` is absent under the approved product policy
-- [ ] Replace callback URL fragment token delivery with backend session completion or one-time exchange code consumed by `frontend/src/stores/auth.ts`.
-
-### Task 8. Rebuild auth endpoints and profile binding endpoints
-
-- [ ] Implement `backend/internal/handler/auth_pending_identity_flow.go` for:
- - fetch pending session summary
- - submit verified email
- - choose create-new-account or bind-existing-account
- - submit nickname/avatar replacement choices
-- [ ] Make bind-existing-account and bind-current-user flows explicit:
- - no automatic linking on matching email
- - fresh password/TOTP proof is scoped to the intended target account only
- - no automatic metadata merge beyond explicitly selected nickname/avatar adoption
-- [ ] Update `backend/internal/handler/auth_handler.go` and `backend/internal/handler/user_handler.go` to expose:
- - current bindings summary
- - start-bind endpoints for LinuxDo/OIDC/WeChat
- - disconnect endpoints with safety checks
- - avatar upload/delete endpoints
-- [ ] Avatar handling requirements:
- - allow external URL
- - allow data URL upload
- - compress image payload to `<=100KB`
- - store compressed value in DB
- - deleting custom avatar must not implicitly resurrect stale provider avatar unless the user explicitly chooses provider avatar again
-
-### Task 9. Add admin visibility and sorting
-
-- [ ] Update `backend/internal/handler/admin/user_handler.go` and supporting query/service code so admin list supports:
- - `last_login_at`
- - `last_active_at`
- - sorting by both
- - binding/provider summary columns
-- [ ] Update `backend/internal/handler/admin/setting_handler.go` and setting service code for:
- - provider initial grant config
- - mandatory-email-on-third-party-signup config
- - payment source exclusivity config
- - advanced scheduler toggle
-
-### Task 10. Backend verification checkpoint
-
-- [ ] Run targeted backend tests:
- ```bash
- cd backend
- go test ./internal/repository -run 'TestUserProfileIdentity|TestAuthIdentityMigration'
- go test ./internal/service -run 'TestAuthIdentityFlow|TestPendingAuthIdentity|TestOpenAIAccountScheduler'
- go test ./internal/handler -run 'TestLinuxDo|TestOidc|TestWechat|TestPaymentWebhook'
- go test ./...
- ```
-- [ ] Commit checkpoint:
- ```bash
- git add backend
- git commit -m "feat: rebuild auth identity backend flows"
- ```
-
-## Phase 3: Frontend Third-Party Flow And Profile UX
-
-### Task 11. Rebuild callback flow UI around pending session decisions
-
-- [ ] Rebuild `frontend/src/components/auth/ThirdPartyAuthCallbackFlow.vue` so it:
- - loads pending-session summary from backend
- - shows provider nickname/avatar candidates
- - lets user independently choose nickname replacement and avatar replacement
- - handles create-new-account vs bind-existing-account
- - enforces verified local email before completion when required
- - handles “email already exists” by branching to bind-existing-account or change-email-and-create-new-account
-- [ ] Update:
- - `frontend/src/views/auth/LinuxDoCallbackView.vue`
- - `frontend/src/views/auth/OidcCallbackView.vue`
- - `frontend/src/views/auth/WechatCallbackView.vue`
- - `frontend/src/api/auth.ts`
- - `frontend/src/stores/auth.ts`
-- [ ] Replace any token-fragment bootstrap with backend session completion or one-time exchange code flow.
-- [ ] During rollout, keep temporary compatibility readers for legacy pending-token aliases behind a bounded bridge contract and explicit removal step.
-
-### Task 12. Rebuild profile account binding and avatar UX
-
-- [ ] Rebuild `frontend/src/components/user/profile/ProfileAccountBindingsCard.vue` to:
- - show linked LinuxDo/OIDC/WeChat providers
- - start bind/unbind flows
- - show provider avatars and nicknames as reference only
- - prevent unsafe disconnect when it would strand the account
-- [ ] Rebuild `frontend/src/components/user/profile/ProfileInfoCard.vue` and `frontend/src/views/user/ProfileView.vue` to:
- - support avatar URL entry
- - support data URL upload/compression preview
- - support avatar delete
- - clearly separate current profile nickname/avatar from provider-sourced suggested nickname/avatar
-
-### Task 13. Add frontend tests for rebuilt auth/profile flows
-
-- [ ] Add or update:
- - `frontend/src/components/auth/__tests__/ThirdPartyAuthCallbackFlow.spec.ts`
- - `frontend/src/components/auth/__tests__/LinuxDoCallbackView.spec.ts`
- - `frontend/src/components/auth/__tests__/WechatCallbackView.spec.ts`
- - `frontend/src/components/user/profile/__tests__/ProfileAccountBindingsCard.spec.ts`
- - `frontend/src/components/user/profile/__tests__/ProfileInfoCard.spec.ts`
-- [ ] Cover:
- - email-required branch
- - email-conflict branch
- - bind-existing-account with re-auth prompt
- - nickname replacement only
- - avatar replacement only
- - neither replacement
- - avatar delete after prior provider adoption
-
-## Phase 4: Payment Routing Rebuild
-
-### Task 14. Normalize payment routing backend
-
-- [ ] Rebuild `backend/internal/service/payment_config_service.go` to expose a stable method-routing contract:
- - frontend visible methods remain `alipay` and `wxpay`
- - admin chooses which provider instance serves each method
- - runtime validation guarantees only one active source per visible method
-- [ ] Add migration logic and tests to normalize historical provider-instance config:
- - `supported_types`
- - `payment_mode`
- - legacy aliases such as `wxpay_direct`
- - historical limit config
-- [ ] Rebuild `backend/internal/service/payment_order.go` and `backend/internal/service/payment_order_lifecycle.go` so order creation snapshots:
- - visible method
- - selected provider instance id
- - provider type
- - provider capability mode
- - verification-critical provider fields needed for later callback/query/refund validation
-- [ ] Rebuild `backend/internal/handler/payment_handler.go` for UX rules:
- - Alipay PC: QR page
- - Alipay mobile: direct jump
- - WeChat PC: QR page
- - WeChat H5 in WeChat: MP/JSAPI first, fallback to H5
- - WeChat H5 outside WeChat: H5 or “open in WeChat” prompt when unavailable
-- [ ] Never derive canonical return URL from `Referer`; use configured or signed internal callback targets only.
-- [ ] Implement `backend/internal/service/payment_resume_service.go` so WeChat in-app payment OAuth resume is server-backed rather than localStorage-backed:
- - create `oauth_required` resume context
- - persist amount/order_type/plan_id/visible method/redirect/state
- - redeem callback into same-origin internal resume target
- - expire and consume resume context safely
-
-### Task 15. Make fulfillment and reconciliation provider-instance-safe
-
-- [ ] Rebuild `backend/internal/handler/payment_webhook_handler.go` and `backend/internal/service/payment_fulfillment.go` so:
- - verification uses the order’s original provider instance
- - webhook processing is idempotent by provider event id and internal order id
- - missed webhook recovery uses server-side provider query, not frontend success return
-- [ ] For WeChat Pay specifically, enforce:
- - fixed HTTPS `notify_url` with no query params
- - no dependency on user login state
- - signature verification before decrypt
- - APIv3 decrypt before business parsing
- - comparison of `appid`, `mchid`, `out_trade_no`, `amount`, `currency`, and trade state against the order snapshot
-- [ ] Harden `frontend/src/views/user/PaymentResultView.vue` and `frontend/src/api/payment.ts` so result polling uses an authenticated order lookup or signed opaque token, not a raw public `out_trade_no` query.
-
-### Task 16. Rebuild payment frontend views
-
-- [ ] Rebuild `frontend/src/views/user/PaymentView.vue`, `frontend/src/views/user/PaymentQRCodeView.vue`, and `frontend/src/stores/payment.ts` so:
- - only two buttons are shown to user: `支付宝` and `微信支付`
- - frontend does not leak official-vs-EasyPay distinction
- - route-specific copy handles QR, jump, MP, H5 fallback correctly
-- [ ] Rebuild WeChat in-app payment resume UX around the server-backed resume context:
- - handle `oauth_required`
- - continue from same-origin resume target
- - avoid long-lived localStorage as the source of truth
-- [ ] Add or update:
- - `frontend/src/views/user/__tests__/PaymentView.spec.ts`
- - `frontend/src/views/user/__tests__/PaymentResultView.spec.ts`
- - backend webhook/payment routing tests
-
-### Task 17. Payment verification checkpoint
-
-- [ ] Run:
- ```bash
- cd backend
- go test ./internal/service -run 'TestPayment'
- go test ./internal/handler -run 'TestPayment'
- cd ../frontend
- pnpm test:run src/views/user/__tests__/PaymentView.spec.ts src/views/user/__tests__/PaymentResultView.spec.ts
- ```
-- [ ] Commit checkpoint:
- ```bash
- git add backend frontend
- git commit -m "feat: rebuild payment routing foundation"
- ```
-
-## Phase 5: Scheduler, Rollout, And Final Compatibility Pass
-
-### Task 18. Gate advanced scheduler behind explicit config
-
-- [ ] Update `backend/internal/service/openai_account_scheduler.go` and related admin setting surfaces so:
- - advanced scheduler remains compiled and testable
- - default runtime state is disabled
- - enablement is explicit through admin settings
- - legacy scheduling behavior remains default on upgrade
-- [ ] Add targeted coverage in `backend/internal/service/openai_account_scheduler_test.go`.
-
-### Task 19. Complete compatibility and rollout safety checks
-
-- [ ] Add migration/repository tests covering:
- - historical email-only user login after upgrade
- - historical LinuxDo user login after upgrade
- - historical synthetic-email LinuxDo user login after upgrade
- - historical synthetic-email WeChat user login after upgrade
- - historical synthetic-email OIDC user login after upgrade
- - historical WeChat `openid`-only rows are reported or explicitly remediated
- - no retroactive grant replay during migration
- - first-bind grant fires once only when enabled
- - email identity dual-write stays consistent
- - bind-existing-account requires password and TOTP where configured
- - mixed-version callback token bridge works during rollout and is removable afterward
- - historical payment config is normalized into visible-method routing without refund/query regression
-- [ ] Add deploy sequencing note to release docs or internal runbook:
- 1. deploy schema and backfill release.
- 2. inspect migration report for unmatched rows.
- 3. deploy backend identity/payment compatibility code with exchange bridge and legacy token aliases still enabled.
- 4. deploy frontend callback/profile/payment UI using session completion, exchange code, and server-backed WeChat payment resume.
- 5. remove legacy callback/token parsing after mixed-version window closes.
- 6. enable strict email-required signup or provider bind grants only after metrics are healthy.
-
-### Task 20. Final verification and handoff
-
-- [ ] Run final backend verification:
- ```bash
- cd backend
- go test ./...
- ```
-- [ ] Run targeted frontend verification:
- ```bash
- cd frontend
- pnpm test:run \
- src/components/auth/__tests__/ThirdPartyAuthCallbackFlow.spec.ts \
- src/components/auth/__tests__/LinuxDoCallbackView.spec.ts \
- src/components/auth/__tests__/WechatCallbackView.spec.ts \
- src/components/user/profile/__tests__/ProfileAccountBindingsCard.spec.ts \
- src/components/user/profile/__tests__/ProfileInfoCard.spec.ts \
- src/views/user/__tests__/PaymentView.spec.ts \
- src/views/user/__tests__/PaymentResultView.spec.ts
- ```
-- [ ] Run focused manual smoke checks:
- - email login with existing account
- - LinuxDo existing-account login after migration
- - WeChat synthetic-email account login after migration
- - OIDC synthetic-email account login after migration
- - third-party first login create-new-account path
- - third-party first login bind-existing-account path
- - first third-party bind with optional nickname/avatar replacement
- - PC Alipay QR
- - mobile Alipay jump
- - PC WeChat QR
- - WeChat H5 MP/JSAPI path
- - WeChat in-app OAuth resume path
- - non-WeChat H5 fallback path
-- [ ] Commit final checkpoint:
- ```bash
- git add docs backend frontend
- git commit -m "feat: rebuild auth identity and payment foundation"
- ```
-
-## Review Checklist
-
-- [ ] No flow still relies on provider email equality for account linking.
-- [ ] No flow still creates third-party users through plain email registration helpers.
-- [ ] No callback still returns first-party bearer tokens in URL fragments.
-- [ ] No callback completion path can be replayed as a bearer token substitute.
-- [ ] No payment result view trusts provider return page as authoritative fulfillment.
-- [ ] No webhook verification path selects provider credentials from “currently active config” instead of the order snapshot.
-- [ ] Existing email users, historical LinuxDo/WeChat/OIDC users, and `openid`-only WeChat remediation cases are covered by migration tests.
-- [ ] Avatar adoption and deletion semantics are explicit and reversible.
-- [ ] Grant timing is source-aware and one-time only.
diff --git a/docs/superpowers/specs/2026-04-20-auth-identity-payment-foundation-design.md b/docs/superpowers/specs/2026-04-20-auth-identity-payment-foundation-design.md
deleted file mode 100644
index 23823cf0..00000000
--- a/docs/superpowers/specs/2026-04-20-auth-identity-payment-foundation-design.md
+++ /dev/null
@@ -1,763 +0,0 @@
-# Auth Identity And Payment Foundation Design
-
-**Date:** 2026-04-20
-
-**Status:** Draft approved in conversation, written for implementation planning
-
-**Goal**
-
-Rebuild the `feat/auth-identity-foundation` intent on a clean branch from `main`, covering unified user identity, third-party login and binding, profile adoption, source-based signup defaults, unified payment routing and UX, admin configuration, compatibility with existing `main` data, and an opt-in OpenAI advanced scheduling switch.
-
-## Scope
-
-This design includes:
-
-- Email login and registration
-- Third-party login and binding for `LinuxDo`, `OIDC`, and `WeChat`
-- Unified identity storage for email and third-party identities
-- Pending auth sessions for callback-to-login/register/bind continuation
-- User-controlled nickname/avatar adoption during first relevant third-party flow
-- Profile binding management and avatar upload/delete
-- Source-based initial grants for balance, concurrency, and subscriptions
-- User management support for `last_login_at` and `last_active_at` sorting
-- Unified payment display methods (`alipay`, `wxpay`) mapped to a single active backend source each
-- Alipay and WeChat UX routing rules across PC, mobile, H5, and WeChat environments
-- Admin settings for auth providers, source defaults, payment sources, and OpenAI advanced scheduling
-- Incremental migration and compatibility for existing email users, existing LinuxDo users, historical LinuxDo/WeChat/OIDC synthetic-email users, and historical WeChat `openid`-only identity records
-
-This design does not treat unrelated upstream merges, docs churn, or license changes from the old branch as required scope.
-
-## Product Rules
-
-### Auth and identity
-
-- Existing email users remain valid and continue to log in with no manual action.
-- Existing LinuxDo, OIDC, and WeChat users represented by historical third-party or synthetic-email data must remain recoverable during migration.
-- Third-party first login behavior:
- - Existing bound identity: direct login
- - Missing identity: start first-login flow
-- Browser-based third-party authorization-code login always uses PKCE `S256`; this is not an admin-toggleable feature.
-- If `force_email_on_third_party_signup` is disabled, a first-login user may create an account without binding an email.
-- If `force_email_on_third_party_signup` is enabled, the user must provide an email.
-- If the provided and verified email already exists:
- - show that the email already exists
- - allow "verify and bind existing account"
- - allow "change email and continue registration"
- - do not allow bypassing the email requirement
-- Upstream provider email verification is not trusted as a local bound email.
-- Matching upstream email must never auto-link to an existing local account.
-- Linking to an existing local account is allowed only when:
- - the user explicitly chooses that target account
- - the target account passes fresh local re-authentication
- - required TOTP verification succeeds
-- New third-party bind initiated from profile must start from an already logged-in local account and preserve explicit bind intent end-to-end.
-- `redirect_to` may only represent a normalized same-origin internal route. It must never contain a third-party URL and must never be derived from `Referer`.
-- OIDC validation rules:
- - canonical identity key is `issuer + sub`
- - discovery issuer and ID token `iss` must match exactly
- - `userinfo.sub` must match ID token `sub` when UserInfo is used
- - upstream `email_verified` may improve UX copy but does not satisfy local email-binding requirements
-- WeChat login chooses channel by environment:
- - in WeChat environment: `mp`
- - outside WeChat: `open`
-- WeChat primary identity key is `unionid`.
-- If a WeChat login/bind flow cannot produce `unionid`, the flow fails and no fallback `openid` identity is created.
-- Historical WeChat records that only contain `openid` are treated as migration-remediation cases, not as a valid long-term canonical identity model.
-- WeChat website login uses authorization code flow, random `state`, and the provider channel/app binding must be persisted alongside the resolved identity.
-
-### Profile adoption
-
-- During the first relevant third-party flow, the user can independently decide:
- - replace current nickname or not
- - replace current avatar or not
-- This applies to first third-party registration and first third-party binding.
-- The decision is explicit user choice, not automatic replacement.
-
-### Source-based initial grants
-
-- Source-specific defaults exist for `email`, `linuxdo`, `oidc`, and `wechat`.
-- Each source defines:
- - default balance
- - default concurrency
- - default subscriptions
- - grant on signup
- - grant on first bind
-- Default behavior:
- - grant on signup: enabled
- - grant on first bind: disabled
-- First-bind grants are optional and controlled per source.
-- Grants must be idempotent.
-
-### Avatar management
-
-- Avatar supports:
- - external URL
- - image `data:` URL
-- `data:` URL images are compressed to at most `100KB` before persistence.
-- Avatar storage is database-backed.
-- Avatar delete is supported.
-
-### Payment UX and routing
-
-- Frontend shows only two display methods:
- - `alipay`
- - `wxpay`
-- Users never choose between official providers and EasyPay explicitly.
-- Backend allows only one active source per display method at a time.
-- Alipay UX:
- - PC: show QR code in page
- - mobile: jump to Alipay app/payment flow
-- WeChat UX:
- - PC: show QR code in page
- - non-WeChat H5: prefer H5 pay; if unavailable, tell the user to open in WeChat
- - WeChat environment: prefer MP/JSAPI pay; if unavailable, fall back to H5 pay
-- Payment success is confirmed by backend order state, webhook, and/or query, not only frontend return.
-- Frontend-visible labels remain `支付宝` and `微信支付`, while internal visible-method identifiers remain `alipay` and `wxpay`.
-- Public result pages must not verify order state by exposing raw `out_trade_no`; they use authenticated lookup or a signed opaque result token instead.
-- Payment callback or return URLs must be fixed same-origin internal targets. They must not be inferred from `Referer`.
-- WeChat payment webhook handling must use a fixed HTTPS `notify_url` with no query parameters and must not depend on user login state.
-
-### OpenAI advanced scheduling
-
-- OpenAI advanced scheduling is supported.
-- It is disabled by default.
-- Admin can enable it explicitly.
-
-## Architecture
-
-Keep `users` as the account owner table and move login identities, channel mappings, pending auth state, callback completion state, and first-bind grant idempotency into dedicated tables and services. Keep email login working while progressively introducing unified identity reads and writes.
-
-Payment uses a similar split between user-visible display methods and backend provider sources. Frontend works only with stable display methods while backend resolves to the currently active source and capability matrix, and stores enough order-time snapshot data to survive later provider-config changes.
-
-Compatibility is a first-class concern: migrations are additive, reads are compatibility-aware, and rollout must tolerate existing `main` data and short-lived frontend/backend version skew.
-
-## Data Model
-
-### `users`
-
-Preserve existing account ownership and local-login fields. Extend or use:
-
-- `email`
-- `password_hash`
-- `totp_enabled`
-- `signup_source`
-- `last_login_at`
-- `last_active_at`
-
-The `users` table remains the primary business subject for balance, concurrency, subscriptions, permissions, and profile.
-
-### `auth_identities`
-
-Represents all canonical login or bindable identities.
-
-Fields:
-
-- `user_id`
-- `provider_type`: `email`, `linuxdo`, `oidc`, `wechat`
-- `provider_key`
-- `provider_subject`
-- `verified_at`
-- `issuer`
-- `metadata`
-- timestamps
-
-Uniqueness:
-
-- `provider_type + provider_key + provider_subject` must be unique
-
-Rules:
-
-- email identity uses canonicalized local email
-- LinuxDo uses stable provider subject under the configured provider namespace
-- OIDC uses stable issuer + subject, with issuer namespace represented consistently through `provider_key` and `issuer`
-- WeChat uses `unionid` as canonical subject under the configured Open Platform namespace
-
-### `auth_identity_channels`
-
-Stores channel-specific subject mappings for an identity.
-
-Primary use:
-
-- WeChat `open` / `mp` / payment channel mapping
-
-Fields:
-
-- `identity_id`
-- `provider_type`
-- `provider_key`
-- `channel`
-- `channel_app_id`
-- `channel_subject`
-- `metadata`
-- timestamps
-
-Rules:
-
-- canonical WeChat identity still keys on `unionid`
-- `openid` values live here as channel mappings
-
-### `pending_auth_sessions`
-
-Stores callback state between third-party callback and final account action.
-
-Fields:
-
-- `intent`
-- `provider_type`
-- `provider_key`
-- `provider_subject`
-- `target_user_id`
-- `redirect_to`
-- `resolved_email`
-- `registration_password_hash`
-- `upstream_identity_claims`
-- `local_flow_state`
-- `browser_session_key`
-- `completion_code_hash`
-- `completion_code_expires_at`
-- `email_verified_at`
-- `password_verified_at`
-- `totp_verified_at`
-- `expires_at`
-- `consumed_at`
-- timestamps
-
-Responsibilities:
-
-- continue provider callback into register/login/bind flows
-- persist nickname/avatar suggestions
-- persist explicit adoption decisions
-- survive navigation between auth pages
-- support mixed-version rollout through short-lived legacy token aliases when required
-
-Security rules:
-
-- callback completion uses backend session completion or a one-time exchange code
-- exchange codes are short-lived, one-time, bound to browser session and pending session, and redeemed via `POST`
-- exchange codes must not behave as bearer tokens and must not be logged, stored in URL fragments, or reused after redemption
-- `local_flow_state` stores mutable local progression only; immutable upstream claims remain in `upstream_identity_claims`
-
-### `identity_adoption_decisions`
-
-Persists user adoption preference collected during a pending-auth flow and resolved onto the bound identity.
-
-Fields:
-
-- `pending_auth_session_id`
-- `identity_id`
-- `adopt_display_name`
-- `adopt_avatar`
-- `decided_at`
-- timestamps
-
-Rules:
-
-- one adoption-decision row exists per pending session
-- `identity_id` is filled once final account creation or bind succeeds
-
-### `user_avatars`
-
-Stores the currently effective custom avatar.
-
-Fields:
-
-- `user_id`
-- `storage_provider`
-- `storage_key`
-- `url`
-- `content_type`
-- `byte_size`
-- `sha256`
-- timestamps
-
-Rules:
-
-- supports URL-backed and inline data-backed representations
-- hard maximum payload size is `100KB`
-
-### `user_provider_default_grants`
-
-Stores idempotency state for source grants.
-
-Fields:
-
-- `user_id`
-- `provider_type`
-- `granted_at`
-- timestamps
-
-Responsibilities:
-
-- prevent duplicate first-bind grants
-- allow signup grants and first-bind grants to be reasoned about independently
-
-## Identity Keys And Canonicalization
-
-- Email canonical key: `lower(trim(email))`
-- LinuxDo canonical key: provider subject from LinuxDo
-- OIDC canonical key: `issuer + sub`
-- WeChat canonical key: `unionid`
-
-WeChat-specific rule:
-
-- `openid` never becomes the primary stored identity key
-- if only `openid` is available, login/bind fails with a configuration/identity error
-- historical `openid`-only records must be reported and either remediated during migration or explicitly blocked from silent auto-upgrade
-
-## Core Flows
-
-### Email register/login
-
-- Existing email auth flow remains
-- On email registration, create canonical `email` identity
-- Apply `email` source signup defaults
-
-### Third-party login with existing identity
-
-- Resolve canonical identity
-- Login mapped `user`
-- Update `last_login_at`
-- Do not issue signup or first-bind grants again
-
-### Third-party first login with no identity
-
-- Create `pending_auth_session`
-- Frontend callback flow decides next action
-- Pending session creation stores immutable upstream claims separately from mutable local progress fields
-
-Branches:
-
-- no forced email binding:
- - user can create account directly
-- forced email binding:
- - user must supply local email
-
-If supplied local email already exists:
-
-- tell the user the email already exists
-- allow verify-and-bind-existing-account
-- allow changing email to continue registration
-
-On new account creation:
-
-- create `users` row
-- create canonical third-party identity
-- create or update canonical email identity when local email binding succeeds
-- apply source signup grants
-- apply adoption choices if selected
-
-### Bind third-party identity to current logged-in user
-
-- current user starts bind flow
-- callback resolves to `bind_current_user`
-- bind intent is tied to the initiating local user session and cannot be re-targeted by email match
-- bind canonical identity to current user
-- if configured and first bind for that provider, apply first-bind grants
-- present nickname/avatar replacement choice
-
-### Bind existing account during first-login flow
-
-- user explicitly selects bind-existing-account
-- verify password for existing account
-- if account requires TOTP, verify TOTP
-- bind canonical identity to target account
-- optionally apply first-bind grants
-- present nickname/avatar replacement choice
-- no automatic profile or metadata merge occurs beyond explicitly selected nickname/avatar replacement
-
-### Callback completion and exchange flow
-
-- third-party callback never returns first-party bearer tokens in URL fragments
-- callback completion uses either:
- - backend session completion tied to the initiating browser session
- - one-time opaque exchange code redeemed by `POST`
-- mixed-version rollout may temporarily emit legacy pending token aliases in addition to the new completion path
-- legacy alias support is transitional and bounded to rollout windows only
-
-### WeChat login and channel mapping
-
-- environment chooses `mp` or `open`
-- website login uses authorization-code flow with provider-configured app/channel binding
-- callback must resolve to `unionid`
-- channel `openid` is optionally recorded in `auth_identity_channels`
-- failure to obtain `unionid` aborts flow
-
-### Avatar upload and delete
-
-- URL avatar: validate and persist reference
-- data URL avatar:
- - decode
- - validate image type
- - compress to `<=100KB`
- - persist database-backed inline representation
-- delete removes current custom avatar entry
-
-## Payment Routing Model
-
-### User-visible methods
-
-- `alipay`
-- `wxpay`
-
-### Backend source abstraction
-
-Each display method maps to exactly one active configured backend source:
-
-- `official_alipay`
-- `easypay_alipay`
-- `official_wechat`
-- `easypay_wechat`
-
-Frontend submits display method only. Backend resolves display method to active source and capability set.
-
-### Legacy payment-config normalization
-
-- existing provider-instance `supported_types`, legacy aliases such as `wxpay_direct`, and per-type limit structures are migrated into the visible-method model
-- migration preserves historical payment capability and refund semantics
-- the system keeps one normalized visible-method mapping per provider instance for rollout and audit
-
-### Alipay routing
-
-- PC: create QR-oriented result and show QR in page
-- mobile: create jump/redirect-oriented result
-
-### WeChat routing
-
-- PC: QR result
-- non-WeChat H5:
- - prefer H5 pay
- - if unavailable, show "open in WeChat" requirement
-- WeChat environment:
- - prefer MP/JSAPI
- - if unavailable, fall back to H5 pay
-
-### WeChat payment OAuth recovery
-
-- if WeChat in-app payment requires `openid` and the current request does not already hold it, backend returns an `oauth_required` response instead of guessing
-- backend creates a server-backed payment-resume context containing:
- - target visible method
- - amount/order type/plan context
- - redirect target
- - anti-replay state
-- backend redirects through a dedicated WeChat payment OAuth start endpoint
-- callback exchanges the provider code server-side, stores `openid` in the payment-resume context, and returns a same-origin internal resume target
-- frontend resumes the original order flow through the resume context instead of trusting raw callback query state or long-lived local storage
-
-### Payment completion
-
-- frontend return restores context and UI state
-- backend order state remains source of truth
-- webhook and/or order query remain authoritative for fulfillment
-- order fulfillment validates webhook or query payload against order-time snapshot data including provider instance, merchant identifiers, amount, currency, and provider order references
-- result pages use authenticated lookup or signed opaque result tokens, never raw public `out_trade_no`
-
-## Admin Configuration Model
-
-### Auth provider settings
-
-- email registration and verification settings
-- force email on third-party signup
-- LinuxDo client settings
-- OIDC issuer/client settings and provider display name
-- WeChat `open` / `mp` capability indicators derived from environment-backed configuration, surfaced to the frontend/admin read models as effective availability rather than full in-panel credential editing
-
-### Source default settings
-
-Per source (`email`, `linuxdo`, `oidc`, `wechat`):
-
-- default balance
-- default concurrency
-- default subscriptions
-- grant on signup
-- grant on first bind
-
-### Payment settings
-
-- active source for `alipay`
-- active source for `wechat`
-- source-specific credentials and enablement
-- effective WeChat payment capabilities may differ by enabled provider instances and selected visible-method source:
- - QR available
- - H5 available
- - MP/JSAPI available
-
-### Scheduling settings
-
-- OpenAI advanced scheduling enabled/disabled
-- default disabled
-
-## Compatibility And Rollout
-
-Compatibility is mandatory, especially for:
-
-- existing email users
-- existing LinuxDo users
-- historical LinuxDo synthetic-email accounts
-- historical WeChat synthetic-email accounts
-- historical OIDC synthetic-email accounts
-- historical WeChat `openid`-only records created by older branches
-
-### Additive migrations
-
-- preserve existing `users` data and behavior
-- add identity and pending-session tables
-- avoid destructive schema swaps
-
-### Migration backfill
-
-- backfill canonical `email` identities for valid existing email users
-- backfill canonical `linuxdo` identities during migration for historical synthetic-email LinuxDo users
-- backfill canonical `wechat` and `oidc` identities when historical synthetic-email or `user_external_identities` data allows deterministic reconstruction
-- emit migration reports for historical WeChat `openid`-only records that cannot be safely promoted to canonical `unionid`
-- backfill must be idempotent and repeatable
-
-### Compatibility reads
-
-During rollout:
-
-- read new identity model first
-- where necessary, retain compatibility logic for existing email and historical LinuxDo/WeChat/OIDC synthetic-email recognition
-
-### Grant idempotency
-
-- migration backfill must not trigger signup or first-bind grants
-- first-bind grants must use explicit idempotency tracking
-
-### API compatibility
-
-Retain transitional support for legacy/new request and response shapes where needed, including:
-
-- `pending_auth_token`
-- `pending_oauth_token`
-- old callback parsing expectations
-- historical profile field mappings
-- legacy callback fragment readers during the bounded rollout window
-
-### Settings and payment compatibility
-
-- preserve existing payment configs and order semantics from `main`
-- add new settings incrementally
-- avoid rewriting the entire settings schema in one cutover
-- preserve legacy provider-instance capabilities by explicitly mapping historical `supported_types`, `payment_mode`, and limit config into normalized visible-method routing
-
-### Rolling upgrade tolerance
-
-- do not assume simultaneous frontend/backend deployment
-- new backend must tolerate short-lived older frontend request shapes
-- rollout must define the deployment order and removal point for legacy callback token parsing and legacy payment resume parsing
-
-## Testing Strategy
-
-### Repository tests
-
-- identity upsert and lookup
-- WeChat channel mapping
-- pending auth session persistence
-- source grant idempotency
-- avatar persistence and delete
-- migration backfill behavior
-
-### Service tests
-
-- direct login by existing identity
-- first third-party signup
-- forced email flow
-- existing-email bind-existing-account flow
-- first-bind grant on/off
-- nickname/avatar adoption choices
-- WeChat `unionid` required behavior
-- payment routing resolution
-
-### Handler and route tests
-
-- LinuxDo/OIDC/WeChat callback handling
-- bind-existing
-- bind-current-user
-- create-account
-- TOTP continuation
-- payment create and recovery
-
-### Frontend tests
-
-- third-party callback flow state machine
-- register/login continuation
-- profile bindings card
-- avatar interactions
-- payment page routing behavior
-- admin settings UI
-
-### Compatibility tests
-
-- existing email users
-- historical LinuxDo synthetic-email users
-- historical WeChat synthetic-email users
-- historical OIDC synthetic-email users
-- historical WeChat `openid`-only records reported or remediated correctly
-- historical payment config
-- legacy auth payload field names
-- historical payment result handling
-- mixed-version callback token bridge behavior
-
-## Implementation Phases
-
-1. Add schema, migrations, compatibility backfill, and repository support
-2. Implement unified identity services and pending auth session flows
-3. Integrate profile binding, avatar, and adoption decision flows
-4. Add per-source default grants and admin config surfaces
-5. Rebuild payment routing abstraction and frontend payment UX
-6. Add user-management sorting and OpenAI advanced scheduling switch
-7. Run compatibility, rollout, and regression hardening
-
-## External Constraints And Best Practices
-
-Implementation must follow current primary-source guidance:
-
-- OAuth 2.0 Security BCP (RFC 9700): strict redirect handling, state protection, mix-up resistant design
-- PKCE (RFC 7636): require `S256` on browser authorization-code flows
-- OpenID Connect Core: stable issuer/subject handling for OIDC identities
-- Account linking best practice: require explicit user confirmation or re-authentication before linking to existing accounts
-- WeChat UnionID and website-login guidance: treat `unionid` as canonical cross-channel subject and persist channel/app binding with website login responses
-- WeChat Pay webhook guidance: verify signatures, decrypt payloads, and confirm merchant/order/amount fields against order-time state before fulfillment
-- Payment success-page guidance: custom success pages are informational and must not be the only fulfillment trigger
-
-References:
-
-- RFC 9700:
-- RFC 7636:
-- OpenID Connect Core 1.0:
-- Auth0 account linking guidance:
-- WeChat UnionID guidance:
-- WeChat website login guidance:
-- WeChat Pay callback/signature guidance:
-- Stripe Checkout fulfillment guidance:
-
-## Audit Synthesis
-
-The clean rebuild direction is not to copy either existing branch directly.
-
-- `feat/auth-identity-foundation` has the better long-term model:
- - unified auth identities
- - pending auth sessions
- - identity adoption decisions
- - provider-scoped default grants
- - payment display-method abstraction
- - OpenAI advanced scheduler layering
-- `personal-dev-branch` has the better real-world closure:
- - LinuxDo and WeChat callback flows are more operationally complete
- - profile binding and avatar UX is more complete
- - historical synthetic-email users across multiple providers are recognized and recovered in live flows
- - WeChat payment OAuth and recovery behavior is more complete
-- Primary-source guidance supplies hard constraints for OAuth/OIDC, account linking, WeChat identity handling, and payment completion semantics.
-
-The final rebuild must therefore:
-
-- keep the `feat/auth-identity-foundation` data model direction
-- absorb the strongest business-flow behavior from `personal-dev-branch`
-- reject transitional or half-finished behavior from both branches
-- treat compatibility and rollout as first-class implementation scope
-
-## Keep / Adapt / Drop
-
-### Keep
-
-Keep these architectural choices essentially intact:
-
-- `auth_identities`, `auth_identity_channels`, `pending_auth_sessions`, `identity_adoption_decisions`
-- per-provider default grants with one-time grant tracking
-- WeChat canonical identity plus channel mapping model
-- pending-auth verification gates before final bind
-- payment visible-method abstraction (`alipay`, `wechat`) decoupled from backend provider source
-- OpenAI advanced scheduler layering and test-backed behavior
-
-Keep these operational flow ideas from `personal-dev-branch`:
-
-- LinuxDo pending identity callback flow
-- WeChat pending identity callback flow
-- profile bindings UX and “cannot disconnect last usable login method” rule
-- separate WeChat login OAuth and WeChat payment OAuth entry points
-- historical synthetic-email recognition logic as a migration bridge
-- explicit WeChat payment OAuth recovery protocol as a product requirement, but reimplemented with server-backed resume state
-
-### Adapt
-
-These areas must be reimplemented with the same intent but stricter boundaries:
-
-- third-party account creation from pending-auth state must be transactional and must not register a plain local user before identity finalization succeeds
-- email identity lifecycle must become real dual-write state, not just one migration-time backfill
-- `signup_source` must be backfilled more accurately for known historical third-party users
-- WeChat payment recovery state must move from frontend-only storage to server-backed continuation state
-- avatar adoption fetches must be security-hardened and failure-visible
-- pending-auth payload modeling must clearly separate immutable upstream payload from mutable local metadata
-- callback completion must use a real exchange/session model instead of fragment-delivered bearer tokens
-- profile binding/avatar DTOs must be simplified to one authoritative backend contract instead of sprawling frontend fallback parsing
-- admin settings should preserve capability while reducing duplicated or transitional config branches
-
-### Drop
-
-Drop these as long-term design choices:
-
-- `user_external_identities` as the primary long-term identity model
-- synthetic email as a long-term canonical identity representation
-- OIDC as a side-path that does not participate in the same identity foundation as LinuxDo and WeChat
-- frontend multi-endpoint probing and broad compatibility parsing once the clean branch becomes the sole supported contract
-- unrelated branch noise such as generated-file churn, locale-only churn, or upstream merge residue as design inputs
-
-## Audit-Driven Hard Constraints
-
-The audit and source review establish these hard constraints:
-
-### Auth
-
-- all browser authorization-code providers use PKCE `S256` and do not expose an admin-off switch
-- callback handling uses strict `redirect_uri` discipline and state validation
-- OIDC identity key is `issuer + sub`
-- existing-account linking after email conflict must require explicit user action plus local-account verification
-- WeChat canonical identity key is `unionid`; `openid` is channel-scoped only
-
-### Compatibility
-
-- existing email users must continue to work with no manual intervention
-- existing LinuxDo users must not split into duplicate accounts
-- historical LinuxDo/WeChat/OIDC synthetic-email users must be backfilled into canonical identities during migration when deterministic recovery is possible
-- historical WeChat `openid`-only records must be surfaced through migration reporting and explicit remediation rules
-- migration backfills must not trigger signup or first-bind grants
-- legacy `pending_auth_token` and `pending_oauth_token` contracts must remain accepted during rollout
-- legacy auth/public setting aliases needed by older frontend builds must remain available during rollout
-- existing payment configs and historical order semantics must remain valid
-
-### Payment
-
-- frontend return pages do not determine final payment success
-- backend order state, webhook processing, and/or provider status query remain authoritative
-- each visible method (`alipay`, `wxpay`) may have only one active backend source at a time
-- public result pages must not expose raw `out_trade_no` lookup
-- WeChat Pay callback handling must verify signature, decrypt payload, and compare order fields against order-time snapshot data
-
-## Known Risks To Eliminate In Implementation
-
-These are specifically observed problems in the existing branches that the clean rebuild must eliminate:
-
-- third-party forced-email account creation currently bypasses the provider-aware account creation path and can leave orphan local accounts if bind finalization fails
-- post-migration email accounts are not fully dual-written into `auth_identities`
-- avatar adoption currently risks silent failure and insecure outbound fetch behavior
-- pending-auth payload responsibilities are internally inconsistent
-- OIDC parity is incomplete in `personal-dev-branch`; it must become a first-class provider in the unified identity model
-- WeChat union/open/channel identity handling is conceptually correct in the feature branch but still partially transitional across the codebase
-- WeChat payment recovery in `personal-dev-branch` is frontend-local and not robust across tabs or concurrent attempts
-- the existing pending-auth migration update is too destructive to reuse unchanged in a safer rollout
-- historical provider provenance should not be permanently flattened to `signup_source = email`
-- design/plan drift can reintroduce ambiguous identity uniqueness or ambiguous adoption-decision ownership if not aligned before implementation
-
-## Rollout Gates
-
-The rebuild is not ready for rollout until all of these are satisfied:
-
-1. Identity schema and migration chain are linearized and production-safe.
-2. Email identity backfill is complete and idempotent.
-3. Historical LinuxDo/WeChat/OIDC synthetic-email backfill to canonical identity is complete where deterministic, and non-recoverable rows are reported.
-4. Historical WeChat `openid`-only rows are either remediated or explicitly blocked with operator-visible reporting.
-5. `signup_source` backfill is accurate for known historical provider-created users.
-6. Dual token acceptance, exchange bridge behavior, and required legacy field aliases are present for the bounded rollout window.
-7. Existing payment configs are normalized and verified against current frontend-visible capabilities.
-8. New frontend flows are verified against mixed-version backend compatibility windows.
-9. Duplicate-account creation, first-bind grants, and payment route selection have regression coverage.
From ed01c599161acc67a18ef19c73daeb9fbe1243ab Mon Sep 17 00:00:00 2001
From: IanShaw027
Date: Tue, 21 Apr 2026 14:54:53 +0800
Subject: [PATCH 117/189] feat: track authenticated user activity
---
.../server/middleware/admin_auth_test.go | 29 ++++++++
.../internal/server/middleware/jwt_auth.go | 16 ++++-
.../server/middleware/jwt_auth_test.go | 59 ++++++++++++++++
.../service/admin_service_apikey_test.go | 3 +
.../service/admin_service_delete_test.go | 12 ++++
.../admin_service_email_identity_sync_test.go | 4 ++
backend/internal/service/user_service.go | 68 +++++++++++++++++++
backend/internal/service/user_service_test.go | 65 ++++++++++++++----
frontend/src/views/admin/UsersView.vue | 17 ++---
.../views/admin/__tests__/UsersView.spec.ts | 7 +-
10 files changed, 254 insertions(+), 26 deletions(-)
diff --git a/backend/internal/server/middleware/admin_auth_test.go b/backend/internal/server/middleware/admin_auth_test.go
index ed2578c8..cc5bead3 100644
--- a/backend/internal/server/middleware/admin_auth_test.go
+++ b/backend/internal/server/middleware/admin_auth_test.go
@@ -7,6 +7,7 @@ import (
"net/http"
"net/http/httptest"
"testing"
+ "time"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
@@ -153,6 +154,18 @@ func (s *stubUserRepo) Delete(ctx context.Context, id int64) error {
panic("unexpected Delete call")
}
+func (s *stubUserRepo) GetUserAvatar(ctx context.Context, userID int64) (*service.UserAvatar, error) {
+ return nil, nil
+}
+
+func (s *stubUserRepo) UpsertUserAvatar(ctx context.Context, userID int64, input service.UpsertUserAvatarInput) (*service.UserAvatar, error) {
+ panic("unexpected UpsertUserAvatar call")
+}
+
+func (s *stubUserRepo) DeleteUserAvatar(ctx context.Context, userID int64) error {
+ panic("unexpected DeleteUserAvatar call")
+}
+
func (s *stubUserRepo) List(ctx context.Context, params pagination.PaginationParams) ([]service.User, *pagination.PaginationResult, error) {
panic("unexpected List call")
}
@@ -161,6 +174,18 @@ func (s *stubUserRepo) ListWithFilters(ctx context.Context, params pagination.Pa
panic("unexpected ListWithFilters call")
}
+func (s *stubUserRepo) GetLatestUsedAtByUserIDs(ctx context.Context, userIDs []int64) (map[int64]*time.Time, error) {
+ panic("unexpected GetLatestUsedAtByUserIDs call")
+}
+
+func (s *stubUserRepo) GetLatestUsedAtByUserID(ctx context.Context, userID int64) (*time.Time, error) {
+ panic("unexpected GetLatestUsedAtByUserID call")
+}
+
+func (s *stubUserRepo) UpdateUserLastActiveAt(ctx context.Context, userID int64, activeAt time.Time) error {
+ panic("unexpected UpdateUserLastActiveAt call")
+}
+
func (s *stubUserRepo) UpdateBalance(ctx context.Context, id int64, amount float64) error {
panic("unexpected UpdateBalance call")
}
@@ -189,6 +214,10 @@ func (s *stubUserRepo) AddGroupToAllowedGroups(ctx context.Context, userID int64
panic("unexpected AddGroupToAllowedGroups call")
}
+func (s *stubUserRepo) ListUserAuthIdentities(ctx context.Context, userID int64) ([]service.UserAuthIdentityRecord, error) {
+ panic("unexpected ListUserAuthIdentities call")
+}
+
func (s *stubUserRepo) UpdateTotpSecret(ctx context.Context, userID int64, encryptedSecret *string) error {
panic("unexpected UpdateTotpSecret call")
}
diff --git a/backend/internal/server/middleware/jwt_auth.go b/backend/internal/server/middleware/jwt_auth.go
index 4aceb355..48cb9004 100644
--- a/backend/internal/server/middleware/jwt_auth.go
+++ b/backend/internal/server/middleware/jwt_auth.go
@@ -1,6 +1,7 @@
package middleware
import (
+ "context"
"errors"
"strings"
@@ -11,11 +12,19 @@ import (
// NewJWTAuthMiddleware 创建 JWT 认证中间件
func NewJWTAuthMiddleware(authService *service.AuthService, userService *service.UserService) JWTAuthMiddleware {
- return JWTAuthMiddleware(jwtAuth(authService, userService))
+ return JWTAuthMiddleware(jwtAuth(authService, userService, userService))
+}
+
+type jwtUserReader interface {
+ GetByID(ctx context.Context, id int64) (*service.User, error)
+}
+
+type userActivityToucher interface {
+ TouchLastActiveForUser(ctx context.Context, user *service.User)
}
// jwtAuth JWT认证中间件实现
-func jwtAuth(authService *service.AuthService, userService *service.UserService) gin.HandlerFunc {
+func jwtAuth(authService *service.AuthService, userService jwtUserReader, activityToucher userActivityToucher) gin.HandlerFunc {
return func(c *gin.Context) {
// 从Authorization header中提取token
authHeader := c.GetHeader("Authorization")
@@ -73,6 +82,9 @@ func jwtAuth(authService *service.AuthService, userService *service.UserService)
Concurrency: user.Concurrency,
})
c.Set(string(ContextKeyUserRole), user.Role)
+ if activityToucher != nil {
+ activityToucher.TouchLastActiveForUser(c.Request.Context(), user)
+ }
c.Next()
}
diff --git a/backend/internal/server/middleware/jwt_auth_test.go b/backend/internal/server/middleware/jwt_auth_test.go
index c483a51e..84fd6967 100644
--- a/backend/internal/server/middleware/jwt_auth_test.go
+++ b/backend/internal/server/middleware/jwt_auth_test.go
@@ -9,6 +9,7 @@ import (
"net/http"
"net/http/httptest"
"testing"
+ "time"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/service"
@@ -30,6 +31,25 @@ func (r *stubJWTUserRepo) GetByID(_ context.Context, id int64) (*service.User, e
return u, nil
}
+func (r *stubJWTUserRepo) GetUserAvatar(_ context.Context, _ int64) (*service.UserAvatar, error) {
+ return nil, nil
+}
+
+func (r *stubJWTUserRepo) UpdateUserLastActiveAt(_ context.Context, _ int64, _ time.Time) error {
+ return nil
+}
+
+type recordingActivityToucher struct {
+ userIDs []int64
+}
+
+func (r *recordingActivityToucher) TouchLastActiveForUser(_ context.Context, user *service.User) {
+ if user == nil {
+ return
+ }
+ r.userIDs = append(r.userIDs, user.ID)
+}
+
// newJWTTestEnv 创建 JWT 认证中间件测试环境。
// 返回 gin.Engine(已注册 JWT 中间件)和 AuthService(用于生成 Token)。
func newJWTTestEnv(users map[int64]*service.User) (*gin.Engine, *service.AuthService) {
@@ -106,6 +126,45 @@ func TestJWTAuth_ValidToken_LowercaseBearer(t *testing.T) {
require.Equal(t, http.StatusOK, w.Code)
}
+func TestJWTAuth_ValidToken_TouchesLastActive(t *testing.T) {
+ user := &service.User{
+ ID: 1,
+ Email: "test@example.com",
+ Role: "user",
+ Status: service.StatusActive,
+ Concurrency: 5,
+ TokenVersion: 1,
+ }
+
+ gin.SetMode(gin.TestMode)
+
+ cfg := &config.Config{}
+ cfg.JWT.Secret = "test-jwt-secret-32bytes-long!!!"
+ cfg.JWT.AccessTokenExpireMinutes = 60
+
+ userRepo := &stubJWTUserRepo{users: map[int64]*service.User{1: user}}
+ authSvc := service.NewAuthService(nil, userRepo, nil, nil, cfg, nil, nil, nil, nil, nil, nil)
+ userSvc := service.NewUserService(userRepo, nil, nil, nil)
+ toucher := &recordingActivityToucher{}
+
+ r := gin.New()
+ r.Use(jwtAuth(authSvc, userSvc, toucher))
+ r.GET("/protected", func(c *gin.Context) {
+ c.Status(http.StatusOK)
+ })
+
+ token, err := authSvc.GenerateToken(user)
+ require.NoError(t, err)
+
+ w := httptest.NewRecorder()
+ req := httptest.NewRequest(http.MethodGet, "/protected", nil)
+ req.Header.Set("Authorization", "Bearer "+token)
+ r.ServeHTTP(w, req)
+
+ require.Equal(t, http.StatusOK, w.Code)
+ require.Equal(t, []int64{1}, toucher.userIDs)
+}
+
func TestJWTAuth_MissingAuthorizationHeader(t *testing.T) {
router, _ := newJWTTestEnv(nil)
diff --git a/backend/internal/service/admin_service_apikey_test.go b/backend/internal/service/admin_service_apikey_test.go
index e2eae0b4..aab35d25 100644
--- a/backend/internal/service/admin_service_apikey_test.go
+++ b/backend/internal/service/admin_service_apikey_test.go
@@ -88,6 +88,9 @@ func (s *userRepoStubForGroupUpdate) GetLatestUsedAtByUserIDs(context.Context, [
func (s *userRepoStubForGroupUpdate) GetLatestUsedAtByUserID(context.Context, int64) (*time.Time, error) {
panic("unexpected")
}
+func (s *userRepoStubForGroupUpdate) UpdateUserLastActiveAt(context.Context, int64, time.Time) error {
+ panic("unexpected")
+}
func (s *userRepoStubForGroupUpdate) RemoveGroupFromUserAllowedGroups(context.Context, int64, int64) error {
panic("unexpected")
}
diff --git a/backend/internal/service/admin_service_delete_test.go b/backend/internal/service/admin_service_delete_test.go
index ac1d8ee7..126faad9 100644
--- a/backend/internal/service/admin_service_delete_test.go
+++ b/backend/internal/service/admin_service_delete_test.go
@@ -107,6 +107,18 @@ func (s *userRepoStub) ListWithFilters(ctx context.Context, params pagination.Pa
panic("unexpected ListWithFilters call")
}
+func (s *userRepoStub) GetLatestUsedAtByUserIDs(ctx context.Context, userIDs []int64) (map[int64]*time.Time, error) {
+ panic("unexpected GetLatestUsedAtByUserIDs call")
+}
+
+func (s *userRepoStub) GetLatestUsedAtByUserID(ctx context.Context, userID int64) (*time.Time, error) {
+ panic("unexpected GetLatestUsedAtByUserID call")
+}
+
+func (s *userRepoStub) UpdateUserLastActiveAt(ctx context.Context, userID int64, activeAt time.Time) error {
+ panic("unexpected UpdateUserLastActiveAt call")
+}
+
func (s *userRepoStub) UpdateBalance(ctx context.Context, id int64, amount float64) error {
panic("unexpected UpdateBalance call")
}
diff --git a/backend/internal/service/admin_service_email_identity_sync_test.go b/backend/internal/service/admin_service_email_identity_sync_test.go
index d6a7af9a..eaf4e84b 100644
--- a/backend/internal/service/admin_service_email_identity_sync_test.go
+++ b/backend/internal/service/admin_service_email_identity_sync_test.go
@@ -97,6 +97,10 @@ func (s *emailSyncRepoStub) GetLatestUsedAtByUserID(context.Context, int64) (*ti
return nil, nil
}
+func (s *emailSyncRepoStub) UpdateUserLastActiveAt(context.Context, int64, time.Time) error {
+ return nil
+}
+
func (s *emailSyncRepoStub) UpdateBalance(context.Context, int64, float64) error { return nil }
func (s *emailSyncRepoStub) DeductBalance(context.Context, int64, float64) error { return nil }
diff --git a/backend/internal/service/user_service.go b/backend/internal/service/user_service.go
index e6053984..c6bf14c2 100644
--- a/backend/internal/service/user_service.go
+++ b/backend/internal/service/user_service.go
@@ -19,10 +19,13 @@ import (
"log/slog"
"net/url"
"sort"
+ "strconv"
"strings"
+ "sync"
"time"
xdraw "golang.org/x/image/draw"
+ "golang.org/x/sync/singleflight"
)
var (
@@ -47,6 +50,8 @@ const (
notifyCodeUserRateWindow = 10 * time.Minute
defaultUserIdentityRedirect = "/settings/profile"
+ userLastActiveMinTouch = 10 * time.Minute
+ userLastActiveFailBackoff = 30 * time.Second
)
var (
@@ -82,6 +87,7 @@ type UserRepository interface {
ListWithFilters(ctx context.Context, params pagination.PaginationParams, filters UserListFilters) ([]User, *pagination.PaginationResult, error)
GetLatestUsedAtByUserIDs(ctx context.Context, userIDs []int64) (map[int64]*time.Time, error)
GetLatestUsedAtByUserID(ctx context.Context, userID int64) (*time.Time, error)
+ UpdateUserLastActiveAt(ctx context.Context, userID int64, activeAt time.Time) error
UpdateBalance(ctx context.Context, id int64, amount float64) error
DeductBalance(ctx context.Context, id int64, amount float64) error
@@ -192,6 +198,8 @@ type UserService struct {
settingRepo SettingRepository
authCacheInvalidator APIKeyAuthCacheInvalidator
billingCache BillingCache
+ lastActiveTouchL1 sync.Map
+ lastActiveTouchSF singleflight.Group
}
// NewUserService 创建用户服务实例
@@ -788,6 +796,66 @@ func (s *UserService) GetByID(ctx context.Context, id int64) (*User, error) {
return user, nil
}
+// TouchLastActive 通过防抖更新 users.last_active_at,减少鉴权热路径写放大。
+// 该操作为尽力而为,不应中断正常请求。
+func (s *UserService) TouchLastActive(ctx context.Context, userID int64) {
+ if s == nil || s.userRepo == nil || userID <= 0 {
+ return
+ }
+
+ user, err := s.userRepo.GetByID(ctx, userID)
+ if err != nil {
+ slog.Debug("skip touch user last active after load failure", "user_id", userID, "error", err)
+ return
+ }
+ s.TouchLastActiveForUser(ctx, user)
+}
+
+// TouchLastActiveForUser 使用已加载的用户信息更新 last_active_at,避免重复读取数据库。
+func (s *UserService) TouchLastActiveForUser(ctx context.Context, user *User) {
+ if s == nil || s.userRepo == nil || user == nil || user.ID <= 0 {
+ return
+ }
+
+ now := time.Now()
+ if userLastActiveFresh(user.LastActiveAt, now) {
+ return
+ }
+ if v, ok := s.lastActiveTouchL1.Load(user.ID); ok {
+ if nextAllowedAt, ok := v.(time.Time); ok && now.Before(nextAllowedAt) {
+ return
+ }
+ }
+
+ _, err, _ := s.lastActiveTouchSF.Do(strconv.FormatInt(user.ID, 10), func() (any, error) {
+ latest := time.Now()
+ if v, ok := s.lastActiveTouchL1.Load(user.ID); ok {
+ if nextAllowedAt, ok := v.(time.Time); ok && latest.Before(nextAllowedAt) {
+ return nil, nil
+ }
+ }
+ if userLastActiveFresh(user.LastActiveAt, latest) {
+ return nil, nil
+ }
+ if err := s.userRepo.UpdateUserLastActiveAt(ctx, user.ID, latest); err != nil {
+ s.lastActiveTouchL1.Store(user.ID, latest.Add(userLastActiveFailBackoff))
+ return nil, fmt.Errorf("touch user last active: %w", err)
+ }
+ s.lastActiveTouchL1.Store(user.ID, latest.Add(userLastActiveMinTouch))
+ return nil, nil
+ })
+ if err != nil {
+ slog.Warn("touch user last active failed", "user_id", user.ID, "error", err)
+ }
+}
+
+func userLastActiveFresh(lastActiveAt *time.Time, now time.Time) bool {
+ if lastActiveAt == nil {
+ return false
+ }
+ return now.Before(lastActiveAt.Add(userLastActiveMinTouch))
+}
+
func (s *UserService) hydrateUserAvatar(ctx context.Context, user *User) error {
if s == nil || s.userRepo == nil || user == nil || user.ID == 0 {
return nil
diff --git a/backend/internal/service/user_service_test.go b/backend/internal/service/user_service_test.go
index d771cb75..2c11f8ec 100644
--- a/backend/internal/service/user_service_test.go
+++ b/backend/internal/service/user_service_test.go
@@ -23,18 +23,21 @@ import (
// --- mock: UserRepository ---
type mockUserRepo struct {
- updateBalanceErr error
- updateBalanceFn func(ctx context.Context, id int64, amount float64) error
- getByIDUser *User
- getByIDErr error
- updateFn func(ctx context.Context, user *User) error
- updateCalls int
- upsertAvatarFn func(ctx context.Context, userID int64, input UpsertUserAvatarInput) (*UserAvatar, error)
- upsertAvatarArgs []UpsertUserAvatarInput
- deleteAvatarFn func(ctx context.Context, userID int64) error
- deleteAvatarIDs []int64
- getAvatarFn func(ctx context.Context, userID int64) (*UserAvatar, error)
- txCalls int
+ updateBalanceErr error
+ updateBalanceFn func(ctx context.Context, id int64, amount float64) error
+ getByIDUser *User
+ getByIDErr error
+ updateLastActiveErr error
+ updateLastActiveUserIDs []int64
+ updateLastActiveAt []time.Time
+ updateFn func(ctx context.Context, user *User) error
+ updateCalls int
+ upsertAvatarFn func(ctx context.Context, userID int64, input UpsertUserAvatarInput) (*UserAvatar, error)
+ upsertAvatarArgs []UpsertUserAvatarInput
+ deleteAvatarFn func(ctx context.Context, userID int64) error
+ deleteAvatarIDs []int64
+ getAvatarFn func(ctx context.Context, userID int64) (*UserAvatar, error)
+ txCalls int
}
type mockUserRepoTxKey struct{}
@@ -144,6 +147,11 @@ func (m *mockUserRepo) UpdateBalance(ctx context.Context, id int64, amount float
}
return m.updateBalanceErr
}
+func (m *mockUserRepo) UpdateUserLastActiveAt(_ context.Context, userID int64, activeAt time.Time) error {
+ m.updateLastActiveUserIDs = append(m.updateLastActiveUserIDs, userID)
+ m.updateLastActiveAt = append(m.updateLastActiveAt, activeAt)
+ return m.updateLastActiveErr
+}
func (m *mockUserRepo) DeductBalance(context.Context, int64, float64) error { return nil }
func (m *mockUserRepo) UpdateConcurrency(context.Context, int64, int) error { return nil }
func (m *mockUserRepo) ExistsByEmail(context.Context, string) (bool, error) { return false, nil }
@@ -288,6 +296,39 @@ func TestUpdateBalance_CacheFailure_DoesNotAffectReturn(t *testing.T) {
}, 2*time.Second, 10*time.Millisecond, "即使失败也应调用 InvalidateUserBalance")
}
+func TestTouchLastActive_UpdatesWhenStale(t *testing.T) {
+ stale := time.Now().Add(-11 * time.Minute)
+ repo := &mockUserRepo{
+ getByIDUser: &User{
+ ID: 42,
+ LastActiveAt: &stale,
+ },
+ }
+ svc := NewUserService(repo, nil, nil, nil)
+
+ svc.TouchLastActive(context.Background(), 42)
+
+ require.Equal(t, []int64{42}, repo.updateLastActiveUserIDs)
+ require.Len(t, repo.updateLastActiveAt, 1)
+ require.WithinDuration(t, time.Now(), repo.updateLastActiveAt[0], 2*time.Second)
+}
+
+func TestTouchLastActive_SkipsWhenRecent(t *testing.T) {
+ recent := time.Now().Add(-time.Minute)
+ repo := &mockUserRepo{
+ getByIDUser: &User{
+ ID: 42,
+ LastActiveAt: &recent,
+ },
+ }
+ svc := NewUserService(repo, nil, nil, nil)
+
+ svc.TouchLastActive(context.Background(), 42)
+
+ require.Empty(t, repo.updateLastActiveUserIDs)
+ require.Empty(t, repo.updateLastActiveAt)
+}
+
func TestUpdateBalance_RepoError_ReturnsError(t *testing.T) {
repo := &mockUserRepo{updateBalanceErr: errors.New("database error")}
cache := &mockBillingCache{}
diff --git a/frontend/src/views/admin/UsersView.vue b/frontend/src/views/admin/UsersView.vue
index 07c9d437..93cfdbbe 100644
--- a/frontend/src/views/admin/UsersView.vue
+++ b/frontend/src/views/admin/UsersView.vue
@@ -455,12 +455,6 @@
{{ formatDateTime(value) }}
-
-
- {{ value ? formatDateTime(value) : '-' }}
-
-
-
{{ value ? formatDateTime(value) : '-' }}
@@ -718,7 +712,6 @@ const allColumns = computed(() => [
{ key: 'usage', label: t('admin.users.columns.usage'), sortable: false },
{ key: 'concurrency', label: t('admin.users.columns.concurrency'), sortable: true },
{ key: 'status', label: t('admin.users.columns.status'), sortable: true },
- { key: 'last_login_at', label: t('admin.users.columns.lastLogin'), sortable: true },
{ key: 'last_used_at', label: t('admin.users.columns.lastUsed'), sortable: true },
{ key: 'last_active_at', label: t('admin.users.columns.lastActive'), sortable: true },
{ key: 'created_at', label: t('admin.users.columns.created'), sortable: true },
@@ -735,7 +728,9 @@ const toggleableColumns = computed(() =>
const hiddenColumns = reactive>(new Set())
// Default hidden columns (columns hidden by default on first load)
-const DEFAULT_HIDDEN_COLUMNS = ['notes', 'groups', 'subscriptions', 'usage', 'concurrency', 'last_login_at', 'last_active_at']
+const DEFAULT_HIDDEN_COLUMNS = ['notes', 'groups', 'subscriptions', 'usage', 'concurrency']
+const REMOVED_COLUMNS = new Set(['last_login_at'])
+const FORCED_VISIBLE_COLUMNS = new Set(['last_active_at'])
// localStorage key for column settings
const HIDDEN_COLUMNS_KEY = 'user-hidden-columns'
@@ -746,7 +741,9 @@ const loadSavedColumns = () => {
const saved = localStorage.getItem(HIDDEN_COLUMNS_KEY)
if (saved) {
const parsed = JSON.parse(saved) as string[]
- parsed.forEach(key => hiddenColumns.add(key))
+ parsed
+ .filter(key => !REMOVED_COLUMNS.has(key) && !FORCED_VISIBLE_COLUMNS.has(key))
+ .forEach(key => hiddenColumns.add(key))
} else {
// Use default hidden columns on first load
DEFAULT_HIDDEN_COLUMNS.forEach(key => hiddenColumns.add(key))
@@ -808,7 +805,7 @@ const searchQuery = ref('')
const USER_SORT_STORAGE_KEY = 'admin-users-table-sort'
const loadInitialSortState = (): { sort_by: string; sort_order: 'asc' | 'desc' } => {
const fallback = { sort_by: 'created_at', sort_order: 'desc' as 'asc' | 'desc' }
- const sortable = new Set(['email', 'id', 'username', 'role', 'balance', 'concurrency', 'status', 'last_login_at', 'last_used_at', 'last_active_at', 'created_at'])
+ const sortable = new Set(['email', 'id', 'username', 'role', 'balance', 'concurrency', 'status', 'last_used_at', 'last_active_at', 'created_at'])
try {
const raw = localStorage.getItem(USER_SORT_STORAGE_KEY)
if (!raw) return fallback
diff --git a/frontend/src/views/admin/__tests__/UsersView.spec.ts b/frontend/src/views/admin/__tests__/UsersView.spec.ts
index 1ea67b63..d9076777 100644
--- a/frontend/src/views/admin/__tests__/UsersView.spec.ts
+++ b/frontend/src/views/admin/__tests__/UsersView.spec.ts
@@ -113,7 +113,7 @@ describe('admin UsersView', () => {
getBatchUserAttributes.mockResolvedValue({ values: {} })
})
- it('shows last_used_at column and requests last_used_at sort', async () => {
+ it('shows active and used activity columns, hides last_login_at, and requests last_used_at sort', async () => {
const wrapper = mount(UsersView, {
global: {
stubs: {
@@ -144,7 +144,10 @@ describe('admin UsersView', () => {
await flushPromises()
- expect(wrapper.get('[data-test="columns"]').text()).toContain('last_used_at')
+ const columns = wrapper.get('[data-test="columns"]').text()
+ expect(columns).toContain('last_used_at')
+ expect(columns).toContain('last_active_at')
+ expect(columns).not.toContain('last_login_at')
await wrapper.get('[data-test="sort-last-used"]').trigger('click')
await flushPromises()
From 49258dd3f6dfaab31589c797c033620c498a21d3 Mon Sep 17 00:00:00 2001
From: IanShaw027
Date: Tue, 21 Apr 2026 14:55:07 +0800
Subject: [PATCH 118/189] fix: preserve scheduler transport compatibility
defaults
---
backend/internal/service/openai_account_scheduler.go | 3 +++
1 file changed, 3 insertions(+)
diff --git a/backend/internal/service/openai_account_scheduler.go b/backend/internal/service/openai_account_scheduler.go
index 38b92b47..5fda3abd 100644
--- a/backend/internal/service/openai_account_scheduler.go
+++ b/backend/internal/service/openai_account_scheduler.go
@@ -767,6 +767,9 @@ func (s *defaultOpenAIAccountScheduler) selectByLoadBalance(
}
func (s *defaultOpenAIAccountScheduler) isAccountTransportCompatible(account *Account, requiredTransport OpenAIUpstreamTransport) bool {
+ if requiredTransport == OpenAIUpstreamTransportAny || requiredTransport == OpenAIUpstreamTransportHTTPSSE {
+ return true
+ }
if s == nil || s.service == nil {
return false
}
From 78f691d2de24d0d13ce68922e120c8119ea32856 Mon Sep 17 00:00:00 2001
From: shaw
Date: Tue, 21 Apr 2026 12:13:45 +0800
Subject: [PATCH 119/189] chore: update sponsors
---
README.md | 5 +++++
README_CN.md | 5 +++++
README_JA.md | 5 +++++
assets/partners/logos/bestproxy.png | Bin 0 -> 9716 bytes
4 files changed, 15 insertions(+)
create mode 100644 assets/partners/logos/bestproxy.png
diff --git a/README.md b/README.md
index bee2e8c3..3e609d65 100644
--- a/README.md
+++ b/README.md
@@ -96,6 +96,11 @@ Sub2API is an AI API gateway platform designed to distribute and manage API quot
Huge thanks to BmoPlus for sponsoring this project! BmoPlus is a highly reliable AI account provider built strictly for heavy AI users and developers. They offer rock-solid, ready-to-use accounts and official top-up services for ChatGPT Plus / ChatGPT Pro (Full Warranty) / Claude Pro / Super Grok / Gemini Pro. By registering and ordering through BmoPlus - Premium AI Accounts & Top-ups , users can unlock the mind-blowing rate of 10% of the official GPT subscription price (90% OFF)
+
+
+Thanks to Bestproxy for sponsoring this project! Bestproxy provides high-purity residential IPs with dedicated one-IP-per-account support. By combining real home networks with fingerprint isolation, it enables link environment isolation and reduces the probability of association-based risk control.
+
+
## Ecosystem
diff --git a/README_CN.md b/README_CN.md
index 892eee61..add32a17 100644
--- a/README_CN.md
+++ b/README_CN.md
@@ -95,6 +95,11 @@ Sub2API 是一个 AI API 网关平台,用于分发和管理 AI 产品订阅的
感谢 BmoPlus 赞助了本项目!BmoPlus 是一家专为AI订阅重度用户打造的可靠 AI 账号代充服务商,提供稳定的 ChatGPT Plus / ChatGPT Pro(全程质保) / Claude Pro / Super Grok / Gemini Pro 的官方代充&成品账号。 通过BmoPlus AI成品号专卖/代充 注册下单的用户,可享GPT 官网订阅一折 的震撼价格!
+
+
+感谢 Bestproxy 赞助了本项目!Bestproxy 是一家提供高纯度住宅IP,支持一号一IP独享,结合真实家庭网络与指纹隔离,可实现链路环境隔离,降低关联风控概率。
+
+
## 生态项目
diff --git a/README_JA.md b/README_JA.md
index 6f0fc900..ccd595b9 100644
--- a/README_JA.md
+++ b/README_JA.md
@@ -95,6 +95,11 @@ Sub2API は、AI 製品のサブスクリプションから API クォータを
本プロジェクトにご支援いただいた BmoPlus に感謝いたします!BmoPlusは、AIサブスクリプションのヘビーユーザー向けに特化した信頼性の高いAIアカウントサービスプロバイダーであり、安定した ChatGPT Plus / ChatGPT Pro (完全保証) / Claude Pro / Super Grok / Gemini Pro の公式代行チャージおよび即納アカウントを提供しています。こちらのBmoPlus AIアカウント専門店/代行チャージ 経由でご登録・ご注文いただいたユーザー様は、GPTを 公式サイト価格の約1割(90% OFF) という驚異的な価格でご利用いただけます!
+
+
+Bestproxy のご支援に感謝します!Bestproxy は高純度の住宅IPを提供し、1アカウント1IP専有をサポートしています。実際の家庭ネットワークとフィンガープリント分離を組み合わせることで、リンク環境の分離を実現し、関連付けによるリスク管理の確率を低減します。
+
+
## エコシステム
diff --git a/assets/partners/logos/bestproxy.png b/assets/partners/logos/bestproxy.png
new file mode 100644
index 0000000000000000000000000000000000000000..87c586705020a09a07484d0b0cb9f3bff79a219d
GIT binary patch
literal 9716
zcmd6NXE>bQ*Y9m2j6Mko5n_lMWz-mLFuEv_kRUpvL}&Eg1<~^8HADzOw1^&^=tM-c
z=!EFKo?DXh{GaDN?}u}p^WmJiu5s_ZXYIAu-h2I4yFI(FB1b|7?Rn$8(~5ZK9UIRXG_%f%Nbj-B)-09=-}P}jt2Dk}*a+gfuOn%Ej)xLmF6
zKxzOG6?e5WG`7TG;YJuU3mY+(^_m72xP^%r%VPm$Ze=?ujJbupy94H-yNbH8yQQ&^
z35z(2P}EfzWMGZK8p2(zt!x~HUBy^_@f8N!7sCh^kjTNrRQQ3k>>m=aC&prq#o7rY
z5H2n*TrPZEwhm?p9w8wi1UD~&mzNWy;B<7e!5X@9+Bn_@IS|5s@sP$i8ar6nVJ&QJ
z;1@g%jclE;Vk|6#@IS02u~-KSqrbV^IC2^O8gaUq!Ub*`VQAur;NjxF5DN|${Y@_{
zY2tX%{k_OTjQy|9@1B*#ua)c!9UL+0ZgvfW43LqvH8V`p?LqM
z`_I@fT2aKsC0behKKoytKZpPKjNdo&oA4jE|G%#hCdPkjgPoIu)h|smF-BmlFxD6w
z?1e7y{LuvyV_~cX)(Z0m
ztBaa13W6)#hcA?5uGdD)XGG3_@Lw#gpAQL!FO=oY2S+UXCNDLFEJlXFg#?(`;hb>r
zZ2p%s__M4i;vW+IUz~v5Kj3x&h+n9`fCBgfxP87q+Xs8Xl1^B2TL;wKhk&^c`AcE3
z3+buzuaf8A#Vp5`h+feDPYMu^zvTWux)RV}{Cy|?63Eri?jP6$$%KDl%GN>M*47Fo
zWoYAU2xonOu`$6oU>w-sAZEn=ffM)zHpKpc2ly`xi2c%gQN({`_up>e_pU%w1sW~H
z-=+=h{B0b;)x}sGK-2e5zOw){n<^Mh1yf~Z;11Y^0=PKTz-6$71Ab9BH~((S;;;gj
zevN|`8bbc3eL?m6V*oxbdcU47m!}p1fbK{^T2kFL^=FzJmZH7l+;*2Yg+Dt+c2FKY
zW{%HM5+(yNxxvlTf|Eiurj@pRxr%6fus#5zM
zX3X-{7YE1m#T7WC-=lt1@(qi4w!?)_#6n>Z
zVmv7iyYnA}+*A=hBuXd5!Lo)kX59RQ%#;%LB5l?%cIdx@Eqf%&aC9$0cZ(dEu1#x
zsNmKpgRcaXeUeB;Xp-zFC_o;a%{UW4q`dUmJF$%Tk(Ga{k%TBPBtFY5mp#LROL4<}
z?4uY|Gfl6>WJmv19su07LGB4ISp+$eNw^6>t9|sWZ$zjp74lx-YH&;6e1Rxt#-;KHi=)!r3cbhyKLW*7&q;VK<+bHqh@s4;o
zOMEa6ES35U`SBnTO^t$$>q*|pzKtK)`cn@GZ~|pTHX!&}(CT|91Q5IN9cG}Q|BJ*o
z75^;m#l8<(1LhwQ*Zt;*6Mb$1$$GU=iJ1Nl7_n1;k4b40fyCFS-dn+gtjf>t;oP?x$keZK@i)IC@9xa5Bfy(oKs(Ol
z!bx#A6@#SZc|Vr!ehh!NEwoX^LcBrX?~5;b8vVrC;+Tgv5ay8c5`uI
z8KI4gAbs0{e?>#-xe`QRP?eg5MYrq$^9tr=ECVopU%*T$+_}|9pdbHhB6Kv!?u?c7
zV0HsV?@JJPLPuG{@Y&kEZnFK>u}@DA=TOys6Mkyz_j6lYLIMKA_u@H}RNgm~0w&$2>cSsHPShmRYMeu}Kuk^xkz+<{Q~vQeG;(fUtn9ES7ftExt6
z+Idc=zH@Kt^|!pmXwV!{k@Wh$KvKJi_7Z4%4_A7
zN*k|!8`S_+y>sB
zzVu7O!}ls}PTRKT$hE%ktbAc=L124}}?xCMt4>XrJ<%jX(%Z
z2r@E!%-_|pdj0{OVRxvG>H1#_s;H+w+LHo3phOI`+`f
z(Q$C-pgjo8^KH{iye
zd0Lt+%4o$ecC+bzu)FmWXW&avb4+@AeR?{oa_n}OqC+VYA7AryN2N9=O?xXlEWGvG
zH{s@KsDP(z_RCE8lnHN}2Z@;2X+VI=QH29dLPiFAm!hgFqW^x;M@mY{NLxy=J1}T@
zrOTUqLLfh^tgNiGtgPtRvOZ=Y?BLB9)sLXGzK?d|r`XGmkvxGA*oMC_SFa!;`v3lUFr5=^5uIhh=6@nx;E8HlV}4L(A?z-efFet6wp8sd{g@V&?r7
z(wGAzZA&zjm~Yqh)4IIfb@viK>^u)d8v>uF#F4x+Ot=&>Zu=7xJ_pS!r5r^58y_z{
z$2Dm7jlCFvDkeSa4-UwMu8rfIFfDl>bF6M^27R
z%H2Z!vV|{!xa5FAgfWlBNKu__-7XhB6>ycDyy5)Rwn+C0J5-duWo7qt{T$`7V=h*x
zXJ~l#LKQXe?c3qbP~X8RPDBYjRQzb*cIA|}<%#n^tdo^83?hod>ob~50z$!NznH`
zO=vhN99hxOh=aA*vtE7XURMu(%FI$!Q_WpZHsz$%KzwlNtu|_Ev@P&qv~E#R;rf>I
zr!Xp)9oP?M>gJM{mGVZ&O0iPB4(TdI6X6BCTw6=mR2p1v2@Y=aIwqwR)24eS4W-8m
z7_g}v`XZA(us=1mzgmV!7oWNwt^j?m#{inK2x1qP{Zvu6h1JpckXuJ#J@Ni9-lPbZ
z@sfp&X=4_Z)E25snfm8PIkmN^gM$ocgOMUJkNsO4uE9KQre*WTQw_J-*#j2c!du6`
zepP}p2i9)3MaqOPPx|=iKYYkXYtNxuo%isa)RUWv6c<8Opt9Ki-r7eRYCUmefc95_1qwMT
z(z!2nUAojuMClGyi=q`doCz-$Jv-Qk_x-3DA*SsixK3ZXEDWlfi+pAsn
z`t{4{h$%Y2Hwq;@JnqRQ<|X54Ff}6BKR$k^ix9XXGf`Ms)X}ZfP>js5o!bw|A->jUKSv7g7r
zM#jd-fJc^tl=XIfX5z;;B&kj#L0`vSTj00&@H
zKsn`a&L>O}hcuT`#m*qszA=K_o@8b(a-cwNM?dnpa5MXN6*3kg~DXF5xjB~MJ
zZu;mnG0W`*5lSwOk!37+r9Nld`XFxGHFu9M+16J%Zq8h;oOmA=8TZuk@%u^_vC|#L
zMuTYB7Pw8yYuEDf9Tt@M06Urn*UhFrueCA-h3nq}cUMMirs~UHzPv)e;VSFAKB=!;
zr%oXQ+^480DXA(f3`lMmA78!39zTfJD-R+7H&9|l?mHcx!^D8wQ$A#c<-p<5b0_cY
zbu(}As1NJRkEsWC$7Ru>S*(SH*4Eavo16e6%7XHl^y!(zgV>GRw>;Dp)wPH^oI3~Ah1i#vgv4t@aw3+A=uNte
z+IT7EpwC38&2Q;W8_Qdsy?r9i$1oOyM@4K5H${-LHp`rDNKn;`x)cjG<
z&(6|`Yu}H?2Xu8OD7kiEXxI$s(6Pdxo+rH?EV$D~B|_syLSlv_6-|#RDhe+%OAw~1
zdiLyDqXEo#H-HSZ1%9B)sH>01#>K^?T2+h|ySER^jb@Ru#-@m~U2WQwMDc2%ROA
z!yIng*Lms+$?M_sXUBrYS`VeRDcWt2J}ZS9wsjR{IXTzp2dBh8cZ&D)2*a0vEg3^%O1&TQ4IcdlNw@fyKZG=MG(^5
zXkfQKSv^{Ium1FBF%x5=4%dE1XaWo!>VD#SI?~{Momq)LQP1*s
zm4HB(zK1n_yZyA?tD?1LFy0s%d_alNdxu?aN>0w7nL*zn*4|g(ln4!9PriFY$h&yP
zO@jGh)7E@Ao!eZBj7;PkWm%|(dUqOMuGIsY#&xo49UBdei_596cDGOrb2=HsOJ7t}
z414>ipq-BxP|59D*=_#HDB5lhP7a|v({^#Ob#YnO^rU!HsC^cM@O+dz^t>2oV-y#+
zov0oL^O{&jFh-cD*`M;-H!Ce&4I!m<64u=*U}I;0`{qsZsKO^c{S);;vBH-1;DHgF
zsAdy957G@$Ae!pYgoCu-fF{KZPe$I>qtT>kRQkwxiV$o7p9h}r2>^kiioL9Bx-WZ+
znKhRBbEeMtNT7fc^Q^9mwPu4ir*bWQpfrP8uH({3!{px4iU>nYD5Ju9-G1YQ%ha;n
zoD74SQ@6WJ_5*%lA&BZO(a!FxFf12fJwQy#W`Ai^&a
zw=accMOqKDNA&x~8*{q;0MXNz&P1WYfW)0?h-OH83mMJyt|rVWHcv!#r(xT{qlvO9j9=}03qpoW
zn#)Fs<9)XD*|#*@ILYy@B&FB2wd3HbKPT(NuRa67keCopil&{3nHeMg->VM&yA8l`@d8lvjV5oSHq3Rj7P85LeXFV*bkPYPLUrV2q%yG*(@l$0=ed!xTU7o045t5a@1P`gQ)
z6QA|r!@Fwt&em|TQjTNKGpUfsAG_6HlyCkfi$P9fIJThSrXU-i&Iy0Gw+R&Q8xB2W
zHh2FD75Pi-6%(Nu!kJa>0sP3r19Ys8q5Qdv%f!l=soVH>KLC)^R$q^%E+ry(tn@5R
z(X(82_o{gqP$HXv0{j*q0G_9uUrNEknvxji|OP=sd{=0YGM4QW%oomF{Iw9E;t`R&a
z=hf_~(={?biD*-rj#4xf5`Js4khkQo4e4`Q1L!z+}LBN;v)`^{oNz3B4
zJq^62Sj%zctyU8MwIglYF69;_8}J1a&RdCnAv-c9m0o;Er*j8a5;v)YFZ1cZPtlnz
ztOA)9<=n~(7O^@_rryMc$6KmBmuek;LZ8VRPunq45Q6t)z%9x@Cu~16_r$i2PgE4$
z#mJOIel-HPuesFG*OzlMCFP0Fc8WD8=i<_(Fqj;4XK&AX>daYHRbRA}1kfvX=}W&z
zIs+s#k}UKaynliTPl^4kdz~IGKz%Da2|d!t$TT#)Uy&jXnpV&X*_1I2aUSo{j@NtT
zPSItK*2gmX987PH=y`N4ZW|M4vG!*8+hx`^=HxEalEx>34WPZZ#JqCCx+as-E{o+
z?FtQHke>T;oNf&&lPd-W#S8HAIt633QjSXiI>
z^l8<+k8h)%d0ej9?J`j@-SP}Csy%?`W>t}`d@#z^E$}R&rU{INxlQSwSTT4qzsje~aX4B{&cDUUo
z6q#|gvVrEaiZ1?U^ql(JFR$?FP9Es%pC|G4vBuj-e+#nI+2~G_WbuX*
zk13D5YIg$KBZ!2%`w!5O*-U!o_L$BhtG)F=SBJ{TD5?`oQUU|H2bG90Bg}jgn7NVQ
zWoNP8>a%+|R0W338K?;g!XOznRyH>GgXv-+nkr~3?SXmxz=2=0X3#P_(0@?(*loqU
z#FhPZ%}pqMxq)QsY?P1B;T98KK+J432Ya~%cu>Bs=`P5vhTXzq&&}}t!BG*dN0D^r
zORbdWOU`7Md>b1b!2k2wWyyD%-Jlb{b7$d^1}EDchnA7l
zfkgpv@qIA8UJJX{weo|MPTae+*p{1$cnKQ})@#8~3@pn&F3kVb)AKsy^Jn$N`T5le
zkBw%ssJnOHf6!Ue(sqR{R@EXW+6HenQatvtH)LfERlOC3qlNuoZCKpQ}Ryx1z24z&ht5X1gT7^|3SfVoloqkwqNUQ&s%-mh41q7nO_V%
zt$3ZAo2I*l3Hzi@Es_maLc-^iUUyfrl@230Q<7#$f|
z1lj?~hYyJW3QEfQjix@ud*K2MXGyDvC4`*q@YS0cVR(Qq%w$eVg$-Z;xZg^jQLkaA
zJIeKuIH6Z_P~!~<>_DzGOYsT|98#$%uS%};SWsKmAFkxi*T+vuyOONpML1ixHfO@_
zRmbwOF}Q@Sc~fp6CP=Tiue;pL^xN#^whZeNp9tC8%1r6yf@eIpo6}=ryq}#WSe5p~
zqvaSLi$<(v6$F4Kv?ce&4{pg=sU3GO&m-e++l)+F9M`u$|D<>-MMf*^eCElHmWH+b
z_>n^@H#z6nqhtE^fB^3tiWP{_#D;kzXISM3uETMC18&FE)47q6=F$6a?o6`=hGxOE
z(iSjYP)xYO%2mRi+B53d<08i76;>D>wbBG?azr=(tA^eY@0FS^$K85tR?{2Z8k3=-
zABV~^2
z6e(mdO0Rv=xIAeHGu~anNyHwJ5JWPyabSNgi=DF;DLl?FS3aBe9K2RkY>V`&CIFy*
zASCDK9cPR0eunTpIw{4aOb+#-QvQ)g4!IE&olX*Z$v;zdi;i$#)`&-`h;!&N&Q+%-
z?T3?J=2{6(uKApX*cca
zA8uB;e#9yLMqd-^&2US=LD8N~`r3ysPRm9K?zYfb+%%QqtWl#+i=36?)@(f#-F$8p
z`w@8b{65}?oR!_L{b<0{`d;JN@a|DAuT-n*cO&wzWh}%8S72%(3tX4(5mw-v-zOm<
zdg>i^O33cOeI+|TE);Y0xwWnPiEcBpo4N7-&-c{-?ThlQbDSH~N@aH*^zwlJ5dbL2s7M!~4gCKH;L*^k
literal 0
HcmV?d00001
From c624cce88e2ccf48a410056192992366499c757d Mon Sep 17 00:00:00 2001
From: IanShaw027
Date: Tue, 21 Apr 2026 15:56:30 +0800
Subject: [PATCH 120/189] fix: unblock auth identity compat backfill migration
---
...entity_compat_backfill_integration_test.go | 73 +++++++++++++++++++
.../internal/repository/migrations_runner.go | 6 ++
.../migrations_runner_checksum_test.go | 9 +++
.../109_auth_identity_compat_backfill.sql | 3 +
4 files changed, 91 insertions(+)
create mode 100644 backend/internal/repository/auth_identity_compat_backfill_integration_test.go
diff --git a/backend/internal/repository/auth_identity_compat_backfill_integration_test.go b/backend/internal/repository/auth_identity_compat_backfill_integration_test.go
new file mode 100644
index 00000000..56b37512
--- /dev/null
+++ b/backend/internal/repository/auth_identity_compat_backfill_integration_test.go
@@ -0,0 +1,73 @@
+//go:build integration
+
+package repository
+
+import (
+ "context"
+ "os"
+ "path/filepath"
+ "strconv"
+ "testing"
+
+ "github.com/stretchr/testify/require"
+)
+
+func TestAuthIdentityCompatBackfillMigration_AllowsLongReportTypes(t *testing.T) {
+ tx := testTx(t)
+ ctx := context.Background()
+
+ migration108Path := filepath.Join("..", "..", "migrations", "108_auth_identity_foundation_core.sql")
+ migration108SQL, err := os.ReadFile(migration108Path)
+ require.NoError(t, err)
+
+ migration109Path := filepath.Join("..", "..", "migrations", "109_auth_identity_compat_backfill.sql")
+ migration109SQL, err := os.ReadFile(migration109Path)
+ require.NoError(t, err)
+
+ _, err = tx.ExecContext(ctx, `
+DROP TABLE IF EXISTS auth_identity_migration_reports CASCADE;
+DROP TABLE IF EXISTS auth_identity_channels CASCADE;
+DROP TABLE IF EXISTS identity_adoption_decisions CASCADE;
+DROP TABLE IF EXISTS pending_auth_sessions CASCADE;
+DROP TABLE IF EXISTS auth_identities CASCADE;
+
+ALTER TABLE users
+ DROP COLUMN IF EXISTS signup_source,
+ DROP COLUMN IF EXISTS last_login_at,
+ DROP COLUMN IF EXISTS last_active_at;
+`)
+ require.NoError(t, err)
+
+ _, err = tx.ExecContext(ctx, string(migration108SQL))
+ require.NoError(t, err)
+
+ var userID int64
+ require.NoError(t, tx.QueryRowContext(ctx, `
+INSERT INTO users (email, password_hash, role, status, balance, concurrency)
+VALUES ('oidc-demo-subject@oidc-connect.invalid', 'hash', 'user', 'active', 0, 1)
+RETURNING id`).Scan(&userID))
+
+ _, err = tx.ExecContext(ctx, string(migration109SQL))
+ require.NoError(t, err)
+
+ var reportCount int
+ require.NoError(t, tx.QueryRowContext(ctx, `
+SELECT COUNT(*)
+FROM auth_identity_migration_reports
+WHERE report_type = 'oidc_synthetic_email_requires_manual_recovery'
+ AND report_key = $1
+`, strconv.FormatInt(userID, 10)).Scan(&reportCount))
+ require.Equal(t, 1, reportCount)
+
+ var reportTypeLimit int
+ require.NoError(t, tx.QueryRowContext(ctx, `
+SELECT character_maximum_length
+FROM information_schema.columns
+WHERE table_schema = 'public'
+ AND table_name = 'auth_identity_migration_reports'
+ AND column_name = 'report_type'
+`).Scan(&reportTypeLimit))
+ require.GreaterOrEqual(t, reportTypeLimit, 45)
+
+ require.NotZero(t, userID)
+}
diff --git a/backend/internal/repository/migrations_runner.go b/backend/internal/repository/migrations_runner.go
index 9cf3b392..5a2e6677 100644
--- a/backend/internal/repository/migrations_runner.go
+++ b/backend/internal/repository/migrations_runner.go
@@ -73,6 +73,12 @@ var migrationChecksumCompatibilityRules = map[string]migrationChecksumCompatibil
"222b4a09c797c22e5922b6b172327c824f5463aaa8760e4f621bc5c22e2be0f3": {},
},
},
+ "109_auth_identity_compat_backfill.sql": {
+ fileChecksum: "551e498aa5616d2d91096e9d72cf9fb36e418ee22eacc557f8811cadbc9e20ee",
+ acceptedDBChecksum: map[string]struct{}{
+ "2b380305e73ff0c13aa8c811e45897f2b36ca4a438f7b3e8f98e19ecb6bae0b3": {},
+ },
+ },
}
// ApplyMigrations 将嵌入的 SQL 迁移文件应用到指定的数据库。
diff --git a/backend/internal/repository/migrations_runner_checksum_test.go b/backend/internal/repository/migrations_runner_checksum_test.go
index 6c3ad725..6030991b 100644
--- a/backend/internal/repository/migrations_runner_checksum_test.go
+++ b/backend/internal/repository/migrations_runner_checksum_test.go
@@ -51,4 +51,13 @@ func TestIsMigrationChecksumCompatible(t *testing.T) {
)
require.False(t, ok)
})
+
+ t.Run("109历史checksum可兼容", func(t *testing.T) {
+ ok := isMigrationChecksumCompatible(
+ "109_auth_identity_compat_backfill.sql",
+ "2b380305e73ff0c13aa8c811e45897f2b36ca4a438f7b3e8f98e19ecb6bae0b3",
+ "551e498aa5616d2d91096e9d72cf9fb36e418ee22eacc557f8811cadbc9e20ee",
+ )
+ require.True(t, ok)
+ })
}
diff --git a/backend/migrations/109_auth_identity_compat_backfill.sql b/backend/migrations/109_auth_identity_compat_backfill.sql
index ddbbedbc..5147ae45 100644
--- a/backend/migrations/109_auth_identity_compat_backfill.sql
+++ b/backend/migrations/109_auth_identity_compat_backfill.sql
@@ -1,3 +1,6 @@
+ALTER TABLE auth_identity_migration_reports
+ALTER COLUMN report_type TYPE VARCHAR(80);
+
INSERT INTO auth_identities (
user_id,
provider_type,
From d08757ce9e06b0a893bbb5b40c129ddb5a18c9cb Mon Sep 17 00:00:00 2001
From: IanShaw027
Date: Tue, 21 Apr 2026 17:34:18 +0800
Subject: [PATCH 121/189] refactor(admin): remove auth migration reports
---
.../admin/admin_basic_handlers_test.go | 26 +-
.../handler/admin/admin_service_stub_test.go | 48 --
.../internal/handler/admin/user_handler.go | 64 --
.../admin/user_handler_activity_test.go | 6 -
.../handler/auth_oauth_pending_flow_test.go | 4 +
backend/internal/handler/dto/mappers.go | 1 -
backend/internal/handler/dto/types.go | 1 -
.../handler/dto/user_mapper_activity_test.go | 3 -
backend/internal/handler/user_handler_test.go | 6 +
.../auth_identity_migration_report.go | 148 ----
...ser_profile_identity_repo_contract_test.go | 31 -
backend/internal/repository/user_repo.go | 4 -
.../user_repo_sort_integration_test.go | 21 -
backend/internal/server/routes/admin.go | 3 -
backend/internal/service/admin_service.go | 205 ------
..._service_identity_migration_report_test.go | 124 ----
.../__tests__/users.migrationReports.spec.ts | 112 ---
frontend/src/api/admin/users.ts | 65 +-
frontend/src/components/layout/AppSidebar.vue | 6 -
.../layout/__tests__/AppSidebar.spec.ts | 6 -
frontend/src/i18n/locales/en.ts | 2 +-
frontend/src/i18n/locales/zh.ts | 4 +-
frontend/src/router/index.ts | 10 -
frontend/src/types/index.ts | 1 -
.../AuthIdentityMigrationReportsView.vue | 689 ------------------
frontend/src/views/admin/UsersView.vue | 2 +-
.../AuthIdentityMigrationReportsView.spec.ts | 303 --------
.../views/admin/__tests__/UsersView.spec.ts | 9 +-
28 files changed, 20 insertions(+), 1884 deletions(-)
delete mode 100644 backend/internal/repository/auth_identity_migration_report.go
delete mode 100644 backend/internal/service/admin_service_identity_migration_report_test.go
delete mode 100644 frontend/src/api/__tests__/users.migrationReports.spec.ts
delete mode 100644 frontend/src/views/admin/AuthIdentityMigrationReportsView.vue
delete mode 100644 frontend/src/views/admin/__tests__/AuthIdentityMigrationReportsView.spec.ts
diff --git a/backend/internal/handler/admin/admin_basic_handlers_test.go b/backend/internal/handler/admin/admin_basic_handlers_test.go
index 207931d9..ddeaab02 100644
--- a/backend/internal/handler/admin/admin_basic_handlers_test.go
+++ b/backend/internal/handler/admin/admin_basic_handlers_test.go
@@ -7,7 +7,6 @@ import (
"net/http/httptest"
"testing"
- servermiddleware "github.com/Wei-Shaw/sub2api/internal/server/middleware"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/require"
)
@@ -23,12 +22,6 @@ func setupAdminRouter() (*gin.Engine, *stubAdminService) {
redeemHandler := NewRedeemHandler(adminSvc, nil)
router.GET("/api/v1/admin/users", userHandler.List)
- router.GET("/api/v1/admin/users/auth-identity-migration-reports/summary", userHandler.GetAuthIdentityMigrationReportSummary)
- router.GET("/api/v1/admin/users/auth-identity-migration-reports", userHandler.ListAuthIdentityMigrationReports)
- router.POST("/api/v1/admin/users/auth-identity-migration-reports/:id/resolve", func(c *gin.Context) {
- c.Set(string(servermiddleware.ContextKeyUser), servermiddleware.AuthSubject{UserID: 99})
- userHandler.ResolveAuthIdentityMigrationReport(c)
- })
router.GET("/api/v1/admin/users/:id", userHandler.GetByID)
router.POST("/api/v1/admin/users/:id/auth-identities", userHandler.BindAuthIdentity)
router.POST("/api/v1/admin/users", userHandler.Create)
@@ -78,23 +71,6 @@ func TestUserHandlerEndpoints(t *testing.T) {
router.ServeHTTP(rec, req)
require.Equal(t, http.StatusOK, rec.Code)
- rec = httptest.NewRecorder()
- req = httptest.NewRequest(http.MethodGet, "/api/v1/admin/users/auth-identity-migration-reports/summary", nil)
- router.ServeHTTP(rec, req)
- require.Equal(t, http.StatusOK, rec.Code)
-
- rec = httptest.NewRecorder()
- req = httptest.NewRequest(http.MethodGet, "/api/v1/admin/users/auth-identity-migration-reports?report_type=oidc_synthetic_email_requires_manual_recovery", nil)
- router.ServeHTTP(rec, req)
- require.Equal(t, http.StatusOK, rec.Code)
-
- body, _ := json.Marshal(map[string]any{"resolution_note": "resolved by manual bind"})
- rec = httptest.NewRecorder()
- req = httptest.NewRequest(http.MethodPost, "/api/v1/admin/users/auth-identity-migration-reports/1/resolve", bytes.NewReader(body))
- req.Header.Set("Content-Type", "application/json")
- router.ServeHTTP(rec, req)
- require.Equal(t, http.StatusOK, rec.Code)
-
rec = httptest.NewRecorder()
req = httptest.NewRequest(http.MethodGet, "/api/v1/admin/users/1", nil)
router.ServeHTTP(rec, req)
@@ -111,7 +87,7 @@ func TestUserHandlerEndpoints(t *testing.T) {
"channel_subject": "openid-123",
},
}
- body, _ = json.Marshal(bindBody)
+ body, _ := json.Marshal(bindBody)
rec = httptest.NewRecorder()
req = httptest.NewRequest(http.MethodPost, "/api/v1/admin/users/1/auth-identities", bytes.NewReader(body))
req.Header.Set("Content-Type", "application/json")
diff --git a/backend/internal/handler/admin/admin_service_stub_test.go b/backend/internal/handler/admin/admin_service_stub_test.go
index dfdf327b..3a395342 100644
--- a/backend/internal/handler/admin/admin_service_stub_test.go
+++ b/backend/internal/handler/admin/admin_service_stub_test.go
@@ -17,7 +17,6 @@ type stubAdminService struct {
proxies []service.Proxy
proxyCounts []service.ProxyWithAccountCount
redeems []service.RedeemCode
- migrationReports []service.AuthIdentityMigrationReport
boundAuthIdentity *service.AdminBindAuthIdentityInput
boundAuthIdentityFor int64
createdAccounts []*service.CreateAccountInput
@@ -134,15 +133,6 @@ func newStubAdminService() *stubAdminService {
proxies: []service.Proxy{proxy},
proxyCounts: []service.ProxyWithAccountCount{{Proxy: proxy, AccountCount: 1}},
redeems: []service.RedeemCode{redeem},
- migrationReports: []service.AuthIdentityMigrationReport{
- {
- ID: 1,
- ReportType: "oidc_synthetic_email_requires_manual_recovery",
- ReportKey: "u-1",
- Details: map[string]any{"user_id": 1},
- CreatedAt: now,
- },
- },
}
}
@@ -193,30 +183,6 @@ func (s *stubAdminService) GetUserUsageStats(ctx context.Context, userID int64,
return map[string]any{"user_id": userID}, nil
}
-func (s *stubAdminService) ListAuthIdentityMigrationReports(ctx context.Context, reportType string, page, pageSize int) ([]service.AuthIdentityMigrationReport, int64, error) {
- if reportType == "" {
- return s.migrationReports, int64(len(s.migrationReports)), nil
- }
- filtered := make([]service.AuthIdentityMigrationReport, 0, len(s.migrationReports))
- for _, report := range s.migrationReports {
- if strings.EqualFold(report.ReportType, reportType) {
- filtered = append(filtered, report)
- }
- }
- return filtered, int64(len(filtered)), nil
-}
-
-func (s *stubAdminService) GetAuthIdentityMigrationReportSummary(ctx context.Context) (*service.AuthIdentityMigrationReportSummary, error) {
- summary := &service.AuthIdentityMigrationReportSummary{
- ByType: map[string]int64{},
- }
- for _, report := range s.migrationReports {
- summary.Total++
- summary.ByType[report.ReportType]++
- }
- return summary, nil
-}
-
func (s *stubAdminService) BindUserAuthIdentity(ctx context.Context, userID int64, input service.AdminBindAuthIdentityInput) (*service.AdminBoundAuthIdentity, error) {
s.boundAuthIdentityFor = userID
copied := input
@@ -263,20 +229,6 @@ func (s *stubAdminService) BindUserAuthIdentity(ctx context.Context, userID int6
return result, nil
}
-func (s *stubAdminService) ResolveAuthIdentityMigrationReport(ctx context.Context, reportID, resolvedByUserID int64, resolutionNote string) (*service.AuthIdentityMigrationReport, error) {
- now := time.Now().UTC()
- for i := range s.migrationReports {
- if s.migrationReports[i].ID != reportID {
- continue
- }
- s.migrationReports[i].ResolvedAt = &now
- s.migrationReports[i].ResolvedByUserID = &resolvedByUserID
- s.migrationReports[i].ResolutionNote = resolutionNote
- return &s.migrationReports[i], nil
- }
- return nil, nil
-}
-
func (s *stubAdminService) ListGroups(ctx context.Context, page, pageSize int, platform, status, search string, isExclusive *bool, sortBy, sortOrder string) ([]service.Group, int64, error) {
return s.groups, int64(len(s.groups)), nil
}
diff --git a/backend/internal/handler/admin/user_handler.go b/backend/internal/handler/admin/user_handler.go
index e582e322..b2ed9d18 100644
--- a/backend/internal/handler/admin/user_handler.go
+++ b/backend/internal/handler/admin/user_handler.go
@@ -7,7 +7,6 @@ import (
"github.com/Wei-Shaw/sub2api/internal/handler/dto"
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
- servermiddleware "github.com/Wei-Shaw/sub2api/internal/server/middleware"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin"
@@ -83,10 +82,6 @@ type BindUserAuthIdentityChannelRequest struct {
Metadata map[string]any `json:"metadata"`
}
-type ResolveAuthIdentityMigrationReportRequest struct {
- ResolutionNote string `json:"resolution_note"`
-}
-
// List handles listing all users with pagination
// GET /api/v1/admin/users
// Query params:
@@ -193,31 +188,6 @@ func (h *UserHandler) GetByID(c *gin.Context) {
response.Success(c, dto.UserFromServiceAdmin(user))
}
-// GetAuthIdentityMigrationReportSummary returns aggregate migration report counts.
-// GET /api/v1/admin/users/auth-identity-migration-reports/summary
-func (h *UserHandler) GetAuthIdentityMigrationReportSummary(c *gin.Context) {
- summary, err := h.adminService.GetAuthIdentityMigrationReportSummary(c.Request.Context())
- if err != nil {
- response.ErrorFrom(c, err)
- return
- }
- response.Success(c, summary)
-}
-
-// ListAuthIdentityMigrationReports returns paginated auth identity migration reports.
-// GET /api/v1/admin/users/auth-identity-migration-reports
-func (h *UserHandler) ListAuthIdentityMigrationReports(c *gin.Context) {
- page, pageSize := response.ParsePagination(c)
- reportType := strings.TrimSpace(c.Query("report_type"))
-
- reports, total, err := h.adminService.ListAuthIdentityMigrationReports(c.Request.Context(), reportType, page, pageSize)
- if err != nil {
- response.ErrorFrom(c, err)
- return
- }
- response.Paginated(c, reports, total, page, pageSize)
-}
-
// BindAuthIdentity manually binds a canonical auth identity to a user.
// POST /api/v1/admin/users/:id/auth-identities
func (h *UserHandler) BindAuthIdentity(c *gin.Context) {
@@ -257,40 +227,6 @@ func (h *UserHandler) BindAuthIdentity(c *gin.Context) {
response.Success(c, result)
}
-// ResolveAuthIdentityMigrationReport marks a migration report as resolved.
-// POST /api/v1/admin/users/auth-identity-migration-reports/:id/resolve
-func (h *UserHandler) ResolveAuthIdentityMigrationReport(c *gin.Context) {
- reportID, err := strconv.ParseInt(c.Param("id"), 10, 64)
- if err != nil {
- response.BadRequest(c, "Invalid report ID")
- return
- }
-
- subject, ok := servermiddleware.GetAuthSubjectFromContext(c)
- if !ok || subject.UserID <= 0 {
- response.Unauthorized(c, "Authentication required")
- return
- }
-
- var req ResolveAuthIdentityMigrationReportRequest
- if err := c.ShouldBindJSON(&req); err != nil {
- response.BadRequest(c, "Invalid request: "+err.Error())
- return
- }
-
- report, err := h.adminService.ResolveAuthIdentityMigrationReport(
- c.Request.Context(),
- reportID,
- subject.UserID,
- req.ResolutionNote,
- )
- if err != nil {
- response.ErrorFrom(c, err)
- return
- }
- response.Success(c, report)
-}
-
// Create handles creating a new user
// POST /api/v1/admin/users
func (h *UserHandler) Create(c *gin.Context) {
diff --git a/backend/internal/handler/admin/user_handler_activity_test.go b/backend/internal/handler/admin/user_handler_activity_test.go
index ac29086d..bfba2408 100644
--- a/backend/internal/handler/admin/user_handler_activity_test.go
+++ b/backend/internal/handler/admin/user_handler_activity_test.go
@@ -29,7 +29,6 @@ func TestUserHandlerListIncludesActivityFieldsAndSortParams(t *testing.T) {
Username: "activity-user",
Role: service.RoleUser,
Status: service.StatusActive,
- LastLoginAt: &lastLoginAt,
LastActiveAt: &lastActiveAt,
LastUsedAt: &lastUsedAt,
CreatedAt: lastLoginAt.Add(-24 * time.Hour),
@@ -57,7 +56,6 @@ func TestUserHandlerListIncludesActivityFieldsAndSortParams(t *testing.T) {
Code int `json:"code"`
Data struct {
Items []struct {
- LastLoginAt *time.Time `json:"last_login_at"`
LastActiveAt *time.Time `json:"last_active_at"`
LastUsedAt *time.Time `json:"last_used_at"`
} `json:"items"`
@@ -66,7 +64,6 @@ func TestUserHandlerListIncludesActivityFieldsAndSortParams(t *testing.T) {
require.NoError(t, json.Unmarshal(recorder.Body.Bytes(), &resp))
require.Equal(t, 0, resp.Code)
require.Len(t, resp.Data.Items, 1)
- require.WithinDuration(t, lastLoginAt, *resp.Data.Items[0].LastLoginAt, time.Second)
require.WithinDuration(t, lastActiveAt, *resp.Data.Items[0].LastActiveAt, time.Second)
require.WithinDuration(t, lastUsedAt, *resp.Data.Items[0].LastUsedAt, time.Second)
}
@@ -86,7 +83,6 @@ func TestUserHandlerGetByIDIncludesActivityFields(t *testing.T) {
Username: "detail-user",
Role: service.RoleUser,
Status: service.StatusActive,
- LastLoginAt: &lastLoginAt,
LastActiveAt: &lastActiveAt,
LastUsedAt: &lastUsedAt,
CreatedAt: lastLoginAt.Add(-24 * time.Hour),
@@ -107,14 +103,12 @@ func TestUserHandlerGetByIDIncludesActivityFields(t *testing.T) {
var resp struct {
Code int `json:"code"`
Data struct {
- LastLoginAt *time.Time `json:"last_login_at"`
LastActiveAt *time.Time `json:"last_active_at"`
LastUsedAt *time.Time `json:"last_used_at"`
} `json:"data"`
}
require.NoError(t, json.Unmarshal(recorder.Body.Bytes(), &resp))
require.Equal(t, 0, resp.Code)
- require.WithinDuration(t, lastLoginAt, *resp.Data.LastLoginAt, time.Second)
require.WithinDuration(t, lastActiveAt, *resp.Data.LastActiveAt, time.Second)
require.WithinDuration(t, lastUsedAt, *resp.Data.LastUsedAt, time.Second)
}
diff --git a/backend/internal/handler/auth_oauth_pending_flow_test.go b/backend/internal/handler/auth_oauth_pending_flow_test.go
index 008c9da2..a4a5c544 100644
--- a/backend/internal/handler/auth_oauth_pending_flow_test.go
+++ b/backend/internal/handler/auth_oauth_pending_flow_test.go
@@ -2293,6 +2293,10 @@ func (r *oauthPendingFlowUserRepo) Update(ctx context.Context, user *service.Use
return nil
}
+func (r *oauthPendingFlowUserRepo) UpdateUserLastActiveAt(ctx context.Context, userID int64, activeAt time.Time) error {
+ return r.client.User.UpdateOneID(userID).SetLastActiveAt(activeAt).Exec(ctx)
+}
+
func (r *oauthPendingFlowUserRepo) Delete(ctx context.Context, id int64) error {
if r.options.rejectDeleteWhileAuthIdentityExists {
count, err := r.client.AuthIdentity.Query().Where(authidentity.UserIDEQ(id)).Count(ctx)
diff --git a/backend/internal/handler/dto/mappers.go b/backend/internal/handler/dto/mappers.go
index d88c110c..9780ff79 100644
--- a/backend/internal/handler/dto/mappers.go
+++ b/backend/internal/handler/dto/mappers.go
@@ -21,7 +21,6 @@ func UserFromServiceShallow(u *service.User) *User {
Concurrency: u.Concurrency,
Status: u.Status,
AllowedGroups: u.AllowedGroups,
- LastLoginAt: u.LastLoginAt,
LastActiveAt: u.LastActiveAt,
CreatedAt: u.CreatedAt,
UpdatedAt: u.UpdatedAt,
diff --git a/backend/internal/handler/dto/types.go b/backend/internal/handler/dto/types.go
index 15b8548a..c0bce40b 100644
--- a/backend/internal/handler/dto/types.go
+++ b/backend/internal/handler/dto/types.go
@@ -15,7 +15,6 @@ type User struct {
Concurrency int `json:"concurrency"`
Status string `json:"status"`
AllowedGroups []int64 `json:"allowed_groups"`
- LastLoginAt *time.Time `json:"last_login_at,omitempty"`
LastActiveAt *time.Time `json:"last_active_at,omitempty"`
CreatedAt time.Time `json:"created_at"`
UpdatedAt time.Time `json:"updated_at"`
diff --git a/backend/internal/handler/dto/user_mapper_activity_test.go b/backend/internal/handler/dto/user_mapper_activity_test.go
index 1e362fba..a17f0ce4 100644
--- a/backend/internal/handler/dto/user_mapper_activity_test.go
+++ b/backend/internal/handler/dto/user_mapper_activity_test.go
@@ -21,16 +21,13 @@ func TestUserFromServiceAdmin_MapsActivityTimestamps(t *testing.T) {
Username: "admin",
Role: service.RoleAdmin,
Status: service.StatusActive,
- LastLoginAt: &lastLoginAt,
LastActiveAt: &lastActiveAt,
LastUsedAt: &lastUsedAt,
})
require.NotNil(t, out)
- require.NotNil(t, out.LastLoginAt)
require.NotNil(t, out.LastActiveAt)
require.NotNil(t, out.LastUsedAt)
- require.WithinDuration(t, lastLoginAt, *out.LastLoginAt, time.Second)
require.WithinDuration(t, lastActiveAt, *out.LastActiveAt, time.Second)
require.WithinDuration(t, lastUsedAt, *out.LastUsedAt, time.Second)
}
diff --git a/backend/internal/handler/user_handler_test.go b/backend/internal/handler/user_handler_test.go
index 24f715d4..8095ed57 100644
--- a/backend/internal/handler/user_handler_test.go
+++ b/backend/internal/handler/user_handler_test.go
@@ -99,6 +99,12 @@ func (s *userHandlerRepoStub) GetLatestUsedAtByUserIDs(context.Context, []int64)
func (s *userHandlerRepoStub) GetLatestUsedAtByUserID(context.Context, int64) (*time.Time, error) {
return nil, nil
}
+func (s *userHandlerRepoStub) UpdateUserLastActiveAt(_ context.Context, _ int64, activeAt time.Time) error {
+ if s.user != nil {
+ s.user.LastActiveAt = &activeAt
+ }
+ return nil
+}
func (s *userHandlerRepoStub) RemoveGroupFromUserAllowedGroups(context.Context, int64, int64) error {
return nil
}
diff --git a/backend/internal/repository/auth_identity_migration_report.go b/backend/internal/repository/auth_identity_migration_report.go
deleted file mode 100644
index 70f298c1..00000000
--- a/backend/internal/repository/auth_identity_migration_report.go
+++ /dev/null
@@ -1,148 +0,0 @@
-package repository
-
-import (
- "context"
- "database/sql"
- "encoding/json"
- "fmt"
- "strings"
- "time"
-)
-
-type AuthIdentityMigrationReport struct {
- ID int64
- ReportType string
- ReportKey string
- Details map[string]any
- CreatedAt time.Time
-}
-
-type AuthIdentityMigrationReportQuery struct {
- ReportType string
- Limit int
- Offset int
-}
-
-type AuthIdentityMigrationReportSummary struct {
- Total int64
- ByType map[string]int64
-}
-
-func (r *userRepository) ListAuthIdentityMigrationReports(ctx context.Context, query AuthIdentityMigrationReportQuery) ([]AuthIdentityMigrationReport, error) {
- exec := txAwareSQLExecutor(ctx, r.sql, r.client)
- if exec == nil {
- return nil, fmt.Errorf("sql executor is not configured")
- }
-
- limit := query.Limit
- if limit <= 0 {
- limit = 100
- }
- rows, err := exec.QueryContext(ctx, `
-SELECT id, report_type, report_key, details, created_at
-FROM auth_identity_migration_reports
-WHERE ($1 = '' OR report_type = $1)
-ORDER BY created_at DESC, id DESC
-LIMIT $2 OFFSET $3`,
- strings.TrimSpace(query.ReportType),
- limit,
- query.Offset,
- )
- if err != nil {
- return nil, err
- }
- defer func() { _ = rows.Close() }()
-
- reports := make([]AuthIdentityMigrationReport, 0)
- for rows.Next() {
- report, scanErr := scanAuthIdentityMigrationReport(rows)
- if scanErr != nil {
- return nil, scanErr
- }
- reports = append(reports, report)
- }
- if err := rows.Err(); err != nil {
- return nil, err
- }
- return reports, nil
-}
-
-func (r *userRepository) GetAuthIdentityMigrationReport(ctx context.Context, reportType, reportKey string) (*AuthIdentityMigrationReport, error) {
- exec := txAwareSQLExecutor(ctx, r.sql, r.client)
- if exec == nil {
- return nil, fmt.Errorf("sql executor is not configured")
- }
-
- rows, err := exec.QueryContext(ctx, `
-SELECT id, report_type, report_key, details, created_at
-FROM auth_identity_migration_reports
-WHERE report_type = $1 AND report_key = $2
-LIMIT 1`,
- strings.TrimSpace(reportType),
- strings.TrimSpace(reportKey),
- )
- if err != nil {
- return nil, err
- }
- defer func() { _ = rows.Close() }()
-
- if !rows.Next() {
- return nil, sql.ErrNoRows
- }
- report, err := scanAuthIdentityMigrationReport(rows)
- if err != nil {
- return nil, err
- }
- return &report, rows.Err()
-}
-
-func (r *userRepository) SummarizeAuthIdentityMigrationReports(ctx context.Context) (*AuthIdentityMigrationReportSummary, error) {
- exec := txAwareSQLExecutor(ctx, r.sql, r.client)
- if exec == nil {
- return nil, fmt.Errorf("sql executor is not configured")
- }
-
- rows, err := exec.QueryContext(ctx, `
-SELECT report_type, COUNT(*)
-FROM auth_identity_migration_reports
-GROUP BY report_type
-ORDER BY report_type ASC`)
- if err != nil {
- return nil, err
- }
- defer func() { _ = rows.Close() }()
-
- summary := &AuthIdentityMigrationReportSummary{
- ByType: make(map[string]int64),
- }
- for rows.Next() {
- var reportType string
- var count int64
- if err := rows.Scan(&reportType, &count); err != nil {
- return nil, err
- }
- summary.ByType[reportType] = count
- summary.Total += count
- }
- if err := rows.Err(); err != nil {
- return nil, err
- }
- return summary, nil
-}
-
-func scanAuthIdentityMigrationReport(scanner interface{ Scan(dest ...any) error }) (AuthIdentityMigrationReport, error) {
- var (
- report AuthIdentityMigrationReport
- details []byte
- )
- if err := scanner.Scan(&report.ID, &report.ReportType, &report.ReportKey, &details, &report.CreatedAt); err != nil {
- return AuthIdentityMigrationReport{}, err
- }
- report.Details = map[string]any{}
- if len(details) > 0 {
- if err := json.Unmarshal(details, &report.Details); err != nil {
- return AuthIdentityMigrationReport{}, err
- }
- }
- return report, nil
-}
diff --git a/backend/internal/repository/user_profile_identity_repo_contract_test.go b/backend/internal/repository/user_profile_identity_repo_contract_test.go
index a02af62b..d24a7d83 100644
--- a/backend/internal/repository/user_profile_identity_repo_contract_test.go
+++ b/backend/internal/repository/user_profile_identity_repo_contract_test.go
@@ -36,7 +36,6 @@ TRUNCATE TABLE
auth_identity_channels,
auth_identities,
pending_auth_sessions,
- auth_identity_migration_reports,
user_provider_default_grants,
user_avatars
RESTART IDENTITY`)
@@ -393,36 +392,6 @@ func (s *UserProfileIdentityRepoSuite) TestUserAvatarCRUDAndUserLookup() {
s.Require().Nil(loadedAvatar)
}
-func (s *UserProfileIdentityRepoSuite) TestAuthIdentityMigrationReportHelpers_ListAndSummarize() {
- _, err := integrationDB.ExecContext(s.ctx, `
-INSERT INTO auth_identity_migration_reports (report_type, report_key, details, created_at)
-VALUES
- ('wechat_openid_only_requires_remediation', 'u-1', '{"user_id":1}'::jsonb, '2026-04-20T10:00:00Z'),
- ('wechat_openid_only_requires_remediation', 'u-2', '{"user_id":2}'::jsonb, '2026-04-20T11:00:00Z'),
- ('oidc_synthetic_email_requires_manual_recovery', 'u-3', '{"user_id":3}'::jsonb, '2026-04-20T12:00:00Z')`)
- s.Require().NoError(err)
-
- summary, err := s.repo.SummarizeAuthIdentityMigrationReports(s.ctx)
- s.Require().NoError(err)
- s.Require().Equal(int64(3), summary.Total)
- s.Require().Equal(int64(2), summary.ByType["wechat_openid_only_requires_remediation"])
- s.Require().Equal(int64(1), summary.ByType["oidc_synthetic_email_requires_manual_recovery"])
-
- reports, err := s.repo.ListAuthIdentityMigrationReports(s.ctx, AuthIdentityMigrationReportQuery{
- ReportType: "wechat_openid_only_requires_remediation",
- Limit: 10,
- })
- s.Require().NoError(err)
- s.Require().Len(reports, 2)
- s.Require().Equal("u-2", reports[0].ReportKey)
- s.Require().Equal(float64(2), reports[0].Details["user_id"])
-
- report, err := s.repo.GetAuthIdentityMigrationReport(s.ctx, "oidc_synthetic_email_requires_manual_recovery", "u-3")
- s.Require().NoError(err)
- s.Require().Equal("u-3", report.ReportKey)
- s.Require().Equal(float64(3), report.Details["user_id"])
-}
-
func (s *UserProfileIdentityRepoSuite) TestUpdateUserLastLoginAndActiveAt_UsesDedicatedColumns() {
user := s.mustCreateUser("activity")
loginAt := time.Date(2026, 4, 20, 8, 0, 0, 0, time.UTC)
diff --git a/backend/internal/repository/user_repo.go b/backend/internal/repository/user_repo.go
index 195776a3..b6f907aa 100644
--- a/backend/internal/repository/user_repo.go
+++ b/backend/internal/repository/user_repo.go
@@ -445,10 +445,6 @@ func userListOrder(params pagination.PaginationParams) []func(*entsql.Selector)
case "created_at":
field = dbuser.FieldCreatedAt
defaultField = false
- case "last_login_at":
- field = dbuser.FieldLastLoginAt
- defaultField = false
- nullsLastField = true
case "last_active_at":
field = dbuser.FieldLastActiveAt
defaultField = false
diff --git a/backend/internal/repository/user_repo_sort_integration_test.go b/backend/internal/repository/user_repo_sort_integration_test.go
index e2445d5b..365a0402 100644
--- a/backend/internal/repository/user_repo_sort_integration_test.go
+++ b/backend/internal/repository/user_repo_sort_integration_test.go
@@ -95,27 +95,6 @@ func (s *UserRepoSuite) TestUpdate_PersistsSignupSourceAndActivityTimestamps() {
s.Require().True(got.LastActiveAt.Equal(lastActiveAt))
}
-func (s *UserRepoSuite) TestListWithFilters_SortByLastLoginAtDesc() {
- older := time.Now().Add(-4 * time.Hour).UTC().Truncate(time.Microsecond)
- newer := time.Now().Add(-1 * time.Hour).UTC().Truncate(time.Microsecond)
-
- s.mustCreateUser(&service.User{Email: "nil-login@example.com"})
- s.mustCreateUser(&service.User{Email: "older-login@example.com", LastLoginAt: &older})
- s.mustCreateUser(&service.User{Email: "newer-login@example.com", LastLoginAt: &newer})
-
- users, _, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{
- Page: 1,
- PageSize: 10,
- SortBy: "last_login_at",
- SortOrder: "desc",
- }, service.UserListFilters{})
- s.Require().NoError(err)
- s.Require().Len(users, 3)
- s.Require().Equal("newer-login@example.com", users[0].Email)
- s.Require().Equal("older-login@example.com", users[1].Email)
- s.Require().Equal("nil-login@example.com", users[2].Email)
-}
-
func (s *UserRepoSuite) TestListWithFilters_SortByLastActiveAtAsc() {
earlier := time.Now().Add(-3 * time.Hour).UTC().Truncate(time.Microsecond)
later := time.Now().Add(-45 * time.Minute).UTC().Truncate(time.Microsecond)
diff --git a/backend/internal/server/routes/admin.go b/backend/internal/server/routes/admin.go
index e5c0eac1..84c963ec 100644
--- a/backend/internal/server/routes/admin.go
+++ b/backend/internal/server/routes/admin.go
@@ -210,9 +210,6 @@ func registerDashboardRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
func registerUserManagementRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
users := admin.Group("/users")
{
- users.GET("/auth-identity-migration-reports/summary", h.Admin.User.GetAuthIdentityMigrationReportSummary)
- users.GET("/auth-identity-migration-reports", h.Admin.User.ListAuthIdentityMigrationReports)
- users.POST("/auth-identity-migration-reports/:id/resolve", h.Admin.User.ResolveAuthIdentityMigrationReport)
users.GET("", h.Admin.User.List)
users.GET("/:id", h.Admin.User.GetByID)
users.POST("/:id/auth-identities", h.Admin.User.BindAuthIdentity)
diff --git a/backend/internal/service/admin_service.go b/backend/internal/service/admin_service.go
index ce1c1a77..e0effa66 100644
--- a/backend/internal/service/admin_service.go
+++ b/backend/internal/service/admin_service.go
@@ -39,10 +39,7 @@ type AdminService interface {
// codeType is optional - pass empty string to return all types.
// Also returns totalRecharged (sum of all positive balance top-ups).
GetUserBalanceHistory(ctx context.Context, userID int64, page, pageSize int, codeType string) ([]RedeemCode, int64, float64, error)
- ListAuthIdentityMigrationReports(ctx context.Context, reportType string, page, pageSize int) ([]AuthIdentityMigrationReport, int64, error)
- GetAuthIdentityMigrationReportSummary(ctx context.Context) (*AuthIdentityMigrationReportSummary, error)
BindUserAuthIdentity(ctx context.Context, userID int64, input AdminBindAuthIdentityInput) (*AdminBoundAuthIdentity, error)
- ResolveAuthIdentityMigrationReport(ctx context.Context, reportID, resolvedByUserID int64, resolutionNote string) (*AuthIdentityMigrationReport, error)
// Group management
ListGroups(ctx context.Context, page, pageSize int, platform, status, search string, isExclusive *bool, sortBy, sortOrder string) ([]Group, int64, error)
@@ -137,24 +134,6 @@ type UpdateUserInput struct {
GroupRates map[int64]*float64
}
-type AuthIdentityMigrationReport struct {
- ID int64 `json:"id"`
- ReportType string `json:"report_type"`
- ReportKey string `json:"report_key"`
- Details map[string]any `json:"details"`
- CreatedAt time.Time `json:"created_at"`
- ResolvedAt *time.Time `json:"resolved_at,omitempty"`
- ResolvedByUserID *int64 `json:"resolved_by_user_id,omitempty"`
- ResolutionNote string `json:"resolution_note,omitempty"`
-}
-
-type AuthIdentityMigrationReportSummary struct {
- Total int64 `json:"total"`
- OpenTotal int64 `json:"open_total"`
- ResolvedTotal int64 `json:"resolved_total"`
- ByType map[string]int64 `json:"by_type"`
-}
-
type AdminBindAuthIdentityInput struct {
ProviderType string
ProviderKey string
@@ -874,152 +853,6 @@ func (s *adminServiceImpl) GetUserBalanceHistory(ctx context.Context, userID int
return codes, result.Total, totalRecharged, nil
}
-func (s *adminServiceImpl) ListAuthIdentityMigrationReports(ctx context.Context, reportType string, page, pageSize int) ([]AuthIdentityMigrationReport, int64, error) {
- db, err := s.adminSQLDB()
- if err != nil {
- return nil, 0, err
- }
-
- reportType = strings.TrimSpace(reportType)
- if page <= 0 {
- page = 1
- }
- if pageSize <= 0 {
- pageSize = 20
- }
- offset := (page - 1) * pageSize
-
- var total int64
- if err := db.QueryRowContext(ctx, `
-SELECT COUNT(*)
-FROM auth_identity_migration_reports
-WHERE ($1 = '' OR report_type = $1)`,
- reportType,
- ).Scan(&total); err != nil {
- return nil, 0, err
- }
-
- rows, err := db.QueryContext(ctx, `
-SELECT id, report_type, report_key, details, created_at, resolved_at, resolved_by_user_id, resolution_note
-FROM auth_identity_migration_reports
-WHERE ($1 = '' OR report_type = $1)
-ORDER BY created_at DESC, id DESC
-LIMIT $2 OFFSET $3`,
- reportType,
- pageSize,
- offset,
- )
- if err != nil {
- return nil, 0, err
- }
- defer func() { _ = rows.Close() }()
-
- reports := make([]AuthIdentityMigrationReport, 0)
- for rows.Next() {
- report, scanErr := scanAuthIdentityMigrationReport(rows)
- if scanErr != nil {
- return nil, 0, scanErr
- }
- reports = append(reports, report)
- }
- if err := rows.Err(); err != nil {
- return nil, 0, err
- }
- return reports, total, nil
-}
-
-func (s *adminServiceImpl) GetAuthIdentityMigrationReportSummary(ctx context.Context) (*AuthIdentityMigrationReportSummary, error) {
- db, err := s.adminSQLDB()
- if err != nil {
- return nil, err
- }
-
- rows, err := db.QueryContext(ctx, `
-SELECT
- report_type,
- COUNT(*),
- SUM(CASE WHEN resolved_at IS NULL THEN 1 ELSE 0 END),
- SUM(CASE WHEN resolved_at IS NOT NULL THEN 1 ELSE 0 END)
-FROM auth_identity_migration_reports
-GROUP BY report_type
-ORDER BY report_type ASC`)
- if err != nil {
- return nil, err
- }
- defer func() { _ = rows.Close() }()
-
- summary := &AuthIdentityMigrationReportSummary{
- ByType: make(map[string]int64),
- }
- for rows.Next() {
- var reportType string
- var count int64
- var openCount int64
- var resolvedCount int64
- if err := rows.Scan(&reportType, &count, &openCount, &resolvedCount); err != nil {
- return nil, err
- }
- summary.ByType[reportType] = count
- summary.Total += count
- summary.OpenTotal += openCount
- summary.ResolvedTotal += resolvedCount
- }
- if err := rows.Err(); err != nil {
- return nil, err
- }
- return summary, nil
-}
-
-func (s *adminServiceImpl) ResolveAuthIdentityMigrationReport(ctx context.Context, reportID, resolvedByUserID int64, resolutionNote string) (*AuthIdentityMigrationReport, error) {
- if reportID <= 0 {
- return nil, infraerrors.BadRequest("INVALID_INPUT", "report id must be greater than 0")
- }
- if resolvedByUserID <= 0 {
- return nil, infraerrors.BadRequest("INVALID_INPUT", "resolved_by_user_id must be greater than 0")
- }
-
- db, err := s.adminSQLDB()
- if err != nil {
- return nil, err
- }
-
- now := time.Now().UTC()
- result, err := db.ExecContext(ctx, `
-UPDATE auth_identity_migration_reports
-SET
- resolved_at = COALESCE(resolved_at, $2),
- resolved_by_user_id = COALESCE(resolved_by_user_id, $3),
- resolution_note = $4
-WHERE id = $1`,
- reportID,
- now,
- resolvedByUserID,
- strings.TrimSpace(resolutionNote),
- )
- if err != nil {
- return nil, err
- }
- affected, err := result.RowsAffected()
- if err != nil {
- return nil, err
- }
- if affected == 0 {
- return nil, infraerrors.NotFound("AUTH_IDENTITY_MIGRATION_REPORT_NOT_FOUND", "auth identity migration report not found")
- }
-
- row := db.QueryRowContext(ctx, `
-SELECT id, report_type, report_key, details, created_at, resolved_at, resolved_by_user_id, resolution_note
-FROM auth_identity_migration_reports
-WHERE id = $1`,
- reportID,
- )
- report, err := scanAuthIdentityMigrationReport(row)
- if err != nil {
- return nil, err
- }
- return &report, nil
-}
-
func (s *adminServiceImpl) BindUserAuthIdentity(ctx context.Context, userID int64, input AdminBindAuthIdentityInput) (*AdminBoundAuthIdentity, error) {
if userID <= 0 {
return nil, infraerrors.BadRequest("INVALID_INPUT", "user_id must be greater than 0")
@@ -1252,44 +1085,6 @@ func cloneAdminAuthIdentityMetadata(input map[string]any) map[string]any {
return out
}
-func scanAuthIdentityMigrationReport(scanner interface{ Scan(dest ...any) error }) (AuthIdentityMigrationReport, error) {
- var (
- report AuthIdentityMigrationReport
- details []byte
- resolvedAt sql.NullTime
- resolvedByUserID sql.NullInt64
- resolutionNote sql.NullString
- )
- if err := scanner.Scan(
- &report.ID,
- &report.ReportType,
- &report.ReportKey,
- &details,
- &report.CreatedAt,
- &resolvedAt,
- &resolvedByUserID,
- &resolutionNote,
- ); err != nil {
- return AuthIdentityMigrationReport{}, err
- }
- report.Details = map[string]any{}
- if len(details) > 0 {
- if err := json.Unmarshal(details, &report.Details); err != nil {
- return AuthIdentityMigrationReport{}, err
- }
- }
- if resolvedAt.Valid {
- report.ResolvedAt = &resolvedAt.Time
- }
- if resolvedByUserID.Valid {
- report.ResolvedByUserID = &resolvedByUserID.Int64
- }
- if resolutionNote.Valid {
- report.ResolutionNote = resolutionNote.String
- }
- return report, nil
-}
-
// Group management implementations
func (s *adminServiceImpl) ListGroups(ctx context.Context, page, pageSize int, platform, status, search string, isExclusive *bool, sortBy, sortOrder string) ([]Group, int64, error) {
params := pagination.PaginationParams{Page: page, PageSize: pageSize, SortBy: sortBy, SortOrder: sortOrder}
diff --git a/backend/internal/service/admin_service_identity_migration_report_test.go b/backend/internal/service/admin_service_identity_migration_report_test.go
deleted file mode 100644
index 6975604b..00000000
--- a/backend/internal/service/admin_service_identity_migration_report_test.go
+++ /dev/null
@@ -1,124 +0,0 @@
-package service
-
-import (
- "context"
- "database/sql"
- "testing"
- "time"
-
- dbent "github.com/Wei-Shaw/sub2api/ent"
- "github.com/Wei-Shaw/sub2api/ent/enttest"
- "github.com/stretchr/testify/require"
-
- "entgo.io/ent/dialect"
- entsql "entgo.io/ent/dialect/sql"
- _ "modernc.org/sqlite"
-)
-
-func newAdminServiceMigrationReportTestClient(t *testing.T) *dbent.Client {
- t.Helper()
-
- db, err := sql.Open("sqlite", "file:admin_service_migration_reports?mode=memory&cache=shared&_fk=1")
- require.NoError(t, err)
- t.Cleanup(func() { _ = db.Close() })
-
- _, err = db.Exec("PRAGMA foreign_keys = ON")
- require.NoError(t, err)
-
- _, err = db.Exec(`CREATE TABLE auth_identity_migration_reports (
- id INTEGER PRIMARY KEY AUTOINCREMENT,
- report_type TEXT NOT NULL,
- report_key TEXT NOT NULL,
- details TEXT NOT NULL DEFAULT '{}',
- created_at DATETIME NOT NULL,
- resolved_at DATETIME NULL,
- resolved_by_user_id INTEGER NULL,
- resolution_note TEXT NOT NULL DEFAULT ''
- )`)
- require.NoError(t, err)
-
- drv := entsql.OpenDB(dialect.SQLite, db)
- client := enttest.NewClient(t, enttest.WithOptions(dbent.Driver(drv)))
- t.Cleanup(func() { _ = client.Close() })
- return client
-}
-
-func TestAdminServiceListAuthIdentityMigrationReports(t *testing.T) {
- client := newAdminServiceMigrationReportTestClient(t)
- driver, ok := client.Driver().(*entsql.Driver)
- require.True(t, ok)
-
- now := time.Now().UTC()
- _, err := driver.DB().ExecContext(context.Background(), `
-INSERT INTO auth_identity_migration_reports (report_type, report_key, details, created_at)
-VALUES
- ($1, $2, $3, $4),
- ($5, $6, $7, $8)`,
- "oidc_synthetic_email_requires_manual_recovery", "u-1", `{"user_id":1}`, now,
- "wechat_provider_key_conflict", "u-2", `{"user_id":2}`, now.Add(-time.Minute),
- )
- require.NoError(t, err)
-
- svc := &adminServiceImpl{entClient: client}
- reports, total, err := svc.ListAuthIdentityMigrationReports(context.Background(), "oidc_synthetic_email_requires_manual_recovery", 1, 20)
- require.NoError(t, err)
- require.Equal(t, int64(1), total)
- require.Len(t, reports, 1)
- require.Equal(t, "oidc_synthetic_email_requires_manual_recovery", reports[0].ReportType)
- require.Equal(t, float64(1), reports[0].Details["user_id"])
-}
-
-func TestAdminServiceGetAuthIdentityMigrationReportSummary(t *testing.T) {
- client := newAdminServiceMigrationReportTestClient(t)
- driver, ok := client.Driver().(*entsql.Driver)
- require.True(t, ok)
-
- now := time.Now().UTC()
- _, err := driver.DB().ExecContext(context.Background(), `
-INSERT INTO auth_identity_migration_reports (report_type, report_key, details, created_at)
-VALUES
- ($1, $2, $3, $4),
- ($5, $6, $7, $8),
- ($9, $10, $11, $12)`,
- "oidc_synthetic_email_requires_manual_recovery", "u-1", `{"user_id":1}`, now,
- "wechat_provider_key_conflict", "u-2", `{"user_id":2}`, now.Add(-time.Minute),
- "wechat_provider_key_conflict", "u-3", `{"user_id":3}`, now.Add(-2*time.Minute),
- )
- require.NoError(t, err)
-
- svc := &adminServiceImpl{entClient: client}
- summary, err := svc.GetAuthIdentityMigrationReportSummary(context.Background())
- require.NoError(t, err)
- require.Equal(t, int64(3), summary.Total)
- require.Equal(t, int64(3), summary.OpenTotal)
- require.Zero(t, summary.ResolvedTotal)
- require.Equal(t, int64(1), summary.ByType["oidc_synthetic_email_requires_manual_recovery"])
- require.Equal(t, int64(2), summary.ByType["wechat_provider_key_conflict"])
-}
-
-func TestAdminServiceResolveAuthIdentityMigrationReport(t *testing.T) {
- client := newAdminServiceMigrationReportTestClient(t)
- driver, ok := client.Driver().(*entsql.Driver)
- require.True(t, ok)
-
- now := time.Now().UTC()
- _, err := driver.DB().ExecContext(context.Background(), `
-INSERT INTO auth_identity_migration_reports (report_type, report_key, details, created_at)
-VALUES ($1, $2, $3, $4)`,
- "oidc_synthetic_email_requires_manual_recovery", "u-1", `{"user_id":1}`, now,
- )
- require.NoError(t, err)
-
- svc := &adminServiceImpl{entClient: client}
- report, err := svc.ResolveAuthIdentityMigrationReport(context.Background(), 1, 99, "resolved by admin binding")
- require.NoError(t, err)
- require.NotNil(t, report.ResolvedAt)
- require.NotNil(t, report.ResolvedByUserID)
- require.Equal(t, int64(99), *report.ResolvedByUserID)
- require.Equal(t, "resolved by admin binding", report.ResolutionNote)
-
- summary, err := svc.GetAuthIdentityMigrationReportSummary(context.Background())
- require.NoError(t, err)
- require.Zero(t, summary.OpenTotal)
- require.Equal(t, int64(1), summary.ResolvedTotal)
-}
diff --git a/frontend/src/api/__tests__/users.migrationReports.spec.ts b/frontend/src/api/__tests__/users.migrationReports.spec.ts
deleted file mode 100644
index ca889600..00000000
--- a/frontend/src/api/__tests__/users.migrationReports.spec.ts
+++ /dev/null
@@ -1,112 +0,0 @@
-import { beforeEach, describe, expect, it, vi } from 'vitest'
-
-const { get, post } = vi.hoisted(() => ({
- get: vi.fn(),
- post: vi.fn(),
-}))
-
-vi.mock('@/api/client', () => ({
- apiClient: {
- get,
- post,
- },
-}))
-
-import {
- bindUserAuthIdentity,
- getAuthIdentityMigrationReportSummary,
- listAuthIdentityMigrationReports,
- resolveAuthIdentityMigrationReport,
-} from '@/api/admin/users'
-
-describe('admin users auth identity migration reports API', () => {
- beforeEach(() => {
- get.mockReset()
- post.mockReset()
- })
-
- it('lists migration reports with pagination and report type filter', async () => {
- const response = {
- items: [],
- total: 0,
- page: 2,
- page_size: 10,
- pages: 0,
- }
- get.mockResolvedValue({ data: response })
-
- const result = await listAuthIdentityMigrationReports({
- page: 2,
- pageSize: 10,
- reportType: 'oidc_synthetic_email_requires_manual_recovery',
- })
-
- expect(get).toHaveBeenCalledWith('/admin/users/auth-identity-migration-reports', {
- params: {
- page: 2,
- page_size: 10,
- report_type: 'oidc_synthetic_email_requires_manual_recovery',
- },
- })
- expect(result).toBe(response)
- })
-
- it('loads migration report summary', async () => {
- const response = {
- total: 2,
- open_total: 1,
- resolved_total: 1,
- by_type: {
- oidc_synthetic_email_requires_manual_recovery: 2,
- },
- }
- get.mockResolvedValue({ data: response })
-
- const result = await getAuthIdentityMigrationReportSummary()
-
- expect(get).toHaveBeenCalledWith('/admin/users/auth-identity-migration-reports/summary')
- expect(result).toBe(response)
- })
-
- it('submits report resolution note', async () => {
- const response = {
- id: 7,
- resolution_note: 'resolved by admin',
- }
- post.mockResolvedValue({ data: response })
-
- const result = await resolveAuthIdentityMigrationReport(7, 'resolved by admin')
-
- expect(post).toHaveBeenCalledWith('/admin/users/auth-identity-migration-reports/7/resolve', {
- resolution_note: 'resolved by admin',
- })
- expect(result).toBe(response)
- })
-
- it('binds a canonical auth identity to a user for remediation', async () => {
- const response = {
- identity_id: 11,
- provider_type: 'oidc',
- provider_key: 'https://issuer.example',
- provider_subject: 'subject-123',
- }
- post.mockResolvedValue({ data: response })
-
- const result = await bindUserAuthIdentity(42, {
- provider_type: 'oidc',
- provider_key: 'https://issuer.example',
- provider_subject: 'subject-123',
- issuer: 'https://issuer.example',
- metadata: { source: 'migration-report' },
- })
-
- expect(post).toHaveBeenCalledWith('/admin/users/42/auth-identities', {
- provider_type: 'oidc',
- provider_key: 'https://issuer.example',
- provider_subject: 'subject-123',
- issuer: 'https://issuer.example',
- metadata: { source: 'migration-report' },
- })
- expect(result).toBe(response)
- })
-})
diff --git a/frontend/src/api/admin/users.ts b/frontend/src/api/admin/users.ts
index be30eddc..1bb3d54c 100644
--- a/frontend/src/api/admin/users.ts
+++ b/frontend/src/api/admin/users.ts
@@ -6,24 +6,6 @@
import { apiClient } from '../client'
import type { AdminUser, UpdateUserRequest, PaginatedResponse, ApiKey } from '@/types'
-export interface AuthIdentityMigrationReport {
- id: number
- report_type: string
- report_key: string
- details: Record
- created_at: string
- resolved_at?: string | null
- resolved_by_user_id?: number | null
- resolution_note?: string
-}
-
-export interface AuthIdentityMigrationReportSummary {
- total: number
- open_total: number
- resolved_total: number
- by_type: Record
-}
-
export interface AdminBindAuthIdentityChannelRequest {
channel: string
channel_app_id?: string
@@ -48,12 +30,6 @@ export interface AdminBoundAuthIdentity {
channel_id?: number | null
}
-export interface ListAuthIdentityMigrationReportsParams {
- page?: number
- pageSize?: number
- reportType?: string
-}
-
/**
* List all users with pagination
* @param page - Page number (default: 1)
@@ -296,42 +272,6 @@ export async function replaceGroup(
return data
}
-export async function getAuthIdentityMigrationReportSummary(): Promise {
- const { data } = await apiClient.get(
- '/admin/users/auth-identity-migration-reports/summary'
- )
- return data
-}
-
-export async function listAuthIdentityMigrationReports(
- params: ListAuthIdentityMigrationReportsParams = {}
-): Promise> {
- const { data } = await apiClient.get>(
- '/admin/users/auth-identity-migration-reports',
- {
- params: {
- page: params.page ?? 1,
- page_size: params.pageSize ?? 20,
- report_type: params.reportType ?? ''
- }
- }
- )
- return data
-}
-
-export async function resolveAuthIdentityMigrationReport(
- id: number,
- resolutionNote: string
-): Promise {
- const { data } = await apiClient.post(
- `/admin/users/auth-identity-migration-reports/${id}/resolve`,
- {
- resolution_note: resolutionNote
- }
- )
- return data
-}
-
export async function bindUserAuthIdentity(
userId: number,
input: AdminBindAuthIdentityRequest
@@ -356,10 +296,7 @@ export const usersAPI = {
getUserUsageStats,
getUserBalanceHistory,
replaceGroup,
- bindUserAuthIdentity,
- getAuthIdentityMigrationReportSummary,
- listAuthIdentityMigrationReports,
- resolveAuthIdentityMigrationReport
+ bindUserAuthIdentity
}
export default usersAPI
diff --git a/frontend/src/components/layout/AppSidebar.vue b/frontend/src/components/layout/AppSidebar.vue
index b7158d9b..92dcc519 100644
--- a/frontend/src/components/layout/AppSidebar.vue
+++ b/frontend/src/components/layout/AppSidebar.vue
@@ -663,12 +663,6 @@ const adminNavItems = computed((): NavItem[] => {
? [{ path: '/admin/ops', label: t('nav.ops'), icon: ChartIcon }]
: []),
{ path: '/admin/users', label: t('nav.users'), icon: UsersIcon, hideInSimpleMode: true },
- {
- path: '/admin/users/auth-identity-migration-reports',
- label: 'Migration Reports',
- icon: UsersIcon,
- hideInSimpleMode: true
- },
{ path: '/admin/groups', label: t('nav.groups'), icon: FolderIcon, hideInSimpleMode: true },
{ path: '/admin/channels', label: t('nav.channels', '渠道管理'), icon: ChannelIcon, hideInSimpleMode: true },
{ path: '/admin/subscriptions', label: t('nav.subscriptions'), icon: CreditCardIcon, hideInSimpleMode: true },
diff --git a/frontend/src/components/layout/__tests__/AppSidebar.spec.ts b/frontend/src/components/layout/__tests__/AppSidebar.spec.ts
index 915a67f8..118c7615 100644
--- a/frontend/src/components/layout/__tests__/AppSidebar.spec.ts
+++ b/frontend/src/components/layout/__tests__/AppSidebar.spec.ts
@@ -30,9 +30,3 @@ describe('AppSidebar header styles', () => {
expect(sidebarBrandBlockMatch?.[0]).not.toContain('overflow: hidden;')
})
})
-
-describe('AppSidebar admin navigation', () => {
- it('includes a visible entry for auth identity migration reports', () => {
- expect(componentSource).toContain("'/admin/users/auth-identity-migration-reports'")
- })
-})
diff --git a/frontend/src/i18n/locales/en.ts b/frontend/src/i18n/locales/en.ts
index 345770a8..4c948f64 100644
--- a/frontend/src/i18n/locales/en.ts
+++ b/frontend/src/i18n/locales/en.ts
@@ -1438,8 +1438,8 @@ export default {
usage: 'Usage',
concurrency: 'Concurrency',
status: 'Status',
- lastLogin: 'Last Login',
lastActive: 'Last Active',
+ lastUsed: 'Last Used',
created: 'Created',
actions: 'Actions'
},
diff --git a/frontend/src/i18n/locales/zh.ts b/frontend/src/i18n/locales/zh.ts
index 6493ffe8..9e6c0d58 100644
--- a/frontend/src/i18n/locales/zh.ts
+++ b/frontend/src/i18n/locales/zh.ts
@@ -1464,8 +1464,8 @@ export default {
usage: '用量',
concurrency: '并发数',
status: '状态',
- lastLogin: '最后登录',
- lastActive: '最后使用',
+ lastActive: '最后活跃时间',
+ lastUsed: '最后使用时间',
created: '创建时间',
actions: '操作'
},
diff --git a/frontend/src/router/index.ts b/frontend/src/router/index.ts
index 2bace626..0d217139 100644
--- a/frontend/src/router/index.ts
+++ b/frontend/src/router/index.ts
@@ -341,16 +341,6 @@ const routes: RouteRecordRaw[] = [
descriptionKey: 'admin.users.description'
}
},
- {
- path: '/admin/users/auth-identity-migration-reports',
- name: 'AdminAuthIdentityMigrationReports',
- component: () => import('@/views/admin/AuthIdentityMigrationReportsView.vue'),
- meta: {
- requiresAuth: true,
- requiresAdmin: true,
- title: 'Auth Identity Migration Reports'
- }
- },
{
path: '/admin/groups',
name: 'AdminGroups',
diff --git a/frontend/src/types/index.ts b/frontend/src/types/index.ts
index bfc11cb2..9f13f8a0 100644
--- a/frontend/src/types/index.ts
+++ b/frontend/src/types/index.ts
@@ -84,7 +84,6 @@ export interface User {
balance_notify_threshold: number | null
balance_notify_extra_emails: NotifyEmailEntry[]
subscriptions?: UserSubscription[] // User's active subscriptions
- last_login_at?: string | null
last_active_at?: string | null
created_at: string
updated_at: string
diff --git a/frontend/src/views/admin/AuthIdentityMigrationReportsView.vue b/frontend/src/views/admin/AuthIdentityMigrationReportsView.vue
deleted file mode 100644
index 12fbf1b9..00000000
--- a/frontend/src/views/admin/AuthIdentityMigrationReportsView.vue
+++ /dev/null
@@ -1,689 +0,0 @@
-
-
-
-
-
-
- {{ copy.total }}
-
-
- {{ summary.total }}
-
-
-
-
- {{ copy.open }}
-
-
- {{ summary.open_total }}
-
-
-
-
- {{ copy.resolved }}
-
-
- {{ summary.resolved_total }}
-
-
-
-
-
-
-
-
-
- {{ copy.title }}
-
-
- {{ copy.subtitle }}
-
-
-
-
-
-
-
-
-
-
-
- {{ copy.reportType }}
-
- {{ copy.allReportTypes }}
-
- {{ option.label }}
-
-
-
-
-
-
-
-
-
-
- {{ row.resolved_at ? copy.resolvedBadge : copy.openBadge }}
-
-
-
-
- {{ value }}
-
-
-
- {{ value }}
-
-
-
-
-
- {{ entry.key }}: {{ entry.value }}
-
-
-
-
-
- {{ formatDateTime(value) }}
-
-
-
-
- {{ value ? formatDateTime(value) : copy.notResolved }}
-
-
-
-
-
- {{ copy.viewDetails }}
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
- {{ copy.detailTitle }}
-
-
- {{ selectedReport ? selectedReport.report_key : copy.selectPrompt }}
-
-
-
- {{ selectedReport.resolved_at ? copy.resolvedBadge : copy.openBadge }}
-
-
-
-
-
-
-
{{ copy.reportType }}
- {{ selectedReport.report_type }}
-
-
-
{{ copy.reportKey }}
- {{ selectedReport.report_key }}
-
-
-
{{ copy.createdAt }}
- {{ formatDateTime(selectedReport.created_at) }}
-
-
-
{{ copy.resolvedAt }}
-
- {{ selectedReport.resolved_at ? formatDateTime(selectedReport.resolved_at) : copy.notResolved }}
-
-
-
-
{{ copy.resolvedBy }}
- {{ selectedReport.resolved_by_user_id ?? '-' }}
-
-
-
{{ copy.resolutionNote }}
-
- {{ selectedReport.resolution_note || copy.emptyResolutionNote }}
-
-
-
-
-
-
{{ copy.keyFields }}
-
-
- {{ entry.key }}: {{ entry.value }}
-
-
-
-
-
-
{{ copy.rawDetails }}
-
{{ formatDetailsJson(selectedReport.details) }}
-
-
-
-
- {{ copy.selectPrompt }}
-
-
-
-
-
- {{ copy.resolveTitle }}
-
-
- {{ copy.resolveSubtitle }}
-
-
-
-
- {{ copy.resolutionNote }}
-
-
-
-
- {{ resolving ? copy.resolving : copy.resolveAction }}
-
-
-
-
-
- {{ copy.remediationTitle }}
-
-
- {{ copy.remediationSubtitle }}
-
-
-
-
-
-
-
-
-
-
-
diff --git a/frontend/src/views/admin/UsersView.vue b/frontend/src/views/admin/UsersView.vue
index 93cfdbbe..39c9b377 100644
--- a/frontend/src/views/admin/UsersView.vue
+++ b/frontend/src/views/admin/UsersView.vue
@@ -712,8 +712,8 @@ const allColumns = computed(() => [
{ key: 'usage', label: t('admin.users.columns.usage'), sortable: false },
{ key: 'concurrency', label: t('admin.users.columns.concurrency'), sortable: true },
{ key: 'status', label: t('admin.users.columns.status'), sortable: true },
- { key: 'last_used_at', label: t('admin.users.columns.lastUsed'), sortable: true },
{ key: 'last_active_at', label: t('admin.users.columns.lastActive'), sortable: true },
+ { key: 'last_used_at', label: t('admin.users.columns.lastUsed'), sortable: true },
{ key: 'created_at', label: t('admin.users.columns.created'), sortable: true },
{ key: 'actions', label: t('admin.users.columns.actions'), sortable: false }
])
diff --git a/frontend/src/views/admin/__tests__/AuthIdentityMigrationReportsView.spec.ts b/frontend/src/views/admin/__tests__/AuthIdentityMigrationReportsView.spec.ts
deleted file mode 100644
index 20f57fa1..00000000
--- a/frontend/src/views/admin/__tests__/AuthIdentityMigrationReportsView.spec.ts
+++ /dev/null
@@ -1,303 +0,0 @@
-import { beforeEach, describe, expect, it, vi } from 'vitest'
-import { flushPromises, mount } from '@vue/test-utils'
-import { defineComponent, h } from 'vue'
-
-import AuthIdentityMigrationReportsView from '../AuthIdentityMigrationReportsView.vue'
-
-const {
- bindUserAuthIdentity,
- getAuthIdentityMigrationReportSummary,
- listAuthIdentityMigrationReports,
- resolveAuthIdentityMigrationReport,
-} = vi.hoisted(() => ({
- bindUserAuthIdentity: vi.fn(),
- getAuthIdentityMigrationReportSummary: vi.fn(),
- listAuthIdentityMigrationReports: vi.fn(),
- resolveAuthIdentityMigrationReport: vi.fn(),
-}))
-
-const { showError, showSuccess } = vi.hoisted(() => ({
- showError: vi.fn(),
- showSuccess: vi.fn(),
-}))
-
-vi.mock('@/api/admin', () => ({
- adminAPI: {
- users: {
- bindUserAuthIdentity,
- getAuthIdentityMigrationReportSummary,
- listAuthIdentityMigrationReports,
- resolveAuthIdentityMigrationReport,
- },
- },
-}))
-
-vi.mock('@/stores/app', () => ({
- useAppStore: () => ({
- showError,
- showSuccess,
- }),
-}))
-
-vi.mock('vue-i18n', async () => {
- const actual = await vi.importActual('vue-i18n')
- return {
- ...actual,
- useI18n: () => ({
- locale: { value: 'en' },
- t: (key: string) => key,
- }),
- }
-})
-
-vi.mock('@/utils/format', () => ({
- formatDateTime: (value: string | null | undefined) => value ?? '',
-}))
-
-const sampleReport = {
- id: 1,
- report_type: 'oidc_synthetic_email_requires_manual_recovery',
- report_key: 'legacy@example.invalid',
- details: {
- user_id: 42,
- legacy_email: 'legacy@example.invalid',
- provider_key: 'https://issuer.example',
- provider_subject: 'subject-123',
- },
- created_at: '2026-04-20T01:02:03Z',
- resolved_at: null,
- resolved_by_user_id: null,
- resolution_note: '',
-}
-
-const summaryResponse = {
- total: 2,
- open_total: 1,
- resolved_total: 1,
- by_type: {
- oidc_synthetic_email_requires_manual_recovery: 2,
- },
-}
-
-const listResponse = {
- items: [sampleReport],
- total: 1,
- page: 1,
- page_size: 20,
- pages: 1,
-}
-
-const AppLayoutStub = defineComponent({
- setup(_, { slots }) {
- return () => h('div', slots.default?.())
- },
-})
-
-const TablePageLayoutStub = defineComponent({
- setup(_, { slots }) {
- return () => h('div', [
- slots.actions?.(),
- slots.filters?.(),
- slots.table?.(),
- slots.default?.(),
- slots.pagination?.(),
- ])
- },
-})
-
-const DataTableStub = defineComponent({
- props: {
- columns: { type: Array, default: () => [] },
- data: { type: Array, default: () => [] },
- loading: { type: Boolean, default: false },
- },
- setup(props, { slots }) {
- return () => h('div', { 'data-test': 'data-table' }, [
- props.loading
- ? h('div', 'loading')
- : (props.data as Array>).map((row) =>
- h(
- 'div',
- { key: String(row.id ?? row.report_key) },
- (props.columns as Array<{ key: string }>).map((column) => {
- const slot = slots[`cell-${column.key}`]
- return h(
- 'div',
- { key: column.key, [`data-test-cell`]: `${String(row.id)}-${column.key}` },
- slot
- ? slot({ row, value: row[column.key] })
- : String(row[column.key] ?? '')
- )
- })
- )
- ),
- ])
- },
-})
-
-const PaginationStub = defineComponent({
- props: {
- total: { type: Number, required: true },
- page: { type: Number, required: true },
- pageSize: { type: Number, required: true },
- },
- emits: ['update:page', 'update:pageSize'],
- setup(props, { emit }) {
- return () => h('div', { 'data-test': 'pagination' }, [
- h('button', {
- type: 'button',
- 'data-test': 'next-page',
- onClick: () => emit('update:page', props.page + 1),
- }, 'next'),
- h('button', {
- type: 'button',
- 'data-test': 'page-size-50',
- onClick: () => emit('update:pageSize', 50),
- }, '50'),
- ])
- },
-})
-
-describe('AuthIdentityMigrationReportsView', () => {
- beforeEach(() => {
- getAuthIdentityMigrationReportSummary.mockReset()
- listAuthIdentityMigrationReports.mockReset()
- resolveAuthIdentityMigrationReport.mockReset()
- bindUserAuthIdentity.mockReset()
- showError.mockReset()
- showSuccess.mockReset()
-
- getAuthIdentityMigrationReportSummary.mockResolvedValue(summaryResponse)
- listAuthIdentityMigrationReports.mockResolvedValue(listResponse)
- resolveAuthIdentityMigrationReport.mockResolvedValue({
- ...sampleReport,
- resolved_at: '2026-04-20T02:00:00Z',
- resolved_by_user_id: 100,
- resolution_note: 'resolved by admin',
- })
- bindUserAuthIdentity.mockResolvedValue({
- identity_id: 77,
- provider_type: 'oidc',
- provider_key: 'https://issuer.example',
- provider_subject: 'subject-123',
- })
- })
-
- const mountView = () =>
- mount(AuthIdentityMigrationReportsView, {
- global: {
- stubs: {
- AppLayout: AppLayoutStub,
- TablePageLayout: TablePageLayoutStub,
- DataTable: DataTableStub,
- Pagination: PaginationStub,
- Icon: true,
- },
- },
- })
-
- it('loads summary and first page of reports on mount', async () => {
- const wrapper = mountView()
-
- await flushPromises()
-
- expect(getAuthIdentityMigrationReportSummary).toHaveBeenCalledTimes(1)
- expect(listAuthIdentityMigrationReports).toHaveBeenCalledWith({
- page: 1,
- pageSize: 20,
- reportType: '',
- })
- expect(wrapper.get('[data-test="summary-total"]').text()).toContain('2')
- expect(wrapper.get('[data-test="summary-open"]').text()).toContain('1')
- expect(wrapper.get('[data-test="summary-resolved"]').text()).toContain('1')
- expect(wrapper.text()).toContain('legacy@example.invalid')
- })
-
- it('reloads list when the report type filter changes', async () => {
- const wrapper = mountView()
-
- await flushPromises()
-
- listAuthIdentityMigrationReports.mockClear()
-
- await wrapper.get('[data-test="report-type-filter"]').setValue(
- 'oidc_synthetic_email_requires_manual_recovery'
- )
- await flushPromises()
-
- expect(listAuthIdentityMigrationReports).toHaveBeenCalledWith({
- page: 1,
- pageSize: 20,
- reportType: 'oidc_synthetic_email_requires_manual_recovery',
- })
- })
-
- it('submits resolve note for the selected report and refreshes data', async () => {
- const wrapper = mountView()
-
- await flushPromises()
-
- getAuthIdentityMigrationReportSummary.mockClear()
- listAuthIdentityMigrationReports.mockClear()
-
- await wrapper.get('[data-test="select-report-1"]').trigger('click')
- await wrapper.get('[data-test="resolution-note"]').setValue('resolved by admin')
- await wrapper.get('[data-test="resolve-submit"]').trigger('click')
- await flushPromises()
-
- expect(resolveAuthIdentityMigrationReport).toHaveBeenCalledWith(1, 'resolved by admin')
- expect(showSuccess).toHaveBeenCalled()
- expect(getAuthIdentityMigrationReportSummary).toHaveBeenCalledTimes(1)
- expect(listAuthIdentityMigrationReports).toHaveBeenCalledWith({
- page: 1,
- pageSize: 20,
- reportType: '',
- })
- })
-
- it('pre-fills and submits remediation binding for the selected report', async () => {
- const wrapper = mountView()
-
- await flushPromises()
- await wrapper.get('[data-test="select-report-1"]').trigger('click')
- await flushPromises()
-
- expect((wrapper.get('[data-test="remediation-user-id"]').element as HTMLInputElement).value).toBe('42')
- expect((wrapper.get('[data-test="remediation-provider-type"]').element as HTMLInputElement).value).toBe('oidc')
- expect((wrapper.get('[data-test="remediation-provider-key"]').element as HTMLInputElement).value).toBe(
- 'https://issuer.example'
- )
- expect((wrapper.get('[data-test="remediation-provider-subject"]').element as HTMLInputElement).value).toBe(
- 'subject-123'
- )
-
- await wrapper.get('[data-test="remediation-submit"]').trigger('click')
- await flushPromises()
-
- expect(bindUserAuthIdentity).toHaveBeenCalledWith(42, {
- provider_type: 'oidc',
- provider_key: 'https://issuer.example',
- provider_subject: 'subject-123',
- issuer: undefined,
- metadata: {},
- })
- expect(showSuccess).toHaveBeenCalled()
- })
-
- it('keeps report type filter options available from list data when summary fails', async () => {
- getAuthIdentityMigrationReportSummary.mockRejectedValueOnce(new Error('summary failed'))
- listAuthIdentityMigrationReports.mockResolvedValueOnce(listResponse)
-
- const wrapper = mountView()
-
- await flushPromises()
-
- const options = wrapper
- .get('[data-test="report-type-filter"]')
- .findAll('option')
- .map((node) => node.element.value)
-
- expect(showError).toHaveBeenCalled()
- expect(options).toContain('oidc_synthetic_email_requires_manual_recovery')
- })
-})
diff --git a/frontend/src/views/admin/__tests__/UsersView.spec.ts b/frontend/src/views/admin/__tests__/UsersView.spec.ts
index d9076777..532d89f1 100644
--- a/frontend/src/views/admin/__tests__/UsersView.spec.ts
+++ b/frontend/src/views/admin/__tests__/UsersView.spec.ts
@@ -70,7 +70,6 @@ const createAdminUser = (): AdminUser => ({
created_at: '2026-04-17T00:00:00Z',
updated_at: '2026-04-17T00:00:00Z',
notes: '',
- last_login_at: '2026-04-16T01:00:00Z',
last_active_at: '2026-04-16T02:00:00Z',
last_used_at: '2026-04-17T02:00:00Z',
current_concurrency: 0
@@ -113,7 +112,7 @@ describe('admin UsersView', () => {
getBatchUserAttributes.mockResolvedValue({ values: {} })
})
- it('shows active and used activity columns, hides last_login_at, and requests last_used_at sort', async () => {
+ it('shows active, used, and created activity columns in order and requests last_used_at sort', async () => {
const wrapper = mount(UsersView, {
global: {
stubs: {
@@ -145,9 +144,9 @@ describe('admin UsersView', () => {
await flushPromises()
const columns = wrapper.get('[data-test="columns"]').text()
- expect(columns).toContain('last_used_at')
- expect(columns).toContain('last_active_at')
- expect(columns).not.toContain('last_login_at')
+ const visibleColumns = columns.split(',')
+ expect(visibleColumns.slice(-4, -1)).toEqual(['last_active_at', 'last_used_at', 'created_at'])
+ expect(visibleColumns).not.toContain('last_login_at')
await wrapper.get('[data-test="sort-last-used"]').trigger('click')
await flushPromises()
From ee3f158f4e11dff97e0db0dccecc6885dcfd1b96 Mon Sep 17 00:00:00 2001
From: IanShaw027
Date: Tue, 21 Apr 2026 17:35:12 +0800
Subject: [PATCH 122/189] fix(settings): restore wechat and payment config
persistence
---
.../internal/handler/admin/setting_handler.go | 99 +
backend/internal/handler/auth_wechat_oauth.go | 42 +-
.../handler/auth_wechat_oauth_test.go | 100 +-
backend/internal/handler/dto/settings.go | 8 +
.../handler/setting_handler_public_test.go | 21 +-
backend/internal/service/domain_constants.go | 9 +
.../service/payment_config_service.go | 9 +
.../service/payment_config_service_test.go | 47 +-
backend/internal/service/payment_order.go | 16 +-
.../service/payment_order_jsapi_test.go | 16 +-
.../service/payment_order_result_test.go | 36 +-
backend/internal/service/setting_service.go | 168 +-
.../service/setting_service_public_test.go | 21 +-
.../setting_service_wechat_config_test.go | 77 +
backend/internal/service/settings_view.go | 20 +
.../__tests__/settings.wechatConnect.spec.ts | 21 +
frontend/src/api/admin/settings.ts | 1032 +-
frontend/src/views/admin/SettingsView.vue | 9289 ++++++++++-------
.../admin/__tests__/SettingsView.spec.ts | 602 +-
19 files changed, 7160 insertions(+), 4473 deletions(-)
create mode 100644 backend/internal/service/setting_service_wechat_config_test.go
create mode 100644 frontend/src/api/__tests__/settings.wechatConnect.spec.ts
diff --git a/backend/internal/handler/admin/setting_handler.go b/backend/internal/handler/admin/setting_handler.go
index f0e91f3a..9bc20771 100644
--- a/backend/internal/handler/admin/setting_handler.go
+++ b/backend/internal/handler/admin/setting_handler.go
@@ -122,6 +122,13 @@ func (h *SettingHandler) GetSettings(c *gin.Context) {
LinuxDoConnectClientID: settings.LinuxDoConnectClientID,
LinuxDoConnectClientSecretConfigured: settings.LinuxDoConnectClientSecretConfigured,
LinuxDoConnectRedirectURL: settings.LinuxDoConnectRedirectURL,
+ WeChatConnectEnabled: settings.WeChatConnectEnabled,
+ WeChatConnectAppID: settings.WeChatConnectAppID,
+ WeChatConnectAppSecretConfigured: settings.WeChatConnectAppSecretConfigured,
+ WeChatConnectMode: settings.WeChatConnectMode,
+ WeChatConnectScopes: settings.WeChatConnectScopes,
+ WeChatConnectRedirectURL: settings.WeChatConnectRedirectURL,
+ WeChatConnectFrontendRedirectURL: settings.WeChatConnectFrontendRedirectURL,
OIDCConnectEnabled: settings.OIDCConnectEnabled,
OIDCConnectProviderName: settings.OIDCConnectProviderName,
OIDCConnectClientID: settings.OIDCConnectClientID,
@@ -246,6 +253,15 @@ type UpdateSettingsRequest struct {
LinuxDoConnectClientSecret string `json:"linuxdo_connect_client_secret"`
LinuxDoConnectRedirectURL string `json:"linuxdo_connect_redirect_url"`
+ // WeChat Connect OAuth 登录
+ WeChatConnectEnabled bool `json:"wechat_connect_enabled"`
+ WeChatConnectAppID string `json:"wechat_connect_app_id"`
+ WeChatConnectAppSecret string `json:"wechat_connect_app_secret"`
+ WeChatConnectMode string `json:"wechat_connect_mode"`
+ WeChatConnectScopes string `json:"wechat_connect_scopes"`
+ WeChatConnectRedirectURL string `json:"wechat_connect_redirect_url"`
+ WeChatConnectFrontendRedirectURL string `json:"wechat_connect_frontend_redirect_url"`
+
// Generic OIDC OAuth 登录
OIDCConnectEnabled bool `json:"oidc_connect_enabled"`
OIDCConnectProviderName string `json:"oidc_connect_provider_name"`
@@ -509,6 +525,54 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
}
}
+ if req.WeChatConnectEnabled {
+ req.WeChatConnectAppID = strings.TrimSpace(req.WeChatConnectAppID)
+ req.WeChatConnectAppSecret = strings.TrimSpace(req.WeChatConnectAppSecret)
+ req.WeChatConnectMode = strings.ToLower(strings.TrimSpace(req.WeChatConnectMode))
+ req.WeChatConnectScopes = strings.TrimSpace(req.WeChatConnectScopes)
+ req.WeChatConnectRedirectURL = strings.TrimSpace(req.WeChatConnectRedirectURL)
+ req.WeChatConnectFrontendRedirectURL = strings.TrimSpace(req.WeChatConnectFrontendRedirectURL)
+
+ if req.WeChatConnectAppID == "" {
+ response.BadRequest(c, "WeChat App ID is required when enabled")
+ return
+ }
+ if req.WeChatConnectAppSecret == "" {
+ if previousSettings.WeChatConnectAppSecret == "" {
+ response.BadRequest(c, "WeChat App Secret is required when enabled")
+ return
+ }
+ req.WeChatConnectAppSecret = previousSettings.WeChatConnectAppSecret
+ }
+ if req.WeChatConnectMode == "" {
+ req.WeChatConnectMode = "open"
+ }
+ switch req.WeChatConnectMode {
+ case "open", "mp":
+ default:
+ response.BadRequest(c, "WeChat mode must be open or mp")
+ return
+ }
+ if req.WeChatConnectScopes == "" {
+ req.WeChatConnectScopes = service.DefaultWeChatConnectScopesForMode(req.WeChatConnectMode)
+ }
+ if req.WeChatConnectRedirectURL == "" {
+ response.BadRequest(c, "WeChat Redirect URL is required when enabled")
+ return
+ }
+ if err := config.ValidateAbsoluteHTTPURL(req.WeChatConnectRedirectURL); err != nil {
+ response.BadRequest(c, "WeChat Redirect URL must be an absolute http(s) URL")
+ return
+ }
+ if req.WeChatConnectFrontendRedirectURL == "" {
+ req.WeChatConnectFrontendRedirectURL = "/auth/wechat/callback"
+ }
+ if err := config.ValidateFrontendRedirectURL(req.WeChatConnectFrontendRedirectURL); err != nil {
+ response.BadRequest(c, "WeChat Frontend Redirect URL is invalid")
+ return
+ }
+ }
+
// Generic OIDC 参数验证
if req.OIDCConnectEnabled {
req.OIDCConnectProviderName = strings.TrimSpace(req.OIDCConnectProviderName)
@@ -857,6 +921,13 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
LinuxDoConnectClientID: req.LinuxDoConnectClientID,
LinuxDoConnectClientSecret: req.LinuxDoConnectClientSecret,
LinuxDoConnectRedirectURL: req.LinuxDoConnectRedirectURL,
+ WeChatConnectEnabled: req.WeChatConnectEnabled,
+ WeChatConnectAppID: req.WeChatConnectAppID,
+ WeChatConnectAppSecret: req.WeChatConnectAppSecret,
+ WeChatConnectMode: req.WeChatConnectMode,
+ WeChatConnectScopes: req.WeChatConnectScopes,
+ WeChatConnectRedirectURL: req.WeChatConnectRedirectURL,
+ WeChatConnectFrontendRedirectURL: req.WeChatConnectFrontendRedirectURL,
OIDCConnectEnabled: req.OIDCConnectEnabled,
OIDCConnectProviderName: req.OIDCConnectProviderName,
OIDCConnectClientID: req.OIDCConnectClientID,
@@ -1136,6 +1207,13 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
LinuxDoConnectClientID: updatedSettings.LinuxDoConnectClientID,
LinuxDoConnectClientSecretConfigured: updatedSettings.LinuxDoConnectClientSecretConfigured,
LinuxDoConnectRedirectURL: updatedSettings.LinuxDoConnectRedirectURL,
+ WeChatConnectEnabled: updatedSettings.WeChatConnectEnabled,
+ WeChatConnectAppID: updatedSettings.WeChatConnectAppID,
+ WeChatConnectAppSecretConfigured: updatedSettings.WeChatConnectAppSecretConfigured,
+ WeChatConnectMode: updatedSettings.WeChatConnectMode,
+ WeChatConnectScopes: updatedSettings.WeChatConnectScopes,
+ WeChatConnectRedirectURL: updatedSettings.WeChatConnectRedirectURL,
+ WeChatConnectFrontendRedirectURL: updatedSettings.WeChatConnectFrontendRedirectURL,
OIDCConnectEnabled: updatedSettings.OIDCConnectEnabled,
OIDCConnectProviderName: updatedSettings.OIDCConnectProviderName,
OIDCConnectClientID: updatedSettings.OIDCConnectClientID,
@@ -1329,6 +1407,27 @@ func diffSettings(before *service.SystemSettings, after *service.SystemSettings,
if before.LinuxDoConnectRedirectURL != after.LinuxDoConnectRedirectURL {
changed = append(changed, "linuxdo_connect_redirect_url")
}
+ if before.WeChatConnectEnabled != after.WeChatConnectEnabled {
+ changed = append(changed, "wechat_connect_enabled")
+ }
+ if before.WeChatConnectAppID != after.WeChatConnectAppID {
+ changed = append(changed, "wechat_connect_app_id")
+ }
+ if req.WeChatConnectAppSecret != "" {
+ changed = append(changed, "wechat_connect_app_secret")
+ }
+ if before.WeChatConnectMode != after.WeChatConnectMode {
+ changed = append(changed, "wechat_connect_mode")
+ }
+ if before.WeChatConnectScopes != after.WeChatConnectScopes {
+ changed = append(changed, "wechat_connect_scopes")
+ }
+ if before.WeChatConnectRedirectURL != after.WeChatConnectRedirectURL {
+ changed = append(changed, "wechat_connect_redirect_url")
+ }
+ if before.WeChatConnectFrontendRedirectURL != after.WeChatConnectFrontendRedirectURL {
+ changed = append(changed, "wechat_connect_frontend_redirect_url")
+ }
if before.OIDCConnectEnabled != after.OIDCConnectEnabled {
changed = append(changed, "oidc_connect_enabled")
}
diff --git a/backend/internal/handler/auth_wechat_oauth.go b/backend/internal/handler/auth_wechat_oauth.go
index 5e697fb5..734fb2ef 100644
--- a/backend/internal/handler/auth_wechat_oauth.go
+++ b/backend/internal/handler/auth_wechat_oauth.go
@@ -8,7 +8,6 @@ import (
"io"
"net/http"
"net/url"
- "os"
"strconv"
"strings"
"time"
@@ -149,7 +148,7 @@ func (h *AuthHandler) WeChatOAuthStart(c *gin.Context) {
// WeChatOAuthCallback exchanges the code with WeChat, resolves openid/unionid,
// and stores the result in the unified pending-auth flow.
func (h *AuthHandler) WeChatOAuthCallback(c *gin.Context) {
- frontendCallback := wechatOAuthFrontendCallback()
+ frontendCallback := h.wechatOAuthFrontendCallback(c.Request.Context())
if providerErr := strings.TrimSpace(c.Query("error")); providerErr != "" {
redirectOAuthError(c, frontendCallback, "provider_error", providerErr, c.Query("error_description"))
@@ -859,6 +858,10 @@ func (h *AuthHandler) getWeChatOAuthConfig(ctx context.Context, rawMode string,
return wechatOAuthConfig{}, err
}
+ if h == nil || h.settingSvc == nil {
+ return wechatOAuthConfig{}, infraerrors.ServiceUnavailable("CONFIG_NOT_READY", "wechat oauth settings service not ready")
+ }
+
apiBaseURL := ""
if h != nil && h.settingSvc != nil {
settings, err := h.settingSvc.GetAllSettings(ctx)
@@ -867,27 +870,28 @@ func (h *AuthHandler) getWeChatOAuthConfig(ctx context.Context, rawMode string,
}
}
+ effective, err := h.settingSvc.GetWeChatConnectOAuthConfig(ctx)
+ if err != nil {
+ return wechatOAuthConfig{}, err
+ }
+ if effective.Mode != mode {
+ return wechatOAuthConfig{}, infraerrors.NotFound("OAUTH_DISABLED", "wechat oauth is disabled")
+ }
+
cfg := wechatOAuthConfig{
mode: mode,
- redirectURI: resolveWeChatOAuthAbsoluteURL(apiBaseURL, c, "/api/v1/auth/oauth/wechat/callback"),
- frontendCallback: wechatOAuthFrontendCallback(),
+ appID: strings.TrimSpace(effective.AppID),
+ appSecret: strings.TrimSpace(effective.AppSecret),
+ redirectURI: firstNonEmpty(strings.TrimSpace(effective.RedirectURL), resolveWeChatOAuthAbsoluteURL(apiBaseURL, c, "/api/v1/auth/oauth/wechat/callback")),
+ frontendCallback: firstNonEmpty(strings.TrimSpace(effective.FrontendRedirectURL), wechatOAuthDefaultFrontendCB),
+ scope: firstNonEmpty(strings.TrimSpace(effective.Scopes), service.DefaultWeChatConnectScopesForMode(mode)),
}
switch mode {
case "mp":
- cfg.appID = strings.TrimSpace(os.Getenv("WECHAT_OAUTH_MP_APP_ID"))
- cfg.appSecret = strings.TrimSpace(os.Getenv("WECHAT_OAUTH_MP_APP_SECRET"))
cfg.authorizeURL = "https://open.weixin.qq.com/connect/oauth2/authorize"
- cfg.scope = "snsapi_userinfo"
default:
- cfg.appID = strings.TrimSpace(os.Getenv("WECHAT_OAUTH_OPEN_APP_ID"))
- cfg.appSecret = strings.TrimSpace(os.Getenv("WECHAT_OAUTH_OPEN_APP_SECRET"))
cfg.authorizeURL = "https://open.weixin.qq.com/connect/qrconnect"
- cfg.scope = "snsapi_login"
- }
-
- if cfg.appID == "" || cfg.appSecret == "" {
- return wechatOAuthConfig{}, infraerrors.NotFound("OAUTH_DISABLED", "wechat oauth is disabled")
}
if strings.TrimSpace(cfg.redirectURI) == "" {
return wechatOAuthConfig{}, infraerrors.InternalServer("OAUTH_CONFIG_INVALID", "wechat oauth redirect url not configured")
@@ -896,8 +900,14 @@ func (h *AuthHandler) getWeChatOAuthConfig(ctx context.Context, rawMode string,
return cfg, nil
}
-func wechatOAuthFrontendCallback() string {
- return firstNonEmpty(strings.TrimSpace(os.Getenv("WECHAT_OAUTH_FRONTEND_REDIRECT_URL")), wechatOAuthDefaultFrontendCB)
+func (h *AuthHandler) wechatOAuthFrontendCallback(ctx context.Context) string {
+ if h != nil && h.settingSvc != nil {
+ cfg, err := h.settingSvc.GetWeChatConnectOAuthConfig(ctx)
+ if err == nil && strings.TrimSpace(cfg.FrontendRedirectURL) != "" {
+ return strings.TrimSpace(cfg.FrontendRedirectURL)
+ }
+ }
+ return wechatOAuthDefaultFrontendCB
}
func resolveWeChatOAuthMode(rawMode string, c *gin.Context) (string, error) {
diff --git a/backend/internal/handler/auth_wechat_oauth_test.go b/backend/internal/handler/auth_wechat_oauth_test.go
index cd34f52f..b0fee617 100644
--- a/backend/internal/handler/auth_wechat_oauth_test.go
+++ b/backend/internal/handler/auth_wechat_oauth_test.go
@@ -33,16 +33,22 @@ import (
)
func TestWeChatOAuthStartRedirectsAndSetsPendingCookies(t *testing.T) {
- t.Setenv("WECHAT_OAUTH_OPEN_APP_ID", "wx-open-app")
- t.Setenv("WECHAT_OAUTH_OPEN_APP_SECRET", "wx-open-secret")
-
gin.SetMode(gin.TestMode)
+ handler, client := newWeChatOAuthTestHandlerWithSettings(t, false, map[string]string{
+ service.SettingKeyWeChatConnectEnabled: "true",
+ service.SettingKeyWeChatConnectAppID: "wx-open-app",
+ service.SettingKeyWeChatConnectAppSecret: "wx-open-secret",
+ service.SettingKeyWeChatConnectMode: "open",
+ service.SettingKeyWeChatConnectScopes: "snsapi_login",
+ service.SettingKeyWeChatConnectRedirectURL: "https://api.example.com/api/v1/auth/oauth/wechat/callback",
+ service.SettingKeyWeChatConnectFrontendRedirectURL: "/auth/wechat/callback",
+ })
+ defer client.Close()
recorder := httptest.NewRecorder()
c, _ := gin.CreateTestContext(recorder)
c.Request = httptest.NewRequest(http.MethodGet, "/api/v1/auth/oauth/wechat/start?mode=open&redirect=/billing", nil)
c.Request.Host = "api.example.com"
- handler := &AuthHandler{}
handler.WeChatOAuthStart(c)
require.Equal(t, http.StatusFound, recorder.Code)
@@ -60,10 +66,6 @@ func TestWeChatOAuthStartRedirectsAndSetsPendingCookies(t *testing.T) {
}
func TestWeChatOAuthCallbackCreatesPendingSessionForUnifiedFlow(t *testing.T) {
- t.Setenv("WECHAT_OAUTH_OPEN_APP_ID", "wx-open-app")
- t.Setenv("WECHAT_OAUTH_OPEN_APP_SECRET", "wx-open-secret")
- t.Setenv("WECHAT_OAUTH_FRONTEND_REDIRECT_URL", "/auth/wechat/callback")
-
originalAccessTokenURL := wechatOAuthAccessTokenURL
originalUserInfoURL := wechatOAuthUserInfoURL
t.Cleanup(func() {
@@ -124,10 +126,6 @@ func TestWeChatOAuthCallbackCreatesPendingSessionForUnifiedFlow(t *testing.T) {
}
func TestWeChatOAuthCallbackRejectsMissingUnionID(t *testing.T) {
- t.Setenv("WECHAT_OAUTH_OPEN_APP_ID", "wx-open-app")
- t.Setenv("WECHAT_OAUTH_OPEN_APP_SECRET", "wx-open-secret")
- t.Setenv("WECHAT_OAUTH_FRONTEND_REDIRECT_URL", "https://app.example.com/auth/wechat/callback")
-
originalAccessTokenURL := wechatOAuthAccessTokenURL
originalUserInfoURL := wechatOAuthUserInfoURL
t.Cleanup(func() {
@@ -151,7 +149,7 @@ func TestWeChatOAuthCallbackRejectsMissingUnionID(t *testing.T) {
wechatOAuthAccessTokenURL = upstream.URL + "/sns/oauth2/access_token"
wechatOAuthUserInfoURL = upstream.URL + "/sns/userinfo"
- handler, client := newWeChatOAuthTestHandler(t, false)
+ handler, client := newWeChatOAuthTestHandlerWithSettings(t, false, wechatOAuthTestSettings("open", "wx-open-app", "wx-open-secret", "https://app.example.com/auth/wechat/callback"))
defer client.Close()
recorder := httptest.NewRecorder()
@@ -177,9 +175,6 @@ func TestWeChatOAuthCallbackRejectsMissingUnionID(t *testing.T) {
}
func TestWeChatPaymentOAuthCallbackRedirectsWithOpaqueResumeToken(t *testing.T) {
- t.Setenv("WECHAT_OAUTH_MP_APP_ID", "wx-mp-app")
- t.Setenv("WECHAT_OAUTH_MP_APP_SECRET", "wx-mp-secret")
-
originalAccessTokenURL := wechatOAuthAccessTokenURL
t.Cleanup(func() {
wechatOAuthAccessTokenURL = originalAccessTokenURL
@@ -196,7 +191,7 @@ func TestWeChatPaymentOAuthCallbackRedirectsWithOpaqueResumeToken(t *testing.T)
defer upstream.Close()
wechatOAuthAccessTokenURL = upstream.URL + "/sns/oauth2/access_token"
- handler, client := newWeChatOAuthTestHandler(t, false)
+ handler, client := newWeChatOAuthTestHandlerWithSettings(t, false, wechatOAuthTestSettings("mp", "wx-mp-app", "wx-mp-secret", "/auth/wechat/callback"))
defer client.Close()
handler.cfg.Totp.EncryptionKey = "0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef"
@@ -240,7 +235,6 @@ func TestWeChatOAuthCallbackBindUsesUnionCanonicalIdentityAcrossChannels(t *test
testCases := []struct {
name string
mode string
- appIDEnv string
appID string
appSecret string
openID string
@@ -248,7 +242,6 @@ func TestWeChatOAuthCallbackBindUsesUnionCanonicalIdentityAcrossChannels(t *test
{
name: "open",
mode: "open",
- appIDEnv: "WECHAT_OAUTH_OPEN_APP_ID",
appID: "wx-open-app",
appSecret: "wx-open-secret",
openID: "openid-open-123",
@@ -256,7 +249,6 @@ func TestWeChatOAuthCallbackBindUsesUnionCanonicalIdentityAcrossChannels(t *test
{
name: "mp",
mode: "mp",
- appIDEnv: "WECHAT_OAUTH_MP_APP_ID",
appID: "wx-mp-app",
appSecret: "wx-mp-secret",
openID: "openid-mp-123",
@@ -265,15 +257,6 @@ func TestWeChatOAuthCallbackBindUsesUnionCanonicalIdentityAcrossChannels(t *test
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
- t.Setenv(tc.appIDEnv, tc.appID)
- switch tc.mode {
- case "open":
- t.Setenv("WECHAT_OAUTH_OPEN_APP_SECRET", tc.appSecret)
- case "mp":
- t.Setenv("WECHAT_OAUTH_MP_APP_SECRET", tc.appSecret)
- }
- t.Setenv("WECHAT_OAUTH_FRONTEND_REDIRECT_URL", "/auth/wechat/callback")
-
originalAccessTokenURL := wechatOAuthAccessTokenURL
originalUserInfoURL := wechatOAuthUserInfoURL
t.Cleanup(func() {
@@ -297,7 +280,7 @@ func TestWeChatOAuthCallbackBindUsesUnionCanonicalIdentityAcrossChannels(t *test
wechatOAuthAccessTokenURL = upstream.URL + "/sns/oauth2/access_token"
wechatOAuthUserInfoURL = upstream.URL + "/sns/userinfo"
- handler, client := newWeChatOAuthTestHandler(t, false)
+ handler, client := newWeChatOAuthTestHandlerWithSettings(t, false, wechatOAuthTestSettings(tc.mode, tc.appID, tc.appSecret, "/auth/wechat/callback"))
defer client.Close()
currentUser, err := client.User.Create().
@@ -354,10 +337,6 @@ func TestWeChatOAuthCallbackBindUsesUnionCanonicalIdentityAcrossChannels(t *test
}
func TestWeChatOAuthCallbackBindRejectsCanonicalOwnershipConflict(t *testing.T) {
- t.Setenv("WECHAT_OAUTH_OPEN_APP_ID", "wx-open-app")
- t.Setenv("WECHAT_OAUTH_OPEN_APP_SECRET", "wx-open-secret")
- t.Setenv("WECHAT_OAUTH_FRONTEND_REDIRECT_URL", "/auth/wechat/callback")
-
originalAccessTokenURL := wechatOAuthAccessTokenURL
originalUserInfoURL := wechatOAuthUserInfoURL
t.Cleanup(func() {
@@ -436,10 +415,6 @@ func TestWeChatOAuthCallbackBindRejectsCanonicalOwnershipConflict(t *testing.T)
}
func TestWeChatOAuthCallbackBindRejectsChannelOwnershipConflict(t *testing.T) {
- t.Setenv("WECHAT_OAUTH_OPEN_APP_ID", "wx-open-app")
- t.Setenv("WECHAT_OAUTH_OPEN_APP_SECRET", "wx-open-secret")
- t.Setenv("WECHAT_OAUTH_FRONTEND_REDIRECT_URL", "/auth/wechat/callback")
-
originalAccessTokenURL := wechatOAuthAccessTokenURL
originalUserInfoURL := wechatOAuthUserInfoURL
t.Cleanup(func() {
@@ -529,10 +504,6 @@ func TestWeChatOAuthCallbackBindRejectsChannelOwnershipConflict(t *testing.T) {
}
func TestWeChatOAuthCallbackBindRejectsLegacyProviderKeyOwnershipConflict(t *testing.T) {
- t.Setenv("WECHAT_OAUTH_OPEN_APP_ID", "wx-open-app")
- t.Setenv("WECHAT_OAUTH_OPEN_APP_SECRET", "wx-open-secret")
- t.Setenv("WECHAT_OAUTH_FRONTEND_REDIRECT_URL", "/auth/wechat/callback")
-
originalAccessTokenURL := wechatOAuthAccessTokenURL
originalUserInfoURL := wechatOAuthUserInfoURL
t.Cleanup(func() {
@@ -611,10 +582,6 @@ func TestWeChatOAuthCallbackBindRejectsLegacyProviderKeyOwnershipConflict(t *tes
}
func TestCompleteWeChatOAuthRegistrationAfterInvitationPendingSession(t *testing.T) {
- t.Setenv("WECHAT_OAUTH_OPEN_APP_ID", "wx-open-app")
- t.Setenv("WECHAT_OAUTH_OPEN_APP_SECRET", "wx-open-secret")
- t.Setenv("WECHAT_OAUTH_FRONTEND_REDIRECT_URL", "/auth/wechat/callback")
-
originalAccessTokenURL := wechatOAuthAccessTokenURL
originalUserInfoURL := wechatOAuthUserInfoURL
t.Cleanup(func() {
@@ -737,10 +704,6 @@ func TestCompleteWeChatOAuthRegistrationAfterInvitationPendingSession(t *testing
}
func TestWeChatOAuthCallbackRepairsLegacyOpenIDOnlyIdentity(t *testing.T) {
- t.Setenv("WECHAT_OAUTH_OPEN_APP_ID", "wx-open-app")
- t.Setenv("WECHAT_OAUTH_OPEN_APP_SECRET", "wx-open-secret")
- t.Setenv("WECHAT_OAUTH_FRONTEND_REDIRECT_URL", "/auth/wechat/callback")
-
originalAccessTokenURL := wechatOAuthAccessTokenURL
originalUserInfoURL := wechatOAuthUserInfoURL
t.Cleanup(func() {
@@ -900,10 +863,6 @@ func TestCompleteWeChatOAuthRegistrationRejectsAdoptExistingUserSession(t *testi
}
func TestWeChatOAuthCallbackRepairsLegacyProviderKeyCanonicalIdentity(t *testing.T) {
- t.Setenv("WECHAT_OAUTH_OPEN_APP_ID", "wx-open-app")
- t.Setenv("WECHAT_OAUTH_OPEN_APP_SECRET", "wx-open-secret")
- t.Setenv("WECHAT_OAUTH_FRONTEND_REDIRECT_URL", "/auth/wechat/callback")
-
originalAccessTokenURL := wechatOAuthAccessTokenURL
originalUserInfoURL := wechatOAuthUserInfoURL
t.Cleanup(func() {
@@ -1010,6 +969,22 @@ func TestWeChatOAuthCallbackRepairsLegacyProviderKeyCanonicalIdentity(t *testing
}
func newWeChatOAuthTestHandler(t *testing.T, invitationEnabled bool) (*AuthHandler, *dbent.Client) {
+ return newWeChatOAuthTestHandlerWithSettings(t, invitationEnabled, nil)
+}
+
+func wechatOAuthTestSettings(mode, appID, secret, frontendRedirect string) map[string]string {
+ return map[string]string{
+ service.SettingKeyWeChatConnectEnabled: "true",
+ service.SettingKeyWeChatConnectAppID: appID,
+ service.SettingKeyWeChatConnectAppSecret: secret,
+ service.SettingKeyWeChatConnectMode: mode,
+ service.SettingKeyWeChatConnectScopes: service.DefaultWeChatConnectScopesForMode(mode),
+ service.SettingKeyWeChatConnectRedirectURL: "https://api.example.com/api/v1/auth/oauth/wechat/callback",
+ service.SettingKeyWeChatConnectFrontendRedirectURL: frontendRedirect,
+ }
+}
+
+func newWeChatOAuthTestHandlerWithSettings(t *testing.T, invitationEnabled bool, extraSettings map[string]string) (*AuthHandler, *dbent.Client) {
t.Helper()
db, err := sql.Open("sqlite", "file:auth_wechat_oauth?mode=memory&cache=shared")
@@ -1036,12 +1011,17 @@ func newWeChatOAuthTestHandler(t *testing.T, invitationEnabled bool) (*AuthHandl
UserConcurrency: 1,
},
}
- settingSvc := service.NewSettingService(&wechatOAuthSettingRepoStub{
- values: map[string]string{
- service.SettingKeyRegistrationEnabled: "true",
- service.SettingKeyInvitationCodeEnabled: boolSettingValue(invitationEnabled),
- },
- }, cfg)
+ values := map[string]string{
+ service.SettingKeyRegistrationEnabled: "true",
+ service.SettingKeyInvitationCodeEnabled: boolSettingValue(invitationEnabled),
+ }
+ for key, value := range wechatOAuthTestSettings("open", "wx-open-app", "wx-open-secret", "/auth/wechat/callback") {
+ values[key] = value
+ }
+ for key, value := range extraSettings {
+ values[key] = value
+ }
+ settingSvc := service.NewSettingService(&wechatOAuthSettingRepoStub{values: values}, cfg)
authSvc := service.NewAuthService(
client,
diff --git a/backend/internal/handler/dto/settings.go b/backend/internal/handler/dto/settings.go
index d67b29a0..4c5edfbf 100644
--- a/backend/internal/handler/dto/settings.go
+++ b/backend/internal/handler/dto/settings.go
@@ -51,6 +51,14 @@ type SystemSettings struct {
LinuxDoConnectClientSecretConfigured bool `json:"linuxdo_connect_client_secret_configured"`
LinuxDoConnectRedirectURL string `json:"linuxdo_connect_redirect_url"`
+ WeChatConnectEnabled bool `json:"wechat_connect_enabled"`
+ WeChatConnectAppID string `json:"wechat_connect_app_id"`
+ WeChatConnectAppSecretConfigured bool `json:"wechat_connect_app_secret_configured"`
+ WeChatConnectMode string `json:"wechat_connect_mode"`
+ WeChatConnectScopes string `json:"wechat_connect_scopes"`
+ WeChatConnectRedirectURL string `json:"wechat_connect_redirect_url"`
+ WeChatConnectFrontendRedirectURL string `json:"wechat_connect_frontend_redirect_url"`
+
OIDCConnectEnabled bool `json:"oidc_connect_enabled"`
OIDCConnectProviderName string `json:"oidc_connect_provider_name"`
OIDCConnectClientID string `json:"oidc_connect_client_id"`
diff --git a/backend/internal/handler/setting_handler_public_test.go b/backend/internal/handler/setting_handler_public_test.go
index b50c982c..628d9341 100644
--- a/backend/internal/handler/setting_handler_public_test.go
+++ b/backend/internal/handler/setting_handler_public_test.go
@@ -84,12 +84,17 @@ func TestSettingHandler_GetPublicSettings_ExposesForceEmailOnThirdPartySignup(t
func TestSettingHandler_GetPublicSettings_ExposesWeChatOAuthModeCapabilities(t *testing.T) {
gin.SetMode(gin.TestMode)
- t.Setenv("WECHAT_OAUTH_OPEN_APP_ID", "wx-open-app")
- t.Setenv("WECHAT_OAUTH_OPEN_APP_SECRET", "wx-open-secret")
- t.Setenv("WECHAT_OAUTH_MP_APP_ID", "")
- t.Setenv("WECHAT_OAUTH_MP_APP_SECRET", "")
-
- h := NewSettingHandler(service.NewSettingService(&settingHandlerPublicRepoStub{}, &config.Config{}), "test-version")
+ h := NewSettingHandler(service.NewSettingService(&settingHandlerPublicRepoStub{
+ values: map[string]string{
+ service.SettingKeyWeChatConnectEnabled: "true",
+ service.SettingKeyWeChatConnectAppID: "wx-mp-app",
+ service.SettingKeyWeChatConnectAppSecret: "wx-mp-secret",
+ service.SettingKeyWeChatConnectMode: "mp",
+ service.SettingKeyWeChatConnectScopes: "snsapi_base",
+ service.SettingKeyWeChatConnectRedirectURL: "https://api.example.com/api/v1/auth/oauth/wechat/callback",
+ service.SettingKeyWeChatConnectFrontendRedirectURL: "/auth/wechat/callback",
+ },
+ }, &config.Config{}), "test-version")
recorder := httptest.NewRecorder()
c, _ := gin.CreateTestContext(recorder)
@@ -110,6 +115,6 @@ func TestSettingHandler_GetPublicSettings_ExposesWeChatOAuthModeCapabilities(t *
require.NoError(t, json.Unmarshal(recorder.Body.Bytes(), &resp))
require.Equal(t, 0, resp.Code)
require.True(t, resp.Data.WeChatOAuthEnabled)
- require.True(t, resp.Data.WeChatOAuthOpenEnabled)
- require.False(t, resp.Data.WeChatOAuthMPEnabled)
+ require.False(t, resp.Data.WeChatOAuthOpenEnabled)
+ require.True(t, resp.Data.WeChatOAuthMPEnabled)
}
diff --git a/backend/internal/service/domain_constants.go b/backend/internal/service/domain_constants.go
index 1dddf77e..4d63cabc 100644
--- a/backend/internal/service/domain_constants.go
+++ b/backend/internal/service/domain_constants.go
@@ -111,6 +111,15 @@ const (
SettingKeyLinuxDoConnectClientSecret = "linuxdo_connect_client_secret"
SettingKeyLinuxDoConnectRedirectURL = "linuxdo_connect_redirect_url"
+ // WeChat Connect OAuth 登录设置
+ SettingKeyWeChatConnectEnabled = "wechat_connect_enabled"
+ SettingKeyWeChatConnectAppID = "wechat_connect_app_id"
+ SettingKeyWeChatConnectAppSecret = "wechat_connect_app_secret"
+ SettingKeyWeChatConnectMode = "wechat_connect_mode"
+ SettingKeyWeChatConnectScopes = "wechat_connect_scopes"
+ SettingKeyWeChatConnectRedirectURL = "wechat_connect_redirect_url"
+ SettingKeyWeChatConnectFrontendRedirectURL = "wechat_connect_frontend_redirect_url"
+
// Generic OIDC OAuth 登录设置
SettingKeyOIDCConnectEnabled = "oidc_connect_enabled"
SettingKeyOIDCConnectProviderName = "oidc_connect_provider_name"
diff --git a/backend/internal/service/payment_config_service.go b/backend/internal/service/payment_config_service.go
index 34462a3a..2d1f3f42 100644
--- a/backend/internal/service/payment_config_service.go
+++ b/backend/internal/service/payment_config_service.go
@@ -93,6 +93,11 @@ type UpdatePaymentConfigRequest struct {
CancelRateLimitWindow *int `json:"cancel_rate_limit_window"`
CancelRateLimitUnit *string `json:"cancel_rate_limit_unit"`
CancelRateLimitMode *string `json:"cancel_rate_limit_window_mode"`
+
+ VisibleMethodAlipaySource *string `json:"payment_visible_method_alipay_source"`
+ VisibleMethodWxpaySource *string `json:"payment_visible_method_wxpay_source"`
+ VisibleMethodAlipayEnabled *bool `json:"payment_visible_method_alipay_enabled"`
+ VisibleMethodWxpayEnabled *bool `json:"payment_visible_method_wxpay_enabled"`
}
// MethodLimits holds per-payment-type limits.
@@ -319,6 +324,10 @@ func (s *PaymentConfigService) UpdatePaymentConfig(ctx context.Context, req Upda
SettingCancelWindowSize: formatPositiveInt(req.CancelRateLimitWindow),
SettingCancelWindowUnit: derefStr(req.CancelRateLimitUnit),
SettingCancelWindowMode: derefStr(req.CancelRateLimitMode),
+ SettingPaymentVisibleMethodAlipaySource: derefStr(req.VisibleMethodAlipaySource),
+ SettingPaymentVisibleMethodWxpaySource: derefStr(req.VisibleMethodWxpaySource),
+ SettingPaymentVisibleMethodAlipayEnabled: formatBoolOrEmpty(req.VisibleMethodAlipayEnabled),
+ SettingPaymentVisibleMethodWxpayEnabled: formatBoolOrEmpty(req.VisibleMethodWxpayEnabled),
}
if req.EnabledTypes != nil {
m[SettingEnabledPaymentTypes] = strings.Join(req.EnabledTypes, ",")
diff --git a/backend/internal/service/payment_config_service_test.go b/backend/internal/service/payment_config_service_test.go
index 10919058..d58ee234 100644
--- a/backend/internal/service/payment_config_service_test.go
+++ b/backend/internal/service/payment_config_service_test.go
@@ -366,7 +366,8 @@ func newPaymentConfigServiceTestClient(t *testing.T) *dbent.Client {
}
type paymentConfigSettingRepoStub struct {
- values map[string]string
+ values map[string]string
+ updates map[string]string
}
func (s *paymentConfigSettingRepoStub) Get(context.Context, string) (*Setting, error) {
@@ -383,10 +384,52 @@ func (s *paymentConfigSettingRepoStub) GetMultiple(_ context.Context, keys []str
}
return out, nil
}
-func (s *paymentConfigSettingRepoStub) SetMultiple(context.Context, map[string]string) error {
+func (s *paymentConfigSettingRepoStub) SetMultiple(_ context.Context, values map[string]string) error {
+ s.updates = make(map[string]string, len(values))
+ for key, value := range values {
+ s.updates[key] = value
+ if s.values == nil {
+ s.values = map[string]string{}
+ }
+ s.values[key] = value
+ }
return nil
}
func (s *paymentConfigSettingRepoStub) GetAll(context.Context) (map[string]string, error) {
return s.values, nil
}
func (s *paymentConfigSettingRepoStub) Delete(context.Context, string) error { return nil }
+
+func TestUpdatePaymentConfig_PersistsVisibleMethodRouting(t *testing.T) {
+ repo := &paymentConfigSettingRepoStub{values: map[string]string{}}
+ svc := &PaymentConfigService{settingRepo: repo}
+
+ alipayEnabled := true
+ wxpayEnabled := false
+ err := svc.UpdatePaymentConfig(context.Background(), UpdatePaymentConfigRequest{
+ VisibleMethodAlipayEnabled: &alipayEnabled,
+ VisibleMethodAlipaySource: paymentConfigStrPtr(VisibleMethodSourceEasyPayAlipay),
+ VisibleMethodWxpayEnabled: &wxpayEnabled,
+ VisibleMethodWxpaySource: paymentConfigStrPtr(VisibleMethodSourceOfficialWechat),
+ })
+ if err != nil {
+ t.Fatalf("UpdatePaymentConfig returned error: %v", err)
+ }
+
+ if repo.values[SettingPaymentVisibleMethodAlipayEnabled] != "true" {
+ t.Fatalf("alipay enabled = %q, want true", repo.values[SettingPaymentVisibleMethodAlipayEnabled])
+ }
+ if repo.values[SettingPaymentVisibleMethodAlipaySource] != VisibleMethodSourceEasyPayAlipay {
+ t.Fatalf("alipay source = %q, want %q", repo.values[SettingPaymentVisibleMethodAlipaySource], VisibleMethodSourceEasyPayAlipay)
+ }
+ if repo.values[SettingPaymentVisibleMethodWxpayEnabled] != "false" {
+ t.Fatalf("wxpay enabled = %q, want false", repo.values[SettingPaymentVisibleMethodWxpayEnabled])
+ }
+ if repo.values[SettingPaymentVisibleMethodWxpaySource] != VisibleMethodSourceOfficialWechat {
+ t.Fatalf("wxpay source = %q, want %q", repo.values[SettingPaymentVisibleMethodWxpaySource], VisibleMethodSourceOfficialWechat)
+ }
+}
+
+func paymentConfigStrPtr(value string) *string {
+ return &value
+}
diff --git a/backend/internal/service/payment_order.go b/backend/internal/service/payment_order.go
index 354f3cd1..01d8c642 100644
--- a/backend/internal/service/payment_order.go
+++ b/backend/internal/service/payment_order.go
@@ -6,7 +6,6 @@ import (
"log/slog"
"math"
"net/url"
- "os"
"strconv"
"strings"
"time"
@@ -512,16 +511,21 @@ func requiresWeChatJSAPICompatibleSelection(req CreateOrderRequest, sel *payment
return req.IsWeChatBrowser || strings.TrimSpace(req.OpenID) != ""
}
-func (s *PaymentService) getWeChatPaymentOAuthCredential(context.Context) (string, string, error) {
- appID := strings.TrimSpace(os.Getenv("WECHAT_OAUTH_MP_APP_ID"))
- appSecret := strings.TrimSpace(os.Getenv("WECHAT_OAUTH_MP_APP_SECRET"))
- if appID == "" || appSecret == "" {
+func (s *PaymentService) getWeChatPaymentOAuthCredential(ctx context.Context) (string, string, error) {
+ if s == nil || s.configService == nil || s.configService.settingRepo == nil {
return "", "", infraerrors.ServiceUnavailable(
"WECHAT_PAYMENT_MP_NOT_CONFIGURED",
"wechat in-app payment requires a complete WeChat MP OAuth credential",
)
}
- return appID, appSecret, nil
+ cfg, err := (&SettingService{settingRepo: s.configService.settingRepo}).GetWeChatConnectOAuthConfig(ctx)
+ if err != nil || cfg.Mode != "mp" || strings.TrimSpace(cfg.AppID) == "" || strings.TrimSpace(cfg.AppSecret) == "" {
+ return "", "", infraerrors.ServiceUnavailable(
+ "WECHAT_PAYMENT_MP_NOT_CONFIGURED",
+ "wechat in-app payment requires a complete WeChat MP OAuth credential",
+ )
+ }
+ return strings.TrimSpace(cfg.AppID), strings.TrimSpace(cfg.AppSecret), nil
}
func classifyCreatePaymentError(req CreateOrderRequest, providerKey string, err error) error {
diff --git a/backend/internal/service/payment_order_jsapi_test.go b/backend/internal/service/payment_order_jsapi_test.go
index 08492432..25f209af 100644
--- a/backend/internal/service/payment_order_jsapi_test.go
+++ b/backend/internal/service/payment_order_jsapi_test.go
@@ -60,10 +60,17 @@ func TestSelectCreateOrderInstancePrefersJSAPICompatibleWxpayInstance(t *testing
}
configService := &PaymentConfigService{
- entClient: client,
+ entClient: client,
settingRepo: &paymentConfigSettingRepoStub{values: map[string]string{
- SettingPaymentVisibleMethodWxpayEnabled: "true",
- SettingPaymentVisibleMethodWxpaySource: VisibleMethodSourceOfficialWechat,
+ SettingPaymentVisibleMethodWxpayEnabled: "true",
+ SettingPaymentVisibleMethodWxpaySource: VisibleMethodSourceOfficialWechat,
+ SettingKeyWeChatConnectEnabled: "true",
+ SettingKeyWeChatConnectAppID: "wx-mp-app",
+ SettingKeyWeChatConnectAppSecret: "wechat-secret",
+ SettingKeyWeChatConnectMode: "mp",
+ SettingKeyWeChatConnectScopes: "snsapi_base",
+ SettingKeyWeChatConnectRedirectURL: "https://api.example.com/api/v1/auth/oauth/wechat/callback",
+ SettingKeyWeChatConnectFrontendRedirectURL: "/auth/wechat/callback",
}},
encryptionKey: []byte(jsapiTestEncryptionKey),
}
@@ -77,9 +84,6 @@ func TestSelectCreateOrderInstancePrefersJSAPICompatibleWxpayInstance(t *testing
configService: configService,
}
- t.Setenv("WECHAT_OAUTH_MP_APP_ID", "wx-mp-app")
- t.Setenv("WECHAT_OAUTH_MP_APP_SECRET", "wechat-secret")
-
sel, err := svc.selectCreateOrderInstance(ctx, CreateOrderRequest{
PaymentType: payment.TypeWxpay,
OpenID: "openid-123",
diff --git a/backend/internal/service/payment_order_result_test.go b/backend/internal/service/payment_order_result_test.go
index 0daa8213..16757323 100644
--- a/backend/internal/service/payment_order_result_test.go
+++ b/backend/internal/service/payment_order_result_test.go
@@ -91,10 +91,15 @@ func TestBuildCreateOrderResponseCopiesJSAPIPayload(t *testing.T) {
}
func TestMaybeBuildWeChatOAuthRequiredResponse(t *testing.T) {
- t.Setenv("WECHAT_OAUTH_MP_APP_ID", "wx123456")
- t.Setenv("WECHAT_OAUTH_MP_APP_SECRET", "wechat-secret")
-
- svc := &PaymentService{}
+ svc := newWeChatPaymentOAuthTestService(map[string]string{
+ SettingKeyWeChatConnectEnabled: "true",
+ SettingKeyWeChatConnectAppID: "wx123456",
+ SettingKeyWeChatConnectAppSecret: "wechat-secret",
+ SettingKeyWeChatConnectMode: "mp",
+ SettingKeyWeChatConnectScopes: "snsapi_base",
+ SettingKeyWeChatConnectRedirectURL: "https://api.example.com/api/v1/auth/oauth/wechat/callback",
+ SettingKeyWeChatConnectFrontendRedirectURL: "/auth/wechat/callback",
+ })
resp, err := svc.maybeBuildWeChatOAuthRequiredResponse(context.Background(), CreateOrderRequest{
Amount: 12.5,
@@ -132,7 +137,7 @@ func TestMaybeBuildWeChatOAuthRequiredResponse(t *testing.T) {
func TestMaybeBuildWeChatOAuthRequiredResponseRequiresMPConfigInWeChat(t *testing.T) {
t.Parallel()
- svc := &PaymentService{}
+ svc := newWeChatPaymentOAuthTestService(nil)
resp, err := svc.maybeBuildWeChatOAuthRequiredResponse(context.Background(), CreateOrderRequest{
Amount: 12.5,
@@ -155,10 +160,15 @@ func TestMaybeBuildWeChatOAuthRequiredResponseRequiresMPConfigInWeChat(t *testin
}
func TestMaybeBuildWeChatOAuthRequiredResponseForSelectionSkipsEasyPayProvider(t *testing.T) {
- t.Setenv("WECHAT_OAUTH_MP_APP_ID", "wx123456")
- t.Setenv("WECHAT_OAUTH_MP_APP_SECRET", "wechat-secret")
-
- svc := &PaymentService{}
+ svc := newWeChatPaymentOAuthTestService(map[string]string{
+ SettingKeyWeChatConnectEnabled: "true",
+ SettingKeyWeChatConnectAppID: "wx123456",
+ SettingKeyWeChatConnectAppSecret: "wechat-secret",
+ SettingKeyWeChatConnectMode: "mp",
+ SettingKeyWeChatConnectScopes: "snsapi_base",
+ SettingKeyWeChatConnectRedirectURL: "https://api.example.com/api/v1/auth/oauth/wechat/callback",
+ SettingKeyWeChatConnectFrontendRedirectURL: "/auth/wechat/callback",
+ })
resp, err := svc.maybeBuildWeChatOAuthRequiredResponseForSelection(context.Background(), CreateOrderRequest{
Amount: 12.5,
@@ -175,3 +185,11 @@ func TestMaybeBuildWeChatOAuthRequiredResponseForSelectionSkipsEasyPayProvider(t
t.Fatalf("expected nil response, got %+v", resp)
}
}
+
+func newWeChatPaymentOAuthTestService(values map[string]string) *PaymentService {
+ return &PaymentService{
+ configService: &PaymentConfigService{
+ settingRepo: &paymentConfigSettingRepoStub{values: values},
+ },
+ }
+}
diff --git a/backend/internal/service/setting_service.go b/backend/internal/service/setting_service.go
index 8c879d52..373c1aef 100644
--- a/backend/internal/service/setting_service.go
+++ b/backend/internal/service/setting_service.go
@@ -9,7 +9,6 @@ import (
"fmt"
"log/slog"
"net/url"
- "os"
"sort"
"strconv"
"strings"
@@ -173,8 +172,43 @@ var (
const (
defaultAuthSourceBalance = 0
defaultAuthSourceConcurrency = 5
+ defaultWeChatConnectMode = "open"
+ defaultWeChatConnectScopes = "snsapi_login"
+ defaultWeChatConnectFrontend = "/auth/wechat/callback"
)
+func normalizeWeChatConnectModeSetting(raw string) string {
+ switch strings.ToLower(strings.TrimSpace(raw)) {
+ case "mp":
+ return "mp"
+ default:
+ return "open"
+ }
+}
+
+func defaultWeChatConnectScopeForMode(mode string) string {
+ if normalizeWeChatConnectModeSetting(mode) == "mp" {
+ return "snsapi_userinfo"
+ }
+ return defaultWeChatConnectScopes
+}
+
+func normalizeWeChatConnectScopeSetting(raw, mode string) string {
+ switch normalizeWeChatConnectModeSetting(mode) {
+ case "mp":
+ switch strings.TrimSpace(raw) {
+ case "snsapi_base":
+ return "snsapi_base"
+ case "snsapi_userinfo":
+ return "snsapi_userinfo"
+ default:
+ return defaultWeChatConnectScopeForMode(mode)
+ }
+ default:
+ return defaultWeChatConnectScopes
+ }
+}
+
// NewSettingService 创建系统设置服务实例
func NewSettingService(settingRepo SettingRepository, cfg *config.Config) *SettingService {
return &SettingService{
@@ -240,6 +274,13 @@ func (s *SettingService) GetPublicSettings(ctx context.Context) (*PublicSettings
SettingKeyCustomMenuItems,
SettingKeyCustomEndpoints,
SettingKeyLinuxDoConnectEnabled,
+ SettingKeyWeChatConnectEnabled,
+ SettingKeyWeChatConnectAppID,
+ SettingKeyWeChatConnectAppSecret,
+ SettingKeyWeChatConnectMode,
+ SettingKeyWeChatConnectScopes,
+ SettingKeyWeChatConnectRedirectURL,
+ SettingKeyWeChatConnectFrontendRedirectURL,
SettingKeyBackendModeEnabled,
SettingPaymentEnabled,
SettingKeyOIDCConnectEnabled,
@@ -274,9 +315,7 @@ func (s *SettingService) GetPublicSettings(ctx context.Context) (*PublicSettings
if oidcProviderName == "" {
oidcProviderName = "OIDC"
}
- weChatOpenEnabled := isWeChatOAuthOpenConfigured()
- weChatMPEnabled := isWeChatOAuthMPConfigured()
- weChatEnabled := weChatOpenEnabled || weChatMPEnabled
+ weChatEnabled, weChatOpenEnabled, weChatMPEnabled := s.weChatOAuthCapabilitiesFromSettings(settings)
// Password reset requires email verification to be enabled
emailVerifyEnabled := settings[SettingKeyEmailVerifyEnabled] == "true"
@@ -431,6 +470,56 @@ func (s *SettingService) GetPublicSettingsForInjection(ctx context.Context) (any
}, nil
}
+func DefaultWeChatConnectScopesForMode(mode string) string {
+ return defaultWeChatConnectScopeForMode(mode)
+}
+
+func (s *SettingService) parseWeChatConnectOAuthConfig(settings map[string]string) (WeChatConnectOAuthConfig, error) {
+ cfg := WeChatConnectOAuthConfig{
+ Enabled: settings[SettingKeyWeChatConnectEnabled] == "true",
+ AppID: strings.TrimSpace(settings[SettingKeyWeChatConnectAppID]),
+ AppSecret: strings.TrimSpace(settings[SettingKeyWeChatConnectAppSecret]),
+ Mode: normalizeWeChatConnectModeSetting(settings[SettingKeyWeChatConnectMode]),
+ Scopes: normalizeWeChatConnectScopeSetting(settings[SettingKeyWeChatConnectScopes], settings[SettingKeyWeChatConnectMode]),
+ RedirectURL: strings.TrimSpace(settings[SettingKeyWeChatConnectRedirectURL]),
+ FrontendRedirectURL: strings.TrimSpace(settings[SettingKeyWeChatConnectFrontendRedirectURL]),
+ }
+ if cfg.FrontendRedirectURL == "" {
+ cfg.FrontendRedirectURL = defaultWeChatConnectFrontend
+ }
+
+ if !cfg.Enabled {
+ return WeChatConnectOAuthConfig{}, infraerrors.NotFound("OAUTH_DISABLED", "wechat oauth is disabled")
+ }
+ if cfg.AppID == "" {
+ return WeChatConnectOAuthConfig{}, infraerrors.InternalServer("OAUTH_CONFIG_INVALID", "wechat oauth app id not configured")
+ }
+ if cfg.AppSecret == "" {
+ return WeChatConnectOAuthConfig{}, infraerrors.InternalServer("OAUTH_CONFIG_INVALID", "wechat oauth app secret not configured")
+ }
+ if cfg.RedirectURL == "" {
+ return WeChatConnectOAuthConfig{}, infraerrors.InternalServer("OAUTH_CONFIG_INVALID", "wechat oauth redirect url not configured")
+ }
+ if cfg.FrontendRedirectURL == "" {
+ return WeChatConnectOAuthConfig{}, infraerrors.InternalServer("OAUTH_CONFIG_INVALID", "wechat oauth frontend redirect url not configured")
+ }
+ if err := config.ValidateAbsoluteHTTPURL(cfg.RedirectURL); err != nil {
+ return WeChatConnectOAuthConfig{}, infraerrors.InternalServer("OAUTH_CONFIG_INVALID", "wechat oauth redirect url invalid")
+ }
+ if err := config.ValidateFrontendRedirectURL(cfg.FrontendRedirectURL); err != nil {
+ return WeChatConnectOAuthConfig{}, infraerrors.InternalServer("OAUTH_CONFIG_INVALID", "wechat oauth frontend redirect url invalid")
+ }
+ return cfg, nil
+}
+
+func (s *SettingService) weChatOAuthCapabilitiesFromSettings(settings map[string]string) (bool, bool, bool) {
+ cfg, err := s.parseWeChatConnectOAuthConfig(settings)
+ if err != nil {
+ return false, false, false
+ }
+ return true, cfg.Mode == "open", cfg.Mode == "mp"
+}
+
// filterUserVisibleMenuItems filters out admin-only menu items from a raw JSON
// array string, returning only items with visibility != "admin".
func filterUserVisibleMenuItems(raw string) json.RawMessage {
@@ -467,20 +556,6 @@ func filterUserVisibleMenuItems(raw string) json.RawMessage {
return result
}
-func isWeChatOAuthConfigured() bool {
- return isWeChatOAuthOpenConfigured() || isWeChatOAuthMPConfigured()
-}
-
-func isWeChatOAuthOpenConfigured() bool {
- return strings.TrimSpace(os.Getenv("WECHAT_OAUTH_OPEN_APP_ID")) != "" &&
- strings.TrimSpace(os.Getenv("WECHAT_OAUTH_OPEN_APP_SECRET")) != ""
-}
-
-func isWeChatOAuthMPConfigured() bool {
- return strings.TrimSpace(os.Getenv("WECHAT_OAUTH_MP_APP_ID")) != "" &&
- strings.TrimSpace(os.Getenv("WECHAT_OAUTH_MP_APP_SECRET")) != ""
-}
-
// safeRawJSONArray returns raw as json.RawMessage if it's valid JSON, otherwise "[]".
func safeRawJSONArray(raw string) json.RawMessage {
raw = strings.TrimSpace(raw)
@@ -625,6 +700,15 @@ func (s *SettingService) buildSystemSettingsUpdates(ctx context.Context, setting
}
settings.PaymentVisibleMethodAlipaySource = alipaySource
settings.PaymentVisibleMethodWxpaySource = wxpaySource
+ settings.WeChatConnectAppID = strings.TrimSpace(settings.WeChatConnectAppID)
+ settings.WeChatConnectAppSecret = strings.TrimSpace(settings.WeChatConnectAppSecret)
+ settings.WeChatConnectMode = normalizeWeChatConnectModeSetting(settings.WeChatConnectMode)
+ settings.WeChatConnectScopes = normalizeWeChatConnectScopeSetting(settings.WeChatConnectScopes, settings.WeChatConnectMode)
+ settings.WeChatConnectRedirectURL = strings.TrimSpace(settings.WeChatConnectRedirectURL)
+ settings.WeChatConnectFrontendRedirectURL = strings.TrimSpace(settings.WeChatConnectFrontendRedirectURL)
+ if settings.WeChatConnectFrontendRedirectURL == "" {
+ settings.WeChatConnectFrontendRedirectURL = defaultWeChatConnectFrontend
+ }
updates := make(map[string]string)
@@ -694,6 +778,17 @@ func (s *SettingService) buildSystemSettingsUpdates(ctx context.Context, setting
updates[SettingKeyOIDCConnectClientSecret] = settings.OIDCConnectClientSecret
}
+ // WeChat Connect OAuth 登录
+ updates[SettingKeyWeChatConnectEnabled] = strconv.FormatBool(settings.WeChatConnectEnabled)
+ updates[SettingKeyWeChatConnectAppID] = settings.WeChatConnectAppID
+ updates[SettingKeyWeChatConnectMode] = settings.WeChatConnectMode
+ updates[SettingKeyWeChatConnectScopes] = settings.WeChatConnectScopes
+ updates[SettingKeyWeChatConnectRedirectURL] = settings.WeChatConnectRedirectURL
+ updates[SettingKeyWeChatConnectFrontendRedirectURL] = settings.WeChatConnectFrontendRedirectURL
+ if settings.WeChatConnectAppSecret != "" {
+ updates[SettingKeyWeChatConnectAppSecret] = settings.WeChatConnectAppSecret
+ }
+
// OEM设置
updates[SettingKeySiteName] = settings.SiteName
updates[SettingKeySiteLogo] = settings.SiteLogo
@@ -1200,6 +1295,10 @@ func (s *SettingService) InitializeDefaultSettings(ctx context.Context) error {
SettingKeyTablePageSizeOptions: "[10,20,50,100]",
SettingKeyCustomMenuItems: "[]",
SettingKeyCustomEndpoints: "[]",
+ SettingKeyWeChatConnectEnabled: "false",
+ SettingKeyWeChatConnectMode: "open",
+ SettingKeyWeChatConnectScopes: "snsapi_login",
+ SettingKeyWeChatConnectFrontendRedirectURL: defaultWeChatConnectFrontend,
SettingKeyOIDCConnectEnabled: "false",
SettingKeyOIDCConnectProviderName: "OIDC",
SettingKeyDefaultConcurrency: strconv.Itoa(s.cfg.Default.UserConcurrency),
@@ -1491,6 +1590,19 @@ func (s *SettingService) parseSettings(settings map[string]string) *SystemSettin
}
result.OIDCConnectClientSecretConfigured = result.OIDCConnectClientSecret != ""
+ // WeChat Connect 设置:完全以 DB 系统设置为准。
+ result.WeChatConnectEnabled = settings[SettingKeyWeChatConnectEnabled] == "true"
+ result.WeChatConnectAppID = strings.TrimSpace(settings[SettingKeyWeChatConnectAppID])
+ result.WeChatConnectAppSecret = strings.TrimSpace(settings[SettingKeyWeChatConnectAppSecret])
+ result.WeChatConnectAppSecretConfigured = result.WeChatConnectAppSecret != ""
+ result.WeChatConnectMode = normalizeWeChatConnectModeSetting(settings[SettingKeyWeChatConnectMode])
+ result.WeChatConnectScopes = normalizeWeChatConnectScopeSetting(settings[SettingKeyWeChatConnectScopes], settings[SettingKeyWeChatConnectMode])
+ result.WeChatConnectRedirectURL = strings.TrimSpace(settings[SettingKeyWeChatConnectRedirectURL])
+ result.WeChatConnectFrontendRedirectURL = strings.TrimSpace(settings[SettingKeyWeChatConnectFrontendRedirectURL])
+ if result.WeChatConnectFrontendRedirectURL == "" {
+ result.WeChatConnectFrontendRedirectURL = defaultWeChatConnectFrontend
+ }
+
// Model fallback settings
result.EnableModelFallback = settings[SettingKeyEnableModelFallback] == "true"
result.FallbackModelAnthropic = s.getStringOrDefault(settings, SettingKeyFallbackModelAnthropic, "claude-3-5-sonnet-20241022")
@@ -1972,6 +2084,26 @@ func (s *SettingService) GetLinuxDoConnectOAuthConfig(ctx context.Context) (conf
return effective, nil
}
+// GetWeChatConnectOAuthConfig 返回用于登录的最终生效 WeChat Connect 配置。
+//
+// WeChat Connect 已回归 DB 系统设置模型,不再回退到 config/env。
+func (s *SettingService) GetWeChatConnectOAuthConfig(ctx context.Context) (WeChatConnectOAuthConfig, error) {
+ keys := []string{
+ SettingKeyWeChatConnectEnabled,
+ SettingKeyWeChatConnectAppID,
+ SettingKeyWeChatConnectAppSecret,
+ SettingKeyWeChatConnectMode,
+ SettingKeyWeChatConnectScopes,
+ SettingKeyWeChatConnectRedirectURL,
+ SettingKeyWeChatConnectFrontendRedirectURL,
+ }
+ settings, err := s.settingRepo.GetMultiple(ctx, keys)
+ if err != nil {
+ return WeChatConnectOAuthConfig{}, fmt.Errorf("get wechat connect settings: %w", err)
+ }
+ return s.parseWeChatConnectOAuthConfig(settings)
+}
+
// GetOverloadCooldownSettings 获取529过载冷却配置
func (s *SettingService) GetOverloadCooldownSettings(ctx context.Context) (*OverloadCooldownSettings, error) {
value, err := s.settingRepo.GetValue(ctx, SettingKeyOverloadCooldownSettings)
diff --git a/backend/internal/service/setting_service_public_test.go b/backend/internal/service/setting_service_public_test.go
index 4cfa9f0c..ffc069dc 100644
--- a/backend/internal/service/setting_service_public_test.go
+++ b/backend/internal/service/setting_service_public_test.go
@@ -92,16 +92,21 @@ func TestSettingService_GetPublicSettings_ExposesForceEmailOnThirdPartySignup(t
}
func TestSettingService_GetPublicSettings_ExposesWeChatOAuthModeCapabilities(t *testing.T) {
- t.Setenv("WECHAT_OAUTH_OPEN_APP_ID", "wx-open-app")
- t.Setenv("WECHAT_OAUTH_OPEN_APP_SECRET", "wx-open-secret")
- t.Setenv("WECHAT_OAUTH_MP_APP_ID", "")
- t.Setenv("WECHAT_OAUTH_MP_APP_SECRET", "")
-
- svc := NewSettingService(&settingPublicRepoStub{}, &config.Config{})
+ svc := NewSettingService(&settingPublicRepoStub{
+ values: map[string]string{
+ SettingKeyWeChatConnectEnabled: "true",
+ SettingKeyWeChatConnectAppID: "wx-mp-app",
+ SettingKeyWeChatConnectAppSecret: "wx-mp-secret",
+ SettingKeyWeChatConnectMode: "mp",
+ SettingKeyWeChatConnectScopes: "snsapi_base",
+ SettingKeyWeChatConnectRedirectURL: "https://api.example.com/api/v1/auth/oauth/wechat/callback",
+ SettingKeyWeChatConnectFrontendRedirectURL: "/auth/wechat/callback",
+ },
+ }, &config.Config{})
settings, err := svc.GetPublicSettings(context.Background())
require.NoError(t, err)
require.True(t, settings.WeChatOAuthEnabled)
- require.True(t, settings.WeChatOAuthOpenEnabled)
- require.False(t, settings.WeChatOAuthMPEnabled)
+ require.False(t, settings.WeChatOAuthOpenEnabled)
+ require.True(t, settings.WeChatOAuthMPEnabled)
}
diff --git a/backend/internal/service/setting_service_wechat_config_test.go b/backend/internal/service/setting_service_wechat_config_test.go
new file mode 100644
index 00000000..2cb312cc
--- /dev/null
+++ b/backend/internal/service/setting_service_wechat_config_test.go
@@ -0,0 +1,77 @@
+//go:build unit
+
+package service
+
+import (
+ "context"
+ "testing"
+
+ "github.com/Wei-Shaw/sub2api/internal/config"
+ "github.com/stretchr/testify/require"
+)
+
+type settingWeChatRepoStub struct {
+ values map[string]string
+}
+
+func (s *settingWeChatRepoStub) Get(context.Context, string) (*Setting, error) {
+ panic("unexpected Get call")
+}
+
+func (s *settingWeChatRepoStub) GetValue(_ context.Context, key string) (string, error) {
+ if value, ok := s.values[key]; ok {
+ return value, nil
+ }
+ return "", ErrSettingNotFound
+}
+
+func (s *settingWeChatRepoStub) Set(context.Context, string, string) error {
+ panic("unexpected Set call")
+}
+
+func (s *settingWeChatRepoStub) GetMultiple(_ context.Context, keys []string) (map[string]string, error) {
+ out := make(map[string]string, len(keys))
+ for _, key := range keys {
+ if value, ok := s.values[key]; ok {
+ out[key] = value
+ }
+ }
+ return out, nil
+}
+
+func (s *settingWeChatRepoStub) SetMultiple(context.Context, map[string]string) error {
+ panic("unexpected SetMultiple call")
+}
+
+func (s *settingWeChatRepoStub) GetAll(context.Context) (map[string]string, error) {
+ panic("unexpected GetAll call")
+}
+
+func (s *settingWeChatRepoStub) Delete(context.Context, string) error {
+ panic("unexpected Delete call")
+}
+
+func TestSettingService_GetWeChatConnectOAuthConfig_UsesDatabaseOverrides(t *testing.T) {
+ repo := &settingWeChatRepoStub{
+ values: map[string]string{
+ SettingKeyWeChatConnectEnabled: "true",
+ SettingKeyWeChatConnectAppID: "wx-db-app",
+ SettingKeyWeChatConnectAppSecret: "wx-db-secret",
+ SettingKeyWeChatConnectMode: "mp",
+ SettingKeyWeChatConnectScopes: "snsapi_base",
+ SettingKeyWeChatConnectRedirectURL: "https://api.example.com/api/v1/auth/oauth/wechat/callback",
+ SettingKeyWeChatConnectFrontendRedirectURL: "/auth/wechat/callback",
+ },
+ }
+ svc := NewSettingService(repo, &config.Config{})
+
+ got, err := svc.GetWeChatConnectOAuthConfig(context.Background())
+ require.NoError(t, err)
+ require.True(t, got.Enabled)
+ require.Equal(t, "wx-db-app", got.AppID)
+ require.Equal(t, "wx-db-secret", got.AppSecret)
+ require.Equal(t, "mp", got.Mode)
+ require.Equal(t, "snsapi_base", got.Scopes)
+ require.Equal(t, "https://api.example.com/api/v1/auth/oauth/wechat/callback", got.RedirectURL)
+ require.Equal(t, "/auth/wechat/callback", got.FrontendRedirectURL)
+}
diff --git a/backend/internal/service/settings_view.go b/backend/internal/service/settings_view.go
index 1b859dbd..41229bfb 100644
--- a/backend/internal/service/settings_view.go
+++ b/backend/internal/service/settings_view.go
@@ -31,6 +31,16 @@ type SystemSettings struct {
LinuxDoConnectClientSecretConfigured bool
LinuxDoConnectRedirectURL string
+ // WeChat Connect OAuth 登录
+ WeChatConnectEnabled bool
+ WeChatConnectAppID string
+ WeChatConnectAppSecret string
+ WeChatConnectAppSecretConfigured bool
+ WeChatConnectMode string
+ WeChatConnectScopes string
+ WeChatConnectRedirectURL string
+ WeChatConnectFrontendRedirectURL string
+
// Generic OIDC OAuth 登录
OIDCConnectEnabled bool
OIDCConnectProviderName string
@@ -177,6 +187,16 @@ type PublicSettings struct {
BalanceLowNotifyRechargeURL string
}
+type WeChatConnectOAuthConfig struct {
+ Enabled bool
+ AppID string
+ AppSecret string
+ Mode string
+ Scopes string
+ RedirectURL string
+ FrontendRedirectURL string
+}
+
// StreamTimeoutSettings 流超时处理配置(仅控制超时后的处理方式,超时判定由网关配置控制)
type StreamTimeoutSettings struct {
// Enabled 是否启用流超时处理
diff --git a/frontend/src/api/__tests__/settings.wechatConnect.spec.ts b/frontend/src/api/__tests__/settings.wechatConnect.spec.ts
new file mode 100644
index 00000000..eccb7214
--- /dev/null
+++ b/frontend/src/api/__tests__/settings.wechatConnect.spec.ts
@@ -0,0 +1,21 @@
+import { describe, expect, it } from "vitest";
+
+import {
+ defaultWeChatConnectScopesForMode,
+ normalizeWeChatConnectMode,
+} from "@/api/admin/settings";
+
+describe("admin settings wechat connect helpers", () => {
+ it("normalizes legacy or noisy mode values to the backend contract", () => {
+ expect(normalizeWeChatConnectMode("OPEN")).toBe("open");
+ expect(normalizeWeChatConnectMode(" open_platform ")).toBe("open");
+ expect(normalizeWeChatConnectMode("mp")).toBe("mp");
+ expect(normalizeWeChatConnectMode("official_account")).toBe("mp");
+ expect(normalizeWeChatConnectMode("unknown")).toBe("open");
+ });
+
+ it("maps each mode to the backend default scopes", () => {
+ expect(defaultWeChatConnectScopesForMode("open")).toBe("snsapi_login");
+ expect(defaultWeChatConnectScopesForMode("mp")).toBe("snsapi_userinfo");
+ });
+});
diff --git a/frontend/src/api/admin/settings.ts b/frontend/src/api/admin/settings.ts
index 235bda7b..ed78a1bc 100644
--- a/frontend/src/api/admin/settings.ts
+++ b/frontend/src/api/admin/settings.ts
@@ -3,155 +3,240 @@
* Handles system settings management for administrators
*/
-import { apiClient } from '../client'
-import type { CustomMenuItem, CustomEndpoint, NotifyEmailEntry } from '@/types'
+import { apiClient } from "../client";
+import type { CustomMenuItem, CustomEndpoint, NotifyEmailEntry } from "@/types";
export interface DefaultSubscriptionSetting {
- group_id: number
- validity_days: number
+ group_id: number;
+ validity_days: number;
}
-export type AuthSourceType = 'email' | 'linuxdo' | 'oidc' | 'wechat'
+export type AuthSourceType = "email" | "linuxdo" | "oidc" | "wechat";
export interface AuthSourceDefaultsValue {
- balance: number
- concurrency: number
- subscriptions: DefaultSubscriptionSetting[]
- grant_on_signup: boolean
- grant_on_first_bind: boolean
+ balance: number;
+ concurrency: number;
+ subscriptions: DefaultSubscriptionSetting[];
+ grant_on_signup: boolean;
+ grant_on_first_bind: boolean;
}
-export type AuthSourceDefaultsState = Record
-export type PaymentVisibleMethod = 'alipay' | 'wxpay'
+export type AuthSourceDefaultsState = Record<
+ AuthSourceType,
+ AuthSourceDefaultsValue
+>;
+export type PaymentVisibleMethod = "alipay" | "wxpay";
export type PaymentVisibleMethodSource =
- | ''
- | 'official_alipay'
- | 'easypay_alipay'
- | 'official_wxpay'
- | 'easypay_wxpay'
+ | ""
+ | "official_alipay"
+ | "easypay_alipay"
+ | "official_wxpay"
+ | "easypay_wxpay";
+export type WeChatConnectMode = "open" | "mp";
export interface PaymentVisibleMethodSourceOption {
- value: PaymentVisibleMethodSource
- labelZh: string
- labelEn: string
+ value: PaymentVisibleMethodSource;
+ labelZh: string;
+ labelEn: string;
}
-const AUTH_SOURCE_TYPES: AuthSourceType[] = ['email', 'linuxdo', 'oidc', 'wechat']
-const AUTH_SOURCE_DEFAULT_BALANCE = 0
-const AUTH_SOURCE_DEFAULT_CONCURRENCY = 5
+export interface WeChatConnectModeOption {
+ value: WeChatConnectMode;
+ labelZh: string;
+ labelEn: string;
+}
+
+const AUTH_SOURCE_TYPES: AuthSourceType[] = [
+ "email",
+ "linuxdo",
+ "oidc",
+ "wechat",
+];
+const AUTH_SOURCE_DEFAULT_BALANCE = 0;
+const AUTH_SOURCE_DEFAULT_CONCURRENCY = 5;
const PAYMENT_VISIBLE_METHOD_SOURCE_OPTIONS: Record<
PaymentVisibleMethod,
PaymentVisibleMethodSourceOption[]
> = {
alipay: [
- { value: '', labelZh: '未配置', labelEn: 'Not configured' },
- { value: 'official_alipay', labelZh: '支付宝官方', labelEn: 'Official Alipay' },
- { value: 'easypay_alipay', labelZh: '易支付支付宝', labelEn: 'EasyPay Alipay' },
+ { value: "", labelZh: "未配置", labelEn: "Not configured" },
+ {
+ value: "official_alipay",
+ labelZh: "支付宝官方",
+ labelEn: "Official Alipay",
+ },
+ {
+ value: "easypay_alipay",
+ labelZh: "易支付支付宝",
+ labelEn: "EasyPay Alipay",
+ },
],
wxpay: [
- { value: '', labelZh: '未配置', labelEn: 'Not configured' },
- { value: 'official_wxpay', labelZh: '微信官方', labelEn: 'Official WeChat Pay' },
- { value: 'easypay_wxpay', labelZh: '易支付微信', labelEn: 'EasyPay WeChat Pay' },
+ { value: "", labelZh: "未配置", labelEn: "Not configured" },
+ {
+ value: "official_wxpay",
+ labelZh: "微信官方",
+ labelEn: "Official WeChat Pay",
+ },
+ {
+ value: "easypay_wxpay",
+ labelZh: "易支付微信",
+ labelEn: "EasyPay WeChat Pay",
+ },
],
-}
+};
const PAYMENT_VISIBLE_METHOD_SOURCE_ALIASES: Record<
PaymentVisibleMethod,
Record
> = {
alipay: {
- official_alipay: 'official_alipay',
- alipay: 'official_alipay',
- alipay_direct: 'official_alipay',
- official: 'official_alipay',
- easypay_alipay: 'easypay_alipay',
- easypay: 'easypay_alipay',
+ official_alipay: "official_alipay",
+ alipay: "official_alipay",
+ alipay_direct: "official_alipay",
+ official: "official_alipay",
+ easypay_alipay: "easypay_alipay",
+ easypay: "easypay_alipay",
},
wxpay: {
- official_wxpay: 'official_wxpay',
- wxpay: 'official_wxpay',
- wxpay_direct: 'official_wxpay',
- wechat: 'official_wxpay',
- official: 'official_wxpay',
- easypay_wxpay: 'easypay_wxpay',
- easypay: 'easypay_wxpay',
+ official_wxpay: "official_wxpay",
+ wxpay: "official_wxpay",
+ wxpay_direct: "official_wxpay",
+ wechat: "official_wxpay",
+ official: "official_wxpay",
+ easypay_wxpay: "easypay_wxpay",
+ easypay: "easypay_wxpay",
},
-}
+};
+const WECHAT_CONNECT_MODE_OPTIONS: WeChatConnectModeOption[] = [
+ { value: "open", labelZh: "微信开放平台", labelEn: "WeChat Open Platform" },
+ {
+ value: "mp",
+ labelZh: "微信公众号 / 小程序",
+ labelEn: "WeChat Official Account / Mini Program",
+ },
+];
+const WECHAT_CONNECT_MODE_ALIASES: Record = {
+ open: "open",
+ open_platform: "open",
+ official: "open",
+ wx_open: "open",
+ mp: "mp",
+ official_account: "mp",
+ wechat_mp: "mp",
+ mini_program: "mp",
+};
export function normalizeDefaultSubscriptionSettings(
- subscriptions: DefaultSubscriptionSetting[] | null | undefined
+ subscriptions: DefaultSubscriptionSetting[] | null | undefined,
): DefaultSubscriptionSetting[] {
- if (!Array.isArray(subscriptions)) return []
+ if (!Array.isArray(subscriptions)) return [];
return subscriptions
.filter((item) => item.group_id > 0 && item.validity_days > 0)
.map((item) => ({
group_id: Math.floor(item.group_id),
- validity_days: Math.min(36500, Math.max(1, Math.floor(item.validity_days)))
- }))
+ validity_days: Math.min(
+ 36500,
+ Math.max(1, Math.floor(item.validity_days)),
+ ),
+ }));
}
export function buildAuthSourceDefaultsState(
- settings: Partial
+ settings: Partial,
): AuthSourceDefaultsState {
- const raw = settings as Record
+ const raw = settings as Record;
return AUTH_SOURCE_TYPES.reduce((acc, source) => {
- const subscriptions = raw[`auth_source_default_${source}_subscriptions`]
+ const subscriptions = raw[`auth_source_default_${source}_subscriptions`];
acc[source] = {
- balance: Number(raw[`auth_source_default_${source}_balance`] ?? AUTH_SOURCE_DEFAULT_BALANCE),
+ balance: Number(
+ raw[`auth_source_default_${source}_balance`] ??
+ AUTH_SOURCE_DEFAULT_BALANCE,
+ ),
concurrency: Math.max(
1,
- Number(raw[`auth_source_default_${source}_concurrency`] ?? AUTH_SOURCE_DEFAULT_CONCURRENCY)
+ Number(
+ raw[`auth_source_default_${source}_concurrency`] ??
+ AUTH_SOURCE_DEFAULT_CONCURRENCY,
+ ),
),
subscriptions: normalizeDefaultSubscriptionSettings(
- Array.isArray(subscriptions) ? (subscriptions as DefaultSubscriptionSetting[]) : []
+ Array.isArray(subscriptions)
+ ? (subscriptions as DefaultSubscriptionSetting[])
+ : [],
),
- grant_on_signup: raw[`auth_source_default_${source}_grant_on_signup`] !== false,
- grant_on_first_bind: raw[`auth_source_default_${source}_grant_on_first_bind`] === true,
- }
- return acc
- }, {} as AuthSourceDefaultsState)
+ grant_on_signup:
+ raw[`auth_source_default_${source}_grant_on_signup`] !== false,
+ grant_on_first_bind:
+ raw[`auth_source_default_${source}_grant_on_first_bind`] === true,
+ };
+ return acc;
+ }, {} as AuthSourceDefaultsState);
}
export function appendAuthSourceDefaultsToUpdateRequest(
payload: UpdateSettingsRequest,
- authSourceDefaults: AuthSourceDefaultsState
+ authSourceDefaults: AuthSourceDefaultsState,
): UpdateSettingsRequest {
- const target = payload as Record
+ const target = payload as Record;
for (const source of AUTH_SOURCE_TYPES) {
- const current = authSourceDefaults[source]
- target[`auth_source_default_${source}_balance`] = Number(current.balance) || 0
+ const current = authSourceDefaults[source];
+ target[`auth_source_default_${source}_balance`] =
+ Number(current.balance) || 0;
target[`auth_source_default_${source}_concurrency`] = Math.max(
1,
- Math.floor(Number(current.concurrency) || AUTH_SOURCE_DEFAULT_CONCURRENCY)
- )
- target[`auth_source_default_${source}_subscriptions`] = normalizeDefaultSubscriptionSettings(
- current.subscriptions
- )
- target[`auth_source_default_${source}_grant_on_signup`] = current.grant_on_signup
- target[`auth_source_default_${source}_grant_on_first_bind`] = current.grant_on_first_bind
+ Math.floor(
+ Number(current.concurrency) || AUTH_SOURCE_DEFAULT_CONCURRENCY,
+ ),
+ );
+ target[`auth_source_default_${source}_subscriptions`] =
+ normalizeDefaultSubscriptionSettings(current.subscriptions);
+ target[`auth_source_default_${source}_grant_on_signup`] =
+ current.grant_on_signup;
+ target[`auth_source_default_${source}_grant_on_first_bind`] =
+ current.grant_on_first_bind;
}
- return payload
+ return payload;
}
export function getPaymentVisibleMethodSourceOptions(
- method: PaymentVisibleMethod
+ method: PaymentVisibleMethod,
): PaymentVisibleMethodSourceOption[] {
- return PAYMENT_VISIBLE_METHOD_SOURCE_OPTIONS[method]
+ return PAYMENT_VISIBLE_METHOD_SOURCE_OPTIONS[method];
}
export function normalizePaymentVisibleMethodSource(
method: PaymentVisibleMethod,
- source: unknown
+ source: unknown,
): PaymentVisibleMethodSource {
- if (typeof source !== 'string') return ''
+ if (typeof source !== "string") return "";
- const normalized = source.trim().toLowerCase()
- if (!normalized) return ''
+ const normalized = source.trim().toLowerCase();
+ if (!normalized) return "";
- return PAYMENT_VISIBLE_METHOD_SOURCE_ALIASES[method][normalized] ?? ''
+ return PAYMENT_VISIBLE_METHOD_SOURCE_ALIASES[method][normalized] ?? "";
+}
+
+export function getWeChatConnectModeOptions(): WeChatConnectModeOption[] {
+ return WECHAT_CONNECT_MODE_OPTIONS;
+}
+
+export function normalizeWeChatConnectMode(source: unknown): WeChatConnectMode {
+ if (typeof source !== "string") return "open";
+
+ const normalized = source.trim().toLowerCase();
+ if (!normalized) return "open";
+
+ return WECHAT_CONNECT_MODE_ALIASES[normalized] ?? "open";
+}
+
+export function defaultWeChatConnectScopesForMode(mode: unknown): string {
+ return normalizeWeChatConnectMode(mode) === "mp"
+ ? "snsapi_userinfo"
+ : "snsapi_login";
}
/**
@@ -159,293 +244,309 @@ export function normalizePaymentVisibleMethodSource(
*/
export interface SystemSettings {
// Registration settings
- registration_enabled: boolean
- email_verify_enabled: boolean
- registration_email_suffix_whitelist: string[]
- promo_code_enabled: boolean
- password_reset_enabled: boolean
- frontend_url: string
- invitation_code_enabled: boolean
- totp_enabled: boolean // TOTP 双因素认证
- totp_encryption_key_configured: boolean // TOTP 加密密钥是否已配置
+ registration_enabled: boolean;
+ email_verify_enabled: boolean;
+ registration_email_suffix_whitelist: string[];
+ promo_code_enabled: boolean;
+ password_reset_enabled: boolean;
+ frontend_url: string;
+ invitation_code_enabled: boolean;
+ totp_enabled: boolean; // TOTP 双因素认证
+ totp_encryption_key_configured: boolean; // TOTP 加密密钥是否已配置
// Default settings
- default_balance: number
- default_concurrency: number
- default_subscriptions: DefaultSubscriptionSetting[]
- auth_source_default_email_balance?: number
- auth_source_default_email_concurrency?: number
- auth_source_default_email_subscriptions?: DefaultSubscriptionSetting[]
- auth_source_default_email_grant_on_signup?: boolean
- auth_source_default_email_grant_on_first_bind?: boolean
- auth_source_default_linuxdo_balance?: number
- auth_source_default_linuxdo_concurrency?: number
- auth_source_default_linuxdo_subscriptions?: DefaultSubscriptionSetting[]
- auth_source_default_linuxdo_grant_on_signup?: boolean
- auth_source_default_linuxdo_grant_on_first_bind?: boolean
- auth_source_default_oidc_balance?: number
- auth_source_default_oidc_concurrency?: number
- auth_source_default_oidc_subscriptions?: DefaultSubscriptionSetting[]
- auth_source_default_oidc_grant_on_signup?: boolean
- auth_source_default_oidc_grant_on_first_bind?: boolean
- auth_source_default_wechat_balance?: number
- auth_source_default_wechat_concurrency?: number
- auth_source_default_wechat_subscriptions?: DefaultSubscriptionSetting[]
- auth_source_default_wechat_grant_on_signup?: boolean
- auth_source_default_wechat_grant_on_first_bind?: boolean
- force_email_on_third_party_signup?: boolean
+ default_balance: number;
+ default_concurrency: number;
+ default_subscriptions: DefaultSubscriptionSetting[];
+ auth_source_default_email_balance?: number;
+ auth_source_default_email_concurrency?: number;
+ auth_source_default_email_subscriptions?: DefaultSubscriptionSetting[];
+ auth_source_default_email_grant_on_signup?: boolean;
+ auth_source_default_email_grant_on_first_bind?: boolean;
+ auth_source_default_linuxdo_balance?: number;
+ auth_source_default_linuxdo_concurrency?: number;
+ auth_source_default_linuxdo_subscriptions?: DefaultSubscriptionSetting[];
+ auth_source_default_linuxdo_grant_on_signup?: boolean;
+ auth_source_default_linuxdo_grant_on_first_bind?: boolean;
+ auth_source_default_oidc_balance?: number;
+ auth_source_default_oidc_concurrency?: number;
+ auth_source_default_oidc_subscriptions?: DefaultSubscriptionSetting[];
+ auth_source_default_oidc_grant_on_signup?: boolean;
+ auth_source_default_oidc_grant_on_first_bind?: boolean;
+ auth_source_default_wechat_balance?: number;
+ auth_source_default_wechat_concurrency?: number;
+ auth_source_default_wechat_subscriptions?: DefaultSubscriptionSetting[];
+ auth_source_default_wechat_grant_on_signup?: boolean;
+ auth_source_default_wechat_grant_on_first_bind?: boolean;
+ force_email_on_third_party_signup?: boolean;
// OEM settings
- site_name: string
- site_logo: string
- site_subtitle: string
- api_base_url: string
- contact_info: string
- doc_url: string
- home_content: string
- hide_ccs_import_button: boolean
- table_default_page_size: number
- table_page_size_options: number[]
- backend_mode_enabled: boolean
- custom_menu_items: CustomMenuItem[]
- custom_endpoints: CustomEndpoint[]
+ site_name: string;
+ site_logo: string;
+ site_subtitle: string;
+ api_base_url: string;
+ contact_info: string;
+ doc_url: string;
+ home_content: string;
+ hide_ccs_import_button: boolean;
+ table_default_page_size: number;
+ table_page_size_options: number[];
+ backend_mode_enabled: boolean;
+ custom_menu_items: CustomMenuItem[];
+ custom_endpoints: CustomEndpoint[];
// SMTP settings
- smtp_host: string
- smtp_port: number
- smtp_username: string
- smtp_password_configured: boolean
- smtp_from_email: string
- smtp_from_name: string
- smtp_use_tls: boolean
+ smtp_host: string;
+ smtp_port: number;
+ smtp_username: string;
+ smtp_password_configured: boolean;
+ smtp_from_email: string;
+ smtp_from_name: string;
+ smtp_use_tls: boolean;
// Cloudflare Turnstile settings
- turnstile_enabled: boolean
- turnstile_site_key: string
- turnstile_secret_key_configured: boolean
+ turnstile_enabled: boolean;
+ turnstile_site_key: string;
+ turnstile_secret_key_configured: boolean;
// LinuxDo Connect OAuth settings
- linuxdo_connect_enabled: boolean
- linuxdo_connect_client_id: string
- linuxdo_connect_client_secret_configured: boolean
- linuxdo_connect_redirect_url: string
+ linuxdo_connect_enabled: boolean;
+ linuxdo_connect_client_id: string;
+ linuxdo_connect_client_secret_configured: boolean;
+ linuxdo_connect_redirect_url: string;
+
+ // WeChat Connect OAuth settings
+ wechat_connect_enabled: boolean;
+ wechat_connect_app_id: string;
+ wechat_connect_app_secret_configured: boolean;
+ wechat_connect_mode: string;
+ wechat_connect_scopes: string;
+ wechat_connect_redirect_url: string;
+ wechat_connect_frontend_redirect_url: string;
// Generic OIDC OAuth settings
- oidc_connect_enabled: boolean
- oidc_connect_provider_name: string
- oidc_connect_client_id: string
- oidc_connect_client_secret_configured: boolean
- oidc_connect_issuer_url: string
- oidc_connect_discovery_url: string
- oidc_connect_authorize_url: string
- oidc_connect_token_url: string
- oidc_connect_userinfo_url: string
- oidc_connect_jwks_url: string
- oidc_connect_scopes: string
- oidc_connect_redirect_url: string
- oidc_connect_frontend_redirect_url: string
- oidc_connect_token_auth_method: string
- oidc_connect_use_pkce: boolean
- oidc_connect_validate_id_token: boolean
- oidc_connect_allowed_signing_algs: string
- oidc_connect_clock_skew_seconds: number
- oidc_connect_require_email_verified: boolean
- oidc_connect_userinfo_email_path: string
- oidc_connect_userinfo_id_path: string
- oidc_connect_userinfo_username_path: string
+ oidc_connect_enabled: boolean;
+ oidc_connect_provider_name: string;
+ oidc_connect_client_id: string;
+ oidc_connect_client_secret_configured: boolean;
+ oidc_connect_issuer_url: string;
+ oidc_connect_discovery_url: string;
+ oidc_connect_authorize_url: string;
+ oidc_connect_token_url: string;
+ oidc_connect_userinfo_url: string;
+ oidc_connect_jwks_url: string;
+ oidc_connect_scopes: string;
+ oidc_connect_redirect_url: string;
+ oidc_connect_frontend_redirect_url: string;
+ oidc_connect_token_auth_method: string;
+ oidc_connect_use_pkce: boolean;
+ oidc_connect_validate_id_token: boolean;
+ oidc_connect_allowed_signing_algs: string;
+ oidc_connect_clock_skew_seconds: number;
+ oidc_connect_require_email_verified: boolean;
+ oidc_connect_userinfo_email_path: string;
+ oidc_connect_userinfo_id_path: string;
+ oidc_connect_userinfo_username_path: string;
// Model fallback configuration
- enable_model_fallback: boolean
- fallback_model_anthropic: string
- fallback_model_openai: string
- fallback_model_gemini: string
- fallback_model_antigravity: string
+ enable_model_fallback: boolean;
+ fallback_model_anthropic: string;
+ fallback_model_openai: string;
+ fallback_model_gemini: string;
+ fallback_model_antigravity: string;
// Identity patch configuration (Claude -> Gemini)
- enable_identity_patch: boolean
- identity_patch_prompt: string
+ enable_identity_patch: boolean;
+ identity_patch_prompt: string;
// Ops Monitoring (vNext)
- ops_monitoring_enabled: boolean
- ops_realtime_monitoring_enabled: boolean
- ops_query_mode_default: 'auto' | 'raw' | 'preagg' | string
- ops_metrics_interval_seconds: number
+ ops_monitoring_enabled: boolean;
+ ops_realtime_monitoring_enabled: boolean;
+ ops_query_mode_default: "auto" | "raw" | "preagg" | string;
+ ops_metrics_interval_seconds: number;
// Claude Code version check
- min_claude_code_version: string
- max_claude_code_version: string
+ min_claude_code_version: string;
+ max_claude_code_version: string;
// 分组隔离
- allow_ungrouped_key_scheduling: boolean
+ allow_ungrouped_key_scheduling: boolean;
// Gateway forwarding behavior
- enable_fingerprint_unification: boolean
- enable_metadata_passthrough: boolean
- enable_cch_signing: boolean
- web_search_emulation_enabled?: boolean
+ enable_fingerprint_unification: boolean;
+ enable_metadata_passthrough: boolean;
+ enable_cch_signing: boolean;
+ web_search_emulation_enabled?: boolean;
// Payment configuration
- payment_enabled: boolean
- payment_min_amount: number
- payment_max_amount: number
- payment_daily_limit: number
- payment_order_timeout_minutes: number
- payment_max_pending_orders: number
- payment_enabled_types: string[]
- payment_balance_disabled: boolean
- payment_balance_recharge_multiplier: number
- payment_recharge_fee_rate: number
- payment_load_balance_strategy: string
- payment_product_name_prefix: string
- payment_product_name_suffix: string
- payment_help_image_url: string
- payment_help_text: string
- payment_cancel_rate_limit_enabled: boolean
- payment_cancel_rate_limit_max: number
- payment_cancel_rate_limit_window: number
- payment_cancel_rate_limit_unit: string
- payment_cancel_rate_limit_window_mode: string
- payment_visible_method_alipay_source?: string
- payment_visible_method_wxpay_source?: string
- payment_visible_method_alipay_enabled?: boolean
- payment_visible_method_wxpay_enabled?: boolean
- openai_advanced_scheduler_enabled?: boolean
+ payment_enabled: boolean;
+ payment_min_amount: number;
+ payment_max_amount: number;
+ payment_daily_limit: number;
+ payment_order_timeout_minutes: number;
+ payment_max_pending_orders: number;
+ payment_enabled_types: string[];
+ payment_balance_disabled: boolean;
+ payment_balance_recharge_multiplier: number;
+ payment_recharge_fee_rate: number;
+ payment_load_balance_strategy: string;
+ payment_product_name_prefix: string;
+ payment_product_name_suffix: string;
+ payment_help_image_url: string;
+ payment_help_text: string;
+ payment_cancel_rate_limit_enabled: boolean;
+ payment_cancel_rate_limit_max: number;
+ payment_cancel_rate_limit_window: number;
+ payment_cancel_rate_limit_unit: string;
+ payment_cancel_rate_limit_window_mode: string;
+ payment_visible_method_alipay_source?: string;
+ payment_visible_method_wxpay_source?: string;
+ payment_visible_method_alipay_enabled?: boolean;
+ payment_visible_method_wxpay_enabled?: boolean;
+ openai_advanced_scheduler_enabled?: boolean;
// Balance & quota notification
- balance_low_notify_enabled: boolean
- balance_low_notify_threshold: number
- balance_low_notify_recharge_url: string
- account_quota_notify_enabled: boolean
- account_quota_notify_emails: NotifyEmailEntry[]
+ balance_low_notify_enabled: boolean;
+ balance_low_notify_threshold: number;
+ balance_low_notify_recharge_url: string;
+ account_quota_notify_enabled: boolean;
+ account_quota_notify_emails: NotifyEmailEntry[];
}
export interface UpdateSettingsRequest {
- registration_enabled?: boolean
- email_verify_enabled?: boolean
- registration_email_suffix_whitelist?: string[]
- promo_code_enabled?: boolean
- password_reset_enabled?: boolean
- frontend_url?: string
- invitation_code_enabled?: boolean
- totp_enabled?: boolean // TOTP 双因素认证
- default_balance?: number
- default_concurrency?: number
- default_subscriptions?: DefaultSubscriptionSetting[]
- auth_source_default_email_balance?: number
- auth_source_default_email_concurrency?: number
- auth_source_default_email_subscriptions?: DefaultSubscriptionSetting[]
- auth_source_default_email_grant_on_signup?: boolean
- auth_source_default_email_grant_on_first_bind?: boolean
- auth_source_default_linuxdo_balance?: number
- auth_source_default_linuxdo_concurrency?: number
- auth_source_default_linuxdo_subscriptions?: DefaultSubscriptionSetting[]
- auth_source_default_linuxdo_grant_on_signup?: boolean
- auth_source_default_linuxdo_grant_on_first_bind?: boolean
- auth_source_default_oidc_balance?: number
- auth_source_default_oidc_concurrency?: number
- auth_source_default_oidc_subscriptions?: DefaultSubscriptionSetting[]
- auth_source_default_oidc_grant_on_signup?: boolean
- auth_source_default_oidc_grant_on_first_bind?: boolean
- auth_source_default_wechat_balance?: number
- auth_source_default_wechat_concurrency?: number
- auth_source_default_wechat_subscriptions?: DefaultSubscriptionSetting[]
- auth_source_default_wechat_grant_on_signup?: boolean
- auth_source_default_wechat_grant_on_first_bind?: boolean
- force_email_on_third_party_signup?: boolean
- site_name?: string
- site_logo?: string
- site_subtitle?: string
- api_base_url?: string
- contact_info?: string
- doc_url?: string
- home_content?: string
- hide_ccs_import_button?: boolean
- table_default_page_size?: number
- table_page_size_options?: number[]
- backend_mode_enabled?: boolean
- custom_menu_items?: CustomMenuItem[]
- custom_endpoints?: CustomEndpoint[]
- smtp_host?: string
- smtp_port?: number
- smtp_username?: string
- smtp_password?: string
- smtp_from_email?: string
- smtp_from_name?: string
- smtp_use_tls?: boolean
- turnstile_enabled?: boolean
- turnstile_site_key?: string
- turnstile_secret_key?: string
- linuxdo_connect_enabled?: boolean
- linuxdo_connect_client_id?: string
- linuxdo_connect_client_secret?: string
- linuxdo_connect_redirect_url?: string
- oidc_connect_enabled?: boolean
- oidc_connect_provider_name?: string
- oidc_connect_client_id?: string
- oidc_connect_client_secret?: string
- oidc_connect_issuer_url?: string
- oidc_connect_discovery_url?: string
- oidc_connect_authorize_url?: string
- oidc_connect_token_url?: string
- oidc_connect_userinfo_url?: string
- oidc_connect_jwks_url?: string
- oidc_connect_scopes?: string
- oidc_connect_redirect_url?: string
- oidc_connect_frontend_redirect_url?: string
- oidc_connect_token_auth_method?: string
- oidc_connect_use_pkce?: boolean
- oidc_connect_validate_id_token?: boolean
- oidc_connect_allowed_signing_algs?: string
- oidc_connect_clock_skew_seconds?: number
- oidc_connect_require_email_verified?: boolean
- oidc_connect_userinfo_email_path?: string
- oidc_connect_userinfo_id_path?: string
- oidc_connect_userinfo_username_path?: string
- enable_model_fallback?: boolean
- fallback_model_anthropic?: string
- fallback_model_openai?: string
- fallback_model_gemini?: string
- fallback_model_antigravity?: string
- enable_identity_patch?: boolean
- identity_patch_prompt?: string
- ops_monitoring_enabled?: boolean
- ops_realtime_monitoring_enabled?: boolean
- ops_query_mode_default?: 'auto' | 'raw' | 'preagg' | string
- ops_metrics_interval_seconds?: number
- min_claude_code_version?: string
- max_claude_code_version?: string
- allow_ungrouped_key_scheduling?: boolean
- enable_fingerprint_unification?: boolean
- enable_metadata_passthrough?: boolean
- enable_cch_signing?: boolean
+ registration_enabled?: boolean;
+ email_verify_enabled?: boolean;
+ registration_email_suffix_whitelist?: string[];
+ promo_code_enabled?: boolean;
+ password_reset_enabled?: boolean;
+ frontend_url?: string;
+ invitation_code_enabled?: boolean;
+ totp_enabled?: boolean; // TOTP 双因素认证
+ default_balance?: number;
+ default_concurrency?: number;
+ default_subscriptions?: DefaultSubscriptionSetting[];
+ auth_source_default_email_balance?: number;
+ auth_source_default_email_concurrency?: number;
+ auth_source_default_email_subscriptions?: DefaultSubscriptionSetting[];
+ auth_source_default_email_grant_on_signup?: boolean;
+ auth_source_default_email_grant_on_first_bind?: boolean;
+ auth_source_default_linuxdo_balance?: number;
+ auth_source_default_linuxdo_concurrency?: number;
+ auth_source_default_linuxdo_subscriptions?: DefaultSubscriptionSetting[];
+ auth_source_default_linuxdo_grant_on_signup?: boolean;
+ auth_source_default_linuxdo_grant_on_first_bind?: boolean;
+ auth_source_default_oidc_balance?: number;
+ auth_source_default_oidc_concurrency?: number;
+ auth_source_default_oidc_subscriptions?: DefaultSubscriptionSetting[];
+ auth_source_default_oidc_grant_on_signup?: boolean;
+ auth_source_default_oidc_grant_on_first_bind?: boolean;
+ auth_source_default_wechat_balance?: number;
+ auth_source_default_wechat_concurrency?: number;
+ auth_source_default_wechat_subscriptions?: DefaultSubscriptionSetting[];
+ auth_source_default_wechat_grant_on_signup?: boolean;
+ auth_source_default_wechat_grant_on_first_bind?: boolean;
+ force_email_on_third_party_signup?: boolean;
+ site_name?: string;
+ site_logo?: string;
+ site_subtitle?: string;
+ api_base_url?: string;
+ contact_info?: string;
+ doc_url?: string;
+ home_content?: string;
+ hide_ccs_import_button?: boolean;
+ table_default_page_size?: number;
+ table_page_size_options?: number[];
+ backend_mode_enabled?: boolean;
+ custom_menu_items?: CustomMenuItem[];
+ custom_endpoints?: CustomEndpoint[];
+ smtp_host?: string;
+ smtp_port?: number;
+ smtp_username?: string;
+ smtp_password?: string;
+ smtp_from_email?: string;
+ smtp_from_name?: string;
+ smtp_use_tls?: boolean;
+ turnstile_enabled?: boolean;
+ turnstile_site_key?: string;
+ turnstile_secret_key?: string;
+ linuxdo_connect_enabled?: boolean;
+ linuxdo_connect_client_id?: string;
+ linuxdo_connect_client_secret?: string;
+ linuxdo_connect_redirect_url?: string;
+ wechat_connect_enabled?: boolean;
+ wechat_connect_app_id?: string;
+ wechat_connect_app_secret?: string;
+ wechat_connect_mode?: string;
+ wechat_connect_scopes?: string;
+ wechat_connect_redirect_url?: string;
+ wechat_connect_frontend_redirect_url?: string;
+ oidc_connect_enabled?: boolean;
+ oidc_connect_provider_name?: string;
+ oidc_connect_client_id?: string;
+ oidc_connect_client_secret?: string;
+ oidc_connect_issuer_url?: string;
+ oidc_connect_discovery_url?: string;
+ oidc_connect_authorize_url?: string;
+ oidc_connect_token_url?: string;
+ oidc_connect_userinfo_url?: string;
+ oidc_connect_jwks_url?: string;
+ oidc_connect_scopes?: string;
+ oidc_connect_redirect_url?: string;
+ oidc_connect_frontend_redirect_url?: string;
+ oidc_connect_token_auth_method?: string;
+ oidc_connect_use_pkce?: boolean;
+ oidc_connect_validate_id_token?: boolean;
+ oidc_connect_allowed_signing_algs?: string;
+ oidc_connect_clock_skew_seconds?: number;
+ oidc_connect_require_email_verified?: boolean;
+ oidc_connect_userinfo_email_path?: string;
+ oidc_connect_userinfo_id_path?: string;
+ oidc_connect_userinfo_username_path?: string;
+ enable_model_fallback?: boolean;
+ fallback_model_anthropic?: string;
+ fallback_model_openai?: string;
+ fallback_model_gemini?: string;
+ fallback_model_antigravity?: string;
+ enable_identity_patch?: boolean;
+ identity_patch_prompt?: string;
+ ops_monitoring_enabled?: boolean;
+ ops_realtime_monitoring_enabled?: boolean;
+ ops_query_mode_default?: "auto" | "raw" | "preagg" | string;
+ ops_metrics_interval_seconds?: number;
+ min_claude_code_version?: string;
+ max_claude_code_version?: string;
+ allow_ungrouped_key_scheduling?: boolean;
+ enable_fingerprint_unification?: boolean;
+ enable_metadata_passthrough?: boolean;
+ enable_cch_signing?: boolean;
// Payment configuration
- payment_enabled?: boolean
- payment_min_amount?: number
- payment_max_amount?: number
- payment_daily_limit?: number
- payment_order_timeout_minutes?: number
- payment_max_pending_orders?: number
- payment_enabled_types?: string[]
- payment_balance_disabled?: boolean
- payment_balance_recharge_multiplier?: number
- payment_recharge_fee_rate?: number
- payment_load_balance_strategy?: string
- payment_product_name_prefix?: string
- payment_product_name_suffix?: string
- payment_help_image_url?: string
- payment_help_text?: string
- payment_cancel_rate_limit_enabled?: boolean
- payment_cancel_rate_limit_max?: number
- payment_cancel_rate_limit_window?: number
- payment_cancel_rate_limit_unit?: string
- payment_cancel_rate_limit_window_mode?: string
- payment_visible_method_alipay_source?: string
- payment_visible_method_wxpay_source?: string
- payment_visible_method_alipay_enabled?: boolean
- payment_visible_method_wxpay_enabled?: boolean
- openai_advanced_scheduler_enabled?: boolean
+ payment_enabled?: boolean;
+ payment_min_amount?: number;
+ payment_max_amount?: number;
+ payment_daily_limit?: number;
+ payment_order_timeout_minutes?: number;
+ payment_max_pending_orders?: number;
+ payment_enabled_types?: string[];
+ payment_balance_disabled?: boolean;
+ payment_balance_recharge_multiplier?: number;
+ payment_recharge_fee_rate?: number;
+ payment_load_balance_strategy?: string;
+ payment_product_name_prefix?: string;
+ payment_product_name_suffix?: string;
+ payment_help_image_url?: string;
+ payment_help_text?: string;
+ payment_cancel_rate_limit_enabled?: boolean;
+ payment_cancel_rate_limit_max?: number;
+ payment_cancel_rate_limit_window?: number;
+ payment_cancel_rate_limit_unit?: string;
+ payment_cancel_rate_limit_window_mode?: string;
+ payment_visible_method_alipay_source?: string;
+ payment_visible_method_wxpay_source?: string;
+ payment_visible_method_alipay_enabled?: boolean;
+ payment_visible_method_wxpay_enabled?: boolean;
+ openai_advanced_scheduler_enabled?: boolean;
// Balance & quota notification
- balance_low_notify_enabled?: boolean
- balance_low_notify_threshold?: number
- balance_low_notify_recharge_url?: string
- account_quota_notify_enabled?: boolean
- account_quota_notify_emails?: NotifyEmailEntry[]
+ balance_low_notify_enabled?: boolean;
+ balance_low_notify_threshold?: number;
+ balance_low_notify_recharge_url?: string;
+ account_quota_notify_enabled?: boolean;
+ account_quota_notify_emails?: NotifyEmailEntry[];
}
/**
@@ -453,8 +554,8 @@ export interface UpdateSettingsRequest {
* @returns System settings
*/
export async function getSettings(): Promise {
- const { data } = await apiClient.get('/admin/settings')
- return data
+ const { data } = await apiClient.get("/admin/settings");
+ return data;
}
/**
@@ -462,20 +563,25 @@ export async function getSettings(): Promise {
* @param settings - Partial settings to update
* @returns Updated settings
*/
-export async function updateSettings(settings: UpdateSettingsRequest): Promise {
- const { data } = await apiClient.put('/admin/settings', settings)
- return data
+export async function updateSettings(
+ settings: UpdateSettingsRequest,
+): Promise {
+ const { data } = await apiClient.put(
+ "/admin/settings",
+ settings,
+ );
+ return data;
}
/**
* Test SMTP connection request
*/
export interface TestSmtpRequest {
- smtp_host: string
- smtp_port: number
- smtp_username: string
- smtp_password: string
- smtp_use_tls: boolean
+ smtp_host: string;
+ smtp_port: number;
+ smtp_username: string;
+ smtp_password: string;
+ smtp_use_tls: boolean;
}
/**
@@ -483,23 +589,28 @@ export interface TestSmtpRequest {
* @param config - SMTP configuration to test
* @returns Test result message
*/
-export async function testSmtpConnection(config: TestSmtpRequest): Promise<{ message: string }> {
- const { data } = await apiClient.post<{ message: string }>('/admin/settings/test-smtp', config)
- return data
+export async function testSmtpConnection(
+ config: TestSmtpRequest,
+): Promise<{ message: string }> {
+ const { data } = await apiClient.post<{ message: string }>(
+ "/admin/settings/test-smtp",
+ config,
+ );
+ return data;
}
/**
* Send test email request
*/
export interface SendTestEmailRequest {
- email: string
- smtp_host: string
- smtp_port: number
- smtp_username: string
- smtp_password: string
- smtp_from_email: string
- smtp_from_name: string
- smtp_use_tls: boolean
+ email: string;
+ smtp_host: string;
+ smtp_port: number;
+ smtp_username: string;
+ smtp_password: string;
+ smtp_from_email: string;
+ smtp_from_name: string;
+ smtp_use_tls: boolean;
}
/**
@@ -507,20 +618,22 @@ export interface SendTestEmailRequest {
* @param request - Email address and SMTP config
* @returns Test result message
*/
-export async function sendTestEmail(request: SendTestEmailRequest): Promise<{ message: string }> {
+export async function sendTestEmail(
+ request: SendTestEmailRequest,
+): Promise<{ message: string }> {
const { data } = await apiClient.post<{ message: string }>(
- '/admin/settings/send-test-email',
- request
- )
- return data
+ "/admin/settings/send-test-email",
+ request,
+ );
+ return data;
}
/**
* Admin API Key status response
*/
export interface AdminApiKeyStatus {
- exists: boolean
- masked_key: string
+ exists: boolean;
+ masked_key: string;
}
/**
@@ -528,8 +641,10 @@ export interface AdminApiKeyStatus {
* @returns Status indicating if key exists and masked version
*/
export async function getAdminApiKey(): Promise {
- const { data } = await apiClient.get('/admin/settings/admin-api-key')
- return data
+ const { data } = await apiClient.get(
+ "/admin/settings/admin-api-key",
+ );
+ return data;
}
/**
@@ -537,8 +652,10 @@ export async function getAdminApiKey(): Promise {
* @returns The new full API key (only shown once)
*/
export async function regenerateAdminApiKey(): Promise<{ key: string }> {
- const { data } = await apiClient.post<{ key: string }>('/admin/settings/admin-api-key/regenerate')
- return data
+ const { data } = await apiClient.post<{ key: string }>(
+ "/admin/settings/admin-api-key/regenerate",
+ );
+ return data;
}
/**
@@ -546,8 +663,10 @@ export async function regenerateAdminApiKey(): Promise<{ key: string }> {
* @returns Success message
*/
export async function deleteAdminApiKey(): Promise<{ message: string }> {
- const { data } = await apiClient.delete<{ message: string }>('/admin/settings/admin-api-key')
- return data
+ const { data } = await apiClient.delete<{ message: string }>(
+ "/admin/settings/admin-api-key",
+ );
+ return data;
}
// ==================== Overload Cooldown Settings ====================
@@ -556,23 +675,25 @@ export async function deleteAdminApiKey(): Promise<{ message: string }> {
* Overload cooldown settings interface (529 handling)
*/
export interface OverloadCooldownSettings {
- enabled: boolean
- cooldown_minutes: number
+ enabled: boolean;
+ cooldown_minutes: number;
}
export async function getOverloadCooldownSettings(): Promise {
- const { data } = await apiClient.get('/admin/settings/overload-cooldown')
- return data
+ const { data } = await apiClient.get(
+ "/admin/settings/overload-cooldown",
+ );
+ return data;
}
export async function updateOverloadCooldownSettings(
- settings: OverloadCooldownSettings
+ settings: OverloadCooldownSettings,
): Promise {
const { data } = await apiClient.put(
- '/admin/settings/overload-cooldown',
- settings
- )
- return data
+ "/admin/settings/overload-cooldown",
+ settings,
+ );
+ return data;
}
// ==================== Stream Timeout Settings ====================
@@ -581,11 +702,11 @@ export async function updateOverloadCooldownSettings(
* Stream timeout settings interface
*/
export interface StreamTimeoutSettings {
- enabled: boolean
- action: 'temp_unsched' | 'error' | 'none'
- temp_unsched_minutes: number
- threshold_count: number
- threshold_window_minutes: number
+ enabled: boolean;
+ action: "temp_unsched" | "error" | "none";
+ temp_unsched_minutes: number;
+ threshold_count: number;
+ threshold_window_minutes: number;
}
/**
@@ -593,8 +714,10 @@ export interface StreamTimeoutSettings {
* @returns Stream timeout settings
*/
export async function getStreamTimeoutSettings(): Promise {
- const { data } = await apiClient.get('/admin/settings/stream-timeout')
- return data
+ const { data } = await apiClient.get(
+ "/admin/settings/stream-timeout",
+ );
+ return data;
}
/**
@@ -603,13 +726,13 @@ export async function getStreamTimeoutSettings(): Promise
* @returns Updated settings
*/
export async function updateStreamTimeoutSettings(
- settings: StreamTimeoutSettings
+ settings: StreamTimeoutSettings,
): Promise {
const { data } = await apiClient.put(
- '/admin/settings/stream-timeout',
- settings
- )
- return data
+ "/admin/settings/stream-timeout",
+ settings,
+ );
+ return data;
}
// ==================== Rectifier Settings ====================
@@ -618,11 +741,11 @@ export async function updateStreamTimeoutSettings(
* Rectifier settings interface
*/
export interface RectifierSettings {
- enabled: boolean
- thinking_signature_enabled: boolean
- thinking_budget_enabled: boolean
- apikey_signature_enabled: boolean
- apikey_signature_patterns: string[]
+ enabled: boolean;
+ thinking_signature_enabled: boolean;
+ thinking_budget_enabled: boolean;
+ apikey_signature_enabled: boolean;
+ apikey_signature_patterns: string[];
}
/**
@@ -630,8 +753,10 @@ export interface RectifierSettings {
* @returns Rectifier settings
*/
export async function getRectifierSettings(): Promise {
- const { data } = await apiClient.get('/admin/settings/rectifier')
- return data
+ const { data } = await apiClient.get(
+ "/admin/settings/rectifier",
+ );
+ return data;
}
/**
@@ -640,13 +765,13 @@ export async function getRectifierSettings(): Promise {
* @returns Updated settings
*/
export async function updateRectifierSettings(
- settings: RectifierSettings
+ settings: RectifierSettings,
): Promise {
const { data } = await apiClient.put(
- '/admin/settings/rectifier',
- settings
- )
- return data
+ "/admin/settings/rectifier",
+ settings,
+ );
+ return data;
}
// ==================== Beta Policy Settings ====================
@@ -655,20 +780,20 @@ export async function updateRectifierSettings(
* Beta policy rule interface
*/
export interface BetaPolicyRule {
- beta_token: string
- action: 'pass' | 'filter' | 'block'
- scope: 'all' | 'oauth' | 'apikey' | 'bedrock'
- error_message?: string
- model_whitelist?: string[]
- fallback_action?: 'pass' | 'filter' | 'block'
- fallback_error_message?: string
+ beta_token: string;
+ action: "pass" | "filter" | "block";
+ scope: "all" | "oauth" | "apikey" | "bedrock";
+ error_message?: string;
+ model_whitelist?: string[];
+ fallback_action?: "pass" | "filter" | "block";
+ fallback_error_message?: string;
}
/**
* Beta policy settings interface
*/
export interface BetaPolicySettings {
- rules: BetaPolicyRule[]
+ rules: BetaPolicyRule[];
}
/**
@@ -676,8 +801,10 @@ export interface BetaPolicySettings {
* @returns Beta policy settings
*/
export async function getBetaPolicySettings(): Promise {
- const { data } = await apiClient.get('/admin/settings/beta-policy')
- return data
+ const { data } = await apiClient.get(
+ "/admin/settings/beta-policy",
+ );
+ return data;
}
/**
@@ -686,70 +813,73 @@ export async function getBetaPolicySettings(): Promise {
* @returns Updated settings
*/
export async function updateBetaPolicySettings(
- settings: BetaPolicySettings
+ settings: BetaPolicySettings,
): Promise {
const { data } = await apiClient.put(
- '/admin/settings/beta-policy',
- settings
- )
- return data
+ "/admin/settings/beta-policy",
+ settings,
+ );
+ return data;
}
// --- Web Search Emulation Config ---
export interface WebSearchProviderConfig {
- type: 'brave' | 'tavily'
- api_key: string
- api_key_configured: boolean
- quota_limit: number | null
- subscribed_at: number | null
- quota_used?: number
- proxy_id: number | null
- expires_at: number | null
+ type: "brave" | "tavily";
+ api_key: string;
+ api_key_configured: boolean;
+ quota_limit: number | null;
+ subscribed_at: number | null;
+ quota_used?: number;
+ proxy_id: number | null;
+ expires_at: number | null;
}
export interface WebSearchEmulationConfig {
- enabled: boolean
- providers: WebSearchProviderConfig[]
+ enabled: boolean;
+ providers: WebSearchProviderConfig[];
}
export interface WebSearchTestResult {
- provider: string
- results: { url: string; title: string; snippet: string; page_age?: string }[]
- query: string
+ provider: string;
+ results: { url: string; title: string; snippet: string; page_age?: string }[];
+ query: string;
}
export async function getWebSearchEmulationConfig(): Promise {
const { data } = await apiClient.get(
- '/admin/settings/web-search-emulation'
- )
- return data
+ "/admin/settings/web-search-emulation",
+ );
+ return data;
}
export async function updateWebSearchEmulationConfig(
- config: WebSearchEmulationConfig
+ config: WebSearchEmulationConfig,
): Promise {
const { data } = await apiClient.put(
- '/admin/settings/web-search-emulation',
- config
- )
- return data
+ "/admin/settings/web-search-emulation",
+ config,
+ );
+ return data;
}
export async function testWebSearchEmulation(
- query: string
+ query: string,
): Promise {
const { data } = await apiClient.post(
- '/admin/settings/web-search-emulation/test',
- { query }
- )
- return data
+ "/admin/settings/web-search-emulation/test",
+ { query },
+ );
+ return data;
}
-export async function resetWebSearchUsage(
- payload: { provider_type: string }
-): Promise {
- await apiClient.post('/admin/settings/web-search-emulation/reset-usage', payload)
+export async function resetWebSearchUsage(payload: {
+ provider_type: string;
+}): Promise {
+ await apiClient.post(
+ "/admin/settings/web-search-emulation/reset-usage",
+ payload,
+ );
}
export const settingsAPI = {
@@ -771,7 +901,7 @@ export const settingsAPI = {
getWebSearchEmulationConfig,
updateWebSearchEmulationConfig,
testWebSearchEmulation,
- resetWebSearchUsage
-}
+ resetWebSearchUsage,
+};
-export default settingsAPI
+export default settingsAPI;
diff --git a/frontend/src/views/admin/SettingsView.vue b/frontend/src/views/admin/SettingsView.vue
index e56afe5f..9612d9f9 100644
--- a/frontend/src/views/admin/SettingsView.vue
+++ b/frontend/src/views/admin/SettingsView.vue
@@ -3,7 +3,9 @@