mirror of
https://github.com/cuigh/swirl
synced 2024-12-28 23:02:02 +00:00
152 lines
3.5 KiB
Go
152 lines
3.5 KiB
Go
package security
|
|
|
|
import (
|
|
"errors"
|
|
"strings"
|
|
"time"
|
|
|
|
"github.com/cuigh/auxo/data"
|
|
"github.com/cuigh/auxo/log"
|
|
"github.com/cuigh/auxo/net/web"
|
|
"github.com/cuigh/auxo/security"
|
|
"github.com/cuigh/auxo/util/cast"
|
|
"github.com/cuigh/swirl/misc"
|
|
"github.com/golang-jwt/jwt"
|
|
)
|
|
|
|
var ErrNoNeedRefresh = errors.New("no need to refresh")
|
|
|
|
type JWT struct {
|
|
Schema string
|
|
Sources data.Options
|
|
KeyFunc jwt.Keyfunc
|
|
Identifier func(token *jwt.Token) web.User
|
|
tokenExpiry int64
|
|
logger log.Logger
|
|
}
|
|
|
|
func NewIdentifier() web.Filter {
|
|
logger := log.Get("security")
|
|
key := misc.Options.TokenKey
|
|
expiry := misc.Options.TokenExpiry
|
|
if key == "" {
|
|
key = "swirl"
|
|
logger.Warnf("Swirl is using default token key as token_key isn't configured, this may cause security problems")
|
|
}
|
|
if expiry == 0 {
|
|
expiry = 30 * time.Minute
|
|
}
|
|
|
|
return &JWT{
|
|
logger: logger,
|
|
tokenExpiry: int64(expiry.Seconds()),
|
|
Schema: "Bearer",
|
|
Sources: data.Options{
|
|
{Name: "header", Value: web.HeaderAuthorization},
|
|
},
|
|
KeyFunc: func(token *jwt.Token) (interface{}, error) {
|
|
// TODO: use user salt as key?
|
|
return []byte(key), nil
|
|
},
|
|
Identifier: func(token *jwt.Token) web.User {
|
|
claims := token.Claims.(jwt.MapClaims)
|
|
return security.NewUser(cast.ToString(claims["sub"]), cast.ToString(claims["name"]))
|
|
},
|
|
}
|
|
}
|
|
|
|
func (j *JWT) Apply(next web.HandlerFunc) web.HandlerFunc {
|
|
if j.KeyFunc == nil {
|
|
panic("KeyFunc is required")
|
|
}
|
|
if j.Schema == "" {
|
|
j.Schema = "Bearer"
|
|
}
|
|
if len(j.Sources) == 0 {
|
|
j.Sources = data.Options{
|
|
{Name: "header", Value: web.HeaderAuthorization},
|
|
}
|
|
}
|
|
if j.Identifier == nil {
|
|
j.Identifier = func(token *jwt.Token) web.User {
|
|
claims := token.Claims.(jwt.MapClaims)
|
|
return security.NewUser(cast.ToString(claims["sub"]), cast.ToString(claims["name"]))
|
|
}
|
|
}
|
|
|
|
return func(ctx web.Context) error {
|
|
ts := j.extractToken(ctx)
|
|
if ts != "" {
|
|
token, err := jwt.Parse(ts, j.KeyFunc)
|
|
if err != nil {
|
|
j.logger.Debugf("failed to parse token: %s", err)
|
|
} else {
|
|
user := j.Identifier(token)
|
|
ctx.SetUser(user)
|
|
if ts, err = j.refreshToken(user, token); err == nil {
|
|
ctx.SetHeader(web.HeaderAuthorization, ts)
|
|
} else if err != ErrNoNeedRefresh {
|
|
j.logger.Errorf("failed to refresh token: %s", err)
|
|
}
|
|
}
|
|
}
|
|
|
|
return next(ctx)
|
|
}
|
|
}
|
|
|
|
func (j *JWT) extractToken(ctx web.Context) (token string) {
|
|
for _, src := range j.Sources {
|
|
switch src.Name {
|
|
case "header":
|
|
token = ctx.Header(src.Value)
|
|
if strings.HasPrefix(token, j.Schema) {
|
|
return token[len(j.Schema)+1:]
|
|
}
|
|
case "cookie":
|
|
if cookie, err := ctx.Cookie(src.Value); err == nil {
|
|
token = cookie.Value
|
|
}
|
|
case "form":
|
|
token = ctx.Form(src.Value)
|
|
case "query":
|
|
token = ctx.Query(src.Value)
|
|
}
|
|
if token != "" {
|
|
return
|
|
}
|
|
}
|
|
return
|
|
}
|
|
|
|
func (j *JWT) refreshToken(user web.User, token *jwt.Token) (string, error) {
|
|
claims := token.Claims.(jwt.MapClaims)
|
|
expiry := cast.ToInt64(claims["exp"])
|
|
now := time.Now().Unix()
|
|
// refresh token when remaining expiry is less than 5 minutes
|
|
if (expiry - now) < 5*60 {
|
|
ts, err := j.CreateToken(user.ID(), user.Name())
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
return ts, nil
|
|
}
|
|
return "", ErrNoNeedRefresh
|
|
}
|
|
|
|
func (j *JWT) CreateToken(id, name string) (string, error) {
|
|
now := time.Now().Unix()
|
|
claims := jwt.MapClaims{
|
|
"name": name,
|
|
"sub": id,
|
|
"iat": now,
|
|
"exp": now + j.tokenExpiry,
|
|
}
|
|
token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
|
|
key, err := j.KeyFunc(token)
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
return token.SignedString(key)
|
|
}
|