Zuncle fb4b266bac fix(channel): 统一渠道分析时间筛选范围
让渠道分析弹窗的概览卡片和趋势图共用同一时间筛选范围,
并将默认行为调整为历史全量展示。前端改为在未选择日期时
不再回退最近 12 天,后端新增统一时间窗口解析逻辑,使
overview 与 daily 在全量、指定区间和 days 模式下保持一致。
2026-04-09 17:51:42 +08:00

743 lines
24 KiB
Go
Executable File
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

package channel
import (
"context"
"errors"
"strconv"
"strings"
"time"
"bindbox-game/internal/pkg/logger"
"bindbox-game/internal/repository/mysql"
"bindbox-game/internal/repository/mysql/dao"
"bindbox-game/internal/repository/mysql/model"
"gorm.io/gen/field"
"gorm.io/gorm"
)
type Service interface {
Create(ctx context.Context, in CreateInput) (*model.Channels, error)
Modify(ctx context.Context, id int64, in ModifyInput) error
Delete(ctx context.Context, id int64) error
List(ctx context.Context, in ListInput) (items []*ChannelWithStat, total int64, err error)
GetStats(ctx context.Context, channelID int64, days int, startDate, endDate string) (*StatsOutput, error)
GetByID(ctx context.Context, id int64) (*model.Channels, error)
SearchUsers(ctx context.Context, in SearchUsersInput) (items []*ChannelUserItem, total int64, err error)
BindUsers(ctx context.Context, channelID int64, userIDs []int64) (*BindUsersOutput, error)
}
type service struct {
logger logger.CustomLogger
readDB *dao.Query
writeDB *dao.Query
}
func New(l logger.CustomLogger, db mysql.Repo) Service {
return &service{logger: l, readDB: dao.Use(db.GetDbR()), writeDB: dao.Use(db.GetDbW())}
}
type CreateInput struct {
Name string
Code string
Type string
Remarks string
}
type ModifyInput struct {
Name *string
Type *string
Remarks *string
}
type ListInput struct {
Name string
Page int
PageSize int
}
type ChannelWithStat struct {
*model.Channels
UserCount int64 `json:"user_count"`
PaidAmountCents int64 `json:"paid_amount_cents"`
PaidAmount int64 `json:"paid_amount"`
}
type StatsOutput struct {
Overview StatsOverview `json:"overview"`
Daily []StatsDailyItem `json:"daily"`
}
type StatsOverview struct {
TotalUsers int64 `json:"total_users"`
TotalOrders int64 `json:"total_orders"`
TotalGMV int64 `json:"total_gmv"`
TotalPaidCents int64 `json:"total_paid_cents"`
TotalCostCents int64 `json:"total_cost_cents"` // 总成本(分)
TotalProfitCents int64 `json:"total_profit_cents"` // 盈亏(分) = paid - cost
TotalCost int64 `json:"total_cost"` // 总成本(元)
TotalProfit int64 `json:"total_profit"` // 盈亏(元)
CashCents int64 `json:"cash_cents"` // 现金支付(分)
CouponCents int64 `json:"coupon_cents"` // 优惠券抵扣(分)
PointsCents int64 `json:"points_cents"` // 积分抵扣(分)
}
type StatsDailyItem struct {
Date string `json:"date"`
UserCount int64 `json:"user_count"`
OrderCount int64 `json:"order_count"`
GMV int64 `json:"gmv"`
PaidCents int64 `json:"paid_cents"`
CostCents int64 `json:"cost_cents"` // 当日成本(分)
ProfitCents int64 `json:"profit_cents"` // 当日盈亏(分)
CashCents int64 `json:"cash_cents"` // 当日现金(分)
CouponCents int64 `json:"coupon_cents"` // 当日优惠券(分)
PointsCents int64 `json:"points_cents"` // 当日积分(分)
}
type SearchUsersInput struct {
Keyword string
ChannelID int64
Page int
PageSize int
}
type ChannelUserItem struct {
ID int64 `json:"id"`
Nickname string `json:"nickname"`
Mobile string `json:"mobile"`
Avatar string `json:"avatar"`
ChannelID int64 `json:"channel_id"`
ChannelName string `json:"channel_name"`
ChannelCode string `json:"channel_code"`
}
type BindUsersOutput struct {
SuccessCount int `json:"success_count"`
FailedCount int `json:"failed_count"`
SkippedCount int `json:"skipped_count"`
Details []BindUserDetail `json:"details"`
}
type BindUserDetail struct {
UserID int64 `json:"user_id"`
Status string `json:"status"` // success | failed | skipped
Message string `json:"message,omitempty"`
OldChannelID int64 `json:"old_channel_id"`
NewChannelID int64 `json:"new_channel_id"`
}
var (
ErrChannelNotFound = errors.New("channel_not_found")
ErrBindUsersEmpty = errors.New("bind_users_empty")
ErrBindUsersTooMany = errors.New("bind_users_too_many")
ErrSearchKeywordEmpty = errors.New("search_keyword_empty")
)
type orderAmountRow struct {
ActualAmount int64
CreatedAt time.Time
}
// GMVBreakdown GMV 支付方式拆分
type GMVBreakdown struct {
Total int64 // total_amount 合计(分)
Cash int64 // actual_amount 现金(分)
Coupon int64 // discount_amount 优惠券(分)
Points int64 // points_amount 积分(分)
}
// calcGMVByTotalAmount 按订单原价total_amount统计渠道GMV同时拆分支付方式。
// total_amount = actual_amount(现金) + discount_amount(优惠券) + points_amount(积分)
// 返回:总拆分、按 dateFmt 格式分组的拆分。
func (s *service) calcGMVByTotalAmount(ctx context.Context, channelID int64, dateFmt string, orderFilter string, startDate, endDate *time.Time) (GMVBreakdown, map[string]GMVBreakdown) {
type row struct {
TotalAmount int64
ActualAmount int64
DiscountAmount int64
PointsAmount int64
CreatedAt time.Time
}
q := s.readDB.Orders.WithContext(ctx).UnderlyingDB().Table("orders").
Joins("JOIN users ON users.id = orders.user_id").
Select("orders.total_amount, orders.actual_amount, orders.discount_amount, orders.points_amount, orders.created_at").
Where(orderFilter, channelID)
if startDate != nil && endDate != nil {
q = q.Where("orders.created_at >= ? AND orders.created_at <= ?", *startDate, *endDate)
}
var rows []row
q.Scan(&rows)
var total GMVBreakdown
byDate := make(map[string]GMVBreakdown)
for _, r := range rows {
total.Total += r.TotalAmount
total.Cash += r.ActualAmount
total.Coupon += r.DiscountAmount
total.Points += r.PointsAmount
key := r.CreatedAt.Format(dateFmt)
d := byDate[key]
d.Total += r.TotalAmount
d.Cash += r.ActualAmount
d.Coupon += r.DiscountAmount
d.Points += r.PointsAmount
byDate[key] = d
}
return total, byDate
}
// calcCostByInventory 计算渠道用户获得奖品的成本(含道具卡倍数)。
// 成本 = SUM(奖品价值 × 道具卡倍数)
// 奖品价值优先级: user_inventory.value_cents → activity_reward_settings.price_snapshot_cents → products.price
// 道具卡倍数: system_item_cards.reward_multiplier_x1000 / 1000无卡时 ×1.0
func (s *service) calcCostByInventory(ctx context.Context, channelID int64, dateFmt string, startDate, endDate *time.Time) (int64, map[string]int64) {
type costRow struct {
UnitCost int64
Multiplier int64
CreatedAt time.Time
}
q := s.readDB.UserInventory.WithContext(ctx).UnderlyingDB().
Table("user_inventory").
Select(`
COALESCE(NULLIF(user_inventory.value_cents, 0), activity_reward_settings.price_snapshot_cents, products.price, 0) AS unit_cost,
CASE WHEN COALESCE(system_item_cards.reward_multiplier_x1000, 1000) < 1000 THEN 1000 ELSE COALESCE(system_item_cards.reward_multiplier_x1000, 1000) END AS multiplier,
user_inventory.created_at
`).
Joins("JOIN users ON users.id = user_inventory.user_id").
Joins("LEFT JOIN orders ON orders.id = user_inventory.order_id").
Joins("LEFT JOIN activity_reward_settings ON activity_reward_settings.id = user_inventory.reward_id").
Joins("LEFT JOIN products ON products.id = user_inventory.product_id").
Joins("LEFT JOIN user_item_cards ON user_item_cards.id = orders.item_card_id").
Joins("LEFT JOIN system_item_cards ON system_item_cards.id = user_item_cards.card_id").
Where("users.channel_id = ? AND users.deleted_at IS NULL", channelID).
Where("user_inventory.status IN ?", []int{1, 3}).
Where("COALESCE(user_inventory.remark, '') NOT LIKE ?", "%void%").
Where("(orders.status = 2 OR user_inventory.order_id = 0 OR user_inventory.order_id IS NULL)").
Where("(orders.source_type IN (2,3,4) OR user_inventory.order_id = 0 OR user_inventory.order_id IS NULL)").
Where("(orders.total_amount > 0 OR user_inventory.order_id = 0 OR user_inventory.order_id IS NULL)")
if startDate != nil && endDate != nil {
q = q.Where("user_inventory.created_at >= ? AND user_inventory.created_at <= ?", *startDate, *endDate)
}
var rows []costRow
q.Scan(&rows)
var total int64
byDate := make(map[string]int64)
for _, r := range rows {
cost := r.UnitCost * r.Multiplier / 1000
total += cost
byDate[r.CreatedAt.Format(dateFmt)] += cost
}
return total, byDate
}
// calcCostByDrawSource 按订单/抽奖来源统计渠道奖品成本。
// 成本口径与活动盈亏保持一致:
// - 来源activity_draw_logs + activity_reward_settings + products.cost_price
// - 数量drop_quantity默认 1
// - 倍数:命中特定道具卡翻倍规则时 +1 份
//
// 注意:这里按 orders.user_id -> users.channel_id 归因,而不是按当前 user_inventory.user_id。
// 这样 inventory 转赠后,成本仍归到原下单用户所属渠道。
func (s *service) calcCostByDrawSource(ctx context.Context, channelID int64, dateFmt string, startDate, endDate *time.Time) (int64, map[string]int64) {
type costRow struct {
CostCents int64
PaidAt time.Time
}
q := s.readDB.ActivityDrawLogs.WithContext(ctx).UnderlyingDB().
Table("activity_draw_logs").
Select(`
SUM(COALESCE(products.cost_price, 0) * (
COALESCE(NULLIF(activity_reward_settings.drop_quantity, 0), 1) +
CASE WHEN user_item_cards.used_draw_log_id = activity_draw_logs.id AND system_item_cards.effect_type = 1 AND system_item_cards.reward_multiplier_x1000 >= 2000 THEN 1 ELSE 0 END
)) AS cost_cents,
orders.paid_at
`).
Joins("JOIN orders ON orders.id = activity_draw_logs.order_id").
Joins("JOIN users ON users.id = orders.user_id").
Joins("LEFT JOIN activity_reward_settings ON activity_reward_settings.id = activity_draw_logs.reward_id").
Joins("LEFT JOIN products ON products.id = activity_reward_settings.product_id").
Joins("LEFT JOIN user_item_cards ON user_item_cards.id = orders.item_card_id").
Joins("LEFT JOIN system_item_cards ON system_item_cards.id = user_item_cards.card_id").
Where("users.channel_id = ? AND users.deleted_at IS NULL", channelID).
Where("orders.status = 2").
Where("orders.source_type IN ?", []int{2, 3, 4}).
Where("orders.ext_order_id = '' OR orders.ext_order_id IS NULL")
if startDate != nil && endDate != nil {
q = q.Where("orders.paid_at >= ? AND orders.paid_at <= ?", *startDate, *endDate)
}
var rows []costRow
q.Group("orders.id, orders.paid_at").Scan(&rows)
var total int64
byDate := make(map[string]int64)
for _, r := range rows {
total += r.CostCents
byDate[r.PaidAt.Format(dateFmt)] += r.CostCents
}
return total, byDate
}
func (s *service) Create(ctx context.Context, in CreateInput) (*model.Channels, error) {
m := &model.Channels{Name: in.Name, Code: in.Code, Type: in.Type, Remarks: in.Remarks}
if err := s.writeDB.Channels.WithContext(ctx).Create(m); err != nil {
return nil, err
}
return m, nil
}
func (s *service) Modify(ctx context.Context, id int64, in ModifyInput) error {
updater := s.writeDB.Channels.WithContext(ctx).Where(s.writeDB.Channels.ID.Eq(id))
set := map[string]any{}
if in.Name != nil {
set["name"] = *in.Name
}
if in.Type != nil {
set["type"] = *in.Type
}
if in.Remarks != nil {
set["remarks"] = *in.Remarks
}
if len(set) == 0 {
return nil
}
_, err := updater.Updates(set)
return err
}
func (s *service) Delete(ctx context.Context, id int64) error {
_, err := s.writeDB.Channels.WithContext(ctx).Where(s.writeDB.Channels.ID.Eq(id)).Delete()
return err
}
func (s *service) List(ctx context.Context, in ListInput) (items []*ChannelWithStat, total int64, err error) {
if in.Page <= 0 {
in.Page = 1
}
if in.PageSize <= 0 {
in.PageSize = 20
}
q := s.readDB.Channels.WithContext(ctx)
if in.Name != "" {
like := "%" + in.Name + "%"
q = q.Where(s.readDB.Channels.Name.Like(like)).Or(s.readDB.Channels.Code.Like(like))
}
total, err = q.Count()
if err != nil {
return
}
// List channels
channels, err := q.Order(s.readDB.Channels.ID.Desc()).Limit(in.PageSize).Offset((in.Page - 1) * in.PageSize).Find()
if err != nil {
return
}
// Get user counts
var channelIDs []int64
for _, c := range channels {
channelIDs = append(channelIDs, c.ID)
}
stats := make(map[int64]int64)
paidStats := make(map[int64]int64)
if len(channelIDs) > 0 {
type Result struct {
ChannelID int64
Count int64
}
var results []Result
// Using raw query for grouping
err = s.readDB.Users.WithContext(ctx).UnderlyingDB().Table("users").
Select("channel_id, count(*) as count").
Where("channel_id IN ?", channelIDs).
Group("channel_id").
Scan(&results).Error
if err == nil {
for _, r := range results {
stats[r.ChannelID] = r.Count
}
}
type GMVResult struct {
ChannelID int64
TotalAmount int64
}
var gmvResults []GMVResult
err = s.readDB.Orders.WithContext(ctx).UnderlyingDB().Table("orders").
Joins("JOIN users ON users.id = orders.user_id").
Select("users.channel_id, orders.total_amount").
Where("users.channel_id IN ?", channelIDs).
Where("users.deleted_at IS NULL AND orders.status = 2 AND orders.total_amount > 0 AND orders.actual_amount > 0 AND orders.source_type IN (2,3,4) AND (orders.ext_order_id = '' OR orders.ext_order_id IS NULL)").
Scan(&gmvResults).Error
if err == nil {
for _, r := range gmvResults {
paidStats[r.ChannelID] += r.TotalAmount
}
}
}
for _, c := range channels {
paidAmountCents := paidStats[c.ID]
items = append(items, &ChannelWithStat{
Channels: c,
UserCount: stats[c.ID],
PaidAmountCents: paidAmountCents,
PaidAmount: paidAmountCents / 100,
})
}
return
}
func (s *service) resolveStatsRange(ctx context.Context, channelID int64, days int, startDateStr, endDateStr string, now time.Time) (*time.Time, *time.Time, error) {
if startDateStr != "" && endDateStr != "" {
startDate, err := time.ParseInLocation("2006-01-02", startDateStr, time.Local)
if err != nil {
return nil, nil, err
}
endDate, err := time.ParseInLocation("2006-01-02", endDateStr, time.Local)
if err != nil {
return nil, nil, err
}
endDate = endDate.Add(24*time.Hour - time.Second)
return &startDate, &endDate, nil
}
if days > 0 {
startDate := time.Date(now.Year(), now.Month(), now.Day(), 0, 0, 0, 0, now.Location()).AddDate(0, 0, -days+1)
endDate := now
return &startDate, &endDate, nil
}
type minTimeRow struct {
MinCreatedAt *time.Time `gorm:"column:min_created_at"`
}
var row minTimeRow
err := s.readDB.Orders.WithContext(ctx).UnderlyingDB().Table("orders").
Joins("JOIN users ON users.id = orders.user_id").
Select("MIN(orders.created_at) as min_created_at").
Where("users.channel_id = ? AND users.deleted_at IS NULL AND orders.status = 2 AND orders.total_amount > 0 AND orders.actual_amount > 0 AND orders.source_type IN (2,3,4) AND (orders.ext_order_id = '' OR orders.ext_order_id IS NULL)", channelID).
Scan(&row).Error
if err != nil {
return nil, nil, err
}
if row.MinCreatedAt == nil {
return nil, nil, nil
}
startDate := time.Date(row.MinCreatedAt.Year(), row.MinCreatedAt.Month(), row.MinCreatedAt.Day(), 0, 0, 0, 0, now.Location())
endDate := now
return &startDate, &endDate, nil
}
func (s *service) GetStats(ctx context.Context, channelID int64, days int, startDateStr, endDateStr string) (*StatsOutput, error) {
now := time.Now()
_, err := s.readDB.Channels.WithContext(ctx).Where(s.readDB.Channels.ID.Eq(channelID)).First()
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, ErrChannelNotFound
}
return nil, err
}
out := &StatsOutput{}
// source_type: 2=小程序抽奖 3=对对碰 4=一番赏/次卡 5=直播间抽奖抖店(不计入);排除商城直购(1)
// actual_amount>0 排除次卡免费使用的订单避免与购买次卡的订单重复计入GMV
orderFilter := "users.channel_id = ? AND users.deleted_at IS NULL AND orders.status = 2 AND orders.total_amount > 0 AND orders.actual_amount > 0 AND orders.source_type IN (2,3,4) AND (orders.ext_order_id = '' OR orders.ext_order_id IS NULL)"
statsStart, statsEnd, err := s.resolveStatsRange(ctx, channelID, days, startDateStr, endDateStr, now)
if err != nil {
return nil, err
}
// ========== 1. Overview与当前筛选窗口保持一致==========
userCountQuery := s.readDB.Users.WithContext(ctx)
if statsStart != nil && statsEnd != nil {
userCountQuery = userCountQuery.Where(s.readDB.Users.CreatedAt.Gte(*statsStart)).Where(s.readDB.Users.CreatedAt.Lte(*statsEnd))
}
userCount, _ := userCountQuery.Where(s.readDB.Users.ChannelID.Eq(channelID)).Count()
out.Overview.TotalUsers = userCount
type countResult struct{ Count int64 }
var cr countResult
overviewOrders := s.readDB.Orders.WithContext(ctx).UnderlyingDB().Table("orders").
Joins("JOIN users ON users.id = orders.user_id").
Select("count(*) as count").
Where(orderFilter, channelID)
if statsStart != nil && statsEnd != nil {
overviewOrders = overviewOrders.Where("orders.created_at >= ? AND orders.created_at <= ?", *statsStart, *statsEnd)
}
overviewOrders.Scan(&cr)
out.Overview.TotalOrders = cr.Count
totalGMV, _ := s.calcGMVByTotalAmount(ctx, channelID, "2006-01-02", orderFilter, statsStart, statsEnd)
out.Overview.TotalPaidCents = totalGMV.Total
out.Overview.TotalGMV = totalGMV.Total / 100
out.Overview.CashCents = totalGMV.Cash
out.Overview.CouponCents = totalGMV.Coupon
out.Overview.PointsCents = totalGMV.Points
// 1d. 累计成本(当前筛选窗口,按原始订单/抽奖来源归因)
totalCost, _ := s.calcCostByDrawSource(ctx, channelID, "2006-01-02", statsStart, statsEnd)
out.Overview.TotalCostCents = totalCost
out.Overview.TotalCost = totalCost / 100
out.Overview.TotalProfitCents = totalGMV.Total - totalCost
out.Overview.TotalProfit = out.Overview.TotalProfitCents / 100
// ========== 2. 趋势图(按天分组,与 overview 保持相同时间窗口)==========
if statsStart == nil || statsEnd == nil {
return out, nil
}
startDate := *statsStart
endDate := *statsEnd
dateMap := make(map[string]*StatsDailyItem)
var dateList []string
for d := startDate; !d.After(endDate); d = d.AddDate(0, 0, 1) {
key := d.Format("2006-01-02")
dateList = append(dateList, key)
dateMap[key] = &StatsDailyItem{Date: key}
}
type dailyCount struct {
Date string
Count int64
}
var dailyUsers []dailyCount
s.readDB.Users.WithContext(ctx).UnderlyingDB().Table("users").
Select("DATE_FORMAT(created_at, '%Y-%m-%d') as date, count(*) as count").
Where("channel_id = ? AND deleted_at IS NULL AND created_at >= ? AND created_at <= ?", channelID, startDate, endDate).
Group("date").Scan(&dailyUsers)
for _, u := range dailyUsers {
if item, ok := dateMap[u.Date]; ok {
item.UserCount = u.Count
}
}
var dailyOrders []dailyCount
s.readDB.Orders.WithContext(ctx).UnderlyingDB().Table("orders").
Joins("JOIN users ON users.id = orders.user_id").
Select("DATE_FORMAT(orders.created_at, '%Y-%m-%d') as date, count(*) as count").
Where(orderFilter+" AND orders.created_at >= ? AND orders.created_at <= ?", channelID, startDate, endDate).
Group("date").Scan(&dailyOrders)
for _, o := range dailyOrders {
if item, ok := dateMap[o.Date]; ok {
item.OrderCount = o.Count
}
}
_, dailyPaid := s.calcGMVByTotalAmount(ctx, channelID, "2006-01-02", orderFilter, &startDate, &endDate)
for dateKey, paid := range dailyPaid {
if item, ok := dateMap[dateKey]; ok {
item.PaidCents = paid.Total
item.GMV = paid.Total / 100
item.CashCents = paid.Cash
item.CouponCents = paid.Coupon
item.PointsCents = paid.Points
}
}
// 2f. 每日成本(按原始订单/抽奖来源归因)
_, dailyCost := s.calcCostByDrawSource(ctx, channelID, "2006-01-02", &startDate, &endDate)
for dateKey, cost := range dailyCost {
if item, ok := dateMap[dateKey]; ok {
item.CostCents = cost
item.ProfitCents = item.PaidCents - cost
}
}
for _, d := range dateList {
out.Daily = append(out.Daily, *dateMap[d])
}
return out, nil
}
func (s *service) GetByID(ctx context.Context, id int64) (*model.Channels, error) {
if id <= 0 {
return nil, ErrChannelNotFound
}
ch, err := s.readDB.Channels.WithContext(ctx).Where(s.readDB.Channels.ID.Eq(id)).First()
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, ErrChannelNotFound
}
return nil, err
}
return ch, nil
}
func (s *service) SearchUsers(ctx context.Context, in SearchUsersInput) (items []*ChannelUserItem, total int64, err error) {
keyword := strings.TrimSpace(in.Keyword)
if keyword == "" && in.ChannelID <= 0 {
return nil, 0, ErrSearchKeywordEmpty
}
if in.Page <= 0 {
in.Page = 1
}
if in.PageSize <= 0 {
in.PageSize = 20
}
if in.PageSize > 50 {
in.PageSize = 50
}
u := s.readDB.Users
c := s.readDB.Channels
q := s.readDB.Users.WithContext(ctx).ReadDB().
LeftJoin(c, c.ID.EqCol(u.ChannelID)).
Select(
u.ID,
u.Nickname,
u.Mobile,
u.Avatar,
u.ChannelID,
c.Name.As("channel_name"),
c.Code.As("channel_code"),
)
if in.ChannelID > 0 {
q = q.Where(u.ChannelID.Eq(in.ChannelID))
}
if keyword != "" {
like := "%" + keyword + "%"
if id, parseErr := strconv.ParseInt(keyword, 10, 64); parseErr == nil {
q = q.Where(field.Or(u.ID.Eq(id), u.Mobile.Like(like), u.Nickname.Like(like)))
} else {
q = q.Where(field.Or(u.Mobile.Like(like), u.Nickname.Like(like)))
}
}
total, err = q.Count()
if err != nil {
return nil, 0, err
}
type row struct {
ID int64
Nickname string
Mobile string
Avatar string
ChannelID int64
ChannelName string
ChannelCode string
}
var rows []row
if err = q.Order(u.ID.Desc()).Offset((in.Page - 1) * in.PageSize).Limit(in.PageSize).Scan(&rows); err != nil {
return nil, 0, err
}
items = make([]*ChannelUserItem, 0, len(rows))
for _, r := range rows {
items = append(items, &ChannelUserItem{
ID: r.ID,
Nickname: r.Nickname,
Mobile: r.Mobile,
Avatar: r.Avatar,
ChannelID: r.ChannelID,
ChannelName: r.ChannelName,
ChannelCode: r.ChannelCode,
})
}
return items, total, nil
}
func (s *service) BindUsers(ctx context.Context, channelID int64, userIDs []int64) (*BindUsersOutput, error) {
if len(userIDs) == 0 {
return nil, ErrBindUsersEmpty
}
seen := make(map[int64]struct{}, len(userIDs))
deduped := make([]int64, 0, len(userIDs))
for _, uid := range userIDs {
if uid <= 0 {
continue
}
if _, ok := seen[uid]; ok {
continue
}
seen[uid] = struct{}{}
deduped = append(deduped, uid)
}
if len(deduped) == 0 {
return nil, ErrBindUsersEmpty
}
if len(deduped) > 200 {
return nil, ErrBindUsersTooMany
}
result := &BindUsersOutput{
Details: make([]BindUserDetail, 0, len(deduped)),
}
err := s.writeDB.Transaction(func(tx *dao.Query) error {
_, chErr := tx.Channels.WithContext(ctx).Where(tx.Channels.ID.Eq(channelID)).First()
if chErr != nil {
if errors.Is(chErr, gorm.ErrRecordNotFound) {
return ErrChannelNotFound
}
return chErr
}
users, findErr := tx.Users.WithContext(ctx).Where(tx.Users.ID.In(deduped...)).Find()
if findErr != nil {
return findErr
}
userMap := make(map[int64]*model.Users, len(users))
for _, u := range users {
userMap[u.ID] = u
}
for _, uid := range deduped {
detail := BindUserDetail{
UserID: uid,
NewChannelID: channelID,
}
userRow, ok := userMap[uid]
if !ok {
detail.Status = "failed"
detail.Message = "user_not_found"
result.FailedCount++
result.Details = append(result.Details, detail)
continue
}
detail.OldChannelID = userRow.ChannelID
if userRow.ChannelID == channelID {
detail.Status = "skipped"
detail.Message = "already_in_channel"
result.SkippedCount++
result.Details = append(result.Details, detail)
continue
}
if _, updateErr := tx.Users.WithContext(ctx).
Where(tx.Users.ID.Eq(uid)).
Update(tx.Users.ChannelID, channelID); updateErr != nil {
return updateErr
}
detail.Status = "success"
result.SuccessCount++
result.Details = append(result.Details, detail)
}
return nil
})
if err != nil {
return nil, err
}
return result, nil
}