aboutsummaryrefslogtreecommitdiff
path: root/item
diff options
context:
space:
mode:
Diffstat (limited to 'item')
-rw-r--r--item/controller.go73
-rw-r--r--item/hooks.go2
-rw-r--r--item/service.go22
-rw-r--r--item/validators.go10
4 files changed, 89 insertions, 18 deletions
diff --git a/item/controller.go b/item/controller.go
index b4e27c1..cf9683d 100644
--- a/item/controller.go
+++ b/item/controller.go
@@ -31,8 +31,17 @@ func handleGetBrandItems (ctx *gin.Context) {
return
}
+ uId, ok := ctx.Get("UserID")
+ if !ok {
+ ctx.Error(e.ErrUnauthorized)
+ ctx.Abort()
+ return
+ }
+
+ userId := uId.(uint)
+
var items []SavedItem
- err = getBrandItems(&items, uint(id))
+ err = getBrandItems(&items, uint(id), userId)
if err != nil {
ctx.Error(err)
ctx.Abort()
@@ -48,7 +57,16 @@ func handleGetBrandItems (ctx *gin.Context) {
func handleGetBrands (ctx *gin.Context) {
var brands []Brand
- err := getBrands(&brands)
+ uId, ok := ctx.Get("UserID")
+ if !ok {
+ ctx.Error(e.ErrUnauthorized)
+ ctx.Abort()
+ return
+ }
+
+ userId := uId.(uint)
+
+ err := getBrands(&brands, userId)
if err != nil {
ctx.Error(err)
ctx.Abort()
@@ -65,6 +83,16 @@ func handleSaveBrand (ctx *gin.Context) {
var brand Brand
ctx.Bind(&brand)
+ uId, ok := ctx.Get("UserID")
+ if !ok {
+ ctx.Error(e.ErrUnauthorized)
+ ctx.Abort()
+ return
+ }
+
+ userId := uId.(uint)
+ brand.UserID = userId
+
err := brand.upsert()
if err != nil {
ctx.Error(err)
@@ -88,6 +116,16 @@ func handleDelBrand (ctx *gin.Context) {
var brand Brand
brand.ID = uint(id)
+ uId, ok := ctx.Get("UserID")
+ if !ok {
+ ctx.Error(e.ErrUnauthorized)
+ ctx.Abort()
+ return
+ }
+
+ userId := uId.(uint)
+ brand.UserID = userId
+
err = brand.del()
if err != nil {
ctx.Error(err)
@@ -103,7 +141,16 @@ func handleDelBrand (ctx *gin.Context) {
func handleGetItems (ctx *gin.Context) {
var items []SavedItem
- err := getItems(&items)
+ uId, ok := ctx.Get("UserID")
+ if !ok {
+ ctx.Error(e.ErrUnauthorized)
+ ctx.Abort()
+ return
+ }
+
+ userId := uId.(uint)
+
+ err := getItems(&items, userId)
if err != nil {
ctx.Error(err)
ctx.Abort()
@@ -120,6 +167,16 @@ func handleSaveItem (ctx *gin.Context) {
var item SavedItem
ctx.Bind(&item)
+ uId, ok := ctx.Get("UserID")
+ if !ok {
+ ctx.Error(e.ErrUnauthorized)
+ ctx.Abort()
+ return
+ }
+
+ userId := uId.(uint)
+ item.UserID = userId
+
err := item.upsert()
if err != nil {
ctx.Error(err)
@@ -143,6 +200,16 @@ func handleDelItem (ctx *gin.Context) {
var item SavedItem
item.ID = uint(id)
+ uId, ok := ctx.Get("UserID")
+ if !ok {
+ ctx.Error(e.ErrUnauthorized)
+ ctx.Abort()
+ return
+ }
+
+ userId := uId.(uint)
+ item.UserID = userId
+
err = item.del()
if err != nil {
ctx.Error(err)
diff --git a/item/hooks.go b/item/hooks.go
index ddb9a44..558a8cb 100644
--- a/item/hooks.go
+++ b/item/hooks.go
@@ -25,7 +25,7 @@ import (
func (i *SavedItem) BeforeSave(tx *gorm.DB) error {
var err error
- err = checkIfBrandExists(i.BrandID)
+ err = checkIfBrandExists(i.BrandID, i.UserID)
if err != nil {
return err
}
diff --git a/item/service.go b/item/service.go
index c8a72f6..fb03adc 100644
--- a/item/service.go
+++ b/item/service.go
@@ -22,12 +22,12 @@ import (
e "vidhukant.com/openbills/errors"
)
-func getBrandItems(items *[]SavedItem, id uint) error {
- // check if id is valid
+func getBrandItems(items *[]SavedItem, id, userId uint) error {
+ // check if brand id is valid and is owned by user
var count int64
err := db.Model(&Brand{}).
Select("id").
- Where("id = ?", id).
+ Where("id = ? and user_id = ?", id, userId).
Count(&count).
Error
@@ -46,6 +46,7 @@ func getBrandItems(items *[]SavedItem, id uint) error {
return res.Error
}
+ // returns 404 if either row doesn't exist or if the user doesn't own it
if res.RowsAffected == 0 {
return e.ErrEmptyResponse
}
@@ -53,8 +54,8 @@ func getBrandItems(items *[]SavedItem, id uint) error {
return nil
}
-func getBrands(brands *[]Brand) error {
- res := db.Find(&brands)
+func getBrands(brands *[]Brand, userId uint) error {
+ res := db.Where("user_id = ?", userId).Find(&brands)
// TODO: handle potential errors
if res.Error != nil {
@@ -74,14 +75,16 @@ func (b *Brand) upsert() error {
return res.Error
}
+// TODO: delete all items upon brand deletion
func (b *Brand) del() error {
- res := db.Delete(b)
+ res := db.Where("id = ? and user_id = ?", b.ID, b.UserID).Delete(b)
// TODO: handle potential errors
if res.Error != nil {
return res.Error
}
+ // returns 404 if either row doesn't exist or if the user doesn't own it
if res.RowsAffected == 0 {
return e.ErrNotFound
}
@@ -89,8 +92,8 @@ func (b *Brand) del() error {
return nil
}
-func getItems(items *[]SavedItem) error {
- res := db.Preload("Brand").Find(&items)
+func getItems(items *[]SavedItem, userId uint) error {
+ res := db.Where("user_id = ?", userId).Preload("Brand").Find(&items)
// TODO: handle potential errors
if res.Error != nil {
@@ -111,13 +114,14 @@ func (i *SavedItem) upsert() error {
}
func (i *SavedItem) del() error {
- res := db.Delete(i)
+ res := db.Where("id = ? and user_id = ?", i.ID, i.UserID).Delete(i)
// TODO: handle potential errors
if res.Error != nil {
return res.Error
}
+ // returns 404 if either row doesn't exist or if the user doesn't own it
if res.RowsAffected == 0 {
return e.ErrNotFound
}
diff --git a/item/validators.go b/item/validators.go
index 09162ab..996a5d7 100644
--- a/item/validators.go
+++ b/item/validators.go
@@ -36,7 +36,7 @@ func (b *Brand) validate() error {
var count int64
err := db.Model(&Brand{}).
Select("name").
- Where("name = ?", b.Name).
+ Where("name = ? and user_id = ?", b.Name, b.UserID).
Count(&count).
Error
@@ -51,12 +51,12 @@ func (b *Brand) validate() error {
return nil
}
-func checkIfBrandExists(id uint) error {
- // check if brand id is valid
+func checkIfBrandExists(id, userId uint) error {
+ // check if brand id is valid and is owned by user
var count int64
err := db.Model(&Brand{}).
Select("id").
- Where("id = ?", id).
+ Where("id = ? and user_id = ?", id, userId).
Count(&count).
Error
@@ -95,7 +95,7 @@ func (i *SavedItem) validate() error {
var count int64
err = db.Model(&SavedItem{}).
Select("name, brand_id").
- Where("brand_id = ? and name = ?", i.BrandID, i.Name).
+ Where("brand_id = ? and name = ? and user_id = ?", i.BrandID, i.Name, i.UserID).
Count(&count).
Error