Skip to content

Commit

Permalink
Toolchainize //scala_proto:{,deps_}toolchain_type
Browse files Browse the repository at this point in the history
Adds scala_proto toolchains to `scala_toolchains()`. Part of #1482.

The most significant part of the change is moving all the toolchain
rules from `scala_proto/BUILD` to `setup_scala_toolchains()` in
`scala_proto/toolchains.bzl`.

Adds the `scala_proto_deps_providers()` macro to replace
`//scala_proto:scalapb_{compile,grpc,worker}_deps_provider` targets in
the `dep_providers` parameter of `scala_proto_deps_toolchain()`.
Examples of this are in `test/proto/custom_generator/BUILD`.

Excludes `@scala_proto_rules_scalapb_protoc_gen` from
`DEFAULT_SCALAPB_WORKER_DEPS` in `scala_proto/default/default_deps.bzl`
for Scala 2.11. For other Scala versions, this repo name will have the
Scala version appended. This is to avoid build failures under Bzlmod,
since:

- This repo is required by ScalaPB 0.11.17, but Scala 2.11 is capped at
  ScalaPB 0.9.8.

- Importing the nonexistent `scala_proto_rules_scalapb_protoc_gen` under
  Scala 2.11 results in an error under Bzlmod, as does importing it
  multiple times when configuring multiple Scala versions.

- `MODULE.bazel` can iterate over a list of Scala versions, filtering
  out Scala 2.11, and call `use_repo()` on each version specific repo.

A lot of the other changes are more opportunistic removals of
`@io_bazel_rules_scala` label prefixes and application of `Label()`
where appropriate. Doing this will allow Bzlmod users to use
`rules_scala` without setting `repo_name = "@io_bazel_rules_scala"` in
`bazel_dep()`.
  • Loading branch information
mbland committed Feb 3, 2025
1 parent f30237b commit 62c17f1
Show file tree
Hide file tree
Showing 13 changed files with 177 additions and 138 deletions.
5 changes: 1 addition & 4 deletions WORKSPACE
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ load("//scala:toolchains.bzl", "scala_toolchains")

scala_toolchains(
fetch_sources = True,
scala_proto = True,
scalafmt = True,
testing = True,
)
Expand All @@ -68,10 +69,6 @@ load("//jmh:jmh.bzl", "jmh_repositories")

jmh_repositories()

load("//scala_proto:scala_proto.bzl", "scala_proto_repositories")

scala_proto_repositories()

