Skip to content

Commit

Permalink
Canonicalize use_extension label
Browse files Browse the repository at this point in the history
Canonicalize the label by adding the current module's repo_name if the
label doesn't specify a repository name. This is necessary as
ModuleExtensionUsages are grouped by the string value of this label, but
later mapped to their Label representation. If multiple strings map to
the same Label, this would result in a crash.

Also enforce that `module()` is called first (if at all).

Closes #17920.

PiperOrigin-RevId: 520890201
Change-Id: Ice8e2feb0da591e3ba953f4a85284766ba599ebf
  • Loading branch information
fmeum authored and copybara-github committed Mar 31, 2023
1 parent e97f62d commit dd82239
Show file tree
Hide file tree
Showing 4 changed files with 155 additions and 5 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -242,6 +242,8 @@ public Builder addExtensionUsage(ModuleExtensionUsage value) {
return this;
}

abstract ModuleKey getKey();

abstract String getName();

abstract Optional<String> getRepoName();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ public class ModuleFileGlobals {
Pattern.compile("(>|<|-|<=|>=)(\\d+\\.){2}\\d+");

private boolean moduleCalled = false;
private boolean hadNonModuleCall = false;
private final boolean ignoreDevDeps;
private final Module.Builder module;
private final Map<String, ModuleKey> deps = new LinkedHashMap<>();
Expand Down Expand Up @@ -208,6 +209,9 @@ public void module(
if (moduleCalled) {
throw Starlark.errorf("the module() directive can only be called once");
}
if (hadNonModuleCall) {
throw Starlark.errorf("if module() is called, it must be called before any other functions");
}
moduleCalled = true;
if (!name.isEmpty()) {
validateModuleName(name);
Expand Down Expand Up @@ -298,6 +302,7 @@ private static ImmutableList<String> checkAllCompatibilityVersions(
public void bazelDep(
String name, String version, String repoName, boolean devDependency, StarlarkThread thread)
throws EvalException {
hadNonModuleCall = true;
if (repoName.isEmpty()) {
repoName = name;
}
Expand Down Expand Up @@ -330,6 +335,7 @@ public void bazelDep(
allowedTypes = {@ParamType(type = Sequence.class, generic1 = String.class)},
doc = "The labels of the platforms to register."))
public void registerExecutionPlatforms(Sequence<?> platformLabels) throws EvalException {
hadNonModuleCall = true;
module.addExecutionPlatformsToRegister(
checkAllAbsolutePatterns(platformLabels, "register_execution_platforms"));
}
Expand All @@ -347,6 +353,7 @@ public void registerExecutionPlatforms(Sequence<?> platformLabels) throws EvalEx
allowedTypes = {@ParamType(type = Sequence.class, generic1 = String.class)},
doc = "The labels of the toolchains to register."))
public void registerToolchains(Sequence<?> toolchainLabels) throws EvalException {
hadNonModuleCall = true;
module.addToolchainsToRegister(
checkAllAbsolutePatterns(toolchainLabels, "register_toolchains"));
}
Expand Down Expand Up @@ -376,7 +383,14 @@ public void registerToolchains(Sequence<?> toolchainLabels) throws EvalException
},
useStarlarkThread = true)
public ModuleExtensionProxy useExtension(
String extensionBzlFile, String extensionName, boolean devDependency, StarlarkThread thread) {
String rawExtensionBzlFile,
String extensionName,
boolean devDependency,
StarlarkThread thread) {
hadNonModuleCall = true;

String extensionBzlFile = normalizeLabelString(rawExtensionBzlFile);

ModuleExtensionUsageBuilder newUsageBuilder =
new ModuleExtensionUsageBuilder(
extensionBzlFile, extensionName, thread.getCallerLocation());
Expand All @@ -399,6 +413,22 @@ public ModuleExtensionProxy useExtension(
return newUsageBuilder.getProxy(devDependency);
}

private String normalizeLabelString(String rawExtensionBzlFile) {
// Normalize the label by adding the current module's repo_name if the label doesn't specify a
// repository name. This is necessary as ModuleExtensionUsages are grouped by the string value
// of this label, but later mapped to their Label representation. If multiple strings map to the
// same Label, this would result in a crash.
// ownName can't change anymore as calling module() after this results in an error.
String ownName = module.getRepoName().orElse(module.getName());
if (module.getKey().equals(ModuleKey.ROOT) && rawExtensionBzlFile.startsWith("@//")) {
return "@" + ownName + rawExtensionBzlFile.substring(1);
} else if (rawExtensionBzlFile.startsWith("//")) {
return "@" + ownName + rawExtensionBzlFile;
} else {
return rawExtensionBzlFile;
}
}

class ModuleExtensionUsageBuilder {
private final String extensionBzlFile;
private final String extensionName;
Expand Down Expand Up @@ -516,6 +546,7 @@ public void useRepo(
Dict<String, Object> kwargs,
StarlarkThread thread)
throws EvalException {
hadNonModuleCall = true;
Location location = thread.getCallerLocation();
for (String arg : Sequence.cast(args, String.class, "args")) {
extensionProxy.addImport(arg, arg, location);
Expand Down Expand Up @@ -598,6 +629,7 @@ public void singleVersionOverride(
Iterable<?> patchCmds,
StarlarkInt patchStrip)
throws EvalException {
hadNonModuleCall = true;
Version parsedVersion;
try {
parsedVersion = Version.parse(version);
Expand Down Expand Up @@ -652,6 +684,7 @@ public void singleVersionOverride(
})
public void multipleVersionOverride(String moduleName, Iterable<?> versions, String registry)
throws EvalException {
hadNonModuleCall = true;
ImmutableList.Builder<Version> parsedVersionsBuilder = new ImmutableList.Builder<>();
try {
for (String version : Sequence.cast(versions, String.class, "versions").getImmutableList()) {
Expand Down Expand Up @@ -735,6 +768,7 @@ public void archiveOverride(
Iterable<?> patchCmds,
StarlarkInt patchStrip)
throws EvalException {
hadNonModuleCall = true;
ImmutableList<String> urlList =
urls instanceof String
? ImmutableList.of((String) urls)
Expand Down Expand Up @@ -806,6 +840,7 @@ public void gitOverride(
Iterable<?> patchCmds,
StarlarkInt patchStrip)
throws EvalException {
hadNonModuleCall = true;
addOverride(
moduleName,
GitOverride.create(
Expand Down Expand Up @@ -835,6 +870,7 @@ public void gitOverride(
positional = false),
})
public void localPathOverride(String moduleName, String path) throws EvalException {
hadNonModuleCall = true;
addOverride(moduleName, LocalPathOverride.create(path));
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -329,6 +329,88 @@ public void simpleExtension() throws Exception {
assertThat(result.get(skyKey).getModule().getGlobal("data")).isEqualTo("foo:fu bar:ba");
}

@Test
public void simpleExtension_nonCanonicalLabel() throws Exception {
scratch.file(
workspaceRoot.getRelative("MODULE.bazel").getPathString(),
"module(name='my_module', version = '1.0')",
"bazel_dep(name='data_repo', version='1.0')",
"ext1 = use_extension('//:defs.bzl', 'ext')",
"ext1.tag(name='foo', data='fu')",
"use_repo(ext1, 'foo')",
"ext2 = use_extension('@my_module//:defs.bzl', 'ext')",
"ext2.tag(name='bar', data='ba')",
"use_repo(ext2, 'bar')",
"ext3 = use_extension('@//:defs.bzl', 'ext')",
"ext3.tag(name='quz', data='qu')",
"use_repo(ext3, 'quz')");
scratch.file(
workspaceRoot.getRelative("defs.bzl").getPathString(),
"load('@data_repo//:defs.bzl','data_repo')",
"tag = tag_class(attrs = {'name':attr.string(),'data':attr.string()})",
"def _ext_impl(ctx):",
" for mod in ctx.modules:",
" for tag in mod.tags.tag:",
" data_repo(name=tag.name,data=tag.data)",
"ext = module_extension(implementation=_ext_impl, tag_classes={'tag':tag})");
scratch.file(workspaceRoot.getRelative("BUILD").getPathString());
scratch.file(
workspaceRoot.getRelative("data.bzl").getPathString(),
"load('@foo//:data.bzl', foo_data='data')",
"load('@bar//:data.bzl', bar_data='data')",
"load('@quz//:data.bzl', quz_data='data')",
"data = 'foo:'+foo_data+' bar:'+bar_data+' quz:'+quz_data");

SkyKey skyKey = BzlLoadValue.keyForBuild(Label.parseCanonical("//:data.bzl"));
EvaluationResult<BzlLoadValue> result =
evaluator.evaluate(ImmutableList.of(skyKey), evaluationContext);
if (result.hasError()) {
throw result.getError().getException();
}
assertThat(result.get(skyKey).getModule().getGlobal("data")).isEqualTo("foo:fu bar:ba quz:qu");
}

@Test
public void simpleExtension_nonCanonicalLabel_repoName() throws Exception {
scratch.file(
workspaceRoot.getRelative("MODULE.bazel").getPathString(),
"module(name='my_module', version = '1.0', repo_name='my_name')",
"bazel_dep(name='data_repo', version='1.0')",
"ext1 = use_extension('//:defs.bzl', 'ext')",
"ext1.tag(name='foo', data='fu')",
"use_repo(ext1, 'foo')",
"ext2 = use_extension('@my_name//:defs.bzl', 'ext')",
"ext2.tag(name='bar', data='ba')",
"use_repo(ext2, 'bar')",
"ext3 = use_extension('@//:defs.bzl', 'ext')",
"ext3.tag(name='quz', data='qu')",
"use_repo(ext3, 'quz')");
scratch.file(
workspaceRoot.getRelative("defs.bzl").getPathString(),
"load('@data_repo//:defs.bzl','data_repo')",
"tag = tag_class(attrs = {'name':attr.string(),'data':attr.string()})",
"def _ext_impl(ctx):",
" for mod in ctx.modules:",
" for tag in mod.tags.tag:",
" data_repo(name=tag.name,data=tag.data)",
"ext = module_extension(implementation=_ext_impl, tag_classes={'tag':tag})");
scratch.file(workspaceRoot.getRelative("BUILD").getPathString());
scratch.file(
workspaceRoot.getRelative("data.bzl").getPathString(),
"load('@foo//:data.bzl', foo_data='data')",
"load('@bar//:data.bzl', bar_data='data')",
"load('@quz//:data.bzl', quz_data='data')",
"data = 'foo:'+foo_data+' bar:'+bar_data+' quz:'+quz_data");

SkyKey skyKey = BzlLoadValue.keyForBuild(Label.parseCanonical("//:data.bzl"));
EvaluationResult<BzlLoadValue> result =
evaluator.evaluate(ImmutableList.of(skyKey), evaluationContext);
if (result.hasError()) {
throw result.getError().getException();
}
assertThat(result.get(skyKey).getModule().getGlobal("data")).isEqualTo("foo:fu bar:ba quz:qu");
}

@Test
public void multipleModules() throws Exception {
scratch.file(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -473,7 +473,7 @@ public void testModuleExtensions_good() throws Exception {
.setRegistry(registry)
.addExtensionUsage(
ModuleExtensionUsage.builder()
.setExtensionBzlFile("//:defs.bzl")
.setExtensionBzlFile("@mymod//:defs.bzl")
.setExtensionName("myext1")
.setLocation(Location.fromFileLineColumn("mymod@1.0/MODULE.bazel", 2, 23))
.setImports(ImmutableBiMap.of("repo1", "repo1"))
Expand All @@ -491,7 +491,7 @@ public void testModuleExtensions_good() throws Exception {
.build())
.addExtensionUsage(
ModuleExtensionUsage.builder()
.setExtensionBzlFile("//:defs.bzl")
.setExtensionBzlFile("@mymod//:defs.bzl")
.setExtensionName("myext2")
.setLocation(Location.fromFileLineColumn("mymod@1.0/MODULE.bazel", 5, 23))
.setImports(ImmutableBiMap.of("other_repo1", "repo1", "repo2", "repo2"))
Expand Down Expand Up @@ -582,7 +582,7 @@ public void testModuleExtensions_duplicateProxy_asRoot() throws Exception {
.setKey(ModuleKey.ROOT)
.addExtensionUsage(
ModuleExtensionUsage.builder()
.setExtensionBzlFile("//:defs.bzl")
.setExtensionBzlFile("@//:defs.bzl")
.setExtensionName("myext")
.setLocation(Location.fromFileLineColumn("<root>/MODULE.bazel", 1, 23))
.setImports(
Expand Down Expand Up @@ -672,7 +672,7 @@ public void testModuleExtensions_duplicateProxy_asDep() throws Exception {
.setRegistry(registry)
.addExtensionUsage(
ModuleExtensionUsage.builder()
.setExtensionBzlFile("//:defs.bzl")
.setExtensionBzlFile("@mymod//:defs.bzl")
.setExtensionName("myext")
.setLocation(Location.fromFileLineColumn("mymod@1.0/MODULE.bazel", 5, 23))
.setImports(ImmutableBiMap.of("beta", "beta", "delta", "delta"))
Expand Down Expand Up @@ -956,4 +956,34 @@ public void moduleRepoName_conflict() throws Exception {

assertContainsEvent("The repo name 'bbb' is already being used as the module's own repo name");
}

@Test
public void module_calledTwice() throws Exception {
scratch.file(
rootDirectory.getRelative("MODULE.bazel").getPathString(),
"module(name='aaa',version='0.1',repo_name='bbb')",
"module(name='aaa',version='0.1',repo_name='bbb')");
FakeRegistry registry = registryFactory.newFakeRegistry("/foo");
ModuleFileFunction.REGISTRIES.set(differencer, ImmutableList.of(registry.getUrl()));

reporter.removeHandler(failFastHandler); // expect failures
evaluator.evaluate(ImmutableList.of(ModuleFileValue.KEY_FOR_ROOT_MODULE), evaluationContext);

assertContainsEvent("the module() directive can only be called once");
}

@Test
public void module_calledLate() throws Exception {
scratch.file(
rootDirectory.getRelative("MODULE.bazel").getPathString(),
"use_extension('//:extensions.bzl', 'my_ext')",
"module(name='aaa',version='0.1',repo_name='bbb')");
FakeRegistry registry = registryFactory.newFakeRegistry("/foo");
ModuleFileFunction.REGISTRIES.set(differencer, ImmutableList.of(registry.getUrl()));

reporter.removeHandler(failFastHandler); // expect failures
evaluator.evaluate(ImmutableList.of(ModuleFileValue.KEY_FOR_ROOT_MODULE), evaluationContext);

assertContainsEvent("if module() is called, it must be called before any other functions");
}
}

0 comments on commit dd82239

Please sign in to comment.