From 1984a7199c34c6327ba51e40acc68c4ee3456bf0 Mon Sep 17 00:00:00 2001 From: Bor Kae Hwang Date: Fri, 25 Oct 2019 10:54:14 -0600 Subject: [PATCH] reduce changes --- scala/private/common.bzl | 25 ++++++++++++++++++------- scala/private/phases/phase_init.bzl | 18 ++++-------------- 2 files changed, 22 insertions(+), 21 deletions(-) diff --git a/scala/private/common.bzl b/scala/private/common.bzl index 34d4244ea4..7a32ba171f 100644 --- a/scala/private/common.bzl +++ b/scala/private/common.bzl @@ -1,6 +1,10 @@ load("@io_bazel_rules_scala//scala:jars_to_labels.bzl", "JarsToLabelsInfo") load("@io_bazel_rules_scala//scala:plusone.bzl", "PlusOneDeps") +def write_manifest(ctx): + main_class = getattr(ctx.attr, "main_class", None) + write_manifest_file(ctx.actions, ctx.outputs.manifest, main_class) + def write_manifest_file(actions, output_file, main_class): # TODO(bazel-team): I don't think this classpath is what you want manifest = "Class-Path: \n" @@ -9,6 +13,13 @@ def write_manifest_file(actions, output_file, main_class): actions.write(output = output_file, content = manifest) +def collect_srcjars(targets): + srcjars = [] + for target in targets: + if hasattr(target, "srcjars"): + srcjars.append(target.srcjars.srcjar) + return depset(srcjars) + def collect_jars( dep_targets, dependency_analyzer_is_off = True, @@ -37,7 +48,7 @@ def collect_plugin_paths(plugins): # which breaks scala macros elif hasattr(p, "files"): - paths.extend([f for f in p.files.to_list() if _not_sources_jar(f.basename)]) + paths.extend([f for f in p.files.to_list() if not_sources_jar(f.basename)]) return depset(paths) def _collect_jars_when_dependency_analyzer_is_off( @@ -60,7 +71,7 @@ def _collect_jars_when_dependency_analyzer_is_off( runtime_jars.append(java_provider.transitive_runtime_jars) if not unused_dependency_checker_is_off: - _add_labels_of_jars_to( + add_labels_of_jars_to( jars2labels, dep_target, [], @@ -99,7 +110,7 @@ def _collect_jars_when_dependency_analyzer_is_on(dep_targets): compile_jars.append(current_dep_compile_jars) transitive_compile_jars.append(current_dep_transitive_compile_jars) - _add_labels_of_jars_to( + add_labels_of_jars_to( jars2labels, dep_target, current_dep_transitive_compile_jars.to_list(), @@ -122,15 +133,15 @@ def _collect_jars_when_dependency_analyzer_is_on(dep_targets): # one of them needs to be removed from classpath # import cats.implicits._ -def _not_sources_jar(name): +def not_sources_jar(name): return "-sources.jar" not in name -def _filter_not_sources(deps): +def filter_not_sources(deps): return depset( - [dep for dep in deps.to_list() if _not_sources_jar(dep.basename)], + [dep for dep in deps.to_list() if not_sources_jar(dep.basename)], ) -def _add_labels_of_jars_to(jars2labels, dependency, all_jars, direct_jars): +def add_labels_of_jars_to(jars2labels, dependency, all_jars, direct_jars): for jar in direct_jars: jars2labels[jar.path] = dependency.label for jar in all_jars: diff --git a/scala/private/phases/phase_init.bzl b/scala/private/phases/phase_init.bzl index 4889fa407f..400df5f38f 100644 --- a/scala/private/phases/phase_init.bzl +++ b/scala/private/phases/phase_init.bzl @@ -10,13 +10,14 @@ load( load( "@io_bazel_rules_scala//scala/private:common.bzl", "collect_jars", - "write_manifest_file", + "collect_srcjars", + "write_manifest", ) def phase_library_init(ctx, p): # This will be used to pick up srcjars from non-scala library # targets (like thrift code generation) - srcjars = _collect_srcjars(ctx.attr.deps) + srcjars = collect_srcjars(ctx.attr.deps) # Add information from exports (is key that AFTER all build actions/runfiles analysis) # Since after, will not show up in deploy_jar or old jars runfiles @@ -32,21 +33,10 @@ def phase_library_init(ctx, p): ) def phase_common_init(ctx, p): - _write_manifest(ctx) + write_manifest(ctx) return struct( scalac_provider = _get_scalac_provider(ctx), ) -def _write_manifest(ctx): - main_class = getattr(ctx.attr, "main_class", None) - write_manifest_file(ctx.actions, ctx.outputs.manifest, main_class) - def _get_scalac_provider(ctx): return ctx.toolchains["@io_bazel_rules_scala//scala:toolchain_type"].scalac_provider_attr[_ScalacProvider] - -def _collect_srcjars(targets): - srcjars = [] - for target in targets: - if hasattr(target, "srcjars"): - srcjars.append(target.srcjars.srcjar) - return depset(srcjars)