diff --git a/internal/api/user/request_shipping_batch_app.go b/internal/api/user/request_shipping_batch_app.go index db625e7..798a1e0 100755 --- a/internal/api/user/request_shipping_batch_app.go +++ b/internal/api/user/request_shipping_batch_app.go @@ -46,8 +46,13 @@ func (h *handler) RequestShippingBatch() core.HandlerFunc { } userID := int64(ctx.SessionUserInfo().Id) - // 运费校验:不满 5 件须已支付运费订单 - if len(req.InventoryIDs) < shippingFeeThreshold { + needFee, reason, err := h.user.CheckShippingFeeRequirement(ctx.RequestContext(), userID, req.InventoryIDs) + if err != nil { + ctx.AbortWithError(core.Error(http.StatusBadRequest, 150004, err.Error())) + return + } + + if needFee { paid, _ := h.readDB.Orders.WithContext(ctx.RequestContext()). Where( h.readDB.Orders.UserID.Eq(userID), @@ -55,7 +60,13 @@ func (h *handler) RequestShippingBatch() core.HandlerFunc { h.readDB.Orders.Status.Eq(2), ).Count() if paid == 0 { - ctx.AbortWithError(core.Error(http.StatusBadRequest, 150003, "不满5件需先支付运费")) + msg := "需先支付运费" + if reason == shippingFeeReasonContainsNonFreeShipping { + msg = "所选商品包含不包邮商品,需先支付运费" + } else if reason == shippingFeeReasonBelowThreshold { + msg = "不满5件需先支付运费" + } + ctx.AbortWithError(core.Error(http.StatusBadRequest, 150003, msg)) return } } diff --git a/internal/api/user/shipping_fee_preorder_app.go b/internal/api/user/shipping_fee_preorder_app.go index 0efda10..c4e9e20 100644 --- a/internal/api/user/shipping_fee_preorder_app.go +++ b/internal/api/user/shipping_fee_preorder_app.go @@ -13,9 +13,11 @@ import ( ) const ( - shippingFeeThreshold = 5 // 低于此件数收运费 - shippingFeeCents = 1000 // 运费金额(分),10 元 - shippingFeeSourceType = int32(5) // orders.source_type: 5 = 运费订单 + shippingFeeThreshold = 5 + shippingFeeCents = 1000 // 运费金额(分),10 元 + shippingFeeSourceType = int32(5) // orders.source_type: 5 = 运费订单 + shippingFeeReasonBelowThreshold = "below_threshold" + shippingFeeReasonContainsNonFreeShipping = "contains_non_free_shipping_item" ) type shippingFeePreorderRequest struct { @@ -26,9 +28,40 @@ type shippingFeePreorderResponse struct { OrderNo string `json:"order_no"` } +type shippingFeeCheckResponse struct { + NeedFee bool `json:"need_fee"` + Reason string `json:"reason,omitempty"` + FeeCents int64 `json:"fee_cents"` +} + +func (h *handler) ShippingFeeCheck() core.HandlerFunc { + return func(ctx core.Context) { + req := new(shippingFeePreorderRequest) + rsp := &shippingFeeCheckResponse{FeeCents: shippingFeeCents} + if err := ctx.ShouldBindJSON(req); err != nil { + ctx.AbortWithError(core.Error(http.StatusBadRequest, code.ParamBindError, validation.Error(err))) + return + } + if len(req.InventoryIDs) == 0 { + ctx.AbortWithError(core.Error(http.StatusBadRequest, code.ParamBindError, "inventory_ids 不能为空")) + return + } + + userID := int64(ctx.SessionUserInfo().Id) + needFee, reason, err := h.user.CheckShippingFeeRequirement(ctx.RequestContext(), userID, req.InventoryIDs) + if err != nil { + ctx.AbortWithError(core.Error(http.StatusBadRequest, 150002, err.Error())) + return + } + rsp.NeedFee = needFee + rsp.Reason = reason + ctx.Payload(rsp) + } +} + // ShippingFeePreorder 创建运费订单 // @Summary 创建运费订单 -// @Description 选中件数不满 5 件时,创建 10 元运费订单并返回 order_no;前端再调用 /pay/wechat/jsapi/preorder 发起支付;满 5 件包邮无需调用 +// @Description 选中商品命中运费规则时,创建 10 元运费订单并返回 order_no;前端再调用 /pay/wechat/jsapi/preorder 发起支付;无需运费时不应调用 // @Tags APP端.用户 // @Accept json // @Produce json @@ -50,12 +83,17 @@ func (h *handler) ShippingFeePreorder() core.HandlerFunc { ctx.AbortWithError(core.Error(http.StatusBadRequest, code.ParamBindError, "inventory_ids 不能为空")) return } - if len(req.InventoryIDs) >= shippingFeeThreshold { - ctx.AbortWithError(core.Error(http.StatusBadRequest, 150001, fmt.Sprintf("件数满 %d 件,无需支付运费", shippingFeeThreshold))) - return - } userID := int64(ctx.SessionUserInfo().Id) + needFee, _, err := h.user.CheckShippingFeeRequirement(ctx.RequestContext(), userID, req.InventoryIDs) + if err != nil { + ctx.AbortWithError(core.Error(http.StatusBadRequest, 150002, err.Error())) + return + } + if !needFee { + ctx.AbortWithError(core.Error(http.StatusBadRequest, 150001, fmt.Sprintf("件数满 %d 件且均非不包邮分类商品,无需支付运费", shippingFeeThreshold))) + return + } remarkBytes, _ := json.Marshal(req.InventoryIDs) diff --git a/internal/api/user/synthesis_app.go b/internal/api/user/synthesis_app.go index d9d487a..8e680f4 100644 --- a/internal/api/user/synthesis_app.go +++ b/internal/api/user/synthesis_app.go @@ -45,6 +45,23 @@ func (h *handler) DoSynthesis() core.HandlerFunc { } } +func (h *handler) DoBatchSynthesis() core.HandlerFunc { + return func(ctx core.Context) { + req := new(synthesizeRequest) + if err := ctx.ShouldBindJSON(req); err != nil || req.RecipeID <= 0 { + ctx.AbortWithError(core.Error(http.StatusBadRequest, code.ParamBindError, "invalid recipe_id")) + return + } + userID := int64(ctx.SessionUserInfo().Id) + result, err := h.synthesis.BatchSynthesize(ctx.RequestContext(), userID, req.RecipeID) + if err != nil { + ctx.AbortWithError(core.Error(http.StatusBadRequest, code.ServerError, err.Error())) + return + } + ctx.Payload(result) + } +} + func (h *handler) ListSynthesisLogsForUser() core.HandlerFunc { return func(ctx core.Context) { userID := int64(ctx.SessionUserInfo().Id) diff --git a/internal/router/router.go b/internal/router/router.go index da82666..5ad27e3 100755 --- a/internal/router/router.go +++ b/internal/router/router.go @@ -538,6 +538,7 @@ func NewHTTPMux(logger logger.CustomLogger, db mysql.Repo) (core.Mux, func(), er lotteryGroup.POST("/users/:user_id/points/redeem-item-card", userHandler.RedeemPointsToItemCard()) // 资产操作(发货/回收) + lotteryGroup.POST("/users/:user_id/inventory/shipping-fee/check", userHandler.ShippingFeeCheck()) lotteryGroup.POST("/users/:user_id/inventory/shipping-fee/preorder", userHandler.ShippingFeePreorder()) lotteryGroup.POST("/users/:user_id/inventory/request-shipping", userHandler.RequestShippingBatch()) lotteryGroup.POST("/users/:user_id/inventory/cancel-shipping", userHandler.CancelShipping()) @@ -547,6 +548,7 @@ func NewHTTPMux(logger logger.CustomLogger, db mysql.Repo) (core.Mux, func(), er // 碎片合成 appAuthApiRouter.GET("/users/:user_id/synthesis/recipes", userHandler.ListSynthesisRecipesForUser()) appAuthApiRouter.POST("/users/:user_id/synthesis/do", userHandler.DoSynthesis()) + appAuthApiRouter.POST("/users/:user_id/synthesis/do-batch", userHandler.DoBatchSynthesis()) appAuthApiRouter.GET("/users/:user_id/synthesis/logs", userHandler.ListSynthesisLogsForUser()) // 对对碰其他接口(不需要严查黑名单,或者已在preorder查过) diff --git a/internal/service/synthesis/synthesis.go b/internal/service/synthesis/synthesis.go index 51c6bec..408e931 100644 --- a/internal/service/synthesis/synthesis.go +++ b/internal/service/synthesis/synthesis.go @@ -10,6 +10,7 @@ import ( "bindbox-game/internal/repository/mysql/model" "gorm.io/gorm" + "gorm.io/gorm/clause" ) type Service interface { @@ -20,6 +21,7 @@ type Service interface { DeleteRecipe(ctx context.Context, id int64) error GetAvailableRecipesForUser(ctx context.Context, userID int64) ([]*UserRecipeView, error) Synthesize(ctx context.Context, userID int64, recipeID int64) (*model.UserInventory, error) + BatchSynthesize(ctx context.Context, userID int64, recipeID int64) (*BatchSynthesizeResult, error) ListLogs(ctx context.Context, page, size int, userID *int64) (list []*SynthesisLogView, total int64, err error) } @@ -51,12 +53,22 @@ type UserMaterialView struct { } type UserRecipeView struct { - ID int64 `json:"id"` - Name string `json:"name"` - Description string `json:"description"` - TargetProduct *model.Products `json:"target_product"` - CanSynthesize bool `json:"can_synthesize"` - Materials []UserMaterialView `json:"materials"` + ID int64 `json:"id"` + Name string `json:"name"` + Description string `json:"description"` + TargetProduct *model.Products `json:"target_product"` + CanSynthesize bool `json:"can_synthesize"` + MaxSynthesizeCount int64 `json:"max_synthesize_count"` + Materials []UserMaterialView `json:"materials"` +} + +type BatchSynthesizeResult struct { + RecipeID int64 `json:"recipe_id"` + TargetProductID int64 `json:"target_product_id"` + TargetProductName string `json:"target_product_name"` + SynthesizedCount int64 `json:"synthesized_count"` + ProducedInventoryIDs []int64 `json:"produced_inventory_ids"` + ConsumedInventoryCount int `json:"consumed_inventory_count"` } type SynthesisLogView struct { @@ -241,10 +253,11 @@ func (s *service) GetAvailableRecipesForUser(ctx context.Context, userID int64) Name: r.Name, Description: r.Description, TargetProduct: &targetProduct, - CanSynthesize: true, Materials: make([]UserMaterialView, 0, len(materials)), } + maxSynthesizeCount := int64(0) + initialized := false for _, m := range materials { var p model.Products db.WithContext(ctx).Where("id = ?", m.FragmentProductID).First(&p) @@ -254,9 +267,15 @@ func (s *service) GetAvailableRecipesForUser(ctx context.Context, userID int64) Where("user_id = ? AND product_id = ? AND status = 1", userID, m.FragmentProductID). Count(&ownedCount) - if ownedCount < int64(m.RequiredCount) { - view.CanSynthesize = false + currentCount := int64(0) + if m.RequiredCount > 0 { + currentCount = ownedCount / int64(m.RequiredCount) } + if !initialized || currentCount < maxSynthesizeCount { + maxSynthesizeCount = currentCount + initialized = true + } + image := "" if p.ImagesJSON != "" { var imgs []string @@ -272,12 +291,34 @@ func (s *service) GetAvailableRecipesForUser(ctx context.Context, userID int64) OwnedCount: ownedCount, }) } + view.MaxSynthesizeCount = maxSynthesizeCount + view.CanSynthesize = maxSynthesizeCount > 0 result = append(result, view) } return result, nil } func (s *service) Synthesize(ctx context.Context, userID int64, recipeID int64) (*model.UserInventory, error) { + result, err := s.batchSynthesize(ctx, userID, recipeID, 1) + if err != nil { + return nil, err + } + if len(result.ProducedInventoryIDs) == 0 { + return nil, fmt.Errorf("synthesis_failed") + } + + var newInv model.UserInventory + if err := s.repo.GetDbR().WithContext(ctx).Where("id = ?", result.ProducedInventoryIDs[0]).First(&newInv).Error; err != nil { + return nil, err + } + return &newInv, nil +} + +func (s *service) BatchSynthesize(ctx context.Context, userID int64, recipeID int64) (*BatchSynthesizeResult, error) { + return s.batchSynthesize(ctx, userID, recipeID, 0) +} + +func (s *service) batchSynthesize(ctx context.Context, userID int64, recipeID int64, limitTimes int64) (*BatchSynthesizeResult, error) { db := s.repo.GetDbR() var recipe model.FragmentSynthesisRecipes @@ -302,16 +343,42 @@ func (s *service) Synthesize(ctx context.Context, userID int64, recipeID int64) InventoryIDs []int64 } toConsume := make([]materialConsume, 0, len(materials)) + maxTimes := int64(0) + initialized := false for _, m := range materials { + var ownedCount int64 + db.WithContext(ctx).Model(&model.UserInventory{}). + Where("user_id = ? AND product_id = ? AND status = 1", userID, m.FragmentProductID). + Count(&ownedCount) + + currentTimes := int64(0) + if m.RequiredCount > 0 { + currentTimes = ownedCount / int64(m.RequiredCount) + } + if !initialized || currentTimes < maxTimes { + maxTimes = currentTimes + initialized = true + } + } + + if limitTimes > 0 && maxTimes > limitTimes { + maxTimes = limitTimes + } + if maxTimes <= 0 { + return nil, fmt.Errorf("insufficient_fragments") + } + + for _, m := range materials { + requiredTotal := int(m.RequiredCount) * int(maxTimes) var invList []*model.UserInventory db.WithContext(ctx). Where("user_id = ? AND product_id = ? AND status = 1", userID, m.FragmentProductID). Order("id ASC"). - Limit(int(m.RequiredCount)). + Limit(requiredTotal). Find(&invList) - if int32(len(invList)) < m.RequiredCount { + if len(invList) < requiredTotal { return nil, fmt.Errorf("insufficient_fragments") } ids := make([]int64, len(invList)) @@ -325,52 +392,84 @@ func (s *service) Synthesize(ctx context.Context, userID int64, recipeID int64) }) } - var newInv model.UserInventory + result := &BatchSynthesizeResult{ + RecipeID: recipeID, + TargetProductID: recipe.TargetProductID, + TargetProductName: targetProduct.Name, + SynthesizedCount: maxTimes, + } + wdb := s.repo.GetDbW() err := wdb.WithContext(ctx).Transaction(func(tx *gorm.DB) error { - allConsumedIDs := make([]int64, 0) + consumedByRound := make([][]int64, int(maxTimes)) + allConsumedCount := 0 + for _, mc := range toConsume { var locked []model.UserInventory - if err := tx.Raw("SELECT * FROM user_inventory WHERE id IN ? AND user_id = ? AND status = 1 FOR UPDATE", mc.InventoryIDs, userID).Scan(&locked).Error; err != nil { + query := tx.WithContext(ctx).Where("id IN ? AND user_id = ? AND status = 1", mc.InventoryIDs, userID) + if tx.Dialector.Name() != "sqlite" { + query = query.Clauses(clause.Locking{Strength: "UPDATE"}) + } + if err := query.Find(&locked).Error; err != nil { return err } - if int32(len(locked)) < mc.Required { + if len(locked) < len(mc.InventoryIDs) { return fmt.Errorf("insufficient_fragments") } - if err := tx.Exec( - "UPDATE user_inventory SET status = 2, updated_at = NOW(3), remark = CONCAT(IFNULL(remark,''), '|synthesis_consumed:recipe_', ?) WHERE id IN ? AND user_id = ? AND status = 1", - recipeID, mc.InventoryIDs, userID, - ).Error; err != nil { + updates := map[string]interface{}{ + "status": 2, + "updated_at": time.Now(), + } + if tx.Dialector.Name() == "sqlite" { + updates["remark"] = gorm.Expr("COALESCE(remark, '') || ?", fmt.Sprintf("|synthesis_consumed:recipe_%d", recipeID)) + } else { + updates["remark"] = gorm.Expr("CONCAT(IFNULL(remark,''), ?)", fmt.Sprintf("|synthesis_consumed:recipe_%d", recipeID)) + } + if err := tx.Model(&model.UserInventory{}). + Where("id IN ? AND user_id = ? AND status = 1", mc.InventoryIDs, userID). + Updates(updates).Error; err != nil { return err } - allConsumedIDs = append(allConsumedIDs, mc.InventoryIDs...) + allConsumedCount += len(mc.InventoryIDs) + for round := int64(0); round < maxTimes; round++ { + start := int(round) * int(mc.Required) + end := start + int(mc.Required) + consumedByRound[round] = append(consumedByRound[round], mc.InventoryIDs[start:end]...) + } } - newInv = model.UserInventory{ - UserID: userID, - ProductID: recipe.TargetProductID, - ValueCents: targetProduct.Price, - Status: 1, - Remark: fmt.Sprintf("synthesis_produced:recipe_%d", recipeID), - } - if err := tx.Omit("ValueSnapshotAt", "ShippingNo").Create(&newInv).Error; err != nil { - return err - } + result.ConsumedInventoryCount = allConsumedCount + result.ProducedInventoryIDs = make([]int64, 0, int(maxTimes)) + for round := int64(0); round < maxTimes; round++ { + newInv := model.UserInventory{ + UserID: userID, + ProductID: recipe.TargetProductID, + ValueCents: targetProduct.Price, + Status: 1, + Remark: fmt.Sprintf("batch_synthesis_produced:recipe_%d:round_%d", recipeID, round+1), + } + if err := tx.Omit("ValueSnapshotAt", "ShippingNo").Create(&newInv).Error; err != nil { + return err + } + result.ProducedInventoryIDs = append(result.ProducedInventoryIDs, newInv.ID) - consumedJSON, _ := json.Marshal(allConsumedIDs) - log := &model.FragmentSynthesisLogs{ - UserID: userID, - RecipeID: recipeID, - ConsumedInventoryIDs: string(consumedJSON), - ProducedInventoryID: newInv.ID, + consumedJSON, _ := json.Marshal(consumedByRound[round]) + log := &model.FragmentSynthesisLogs{ + UserID: userID, + RecipeID: recipeID, + ConsumedInventoryIDs: string(consumedJSON), + ProducedInventoryID: newInv.ID, + } + if err := tx.Create(log).Error; err != nil { + return err + } } - return tx.Create(log).Error + return nil }) - if err != nil { return nil, err } - return &newInv, nil + return result, nil } func (s *service) ListLogs(ctx context.Context, page, size int, userID *int64) ([]*SynthesisLogView, int64, error) { diff --git a/internal/service/synthesis/synthesis_test.go b/internal/service/synthesis/synthesis_test.go index 9ec963e..f00fec5 100644 --- a/internal/service/synthesis/synthesis_test.go +++ b/internal/service/synthesis/synthesis_test.go @@ -5,6 +5,7 @@ import ( "testing" "bindbox-game/internal/repository/mysql" + "bindbox-game/internal/repository/mysql/model" "gorm.io/driver/sqlite" "gorm.io/gorm" @@ -18,24 +19,69 @@ func newSynthesisServiceForTest(t *testing.T) *service { t.Fatalf("open sqlite failed: %v", err) } - if err := db.Exec(` - CREATE TABLE product_categories ( + statements := []string{ + `CREATE TABLE product_categories ( id INTEGER PRIMARY KEY AUTOINCREMENT, is_fragment INTEGER NOT NULL DEFAULT 0, deleted_at DATETIME NULL - ); - `).Error; err != nil { - t.Fatalf("create product_categories failed: %v", err) - } - - if err := db.Exec(` - CREATE TABLE products ( + );`, + `CREATE TABLE products ( id INTEGER PRIMARY KEY AUTOINCREMENT, category_id INTEGER NOT NULL DEFAULT 0, + name TEXT NOT NULL DEFAULT '', + price INTEGER NOT NULL DEFAULT 0, + status INTEGER NOT NULL DEFAULT 1, + images_json TEXT NOT NULL DEFAULT '', deleted_at DATETIME NULL - ); - `).Error; err != nil { - t.Fatalf("create products failed: %v", err) + );`, + `CREATE TABLE fragment_synthesis_recipes ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + name TEXT NOT NULL DEFAULT '', + description TEXT NOT NULL DEFAULT '', + target_product_id INTEGER NOT NULL DEFAULT 0, + status INTEGER NOT NULL DEFAULT 1, + created_at DATETIME NULL, + updated_at DATETIME NULL, + deleted_at DATETIME NULL + );`, + `CREATE TABLE fragment_synthesis_recipe_materials ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + recipe_id INTEGER NOT NULL DEFAULT 0, + fragment_product_id INTEGER NOT NULL DEFAULT 0, + required_count INTEGER NOT NULL DEFAULT 0, + created_at DATETIME NULL, + updated_at DATETIME NULL, + deleted_at DATETIME NULL + );`, + `CREATE TABLE fragment_synthesis_logs ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + created_at DATETIME NULL, + user_id INTEGER NOT NULL DEFAULT 0, + recipe_id INTEGER NOT NULL DEFAULT 0, + consumed_inventory_ids TEXT NOT NULL DEFAULT '', + produced_inventory_id INTEGER NOT NULL DEFAULT 0 + );`, + `CREATE TABLE user_inventory ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + created_at DATETIME NULL, + updated_at DATETIME NULL, + user_id INTEGER NOT NULL DEFAULT 0, + product_id INTEGER NOT NULL DEFAULT 0, + value_cents INTEGER NOT NULL DEFAULT 0, + value_source INTEGER NOT NULL DEFAULT 0, + value_snapshot_at DATETIME NULL, + order_id INTEGER NOT NULL DEFAULT 0, + activity_id INTEGER NOT NULL DEFAULT 0, + reward_id INTEGER NOT NULL DEFAULT 0, + status INTEGER NOT NULL DEFAULT 1, + shipping_no TEXT NOT NULL DEFAULT '', + remark TEXT NOT NULL DEFAULT '' + );`, + } + for _, stmt := range statements { + if err := db.Exec(stmt).Error; err != nil { + t.Fatalf("exec schema failed: %v", err) + } } return New(mysql.NewTestRepo(db)).(*service) @@ -93,3 +139,111 @@ func TestValidateRecipeProducts_ValidCombination(t *testing.T) { } } +func TestBatchSynthesizeProducesAllPossibleItems(t *testing.T) { + svc := newSynthesisServiceForTest(t) + ctx := context.Background() + seedBatchSynthesisFixture(t, svc) + + result, err := svc.BatchSynthesize(ctx, 1001, 1) + if err != nil { + t.Fatalf("batch synthesize failed: %v", err) + } + if result.SynthesizedCount != 3 { + t.Fatalf("expected 3 syntheses, got %d", result.SynthesizedCount) + } + if len(result.ProducedInventoryIDs) != 3 { + t.Fatalf("expected 3 produced ids, got %d", len(result.ProducedInventoryIDs)) + } + if result.ConsumedInventoryCount != 9 { + t.Fatalf("expected 9 consumed inventory items, got %d", result.ConsumedInventoryCount) + } + + assertInventoryStatusCount(t, svc, 1001, 11, 2, 6) + assertInventoryStatusCount(t, svc, 1001, 12, 2, 3) + assertInventoryStatusCount(t, svc, 1001, 10, 1, 3) + assertSynthesisLogCount(t, svc, 1001, 1, 3) +} + +func TestBatchSynthesizeUsesShortestMaterial(t *testing.T) { + svc := newSynthesisServiceForTest(t) + ctx := context.Background() + seedBatchSynthesisFixture(t, svc) + + if err := svc.repo.GetDbW().Exec("INSERT INTO user_inventory(user_id, product_id, value_cents, status, remark) VALUES (1001, 12, 0, 1, 'extra_fragment')").Error; err != nil { + t.Fatalf("seed extra fragment failed: %v", err) + } + + result, err := svc.BatchSynthesize(ctx, 1001, 1) + if err != nil { + t.Fatalf("batch synthesize failed: %v", err) + } + if result.SynthesizedCount != 3 { + t.Fatalf("expected shortest material to cap at 3, got %d", result.SynthesizedCount) + } + assertInventoryStatusCount(t, svc, 1001, 12, 1, 1) +} + +func TestBatchSynthesizeFailsWhenInsufficient(t *testing.T) { + svc := newSynthesisServiceForTest(t) + ctx := context.Background() + seedBatchSynthesisFixture(t, svc) + + if err := svc.repo.GetDbW().Exec("DELETE FROM user_inventory WHERE user_id = ? AND product_id = ?", 1001, 12).Error; err != nil { + t.Fatalf("clear fragments failed: %v", err) + } + + _, err := svc.BatchSynthesize(ctx, 1001, 1) + if err == nil || err.Error() != "insufficient_fragments" { + t.Fatalf("expected insufficient_fragments, got %v", err) + } +} + +func seedBatchSynthesisFixture(t *testing.T, svc *service) { + t.Helper() + db := svc.repo.GetDbW() + + statements := []string{ + "INSERT INTO product_categories(id, is_fragment) VALUES (1, 1), (2, 0)", + "INSERT INTO products(id, category_id, name, price, status) VALUES (10, 2, '目标商品', 1999, 1), (11, 1, '碎片A', 0, 1), (12, 1, '碎片B', 0, 1)", + "INSERT INTO fragment_synthesis_recipes(id, name, description, target_product_id, status) VALUES (1, '配方1', '测试配方', 10, 1)", + "INSERT INTO fragment_synthesis_recipe_materials(id, recipe_id, fragment_product_id, required_count) VALUES (1, 1, 11, 2), (2, 1, 12, 1)", + } + for _, stmt := range statements { + if err := db.Exec(stmt).Error; err != nil { + t.Fatalf("seed fixture failed: %v", err) + } + } + + for i := 0; i < 6; i++ { + if err := db.Create(&model.UserInventory{UserID: 1001, ProductID: 11, ValueCents: 0, Status: 1, Remark: "fragment_a"}).Error; err != nil { + t.Fatalf("seed fragment a failed: %v", err) + } + } + for i := 0; i < 3; i++ { + if err := db.Create(&model.UserInventory{UserID: 1001, ProductID: 12, ValueCents: 0, Status: 1, Remark: "fragment_b"}).Error; err != nil { + t.Fatalf("seed fragment b failed: %v", err) + } + } +} + +func assertInventoryStatusCount(t *testing.T, svc *service, userID, productID int64, status int32, want int64) { + t.Helper() + var count int64 + if err := svc.repo.GetDbR().Model(&model.UserInventory{}).Where("user_id = ? AND product_id = ? AND status = ?", userID, productID, status).Count(&count).Error; err != nil { + t.Fatalf("count inventory failed: %v", err) + } + if count != want { + t.Fatalf("expected %d inventory rows for product %d status %d, got %d", want, productID, status, count) + } +} + +func assertSynthesisLogCount(t *testing.T, svc *service, userID, recipeID int64, want int64) { + t.Helper() + var count int64 + if err := svc.repo.GetDbR().Model(&model.FragmentSynthesisLogs{}).Where("user_id = ? AND recipe_id = ?", userID, recipeID).Count(&count).Error; err != nil { + t.Fatalf("count logs failed: %v", err) + } + if count != want { + t.Fatalf("expected %d synthesis logs, got %d", want, count) + } +} diff --git a/internal/service/user/address_share.go b/internal/service/user/address_share.go index e3d98ff..1bbad82 100755 --- a/internal/service/user/address_share.go +++ b/internal/service/user/address_share.go @@ -15,6 +15,17 @@ import ( "gorm.io/gorm" ) +const ( + shippingFeeThreshold = 5 + shippingFeeReasonBelowThreshold = "below_threshold" + shippingFeeReasonContainsNonFreeShipping = "contains_non_free_shipping_item" +) + +var nonFreeShippingCategoryIDs = map[int64]struct{}{ + 14: {}, + 15: {}, +} + type shareClaims struct { OwnerUserID int64 `json:"owner_user_id"` InventoryID int64 `json:"inventory_id"` @@ -319,6 +330,63 @@ func generateBatchNo(userID int64) string { return fmt.Sprintf("B%d%d", userID, time.Now().UnixNano()/1000000) } +func (s *service) CheckShippingFeeRequirement(ctx context.Context, userID int64, inventoryIDs []int64) (bool, string, error) { + uniqMap := make(map[int64]struct{}, len(inventoryIDs)) + uniq := make([]int64, 0, len(inventoryIDs)) + for _, id := range inventoryIDs { + if id <= 0 { + continue + } + if _, ok := uniqMap[id]; ok { + continue + } + uniqMap[id] = struct{}{} + uniq = append(uniq, id) + } + if len(uniq) == 0 { + return false, "", fmt.Errorf("invalid inventory_ids") + } + + invList, err := s.readDB.UserInventory.WithContext(ctx). + Where(s.readDB.UserInventory.ID.In(uniq...)). + Find() + if err != nil { + return false, "", err + } + + productIDSet := make(map[int64]struct{}, len(invList)) + productIDs := make([]int64, 0, len(invList)) + for _, inv := range invList { + if inv == nil || inv.UserID != userID || inv.ProductID <= 0 { + continue + } + if _, ok := productIDSet[inv.ProductID]; ok { + continue + } + productIDSet[inv.ProductID] = struct{}{} + productIDs = append(productIDs, inv.ProductID) + } + + if len(productIDs) > 0 { + products, err := s.readDB.Products.WithContext(ctx). + Where(s.readDB.Products.ID.In(productIDs...)). + Find() + if err != nil { + return false, "", err + } + for _, product := range products { + if _, ok := nonFreeShippingCategoryIDs[product.CategoryID]; ok { + return true, shippingFeeReasonContainsNonFreeShipping, nil + } + } + } + + if len(uniq) < shippingFeeThreshold { + return true, shippingFeeReasonBelowThreshold, nil + } + return false, "", nil +} + func (s *service) RequestShippings(ctx context.Context, userID int64, inventoryIDs []int64, addressID *int64) (addrID int64, batchNo string, success []int64, skipped []struct { ID int64 Reason string diff --git a/internal/service/user/request_shipping_batch_test.go b/internal/service/user/request_shipping_batch_test.go index 5507627..ce37406 100755 --- a/internal/service/user/request_shipping_batch_test.go +++ b/internal/service/user/request_shipping_batch_test.go @@ -29,7 +29,6 @@ func TestRequestShippings_EmptyInventoryIDs(t *testing.T) { db, _ := setupMockDBForShipping(t) svc := newTestService(db) - // Empty inventory IDs should return failed with "invalid_params" _, _, _, _, failed, err := svc.RequestShippings(context.Background(), 1, []int64{}, nil) assert.NoError(t, err) assert.Len(t, failed, 1) @@ -40,7 +39,6 @@ func TestRequestShippings_AllZeroInventoryIDs(t *testing.T) { db, _ := setupMockDBForShipping(t) svc := newTestService(db) - // All zero or negative IDs should be filtered, resulting in empty uniq list _, _, _, _, failed, err := svc.RequestShippings(context.Background(), 1, []int64{0, -1, 0}, nil) assert.NoError(t, err) assert.Len(t, failed, 1) @@ -51,17 +49,88 @@ func TestRequestShippings_NoDefaultAddress(t *testing.T) { db, mock := setupMockDBForShipping(t) svc := newTestService(db) - // Mock default address query - return no rows + mock.ExpectQuery("SELECT .* FROM `user_addresses`"). + WillReturnRows(sqlmock.NewRows(nil)) mock.ExpectQuery("SELECT .* FROM `user_addresses`"). WillReturnRows(sqlmock.NewRows(nil)) - // Mock all addresses query - return empty - mock.ExpectQuery("SELECT .* FROM `user_addresses`"). - WillReturnRows(sqlmock.NewRows(nil)) - - // With valid IDs but no address, should return no_default_address error _, _, _, _, failed, err := svc.RequestShippings(context.Background(), 1, []int64{1, 2}, nil) assert.NoError(t, err) assert.Len(t, failed, 1) assert.Equal(t, "no_default_address", failed[0].Reason) } + +func TestCheckShippingFeeRequirement_BelowThreshold(t *testing.T) { + db, mock := setupMockDBForShipping(t) + svc := newTestService(db) + + mock.ExpectQuery("SELECT .* FROM `user_inventory`"). + WillReturnRows(sqlmock.NewRows([]string{"id", "user_id", "product_id"}). + AddRow(1, 99, 101). + AddRow(2, 99, 102). + AddRow(3, 99, 103). + AddRow(4, 99, 104)) + mock.ExpectQuery("SELECT .* FROM `products`"). + WillReturnRows(sqlmock.NewRows([]string{"id", "category_id"}). + AddRow(101, 1). + AddRow(102, 2). + AddRow(103, 3). + AddRow(104, 4)) + + needFee, reason, err := svc.CheckShippingFeeRequirement(context.Background(), 99, []int64{1, 2, 3, 4}) + assert.NoError(t, err) + assert.True(t, needFee) + assert.Equal(t, shippingFeeReasonBelowThreshold, reason) +} + +func TestCheckShippingFeeRequirement_FreeWhenThresholdReached(t *testing.T) { + db, mock := setupMockDBForShipping(t) + svc := newTestService(db) + + mock.ExpectQuery("SELECT .* FROM `user_inventory`"). + WillReturnRows(sqlmock.NewRows([]string{"id", "user_id", "product_id"}). + AddRow(1, 99, 101). + AddRow(2, 99, 102). + AddRow(3, 99, 103). + AddRow(4, 99, 104). + AddRow(5, 99, 105)) + mock.ExpectQuery("SELECT .* FROM `products`"). + WillReturnRows(sqlmock.NewRows([]string{"id", "category_id"}). + AddRow(101, 1). + AddRow(102, 2). + AddRow(103, 3). + AddRow(104, 4). + AddRow(105, 5)) + + needFee, reason, err := svc.CheckShippingFeeRequirement(context.Background(), 99, []int64{1, 2, 3, 4, 5}) + assert.NoError(t, err) + assert.False(t, needFee) + assert.Equal(t, "", reason) +} + +func TestCheckShippingFeeRequirement_NonFreeCategoryOverridesThreshold(t *testing.T) { + db, mock := setupMockDBForShipping(t) + svc := newTestService(db) + + mock.ExpectQuery("SELECT .* FROM `user_inventory`"). + WillReturnRows(sqlmock.NewRows([]string{"id", "user_id", "product_id"}). + AddRow(1, 99, 101). + AddRow(2, 99, 102). + AddRow(3, 99, 103). + AddRow(4, 99, 104). + AddRow(5, 99, 105). + AddRow(6, 99, 106)) + mock.ExpectQuery("SELECT .* FROM `products`"). + WillReturnRows(sqlmock.NewRows([]string{"id", "category_id"}). + AddRow(101, 1). + AddRow(102, 2). + AddRow(103, 3). + AddRow(104, 4). + AddRow(105, 14). + AddRow(106, 5)) + + needFee, reason, err := svc.CheckShippingFeeRequirement(context.Background(), 99, []int64{1, 2, 3, 4, 5, 6}) + assert.NoError(t, err) + assert.True(t, needFee) + assert.Equal(t, shippingFeeReasonContainsNonFreeShipping, reason) +} diff --git a/internal/service/user/user.go b/internal/service/user/user.go index 91641cc..a8c0660 100755 --- a/internal/service/user/user.go +++ b/internal/service/user/user.go @@ -57,6 +57,7 @@ type Service interface { SubmitAddressShare(ctx context.Context, shareToken string, name string, mobile string, province string, city string, district string, address string, submittedByUserID *int64, submittedIP *string) (int64, error) RequestShipping(ctx context.Context, userID int64, inventoryID int64) (int64, error) CancelShipping(ctx context.Context, userID int64, inventoryID int64, batchNo string) (int64, error) + CheckShippingFeeRequirement(ctx context.Context, userID int64, inventoryIDs []int64) (needFee bool, reason string, err error) RequestShippings(ctx context.Context, userID int64, inventoryIDs []int64, addressID *int64) (addrID int64, batchNo string, success []int64, skipped []struct { ID int64 Reason string