diff --git a/apis/v1alpha2/validation/gateway.go b/apis/v1alpha2/validation/gateway.go index 7e7bc27695..6431a83e01 100644 --- a/apis/v1alpha2/validation/gateway.go +++ b/apis/v1alpha2/validation/gateway.go @@ -17,11 +17,27 @@ limitations under the License. package validation import ( + "fmt" + "k8s.io/apimachinery/pkg/util/validation/field" gatewayv1a2 "sigs.k8s.io/gateway-api/apis/v1alpha2" ) +var ( + // set of protocols for which we need to validate that hostname is empty + protocolsHostnameInvalid = map[gatewayv1a2.ProtocolType]struct{}{ + gatewayv1a2.TCPProtocolType: {}, + gatewayv1a2.UDPProtocolType: {}, + } + // set of protocols for which TLSConfig shall not be present + protocolsTLSInvalid = map[gatewayv1a2.ProtocolType]struct{}{ + gatewayv1a2.HTTPProtocolType: {}, + gatewayv1a2.UDPProtocolType: {}, + gatewayv1a2.TCPProtocolType: {}, + } +) + // ValidateGateway validates gw according to the Gateway API specification. // For additional details of the Gateway spec, refer to: // https://gateway-api.sigs.k8s.io/spec/#gateway.networking.k8s.io/v1alpha2.Gateway @@ -29,5 +45,47 @@ import ( // Validation that is not possible with CRD annotations may be added here in the future. // See https://github.com/kubernetes-sigs/gateway-api/issues/868 for more information. func ValidateGateway(gw *gatewayv1a2.Gateway) field.ErrorList { - return nil + return validateGatewaySpec(&gw.Spec, field.NewPath("spec")) +} + +// validateGatewaySpec validates whether required fields of spec are set according to the +// Gateway API specification. +func validateGatewaySpec(spec *gatewayv1a2.GatewaySpec, path *field.Path) field.ErrorList { + return validateGatewayListeners(spec.Listeners, path.Child("listeners")) +} + +// validateGatewayListeners validates whether required fields of listeners are set according +// to the Gateway API specification. +func validateGatewayListeners(listeners []gatewayv1a2.Listener, path *field.Path) field.ErrorList { + var errs field.ErrorList + errs = append(errs, validateListenerTLSConfig(listeners, path)...) + errs = append(errs, validateListenerHostname(listeners, path)...) + return errs +} + +func validateListenerTLSConfig(listeners []gatewayv1a2.Listener, path *field.Path) field.ErrorList { + var errs field.ErrorList + for i, l := range listeners { + if isProtocolInSubset(l.Protocol, protocolsTLSInvalid) && l.TLS != nil { + errs = append(errs, field.Forbidden(path.Index(i).Child("tls"), fmt.Sprintf("should be empty for protocol %v", l.Protocol))) + } + } + return errs +} + +func isProtocolInSubset(protocol gatewayv1a2.ProtocolType, set map[gatewayv1a2.ProtocolType]struct{}) bool { + _, ok := set[protocol] + return ok +} + +// validateListenerHostname validates each listener hostname +// should be empty in case protocol is TCP or UDP +func validateListenerHostname(listeners []gatewayv1a2.Listener, path *field.Path) field.ErrorList { + var errs field.ErrorList + for i, h := range listeners { + if isProtocolInSubset(h.Protocol, protocolsHostnameInvalid) && h.Hostname != nil { + errs = append(errs, field.Forbidden(path.Index(i).Child("hostname"), fmt.Sprintf("should be empty for protocol %v", h.Protocol))) + } + } + return errs } diff --git a/apis/v1alpha2/validation/gateway_test.go b/apis/v1alpha2/validation/gateway_test.go new file mode 100644 index 0000000000..1ce7fc0d0b --- /dev/null +++ b/apis/v1alpha2/validation/gateway_test.go @@ -0,0 +1,97 @@ +/* +Copyright 2021 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package validation + +import ( + "testing" + + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + + gatewayv1a2 "sigs.k8s.io/gateway-api/apis/v1alpha2" +) + +func TestValidateGateway(t *testing.T) { + listeners := []gatewayv1a2.Listener{ + { + Hostname: nil, + }, + } + baseGateway := gatewayv1a2.Gateway{ + ObjectMeta: metav1.ObjectMeta{ + Name: "foo", + Namespace: metav1.NamespaceDefault, + }, + Spec: gatewayv1a2.GatewaySpec{ + GatewayClassName: "foo", + Listeners: listeners, + }, + } + tlsConfig := gatewayv1a2.GatewayTLSConfig{} + + testCases := map[string]struct { + mutate func(gw *gatewayv1a2.Gateway) + expectErrsOnFields []string + }{ + "tls config present with http protocol": { + mutate: func(gw *gatewayv1a2.Gateway) { + gw.Spec.Listeners[0].Protocol = gatewayv1a2.HTTPProtocolType + gw.Spec.Listeners[0].TLS = &tlsConfig + }, + expectErrsOnFields: []string{"spec.listeners[0].tls"}, + }, + "tls config present with tcp protocol": { + mutate: func(gw *gatewayv1a2.Gateway) { + gw.Spec.Listeners[0].Protocol = gatewayv1a2.TCPProtocolType + gw.Spec.Listeners[0].TLS = &tlsConfig + }, + expectErrsOnFields: []string{"spec.listeners[0].tls"}, + }, + "hostname present with tcp protocol": { + mutate: func(gw *gatewayv1a2.Gateway) { + hostname := gatewayv1a2.Hostname("foo.bar.com") + gw.Spec.Listeners[0].Hostname = &hostname + gw.Spec.Listeners[0].Protocol = gatewayv1a2.TCPProtocolType + }, + expectErrsOnFields: []string{"spec.listeners[0].hostname"}, + }, + "hostname present with udp protocol": { + mutate: func(gw *gatewayv1a2.Gateway) { + hostname := gatewayv1a2.Hostname("foo.bar.com") + gw.Spec.Listeners[0].Hostname = &hostname + gw.Spec.Listeners[0].Protocol = gatewayv1a2.UDPProtocolType + }, + expectErrsOnFields: []string{"spec.listeners[0].hostname"}, + }, + } + + for name, tc := range testCases { + tc := tc + t.Run(name, func(t *testing.T) { + gw := baseGateway.DeepCopy() + tc.mutate(gw) + errs := ValidateGateway(gw) + if len(tc.expectErrsOnFields) != len(errs) { + t.Fatalf("Expected %d errors, got %d errors: %v", len(tc.expectErrsOnFields), len(errs), errs) + } + for i, err := range errs { + if err.Field != tc.expectErrsOnFields[i] { + t.Errorf("Expected error on field: %s, got: %s", tc.expectErrsOnFields[i], err.Error()) + } + } + }) + } +}