aboutsummaryrefslogtreecommitdiff
path: root/item
diff options
context:
space:
mode:
Diffstat (limited to 'item')
-rw-r--r--item/controller.go15
-rw-r--r--item/hooks.go6
-rw-r--r--item/service.go27
-rw-r--r--item/validators.go72
4 files changed, 80 insertions, 40 deletions
diff --git a/item/controller.go b/item/controller.go
index cf9683d..9993688 100644
--- a/item/controller.go
+++ b/item/controller.go
@@ -116,6 +116,7 @@ func handleDelBrand (ctx *gin.Context) {
var brand Brand
brand.ID = uint(id)
+
uId, ok := ctx.Get("UserID")
if !ok {
ctx.Error(e.ErrUnauthorized)
@@ -126,6 +127,13 @@ func handleDelBrand (ctx *gin.Context) {
userId := uId.(uint)
brand.UserID = userId
+ err = checkBrandOwnership(brand.ID, brand.UserID)
+ if err != nil {
+ ctx.Error(err)
+ ctx.Abort()
+ return
+ }
+
err = brand.del()
if err != nil {
ctx.Error(err)
@@ -210,6 +218,13 @@ func handleDelItem (ctx *gin.Context) {
userId := uId.(uint)
item.UserID = userId
+ err = checkItemOwnership(item.ID, item.UserID)
+ if err != nil {
+ ctx.Error(err)
+ ctx.Abort()
+ return
+ }
+
err = item.del()
if err != nil {
ctx.Error(err)
diff --git a/item/hooks.go b/item/hooks.go
index 558a8cb..5a27114 100644
--- a/item/hooks.go
+++ b/item/hooks.go
@@ -25,7 +25,8 @@ import (
func (i *SavedItem) BeforeSave(tx *gorm.DB) error {
var err error
- err = checkIfBrandExists(i.BrandID, i.UserID)
+ // also checks if brand actually exists
+ err = checkBrandOwnership(i.BrandID, i.UserID)
if err != nil {
return err
}
@@ -53,5 +54,8 @@ func (b *Brand) BeforeDelete(tx *gorm.DB) error {
return errors.ErrNoWhereCondition
}
+ // delete all items
+ db.Where("brand_id = ? and user_id = ?", b.ID, b.UserID).Delete(&SavedItem{})
+
return nil
}
diff --git a/item/service.go b/item/service.go
index fb03adc..80faff0 100644
--- a/item/service.go
+++ b/item/service.go
@@ -19,26 +19,15 @@ package item
import (
"vidhukant.com/openbills/errors"
- e "vidhukant.com/openbills/errors"
)
func getBrandItems(items *[]SavedItem, id, userId uint) error {
- // check if brand id is valid and is owned by user
- var count int64
- err := db.Model(&Brand{}).
- Select("id").
- Where("id = ? and user_id = ?", id, userId).
- Count(&count).
- Error
-
+ err := checkBrandOwnership(id, userId)
if err != nil {
return err
- }
-
- if count == 0 {
- return errors.ErrBrandNotFound
}
+ // get items
res := db.Model(&SavedItem{}).Where("brand_id = ?", id).Find(&items)
// TODO: handle potential errors
@@ -48,7 +37,7 @@ func getBrandItems(items *[]SavedItem, id, userId uint) error {
// returns 404 if either row doesn't exist or if the user doesn't own it
if res.RowsAffected == 0 {
- return e.ErrEmptyResponse
+ return errors.ErrEmptyResponse
}
return nil
@@ -63,7 +52,7 @@ func getBrands(brands *[]Brand, userId uint) error {
}
if res.RowsAffected == 0 {
- return e.ErrEmptyResponse
+ return errors.ErrEmptyResponse
}
return nil
@@ -75,8 +64,8 @@ func (b *Brand) upsert() error {
return res.Error
}
-// TODO: delete all items upon brand deletion
func (b *Brand) del() error {
+ // delete brand
res := db.Where("id = ? and user_id = ?", b.ID, b.UserID).Delete(b)
// TODO: handle potential errors
@@ -86,7 +75,7 @@ func (b *Brand) del() error {
// returns 404 if either row doesn't exist or if the user doesn't own it
if res.RowsAffected == 0 {
- return e.ErrNotFound
+ return errors.ErrNotFound
}
return nil
@@ -101,7 +90,7 @@ func getItems(items *[]SavedItem, userId uint) error {
}
if res.RowsAffected == 0 {
- return e.ErrEmptyResponse
+ return errors.ErrEmptyResponse
}
return nil
@@ -123,7 +112,7 @@ func (i *SavedItem) del() error {
// returns 404 if either row doesn't exist or if the user doesn't own it
if res.RowsAffected == 0 {
- return e.ErrNotFound
+ return errors.ErrNotFound
}
return nil
diff --git a/item/validators.go b/item/validators.go
index 996a5d7..e931843 100644
--- a/item/validators.go
+++ b/item/validators.go
@@ -51,26 +51,6 @@ func (b *Brand) validate() error {
return nil
}
-func checkIfBrandExists(id, userId uint) error {
- // check if brand id is valid and is owned by user
- var count int64
- err := db.Model(&Brand{}).
- Select("id").
- Where("id = ? and user_id = ?", id, userId).
- Count(&count).
- Error
-
- if err != nil {
- return err
- }
-
- if count == 0 {
- return errors.ErrBrandNotFound
- }
-
- return nil
-}
-
func (i *SavedItem) validate() error {
// trim whitespaces
i.Name = strings.TrimSpace(i.Name)
@@ -109,3 +89,55 @@ func (i *SavedItem) validate() error {
return nil
}
+
+func checkBrandOwnership(brandId, userId uint) error {
+ var brand Brand
+ err := db.
+ Select("id", "user_id").
+ Where("id = ?", brandId).
+ Find(&brand).
+ Error
+
+ // TODO: handle potential errors
+ if err != nil {
+ return err
+ }
+
+ // brand doesn't exist
+ if brand.ID == 0 {
+ return errors.ErrBrandNotFound
+ }
+
+ // user doesn't own this brand
+ if brand.UserID != userId {
+ return errors.ErrForbidden
+ }
+
+ return nil
+}
+
+func checkItemOwnership(itemId, userId uint) error {
+ var item SavedItem
+ err := db.
+ Select("id", "user_id").
+ Where("id = ?", itemId).
+ Find(&item).
+ Error
+
+ // TODO: handle potential errors
+ if err != nil {
+ return err
+ }
+
+ // item doesn't exist
+ if item.ID == 0 {
+ return errors.ErrNotFound
+ }
+
+ // user doesn't own this item
+ if item.UserID != userId {
+ return errors.ErrForbidden
+ }
+
+ return nil
+}