Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add SDS support #56

Merged
merged 6 commits into from
Jan 25, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,9 @@ integration.docker: docker
docker run -it -e "XDS=ads" test -debug
docker run -it -e "XDS=xds" test -debug
docker run -it -e "XDS=rest" test -debug
docker run -it -e "XDS=ads" test -debug -tls
docker run -it -e "XDS=xds" test -debug -tls
docker run -it -e "XDS=rest" test -debug -tls

#-----------------
#-- code generaion
Expand Down
5 changes: 5 additions & 0 deletions pkg/cache/resource.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ import (
"github.com/gogo/protobuf/proto"

v2 "github.com/envoyproxy/go-control-plane/envoy/api/v2"
"github.com/envoyproxy/go-control-plane/envoy/api/v2/auth"
hcm "github.com/envoyproxy/go-control-plane/envoy/config/filter/network/http_connection_manager/v2"
"github.com/envoyproxy/go-control-plane/pkg/util"
)
Expand All @@ -35,6 +36,7 @@ const (
ClusterType = typePrefix + "Cluster"
RouteType = typePrefix + "RouteConfiguration"
ListenerType = typePrefix + "Listener"
SecretType = typePrefix + "auth.Secret"

// AnyType is used only by ADS
AnyType = ""
Expand All @@ -47,6 +49,7 @@ var (
ClusterType,
RouteType,
ListenerType,
SecretType,
}
)

Expand All @@ -61,6 +64,8 @@ func GetResourceName(res Resource) string {
return v.GetName()
case *v2.Listener:
return v.GetName()
case *auth.Secret:
return v.GetName()
default:
return ""
}
Expand Down
25 changes: 16 additions & 9 deletions pkg/cache/simple_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,13 @@ var (
cache.RouteType: []string{routeName},
cache.ListenerType: nil,
}

testTypes = []string{
cache.EndpointType,
cache.ClusterType,
cache.RouteType,
cache.ListenerType,
}
)

