diff options
-rw-r--r-- | auth/controller.go | 2 | ||||
-rw-r--r-- | customer/controller.go | 48 | ||||
-rw-r--r-- | customer/customer.go | 5 | ||||
-rw-r--r-- | customer/service.go | 7 | ||||
-rw-r--r-- | customer/validators.go | 8 | ||||
-rw-r--r-- | errors/errors.go | 5 | ||||
-rw-r--r-- | errors/status.go | 5 | ||||
-rw-r--r-- | item/item.go | 3 |
8 files changed, 73 insertions, 10 deletions
diff --git a/auth/controller.go b/auth/controller.go index 93211dd..277a85a 100644 --- a/auth/controller.go +++ b/auth/controller.go @@ -85,7 +85,7 @@ func handleSignIn (ctx *gin.Context) { AuthClaims { jwt.RegisteredClaims { IssuedAt: jwt.NewNumericDate(time.Now()), - ExpiresAt: jwt.NewNumericDate(time.Now().Add(time.Minute * 2)), + ExpiresAt: jwt.NewNumericDate(time.Now().Add(time.Minute * 2)), }, u.ID, }, diff --git a/customer/controller.go b/customer/controller.go index 9381c45..ae6101f 100644 --- a/customer/controller.go +++ b/customer/controller.go @@ -31,6 +31,15 @@ func handleGetSingleCustomer (ctx *gin.Context) { return } + uId, ok := ctx.Get("UserID") + if !ok { + ctx.Error(e.ErrUnauthorized) + ctx.Abort() + return + } + + userId := uId.(uint) + var customer Customer err = getCustomer(&customer, uint(id)) @@ -40,6 +49,12 @@ func handleGetSingleCustomer (ctx *gin.Context) { return } + if customer.UserID != userId { + ctx.Error(e.ErrForbidden) + ctx.Abort() + return + } + ctx.JSON(http.StatusOK, gin.H{ "message": "success", "data": customer, @@ -49,7 +64,16 @@ func handleGetSingleCustomer (ctx *gin.Context) { func handleGetCustomers (ctx *gin.Context) { var customers []Customer - err := getCustomers(&customers) + uId, ok := ctx.Get("UserID") + if !ok { + ctx.Error(e.ErrUnauthorized) + ctx.Abort() + return + } + + userId := uId.(uint) + + err := getCustomers(&customers, userId) if err != nil { ctx.Error(err) ctx.Abort() @@ -66,6 +90,17 @@ func handleSaveCustomer (ctx *gin.Context) { var customer Customer ctx.Bind(&customer) + uId, ok := ctx.Get("UserID") + if !ok { + ctx.Error(e.ErrUnauthorized) + ctx.Abort() + return + } + + userId := uId.(uint) + customer.UserID = userId + customer.Contact.UserID = userId + err := customer.upsert() if err != nil { ctx.Error(err) @@ -89,6 +124,17 @@ func handleDelCustomer (ctx *gin.Context) { var customer Customer customer.ID = uint(id) + uId, ok := ctx.Get("UserID") + if !ok { + ctx.Error(e.ErrUnauthorized) + ctx.Abort() + return + } + + userId := uId.(uint) + customer.UserID = userId + + // TODO: if userid and customer's user id don't match, dont delete err = customer.del() if err != nil { ctx.Error(err) diff --git a/customer/customer.go b/customer/customer.go index 5f25e2d..23c630d 100644 --- a/customer/customer.go +++ b/customer/customer.go @@ -20,6 +20,7 @@ package customer import ( "gorm.io/gorm" d "vidhukant.com/openbills/db" + "vidhukant.com/openbills/user" ) var db *gorm.DB @@ -31,6 +32,8 @@ func init() { type CustomerContact struct { gorm.Model + UserID uint `json:"-"` + User user.User `json:"-"` CustomerID uint Name string Phone string @@ -58,6 +61,8 @@ type CustomerShippingAddress struct { type Customer struct { gorm.Model + UserID uint `json:"-"` + User user.User `json:"-"` Name string Gstin string Contact CustomerContact diff --git a/customer/service.go b/customer/service.go index c5d7cb8..f1108c6 100644 --- a/customer/service.go +++ b/customer/service.go @@ -36,8 +36,8 @@ func getCustomer(customer *Customer, id uint) error { return nil } -func getCustomers(customers *[]Customer) error { - res := db.Find(&customers) +func getCustomers(customers *[]Customer, userId uint) error { + res := db.Where("user_id = ?", userId).Find(&customers) // TODO: handle potential errors if res.Error != nil { @@ -58,13 +58,14 @@ func (c *Customer) upsert() error { } func (c *Customer) del() error { - res := db.Delete(c) + res := db.Where("id = ? and user_id = ?", c.ID, c.UserID).Delete(c) // TODO: handle potential errors if res.Error != nil { return res.Error } + // returns 404 if either row doesn't exist or if the user doesn't own it if res.RowsAffected == 0 { return e.ErrNotFound } diff --git a/customer/validators.go b/customer/validators.go index 6c51ad9..bfd244f 100644 --- a/customer/validators.go +++ b/customer/validators.go @@ -26,11 +26,11 @@ import ( // NOTE: very inefficient and really really really dumb but it works // TODO: find a better (or even a remotely good) way -func validateContactField(field, value string) error { +func validateContactField(field, value string, userId uint) error { if value != "" { var count int64 err := db.Model(&CustomerContact{}). - Where(field + " = ?", value). + Where(field + " = ? and user_id = ?", value, userId). Count(&count). Error @@ -64,7 +64,7 @@ func (c *CustomerContact) validate() error { var err error for _, i := range [][]string{{"phone", c.Phone}, {"email", c.Email}, {"website", c.Website}} { - err = validateContactField(i[0], i[1]) + err = validateContactField(i[0], i[1], c.UserID) if err != nil { return err } @@ -90,7 +90,7 @@ func (c *Customer) validate() error { var count int64 err := db.Model(&Customer{}). Select("gstin"). - Where("gstin = ?", c.Gstin). + Where("gstin = ? and user_id = ?", c.Gstin, c.UserID). Count(&count). Error diff --git a/errors/errors.go b/errors/errors.go index 16e1646..6716fdc 100644 --- a/errors/errors.go +++ b/errors/errors.go @@ -42,6 +42,9 @@ var ( ErrUnauthorized = errors.New("Unauthorized") ErrSessionExpired = errors.New("Session Expired") + // 403 + ErrForbidden = errors.New("You Are Not Authorized To Access This Resource") + // 404 ErrNotFound = errors.New("Not Found") ErrBrandNotFound = errors.New("This Brand Does Not Exist") @@ -56,5 +59,5 @@ var ( ErrNonUniqueBrandItem = errors.New("Item With Same Name And Brand Already Exists") // 500 - ErrInternalServerError = errors.New("Internal Server Error") + ErrInternalServerError = errors.New("Internal Server Error") ) diff --git a/errors/status.go b/errors/status.go index 4767aeb..1fa33d1 100644 --- a/errors/status.go +++ b/errors/status.go @@ -49,6 +49,11 @@ func StatusCodeFromErr(err error) int { return http.StatusUnauthorized } + // 403 + if errors.Is(err, ErrForbidden) { + return http.StatusForbidden + } + // 404 if errors.Is(err, ErrNotFound) || errors.Is(err, ErrBrandNotFound) { diff --git a/item/item.go b/item/item.go index 4da9c78..839cbe0 100644 --- a/item/item.go +++ b/item/item.go @@ -20,6 +20,7 @@ package item import ( "gorm.io/gorm" d "vidhukant.com/openbills/db" + "vidhukant.com/openbills/user" ) var db *gorm.DB @@ -35,6 +36,8 @@ type Brand struct { } type Item struct { + UserID uint `json:"-"` + User user.User `json:"-"` BrandID uint Brand Brand UnitOfMeasure string // TODO: probably has to be a custom type |