swirl/security/jwt.go
2021-12-15 17:26:45 +08:00

152 lines
3.5 KiB
Go

package security
import (
"errors"
"strings"
"time"
"github.com/cuigh/auxo/data"
"github.com/cuigh/auxo/log"
"github.com/cuigh/auxo/net/web"
"github.com/cuigh/auxo/security"
"github.com/cuigh/auxo/util/cast"
"github.com/cuigh/swirl/misc"
"github.com/dgrijalva/jwt-go"
)
var ErrNoNeedRefresh = errors.New("no need to refresh")
type JWT struct {
Schema string
Sources data.Options
KeyFunc jwt.Keyfunc
Identifier func(token *jwt.Token) web.User
tokenExpiry int64
logger log.Logger
}
func NewIdentifier() web.Filter {
logger := log.Get("security")
key := misc.Options.TokenKey
expiry := misc.Options.TokenExpiry
if key == "" {
key = "swirl"
logger.Warnf("Swirl is using default token key as token_key isn't configured, this may cause security problems")
}
if expiry == 0 {
expiry = 30 * time.Minute
}
return &JWT{
logger: logger,
tokenExpiry: int64(expiry.Seconds()),
Schema: "Bearer",
Sources: data.Options{
{Name: "header", Value: web.HeaderAuthorization},
},
KeyFunc: func(token *jwt.Token) (interface{}, error) {
// TODO: use user salt as key?
return []byte(key), nil
},
Identifier: func(token *jwt.Token) web.User {
claims := token.Claims.(jwt.MapClaims)
return security.NewUser(cast.ToString(claims["sub"]), cast.ToString(claims["name"]))
},
}
}
func (j *JWT) Apply(next web.HandlerFunc) web.HandlerFunc {
if j.KeyFunc == nil {
panic("KeyFunc is required")
}
if j.Schema == "" {
j.Schema = "Bearer"
}
if len(j.Sources) == 0 {
j.Sources = data.Options{
{Name: "header", Value: web.HeaderAuthorization},
}
}
if j.Identifier == nil {
j.Identifier = func(token *jwt.Token) web.User {
claims := token.Claims.(jwt.MapClaims)
return security.NewUser(cast.ToString(claims["sub"]), cast.ToString(claims["name"]))
}
}
return func(ctx web.Context) error {
ts := j.extractToken(ctx)
if ts != "" {
token, err := jwt.Parse(ts, j.KeyFunc)
if err != nil {
j.logger.Debugf("failed to parse token: %s", err)
} else {
user := j.Identifier(token)
ctx.SetUser(user)
if ts, err = j.refreshToken(user, token); err == nil {
ctx.SetHeader(web.HeaderAuthorization, ts)
} else if err != ErrNoNeedRefresh {
j.logger.Errorf("failed to refresh token: %s", err)
}
}
}
return next(ctx)
}
}
func (j *JWT) extractToken(ctx web.Context) (token string) {
for _, src := range j.Sources {
switch src.Name {
case "header":
token = ctx.Header(src.Value)
if strings.HasPrefix(token, j.Schema) {
return token[len(j.Schema)+1:]
}
case "cookie":
if cookie, err := ctx.Cookie(src.Value); err == nil {
token = cookie.Value
}
case "form":
token = ctx.Form(src.Value)
case "query":
token = ctx.Query(src.Value)
}
if token != "" {
return
}
}
return
}
func (j *JWT) refreshToken(user web.User, token *jwt.Token) (string, error) {
claims := token.Claims.(jwt.MapClaims)
expiry := cast.ToInt64(claims["exp"])
now := time.Now().Unix()
// refresh token when remaining expiry is less than 5 minutes
if (expiry - now) < 5*60 {
ts, err := j.CreateToken(user.ID(), user.Name())
if err != nil {
return "", err
}
return ts, nil
}
return "", ErrNoNeedRefresh
}
func (j *JWT) CreateToken(id, name string) (string, error) {
now := time.Now().Unix()
claims := jwt.MapClaims{
"name": name,
"sub": id,
"iat": now,
"exp": now + j.tokenExpiry,
}
token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
key, err := j.KeyFunc(token)
if err != nil {
return "", err
}
return token.SignedString(key)
}