type logger struct {
Expand All @@ -81,7 +88,7 @@ func TestSnapshotCache(t *testing.T) {
case <-time.After(time.Second / 4):
}

for _, typ := range cache.ResponseTypes {
for _, typ := range testTypes {
t.Run(typ, func(t *testing.T) {
value, _ := c.CreateWatch(v2.DiscoveryRequest{TypeUrl: typ, ResourceNames: names[typ]})
select {
Expand All @@ -105,7 +112,7 @@ func TestSnapshotCacheFetch(t *testing.T) {
t.Fatal(err)
}

for _, typ := range cache.ResponseTypes {
for _, typ := range testTypes {
t.Run(typ, func(t *testing.T) {
resp, err := c.Fetch(context.Background(), v2.DiscoveryRequest{TypeUrl: typ, ResourceNames: names[typ]})
if err != nil || resp == nil {
Expand Down Expand Up @@ -133,13 +140,13 @@ func TestSnapshotCacheFetch(t *testing.T) {
func TestSnapshotCacheWatch(t *testing.T) {
c := cache.NewSnapshotCache(true, group{}, logger{t: t})
watches := make(map[string]chan cache.Response)
for _, typ := range cache.ResponseTypes {
for _, typ := range testTypes {
watches[typ], _ = c.CreateWatch(v2.DiscoveryRequest{TypeUrl: typ, ResourceNames: names[typ]})
}
if err := c.SetSnapshot(key, snapshot); err != nil {
t.Fatal(err)
}
for _, typ := range cache.ResponseTypes {
for _, typ := range testTypes {
t.Run(typ, func(t *testing.T) {
select {
case out := <-watches[typ]:
Expand All @@ -156,10 +163,10 @@ func TestSnapshotCacheWatch(t *testing.T) {
}

// open new watches with the latest version
for _, typ := range cache.ResponseTypes {
for _, typ := range testTypes {
watches[typ], _ = c.CreateWatch(v2.DiscoveryRequest{TypeUrl: typ, ResourceNames: names[typ], VersionInfo: version})
}
if count := c.GetStatusInfo(key).GetNumWatches(); count != len(cache.ResponseTypes) {
if count := c.GetStatusInfo(key).GetNumWatches(); count != len(testTypes) {
t.Errorf("watches should be created for the latest version: %d", count)
}

Expand All @@ -169,7 +176,7 @@ func TestSnapshotCacheWatch(t *testing.T) {
if err := c.SetSnapshot(key, snapshot2); err != nil {
t.Fatal(err)
}
if count := c.GetStatusInfo(key).GetNumWatches(); count != len(cache.ResponseTypes)-1 {
if count := c.GetStatusInfo(key).GetNumWatches(); count != len(testTypes)-1 {
t.Errorf("watches should be preserved for all but one: %d", count)
}

Expand Down Expand Up @@ -215,7 +222,7 @@ func TestConcurrentSetWatch(t *testing.T) {

func TestSnapshotCacheWatchCancel(t *testing.T) {
c := cache.NewSnapshotCache(true, group{}, logger{t: t})
for _, typ := range cache.ResponseTypes {
for _, typ := range testTypes {
_, cancel := c.CreateWatch(v2.DiscoveryRequest{TypeUrl: typ, ResourceNames: names[typ]})
cancel()
}
Expand All @@ -224,7 +231,7 @@ func TestSnapshotCacheWatchCancel(t *testing.T) {
t.Error("got 0, want status info for the node")
}

for _, typ := range cache.ResponseTypes {
for _, typ := range testTypes {
if count := c.GetStatusInfo(key).GetNumWatches(); count > 0 {
t.Errorf("watches should be released for %s", typ)
}
Expand Down
7 changes: 7 additions & 0 deletions pkg/cache/snapshot.go
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,9 @@ type Snapshot struct {

// Listeners are items in the LDS response payload.
Listeners Resources

// Secrets are items in the SDS response payload.
Secrets Resources
}

// NewSnapshot creates a snapshot from response types and a version.
Expand Down Expand Up @@ -117,6 +120,8 @@ func (s *Snapshot) GetResources(typ string) map[string]Resource {
return s.Routes.Items
case ListenerType:
return s.Listeners.Items
case SecretType:
return s.Secrets.Items
}
return nil
}
Expand All @@ -135,6 +140,8 @@ func (s *Snapshot) GetVersion(typ string) string {
return s.Routes.Version
case ListenerType:
return s.Listeners.Version
case SecretType:
return s.Secrets.Version
}
return ""
}
2 changes: 2 additions & 0 deletions pkg/server/gateway.go
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,8 @@ func (h *HTTPGateway) ServeHTTP(resp http.ResponseWriter, req *http.Request) {
typeURL = cache.ListenerType
case "/v2/discovery:routes":
typeURL = cache.RouteType
case "/v2/discovery:secrets":
typeURL = cache.SecretType
default:
http.Error(resp, "no endpoint", http.StatusNotFound)
return
Expand Down
34 changes: 34 additions & 0 deletions pkg/server/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ type Server interface {
v2.RouteDiscoveryServiceServer
v2.ListenerDiscoveryServiceServer
discovery.AggregatedDiscoveryServiceServer
discovery.SecretDiscoveryServiceServer

// Fetch is the universal fetch method.
Fetch(context.Context, *v2.DiscoveryRequest) (*v2.DiscoveryResponse, error)
Expand Down Expand Up @@ -89,16 +90,19 @@ type watches struct {
clusters chan cache.Response
routes chan cache.Response
listeners chan cache.Response
secrets chan cache.Response

endpointCancel func()
clusterCancel func()
routeCancel func()
listenerCancel func()
secretCancel func()

endpointNonce string
clusterNonce string
routeNonce string
listenerNonce string
secretNonce string
}

// Cancel all watches
Expand All @@ -115,6 +119,9 @@ func (values watches) Cancel() {
if values.listenerCancel != nil {
values.listenerCancel()
}
if values.secretCancel != nil {
values.secretCancel()
}
}

func createResponse(resp *cache.Response, typeURL string) (*v2.DiscoveryResponse, error) {
Expand Down Expand Up @@ -223,6 +230,16 @@ func (s *server) process(stream stream, reqCh <-chan *v2.DiscoveryRequest, defau
}
values.listenerNonce = nonce

case resp, more := <-values.secrets:
if !more {
return status.Errorf(codes.Unavailable, "secrets watch failed")
}
nonce, err := send(resp, cache.SecretType)
if err != nil {
return err
}
values.secretNonce = nonce

case req, more := <-reqCh:
// input stream ended or errored out
if !more {
Expand Down Expand Up @@ -270,6 +287,11 @@ func (s *server) process(stream stream, reqCh <-chan *v2.DiscoveryRequest, defau
values.listenerCancel()
}
values.listeners, values.listenerCancel = s.cache.CreateWatch(*req)
case req.TypeUrl == cache.SecretType && (values.secretNonce == "" || values.secretNonce == nonce):
if values.secretCancel != nil {
values.secretCancel()
}
values.secrets, values.secretCancel = s.cache.CreateWatch(*req)
}
}
}
Expand Down Expand Up @@ -323,6 +345,10 @@ func (s *server) StreamListeners(stream v2.ListenerDiscoveryService_StreamListen
return s.handler(stream, cache.ListenerType)
}

func (s *server) StreamSecrets(stream discovery.SecretDiscoveryService_StreamSecretsServer) error {
return s.handler(stream, cache.SecretType)
}

// Fetch is the universal fetch method.
func (s *server) Fetch(ctx context.Context, req *v2.DiscoveryRequest) (*v2.DiscoveryResponse, error) {
if s.callbacks != nil {
Expand Down Expand Up @@ -373,6 +399,14 @@ func (s *server) FetchListeners(ctx context.Context, req *v2.DiscoveryRequest) (
return s.Fetch(ctx, req)
}

func (s *server) FetchSecrets(ctx context.Context, req *v2.DiscoveryRequest) (*v2.DiscoveryResponse, error) {
if req == nil {
return nil, status.Errorf(codes.Unavailable, "empty request")
}
req.TypeUrl = cache.SecretType
return s.Fetch(ctx, req)
}

func (s *server) IncrementalAggregatedResources(_ discovery.AggregatedDiscoveryService_IncrementalAggregatedResourcesServer) error {
return errors.New("not implemented")
}
Expand Down
24 changes: 15 additions & 9 deletions pkg/server/server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -165,10 +165,16 @@ var (
Id: "test-id",
Cluster: "test-cluster",
}
endpoint = resource.MakeEndpoint(clusterName, 8080)
cluster = resource.MakeCluster(resource.Ads, clusterName)
route = resource.MakeRoute(routeName, clusterName)
listener = resource.MakeHTTPListener(resource.Ads, listenerName, 80, routeName)
endpoint = resource.MakeEndpoint(clusterName, 8080)
cluster = resource.MakeCluster(resource.Ads, clusterName)
route = resource.MakeRoute(routeName, clusterName)
listener = resource.MakeHTTPListener(resource.Ads, listenerName, 80, routeName)
testTypes = []string{
cache.EndpointType,
cache.ClusterType,
cache.RouteType,
cache.ListenerType,
}
)

func makeResponses() map[string][]cache.Response {
Expand All @@ -193,7 +199,7 @@ func makeResponses() map[string][]cache.Response {
}

func TestResponseHandlers(t *testing.T) {
for _, typ := range cache.ResponseTypes {
for _, typ := range testTypes {
t.Run(typ, func(t *testing.T) {
config := makeMockConfigWatcher()
config.responses = makeResponses()
Expand Down Expand Up @@ -304,7 +310,7 @@ func TestFetch(t *testing.T) {
}

func TestWatchClosed(t *testing.T) {
for _, typ := range cache.ResponseTypes {
for _, typ := range testTypes {
t.Run(typ, func(t *testing.T) {
config := makeMockConfigWatcher()
config.closeWatch = true
Expand All @@ -328,7 +334,7 @@ func TestWatchClosed(t *testing.T) {
}

func TestSendError(t *testing.T) {
for _, typ := range cache.ResponseTypes {
for _, typ := range testTypes {
t.Run(typ, func(t *testing.T) {
config := makeMockConfigWatcher()
config.responses = makeResponses()
Expand All @@ -353,7 +359,7 @@ func TestSendError(t *testing.T) {
}

func TestStaleNonce(t *testing.T) {
for _, typ := range cache.ResponseTypes {
for _, typ := range testTypes {
t.Run(typ, func(t *testing.T) {
config := makeMockConfigWatcher()
config.responses = makeResponses()
Expand Down Expand Up @@ -466,7 +472,7 @@ func TestAggregateRequestType(t *testing.T) {
}

func TestCallbackError(t *testing.T) {
for _, typ := range cache.ResponseTypes {
for _, typ := range testTypes {
t.Run(typ, func(t *testing.T) {
config := makeMockConfigWatcher()
config.responses = makeResponses()
Expand Down
13 changes: 12 additions & 1 deletion pkg/test/main/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ package main

import (
"context"
cryptotls "crypto/tls"
"flag"
"fmt"
"io/ioutil"
Expand Down Expand Up @@ -51,6 +52,7 @@ var (
clusters int
httpListeners int
tcpListeners int
tls bool

nodeID string
)
Expand All @@ -70,6 +72,7 @@ func init() {
flag.IntVar(&httpListeners, "http", 2, "Number of HTTP listeners (and RDS configs)")
flag.IntVar(&tcpListeners, "tcp", 2, "Number of TCP pass-through listeners")
flag.StringVar(&nodeID, "nodeID", "test-id", "Node ID")
flag.BoolVar(&tls, "tls", false, "Enable TLS on all listeners and use SDS for secret delivery")
}

// main returns code 1 if any of the batches failed to pass all requests
Expand Down Expand Up @@ -98,6 +101,7 @@ func main() {
NumClusters: clusters,
NumHTTPListeners: httpListeners,
NumTCPListeners: tcpListeners,
TLS: tls,
}

// start the xDS server
Expand Down Expand Up @@ -164,8 +168,15 @@ func callEcho() (int, int) {
go func(i int) {
client := http.Client{
Timeout: 100 * time.Millisecond,
Transport: &http.Transport{
TLSClientConfig: &cryptotls.Config{InsecureSkipVerify: true},
},
}
req, err := client.Get(fmt.Sprintf("http://localhost:%d", basePort+uint(i)))
proto := "http"
if tls {
proto = "https"
}
req, err := client.Get(fmt.Sprintf("%s://127.0.0.1:%d", proto, basePort+uint(i)))
if err != nil {
ch <- err
return
Expand Down
Loading