summaryrefslogtreecommitdiff
path: root/auth
diff options
context:
space:
mode:
Diffstat (limited to 'auth')
-rw-r--r--auth/auth.go16
-rw-r--r--auth/jwt.go30
-rw-r--r--auth/jwt_middleware.go30
3 files changed, 57 insertions, 19 deletions
diff --git a/auth/auth.go b/auth/auth.go
index 7bf251f..1048f82 100644
--- a/auth/auth.go
+++ b/auth/auth.go
@@ -51,17 +51,13 @@ func Routes(route *gin.Engine) {
})
r.POST("/refresh", verifyRefreshToken(), func (ctx *gin.Context) {
- userId := ctx.MustGet("userId")
- if userId != "" {
- accessToken, err := newAccessToken(userId.(string))
- if err != nil {
- log.Printf("Error while generating new access token: %v", err)
- ctx.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"message": "Internal Server Error (cannot refresh session)"})
- } else {
- ctx.JSON(http.StatusOK, gin.H{"accessToken": accessToken})
- }
+ u := ctx.MustGet("user").(user.User)
+ accessToken, err := newAccessToken(u.Id.Hex())
+ if err != nil {
+ log.Printf("Error while generating new access token: %v", err)
+ ctx.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"message": "Internal Server Error (cannot refresh session)"})
} else {
- ctx.JSON(http.StatusBadRequest, gin.H{"message": "invalid user info"})
+ ctx.JSON(http.StatusOK, gin.H{"accessToken": accessToken})
}
})
}
diff --git a/auth/jwt.go b/auth/jwt.go
index 2d2ea8e..66a4f12 100644
--- a/auth/jwt.go
+++ b/auth/jwt.go
@@ -18,11 +18,21 @@
package auth
import (
- "github.com/golang-jwt/jwt/v4"
+ "github.com/MikunoNaka/OpenBills-server/user"
"github.com/MikunoNaka/OpenBills-server/util"
+ "github.com/golang-jwt/jwt/v4"
+ "go.mongodb.org/mongo-driver/bson"
+ "go.mongodb.org/mongo-driver/bson/primitive"
+
+ "context"
+ "errors"
"time"
)
+var (
+ errUserNotFound error = errors.New("user does not exist")
+)
+
var accessSecret []byte
var refreshSecret []byte
func init() {
@@ -56,18 +66,30 @@ func newAccessToken(userId string) (string, error) {
* for enhanced security
*/
func newRefreshToken(userId string) (string, int64, error) {
- // TODO: store in DB
- expiresAt := time.Now().Add(time.Hour * 12).Unix()
+ // convert id from string to ObjectID
+ id, _ := primitive.ObjectIDFromHex(userId)
+ // check if user exists
+ var u user.User
+ if err := db.FindOne(context.TODO(), bson.M{"_id": id}).Decode(&u); err != nil {
+ return "", 0, errUserNotFound
+ }
+
+ // generate refresh token
+ expiresAt := time.Now().Add(time.Hour * 12).Unix()
claims := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.StandardClaims {
Issuer: userId,
ExpiresAt: expiresAt,
})
-
token, err := claims.SignedString(refreshSecret)
if err != nil {
return "", expiresAt, err
}
+ // store refresh token in db with unique session name for ease in identification
+ sessionName := time.Now().Format("01-02-2006.15:04:05") + "-" + u.UserName
+ u.Sessions = append(u.Sessions, user.Session{Name: sessionName, Token: token})
+ db.UpdateOne(context.TODO(), bson.M{"_id": id}, bson.D{{"$set", u}})
+
return token, expiresAt, nil
}
diff --git a/auth/jwt_middleware.go b/auth/jwt_middleware.go
index 22d1fd7..8dd77b8 100644
--- a/auth/jwt_middleware.go
+++ b/auth/jwt_middleware.go
@@ -18,9 +18,13 @@
package auth
import (
- "net/http"
+ "github.com/MikunoNaka/OpenBills-server/user"
+ "go.mongodb.org/mongo-driver/bson/primitive"
+ "go.mongodb.org/mongo-driver/bson"
"github.com/golang-jwt/jwt/v4"
"github.com/gin-gonic/gin"
+ "net/http"
+ "context"
)
func Authorize() gin.HandlerFunc {
@@ -51,11 +55,27 @@ func verifyRefreshToken() gin.HandlerFunc {
token, err := jwt.ParseWithClaims(refreshToken, &jwt.StandardClaims{}, func(token *jwt.Token) (interface{}, error) {
return []byte(refreshSecret), nil
})
- if err != nil {
+ if err != nil { // invalid token
ctx.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{"message": "refresh token expired"})
- } else {
- ctx.Set("userId", token.Claims.(*jwt.StandardClaims).Issuer)
- ctx.Next()
+ } else { // valid token
+ // convert id from string to ObjectID
+ id, _ := primitive.ObjectIDFromHex(token.Claims.(*jwt.StandardClaims).Issuer)
+
+ // check if user exists
+ var u user.User
+ if err := db.FindOne(context.TODO(), bson.M{"_id": id}).Decode(&u); err != nil {
+ ctx.AbortWithStatusJSON(http.StatusNotFound, gin.H{"message": "user not found"})
+ } else {
+ // check if this refreshToken is in DB
+ for _, i := range u.Sessions {
+ if i.Token == refreshToken {
+ ctx.Set("user", u)
+ ctx.Next()
+ } else {
+ ctx.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{"message": "refresh token expired"})
+ }
+ }
+ }
}
} else {
// invalid Authorization header