diff options
| -rw-r--r-- | auth/auth.go | 6 | ||||
| -rw-r--r-- | auth/controller.go | 10 | ||||
| -rw-r--r-- | errors/errors.go | 1 | ||||
| -rw-r--r-- | errors/status.go | 1 | ||||
| -rw-r--r-- | main.go | 2 | ||||
| -rw-r--r-- | user/service.go | 19 | ||||
| -rw-r--r-- | user/user.go | 5 | 
7 files changed, 37 insertions, 7 deletions
diff --git a/auth/auth.go b/auth/auth.go index 6797c91..ae2db9b 100644 --- a/auth/auth.go +++ b/auth/auth.go @@ -25,3 +25,9 @@ type AuthClaims struct {  	jwt.RegisteredClaims  	UserID uint `json:"userid"`  } + +type LoginReq struct { +	AccountName string +	Method      string +	Password    string +} diff --git a/auth/controller.go b/auth/controller.go index 277a85a..1e6f7fe 100644 --- a/auth/controller.go +++ b/auth/controller.go @@ -25,6 +25,7 @@ import (  	"vidhukant.com/openbills/user"  	"net/http"  	"time" +	"fmt"  )  var ( @@ -68,12 +69,15 @@ func handleSignUp (ctx *gin.Context) {  }  func handleSignIn (ctx *gin.Context) { -	var u user.User -	ctx.Bind(&u) +	var req LoginReq +	ctx.Bind(&req) + +	fmt.Println(req)  	var err error +	var u user.User -	err = user.CheckPassword(u.ID, u.Password) +	err = user.CheckPassword(&u, req.AccountName, req.Method, req.Password)  	if err != nil {  		// TODO: handle potential errors  		ctx.Error(err) diff --git a/errors/errors.go b/errors/errors.go index 3f1efc5..b5f4f63 100644 --- a/errors/errors.go +++ b/errors/errors.go @@ -38,6 +38,7 @@ var (  	ErrInvalidGSTPercentage = errors.New("Invalid GST Percentage")  	ErrPasswordTooShort     = errors.New("Password Is Too Short")  	ErrPasswordTooLong      = errors.New("Password Is Too Long") +	ErrInvalidLoginMethod   = errors.New("Login Method Can Only Be 'email' Or 'username'")  	// 401  	ErrWrongPassword        = errors.New("Wrong Password") diff --git a/errors/status.go b/errors/status.go index 7a23ddf..c7fc2a4 100644 --- a/errors/status.go +++ b/errors/status.go @@ -40,6 +40,7 @@ func StatusCodeFromErr(err error) int {  		errors.Is(err, ErrInvalidUnitPrice) ||  		errors.Is(err, ErrPasswordTooShort) ||  		errors.Is(err, ErrPasswordTooLong) || +		errors.Is(err, ErrInvalidLoginMethod) ||  		errors.Is(err, ErrInvalidGSTPercentage) {  		return http.StatusBadRequest  	} @@ -37,7 +37,7 @@ import (  	"log"  ) -const OPENBILLS_VERSION = "v0.0.5" +const OPENBILLS_VERSION = "v0.0.6"  func init() {  	if viper.GetBool("production_mode") { diff --git a/user/service.go b/user/service.go index 5e0632b..4544cb4 100644 --- a/user/service.go +++ b/user/service.go @@ -27,6 +27,25 @@ func (u *User) Create() error {  	return res.Error  } +func GetUserWithAccountName(user *User, accountName, method string) error { +	if method != "username" && method != "email" { +		return e.ErrInvalidLoginMethod +	} + +	res := db.Where(method + " = ?", accountName).Find(&user) + +	// TODO: handle potential errors +	if res.Error != nil { +		return res.Error +	} + +	if res.RowsAffected == 0 { +		return e.ErrNotFound +	} + +	return nil +} +  func GetUser(user *User, id uint) error {  	res := db.Find(&user, id) diff --git a/user/user.go b/user/user.go index 68ceb47..ee36e95 100644 --- a/user/user.go +++ b/user/user.go @@ -44,9 +44,8 @@ type User struct {  	IsVerified bool  } -func CheckPassword(id uint, pass string) error { -	var user User -	err := GetUser(&user, id) +func CheckPassword(user *User, accountName, method, pass string) error { +	err := GetUserWithAccountName(user, accountName, method)  	if err != nil {  		return err  	}  |