diff --git a/cmd/snaptel/commands.go b/cmd/snaptel/commands.go index 3bf9b1181..6c2c22c16 100644 --- a/cmd/snaptel/commands.go +++ b/cmd/snaptel/commands.go @@ -101,12 +101,13 @@ var ( Subcommands: []cli.Command{ { Name: "load", - Usage: "load [--plugin-cert= --plugin-key=]", + Usage: "load [--plugin-cert= --plugin-key= --plugin-root-certs=]", Action: loadPlugin, Flags: []cli.Flag{ flPluginAsc, flPluginCert, flPluginKey, + flPluginRootCerts, }, }, { @@ -116,7 +117,7 @@ var ( }, { Name: "swap", - Usage: "swap :: or swap -t -n -v [--plugin-cert= --plugin-key=]", + Usage: "swap :: or swap -t -n -v [--plugin-cert= --plugin-key= --plugin-root-certs=]", Action: swapPlugins, Flags: []cli.Flag{ flPluginAsc, @@ -125,6 +126,7 @@ var ( flPluginVersion, flPluginCert, flPluginKey, + flPluginRootCerts, }, }, { diff --git a/cmd/snaptel/flags.go b/cmd/snaptel/flags.go index c59b71d54..fd764a226 100644 --- a/cmd/snaptel/flags.go +++ b/cmd/snaptel/flags.go @@ -79,6 +79,10 @@ var ( Name: "plugin-key, k", Usage: "The plugin key", } + flPluginRootCerts = cli.StringFlag{ + Name: "plugin-root-certs, r", + Usage: "List of root cert paths for TLS to use (folder/file)", + } flPluginType = cli.StringFlag{ Name: "plugin-type, t", Usage: "The plugin type", diff --git a/cmd/snaptel/plugin.go b/cmd/snaptel/plugin.go index a0a600b1f..20d9b9691 100644 --- a/cmd/snaptel/plugin.go +++ b/cmd/snaptel/plugin.go @@ -205,12 +205,13 @@ func listPlugins(ctx *cli.Context) error { return nil } -// storeTLSPaths extracts paths related to TLS (certificate, key) from command -// line context into temporary files. Those files are appended to list of paths -// returned from this function. +// storeTLSPaths extracts paths related to TLS (certificate, key, root certs) +// from command line context into temporary files. Those files are appended to +// list of paths returned from this function. func storeTLSPaths(ctx *cli.Context, paths []string) ([]string, error) { pCert := ctx.String("plugin-cert") pKey := ctx.String("plugin-key") + pRootCertPaths := ctx.String("plugin-root-certs") if pCert != pKey && (pCert == "" || pKey == "") { return paths, fmt.Errorf("Error processing plugin TLS arguments - one of (certificate, key) arguments is missing") } @@ -236,5 +237,16 @@ func storeTLSPaths(ctx *cli.Context, paths []string) ([]string, error) { } paths = append(paths, tmpFile.Name()) } + if pRootCertPaths != "" { + tmpFile, err := ioutil.TempFile("", v1.TLSRootCertsPrefix) + if err != nil { + return paths, fmt.Errorf("Error processing plugin TLS root certificates - unable to create link:\n%v", err.Error()) + } + _, err = tmpFile.WriteString(pRootCertPaths) + if err != nil { + return paths, fmt.Errorf("Error processing plugin TLS root certificates - unable to write link:\n%v", err.Error()) + } + paths = append(paths, tmpFile.Name()) + } return paths, nil } diff --git a/control/config.go b/control/config.go index 1ae34c31c..529fc99bf 100644 --- a/control/config.go +++ b/control/config.go @@ -48,6 +48,9 @@ var ( defaultCacheExpiration = 500 * time.Millisecond defaultPprof = false defaultTempDirPath = os.TempDir() + defaultTLSCertPath = "" + defaultTLSKeyPath = "" + defaultRootCertPaths = "" ) type pluginConfig struct { @@ -88,6 +91,7 @@ type Config struct { TempDirPath string `json:"temp_dir_path"yaml:"temp_dir_path"` TLSCertPath string `json:"tls_cert_path"yaml:"tls_cert_path"` TLSKeyPath string `json:"tls_key_path"yaml:"tls_key_path"` + RootCertPaths string `json:"root_cert_paths"yaml:"root_cert_paths"` } const ( @@ -148,6 +152,9 @@ const ( }, "tls_key_path": { "type": "string" + }, + "root_cert_paths": { + "type": "string" } }, "additionalProperties": false @@ -171,6 +178,9 @@ func GetDefaultConfig() *Config { Pprof: defaultPprof, MaxPluginRestarts: MaxPluginRestartCount, TempDirPath: defaultTempDirPath, + TLSCertPath: defaultTLSCertPath, + TLSKeyPath: defaultTLSKeyPath, + RootCertPaths: defaultRootCertPaths, } } diff --git a/control/control.go b/control/control.go index c629c2020..236f24f11 100644 --- a/control/control.go +++ b/control/control.go @@ -32,13 +32,12 @@ import ( "sync" "time" - "google.golang.org/grpc" - log "github.com/Sirupsen/logrus" - "github.com/intelsdi-x/gomit" + "google.golang.org/grpc" "github.com/intelsdi-x/snap/control/plugin" + "github.com/intelsdi-x/snap/control/plugin/client" "github.com/intelsdi-x/snap/control/strategy" "github.com/intelsdi-x/snap/core" "github.com/intelsdi-x/snap/core/cdata" @@ -92,6 +91,7 @@ type pluginControl struct { wg sync.WaitGroup subscriptionGroups ManagesSubscriptionGroups + grpcSecurity client.GRPCSecurity } type subscribedPlugin struct { @@ -223,7 +223,15 @@ func New(cfg *Config) *pluginControl { OptSetTempDirPath(cfg.TempDirPath), } if cfg.IsTLSEnabled() { - managerOpts = append(managerOpts, OptEnableManagerTLS(cfg.TLSCertPath, cfg.TLSKeyPath)) + if cfg.RootCertPaths != "" { + certPaths := filepath.SplitList(cfg.RootCertPaths) + c.grpcSecurity = client.SecurityTLSExtended(cfg.TLSCertPath, cfg.TLSKeyPath, client.SecureClient, certPaths) + } else { + c.grpcSecurity = client.SecurityTLSEnabled(cfg.TLSCertPath, cfg.TLSKeyPath, client.SecureClient) + } + } + if cfg.IsTLSEnabled() { + managerOpts = append(managerOpts, OptEnableManagerTLS(c.grpcSecurity)) } c.pluginManager = newPluginManager(managerOpts...) controlLogger.WithFields(log.Fields{ @@ -240,7 +248,7 @@ func New(cfg *Config) *pluginControl { // Plugin Runner if cfg.IsTLSEnabled() { - c.pluginRunner = newRunner(OptEnableRunnerTLS(cfg.TLSCertPath, cfg.TLSKeyPath)) + c.pluginRunner = newRunner(OptEnableRunnerTLS(c.grpcSecurity)) } else { c.pluginRunner = newRunner() } @@ -596,6 +604,7 @@ func (p *pluginControl) returnPluginDetails(rp *core.RequestedPlugin) (*pluginDe details.Signature = rp.Signature() details.CertPath = rp.CertPath() details.KeyPath = rp.KeyPath() + details.RootCertPaths = rp.RootCertPaths() details.TLSEnabled = rp.TLSEnabled() if filepath.Ext(rp.Path()) == ".aci" { diff --git a/control/control_security_test.go b/control/control_security_test.go new file mode 100644 index 000000000..839247910 --- /dev/null +++ b/control/control_security_test.go @@ -0,0 +1,526 @@ +// +build medium + +/* +http://www.apache.org/licenses/LICENSE-2.0.txt + + +Copyright 2017 Intel Corporation + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package control + +import ( + "fmt" + "io/ioutil" + "math/rand" + "os" + "path/filepath" + "strings" + "testing" + "time" + + log "github.com/Sirupsen/logrus" + "github.com/intelsdi-x/gomit" + . "github.com/smartystreets/goconvey/convey" + + "github.com/intelsdi-x/snap/control/fixtures" + "github.com/intelsdi-x/snap/control/plugin" + "github.com/intelsdi-x/snap/control/plugin/client" + "github.com/intelsdi-x/snap/core" + "github.com/intelsdi-x/snap/core/cdata" + "github.com/intelsdi-x/snap/core/ctypes" + "github.com/intelsdi-x/snap/plugin/helper" +) + +type MockEmitter struct{} + +const ( + tlsTestCAFn = "snaptest-CA" + tlsTestSrvFn = "snaptest-srv" + tlsTestCliFn = "snaptest-cli" +) + +var tlsTestCA, tlsTestSrv, tlsTestCli string + +var testFilesToRemove []string + +func (memitter *MockEmitter) Emit(gomit.EventBody) (int, error) { return 0, nil } + +type configTLSMock Config + +func (m *configTLSMock) export() *Config { + return (*Config)(m) +} + +func (m *configTLSMock) setTLSCertPath(certPath string) *configTLSMock { + m.TLSCertPath = certPath + return m +} + +func (m *configTLSMock) setTLSKeyPath(keyPath string) *configTLSMock { + m.TLSKeyPath = keyPath + return m +} + +func TestMain(m *testing.M) { + setUpTestMain() + retCode := m.Run() + tearDownTestMain() + os.Exit(retCode) +} + +func TestSecureCollector(t *testing.T) { + log.SetLevel(log.DebugLevel) + Convey("Having a secure collector", t, func() { + var ap *availablePlugin + Convey("framework should establish secure connection", func() { + security := client.SecurityTLSExtended(tlsTestCli+fixtures.TestCrtFileExt, tlsTestCli+fixtures.TestKeyFileExt, client.SecureClient, []string{tlsTestCA + fixtures.TestCrtFileExt}) + var err error + ap, err = runPlugin(plugin.Arg{}. + SetCertPath(tlsTestSrv+fixtures.TestCrtFileExt). + SetKeyPath(tlsTestSrv+fixtures.TestKeyFileExt). + SetRootCertPaths(tlsTestCA+fixtures.TestCrtFileExt). + SetTLSEnabled(true), helper.PluginFilePath("snap-plugin-collector-mock2-grpc"), + security) + So(err, ShouldBeNil) + Convey("and valid plugin client should be obtained", func() { + cli, isCollector := ap.client.(client.PluginCollectorClient) + So(isCollector, ShouldBeTrue) + Convey("Ping should not fail", func() { + err := cli.Ping() + So(err, ShouldBeNil) + }) + Convey("GetConfigPolicy should not fail", func() { + _, err := cli.GetConfigPolicy() + So(err, ShouldBeNil) + }) + Convey("GetMetricTypes should not fail", func() { + cfg := plugin.ConfigType{ConfigDataNode: cdata.NewNode()} + _, err := cli.GetMetricTypes(cfg) + So(err, ShouldBeNil) + }) + Convey("CollectMetrics should not fail", func() { + _, err := cli.CollectMetrics([]core.Metric{}) + So(err, ShouldBeNil) + }) + }) + Reset(func() { + ap.Kill("end-of-test") + }) + }) + }) +} + +func TestSecureProcessor(t *testing.T) { + log.SetLevel(log.DebugLevel) + Convey("Having a secure processor", t, func() { + var ap *availablePlugin + Convey("framework should establish secure connection", func() { + security := client.SecurityTLSExtended(tlsTestCli+fixtures.TestCrtFileExt, tlsTestCli+fixtures.TestKeyFileExt, client.SecureClient, []string{tlsTestCA + fixtures.TestCrtFileExt}) + var err error + ap, err = runPlugin(plugin.Arg{}. + SetCertPath(tlsTestSrv+fixtures.TestCrtFileExt). + SetKeyPath(tlsTestSrv+fixtures.TestKeyFileExt). + SetRootCertPaths(tlsTestCA+fixtures.TestCrtFileExt). + SetTLSEnabled(true), helper.PluginFilePath("snap-plugin-processor-passthru-grpc"), + security) + So(err, ShouldBeNil) + Convey("and valid plugin client should be obtained", func() { + cli, isProcessor := ap.client.(client.PluginProcessorClient) + So(isProcessor, ShouldBeTrue) + Convey("Ping should not fail", func() { + err := cli.Ping() + So(err, ShouldBeNil) + }) + Convey("GetConfigPolicy should not fail", func() { + _, err := cli.GetConfigPolicy() + So(err, ShouldBeNil) + }) + Convey("Process should not fail", func() { + cfg := map[string]ctypes.ConfigValue{} + _, err := cli.Process([]core.Metric{}, cfg) + So(err, ShouldBeNil) + }) + }) + Reset(func() { + ap.Kill("end-of-test") + }) + }) + }) +} + +func TestSecurePublisher(t *testing.T) { + log.SetLevel(log.DebugLevel) + Convey("Having a secure publisher", t, func() { + var ap *availablePlugin + Convey("framework should establish secure connection", func() { + security := client.SecurityTLSExtended(tlsTestCli+fixtures.TestCrtFileExt, tlsTestCli+fixtures.TestKeyFileExt, client.SecureClient, []string{tlsTestCA + fixtures.TestCrtFileExt}) + var err error + ap, err = runPlugin(plugin.NewArg(int(log.DebugLevel), false). + SetCertPath(tlsTestSrv+fixtures.TestCrtFileExt). + SetKeyPath(tlsTestSrv+fixtures.TestKeyFileExt). + SetRootCertPaths(tlsTestCA+fixtures.TestCrtFileExt). + SetTLSEnabled(true), helper.PluginFilePath("snap-plugin-publisher-mock-file-grpc"), + security) + So(err, ShouldBeNil) + Convey("and valid plugin client should be obtained", func() { + cli, isPublisher := ap.client.(client.PluginPublisherClient) + So(isPublisher, ShouldBeTrue) + Convey("Ping should not fail", func() { + err := cli.Ping() + So(err, ShouldBeNil) + }) + Convey("GetConfigPolicy should not fail", func() { + _, err := cli.GetConfigPolicy() + So(err, ShouldBeNil) + }) + Convey("Publish should not fail", func() { + cfg := map[string]ctypes.ConfigValue{} + tf, err := ioutil.TempFile("", "mock-file-publisher-output") + if err != nil { + panic(err) + } + testFilesToRemove = append(testFilesToRemove, tf.Name()) + tf.Close() + cfg["file"] = ctypes.ConfigValueStr{Value: tf.Name()} + err = cli.Publish([]core.Metric{}, cfg) + So(err, ShouldBeNil) + }) + }) + Reset(func() { + ap.Kill("end-of-test") + }) + }) + }) +} + +func TestSecureStreamingCollector(t *testing.T) { + log.SetLevel(log.DebugLevel) + Convey("Having a secure streaming collector", t, func() { + var ap *availablePlugin + Convey("framework should establish secure connection", func() { + security := client.SecurityTLSExtended(tlsTestCli+fixtures.TestCrtFileExt, tlsTestCli+fixtures.TestKeyFileExt, client.SecureClient, []string{tlsTestCA + fixtures.TestCrtFileExt}) + var err error + ap, err = runPlugin(plugin.Arg{}. + SetCertPath(tlsTestSrv+fixtures.TestCrtFileExt). + SetKeyPath(tlsTestSrv+fixtures.TestKeyFileExt). + SetRootCertPaths(tlsTestCA+fixtures.TestCrtFileExt). + SetTLSEnabled(true), helper.PluginFilePath("snap-plugin-stream-collector-rand1"), + security) + So(err, ShouldBeNil) + Convey("and valid plugin client should be obtained", func() { + cli, isStreamer := ap.client.(client.PluginStreamCollectorClient) + So(isStreamer, ShouldBeTrue) + Convey("Ping should not fail", func() { + err := cli.Ping() + So(err, ShouldBeNil) + }) + Convey("GetConfigPolicy should not fail", func() { + _, err := cli.GetConfigPolicy() + So(err, ShouldBeNil) + }) + Convey("GetMetricTypes should not fail", func() { + cfg := plugin.ConfigType{ConfigDataNode: cdata.NewNode()} + _, err := cli.GetMetricTypes(cfg) + So(err, ShouldBeNil) + }) + Convey("StreamMetrics should not fail", func() { + cli.UpdateCollectDuration(time.Second) + cli.UpdateMetricsBuffer(1) + mtsin := []core.Metric{} + m := plugin.MetricType{Namespace_: core.NewNamespace(strings.Fields("a b integer")...)} + mtsin = append(mtsin, m) + mch, errch, err := cli.StreamMetrics(mtsin) + So(err, ShouldBeNil) + Convey("streaming should deliver metrics rather than error", func() { + select { + case mtsout := <-mch: + So(mtsout, ShouldNotBeNil) + break + case err := <-errch: + t.Fatal(err) + case <-time.After(5 * time.Second): + t.Fatal("failed to receive response from stream collector") + } + }) + }) + Convey("UpdateCollectedMetrics should not fail", func() { + err := cli.UpdateCollectedMetrics([]core.Metric{}) + So(err, ShouldBeNil) + }) + Convey("UpdateCollectDuration should not fail", func() { + err := cli.UpdateCollectDuration(5 * time.Second) + So(err, ShouldBeNil) + err = cli.UpdateCollectDuration(1 * time.Second) + So(err, ShouldBeNil) + }) + Convey("UpdateMetricsBuffer should not fail", func() { + err := cli.UpdateMetricsBuffer(100) + So(err, ShouldBeNil) + err = cli.UpdateMetricsBuffer(10) + So(err, ShouldBeNil) + }) + }) + Reset(func() { + ap.Kill("end-of-test") + }) + }) + }) +} + +func TestInsecureConfigurationFails(t *testing.T) { + log.SetLevel(log.DebugLevel) + tcs := []struct { + name string + msg func(*plugin.Arg, *client.GRPCSecurity, func(string)) + }{ + { + name: "SecureFrameworkInsecurePlugin_Fail", + msg: func(srv *plugin.Arg, cli *client.GRPCSecurity, f func(string)) { + // note: server root certs are used to validate client certs (and vice versa) + *srv = srv.SetTLSEnabled(false) + f("Attempting TLS connection between secure framework and insecure plugin") + }, + }, + { + name: "InvalidPluginCertKeyPair_Fail", + msg: func(srv *plugin.Arg, cli *client.GRPCSecurity, f func(string)) { + *srv = srv.SetCertPath(tlsTestCA + fixtures.TestCrtFileExt) + f("Attempting TLS connection between secure framework and plugin with invalid cert-key pair") + }, + }, + { + name: "BadRootCertInFramework_Fail", + msg: func(srv *plugin.Arg, cli *client.GRPCSecurity, f func(string)) { + cli.RootCertPaths = []string{tlsTestCA + fixtures.TestBadCrtFileExt} + f("Attempting TLS connection between framework and plugin using incompatible root certs, bad cert in framework") + }, + }, + { + name: "PluginRootCertsUnknownToFramework_Fail", + msg: func(srv *plugin.Arg, cli *client.GRPCSecurity, f func(string)) { + cli.RootCertPaths = []string{} + f("Attempting TLS connection between secure framework and plugin with certificate without root certs known to framework") + }, + }, + { + name: "InsecureFrameworkSecurePlugin_Fail", + msg: func(srv *plugin.Arg, cli *client.GRPCSecurity, f func(string)) { + cli.TLSEnabled = false + f("Attempting TLS connection between insecure framework and secure plugin") + }, + }, + { + name: "InvalidFrameworkCertKeyPair_Fail", + msg: func(srv *plugin.Arg, cli *client.GRPCSecurity, f func(string)) { + cli.TLSCertPath = tlsTestCA + fixtures.TestCrtFileExt + f("Attempting TLS connection between invalid framework with invalid cert-key pair and secure plugin") + }, + }, + { + name: "FrameworkRootCertsUnknownToPlugin_Fail", + msg: func(srv *plugin.Arg, cli *client.GRPCSecurity, f func(string)) { + srv.RootCertPaths = "" + f("Attempting TLS connection between invalid framework with no root certs known to plugin and secure plugin") + }, + }, + { + name: "BadRootCertInPlugin_Fail", + msg: func(srv *plugin.Arg, cli *client.GRPCSecurity, f func(string)) { + srv.RootCertPaths = tlsTestCA + fixtures.TestBadCrtFileExt + f("Attempting TLS connection between framework and plugin using incompatible root certs, bad cert in plugin") + }, + }, + } + for _, tc := range tcs { + security := client.SecurityTLSExtended(tlsTestCli+fixtures.TestCrtFileExt, tlsTestCli+fixtures.TestKeyFileExt, client.SecureClient, []string{tlsTestCA + fixtures.TestCrtFileExt}) + pluginArgs := plugin.Arg{}. + SetCertPath(tlsTestSrv + fixtures.TestCrtFileExt). + SetKeyPath(tlsTestSrv + fixtures.TestKeyFileExt). + SetRootCertPaths(tlsTestCA + fixtures.TestCrtFileExt). + SetTLSEnabled(true) + runThisCase := func(f func(msg string)) { + t.Run(tc.name, func(_ *testing.T) { + tc.msg(&pluginArgs, &security, f) + }) + } + runThisCase(func(msg string) { + Convey(msg, t, func() { + var ap *availablePlugin + Convey("should fail", func() { + So(func() { + var err error + ap, err = runPlugin(pluginArgs, helper.PluginFilePath("snap-plugin-collector-mock2-grpc"), + security) + // currently grpc may not return error immediately; attempt to ping + if err != nil { + panic(err) + } + cli, isCollector := ap.client.(client.PluginCollectorClient) + So(isCollector, ShouldBeTrue) + err = cli.Ping() + if err != nil { + panic(err) + } + }, ShouldPanic) + }) + Reset(func() { + if ap != nil { + ap.Kill("end-of-test") + } + }) + }) + }) + } +} + +func (m *configTLSMock) setRootCertPaths(rootCertPaths string) *configTLSMock { + m.RootCertPaths = rootCertPaths + return m +} + +func TestSecuritySetupFromConfig(t *testing.T) { + var ( + fakeSampleCert = "/fake-samples/certs/server-cert" + fakeSampleKey = "/fake-samples/keys/server-key" + fakeSampleRootCertsSplit = []string{"/fake-samples/root-ca/ca-one", "/fake-samples/root-ca/ca-two"} + fakeSampleRootCerts = strings.Join(fakeSampleRootCertsSplit, string(filepath.ListSeparator)) + ) + tcs := []struct { + name string + msg func(func(string)) + cfg *Config + wantError bool + wantRunnersec client.GRPCSecurity + wantManagersec client.GRPCSecurity + }{ + { + name: "DefaultEmptyConfig", + msg: func(f func(string)) { + f("passing default (empty) config values, initialization should succeed and result in security disabled") + }, + cfg: GetDefaultConfig(), + wantError: false, + wantRunnersec: client.SecurityTLSOff(), + wantManagersec: client.SecurityTLSOff(), + }, + { + name: "TLSEnabledForwardedToSubmodules", + msg: func(f func(string)) { + f("having TLS enabled in config, plugin runner and manager receive same security values") + }, + cfg: (*configTLSMock)(GetDefaultConfig()). + setTLSCertPath(fakeSampleCert). + setTLSKeyPath(fakeSampleKey). + export(), + wantError: false, + wantRunnersec: client.SecurityTLSEnabled(fakeSampleCert, fakeSampleKey, client.SecureClient), + wantManagersec: client.SecurityTLSEnabled(fakeSampleCert, fakeSampleKey, client.SecureClient), + }, + { + name: "TLSEnabledRootCertsForwardedToSubmodules", + msg: func(f func(string)) { + f("having TLS enabled with root cert paths in config, plugin runner and manager receive same security values") + }, + cfg: (*configTLSMock)(GetDefaultConfig()). + setTLSCertPath(fakeSampleCert). + setTLSKeyPath(fakeSampleKey). + setRootCertPaths(fakeSampleRootCerts). + export(), + wantError: false, + wantRunnersec: client.SecurityTLSExtended(fakeSampleCert, fakeSampleKey, client.SecureClient, fakeSampleRootCertsSplit), + wantManagersec: client.SecurityTLSExtended(fakeSampleCert, fakeSampleKey, client.SecureClient, fakeSampleRootCertsSplit), + }, + } + var gotRunner *runner + var gotManager *pluginManager + + for _, tc := range tcs { + oldRunnerOpts, oldManagerOpts := append([]pluginRunnerOpt{}, defaultRunnerOpts...), append([]pluginManagerOpt{}, defaultManagerOpts...) + defaultRunnerOpts = append(defaultRunnerOpts, func(r *runner) { + gotRunner = r + }) + defaultManagerOpts = append(defaultManagerOpts, func(m *pluginManager) { + gotManager = m + }) + runThisCase := func(f func(msg string)) { + t.Run(tc.name, func(_ *testing.T) { + Convey("Initializing plugin control module", t, func() { + tc.msg(f) + }) + }) + } + runThisCase(func(msg string) { + Convey(msg, func() { + if tc.wantError { + So(func() { + New(tc.cfg) + }, ShouldPanic) + return + } + So(func() { + New(tc.cfg) + }, ShouldNotPanic) + So(gotRunner.grpcSecurity, ShouldResemble, tc.wantRunnersec) + So(gotManager.grpcSecurity, ShouldResemble, tc.wantManagersec) + }) + Reset(func() { + defaultRunnerOpts, defaultManagerOpts = oldRunnerOpts, oldManagerOpts + }) + }) + } +} + +func runPlugin(args plugin.Arg, pluginPath string, security client.GRPCSecurity) (*availablePlugin, error) { + ep, err := fixtures.NewExecutablePlugin(args, pluginPath) + if err != nil { + panic(err) + } + var r *runner + if security.TLSEnabled { + r = newRunner(OptEnableRunnerTLS(security)) + } else { + r = newRunner() + } + r.SetEmitter(new(MockEmitter)) + ap, err := r.startPlugin(ep) + if err != nil { + return nil, err + } + return ap, nil +} + +func setUpTestMain() { + rand.Seed(time.Now().Unix()) + cwd, err := os.Getwd() + if err != nil { + panic(fmt.Errorf("unable to reach current dir for generating TLS certificates: %v", err)) + } + u := fixtures.CertTestUtil{Prefix: cwd} + if tlsTestFiles, err := u.StoreTLSCerts(tlsTestCAFn, tlsTestSrvFn, tlsTestCliFn); err != nil { + panic(err) + } else { + testFilesToRemove = append(testFilesToRemove, tlsTestFiles...) + } + tlsTestCA = filepath.Join(cwd, tlsTestCAFn) + tlsTestSrv = filepath.Join(cwd, tlsTestSrvFn) + tlsTestCli = filepath.Join(cwd, tlsTestCliFn) +} + +func tearDownTestMain() { + for _, fn := range testFilesToRemove { + os.Remove(fn) + } +} diff --git a/control/fixtures/fixtures.go b/control/fixtures/fixtures.go index 9fe5a3ec0..c622d0256 100644 --- a/control/fixtures/fixtures.go +++ b/control/fixtures/fixtures.go @@ -24,6 +24,7 @@ import ( "encoding/json" "time" + "github.com/intelsdi-x/snap/control/plugin" "github.com/intelsdi-x/snap/core" "github.com/intelsdi-x/snap/core/cdata" "github.com/intelsdi-x/snap/plugin/helper" @@ -155,3 +156,20 @@ func (m MockRequestedMetric) Version() int { func (m MockRequestedMetric) Namespace() core.Namespace { return m.namespace } + +func NewExecutablePlugin(a plugin.Arg, path string) (*plugin.ExecutablePlugin, error) { + // Travis optimization: Try starting the plugin three times before finally + // returning an error + var e error + var ep *plugin.ExecutablePlugin + for i := 0; i < 3; i++ { + ep, e = plugin.NewExecutablePlugin(a, path) + if e == nil { + break + } + if e != nil && i == 2 { + return nil, e + } + } + return ep, nil +} diff --git a/control/fixtures/tls_cert_util.go b/control/fixtures/tls_cert_util.go new file mode 100644 index 000000000..f9e77db13 --- /dev/null +++ b/control/fixtures/tls_cert_util.go @@ -0,0 +1,196 @@ +// +build legacy small medium large + +/* +http://www.apache.org/licenses/LICENSE-2.0.txt + + +Copyright 2017 Intel Corporation + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package fixtures + +import ( + "bufio" + "crypto/rand" + "crypto/rsa" + "crypto/sha256" + "crypto/x509" + "crypto/x509/pkix" + "encoding/pem" + "fmt" + "math/big" + mrand "math/rand" + "net" + "os" + "path/filepath" + "strings" + "time" +) + +const ( + TestCrtFileExt = ".crt" + TestBadCrtFileExt = "-BAD.crt" + TestKeyFileExt = ".key" +) + +const ( + keyBitsDefault = 2048 + defaultKeyValidPeriod = 6 * time.Hour + rsaKeyPEMHeader = "RSA PRIVATE KEY" + certificatePEMHeader = "CERTIFICATE" + defaultSignatureAlgorithm = x509.SHA256WithRSA + defaultPublicKeyAlgorithm = x509.RSA +) + +// CertTestUtil offers a few methods to generate a few self-signed certificates +// suitable only for test. +type CertTestUtil struct { + Prefix string +} + +func (u CertTestUtil) WritePEMFile(fn string, pemHeader string, b []byte) error { + f, err := os.Create(fn) + if err != nil { + return err + } + defer f.Close() + w := bufio.NewWriter(f) + pem.Encode(w, &pem.Block{ + Type: pemHeader, + Bytes: b, + }) + w.Flush() + return nil +} + +func (u CertTestUtil) MakeCACertKeyPair(caName, ouName string, keyValidPeriod time.Duration) (caCertTpl *x509.Certificate, caCertBytes []byte, caPrivKey *rsa.PrivateKey, err error) { + caPrivKey, err = rsa.GenerateKey(rand.Reader, keyBitsDefault) + if err != nil { + return nil, nil, nil, err + } + caPubKey := caPrivKey.Public() + caPubBytes, err := x509.MarshalPKIXPublicKey(caPubKey) + if err != nil { + return nil, nil, nil, err + } + caPubSha256 := sha256.Sum256(caPubBytes) + caCertTpl = &x509.Certificate{ + SignatureAlgorithm: defaultSignatureAlgorithm, + PublicKeyAlgorithm: defaultPublicKeyAlgorithm, + Version: 3, + SerialNumber: big.NewInt(1), + Subject: pkix.Name{ + CommonName: caName, + }, + NotBefore: time.Now(), + NotAfter: time.Now().Add(keyValidPeriod), + KeyUsage: x509.KeyUsageCertSign | x509.KeyUsageCRLSign, + BasicConstraintsValid: true, + MaxPathLenZero: true, + IsCA: true, + SubjectKeyId: caPubSha256[:], + } + caCertBytes, err = x509.CreateCertificate(rand.Reader, caCertTpl, caCertTpl, caPubKey, caPrivKey) + if err != nil { + return nil, nil, nil, err + } + return caCertTpl, caCertBytes, caPrivKey, nil +} + +func (u CertTestUtil) MakeSubjCertKeyPair(cn, ou string, keyValidPeriod time.Duration, caCertTpl *x509.Certificate, caPrivKey *rsa.PrivateKey) (subjCertBytes []byte, subjPrivKey *rsa.PrivateKey, err error) { + subjPrivKey, err = rsa.GenerateKey(rand.Reader, keyBitsDefault) + if err != nil { + return nil, nil, err + } + subjPubBytes, err := x509.MarshalPKIXPublicKey(subjPrivKey.Public()) + if err != nil { + return nil, nil, err + } + subjPubSha256 := sha256.Sum256(subjPubBytes) + subjCertTpl := x509.Certificate{ + SignatureAlgorithm: defaultSignatureAlgorithm, + PublicKeyAlgorithm: defaultPublicKeyAlgorithm, + Version: 3, + SerialNumber: big.NewInt(1), + Subject: pkix.Name{ + OrganizationalUnit: []string{ou}, + CommonName: cn, + }, + NotBefore: time.Now(), + NotAfter: time.Now().Add(keyValidPeriod), + KeyUsage: x509.KeyUsageDigitalSignature | x509.KeyUsageKeyEncipherment | x509.KeyUsageDataEncipherment | x509.KeyUsageKeyAgreement, + ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageClientAuth, x509.ExtKeyUsageServerAuth}, + SubjectKeyId: subjPubSha256[:], + } + subjCertTpl.DNSNames = strings.Fields("localhost") + subjCertTpl.IPAddresses = []net.IP{net.ParseIP("127.0.0.1")} + subjCertBytes, err = x509.CreateCertificate(rand.Reader, &subjCertTpl, caCertTpl, subjPrivKey.Public(), caPrivKey) + return subjCertBytes, subjPrivKey, err +} + +// StoreTLSCerts builds a set of certificates and private keys for testing TLS. +// Generated files include: CA certificate, server certificate and private key, +// client certificate and private key, and alternate (BAD) CA certificate. +// Certificate and key files are named after given common names (e.g.: srvCN). +func (u CertTestUtil) StoreTLSCerts(caCN, srvCN, cliCN string) (resFiles []string, err error) { + ou := fmt.Sprintf("%06x", mrand.Intn(1<<24)) + caCertTpl, caCert, caPrivKey, err := u.MakeCACertKeyPair(caCN, ou, defaultKeyValidPeriod) + if err != nil { + return nil, err + } + caCertFn := filepath.Join(u.Prefix, caCN+TestCrtFileExt) + if err := u.WritePEMFile(caCertFn, certificatePEMHeader, caCert); err != nil { + return nil, err + } + resFiles = append(resFiles, caCertFn) + _, caBadCert, _, err := u.MakeCACertKeyPair(caCN, ou, defaultKeyValidPeriod) + if err != nil { + return resFiles, err + } + badCaCertFn := caCN + TestBadCrtFileExt + if err := u.WritePEMFile(badCaCertFn, certificatePEMHeader, caBadCert); err != nil { + return resFiles, err + } + resFiles = append(resFiles, badCaCertFn) + srvCert, srvPrivKey, err := u.MakeSubjCertKeyPair(srvCN, ou, defaultKeyValidPeriod, caCertTpl, caPrivKey) + if err != nil { + return resFiles, err + } + srvCertFn := filepath.Join(u.Prefix, srvCN+TestCrtFileExt) + srvKeyFn := filepath.Join(u.Prefix, srvCN+TestKeyFileExt) + if err := u.WritePEMFile(srvCertFn, certificatePEMHeader, srvCert); err != nil { + return resFiles, err + } + resFiles = append(resFiles, srvCertFn) + if err := u.WritePEMFile(srvKeyFn, rsaKeyPEMHeader, x509.MarshalPKCS1PrivateKey(srvPrivKey)); err != nil { + return resFiles, err + } + resFiles = append(resFiles, srvKeyFn) + cliCert, cliPrivKey, err := u.MakeSubjCertKeyPair(cliCN, ou, defaultKeyValidPeriod, caCertTpl, caPrivKey) + if err != nil { + return resFiles, err + } + cliCertFn := filepath.Join(u.Prefix, cliCN+TestCrtFileExt) + cliKeyFn := filepath.Join(u.Prefix, cliCN+TestKeyFileExt) + if err := u.WritePEMFile(cliCertFn, certificatePEMHeader, cliCert); err != nil { + return resFiles, err + } + resFiles = append(resFiles, cliCertFn) + if err := u.WritePEMFile(cliKeyFn, rsaKeyPEMHeader, x509.MarshalPKCS1PrivateKey(cliPrivKey)); err != nil { + return resFiles, err + } + resFiles = append(resFiles, cliKeyFn) + return resFiles, nil +} diff --git a/control/flags.go b/control/flags.go index 70ab3e596..5259b7a6a 100644 --- a/control/flags.go +++ b/control/flags.go @@ -66,6 +66,10 @@ var ( Name: "tls-key", Usage: "A path to PEM-encoded private key file to use for TLS channels", } + flRootCertPaths = cli.StringFlag{ + Name: "root-cert-paths", + Usage: "A list of paths to root certificates or their parent directories, separated with OS path separator", + } flControlRpcPort = cli.StringFlag{ Name: "control-listen-port", @@ -85,5 +89,5 @@ var ( EnvVar: "SNAP_TEMP_DIR_PATH", } - Flags = []cli.Flag{flNumberOfPLs, flPluginLoadTimeout, flAutoDiscover, flPluginTrust, flKeyringPaths, flCache, flControlRpcPort, flControlRpcAddr, flTempDirPath, flTLSCert, flTLSKey} + Flags = []cli.Flag{flNumberOfPLs, flPluginLoadTimeout, flAutoDiscover, flPluginTrust, flKeyringPaths, flCache, flControlRpcPort, flControlRpcAddr, flTempDirPath, flTLSCert, flTLSKey, flRootCertPaths} ) diff --git a/control/plugin/client/grpc.go b/control/plugin/client/grpc.go index f2230932e..9b7df2d3d 100644 --- a/control/plugin/client/grpc.go +++ b/control/plugin/client/grpc.go @@ -24,15 +24,17 @@ import ( "crypto/x509" "errors" "fmt" + "io/ioutil" + "os" + "path/filepath" "strconv" "strings" "time" - "google.golang.org/grpc" - - "golang.org/x/net/context" - log "github.com/Sirupsen/logrus" + "golang.org/x/net/context" + "google.golang.org/grpc" + "google.golang.org/grpc/credentials" "github.com/intelsdi-x/snap/control/plugin" "github.com/intelsdi-x/snap/control/plugin/cpolicy" @@ -42,7 +44,6 @@ import ( "github.com/intelsdi-x/snap/core/cdata" "github.com/intelsdi-x/snap/core/ctypes" "github.com/intelsdi-x/snap/pkg/rpcutil" - "google.golang.org/grpc/credentials" ) // SecureSide identifies security mode to apply in securing gRPC @@ -81,22 +82,34 @@ type grpcClient struct { // GRPCSecurity contains data necessary to setup secure gRPC communication type GRPCSecurity struct { - TLSEnabled bool - SecureSide SecureSide - TLSCertPath string - TLSKeyPath string + TLSEnabled bool + SecureSide SecureSide + TLSCertPath string + TLSKeyPath string + RootCertPaths []string } // SecurityTLSEnabled generates setup object for securing gRPC communication -func SecurityTLSEnabled(certPath, keyPath string, SecureSide SecureSide) GRPCSecurity { +func SecurityTLSEnabled(certPath, keyPath string, secureSide SecureSide) GRPCSecurity { return GRPCSecurity{ TLSEnabled: true, - SecureSide: SecureSide, + SecureSide: secureSide, TLSCertPath: certPath, TLSKeyPath: keyPath, } } +// SecurityTLSExtended generates setup object for securing gRPC communication +func SecurityTLSExtended(certPath, keyPath string, secureSide SecureSide, rootCertPaths []string) GRPCSecurity { + return GRPCSecurity{ + TLSEnabled: true, + SecureSide: secureSide, + TLSCertPath: certPath, + TLSKeyPath: keyPath, + RootCertPaths: rootCertPaths, + } +} + // SecurityTLSOff generates setup object deactivating gRPC security func SecurityTLSOff() GRPCSecurity { return GRPCSecurity{ @@ -141,6 +154,52 @@ func NewPublisherGrpcClient(address string, timeout time.Duration, security GRPC return p.(PluginPublisherClient), err } +func loadRootCerts(certPaths []string) (rootCAs *x509.CertPool, err error) { + var path string + var filepaths []string + // list potential certificate files + for _, path := range certPaths { + var stat os.FileInfo + if stat, err = os.Stat(path); err != nil { + return nil, fmt.Errorf("unable to process CA cert source path %s: %v", path, err) + } + if !stat.IsDir() { + filepaths = append(filepaths, path) + continue + } + var subfiles []os.FileInfo + if subfiles, err = ioutil.ReadDir(path); err != nil { + return nil, fmt.Errorf("unable to process CA cert source directory %s: %v", path, err) + } + for _, subfile := range subfiles { + subpath := filepath.Join(path, subfile.Name()) + if subfile.IsDir() { + log.WithField("path", subpath).Debug("Skipping second level directory found among certificate files") + continue + } + filepaths = append(filepaths, subpath) + } + } + rootCAs = x509.NewCertPool() + numread := 0 + for _, path = range filepaths { + b, err := ioutil.ReadFile(path) + if err != nil { + log.WithFields(log.Fields{"path": path, "error": err}).Debug("Unable to read cert file") + continue + } + if !rootCAs.AppendCertsFromPEM(b) { + log.WithField("path", path).Debug("Didn't find any usable certificates in cert file") + continue + } + numread++ + } + if numread == 0 { + return nil, fmt.Errorf("found no usable certificates in given locations") + } + return rootCAs, nil +} + func buildCredentials(security GRPCSecurity) (creds credentials.TransportCredentials, err error) { if !security.TLSEnabled { return nil, nil @@ -149,9 +208,19 @@ func buildCredentials(security GRPCSecurity) (creds credentials.TransportCredent if err != nil { return nil, fmt.Errorf("unable to load TLS key pair: %v", err) } - rootCAs, err := x509.SystemCertPool() - if err != nil { - return nil, fmt.Errorf("unable to load system-wide root TLS certificates: %v", err) + var rootCAs *x509.CertPool + if len(security.RootCertPaths) > 0 { + log.Debug("Loading root certificates given explicitly") + rootCAs, err = loadRootCerts(security.RootCertPaths) + if err != nil { + return nil, err + } + } else { + log.Debug("Loading root certificates from operating system") + rootCAs, err = x509.SystemCertPool() + if err != nil { + return nil, fmt.Errorf("unable to load system-wide root TLS certificates: %v", err) + } } switch security.SecureSide { case SecureClient: @@ -168,7 +237,7 @@ func buildCredentials(security GRPCSecurity) (creds credentials.TransportCredent tls.TLS_RSA_WITH_AES_256_GCM_SHA384, }, ClientAuth: tls.RequireAndVerifyClientCert, - RootCAs: rootCAs, + ClientCAs: rootCAs, }) case DisabledSecurity: creds = nil diff --git a/control/plugin/plugin.go b/control/plugin/plugin.go index 2a84d34bc..0244da821 100644 --- a/control/plugin/plugin.go +++ b/control/plugin/plugin.go @@ -25,6 +25,7 @@ import ( "time" log "github.com/Sirupsen/logrus" + "github.com/intelsdi-x/snap/control/plugin/cpolicy" ) @@ -149,9 +150,10 @@ type Arg struct { // enable pprof Pprof bool - CertPath string - KeyPath string - TLSEnabled bool + CertPath string `json:"CertPath"` + KeyPath string `json:"KeyPath"` + RootCertPaths string `json:"RootCertPaths"` + TLSEnabled bool `json:"TLSEnabled"` } // SetCertPath sets path to TLS certificate in plugin arguments @@ -172,6 +174,12 @@ func (a Arg) SetTLSEnabled(tlsEnabled bool) Arg { return a } +// SetRootCertPaths sets list of certificate paths for client verification +func (a Arg) SetRootCertPaths(rootCertPaths string) Arg { + a.RootCertPaths = rootCertPaths + return a +} + // NewArg returns new plugin arguments structure func NewArg(logLevel int, pprof bool) Arg { return Arg{ diff --git a/control/plugin_manager.go b/control/plugin_manager.go index b8d4af2ae..3f34b24d1 100644 --- a/control/plugin_manager.go +++ b/control/plugin_manager.go @@ -65,6 +65,8 @@ var ( ErrPluginNotInLoadedState = errors.New("Plugin must be in a LoadedState") pmLogger = log.WithField("_module", "control-plugin-mgr") + + defaultManagerOpts = []pluginManagerOpt{optDefaultManagerSecurity()} ) type pluginState string @@ -156,17 +158,18 @@ func (l *loadedPlugins) findLatest(typeName, name string) (*loadedPlugin, error) // the struct representing a plugin that is loaded into snap type pluginDetails struct { - CheckSum [sha256.Size]byte - Exec []string - ExecPath string - IsPackage bool - Manifest *schema.ImageManifest - Path string - Signed bool - Signature []byte - CertPath string - KeyPath string - TLSEnabled bool + CheckSum [sha256.Size]byte + Exec []string + ExecPath string + IsPackage bool + Manifest *schema.ImageManifest + Path string + Signed bool + Signature []byte + CertPath string + KeyPath string + RootCertPaths string + TLSEnabled bool } type loadedPlugin struct { @@ -240,9 +243,7 @@ type pluginManager struct { pluginTags map[string]map[string]string pprof bool tempDirPath string - tlsCertPath string - tlsKeyPath string - tlsEnabled bool + grpcSecurity client.GRPCSecurity } func newPluginManager(opts ...pluginManagerOpt) *pluginManager { @@ -257,8 +258,9 @@ func newPluginManager(opts ...pluginManagerOpt) *pluginManager { pluginConfig: newPluginConfig(), pluginTags: newPluginTags(), } - - for _, opt := range opts { + mergedOpts := append([]pluginManagerOpt{}, defaultManagerOpts...) + mergedOpts = append(mergedOpts, opts...) + for _, opt := range mergedOpts { opt(p) } @@ -282,11 +284,9 @@ func OptSetPprof(pprof bool) pluginManagerOpt { } // OptEnableManagerTLS enables the TLS configuration in plugin manager. -func OptEnableManagerTLS(tlsCertPath, tlsKeyPath string) pluginManagerOpt { +func OptEnableManagerTLS(grpcSecurity client.GRPCSecurity) pluginManagerOpt { return func(p *pluginManager) { - p.tlsCertPath = tlsCertPath - p.tlsKeyPath = tlsKeyPath - p.tlsEnabled = true + p.grpcSecurity = grpcSecurity } } @@ -304,6 +304,12 @@ func OptSetPluginTags(tags map[string]map[string]string) pluginManagerOpt { } } +func optDefaultManagerSecurity() pluginManagerOpt { + return func(p *pluginManager) { + p.grpcSecurity = client.SecurityTLSOff() + } +} + // SetPluginLoadTimeout sets plugin load timeout func (p *pluginManager) SetPluginLoadTimeout(to int) { p.pluginLoadTimeout = to @@ -353,6 +359,7 @@ func (p *pluginManager) LoadPlugin(details *pluginDetails, emitter gomit.Emitter p.GenerateArgs(int(log.GetLevel())). SetCertPath(details.CertPath). SetKeyPath(details.KeyPath). + SetRootCertPaths(details.RootCertPaths). SetTLSEnabled(details.TLSEnabled), commands...) if err != nil { @@ -386,13 +393,7 @@ func (p *pluginManager) LoadPlugin(details *pluginDetails, emitter gomit.Emitter "plugin-type": resp.Type.String(), }) } - var grpcSecurity client.GRPCSecurity - if p.tlsEnabled { - grpcSecurity = client.SecurityTLSEnabled(p.tlsCertPath, p.tlsKeyPath, client.SecureClient) - } else { - grpcSecurity = client.SecurityTLSOff() - } - ap, err := newAvailablePlugin(resp, emitter, ePlugin, grpcSecurity) + ap, err := newAvailablePlugin(resp, emitter, ePlugin, p.grpcSecurity) if err != nil { pmLogger.WithFields(log.Fields{ "_block": "load-plugin", diff --git a/control/runner.go b/control/runner.go index 6910b8974..df4fccd9c 100644 --- a/control/runner.go +++ b/control/runner.go @@ -58,6 +58,8 @@ var ( // MaximumRestartOnDeadPluginEvent is the maximum count of restarting a plugin // after the event of control_event.DeadAvailablePluginEvent MaxPluginRestartCount = 3 + + defaultRunnerOpts = []pluginRunnerOpt{optDefaultRunnerSecurity()} ) type executablePlugin interface { @@ -73,9 +75,7 @@ type runner struct { availablePlugins *availablePlugins metricCatalog catalogsMetrics pluginManager managesPlugins - tlsCertPath string - tlsKeyPath string - tlsEnabled bool + grpcSecurity client.GRPCSecurity } func newRunner(opts ...pluginRunnerOpt) *runner { @@ -83,7 +83,9 @@ func newRunner(opts ...pluginRunnerOpt) *runner { monitor: newMonitor(), availablePlugins: newAvailablePlugins(), } - for _, opt := range opts { + mergedOpts := append([]pluginRunnerOpt{}, defaultRunnerOpts...) + mergedOpts = append(mergedOpts, opts...) + for _, opt := range append(mergedOpts) { opt(r) } return r @@ -92,11 +94,15 @@ func newRunner(opts ...pluginRunnerOpt) *runner { type pluginRunnerOpt func(*runner) // OptEnableRunnerTLS enables the TLS configuration in runner -func OptEnableRunnerTLS(tlsCertPath, tlsKeyPath string) pluginRunnerOpt { +func OptEnableRunnerTLS(grpcSecurity client.GRPCSecurity) pluginRunnerOpt { return func(r *runner) { - r.tlsCertPath = tlsCertPath - r.tlsKeyPath = tlsKeyPath - r.tlsEnabled = true + r.grpcSecurity = grpcSecurity + } +} + +func optDefaultRunnerSecurity() pluginRunnerOpt { + return func(r *runner) { + r.grpcSecurity = client.SecurityTLSOff() } } @@ -193,14 +199,7 @@ func (r *runner) startPlugin(p executablePlugin) (*availablePlugin, error) { return nil, e } - // build availablePlugin - var grpcSecurity client.GRPCSecurity - if r.tlsEnabled { - grpcSecurity = client.SecurityTLSEnabled(r.tlsCertPath, r.tlsKeyPath, client.SecureClient) - } else { - grpcSecurity = client.SecurityTLSOff() - } - ap, err := newAvailablePlugin(resp, r.emitter, p, grpcSecurity) + ap, err := newAvailablePlugin(resp, r.emitter, p, r.grpcSecurity) if err != nil { return nil, err } @@ -348,6 +347,7 @@ func (r *runner) runPlugin(name string, details *pluginDetails) error { ePlugin, err := plugin.NewExecutablePlugin(r.pluginManager.GenerateArgs(int(log.GetLevel())). SetCertPath(details.CertPath). SetKeyPath(details.KeyPath). + SetRootCertPaths(details.RootCertPaths). SetTLSEnabled(details.TLSEnabled), commands...) if err != nil { runnerLog.WithFields(log.Fields{ diff --git a/control/runner_test.go b/control/runner_test.go index 1be286408..eb6279eaa 100644 --- a/control/runner_test.go +++ b/control/runner_test.go @@ -324,23 +324,6 @@ func TestRunnerState(t *testing.T) { }) } -func newExecutablePlugin(a plugin.Arg, path string) (*plugin.ExecutablePlugin, error) { - // Travis optimization: Try starting the plugin three times before finally - // returning an error - var e error - var ep *plugin.ExecutablePlugin - for i := 0; i < 3; i++ { - ep, e = plugin.NewExecutablePlugin(a, path) - if e == nil { - break - } - if e != nil && i == 2 { - return nil, e - } - } - return ep, nil -} - func TestRunnerPluginRunning(t *testing.T) { // log.SetLevel(log.DebugLevel) Convey("snap/control", t, func() { @@ -353,7 +336,7 @@ func TestRunnerPluginRunning(t *testing.T) { r := newRunner() r.SetEmitter(new(MockEmitter)) a := plugin.Arg{} - exPlugin, err := newExecutablePlugin(a, fixtures.PluginPathMock2) + exPlugin, err := fixtures.NewExecutablePlugin(a, fixtures.PluginPathMock2) if err != nil { panic(err) } @@ -375,7 +358,7 @@ func TestRunnerPluginRunning(t *testing.T) { r := newRunner() r.SetEmitter(new(MockEmitter)) a := plugin.Arg{} - exPlugin, err := newExecutablePlugin(a, fixtures.PluginPathMock2) + exPlugin, err := fixtures.NewExecutablePlugin(a, fixtures.PluginPathMock2) if err != nil { panic(err) } @@ -393,7 +376,7 @@ func TestRunnerPluginRunning(t *testing.T) { r := newRunner() r.SetEmitter(new(MockEmitter)) a := plugin.Arg{} - exPlugin, err := newExecutablePlugin(a, fixtures.PluginPathMock2) + exPlugin, err := fixtures.NewExecutablePlugin(a, fixtures.PluginPathMock2) if err != nil { panic(err) } @@ -410,7 +393,7 @@ func TestRunnerPluginRunning(t *testing.T) { r := newRunner() r.SetEmitter(new(MockEmitter)) a := plugin.Arg{} - exPlugin, err := newExecutablePlugin(a, fixtures.PluginPathMock2) + exPlugin, err := fixtures.NewExecutablePlugin(a, fixtures.PluginPathMock2) if err != nil { panic(err) } @@ -427,7 +410,7 @@ func TestRunnerPluginRunning(t *testing.T) { r := newRunner() r.SetEmitter(new(MockEmitter)) a := plugin.Arg{} - exPlugin, err := newExecutablePlugin(a, fixtures.PluginPathMock2) + exPlugin, err := fixtures.NewExecutablePlugin(a, fixtures.PluginPathMock2) if err != nil { panic(err) } @@ -448,7 +431,7 @@ func TestRunnerPluginRunning(t *testing.T) { r := newRunner() r.SetEmitter(new(MockEmitter)) a := plugin.Arg{} - exPlugin, err := newExecutablePlugin(a, fixtures.PluginPathMock2) + exPlugin, err := fixtures.NewExecutablePlugin(a, fixtures.PluginPathMock2) if err != nil { panic(err) } @@ -481,7 +464,7 @@ func TestRunnerPluginRunning(t *testing.T) { r := newRunner() r.SetEmitter(new(MockEmitter)) a := plugin.Arg{} - exPlugin, err := newExecutablePlugin(a, fixtures.PluginPathMock2) + exPlugin, err := fixtures.NewExecutablePlugin(a, fixtures.PluginPathMock2) if err != nil { panic(err) } diff --git a/core/plugin.go b/core/plugin.go index f67a4aaa8..44bed4894 100644 --- a/core/plugin.go +++ b/core/plugin.go @@ -122,12 +122,13 @@ type SubscribedPlugin interface { } type RequestedPlugin struct { - path string - checkSum [sha256.Size]byte - signature []byte - certPath string - keyPath string - tlsEnabled bool + path string + checkSum [sha256.Size]byte + signature []byte + certPath string + keyPath string + rootCertPaths string + tlsEnabled bool } // NewRequestedPlugin returns a Requested Plugin which represents the plugin path and signature @@ -198,6 +199,11 @@ func (p *RequestedPlugin) KeyPath() string { return p.keyPath } +// RootCertPaths returns the list of TLS root cert paths for plugin to use +func (p *RequestedPlugin) RootCertPaths() string { + return p.rootCertPaths +} + // TLSEnabled returns the TLS enabled flag for requested plugin func (p *RequestedPlugin) TLSEnabled() bool { return p.tlsEnabled @@ -225,6 +231,11 @@ func (p *RequestedPlugin) SetKeyPath(keyPath string) { p.keyPath = keyPath } +// SetRootCertPaths sets the list of paths to TLS root certificate for plugin to use +func (p *RequestedPlugin) SetRootCertPaths(rootCertPaths string) { + p.rootCertPaths = rootCertPaths +} + // SetTLSEnabled sets the TLS flag on requested plugin func (p *RequestedPlugin) SetTLSEnabled(tlsEnabled bool) { p.tlsEnabled = tlsEnabled diff --git a/docs/SNAPTEL.md b/docs/SNAPTEL.md index b52dfe5d5..f748759a4 100644 --- a/docs/SNAPTEL.md +++ b/docs/SNAPTEL.md @@ -87,9 +87,9 @@ help, h Shows a list of commands or help for one command $ snaptel plugin command [command options] [arguments...] ``` ``` -load load +load load [--plugin-cert= --plugin-key=] unload unload -swap swap :: or swap -t -n -v +swap swap :: or swap -t -n -v [--plugin-cert= --plugin-key= list list help, h Shows a list of commands or help for one command ``` diff --git a/docs/SNAPTELD.md b/docs/SNAPTELD.md index 9b65987ed..9f62b471d 100644 --- a/docs/SNAPTELD.md +++ b/docs/SNAPTELD.md @@ -44,6 +44,8 @@ $ snapteld [global options] command [command options] [arguments...] --control-listen-port value Listen port for control RPC server (default: 8082) [$SNAP_CONTROL_LISTEN_PORT] --control-listen-addr value Listen address for control RPC server [$SNAP_CONTROL_LISTEN_ADDR] --temp_dir_path value Temporary path for loading plugins [$SNAP_TEMP_DIR_PATH] +--tls-cert value A path to PEM-encoded certificate to use for TLS channels +--tls-key value A path to PEM-encoded private key file to use for TLS channels --work-manager-queue-size value Size of the work manager queue (default: 25) [$WORK_MANAGER_QUEUE_SIZE] --work-manager-pool-size value Size of the work manager pool (default: 4) [$WORK_MANAGER_POOL_SIZE] --disable-api, -d Disable the agent REST API diff --git a/glide.lock b/glide.lock index 522b2ca96..d7ee21634 100644 --- a/glide.lock +++ b/glide.lock @@ -34,7 +34,7 @@ imports: - name: github.com/intelsdi-x/gomit version: db68f6fda248706a71980abc58e969fcd63f5ea6 - name: github.com/intelsdi-x/snap-plugin-lib-go - version: af03307f16e97c41f966b1e00d340d3393d6675d + version: 3527311f5c8e6fe9b55fc1f44e1b515a30845e18 subpackages: - v1/plugin - v1/plugin/rpc diff --git a/mgmt/rest/client/client.go b/mgmt/rest/client/client.go index e86f83fe1..3642afe79 100644 --- a/mgmt/rest/client/client.go +++ b/mgmt/rest/client/client.go @@ -319,7 +319,8 @@ func (c *Client) pluginUploadRequest(pluginPaths []string) (*rbody.APIResponse, defer file.Close() bufin := bufio.NewReader(file) bufins = append(bufins, bufin) - if baseName := filepath.Base(pluginPath); strings.HasPrefix(baseName, v1.TLSCertPrefix) || strings.HasPrefix(baseName, v1.TLSKeyPrefix) { + if baseName := filepath.Base(pluginPath); strings.HasPrefix(baseName, v1.TLSCertPrefix) || + strings.HasPrefix(baseName, v1.TLSKeyPrefix) || strings.HasPrefix(baseName, v1.TLSRootCertsPrefix) { defer os.Remove(pluginPath) } paths = append(paths, filepath.Base(pluginPath)) diff --git a/mgmt/rest/v1/api.go b/mgmt/rest/v1/api.go index 549abc783..86cd5bfdc 100644 --- a/mgmt/rest/v1/api.go +++ b/mgmt/rest/v1/api.go @@ -12,6 +12,8 @@ const ( TLSCertPrefix = "crt." // TLSKeyPrefix defines a prefix for file fragment carrying path to TLS private key TLSKeyPrefix = "key." + // TLSRootCertsPrefix defines a prefix for file fragment carrying paths to TLS root certificates + TLSRootCertsPrefix = "root." version = "v1" prefix = "/" + version diff --git a/mgmt/rest/v1/plugin.go b/mgmt/rest/v1/plugin.go index f83754b76..af456dd18 100644 --- a/mgmt/rest/v1/plugin.go +++ b/mgmt/rest/v1/plugin.go @@ -76,6 +76,7 @@ func (s *apiV1) loadPlugin(w http.ResponseWriter, r *http.Request, _ httprouter. if strings.HasPrefix(mediaType, "multipart/") { var certPath string var keyPath string + var rootCertPaths string var signature []byte var checkSum [sha256.Size]byte lp := &rbody.PluginsLoaded{} @@ -132,7 +133,7 @@ func (s *apiV1) loadPlugin(w http.ResponseWriter, r *http.Request, _ httprouter. return } checkSum = sha256.Sum256(b) - case i < 4: + case i < 5: if filepath.Ext(p.FileName()) == ".asc" { signature = b } else if strings.HasPrefix(p.FileName(), TLSCertPrefix) { @@ -149,13 +150,16 @@ func (s *apiV1) loadPlugin(w http.ResponseWriter, r *http.Request, _ httprouter. rbody.Write(500, rbody.FromError(e), w) return } + } else if strings.HasPrefix(p.FileName(), TLSRootCertsPrefix) { + rootCertPaths = string(b) + // validation will take place later; take it as it is } else { e := errors.New("Error: unrecognized file was passed") rbody.Write(500, rbody.FromError(e), w) return } - case i == 4: - e := errors.New("Error: More than four files passed to the load plugin API") + case i == 5: + e := errors.New("Error: More than five files passed to the load plugin API") rbody.Write(500, rbody.FromError(e), w) return } @@ -171,6 +175,7 @@ func (s *apiV1) loadPlugin(w http.ResponseWriter, r *http.Request, _ httprouter. rp.SetSignature(signature) rp.SetCertPath(certPath) rp.SetKeyPath(keyPath) + rp.SetRootCertPaths(rootCertPaths) if certPath != "" && keyPath != "" { rp.SetTLSEnabled(true) } else if certPath != "" || keyPath != "" { diff --git a/plugin/collector/snap-plugin-collector-mock2-grpc/main_small_test.go b/plugin/collector/snap-plugin-collector-mock2-grpc/main_small_test.go index 045c93574..1385d72a9 100644 --- a/plugin/collector/snap-plugin-collector-mock2-grpc/main_small_test.go +++ b/plugin/collector/snap-plugin-collector-mock2-grpc/main_small_test.go @@ -22,6 +22,7 @@ limitations under the License. package main import ( + "os" "testing" . "github.com/smartystreets/goconvey/convey" @@ -29,6 +30,7 @@ import ( func TestMain(t *testing.T) { Convey("ensure plugin loads and responds", t, func() { + os.Args = []string{"", "{\"NoDaemon\": true}"} So(func() { main() }, ShouldNotPanic) }) } diff --git a/plugin/processor/snap-plugin-processor-passthru-grpc/main_small_test.go b/plugin/processor/snap-plugin-processor-passthru-grpc/main_small_test.go index 045c93574..1385d72a9 100644 --- a/plugin/processor/snap-plugin-processor-passthru-grpc/main_small_test.go +++ b/plugin/processor/snap-plugin-processor-passthru-grpc/main_small_test.go @@ -22,6 +22,7 @@ limitations under the License. package main import ( + "os" "testing" . "github.com/smartystreets/goconvey/convey" @@ -29,6 +30,7 @@ import ( func TestMain(t *testing.T) { Convey("ensure plugin loads and responds", t, func() { + os.Args = []string{"", "{\"NoDaemon\": true}"} So(func() { main() }, ShouldNotPanic) }) } diff --git a/snapteld.go b/snapteld.go index 9c75d82de..1c2ce02a6 100644 --- a/snapteld.go +++ b/snapteld.go @@ -34,11 +34,10 @@ import ( "syscall" "time" - "golang.org/x/crypto/ssh/terminal" - log "github.com/Sirupsen/logrus" "github.com/urfave/cli" "github.com/vrischmann/jsonutil" + "golang.org/x/crypto/ssh/terminal" "github.com/intelsdi-x/snap/control" "github.com/intelsdi-x/snap/core/serror" @@ -196,6 +195,13 @@ type managesTribe interface { GetMember(name string) *agreement.Member } +type runtimeFlagsContext interface { + String(key string) string + Int(key string) int + Bool(key string) bool + IsSet(key string) bool +} + func main() { // Add a check to see if gitversion is blank from the build process @@ -318,6 +324,9 @@ func action(ctx *cli.Context) error { setMaxProcs(cfg.GoMaxProcs) c := control.New(cfg.Control) + if c.Config.AutoDiscoverPath != "" && c.Config.IsTLSEnabled() { + log.Fatal("TLS security is not supported in autodiscovery mode") + } coreModules = []coreModule{} @@ -337,7 +346,9 @@ func action(ctx *cli.Context) error { } cfg.RestAPI.RestAuthPassword = string(password) } - + if cfg.Tribe.Enable && c.Config.IsTLSEnabled() { + log.Fatal("TLS security is not supported in tribe mode") + } var tr managesTribe if cfg.Tribe.Enable { cfg.Tribe.RestAPIPort = cfg.RestAPI.Port @@ -590,7 +601,7 @@ func defaultConfigFile() bool { // used to set fields in the configuration to values from the // command line context if the corresponding flagName is set // in that context -func setBoolVal(field bool, ctx *cli.Context, flagName string, inverse ...bool) bool { +func setBoolVal(field bool, ctx runtimeFlagsContext, flagName string, inverse ...bool) bool { // check to see if a value was set (either on the command-line or via the associated // environment variable, if any); if so, use that as value for the input field val := ctx.Bool(flagName) @@ -603,7 +614,7 @@ func setBoolVal(field bool, ctx *cli.Context, flagName string, inverse ...bool) return field } -func setStringVal(field string, ctx *cli.Context, flagName string) string { +func setStringVal(field string, ctx runtimeFlagsContext, flagName string) string { // check to see if a value was set (either on the command-line or via the associated // environment variable, if any); if so, use that as value for the input field val := ctx.String(flagName) @@ -613,7 +624,7 @@ func setStringVal(field string, ctx *cli.Context, flagName string) string { return field } -func setIntVal(field int, ctx *cli.Context, flagName string) int { +func setIntVal(field int, ctx runtimeFlagsContext, flagName string) int { // check to see if a value was set (either on the command-line or via the associated // environment variable, if any); if so, use that as value for the input field val := ctx.String(flagName) @@ -629,7 +640,7 @@ func setIntVal(field int, ctx *cli.Context, flagName string) int { return field } -func setUIntVal(field uint, ctx *cli.Context, flagName string) uint { +func setUIntVal(field uint, ctx runtimeFlagsContext, flagName string) uint { // check to see if a value was set (either on the command-line or via the associated // environment variable, if any); if so, use that as value for the input field val := ctx.String(flagName) @@ -645,7 +656,7 @@ func setUIntVal(field uint, ctx *cli.Context, flagName string) uint { return field } -func setDurationVal(field time.Duration, ctx *cli.Context, flagName string) time.Duration { +func setDurationVal(field time.Duration, ctx runtimeFlagsContext, flagName string) time.Duration { // check to see if a value was set (either on the command-line or via the associated // environment variable, if any); if so, use that as value for the input field val := ctx.String(flagName) @@ -755,7 +766,7 @@ func checkHostPortVals(addr string, port *int, errPrefix string) (bool, error) { // appropriately; returns the port read from the command-line arguments, a flag // indicating whether or not a port was detected in the address read from the // command-line arguments, and an error if one is detected -func checkCmdLineFlags(ctx *cli.Context) (int, bool, error) { +func checkCmdLineFlags(ctx runtimeFlagsContext) (int, bool, error) { tlsCert := ctx.String("tls-cert") tlsKey := ctx.String("tls-key") if _, err := checkTLSEnabled(tlsCert, tlsKey, commandLineErrorPrefix); err != nil { @@ -813,7 +824,7 @@ func checkTLSEnabled(certPath, keyPath, errPrefix string) (tlsEnabled bool, err // Apply the command line flags set (if any) to override the values // in the input configuration -func applyCmdLineFlags(cfg *Config, ctx *cli.Context) { +func applyCmdLineFlags(cfg *Config, ctx runtimeFlagsContext) { // check the settings for the command-line arguments included in the cli.Context cmdLinePort, cmdLinePortInAddr, cmdLineErr := checkCmdLineFlags(ctx) if cmdLineErr != nil { @@ -845,6 +856,7 @@ func applyCmdLineFlags(cfg *Config, ctx *cli.Context) { cfg.Control.TempDirPath = setStringVal(cfg.Control.TempDirPath, ctx, "temp_dir_path") cfg.Control.TLSCertPath = setStringVal(cfg.Control.TLSCertPath, ctx, "tls-cert") cfg.Control.TLSKeyPath = setStringVal(cfg.Control.TLSKeyPath, ctx, "tls-key") + cfg.Control.RootCertPaths = setStringVal(cfg.Control.RootCertPaths, ctx, "root-cert-paths") // next for the RESTful server related flags cfg.RestAPI.Enable = setBoolVal(cfg.RestAPI.Enable, ctx, "disable-api", invertBoolean) cfg.RestAPI.Port = setIntVal(cfg.RestAPI.Port, ctx, "api-port") diff --git a/snapteld_test.go b/snapteld_test.go index 2e23120e9..411bd9ff0 100644 --- a/snapteld_test.go +++ b/snapteld_test.go @@ -23,12 +23,103 @@ package main import ( "encoding/json" + "strconv" "testing" + "time" - "github.com/intelsdi-x/snap/pkg/cfgfile" . "github.com/smartystreets/goconvey/convey" + "github.com/vrischmann/jsonutil" + + "github.com/intelsdi-x/snap/control" + "github.com/intelsdi-x/snap/mgmt/rest" + "github.com/intelsdi-x/snap/mgmt/tribe" + "github.com/intelsdi-x/snap/pkg/cfgfile" + "github.com/intelsdi-x/snap/scheduler" ) +var validCmdlineFlags_input = mockFlags{ + "max-procs": "11", + "log-level": "1", + "log-path": "/no/logs/allowed", + "log-truncate": "true", + "log-colors": "true", + "max-running-plugins": "12", + "plugin-load-timeout": "20", + "plugin-trust": "1", + "auto-discover": "/no/plugins/here", + "keyring-paths": "/no/keyrings/here", + "cache-expiration": "30ms", + "control-listen-addr": "100.101.102.103", + "control-listen-port": "10400", + "pprof": "true", + "temp_dir_path": "/no/temp/files", + "tls-cert": "/no/cert/here", + "tls-key": "/no/key/here", + "root-cert-paths": "/no/root/certs", + "disable-api": "false", + "api-port": "12400", + "api-addr": "120.121.122.123", + "rest-https": "true", + "rest-cert": "/no/rest/cert", + "rest-key": "/no/rest/key", + "rest-auth": "true", + "rest-auth-pwd": "noway", + "allowed_origins": "140.141.142.143", + "work-manager-queue-size": "70", + "work-manager-pool-size": "71", + "tribe-node-name": "bonk", + "tribe": "true", + "tribe-addr": "160.161.162.163", + "tribe-port": "16400", + "tribe-seed": "180.181.182.183", +} + +var validCmdlineFlags_expected = &Config{ + Control: &control.Config{ + MaxRunningPlugins: 12, + PluginLoadTimeout: 20, + PluginTrust: 1, + AutoDiscoverPath: "/no/plugins/here", + KeyringPaths: "/no/keyrings/here", + CacheExpiration: jsonutil.Duration{30 * time.Millisecond}, + ListenAddr: "100.101.102.103", + ListenPort: 10400, + Pprof: true, + TempDirPath: "/no/temp/files", + TLSCertPath: "/no/cert/here", + TLSKeyPath: "/no/key/here", + RootCertPaths: "/no/root/certs", + }, + RestAPI: &rest.Config{ + Enable: true, + Port: 12400, + Address: "120.121.122.123:12400", + HTTPS: true, + RestCertificate: "/no/rest/cert", + RestKey: "/no/rest/key", + RestAuth: true, + RestAuthPassword: "noway", + Pprof: true, + Corsd: "140.141.142.143", + }, + Tribe: &tribe.Config{ + Name: "bonk", + Enable: true, + BindAddr: "160.161.162.163", + BindPort: 16400, + Seed: "180.181.182.183", + }, + Scheduler: &scheduler.Config{ + WorkManagerQueueSize: 70, + WorkManagerPoolSize: 71, + }, + GoMaxProcs: 11, + LogLevel: 1, + LogPath: "/no/logs/allowed", + LogTruncate: true, + LogColors: true, +} + func TestSnapConfig(t *testing.T) { Convey("Test Config", t, func() { Convey("with defaults", func() { @@ -39,3 +130,268 @@ func TestSnapConfig(t *testing.T) { }) }) } + +type mockFlags map[string]string + +func (m mockFlags) String(key string) string { + return m[key] +} + +func (m mockFlags) Int(key string) int { + if v, err := strconv.Atoi(m[key]); err == nil { + return v + } + return 0 +} + +func (m mockFlags) Bool(key string) bool { + if v, err := strconv.ParseBool(m[key]); err == nil { + return v + } + return false +} + +func (m mockFlags) IsSet(key string) bool { + _, gotIt := m[key] + return gotIt +} + +func (m mockFlags) getCopy() mockFlags { + r := mockFlags{} + for k, v := range m { + r[k] = v + } + return r +} + +func (m mockFlags) copyWithout(keys ...string) mockFlags { + r := m.getCopy() + for _, k := range keys { + delete(r, k) + } + return r +} + +func (m mockFlags) update(key, value string) mockFlags { + m[key] = value + return m +} + +type mockCfg Config + +func (c *mockCfg) setTLSCert(tlsCertPath string) *mockCfg { + c.Control.TLSCertPath = tlsCertPath + return c +} + +func (c *mockCfg) setTLSKey(tlsKeyPath string) *mockCfg { + c.Control.TLSKeyPath = tlsKeyPath + return c +} + +func (c *mockCfg) setRootCertPaths(rootCertPaths string) *mockCfg { + c.Control.RootCertPaths = rootCertPaths + return c +} + +func (c *mockCfg) setApiAddr(apiAddr string) *mockCfg { + c.RestAPI.Address = apiAddr + return c +} + +func (c *mockCfg) getCopy() (r *mockCfg) { + r = &mockCfg{} + b, err := json.Marshal(*c) + if err != nil { + panic(err) + } + err = json.Unmarshal(b, r) + if err != nil { + panic(err) + } + return r +} + +func (c *mockCfg) export() *Config { + return (*Config)(c) +} + +func Test_checkCmdLineFlags(t *testing.T) { + testCtx := mockFlags{ + "tls-cert": "mock-cli.crt", + "tls-key": "mock-cli.key", + "root-cert-paths": "mock-ca.crt", + "api-addr": "localhost", + "api-port": "9000"} + tests := []struct { + name string + msg func(func(string)) + ctx runtimeFlagsContext + wantErr bool + wantPort int + wantPortInAddr bool + }{ + {name: "CmdlineArgsParseWell", + msg: func(f func(string)) { + f("Having valid command line flags, parsing suceeds") + }, + ctx: testCtx.getCopy(), + wantErr: false, + wantPort: 9000, + wantPortInAddr: false}, + {name: "CmdlineArgsWithoutTLSConfigParseWell", + msg: func(f func(string)) { + f("Having valid command line flags without any TLS parameters, parsing suceeds") + }, + ctx: testCtx. + copyWithout("tls-cert", "tls-key", "root-cert-paths", "api-port"). + update("api-addr", "127.0.0.1:9002"), + wantErr: false, + wantPort: 9002, + wantPortInAddr: true}, + {name: "ArgsWithTLSCertWithoutKey_Fail", + msg: func(f func(string)) { + f("Having command line flags with TLS cert without key, parsing fails") + }, + ctx: testCtx.copyWithout("tls-key"), + wantErr: true, + }, + {name: "ArgsWithTLSKeyWithoutCert_Fail", + msg: func(f func(string)) { + f("Having command line flags with TLS key without cert, parsing fails") + }, + ctx: testCtx.copyWithout("tls-cert"), + wantErr: true, + }, + } + for _, tc := range tests { + runThisCase := func(f func(msg string)) { + t.Run(tc.name, func(_ *testing.T) { + tc.msg(f) + }) + } + runThisCase(func(msg string) { + Convey(msg, t, func() { + gotPort, gotPortInAddr, err := checkCmdLineFlags(tc.ctx) + if tc.wantErr { + So(err, ShouldNotBeNil) + return + } + So(err, ShouldBeNil) + So(gotPort, ShouldEqual, tc.wantPort) + So(gotPortInAddr, ShouldEqual, tc.wantPortInAddr) + }) + }) + } +} + +func Test_checkCfgSettings(t *testing.T) { + const DontCheckInt = -99 + testCfg := &mockCfg{ + Control: &control.Config{}, + RestAPI: &rest.Config{}, + } + tests := []struct { + name string + msg func(func(string)) + cfg *Config + wantErr bool + wantPort int + wantPortInAddr bool + }{ + {name: "DefaultConfigSettingsValidateWell", + msg: func(f func(string)) { + f("Having all default (empty) values for config, validation succeeds") + }, + cfg: (&mockCfg{Control: control.GetDefaultConfig(), RestAPI: rest.GetDefaultConfig()}).export(), + wantErr: false, + wantPort: DontCheckInt, + wantPortInAddr: false}, + {name: "ConfigSettingsValidateWell", + msg: func(f func(string)) { + f("Having correct values, config validation succeeds") + }, + cfg: testCfg.getCopy(). + setApiAddr("localhost:9000"). + setTLSCert("mock-cli.crt"). + setTLSKey("mock-cli.key"). + setRootCertPaths("mock-ca.crt"). + export(), + wantErr: false, + wantPort: 9000, + wantPortInAddr: true}, + {name: "ConfigSettingsWithoutTLSConfigValidateWell", + msg: func(f func(string)) { + f("Having correct values without any TLS parameters, config validation succeeds") + }, + cfg: testCfg.getCopy(). + setApiAddr("localhost:9000"). + export(), + wantErr: false, + wantPort: 9000, + wantPortInAddr: true}, + {name: "ConfigSettingsWithTLSCertWithoutKey_Fail", + msg: func(f func(string)) { + f("Having config with TLS cert without key, config fails to validate") + }, + cfg: testCfg.getCopy(). + setApiAddr("localhost:9000"). + setTLSCert("mock-cli.crt"). + export(), + wantErr: true, + wantPort: 9000, + wantPortInAddr: true}, + {name: "ConfigSettingsWithTLSKeyWithoutCert_Fail", + msg: func(f func(string)) { + f("Having config with TLS key without cert, config fails to validate") + }, + cfg: testCfg.getCopy(). + setApiAddr("localhost:9000"). + setTLSKey("mock-cli.crt"). + export(), + wantErr: true, + wantPort: 9000, + wantPortInAddr: true}, + } + + for _, tc := range tests { + runThisCase := func(f func(msg string)) { + t.Run(tc.name, func(_ *testing.T) { + tc.msg(f) + }) + } + runThisCase(func(msg string) { + Convey(msg, t, func() { + gotPort, gotPortInAddr, err := checkCfgSettings(tc.cfg) + if tc.wantErr { + So(err, ShouldNotBeNil) + return + } + So(err, ShouldBeNil) + if tc.wantPort != DontCheckInt { + So(gotPort, ShouldEqual, tc.wantPort) + } + So(gotPortInAddr, ShouldEqual, tc.wantPortInAddr) + }) + }) + } +} + +func Test_applyCmdLineFlags(t *testing.T) { + Convey("Having arguments given on command line", t, func() { + gotConfig := Config{ + Control: &control.Config{}, + RestAPI: &rest.Config{}, + Tribe: &tribe.Config{}, + Scheduler: &scheduler.Config{}, + } + applyCmdLineFlags(&gotConfig, validCmdlineFlags_input) + Convey("config should be filled with correct values", func() { + So(*gotConfig.Control, ShouldResemble, *validCmdlineFlags_expected.Control) + So(*gotConfig.RestAPI, ShouldResemble, *validCmdlineFlags_expected.RestAPI) + So(*gotConfig.Tribe, ShouldResemble, *validCmdlineFlags_expected.Tribe) + So(*gotConfig.Scheduler, ShouldResemble, *validCmdlineFlags_expected.Scheduler) + So(gotConfig, ShouldResemble, *validCmdlineFlags_expected) + }) + }) +}