Skip to content

Commit

Permalink
Bzlmod: Reintroduce toolchain/platform registration on module()
Browse files Browse the repository at this point in the history
(#13316)

We ended up deciding to go with the original approach anyway! The only functional difference of using this approach is that we won't allow conditional toolchain/platform registration (i.e. "register this toolchain iff this option is true"). A quick search through the GitHub codebase seems to suggest that it's only ever used by rules_foreign_cc, and even then can be worked around using tags & aliases, so we're not too worried.

PiperOrigin-RevId: 400936021
  • Loading branch information
Wyverald authored and copybara-github committed Oct 5, 2021
1 parent 6345c80 commit a3cec13
Show file tree
Hide file tree
Showing 7 changed files with 406 additions and 9 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,19 @@ public final String getCanonicalRepoName() {
*/
public abstract int getCompatibilityLevel();

/**
* Target patterns identifying execution platforms to register when this module is selected. Note
* that these are what was written in module files verbatim, and don't contain canonical repo
* names.
*/
public abstract ImmutableList<String> getExecutionPlatformsToRegister();

/**
* Target patterns identifying toolchains to register when this module is selected. Note that
* these are what was written in module files verbatim, and don't contain canonical repo names.
*/
public abstract ImmutableList<String> getToolchainsToRegister();

/**
* The direct dependencies of this module. The key type is the repo name of the dep, and the value
* type is the ModuleKey (name+version) of the dep.
Expand Down Expand Up @@ -125,7 +138,9 @@ public static Builder builder() {
.setName("")
.setVersion(Version.EMPTY)
.setKey(ModuleKey.ROOT)
.setCompatibilityLevel(0);
.setCompatibilityLevel(0)
.setExecutionPlatformsToRegister(ImmutableList.of())
.setToolchainsToRegister(ImmutableList.of());
}

/**
Expand Down Expand Up @@ -153,6 +168,12 @@ public abstract static class Builder {
/** Optional; defaults to {@code 0}. */
public abstract Builder setCompatibilityLevel(int value);

/** Optional; defaults to an empty list. */
public abstract Builder setExecutionPlatformsToRegister(ImmutableList<String> value);

/** Optional; defaults to an empty list. */
public abstract Builder setToolchainsToRegister(ImmutableList<String> value);

public abstract Builder setDeps(ImmutableMap<String, ModuleKey> value);

abstract ImmutableMap.Builder<String, ModuleKey> depsBuilder();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -122,12 +122,38 @@ private void addRepoNameUsage(String repoName, String how, Location where) throw
named = true,
positional = false,
defaultValue = "0"),
@Param(
name = "execution_platforms_to_register",
doc =
"A list of already-defined execution platforms to be registered when this module is"
+ " selected. Should be a list of absolute target patterns (ie. beginning with"
+ " either <code>@</code> or <code>//</code>). See <a"
+ " href=\"../../toolchains.html\">toolchain resolution</a> for more"
+ " information.",
named = true,
positional = false,
allowedTypes = {@ParamType(type = Iterable.class, generic1 = String.class)},
defaultValue = "[]"),
@Param(
name = "toolchains_to_register",
doc =
"A list of already-defined toolchains to be registered when this module is"
+ " selected. Should be a list of absolute target patterns (ie. beginning with"
+ " either <code>@</code> or <code>//</code>). See <a"
+ " href=\"../../toolchains.html\">toolchain resolution</a> for more"
+ " information.",
named = true,
positional = false,
allowedTypes = {@ParamType(type = Iterable.class, generic1 = String.class)},
defaultValue = "[]"),
},
useStarlarkThread = true)
public void module(
String name,
String version,
StarlarkInt compatibilityLevel,
Iterable<?> executionPlatformsToRegister,
Iterable<?> toolchainsToRegister,
StarlarkThread thread)
throws EvalException {
if (moduleCalled) {
Expand All @@ -144,10 +170,29 @@ public void module(
module
.setName(name)
.setVersion(parsedVersion)
.setCompatibilityLevel(compatibilityLevel.toInt("compatibility_level"));
.setCompatibilityLevel(compatibilityLevel.toInt("compatibility_level"))
.setExecutionPlatformsToRegister(
checkAllAbsolutePatterns(
executionPlatformsToRegister, "execution_platforms_to_register"))
.setToolchainsToRegister(
checkAllAbsolutePatterns(toolchainsToRegister, "toolchains_to_register"));
addRepoNameUsage(name, "as the current module name", thread.getCallerLocation());
}

private static ImmutableList<String> checkAllAbsolutePatterns(Iterable<?> iterable, String where)
throws EvalException {
Sequence<String> list = Sequence.cast(iterable, String.class, where);
for (String item : list) {
if (!item.startsWith("//") && !item.startsWith("@")) {
throw Starlark.errorf(
"Expected absolute target patterns (must begin with '//' or '@') for '%s' argument, but"
+ " got '%s' as an argument",
where, item);
}
}
return list.getImmutableList();
}

@StarlarkMethod(
name = "bazel_dep",
doc = "Declares a direct dependency on another Bazel module.",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,8 @@
import com.google.devtools.build.lib.analysis.config.BuildConfiguration;
import com.google.devtools.build.lib.analysis.platform.PlatformInfo;
import com.google.devtools.build.lib.analysis.platform.PlatformProviderUtils;
import com.google.devtools.build.lib.bazel.bzlmod.ExternalDepsException;
import com.google.devtools.build.lib.bazel.bzlmod.BazelModuleResolutionValue;
import com.google.devtools.build.lib.bazel.bzlmod.Module;
import com.google.devtools.build.lib.cmdline.Label;
import com.google.devtools.build.lib.cmdline.LabelConstants;
import com.google.devtools.build.lib.cmdline.RepositoryName;
Expand All @@ -34,6 +35,7 @@
import com.google.devtools.build.lib.packages.Package;
import com.google.devtools.build.lib.packages.Target;
import com.google.devtools.build.lib.pkgcache.FilteringPolicy;
import com.google.devtools.build.lib.rules.repository.RepositoryDelegatorFunction;
import com.google.devtools.build.lib.server.FailureDetails.Analysis;
import com.google.devtools.build.lib.server.FailureDetails.Analysis.Code;
import com.google.devtools.build.lib.server.FailureDetails.FailureDetail;
Expand Down Expand Up @@ -96,6 +98,13 @@ public SkyValue compute(SkyKey skyKey, Environment env)
}
}

// Get registered execution platforms from bzlmod.
ImmutableList<TargetPattern> bzlmodExecutionPlatforms = getBzlmodExecutionPlatforms(env);
if (bzlmodExecutionPlatforms == null) {
return null;
}
targetPatternBuilder.addAll(TargetPatternUtil.toSigned(bzlmodExecutionPlatforms));

// Get the registered execution platforms from the WORKSPACE.
ImmutableList<TargetPattern> workspaceExecutionPlatforms = getWorkspaceExecutionPlatforms(env);
if (workspaceExecutionPlatforms == null) {
Expand Down Expand Up @@ -146,6 +155,36 @@ public static ImmutableList<TargetPattern> getWorkspaceExecutionPlatforms(Enviro
return externalPackage.getRegisteredExecutionPlatforms();
}

@Nullable
private static ImmutableList<TargetPattern> getBzlmodExecutionPlatforms(Environment env)
throws InterruptedException, RegisteredExecutionPlatformsFunctionException {
if (!RepositoryDelegatorFunction.ENABLE_BZLMOD.get(env)) {
return ImmutableList.of();
}
BazelModuleResolutionValue bazelModuleResolutionValue =
(BazelModuleResolutionValue) env.getValue(BazelModuleResolutionValue.KEY);
if (bazelModuleResolutionValue == null) {
return null;
}
ImmutableList.Builder<TargetPattern> executionPlatforms = ImmutableList.builder();
for (Module module : bazelModuleResolutionValue.getDepGraph().values()) {
TargetPattern.Parser parser =
new TargetPattern.Parser(
PathFragment.EMPTY_FRAGMENT,
RepositoryName.createFromValidStrippedName(module.getCanonicalRepoName()),
bazelModuleResolutionValue.getFullRepoMapping(module.getKey()));
for (String pattern : module.getExecutionPlatformsToRegister()) {
try {
executionPlatforms.add(parser.parse(pattern));
} catch (TargetParsingException e) {
throw new RegisteredExecutionPlatformsFunctionException(
new InvalidExecutionPlatformLabelException(pattern, e), Transience.PERSISTENT);
}
}
}
return executionPlatforms.build();
}

private static ImmutableList<ConfiguredTargetKey> configureRegisteredExecutionPlatforms(
Environment env, BuildConfiguration configuration, List<Label> labels)
throws InterruptedException, RegisteredExecutionPlatformsFunctionException {
Expand Down Expand Up @@ -246,10 +285,5 @@ private RegisteredExecutionPlatformsFunctionException(
InvalidPlatformException cause, Transience transience) {
super(cause, transience);
}

private RegisteredExecutionPlatformsFunctionException(
ExternalDepsException cause, Transience transience) {
super(cause, transience);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,9 @@
import com.google.devtools.build.lib.analysis.config.BuildConfiguration;
import com.google.devtools.build.lib.analysis.platform.DeclaredToolchainInfo;
import com.google.devtools.build.lib.analysis.platform.PlatformProviderUtils;
import com.google.devtools.build.lib.bazel.bzlmod.BazelModuleResolutionValue;
import com.google.devtools.build.lib.bazel.bzlmod.ExternalDepsException;
import com.google.devtools.build.lib.bazel.bzlmod.Module;
import com.google.devtools.build.lib.cmdline.Label;
import com.google.devtools.build.lib.cmdline.LabelConstants;
import com.google.devtools.build.lib.cmdline.RepositoryName;
Expand All @@ -33,6 +35,7 @@
import com.google.devtools.build.lib.cmdline.TargetPattern;
import com.google.devtools.build.lib.packages.Package;
import com.google.devtools.build.lib.pkgcache.FilteringPolicies;
import com.google.devtools.build.lib.rules.repository.RepositoryDelegatorFunction;
import com.google.devtools.build.lib.server.FailureDetails.Toolchain.Code;
import com.google.devtools.build.lib.skyframe.TargetPatternUtil.InvalidTargetPatternException;
import com.google.devtools.build.lib.vfs.PathFragment;
Expand Down Expand Up @@ -85,6 +88,13 @@ public SkyValue compute(SkyKey skyKey, Environment env)
new InvalidToolchainLabelException(e), Transience.PERSISTENT);
}

// Get registered toolchains from bzlmod.
ImmutableList<TargetPattern> bzlmodToolchains = getBzlmodToolchains(env);
if (bzlmodToolchains == null) {
return null;
}
targetPatternBuilder.addAll(TargetPatternUtil.toSigned(bzlmodToolchains));

// Get the registered toolchains from the WORKSPACE.
ImmutableList<TargetPattern> workspaceToolchains = getWorkspaceToolchains(env);
if (workspaceToolchains == null) {
Expand Down Expand Up @@ -135,6 +145,36 @@ public static ImmutableList<TargetPattern> getWorkspaceToolchains(Environment en
return externalPackage.getRegisteredToolchains();
}

@Nullable
private static ImmutableList<TargetPattern> getBzlmodToolchains(Environment env)
throws InterruptedException, RegisteredToolchainsFunctionException {
if (!RepositoryDelegatorFunction.ENABLE_BZLMOD.get(env)) {
return ImmutableList.of();
}
BazelModuleResolutionValue bazelModuleResolutionValue =
(BazelModuleResolutionValue) env.getValue(BazelModuleResolutionValue.KEY);
if (bazelModuleResolutionValue == null) {
return null;
}
ImmutableList.Builder<TargetPattern> toolchains = ImmutableList.builder();
for (Module module : bazelModuleResolutionValue.getDepGraph().values()) {
TargetPattern.Parser parser =
new TargetPattern.Parser(
PathFragment.EMPTY_FRAGMENT,
RepositoryName.createFromValidStrippedName(module.getCanonicalRepoName()),
bazelModuleResolutionValue.getFullRepoMapping(module.getKey()));
for (String pattern : module.getToolchainsToRegister()) {
try {
toolchains.add(parser.parse(pattern));
} catch (TargetParsingException e) {
throw new RegisteredToolchainsFunctionException(
new InvalidToolchainLabelException(pattern, e), Transience.PERSISTENT);
}
}
}
return toolchains.build();
}

private static ImmutableList<DeclaredToolchainInfo> configureRegisteredToolchains(
Environment env, BuildConfiguration configuration, List<Label> labels)
throws InterruptedException, RegisteredToolchainsFunctionException {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -174,7 +174,13 @@ public void setup() throws Exception {
public void testRootModule() throws Exception {
scratch.file(
rootDirectory.getRelative("MODULE.bazel").getPathString(),
"module(name='A',version='0.1',compatibility_level=4)",
"module(",
" name='A',",
" version='0.1',",
" compatibility_level=4,",
" toolchains_to_register=['//my:toolchain', '//my:toolchain2'],",
" execution_platforms_to_register=['//my:platform', '//my:platform2'],",
")",
"bazel_dep(name='B',version='1.0')",
"bazel_dep(name='C',version='2.0',repo_name='see')",
"single_version_override(module_name='D',version='18')",
Expand All @@ -197,6 +203,9 @@ public void testRootModule() throws Exception {
.setVersion(Version.parse("0.1"))
.setKey(ModuleKey.ROOT)
.setCompatibilityLevel(4)
.setExecutionPlatformsToRegister(
ImmutableList.of("//my:platform", "//my:platform2"))
.setToolchainsToRegister(ImmutableList.of("//my:toolchain", "//my:toolchain2"))
.addDep("B", createModuleKey("B", "1.0"))
.addDep("see", createModuleKey("C", "2.0"))
.build());
Expand Down
Loading

0 comments on commit a3cec13

Please sign in to comment.