diff options
Diffstat (limited to 'customer')
| -rw-r--r-- | customer/controller.go | 9 | ||||
| -rw-r--r-- | customer/customer.go | 24 | ||||
| -rw-r--r-- | customer/hooks.go | 5 | ||||
| -rw-r--r-- | customer/validators.go | 59 | 
4 files changed, 55 insertions, 42 deletions
diff --git a/customer/controller.go b/customer/controller.go index ae6101f..2bacd02 100644 --- a/customer/controller.go +++ b/customer/controller.go @@ -99,7 +99,6 @@ func handleSaveCustomer (ctx *gin.Context) {  	userId := uId.(uint)  	customer.UserID = userId -	customer.Contact.UserID = userId  	err := customer.upsert()  	if err != nil { @@ -134,7 +133,13 @@ func handleDelCustomer (ctx *gin.Context) {  	userId := uId.(uint)  	customer.UserID = userId -	// TODO: if userid and customer's user id don't match, dont delete +	err = checkCustomerOwnership(customer.ID, customer.UserID) +	if err != nil { +		ctx.Error(err) +		ctx.Abort() +		return +	} +  	err = customer.del()  	if err != nil {  		ctx.Error(err) diff --git a/customer/customer.go b/customer/customer.go index 23c630d..e411ad5 100644 --- a/customer/customer.go +++ b/customer/customer.go @@ -27,18 +27,7 @@ var db *gorm.DB  func init() {  	db = d.DB -	db.AutoMigrate(&Customer{}, &CustomerContact{}, &CustomerBillingAddress{}, &CustomerShippingAddress{}) -} - -type CustomerContact struct { -	gorm.Model -	UserID     uint      `json:"-"` -	User       user.User `json:"-"` -	CustomerID uint -	Name       string -	Phone      string -	Email      string -	Website    string +	db.AutoMigrate(&Customer{}, &CustomerBillingAddress{}, &CustomerShippingAddress{})  }  type Address struct { @@ -64,8 +53,11 @@ type Customer struct {  	UserID            uint      `json:"-"`  	User              user.User `json:"-"`  	Name              string -	Gstin             string  -	Contact           CustomerContact -	BillingAddress    CustomerBillingAddress -	ShippingAddresses []CustomerShippingAddress +	Gstin             string +	ContactName       string +	Phone             string +	Email             string +	Website           string +	//BillingAddress    CustomerBillingAddress +	//ShippingAddresses []CustomerShippingAddress  } diff --git a/customer/hooks.go b/customer/hooks.go index 020e1c5..ac246f3 100644 --- a/customer/hooks.go +++ b/customer/hooks.go @@ -30,11 +30,6 @@ func (c *Customer) BeforeSave(tx *gorm.DB) error {  		return err  	} -	err = c.Contact.validate() -	if err != nil { -		return err -	} -  	return nil  } diff --git a/customer/validators.go b/customer/validators.go index bfd244f..2a37394 100644 --- a/customer/validators.go +++ b/customer/validators.go @@ -29,8 +29,9 @@ import (  func validateContactField(field, value string, userId uint) error {  	if value != "" {  		var count int64 -		err := db.Model(&CustomerContact{}). -			Where(field + " = ? and user_id = ?", value, userId). +		err := db.Model(&Customer{}). +			//Select(""). +			Where("user_id = ? and " + field + " = ?", userId, value).  			Count(&count).  			Error @@ -55,29 +56,15 @@ func validateContactField(field, value string, userId uint) error {  	return nil  } -func (c *CustomerContact) validate() error { +func (c *Customer) validate() error {  	// trim whitespaces  	c.Name = strings.TrimSpace(c.Name) +	c.Gstin = strings.TrimSpace(c.Gstin) +	c.ContactName = strings.TrimSpace(c.Name)  	c.Phone = strings.TrimSpace(c.Phone)  	c.Email = strings.TrimSpace(c.Email)  	c.Website = strings.TrimSpace(c.Website) -	var err error -	for _, i := range [][]string{{"phone", c.Phone}, {"email", c.Email}, {"website", c.Website}} { -		err = validateContactField(i[0], i[1], c.UserID) -		if err != nil { -			return err -		} -	} - -	return nil -} - -func (c *Customer) validate() error { -	// trim whitespaces -	c.Name = strings.TrimSpace(c.Name) -	c.Gstin = strings.TrimSpace(c.Gstin) -  	// don't validate if GSTIN is empty  	if c.Gstin != "" {  	  // GSTIN regex validation @@ -103,5 +90,39 @@ func (c *Customer) validate() error {  		}  	} +	var err error +	for _, i := range [][]string{{"phone", c.Phone}, {"email", c.Email}, {"website", c.Website}} { +		err = validateContactField(i[0], i[1], c.UserID) +		if err != nil { +			return err +		} +	} + +	return nil +} + +func checkCustomerOwnership(customerId, userId uint) error { +	var customer Customer +	err := db. +		Select("id", "user_id"). +		Where("id = ?", customerId). +		Find(&customer). +		Error + +	// TODO: handle potential errors +	if err != nil { +		return err +  } + +	// customer doesn't exist +	if customer.ID == 0 { +		return errors.ErrNotFound +	} + +	// user doesn't own this customer +	if customer.UserID != userId { +		return errors.ErrForbidden +	} +  	return nil  }  |