From c88ad2ee2db41b678120da70253c5162a0c5dc54 Mon Sep 17 00:00:00 2001 From: Smriti Dahal <93288516+smritidahal653@users.noreply.github.com> Date: Mon, 10 Apr 2023 16:58:36 -0700 Subject: [PATCH] test: added unit test for getGPUSKU (#525) --- pkg/provider/aci.go | 2 +- pkg/provider/aci_test.go | 73 ++++++++++++++++++++++++++++++++++++++++ 2 files changed, 74 insertions(+), 1 deletion(-) diff --git a/pkg/provider/aci.go b/pkg/provider/aci.go index 7dc2a195..74943766 100644 --- a/pkg/provider/aci.go +++ b/pkg/provider/aci.go @@ -1207,7 +1207,7 @@ func (p *ACIProvider) getGPUSKU(pod *v1.Pod) (azaciv2.GpuSKU, error) { } } - return "", fmt.Errorf("the pod requires GPU SKU %s, but ACI only supports SKUs %v in region %s", desiredSKU, p.region, p.gpuSKUs) + return "", fmt.Errorf("the pod requires GPU SKU %s, but ACI only supports SKUs %v in region %s", desiredSKU, p.gpuSKUs, p.region) } return p.gpuSKUs[0], nil diff --git a/pkg/provider/aci_test.go b/pkg/provider/aci_test.go index 6f8a709a..2d12faf4 100644 --- a/pkg/provider/aci_test.go +++ b/pkg/provider/aci_test.go @@ -1785,5 +1785,78 @@ func TestGetContainerLogs(t *testing.T) { }) } +} + +func TestGetGPUSKU(t *testing.T) { + podName := "pod-" + uuid.New().String() + podNamespace := "ns-" + uuid.New().String() + + mockCtrl := gomock.NewController(t) + defer mockCtrl.Finish() + + aciMocks := createNewACIMock() + + provider, err := createTestProvider(aciMocks, NewMockConfigMapLister(mockCtrl), + NewMockSecretLister(mockCtrl), NewMockPodLister(mockCtrl)) + if err != nil { + t.Fatal("failed to create the test provider", err) + } + cases := []struct { + description string + gpuSkus []azaciv2.GpuSKU + desiredSku string + expectedError error + }{ + { + description: "gpuTypeAnnotation is not set but ACI provides gpusku", + gpuSkus: []azaciv2.GpuSKU{azaciv2.GpuSKUK80, azaciv2.GpuSKUP100}, + desiredSku: "", + expectedError: nil, + }, + { + description: "gpuTypeAnnotation is set and the desired sku is supported by ACI", + gpuSkus: []azaciv2.GpuSKU{azaciv2.GpuSKUK80, azaciv2.GpuSKUP100}, + desiredSku: "P100", + expectedError: nil, + }, + { + description: "gpuTypeAnnotation is set but the desired sku is not supported by ACI", + gpuSkus: []azaciv2.GpuSKU{azaciv2.GpuSKUK80, azaciv2.GpuSKUP100}, + desiredSku: "P120", + expectedError: fmt.Errorf("the pod requires GPU SKU P120, but ACI only supports SKUs %v in region %s", []azaciv2.GpuSKU{azaciv2.GpuSKUK80, azaciv2.GpuSKUP100}, provider.region), + }, + { + description: "ACI doesn't provide any gpusku", + gpuSkus: []azaciv2.GpuSKU{}, + desiredSku: "", + expectedError: fmt.Errorf("the pod requires GPU resource, but ACI doesn't provide GPU enabled container group in region %s", provider.region), + }, + } + + for _, tc := range cases { + t.Run(tc.description, func(t *testing.T) { + provider.gpuSKUs = tc.gpuSkus + + pod := testsutil.CreatePodObj(podName, podNamespace) + if len(tc.desiredSku) > 0 { + pod.Annotations = map[string]string{} + pod.Annotations[gpuTypeAnnotation] = tc.desiredSku + } + + gpuSKU, err := provider.getGPUSKU(pod) + + if tc.expectedError != nil { + assert.Equal(t, err.Error(), tc.expectedError.Error(), "Error messages should match") + assert.Equal(t, string(gpuSKU), "", "No GPU SKU should be returned") + } else { + assert.NilError(t, err, "no error should be returned") + if len(tc.desiredSku) == 0 { + assert.Equal(t, gpuSKU, tc.gpuSkus[0], "Since no desired SKU was set, the first gpuSKU in the list should be returned") + } else { + assert.Equal(t, string(gpuSKU), tc.desiredSku, "Desired SKU should be returned") + } + } + }) + } }