aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorVidhu Kant Sharma <vidhukant@vidhukant.com>2023-09-03 15:00:03 +0530
committerVidhu Kant Sharma <vidhukant@vidhukant.com>2023-09-03 15:00:03 +0530
commit95dfc551f7eaaf6e8ebdefce1b733951354ac40d (patch)
treeecabef39a94084c456eaf32bb508d8809d592c9f
parentfc83df70b787e447bf31f4d99fa723c7e38544f2 (diff)
added JWT authorization middleware
-rw-r--r--auth/middleware.go77
-rw-r--r--errors/errors.go5
-rw-r--r--errors/status.go8
-rw-r--r--main.go9
4 files changed, 96 insertions, 3 deletions
diff --git a/auth/middleware.go b/auth/middleware.go
new file mode 100644
index 0000000..299a7be
--- /dev/null
+++ b/auth/middleware.go
@@ -0,0 +1,77 @@
+/* openbills - Server for web based Libre Billing Software
+ * Copyright (C) 2023 Vidhu Kant Sharma <vidhukant@vidhukant.com>
+ *
+ * This program is free software: you can redistribute it and/or modify
+ * it under the terms of the GNU General Public License as published by
+ * the Free Software Foundation, either version 3 of the License, or
+ * (at your option) any later version.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ * GNU General Public License for more details.
+ *
+ * You should have received a copy of the GNU General Public License
+ * along with this program. If not, see <https://www.gnu.org/licenses/>.
+ */
+
+package auth
+
+import (
+ "vidhukant.com/openbills/errors"
+ "github.com/gin-gonic/gin"
+ "github.com/golang-jwt/jwt/v5"
+ "strings"
+ "time"
+)
+
+func getBearerToken(header []string) (string, error) {
+ if len(header) > 0 {
+ s := strings.Split(header[0], "Bearer ")
+ if len(s) == 2 {
+ return s[1], nil
+ } else {
+ return "", errors.ErrInvalidAuthHeader
+ }
+ } else {
+ return "", errors.ErrInvalidAuthHeader
+ }
+}
+
+func Authorize() gin.HandlerFunc {
+ return func(ctx *gin.Context) {
+ bearerToken, err := getBearerToken(ctx.Request.Header["Authorization"])
+ if err != nil {
+ ctx.Error(err)
+ ctx.Abort()
+ return
+ }
+
+ tk, err := jwt.ParseWithClaims(bearerToken, &AuthClaims{}, func (token *jwt.Token) (interface{}, error) {
+ return []byte(AUTH_KEY), nil
+ })
+
+ claims, ok := tk.Claims.(*AuthClaims)
+ if !ok {
+ ctx.Error(errors.ErrInvalidAuthHeader)
+ ctx.Abort()
+ return
+ }
+
+ if !tk.Valid {
+ eat := claims.ExpiresAt.Unix()
+ if eat != 0 && eat < time.Now().Unix() {
+ ctx.Error(errors.ErrSessionExpired)
+ } else {
+ ctx.Error(errors.ErrUnauthorized)
+ }
+
+ ctx.Abort()
+ return
+ }
+
+ ctx.Set("UserID", claims.UserID)
+
+ ctx.Next()
+ }
+}
diff --git a/errors/errors.go b/errors/errors.go
index 1cae027..16e1646 100644
--- a/errors/errors.go
+++ b/errors/errors.go
@@ -23,7 +23,7 @@ import (
var (
// 204
- ErrEmptyResponse = errors.New("No Records Found")
+ ErrEmptyResponse = errors.New("No Records Found")
// 400
ErrNoWhereCondition = errors.New("No Where Condition")
@@ -38,6 +38,9 @@ var (
// 401
ErrWrongPassword = errors.New("Wrong Password")
+ ErrInvalidAuthHeader = errors.New("Invalid Authorization Header")
+ ErrUnauthorized = errors.New("Unauthorized")
+ ErrSessionExpired = errors.New("Session Expired")
// 404
ErrNotFound = errors.New("Not Found")
diff --git a/errors/status.go b/errors/status.go
index b0da1ae..4767aeb 100644
--- a/errors/status.go
+++ b/errors/status.go
@@ -41,6 +41,14 @@ func StatusCodeFromErr(err error) int {
return http.StatusBadRequest
}
+ // 401
+ if errors.Is(err, ErrWrongPassword) ||
+ errors.Is(err, ErrUnauthorized) ||
+ errors.Is(err, ErrInvalidAuthHeader) ||
+ errors.Is(err, ErrSessionExpired) {
+ return http.StatusUnauthorized
+ }
+
// 404
if errors.Is(err, ErrNotFound) ||
errors.Is(err, ErrBrandNotFound) {
diff --git a/main.go b/main.go
index dfeec0c..6b9704c 100644
--- a/main.go
+++ b/main.go
@@ -51,8 +51,13 @@ func main() {
{
user.Routes(api)
auth.Routes(api)
- customer.Routes(api)
- item.Routes(api)
+ }
+
+ protected := api.Group("/")
+ protected.Use(auth.Authorize())
+ {
+ customer.Routes(protected)
+ item.Routes(protected)
}
r.Run(":" + viper.GetString("port"))