Skip to content

Commit

Permalink
provider: improve validation for subscription_id to allow `terrafor…
Browse files Browse the repository at this point in the history
…m validate` to work when it is unspecified
  • Loading branch information
manicminer committed Aug 23, 2024
1 parent ae3c272 commit 7117c91
Show file tree
Hide file tree
Showing 3 changed files with 17 additions and 7 deletions.
12 changes: 9 additions & 3 deletions internal/provider/framework/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,12 @@ func (p *ProviderConfig) Load(ctx context.Context, data *ProviderModel, tfVersio
env := &environments.Environment{}
var err error

subscriptionId := getEnvStringOrDefault(data.SubscriptionId, "ARM_SUBSCRIPTION_ID", "")
if subscriptionId == "" {
diags.Append(diag.NewErrorDiagnostic("Configuring subscription", "`subscription_id` is a required provider property when performing a plan/apply operation"))
return
}

if metadataHost := getEnvStringOrDefault(data.MetaDataHost, "ARM_METADATA_HOSTNAME", ""); metadataHost != "" {
env, err = environments.FromEndpoint(ctx, metadataHost)
if err != nil {
Expand Down Expand Up @@ -98,7 +104,7 @@ func (p *ProviderConfig) Load(ctx context.Context, data *ProviderModel, tfVersio

CustomManagedIdentityEndpoint: getEnvStringOrDefault(data.MSIEndpoint, "ARM_MSI_ENDPOINT", ""),

AzureCliSubscriptionIDHint: getEnvStringOrDefault(data.SubscriptionId, "ARM_SUBSCRIPTION_ID", ""),
AzureCliSubscriptionIDHint: subscriptionId,

EnableAuthenticatingUsingClientCertificate: true,
EnableAuthenticatingUsingClientSecret: true,
Expand Down Expand Up @@ -511,11 +517,11 @@ func (p *ProviderConfig) Load(ctx context.Context, data *ProviderModel, tfVersio
}
}

subscriptionId := commonids.NewSubscriptionID(client.Account.SubscriptionId)
subId := commonids.NewSubscriptionID(client.Account.SubscriptionId)
ctx2, cancel := context.WithTimeout(ctx, 30*time.Minute)
defer cancel()

if err = resourceproviders.EnsureRegistered(ctx2, client.Resource.ResourceProvidersClient, subscriptionId, requiredResourceProviders); err != nil {
if err = resourceproviders.EnsureRegistered(ctx2, client.Resource.ResourceProvidersClient, subId, requiredResourceProviders); err != nil {
diags.AddError("registering resource providers", err.Error())
return
}
Expand Down
3 changes: 1 addition & 2 deletions internal/provider/framework/provider.go
Original file line number Diff line number Diff line change
Expand Up @@ -56,8 +56,7 @@ func (p *azureRmFrameworkProvider) Schema(_ context.Context, _ provider.SchemaRe
"subscription_id": schema.StringAttribute{
// Note: There is no equivalent of `DefaultFunc` in the provider schema package. This property is Required, but can be
// set via env var instead of provider config, so needs to be toggled in schema based on the presence of that env var.
Required: getEnvStringOrDefault(types.StringUnknown(), "ARM_SUBSCRIPTION_ID", "") == "",
Optional: getEnvStringOrDefault(types.StringUnknown(), "ARM_SUBSCRIPTION_ID", "") != "",
Optional: true,
Description: "The Subscription ID which should be used.",
},

Expand Down
9 changes: 7 additions & 2 deletions internal/provider/provider.go
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,7 @@ func azureProvider(supportLegacyTestSuite bool) *schema.Provider {
Schema: map[string]*schema.Schema{
"subscription_id": {
Type: schema.TypeString,
Required: true,
Optional: true,
DefaultFunc: schema.EnvDefaultFunc("ARM_SUBSCRIPTION_ID", nil),
Description: "The Subscription ID which should be used.",
},
Expand Down Expand Up @@ -385,6 +385,11 @@ func azureProvider(supportLegacyTestSuite bool) *schema.Provider {
// This separation allows us to robustly test different authentication scenarios.
func providerConfigure(p *schema.Provider) schema.ConfigureContextFunc {
return func(ctx context.Context, d *schema.ResourceData) (interface{}, diag.Diagnostics) {
subscriptionId := d.Get("subscription_id").(string)
if subscriptionId == "" {
return nil, diag.FromErr(fmt.Errorf("`subscription_id` is a required provider property when performing a plan/apply operation"))
}

var auxTenants []string
if v, ok := d.Get("auxiliary_tenant_ids").([]interface{}); ok && len(v) > 0 {
auxTenants = *utils.ExpandStringSlice(v)
Expand Down Expand Up @@ -467,7 +472,7 @@ func providerConfigure(p *schema.Provider) schema.ConfigureContextFunc {

CustomManagedIdentityEndpoint: d.Get("msi_endpoint").(string),

AzureCliSubscriptionIDHint: d.Get("subscription_id").(string),
AzureCliSubscriptionIDHint: subscriptionId,

EnableAuthenticatingUsingClientCertificate: true,
EnableAuthenticatingUsingClientSecret: true,
Expand Down

0 comments on commit 7117c91

Please sign in to comment.