diff options
-rw-r--r-- | conf/conf.go | 2 | ||||
-rw-r--r-- | customer/controller.go | 9 | ||||
-rw-r--r-- | customer/customer.go | 24 | ||||
-rw-r--r-- | customer/hooks.go | 5 | ||||
-rw-r--r-- | customer/validators.go | 59 |
5 files changed, 55 insertions, 44 deletions
diff --git a/conf/conf.go b/conf/conf.go index 6cf0d99..8afe1f4 100644 --- a/conf/conf.go +++ b/conf/conf.go @@ -73,7 +73,6 @@ func validateConf() { } } -// TODO: validate config func init() { viper.SetConfigName("openbills") viper.AddConfigPath("/etc/openbills") @@ -94,7 +93,6 @@ func init() { viper.SetDefault("instance.description", "Libre Billing Software") viper.SetDefault("instance.url", "https://openbills.vidhukant.com") - // TODO: exit if these 3 have unallowed fields viper.SetDefault("username.allowed_characters", "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789.-_") viper.SetDefault("username.min_username_length", 2) viper.SetDefault("username.max_username_length", 20) diff --git a/customer/controller.go b/customer/controller.go index ae6101f..2bacd02 100644 --- a/customer/controller.go +++ b/customer/controller.go @@ -99,7 +99,6 @@ func handleSaveCustomer (ctx *gin.Context) { userId := uId.(uint) customer.UserID = userId - customer.Contact.UserID = userId err := customer.upsert() if err != nil { @@ -134,7 +133,13 @@ func handleDelCustomer (ctx *gin.Context) { userId := uId.(uint) customer.UserID = userId - // TODO: if userid and customer's user id don't match, dont delete + err = checkCustomerOwnership(customer.ID, customer.UserID) + if err != nil { + ctx.Error(err) + ctx.Abort() + return + } + err = customer.del() if err != nil { ctx.Error(err) diff --git a/customer/customer.go b/customer/customer.go index 23c630d..e411ad5 100644 --- a/customer/customer.go +++ b/customer/customer.go @@ -27,18 +27,7 @@ var db *gorm.DB func init() { db = d.DB - db.AutoMigrate(&Customer{}, &CustomerContact{}, &CustomerBillingAddress{}, &CustomerShippingAddress{}) -} - -type CustomerContact struct { - gorm.Model - UserID uint `json:"-"` - User user.User `json:"-"` - CustomerID uint - Name string - Phone string - Email string - Website string + db.AutoMigrate(&Customer{}, &CustomerBillingAddress{}, &CustomerShippingAddress{}) } type Address struct { @@ -64,8 +53,11 @@ type Customer struct { UserID uint `json:"-"` User user.User `json:"-"` Name string - Gstin string - Contact CustomerContact - BillingAddress CustomerBillingAddress - ShippingAddresses []CustomerShippingAddress + Gstin string + ContactName string + Phone string + Email string + Website string + //BillingAddress CustomerBillingAddress + //ShippingAddresses []CustomerShippingAddress } diff --git a/customer/hooks.go b/customer/hooks.go index 020e1c5..ac246f3 100644 --- a/customer/hooks.go +++ b/customer/hooks.go @@ -30,11 +30,6 @@ func (c *Customer) BeforeSave(tx *gorm.DB) error { return err } - err = c.Contact.validate() - if err != nil { - return err - } - return nil } diff --git a/customer/validators.go b/customer/validators.go index bfd244f..2a37394 100644 --- a/customer/validators.go +++ b/customer/validators.go @@ -29,8 +29,9 @@ import ( func validateContactField(field, value string, userId uint) error { if value != "" { var count int64 - err := db.Model(&CustomerContact{}). - Where(field + " = ? and user_id = ?", value, userId). + err := db.Model(&Customer{}). + //Select(""). + Where("user_id = ? and " + field + " = ?", userId, value). Count(&count). Error @@ -55,29 +56,15 @@ func validateContactField(field, value string, userId uint) error { return nil } -func (c *CustomerContact) validate() error { +func (c *Customer) validate() error { // trim whitespaces c.Name = strings.TrimSpace(c.Name) + c.Gstin = strings.TrimSpace(c.Gstin) + c.ContactName = strings.TrimSpace(c.Name) c.Phone = strings.TrimSpace(c.Phone) c.Email = strings.TrimSpace(c.Email) c.Website = strings.TrimSpace(c.Website) - var err error - for _, i := range [][]string{{"phone", c.Phone}, {"email", c.Email}, {"website", c.Website}} { - err = validateContactField(i[0], i[1], c.UserID) - if err != nil { - return err - } - } - - return nil -} - -func (c *Customer) validate() error { - // trim whitespaces - c.Name = strings.TrimSpace(c.Name) - c.Gstin = strings.TrimSpace(c.Gstin) - // don't validate if GSTIN is empty if c.Gstin != "" { // GSTIN regex validation @@ -103,5 +90,39 @@ func (c *Customer) validate() error { } } + var err error + for _, i := range [][]string{{"phone", c.Phone}, {"email", c.Email}, {"website", c.Website}} { + err = validateContactField(i[0], i[1], c.UserID) + if err != nil { + return err + } + } + + return nil +} + +func checkCustomerOwnership(customerId, userId uint) error { + var customer Customer + err := db. + Select("id", "user_id"). + Where("id = ?", customerId). + Find(&customer). + Error + + // TODO: handle potential errors + if err != nil { + return err + } + + // customer doesn't exist + if customer.ID == 0 { + return errors.ErrNotFound + } + + // user doesn't own this customer + if customer.UserID != userId { + return errors.ErrForbidden + } + return nil } |