diff --git a/pkg/api/client_options.go b/pkg/api/client_options.go index 8ed5937..3464aac 100644 --- a/pkg/api/client_options.go +++ b/pkg/api/client_options.go @@ -89,7 +89,7 @@ func optionsNeedResolution(opts ClientOptions) bool { } func resolveOptions(opts ClientOptions) (ClientOptions, error) { - cfg, _ := config.Read() + cfg, _ := config.Read(nil) if opts.Host == "" { opts.Host, _ = auth.DefaultHost() } diff --git a/pkg/api/http_client_test.go b/pkg/api/http_client_test.go index 3308b39..dee6b02 100644 --- a/pkg/api/http_client_test.go +++ b/pkg/api/http_client_test.go @@ -247,7 +247,7 @@ func defaultHeaders() http.Header { func stubConfig(t *testing.T, cfgStr string) { t.Helper() old := config.Read - config.Read = func() (*config.Config, error) { + config.Read = func(_ *config.Config) (*config.Config, error) { return config.ReadFromString(cfgStr), nil } t.Cleanup(func() { diff --git a/pkg/auth/auth.go b/pkg/auth/auth.go index 50ac4c7..ce39d58 100644 --- a/pkg/auth/auth.go +++ b/pkg/auth/auth.go @@ -55,7 +55,7 @@ func TokenForHost(host string) (string, string) { // file as fallback, but does not support reading the token from system keyring. Most consumers // should use TokenForHost. func TokenFromEnvOrConfig(host string) (string, string) { - cfg, _ := config.Read() + cfg, _ := config.Read(nil) return tokenForHost(cfg, host) } @@ -105,7 +105,7 @@ func tokenFromGh(path string, host string) (string, string) { // or from the configuration file. // Returns an empty string slice if no hosts are found. func KnownHosts() []string { - cfg, _ := config.Read() + cfg, _ := config.Read(nil) return knownHosts(cfg) } @@ -131,7 +131,7 @@ func knownHosts(cfg *config.Config) []string { // configuration file. // Returns "github.com", "default" if no viable host is found. func DefaultHost() (string, string) { - cfg, _ := config.Read() + cfg, _ := config.Read(nil) return defaultHost(cfg) } diff --git a/pkg/browser/browser.go b/pkg/browser/browser.go index 8e8d36f..4d56710 100644 --- a/pkg/browser/browser.go +++ b/pkg/browser/browser.go @@ -70,7 +70,7 @@ func resolveLauncher() string { if ghBrowser := os.Getenv("GH_BROWSER"); ghBrowser != "" { return ghBrowser } - cfg, err := config.Read() + cfg, err := config.Read(nil) if err == nil { if cfgBrowser, _ := cfg.Get([]string{"browser"}); cfgBrowser != "" { return cfgBrowser diff --git a/pkg/browser/browser_test.go b/pkg/browser/browser_test.go index 5c7375d..5514134 100644 --- a/pkg/browser/browser_test.go +++ b/pkg/browser/browser_test.go @@ -89,7 +89,7 @@ func TestResolveLauncher(t *testing.T) { } if tt.config != nil { old := config.Read - config.Read = func() (*config.Config, error) { + config.Read = func(_ *config.Config) (*config.Config, error) { return tt.config, nil } defer func() { config.Read = old }() diff --git a/pkg/config/config.go b/pkg/config/config.go index bdf15ed..e2b1fc8 100644 --- a/pkg/config/config.go +++ b/pkg/config/config.go @@ -119,18 +119,23 @@ func (c *Config) Set(keys []string, value string) { m.SetEntry(keys[len(keys)-1], yamlmap.StringValue(value)) } +func (c *Config) deepCopy() *Config { + return ReadFromString(c.entries.String()) +} + // Read gh configuration files from the local file system and -// return a Config. -var Read = func() (*Config, error) { +// returns a Config. A copy of the fallback configuration will +// be returned when there are no configuration files to load. +// If there are no configuration files and no fallback configuration +// an empty configuration will be returned. +var Read = func(fallback *Config) (*Config, error) { once.Do(func() { - cfg, loadErr = load(generalConfigFile(), hostsConfigFile()) + cfg, loadErr = load(generalConfigFile(), hostsConfigFile(), fallback) }) return cfg, loadErr } // ReadFromString takes a yaml string and returns a Config. -// Note: This is only used for testing, and should not be -// relied upon in production. func ReadFromString(str string) *Config { m, _ := mapFromString(str) if m == nil { @@ -174,7 +179,7 @@ func Write(c *Config) error { return nil } -func load(generalFilePath, hostsFilePath string) (*Config, error) { +func load(generalFilePath, hostsFilePath string, fallback *Config) (*Config, error) { generalMap, err := mapFromFile(generalFilePath) if err != nil && !os.IsNotExist(err) { if errors.Is(err, yamlmap.ErrInvalidYaml) || @@ -184,8 +189,8 @@ func load(generalFilePath, hostsFilePath string) (*Config, error) { return nil, err } - if generalMap == nil || generalMap.Empty() { - generalMap, _ = mapFromString(defaultGeneralEntries) + if generalMap == nil { + generalMap = yamlmap.MapValue() } hostsMap, err := mapFromFile(hostsFilePath) @@ -201,6 +206,10 @@ func load(generalFilePath, hostsFilePath string) (*Config, error) { generalMap.AddEntry("hosts", hostsMap) } + if generalMap.Empty() && fallback != nil { + return fallback.deepCopy(), nil + } + return &Config{entries: generalMap}, nil } @@ -302,21 +311,3 @@ func writeFile(filename string, data []byte) (writeErr error) { _, writeErr = file.Write(data) return } - -var defaultGeneralEntries = ` -# What protocol to use when performing git operations. Supported values: ssh, https -git_protocol: https -# What editor gh should run when creating issues, pull requests, etc. If blank, will refer to environment. -editor: -# When to interactively prompt. This is a global config that cannot be overridden by hostname. Supported values: enabled, disabled -prompt: enabled -# A pager program to send command output to, e.g. "less". Set the value to "cat" to disable the pager. -pager: -# Aliases allow you to create nicknames for gh commands -aliases: - co: pr checkout -# The path to a unix socket through which send HTTP connections. If blank, HTTP traffic will be handled by net/http.DefaultTransport. -http_unix_socket: -# What web browser gh should use when opening URLs. If blank, will refer to environment. -browser: -` diff --git a/pkg/config/config_test.go b/pkg/config/config_test.go index 00e88bb..b24e3c4 100644 --- a/pkg/config/config_test.go +++ b/pkg/config/config_test.go @@ -244,11 +244,11 @@ func TestLoad(t *testing.T) { name string globalConfigPath string hostsConfigPath string + fallback *Config wantGitProtocol string wantToken string wantErr bool wantErrMsg string - wantGetErr bool }{ { name: "global and hosts files exist", @@ -274,7 +274,7 @@ func TestLoad(t *testing.T) { name: "global file does not exist and hosts file exist", globalConfigPath: "", hostsConfigPath: hostsFilePath, - wantGitProtocol: "https", + wantGitProtocol: "", wantToken: "yyyyyyyyyyyyyyyyyyyy", }, { @@ -282,28 +282,51 @@ func TestLoad(t *testing.T) { globalConfigPath: globalFilePath, hostsConfigPath: "", wantGitProtocol: "ssh", - wantGetErr: true, + wantToken: "", + }, + { + name: "global file does not exist and hosts file does not exist with no fallback", + globalConfigPath: "", + hostsConfigPath: "", + wantGitProtocol: "", + wantToken: "", + }, + { + name: "global file does not exist and hosts file does not exist with fallback", + globalConfigPath: "", + hostsConfigPath: "", + fallback: ReadFromString(testFullConfig()), + wantGitProtocol: "ssh", + wantToken: "yyyyyyyyyyyyyyyyyyyy", }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - cfg, err := load(tt.globalConfigPath, tt.hostsConfigPath) + cfg, err := load(tt.globalConfigPath, tt.hostsConfigPath, tt.fallback) if tt.wantErr { assert.EqualError(t, err, tt.wantErrMsg) return } assert.NoError(t, err) - protocol, err := cfg.Get([]string{"git_protocol"}) - assert.NoError(t, err) - assert.Equal(t, tt.wantGitProtocol, protocol) - token, err := cfg.Get([]string{"hosts", "enterprise.com", "oauth_token"}) - if tt.wantGetErr { - assert.EqualError(t, err, `could not find key "hosts"`) + + if tt.wantGitProtocol == "" { + assertNoKey(t, cfg, []string{"git_protocol"}) } else { - assert.NoError(t, err) + assertKeyWithValue(t, cfg, []string{"git_protocol"}, tt.wantGitProtocol) + } + + if tt.wantToken == "" { + assertNoKey(t, cfg, []string{"hosts", "enterprise.com", "oauth_token"}) + } else { + assertKeyWithValue(t, cfg, []string{"hosts", "enterprise.com", "oauth_token"}, tt.wantToken) + } + + if tt.fallback != nil { + // Assert that load returns an equivalent copy of fallvback. + assert.Equal(t, tt.fallback.entries.String(), cfg.entries.String()) + assert.False(t, tt.fallback == cfg) } - assert.Equal(t, tt.wantToken, token) }) } } @@ -324,6 +347,14 @@ func TestWrite(t *testing.T) { cfg.Set([]string{"hosts", "github.com", "git_protocol"}, "https") return cfg }, + wantConfig: func() *Config { + // Same as created config as both a global property and host property has + // been edited. + cfg := ReadFromString(testFullConfig()) + cfg.Set([]string{"editor"}, "vim") + cfg.Set([]string{"hosts", "github.com", "git_protocol"}, "https") + return cfg + }, }, { name: "only writes hosts file", @@ -333,9 +364,8 @@ func TestWrite(t *testing.T) { return cfg }, wantConfig: func() *Config { - // The hosts file is writen but not the general config file. - // When we use Read in the test the defaultGeneralEntries are used. - cfg := ReadFromString(defaultGeneralEntries) + // The hosts file is writen but not the global config file. + cfg := ReadFromString("") cfg.Set([]string{"hosts", "github.com", "user"}, "user1") cfg.Set([]string{"hosts", "github.com", "oauth_token"}, "xxxxxxxxxxxxxxxxxxxx") cfg.Set([]string{"hosts", "github.com", "git_protocol"}, "ssh") @@ -346,26 +376,16 @@ func TestWrite(t *testing.T) { }, }, { - name: "only writes config file", + name: "only writes global config file", createConfig: func() *Config { cfg := ReadFromString(testFullConfig()) cfg.Set([]string{"editor"}, "vim") return cfg }, wantConfig: func() *Config { - // The general config file is written but not the hosts config file. - // When we use Read in the test there will not be any hosts entries. - cfg := ReadFromString(testFullConfig()) + // The global config file is written but not the hosts config file. + cfg := ReadFromString(testGlobalData()) cfg.Set([]string{"editor"}, "vim") - _ = cfg.Remove([]string{"hosts"}) - return cfg - }, - }, - { - name: "write default config file keeps comments", - createConfig: func() *Config { - cfg := ReadFromString(defaultGeneralEntries) - cfg.entries.SetModified() return cfg }, }, @@ -378,12 +398,9 @@ func TestWrite(t *testing.T) { cfg := tt.createConfig() err := Write(cfg) assert.NoError(t, err) - loadedCfg, err := load(generalConfigFile(), hostsConfigFile()) + loadedCfg, err := load(generalConfigFile(), hostsConfigFile(), nil) assert.NoError(t, err) - wantCfg := cfg - if tt.wantConfig != nil { - wantCfg = tt.wantConfig() - } + wantCfg := tt.wantConfig() assert.Equal(t, wantCfg.entries.String(), loadedCfg.entries.String()) }) } @@ -391,11 +408,10 @@ func TestWrite(t *testing.T) { func TestGet(t *testing.T) { tests := []struct { - name string - keys []string - wantValue string - wantErr bool - wantErrMsg string + name string + keys []string + wantValue string + wantErr bool }{ { name: "get git_protocol value", @@ -418,11 +434,9 @@ func TestGet(t *testing.T) { wantValue: "less", }, { - name: "non-existant key", - keys: []string{"unknown"}, - wantErr: true, - wantErrMsg: `could not find key "unknown"`, - wantValue: "", + name: "non-existant key", + keys: []string{"unknown"}, + wantErr: true, }, { name: "nested key", @@ -435,24 +449,20 @@ func TestGet(t *testing.T) { wantValue: "more", }, { - name: "nested non-existant key", - keys: []string{"nested", "invalid"}, - wantErr: true, - wantErrMsg: `could not find key "invalid"`, - wantValue: "", + name: "nested non-existant key", + keys: []string{"nested", "invalid"}, + wantErr: true, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { cfg := testConfig() - value, err := cfg.Get(tt.keys) if tt.wantErr { - assert.EqualError(t, err, tt.wantErrMsg) + assertNoKey(t, cfg, tt.keys) } else { - assert.NoError(t, err) + assertKeyWithValue(t, cfg, tt.keys, tt.wantValue) } - assert.Equal(t, tt.wantValue, value) assert.False(t, cfg.entries.IsModified()) }) } @@ -544,8 +554,7 @@ func TestRemove(t *testing.T) { assert.NoError(t, err) assert.True(t, cfg.entries.IsModified()) } - _, getErr := cfg.Get(tt.keys) - assert.Error(t, getErr) + assertNoKey(t, cfg, tt.keys) }) } } @@ -593,45 +602,11 @@ func TestSet(t *testing.T) { cfg := testConfig() cfg.Set(tt.keys, tt.value) assert.True(t, cfg.entries.IsModified()) - value, err := cfg.Get(tt.keys) - assert.NoError(t, err) - assert.Equal(t, tt.value, value) + assertKeyWithValue(t, cfg, tt.keys, tt.value) }) } } -func TestDefaultGeneralEntries(t *testing.T) { - cfg := ReadFromString(defaultGeneralEntries) - - protocol, err := cfg.Get([]string{"git_protocol"}) - assert.NoError(t, err) - assert.Equal(t, "https", protocol) - - editor, err := cfg.Get([]string{"editor"}) - assert.NoError(t, err) - assert.Equal(t, "", editor) - - prompt, err := cfg.Get([]string{"prompt"}) - assert.NoError(t, err) - assert.Equal(t, "enabled", prompt) - - pager, err := cfg.Get([]string{"pager"}) - assert.NoError(t, err) - assert.Equal(t, "", pager) - - socket, err := cfg.Get([]string{"http_unix_socket"}) - assert.NoError(t, err) - assert.Equal(t, "", socket) - - browser, err := cfg.Get([]string{"browser"}) - assert.NoError(t, err) - assert.Equal(t, "", browser) - - unknown, err := cfg.Get([]string{"unknown"}) - assert.EqualError(t, err, `could not find key "unknown"`) - assert.Equal(t, "", unknown) -} - func testConfig() *Config { var data = ` git_protocol: ssh @@ -687,3 +662,17 @@ hosts: ` return data } + +func assertNoKey(t *testing.T, cfg *Config, keys []string) { + t.Helper() + _, err := cfg.Get(keys) + var keyNotFoundError *KeyNotFoundError + assert.ErrorAs(t, err, &keyNotFoundError) +} + +func assertKeyWithValue(t *testing.T, cfg *Config, keys []string, value string) { + t.Helper() + actual, err := cfg.Get(keys) + assert.NoError(t, err) + assert.Equal(t, value, actual) +} diff --git a/pkg/repository/repository_test.go b/pkg/repository/repository_test.go index 54c9918..dfcdca7 100644 --- a/pkg/repository/repository_test.go +++ b/pkg/repository/repository_test.go @@ -193,7 +193,7 @@ func TestParseWithHost(t *testing.T) { func stubConfig(t *testing.T, cfgStr string) { t.Helper() old := config.Read - config.Read = func() (*config.Config, error) { + config.Read = func(_ *config.Config) (*config.Config, error) { return config.ReadFromString(cfgStr), nil } t.Cleanup(func() {