Skip to content

Commit

Permalink
Deduplicate extra_providers against built-in providers (#147)
Browse files Browse the repository at this point in the history
This avoids an error if users specify a provider in `extra_providers`
that's already in the list of built-in providers.

Work towards #146
  • Loading branch information
fmeum authored Jan 7, 2025
1 parent 5970101 commit 3ebbe33
Show file tree
Hide file tree
Showing 3 changed files with 27 additions and 4 deletions.
6 changes: 5 additions & 1 deletion examples/cc_define_test/cc_define_test.bzl
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
load("@with_cfg.bzl", "with_cfg")

_builder = with_cfg(native.cc_test)
_builder = with_cfg(
native.cc_test,
# Verify that duplicated providers are handled gracefully.
extra_providers = [DefaultInfo, CcInfo],
)
_builder.extend(
"copt",
select({
Expand Down
9 changes: 8 additions & 1 deletion with_cfg/private/rule_defaults.bzl
Original file line number Diff line number Diff line change
@@ -1,8 +1,15 @@
visibility("private")

DEFAULT_PROVIDERS = [
SPECIAL_CASED_PROVIDERS = [
DefaultInfo,
# Forwarding is handled by coverage_common.instrumented_files_info.
InstrumentedFilesInfo,
# RunEnvironmentInfo can't be returned from a non-executable, non-test rule and thus requires
# special handling so that it isn't returned by the transitioning alias.
RunEnvironmentInfo,
]

DEFAULT_PROVIDERS = [
AnalysisTestResultInfo,
CcInfo,
CcToolchainConfigInfo,
Expand Down
16 changes: 14 additions & 2 deletions with_cfg/private/with_cfg.bzl
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
load(":builder.bzl", "make_builder")
load(":providers.bzl", "RuleInfo")
load(":rule_defaults.bzl", "DEFAULT_PROVIDERS", "IMPLICIT_TARGETS")
load(":rule_defaults.bzl", "DEFAULT_PROVIDERS", "IMPLICIT_TARGETS", "SPECIAL_CASED_PROVIDERS")

visibility("//with_cfg/...")

Expand Down Expand Up @@ -125,7 +125,7 @@ def with_cfg(
executable = executable,
test = test,
implicit_targets = implicit_targets,
providers = DEFAULT_PROVIDERS + extra_providers,
providers = _all_providers(extra_providers),
native = _is_native(kind),
supports_inheritance = _supports_inheritance(kind),
supports_extension = _supports_extension(kind),
Expand Down Expand Up @@ -174,3 +174,15 @@ def _supports_extension(kind):

def get_implicit_targets(rule_name):
return IMPLICIT_TARGETS.get(rule_name, [])

def _all_providers(extra_providers):
if not extra_providers:
return DEFAULT_PROVIDERS
all_providers = list(DEFAULT_PROVIDERS)

# Providers aren't hashable.
# TODO: Improve this after https://github.com/bazelbuild/bazel/pull/24848.
for p in extra_providers:
if p not in all_providers and p not in SPECIAL_CASED_PROVIDERS:
all_providers.append(p)
return all_providers

0 comments on commit 3ebbe33

Please sign in to comment.