Skip to content

Commit

Permalink
Get toolchain context and toolchain info from automatic exec groups i…
Browse files Browse the repository at this point in the history
…f needed

Before this, toolchain context returned context from the default exec group, which is invalid once the AEG are implemented (https://docs.google.com/document/d/1-rbP_hmKs9D639YWw5F_JyxPxL2bi6dSmmvj_WXak9M).

I've added getToolchainContext(toolchainType) and getToolchainInfo(toolchainType) functions which cover automatic exec groups.

PiperOrigin-RevId: 502816645
Change-Id: Ife34f97700c28b0a1a64f2663e1a8182de5ca44c
  • Loading branch information
kotlaja authored and copybara-github committed Jan 18, 2023
1 parent 130b444 commit 928e0fe
Show file tree
Hide file tree
Showing 8 changed files with 76 additions and 32 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@
import com.google.devtools.build.lib.analysis.config.Fragment;
import com.google.devtools.build.lib.analysis.platform.ConstraintValueInfo;
import com.google.devtools.build.lib.analysis.platform.PlatformInfo;
import com.google.devtools.build.lib.analysis.platform.ToolchainInfo;
import com.google.devtools.build.lib.analysis.starlark.StarlarkRuleContext;
import com.google.devtools.build.lib.analysis.stringtemplate.TemplateContext;
import com.google.devtools.build.lib.cmdline.Label;
Expand Down Expand Up @@ -1216,6 +1217,10 @@ public boolean useAutoExecGroups() {
}
}

