Skip to content

Commit

Permalink
add SDS support (#56)
Browse files Browse the repository at this point in the history
  • Loading branch information
kyessenov authored Jan 25, 2019
1 parent f1d16b3 commit 77cedfa
Show file tree
Hide file tree
Showing 11 changed files with 275 additions and 30 deletions.
3 changes: 3 additions & 0 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,9 @@ integration.docker: docker
docker run -it -e "XDS=ads" test -debug
docker run -it -e "XDS=xds" test -debug
docker run -it -e "XDS=rest" test -debug
docker run -it -e "XDS=ads" test -debug -tls
docker run -it -e "XDS=xds" test -debug -tls
docker run -it -e "XDS=rest" test -debug -tls

#-----------------
#-- code generaion
Expand Down
5 changes: 5 additions & 0 deletions pkg/cache/resource.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ import (
"github.com/gogo/protobuf/proto"

v2 "github.com/envoyproxy/go-control-plane/envoy/api/v2"
"github.com/envoyproxy/go-control-plane/envoy/api/v2/auth"
hcm "github.com/envoyproxy/go-control-plane/envoy/config/filter/network/http_connection_manager/v2"
"github.com/envoyproxy/go-control-plane/pkg/util"
)
Expand All @@ -35,6 +36,7 @@ const (
ClusterType = typePrefix + "Cluster"
RouteType = typePrefix + "RouteConfiguration"
ListenerType = typePrefix + "Listener"
SecretType = typePrefix + "auth.Secret"

// AnyType is used only by ADS
AnyType = ""
Expand All @@ -47,6 +49,7 @@ var (
ClusterType,
RouteType,
ListenerType,
SecretType,
}
)

Expand All @@ -61,6 +64,8 @@ func GetResourceName(res Resource) string {
return v.GetName()
case *v2.Listener:
return v.GetName()
case *auth.Secret:
return v.GetName()
default:
return ""
}
Expand Down
25 changes: 16 additions & 9 deletions pkg/cache/simple_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,13 @@ var (
cache.RouteType: []string{routeName},
cache.ListenerType: nil,
}

testTypes = []string{
cache.EndpointType,
cache.ClusterType,
cache.RouteType,
cache.ListenerType,
}
)

