aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--auth/auth.go6
-rw-r--r--auth/controller.go10
-rw-r--r--errors/errors.go1
-rw-r--r--errors/status.go1
-rw-r--r--main.go2
-rw-r--r--user/service.go19
-rw-r--r--user/user.go5
7 files changed, 37 insertions, 7 deletions
diff --git a/auth/auth.go b/auth/auth.go
index 6797c91..ae2db9b 100644
--- a/auth/auth.go
+++ b/auth/auth.go
@@ -25,3 +25,9 @@ type AuthClaims struct {
jwt.RegisteredClaims
UserID uint `json:"userid"`
}
+
+type LoginReq struct {
+ AccountName string
+ Method string
+ Password string
+}
diff --git a/auth/controller.go b/auth/controller.go
index 277a85a..1e6f7fe 100644
--- a/auth/controller.go
+++ b/auth/controller.go
@@ -25,6 +25,7 @@ import (
"vidhukant.com/openbills/user"
"net/http"
"time"
+ "fmt"
)
var (
@@ -68,12 +69,15 @@ func handleSignUp (ctx *gin.Context) {
}
func handleSignIn (ctx *gin.Context) {
- var u user.User
- ctx.Bind(&u)
+ var req LoginReq
+ ctx.Bind(&req)
+
+ fmt.Println(req)
var err error
+ var u user.User
- err = user.CheckPassword(u.ID, u.Password)
+ err = user.CheckPassword(&u, req.AccountName, req.Method, req.Password)
if err != nil {
// TODO: handle potential errors
ctx.Error(err)
diff --git a/errors/errors.go b/errors/errors.go
index 3f1efc5..b5f4f63 100644
--- a/errors/errors.go
+++ b/errors/errors.go
@@ -38,6 +38,7 @@ var (
ErrInvalidGSTPercentage = errors.New("Invalid GST Percentage")
ErrPasswordTooShort = errors.New("Password Is Too Short")
ErrPasswordTooLong = errors.New("Password Is Too Long")
+ ErrInvalidLoginMethod = errors.New("Login Method Can Only Be 'email' Or 'username'")
// 401
ErrWrongPassword = errors.New("Wrong Password")
diff --git a/errors/status.go b/errors/status.go
index 7a23ddf..c7fc2a4 100644
--- a/errors/status.go
+++ b/errors/status.go
@@ -40,6 +40,7 @@ func StatusCodeFromErr(err error) int {
errors.Is(err, ErrInvalidUnitPrice) ||
errors.Is(err, ErrPasswordTooShort) ||
errors.Is(err, ErrPasswordTooLong) ||
+ errors.Is(err, ErrInvalidLoginMethod) ||
errors.Is(err, ErrInvalidGSTPercentage) {
return http.StatusBadRequest
}
diff --git a/main.go b/main.go
index 5dda554..76e2831 100644
--- a/main.go
+++ b/main.go
@@ -37,7 +37,7 @@ import (
"log"
)
-const OPENBILLS_VERSION = "v0.0.5"
+const OPENBILLS_VERSION = "v0.0.6"
func init() {
if viper.GetBool("production_mode") {
diff --git a/user/service.go b/user/service.go
index 5e0632b..4544cb4 100644
--- a/user/service.go
+++ b/user/service.go
@@ -27,6 +27,25 @@ func (u *User) Create() error {
return res.Error
}
+func GetUserWithAccountName(user *User, accountName, method string) error {
+ if method != "username" && method != "email" {
+ return e.ErrInvalidLoginMethod
+ }
+
+ res := db.Where(method + " = ?", accountName).Find(&user)
+
+ // TODO: handle potential errors
+ if res.Error != nil {
+ return res.Error
+ }
+
+ if res.RowsAffected == 0 {
+ return e.ErrNotFound
+ }
+
+ return nil
+}
+
func GetUser(user *User, id uint) error {
res := db.Find(&user, id)
diff --git a/user/user.go b/user/user.go
index 68ceb47..ee36e95 100644
--- a/user/user.go
+++ b/user/user.go
@@ -44,9 +44,8 @@ type User struct {
IsVerified bool
}
-func CheckPassword(id uint, pass string) error {
- var user User
- err := GetUser(&user, id)
+func CheckPassword(user *User, accountName, method, pass string) error {
+ err := GetUserWithAccountName(user, accountName, method)
if err != nil {
return err
}