diff options
Diffstat (limited to 'item')
-rw-r--r-- | item/controller.go | 15 | ||||
-rw-r--r-- | item/hooks.go | 6 | ||||
-rw-r--r-- | item/service.go | 27 | ||||
-rw-r--r-- | item/validators.go | 72 |
4 files changed, 80 insertions, 40 deletions
diff --git a/item/controller.go b/item/controller.go index cf9683d..9993688 100644 --- a/item/controller.go +++ b/item/controller.go @@ -116,6 +116,7 @@ func handleDelBrand (ctx *gin.Context) { var brand Brand brand.ID = uint(id) + uId, ok := ctx.Get("UserID") if !ok { ctx.Error(e.ErrUnauthorized) @@ -126,6 +127,13 @@ func handleDelBrand (ctx *gin.Context) { userId := uId.(uint) brand.UserID = userId + err = checkBrandOwnership(brand.ID, brand.UserID) + if err != nil { + ctx.Error(err) + ctx.Abort() + return + } + err = brand.del() if err != nil { ctx.Error(err) @@ -210,6 +218,13 @@ func handleDelItem (ctx *gin.Context) { userId := uId.(uint) item.UserID = userId + err = checkItemOwnership(item.ID, item.UserID) + if err != nil { + ctx.Error(err) + ctx.Abort() + return + } + err = item.del() if err != nil { ctx.Error(err) diff --git a/item/hooks.go b/item/hooks.go index 558a8cb..5a27114 100644 --- a/item/hooks.go +++ b/item/hooks.go @@ -25,7 +25,8 @@ import ( func (i *SavedItem) BeforeSave(tx *gorm.DB) error { var err error - err = checkIfBrandExists(i.BrandID, i.UserID) + // also checks if brand actually exists + err = checkBrandOwnership(i.BrandID, i.UserID) if err != nil { return err } @@ -53,5 +54,8 @@ func (b *Brand) BeforeDelete(tx *gorm.DB) error { return errors.ErrNoWhereCondition } + // delete all items + db.Where("brand_id = ? and user_id = ?", b.ID, b.UserID).Delete(&SavedItem{}) + return nil } diff --git a/item/service.go b/item/service.go index fb03adc..80faff0 100644 --- a/item/service.go +++ b/item/service.go @@ -19,26 +19,15 @@ package item import ( "vidhukant.com/openbills/errors" - e "vidhukant.com/openbills/errors" ) func getBrandItems(items *[]SavedItem, id, userId uint) error { - // check if brand id is valid and is owned by user - var count int64 - err := db.Model(&Brand{}). - Select("id"). - Where("id = ? and user_id = ?", id, userId). - Count(&count). - Error - + err := checkBrandOwnership(id, userId) if err != nil { return err - } - - if count == 0 { - return errors.ErrBrandNotFound } + // get items res := db.Model(&SavedItem{}).Where("brand_id = ?", id).Find(&items) // TODO: handle potential errors @@ -48,7 +37,7 @@ func getBrandItems(items *[]SavedItem, id, userId uint) error { // returns 404 if either row doesn't exist or if the user doesn't own it if res.RowsAffected == 0 { - return e.ErrEmptyResponse + return errors.ErrEmptyResponse } return nil @@ -63,7 +52,7 @@ func getBrands(brands *[]Brand, userId uint) error { } if res.RowsAffected == 0 { - return e.ErrEmptyResponse + return errors.ErrEmptyResponse } return nil @@ -75,8 +64,8 @@ func (b *Brand) upsert() error { return res.Error } -// TODO: delete all items upon brand deletion func (b *Brand) del() error { + // delete brand res := db.Where("id = ? and user_id = ?", b.ID, b.UserID).Delete(b) // TODO: handle potential errors @@ -86,7 +75,7 @@ func (b *Brand) del() error { // returns 404 if either row doesn't exist or if the user doesn't own it if res.RowsAffected == 0 { - return e.ErrNotFound + return errors.ErrNotFound } return nil @@ -101,7 +90,7 @@ func getItems(items *[]SavedItem, userId uint) error { } if res.RowsAffected == 0 { - return e.ErrEmptyResponse + return errors.ErrEmptyResponse } return nil @@ -123,7 +112,7 @@ func (i *SavedItem) del() error { // returns 404 if either row doesn't exist or if the user doesn't own it if res.RowsAffected == 0 { - return e.ErrNotFound + return errors.ErrNotFound } return nil diff --git a/item/validators.go b/item/validators.go index 996a5d7..e931843 100644 --- a/item/validators.go +++ b/item/validators.go @@ -51,26 +51,6 @@ func (b *Brand) validate() error { return nil } -func checkIfBrandExists(id, userId uint) error { - // check if brand id is valid and is owned by user - var count int64 - err := db.Model(&Brand{}). - Select("id"). - Where("id = ? and user_id = ?", id, userId). - Count(&count). - Error - - if err != nil { - return err - } - - if count == 0 { - return errors.ErrBrandNotFound - } - - return nil -} - func (i *SavedItem) validate() error { // trim whitespaces i.Name = strings.TrimSpace(i.Name) @@ -109,3 +89,55 @@ func (i *SavedItem) validate() error { return nil } + +func checkBrandOwnership(brandId, userId uint) error { + var brand Brand + err := db. + Select("id", "user_id"). + Where("id = ?", brandId). + Find(&brand). + Error + + // TODO: handle potential errors + if err != nil { + return err + } + + // brand doesn't exist + if brand.ID == 0 { + return errors.ErrBrandNotFound + } + + // user doesn't own this brand + if brand.UserID != userId { + return errors.ErrForbidden + } + + return nil +} + +func checkItemOwnership(itemId, userId uint) error { + var item SavedItem + err := db. + Select("id", "user_id"). + Where("id = ?", itemId). + Find(&item). + Error + + // TODO: handle potential errors + if err != nil { + return err + } + + // item doesn't exist + if item.ID == 0 { + return errors.ErrNotFound + } + + // user doesn't own this item + if item.UserID != userId { + return errors.ErrForbidden + } + + return nil +} |