Skip to content

Commit

Permalink
Split out a double-check-cache job for jvm/rsc compile. (#8221)
Browse files Browse the repository at this point in the history
### Problem

#8190 moved cache writing out of the completion of the zinc and rsc jobs and into a dependent job. But at the same time, we also had multiple attempts to "double check" the cache happening concurrently due to both the zinc and rsc jobs checking, and that race could lead to partial entries being extracted.

### Solution

Since we can't actually cancel or coordinate the concurrent work, we can't safely double check the cache once either job has started. So instead, this change extracts the cache double-check into its own job that both the zinc and rsc tasks will depend on.
  • Loading branch information
stuhood authored Aug 29, 2019
1 parent dbf5851 commit e3cf637
Show file tree
Hide file tree
Showing 3 changed files with 135 additions and 74 deletions.
61 changes: 38 additions & 23 deletions src/python/pants/backend/jvm/tasks/jvm_compile/jvm_compile.py
Original file line number Diff line number Diff line change
Expand Up @@ -711,53 +711,64 @@ def _upstream_analysis(self, compile_contexts, classpath_entries):
else:
yield compile_context.classes_dir.path, compile_context.analysis_file

def exec_graph_double_check_cache_key_for_target(self, target):
return 'double_check_cache({})'.format(target.address.spec)

def exec_graph_key_for_target(self, compile_target):
return "compile({})".format(compile_target.address.spec)

def _create_compile_jobs(self, compile_contexts, invalid_targets, invalid_vts, classpath_product):
class Counter:
def __init__(self, size, initial=0):
def __init__(self, size=0):
self.size = size
self.count = initial
self.count = 0

def __call__(self):
self.count += 1
return self.count

def increment_size(self, by=1):
self.size += by

def format_length(self):
return len(str(self.size))
counter = Counter(len(invalid_vts))

jobs = []
counter = Counter()

jobs.extend(self.pre_compile_jobs(counter))
invalid_target_set = set(invalid_targets)
for ivts in invalid_vts:
# Invalidated targets are a subset of relevant targets: get the context for this one.
compile_target = ivts.target
invalid_dependencies = self._collect_invalid_compile_dependencies(compile_target,
invalid_target_set)

jobs.extend(
self.create_compile_jobs(compile_target, compile_contexts, invalid_dependencies, ivts,
counter, classpath_product))
new_jobs, new_count = self.create_compile_jobs(
compile_target, compile_contexts, invalid_dependencies, ivts, counter, classpath_product)
jobs.extend(new_jobs)
counter.increment_size(by=new_count)

counter.size = len(jobs)
return jobs

def pre_compile_jobs(self, counter):
"""Override this to provide jobs that are not related to particular targets.
This is only called when there are invalid targets."""
return []

def create_compile_jobs(self, compile_target, all_compile_contexts, invalid_dependencies, ivts,
counter, classpath_product):
"""Return a list of jobs, and a count of those jobs that represent meaningful ("countable") work."""

context_for_target = all_compile_contexts[compile_target]
compile_context = self.select_runtime_context(context_for_target)

job = Job(self.exec_graph_key_for_target(compile_target),
compile_deps = [self.exec_graph_key_for_target(target) for target in invalid_dependencies]

# The cache checking job doesn't technically have any dependencies, but we want to delay it
# until immediately before we would otherwise try compiling, so we indicate that it depends on
# all compile dependencies.
double_check_cache_job = Job(self.exec_graph_double_check_cache_key_for_target(compile_target),
functools.partial(self._default_double_check_cache_for_vts, ivts),
compile_deps)
# The compile job depends on the cache check job. This decomposition is necessary in order to
# support more complex situations where compilation runs multiple jobs in parallel, and wants to
# double check the cache before starting any of them.
compile_job = Job(self.exec_graph_key_for_target(compile_target),
functools.partial(
self._default_work_for_vts,
ivts,
Expand All @@ -766,15 +777,15 @@ def create_compile_jobs(self, compile_target, all_compile_contexts, invalid_depe
counter,
all_compile_contexts,
classpath_product),
[self.exec_graph_key_for_target(target) for target in invalid_dependencies],
[double_check_cache_job.key] + compile_deps,
self._size_estimator(compile_context.sources),
# If compilation and analysis work succeeds, validate the vts.
# Otherwise, fail it.
on_success=ivts.update,
on_failure=ivts.force_invalidate)
return [job]
return ([double_check_cache_job, compile_job], 1)

def check_cache(self, vts, counter):
def check_cache(self, vts):
"""Manually checks the artifact cache (usually immediately before compilation.)
Returns true if the cache was hit successfully, indicating that no compilation is necessary.
Expand All @@ -790,7 +801,6 @@ def check_cache(self, vts, counter):
'Cache returned unexpected target: {} vs {}'.format(cached_vts, [vts])
)
self.context.log.info('Hit cache during double check for {}'.format(vts.target.address.spec))
counter()
return True

def should_compile_incrementally(self, vts, ctx):
Expand Down Expand Up @@ -916,13 +926,18 @@ def _get_jvm_distribution(self):
self.HERMETIC: lambda: self._HermeticDistribution('.jdk', local_distribution),
})()

def _default_double_check_cache_for_vts(self, vts):
# Double check the cache before beginning compilation
if self.check_cache(vts):
vts.update()

def _default_work_for_vts(self, vts, ctx, input_classpath_product_key, counter, all_compile_contexts, output_classpath_product):
progress_message = ctx.target.address.spec

# Double check the cache before beginning compilation
hit_cache = self.check_cache(vts, counter)

if not hit_cache:
# See whether the cache-doublecheck job hit the cache: if so, noop: otherwise, compile.
if vts.valid:
counter()
else:
# Compute the compile classpath for this target.
dependency_cp_entries = self._zinc.compile_classpath_entries(
input_classpath_product_key,
Expand Down
99 changes: 52 additions & 47 deletions src/python/pants/backend/jvm/tasks/jvm_compile/rsc/rsc_compile.py
Original file line number Diff line number Diff line change
Expand Up @@ -292,25 +292,6 @@ def _zinc_key_for_target(self, target, workflow):
def _write_to_cache_key_for_target(self, target):
return 'write_to_cache({})'.format(target.address.spec)

def _check_cache_before_work(self, work_str, vts, ctx, counter, debug = False, work_fn = lambda: None):
hit_cache = self.check_cache(vts, counter)

if not hit_cache:
counter_val = str(counter()).rjust(counter.format_length(), ' ')
counter_str = '[{}/{}] '.format(counter_val, counter.size)
log_fn = self.context.log.debug if debug else self.context.log.info
log_fn(
counter_str,
f'{work_str} ',
items_to_report_element(ctx.sources, '{} source'.format(self.name())),
' in ',
items_to_report_element([t.address.reference() for t in vts.targets], 'target'),
' (',
ctx.target.address.spec,
').')

work_fn()

def create_compile_jobs(self,
compile_target,
compile_contexts,
Expand All @@ -323,7 +304,19 @@ def work_for_vts_rsc(vts, ctx):
target = ctx.target
tgt, = vts.targets

def work_fn():
# If we didn't hit the cache in the cache job, run rsc.
if not vts.valid:
counter_val = str(counter()).rjust(counter.format_length(), ' ')
counter_str = '[{}/{}] '.format(counter_val, counter.size)
self.context.log.info(
counter_str,
'Rsc-ing ',
items_to_report_element(ctx.sources, '{} source'.format(self.name())),
' in ',
items_to_report_element([t.address.reference() for t in vts.targets], 'target'),
' (',
ctx.target.address.spec,
').')
# This does the following
# - Collect the rsc classpath elements, including zinc compiles of rsc incompatible targets
# and rsc compiles of rsc compatible targets.
Expand Down Expand Up @@ -391,16 +384,11 @@ def nonhermetic_digest_classpath():
'rsc'
)

# Double check the cache before beginning compilation
self._check_cache_before_work('Rsc-ing', vts, ctx, counter, work_fn=work_fn)

# Update the products with the latest classes.
self.register_extra_products_from_contexts([ctx.target], compile_contexts)

def work_for_vts_write_to_cache(vts, ctx):
self._check_cache_before_work('Writing to cache for', vts, ctx, counter, debug=True)

### Create Jobs for ExecutionGraph
cache_doublecheck_jobs = []
rsc_jobs = []
zinc_jobs = []

Expand All @@ -410,6 +398,8 @@ def work_for_vts_write_to_cache(vts, ctx):
rsc_compile_context = merged_compile_context.rsc_cc
zinc_compile_context = merged_compile_context.zinc_cc

cache_doublecheck_key = self.exec_graph_double_check_cache_key_for_target(compile_target)

def all_zinc_rsc_invalid_dep_keys(invalid_deps):
"""Get the rsc key for an rsc-and-zinc target, or the zinc key for a zinc-only target."""
for tgt in invalid_deps:
Expand All @@ -420,6 +410,14 @@ def all_zinc_rsc_invalid_dep_keys(invalid_deps):
# Rely on the results of zinc compiles for zinc-compatible targets
yield self._key_for_target_as_dep(tgt, tgt_rsc_cc.workflow)

def make_cache_doublecheck_job(dep_keys):
# As in JvmCompile.create_compile_jobs, we create a cache-double-check job that all "real" work
# depends on. It depends on completion of the same dependencies as the rsc job in order to run
# as late as possible, while still running before rsc or zinc.
return Job(cache_doublecheck_key,
functools.partial(self._default_double_check_cache_for_vts, ivts),
dependencies=list(dep_keys))

def make_rsc_job(target, dep_targets):
return Job(
key=self._rsc_key_for_target(target),
Expand All @@ -432,7 +430,7 @@ def make_rsc_job(target, dep_targets):
),
# The rsc jobs depend on other rsc jobs, and on zinc jobs for targets that are not
# processed by rsc.
dependencies=list(all_zinc_rsc_invalid_dep_keys(dep_targets)),
dependencies=[cache_doublecheck_key] + list(all_zinc_rsc_invalid_dep_keys(dep_targets)),
size=self._size_estimator(rsc_compile_context.sources),
)

Expand All @@ -453,7 +451,7 @@ def make_zinc_job(target, input_product_key, output_products, dep_keys):
counter,
compile_contexts,
CompositeProductAdder(*output_products)),
dependencies=list(dep_keys),
dependencies=[cache_doublecheck_key] + list(dep_keys),
size=self._size_estimator(zinc_compile_context.sources),
)

Expand All @@ -470,6 +468,19 @@ def record(k, v):
record('workflow', workflow.value)
record('execution_strategy', self.execution_strategy)

# Create the cache doublecheck job.
workflow.resolve_for_enum_variant({
'zinc-only': lambda: cache_doublecheck_jobs.append(
make_cache_doublecheck_job(list(all_zinc_rsc_invalid_dep_keys(invalid_dependencies)))
),
'zinc-java': lambda: cache_doublecheck_jobs.append(
make_cache_doublecheck_job(list(only_zinc_invalid_dep_keys(invalid_dependencies)))
),
'rsc-and-zinc': lambda: cache_doublecheck_jobs.append(
make_cache_doublecheck_job(list(all_zinc_rsc_invalid_dep_keys(invalid_dependencies)))
),
})()

# Create the rsc job.
# Currently, rsc only supports outlining scala.
workflow.resolve_for_enum_variant({
Expand Down Expand Up @@ -519,25 +530,19 @@ def record(k, v):
)),
})()

all_jobs = rsc_jobs + zinc_jobs

if all_jobs:
write_to_cache_job = Job(
key=self._write_to_cache_key_for_target(compile_target),
fn=functools.partial(
work_for_vts_write_to_cache,
ivts,
rsc_compile_context,
),
dependencies=[job.key for job in all_jobs],
run_asap=True,
# If compilation and analysis work succeeds, validate the vts.
# Otherwise, fail it.
on_success=ivts.update,
on_failure=ivts.force_invalidate)
all_jobs.append(write_to_cache_job)

return all_jobs
compile_jobs = rsc_jobs + zinc_jobs

# Create a job that depends on all real work having completed that will eagerly write to the
# cache by calling `vt.update()`.
write_to_cache_job = Job(
key=self._write_to_cache_key_for_target(compile_target),
fn=ivts.update,
dependencies=[job.key for job in compile_jobs],
run_asap=True,
on_failure=ivts.force_invalidate)

all_jobs = cache_doublecheck_jobs + rsc_jobs + zinc_jobs + [write_to_cache_job]
return (all_jobs, len(compile_jobs))

class RscZincMergedCompileContexts(datatype([
('rsc_cc', RscCompileContext),
Expand Down
Loading

0 comments on commit e3cf637

Please sign in to comment.