Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add unit tests for pkg/util #1343

Merged
merged 3 commits into from
Jun 10, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 11 additions & 2 deletions pkg/util/enforcement_action.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package util

import (
"errors"
"fmt"

"k8s.io/apimachinery/pkg/apis/meta/v1/unstructured"
Expand All @@ -22,19 +23,27 @@ var supportedEnforcementActions = []EnforcementAction{Deny, Dryrun, Warn}
// KnownEnforcementActions are all defined EnforcementActions.
var KnownEnforcementActions = []EnforcementAction{Deny, Dryrun, Warn, Unrecognized}

// ErrEnforcementAction indicates the passed EnforcementAction is not valid.
var ErrEnforcementAction = errors.New("unrecognized enforcementAction")

// ErrInvalidSpecEnforcementAction indicates that we were unable to parse the
// spec.enforcementAction field as it was not a string.
var ErrInvalidSpecEnforcementAction = errors.New("spec.enforcementAction must be a string")

func ValidateEnforcementAction(input EnforcementAction) error {
for _, n := range supportedEnforcementActions {
if input == n {
return nil
}
}
return fmt.Errorf("could not find the provided enforcementAction value %s within the supported list %v", input, supportedEnforcementActions)
return fmt.Errorf("%w: %q is not within the supported list %v",
ErrEnforcementAction, input, supportedEnforcementActions)
}

func GetEnforcementAction(item map[string]interface{}) (EnforcementAction, error) {
enforcementActionSpec, _, err := unstructured.NestedString(item, "spec", "enforcementAction")
if err != nil {
return "", err
return "", fmt.Errorf("%w: %v", ErrInvalidSpecEnforcementAction, err)
}
enforcementAction := EnforcementAction(enforcementActionSpec)
// default enforcementAction is deny
Expand Down
94 changes: 84 additions & 10 deletions pkg/util/enforcement_action_test.go
Original file line number Diff line number Diff line change
@@ -1,20 +1,94 @@
package util

import "testing"
import (
"errors"
"testing"
)

func TestValidateEnforcementAction(t *testing.T) {
err := ValidateEnforcementAction("")
if err == nil {
t.Errorf("ValidateEnforcementAction should error when enforcementAction is not recognized, %v", err)
testCases := []struct {
name string
action EnforcementAction
wantErr error
}{
{
name: "empty string",
action: "",
wantErr: ErrEnforcementAction,
},
{
action: "notsupported",
wantErr: ErrEnforcementAction,
},
{
action: Dryrun,
},
}

err = ValidateEnforcementAction("notsupported")
if err == nil {
t.Errorf("ValidateEnforcementAction should error when enforcementAction is not recognized, %v", err)
for _, tc := range testCases {
if tc.name == "" {
tc.name = string(tc.action)
}
t.Run(tc.name, func(t *testing.T) {
err := ValidateEnforcementAction(tc.action)
if !errors.Is(err, tc.wantErr) {
t.Errorf("got ValidateEnforcementAction(%q) == %v, want %v",
tc.action, err, tc.wantErr)
}
})
}
}

func TestGetEnforcementAction(t *testing.T) {
testCases := []struct {
name string
item map[string]interface{}
want EnforcementAction
wantErr error
}{
{
name: "empty item",
item: map[string]interface{}{},
want: Deny,
},
{
name: "invalid spec.enforcementAction",
item: map[string]interface{}{
"spec": []string{},
},
wantErr: ErrInvalidSpecEnforcementAction,
},
{
name: "unsupported spec.enforcementAction",
item: map[string]interface{}{
"spec": map[string]interface{}{
"enforcementAction": "notsupported",
},
},
want: Unrecognized,
},
{
name: "valid spec.enforcementAction",
item: map[string]interface{}{
"spec": map[string]interface{}{
"enforcementAction": string(Dryrun),
},
},
want: Dryrun,
},
}

err = ValidateEnforcementAction("dryrun")
if err != nil {
t.Errorf("ValidateEnforcementAction should not error when enforcementAction is recognized, %v", err)
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
got, err := GetEnforcementAction(tc.item)
if !errors.Is(err, tc.wantErr) {
t.Fatalf("got GetEnforcementAction() error = %v, want %v",
err, tc.wantErr)
}
if got != tc.want {
t.Errorf("got GetEnforcementAction() = %v, want %v",
got, tc.want)
}
})
}
}
11 changes: 9 additions & 2 deletions pkg/util/pack.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package util

import (
"errors"
"fmt"
"strings"

Expand All @@ -12,17 +13,23 @@ import (
"sigs.k8s.io/controller-runtime/pkg/reconcile"
)

// ErrInvalidPackedName indicates that the packed name of the request to be
// unpacked was invalid.
var ErrInvalidPackedName = errors.New("invalid packed name, want request.Name to match 'gvk:[Kind].[Version].[Group]:[Name]'")

// UnpackRequest unpacks the GVK from a reconcile.Request and returns the separated components.
// GVK is encoded as "Kind.Version.Group".
// Requests are expected to be in the format: {Name: "gvk:EncodedGVK:Name", Namespace: Namespace}
func UnpackRequest(r reconcile.Request) (schema.GroupVersionKind, reconcile.Request, error) {
fields := strings.SplitN(r.Name, ":", 3)
if len(fields) != 3 || fields[0] != "gvk" {
return schema.GroupVersionKind{}, reconcile.Request{}, fmt.Errorf("invalid packed name: %s", r.Name)
return schema.GroupVersionKind{}, reconcile.Request{},
fmt.Errorf("%w: %q", ErrInvalidPackedName, r.Name)
}
gvk, _ := schema.ParseKindArg(fields[1])
if gvk == nil {
return schema.GroupVersionKind{}, reconcile.Request{}, fmt.Errorf("unable to parse gvk: %s", fields[1])
return schema.GroupVersionKind{}, reconcile.Request{},
fmt.Errorf("%w: unable to parse [Kind].[Version].[Group]: %q", ErrInvalidPackedName, fields[1])
}

return *gvk, reconcile.Request{NamespacedName: types.NamespacedName{
Expand Down
107 changes: 107 additions & 0 deletions pkg/util/pack_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
package util

import (
"errors"
"testing"

"github.com/google/go-cmp/cmp"
"k8s.io/apimachinery/pkg/apis/meta/v1/unstructured"
"k8s.io/apimachinery/pkg/runtime/schema"
"k8s.io/apimachinery/pkg/types"
"sigs.k8s.io/controller-runtime/pkg/client"
"sigs.k8s.io/controller-runtime/pkg/reconcile"
)

func TestUnpackRequest(t *testing.T) {
testCases := []struct {
name string
request reconcile.Request
wantGVK schema.GroupVersionKind
wantRequest reconcile.Request
wantErr error
}{
{
name: "empty request",
request: reconcile.Request{},
wantErr: ErrInvalidPackedName,
},
{
name: "invalid gvk",
request: reconcile.Request{
NamespacedName: types.NamespacedName{
Name: "gvk:b:c",
},
},
wantErr: ErrInvalidPackedName,
},
{
name: "valid gvk",
request: reconcile.Request{
NamespacedName: types.NamespacedName{
Name: "gvk:Role.v1.rbac:foo",
Namespace: "shipping",
},
},
wantGVK: schema.GroupVersionKind{Kind: "Role", Version: "v1", Group: "rbac"},
wantRequest: reconcile.Request{
NamespacedName: types.NamespacedName{
Name: "foo",
Namespace: "shipping",
},
},
},
}

for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
gvk, request, err := UnpackRequest(tc.request)
if !errors.Is(err, tc.wantErr) {
t.Fatalf("got UnpackRequest() err = %v, want %v",
err, tc.wantErr)
}
if diff := cmp.Diff(tc.wantGVK, gvk); diff != "" {
t.Errorf("got UnpackRequest() gvk diff: %v", diff)
}
if diff := cmp.Diff(tc.wantRequest, request); diff != "" {
t.Errorf("got UnpackRequest() request diff: %v", diff)
}
})
}
}

func TestEventPackerMapFunc(t *testing.T) {
testCases := []struct {
name string
obj client.Object
want []reconcile.Request
}{
{
name: "no object",
obj: nil,
want: nil,
},
{
name: "empty object",
obj: &unstructured.Unstructured{},
want: []reconcile.Request{
{NamespacedName: types.NamespacedName{Name: "gvk:.v1.:"}},
},
},
}

for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
got := EventPackerMapFunc()(tc.obj)
if diff := cmp.Diff(tc.want, got); diff != "" {
t.Errorf("got EventPackerMapFunc()(obj) diff: %v", diff)
}

for _, r := range got {
_, _, err := UnpackRequest(r)
if err != nil {
t.Errorf("got invalid Request: %v", err)
}
}
})
}
}