Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: support env loading for all string fields #3019

Merged
merged 4 commits into from
Jan 8, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion internal/db/branch/switch_/switch__test.go
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ func TestSwitchCommand(t *testing.T) {
// Run test
err := Run(context.Background(), "target", fsys)
// Check error
assert.ErrorContains(t, err, "toml: line 0: unexpected EOF; expected key separator '='")
assert.ErrorContains(t, err, "toml: expected = after a key, but the document ends there")
})

t.Run("throws error on missing database", func(t *testing.T) {
Expand Down
2 changes: 1 addition & 1 deletion internal/start/start_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ func TestStartCommand(t *testing.T) {
// Run test
err := Run(context.Background(), fsys, []string{}, false)
// Check error
assert.ErrorContains(t, err, "toml: line 0: unexpected EOF; expected key separator '='")
assert.ErrorContains(t, err, "toml: expected = after a key, but the document ends there")
})

t.Run("throws error on missing docker", func(t *testing.T) {
Expand Down
2 changes: 1 addition & 1 deletion internal/status/status_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ func TestStatusCommand(t *testing.T) {
// Run test
err := Run(context.Background(), CustomName{}, utils.OutputPretty, fsys)
// Check error
assert.ErrorContains(t, err, "toml: line 0: unexpected EOF; expected key separator '='")
assert.ErrorContains(t, err, "toml: expected = after a key, but the document ends there")
})

t.Run("throws error on missing docker", func(t *testing.T) {
Expand Down
176 changes: 110 additions & 66 deletions pkg/config/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ import (
"os"
"path"
"path/filepath"
"reflect"
"regexp"
"sort"
"strconv"
Expand Down Expand Up @@ -378,9 +379,64 @@ func (c *config) Eject(w io.Writer) error {
return nil
}

// Loads custom config file to struct fields tagged with toml.
func (c *config) loadFromFile(filename string, fsys fs.FS) error {
v := viper.New()
v.SetConfigType("toml")
// Load default values
var buf bytes.Buffer
if err := initConfigTemplate.Option("missingkey=zero").Execute(&buf, c); err != nil {
return errors.Errorf("failed to initialise template config: %w", err)
} else if err := c.loadFromReader(v, &buf); err != nil {
return err
}
// Load custom config
if ext := filepath.Ext(filename); len(ext) > 0 {
v.SetConfigType(ext[1:])
}
f, err := fsys.Open(filename)
if err != nil {
return errors.Errorf("failed to read file config: %w", err)
}
defer f.Close()
return c.loadFromReader(v, f)
}

func (c *config) loadFromReader(v *viper.Viper, r io.Reader) error {
if err := v.MergeConfig(r); err != nil {
return errors.Errorf("failed to merge config: %w", err)
}
// Manually parse [functions.*] to empty struct for backwards compatibility
for key, value := range v.GetStringMap("functions") {
if m, ok := value.(map[string]any); ok && len(m) == 0 {
v.Set("functions."+key, function{})
}
}
if err := v.UnmarshalExact(c, viper.DecodeHook(mapstructure.ComposeDecodeHookFunc(
mapstructure.StringToTimeDurationHookFunc(),
mapstructure.StringToIPHookFunc(),
mapstructure.StringToSliceHookFunc(","),
mapstructure.TextUnmarshallerHookFunc(),
LoadEnvHook,
// TODO: include decrypt secret hook
)), func(dc *mapstructure.DecoderConfig) {
dc.TagName = "toml"
dc.Squash = true
}); err != nil {
return errors.Errorf("failed to parse config: %w", err)
}
return nil
}

// Loads envs prefixed with supabase_ to struct fields tagged with mapstructure.
func (c *config) loadFromEnv() error {
// Allow overriding base config object with automatic env
// Ref: https://github.com/spf13/viper/issues/761
v := viper.New()
v.SetEnvPrefix("SUPABASE")
v.SetEnvKeyReplacer(strings.NewReplacer(".", "_"))
v.AutomaticEnv()
// Viper does not parse env vars automatically. Instead of calling viper.BindEnv
// per key, we decode all keys from an existing struct, and merge them to viper.
// Ref: https://github.com/spf13/viper/issues/761#issuecomment-859306364
envKeysMap := map[string]interface{}{}
if dec, err := mapstructure.NewDecoder(&mapstructure.DecoderConfig{
Result: &envKeysMap,
Expand All @@ -389,47 +445,32 @@ func (c *config) loadFromEnv() error {
return errors.Errorf("failed to create decoder: %w", err)
} else if err := dec.Decode(c.baseConfig); err != nil {
return errors.Errorf("failed to decode env: %w", err)
}
v := viper.New()
v.SetEnvPrefix("SUPABASE")
v.SetEnvKeyReplacer(strings.NewReplacer(".", "_"))
v.AutomaticEnv()
if err := v.MergeConfigMap(envKeysMap); err != nil {
return errors.Errorf("failed to merge config: %w", err)
} else if err := v.Unmarshal(c); err != nil {
return errors.Errorf("failed to parse env to config: %w", err)
} else if err := v.MergeConfigMap(envKeysMap); err != nil {
return errors.Errorf("failed to merge env config: %w", err)
}
// Writes viper state back to config struct, with automatic env substitution
if err := v.UnmarshalExact(c, viper.DecodeHook(mapstructure.ComposeDecodeHookFunc(
mapstructure.StringToTimeDurationHookFunc(),
mapstructure.StringToIPHookFunc(),
mapstructure.StringToSliceHookFunc(","),
mapstructure.TextUnmarshallerHookFunc(),
// TODO: include decrypt secret hook
))); err != nil {
return errors.Errorf("failed to parse env override: %w", err)
}
return nil
}

func (c *config) Load(path string, fsys fs.FS) error {
builder := NewPathBuilder(path)
// Load default values
var buf bytes.Buffer
if err := initConfigTemplate.Option("missingkey=zero").Execute(&buf, c); err != nil {
return errors.Errorf("failed to initialise config template: %w", err)
}
dec := toml.NewDecoder(&buf)
if _, err := dec.Decode(c); err != nil {
return errors.Errorf("failed to decode config template: %w", err)
}
if metadata, err := toml.DecodeFS(fsys, builder.ConfigPath, c); err != nil {
cwd, osErr := os.Getwd()
if osErr != nil {
cwd = "current directory"
}
return errors.Errorf("cannot read config in %s: %w", cwd, err)
} else if undecoded := metadata.Undecoded(); len(undecoded) > 0 {
for _, key := range undecoded {
if key[0] != "remotes" {
fmt.Fprintf(os.Stderr, "Unknown config field: [%s]\n", key)
}
}
}
// Load secrets from .env file
if err := loadDefaultEnv(); err != nil {
return err
} else if err := c.loadFromEnv(); err != nil {
}
if err := c.loadFromFile(builder.ConfigPath, fsys); err != nil {
return err
}
if err := c.loadFromEnv(); err != nil {
return err
}
// Generate JWT tokens
Expand Down Expand Up @@ -619,17 +660,16 @@ func (c *baseConfig) Validate(fsys fs.FS) error {
case 15:
if len(c.Experimental.OrioleDBVersion) > 0 {
c.Db.Image = "supabase/postgres:orioledb-" + c.Experimental.OrioleDBVersion
var err error
if c.Experimental.S3Host, err = maybeLoadEnv(c.Experimental.S3Host); err != nil {
if err := assertEnvLoaded(c.Experimental.S3Host); err != nil {
return err
}
if c.Experimental.S3Region, err = maybeLoadEnv(c.Experimental.S3Region); err != nil {
if err := assertEnvLoaded(c.Experimental.S3Region); err != nil {
return err
}
if c.Experimental.S3AccessKey, err = maybeLoadEnv(c.Experimental.S3AccessKey); err != nil {
if err := assertEnvLoaded(c.Experimental.S3AccessKey); err != nil {
return err
}
if c.Experimental.S3SecretKey, err = maybeLoadEnv(c.Experimental.S3SecretKey); err != nil {
if err := assertEnvLoaded(c.Experimental.S3SecretKey); err != nil {
return err
}
}
Expand Down Expand Up @@ -666,7 +706,6 @@ func (c *baseConfig) Validate(fsys fs.FS) error {
} else if parsed.Host == "" || parsed.Host == c.Hostname {
c.Studio.ApiUrl = c.Api.ExternalUrl
}
c.Studio.OpenaiApiKey, _ = maybeLoadEnv(c.Studio.OpenaiApiKey)
}
// Validate smtp config
if c.Inbucket.Enabled {
Expand All @@ -679,12 +718,11 @@ func (c *baseConfig) Validate(fsys fs.FS) error {
if c.Auth.SiteUrl == "" {
return errors.New("Missing required field in config: auth.site_url")
}
var err error
if c.Auth.SiteUrl, err = maybeLoadEnv(c.Auth.SiteUrl); err != nil {
if err := assertEnvLoaded(c.Auth.SiteUrl); err != nil {
return err
}
for i, url := range c.Auth.AdditionalRedirectUrls {
if c.Auth.AdditionalRedirectUrls[i], err = maybeLoadEnv(url); err != nil {
if err := assertEnvLoaded(url); err != nil {
return errors.Errorf("Invalid config for auth.additional_redirect_urls[%d]: %v", i, err)
}
}
Expand Down Expand Up @@ -749,18 +787,24 @@ func (c *baseConfig) Validate(fsys fs.FS) error {
return nil
}

func maybeLoadEnv(s string) (string, error) {
matches := envPattern.FindStringSubmatch(s)
if len(matches) == 0 {
return s, nil
func assertEnvLoaded(s string) error {
if matches := envPattern.FindStringSubmatch(s); len(matches) > 1 {
return errors.Errorf(`Error evaluating "%s": environment variable %s is unset.`, s, matches[1])
}
return nil
}

envName := matches[1]
if value := os.Getenv(envName); value != "" {
return value, nil
func LoadEnvHook(f reflect.Kind, t reflect.Kind, data interface{}) (interface{}, error) {
if f != reflect.String || t != reflect.String {
return data, nil
}

return "", errors.Errorf(`Error evaluating "%s": environment variable %s is unset.`, s, envName)
value := data.(string)
if matches := envPattern.FindStringSubmatch(value); len(matches) > 1 {
if v, exists := os.LookupEnv(matches[1]); exists {
value = v
}
}
return value, nil
}

func truncateText(text string, maxLen int) string {
Expand Down Expand Up @@ -874,7 +918,7 @@ func (e *email) validate(fsys fs.FS) (err error) {
if len(e.Smtp.AdminEmail) == 0 {
return errors.New("Missing required field in config: auth.email.smtp.admin_email")
}
if e.Smtp.Pass, err = maybeLoadEnv(e.Smtp.Pass); err != nil {
if err := assertEnvLoaded(e.Smtp.Pass); err != nil {
return err
}
}
Expand All @@ -893,7 +937,7 @@ func (s *sms) validate() (err error) {
if len(s.Twilio.AuthToken) == 0 {
return errors.New("Missing required field in config: auth.sms.twilio.auth_token")
}
if s.Twilio.AuthToken, err = maybeLoadEnv(s.Twilio.AuthToken); err != nil {
if err := assertEnvLoaded(s.Twilio.AuthToken); err != nil {
return err
}
case s.TwilioVerify.Enabled:
Expand All @@ -906,7 +950,7 @@ func (s *sms) validate() (err error) {
if len(s.TwilioVerify.AuthToken) == 0 {
return errors.New("Missing required field in config: auth.sms.twilio_verify.auth_token")
}
if s.TwilioVerify.AuthToken, err = maybeLoadEnv(s.TwilioVerify.AuthToken); err != nil {
if err := assertEnvLoaded(s.TwilioVerify.AuthToken); err != nil {
return err
}
case s.Messagebird.Enabled:
Expand All @@ -916,7 +960,7 @@ func (s *sms) validate() (err error) {
if len(s.Messagebird.AccessKey) == 0 {
return errors.New("Missing required field in config: auth.sms.messagebird.access_key")
}
if s.Messagebird.AccessKey, err = maybeLoadEnv(s.Messagebird.AccessKey); err != nil {
if err := assertEnvLoaded(s.Messagebird.AccessKey); err != nil {
return err
}
case s.Textlocal.Enabled:
Expand All @@ -926,7 +970,7 @@ func (s *sms) validate() (err error) {
if len(s.Textlocal.ApiKey) == 0 {
return errors.New("Missing required field in config: auth.sms.textlocal.api_key")
}
if s.Textlocal.ApiKey, err = maybeLoadEnv(s.Textlocal.ApiKey); err != nil {
if err := assertEnvLoaded(s.Textlocal.ApiKey); err != nil {
return err
}
case s.Vonage.Enabled:
Expand All @@ -939,10 +983,10 @@ func (s *sms) validate() (err error) {
if len(s.Vonage.ApiSecret) == 0 {
return errors.New("Missing required field in config: auth.sms.vonage.api_secret")
}
if s.Vonage.ApiKey, err = maybeLoadEnv(s.Vonage.ApiKey); err != nil {
if err := assertEnvLoaded(s.Vonage.ApiKey); err != nil {
return err
}
if s.Vonage.ApiSecret, err = maybeLoadEnv(s.Vonage.ApiSecret); err != nil {
if err := assertEnvLoaded(s.Vonage.ApiSecret); err != nil {
return err
}
case s.EnableSignup:
Expand All @@ -969,16 +1013,16 @@ func (e external) validate() (err error) {
if !sliceContains([]string{"apple", "google"}, ext) && provider.Secret == "" {
return errors.Errorf("Missing required field in config: auth.external.%s.secret", ext)
}
if provider.ClientId, err = maybeLoadEnv(provider.ClientId); err != nil {
if err := assertEnvLoaded(provider.ClientId); err != nil {
return err
}
if provider.Secret, err = maybeLoadEnv(provider.Secret); err != nil {
if err := assertEnvLoaded(provider.Secret); err != nil {
return err
}
if provider.RedirectUri, err = maybeLoadEnv(provider.RedirectUri); err != nil {
if err := assertEnvLoaded(provider.RedirectUri); err != nil {
return err
}
if provider.Url, err = maybeLoadEnv(provider.Url); err != nil {
if err := assertEnvLoaded(provider.Url); err != nil {
return err
}
e[ext] = provider
Expand Down Expand Up @@ -1033,7 +1077,7 @@ func (h *hookConfig) validate(hookType string) (err error) {
case "http", "https":
if len(h.Secrets) == 0 {
return errors.Errorf("Missing required field in config: auth.hook.%s.secrets", hookType)
} else if h.Secrets, err = maybeLoadEnv(h.Secrets); err != nil {
} else if err := assertEnvLoaded(h.Secrets); err != nil {
return err
}
for _, secret := range strings.Split(h.Secrets, "|") {
Expand Down Expand Up @@ -1119,13 +1163,13 @@ func (c *tpaCognito) issuerURL() string {
func (c *tpaCognito) validate() (err error) {
if c.UserPoolID == "" {
return errors.New("Invalid config: auth.third_party.cognito is enabled but without a user_pool_id.")
} else if c.UserPoolID, err = maybeLoadEnv(c.UserPoolID); err != nil {
} else if err := assertEnvLoaded(c.UserPoolID); err != nil {
return err
}

if c.UserPoolRegion == "" {
return errors.New("Invalid config: auth.third_party.cognito is enabled but without a user_pool_region.")
} else if c.UserPoolRegion, err = maybeLoadEnv(c.UserPoolRegion); err != nil {
} else if err := assertEnvLoaded(c.UserPoolRegion); err != nil {
return err
}

Expand Down
Loading