feat(channels): add custom account stats pricing rules
Allow channels to configure independent model pricing for account statistics cost calculation, decoupled from user billing. Backend: - Migration 101: channels.apply_pricing_to_account_stats toggle, channel_account_stats_pricing_rules/model_pricing tables, usage_logs.account_stats_cost column - resolveAccountStatsCost: match rules by group/account, then channel pricing, fallback to original formula when unconfigured - Integrate into both GatewayService.recordUsageCore and OpenAIGatewayService.RecordUsage - Update 8 account stats SQL queries to use COALESCE(account_stats_cost, total_cost) * account_rate_multiplier - 23 unit tests for matching, pricing lookup, and cost calculation Frontend: - Channel edit dialog: toggle + custom rules UI with group/account multi-select and pricing entry cards - API types and i18n (zh/en)
This commit is contained in:
parent
7fad9f604f
commit
7535e312e0
@ -26,28 +26,30 @@ func NewChannelHandler(channelService *service.ChannelService, billingService *s
|
||||
// --- Request / Response types ---
|
||||
|
||||
type createChannelRequest struct {
|
||||
Name string `json:"name" binding:"required,max=100"`
|
||||
Description string `json:"description"`
|
||||
GroupIDs []int64 `json:"group_ids"`
|
||||
ModelPricing []channelModelPricingRequest `json:"model_pricing"`
|
||||
ModelMapping map[string]map[string]string `json:"model_mapping"`
|
||||
BillingModelSource string `json:"billing_model_source" binding:"omitempty,oneof=requested upstream channel_mapped"`
|
||||
RestrictModels bool `json:"restrict_models"`
|
||||
Features string `json:"features"`
|
||||
FeaturesConfig map[string]any `json:"features_config"`
|
||||
Name string `json:"name" binding:"required,max=100"`
|
||||
Description string `json:"description"`
|
||||
GroupIDs []int64 `json:"group_ids"`
|
||||
ModelPricing []channelModelPricingRequest `json:"model_pricing"`
|
||||
ModelMapping map[string]map[string]string `json:"model_mapping"`
|
||||
BillingModelSource string `json:"billing_model_source" binding:"omitempty,oneof=requested upstream channel_mapped"`
|
||||
RestrictModels bool `json:"restrict_models"`
|
||||
Features string `json:"features"`
|
||||
ApplyPricingToAccountStats bool `json:"apply_pricing_to_account_stats"`
|
||||
AccountStatsPricingRules []accountStatsPricingRuleRequest `json:"account_stats_pricing_rules"`
|
||||
}
|
||||
|
||||
type updateChannelRequest struct {
|
||||
Name string `json:"name" binding:"omitempty,max=100"`
|
||||
Description *string `json:"description"`
|
||||
Status string `json:"status" binding:"omitempty,oneof=active disabled"`
|
||||
GroupIDs *[]int64 `json:"group_ids"`
|
||||
ModelPricing *[]channelModelPricingRequest `json:"model_pricing"`
|
||||
ModelMapping map[string]map[string]string `json:"model_mapping"`
|
||||
BillingModelSource string `json:"billing_model_source" binding:"omitempty,oneof=requested upstream channel_mapped"`
|
||||
RestrictModels *bool `json:"restrict_models"`
|
||||
Features *string `json:"features"`
|
||||
FeaturesConfig map[string]any `json:"features_config"`
|
||||
Name string `json:"name" binding:"omitempty,max=100"`
|
||||
Description *string `json:"description"`
|
||||
Status string `json:"status" binding:"omitempty,oneof=active disabled"`
|
||||
GroupIDs *[]int64 `json:"group_ids"`
|
||||
ModelPricing *[]channelModelPricingRequest `json:"model_pricing"`
|
||||
ModelMapping map[string]map[string]string `json:"model_mapping"`
|
||||
BillingModelSource string `json:"billing_model_source" binding:"omitempty,oneof=requested upstream channel_mapped"`
|
||||
RestrictModels *bool `json:"restrict_models"`
|
||||
Features *string `json:"features"`
|
||||
ApplyPricingToAccountStats *bool `json:"apply_pricing_to_account_stats"`
|
||||
AccountStatsPricingRules *[]accountStatsPricingRuleRequest `json:"account_stats_pricing_rules"`
|
||||
}
|
||||
|
||||
type channelModelPricingRequest struct {
|
||||
@ -75,20 +77,28 @@ type pricingIntervalRequest struct {
|
||||
SortOrder int `json:"sort_order"`
|
||||
}
|
||||
|
||||
type accountStatsPricingRuleRequest struct {
|
||||
Name string `json:"name"`
|
||||
GroupIDs []int64 `json:"group_ids"`
|
||||
AccountIDs []int64 `json:"account_ids"`
|
||||
Pricing []channelModelPricingRequest `json:"pricing"`
|
||||
}
|
||||
|
||||
type channelResponse struct {
|
||||
ID int64 `json:"id"`
|
||||
Name string `json:"name"`
|
||||
Description string `json:"description"`
|
||||
Status string `json:"status"`
|
||||
BillingModelSource string `json:"billing_model_source"`
|
||||
RestrictModels bool `json:"restrict_models"`
|
||||
Features string `json:"features"`
|
||||
FeaturesConfig map[string]any `json:"features_config"`
|
||||
GroupIDs []int64 `json:"group_ids"`
|
||||
ModelPricing []channelModelPricingResponse `json:"model_pricing"`
|
||||
ModelMapping map[string]map[string]string `json:"model_mapping"`
|
||||
CreatedAt string `json:"created_at"`
|
||||
UpdatedAt string `json:"updated_at"`
|
||||
ID int64 `json:"id"`
|
||||
Name string `json:"name"`
|
||||
Description string `json:"description"`
|
||||
Status string `json:"status"`
|
||||
BillingModelSource string `json:"billing_model_source"`
|
||||
RestrictModels bool `json:"restrict_models"`
|
||||
Features string `json:"features"`
|
||||
GroupIDs []int64 `json:"group_ids"`
|
||||
ModelPricing []channelModelPricingResponse `json:"model_pricing"`
|
||||
ModelMapping map[string]map[string]string `json:"model_mapping"`
|
||||
ApplyPricingToAccountStats bool `json:"apply_pricing_to_account_stats"`
|
||||
AccountStatsPricingRules []accountStatsPricingRuleResponse `json:"account_stats_pricing_rules"`
|
||||
CreatedAt string `json:"created_at"`
|
||||
UpdatedAt string `json:"updated_at"`
|
||||
}
|
||||
|
||||
type channelModelPricingResponse struct {
|
||||
@ -118,6 +128,14 @@ type pricingIntervalResponse struct {
|
||||
SortOrder int `json:"sort_order"`
|
||||
}
|
||||
|
||||
type accountStatsPricingRuleResponse struct {
|
||||
ID int64 `json:"id"`
|
||||
Name string `json:"name"`
|
||||
GroupIDs []int64 `json:"group_ids"`
|
||||
AccountIDs []int64 `json:"account_ids"`
|
||||
Pricing []channelModelPricingResponse `json:"pricing"`
|
||||
}
|
||||
|
||||
func channelToResponse(ch *service.Channel) *channelResponse {
|
||||
if ch == nil {
|
||||
return nil
|
||||
@ -129,7 +147,6 @@ func channelToResponse(ch *service.Channel) *channelResponse {
|
||||
Status: ch.Status,
|
||||
RestrictModels: ch.RestrictModels,
|
||||
Features: ch.Features,
|
||||
FeaturesConfig: ch.FeaturesConfig,
|
||||
GroupIDs: ch.GroupIDs,
|
||||
ModelMapping: ch.ModelMapping,
|
||||
CreatedAt: ch.CreatedAt.Format("2006-01-02T15:04:05Z"),
|
||||
@ -150,6 +167,29 @@ func channelToResponse(ch *service.Channel) *channelResponse {
|
||||
for _, p := range ch.ModelPricing {
|
||||
resp.ModelPricing = append(resp.ModelPricing, pricingToResponse(&p))
|
||||
}
|
||||
|
||||
resp.ApplyPricingToAccountStats = ch.ApplyPricingToAccountStats
|
||||
resp.AccountStatsPricingRules = make([]accountStatsPricingRuleResponse, 0, len(ch.AccountStatsPricingRules))
|
||||
for _, rule := range ch.AccountStatsPricingRules {
|
||||
ruleResp := accountStatsPricingRuleResponse{
|
||||
ID: rule.ID,
|
||||
Name: rule.Name,
|
||||
GroupIDs: rule.GroupIDs,
|
||||
AccountIDs: rule.AccountIDs,
|
||||
Pricing: make([]channelModelPricingResponse, 0, len(rule.Pricing)),
|
||||
}
|
||||
if ruleResp.GroupIDs == nil {
|
||||
ruleResp.GroupIDs = []int64{}
|
||||
}
|
||||
if ruleResp.AccountIDs == nil {
|
||||
ruleResp.AccountIDs = []int64{}
|
||||
}
|
||||
for i := range rule.Pricing {
|
||||
ruleResp.Pricing = append(ruleResp.Pricing, pricingToResponse(&rule.Pricing[i]))
|
||||
}
|
||||
resp.AccountStatsPricingRules = append(resp.AccountStatsPricingRules, ruleResp)
|
||||
}
|
||||
|
||||
return resp
|
||||
}
|
||||
|
||||
@ -241,6 +281,15 @@ func pricingRequestToService(reqs []channelModelPricingRequest) []service.Channe
|
||||
return result
|
||||
}
|
||||
|
||||
func accountStatsPricingRuleRequestToService(r accountStatsPricingRuleRequest) service.AccountStatsPricingRule {
|
||||
return service.AccountStatsPricingRule{
|
||||
Name: r.Name,
|
||||
GroupIDs: r.GroupIDs,
|
||||
AccountIDs: r.AccountIDs,
|
||||
Pricing: pricingRequestToService(r.Pricing),
|
||||
}
|
||||
}
|
||||
|
||||
// --- Handlers ---
|
||||
|
||||
// List handles listing channels with pagination
|
||||
@ -300,16 +349,24 @@ func (h *ChannelHandler) Create(c *gin.Context) {
|
||||
|
||||
pricing := pricingRequestToService(req.ModelPricing)
|
||||
|
||||
var statsRules []service.AccountStatsPricingRule
|
||||
for i, r := range req.AccountStatsPricingRules {
|
||||
rule := accountStatsPricingRuleRequestToService(r)
|
||||
rule.SortOrder = i
|
||||
statsRules = append(statsRules, rule)
|
||||
}
|
||||
|
||||
channel, err := h.channelService.Create(c.Request.Context(), &service.CreateChannelInput{
|
||||
Name: req.Name,
|
||||
Description: req.Description,
|
||||
GroupIDs: req.GroupIDs,
|
||||
ModelPricing: pricing,
|
||||
ModelMapping: req.ModelMapping,
|
||||
BillingModelSource: req.BillingModelSource,
|
||||
RestrictModels: req.RestrictModels,
|
||||
Features: req.Features,
|
||||
FeaturesConfig: req.FeaturesConfig,
|
||||
Name: req.Name,
|
||||
Description: req.Description,
|
||||
GroupIDs: req.GroupIDs,
|
||||
ModelPricing: pricing,
|
||||
ModelMapping: req.ModelMapping,
|
||||
BillingModelSource: req.BillingModelSource,
|
||||
RestrictModels: req.RestrictModels,
|
||||
Features: req.Features,
|
||||
ApplyPricingToAccountStats: req.ApplyPricingToAccountStats,
|
||||
AccountStatsPricingRules: statsRules,
|
||||
})
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
@ -335,20 +392,29 @@ func (h *ChannelHandler) Update(c *gin.Context) {
|
||||
}
|
||||
|
||||
input := &service.UpdateChannelInput{
|
||||
Name: req.Name,
|
||||
Description: req.Description,
|
||||
Status: req.Status,
|
||||
GroupIDs: req.GroupIDs,
|
||||
ModelMapping: req.ModelMapping,
|
||||
BillingModelSource: req.BillingModelSource,
|
||||
RestrictModels: req.RestrictModels,
|
||||
Features: req.Features,
|
||||
FeaturesConfig: req.FeaturesConfig,
|
||||
Name: req.Name,
|
||||
Description: req.Description,
|
||||
Status: req.Status,
|
||||
GroupIDs: req.GroupIDs,
|
||||
ModelMapping: req.ModelMapping,
|
||||
BillingModelSource: req.BillingModelSource,
|
||||
RestrictModels: req.RestrictModels,
|
||||
Features: req.Features,
|
||||
ApplyPricingToAccountStats: req.ApplyPricingToAccountStats,
|
||||
}
|
||||
if req.ModelPricing != nil {
|
||||
pricing := pricingRequestToService(*req.ModelPricing)
|
||||
input.ModelPricing = &pricing
|
||||
}
|
||||
if req.AccountStatsPricingRules != nil {
|
||||
statsRules := make([]service.AccountStatsPricingRule, 0, len(*req.AccountStatsPricingRules))
|
||||
for i, r := range *req.AccountStatsPricingRules {
|
||||
rule := accountStatsPricingRuleRequestToService(r)
|
||||
rule.SortOrder = i
|
||||
statsRules = append(statsRules, rule)
|
||||
}
|
||||
input.AccountStatsPricingRules = &statsRules
|
||||
}
|
||||
|
||||
channel, err := h.channelService.Update(c.Request.Context(), id, input)
|
||||
if err != nil {
|
||||
|
||||
@ -41,14 +41,10 @@ func (r *channelRepository) Create(ctx context.Context, channel *service.Channel
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
featuresConfigJSON, err := marshalFeaturesConfig(channel.FeaturesConfig)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
err = tx.QueryRowContext(ctx,
|
||||
`INSERT INTO channels (name, description, status, model_mapping, billing_model_source, restrict_models, features, features_config) VALUES ($1, $2, $3, $4, $5, $6, $7, $8)
|
||||
`INSERT INTO channels (name, description, status, model_mapping, billing_model_source, restrict_models, features, apply_pricing_to_account_stats) VALUES ($1, $2, $3, $4, $5, $6, $7, $8)
|
||||
RETURNING id, created_at, updated_at`,
|
||||
channel.Name, channel.Description, channel.Status, modelMappingJSON, channel.BillingModelSource, channel.RestrictModels, channel.Features, featuresConfigJSON,
|
||||
channel.Name, channel.Description, channel.Status, modelMappingJSON, channel.BillingModelSource, channel.RestrictModels, channel.Features, channel.ApplyPricingToAccountStats,
|
||||
).Scan(&channel.ID, &channel.CreatedAt, &channel.UpdatedAt)
|
||||
if err != nil {
|
||||
if isUniqueViolation(err) {
|
||||
@ -71,17 +67,24 @@ func (r *channelRepository) Create(ctx context.Context, channel *service.Channel
|
||||
}
|
||||
}
|
||||
|
||||
// 设置账号统计定价规则
|
||||
if len(channel.AccountStatsPricingRules) > 0 {
|
||||
if err := replaceAccountStatsPricingRulesTx(ctx, tx, channel.ID, channel.AccountStatsPricingRules); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
func (r *channelRepository) GetByID(ctx context.Context, id int64) (*service.Channel, error) {
|
||||
ch := &service.Channel{}
|
||||
var modelMappingJSON, featuresConfigJSON []byte
|
||||
var modelMappingJSON []byte
|
||||
err := r.db.QueryRowContext(ctx,
|
||||
`SELECT id, name, description, status, model_mapping, billing_model_source, restrict_models, features, features_config, created_at, updated_at
|
||||
`SELECT id, name, description, status, model_mapping, billing_model_source, restrict_models, features, apply_pricing_to_account_stats, created_at, updated_at
|
||||
FROM channels WHERE id = $1`, id,
|
||||
).Scan(&ch.ID, &ch.Name, &ch.Description, &ch.Status, &modelMappingJSON, &ch.BillingModelSource, &ch.RestrictModels, &ch.Features, &featuresConfigJSON, &ch.CreatedAt, &ch.UpdatedAt)
|
||||
).Scan(&ch.ID, &ch.Name, &ch.Description, &ch.Status, &modelMappingJSON, &ch.BillingModelSource, &ch.RestrictModels, &ch.Features, &ch.ApplyPricingToAccountStats, &ch.CreatedAt, &ch.UpdatedAt)
|
||||
if err == sql.ErrNoRows {
|
||||
return nil, service.ErrChannelNotFound
|
||||
}
|
||||
@ -89,7 +92,6 @@ func (r *channelRepository) GetByID(ctx context.Context, id int64) (*service.Cha
|
||||
return nil, fmt.Errorf("get channel: %w", err)
|
||||
}
|
||||
ch.ModelMapping = unmarshalModelMapping(modelMappingJSON)
|
||||
ch.FeaturesConfig = unmarshalFeaturesConfig(featuresConfigJSON)
|
||||
|
||||
groupIDs, err := r.GetGroupIDs(ctx, id)
|
||||
if err != nil {
|
||||
@ -103,6 +105,12 @@ func (r *channelRepository) GetByID(ctx context.Context, id int64) (*service.Cha
|
||||
}
|
||||
ch.ModelPricing = pricing
|
||||
|
||||
statsPricingRules, err := r.loadAccountStatsPricingRules(ctx, id)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
ch.AccountStatsPricingRules = statsPricingRules
|
||||
|
||||
return ch, nil
|
||||
}
|
||||
|
||||
@ -112,14 +120,10 @@ func (r *channelRepository) Update(ctx context.Context, channel *service.Channel
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
featuresConfigJSON, err := marshalFeaturesConfig(channel.FeaturesConfig)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
result, err := tx.ExecContext(ctx,
|
||||
`UPDATE channels SET name = $1, description = $2, status = $3, model_mapping = $4, billing_model_source = $5, restrict_models = $6, features = $7, features_config = $8, updated_at = NOW()
|
||||
`UPDATE channels SET name = $1, description = $2, status = $3, model_mapping = $4, billing_model_source = $5, restrict_models = $6, features = $7, apply_pricing_to_account_stats = $8, updated_at = NOW()
|
||||
WHERE id = $9`,
|
||||
channel.Name, channel.Description, channel.Status, modelMappingJSON, channel.BillingModelSource, channel.RestrictModels, channel.Features, featuresConfigJSON, channel.ID,
|
||||
channel.Name, channel.Description, channel.Status, modelMappingJSON, channel.BillingModelSource, channel.RestrictModels, channel.Features, channel.ApplyPricingToAccountStats, channel.ID,
|
||||
)
|
||||
if err != nil {
|
||||
if isUniqueViolation(err) {
|
||||
@ -146,6 +150,13 @@ func (r *channelRepository) Update(ctx context.Context, channel *service.Channel
|
||||
}
|
||||
}
|
||||
|
||||
// 更新账号统计定价规则
|
||||
if channel.AccountStatsPricingRules != nil {
|
||||
if err := replaceAccountStatsPricingRulesTx(ctx, tx, channel.ID, channel.AccountStatsPricingRules); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
})
|
||||
}
|
||||
@ -196,7 +207,7 @@ func (r *channelRepository) List(ctx context.Context, params pagination.Paginati
|
||||
|
||||
// 查询 channel 列表
|
||||
dataQuery := fmt.Sprintf(
|
||||
`SELECT c.id, c.name, c.description, c.status, c.model_mapping, c.billing_model_source, c.restrict_models, c.features, c.features_config, c.created_at, c.updated_at
|
||||
`SELECT c.id, c.name, c.description, c.status, c.model_mapping, c.billing_model_source, c.restrict_models, c.features, c.apply_pricing_to_account_stats, c.created_at, c.updated_at
|
||||
FROM channels c WHERE %s ORDER BY %s LIMIT $%d OFFSET $%d`,
|
||||
whereClause, channelListOrderBy(params), argIdx, argIdx+1,
|
||||
)
|
||||
@ -212,12 +223,11 @@ func (r *channelRepository) List(ctx context.Context, params pagination.Paginati
|
||||
var channelIDs []int64
|
||||
for rows.Next() {
|
||||
var ch service.Channel
|
||||
var modelMappingJSON, featuresConfigJSON []byte
|
||||
if err := rows.Scan(&ch.ID, &ch.Name, &ch.Description, &ch.Status, &modelMappingJSON, &ch.BillingModelSource, &ch.RestrictModels, &ch.Features, &featuresConfigJSON, &ch.CreatedAt, &ch.UpdatedAt); err != nil {
|
||||
var modelMappingJSON []byte
|
||||
if err := rows.Scan(&ch.ID, &ch.Name, &ch.Description, &ch.Status, &modelMappingJSON, &ch.BillingModelSource, &ch.RestrictModels, &ch.Features, &ch.ApplyPricingToAccountStats, &ch.CreatedAt, &ch.UpdatedAt); err != nil {
|
||||
return nil, nil, fmt.Errorf("scan channel: %w", err)
|
||||
}
|
||||
ch.ModelMapping = unmarshalModelMapping(modelMappingJSON)
|
||||
ch.FeaturesConfig = unmarshalFeaturesConfig(featuresConfigJSON)
|
||||
channels = append(channels, ch)
|
||||
channelIDs = append(channelIDs, ch.ID)
|
||||
}
|
||||
@ -235,9 +245,14 @@ func (r *channelRepository) List(ctx context.Context, params pagination.Paginati
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
statsRulesMap, err := r.batchLoadAccountStatsPricingRules(ctx, channelIDs)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
for i := range channels {
|
||||
channels[i].GroupIDs = groupMap[channels[i].ID]
|
||||
channels[i].ModelPricing = pricingMap[channels[i].ID]
|
||||
channels[i].AccountStatsPricingRules = statsRulesMap[channels[i].ID]
|
||||
}
|
||||
}
|
||||
|
||||
@ -283,7 +298,7 @@ func channelListOrderBy(params pagination.PaginationParams) string {
|
||||
|
||||
func (r *channelRepository) ListAll(ctx context.Context) ([]service.Channel, error) {
|
||||
rows, err := r.db.QueryContext(ctx,
|
||||
`SELECT id, name, description, status, model_mapping, billing_model_source, restrict_models, features, features_config, created_at, updated_at FROM channels ORDER BY id`,
|
||||
`SELECT id, name, description, status, model_mapping, billing_model_source, restrict_models, features, apply_pricing_to_account_stats, created_at, updated_at FROM channels ORDER BY id`,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("query all channels: %w", err)
|
||||
@ -294,12 +309,11 @@ func (r *channelRepository) ListAll(ctx context.Context) ([]service.Channel, err
|
||||
var channelIDs []int64
|
||||
for rows.Next() {
|
||||
var ch service.Channel
|
||||
var modelMappingJSON, featuresConfigJSON []byte
|
||||
if err := rows.Scan(&ch.ID, &ch.Name, &ch.Description, &ch.Status, &modelMappingJSON, &ch.BillingModelSource, &ch.RestrictModels, &ch.Features, &featuresConfigJSON, &ch.CreatedAt, &ch.UpdatedAt); err != nil {
|
||||
var modelMappingJSON []byte
|
||||
if err := rows.Scan(&ch.ID, &ch.Name, &ch.Description, &ch.Status, &modelMappingJSON, &ch.BillingModelSource, &ch.RestrictModels, &ch.Features, &ch.ApplyPricingToAccountStats, &ch.CreatedAt, &ch.UpdatedAt); err != nil {
|
||||
return nil, fmt.Errorf("scan channel: %w", err)
|
||||
}
|
||||
ch.ModelMapping = unmarshalModelMapping(modelMappingJSON)
|
||||
ch.FeaturesConfig = unmarshalFeaturesConfig(featuresConfigJSON)
|
||||
channels = append(channels, ch)
|
||||
channelIDs = append(channelIDs, ch.ID)
|
||||
}
|
||||
@ -323,9 +337,16 @@ func (r *channelRepository) ListAll(ctx context.Context) ([]service.Channel, err
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 批量加载账号统计定价规则
|
||||
statsRulesMap, err := r.batchLoadAccountStatsPricingRules(ctx, channelIDs)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
for i := range channels {
|
||||
channels[i].GroupIDs = groupMap[channels[i].ID]
|
||||
channels[i].ModelPricing = pricingMap[channels[i].ID]
|
||||
channels[i].AccountStatsPricingRules = statsRulesMap[channels[i].ID]
|
||||
}
|
||||
|
||||
return channels, nil
|
||||
@ -467,28 +488,6 @@ func unmarshalModelMapping(data []byte) map[string]map[string]string {
|
||||
return m
|
||||
}
|
||||
|
||||
func marshalFeaturesConfig(m map[string]any) ([]byte, error) {
|
||||
if len(m) == 0 {
|
||||
return []byte("{}"), nil
|
||||
}
|
||||
data, err := json.Marshal(m)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("marshal features_config: %w", err)
|
||||
}
|
||||
return data, nil
|
||||
}
|
||||
|
||||
func unmarshalFeaturesConfig(data []byte) map[string]any {
|
||||
if len(data) == 0 {
|
||||
return nil
|
||||
}
|
||||
var m map[string]any
|
||||
if err := json.Unmarshal(data, &m); err != nil {
|
||||
return nil
|
||||
}
|
||||
return m
|
||||
}
|
||||
|
||||
// GetGroupPlatforms 批量查询分组 ID 对应的平台
|
||||
func (r *channelRepository) GetGroupPlatforms(ctx context.Context, groupIDs []int64) (map[int64]string, error) {
|
||||
if len(groupIDs) == 0 {
|
||||
|
||||
@ -0,0 +1,170 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
"github.com/lib/pq"
|
||||
)
|
||||
|
||||
// --- 账号统计定价规则 ---
|
||||
|
||||
// batchLoadAccountStatsPricingRules 批量加载多个渠道的账号统计定价规则(含模型定价)
|
||||
func (r *channelRepository) batchLoadAccountStatsPricingRules(ctx context.Context, channelIDs []int64) (map[int64][]service.AccountStatsPricingRule, error) {
|
||||
// 1. 查询规则
|
||||
rows, err := r.db.QueryContext(ctx,
|
||||
`SELECT id, channel_id, name, group_ids, account_ids, sort_order, created_at, updated_at
|
||||
FROM channel_account_stats_pricing_rules WHERE channel_id = ANY($1) ORDER BY channel_id, sort_order, id`,
|
||||
pq.Array(channelIDs),
|
||||
)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("batch load account stats pricing rules: %w", err)
|
||||
}
|
||||
defer func() { _ = rows.Close() }()
|
||||
|
||||
var allRules []service.AccountStatsPricingRule
|
||||
var ruleIDs []int64
|
||||
for rows.Next() {
|
||||
var rule service.AccountStatsPricingRule
|
||||
if err := rows.Scan(
|
||||
&rule.ID, &rule.ChannelID, &rule.Name,
|
||||
pq.Array(&rule.GroupIDs), pq.Array(&rule.AccountIDs),
|
||||
&rule.SortOrder, &rule.CreatedAt, &rule.UpdatedAt,
|
||||
); err != nil {
|
||||
return nil, fmt.Errorf("scan account stats pricing rule: %w", err)
|
||||
}
|
||||
ruleIDs = append(ruleIDs, rule.ID)
|
||||
allRules = append(allRules, rule)
|
||||
}
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, fmt.Errorf("iterate account stats pricing rules: %w", err)
|
||||
}
|
||||
|
||||
// 2. 批量加载规则的模型定价
|
||||
pricingMap, err := r.batchLoadAccountStatsModelPricing(ctx, ruleIDs)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 3. 按 channelID 分组并关联定价
|
||||
result := make(map[int64][]service.AccountStatsPricingRule, len(channelIDs))
|
||||
for i := range allRules {
|
||||
allRules[i].Pricing = pricingMap[allRules[i].ID]
|
||||
result[allRules[i].ChannelID] = append(result[allRules[i].ChannelID], allRules[i])
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// batchLoadAccountStatsModelPricing 批量加载规则的模型定价
|
||||
func (r *channelRepository) batchLoadAccountStatsModelPricing(ctx context.Context, ruleIDs []int64) (map[int64][]service.ChannelModelPricing, error) {
|
||||
if len(ruleIDs) == 0 {
|
||||
return make(map[int64][]service.ChannelModelPricing), nil
|
||||
}
|
||||
|
||||
rows, err := r.db.QueryContext(ctx,
|
||||
`SELECT id, rule_id, platform, models, billing_mode, input_price, output_price,
|
||||
cache_write_price, cache_read_price, image_output_price, per_request_price, created_at, updated_at
|
||||
FROM channel_account_stats_model_pricing WHERE rule_id = ANY($1) ORDER BY rule_id, id`,
|
||||
pq.Array(ruleIDs),
|
||||
)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("batch load account stats model pricing: %w", err)
|
||||
}
|
||||
defer func() { _ = rows.Close() }()
|
||||
|
||||
pricingMap := make(map[int64][]service.ChannelModelPricing, len(ruleIDs))
|
||||
for rows.Next() {
|
||||
var p service.ChannelModelPricing
|
||||
var ruleID int64
|
||||
var modelsJSON []byte
|
||||
if err := rows.Scan(
|
||||
&p.ID, &ruleID, &p.Platform, &modelsJSON, &p.BillingMode,
|
||||
&p.InputPrice, &p.OutputPrice, &p.CacheWritePrice, &p.CacheReadPrice,
|
||||
&p.ImageOutputPrice, &p.PerRequestPrice, &p.CreatedAt, &p.UpdatedAt,
|
||||
); err != nil {
|
||||
return nil, fmt.Errorf("scan account stats model pricing: %w", err)
|
||||
}
|
||||
if err := json.Unmarshal(modelsJSON, &p.Models); err != nil {
|
||||
p.Models = []string{}
|
||||
}
|
||||
pricingMap[ruleID] = append(pricingMap[ruleID], p)
|
||||
}
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, fmt.Errorf("iterate account stats model pricing: %w", err)
|
||||
}
|
||||
return pricingMap, nil
|
||||
}
|
||||
|
||||
// loadAccountStatsPricingRules 加载单个渠道的账号统计定价规则(供 GetByID 使用)
|
||||
func (r *channelRepository) loadAccountStatsPricingRules(ctx context.Context, channelID int64) ([]service.AccountStatsPricingRule, error) {
|
||||
result, err := r.batchLoadAccountStatsPricingRules(ctx, []int64{channelID})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return result[channelID], nil
|
||||
}
|
||||
|
||||
// replaceAccountStatsPricingRulesTx 在事务中替换渠道的账号统计定价规则(删除旧的 + 插入新的)
|
||||
func replaceAccountStatsPricingRulesTx(ctx context.Context, tx *sql.Tx, channelID int64, rules []service.AccountStatsPricingRule) error {
|
||||
// CASCADE 会自动删除关联的 model_pricing
|
||||
if _, err := tx.ExecContext(ctx,
|
||||
`DELETE FROM channel_account_stats_pricing_rules WHERE channel_id = $1`, channelID,
|
||||
); err != nil {
|
||||
return fmt.Errorf("delete old account stats pricing rules: %w", err)
|
||||
}
|
||||
|
||||
for i := range rules {
|
||||
rules[i].ChannelID = channelID
|
||||
if err := createAccountStatsPricingRuleTx(ctx, tx, &rules[i]); err != nil {
|
||||
return fmt.Errorf("insert account stats pricing rule: %w", err)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// createAccountStatsPricingRuleTx 在事务中创建单条账号统计定价规则及其模型定价
|
||||
func createAccountStatsPricingRuleTx(ctx context.Context, tx *sql.Tx, rule *service.AccountStatsPricingRule) error {
|
||||
err := tx.QueryRowContext(ctx,
|
||||
`INSERT INTO channel_account_stats_pricing_rules (channel_id, name, group_ids, account_ids, sort_order)
|
||||
VALUES ($1, $2, $3, $4, $5) RETURNING id, created_at, updated_at`,
|
||||
rule.ChannelID, rule.Name, pq.Array(rule.GroupIDs), pq.Array(rule.AccountIDs), rule.SortOrder,
|
||||
).Scan(&rule.ID, &rule.CreatedAt, &rule.UpdatedAt)
|
||||
if err != nil {
|
||||
return fmt.Errorf("insert account stats pricing rule: %w", err)
|
||||
}
|
||||
|
||||
for j := range rule.Pricing {
|
||||
if err := createAccountStatsModelPricingTx(ctx, tx, rule.ID, &rule.Pricing[j]); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// createAccountStatsModelPricingTx 在事务中创建单条账号统计模型定价
|
||||
func createAccountStatsModelPricingTx(ctx context.Context, tx *sql.Tx, ruleID int64, pricing *service.ChannelModelPricing) error {
|
||||
modelsJSON, err := json.Marshal(pricing.Models)
|
||||
if err != nil {
|
||||
return fmt.Errorf("marshal models: %w", err)
|
||||
}
|
||||
billingMode := pricing.BillingMode
|
||||
if billingMode == "" {
|
||||
billingMode = service.BillingModeToken
|
||||
}
|
||||
platform := pricing.Platform
|
||||
err = tx.QueryRowContext(ctx,
|
||||
`INSERT INTO channel_account_stats_model_pricing (rule_id, platform, models, billing_mode, input_price, output_price, cache_write_price, cache_read_price, image_output_price, per_request_price)
|
||||
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10) RETURNING id, created_at, updated_at`,
|
||||
ruleID, platform, modelsJSON, billingMode,
|
||||
pricing.InputPrice, pricing.OutputPrice, pricing.CacheWritePrice, pricing.CacheReadPrice,
|
||||
pricing.ImageOutputPrice, pricing.PerRequestPrice,
|
||||
).Scan(&pricing.ID, &pricing.CreatedAt, &pricing.UpdatedAt)
|
||||
if err != nil {
|
||||
return fmt.Errorf("insert account stats model pricing: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@ -28,7 +28,7 @@ import (
|
||||
gocache "github.com/patrickmn/go-cache"
|
||||
)
|
||||
|
||||
const usageLogSelectColumns = "id, user_id, api_key_id, account_id, request_id, model, requested_model, upstream_model, group_id, subscription_id, input_tokens, output_tokens, cache_creation_tokens, cache_read_tokens, cache_creation_5m_tokens, cache_creation_1h_tokens, image_output_tokens, image_output_cost, input_cost, output_cost, cache_creation_cost, cache_read_cost, total_cost, actual_cost, rate_multiplier, account_rate_multiplier, billing_type, request_type, stream, openai_ws_mode, duration_ms, first_token_ms, user_agent, ip_address, image_count, image_size, service_tier, reasoning_effort, inbound_endpoint, upstream_endpoint, cache_ttl_overridden, channel_id, model_mapping_chain, billing_tier, billing_mode, created_at"
|
||||
const usageLogSelectColumns = "id, user_id, api_key_id, account_id, request_id, model, requested_model, upstream_model, group_id, subscription_id, input_tokens, output_tokens, cache_creation_tokens, cache_read_tokens, cache_creation_5m_tokens, cache_creation_1h_tokens, image_output_tokens, image_output_cost, input_cost, output_cost, cache_creation_cost, cache_read_cost, total_cost, actual_cost, rate_multiplier, account_rate_multiplier, billing_type, request_type, stream, openai_ws_mode, duration_ms, first_token_ms, user_agent, ip_address, image_count, image_size, service_tier, reasoning_effort, inbound_endpoint, upstream_endpoint, cache_ttl_overridden, channel_id, model_mapping_chain, billing_tier, billing_mode, account_stats_cost, created_at"
|
||||
|
||||
// usageLogInsertArgTypes must stay in the same order as:
|
||||
// 1. prepareUsageLogInsert().args
|
||||
@ -82,6 +82,7 @@ var usageLogInsertArgTypes = [...]string{
|
||||
"text", // model_mapping_chain
|
||||
"text", // billing_tier
|
||||
"text", // billing_mode
|
||||
"numeric", // account_stats_cost
|
||||
"timestamptz", // created_at
|
||||
}
|
||||
|
||||
@ -360,6 +361,7 @@ func (r *usageLogRepository) createSingle(ctx context.Context, sqlq sqlExecutor,
|
||||
model_mapping_chain,
|
||||
billing_tier,
|
||||
billing_mode,
|
||||
account_stats_cost,
|
||||
created_at
|
||||
) VALUES (
|
||||
$1, $2, $3, $4, $5, $6, $7,
|
||||
@ -367,7 +369,7 @@ func (r *usageLogRepository) createSingle(ctx context.Context, sqlq sqlExecutor,
|
||||
$10, $11, $12, $13,
|
||||
$14, $15, $16, $17,
|
||||
$18, $19, $20, $21, $22, $23,
|
||||
$24, $25, $26, $27, $28, $29, $30, $31, $32, $33, $34, $35, $36, $37, $38, $39, $40, $41, $42, $43, $44, $45
|
||||
$24, $25, $26, $27, $28, $29, $30, $31, $32, $33, $34, $35, $36, $37, $38, $39, $40, $41, $42, $43, $44, $45, $46
|
||||
)
|
||||
ON CONFLICT (request_id, api_key_id) DO NOTHING
|
||||
RETURNING id, created_at
|
||||
@ -797,6 +799,7 @@ func buildUsageLogBatchInsertQuery(keys []string, preparedByKey map[string]usage
|
||||
model_mapping_chain,
|
||||
billing_tier,
|
||||
billing_mode,
|
||||
account_stats_cost,
|
||||
created_at
|
||||
) AS (VALUES `)
|
||||
|
||||
@ -873,6 +876,7 @@ func buildUsageLogBatchInsertQuery(keys []string, preparedByKey map[string]usage
|
||||
model_mapping_chain,
|
||||
billing_tier,
|
||||
billing_mode,
|
||||
account_stats_cost,
|
||||
created_at
|
||||
)
|
||||
SELECT
|
||||
@ -920,6 +924,7 @@ func buildUsageLogBatchInsertQuery(keys []string, preparedByKey map[string]usage
|
||||
model_mapping_chain,
|
||||
billing_tier,
|
||||
billing_mode,
|
||||
account_stats_cost,
|
||||
created_at
|
||||
FROM input
|
||||
ON CONFLICT (request_id, api_key_id) DO NOTHING
|
||||
@ -1007,10 +1012,11 @@ func buildUsageLogBestEffortInsertQuery(preparedList []usageLogInsertPrepared) (
|
||||
model_mapping_chain,
|
||||
billing_tier,
|
||||
billing_mode,
|
||||
account_stats_cost,
|
||||
created_at
|
||||
) AS (VALUES `)
|
||||
|
||||
args := make([]any, 0, len(preparedList)*45)
|
||||
args := make([]any, 0, len(preparedList)*46)
|
||||
argPos := 1
|
||||
for idx, prepared := range preparedList {
|
||||
if idx > 0 {
|
||||
@ -1080,6 +1086,7 @@ func buildUsageLogBestEffortInsertQuery(preparedList []usageLogInsertPrepared) (
|
||||
model_mapping_chain,
|
||||
billing_tier,
|
||||
billing_mode,
|
||||
account_stats_cost,
|
||||
created_at
|
||||
)
|
||||
SELECT
|
||||
@ -1127,6 +1134,7 @@ func buildUsageLogBestEffortInsertQuery(preparedList []usageLogInsertPrepared) (
|
||||
model_mapping_chain,
|
||||
billing_tier,
|
||||
billing_mode,
|
||||
account_stats_cost,
|
||||
created_at
|
||||
FROM input
|
||||
ON CONFLICT (request_id, api_key_id) DO NOTHING
|
||||
@ -1182,6 +1190,7 @@ func execUsageLogInsertNoResult(ctx context.Context, sqlq sqlExecutor, prepared
|
||||
model_mapping_chain,
|
||||
billing_tier,
|
||||
billing_mode,
|
||||
account_stats_cost,
|
||||
created_at
|
||||
) VALUES (
|
||||
$1, $2, $3, $4, $5, $6, $7,
|
||||
@ -1189,7 +1198,7 @@ func execUsageLogInsertNoResult(ctx context.Context, sqlq sqlExecutor, prepared
|
||||
$10, $11, $12, $13,
|
||||
$14, $15, $16, $17,
|
||||
$18, $19, $20, $21, $22, $23,
|
||||
$24, $25, $26, $27, $28, $29, $30, $31, $32, $33, $34, $35, $36, $37, $38, $39, $40, $41, $42, $43, $44, $45
|
||||
$24, $25, $26, $27, $28, $29, $30, $31, $32, $33, $34, $35, $36, $37, $38, $39, $40, $41, $42, $43, $44, $45, $46
|
||||
)
|
||||
ON CONFLICT (request_id, api_key_id) DO NOTHING
|
||||
`, prepared.args...)
|
||||
@ -1285,6 +1294,7 @@ func prepareUsageLogInsert(log *service.UsageLog) usageLogInsertPrepared {
|
||||
modelMappingChain,
|
||||
billingTier,
|
||||
billingMode,
|
||||
log.AccountStatsCost, // account_stats_cost
|
||||
createdAt,
|
||||
},
|
||||
}
|
||||
@ -1959,7 +1969,7 @@ func (r *usageLogRepository) GetAccountTodayStats(ctx context.Context, accountID
|
||||
SELECT
|
||||
COUNT(*) as requests,
|
||||
COALESCE(SUM(input_tokens + output_tokens + cache_creation_tokens + cache_read_tokens), 0) as tokens,
|
||||
COALESCE(SUM(total_cost * COALESCE(account_rate_multiplier, 1)), 0) as cost,
|
||||
COALESCE(SUM(COALESCE(account_stats_cost, total_cost) * COALESCE(account_rate_multiplier, 1)), 0) as cost,
|
||||
COALESCE(SUM(total_cost), 0) as standard_cost,
|
||||
COALESCE(SUM(actual_cost), 0) as user_cost
|
||||
FROM usage_logs
|
||||
@ -1989,7 +1999,7 @@ func (r *usageLogRepository) GetAccountWindowStats(ctx context.Context, accountI
|
||||
SELECT
|
||||
COUNT(*) as requests,
|
||||
COALESCE(SUM(input_tokens + output_tokens + cache_creation_tokens + cache_read_tokens), 0) as tokens,
|
||||
COALESCE(SUM(total_cost * COALESCE(account_rate_multiplier, 1)), 0) as cost,
|
||||
COALESCE(SUM(COALESCE(account_stats_cost, total_cost) * COALESCE(account_rate_multiplier, 1)), 0) as cost,
|
||||
COALESCE(SUM(total_cost), 0) as standard_cost,
|
||||
COALESCE(SUM(actual_cost), 0) as user_cost
|
||||
FROM usage_logs
|
||||
@ -2026,7 +2036,7 @@ func (r *usageLogRepository) GetAccountWindowStatsBatch(ctx context.Context, acc
|
||||
account_id,
|
||||
COUNT(*) as requests,
|
||||
COALESCE(SUM(input_tokens + output_tokens + cache_creation_tokens + cache_read_tokens), 0) as tokens,
|
||||
COALESCE(SUM(total_cost * COALESCE(account_rate_multiplier, 1)), 0) as cost,
|
||||
COALESCE(SUM(COALESCE(account_stats_cost, total_cost) * COALESCE(account_rate_multiplier, 1)), 0) as cost,
|
||||
COALESCE(SUM(total_cost), 0) as standard_cost,
|
||||
COALESCE(SUM(actual_cost), 0) as user_cost
|
||||
FROM usage_logs
|
||||
@ -2990,7 +3000,7 @@ func (r *usageLogRepository) getModelStatsWithFiltersBySource(ctx context.Contex
|
||||
actualCostExpr := "COALESCE(SUM(actual_cost), 0) as actual_cost"
|
||||
// 当仅按 account_id 聚合时,实际费用使用账号倍率(total_cost * account_rate_multiplier)。
|
||||
if accountID > 0 && userID == 0 && apiKeyID == 0 {
|
||||
actualCostExpr = "COALESCE(SUM(total_cost * COALESCE(account_rate_multiplier, 1)), 0) as actual_cost"
|
||||
actualCostExpr = "COALESCE(SUM(COALESCE(account_stats_cost, total_cost) * COALESCE(account_rate_multiplier, 1)), 0) as actual_cost"
|
||||
}
|
||||
modelExpr := resolveModelDimensionExpression(source)
|
||||
|
||||
@ -3358,7 +3368,7 @@ func (r *usageLogRepository) GetStatsWithFilters(ctx context.Context, filters Us
|
||||
COALESCE(SUM(cache_creation_tokens + cache_read_tokens), 0) as total_cache_tokens,
|
||||
COALESCE(SUM(total_cost), 0) as total_cost,
|
||||
COALESCE(SUM(actual_cost), 0) as total_actual_cost,
|
||||
COALESCE(SUM(total_cost * COALESCE(account_rate_multiplier, 1)), 0) as total_account_cost,
|
||||
COALESCE(SUM(COALESCE(account_stats_cost, total_cost) * COALESCE(account_rate_multiplier, 1)), 0) as total_account_cost,
|
||||
COALESCE(AVG(duration_ms), 0) as avg_duration_ms
|
||||
FROM usage_logs
|
||||
%s
|
||||
@ -3433,7 +3443,7 @@ type EndpointStat = usagestats.EndpointStat
|
||||
func (r *usageLogRepository) getEndpointStatsByColumnWithFilters(ctx context.Context, endpointColumn string, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, model string, requestType *int16, stream *bool, billingType *int8) (results []EndpointStat, err error) {
|
||||
actualCostExpr := "COALESCE(SUM(actual_cost), 0) as actual_cost"
|
||||
if accountID > 0 && userID == 0 && apiKeyID == 0 {
|
||||
actualCostExpr = "COALESCE(SUM(total_cost * COALESCE(account_rate_multiplier, 1)), 0) as actual_cost"
|
||||
actualCostExpr = "COALESCE(SUM(COALESCE(account_stats_cost, total_cost) * COALESCE(account_rate_multiplier, 1)), 0) as actual_cost"
|
||||
}
|
||||
|
||||
query := fmt.Sprintf(`
|
||||
@ -3500,7 +3510,7 @@ func (r *usageLogRepository) getEndpointStatsByColumnWithFilters(ctx context.Con
|
||||
func (r *usageLogRepository) getEndpointPathStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, model string, requestType *int16, stream *bool, billingType *int8) (results []EndpointStat, err error) {
|
||||
actualCostExpr := "COALESCE(SUM(actual_cost), 0) as actual_cost"
|
||||
if accountID > 0 && userID == 0 && apiKeyID == 0 {
|
||||
actualCostExpr = "COALESCE(SUM(total_cost * COALESCE(account_rate_multiplier, 1)), 0) as actual_cost"
|
||||
actualCostExpr = "COALESCE(SUM(COALESCE(account_stats_cost, total_cost) * COALESCE(account_rate_multiplier, 1)), 0) as actual_cost"
|
||||
}
|
||||
|
||||
query := fmt.Sprintf(`
|
||||
@ -3591,7 +3601,7 @@ func (r *usageLogRepository) GetAccountUsageStats(ctx context.Context, accountID
|
||||
COUNT(*) as requests,
|
||||
COALESCE(SUM(input_tokens + output_tokens + cache_creation_tokens + cache_read_tokens), 0) as tokens,
|
||||
COALESCE(SUM(total_cost), 0) as cost,
|
||||
COALESCE(SUM(total_cost * COALESCE(account_rate_multiplier, 1)), 0) as actual_cost,
|
||||
COALESCE(SUM(COALESCE(account_stats_cost, total_cost) * COALESCE(account_rate_multiplier, 1)), 0) as actual_cost,
|
||||
COALESCE(SUM(actual_cost), 0) as user_cost
|
||||
FROM usage_logs
|
||||
WHERE account_id = $1 AND created_at >= $2 AND created_at < $3
|
||||
@ -4069,6 +4079,7 @@ func scanUsageLog(scanner interface{ Scan(...any) error }) (*service.UsageLog, e
|
||||
modelMappingChain sql.NullString
|
||||
billingTier sql.NullString
|
||||
billingMode sql.NullString
|
||||
accountStatsCost sql.NullFloat64
|
||||
createdAt time.Time
|
||||
)
|
||||
|
||||
@ -4118,6 +4129,7 @@ func scanUsageLog(scanner interface{ Scan(...any) error }) (*service.UsageLog, e
|
||||
&modelMappingChain,
|
||||
&billingTier,
|
||||
&billingMode,
|
||||
&accountStatsCost,
|
||||
&createdAt,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
@ -4214,6 +4226,9 @@ func scanUsageLog(scanner interface{ Scan(...any) error }) (*service.UsageLog, e
|
||||
if billingMode.Valid {
|
||||
log.BillingMode = &billingMode.String
|
||||
}
|
||||
if accountStatsCost.Valid {
|
||||
log.AccountStatsCost = &accountStatsCost.Float64
|
||||
}
|
||||
|
||||
return log, nil
|
||||
}
|
||||
|
||||
@ -85,6 +85,7 @@ func TestUsageLogRepositoryCreateSyncRequestTypeAndLegacyFields(t *testing.T) {
|
||||
sqlmock.AnyArg(), // model_mapping_chain
|
||||
sqlmock.AnyArg(), // billing_tier
|
||||
sqlmock.AnyArg(), // billing_mode
|
||||
sqlmock.AnyArg(), // account_stats_cost
|
||||
createdAt,
|
||||
).
|
||||
WillReturnRows(sqlmock.NewRows([]string{"id", "created_at"}).AddRow(int64(99), createdAt))
|
||||
@ -163,6 +164,7 @@ func TestUsageLogRepositoryCreate_PersistsServiceTier(t *testing.T) {
|
||||
sqlmock.AnyArg(), // model_mapping_chain
|
||||
sqlmock.AnyArg(), // billing_tier
|
||||
sqlmock.AnyArg(), // billing_mode
|
||||
sqlmock.AnyArg(), // account_stats_cost
|
||||
createdAt,
|
||||
).
|
||||
WillReturnRows(sqlmock.NewRows([]string{"id", "created_at"}).AddRow(int64(100), createdAt))
|
||||
@ -483,10 +485,11 @@ func TestScanUsageLogRequestTypeAndLegacyFallback(t *testing.T) {
|
||||
sql.NullString{},
|
||||
sql.NullString{},
|
||||
false,
|
||||
sql.NullInt64{}, // channel_id
|
||||
sql.NullString{}, // model_mapping_chain
|
||||
sql.NullString{}, // billing_tier
|
||||
sql.NullString{}, // billing_mode
|
||||
sql.NullInt64{}, // channel_id
|
||||
sql.NullString{}, // model_mapping_chain
|
||||
sql.NullString{}, // billing_tier
|
||||
sql.NullString{}, // billing_mode
|
||||
sql.NullFloat64{}, // account_stats_cost
|
||||
now,
|
||||
}})
|
||||
require.NoError(t, err)
|
||||
@ -530,10 +533,11 @@ func TestScanUsageLogRequestTypeAndLegacyFallback(t *testing.T) {
|
||||
sql.NullString{},
|
||||
sql.NullString{},
|
||||
false,
|
||||
sql.NullInt64{}, // channel_id
|
||||
sql.NullString{}, // model_mapping_chain
|
||||
sql.NullString{}, // billing_tier
|
||||
sql.NullString{}, // billing_mode
|
||||
sql.NullInt64{}, // channel_id
|
||||
sql.NullString{}, // model_mapping_chain
|
||||
sql.NullString{}, // billing_tier
|
||||
sql.NullString{}, // billing_mode
|
||||
sql.NullFloat64{}, // account_stats_cost
|
||||
now,
|
||||
}})
|
||||
require.NoError(t, err)
|
||||
@ -577,10 +581,11 @@ func TestScanUsageLogRequestTypeAndLegacyFallback(t *testing.T) {
|
||||
sql.NullString{},
|
||||
sql.NullString{},
|
||||
false,
|
||||
sql.NullInt64{}, // channel_id
|
||||
sql.NullString{}, // model_mapping_chain
|
||||
sql.NullString{}, // billing_tier
|
||||
sql.NullString{}, // billing_mode
|
||||
sql.NullInt64{}, // channel_id
|
||||
sql.NullString{}, // model_mapping_chain
|
||||
sql.NullString{}, // billing_tier
|
||||
sql.NullString{}, // billing_mode
|
||||
sql.NullFloat64{}, // account_stats_cost
|
||||
now,
|
||||
}})
|
||||
require.NoError(t, err)
|
||||
|
||||
192
backend/internal/service/account_stats_pricing.go
Normal file
192
backend/internal/service/account_stats_pricing.go
Normal file
@ -0,0 +1,192 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sort"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// resolveAccountStatsCost 计算账号统计定价费用。
|
||||
// 返回 nil 表示不覆盖,使用默认公式(total_cost × account_rate_multiplier)。
|
||||
//
|
||||
// 匹配优先级(先命中为准):
|
||||
// 1. 自定义规则(AccountStatsPricingRules,按数组顺序遍历)
|
||||
// 2. 渠道已有的模型定价(ApplyPricingToAccountStats 开启时)
|
||||
// 3. nil → 走默认公式
|
||||
func resolveAccountStatsCost(
|
||||
ctx context.Context,
|
||||
channelService *ChannelService,
|
||||
billingService *BillingService,
|
||||
accountID int64,
|
||||
groupID int64,
|
||||
billingModel string,
|
||||
tokens UsageTokens,
|
||||
requestCount int,
|
||||
serviceTier string,
|
||||
) *float64 {
|
||||
if channelService == nil || billingService == nil {
|
||||
return nil
|
||||
}
|
||||
channel, err := channelService.GetChannelForGroup(ctx, groupID)
|
||||
if err != nil || channel == nil || !channel.ApplyPricingToAccountStats {
|
||||
return nil
|
||||
}
|
||||
|
||||
platform := channelService.GetGroupPlatform(ctx, groupID)
|
||||
modelLower := strings.ToLower(billingModel)
|
||||
|
||||
// 优先级 1:自定义规则
|
||||
if cost := tryCustomRules(channel, accountID, groupID, platform, modelLower, tokens, requestCount); cost != nil {
|
||||
return cost
|
||||
}
|
||||
|
||||
// 优先级 2:渠道已有模型定价
|
||||
return tryChannelPricing(ctx, channelService, groupID, billingModel, tokens, requestCount)
|
||||
}
|
||||
|
||||
// tryCustomRules 遍历自定义规则,按数组顺序先命中为准。
|
||||
func tryCustomRules(
|
||||
channel *Channel, accountID, groupID int64,
|
||||
platform, modelLower string, tokens UsageTokens, requestCount int,
|
||||
) *float64 {
|
||||
for _, rule := range channel.AccountStatsPricingRules {
|
||||
if !matchAccountStatsRule(&rule, accountID, groupID) {
|
||||
continue
|
||||
}
|
||||
pricing := findPricingForModel(rule.Pricing, platform, modelLower)
|
||||
if pricing == nil {
|
||||
continue // 规则匹配但模型不在规则定价中,继续下一条
|
||||
}
|
||||
return calculateStatsCost(pricing, tokens, requestCount)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// tryChannelPricing 使用渠道已有的模型定价计算账号统计费用。
|
||||
func tryChannelPricing(
|
||||
ctx context.Context, channelService *ChannelService,
|
||||
groupID int64, billingModel string, tokens UsageTokens, requestCount int,
|
||||
) *float64 {
|
||||
pricing := channelService.GetChannelModelPricing(ctx, groupID, billingModel)
|
||||
if pricing == nil {
|
||||
return nil
|
||||
}
|
||||
return calculateStatsCost(pricing, tokens, requestCount)
|
||||
}
|
||||
|
||||
// matchAccountStatsRule 检查规则是否匹配指定的 accountID 和 groupID。
|
||||
// 匹配条件:accountID ∈ rule.AccountIDs 或 groupID ∈ rule.GroupIDs。
|
||||
// 如果规则的 AccountIDs 和 GroupIDs 都为空,视为不匹配。
|
||||
func matchAccountStatsRule(rule *AccountStatsPricingRule, accountID, groupID int64) bool {
|
||||
if len(rule.AccountIDs) == 0 && len(rule.GroupIDs) == 0 {
|
||||
return false
|
||||
}
|
||||
for _, id := range rule.AccountIDs {
|
||||
if id == accountID {
|
||||
return true
|
||||
}
|
||||
}
|
||||
for _, id := range rule.GroupIDs {
|
||||
if id == groupID {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// wildcardMatch 通配符匹配候选项(用于排序)
|
||||
type wildcardMatch struct {
|
||||
prefixLen int
|
||||
pricing *ChannelModelPricing
|
||||
}
|
||||
|
||||
// findPricingForModel 在定价列表中查找匹配的模型定价。
|
||||
// 先精确匹配,再通配符匹配(前缀越长优先级越高)。
|
||||
func findPricingForModel(pricingList []ChannelModelPricing, platform, modelLower string) *ChannelModelPricing {
|
||||
// 精确匹配优先
|
||||
for i := range pricingList {
|
||||
p := &pricingList[i]
|
||||
if !isPlatformMatch(platform, p.Platform) {
|
||||
continue
|
||||
}
|
||||
for _, m := range p.Models {
|
||||
if strings.ToLower(m) == modelLower {
|
||||
return p
|
||||
}
|
||||
}
|
||||
}
|
||||
// 通配符匹配:收集所有匹配项,按前缀长度降序取最长
|
||||
var matches []wildcardMatch
|
||||
for i := range pricingList {
|
||||
p := &pricingList[i]
|
||||
if !isPlatformMatch(platform, p.Platform) {
|
||||
continue
|
||||
}
|
||||
for _, m := range p.Models {
|
||||
ml := strings.ToLower(m)
|
||||
if !strings.HasSuffix(ml, "*") {
|
||||
continue
|
||||
}
|
||||
prefix := strings.TrimSuffix(ml, "*")
|
||||
if strings.HasPrefix(modelLower, prefix) {
|
||||
matches = append(matches, wildcardMatch{prefixLen: len(prefix), pricing: p})
|
||||
}
|
||||
}
|
||||
}
|
||||
if len(matches) == 0 {
|
||||
return nil
|
||||
}
|
||||
sort.Slice(matches, func(i, j int) bool {
|
||||
return matches[i].prefixLen > matches[j].prefixLen
|
||||
})
|
||||
return matches[0].pricing
|
||||
}
|
||||
|
||||
// isPlatformMatch 判断平台是否匹配(空平台视为不限平台)。
|
||||
func isPlatformMatch(queryPlatform, pricingPlatform string) bool {
|
||||
if queryPlatform == "" || pricingPlatform == "" {
|
||||
return true
|
||||
}
|
||||
return queryPlatform == pricingPlatform
|
||||
}
|
||||
|
||||
// calculateStatsCost 使用给定的定价计算费用(不含任何倍率,原始费用)。
|
||||
func calculateStatsCost(pricing *ChannelModelPricing, tokens UsageTokens, requestCount int) *float64 {
|
||||
if pricing == nil {
|
||||
return nil
|
||||
}
|
||||
switch pricing.BillingMode {
|
||||
case BillingModePerRequest, BillingModeImage:
|
||||
return calculatePerRequestStatsCost(pricing, requestCount)
|
||||
default:
|
||||
return calculateTokenStatsCost(pricing, tokens)
|
||||
}
|
||||
}
|
||||
|
||||
// calculatePerRequestStatsCost 按次/图片计费。
|
||||
func calculatePerRequestStatsCost(pricing *ChannelModelPricing, requestCount int) *float64 {
|
||||
if pricing.PerRequestPrice == nil || *pricing.PerRequestPrice <= 0 {
|
||||
return nil
|
||||
}
|
||||
cost := *pricing.PerRequestPrice * float64(requestCount)
|
||||
return &cost
|
||||
}
|
||||
|
||||
// calculateTokenStatsCost Token 计费。
|
||||
func calculateTokenStatsCost(pricing *ChannelModelPricing, tokens UsageTokens) *float64 {
|
||||
deref := func(p *float64) float64 {
|
||||
if p == nil {
|
||||
return 0
|
||||
}
|
||||
return *p
|
||||
}
|
||||
cost := float64(tokens.InputTokens)*deref(pricing.InputPrice) +
|
||||
float64(tokens.OutputTokens)*deref(pricing.OutputPrice) +
|
||||
float64(tokens.CacheCreationTokens)*deref(pricing.CacheWritePrice) +
|
||||
float64(tokens.CacheReadTokens)*deref(pricing.CacheReadPrice) +
|
||||
float64(tokens.ImageOutputTokens)*deref(pricing.ImageOutputPrice)
|
||||
if cost == 0 {
|
||||
return nil
|
||||
}
|
||||
return &cost
|
||||
}
|
||||
430
backend/internal/service/account_stats_pricing_test.go
Normal file
430
backend/internal/service/account_stats_pricing_test.go
Normal file
@ -0,0 +1,430 @@
|
||||
//go:build unit
|
||||
|
||||
package service
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// matchAccountStatsRule
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
func TestMatchAccountStatsRule_BothEmpty_NoMatch(t *testing.T) {
|
||||
rule := &AccountStatsPricingRule{}
|
||||
require.False(t, matchAccountStatsRule(rule, 1, 10))
|
||||
}
|
||||
|
||||
func TestMatchAccountStatsRule_AccountIDMatch(t *testing.T) {
|
||||
rule := &AccountStatsPricingRule{AccountIDs: []int64{1, 2, 3}}
|
||||
require.True(t, matchAccountStatsRule(rule, 2, 999))
|
||||
}
|
||||
|
||||
func TestMatchAccountStatsRule_GroupIDMatch(t *testing.T) {
|
||||
rule := &AccountStatsPricingRule{GroupIDs: []int64{10, 20}}
|
||||
require.True(t, matchAccountStatsRule(rule, 999, 20))
|
||||
}
|
||||
|
||||
func TestMatchAccountStatsRule_BothConfigured_AccountMatch(t *testing.T) {
|
||||
rule := &AccountStatsPricingRule{
|
||||
AccountIDs: []int64{1, 2},
|
||||
GroupIDs: []int64{10, 20},
|
||||
}
|
||||
require.True(t, matchAccountStatsRule(rule, 2, 999))
|
||||
}
|
||||
|
||||
func TestMatchAccountStatsRule_BothConfigured_GroupMatch(t *testing.T) {
|
||||
rule := &AccountStatsPricingRule{
|
||||
AccountIDs: []int64{1, 2},
|
||||
GroupIDs: []int64{10, 20},
|
||||
}
|
||||
require.True(t, matchAccountStatsRule(rule, 999, 10))
|
||||
}
|
||||
|
||||
func TestMatchAccountStatsRule_BothConfigured_NeitherMatch(t *testing.T) {
|
||||
rule := &AccountStatsPricingRule{
|
||||
AccountIDs: []int64{1, 2},
|
||||
GroupIDs: []int64{10, 20},
|
||||
}
|
||||
require.False(t, matchAccountStatsRule(rule, 999, 999))
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// findPricingForModel
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
func TestFindPricingForModel(t *testing.T) {
|
||||
exactPricing := ChannelModelPricing{
|
||||
ID: 1,
|
||||
Models: []string{"claude-opus-4"},
|
||||
}
|
||||
wildcardPricing := ChannelModelPricing{
|
||||
ID: 2,
|
||||
Models: []string{"claude-*"},
|
||||
}
|
||||
platformPricing := ChannelModelPricing{
|
||||
ID: 3,
|
||||
Platform: "openai",
|
||||
Models: []string{"gpt-4o"},
|
||||
}
|
||||
emptyPlatformPricing := ChannelModelPricing{
|
||||
ID: 4,
|
||||
Models: []string{"gemini-2.5-pro"},
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
list []ChannelModelPricing
|
||||
platform string
|
||||
model string
|
||||
wantID int64
|
||||
wantNil bool
|
||||
}{
|
||||
{
|
||||
name: "exact match",
|
||||
list: []ChannelModelPricing{exactPricing},
|
||||
platform: "anthropic",
|
||||
model: "claude-opus-4",
|
||||
wantID: 1,
|
||||
},
|
||||
{
|
||||
name: "exact match case insensitive",
|
||||
list: []ChannelModelPricing{{ID: 5, Models: []string{"Claude-Opus-4"}}},
|
||||
platform: "",
|
||||
model: "claude-opus-4",
|
||||
wantID: 5,
|
||||
},
|
||||
{
|
||||
name: "wildcard match",
|
||||
list: []ChannelModelPricing{wildcardPricing},
|
||||
platform: "anthropic",
|
||||
model: "claude-opus-4",
|
||||
wantID: 2,
|
||||
},
|
||||
{
|
||||
name: "exact match takes priority over wildcard",
|
||||
list: []ChannelModelPricing{wildcardPricing, exactPricing},
|
||||
platform: "anthropic",
|
||||
model: "claude-opus-4",
|
||||
wantID: 1,
|
||||
},
|
||||
{
|
||||
name: "platform mismatch skipped",
|
||||
list: []ChannelModelPricing{platformPricing},
|
||||
platform: "anthropic",
|
||||
model: "gpt-4o",
|
||||
wantNil: true,
|
||||
},
|
||||
{
|
||||
name: "empty platform in pricing matches any",
|
||||
list: []ChannelModelPricing{emptyPlatformPricing},
|
||||
platform: "gemini",
|
||||
model: "gemini-2.5-pro",
|
||||
wantID: 4,
|
||||
},
|
||||
{
|
||||
name: "empty platform in query matches any pricing platform",
|
||||
list: []ChannelModelPricing{platformPricing},
|
||||
platform: "",
|
||||
model: "gpt-4o",
|
||||
wantID: 3,
|
||||
},
|
||||
{
|
||||
name: "no match at all",
|
||||
list: []ChannelModelPricing{exactPricing, wildcardPricing},
|
||||
platform: "anthropic",
|
||||
model: "gpt-4o",
|
||||
wantNil: true,
|
||||
},
|
||||
{
|
||||
name: "empty list returns nil",
|
||||
list: nil,
|
||||
model: "claude-opus-4",
|
||||
wantNil: true,
|
||||
},
|
||||
{
|
||||
name: "longer wildcard prefix wins over shorter",
|
||||
list: []ChannelModelPricing{
|
||||
{ID: 10, Models: []string{"claude-*"}},
|
||||
{ID: 11, Models: []string{"claude-opus-*"}},
|
||||
},
|
||||
platform: "",
|
||||
model: "claude-opus-4",
|
||||
wantID: 11, // "claude-opus-" (12 chars) > "claude-" (7 chars)
|
||||
},
|
||||
{
|
||||
name: "shorter wildcard used when longer does not match",
|
||||
list: []ChannelModelPricing{
|
||||
{ID: 10, Models: []string{"claude-*"}},
|
||||
{ID: 11, Models: []string{"claude-opus-*"}},
|
||||
},
|
||||
platform: "",
|
||||
model: "claude-sonnet-4",
|
||||
wantID: 10, // only "claude-*" matches
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := findPricingForModel(tt.list, tt.platform, tt.model)
|
||||
if tt.wantNil {
|
||||
require.Nil(t, result)
|
||||
return
|
||||
}
|
||||
require.NotNil(t, result)
|
||||
require.Equal(t, tt.wantID, result.ID)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// calculateStatsCost
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
func TestCalculateStatsCost_NilPricing(t *testing.T) {
|
||||
result := calculateStatsCost(nil, UsageTokens{}, 1)
|
||||
require.Nil(t, result)
|
||||
}
|
||||
|
||||
func TestCalculateStatsCost_TokenBilling(t *testing.T) {
|
||||
pricing := &ChannelModelPricing{
|
||||
BillingMode: BillingModeToken,
|
||||
InputPrice: testPtrFloat64(0.001),
|
||||
OutputPrice: testPtrFloat64(0.002),
|
||||
}
|
||||
tokens := UsageTokens{
|
||||
InputTokens: 100,
|
||||
OutputTokens: 50,
|
||||
}
|
||||
result := calculateStatsCost(pricing, tokens, 1)
|
||||
require.NotNil(t, result)
|
||||
// 100*0.001 + 50*0.002 = 0.1 + 0.1 = 0.2
|
||||
require.InDelta(t, 0.2, *result, 1e-12)
|
||||
}
|
||||
|
||||
func TestCalculateStatsCost_TokenBilling_WithCache(t *testing.T) {
|
||||
pricing := &ChannelModelPricing{
|
||||
BillingMode: BillingModeToken,
|
||||
InputPrice: testPtrFloat64(0.001),
|
||||
OutputPrice: testPtrFloat64(0.002),
|
||||
CacheWritePrice: testPtrFloat64(0.003),
|
||||
CacheReadPrice: testPtrFloat64(0.0005),
|
||||
}
|
||||
tokens := UsageTokens{
|
||||
InputTokens: 100,
|
||||
OutputTokens: 50,
|
||||
CacheCreationTokens: 200,
|
||||
CacheReadTokens: 300,
|
||||
}
|
||||
result := calculateStatsCost(pricing, tokens, 1)
|
||||
require.NotNil(t, result)
|
||||
// 100*0.001 + 50*0.002 + 200*0.003 + 300*0.0005
|
||||
// = 0.1 + 0.1 + 0.6 + 0.15 = 0.95
|
||||
require.InDelta(t, 0.95, *result, 1e-12)
|
||||
}
|
||||
|
||||
func TestCalculateStatsCost_TokenBilling_WithImageOutput(t *testing.T) {
|
||||
pricing := &ChannelModelPricing{
|
||||
BillingMode: BillingModeToken,
|
||||
InputPrice: testPtrFloat64(0.001),
|
||||
OutputPrice: testPtrFloat64(0.002),
|
||||
ImageOutputPrice: testPtrFloat64(0.01),
|
||||
}
|
||||
tokens := UsageTokens{
|
||||
InputTokens: 100,
|
||||
OutputTokens: 50,
|
||||
ImageOutputTokens: 10,
|
||||
}
|
||||
result := calculateStatsCost(pricing, tokens, 1)
|
||||
require.NotNil(t, result)
|
||||
// 100*0.001 + 50*0.002 + 10*0.01 = 0.1 + 0.1 + 0.1 = 0.3
|
||||
require.InDelta(t, 0.3, *result, 1e-12)
|
||||
}
|
||||
|
||||
func TestCalculateStatsCost_TokenBilling_PartialPricesNil(t *testing.T) {
|
||||
pricing := &ChannelModelPricing{
|
||||
BillingMode: BillingModeToken,
|
||||
InputPrice: testPtrFloat64(0.001),
|
||||
// OutputPrice, CacheWritePrice, etc. are all nil → treated as 0
|
||||
}
|
||||
tokens := UsageTokens{
|
||||
InputTokens: 100,
|
||||
OutputTokens: 50,
|
||||
CacheCreationTokens: 200,
|
||||
}
|
||||
result := calculateStatsCost(pricing, tokens, 1)
|
||||
require.NotNil(t, result)
|
||||
// Only input contributes: 100*0.001 = 0.1
|
||||
require.InDelta(t, 0.1, *result, 1e-12)
|
||||
}
|
||||
|
||||
func TestCalculateStatsCost_TokenBilling_AllTokensZero(t *testing.T) {
|
||||
pricing := &ChannelModelPricing{
|
||||
BillingMode: BillingModeToken,
|
||||
InputPrice: testPtrFloat64(0.001),
|
||||
OutputPrice: testPtrFloat64(0.002),
|
||||
}
|
||||
tokens := UsageTokens{} // all zeros
|
||||
result := calculateStatsCost(pricing, tokens, 1)
|
||||
// totalCost == 0 → returns nil (does not override, falls back to default formula)
|
||||
require.Nil(t, result)
|
||||
}
|
||||
|
||||
func TestCalculateStatsCost_PerRequestBilling(t *testing.T) {
|
||||
pricing := &ChannelModelPricing{
|
||||
BillingMode: BillingModePerRequest,
|
||||
PerRequestPrice: testPtrFloat64(0.05),
|
||||
}
|
||||
tokens := UsageTokens{InputTokens: 999, OutputTokens: 999}
|
||||
result := calculateStatsCost(pricing, tokens, 3)
|
||||
require.NotNil(t, result)
|
||||
// 0.05 * 3 = 0.15
|
||||
require.InDelta(t, 0.15, *result, 1e-12)
|
||||
}
|
||||
|
||||
func TestCalculateStatsCost_PerRequestBilling_PriceNil(t *testing.T) {
|
||||
pricing := &ChannelModelPricing{
|
||||
BillingMode: BillingModePerRequest,
|
||||
// PerRequestPrice is nil
|
||||
}
|
||||
result := calculateStatsCost(pricing, UsageTokens{}, 1)
|
||||
require.Nil(t, result)
|
||||
}
|
||||
|
||||
func TestCalculateStatsCost_PerRequestBilling_PriceZero(t *testing.T) {
|
||||
pricing := &ChannelModelPricing{
|
||||
BillingMode: BillingModePerRequest,
|
||||
PerRequestPrice: testPtrFloat64(0),
|
||||
}
|
||||
result := calculateStatsCost(pricing, UsageTokens{}, 1)
|
||||
// price == 0 → condition *pricing.PerRequestPrice > 0 is false → returns nil
|
||||
require.Nil(t, result)
|
||||
}
|
||||
|
||||
func TestCalculateStatsCost_ImageBilling(t *testing.T) {
|
||||
pricing := &ChannelModelPricing{
|
||||
BillingMode: BillingModeImage,
|
||||
PerRequestPrice: testPtrFloat64(0.10),
|
||||
}
|
||||
result := calculateStatsCost(pricing, UsageTokens{}, 2)
|
||||
require.NotNil(t, result)
|
||||
// 0.10 * 2 = 0.20
|
||||
require.InDelta(t, 0.20, *result, 1e-12)
|
||||
}
|
||||
|
||||
func TestCalculateStatsCost_ImageBilling_PriceNil(t *testing.T) {
|
||||
pricing := &ChannelModelPricing{
|
||||
BillingMode: BillingModeImage,
|
||||
// PerRequestPrice is nil
|
||||
}
|
||||
result := calculateStatsCost(pricing, UsageTokens{}, 1)
|
||||
require.Nil(t, result)
|
||||
}
|
||||
|
||||
func TestCalculateStatsCost_DefaultBillingMode_FallsToToken(t *testing.T) {
|
||||
// BillingMode is empty string (default) → falls into token billing
|
||||
pricing := &ChannelModelPricing{
|
||||
InputPrice: testPtrFloat64(0.001),
|
||||
OutputPrice: testPtrFloat64(0.002),
|
||||
}
|
||||
tokens := UsageTokens{
|
||||
InputTokens: 100,
|
||||
OutputTokens: 50,
|
||||
}
|
||||
result := calculateStatsCost(pricing, tokens, 1)
|
||||
require.NotNil(t, result)
|
||||
require.InDelta(t, 0.2, *result, 1e-12)
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// tryCustomRules — 多规则顺序测试
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
func TestTryCustomRules_FirstMatchWins(t *testing.T) {
|
||||
channel := &Channel{
|
||||
AccountStatsPricingRules: []AccountStatsPricingRule{
|
||||
{
|
||||
GroupIDs: []int64{1},
|
||||
Pricing: []ChannelModelPricing{
|
||||
{ID: 100, Models: []string{"claude-opus-4"}, InputPrice: testPtrFloat64(0.01), OutputPrice: testPtrFloat64(0.02)},
|
||||
},
|
||||
},
|
||||
{
|
||||
GroupIDs: []int64{1},
|
||||
Pricing: []ChannelModelPricing{
|
||||
{ID: 200, Models: []string{"claude-opus-4"}, InputPrice: testPtrFloat64(0.99), OutputPrice: testPtrFloat64(0.99)},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
tokens := UsageTokens{InputTokens: 100, OutputTokens: 50}
|
||||
result := tryCustomRules(channel, 999, 1, "", "claude-opus-4", tokens, 1)
|
||||
require.NotNil(t, result)
|
||||
// 应使用第一条规则的价格:100*0.01 + 50*0.02 = 2.0
|
||||
require.InDelta(t, 2.0, *result, 1e-12)
|
||||
}
|
||||
|
||||
func TestTryCustomRules_SkipsNonMatchingRules(t *testing.T) {
|
||||
channel := &Channel{
|
||||
AccountStatsPricingRules: []AccountStatsPricingRule{
|
||||
{
|
||||
AccountIDs: []int64{888}, // 不匹配
|
||||
Pricing: []ChannelModelPricing{
|
||||
{ID: 100, Models: []string{"claude-opus-4"}, InputPrice: testPtrFloat64(0.99)},
|
||||
},
|
||||
},
|
||||
{
|
||||
GroupIDs: []int64{1}, // 匹配
|
||||
Pricing: []ChannelModelPricing{
|
||||
{ID: 200, Models: []string{"claude-opus-4"}, InputPrice: testPtrFloat64(0.05)},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
tokens := UsageTokens{InputTokens: 100}
|
||||
result := tryCustomRules(channel, 999, 1, "", "claude-opus-4", tokens, 1)
|
||||
require.NotNil(t, result)
|
||||
// 跳过规则1(账号不匹配),使用规则2:100*0.05 = 5.0
|
||||
require.InDelta(t, 5.0, *result, 1e-12)
|
||||
}
|
||||
|
||||
func TestTryCustomRules_NoMatch_ReturnsNil(t *testing.T) {
|
||||
channel := &Channel{
|
||||
AccountStatsPricingRules: []AccountStatsPricingRule{
|
||||
{
|
||||
AccountIDs: []int64{888},
|
||||
Pricing: []ChannelModelPricing{
|
||||
{ID: 100, Models: []string{"claude-opus-4"}, InputPrice: testPtrFloat64(0.01)},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
tokens := UsageTokens{InputTokens: 100}
|
||||
result := tryCustomRules(channel, 999, 2, "", "claude-opus-4", tokens, 1)
|
||||
require.Nil(t, result) // 账号和分组都不匹配
|
||||
}
|
||||
|
||||
func TestTryCustomRules_RuleMatchesButModelNot_ContinuesToNext(t *testing.T) {
|
||||
channel := &Channel{
|
||||
AccountStatsPricingRules: []AccountStatsPricingRule{
|
||||
{
|
||||
GroupIDs: []int64{1},
|
||||
Pricing: []ChannelModelPricing{
|
||||
{ID: 100, Models: []string{"gpt-4o"}, InputPrice: testPtrFloat64(0.01)}, // 模型不匹配
|
||||
},
|
||||
},
|
||||
{
|
||||
GroupIDs: []int64{1},
|
||||
Pricing: []ChannelModelPricing{
|
||||
{ID: 200, Models: []string{"claude-opus-4"}, InputPrice: testPtrFloat64(0.05)}, // 模型匹配
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
tokens := UsageTokens{InputTokens: 100}
|
||||
result := tryCustomRules(channel, 999, 1, "", "claude-opus-4", tokens, 1)
|
||||
require.NotNil(t, result)
|
||||
require.InDelta(t, 5.0, *result, 1e-12) // 使用规则2
|
||||
}
|
||||
@ -49,21 +49,25 @@ type Channel struct {
|
||||
ModelPricing []ChannelModelPricing
|
||||
// 渠道级模型映射(按平台分组:platform → {src→dst})
|
||||
ModelMapping map[string]map[string]string
|
||||
// 渠道特性配置(如 {"web_search_emulation": {"anthropic": true}})
|
||||
FeaturesConfig map[string]any
|
||||
|
||||
// 账号统计定价
|
||||
ApplyPricingToAccountStats bool // 是否应用渠道模型定价到账号统计
|
||||
AccountStatsPricingRules []AccountStatsPricingRule // 自定义账号统计定价规则(按 SortOrder 排序,先命中为准)
|
||||
}
|
||||
|
||||
// IsWebSearchEmulationEnabled 返回该渠道是否为指定平台启用了 web search 模拟。
|
||||
func (c *Channel) IsWebSearchEmulationEnabled(platform string) bool {
|
||||
if c == nil || c.FeaturesConfig == nil {
|
||||
return false
|
||||
}
|
||||
wse, ok := c.FeaturesConfig[featureKeyWebSearchEmulation].(map[string]any)
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
enabled, ok := wse[platform].(bool)
|
||||
return ok && enabled
|
||||
// AccountStatsPricingRule 账号统计定价规则
|
||||
// 每条规则包含匹配条件(分组/账号)和独立的模型定价。
|
||||
// 多条规则按 SortOrder 排序,先命中为准。
|
||||
type AccountStatsPricingRule struct {
|
||||
ID int64
|
||||
ChannelID int64
|
||||
Name string
|
||||
GroupIDs []int64
|
||||
AccountIDs []int64
|
||||
SortOrder int
|
||||
Pricing []ChannelModelPricing // 规则内的模型定价(复用现有定价结构)
|
||||
CreatedAt time.Time
|
||||
UpdatedAt time.Time
|
||||
}
|
||||
|
||||
// ChannelModelPricing 渠道模型定价条目
|
||||
@ -192,6 +196,26 @@ func (c *Channel) Clone() *Channel {
|
||||
cp.ModelMapping[platform] = inner
|
||||
}
|
||||
}
|
||||
if c.AccountStatsPricingRules != nil {
|
||||
cp.AccountStatsPricingRules = make([]AccountStatsPricingRule, len(c.AccountStatsPricingRules))
|
||||
for i, rule := range c.AccountStatsPricingRules {
|
||||
cp.AccountStatsPricingRules[i] = rule
|
||||
if rule.GroupIDs != nil {
|
||||
cp.AccountStatsPricingRules[i].GroupIDs = make([]int64, len(rule.GroupIDs))
|
||||
copy(cp.AccountStatsPricingRules[i].GroupIDs, rule.GroupIDs)
|
||||
}
|
||||
if rule.AccountIDs != nil {
|
||||
cp.AccountStatsPricingRules[i].AccountIDs = make([]int64, len(rule.AccountIDs))
|
||||
copy(cp.AccountStatsPricingRules[i].AccountIDs, rule.AccountIDs)
|
||||
}
|
||||
if rule.Pricing != nil {
|
||||
cp.AccountStatsPricingRules[i].Pricing = make([]ChannelModelPricing, len(rule.Pricing))
|
||||
for j := range rule.Pricing {
|
||||
cp.AccountStatsPricingRules[i].Pricing[j] = rule.Pricing[j].Clone()
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return &cp
|
||||
}
|
||||
|
||||
|
||||
@ -416,6 +416,15 @@ func (s *ChannelService) GetChannelForGroup(ctx context.Context, groupID int64)
|
||||
return ch.Clone(), nil
|
||||
}
|
||||
|
||||
// GetGroupPlatform 获取分组的平台标识(从缓存)
|
||||
func (s *ChannelService) GetGroupPlatform(ctx context.Context, groupID int64) string {
|
||||
cache, err := s.loadCache(ctx)
|
||||
if err != nil {
|
||||
return ""
|
||||
}
|
||||
return cache.groupPlatform[groupID]
|
||||
}
|
||||
|
||||
// channelLookup 热路径公共查找结果
|
||||
type channelLookup struct {
|
||||
cache *channelCache
|
||||
@ -656,16 +665,17 @@ func (s *ChannelService) Create(ctx context.Context, input *CreateChannelInput)
|
||||
}
|
||||
|
||||
channel := &Channel{
|
||||
Name: input.Name,
|
||||
Description: input.Description,
|
||||
Status: StatusActive,
|
||||
BillingModelSource: input.BillingModelSource,
|
||||
RestrictModels: input.RestrictModels,
|
||||
GroupIDs: input.GroupIDs,
|
||||
ModelPricing: input.ModelPricing,
|
||||
ModelMapping: input.ModelMapping,
|
||||
Features: input.Features,
|
||||
FeaturesConfig: input.FeaturesConfig,
|
||||
Name: input.Name,
|
||||
Description: input.Description,
|
||||
Status: StatusActive,
|
||||
BillingModelSource: input.BillingModelSource,
|
||||
RestrictModels: input.RestrictModels,
|
||||
GroupIDs: input.GroupIDs,
|
||||
ModelPricing: input.ModelPricing,
|
||||
ModelMapping: input.ModelMapping,
|
||||
Features: input.Features,
|
||||
ApplyPricingToAccountStats: input.ApplyPricingToAccountStats,
|
||||
AccountStatsPricingRules: input.AccountStatsPricingRules,
|
||||
}
|
||||
if channel.BillingModelSource == "" {
|
||||
channel.BillingModelSource = BillingModelSourceChannelMapped
|
||||
@ -754,8 +764,11 @@ func (s *ChannelService) applyUpdateInput(ctx context.Context, channel *Channel,
|
||||
if input.BillingModelSource != "" {
|
||||
channel.BillingModelSource = input.BillingModelSource
|
||||
}
|
||||
if input.FeaturesConfig != nil {
|
||||
channel.FeaturesConfig = input.FeaturesConfig
|
||||
if input.ApplyPricingToAccountStats != nil {
|
||||
channel.ApplyPricingToAccountStats = *input.ApplyPricingToAccountStats
|
||||
}
|
||||
if input.AccountStatsPricingRules != nil {
|
||||
channel.AccountStatsPricingRules = *input.AccountStatsPricingRules
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@ -922,27 +935,29 @@ func detectConflicts(entries []modelEntry, platform, errCode, label string) erro
|
||||
|
||||
// CreateChannelInput 创建渠道输入
|
||||
type CreateChannelInput struct {
|
||||
Name string
|
||||
Description string
|
||||
GroupIDs []int64
|
||||
ModelPricing []ChannelModelPricing
|
||||
ModelMapping map[string]map[string]string // platform → {src→dst}
|
||||
BillingModelSource string
|
||||
RestrictModels bool
|
||||
Features string
|
||||
FeaturesConfig map[string]any
|
||||
Name string
|
||||
Description string
|
||||
GroupIDs []int64
|
||||
ModelPricing []ChannelModelPricing
|
||||
ModelMapping map[string]map[string]string // platform → {src→dst}
|
||||
BillingModelSource string
|
||||
RestrictModels bool
|
||||
Features string
|
||||
ApplyPricingToAccountStats bool
|
||||
AccountStatsPricingRules []AccountStatsPricingRule
|
||||
}
|
||||
|
||||
// UpdateChannelInput 更新渠道输入
|
||||
type UpdateChannelInput struct {
|
||||
Name string
|
||||
Description *string
|
||||
Status string
|
||||
GroupIDs *[]int64
|
||||
ModelPricing *[]ChannelModelPricing
|
||||
ModelMapping map[string]map[string]string // platform → {src→dst}
|
||||
BillingModelSource string
|
||||
RestrictModels *bool
|
||||
Features *string
|
||||
FeaturesConfig map[string]any
|
||||
Name string
|
||||
Description *string
|
||||
Status string
|
||||
GroupIDs *[]int64
|
||||
ModelPricing *[]ChannelModelPricing
|
||||
ModelMapping map[string]map[string]string // platform → {src→dst}
|
||||
BillingModelSource string
|
||||
RestrictModels *bool
|
||||
Features *string
|
||||
ApplyPricingToAccountStats *bool
|
||||
AccountStatsPricingRules *[]AccountStatsPricingRule
|
||||
}
|
||||
|
||||
@ -7559,6 +7559,23 @@ func (s *GatewayService) recordUsageCore(ctx context.Context, input *recordUsage
|
||||
usageLog := s.buildRecordUsageLog(ctx, input, result, apiKey, user, account, subscription,
|
||||
requestedModel, multiplier, accountRateMultiplier, billingType, cacheTTLOverridden, cost, opts)
|
||||
|
||||
// 计算账号统计定价费用
|
||||
if apiKey.GroupID != nil {
|
||||
usageLog.AccountStatsCost = resolveAccountStatsCost(
|
||||
ctx, s.channelService, s.billingService,
|
||||
account.ID, *apiKey.GroupID, billingModel,
|
||||
UsageTokens{
|
||||
InputTokens: result.Usage.InputTokens,
|
||||
OutputTokens: result.Usage.OutputTokens,
|
||||
CacheCreationTokens: result.Usage.CacheCreationInputTokens,
|
||||
CacheReadTokens: result.Usage.CacheReadInputTokens,
|
||||
ImageOutputTokens: result.Usage.ImageOutputTokens,
|
||||
},
|
||||
1, // requestCount
|
||||
"", // serviceTier: Anthropic 平台不使用 service tier
|
||||
)
|
||||
}
|
||||
|
||||
if s.cfg != nil && s.cfg.RunMode == config.RunModeSimple {
|
||||
writeUsageLogBestEffort(ctx, s.usageLogRepo, usageLog, "service.gateway")
|
||||
logger.LegacyPrintf("service.gateway", "[SIMPLE MODE] Usage recorded (not billed): user=%d, tokens=%d", usageLog.UserID, usageLog.TotalTokens())
|
||||
|
||||
@ -4569,6 +4569,15 @@ func (s *OpenAIGatewayService) RecordUsage(ctx context.Context, input *OpenAIRec
|
||||
usageLog.SubscriptionID = &subscription.ID
|
||||
}
|
||||
|
||||
// 计算账号统计定价费用
|
||||
if apiKey.GroupID != nil {
|
||||
usageLog.AccountStatsCost = resolveAccountStatsCost(
|
||||
ctx, s.channelService, s.billingService,
|
||||
account.ID, *apiKey.GroupID, billingModel,
|
||||
tokens, 1, serviceTier,
|
||||
)
|
||||
}
|
||||
|
||||
if s.cfg != nil && s.cfg.RunMode == config.RunModeSimple {
|
||||
writeUsageLogBestEffort(ctx, s.usageLogRepo, usageLog, "service.openai_gateway")
|
||||
logger.LegacyPrintf("service.openai_gateway", "[SIMPLE MODE] Usage recorded (not billed): user=%d, tokens=%d", usageLog.UserID, usageLog.TotalTokens())
|
||||
|
||||
@ -146,6 +146,8 @@ type UsageLog struct {
|
||||
RateMultiplier float64
|
||||
// AccountRateMultiplier 账号计费倍率快照(nil 表示历史数据,按 1.0 处理)
|
||||
AccountRateMultiplier *float64
|
||||
// AccountStatsCost 账号统计定价预计算费用(nil = 使用默认公式 total_cost × account_rate_multiplier)
|
||||
AccountStatsCost *float64
|
||||
|
||||
BillingType int8
|
||||
RequestType RequestType
|
||||
|
||||
38
backend/migrations/101_add_account_stats_pricing.sql
Normal file
38
backend/migrations/101_add_account_stats_pricing.sql
Normal file
@ -0,0 +1,38 @@
|
||||
-- Account statistics pricing: allow channels to configure custom pricing for account cost tracking.
|
||||
|
||||
-- 1. Channel-level toggle
|
||||
ALTER TABLE channels ADD COLUMN IF NOT EXISTS apply_pricing_to_account_stats BOOLEAN NOT NULL DEFAULT FALSE;
|
||||
|
||||
-- 2. Account stats pricing rules (ordered list per channel)
|
||||
CREATE TABLE IF NOT EXISTS channel_account_stats_pricing_rules (
|
||||
id BIGSERIAL PRIMARY KEY,
|
||||
channel_id BIGINT NOT NULL REFERENCES channels(id) ON DELETE CASCADE,
|
||||
name VARCHAR(100) NOT NULL DEFAULT '',
|
||||
group_ids BIGINT[] NOT NULL DEFAULT '{}',
|
||||
account_ids BIGINT[] NOT NULL DEFAULT '{}',
|
||||
sort_order INT NOT NULL DEFAULT 0,
|
||||
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
|
||||
updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
|
||||
);
|
||||
CREATE INDEX IF NOT EXISTS idx_cas_pricing_rules_channel_id ON channel_account_stats_pricing_rules(channel_id);
|
||||
|
||||
-- 3. Model pricing for each rule (same structure as channel_model_pricing)
|
||||
CREATE TABLE IF NOT EXISTS channel_account_stats_model_pricing (
|
||||
id BIGSERIAL PRIMARY KEY,
|
||||
rule_id BIGINT NOT NULL REFERENCES channel_account_stats_pricing_rules(id) ON DELETE CASCADE,
|
||||
platform VARCHAR(50) NOT NULL DEFAULT '',
|
||||
models JSONB NOT NULL DEFAULT '[]',
|
||||
billing_mode VARCHAR(20) NOT NULL DEFAULT 'token',
|
||||
input_price NUMERIC(20,10),
|
||||
output_price NUMERIC(20,10),
|
||||
cache_write_price NUMERIC(20,10),
|
||||
cache_read_price NUMERIC(20,10),
|
||||
image_output_price NUMERIC(20,10),
|
||||
per_request_price NUMERIC(20,10),
|
||||
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
|
||||
updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
|
||||
);
|
||||
CREATE INDEX IF NOT EXISTS idx_cas_model_pricing_rule_id ON channel_account_stats_model_pricing(rule_id);
|
||||
|
||||
-- 4. Usage logs: pre-computed account stats cost (NULL = use default formula)
|
||||
ALTER TABLE usage_logs ADD COLUMN IF NOT EXISTS account_stats_cost NUMERIC(20,10);
|
||||
@ -34,6 +34,14 @@ export interface ChannelModelPricing {
|
||||
intervals: PricingInterval[]
|
||||
}
|
||||
|
||||
export interface AccountStatsPricingRule {
|
||||
id?: number
|
||||
name: string
|
||||
group_ids: number[]
|
||||
account_ids: number[]
|
||||
pricing: ChannelModelPricing[]
|
||||
}
|
||||
|
||||
export interface Channel {
|
||||
id: number
|
||||
name: string
|
||||
@ -41,10 +49,11 @@ export interface Channel {
|
||||
status: string
|
||||
billing_model_source: string // "requested" | "upstream"
|
||||
restrict_models: boolean
|
||||
features_config?: Record<string, unknown>
|
||||
group_ids: number[]
|
||||
model_pricing: ChannelModelPricing[]
|
||||
model_mapping: Record<string, Record<string, string>> // platform → {src→dst}
|
||||
apply_pricing_to_account_stats: boolean
|
||||
account_stats_pricing_rules: AccountStatsPricingRule[]
|
||||
created_at: string
|
||||
updated_at: string
|
||||
}
|
||||
@ -57,7 +66,8 @@ export interface CreateChannelRequest {
|
||||
model_mapping?: Record<string, Record<string, string>>
|
||||
billing_model_source?: string
|
||||
restrict_models?: boolean
|
||||
features_config?: Record<string, unknown>
|
||||
apply_pricing_to_account_stats?: boolean
|
||||
account_stats_pricing_rules?: AccountStatsPricingRule[]
|
||||
}
|
||||
|
||||
export interface UpdateChannelRequest {
|
||||
@ -69,7 +79,8 @@ export interface UpdateChannelRequest {
|
||||
model_mapping?: Record<string, Record<string, string>>
|
||||
billing_model_source?: string
|
||||
restrict_models?: boolean
|
||||
features_config?: Record<string, unknown>
|
||||
apply_pricing_to_account_stats?: boolean
|
||||
account_stats_pricing_rules?: AccountStatsPricingRule[]
|
||||
}
|
||||
|
||||
interface PaginatedResponse<T> {
|
||||
|
||||
@ -1844,7 +1844,18 @@ export default {
|
||||
noPlatforms: 'Click "Add Platform" to start configuring the channel',
|
||||
mappingCount: 'mappings',
|
||||
pricingEntry: 'Pricing Entry',
|
||||
noModels: 'No models added'
|
||||
noModels: 'No models added',
|
||||
applyPricingToAccountStats: 'Apply Pricing to Account Stats',
|
||||
applyPricingToAccountStatsDesc: 'When enabled, account statistics cost will use channel model pricing. Account rate multiplier still applies.',
|
||||
accountStatsPricingRules: 'Custom Account Stats Pricing Rules',
|
||||
addRule: 'Add Rule',
|
||||
noRulesConfigured: 'No custom rules configured. Channel model pricing above will be used.',
|
||||
ruleName: 'Rule name (optional)',
|
||||
ruleGroups: 'Groups',
|
||||
ruleAccounts: 'Account IDs',
|
||||
ruleAccountsPlaceholder: 'Enter account IDs, comma-separated',
|
||||
ruleModelPricing: 'Model Pricing',
|
||||
noGroupsInChannel: 'No groups selected in platform tabs above'
|
||||
}
|
||||
},
|
||||
|
||||
|
||||
@ -1923,7 +1923,18 @@ export default {
|
||||
noPlatforms: '点击"添加平台"开始配置渠道',
|
||||
mappingCount: '条映射',
|
||||
pricingEntry: '定价配置',
|
||||
noModels: '未添加模型'
|
||||
noModels: '未添加模型',
|
||||
applyPricingToAccountStats: '应用模型定价到账号统计',
|
||||
applyPricingToAccountStatsDesc: '启用后,账号统计费用将使用渠道模型定价计算。账号自身的统计倍率仍然生效。',
|
||||
accountStatsPricingRules: '自定义账号统计定价规则',
|
||||
addRule: '添加规则',
|
||||
noRulesConfigured: '未配置自定义规则,将使用上方的模型定价。',
|
||||
ruleName: '规则名称(可选)',
|
||||
ruleGroups: '分组',
|
||||
ruleAccounts: '账号 ID',
|
||||
ruleAccountsPlaceholder: '输入账号 ID,逗号分隔',
|
||||
ruleModelPricing: '模型定价',
|
||||
noGroupsInChannel: '上方平台标签页中未选择分组'
|
||||
}
|
||||
},
|
||||
|
||||
|
||||
@ -306,24 +306,6 @@
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<!-- Web Search Emulation (Anthropic only) -->
|
||||
<div v-if="section.platform === 'anthropic'" class="border-t border-gray-200 pt-3 dark:border-dark-600">
|
||||
<div class="flex items-center justify-between">
|
||||
<div>
|
||||
<label class="text-xs font-medium text-orange-600 dark:text-orange-400">
|
||||
{{ t('admin.channels.form.webSearchEmulation') }}
|
||||
</label>
|
||||
<p v-if="webSearchGlobalEnabled" class="mt-0.5 text-[11px] text-amber-500 dark:text-amber-400">
|
||||
{{ t('admin.channels.form.webSearchEmulationHint') }}
|
||||
</p>
|
||||
<p v-else class="mt-0.5 text-[11px] text-gray-400">
|
||||
{{ t('admin.channels.form.webSearchEmulationGlobalDisabled') }}
|
||||
</p>
|
||||
</div>
|
||||
<Toggle v-model="section.web_search_emulation" :disabled="!webSearchGlobalEnabled" />
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<!-- Model Mapping -->
|
||||
<div>
|
||||
<div class="mb-1 flex items-center justify-between">
|
||||
@ -398,6 +380,143 @@
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<!-- Account Stats Pricing (always visible, not tied to platform tabs) -->
|
||||
<div class="mt-6 border-t border-gray-200 pt-4 dark:border-dark-700">
|
||||
<!-- Toggle -->
|
||||
<div class="flex items-center justify-between mb-3">
|
||||
<div>
|
||||
<label class="text-sm font-medium text-gray-700 dark:text-gray-300">
|
||||
{{ t('admin.channels.form.applyPricingToAccountStats', 'Apply Pricing to Account Stats') }}
|
||||
</label>
|
||||
<p class="mt-0.5 text-xs text-gray-500 dark:text-gray-400">
|
||||
{{ t('admin.channels.form.applyPricingToAccountStatsDesc', 'When enabled, account statistics cost will use channel model pricing. Account rate multiplier still applies.') }}
|
||||
</p>
|
||||
</div>
|
||||
<Toggle
|
||||
:modelValue="form.apply_pricing_to_account_stats"
|
||||
@update:modelValue="form.apply_pricing_to_account_stats = $event"
|
||||
/>
|
||||
</div>
|
||||
|
||||
<!-- Custom rules (only when toggle is on) -->
|
||||
<div v-if="form.apply_pricing_to_account_stats" class="mt-4 space-y-4">
|
||||
<div class="flex items-center justify-between">
|
||||
<h4 class="text-sm font-medium text-gray-700 dark:text-gray-300">
|
||||
{{ t('admin.channels.form.accountStatsPricingRules', 'Custom Account Stats Pricing Rules') }}
|
||||
</h4>
|
||||
<button
|
||||
type="button"
|
||||
@click="addAccountStatsRule()"
|
||||
class="rounded-lg border border-primary-300 px-3 py-1 text-xs font-medium text-primary-600 hover:bg-primary-50 dark:border-primary-600 dark:text-primary-400 dark:hover:bg-primary-900/20"
|
||||
>
|
||||
+ {{ t('admin.channels.form.addRule', 'Add Rule') }}
|
||||
</button>
|
||||
</div>
|
||||
|
||||
<p
|
||||
v-if="form.account_stats_pricing_rules.length === 0"
|
||||
class="text-xs italic text-gray-400 dark:text-gray-500"
|
||||
>
|
||||
{{ t('admin.channels.form.noRulesConfigured', 'No custom rules configured. Channel model pricing above will be used.') }}
|
||||
</p>
|
||||
|
||||
<!-- Rule cards -->
|
||||
<div
|
||||
v-for="(rule, ruleIndex) in form.account_stats_pricing_rules"
|
||||
:key="ruleIndex"
|
||||
class="space-y-3 rounded-lg border border-gray-200 p-4 dark:border-dark-600"
|
||||
>
|
||||
<div class="flex items-center justify-between">
|
||||
<input
|
||||
v-model="rule.name"
|
||||
:placeholder="t('admin.channels.form.ruleName', 'Rule name (optional)')"
|
||||
class="bg-transparent text-sm font-medium text-gray-700 placeholder-gray-400 outline-none dark:text-gray-300"
|
||||
/>
|
||||
<button
|
||||
type="button"
|
||||
@click="removeAccountStatsRule(ruleIndex)"
|
||||
class="text-xs text-red-500 hover:text-red-700"
|
||||
>
|
||||
{{ t('common.delete', 'Delete') }}
|
||||
</button>
|
||||
</div>
|
||||
|
||||
<!-- Group selection (multi-select from channel's groups) -->
|
||||
<div>
|
||||
<label class="text-xs text-gray-500 dark:text-gray-400">
|
||||
{{ t('admin.channels.form.ruleGroups', 'Groups') }}
|
||||
</label>
|
||||
<div class="mt-1 flex flex-wrap gap-1">
|
||||
<label
|
||||
v-for="gid in allFormGroupIds"
|
||||
:key="gid"
|
||||
class="inline-flex cursor-pointer items-center gap-1 rounded-md border px-2 py-1 text-xs transition-colors"
|
||||
:class="rule.group_ids.includes(gid)
|
||||
? 'border-primary-300 bg-primary-50 dark:border-primary-700 dark:bg-primary-900/20'
|
||||
: 'border-gray-200 hover:bg-gray-50 dark:border-dark-600 dark:hover:bg-dark-700'"
|
||||
>
|
||||
<input
|
||||
type="checkbox"
|
||||
:checked="rule.group_ids.includes(gid)"
|
||||
class="h-3 w-3 rounded border-gray-300 text-primary-600 focus:ring-primary-500"
|
||||
@change="rule.group_ids.includes(gid) ? rule.group_ids.splice(rule.group_ids.indexOf(gid), 1) : rule.group_ids.push(gid)"
|
||||
/>
|
||||
<span>{{ getGroupNameById(gid) }}</span>
|
||||
</label>
|
||||
</div>
|
||||
<p v-if="allFormGroupIds.length === 0" class="mt-1 text-xs text-gray-400">
|
||||
{{ t('admin.channels.form.noGroupsInChannel', 'No groups selected in platform tabs above') }}
|
||||
</p>
|
||||
</div>
|
||||
|
||||
<!-- Account IDs input -->
|
||||
<div>
|
||||
<label class="text-xs text-gray-500 dark:text-gray-400">
|
||||
{{ t('admin.channels.form.ruleAccounts', 'Account IDs') }}
|
||||
</label>
|
||||
<input
|
||||
:value="rule.account_ids.join(', ')"
|
||||
@change="rule.account_ids = parseAccountIdsInput(($event.target as HTMLInputElement).value)"
|
||||
:placeholder="t('admin.channels.form.ruleAccountsPlaceholder', 'Enter account IDs, comma-separated')"
|
||||
class="input mt-1 text-sm"
|
||||
/>
|
||||
</div>
|
||||
|
||||
<!-- Model Pricing entries -->
|
||||
<div>
|
||||
<div class="mb-1 flex items-center justify-between">
|
||||
<label class="text-xs text-gray-500 dark:text-gray-400">
|
||||
{{ t('admin.channels.form.ruleModelPricing', 'Model Pricing') }}
|
||||
</label>
|
||||
<button
|
||||
type="button"
|
||||
@click="addRulePricingEntry(ruleIndex)"
|
||||
class="text-xs text-primary-600 hover:text-primary-700"
|
||||
>
|
||||
+ {{ t('common.add', 'Add') }}
|
||||
</button>
|
||||
</div>
|
||||
<div
|
||||
v-if="rule.pricing.length === 0"
|
||||
class="rounded border border-dashed border-gray-300 p-2 text-center text-xs text-gray-400 dark:border-dark-500"
|
||||
>
|
||||
{{ t('admin.channels.form.noPricingRules', 'No pricing rules yet. Click "Add" to create one.') }}
|
||||
</div>
|
||||
<div v-else class="space-y-2">
|
||||
<PricingEntryCard
|
||||
v-for="(entry, pIdx) in rule.pricing"
|
||||
:key="pIdx"
|
||||
:entry="entry"
|
||||
platform=""
|
||||
@update="rule.pricing.splice(pIdx, 1, $event)"
|
||||
@remove="removeRulePricingEntry(ruleIndex, pIdx)"
|
||||
/>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</form>
|
||||
</div>
|
||||
|
||||
@ -441,9 +560,8 @@
|
||||
import { ref, reactive, computed, onMounted, onUnmounted } from 'vue'
|
||||
import { useI18n } from 'vue-i18n'
|
||||
import { useAppStore } from '@/stores/app'
|
||||
import { extractApiErrorMessage } from '@/utils/apiError'
|
||||
import { adminAPI } from '@/api/admin'
|
||||
import type { Channel, ChannelModelPricing, CreateChannelRequest, UpdateChannelRequest } from '@/api/admin/channels'
|
||||
import type { Channel, ChannelModelPricing, CreateChannelRequest, UpdateChannelRequest, AccountStatsPricingRule } from '@/api/admin/channels'
|
||||
import type { PricingFormEntry } from '@/components/admin/channel/types'
|
||||
import { mTokToPerToken, perTokenToMTok, apiIntervalsToForm, formIntervalsToAPI, findModelConflict, validateIntervals } from '@/components/admin/channel/types'
|
||||
import type { AdminGroup, GroupPlatform } from '@/types'
|
||||
@ -465,18 +583,6 @@ import { getPersistedPageSize } from '@/composables/usePersistedPageSize'
|
||||
const { t } = useI18n()
|
||||
const appStore = useAppStore()
|
||||
|
||||
// Web Search global enabled state (loaded once on mount)
|
||||
const webSearchGlobalEnabled = ref(false)
|
||||
async function loadWebSearchGlobalState() {
|
||||
try {
|
||||
const cfg = await adminAPI.settings.getWebSearchEmulationConfig()
|
||||
webSearchGlobalEnabled.value = cfg?.enabled === true && (cfg?.providers?.length ?? 0) > 0
|
||||
} catch (err: unknown) {
|
||||
console.warn('Failed to load web search global state:', err)
|
||||
webSearchGlobalEnabled.value = false
|
||||
}
|
||||
}
|
||||
|
||||
// ── Platform Section type ──
|
||||
interface PlatformSection {
|
||||
platform: GroupPlatform
|
||||
@ -485,7 +591,6 @@ interface PlatformSection {
|
||||
group_ids: number[]
|
||||
model_mapping: Record<string, string>
|
||||
model_pricing: PricingFormEntry[]
|
||||
web_search_emulation: boolean
|
||||
}
|
||||
|
||||
// ── Table columns ──
|
||||
@ -553,7 +658,14 @@ const form = reactive({
|
||||
status: 'active',
|
||||
restrict_models: false,
|
||||
billing_model_source: 'channel_mapped' as string,
|
||||
platforms: [] as PlatformSection[]
|
||||
platforms: [] as PlatformSection[],
|
||||
apply_pricing_to_account_stats: false,
|
||||
account_stats_pricing_rules: [] as Array<{
|
||||
name: string
|
||||
group_ids: number[]
|
||||
account_ids: number[]
|
||||
pricing: PricingFormEntry[]
|
||||
}>
|
||||
})
|
||||
|
||||
let abortController: AbortController | null = null
|
||||
@ -597,8 +709,7 @@ function addPlatformSection(platform: GroupPlatform) {
|
||||
collapsed: false,
|
||||
group_ids: [],
|
||||
model_mapping: {},
|
||||
model_pricing: [],
|
||||
web_search_emulation: false,
|
||||
model_pricing: []
|
||||
})
|
||||
}
|
||||
|
||||
@ -711,15 +822,89 @@ function renameMappingKey(sectionIdx: number, oldKey: string, newKey: string) {
|
||||
mapping[newKey] = value
|
||||
}
|
||||
|
||||
// ── Account Stats Pricing helpers ──
|
||||
function addAccountStatsRule() {
|
||||
form.account_stats_pricing_rules.push({
|
||||
name: '',
|
||||
group_ids: [],
|
||||
account_ids: [],
|
||||
pricing: []
|
||||
})
|
||||
}
|
||||
|
||||
function addRulePricingEntry(ruleIndex: number) {
|
||||
form.account_stats_pricing_rules[ruleIndex].pricing.push({
|
||||
models: [],
|
||||
billing_mode: 'token',
|
||||
input_price: null,
|
||||
output_price: null,
|
||||
cache_write_price: null,
|
||||
cache_read_price: null,
|
||||
image_output_price: null,
|
||||
per_request_price: null,
|
||||
intervals: []
|
||||
})
|
||||
}
|
||||
|
||||
function removeAccountStatsRule(ruleIndex: number) {
|
||||
form.account_stats_pricing_rules.splice(ruleIndex, 1)
|
||||
}
|
||||
|
||||
function removeRulePricingEntry(ruleIndex: number, pricingIndex: number) {
|
||||
form.account_stats_pricing_rules[ruleIndex].pricing.splice(pricingIndex, 1)
|
||||
}
|
||||
|
||||
function getGroupNameById(groupId: number): string {
|
||||
const group = allGroups.value.find(g => g.id === groupId)
|
||||
return group ? group.name : `#${groupId}`
|
||||
}
|
||||
|
||||
/** Collect all group_ids from enabled platform sections */
|
||||
const allFormGroupIds = computed(() => {
|
||||
const ids = new Set<number>()
|
||||
for (const section of form.platforms) {
|
||||
if (!section.enabled) continue
|
||||
for (const gid of section.group_ids) {
|
||||
ids.add(gid)
|
||||
}
|
||||
}
|
||||
return [...ids]
|
||||
})
|
||||
|
||||
function parseAccountIdsInput(value: string): number[] {
|
||||
return value
|
||||
.split(',')
|
||||
.map(s => parseInt(s.trim()))
|
||||
.filter(n => !isNaN(n) && n > 0)
|
||||
}
|
||||
|
||||
function accountStatsRulesToAPI(): AccountStatsPricingRule[] {
|
||||
return form.account_stats_pricing_rules.map(rule => ({
|
||||
name: rule.name,
|
||||
group_ids: rule.group_ids,
|
||||
account_ids: rule.account_ids,
|
||||
pricing: rule.pricing
|
||||
.filter(p => p.models.length > 0)
|
||||
.map(p => ({
|
||||
platform: '',
|
||||
models: p.models,
|
||||
billing_mode: p.billing_mode,
|
||||
input_price: mTokToPerToken(p.input_price),
|
||||
output_price: mTokToPerToken(p.output_price),
|
||||
cache_write_price: mTokToPerToken(p.cache_write_price),
|
||||
cache_read_price: mTokToPerToken(p.cache_read_price),
|
||||
image_output_price: mTokToPerToken(p.image_output_price),
|
||||
per_request_price: p.per_request_price != null && p.per_request_price !== '' ? Number(p.per_request_price) : null,
|
||||
intervals: formIntervalsToAPI(p.intervals || [])
|
||||
}))
|
||||
}))
|
||||
}
|
||||
|
||||
// ── Form ↔ API conversion ──
|
||||
function formToAPI(): { group_ids: number[], model_pricing: ChannelModelPricing[], model_mapping: Record<string, Record<string, string>>, features_config: Record<string, unknown> } {
|
||||
function formToAPI(): { group_ids: number[], model_pricing: ChannelModelPricing[], model_mapping: Record<string, Record<string, string>> } {
|
||||
const group_ids: number[] = []
|
||||
const model_pricing: ChannelModelPricing[] = []
|
||||
const model_mapping: Record<string, Record<string, string>> = {}
|
||||
// Preserve existing features_config fields not managed by the form
|
||||
const featuresConfig: Record<string, unknown> = editingChannel.value?.features_config
|
||||
? { ...editingChannel.value.features_config }
|
||||
: {}
|
||||
|
||||
for (const section of form.platforms) {
|
||||
if (!section.enabled) continue
|
||||
@ -748,19 +933,7 @@ function formToAPI(): { group_ids: number[], model_pricing: ChannelModelPricing[
|
||||
}
|
||||
}
|
||||
|
||||
// Collect web_search_emulation (only anthropic platform supports it)
|
||||
const wsEmulation: Record<string, boolean> = {}
|
||||
for (const section of form.platforms) {
|
||||
if (!section.enabled) continue
|
||||
if (section.web_search_emulation && section.platform === 'anthropic') {
|
||||
wsEmulation[section.platform] = true
|
||||
}
|
||||
}
|
||||
if (Object.keys(wsEmulation).length > 0) {
|
||||
featuresConfig.web_search_emulation = wsEmulation
|
||||
}
|
||||
|
||||
return { group_ids, model_pricing, model_mapping, features_config: featuresConfig }
|
||||
return { group_ids, model_pricing, model_mapping }
|
||||
}
|
||||
|
||||
function apiToForm(channel: Channel): PlatformSection[] {
|
||||
@ -804,19 +977,13 @@ function apiToForm(channel: Channel): PlatformSection[] {
|
||||
intervals: apiIntervalsToForm(p.intervals || [])
|
||||
} as PricingFormEntry))
|
||||
|
||||
// Read web_search_emulation from features_config
|
||||
const fc = channel.features_config
|
||||
const wsEmulation = fc?.web_search_emulation as Record<string, boolean> | undefined
|
||||
const webSearchEnabled = wsEmulation?.[platform] === true
|
||||
|
||||
sections.push({
|
||||
platform,
|
||||
enabled: true,
|
||||
collapsed: false,
|
||||
group_ids: groupIds,
|
||||
model_mapping: { ...mapping },
|
||||
model_pricing: pricing,
|
||||
web_search_emulation: webSearchEnabled,
|
||||
model_pricing: pricing
|
||||
})
|
||||
}
|
||||
|
||||
@ -841,10 +1008,10 @@ async function loadChannels() {
|
||||
if (ctrl.signal.aborted || abortController !== ctrl) return
|
||||
channels.value = response.items || []
|
||||
pagination.total = response.total
|
||||
} catch (error: unknown) {
|
||||
const e = error as { name?: string; code?: string }
|
||||
if (e?.name === 'AbortError' || e?.code === 'ERR_CANCELED') return
|
||||
appStore.showError(extractApiErrorMessage(error, t('admin.channels.loadError', 'Failed to load channels')))
|
||||
} catch (error: any) {
|
||||
if (error?.name === 'AbortError' || error?.code === 'ERR_CANCELED') return
|
||||
appStore.showError(t('admin.channels.loadError', 'Failed to load channels'))
|
||||
console.error('Error loading channels:', error)
|
||||
} finally {
|
||||
if (abortController === ctrl) {
|
||||
loading.value = false
|
||||
@ -909,6 +1076,8 @@ function resetForm() {
|
||||
form.restrict_models = false
|
||||
form.billing_model_source = 'channel_mapped'
|
||||
form.platforms = []
|
||||
form.apply_pricing_to_account_stats = false
|
||||
form.account_stats_pricing_rules = []
|
||||
activeTab.value = 'basic'
|
||||
}
|
||||
|
||||
@ -926,6 +1095,23 @@ async function openEditDialog(channel: Channel) {
|
||||
form.status = channel.status
|
||||
form.restrict_models = channel.restrict_models || false
|
||||
form.billing_model_source = channel.billing_model_source || 'channel_mapped'
|
||||
form.apply_pricing_to_account_stats = channel.apply_pricing_to_account_stats || false
|
||||
form.account_stats_pricing_rules = (channel.account_stats_pricing_rules || []).map(rule => ({
|
||||
name: rule.name || '',
|
||||
group_ids: [...(rule.group_ids || [])],
|
||||
account_ids: [...(rule.account_ids || [])],
|
||||
pricing: (rule.pricing || []).map(p => ({
|
||||
models: [...(p.models || [])],
|
||||
billing_mode: p.billing_mode,
|
||||
input_price: perTokenToMTok(p.input_price),
|
||||
output_price: perTokenToMTok(p.output_price),
|
||||
cache_write_price: perTokenToMTok(p.cache_write_price),
|
||||
cache_read_price: perTokenToMTok(p.cache_read_price),
|
||||
image_output_price: perTokenToMTok(p.image_output_price),
|
||||
per_request_price: p.per_request_price,
|
||||
intervals: apiIntervalsToForm(p.intervals || [])
|
||||
} as PricingFormEntry))
|
||||
}))
|
||||
// Must load groups first so apiToForm can map groupID → platform
|
||||
await Promise.all([loadGroups(), loadAllChannelsForConflict()])
|
||||
form.platforms = apiToForm(channel)
|
||||
@ -1024,7 +1210,7 @@ async function handleSubmit() {
|
||||
}
|
||||
}
|
||||
|
||||
const { group_ids, model_pricing, model_mapping, features_config } = formToAPI()
|
||||
const { group_ids, model_pricing, model_mapping } = formToAPI()
|
||||
|
||||
submitting.value = true
|
||||
try {
|
||||
@ -1038,7 +1224,8 @@ async function handleSubmit() {
|
||||
model_mapping: Object.keys(model_mapping).length > 0 ? model_mapping : {},
|
||||
billing_model_source: form.billing_model_source,
|
||||
restrict_models: form.restrict_models,
|
||||
features_config,
|
||||
apply_pricing_to_account_stats: form.apply_pricing_to_account_stats,
|
||||
account_stats_pricing_rules: accountStatsRulesToAPI()
|
||||
}
|
||||
await adminAPI.channels.update(editingChannel.value.id, req)
|
||||
appStore.showSuccess(t('admin.channels.updateSuccess', 'Channel updated'))
|
||||
@ -1051,17 +1238,20 @@ async function handleSubmit() {
|
||||
model_mapping: Object.keys(model_mapping).length > 0 ? model_mapping : {},
|
||||
billing_model_source: form.billing_model_source,
|
||||
restrict_models: form.restrict_models,
|
||||
features_config,
|
||||
apply_pricing_to_account_stats: form.apply_pricing_to_account_stats,
|
||||
account_stats_pricing_rules: accountStatsRulesToAPI()
|
||||
}
|
||||
await adminAPI.channels.create(req)
|
||||
appStore.showSuccess(t('admin.channels.createSuccess', 'Channel created'))
|
||||
}
|
||||
closeDialog()
|
||||
loadChannels()
|
||||
} catch (error: unknown) {
|
||||
appStore.showError(extractApiErrorMessage(error, editingChannel.value
|
||||
} catch (error: any) {
|
||||
const msg = error.response?.data?.detail || (editingChannel.value
|
||||
? t('admin.channels.updateError', 'Failed to update channel')
|
||||
: t('admin.channels.createError', 'Failed to create channel')))
|
||||
: t('admin.channels.createError', 'Failed to create channel'))
|
||||
appStore.showError(msg)
|
||||
console.error('Error saving channel:', error)
|
||||
} finally {
|
||||
submitting.value = false
|
||||
}
|
||||
@ -1099,8 +1289,9 @@ async function confirmDelete() {
|
||||
showDeleteDialog.value = false
|
||||
deletingChannel.value = null
|
||||
loadChannels()
|
||||
} catch (error: unknown) {
|
||||
appStore.showError(extractApiErrorMessage(error, t('admin.channels.deleteError', 'Failed to delete channel')))
|
||||
} catch (error: any) {
|
||||
appStore.showError(error.response?.data?.detail || t('admin.channels.deleteError', 'Failed to delete channel'))
|
||||
console.error('Error deleting channel:', error)
|
||||
}
|
||||
}
|
||||
|
||||
@ -1108,7 +1299,6 @@ async function confirmDelete() {
|
||||
onMounted(() => {
|
||||
loadChannels()
|
||||
loadGroups()
|
||||
loadWebSearchGlobalState()
|
||||
})
|
||||
|
||||
onUnmounted(() => {
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user