From b30a1166e0c7bc93b5a41d2ead09e349f5a6329c Mon Sep 17 00:00:00 2001 From: Sam Batschelet Date: Fri, 4 May 2018 16:16:39 -0400 Subject: [PATCH] auth: fix panic using WithRoot and improve JWT coverage --- auth/jwt_test.go | 6 ++++++ auth/store.go | 9 ++++++--- auth/store_test.go | 27 ++++++++++++++++++--------- tests/e2e/ctl_v3_auth_test.go | 10 ++++++++-- 4 files changed, 38 insertions(+), 14 deletions(-) diff --git a/auth/jwt_test.go b/auth/jwt_test.go index 4499b5b62aa..926651057da 100644 --- a/auth/jwt_test.go +++ b/auth/jwt_test.go @@ -16,6 +16,7 @@ package auth import ( "context" + "fmt" "testing" "go.uber.org/zap" @@ -94,3 +95,8 @@ func TestJWTBad(t *testing.T) { } opts["priv-key"] = jwtPrivKey } + +// testJWTOpts is useful for passing to NewTokenProvider which requires a string. +func testJWTOpts() string { + return fmt.Sprintf("%s,pub-key=%s,priv-key=%s,sign-method=RS256", tokenTypeJWT, jwtPubKey, jwtPrivKey) +} diff --git a/auth/store.go b/auth/store.go index 3f305a1a088..44df8787328 100644 --- a/auth/store.go +++ b/auth/store.go @@ -72,6 +72,9 @@ const ( rootUser = "root" rootRole = "root" + tokenTypeSimple = "simple" + tokenTypeJWT = "jwt" + revBytesLen = 8 ) @@ -1255,7 +1258,7 @@ func NewTokenProvider( } switch tokenType { - case "simple": + case tokenTypeSimple: if lg != nil { lg.Warn("simple token is not cryptographically signed") } else { @@ -1263,7 +1266,7 @@ func NewTokenProvider( } return newTokenProviderSimple(lg, indexWaiter), nil - case "jwt": + case tokenTypeJWT: return newTokenProviderJWT(lg, typeSpecificOpts) case "": @@ -1289,7 +1292,7 @@ func (as *authStore) WithRoot(ctx context.Context) context.Context { } var ctxForAssign context.Context - if ts := as.tokenProvider.(*tokenSimple); ts != nil { + if ts, ok := as.tokenProvider.(*tokenSimple); ok && ts != nil { ctx1 := context.WithValue(ctx, AuthenticateParamIndex{}, uint64(0)) prefix, err := ts.genTokenPrefix() if err != nil { diff --git a/auth/store_test.go b/auth/store_test.go index 4a459232afc..05d2c03a80f 100644 --- a/auth/store_test.go +++ b/auth/store_test.go @@ -48,7 +48,7 @@ func TestNewAuthStoreRevision(t *testing.T) { b, tPath := backend.NewDefaultTmpBackend() defer os.Remove(tPath) - tp, err := NewTokenProvider(zap.NewExample(), "simple", dummyIndexWaiter) + tp, err := NewTokenProvider(zap.NewExample(), tokenTypeSimple, dummyIndexWaiter) if err != nil { t.Fatal(err) } @@ -78,7 +78,7 @@ func TestNewAuthStoreBcryptCost(t *testing.T) { b, tPath := backend.NewDefaultTmpBackend() defer os.Remove(tPath) - tp, err := NewTokenProvider(zap.NewExample(), "simple", dummyIndexWaiter) + tp, err := NewTokenProvider(zap.NewExample(), tokenTypeSimple, dummyIndexWaiter) if err != nil { t.Fatal(err) } @@ -98,7 +98,7 @@ func TestNewAuthStoreBcryptCost(t *testing.T) { func setupAuthStore(t *testing.T) (store *authStore, teardownfunc func(t *testing.T)) { b, tPath := backend.NewDefaultTmpBackend() - tp, err := NewTokenProvider(zap.NewExample(), "simple", dummyIndexWaiter) + tp, err := NewTokenProvider(zap.NewExample(), tokenTypeSimple, dummyIndexWaiter) if err != nil { t.Fatal(err) } @@ -535,7 +535,7 @@ func TestAuthInfoFromCtxRace(t *testing.T) { b, tPath := backend.NewDefaultTmpBackend() defer os.Remove(tPath) - tp, err := NewTokenProvider(zap.NewExample(), "simple", dummyIndexWaiter) + tp, err := NewTokenProvider(zap.NewExample(), tokenTypeSimple, dummyIndexWaiter) if err != nil { t.Fatal(err) } @@ -601,7 +601,7 @@ func TestRecoverFromSnapshot(t *testing.T) { as.Close() - tp, err := NewTokenProvider(zap.NewExample(), "simple", dummyIndexWaiter) + tp, err := NewTokenProvider(zap.NewExample(), tokenTypeSimple, dummyIndexWaiter) if err != nil { t.Fatal(err) } @@ -683,7 +683,7 @@ func TestRolesOrder(t *testing.T) { b, tPath := backend.NewDefaultTmpBackend() defer os.Remove(tPath) - tp, err := NewTokenProvider(zap.NewExample(), "simple", dummyIndexWaiter) + tp, err := NewTokenProvider(zap.NewExample(), tokenTypeSimple, dummyIndexWaiter) if err != nil { t.Fatal(err) } @@ -724,12 +724,21 @@ func TestRolesOrder(t *testing.T) { } } -// TestAuthInfoFromCtxWithRoot ensures "WithRoot" properly embeds token in the context. -func TestAuthInfoFromCtxWithRoot(t *testing.T) { +func TestAuthInfoFromCtxWithRootSimple(t *testing.T) { + testAuthInfoFromCtxWithRoot(t, tokenTypeSimple) +} + +func TestAuthInfoFromCtxWithRootJWT(t *testing.T) { + opts := testJWTOpts() + testAuthInfoFromCtxWithRoot(t, opts) +} + +// testAuthInfoFromCtxWithRoot ensures "WithRoot" properly embeds token in the context. +func testAuthInfoFromCtxWithRoot(t *testing.T, opts string) { b, tPath := backend.NewDefaultTmpBackend() defer os.Remove(tPath) - tp, err := NewTokenProvider(zap.NewExample(), "simple", dummyIndexWaiter) + tp, err := NewTokenProvider(zap.NewExample(), opts, dummyIndexWaiter) if err != nil { t.Fatal(err) } diff --git a/tests/e2e/ctl_v3_auth_test.go b/tests/e2e/ctl_v3_auth_test.go index f0c313e5f59..928687d8424 100644 --- a/tests/e2e/ctl_v3_auth_test.go +++ b/tests/e2e/ctl_v3_auth_test.go @@ -30,6 +30,7 @@ func TestCtlV3AuthRoleUpdate(t *testing.T) { testCtl(t, authRoleUpdateT func TestCtlV3AuthUserDeleteDuringOps(t *testing.T) { testCtl(t, authUserDeleteDuringOpsTest) } func TestCtlV3AuthRoleRevokeDuringOps(t *testing.T) { testCtl(t, authRoleRevokeDuringOpsTest) } func TestCtlV3AuthTxn(t *testing.T) { testCtl(t, authTestTxn) } +func TestCtlV3AuthTxnJWT(t *testing.T) { testCtl(t, authTestTxn, withCfg(configJWT)) } func TestCtlV3AuthPrefixPerm(t *testing.T) { testCtl(t, authTestPrefixPerm) } func TestCtlV3AuthMemberAdd(t *testing.T) { testCtl(t, authTestMemberAdd) } func TestCtlV3AuthMemberRemove(t *testing.T) { @@ -41,11 +42,15 @@ func TestCtlV3AuthRevokeWithDelete(t *testing.T) { testCtl(t, authTestRevokeWith func TestCtlV3AuthInvalidMgmt(t *testing.T) { testCtl(t, authTestInvalidMgmt) } func TestCtlV3AuthFromKeyPerm(t *testing.T) { testCtl(t, authTestFromKeyPerm) } func TestCtlV3AuthAndWatch(t *testing.T) { testCtl(t, authTestWatch) } +func TestCtlV3AuthAndWatchJWT(t *testing.T) { testCtl(t, authTestWatch, withCfg(configJWT)) } func TestCtlV3AuthLeaseTestKeepAlive(t *testing.T) { testCtl(t, authLeaseTestKeepAlive) } func TestCtlV3AuthLeaseTestTimeToLiveExpired(t *testing.T) { testCtl(t, authLeaseTestTimeToLiveExpired) } func TestCtlV3AuthLeaseGrantLeases(t *testing.T) { testCtl(t, authLeaseTestLeaseGrantLeases) } -func TestCtlV3AuthLeaseRevoke(t *testing.T) { testCtl(t, authLeaseTestLeaseRevoke) } +func TestCtlV3AuthLeaseGrantLeasesJWT(t *testing.T) { + testCtl(t, authLeaseTestLeaseGrantLeases, withCfg(configJWT)) +} +func TestCtlV3AuthLeaseRevoke(t *testing.T) { testCtl(t, authLeaseTestLeaseRevoke) } func TestCtlV3AuthRoleGet(t *testing.T) { testCtl(t, authTestRoleGet) } func TestCtlV3AuthUserGet(t *testing.T) { testCtl(t, authTestUserGet) } @@ -55,7 +60,8 @@ func TestCtlV3AuthDefrag(t *testing.T) { testCtl(t, authTestDefrag) } func TestCtlV3AuthEndpointHealth(t *testing.T) { testCtl(t, authTestEndpointHealth, withQuorum()) } -func TestCtlV3AuthSnapshot(t *testing.T) { testCtl(t, authTestSnapshot) } +func TestCtlV3AuthSnapshot(t *testing.T) { testCtl(t, authTestSnapshot) } +func TestCtlV3AuthSnapshotJWT(t *testing.T) { testCtl(t, authTestSnapshot, withCfg(configJWT)) } func TestCtlV3AuthCertCNAndUsername(t *testing.T) { testCtl(t, authTestCertCNAndUsername, withCfg(configClientTLSCertAuth)) }