Skip to content

Commit

Permalink
Cache external CodeInstances (#43990)
Browse files Browse the repository at this point in the history
Prior to this PR, Julia's precompiled `*.ji` files saved just two
categories of code: unspecialized method definitions and
type-specialized code for the methods defined by the package.  Any
novel specializations of methods from Base or previously-loaded
packages were not saved, and therefore effectively thrown away.

This PR caches all the code---internal or external---called during
package definition that hadn't been previously inferred, as long
as there is a backedge linking it back to a method owned by
a module being precompiled. (The latter condition ensures it will
actually be called by package methods, and not merely transiently
generated for the purpose of, e.g., metaprogramming or variable
initialization.) This makes precompilation more intuitive (now it
saves all relevant inference results), and substantially reduces
latency for inference-bound packages.

Closes #42016
Fixes #35972

Issue #35972 arose because codegen got started without re-inferring
some discarded CodeInstances. This forced the compiler to insert a
`jl_invoke`. This PR fixes the issue because needed CodeInstances are
no longer discarded by precompilation.
  • Loading branch information
timholy authored Feb 24, 2022
1 parent fb85cb3 commit df81bf9
Show file tree
Hide file tree
Showing 13 changed files with 724 additions and 96 deletions.
19 changes: 18 additions & 1 deletion base/Base.jl
Original file line number Diff line number Diff line change
Expand Up @@ -427,7 +427,7 @@ end_base_include = time_ns()
const _sysimage_modules = PkgId[]
in_sysimage(pkgid::PkgId) = pkgid in _sysimage_modules

# Precompiles for Revise
# Precompiles for Revise and other packages
# TODO: move these to contrib/generate_precompile.jl
# The problem is they don't work there
for match = _methods(+, (Int, Int), -1, get_world_counter())
Expand Down Expand Up @@ -461,6 +461,23 @@ for match = _methods(+, (Int, Int), -1, get_world_counter())

# Code loading uses this
sortperm(mtime.(readdir(".")), rev=true)
# JLLWrappers uses these
Dict{UUID,Set{String}}()[UUID("692b3bcd-3c85-4b1f-b108-f13ce0eb3210")] = Set{String}()
get!(Set{String}, Dict{UUID,Set{String}}(), UUID("692b3bcd-3c85-4b1f-b108-f13ce0eb3210"))
eachindex(IndexLinear(), Expr[])
push!(Expr[], Expr(:return, false))
vcat(String[], String[])
k, v = (:hello => nothing)
precompile(indexed_iterate, (Pair{Symbol, Union{Nothing, String}}, Int))
precompile(indexed_iterate, (Pair{Symbol, Union{Nothing, String}}, Int, Int))
# Preferences uses these
precompile(get_preferences, (UUID,))
precompile(record_compiletime_preference, (UUID, String))
get(Dict{String,Any}(), "missing", nothing)
delete!(Dict{String,Any}(), "missing")
for (k, v) in Dict{String,Any}()
println(k)
end

break # only actually need to do this once
end
Expand Down
61 changes: 36 additions & 25 deletions base/binaryplatforms.jl
Original file line number Diff line number Diff line change
Expand Up @@ -40,10 +40,10 @@ struct Platform <: AbstractPlatform
# The "compare strategy" allows selective overriding on how a tag is compared
compare_strategies::Dict{String,Function}

function Platform(arch::String, os::String;
# Passing `tags` as a `Dict` avoids the need to infer different NamedTuple specializations
function Platform(arch::String, os::String, _tags::Dict{String};
validate_strict::Bool = false,
compare_strategies::Dict{String,<:Function} = Dict{String,Function}(),
kwargs...)
compare_strategies::Dict{String,<:Function} = Dict{String,Function}())
# A wee bit of normalization
os = lowercase(os)
arch = CPUID.normalize_arch(arch)
Expand All @@ -52,8 +52,9 @@ struct Platform <: AbstractPlatform
"arch" => arch,
"os" => os,
)
for (tag, value) in kwargs
tag = lowercase(string(tag::Symbol))
for (tag, value) in _tags
value = value::Union{String,VersionNumber,Nothing}
tag = lowercase(tag)
if tag ("arch", "os")
throw(ArgumentError("Cannot double-pass key $(tag)"))
end
Expand All @@ -70,8 +71,8 @@ struct Platform <: AbstractPlatform
if tag ("libgfortran_version", "libstdcxx_version", "os_version")
if isa(value, VersionNumber)
value = string(value)
elseif isa(value, AbstractString)
v = tryparse(VersionNumber, String(value)::String)
elseif isa(value, String)
v = tryparse(VersionNumber, value)
if isa(v, VersionNumber)
value = string(v)
end
Expand Down Expand Up @@ -110,6 +111,19 @@ struct Platform <: AbstractPlatform
end
end

