Skip to content

Commit

Permalink
Add ability to reconcile bootstrap data between datastore and disk (k…
Browse files Browse the repository at this point in the history
…3s-io#3398)

Signed-off-by: Brian Downs <brian.downs@gmail.com>
  • Loading branch information
briandowns committed Oct 23, 2021
1 parent 63bcc30 commit 8113342
Show file tree
Hide file tree
Showing 15 changed files with 1,144 additions and 150 deletions.
47 changes: 35 additions & 12 deletions pkg/bootstrap/bootstrap.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (
"net/http"
"os"
"path/filepath"
"time"

"github.com/pkg/errors"
"github.com/rancher/k3s/pkg/daemons/config"
Expand All @@ -15,17 +16,19 @@ import (
func Handler(bootstrap *config.ControlRuntimeBootstrap) http.Handler {
return http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
rw.Header().Set("Content-Type", "application/json")
Write(rw, bootstrap)
ReadFromDisk(rw, bootstrap)
})
}

func Write(w io.Writer, bootstrap *config.ControlRuntimeBootstrap) error {
paths, err := objToMap(bootstrap)
// ReadFromDisk reads the bootstrap data from the files on disk and
// writes their content in JSON form to the given io.Writer.
func ReadFromDisk(w io.Writer, bootstrap *config.ControlRuntimeBootstrap) error {
paths, err := ObjToMap(bootstrap)
if err != nil {
return nil
}

dataMap := map[string][]byte{}
dataMap := make(map[string]File)
for pathKey, path := range paths {
if path == "" {
continue
Expand All @@ -35,24 +38,45 @@ func Write(w io.Writer, bootstrap *config.ControlRuntimeBootstrap) error {
return errors.Wrapf(err, "failed to read %s", path)
}

dataMap[pathKey] = data
info, err := os.Stat(path)
if err != nil {
return err
}

dataMap[pathKey] = File{
Timestamp: info.ModTime(),
Content: data,
}
}

return json.NewEncoder(w).Encode(dataMap)
}

func Read(r io.Reader, bootstrap *config.ControlRuntimeBootstrap) error {
paths, err := objToMap(bootstrap)
// File is a representation of a certificate
// or key file within the bootstrap context that contains
// the contents of the file as well as a timestamp from
// when the file was last modified.
type File struct {
Timestamp time.Time
Content []byte
}

type PathsDataformat map[string]File

// WriteToDiskFromStorage writes the contents of the given reader to the paths
// derived from within the ControlRuntimeBootstrap.
func WriteToDiskFromStorage(r io.Reader, bootstrap *config.ControlRuntimeBootstrap) error {
paths, err := ObjToMap(bootstrap)
if err != nil {
return err
}

files := map[string][]byte{}
files := make(PathsDataformat)
if err := json.NewDecoder(r).Decode(&files); err != nil {
return err
}

for pathKey, data := range files {
for pathKey, bsf := range files {
path, ok := paths[pathKey]
if !ok {
continue
Expand All @@ -61,16 +85,15 @@ func Read(r io.Reader, bootstrap *config.ControlRuntimeBootstrap) error {
if err := os.MkdirAll(filepath.Dir(path), 0700); err != nil {
return errors.Wrapf(err, "failed to mkdir %s", filepath.Dir(path))
}

if err := ioutil.WriteFile(path, data, 0600); err != nil {
if err := ioutil.WriteFile(path, bsf.Content, 0600); err != nil {
return errors.Wrapf(err, "failed to write to %s", path)
}
}

return nil
}

func objToMap(obj interface{}) (map[string]string, error) {
func ObjToMap(obj interface{}) (map[string]string, error) {
bytes, err := json.Marshal(obj)
if err != nil {
return nil, err
Expand Down
46 changes: 46 additions & 0 deletions pkg/bootstrap/bootstrap_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
package bootstrap

import (
"testing"

"github.com/rancher/k3s/pkg/daemons/config"
)

func TestObjToMap(t *testing.T) {
type args struct {
obj interface{}
}
tests := []struct {
name string
args args
want map[string]string
wantErr bool
}{
{
name: "Minimal Valid",
args: args{
obj: &config.ControlRuntimeBootstrap{
ServerCA: "/var/lib/rancher/k3s/server/tls/server-ca.crt",
ServerCAKey: "/var/lib/rancher/k3s/server/tls/server-ca.key",
},
},
wantErr: false,
},
{
name: "Minimal Invalid",
args: args{
obj: 1,
},
wantErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
_, err := ObjToMap(tt.args.obj)
if (err != nil) != tt.wantErr {
t.Errorf("ObjToMap() error = %v, wantErr %v", err, tt.wantErr)
return
}
})
}
}
11 changes: 11 additions & 0 deletions pkg/cli/server/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -175,6 +175,17 @@ func run(app *cli.Context, cfg *cmds.Server, leaderControllers server.CustomCont
// delete local loadbalancers state for apiserver and supervisor servers
loadbalancer.ResetLoadBalancer(filepath.Join(cfg.DataDir, "agent"), loadbalancer.SupervisorServiceName)
loadbalancer.ResetLoadBalancer(filepath.Join(cfg.DataDir, "agent"), loadbalancer.APIServerServiceName)

// at this point we're doing a restore. Check to see if we've
// passed in a token and if not, check if the token file exists.
// If it doesn't, return an error indicating the token is necessary.
if cfg.Token == "" {
if _, err := os.Stat(filepath.Join(cfg.DataDir, "server/token")); err != nil {
if os.IsNotExist(err) {
return errors.New("")
}
}
}
}

serverConfig.ControlConfig.ClusterReset = cfg.ClusterReset
Expand Down
2 changes: 1 addition & 1 deletion pkg/clientaccess/kubeconfig.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ import (
)

// WriteClientKubeConfig generates a kubeconfig at destFile that can be used to connect to a server at url with the given certs and keys
func WriteClientKubeConfig(destFile string, url string, serverCAFile string, clientCertFile string, clientKeyFile string) error {
func WriteClientKubeConfig(destFile, url, serverCAFile, clientCertFile, clientKeyFile string) error {
serverCA, err := ioutil.ReadFile(serverCAFile)
if err != nil {
return errors.Wrapf(err, "failed to read %s", serverCAFile)
Expand Down
62 changes: 31 additions & 31 deletions pkg/clientaccess/token.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,15 @@ import (
"github.com/sirupsen/logrus"
)

var (
const (
tokenPrefix = "K10"
tokenFormat = "%s%s::%s:%s"
caHashLength = sha256.Size * 2

defaultClientTimeout = 10 * time.Second
)

var (
defaultClient = &http.Client{
Timeout: defaultClientTimeout,
}
Expand All @@ -32,12 +38,6 @@ var (
}
)

const (
tokenPrefix = "K10"
tokenFormat = "%s%s::%s:%s"
caHashLength = sha256.Size * 2
)

type OverrideURLCallback func(config []byte) (*url.URL, error)

type Info struct {
Expand All @@ -49,8 +49,8 @@ type Info struct {
}

// String returns the token data, templated according to the token format
func (info *Info) String() string {
return fmt.Sprintf(tokenFormat, tokenPrefix, hashCA(info.CACerts), info.Username, info.Password)
func (i *Info) String() string {
return fmt.Sprintf(tokenFormat, tokenPrefix, hashCA(i.CACerts), i.Username, i.Password)
}

// ParseAndValidateToken parses a token, downloads and validates the server's CA bundle,
Expand All @@ -70,7 +70,7 @@ func ParseAndValidateToken(server string, token string) (*Info, error) {

// ParseAndValidateToken parses a token with user override, downloads and
// validates the server's CA bundle, and validates it according to the caHash from the token if set.
func ParseAndValidateTokenForUser(server string, token string, username string) (*Info, error) {
func ParseAndValidateTokenForUser(server, token, username string) (*Info, error) {
info, err := parseToken(token)
if err != nil {
return nil, err
Expand All @@ -86,11 +86,11 @@ func ParseAndValidateTokenForUser(server string, token string, username string)
}

// setAndValidateServer updates the remote server's cert info, and validates it against the provided hash
func (info *Info) setAndValidateServer(server string) error {
if err := info.setServer(server); err != nil {
func (i *Info) setAndValidateServer(server string) error {
if err := i.setServer(server); err != nil {
return err
}
return info.validateCAHash()
return i.validateCAHash()
}

// validateCACerts returns a boolean indicating whether or not a CA bundle matches the provided hash,
Expand Down Expand Up @@ -118,7 +118,7 @@ func ParseUsernamePassword(token string) (string, string, bool) {

// parseToken parses a token into an Info struct
func parseToken(token string) (*Info, error) {
var info = &Info{}
var info Info

if len(token) == 0 {
return nil, errors.New("token must not be empty")
Expand Down Expand Up @@ -150,7 +150,7 @@ func parseToken(token string) (*Info, error) {
info.Username = parts[0]
info.Password = parts[1]

return info, nil
return &info, nil
}

// GetHTTPClient returns a http client that validates TLS server certificates using the provided CA bundle.
Expand All @@ -177,25 +177,25 @@ func GetHTTPClient(cacerts []byte) *http.Client {
}

// Get makes a request to a subpath of info's BaseURL
func (info *Info) Get(path string) ([]byte, error) {
u, err := url.Parse(info.BaseURL)
func (i *Info) Get(path string) ([]byte, error) {
u, err := url.Parse(i.BaseURL)
if err != nil {
return nil, err
}
u.Path = path
return get(u.String(), GetHTTPClient(info.CACerts), info.Username, info.Password)
return get(u.String(), GetHTTPClient(i.CACerts), i.Username, i.Password)
}

// setServer sets the BaseURL and CACerts fields of the Info by connecting to the server
// and storing the CA bundle.
func (info *Info) setServer(server string) error {
func (i *Info) setServer(server string) error {
url, err := url.Parse(server)
if err != nil {
return errors.Wrapf(err, "Invalid server url, failed to parse: %s", server)
}

if url.Scheme != "https" {
return fmt.Errorf("only https:// URLs are supported, invalid scheme: %s", server)
return errors.New("only https:// URLs are supported, invalid scheme: " + server)
}

for strings.HasSuffix(url.Path, "/") {
Expand All @@ -207,25 +207,25 @@ func (info *Info) setServer(server string) error {
return err
}

info.BaseURL = url.String()
info.CACerts = cacerts
i.BaseURL = url.String()
i.CACerts = cacerts
return nil
}

// ValidateCAHash validates that info's caHash matches the CACerts hash.
func (info *Info) validateCAHash() error {
if len(info.caHash) > 0 && len(info.CACerts) == 0 {
func (i *Info) validateCAHash() error {
if len(i.caHash) > 0 && len(i.CACerts) == 0 {
// Warn if the user provided a CA hash but we're not going to validate because it's already trusted
logrus.Warn("Cluster CA certificate is trusted by the host CA bundle. " +
"Token CA hash will not be validated.")
} else if len(info.caHash) == 0 && len(info.CACerts) > 0 {
} else if len(i.caHash) == 0 && len(i.CACerts) > 0 {
// Warn if the CA is self-signed but the user didn't provide a hash to validate it against
logrus.Warn("Cluster CA certificate is not trusted by the host CA bundle, but the token does not include a CA hash. " +
"Use the full token from the server's node-token file to enable Cluster CA validation.")
} else if len(info.CACerts) > 0 && len(info.caHash) > 0 {
} else if len(i.CACerts) > 0 && len(i.caHash) > 0 {
// only verify CA hash if the server cert is not trusted by the OS CA bundle
if ok, serverHash := validateCACerts(info.CACerts, info.caHash); !ok {
return fmt.Errorf("token CA hash does not match the Cluster CA certificate hash: %s != %s", info.caHash, serverHash)
if ok, serverHash := validateCACerts(i.CACerts, i.caHash); !ok {
return fmt.Errorf("token CA hash does not match the Cluster CA certificate hash: %s != %s", i.caHash, serverHash)
}
}
return nil
Expand Down Expand Up @@ -288,18 +288,18 @@ func get(u string, client *http.Client, username, password string) ([]byte, erro
return ioutil.ReadAll(resp.Body)
}

func FormatToken(token string, certFile string) (string, error) {
func FormatToken(token, certFile string) (string, error) {
if len(token) == 0 {
return token, nil
}

certHash := ""
if len(certFile) > 0 {
bytes, err := ioutil.ReadFile(certFile)
b, err := ioutil.ReadFile(certFile)
if err != nil {
return "", nil
}
digest := sha256.Sum256(bytes)
digest := sha256.Sum256(b)
certHash = tokenPrefix + hex.EncodeToString(digest[:]) + "::"
}
return certHash + token, nil
Expand Down
2 changes: 1 addition & 1 deletion pkg/clientaccess/token_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -320,7 +320,7 @@ func newTLSServer(t *testing.T, username, password string, sendWrongCA bool) *ht
}
bootstrapData := &config.ControlRuntimeBootstrap{}
w.Header().Set("Content-Type", "application/json")
if err := bootstrap.Write(w, bootstrapData); err != nil {
if err := bootstrap.ReadFromDisk(w, bootstrapData); err != nil {
t.Errorf("failed to write bootstrap: %v", err)
}
return
Expand Down
Loading

0 comments on commit 8113342

Please sign in to comment.