diff options
Diffstat (limited to 'item')
| -rw-r--r-- | item/controller.go | 73 | ||||
| -rw-r--r-- | item/hooks.go | 2 | ||||
| -rw-r--r-- | item/service.go | 22 | ||||
| -rw-r--r-- | item/validators.go | 10 | 
4 files changed, 89 insertions, 18 deletions
diff --git a/item/controller.go b/item/controller.go index b4e27c1..cf9683d 100644 --- a/item/controller.go +++ b/item/controller.go @@ -31,8 +31,17 @@ func handleGetBrandItems (ctx *gin.Context) {  		return  	} +	uId, ok := ctx.Get("UserID") +	if !ok { +		ctx.Error(e.ErrUnauthorized) +		ctx.Abort() +		return +	} + +	userId := uId.(uint) +  	var items []SavedItem -	err = getBrandItems(&items, uint(id)) +	err = getBrandItems(&items, uint(id), userId)  	if err != nil {  		ctx.Error(err)  		ctx.Abort() @@ -48,7 +57,16 @@ func handleGetBrandItems (ctx *gin.Context) {  func handleGetBrands (ctx *gin.Context) {  	var brands []Brand -	err := getBrands(&brands) +	uId, ok := ctx.Get("UserID") +	if !ok { +		ctx.Error(e.ErrUnauthorized) +		ctx.Abort() +		return +	} + +	userId := uId.(uint) + +	err := getBrands(&brands, userId)  	if err != nil {  		ctx.Error(err)  		ctx.Abort() @@ -65,6 +83,16 @@ func handleSaveBrand (ctx *gin.Context) {  	var brand Brand  	ctx.Bind(&brand) +	uId, ok := ctx.Get("UserID") +	if !ok { +		ctx.Error(e.ErrUnauthorized) +		ctx.Abort() +		return +	} + +	userId := uId.(uint) +	brand.UserID = userId +  	err := brand.upsert()  	if err != nil {  		ctx.Error(err) @@ -88,6 +116,16 @@ func handleDelBrand (ctx *gin.Context) {  	var brand Brand  	brand.ID = uint(id) +	uId, ok := ctx.Get("UserID") +	if !ok { +		ctx.Error(e.ErrUnauthorized) +		ctx.Abort() +		return +	} + +	userId := uId.(uint) +	brand.UserID = userId +  	err = brand.del()  	if err != nil {  		ctx.Error(err) @@ -103,7 +141,16 @@ func handleDelBrand (ctx *gin.Context) {  func handleGetItems (ctx *gin.Context) {  	var items []SavedItem -	err := getItems(&items) +	uId, ok := ctx.Get("UserID") +	if !ok { +		ctx.Error(e.ErrUnauthorized) +		ctx.Abort() +		return +	} + +	userId := uId.(uint) + +	err := getItems(&items, userId)  	if err != nil {  		ctx.Error(err)  		ctx.Abort() @@ -120,6 +167,16 @@ func handleSaveItem (ctx *gin.Context) {  	var item SavedItem  	ctx.Bind(&item) +	uId, ok := ctx.Get("UserID") +	if !ok { +		ctx.Error(e.ErrUnauthorized) +		ctx.Abort() +		return +	} + +	userId := uId.(uint) +	item.UserID = userId +  	err := item.upsert()  	if err != nil {  		ctx.Error(err) @@ -143,6 +200,16 @@ func handleDelItem (ctx *gin.Context) {  	var item SavedItem  	item.ID = uint(id) +	uId, ok := ctx.Get("UserID") +	if !ok { +		ctx.Error(e.ErrUnauthorized) +		ctx.Abort() +		return +	} + +	userId := uId.(uint) +	item.UserID = userId +  	err = item.del()  	if err != nil {  		ctx.Error(err) diff --git a/item/hooks.go b/item/hooks.go index ddb9a44..558a8cb 100644 --- a/item/hooks.go +++ b/item/hooks.go @@ -25,7 +25,7 @@ import (  func (i *SavedItem) BeforeSave(tx *gorm.DB) error {  	var err error -	err = checkIfBrandExists(i.BrandID) +	err = checkIfBrandExists(i.BrandID, i.UserID)  	if err != nil {  		return err  	} diff --git a/item/service.go b/item/service.go index c8a72f6..fb03adc 100644 --- a/item/service.go +++ b/item/service.go @@ -22,12 +22,12 @@ import (  	e "vidhukant.com/openbills/errors"  ) -func getBrandItems(items *[]SavedItem, id uint) error { -	// check if id is valid +func getBrandItems(items *[]SavedItem, id, userId uint) error { +	// check if brand id is valid and is owned by user  	var count int64  	err := db.Model(&Brand{}).  		Select("id"). -		Where("id = ?", id). +		Where("id = ? and user_id = ?", id, userId).  		Count(&count).  		Error @@ -46,6 +46,7 @@ func getBrandItems(items *[]SavedItem, id uint) error {  		return res.Error  	} +	// returns 404 if either row doesn't exist or if the user doesn't own it  	if res.RowsAffected == 0 {  		return e.ErrEmptyResponse  	} @@ -53,8 +54,8 @@ func getBrandItems(items *[]SavedItem, id uint) error {  	return nil  } -func getBrands(brands *[]Brand) error { -	res := db.Find(&brands) +func getBrands(brands *[]Brand, userId uint) error { +	res := db.Where("user_id = ?", userId).Find(&brands)  	// TODO: handle potential errors  	if res.Error != nil { @@ -74,14 +75,16 @@ func (b *Brand) upsert() error {  	return res.Error  } +// TODO: delete all items upon brand deletion  func (b *Brand) del() error { -	res := db.Delete(b) +	res := db.Where("id = ? and user_id = ?", b.ID, b.UserID).Delete(b)  	// TODO: handle potential errors  	if res.Error != nil {  		return res.Error  	} +	// returns 404 if either row doesn't exist or if the user doesn't own it  	if res.RowsAffected == 0 {  		return e.ErrNotFound  	} @@ -89,8 +92,8 @@ func (b *Brand) del() error {  	return nil  } -func getItems(items *[]SavedItem) error { -	res := db.Preload("Brand").Find(&items) +func getItems(items *[]SavedItem, userId uint) error { +	res := db.Where("user_id = ?", userId).Preload("Brand").Find(&items)  	// TODO: handle potential errors  	if res.Error != nil { @@ -111,13 +114,14 @@ func (i *SavedItem) upsert() error {  }  func (i *SavedItem) del() error { -	res := db.Delete(i) +	res := db.Where("id = ? and user_id = ?", i.ID, i.UserID).Delete(i)  	// TODO: handle potential errors  	if res.Error != nil {  		return res.Error  	} +	// returns 404 if either row doesn't exist or if the user doesn't own it  	if res.RowsAffected == 0 {  		return e.ErrNotFound  	} diff --git a/item/validators.go b/item/validators.go index 09162ab..996a5d7 100644 --- a/item/validators.go +++ b/item/validators.go @@ -36,7 +36,7 @@ func (b *Brand) validate() error {  	var count int64  	err := db.Model(&Brand{}).  		Select("name"). -		Where("name = ?", b.Name). +		Where("name = ? and user_id = ?", b.Name, b.UserID).  		Count(&count).  		Error @@ -51,12 +51,12 @@ func (b *Brand) validate() error {  	return nil  } -func checkIfBrandExists(id uint) error { -	// check if brand id is valid +func checkIfBrandExists(id, userId uint) error { +	// check if brand id is valid and is owned by user  	var count int64  	err := db.Model(&Brand{}).  		Select("id"). -		Where("id = ?", id). +		Where("id = ? and user_id = ?", id, userId).  		Count(&count).  		Error @@ -95,7 +95,7 @@ func (i *SavedItem) validate() error {  	var count int64  	err = db.Model(&SavedItem{}).  		Select("name, brand_id"). -		Where("brand_id = ? and name = ?", i.BrandID, i.Name). +		Where("brand_id = ? and name = ? and user_id = ?", i.BrandID, i.Name, i.UserID).  		Count(&count).  		Error  |