From 76047e00738ba3cbd259d130bf3d0e97adaff7ce Mon Sep 17 00:00:00 2001 From: Yun-Tang Hsu Date: Mon, 12 Sep 2022 18:14:46 -0700 Subject: [PATCH] Add NetworkPolicyRecommnedation rest handler Add unit-test for rest.go Add unit-test for controller.go Signed-off-by: Yun-Tang Hsu --- .../templates/theia-manager/clusterrole.yaml | 8 +- cmd/theia-manager/theia-manager.go | 23 +- pkg/apis/crd/v1alpha1/types.go | 2 +- pkg/apis/intelligence/v1alpha1/types.go | 19 +- .../v1alpha1/zz_generated.deepcopy.go | 7 +- .../networkpolicyrecommendation/rest.go | 84 +++-- .../networkpolicyrecommendation/rest_test.go | 182 +++++++++++ .../networkpolicyrecommendation/controller.go | 126 ++++++- .../controller_test.go | 308 ++++++++++++++++++ pkg/controller/utils.go | 140 ++++++++ pkg/querier/querier.go | 6 +- .../commands/policy_recommendation_status.go | 238 +++----------- 12 files changed, 893 insertions(+), 250 deletions(-) create mode 100644 pkg/apiserver/registry/intelligence/networkpolicyrecommendation/rest_test.go create mode 100644 pkg/controller/networkpolicyrecommendation/controller_test.go create mode 100644 pkg/controller/utils.go diff --git a/build/charts/theia/templates/theia-manager/clusterrole.yaml b/build/charts/theia/templates/theia-manager/clusterrole.yaml index ee9b57fab..3b1a67501 100644 --- a/build/charts/theia/templates/theia-manager/clusterrole.yaml +++ b/build/charts/theia/templates/theia-manager/clusterrole.yaml @@ -18,5 +18,11 @@ rules: verbs: ["get", "list", "watch"] - apiGroups: ["crd.theia.antrea.io"] resources: ["networkpolicyrecommendations"] - verbs: ["get", "list", "watch"] + verbs: ["get", "list", "watch", "create", "delete"] + - apiGroups: [ "" ] + resources: [ "pods" ] + verbs: [ "list"] + - apiGroups: [ "" ] + resources: [ "services", "secrets" ] + verbs: [ "get" ] {{- end }} diff --git a/cmd/theia-manager/theia-manager.go b/cmd/theia-manager/theia-manager.go index cfe847eb5..d88dcb8b1 100644 --- a/cmd/theia-manager/theia-manager.go +++ b/cmd/theia-manager/theia-manager.go @@ -21,6 +21,7 @@ import ( "antrea.io/antrea/pkg/log" "antrea.io/antrea/pkg/signals" "antrea.io/antrea/pkg/util/cipher" + "k8s.io/client-go/kubernetes" "k8s.io/client-go/rest" "k8s.io/klog/v2" @@ -52,10 +53,16 @@ func run(o *Options) error { if err != nil { return fmt.Errorf("error when generating CRD client: %v", err) } + k8sClient, err := createK8sClient() + if err != nil { + return fmt.Errorf("error when creating K8s client: %v", err) + } crdInformerFactory := crdinformers.NewSharedInformerFactory(crdClient, informerDefaultResync) npRecommendationInformer := crdInformerFactory.Crd().V1alpha1().NetworkPolicyRecommendations() - npRecoController := networkpolicyrecommendation.NewNPRecommendationController(crdClient, npRecommendationInformer) - + npRecoController, err := networkpolicyrecommendation.NewNPRecommendationController(crdClient, k8sClient, npRecommendationInformer) + if err != nil { + return fmt.Errorf("error when creating networkPolicyRecommendation controller: %v", err) + } cipherSuites, err := cipher.GenerateCipherSuitesList(o.config.APIServer.TLSCipherSuites) if err != nil { return fmt.Errorf("error when generating Cipher Suite list: %v", err) @@ -77,3 +84,15 @@ func run(o *Options) error { klog.InfoS("Stopping theia manager") return nil } + +func createK8sClient() (kubernetes.Interface, error) { + config, err := rest.InClusterConfig() + if err != nil { + return nil, err + } + k8sClient, err := kubernetes.NewForConfig(config) + if err != nil { + return nil, err + } + return k8sClient, nil +} diff --git a/pkg/apis/crd/v1alpha1/types.go b/pkg/apis/crd/v1alpha1/types.go index ecebc6c6d..856d63e5e 100644 --- a/pkg/apis/crd/v1alpha1/types.go +++ b/pkg/apis/crd/v1alpha1/types.go @@ -35,7 +35,7 @@ type NetworkPolicyRecommendation struct { } type NetworkPolicyRecommendationSpec struct { - Type string `json:"type,omitempty"` + JobType string `json:"jobType,omitempty"` Limit int `json:"limit,omitempty"` PolicyType string `json:"policyType,omitempty"` StartTime metav1.Time `json:"startTime,omitempty"` diff --git a/pkg/apis/intelligence/v1alpha1/types.go b/pkg/apis/intelligence/v1alpha1/types.go index 388296ef9..91eb3e04f 100644 --- a/pkg/apis/intelligence/v1alpha1/types.go +++ b/pkg/apis/intelligence/v1alpha1/types.go @@ -37,8 +37,8 @@ type NetworkPolicyRecommendation struct { Type string `json:"jobType,omitempty"` Limit int `json:"limit,omitempty"` PolicyType string `json:"policyType,omitempty"` - StartInterval metav1.Time `json:"startInterval,omitempty"` - EndInterval metav1.Time `json:"endInterval,omitempty"` + IntervalStart metav1.Time `json:"intervalStart,omitempty"` + IntervalEnd metav1.Time `json:"intervalEnd,omitempty"` NSAllowList []string `json:"nsAllowList,omitempty"` ExcludeLabels bool `json:"excludeLabels,omitempty"` ToServices bool `json:"toServices,omitempty"` @@ -51,14 +51,13 @@ type NetworkPolicyRecommendation struct { } type NetworkPolicyRecommendationStatus struct { - State string `json:"state,omitempty"` - SparkApplication string `json:"sparkApplication,omitempty"` - CompletedStages int `json:"completedStages,omitempty"` - TotalStages int `json:"totalStages,omitempty"` - RecommendationOutcome string `json:"recommendationOutcome,omitempty"` - CompletionTimestamp metav1.Time `json:"completionTimestamp,omitempty"` - ErrorCode string `json:"errorCode,omitempty"` - ErrorMsg string `json:"errorMsg,omitempty"` + State string `json:"state,omitempty"` + SparkApplication string `json:"sparkApplication,omitempty"` + CompletedStages int `json:"completedStages,omitempty"` + TotalStages int `json:"totalStages,omitempty"` + RecommendedNetworkPolicy string `json:"recommendedNetworkPolicy,omitempty"` + ErrorCode string `json:"errorCode,omitempty"` + ErrorMsg string `json:"errorMsg,omitempty"` } // +k8s:deepcopy-gen:interfaces=k8s.io/apimachinery/pkg/runtime.Object diff --git a/pkg/apis/intelligence/v1alpha1/zz_generated.deepcopy.go b/pkg/apis/intelligence/v1alpha1/zz_generated.deepcopy.go index ae94343a0..c5fde9ce5 100644 --- a/pkg/apis/intelligence/v1alpha1/zz_generated.deepcopy.go +++ b/pkg/apis/intelligence/v1alpha1/zz_generated.deepcopy.go @@ -28,14 +28,14 @@ func (in *NetworkPolicyRecommendation) DeepCopyInto(out *NetworkPolicyRecommenda *out = *in out.TypeMeta = in.TypeMeta in.ObjectMeta.DeepCopyInto(&out.ObjectMeta) - in.StartInterval.DeepCopyInto(&out.StartInterval) - in.EndInterval.DeepCopyInto(&out.EndInterval) + in.IntervalStart.DeepCopyInto(&out.IntervalStart) + in.IntervalEnd.DeepCopyInto(&out.IntervalEnd) if in.NSAllowList != nil { in, out := &in.NSAllowList, &out.NSAllowList *out = make([]string, len(*in)) copy(*out, *in) } - in.Status.DeepCopyInto(&out.Status) + out.Status = in.Status return } @@ -93,7 +93,6 @@ func (in *NetworkPolicyRecommendationList) DeepCopyObject() runtime.Object { // DeepCopyInto is an autogenerated deepcopy function, copying the receiver, writing into out. in must be non-nil. func (in *NetworkPolicyRecommendationStatus) DeepCopyInto(out *NetworkPolicyRecommendationStatus) { *out = *in - in.CompletionTimestamp.DeepCopyInto(&out.CompletionTimestamp) return } diff --git a/pkg/apiserver/registry/intelligence/networkpolicyrecommendation/rest.go b/pkg/apiserver/registry/intelligence/networkpolicyrecommendation/rest.go index 67da70338..13ded82b9 100644 --- a/pkg/apiserver/registry/intelligence/networkpolicyrecommendation/rest.go +++ b/pkg/apiserver/registry/intelligence/networkpolicyrecommendation/rest.go @@ -16,6 +16,7 @@ package networkpolicyrecommendation import ( "context" + "fmt" "k8s.io/apimachinery/pkg/api/errors" "k8s.io/apimachinery/pkg/apis/meta/internalversion" @@ -23,6 +24,7 @@ import ( "k8s.io/apimachinery/pkg/runtime" "k8s.io/apiserver/pkg/registry/rest" + crdv1alpha1 "antrea.io/theia/pkg/apis/crd/v1alpha1" intelligence "antrea.io/theia/pkg/apis/intelligence/v1alpha1" "antrea.io/theia/pkg/querier" ) @@ -33,9 +35,11 @@ type REST struct { } var ( - _ rest.Scoper = &REST{} - _ rest.Getter = &REST{} - _ rest.Lister = &REST{} + _ rest.Scoper = &REST{} + _ rest.Getter = &REST{} + _ rest.Lister = &REST{} + _ rest.Creater = &REST{} + _ rest.GracefulDeleter = &REST{} ) // NewREST returns a REST object that will work against API services. @@ -48,30 +52,14 @@ func (r *REST) New() runtime.Object { } func (r *REST) getNetworkPolicyRecommendation(name string) *intelligence.NetworkPolicyRecommendation { - npReco, err := r.npRecommendationQuerier.GetNetworkPolicyRecommendation("flow-visibility", name) + npReco, err := r.npRecommendationQuerier.GetNetworkPolicyRecommendation(name) if err != nil { return nil } - - job := new(intelligence.NetworkPolicyRecommendation) - job.Name = npReco.Name - job.Type = npReco.Spec.Type - job.Limit = npReco.Spec.Limit - job.PolicyType = npReco.Spec.PolicyType - job.StartInterval = npReco.Spec.StartTime - job.EndInterval = npReco.Spec.EndTime - job.NSAllowList = npReco.Spec.NSAllowList - job.ExcludeLabels = npReco.Spec.ExcludeLabels - job.ToServices = npReco.Spec.ToServices - job.ExecutorInstances = npReco.Spec.ExecutorInstances - job.DriverCoreRequest = npReco.Spec.DriverCoreRequest - job.DriverMemory = npReco.Spec.DriverMemory - job.ExecutorCoreRequest = npReco.Spec.ExecutorCoreRequest - job.ExecutorMemory = npReco.Spec.ExecutorMemory - return job + return npReco } -func (r *REST) Get(ctx context.Context, name string, options *metav1.GetOptions) (runtime.Object, error) { +func (r *REST) Get(_ context.Context, name string, _ *metav1.GetOptions) (runtime.Object, error) { job := r.getNetworkPolicyRecommendation(name) if job == nil { return nil, errors.NewNotFound(intelligence.Resource("networkpolicyrecommendations"), name) @@ -83,9 +71,12 @@ func (r *REST) NewList() runtime.Object { return &intelligence.NetworkPolicyRecommendationList{} } -func (r *REST) List(ctx context.Context, options *internalversion.ListOptions) (runtime.Object, error) { - list := new(intelligence.NetworkPolicyRecommendationList) - return list, nil +func (r *REST) List(_ context.Context, _ *internalversion.ListOptions) (runtime.Object, error) { + itemList, err := r.npRecommendationQuerier.ListNetworkPolicyRecommendation() + if err != nil { + return nil, errors.NewBadRequest(fmt.Sprintf("cannot retrieve npr from controller. err:%s", err)) + } + return itemList, nil } func (r *REST) NamespaceScoped() bool { @@ -95,3 +86,46 @@ func (r *REST) NamespaceScoped() bool { func (r *REST) ConvertToTable(ctx context.Context, obj runtime.Object, tableOptions runtime.Object) (*metav1.Table, error) { return rest.NewDefaultTableConvertor(intelligence.Resource("networkpolicyrecommendations")).ConvertToTable(ctx, obj, tableOptions) } + +func (r *REST) Create(_ context.Context, obj runtime.Object, _ rest.ValidateObjectFunc, _ *metav1.CreateOptions) (runtime.Object, error) { + npReco, ok := obj.(*intelligence.NetworkPolicyRecommendation) + if !ok { + return nil, errors.NewBadRequest(fmt.Sprintf("not a NetworkPolicyRecommendation object: %T", obj)) + } + existNPReco, _ := r.npRecommendationQuerier.GetNetworkPolicyRecommendation(npReco.Name) + if existNPReco != nil { + return nil, errors.NewBadRequest(fmt.Sprintf("networkPolicyRecommendation job exists, name: %s", npReco.Name)) + } + job := new(crdv1alpha1.NetworkPolicyRecommendation) + job.Name = npReco.Name + job.Spec.JobType = npReco.Type + job.Spec.Limit = npReco.Limit + job.Spec.PolicyType = npReco.PolicyType + job.Spec.StartTime = npReco.IntervalStart + job.Spec.EndTime = npReco.IntervalEnd + job.Spec.NSAllowList = npReco.NSAllowList + job.Spec.ExcludeLabels = npReco.ExcludeLabels + job.Spec.ToServices = npReco.ToServices + job.Spec.ExecutorInstances = npReco.ExecutorInstances + job.Spec.DriverCoreRequest = npReco.DriverCoreRequest + job.Spec.DriverMemory = npReco.DriverMemory + job.Spec.ExecutorCoreRequest = npReco.ExecutorCoreRequest + job.Spec.ExecutorMemory = npReco.ExecutorMemory + _, err := r.npRecommendationQuerier.CreateNetworkPolicyRecommendation(job) + if err != nil { + return nil, err + } + return &metav1.Status{Status: metav1.StatusSuccess}, nil +} + +func (r *REST) Delete(_ context.Context, name string, _ rest.ValidateObjectFunc, _ *metav1.DeleteOptions) (runtime.Object, bool, error) { + _, err := r.npRecommendationQuerier.GetNetworkPolicyRecommendation(name) + if err != nil { + return nil, false, errors.NewBadRequest(fmt.Sprintf("networkPolicyRecommendation job doesn't exist, name: %s", name)) + } + err = r.npRecommendationQuerier.DeleteNetworkPolicyRecommendation(name) + if err != nil { + return nil, false, err + } + return &metav1.Status{Status: metav1.StatusSuccess}, false, nil +} diff --git a/pkg/apiserver/registry/intelligence/networkpolicyrecommendation/rest_test.go b/pkg/apiserver/registry/intelligence/networkpolicyrecommendation/rest_test.go new file mode 100644 index 000000000..79688babc --- /dev/null +++ b/pkg/apiserver/registry/intelligence/networkpolicyrecommendation/rest_test.go @@ -0,0 +1,182 @@ +// Copyright 2020 Antrea Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package networkpolicyrecommendation + +import ( + "context" + "fmt" + "testing" + + "github.com/stretchr/testify/assert" + "k8s.io/apimachinery/pkg/api/errors" + "k8s.io/apimachinery/pkg/apis/meta/internalversion" + v1 "k8s.io/apimachinery/pkg/apis/meta/v1" + "k8s.io/apimachinery/pkg/runtime" + + "antrea.io/theia/pkg/apis/crd/v1alpha1" + crdv1alpha1 "antrea.io/theia/pkg/apis/crd/v1alpha1" + intelligence "antrea.io/theia/pkg/apis/intelligence/v1alpha1" +) + +type fakeQuerier struct{} + +func TestREST_Get(t *testing.T) { + tests := []struct { + name string + nprName string + expectErr error + expectResult *intelligence.NetworkPolicyRecommendation + }{ + { + name: "Not Found case", + nprName: "npr-1", + expectErr: errors.NewNotFound(intelligence.Resource("networkpolicyrecommendations"), "npr-1"), + expectResult: nil, + }, + { + name: "Successful Get case", + nprName: "npr-2", + expectErr: nil, + expectResult: &intelligence.NetworkPolicyRecommendation{Type: "NPR", PolicyType: "Allow"}, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + r := NewREST(&fakeQuerier{}) + npr, err := r.Get(context.TODO(), tt.nprName, &v1.GetOptions{}) + assert.Equal(t, err, tt.expectErr) + if npr != nil { + assert.Equal(t, tt.expectResult, npr.(*intelligence.NetworkPolicyRecommendation)) + } else { + assert.Nil(t, tt.expectResult) + } + }) + } +} + +func TestREST_Delete(t *testing.T) { + tests := []struct { + name string + nprName string + expectErr error + }{ + { + name: "Job doesn't exist case", + nprName: "npr-1", + expectErr: errors.NewBadRequest(fmt.Sprintf("networkPolicyRecommendation job doesn't exist, name: %s", "npr-1")), + }, + { + name: "Successful Delete case", + nprName: "npr-2", + expectErr: nil, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + r := NewREST(&fakeQuerier{}) + _, _, err := r.Delete(context.TODO(), tt.nprName, nil, &v1.DeleteOptions{}) + assert.Equal(t, err, tt.expectErr) + }) + } +} + +func TestREST_Create(t *testing.T) { + tests := []struct { + name string + obj runtime.Object + expectErr error + expectResult runtime.Object + }{ + { + name: "Wrong object case", + obj: &crdv1alpha1.NetworkPolicyRecommendation{}, + expectErr: errors.NewBadRequest(fmt.Sprintf("not a NetworkPolicyRecommendation object: %T", &crdv1alpha1.NetworkPolicyRecommendation{})), + expectResult: nil, + }, + { + name: "Job already exists case", + obj: &intelligence.NetworkPolicyRecommendation{ + TypeMeta: v1.TypeMeta{}, + ObjectMeta: v1.ObjectMeta{Name: "npr-2"}, + }, + expectErr: errors.NewBadRequest(fmt.Sprintf("networkPolicyRecommendation job exists, name: %s", "npr-2")), + expectResult: nil, + }, + { + name: "Successful Create case", + obj: &intelligence.NetworkPolicyRecommendation{ + TypeMeta: v1.TypeMeta{}, + ObjectMeta: v1.ObjectMeta{Name: "npr-1"}, + }, + expectErr: nil, + expectResult: &v1.Status{Status: v1.StatusSuccess}, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + r := NewREST(&fakeQuerier{}) + result, err := r.Create(context.TODO(), tt.obj, nil, &v1.CreateOptions{}) + assert.Equal(t, err, tt.expectErr) + assert.Equal(t, tt.expectResult, result) + }) + } +} + +func TestREST_List(t *testing.T) { + tests := []struct { + name string + expectResult []intelligence.NetworkPolicyRecommendation + }{ + { + name: "Successful List case", + expectResult: []intelligence.NetworkPolicyRecommendation{ + {ObjectMeta: v1.ObjectMeta{Name: "npr-1"}}, + {ObjectMeta: v1.ObjectMeta{Name: "npr-2"}}, + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + r := NewREST(&fakeQuerier{}) + itemList, err := r.List(context.TODO(), &internalversion.ListOptions{}) + assert.NoError(t, err) + nprList, ok := itemList.(*intelligence.NetworkPolicyRecommendationList) + assert.True(t, ok) + assert.ElementsMatch(t, tt.expectResult, nprList.Items) + }) + } +} + +func (c *fakeQuerier) GetNetworkPolicyRecommendation(name string) (*intelligence.NetworkPolicyRecommendation, error) { + if name == "npr-1" { + return nil, fmt.Errorf("not found") + } + return &intelligence.NetworkPolicyRecommendation{Type: "NPR", PolicyType: "Allow"}, nil +} + +func (c *fakeQuerier) CreateNetworkPolicyRecommendation(*v1alpha1.NetworkPolicyRecommendation) (*v1alpha1.NetworkPolicyRecommendation, error) { + return nil, nil +} + +func (c *fakeQuerier) DeleteNetworkPolicyRecommendation(name string) error { + return nil +} + +func (c *fakeQuerier) ListNetworkPolicyRecommendation() (*intelligence.NetworkPolicyRecommendationList, error) { + return &intelligence.NetworkPolicyRecommendationList{Items: []intelligence.NetworkPolicyRecommendation{ + {ObjectMeta: v1.ObjectMeta{Name: "npr-1"}}, + {ObjectMeta: v1.ObjectMeta{Name: "npr-2"}}, + }}, nil +} diff --git a/pkg/controller/networkpolicyrecommendation/controller.go b/pkg/controller/networkpolicyrecommendation/controller.go index e60f0a1cd..d6437a7b0 100644 --- a/pkg/controller/networkpolicyrecommendation/controller.go +++ b/pkg/controller/networkpolicyrecommendation/controller.go @@ -15,19 +15,26 @@ package networkpolicyrecommendation import ( + "context" + "database/sql" + "fmt" "time" apimachineryerrors "k8s.io/apimachinery/pkg/api/errors" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" apimachinerytypes "k8s.io/apimachinery/pkg/types" "k8s.io/apimachinery/pkg/util/wait" + "k8s.io/client-go/kubernetes" "k8s.io/client-go/tools/cache" "k8s.io/client-go/util/workqueue" "k8s.io/klog/v2" crdv1alpha1 "antrea.io/theia/pkg/apis/crd/v1alpha1" + intelligence "antrea.io/theia/pkg/apis/intelligence/v1alpha1" "antrea.io/theia/pkg/client/clientset/versioned" crdv1a1informers "antrea.io/theia/pkg/client/informers/externalversions/crd/v1alpha1" "antrea.io/theia/pkg/client/listers/crd/v1alpha1" + "antrea.io/theia/pkg/controller" ) const ( @@ -38,7 +45,8 @@ const ( minRetryDelay = 5 * time.Second maxRetryDelay = 300 * time.Second // Default number of workers processing an Service change. - defaultWorkers = 4 + defaultWorkers = 4 + defaultNameSpace = "flow-visibility" ) type NPRecommendationController struct { @@ -48,19 +56,31 @@ type NPRecommendationController struct { npRecommendationLister v1alpha1.NetworkPolicyRecommendationLister npRecommendationSynced cache.InformerSynced // queue maintains the Service objects that need to be synced. - queue workqueue.RateLimitingInterface + queue workqueue.RateLimitingInterface + connect *sql.DB } func NewNPRecommendationController( crdClient versioned.Interface, + k8sClient kubernetes.Interface, npRecommendationInformer crdv1a1informers.NetworkPolicyRecommendationInformer, -) *NPRecommendationController { +) (*NPRecommendationController, error) { + err := controller.CheckClickHousePod(k8sClient) + if err != nil { + return nil, fmt.Errorf("error when checking ClickHouse status: %v", err) + } + connect, err := controller.SetupClickHouseConnection(k8sClient) + if err != nil { + return nil, fmt.Errorf("error when connecting to ClickHouse: %v", err) + } + c := &NPRecommendationController{ crdClient: crdClient, queue: workqueue.NewNamedRateLimitingQueue(workqueue.NewItemExponentialFailureRateLimiter(minRetryDelay, maxRetryDelay), "npRecommendation"), npRecommendationInformer: npRecommendationInformer.Informer(), npRecommendationLister: npRecommendationInformer.Lister(), npRecommendationSynced: npRecommendationInformer.Informer().HasSynced, + connect: connect, } c.npRecommendationInformer.AddEventHandlerWithResyncPeriod( @@ -71,7 +91,7 @@ func NewNPRecommendationController( resyncPeriod, ) - return c + return c, nil } func (c *NPRecommendationController) addNPRecommendation(obj interface{}) { @@ -173,6 +193,100 @@ func (c *NPRecommendationController) syncNPRecommendation(key apimachinerytypes. return nil } -func (c *NPRecommendationController) GetNetworkPolicyRecommendation(namespace, name string) (*crdv1alpha1.NetworkPolicyRecommendation, error) { - return c.npRecommendationLister.NetworkPolicyRecommendations(namespace).Get(name) +func (c *NPRecommendationController) GetNetworkPolicyRecommendation(name string) (*intelligence.NetworkPolicyRecommendation, error) { + npReco, err := c.npRecommendationLister.NetworkPolicyRecommendations(defaultNameSpace).Get(name) + if err != nil { + return nil, fmt.Errorf("error when finding NetworkPolicyRecommendations CR: %v", err) + } + intelli := new(intelligence.NetworkPolicyRecommendation) + err = c.copyNetworkPolicyRecommendation(intelli, npReco) + if err != nil { + return nil, fmt.Errorf("error when copying NetworkPolicyRecommendations CR: %v", err) + } + return intelli, nil +} + +func (c *NPRecommendationController) CreateNetworkPolicyRecommendation(npReco *crdv1alpha1.NetworkPolicyRecommendation) (*crdv1alpha1.NetworkPolicyRecommendation, error) { + return c.crdClient.CrdV1alpha1().NetworkPolicyRecommendations(defaultNameSpace).Create(context.TODO(), npReco, metav1.CreateOptions{}) +} + +func (c *NPRecommendationController) DeleteNetworkPolicyRecommendation(name string) error { + // delete NetworkPolicyRecommendation and RecommendedNetworkPolicy + result, _ := c.getRecommendedNetworkPolicyResult(name) + if result != "" { + err := c.deleteRecommendedNetworkPolicyResult(name) + if err != nil { + return fmt.Errorf("error when delete result in ClickHouse: %v", err) + } + } + err := c.crdClient.CrdV1alpha1().NetworkPolicyRecommendations(defaultNameSpace).Delete(context.TODO(), name, metav1.DeleteOptions{}) + return err +} + +func (c *NPRecommendationController) ListNetworkPolicyRecommendation() (*intelligence.NetworkPolicyRecommendationList, error) { + npRecoItems, err := c.crdClient.CrdV1alpha1().NetworkPolicyRecommendations(defaultNameSpace).List(context.TODO(), metav1.ListOptions{}) + if err != nil { + return nil, fmt.Errorf("error when getting NetworkPolicyRecommendationsList: %v", err) + } + items := make([]intelligence.NetworkPolicyRecommendation, 0, len(npRecoItems.Items)) + for _, npReco := range npRecoItems.Items { + job := intelligence.NetworkPolicyRecommendation{} + err = c.copyNetworkPolicyRecommendation(&job, &npReco) + if err != nil { + return nil, fmt.Errorf("error when copying NetworkPolicyRecommendation CR: %v", err) + } + items = append(items, job) + } + list := &intelligence.NetworkPolicyRecommendationList{Items: items} + return list, nil +} + +// getRecommendedNetworkPolicyResult is used to get Recommended Network Policy in ClickHouse +func (c *NPRecommendationController) getRecommendedNetworkPolicyResult(id string) (string, error) { + var recoResult string + query := "SELECT yamls FROM recommendations WHERE id = (?)" + err := c.connect.QueryRow(query, id).Scan(&recoResult) + if err != nil { + return recoResult, fmt.Errorf("failed to get Recommended Network Policy Result with id %s: %v", id, err) + } + return recoResult, nil +} + +// deleteRecommendedNetworkPolicyResult is used to delete Recommended Network Policy in ClickHouse +func (c *NPRecommendationController) deleteRecommendedNetworkPolicyResult(recoID string) error { + query := "ALTER TABLE recommendations_local ON CLUSTER '{cluster}' DELETE WHERE id = (?)" + _, err := c.connect.Exec(query, recoID) + if err != nil { + return fmt.Errorf("failed to delete Recommended Network Policy Result with id %s: %v", recoID, err) + } + return nil +} + +// copyNetworkPolicyRecommendation is used to copy NetworkPolicyRecommendation from crd to intelligence +func (c *NPRecommendationController) copyNetworkPolicyRecommendation(intelli *intelligence.NetworkPolicyRecommendation, crd *crdv1alpha1.NetworkPolicyRecommendation) error { + intelli.Name = crd.Name + intelli.Type = crd.Spec.JobType + intelli.Limit = crd.Spec.Limit + intelli.PolicyType = crd.Spec.PolicyType + intelli.IntervalStart = crd.Spec.StartTime + intelli.IntervalEnd = crd.Spec.EndTime + intelli.NSAllowList = crd.Spec.NSAllowList + intelli.ExcludeLabels = crd.Spec.ExcludeLabels + intelli.ToServices = crd.Spec.ToServices + intelli.ExecutorInstances = crd.Spec.ExecutorInstances + intelli.DriverCoreRequest = crd.Spec.DriverCoreRequest + intelli.DriverMemory = crd.Spec.DriverMemory + intelli.ExecutorCoreRequest = crd.Spec.ExecutorCoreRequest + intelli.ExecutorMemory = crd.Spec.ExecutorMemory + intelli.Status.State = crd.Status.State + // todo: need to check and add other status field. + if intelli.Status.State != "COMPLETE" { + return nil + } + result, err := c.getRecommendedNetworkPolicyResult(crd.Name) + if err != nil { + return fmt.Errorf("error when getting result from ClickHouse: %v", err) + } + intelli.Status.RecommendedNetworkPolicy = result + return nil } diff --git a/pkg/controller/networkpolicyrecommendation/controller_test.go b/pkg/controller/networkpolicyrecommendation/controller_test.go new file mode 100644 index 000000000..9c15115cc --- /dev/null +++ b/pkg/controller/networkpolicyrecommendation/controller_test.go @@ -0,0 +1,308 @@ +package networkpolicyrecommendation + +import ( + "fmt" + "testing" + "time" + + "github.com/DATA-DOG/go-sqlmock" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + "k8s.io/apimachinery/pkg/util/wait" + "k8s.io/client-go/util/workqueue" + + crdv1alpha1 "antrea.io/theia/pkg/apis/crd/v1alpha1" + intelligence "antrea.io/theia/pkg/apis/intelligence/v1alpha1" + "antrea.io/theia/pkg/client/clientset/versioned/fake" + crdinformers "antrea.io/theia/pkg/client/informers/externalversions" +) + +const informerDefaultResync = 12 * time.Hour + +var ( + npr1 = &crdv1alpha1.NetworkPolicyRecommendation{ + ObjectMeta: metav1.ObjectMeta{ + Name: "npr1", + Namespace: defaultNameSpace, + }, + Spec: crdv1alpha1.NetworkPolicyRecommendationSpec{ + JobType: "Initial", + }, + Status: crdv1alpha1.NetworkPolicyRecommendationStatus{ + State: "Pending", + }, + } + npr2 = &crdv1alpha1.NetworkPolicyRecommendation{ + ObjectMeta: metav1.ObjectMeta{ + Name: "npr2", + Namespace: defaultNameSpace, + }, + Spec: crdv1alpha1.NetworkPolicyRecommendationSpec{ + JobType: "Initial", + }, + Status: crdv1alpha1.NetworkPolicyRecommendationStatus{ + State: "COMPLETE", + }, + } + npr3 = &crdv1alpha1.NetworkPolicyRecommendation{ + ObjectMeta: metav1.ObjectMeta{ + Name: "npr3", + Namespace: defaultNameSpace, + }, + Spec: crdv1alpha1.NetworkPolicyRecommendationSpec{ + JobType: "Subsequent", + }, + Status: crdv1alpha1.NetworkPolicyRecommendationStatus{ + State: "COMPLETE", + }, + } + npr1Intelli = &intelligence.NetworkPolicyRecommendation{ + ObjectMeta: metav1.ObjectMeta{ + Name: "npr1", + }, + Type: "Initial", + Status: intelligence.NetworkPolicyRecommendationStatus{ + State: "Pending", + }, + } + npr2Intelli = &intelligence.NetworkPolicyRecommendation{ + ObjectMeta: metav1.ObjectMeta{ + Name: "npr2", + }, + Type: "Initial", + Status: intelligence.NetworkPolicyRecommendationStatus{ + State: "COMPLETE", + RecommendedNetworkPolicy: "RNP-test-npr2", + }, + } + npr3Intelli = &intelligence.NetworkPolicyRecommendation{ + ObjectMeta: metav1.ObjectMeta{ + Name: "npr3", + }, + Type: "Subsequent", + Status: intelligence.NetworkPolicyRecommendationStatus{ + State: "COMPLETE", + RecommendedNetworkPolicy: "RNP-test-npr3", + }, + } +) + +func initTestObjects(t *testing.T, stopCh chan struct{}) (*NPRecommendationController, sqlmock.Sqlmock) { + connect, mockConnect, err := sqlmock.New() + assert.NoError(t, err) + fakeClient := fake.NewSimpleClientset(npr1, npr2, npr3) + crdInformerFactory := crdinformers.NewSharedInformerFactory(fakeClient, informerDefaultResync) + npRecommendationInformer := crdInformerFactory.Crd().V1alpha1().NetworkPolicyRecommendations() + + c := &NPRecommendationController{ + crdClient: fakeClient, + queue: workqueue.NewNamedRateLimitingQueue(workqueue.NewItemExponentialFailureRateLimiter(minRetryDelay, maxRetryDelay), "npRecommendation"), + npRecommendationInformer: npRecommendationInformer.Informer(), + npRecommendationLister: npRecommendationInformer.Lister(), + npRecommendationSynced: npRecommendationInformer.Informer().HasSynced, + connect: connect, + } + crdInformerFactory.Start(stopCh) + // Wait until npr propagates to the informer + err = waitPropagationToInformer("npr1", c) + require.NoError(t, err) + return c, mockConnect +} + +func waitPropagationToInformer(name string, c *NPRecommendationController) error { + // Wait until npr propagates to the informer + err := wait.PollImmediate(100*time.Millisecond, 3*time.Second, func() (bool, error) { + _, err := c.npRecommendationLister.NetworkPolicyRecommendations(defaultNameSpace).Get(name) + if err != nil { + return false, nil + } + return true, nil + }) + return err +} + +func TestGetNetworkPolicyRecommendation(t *testing.T) { + stopCh := make(chan struct{}) + defer close(stopCh) + c, mockConnect := initTestObjects(t, stopCh) + + tests := []struct { + name string + nprName string + nprNameSpace string + expectedError error + expectedResult *intelligence.NetworkPolicyRecommendation + }{ + { + name: "NPR not exist case", + nprName: "empty", + expectedError: fmt.Errorf("error when finding NetworkPolicyRecommendations CR"), + expectedResult: nil, + }, + { + name: "Status Complete but no data in ClickHouse", + nprName: "npr3", + expectedError: fmt.Errorf("error when copying NetworkPolicyRecommendations CR"), + expectedResult: nil, + }, + { + name: "Successful case", + nprName: "npr1", + expectedError: nil, + expectedResult: npr1Intelli, + }, + { + name: "Successful case with Status Complete", + nprName: "npr2", + expectedError: nil, + expectedResult: npr2Intelli, + }, + } + rows := mockConnect.NewRows([]string{"RNP"}).AddRow("RNP-test-npr2") + query := "SELECT yamls FROM recommendations WHERE id = (?)" + mockConnect.ExpectQuery(query).WithArgs("npr2").WillReturnRows(rows) + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + intelli, err := c.GetNetworkPolicyRecommendation(tt.nprName) + if tt.expectedError != nil { + assert.Error(t, err) + assert.Contains(t, err.Error(), tt.expectedError.Error()) + } else { + assert.NoError(t, err) + assert.Equal(t, tt.expectedResult, intelli) + } + }) + } +} + +func TestCreateNetworkPolicyRecommendation(t *testing.T) { + stopCh := make(chan struct{}) + defer close(stopCh) + c, _ := initTestObjects(t, stopCh) + + nprCreate := &crdv1alpha1.NetworkPolicyRecommendation{ + ObjectMeta: metav1.ObjectMeta{ + Name: "nprCreate", + Namespace: defaultNameSpace, + }, + Spec: crdv1alpha1.NetworkPolicyRecommendationSpec{ + JobType: "initial", + }, + Status: crdv1alpha1.NetworkPolicyRecommendationStatus{ + State: "RUNNING", + }, + } + tests := []struct { + name string + nprName string + CreateNPR *crdv1alpha1.NetworkPolicyRecommendation + expectedError error + expectedResult *crdv1alpha1.NetworkPolicyRecommendation + }{ + { + name: "Successful case", + nprName: "nprCreate", + CreateNPR: nprCreate, + expectedError: nil, + expectedResult: nprCreate, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + _, err := c.CreateNetworkPolicyRecommendation(nprCreate) + assert.NoError(t, err) + err = waitPropagationToInformer("nprCreate", c) + assert.NoError(t, err) + npr, err := c.npRecommendationLister.NetworkPolicyRecommendations(defaultNameSpace).Get(tt.nprName) + assert.NoError(t, err) + assert.Equal(t, tt.expectedResult, npr) + }) + } +} + +func TestDeleteNetworkPolicyRecommendation(t *testing.T) { + stopCh := make(chan struct{}) + defer close(stopCh) + c, mockConnect := initTestObjects(t, stopCh) + + tests := []struct { + name string + nprName string + expectedError error + }{ + { + name: "No result in ClickHouse and CRD", + nprName: "empty", + expectedError: fmt.Errorf("networkpolicyrecommendations.crd.theia.antrea.io \"empty\" not found"), + }, + { + name: "Has result in CRD but No result in ClickHouse", + nprName: "npr1", + expectedError: nil, + }, + { + name: "Has result in CRD and ClickHouse", + nprName: "npr2", + expectedError: nil, + }, + } + rows := mockConnect.NewRows([]string{"RNP"}).AddRow("RNP-test") + queryGet := "SELECT yamls FROM recommendations WHERE id = (?)" + mockConnect.ExpectQuery(queryGet).WithArgs("npr2").WillReturnRows(rows) + query := "ALTER TABLE recommendations_local ON CLUSTER '{cluster}' DELETE WHERE id = (?)" + mockConnect.ExpectExec(query).WithArgs("npr2").WillReturnResult(sqlmock.NewResult(1, 1)) + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := c.DeleteNetworkPolicyRecommendation(tt.nprName) + if tt.expectedError != nil { + assert.Error(t, err) + assert.Contains(t, err.Error(), tt.expectedError.Error()) + } else { + assert.NoError(t, err) + err = wait.PollImmediate(100*time.Millisecond, 3*time.Second, func() (bool, error) { + _, err := c.npRecommendationLister.NetworkPolicyRecommendations(defaultNameSpace).Get(tt.nprName) + if err != nil { + return true, nil + } + return false, nil + }) + assert.NoError(t, err) + } + }) + } +} + +func TestListNetworkPolicyRecommendation(t *testing.T) { + stopCh := make(chan struct{}) + defer close(stopCh) + c, mockConnect := initTestObjects(t, stopCh) + + tests := []struct { + name string + nprName string + expectedResult *intelligence.NetworkPolicyRecommendationList + }{ + { + name: "Successful case", + expectedResult: &intelligence.NetworkPolicyRecommendationList{ + Items: []intelligence.NetworkPolicyRecommendation{ + *npr1Intelli, *npr2Intelli, *npr3Intelli, + }, + }, + }, + } + rows := mockConnect.NewRows([]string{"RNP"}).AddRow("RNP-test-npr2") + query := "SELECT yamls FROM recommendations WHERE id = (?)" + mockConnect.ExpectQuery(query).WithArgs("npr2").WillReturnRows(rows) + rows = mockConnect.NewRows([]string{"RNP"}).AddRow("RNP-test-npr3") + mockConnect.ExpectQuery(query).WithArgs("npr3").WillReturnRows(rows) + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + list, err := c.ListNetworkPolicyRecommendation() + assert.NoError(t, err) + assert.ElementsMatch(t, tt.expectedResult.Items, list.Items) + }) + } +} diff --git a/pkg/controller/utils.go b/pkg/controller/utils.go new file mode 100644 index 000000000..0f58522de --- /dev/null +++ b/pkg/controller/utils.go @@ -0,0 +1,140 @@ +// Copyright 2022 Antrea Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package controller + +import ( + "context" + "database/sql" + "fmt" + "time" + + "github.com/ClickHouse/clickhouse-go" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + "k8s.io/apimachinery/pkg/util/wait" + "k8s.io/client-go/kubernetes" + + "antrea.io/theia/pkg/theia/commands/config" +) + +func CheckClickHousePod(clientset kubernetes.Interface) error { + // Check the ClickHouse deployment in flow-visibility namespace + pods, err := clientset.CoreV1().Pods(config.FlowVisibilityNS).List(context.TODO(), metav1.ListOptions{ + LabelSelector: "app=clickhouse", + }) + if err != nil { + return fmt.Errorf("error %v when finding the ClickHouse Pod, please check the deployment of the ClickHouse", err) + } + if len(pods.Items) < 1 { + return fmt.Errorf("can't find the ClickHouse Pod, please check the deployment of ClickHouse") + } + hasRunningPod := false + for _, pod := range pods.Items { + if pod.Status.Phase == "Running" { + hasRunningPod = true + break + } + } + if !hasRunningPod { + return fmt.Errorf("can't find a running ClickHouse Pod, please check the deployment of ClickHouse") + } + return nil +} + +func GetServiceAddr(clientset kubernetes.Interface, serviceName string) (string, int, error) { + var serviceIP string + var servicePort int + service, err := clientset.CoreV1().Services(config.FlowVisibilityNS).Get(context.TODO(), serviceName, metav1.GetOptions{}) + if err != nil { + return serviceIP, servicePort, fmt.Errorf("error when finding the Service %s: %v", serviceName, err) + } + serviceIP = service.Spec.ClusterIP + for _, port := range service.Spec.Ports { + if port.Name == "tcp" { + servicePort = int(port.Port) + } + } + if servicePort == 0 { + return serviceIP, servicePort, fmt.Errorf("error when finding the Service %s: %v", serviceName, err) + } + return serviceIP, servicePort, nil +} + +func getClickHouseSecret(clientset kubernetes.Interface) (username []byte, password []byte, err error) { + secret, err := clientset.CoreV1().Secrets(config.FlowVisibilityNS).Get(context.TODO(), "clickhouse-secret", metav1.GetOptions{}) + if err != nil { + return username, password, fmt.Errorf("error %v when finding the ClickHouse secret, please check the deployment of ClickHouse", err) + } + username, ok := secret.Data["username"] + if !ok { + return username, password, fmt.Errorf("error when getting the ClickHouse username") + } + password, ok = secret.Data["password"] + if !ok { + return username, password, fmt.Errorf("error when getting the ClickHouse password") + } + return username, password, nil +} + +func connectClickHouse(clientset kubernetes.Interface, url string) (*sql.DB, error) { + var connect *sql.DB + var connErr error + connRetryInterval := 1 * time.Second + connTimeout := 10 * time.Second + + // Connect to ClickHouse in a loop + if err := wait.PollImmediate(connRetryInterval, connTimeout, func() (bool, error) { + // Open the database and ping it + var err error + connect, err = sql.Open("clickhouse", url) + if err != nil { + connErr = fmt.Errorf("failed to open ClickHouse: %v", err) + return false, nil + } + if err := connect.Ping(); err != nil { + if exception, ok := err.(*clickhouse.Exception); ok { + connErr = fmt.Errorf("failed to ping ClickHouse: %v", exception.Message) + } else { + connErr = fmt.Errorf("failed to ping ClickHouse: %v", err) + } + return false, nil + } else { + return true, nil + } + }); err != nil { + return nil, fmt.Errorf("failed to connect to ClickHouse after %s: %v", connTimeout, connErr) + } + return connect, nil +} + +func SetupClickHouseConnection(clientset kubernetes.Interface) (connect *sql.DB, err error) { + service := "clickhouse-clickhouse" + serviceIP, servicePort, err := GetServiceAddr(clientset, service) + if err != nil { + return nil, fmt.Errorf("error when getting the ClickHouse Service address: %v", err) + } + endpoint := fmt.Sprintf("tcp://%s:%d", serviceIP, servicePort) + + // Connect to ClickHouse and execute query + username, password, err := getClickHouseSecret(clientset) + if err != nil { + return nil, err + } + url := fmt.Sprintf("%s?debug=false&username=%s&password=%s", endpoint, username, password) + connect, err = connectClickHouse(clientset, url) + if err != nil { + return nil, fmt.Errorf("error when connecting to ClickHouse, %v", err) + } + return connect, nil +} diff --git a/pkg/querier/querier.go b/pkg/querier/querier.go index 26d885c09..856d2877e 100644 --- a/pkg/querier/querier.go +++ b/pkg/querier/querier.go @@ -16,8 +16,12 @@ package querier import ( "antrea.io/theia/pkg/apis/crd/v1alpha1" + intelligence "antrea.io/theia/pkg/apis/intelligence/v1alpha1" ) type NPRecommendationQuerier interface { - GetNetworkPolicyRecommendation(namespace, name string) (*v1alpha1.NetworkPolicyRecommendation, error) + GetNetworkPolicyRecommendation(name string) (*intelligence.NetworkPolicyRecommendation, error) + CreateNetworkPolicyRecommendation(*v1alpha1.NetworkPolicyRecommendation) (*v1alpha1.NetworkPolicyRecommendation, error) + DeleteNetworkPolicyRecommendation(name string) error + ListNetworkPolicyRecommendation() (*intelligence.NetworkPolicyRecommendationList, error) } diff --git a/pkg/theia/commands/policy_recommendation_status.go b/pkg/theia/commands/policy_recommendation_status.go index c57b32228..e8bfc11b1 100644 --- a/pkg/theia/commands/policy_recommendation_status.go +++ b/pkg/theia/commands/policy_recommendation_status.go @@ -15,21 +15,9 @@ package commands import ( - "context" - "encoding/json" "fmt" - "io" - "net/http" - "strings" - "time" "github.com/spf13/cobra" - "k8s.io/apimachinery/pkg/util/wait" - "k8s.io/client-go/kubernetes" - "k8s.io/klog/v2" - - "antrea.io/theia/pkg/theia/commands/config" - sparkv1 "antrea.io/theia/third_party/sparkoperator/v1beta2" ) // policyRecommendationStatusCmd represents the policy-recommendation status command @@ -47,211 +35,61 @@ $ theia policy-recommendation status e998433e-accb-4888-9fc8-06563f073e86 Use Service ClusterIP when checking the current status of job with ID e998433e-accb-4888-9fc8-06563f073e86 $ theia policy-recommendation status e998433e-accb-4888-9fc8-06563f073e86 --use-cluster-ip `, - RunE: func(cmd *cobra.Command, args []string) error { - recoID, err := cmd.Flags().GetString("id") - if err != nil { - return err - } - if recoID == "" && len(args) == 1 { - recoID = args[0] - } - err = ParseRecommendationID(recoID) - if err != nil { - return err - } - kubeconfig, err := ResolveKubeConfig(cmd) - if err != nil { - return err - } - clientset, err := CreateK8sClient(kubeconfig) - if err != nil { - return fmt.Errorf("couldn't create k8s client using given kubeconfig, %v", err) - } - endpoint, err := cmd.Flags().GetString("clickhouse-endpoint") - if err != nil { - return err - } - if endpoint != "" { - err = ParseEndpoint(endpoint) - if err != nil { - return err - } - } - useClusterIP, err := cmd.Flags().GetBool("use-cluster-ip") - if err != nil { - return err - } - - err = PolicyRecoPreCheck(clientset) - if err != nil { - return err - } - var state, errorMessage string - // Check the ClickHouse first because completed jobs will store results in ClickHouse - _, err = getPolicyRecommendationResult(clientset, kubeconfig, endpoint, useClusterIP, "", recoID) - if err != nil { - state, err = getPolicyRecommendationStatus(clientset, recoID) - if err != nil { - return err - } - if state == "" { - state = "NEW" - } - if state == "RUNNING" { - var endpoint string - service := fmt.Sprintf("pr-%s-ui-svc", recoID) - if useClusterIP { - serviceIP, servicePort, err := GetServiceAddr(clientset, service) - if err != nil { - klog.V(2).ErrorS(err, "error when getting the progress of the job, cannot get Spark Monitor Service address") - } else { - endpoint = fmt.Sprintf("tcp://%s:%d", serviceIP, servicePort) - } - } else { - servicePort := 4040 - listenAddress := "localhost" - listenPort := 4040 - pf, err := StartPortForward(kubeconfig, service, servicePort, listenAddress, listenPort) - if err != nil { - klog.V(2).ErrorS(err, "error when getting the progress of the job, cannot forward port") - } else { - endpoint = fmt.Sprintf("http://%s:%d", listenAddress, listenPort) - defer pf.Stop() - } - } - // Check the working progress of running recommendation job - if endpoint != "" { - stateProgress, err := getPolicyRecommendationProgress(endpoint) - if err != nil { - klog.V(2).ErrorS(err, "failed to get the progress of the job") - } - state += stateProgress - } - } - errorMessage, err = getPolicyRecommendationErrorMsg(clientset, recoID) - if err != nil { - return err - } - } else { - state = "COMPLETED" - } - fmt.Printf("Status of this policy recommendation job is %s\n", state) - if errorMessage != "" { - fmt.Printf("Error message: %s\n", errorMessage) - } - return nil - }, + RunE: policyRecommendationStatus, } -func getSparkAppByRecommendationID(clientset kubernetes.Interface, id string) (sparkApp sparkv1.SparkApplication, err error) { - err = clientset.CoreV1().RESTClient(). - Get(). - AbsPath("/apis/sparkoperator.k8s.io/v1beta2"). - Namespace(config.FlowVisibilityNS). - Resource("sparkapplications"). - Name("pr-" + id). - Do(context.TODO()). - Into(&sparkApp) - if err != nil { - return sparkApp, err - } - return sparkApp, nil +func init() { + policyRecommendationCmd.AddCommand(policyRecommendationStatusCmd) + policyRecommendationStatusCmd.Flags().StringP( + "name", + "", + "", + "Name of the policy recommendation job.", + ) } -func getPolicyRecommendationStatus(clientset kubernetes.Interface, id string) (string, error) { - sparkApplication, err := getSparkAppByRecommendationID(clientset, id) +func policyRecommendationStatus(cmd *cobra.Command, args []string) error { + prName, err := cmd.Flags().GetString("name") if err != nil { - return "", err + return err } - state := strings.TrimSpace(string(sparkApplication.Status.AppState.State)) - if state == "" { - state = "NEW" + if prName == "" && len(args) == 1 { + prName = args[0] } - return state, nil -} - -func getPolicyRecommendationErrorMsg(clientset kubernetes.Interface, id string) (string, error) { - sparkApplication, err := getSparkAppByRecommendationID(clientset, id) + err = ParseRecommendationName(prName) if err != nil { - return "", err + return err } - errorMessage := strings.TrimSpace(string(sparkApplication.Status.AppState.ErrorMessage)) - return errorMessage, nil -} - -func getPolicyRecommendationProgress(baseUrl string) (string, error) { - // Get the id of current Spark application - url := fmt.Sprintf("%s/api/v1/applications", baseUrl) - response, err := getResponseFromSparkMonitoringSvc(url) + useClusterIP, err := cmd.Flags().GetBool("use-cluster-ip") if err != nil { - return "", fmt.Errorf("failed to get response from the Spark Monitoring Service: %v", err) + return err } - var getAppsResult []map[string]interface{} - json.Unmarshal([]byte(response), &getAppsResult) - if len(getAppsResult) != 1 { - return "", fmt.Errorf("wrong Spark Application number, expected 1, got %d", len(getAppsResult)) - } - sparkAppID := getAppsResult[0]["id"] - // Check the percentage of completed stages - url = fmt.Sprintf("%s/api/v1/applications/%s/stages", baseUrl, sparkAppID) - response, err = getResponseFromSparkMonitoringSvc(url) + theiaClient, pf, err := SetupTheiaClientAndConnection(cmd, useClusterIP) if err != nil { - return "", fmt.Errorf("failed to get response from the Spark Monitoring Service: %v", err) - } - var getStagesResult []map[string]interface{} - json.Unmarshal([]byte(response), &getStagesResult) - NumStageResult := len(getStagesResult) - if NumStageResult < 1 { - return "", fmt.Errorf("wrong Spark Application stages number, expected at least 1, got %d", NumStageResult) + return fmt.Errorf("couldn't setup theia client: %v", err) } - completedStages := 0 - for _, stage := range getStagesResult { - if stage["status"] == "COMPLETE" || stage["status"] == "SKIPPED" { - completedStages++ - } + if pf != nil { + defer pf.Stop() } - return fmt.Sprintf(": %d/%d (%d%%) stages completed", completedStages, NumStageResult, completedStages*100/NumStageResult), nil -} - -func getResponseFromSparkMonitoringSvc(url string) ([]byte, error) { - sparkMonitoringClient := http.Client{} - request, err := http.NewRequest(http.MethodGet, url, nil) + npr, err := getPolicyRecommendationByName(theiaClient, prName) if err != nil { - return nil, err + return fmt.Errorf("error when getting policy recommendation job by using job name: %v", err) } - var res *http.Response - var getErr error - connRetryInterval := 1 * time.Second - connTimeout := 10 * time.Second - if err := wait.PollImmediate(connRetryInterval, connTimeout, func() (bool, error) { - res, err = sparkMonitoringClient.Do(request) - if err != nil { - getErr = err - return false, nil + state := npr.Status.State + if state == "RUNNING" { + completedStages := npr.Status.CompletedStages + totalStages := npr.Status.TotalStages + if totalStages < 1 { + return fmt.Errorf("wrong Spark Application stages number, expected at least 1, got %d", totalStages) } - return true, nil - }); err != nil { - return nil, getErr + stateProgress := fmt.Sprintf(": %d/%d (%d%%) stages completed", completedStages, totalStages, completedStages*100/totalStages) + state += stateProgress } - if res == nil { - return nil, fmt.Errorf("response is nil") - } - if res.Body != nil { - defer res.Body.Close() - } - body, readErr := io.ReadAll(res.Body) - if readErr != nil { - return nil, readErr - } - return body, nil -} + errorMessage := npr.Status.ErrorMsg -func init() { - policyRecommendationCmd.AddCommand(policyRecommendationStatusCmd) - policyRecommendationStatusCmd.Flags().StringP( - "id", - "i", - "", - "ID of the policy recommendation Spark job.", - ) + fmt.Printf("Status of this policy recommendation job is %s\n", state) + if errorMessage != "" { + fmt.Printf("Error message: %s\n", errorMessage) + } + return nil }