# Keyword interface (to avoid inference of specialized NamedTuple methods, use the Dict interface for `tags`)
function Platform(arch::String, os::String;
validate_strict::Bool = false,
compare_strategies::Dict{String,<:Function} = Dict{String,Function}(),
kwargs...)
tags = Dict{String,Any}(String(tag)::String=>tagvalue(value) for (tag, value) in kwargs)
return Platform(arch, os, tags; validate_strict, compare_strategies)
end

tagvalue(v::Union{String,VersionNumber,Nothing}) = v
tagvalue(v::Symbol) = String(v)
tagvalue(v::AbstractString) = convert(String, v)::String

# Simple tag insertion that performs a little bit of validation
function add_tag!(tags::Dict{String,String}, tag::String, value::String)
# I know we said only alphanumeric and dots, but let's be generous so that we can expand
Expand Down Expand Up @@ -699,21 +713,22 @@ function Base.parse(::Type{Platform}, triplet::AbstractString; validate_strict::
end

# Extract the information we're interested in:
tags = Dict{String,Any}()
arch = get_field(m, arch_mapping)
os = get_field(m, os_mapping)
libc = get_field(m, libc_mapping)
call_abi = get_field(m, call_abi_mapping)
libgfortran_version = get_field(m, libgfortran_version_mapping)
libstdcxx_version = get_field(m, libstdcxx_version_mapping)
cxxstring_abi = get_field(m, cxxstring_abi_mapping)
tags["libc"] = get_field(m, libc_mapping)
tags["call_abi"] = get_field(m, call_abi_mapping)
tags["libgfortran_version"] = get_field(m, libgfortran_version_mapping)
tags["libstdcxx_version"] = get_field(m, libstdcxx_version_mapping)
tags["cxxstring_abi"] = get_field(m, cxxstring_abi_mapping)
function split_tags(tagstr)
tag_fields = split(tagstr, "-"; keepempty=false)
if isempty(tag_fields)
return Pair{String,String}[]
end
return map(v -> Symbol(v[1]) => v[2], split.(tag_fields, "+"))
return map(v -> String(v[1]) => String(v[2]), split.(tag_fields, "+"))
end
tags = split_tags(m["tags"])
merge!(tags, Dict(split_tags(m["tags"])))

# Special parsing of os version number, if any exists
function extract_os_version(os_name, pattern)
Expand All @@ -730,18 +745,9 @@ function Base.parse(::Type{Platform}, triplet::AbstractString; validate_strict::
if os == "freebsd"
os_version = extract_os_version("freebsd", r".*freebsd([\d.]+)")
end
tags["os_version"] = os_version

return Platform(
arch, os;
validate_strict,
libc,
call_abi,
libgfortran_version,
cxxstring_abi,
libstdcxx_version,
os_version,
tags...,
)
return Platform(arch, os, tags; validate_strict)
end
throw(ArgumentError("Platform `$(triplet)` is not an officially supported platform"))
end
Expand Down Expand Up @@ -1068,4 +1074,9 @@ function select_platform(download_info::Dict, platform::AbstractPlatform = HostP
return download_info[p]
end

# precompiles to reduce latency (see https://github.com/JuliaLang/julia/pull/43990#issuecomment-1025692379)
Dict{Platform,String}()[HostPlatform()] = ""
Platform("x86_64", "linux", Dict{String,Any}(); validate_strict=true)
Platform("x86_64", "linux", Dict{String,String}(); validate_strict=false) # called this way from Artifacts.unpack_platform

end # module
10 changes: 10 additions & 0 deletions base/compiler/typeinfer.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
# This file is a part of Julia. License is MIT: https://julialang.org/license

# Tracking of newly-inferred MethodInstances during precompilation
const track_newly_inferred = RefValue{Bool}(false)
const newly_inferred = MethodInstance[]

# build (and start inferring) the inference frame for the top-level MethodInstance
function typeinf(interp::AbstractInterpreter, result::InferenceResult, cache::Symbol)
frame = InferenceState(result, cache, interp)
Expand Down Expand Up @@ -389,6 +393,12 @@ function cache_result!(interp::AbstractInterpreter, result::InferenceResult)
if !already_inferred
inferred_result = transform_result_for_cache(interp, linfo, valid_worlds, result.src)
code_cache(interp)[linfo] = CodeInstance(result, inferred_result, valid_worlds)
if track_newly_inferred[]
m = linfo.def
if isa(m, Method)
m.module != Core && push!(newly_inferred, linfo)
end
end
end
unlock_mi_inference(interp, linfo)
nothing
Expand Down
8 changes: 6 additions & 2 deletions base/loading.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1395,13 +1395,17 @@ function include_package_for_output(pkg::PkgId, input::String, depot_path::Vecto
task_local_storage()[:SOURCE_PATH] = source
end

Core.Compiler.track_newly_inferred.x = true
try
Base.include(Base.__toplevel__, input)
catch ex
precompilableerror(ex) || rethrow()
@debug "Aborting `create_expr_cache'" exception=(ErrorException("Declaration of __precompile__(false) not allowed"), catch_backtrace())
exit(125) # we define status = 125 means PrecompileableError
finally
Core.Compiler.track_newly_inferred.x = false
end
ccall(:jl_set_newly_inferred, Cvoid, (Any,), Core.Compiler.newly_inferred)
end

const PRECOMPILE_TRACE_COMPILE = Ref{String}()
Expand Down Expand Up @@ -2033,12 +2037,12 @@ end
Compile the given function `f` for the argument tuple (of types) `args`, but do not execute it.
"""
function precompile(@nospecialize(f), args::Tuple)
function precompile(@nospecialize(f), @nospecialize(args::Tuple))
precompile(Tuple{Core.Typeof(f), args...})
end

const ENABLE_PRECOMPILE_WARNINGS = Ref(false)
function precompile(argt::Type)
function precompile(@nospecialize(argt::Type))
ret = ccall(:jl_compile_hint, Int32, (Any,), argt) != 0
if !ret && ENABLE_PRECOMPILE_WARNINGS[]
@warn "Inactive precompile statement" maxlog=100 form=argt _module=nothing _file=nothing _line=0
Expand Down
2 changes: 1 addition & 1 deletion src/codegen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7870,7 +7870,7 @@ jl_compile_result_t jl_emit_codeinst(
// don't delete inlineable code, unless it is constant
(codeinst->invoke == jl_fptr_const_return_addr || !jl_ir_flag_inlineable((jl_array_t*)codeinst->inferred)) &&
// don't delete code when generating a precompile file
!imaging_mode) {
!(imaging_mode || jl_options.incremental)) {
// if not inlineable, code won't be needed again
codeinst->inferred = jl_nothing;
}
Expand Down
Loading

0 comments on commit df81bf9

Please sign in to comment.