diff --git a/go.mod b/go.mod index 9149c359..bb4c1b91 100644 --- a/go.mod +++ b/go.mod @@ -21,6 +21,7 @@ require ( github.com/patrickmn/go-cache v2.1.0+incompatible github.com/pkg/errors v0.9.1 github.com/sirupsen/logrus v1.6.0 + github.com/stretchr/testify v1.7.0 github.com/thoas/go-funk v0.9.1 github.com/virtual-kubelet/node-cli v0.8.0 github.com/virtual-kubelet/virtual-kubelet v1.6.0 @@ -70,6 +71,7 @@ require ( github.com/mitchellh/go-homedir v1.1.0 // indirect github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect github.com/modern-go/reflect2 v1.0.1 // indirect + github.com/pmezard/go-difflib v1.0.0 // indirect github.com/prometheus/client_golang v1.11.1 // indirect github.com/prometheus/client_model v0.2.0 // indirect github.com/prometheus/common v0.26.0 // indirect diff --git a/go.sum b/go.sum index 624888bf..a225d4bf 100644 --- a/go.sum +++ b/go.sum @@ -454,6 +454,7 @@ github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UV github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4= github.com/stretchr/testify v1.5.1/go.mod h1:5W2xD1RspED5o8YsWQXVCued0rvSQ+mT+I5cxcmMvtA= github.com/stretchr/testify v1.7.0 h1:nwc3DEeHmmLAfoZucVR881uASk0Mfjw8xYJ99tb5CcY= +github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/thoas/go-funk v0.9.1 h1:O549iLZqPpTUQ10ykd26sZhzD+rmR5pWhuElrhbC20M= github.com/thoas/go-funk v0.9.1/go.mod h1:+IWnUfUmFO1+WVYQWQtIJHeRRdaIyyYglZN7xzUPe4Q= github.com/tidwall/pretty v1.0.0/go.mod h1:XNkn88O1ChpSDQmQeStsy+sBenx6DDtFZJxhVysOjyk= @@ -752,6 +753,7 @@ gopkg.in/yaml.v2 v2.2.8/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= gopkg.in/yaml.v2 v2.3.0/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= gopkg.in/yaml.v2 v2.4.0 h1:D8xgwECY7CYvx+Y2n4sBz93Jn9JRvxdiyyo8CTfuKaY= gopkg.in/yaml.v2 v2.4.0/go.mod h1:RDklbk79AGWmwhnvt/jBztapEOGDOx6ZbXqjP6csGnQ= +gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= gopkg.in/yaml.v3 v3.0.0-20200615113413-eeeca48fe776 h1:tQIYjPdBoyREyB9XMu+nnTclpTYkz2zFM+lzLJFO4gQ= gopkg.in/yaml.v3 v3.0.0-20200615113413-eeeca48fe776/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= gotest.tools v2.2.0+incompatible h1:VsBPFP1AI068pPrMxtb/S8Zkgf9xEmTLJjfM+P5UIEo= diff --git a/pkg/provider/aci_network.go b/pkg/network/aci_network.go similarity index 65% rename from pkg/provider/aci_network.go rename to pkg/network/aci_network.go index e80d4c92..68f4bdfa 100644 --- a/pkg/provider/aci_network.go +++ b/pkg/network/aci_network.go @@ -2,7 +2,7 @@ Copyright (c) Microsoft Corporation. Licensed under the Apache 2.0 license. */ -package provider +package network import ( "context" @@ -12,6 +12,7 @@ import ( "strings" "github.com/pkg/errors" + "github.com/virtual-kubelet/azure-aci/pkg/util" utilvalidation "k8s.io/apimachinery/pkg/util/validation" azaci "github.com/Azure/azure-sdk-for-go/services/containerinstance/mgmt/2021-10-01/containerinstance" @@ -25,103 +26,89 @@ import ( // DNS configuration settings const ( - maxDNSNameservers = 3 - maxDNSSearchPaths = 6 - maxDNSSearchListChars = 256 + maxDNSNameservers = 3 + maxDNSSearchPaths = 6 + maxDNSSearchListChars = 256 + subnetDelegationService = "Microsoft.ContainerInstance/containerGroups" ) -func (p *ACIProvider) setVNETConfig(ctx context.Context, azConfig *auth.Config) error { +type ProviderNetwork struct { + VnetSubscriptionID string + VnetName string + VnetResourceGroup string + SubnetName string + SubnetCIDR string + KubeDNSIP string +} + +func (pn *ProviderNetwork) SetVNETConfig(ctx context.Context, azConfig *auth.Config) error { // the VNET subscription ID by default is authentication subscription ID. // We need to override when using cross subscription virtual network resource - p.vnetSubscriptionID = azConfig.AuthConfig.SubscriptionID + pn.VnetSubscriptionID = azConfig.AuthConfig.SubscriptionID if vnetSubscriptionID := os.Getenv("ACI_VNET_SUBSCRIPTION_ID"); vnetSubscriptionID != "" { - p.vnetSubscriptionID = vnetSubscriptionID + pn.VnetSubscriptionID = vnetSubscriptionID } if vnetName := os.Getenv("ACI_VNET_NAME"); vnetName != "" { - p.vnetName = vnetName - } else if p.vnetName == "" { + pn.VnetName = vnetName + } else if pn.VnetName == "" { return errors.New("vnet name can not be empty please set ACI_VNET_NAME") } if vnetResourceGroup := os.Getenv("ACI_VNET_RESOURCE_GROUP"); vnetResourceGroup != "" { - p.vnetResourceGroup = vnetResourceGroup - } else if p.vnetResourceGroup == "" { + pn.VnetResourceGroup = vnetResourceGroup + } else if pn.VnetResourceGroup == "" { return errors.New("vnet resourceGroup can not be empty please set ACI_VNET_RESOURCE_GROUP") } // Set subnet properties. - if subnetName := os.Getenv("ACI_SUBNET_NAME"); p.vnetName != "" && subnetName != "" { - p.subnetName = subnetName + if subnetName := os.Getenv("ACI_SUBNET_NAME"); pn.VnetName != "" && subnetName != "" { + pn.SubnetName = subnetName } if subnetCIDR := os.Getenv("ACI_SUBNET_CIDR"); subnetCIDR != "" { - if p.subnetName == "" { + if pn.SubnetName == "" { return fmt.Errorf("subnet CIDR defined but no subnet name, subnet name is required to set a subnet CIDR") } if _, _, err := net.ParseCIDR(subnetCIDR); err != nil { return fmt.Errorf("error parsing provided subnet range: %v", err) } - p.subnetCIDR = subnetCIDR + pn.SubnetCIDR = subnetCIDR } - if p.subnetName != "" { - if err := p.setupNetwork(ctx, azConfig); err != nil { + if pn.SubnetName != "" { + if err := pn.setupNetwork(ctx, azConfig); err != nil { return fmt.Errorf("error setting up network: %v", err) } - masterURI := os.Getenv("MASTER_URI") - if masterURI == "" { - masterURI = "10.0.0.1" - } - - clusterCIDR := os.Getenv("CLUSTER_CIDR") - if clusterCIDR == "" { - clusterCIDR = "10.240.0.0/16" - } - - // setup aci extensions - kubeExtensions, err := client2.GetKubeProxyExtension(serviceAccountSecretMountPath, masterURI, clusterCIDR) - if err != nil { - return fmt.Errorf("error creating kube proxy extension: %v", err) - } - - p.containerGroupExtensions = append(p.containerGroupExtensions, kubeExtensions) - - enableRealTimeMetricsExtension := os.Getenv("ENABLE_REAL_TIME_METRICS") - if enableRealTimeMetricsExtension == "true" { - realtimeExtension := client2.GetRealtimeMetricsExtension() - p.containerGroupExtensions = append(p.containerGroupExtensions, realtimeExtension) - } - if kubeDNSIP := os.Getenv("KUBE_DNS_IP"); kubeDNSIP != "" { - p.kubeDNSIP = kubeDNSIP + pn.KubeDNSIP = kubeDNSIP } } return nil } -func (p *ACIProvider) setupNetwork(ctx context.Context, azConfig *auth.Config) error { +func (pn *ProviderNetwork) setupNetwork(ctx context.Context, azConfig *auth.Config) error { c := aznetwork.NewSubnetsClient(azConfig.AuthConfig.SubscriptionID) c.Authorizer = azConfig.Authorizer createSubnet := true - subnet, err := c.Get(ctx, p.vnetResourceGroup, p.vnetName, p.subnetName, "") + subnet, err := c.Get(ctx, pn.VnetResourceGroup, pn.VnetName, pn.SubnetName, "") if err != nil && !network.IsNotFound(err) { return fmt.Errorf("error while looking up subnet: %v", err) } - if network.IsNotFound(err) && p.subnetCIDR == "" { - return fmt.Errorf("subnet '%s' is not found in vnet '%s' in resource group '%s' and subscription '%s' and subnet CIDR is not specified", p.subnetName, p.vnetName, p.vnetResourceGroup, p.vnetSubscriptionID) + if network.IsNotFound(err) && pn.SubnetCIDR == "" { + return fmt.Errorf("subnet '%s' is not found in vnet '%s' in resource group '%s' and subscription '%s' and subnet CIDR is not specified", pn.SubnetName, pn.VnetName, pn.VnetResourceGroup, pn.VnetSubscriptionID) } if err == nil { - if p.subnetCIDR == "" { - p.subnetCIDR = *subnet.SubnetPropertiesFormat.AddressPrefix + if pn.SubnetCIDR == "" { + pn.SubnetCIDR = *subnet.SubnetPropertiesFormat.AddressPrefix } - if p.subnetCIDR != *subnet.SubnetPropertiesFormat.AddressPrefix { - return fmt.Errorf("found subnet '%s' using different CIDR: '%s'. desired: '%s'", p.subnetName, *subnet.SubnetPropertiesFormat.AddressPrefix, p.subnetCIDR) + if pn.SubnetCIDR != *subnet.SubnetPropertiesFormat.AddressPrefix { + return fmt.Errorf("found subnet '%s' using different CIDR: '%s'. desired: '%s'", pn.SubnetName, *subnet.SubnetPropertiesFormat.AddressPrefix, pn.SubnetCIDR) } if subnet.SubnetPropertiesFormat.RouteTable != nil { - return fmt.Errorf("unable to delegate subnet '%s' to Azure Container Instance since it references the route table '%s'", p.subnetName, *subnet.SubnetPropertiesFormat.RouteTable.ID) + return fmt.Errorf("unable to delegate subnet '%s' to Azure Container Instance since it references the route table '%s'", pn.SubnetName, *subnet.SubnetPropertiesFormat.RouteTable.ID) } if subnet.SubnetPropertiesFormat.ServiceAssociationLinks != nil { for _, l := range *subnet.SubnetPropertiesFormat.ServiceAssociationLinks { @@ -130,7 +117,7 @@ func (p *ACIProvider) setupNetwork(ctx context.Context, azConfig *auth.Config) e createSubnet = false break } else { - return fmt.Errorf("unable to delegate subnet '%s' to Azure Container Instance as it is used by other Azure resource: '%v'", p.subnetName, l) + return fmt.Errorf("unable to delegate subnet '%s' to Azure Container Instance as it is used by other Azure resource: '%v'", pn.SubnetName, l) } } } @@ -152,9 +139,9 @@ func (p *ACIProvider) setupNetwork(ctx context.Context, azConfig *auth.Config) e ) subnet = aznetwork.Subnet{ - Name: &p.subnetName, + Name: &pn.SubnetName, SubnetPropertiesFormat: &aznetwork.SubnetPropertiesFormat{ - AddressPrefix: &p.subnetCIDR, + AddressPrefix: &pn.SubnetCIDR, Delegations: &[]aznetwork.Delegation{ { Name: &delegationName, @@ -166,7 +153,7 @@ func (p *ACIProvider) setupNetwork(ctx context.Context, azConfig *auth.Config) e }, }, } - _, err = c.CreateOrUpdate(ctx, p.vnetResourceGroup, p.vnetName, p.subnetName, subnet) + _, err = c.CreateOrUpdate(ctx, pn.VnetResourceGroup, pn.VnetName, pn.SubnetName, subnet) if err != nil { return fmt.Errorf("error creating subnet: %v", err) } @@ -174,32 +161,31 @@ func (p *ACIProvider) setupNetwork(ctx context.Context, azConfig *auth.Config) e return nil } -func (p *ACIProvider) amendVnetResources(ctx context.Context, cg client2.ContainerGroupWrapper, pod *v1.Pod) { - if p.subnetName == "" { +func (pn *ProviderNetwork) AmendVnetResources(ctx context.Context, cg client2.ContainerGroupWrapper, pod *v1.Pod, clusterDomain string) { + if pn.SubnetName == "" { return } - subnetID := "/subscriptions/" + p.vnetSubscriptionID + "/resourceGroups/" + p.vnetResourceGroup + "/providers/Microsoft.Network/virtualNetworks/" + p.vnetName + "/subnets/" + p.subnetName + subnetID := "/subscriptions/" + pn.VnetSubscriptionID + "/resourceGroups/" + pn.VnetResourceGroup + "/providers/Microsoft.Network/virtualNetworks/" + pn.VnetName + "/subnets/" + pn.SubnetName cgIDList := []azaci.ContainerGroupSubnetID{{ID: &subnetID}} cg.ContainerGroupPropertiesWrapper.ContainerGroupProperties.SubnetIds = &cgIDList - cg.ContainerGroupPropertiesWrapper.ContainerGroupProperties.DNSConfig = p.getDNSConfig(ctx, pod) - cg.ContainerGroupPropertiesWrapper.Extensions = p.containerGroupExtensions + cg.ContainerGroupPropertiesWrapper.ContainerGroupProperties.DNSConfig = getDNSConfig(ctx, pod, pn.KubeDNSIP, clusterDomain) } -func (p *ACIProvider) getDNSConfig(ctx context.Context, pod *v1.Pod) *azaci.DNSConfiguration { +func getDNSConfig(ctx context.Context, pod *v1.Pod, kubeDNSIP, clusterDomain string) *azaci.DNSConfiguration { nameServers := make([]string, 0) searchDomains := make([]string, 0) if pod.Spec.DNSPolicy == v1.DNSClusterFirst || pod.Spec.DNSPolicy == v1.DNSClusterFirstWithHostNet { - nameServers = append(nameServers, p.kubeDNSIP) - searchDomains = p.generateSearchesForDNSClusterFirst(pod.Spec.DNSConfig, pod) + nameServers = append(nameServers, kubeDNSIP) + searchDomains = generateSearchesForDNSClusterFirst(pod.Spec.DNSConfig, pod, clusterDomain) } options := make([]string, 0) if pod.Spec.DNSConfig != nil { - nameServers = omitDuplicates(append(nameServers, pod.Spec.DNSConfig.Nameservers...)) - searchDomains = omitDuplicates(append(searchDomains, pod.Spec.DNSConfig.Searches...)) + nameServers = util.OmitDuplicates(append(nameServers, pod.Spec.DNSConfig.Nameservers...)) + searchDomains = util.OmitDuplicates(append(searchDomains, pod.Spec.DNSConfig.Searches...)) for _, option := range pod.Spec.DNSConfig.Options { op := option.Name @@ -226,22 +212,21 @@ func (p *ACIProvider) getDNSConfig(ctx context.Context, pod *v1.Pod) *azaci.DNSC } // This is taken from the kubelet equivalent - https://github.com/kubernetes/kubernetes/blob/d24fe8a801748953a5c34fd34faa8005c6ad1770/pkg/kubelet/network/dns/dns.go#L141-L151 -func (p *ACIProvider) generateSearchesForDNSClusterFirst(dnsConfig *v1.PodDNSConfig, pod *v1.Pod) []string { - +func generateSearchesForDNSClusterFirst(dnsConfig *v1.PodDNSConfig, pod *v1.Pod, clusterDomain string) []string { hostSearch := make([]string, 0) if dnsConfig != nil { hostSearch = dnsConfig.Searches } - if p.clusterDomain == "" { + if clusterDomain == "" { return hostSearch } - nsSvcDomain := fmt.Sprintf("%s.svc.%s", pod.Namespace, p.clusterDomain) - svcDomain := fmt.Sprintf("svc.%s", p.clusterDomain) - clusterSearch := []string{nsSvcDomain, svcDomain, p.clusterDomain} + nsSvcDomain := fmt.Sprintf("%s.svc.%s", pod.Namespace, clusterDomain) + svcDomain := fmt.Sprintf("svc.%s", clusterDomain) + clusterSearch := []string{nsSvcDomain, svcDomain, clusterDomain} - return omitDuplicates(append(clusterSearch, hostSearch...)) + return util.OmitDuplicates(append(clusterSearch, hostSearch...)) } // https://github.com/kubernetes/kubernetes/blob/4276ed36282405d026d8072e0ebed4f1da49070d/pkg/kubelet/network/dns/dns.go#L101-L149 @@ -298,12 +283,3 @@ func formDNSSearchFitsLimits(ctx context.Context, searches []string) string { return strings.Join(searches, " ") } - -func getProtocol(pro v1.Protocol) azaci.ContainerNetworkProtocol { - switch pro { - case v1.ProtocolUDP: - return azaci.ContainerNetworkProtocolUDP - default: - return azaci.ContainerNetworkProtocolTCP - } -} diff --git a/pkg/network/aci_network_test.go b/pkg/network/aci_network_test.go new file mode 100644 index 00000000..24ea1c2d --- /dev/null +++ b/pkg/network/aci_network_test.go @@ -0,0 +1,137 @@ +package network + +import ( + "context" + "fmt" + "testing" + + "github.com/google/uuid" + "github.com/stretchr/testify/assert" + testsutil "github.com/virtual-kubelet/azure-aci/pkg/tests" + v1 "k8s.io/api/core/v1" +) + +func TestGetDNSConfig(t *testing.T) { + kubeDNSIP := "10.0.0.10" + clusterDomain := "fakeClusterDomain" + podName := "pod-" + uuid.New().String() + podNamespace := "ns-" + uuid.New().String() + + testCases := []struct { + desc string + prepPodFunc func(p *v1.Pod) + kubeDNSIP bool + shouldHaveClusterDomain bool + }{ + { + desc: fmt.Sprint("Pod with DNSPolicy == ", v1.DNSClusterFirst, "with DNSConfig"), + prepPodFunc: func(p *v1.Pod) { + p.Spec.DNSPolicy = v1.DNSClusterFirst + p.Spec.DNSConfig = &v1.PodDNSConfig{ + Nameservers: []string{"clusterFirstNS"}, + Searches: []string{"clusterFirstSearches"}, + } + }, + kubeDNSIP: true, + shouldHaveClusterDomain: true, + }, + { + desc: fmt.Sprint("Pod with DNSPolicy == ", v1.DNSClusterFirstWithHostNet, "with DNSConfig"), + prepPodFunc: func(p *v1.Pod) { + p.Spec.DNSPolicy = v1.DNSClusterFirstWithHostNet + p.Spec.DNSConfig = &v1.PodDNSConfig{ + Nameservers: []string{"clusterFirstWithHostNettNS"}, + Searches: []string{"clusterFirstWithHostNetSearches"}, + } + }, + kubeDNSIP: true, + shouldHaveClusterDomain: true, + }, + { + desc: "Pod with other valid DNSPolicy and DNSConfig", + prepPodFunc: func(p *v1.Pod) { + p.Spec.DNSPolicy = v1.DNSDefault + p.Spec.DNSConfig = &v1.PodDNSConfig{ + Nameservers: []string{"defaultNS"}, + Searches: []string{"defaultSearches"}, + } + }, + kubeDNSIP: false, + shouldHaveClusterDomain: false, + }, + } + for i, tc := range testCases { + t.Run(tc.desc, func(t *testing.T) { + ctx := context.TODO() + testPod := testsutil.CreatePodObj(podName, podNamespace) + tc.prepPodFunc(testPod) + aciDNSConfig := getDNSConfig(ctx, testPod, kubeDNSIP, clusterDomain) + + if tc.kubeDNSIP { + assert.Contains(t, *aciDNSConfig.NameServers, kubeDNSIP, "test [%d]", i) + } + if tc.shouldHaveClusterDomain { + assert.Contains(t, *aciDNSConfig.SearchDomains, clusterDomain, "test [%d]", i) + } + }) + } +} + +func TestFormDNSSearchFitsLimits(t *testing.T) { + testCases := []struct { + desc string + hostNames []string + resultSearch []string + expandedDNSConfig bool + }{ + { + desc: "3 search paths", + hostNames: []string{"testNS.svc.TEST", "svc.TEST", "TEST"}, + resultSearch: []string{"testNS.svc.TEST", "svc.TEST", "TEST"}, + }, + { + desc: fmt.Sprint("5 search paths will get omitted to the max (", maxDNSNameservers, ")"), + hostNames: []string{"testNS.svc.TEST", "svc.TEST", "TEST", "AA", "BB"}, + resultSearch: []string{"testNS.svc.TEST", "svc.TEST", "TEST"}, + }, + } + + for i, tc := range testCases { + t.Run(tc.desc, func(t *testing.T) { + ctx := context.TODO() + dnsSearch := formDNSNameserversFitsLimits(ctx, tc.hostNames) + assert.EqualValues(t, tc.resultSearch, dnsSearch, "test [%d]", i) + }) + } +} + +// https://github.com/kubernetes/kubernetes/blob/4276ed36282405d026d8072e0ebed4f1da49070d/pkg/kubelet/network/dns/dns_test.go#L246 +func TestFormDNSNameserversFitsLimits(t *testing.T) { + testCases := []struct { + desc string + nameservers []string + expectedNameserver []string + }{ + { + desc: "valid: 1 nameserver", + nameservers: []string{"127.0.0.1"}, + expectedNameserver: []string{"127.0.0.1"}, + }, + { + desc: "valid: 3 nameservers", + nameservers: []string{"127.0.0.1", "10.0.0.10", "8.8.8.8"}, + expectedNameserver: []string{"127.0.0.1", "10.0.0.10", "8.8.8.8"}, + }, + { + desc: "invalid: 4 nameservers, trimmed to 3", + nameservers: []string{"127.0.0.1", "10.0.0.10", "8.8.8.8", "1.2.3.4"}, + expectedNameserver: []string{"127.0.0.1", "10.0.0.10", "8.8.8.8"}, + }, + } + + for _, tc := range testCases { + ctx := context.TODO() + appliedNameservers := formDNSNameserversFitsLimits(ctx, tc.nameservers) + assert.EqualValues(t, tc.expectedNameserver, appliedNameservers, tc.desc) + } +} diff --git a/pkg/provider/aci.go b/pkg/provider/aci.go index d25481d1..96858c46 100644 --- a/pkg/provider/aci.go +++ b/pkg/provider/aci.go @@ -24,6 +24,8 @@ import ( client2 "github.com/virtual-kubelet/azure-aci/pkg/client" "github.com/virtual-kubelet/azure-aci/pkg/featureflag" "github.com/virtual-kubelet/azure-aci/pkg/metrics" + "github.com/virtual-kubelet/azure-aci/pkg/network" + "github.com/virtual-kubelet/azure-aci/pkg/util" "github.com/virtual-kubelet/azure-aci/pkg/validation" "github.com/virtual-kubelet/node-cli/manager" "github.com/virtual-kubelet/virtual-kubelet/errdefs" @@ -41,7 +43,6 @@ const ( virtualKubeletDNSNameLabel = "virtualkubelet.io/dnsnamelabel" - subnetDelegationService = "Microsoft.ContainerInstance/containerGroups" // Parameter names defined in azure file CSI driver, refer to // https://github.com/kubernetes-sigs/azurefile-csi-driver/blob/master/docs/driver-parameters.md azureFileShareName = "shareName" @@ -72,6 +73,7 @@ type ACIProvider struct { resourceManager *manager.ResourceManager containerGroupExtensions []*client2.Extension enabledFeatures *featureflag.FlagIdentifier + providernetwork network.ProviderNetwork resourceGroup string region string @@ -85,13 +87,7 @@ type ACIProvider struct { internalIP string daemonEndpointPort int32 diagnostics *azaci.ContainerGroupDiagnostics - subnetName string - subnetCIDR string - vnetSubscriptionID string - vnetName string - vnetResourceGroup string clusterDomain string - kubeDNSIP string tracker *PodsTracker *metrics.ACIPodMetricsProvider @@ -197,12 +193,12 @@ func NewACIProvider(ctx context.Context, config string, azConfig auth.Config, az if azConfig.AKSCredential != nil { p.resourceGroup = azConfig.AKSCredential.ResourceGroup p.region = azConfig.AKSCredential.Region - p.vnetName = azConfig.AKSCredential.VNetName - p.vnetResourceGroup = azConfig.AKSCredential.VNetResourceGroup + p.providernetwork.VnetName = azConfig.AKSCredential.VNetName + p.providernetwork.VnetResourceGroup = azConfig.AKSCredential.VNetResourceGroup } - if p.vnetResourceGroup == "" { - p.vnetResourceGroup = p.resourceGroup + if p.providernetwork.VnetResourceGroup == "" { + p.providernetwork.VnetResourceGroup = p.resourceGroup } // If the log analytics file has been specified, load workspace credentials from the file if logAnalyticsAuthFile := os.Getenv("LOG_ANALYTICS_AUTH_LOCATION"); logAnalyticsAuthFile != "" { @@ -255,10 +251,17 @@ func NewACIProvider(ctx context.Context, config string, azConfig auth.Config, az return nil, err } - if err := p.setVNETConfig(ctx, &azConfig); err != nil { + if err := p.providernetwork.SetVNETConfig(ctx, &azConfig); err != nil { return nil, err } + if p.providernetwork.SubnetName != "" { + err = p.setACIExtensions(ctx) + if err != nil { + return nil, err + } + } + p.ACIPodMetricsProvider = metrics.NewACIPodMetricsProvider(nodeName, p.resourceGroup, p.resourceManager, p.azClientsAPIs) return &p, err } @@ -337,7 +340,7 @@ func (p *ACIProvider) CreatePod(ctx context.Context, pod *v1.Pod) error { }) } } - if len(ports) > 0 && p.subnetName == "" { + if len(ports) > 0 && p.providernetwork.SubnetName == "" { cg.ContainerGroupPropertiesWrapper.ContainerGroupProperties.IPAddress = &azaci.IPAddress{ Ports: &ports, Type: azaci.ContainerGroupIPAddressTypePublic, @@ -359,13 +362,41 @@ func (p *ACIProvider) CreatePod(ctx context.Context, pod *v1.Pod) error { "CreationTimestamp": &podCreationTimestamp, } - p.amendVnetResources(ctx, *cg, pod) + p.providernetwork.AmendVnetResources(ctx, *cg, pod, p.clusterDomain) + + cg.ContainerGroupPropertiesWrapper.Extensions = p.containerGroupExtensions log.G(ctx).Infof("start creating pod %v", pod.Name) // TODO: Run in a go routine to not block workers, and use tracker.UpdatePodStatus() based on result. return p.azClientsAPIs.CreateContainerGroup(ctx, p.resourceGroup, pod.Namespace, pod.Name, cg) } +// setACIExtensions +func (p *ACIProvider) setACIExtensions(ctx context.Context) error { + masterURI := os.Getenv("MASTER_URI") + if masterURI == "" { + masterURI = "10.0.0.1" + } + clusterCIDR := os.Getenv("CLUSTER_CIDR") + if clusterCIDR == "" { + clusterCIDR = "10.240.0.0/16" + } + + kubeExtensions, err := client2.GetKubeProxyExtension(serviceAccountSecretMountPath, masterURI, clusterCIDR) + if err != nil { + return fmt.Errorf("error creating kube proxy extension: %v", err) + } + + p.containerGroupExtensions = append(p.containerGroupExtensions, kubeExtensions) + + enableRealTimeMetricsExtension := os.Getenv("ENABLE_REAL_TIME_METRICS") + if enableRealTimeMetricsExtension == "true" { + realtimeExtension := client2.GetRealtimeMetricsExtension() + p.containerGroupExtensions = append(p.containerGroupExtensions, realtimeExtension) + } + return nil +} + func (p *ACIProvider) getDiagnostics(pod *v1.Pod) *azaci.ContainerGroupDiagnostics { if p.diagnostics != nil && p.diagnostics.LogAnalytics != nil && p.diagnostics.LogAnalytics.LogType == azaci.LogAnalyticsLogTypeContainerInsights { d := *p.diagnostics @@ -960,7 +991,7 @@ func (p *ACIProvider) getContainers(pod *v1.Pod) (*[]azaci.Container, error) { containerPorts := aciContainer.Ports containerPortsList := append(*containerPorts, azaci.ContainerPort{ Port: &podContainers[c].Ports[i].ContainerPort, - Protocol: getProtocol(podContainers[c].Ports[i].Protocol), + Protocol: util.GetProtocol(podContainers[c].Ports[i].Protocol), }) aciContainer.Ports = &containerPortsList } diff --git a/pkg/provider/aci_init_container_test.go b/pkg/provider/aci_init_container_test.go index 5a1710d3..13a93460 100644 --- a/pkg/provider/aci_init_container_test.go +++ b/pkg/provider/aci_init_container_test.go @@ -178,7 +178,7 @@ func TestCreatePodWithInitContainers(t *testing.T) { t.Run(tc.description, func(t *testing.T) { ctx := context.TODO() - + resourceManager, err := manager.NewResourceManager( NewMockPodLister(mockCtrl), NewMockSecretLister(mockCtrl), diff --git a/pkg/provider/aci_test.go b/pkg/provider/aci_test.go index bfc19796..ae57022c 100644 --- a/pkg/provider/aci_test.go +++ b/pkg/provider/aci_test.go @@ -19,6 +19,7 @@ import ( "github.com/virtual-kubelet/azure-aci/pkg/auth" "github.com/virtual-kubelet/azure-aci/pkg/client" testsutil "github.com/virtual-kubelet/azure-aci/pkg/tests" + "github.com/virtual-kubelet/azure-aci/pkg/util" "github.com/virtual-kubelet/node-cli/manager" "gotest.tools/assert" @@ -152,7 +153,7 @@ func TestCreatePodWithoutResourceSpec(t *testing.T) { } if err := provider.CreatePod(context.Background(), pod); err != nil { - t.Fatal("Failed to create pod", err) + t.Fatal("failed to create pod", err) } } @@ -873,3 +874,59 @@ func TestCreatedPodWithContainerPort(t *testing.T) { }) } } + +func TestGetPodWithContainerID(t *testing.T) { + podName := "pod-" + uuid.New().String() + podNamespace := "ns-" + uuid.New().String() + mockCtrl := gomock.NewController(t) + defer mockCtrl.Finish() + + podLister := NewMockPodLister(mockCtrl) + + mockPodsNamespaceLister := NewMockPodNamespaceLister(mockCtrl) + podLister.EXPECT().Pods(podNamespace).Return(mockPodsNamespaceLister) + mockPodsNamespaceLister.EXPECT().Get(podName). + Return(testsutil.CreatePodObj(podName, podNamespace), nil) + + err := azConfig.SetAuthConfig() + if err != nil { + t.Fatal("failed to get auth configuration", err) + } + + aciMocks := createNewACIMock() + cgID := "" + aciMocks.MockGetContainerGroupInfo = func(ctx context.Context, resourceGroup, namespace, name, nodeName string) (*azaci.ContainerGroup, error) { + + cg := testsutil.CreateContainerGroupObj(podName, podNamespace, "Succeeded", + testsutil.CreateACIContainersListObj("Running", "Initializing", testsutil.CgCreationTime.Add(time.Second*2), testsutil.CgCreationTime.Add(time.Second*3), false, false, false), "Succeeded") + cgID = *cg.ID + return cg, nil + } + + resourceManager, err := manager.NewResourceManager( + podLister, + NewMockSecretLister(mockCtrl), + NewMockConfigMapLister(mockCtrl), + NewMockServiceLister(mockCtrl), + NewMockPersistentVolumeClaimLister(mockCtrl), + NewMockPersistentVolumeLister(mockCtrl)) + if err != nil { + t.Fatal("Unable to prepare the mocks for resourceManager", err) + } + + provider, err := createTestProvider(aciMocks, resourceManager) + if err != nil { + t.Fatal("failed to create the test provider", err) + } + + pod, err := provider.GetPod(context.Background(), podNamespace, podName) + if err != nil { + t.Fatal("Failed to get pod", err) + } + + assert.Check(t, &pod != nil, "Response pod should not be nil") + assert.Check(t, is.Equal(1, len(pod.Status.ContainerStatuses)), "1 container status is expected") + assert.Check(t, is.Equal(testsutil.TestContainerName, pod.Status.ContainerStatuses[0].Name), "Container name in the container status doesn't match") + assert.Check(t, is.Equal(testsutil.TestImageNginx, pod.Status.ContainerStatuses[0].Image), "Container image in the container status doesn't match") + assert.Check(t, is.Equal(util.GetContainerID(&cgID, &testsutil.TestContainerName), pod.Status.ContainerStatuses[0].ContainerID), "Container ID in the container status is not expected") +} diff --git a/pkg/provider/aci_utils_test.go b/pkg/provider/aci_utils_test.go deleted file mode 100644 index c4938ea6..00000000 --- a/pkg/provider/aci_utils_test.go +++ /dev/null @@ -1,75 +0,0 @@ -/* -Copyright (c) Microsoft Corporation. -Licensed under the Apache 2.0 license. -*/ -package provider - -import ( - "context" - "testing" - "time" - - azaci "github.com/Azure/azure-sdk-for-go/services/containerinstance/mgmt/2021-10-01/containerinstance" - "github.com/golang/mock/gomock" - "github.com/google/uuid" - testsutil "github.com/virtual-kubelet/azure-aci/pkg/tests" - "github.com/virtual-kubelet/node-cli/manager" - "gotest.tools/assert" - is "gotest.tools/assert/cmp" -) - -func TestGetPodWithContainerID(t *testing.T) { - podName := "pod-" + uuid.New().String() - podNamespace := "ns-" + uuid.New().String() - mockCtrl := gomock.NewController(t) - defer mockCtrl.Finish() - - podLister := NewMockPodLister(mockCtrl) - - mockPodsNamespaceLister := NewMockPodNamespaceLister(mockCtrl) - podLister.EXPECT().Pods(podNamespace).Return(mockPodsNamespaceLister) - mockPodsNamespaceLister.EXPECT().Get(podName). - Return(testsutil.CreatePodObj(podName, podNamespace), nil) - - err := azConfig.SetAuthConfig() - if err != nil { - t.Fatal("failed to get auth configuration", err) - } - - aciMocks := createNewACIMock() - cgID := "" - aciMocks.MockGetContainerGroupInfo = func(ctx context.Context, resourceGroup, namespace, name, nodeName string) (*azaci.ContainerGroup, error) { - - cg := testsutil.CreateContainerGroupObj(podName, podNamespace, "Succeeded", - testsutil.CreateACIContainersListObj("Running", "Initializing", testsutil.CgCreationTime.Add(time.Second*2), testsutil.CgCreationTime.Add(time.Second*3), false, false, false), "Succeeded") - cgID = *cg.ID - return cg, nil - } - - resourceManager, err := manager.NewResourceManager( - podLister, - NewMockSecretLister(mockCtrl), - NewMockConfigMapLister(mockCtrl), - NewMockServiceLister(mockCtrl), - NewMockPersistentVolumeClaimLister(mockCtrl), - NewMockPersistentVolumeLister(mockCtrl)) - if err != nil { - t.Fatal("Unable to prepare the mocks for resourceManager", err) - } - - provider, err := createTestProvider(aciMocks, resourceManager) - if err != nil { - t.Fatal("failed to create the test provider", err) - } - - pod, err := provider.GetPod(context.Background(), podNamespace, podName) - if err != nil { - t.Fatal("Failed to get pod", err) - } - - assert.Check(t, &pod != nil, "Response pod should not be nil") - assert.Check(t, is.Equal(1, len(pod.Status.ContainerStatuses)), "1 container status is expected") - assert.Check(t, is.Equal(testsutil.TestContainerName, pod.Status.ContainerStatuses[0].Name), "Container name in the container status doesn't match") - assert.Check(t, is.Equal(testsutil.TestImageNginx, pod.Status.ContainerStatuses[0].Image), "Container image in the container status doesn't match") - assert.Check(t, is.Equal(getContainerID(&cgID, &testsutil.TestContainerName), pod.Status.ContainerStatuses[0].ContainerID), "Container ID in the container status is not expected") -} diff --git a/pkg/provider/aci_volumes_test.go b/pkg/provider/aci_volumes_test.go index 9ab3283f..1bcfc896 100644 --- a/pkg/provider/aci_volumes_test.go +++ b/pkg/provider/aci_volumes_test.go @@ -201,18 +201,16 @@ func TestCreatedPodWithAzureFilesVolume(t *testing.T) { err = provider.CreatePod(context.Background(), pod) - for _, vol := range tc.volumes { - if vol.Name == azureFileVolumeName1 || vol.Name == azureFileVolumeName2 { - if tc.expectedError != nil { - assert.Equal(t, tc.expectedError.Error(), err.Error()) - } else { - assert.NilError(t, tc.expectedError, err) - } - } + if tc.expectedError == nil { + assert.NilError(t, tc.expectedError, err) + } else { + assert.Equal(t, tc.expectedError.Error(), err.Error()) } + }) } } + func TestCreatePodWithProjectedVolume(t *testing.T) { projectedVolumeName := "projectedvolume" fakeSecretName := "fake-secret" @@ -516,15 +514,12 @@ func TestCreatePodWithCSIVolume(t *testing.T) { err = provider.CreatePod(context.Background(), pod) - for _, vol := range tc.volumes { - if vol.Name == azureFileVolumeName { - if tc.expectedError != nil { - assert.Equal(t, tc.expectedError.Error(), err.Error()) - } else { - assert.NilError(t, tc.expectedError, err) - } - } + if tc.expectedError == nil { + assert.NilError(t, tc.expectedError, err) + } else { + assert.Equal(t, tc.expectedError.Error(), err.Error()) } + }) } } diff --git a/pkg/provider/config.go b/pkg/provider/config.go index 864d0179..1188d5c3 100644 --- a/pkg/provider/config.go +++ b/pkg/provider/config.go @@ -58,7 +58,7 @@ func (p *ACIProvider) loadConfig(r io.Reader) error { // default subnet name if config.SubnetName != "" { - p.subnetName = config.SubnetName + p.providernetwork.SubnetName = config.SubnetName } if config.SubnetCIDR != "" { if config.SubnetName == "" { diff --git a/pkg/provider/containergroup_to_pod.go b/pkg/provider/containergroup_to_pod.go index c704b5fd..a5587ef2 100644 --- a/pkg/provider/containergroup_to_pod.go +++ b/pkg/provider/containergroup_to_pod.go @@ -10,6 +10,7 @@ import ( azaci "github.com/Azure/azure-sdk-for-go/services/containerinstance/mgmt/2021-10-01/containerinstance" "github.com/pkg/errors" "github.com/virtual-kubelet/azure-aci/pkg/tests" + "github.com/virtual-kubelet/azure-aci/pkg/util" "github.com/virtual-kubelet/azure-aci/pkg/validation" v1 "k8s.io/api/core/v1" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" @@ -58,7 +59,7 @@ func (p *ACIProvider) getPodStatusFromContainerGroup(cg *azaci.ContainerGroup) ( RestartCount: *containersList[i].InstanceView.RestartCount, Image: *containersList[i].Image, ImageID: "", - ContainerID: getContainerID(cg.ID, containersList[i].Name), + ContainerID: util.GetContainerID(cg.ID, containersList[i].Name), } if getPodPhaseFromACIState(*containersList[i].InstanceView.CurrentState.State) != v1.PodRunning && diff --git a/pkg/provider/aci_utils.go b/pkg/util/aci_utils.go similarity index 57% rename from pkg/provider/aci_utils.go rename to pkg/util/aci_utils.go index f358eedc..42a4a244 100644 --- a/pkg/provider/aci_utils.go +++ b/pkg/util/aci_utils.go @@ -2,16 +2,19 @@ Copyright (c) Microsoft Corporation. Licensed under the Apache 2.0 license. */ -package provider +package util import ( "crypto/sha256" "encoding/hex" "fmt" "strings" + + "github.com/Azure/azure-sdk-for-go/services/containerinstance/mgmt/2021-10-01/containerinstance" + v1 "k8s.io/api/core/v1" ) -func getContainerID(cgID, containerName *string) string { +func GetContainerID(cgID, containerName *string) string { if cgID == nil { return "" } @@ -26,7 +29,7 @@ func getContainerID(cgID, containerName *string) string { return fmt.Sprintf("aci://%s", hex.EncodeToString(hashBytes)) } -func omitDuplicates(strs []string) []string { +func OmitDuplicates(strs []string) []string { uniqueStrs := make(map[string]bool) var ret []string @@ -38,3 +41,12 @@ func omitDuplicates(strs []string) []string { } return ret } + +func GetProtocol(pro v1.Protocol) containerinstance.ContainerNetworkProtocol { + switch pro { + case v1.ProtocolUDP: + return containerinstance.ContainerNetworkProtocolUDP + default: + return containerinstance.ContainerNetworkProtocolTCP + } +}