aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--auth/controller.go2
-rw-r--r--customer/controller.go48
-rw-r--r--customer/customer.go5
-rw-r--r--customer/service.go7
-rw-r--r--customer/validators.go8
-rw-r--r--errors/errors.go5
-rw-r--r--errors/status.go5
-rw-r--r--item/item.go3
8 files changed, 73 insertions, 10 deletions
diff --git a/auth/controller.go b/auth/controller.go
index 93211dd..277a85a 100644
--- a/auth/controller.go
+++ b/auth/controller.go
@@ -85,7 +85,7 @@ func handleSignIn (ctx *gin.Context) {
AuthClaims {
jwt.RegisteredClaims {
IssuedAt: jwt.NewNumericDate(time.Now()),
- ExpiresAt: jwt.NewNumericDate(time.Now().Add(time.Minute * 2)),
+ ExpiresAt: jwt.NewNumericDate(time.Now().Add(time.Minute * 2)),
},
u.ID,
},
diff --git a/customer/controller.go b/customer/controller.go
index 9381c45..ae6101f 100644
--- a/customer/controller.go
+++ b/customer/controller.go
@@ -31,6 +31,15 @@ func handleGetSingleCustomer (ctx *gin.Context) {
return
}
+ uId, ok := ctx.Get("UserID")
+ if !ok {
+ ctx.Error(e.ErrUnauthorized)
+ ctx.Abort()
+ return
+ }
+
+ userId := uId.(uint)
+
var customer Customer
err = getCustomer(&customer, uint(id))
@@ -40,6 +49,12 @@ func handleGetSingleCustomer (ctx *gin.Context) {
return
}
+ if customer.UserID != userId {
+ ctx.Error(e.ErrForbidden)
+ ctx.Abort()
+ return
+ }
+
ctx.JSON(http.StatusOK, gin.H{
"message": "success",
"data": customer,
@@ -49,7 +64,16 @@ func handleGetSingleCustomer (ctx *gin.Context) {
func handleGetCustomers (ctx *gin.Context) {
var customers []Customer
- err := getCustomers(&customers)
+ uId, ok := ctx.Get("UserID")
+ if !ok {
+ ctx.Error(e.ErrUnauthorized)
+ ctx.Abort()
+ return
+ }
+
+ userId := uId.(uint)
+
+ err := getCustomers(&customers, userId)
if err != nil {
ctx.Error(err)
ctx.Abort()
@@ -66,6 +90,17 @@ func handleSaveCustomer (ctx *gin.Context) {
var customer Customer
ctx.Bind(&customer)
+ uId, ok := ctx.Get("UserID")
+ if !ok {
+ ctx.Error(e.ErrUnauthorized)
+ ctx.Abort()
+ return
+ }
+
+ userId := uId.(uint)
+ customer.UserID = userId
+ customer.Contact.UserID = userId
+
err := customer.upsert()
if err != nil {
ctx.Error(err)
@@ -89,6 +124,17 @@ func handleDelCustomer (ctx *gin.Context) {
var customer Customer
customer.ID = uint(id)
+ uId, ok := ctx.Get("UserID")
+ if !ok {
+ ctx.Error(e.ErrUnauthorized)
+ ctx.Abort()
+ return
+ }
+
+ userId := uId.(uint)
+ customer.UserID = userId
+
+ // TODO: if userid and customer's user id don't match, dont delete
err = customer.del()
if err != nil {
ctx.Error(err)
diff --git a/customer/customer.go b/customer/customer.go
index 5f25e2d..23c630d 100644
--- a/customer/customer.go
+++ b/customer/customer.go
@@ -20,6 +20,7 @@ package customer
import (
"gorm.io/gorm"
d "vidhukant.com/openbills/db"
+ "vidhukant.com/openbills/user"
)
var db *gorm.DB
@@ -31,6 +32,8 @@ func init() {
type CustomerContact struct {
gorm.Model
+ UserID uint `json:"-"`
+ User user.User `json:"-"`
CustomerID uint
Name string
Phone string
@@ -58,6 +61,8 @@ type CustomerShippingAddress struct {
type Customer struct {
gorm.Model
+ UserID uint `json:"-"`
+ User user.User `json:"-"`
Name string
Gstin string
Contact CustomerContact
diff --git a/customer/service.go b/customer/service.go
index c5d7cb8..f1108c6 100644
--- a/customer/service.go
+++ b/customer/service.go
@@ -36,8 +36,8 @@ func getCustomer(customer *Customer, id uint) error {
return nil
}
-func getCustomers(customers *[]Customer) error {
- res := db.Find(&customers)
+func getCustomers(customers *[]Customer, userId uint) error {
+ res := db.Where("user_id = ?", userId).Find(&customers)
// TODO: handle potential errors
if res.Error != nil {
@@ -58,13 +58,14 @@ func (c *Customer) upsert() error {
}
func (c *Customer) del() error {
- res := db.Delete(c)
+ res := db.Where("id = ? and user_id = ?", c.ID, c.UserID).Delete(c)
// TODO: handle potential errors
if res.Error != nil {
return res.Error
}
+ // returns 404 if either row doesn't exist or if the user doesn't own it
if res.RowsAffected == 0 {
return e.ErrNotFound
}
diff --git a/customer/validators.go b/customer/validators.go
index 6c51ad9..bfd244f 100644
--- a/customer/validators.go
+++ b/customer/validators.go
@@ -26,11 +26,11 @@ import (
// NOTE: very inefficient and really really really dumb but it works
// TODO: find a better (or even a remotely good) way
-func validateContactField(field, value string) error {
+func validateContactField(field, value string, userId uint) error {
if value != "" {
var count int64
err := db.Model(&CustomerContact{}).
- Where(field + " = ?", value).
+ Where(field + " = ? and user_id = ?", value, userId).
Count(&count).
Error
@@ -64,7 +64,7 @@ func (c *CustomerContact) validate() error {
var err error
for _, i := range [][]string{{"phone", c.Phone}, {"email", c.Email}, {"website", c.Website}} {
- err = validateContactField(i[0], i[1])
+ err = validateContactField(i[0], i[1], c.UserID)
if err != nil {
return err
}
@@ -90,7 +90,7 @@ func (c *Customer) validate() error {
var count int64
err := db.Model(&Customer{}).
Select("gstin").
- Where("gstin = ?", c.Gstin).
+ Where("gstin = ? and user_id = ?", c.Gstin, c.UserID).
Count(&count).
Error
diff --git a/errors/errors.go b/errors/errors.go
index 16e1646..6716fdc 100644
--- a/errors/errors.go
+++ b/errors/errors.go
@@ -42,6 +42,9 @@ var (
ErrUnauthorized = errors.New("Unauthorized")
ErrSessionExpired = errors.New("Session Expired")
+ // 403
+ ErrForbidden = errors.New("You Are Not Authorized To Access This Resource")
+
// 404
ErrNotFound = errors.New("Not Found")
ErrBrandNotFound = errors.New("This Brand Does Not Exist")
@@ -56,5 +59,5 @@ var (
ErrNonUniqueBrandItem = errors.New("Item With Same Name And Brand Already Exists")
// 500
- ErrInternalServerError = errors.New("Internal Server Error")
+ ErrInternalServerError = errors.New("Internal Server Error")
)
diff --git a/errors/status.go b/errors/status.go
index 4767aeb..1fa33d1 100644
--- a/errors/status.go
+++ b/errors/status.go
@@ -49,6 +49,11 @@ func StatusCodeFromErr(err error) int {
return http.StatusUnauthorized
}
+ // 403
+ if errors.Is(err, ErrForbidden) {
+ return http.StatusForbidden
+ }
+
// 404
if errors.Is(err, ErrNotFound) ||
errors.Is(err, ErrBrandNotFound) {
diff --git a/item/item.go b/item/item.go
index 4da9c78..839cbe0 100644
--- a/item/item.go
+++ b/item/item.go
@@ -20,6 +20,7 @@ package item
import (
"gorm.io/gorm"
d "vidhukant.com/openbills/db"
+ "vidhukant.com/openbills/user"
)
var db *gorm.DB
@@ -35,6 +36,8 @@ type Brand struct {
}
type Item struct {
+ UserID uint `json:"-"`
+ User user.User `json:"-"`
BrandID uint
Brand Brand
UnitOfMeasure string // TODO: probably has to be a custom type