diff options
-rw-r--r-- | auth/auth.go | 6 | ||||
-rw-r--r-- | auth/controller.go | 10 | ||||
-rw-r--r-- | errors/errors.go | 1 | ||||
-rw-r--r-- | errors/status.go | 1 | ||||
-rw-r--r-- | main.go | 2 | ||||
-rw-r--r-- | user/service.go | 19 | ||||
-rw-r--r-- | user/user.go | 5 |
7 files changed, 37 insertions, 7 deletions
diff --git a/auth/auth.go b/auth/auth.go index 6797c91..ae2db9b 100644 --- a/auth/auth.go +++ b/auth/auth.go @@ -25,3 +25,9 @@ type AuthClaims struct { jwt.RegisteredClaims UserID uint `json:"userid"` } + +type LoginReq struct { + AccountName string + Method string + Password string +} diff --git a/auth/controller.go b/auth/controller.go index 277a85a..1e6f7fe 100644 --- a/auth/controller.go +++ b/auth/controller.go @@ -25,6 +25,7 @@ import ( "vidhukant.com/openbills/user" "net/http" "time" + "fmt" ) var ( @@ -68,12 +69,15 @@ func handleSignUp (ctx *gin.Context) { } func handleSignIn (ctx *gin.Context) { - var u user.User - ctx.Bind(&u) + var req LoginReq + ctx.Bind(&req) + + fmt.Println(req) var err error + var u user.User - err = user.CheckPassword(u.ID, u.Password) + err = user.CheckPassword(&u, req.AccountName, req.Method, req.Password) if err != nil { // TODO: handle potential errors ctx.Error(err) diff --git a/errors/errors.go b/errors/errors.go index 3f1efc5..b5f4f63 100644 --- a/errors/errors.go +++ b/errors/errors.go @@ -38,6 +38,7 @@ var ( ErrInvalidGSTPercentage = errors.New("Invalid GST Percentage") ErrPasswordTooShort = errors.New("Password Is Too Short") ErrPasswordTooLong = errors.New("Password Is Too Long") + ErrInvalidLoginMethod = errors.New("Login Method Can Only Be 'email' Or 'username'") // 401 ErrWrongPassword = errors.New("Wrong Password") diff --git a/errors/status.go b/errors/status.go index 7a23ddf..c7fc2a4 100644 --- a/errors/status.go +++ b/errors/status.go @@ -40,6 +40,7 @@ func StatusCodeFromErr(err error) int { errors.Is(err, ErrInvalidUnitPrice) || errors.Is(err, ErrPasswordTooShort) || errors.Is(err, ErrPasswordTooLong) || + errors.Is(err, ErrInvalidLoginMethod) || errors.Is(err, ErrInvalidGSTPercentage) { return http.StatusBadRequest } @@ -37,7 +37,7 @@ import ( "log" ) -const OPENBILLS_VERSION = "v0.0.5" +const OPENBILLS_VERSION = "v0.0.6" func init() { if viper.GetBool("production_mode") { diff --git a/user/service.go b/user/service.go index 5e0632b..4544cb4 100644 --- a/user/service.go +++ b/user/service.go @@ -27,6 +27,25 @@ func (u *User) Create() error { return res.Error } +func GetUserWithAccountName(user *User, accountName, method string) error { + if method != "username" && method != "email" { + return e.ErrInvalidLoginMethod + } + + res := db.Where(method + " = ?", accountName).Find(&user) + + // TODO: handle potential errors + if res.Error != nil { + return res.Error + } + + if res.RowsAffected == 0 { + return e.ErrNotFound + } + + return nil +} + func GetUser(user *User, id uint) error { res := db.Find(&user, id) diff --git a/user/user.go b/user/user.go index 68ceb47..ee36e95 100644 --- a/user/user.go +++ b/user/user.go @@ -44,9 +44,8 @@ type User struct { IsVerified bool } -func CheckPassword(id uint, pass string) error { - var user User - err := GetUser(&user, id) +func CheckPassword(user *User, accountName, method, pass string) error { + err := GetUserWithAccountName(user, accountName, method) if err != nil { return err } |