type logger struct {
Expand All @@ -81,7 +88,7 @@ func TestSnapshotCache(t *testing.T) {
case <-time.After(time.Second / 4):
}

for _, typ := range cache.ResponseTypes {
for _, typ := range testTypes {
t.Run(typ, func(t *testing.T) {
value, _ := c.CreateWatch(v2.DiscoveryRequest{TypeUrl: typ, ResourceNames: names[typ]})
select {
Expand All @@ -105,7 +112,7 @@ func TestSnapshotCacheFetch(t *testing.T) {
t.Fatal(err)
}

for _, typ := range cache.ResponseTypes {
for _, typ := range testTypes {
t.Run(typ, func(t *testing.T) {
resp, err := c.Fetch(context.Background(), v2.DiscoveryRequest{TypeUrl: typ, ResourceNames: names[typ]})
if err != nil || resp == nil {
Expand Down Expand Up @@ -133,13 +140,13 @@ func TestSnapshotCacheFetch(t *testing.T) {
func TestSnapshotCacheWatch(t *testing.T) {
c := cache.NewSnapshotCache(true, group{}, logger{t: t})
watches := make(map[string]chan cache.Response)
for _, typ := range cache.ResponseTypes {
for _, typ := range testTypes {
watches[typ], _ = c.CreateWatch(v2.DiscoveryRequest{TypeUrl: typ, ResourceNames: names[typ]})
}
if err := c.SetSnapshot(key, snapshot); err != nil {
t.Fatal(err)
}
for _, typ := range cache.ResponseTypes {
for _, typ := range testTypes {
t.Run(typ, func(t *testing.T) {
select {
case out := <-watches[typ]:
Expand All @@ -156,10 +163,10 @@ func TestSnapshotCacheWatch(t *testing.T) {
}

// open new watches with the latest version
for _, typ := range cache.ResponseTypes {
for _, typ := range testTypes {
watches[typ], _ = c.CreateWatch(v2.DiscoveryRequest{TypeUrl: typ, ResourceNames: names[typ], VersionInfo: version})
}
if count := c.GetStatusInfo(key).GetNumWatches(); count != len(cache.ResponseTypes) {
if count := c.GetStatusInfo(key).GetNumWatches(); count != len(testTypes) {
t.Errorf("watches should be created for the latest version: %d", count)
}

Expand All @@ -169,7 +176,7 @@ func TestSnapshotCacheWatch(t *testing.T) {
if err := c.SetSnapshot(key, snapshot2); err != nil {
t.Fatal(err)
}
if count := c.GetStatusInfo(key).GetNumWatches(); count != len(cache.ResponseTypes)-1 {
if count := c.GetStatusInfo(key).GetNumWatches(); count != len(testTypes)-1 {
t.Errorf("watches should be preserved for all but one: %d", count)
}

Expand Down Expand Up @@ -215,7 +222,7 @@ func TestConcurrentSetWatch(t *testing.T) {

func TestSnapshotCacheWatchCancel(t *testing.T) {
c := cache.NewSnapshotCache(true, group{}, logger{t: t})
for _, typ := range cache.ResponseTypes {
for _, typ := range testTypes {
_, cancel := c.CreateWatch(v2.DiscoveryRequest{TypeUrl: typ, ResourceNames: names[typ]})
cancel()
}
Expand All @@ -224,7 +231,7 @@ func TestSnapshotCacheWatchCancel(t *testing.T) {
t.Error("got 0, want status info for the node")
}

for _, typ := range cache.ResponseTypes {
for _, typ := range testTypes {
if count := c.GetStatusInfo(key).GetNumWatches(); count > 0 {
t.Errorf("watches should be released for %s", typ)
}
Expand Down
7 changes: 7 additions & 0 deletions pkg/cache/snapshot.go
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,9 @@ type Snapshot struct {

// Listeners are items in the LDS response payload.
Listeners Resources

// Secrets are items in the SDS response payload.
Secrets Resources
}

// NewSnapshot creates a snapshot from response types and a version.
Expand Down Expand Up @@ -117,6 +120,8 @@ func (s *Snapshot) GetResources(typ string) map[string]Resource {
return s.Routes.Items
case ListenerType:
return s.Listeners.Items
case SecretType:
return s.Secrets.Items
}
return nil
}
Expand All @@ -135,6 +140,8 @@ func (s *Snapshot) GetVersion(typ string) string {
return s.Routes.Version
case ListenerType:
return s.Listeners.Version
case SecretType:
return s.Secrets.Version
}
return ""
}
2 changes: 2 additions & 0 deletions pkg/server/gateway.go
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,8 @@ func (h *HTTPGateway) ServeHTTP(resp http.ResponseWriter, req *http.Request) {
typeURL = cache.ListenerType
case "/v2/discovery:routes":
typeURL = cache.RouteType
case "/v2/discovery:secrets":
typeURL = cache.SecretType
default:
http.Error(resp, "no endpoint", http.StatusNotFound)
return
Expand Down
34 changes: 34 additions & 0 deletions pkg/server/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ type Server interface {
v2.RouteDiscoveryServiceServer
v2.ListenerDiscoveryServiceServer
discovery.AggregatedDiscoveryServiceServer
discovery.SecretDiscoveryServiceServer

// Fetch is the universal fetch method.
Fetch(context.Context, *v2.DiscoveryRequest) (*v2.DiscoveryResponse, error)
Expand Down Expand Up @@ -89,16 +90,19 @@ type watches struct {
clusters chan cache.Response
routes chan cache.Response
listeners chan cache.Response
secrets chan cache.Response

endpointCancel func()
clusterCancel func()
routeCancel func()
listenerCancel func()
secretCancel func()

endpointNonce string
clusterNonce string
routeNonce string
listenerNonce string
secretNonce string
}

// Cancel all watches
Expand All @@ -115,6 +119,9 @@ func (values watches) Cancel() {
if values.listenerCancel != nil {
values.listenerCancel()
}
if values.secretCancel != nil {
values.secretCancel()
}
}

func createResponse(resp *cache.Response, typeURL string) (*v2.DiscoveryResponse, error) {
Expand Down Expand Up @@ -223,6 +230,16 @@ func (s *server) process(stream stream, reqCh <-chan *v2.DiscoveryRequest, defau
}
values.listenerNonce = nonce

case resp, more := <-values.secrets:
if !more {
return status.Errorf(codes.Unavailable, "secrets watch failed")
}
nonce, err := send(resp, cache.SecretType)
if err != nil {
return err
}
values.secretNonce = nonce

case req, more := <-reqCh:
// input stream ended or errored out
if !more {
Expand Down Expand Up @@ -270,6 +287,11 @@ func (s *server) process(stream stream, reqCh <-chan *v2.DiscoveryRequest, defau
values.listenerCancel()
}
values.listeners, values.listenerCancel = s.cache.CreateWatch(*req)
case req.TypeUrl == cache.SecretType && (values.secretNonce == "" || values.secretNonce == nonce):
if values.secretCancel != nil {
values.secretCancel()
}
values.secrets, values.secretCancel = s.cache.CreateWatch(*req)
}
}
}
Expand Down Expand Up @@ -323,6 +345,10 @@ func (s *server) StreamListeners(stream v2.ListenerDiscoveryService_StreamListen
return s.handler(stream, cache.ListenerType)
}

func (s *server) StreamSecrets(stream discovery.SecretDiscoveryService_StreamSecretsServer) error {
return s.handler(stream, cache.SecretType)
}

// Fetch is the universal fetch method.
func (s *server) Fetch(ctx context.Context, req *v2.DiscoveryRequest) (*v2.DiscoveryResponse, error) {
if s.callbacks != nil {
Expand Down Expand Up @@ -373,6 +399,14 @@ func (s *server) FetchListeners(ctx context.Context, req *v2.DiscoveryRequest) (
return s.Fetch(ctx, req)
}

func (s *server) FetchSecrets(ctx context.Context, req *v2.DiscoveryRequest) (*v2.DiscoveryResponse, error) {
if req == nil {
return nil, status.Errorf(codes.Unavailable, "empty request")
}
req.TypeUrl = cache.SecretType
return s.Fetch(ctx, req)
}

func (s *server) IncrementalAggregatedResources(_ discovery.AggregatedDiscoveryService_IncrementalAggregatedResourcesServer) error {
return errors.New("not implemented")
}
Expand Down
24 changes: 15 additions & 9 deletions pkg/server/server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -165,10 +165,16 @@ var (
Id: "test-id",
Cluster: "test-cluster",
}
endpoint = resource.MakeEndpoint(clusterName, 8080)
cluster = resource.MakeCluster(resource.Ads, clusterName)
route = resource.MakeRoute(routeName, clusterName)
listener = resource.MakeHTTPListener(resource.Ads, listenerName, 80, routeName)
endpoint = resource.MakeEndpoint(clusterName, 8080)
cluster = resource.MakeCluster(resource.Ads, clusterName)
route = resource.MakeRoute(routeName, clusterName)
listener = resource.MakeHTTPListener(resource.Ads, listenerName, 80, routeName)
testTypes = []string{
cache.EndpointType,
cache.ClusterType,
cache.RouteType,
cache.ListenerType,
}
)

func makeResponses() map[string][]cache.Response {
Expand All @@ -193,7 +199,7 @@ func makeResponses() map[string][]cache.Response {
}

func TestResponseHandlers(t *testing.T) {
for _, typ := range cache.ResponseTypes {
for _, typ := range testTypes {
t.Run(typ, func(t *testing.T) {
config := makeMockConfigWatcher()
config.responses = makeResponses()
Expand Down Expand Up @@ -304,7 +310,7 @@ func TestFetch(t *testing.T) {
}

func TestWatchClosed(t *testing.T) {
for _, typ := range cache.ResponseTypes {
for _, typ := range testTypes {
t.Run(typ, func(t *testing.T) {
config := makeMockConfigWatcher()
config.closeWatch = true
Expand All @@ -328,7 +334,7 @@ func TestWatchClosed(t *testing.T) {
}

func TestSendError(t *testing.T) {
for _, typ := range cache.ResponseTypes {
for _, typ := range testTypes {
t.Run(typ, func(t *testing.T) {
config := makeMockConfigWatcher()
config.responses = makeResponses()
Expand All @@ -353,7 +359,7 @@ func TestSendError(t *testing.T) {
}

func TestStaleNonce(t *testing.T) {
for _, typ := range cache.ResponseTypes {
for _, typ := range testTypes {
t.Run(typ, func(t *testing.T) {
config := makeMockConfigWatcher()
config.responses = makeResponses()
Expand Down Expand Up @@ -466,7 +472,7 @@ func TestAggregateRequestType(t *testing.T) {
}

func TestCallbackError(t *testing.T) {
for _, typ := range cache.ResponseTypes {
for _, typ := range testTypes {
t.Run(typ, func(t *testing.T) {
config := makeMockConfigWatcher()
config.responses = makeResponses()
Expand Down
13 changes: 12 additions & 1 deletion pkg/test/main/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ package main

import (
"context"
cryptotls "crypto/tls"
"flag"
"fmt"
"io/ioutil"
Expand Down Expand Up @@ -51,6 +52,7 @@ var (
clusters int
httpListeners int
tcpListeners int
tls bool

nodeID string
)
Expand All @@ -70,6 +72,7 @@ func init() {
flag.IntVar(&httpListeners, "http", 2, "Number of HTTP listeners (and RDS configs)")
flag.IntVar(&tcpListeners, "tcp", 2, "Number of TCP pass-through listeners")
flag.StringVar(&nodeID, "nodeID", "test-id", "Node ID")
flag.BoolVar(&tls, "tls", false, "Enable TLS on all listeners and use SDS for secret delivery")
}

// main returns code 1 if any of the batches failed to pass all requests
Expand Down Expand Up @@ -98,6 +101,7 @@ func main() {
NumClusters: clusters,
NumHTTPListeners: httpListeners,
NumTCPListeners: tcpListeners,
TLS: tls,
}

// start the xDS server
Expand Down Expand Up @@ -164,8 +168,15 @@ func callEcho() (int, int) {
go func(i int) {
client := http.Client{
Timeout: 100 * time.Millisecond,
Transport: &http.Transport{
TLSClientConfig: &cryptotls.Config{InsecureSkipVerify: true},
},
}
req, err := client.Get(fmt.Sprintf("http://localhost:%d", basePort+uint(i)))
proto := "http"
if tls {
proto = "https"
}
req, err := client.Get(fmt.Sprintf("%s://127.0.0.1:%d", proto, basePort+uint(i)))
if err != nil {
ch <- err
return
Expand Down
Loading

0 comments on commit 77cedfa

Please sign in to comment.