# needed for the cross repo proto test
local_repository(
name = "proto_cross_repo_boundary",
Expand Down
6 changes: 3 additions & 3 deletions scala/private/toolchain_deps/toolchain_dep_rules.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -3,16 +3,16 @@ load(
"expose_toolchain_deps",
)

_toolchain_type = "@io_bazel_rules_scala//scala:toolchain_type"
_TOOLCHAIN_TYPE = Label("//scala:toolchain_type")

def _common_toolchain_deps(ctx):
return expose_toolchain_deps(ctx, _toolchain_type)
return expose_toolchain_deps(ctx, _TOOLCHAIN_TYPE)

common_toolchain_deps = rule(
implementation = _common_toolchain_deps,
attrs = {
"deps_id": attr.string(mandatory = True),
},
toolchains = [_toolchain_type],
toolchains = [_TOOLCHAIN_TYPE],
incompatible_use_toolchain_transition = True,
)
16 changes: 15 additions & 1 deletion scala/toolchains.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ load(
)
load("//scala:scala_cross_version.bzl", "default_maven_server_urls")
load("//scala:toolchains_repo.bzl", "scala_toolchains_repo")
load("//scala_proto/default:repositories.bzl", "scala_proto_artifact_ids")
load("//scalatest:scalatest.bzl", "scalatest_artifact_ids")
load("//specs2:specs2.bzl", "specs2_artifact_ids")
load("//specs2:specs2_junit.bzl", "specs2_junit_artifact_ids")
Expand All @@ -28,7 +29,9 @@ def scala_toolchains(
specs2 = False,
testing = False,
scalafmt = False,
scalafmt_default_config_path = ".scalafmt.conf"):
scalafmt_default_config_path = ".scalafmt.conf",
scala_proto = False,
scala_proto_enable_all_options = False):
"""Instantiates @io_bazel_rules_scala_toolchains and all its dependencies.
Provides a unified interface to configuring rules_scala both directly in a
Expand Down Expand Up @@ -76,6 +79,10 @@ def scala_toolchains(
scalafmt: whether to instantiate the Scalafmt toolchain
scalafmt_default_config_path: the relative path to the default Scalafmt
config file within the repository
scala_proto: whether to instantiate the scala_proto toolchain
scala_proto_enable_all_options: whether to instantiate the scala_proto
toolchain with all options enabled; `scala_proto` must also be
`True` for this to take effect
"""
scala_repositories(
maven_servers = maven_servers,
Expand Down Expand Up @@ -119,6 +126,11 @@ def scala_toolchains(
for scala_version in SCALA_VERSIONS:
version_specific_artifact_ids = {}

if scala_proto:
version_specific_artifact_ids.update({
id: True
for id in scala_proto_artifact_ids(scala_version)
})
if scalafmt:
version_specific_artifact_ids.update({
id: fetch_sources
Expand All @@ -145,6 +157,8 @@ def scala_toolchains(
specs2 = specs2,
testing = testing,
scalafmt = scalafmt,
scala_proto = scala_proto,
scala_proto_enable_all_options = scala_proto_enable_all_options,
)

def scala_register_toolchains():
Expand Down
45 changes: 45 additions & 0 deletions scala/toolchains_repo.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -39,11 +39,14 @@ def _scala_toolchains_repo_impl(repository_ctx):
repo_attr = repository_ctx.attr
format_args = {
"rules_scala_repo": Label("//:all").repo_name,
"proto_enable_all_options": repo_attr.scala_proto_enable_all_options,
}
toolchains = {}

if repo_attr.scala:
toolchains["scala"] = _SCALA_TOOLCHAIN_BUILD
if repo_attr.scala_proto:
toolchains["scala_proto"] = _SCALA_PROTO_TOOLCHAIN_BUILD

testing_build_args = _generate_testing_toolchain_build_file_args(repo_attr)
if testing_build_args != None:
Expand Down Expand Up @@ -73,6 +76,8 @@ _scala_toolchains_repo = repository_rule(
"specs2": attr.bool(),
"testing": attr.bool(),
"scalafmt": attr.bool(),
"scala_proto": attr.bool(),
"scala_proto_enable_all_options": attr.bool(),
},
)

Expand Down Expand Up @@ -156,3 +161,43 @@ load(
setup_scalafmt_toolchains()
"""

_SCALA_PROTO_TOOLCHAIN_BUILD = """
load("@@{rules_scala_repo}//scala:providers.bzl", "declare_deps_provider")
load(
"@@{rules_scala_repo}//scala_proto/default:default_deps.bzl",
"DEFAULT_SCALAPB_COMPILE_DEPS",
"DEFAULT_SCALAPB_GRPC_DEPS",
"DEFAULT_SCALAPB_WORKER_DEPS",
)
load(
"@@{rules_scala_repo}//scala_proto:toolchains.bzl",
"setup_scala_proto_toolchains",
)
setup_scala_proto_toolchains(
name = "scala_proto",
enable_all_options = {proto_enable_all_options},
)
declare_deps_provider(
name = "scalapb_compile_deps_provider",
deps_id = "scalapb_compile_deps",
visibility = ["//visibility:public"],
deps = DEFAULT_SCALAPB_COMPILE_DEPS,
)
declare_deps_provider(
name = "scalapb_grpc_deps_provider",
deps_id = "scalapb_grpc_deps",
visibility = ["//visibility:public"],
deps = DEFAULT_SCALAPB_GRPC_DEPS,
)
declare_deps_provider(
name = "scalapb_worker_deps_provider",
deps_id = "scalapb_worker_deps",
visibility = ["//visibility:public"],
deps = DEFAULT_SCALAPB_WORKER_DEPS,
)
"""
74 changes: 0 additions & 74 deletions scala_proto/BUILD
Original file line number Diff line number Diff line change
@@ -1,16 +1,4 @@
load("//scala:providers.bzl", "declare_deps_provider")
load(
"//scala_proto/default:default_deps.bzl",
"DEFAULT_SCALAPB_COMPILE_DEPS",
"DEFAULT_SCALAPB_GRPC_DEPS",
"DEFAULT_SCALAPB_WORKER_DEPS",
)
load("//scala_proto/private:toolchain_deps.bzl", "export_scalapb_toolchain_deps")
load(
"//scala_proto:scala_proto_toolchain.bzl",
"scala_proto_deps_toolchain",
"scala_proto_toolchain",
)

toolchain_type(
name = "toolchain_type",
Expand All @@ -22,68 +10,6 @@ toolchain_type(
visibility = ["//visibility:public"],
)

scala_proto_deps_toolchain(
name = "default_deps_toolchain_impl",
visibility = ["//visibility:public"],
)

scala_proto_toolchain(
name = "default_toolchain_impl",
visibility = ["//visibility:public"],
with_flat_package = False,
with_grpc = True,
with_single_line_to_string = False,
)

toolchain(
name = "default_toolchain",
toolchain = ":default_toolchain_impl",
toolchain_type = "//scala_proto:toolchain_type",
visibility = ["//visibility:public"],
)

toolchain(
name = "default_deps_toolchain",
toolchain = ":default_deps_toolchain_impl",
toolchain_type = ":deps_toolchain_type",
)

scala_proto_toolchain(
name = "enable_all_options_toolchain_impl",
visibility = ["//visibility:public"],
with_flat_package = True,
with_grpc = True,
with_single_line_to_string = True,
)

toolchain(
name = "enable_all_options_toolchain",
toolchain = ":enable_all_options_toolchain_impl",
toolchain_type = "//scala_proto:toolchain_type",
visibility = ["//visibility:public"],
)

declare_deps_provider(
name = "scalapb_compile_deps_provider",
deps_id = "scalapb_compile_deps",
visibility = ["//visibility:public"],
deps = DEFAULT_SCALAPB_COMPILE_DEPS,
)

declare_deps_provider(
name = "scalapb_grpc_deps_provider",
deps_id = "scalapb_grpc_deps",
visibility = ["//visibility:public"],
deps = DEFAULT_SCALAPB_GRPC_DEPS,
)

declare_deps_provider(
name = "scalapb_worker_deps_provider",
deps_id = "scalapb_worker_deps",
visibility = ["//visibility:public"],
deps = DEFAULT_SCALAPB_WORKER_DEPS,
)

export_scalapb_toolchain_deps(
name = "scalapb_worker_deps",
deps_id = "scalapb_worker_deps",
Expand Down
18 changes: 16 additions & 2 deletions scala_proto/default/default_deps.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,22 @@
# dependency lists. This needs to be the unrolled transitive path to be used
# without such a facility.

load("//scala:scala_cross_version.bzl", "repositories")
load("//scala:scala_cross_version_select.bzl", "select_for_scala_version")
load("@io_bazel_rules_scala_config//:config.bzl", "SCALA_VERSION")

_DEFAULT_DEP_PROVIDER_FORMAT = (
"@io_bazel_rules_scala_toolchains//scala_proto:scalapb_%s_deps_provider"
)

def scala_proto_deps_providers(
compile = _DEFAULT_DEP_PROVIDER_FORMAT % "compile",
grpc = _DEFAULT_DEP_PROVIDER_FORMAT % "grpc",
worker = _DEFAULT_DEP_PROVIDER_FORMAT % "worker"):
return [compile, grpc, worker]

DEFAULT_SCALAPB_COMPILE_DEPS = [
"//scala/private/toolchain_deps:scala_library_classpath",
Label("//scala/private/toolchain_deps:scala_library_classpath"),
"@com_google_protobuf//:protobuf_java",
"@com_lihaoyi_fastparse",
"@scala_proto_rules_scalapb_lenses",
Expand Down Expand Up @@ -51,5 +63,7 @@ DEFAULT_SCALAPB_WORKER_DEPS = [
"@scala_proto_rules_scalapb_protoc_bridge",
] + select_for_scala_version(
any_2_11 = [],
since_2_12 = ["@scala_proto_rules_scalapb_protoc_gen"],
since_2_12 = repositories(SCALA_VERSION, [
"@scala_proto_rules_scalapb_protoc_gen",
]),
)
22 changes: 9 additions & 13 deletions scala_proto/private/scala_proto_aspect.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -9,14 +9,10 @@ load(
)
load("//scala/private/toolchain_deps:toolchain_deps.bzl", "find_deps_info_on")
load(
"@io_bazel_rules_scala//scala_proto/private:scala_proto_aspect_provider.bzl",
"//scala_proto/private:scala_proto_aspect_provider.bzl",
"ScalaProtoAspectInfo",
)
load(
"@io_bazel_rules_scala//scala/private:phases/api.bzl",
"extras_phases",
"run_aspect_phases",
)
load("//scala/private:phases/api.bzl", "extras_phases", "run_aspect_phases")
load("@bazel_skylib//lib:dicts.bzl", "dicts")

def _import_paths(proto, ctx):
Expand Down Expand Up @@ -47,7 +43,7 @@ def _code_should_be_generated(ctx, toolchain):
return toolchain.blacklisted_protos.get(target_absolute_label) == None

def _compile_deps(ctx, toolchain):
deps_toolchain_type_label = "@io_bazel_rules_scala//scala_proto:deps_toolchain_type"
deps_toolchain_type_label = Label("//scala_proto:deps_toolchain_type")
return [
dep[JavaInfo]
for id in toolchain.compile_dep_ids
Expand Down Expand Up @@ -149,14 +145,14 @@ def _phase_deps(ctx, p):
return [d[ScalaProtoAspectInfo].java_info for d in ctx.rule.attr.deps]

def _phase_scalacopts(ctx, p):
return ctx.toolchains["@io_bazel_rules_scala//scala:toolchain_type"].scalacopts
return ctx.toolchains[Label("//scala:toolchain_type")].scalacopts

def _phase_generate_and_compile(ctx, p):
proto = p.proto_info
deps = p.deps
scalacopts = p.scalacopts
stamp_label = p.stamp_label
toolchain = ctx.toolchains["@io_bazel_rules_scala//scala_proto:toolchain_type"]
toolchain = ctx.toolchains[Label("//scala_proto:toolchain_type")]

if proto.direct_sources and _code_should_be_generated(ctx, toolchain):
src_jars = _generate_sources(ctx, toolchain, proto)
Expand All @@ -181,7 +177,7 @@ def _strip_suffix(str, suffix):

def _phase_stamp_label(ctx, p):
rule_label = str(p.target.label)
toolchain = ctx.toolchains["@io_bazel_rules_scala//scala_proto:toolchain_type"]
toolchain = ctx.toolchains[Label("//scala_proto:toolchain_type")]

if toolchain.stamp_by_convention and rule_label.endswith("_proto"):
return _strip_suffix(rule_label, "_proto") + "_scala_proto"
Expand Down Expand Up @@ -228,9 +224,9 @@ def make_scala_proto_aspect(*extras):
*[extra["attrs"] for extra in extras if "attrs" in extra]
),
toolchains = [
"@io_bazel_rules_scala//scala:toolchain_type",
"@io_bazel_rules_scala//scala_proto:toolchain_type",
"@io_bazel_rules_scala//scala_proto:deps_toolchain_type",
Label("//scala:toolchain_type"),
Label("//scala_proto:toolchain_type"),
Label("//scala_proto:deps_toolchain_type"),
"@bazel_tools//tools/jdk:toolchain_type",
],
)
Expand Down
8 changes: 5 additions & 3 deletions scala_proto/private/toolchain_deps.bzl
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
load(
"@io_bazel_rules_scala//scala/private/toolchain_deps:toolchain_deps.bzl",
"//scala/private/toolchain_deps:toolchain_deps.bzl",
"expose_toolchain_deps",
)

_DEPS_TOOLCHAIN_TYPE = Label("//scala_proto:deps_toolchain_type")

def _export_scalapb_toolchain_deps(ctx):
return expose_toolchain_deps(ctx, "@io_bazel_rules_scala//scala_proto:deps_toolchain_type")
return expose_toolchain_deps(ctx, _DEPS_TOOLCHAIN_TYPE)

export_scalapb_toolchain_deps = rule(
_export_scalapb_toolchain_deps,
Expand All @@ -14,5 +16,5 @@ export_scalapb_toolchain_deps = rule(
),
},
incompatible_use_toolchain_transition = True,
toolchains = ["@io_bazel_rules_scala//scala_proto:deps_toolchain_type"],
toolchains = [_DEPS_TOOLCHAIN_TYPE],
)
12 changes: 7 additions & 5 deletions scala_proto/scala_proto_toolchain.bzl
Original file line number Diff line number Diff line change
@@ -1,4 +1,8 @@
load("//scala:providers.bzl", "DepsInfo")
load(
"//scala_proto/default:default_deps.bzl",
_scala_proto_deps_providers = "scala_proto_deps_providers",
)

def _generators(ctx):
return dict(
Expand Down Expand Up @@ -138,13 +142,11 @@ scala_proto_deps_toolchain = rule(
implementation = _scala_proto_deps_toolchain,
attrs = {
"dep_providers": attr.label_list(
default = [
Label("//scala_proto:scalapb_compile_deps_provider"),
Label("//scala_proto:scalapb_grpc_deps_provider"),
Label("//scala_proto:scalapb_worker_deps_provider"),
],
default = _scala_proto_deps_providers(),
cfg = "target",
providers = [DepsInfo],
),
},
)

scala_proto_deps_providers = _scala_proto_deps_providers
Loading

0 comments on commit 62c17f1

Please sign in to comment.