From 2739fd1b8ca3fa1425b1db0ed209a5d0514a099e Mon Sep 17 00:00:00 2001 From: igolaizola <11333576+igolaizola@users.noreply.github.com> Date: Fri, 5 Jul 2024 22:29:58 +0200 Subject: [PATCH] Check token expiration --- go.mod | 1 + go.sum | 2 ++ pkg/cmd/extend/extend.go | 5 +++- pkg/cmd/generate/generate.go | 5 +++- pkg/runway/client.go | 46 +++++++++++++++++++++++++++--------- 5 files changed, 46 insertions(+), 13 deletions(-) diff --git a/go.mod b/go.mod index 729247c..6b3749d 100644 --- a/go.mod +++ b/go.mod @@ -5,6 +5,7 @@ go 1.20 require ( github.com/bogdanfinn/fhttp v0.5.28 github.com/bogdanfinn/tls-client v1.7.5 + github.com/golang-jwt/jwt v3.2.2+incompatible github.com/peterbourgon/ff/v3 v3.3.0 ) diff --git a/go.sum b/go.sum index f647f49..ccf24a7 100644 --- a/go.sum +++ b/go.sum @@ -11,6 +11,8 @@ github.com/cloudflare/circl v1.3.6/go.mod h1:5XYMA4rFBvNIrhs50XuiBJ15vF2pZn4nnUK github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E= github.com/go-logr/logr v1.2.4 h1:g01GSCwiDw2xSZfjJ2/T9M+S6pFdcNtFYsp+Y43HYDQ= github.com/go-task/slim-sprig v0.0.0-20230315185526-52ccab3ef572 h1:tfuBGBXKqDEevZMzYi5KSi8KkcZtzBcTgAUUtapy0OI= +github.com/golang-jwt/jwt v3.2.2+incompatible h1:IfV12K8xAKAnZqdXVzCZ+TOjboZ2keLg81eXfW3O+oY= +github.com/golang-jwt/jwt v3.2.2+incompatible/go.mod h1:8pz2t5EyA70fFQQSrl6XZXzqecmYZeUEB8OUGHkxJ+I= github.com/google/go-cmp v0.5.9 h1:O2Tfq5qg4qc4AmwVlvv0oLiVAGB7enBSJ2x2DqQFi38= github.com/google/pprof v0.0.0-20210407192527-94a9f03dee38 h1:yAJXTCF9TqKcTiHJAE8dj7HMvPfh66eeA2JYW7eFpSE= github.com/klauspost/compress v1.16.7 h1:2mk3MPGNzKyxErAw8YaohYh69+pa4sIQSC0fPGCFR9I= diff --git a/pkg/cmd/extend/extend.go b/pkg/cmd/extend/extend.go index 02ef6b3..2ec6674 100644 --- a/pkg/cmd/extend/extend.go +++ b/pkg/cmd/extend/extend.go @@ -40,12 +40,15 @@ func Run(ctx context.Context, cfg *Config) error { if cfg.Token == "" { return fmt.Errorf("token is required") } - client := runway.New(&runway.Config{ + client, err := runway.New(&runway.Config{ Token: cfg.Token, Wait: cfg.Wait, Debug: cfg.Debug, Proxy: cfg.Proxy, }) + if err != nil { + return fmt.Errorf("vidai: couldn't create client: %w", err) + } base := strings.TrimSuffix(filepath.Base(cfg.Input), filepath.Ext(cfg.Input)) diff --git a/pkg/cmd/generate/generate.go b/pkg/cmd/generate/generate.go index df0773b..322f246 100644 --- a/pkg/cmd/generate/generate.go +++ b/pkg/cmd/generate/generate.go @@ -38,12 +38,15 @@ func Run(ctx context.Context, cfg *Config) error { if cfg.Token == "" { return fmt.Errorf("token is required") } - client := runway.New(&runway.Config{ + client, err := runway.New(&runway.Config{ Token: cfg.Token, Wait: cfg.Wait, Debug: cfg.Debug, Proxy: cfg.Proxy, }) + if err != nil { + return fmt.Errorf("vidai: couldn't create client: %w", err) + } var imageURL string if cfg.Image != "" { diff --git a/pkg/runway/client.go b/pkg/runway/client.go index 0492840..86de8c8 100644 --- a/pkg/runway/client.go +++ b/pkg/runway/client.go @@ -14,16 +14,18 @@ import ( "time" http "github.com/bogdanfinn/fhttp" + "github.com/golang-jwt/jwt" "github.com/igopr/vidai/pkg/fhttp" "github.com/igopr/vidai/pkg/ratelimit" ) type Client struct { - client fhttp.Client - debug bool - ratelimit ratelimit.Lock - token string - teamID int + client fhttp.Client + debug bool + ratelimit ratelimit.Lock + token string + expiration time.Time + teamID int } type Config struct { @@ -33,18 +35,37 @@ type Config struct { Proxy string } -func New(cfg *Config) *Client { +func New(cfg *Config) (*Client, error) { wait := cfg.Wait if wait == 0 { wait = 1 * time.Second } + // Parse the JWT + parser := jwt.Parser{} + t, _, err := parser.ParseUnverified(cfg.Token, jwt.MapClaims{}) + if err != nil { + return nil, fmt.Errorf("runway: couldn't parse token: %w", err) + } + claims, ok := t.Claims.(jwt.MapClaims) + if !ok { + return nil, fmt.Errorf("runway: couldn't parse claims: %w", err) + } + exp, ok := claims["exp"].(float64) + if !ok { + return nil, fmt.Errorf("runway: couldn't parse expiration: %w", err) + } + expiration := time.Unix(int64(exp), 0) + if expiration.Before(time.Now()) { + return nil, fmt.Errorf("runway: token expired") + } client := fhttp.NewClient(2*time.Minute, true, cfg.Proxy) return &Client{ - client: client, - ratelimit: ratelimit.New(wait), - debug: cfg.Debug, - token: cfg.Token, - } + client: client, + ratelimit: ratelimit.New(wait), + debug: cfg.Debug, + token: cfg.Token, + expiration: expiration, + }, nil } func (c *Client) log(format string, args ...interface{}) { @@ -61,6 +82,9 @@ var backoff = []time.Duration{ } func (c *Client) do(ctx context.Context, method, path string, in, out any) ([]byte, error) { + if time.Now().After(c.expiration) { + return nil, fmt.Errorf("runway: token expired") + } maxAttempts := 3 attempts := 0 var err error