Skip to content

Commit

Permalink
Merge pull request #842 from bishtsaurabh5/protocol-validation
Browse files Browse the repository at this point in the history
Add protocol specific validation for Gateway api
  • Loading branch information
k8s-ci-robot authored Sep 28, 2021
2 parents e237162 + 907a029 commit b088c14
Show file tree
Hide file tree
Showing 2 changed files with 156 additions and 1 deletion.
60 changes: 59 additions & 1 deletion apis/v1alpha2/validation/gateway.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,17 +17,75 @@ limitations under the License.
package validation

import (
"fmt"

"k8s.io/apimachinery/pkg/util/validation/field"

gatewayv1a2 "sigs.k8s.io/gateway-api/apis/v1alpha2"
)

var (
// set of protocols for which we need to validate that hostname is empty
protocolsHostnameInvalid = map[gatewayv1a2.ProtocolType]struct{}{
gatewayv1a2.TCPProtocolType: {},
gatewayv1a2.UDPProtocolType: {},
}
// set of protocols for which TLSConfig shall not be present
protocolsTLSInvalid = map[gatewayv1a2.ProtocolType]struct{}{
gatewayv1a2.HTTPProtocolType: {},
gatewayv1a2.UDPProtocolType: {},
gatewayv1a2.TCPProtocolType: {},
}
)

// ValidateGateway validates gw according to the Gateway API specification.
// For additional details of the Gateway spec, refer to:
// https://gateway-api.sigs.k8s.io/spec/#gateway.networking.k8s.io/v1alpha2.Gateway
//
// Validation that is not possible with CRD annotations may be added here in the future.
// See https://github.com/kubernetes-sigs/gateway-api/issues/868 for more information.
func ValidateGateway(gw *gatewayv1a2.Gateway) field.ErrorList {
return nil
return validateGatewaySpec(&gw.Spec, field.NewPath("spec"))
}

// validateGatewaySpec validates whether required fields of spec are set according to the
// Gateway API specification.
func validateGatewaySpec(spec *gatewayv1a2.GatewaySpec, path *field.Path) field.ErrorList {
return validateGatewayListeners(spec.Listeners, path.Child("listeners"))
}

// validateGatewayListeners validates whether required fields of listeners are set according
// to the Gateway API specification.
func validateGatewayListeners(listeners []gatewayv1a2.Listener, path *field.Path) field.ErrorList {
var errs field.ErrorList
errs = append(errs, validateListenerTLSConfig(listeners, path)...)
errs = append(errs, validateListenerHostname(listeners, path)...)
return errs
}

func validateListenerTLSConfig(listeners []gatewayv1a2.Listener, path *field.Path) field.ErrorList {
var errs field.ErrorList
for i, l := range listeners {
if isProtocolInSubset(l.Protocol, protocolsTLSInvalid) && l.TLS != nil {
errs = append(errs, field.Forbidden(path.Index(i).Child("tls"), fmt.Sprintf("should be empty for protocol %v", l.Protocol)))
}
}
return errs
}

func isProtocolInSubset(protocol gatewayv1a2.ProtocolType, set map[gatewayv1a2.ProtocolType]struct{}) bool {
_, ok := set[protocol]
return ok
}

// validateListenerHostname validates each listener hostname
// should be empty in case protocol is TCP or UDP
func validateListenerHostname(listeners []gatewayv1a2.Listener, path *field.Path) field.ErrorList {
var errs field.ErrorList
for i, h := range listeners {
if isProtocolInSubset(h.Protocol, protocolsHostnameInvalid) && h.Hostname != nil {
errs = append(errs, field.Forbidden(path.Index(i).Child("hostname"), fmt.Sprintf("should be empty for protocol %v", h.Protocol)))
}
}
return errs
}
97 changes: 97 additions & 0 deletions apis/v1alpha2/validation/gateway_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
/*
Copyright 2021 The Kubernetes Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/

package validation

import (
"testing"

metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"

gatewayv1a2 "sigs.k8s.io/gateway-api/apis/v1alpha2"
)

func TestValidateGateway(t *testing.T) {
listeners := []gatewayv1a2.Listener{
{
Hostname: nil,
},
}
baseGateway := gatewayv1a2.Gateway{
ObjectMeta: metav1.ObjectMeta{
Name: "foo",
Namespace: metav1.NamespaceDefault,
},
Spec: gatewayv1a2.GatewaySpec{
GatewayClassName: "foo",
Listeners: listeners,
},
}
tlsConfig := gatewayv1a2.GatewayTLSConfig{}

testCases := map[string]struct {
mutate func(gw *gatewayv1a2.Gateway)
expectErrsOnFields []string
}{
"tls config present with http protocol": {
mutate: func(gw *gatewayv1a2.Gateway) {
gw.Spec.Listeners[0].Protocol = gatewayv1a2.HTTPProtocolType
gw.Spec.Listeners[0].TLS = &tlsConfig
},
expectErrsOnFields: []string{"spec.listeners[0].tls"},
},
"tls config present with tcp protocol": {
mutate: func(gw *gatewayv1a2.Gateway) {
gw.Spec.Listeners[0].Protocol = gatewayv1a2.TCPProtocolType
gw.Spec.Listeners[0].TLS = &tlsConfig
},
expectErrsOnFields: []string{"spec.listeners[0].tls"},
},
"hostname present with tcp protocol": {
mutate: func(gw *gatewayv1a2.Gateway) {
hostname := gatewayv1a2.Hostname("foo.bar.com")
gw.Spec.Listeners[0].Hostname = &hostname
gw.Spec.Listeners[0].Protocol = gatewayv1a2.TCPProtocolType
},
expectErrsOnFields: []string{"spec.listeners[0].hostname"},
},
"hostname present with udp protocol": {
mutate: func(gw *gatewayv1a2.Gateway) {
hostname := gatewayv1a2.Hostname("foo.bar.com")
gw.Spec.Listeners[0].Hostname = &hostname
gw.Spec.Listeners[0].Protocol = gatewayv1a2.UDPProtocolType
},
expectErrsOnFields: []string{"spec.listeners[0].hostname"},
},
}

for name, tc := range testCases {
tc := tc
t.Run(name, func(t *testing.T) {
gw := baseGateway.DeepCopy()
tc.mutate(gw)
errs := ValidateGateway(gw)
if len(tc.expectErrsOnFields) != len(errs) {
t.Fatalf("Expected %d errors, got %d errors: %v", len(tc.expectErrsOnFields), len(errs), errs)
}
for i, err := range errs {
if err.Field != tc.expectErrsOnFields[i] {
t.Errorf("Expected error on field: %s, got: %s", tc.expectErrsOnFields[i], err.Error())
}
}
})
}
}

0 comments on commit b088c14

Please sign in to comment.