/**
* Returns the toolchain context from the default exec group. Important note: In case automatic
* exec groups are enabled, use `getToolchainContext(Label toolchainType)` function.
*/
@Nullable
public ResolvedToolchainContext getToolchainContext() {
return toolchainContexts == null ? null : toolchainContexts.getDefaultToolchainContext();
Expand All @@ -1226,6 +1231,22 @@ private ResolvedToolchainContext getToolchainContext(String execGroup) {
return toolchainContexts == null ? null : toolchainContexts.getToolchainContext(execGroup);
}

/**
* Returns the toolchain info from the default exec group in case automatic exec groups are not
* enabled. If they are enabled, retrieves toolchain info from the corresponding automatic exec
* group.
*/
@Nullable
public ToolchainInfo getToolchainInfo(Label toolchainType) {
ResolvedToolchainContext toolchainContext;
if (useAutoExecGroups()) {
toolchainContext = toolchainContexts.getToolchainContext(toolchainType.toString());
} else {
toolchainContext = getToolchainContext();
}
return toolchainContext == null ? null : toolchainContext.forToolchainType(toolchainType);
}

public boolean hasToolchainContext(String execGroup) {
return toolchainContexts != null && toolchainContexts.hasToolchainContext(execGroup);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,6 @@
import com.google.devtools.build.lib.analysis.AnalysisUtils;
import com.google.devtools.build.lib.analysis.Expander;
import com.google.devtools.build.lib.analysis.FileProvider;
import com.google.devtools.build.lib.analysis.ResolvedToolchainContext;
import com.google.devtools.build.lib.analysis.RuleConfiguredTargetBuilder;
import com.google.devtools.build.lib.analysis.RuleContext;
import com.google.devtools.build.lib.analysis.RuleErrorConsumer;
Expand Down Expand Up @@ -368,17 +367,11 @@ public static Label getToolchainTypeFromRuleClass(RuleContext ruleContext) {

private static CcToolchainProvider getToolchainFromPlatformConstraints(
RuleContext ruleContext, Label toolchainType) throws RuleErrorException {
ResolvedToolchainContext toolchainContext = ruleContext.getToolchainContext();
ToolchainInfo toolchainInfo = toolchainContext.forToolchainType(toolchainType);
ToolchainInfo toolchainInfo = ruleContext.getToolchainInfo(toolchainType);
if (toolchainInfo == null) {
throw ruleContext.throwWithRuleError(
String.format(
"Unable to find a CC toolchain using toolchain resolution"
+ " (target %s, target platform %s, exec platform %s)."
+ " Did you properly set --platforms?",
ruleContext.getLabel(),
toolchainContext.targetPlatform().label(),
toolchainContext.executionPlatform().label()));
"Unable to find a CC toolchain using toolchain resolution. Did you properly set"
+ " --platforms?");
}
try {
return (CcToolchainProvider) toolchainInfo.getValue("cc");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -76,12 +76,10 @@ public static JavaRuntimeInfo forHost(RuleContext ruleContext) {

public static JavaRuntimeInfo from(RuleContext ruleContext) {
ToolchainInfo toolchainInfo =
ruleContext
.getToolchainContext()
.forToolchainType(
ruleContext
.getPrerequisite(JavaRuleClasses.JAVA_RUNTIME_TOOLCHAIN_TYPE_ATTRIBUTE_NAME)
.getLabel());
ruleContext.getToolchainInfo(
ruleContext
.getPrerequisite(JavaRuleClasses.JAVA_RUNTIME_TOOLCHAIN_TYPE_ATTRIBUTE_NAME)
.getLabel());
return from(ruleContext, toolchainInfo);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,12 +56,10 @@ public final class JavaToolchainProvider extends NativeInfo
/** Returns the Java Toolchain associated with the rule being analyzed or {@code null}. */
public static JavaToolchainProvider from(RuleContext ruleContext) {
ToolchainInfo toolchainInfo =
ruleContext
.getToolchainContext()
.forToolchainType(
ruleContext
.getPrerequisite(JavaRuleClasses.JAVA_TOOLCHAIN_TYPE_ATTRIBUTE_NAME)
.getLabel());
ruleContext.getToolchainInfo(
ruleContext
.getPrerequisite(JavaRuleClasses.JAVA_TOOLCHAIN_TYPE_ATTRIBUTE_NAME)
.getLabel());
return from(toolchainInfo, ruleContext);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -475,7 +475,7 @@ private static PyRuntimeInfo initRuntimeFromToolchain(
return null;
}
Label toolchainType = ruleContext.attributes().get("$py_toolchain_type", BuildType.NODEP_LABEL);
ToolchainInfo toolchainInfo = ruleContext.getToolchainContext().forToolchainType(toolchainType);
ToolchainInfo toolchainInfo = ruleContext.getToolchainInfo(toolchainType);
Preconditions.checkArgument(
toolchainInfo != null,
"Could not retrieve a Python toolchain for '%s' rule",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,10 @@

import com.google.common.collect.ImmutableMap;
import com.google.common.collect.ObjectArrays;
import com.google.devtools.build.lib.analysis.platform.ToolchainInfo;
import com.google.devtools.build.lib.analysis.platform.ToolchainTypeInfo;
import com.google.devtools.build.lib.analysis.util.BuildViewTestCase;
import com.google.devtools.build.lib.cmdline.Label;
import com.google.devtools.build.lib.packages.ExecGroup;
import org.junit.Before;
import org.junit.Test;
Expand Down Expand Up @@ -232,4 +235,41 @@ public void automaticExecutionGroups_enabledAndAttributeNotSet_enabled() throws

assertThat(execGroups).isNotEmpty();
}

@Test
public void getToolchainInfoAndContext_automaticExecGroupsEnabled() throws Exception {
createCustomRule(
/* actionParameters= */ "toolchain = '//rule:toolchain_type_1',",
/* extraAttributes= */ "",
/* toolchains= */ "['//rule:toolchain_type_1']");
useConfiguration("--incompatible_auto_exec_groups");

ConfiguredTarget target = getConfiguredTarget("//test:custom_rule_name");
RuleContext ruleContext = getRuleContext(target);
ImmutableMap<ToolchainTypeInfo, ToolchainInfo> defaultExecGroupToolchains =
ruleContext.getToolchainContext().toolchains();
ToolchainInfo toolchainInfo =
ruleContext.getToolchainInfo(Label.parseCanonical("//rule:toolchain_type_1"));

assertThat(defaultExecGroupToolchains).isEmpty();
assertThat(toolchainInfo).isNotNull();
}

@Test
public void getToolchainInfoAndContext_automaticExecGroupsDisabled() throws Exception {
createCustomRule(
/* actionParameters= */ "toolchain = '//rule:toolchain_type_1',",
/* extraAttributes= */ "",
/* toolchains= */ "['//rule:toolchain_type_1']");

ConfiguredTarget target = getConfiguredTarget("//test:custom_rule_name");
RuleContext ruleContext = getRuleContext(target);
ImmutableMap<ToolchainTypeInfo, ToolchainInfo> defaultExecGroupToolchains =
ruleContext.getToolchainContext().toolchains();
ToolchainInfo toolchainInfo =
ruleContext.getToolchainInfo(Label.parseCanonical("//rule:toolchain_type_1"));

assertThat(defaultExecGroupToolchains).isNotEmpty();
assertThat(toolchainInfo).isNotNull();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -52,9 +52,7 @@ public void testResolvedCcToolchain() throws Exception {
.setList("srcs", "a.cc")
.write();
ToolchainInfo toolchainInfo =
getRuleContext(target)
.getToolchainContext()
.forToolchainType(Label.parseCanonical(CPP_TOOLCHAIN_TYPE));
getRuleContext(target).getToolchainInfo(Label.parseCanonical(CPP_TOOLCHAIN_TYPE));
CcToolchainProvider toolchain = (CcToolchainProvider) toolchainInfo.getValue("cc");
assertThat(toolchain.getToolchainIdentifier()).endsWith("k8");
}
Expand All @@ -70,9 +68,7 @@ public void testToolchainSelectionWithPlatforms() throws Exception {
.setList("srcs", "a.cc")
.write();
ToolchainInfo toolchainInfo =
getRuleContext(target)
.getToolchainContext()
.forToolchainType(Label.parseCanonical(CPP_TOOLCHAIN_TYPE));
getRuleContext(target).getToolchainInfo(Label.parseCanonical(CPP_TOOLCHAIN_TYPE));
CcToolchainProvider toolchain = (CcToolchainProvider) toolchainInfo.getValue("cc");
assertThat(toolchain.getToolchainIdentifier()).endsWith("k8");
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,9 +39,7 @@ public void testToolchains() throws Exception {

assertThat(ruleContext.getToolchainContext()).hasToolchainType("//toolchain:test_toolchain");
ToolchainInfo toolchain =
ruleContext
.getToolchainContext()
.forToolchainType(Label.parseCanonical("//toolchain:test_toolchain"));
ruleContext.getToolchainInfo(Label.parseCanonical("//toolchain:test_toolchain"));
assertThat(toolchain.getValue("data")).isEqualTo("foo");
}

Expand Down

0 comments on commit 928e0fe

Please sign in to comment.