Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use on-the-fly rewrite (no more file writes) #11

Merged
merged 6 commits into from
Sep 22, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "PackageExtensionCompat"
uuid = "65ce6f38-6b18-4e1d-a461-8949797d7930"
authors = ["Christopher Doris <github.com/cjdoris>"]
version = "1.0.1"
version = "1.0.2"

[deps]
Requires = "ae029012-a4dd-5104-9daa-d747884805df"
Expand Down
81 changes: 51 additions & 30 deletions src/PackageExtensionCompat.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,37 +7,63 @@ const HAS_NATIVE_EXTENSIONS = isdefined(Base, :get_extension)
@static if !HAS_NATIVE_EXTENSIONS
using Requires, TOML

function rewrite_import(str, pkgs)
parts = split(strip(str))
if length(parts) == 1 || (length(parts) ≥ 2 && parts[2] == "as")
if parts[1] ∈ pkgs
parts[1] = string("..", parts[1])
@static if hasmethod(Base.include, Tuple{Function, Module, String})
function _include(mapexpr::Function, m::Module, path::AbstractString)
Base.include(mapexpr, m, path)
end
else
function _include(mapexpr::Function, m::Module, path::AbstractString)
path = abspath(path)
cd(dirname(path)) do
str = read(basename(path), String)
pos = 1
while true
cjdoris marked this conversation as resolved.
Show resolved Hide resolved
(expr, pos) = Meta.parse(str, pos; raise=false)
expr !== nothing || break
m.eval(mapexpr(expr))
end
end
end
join(parts, " ")
end

function rewrite_imports(str, pkgs)
parts = split(str, ",")
parts = map(part -> rewrite_import(part, pkgs), parts)
join(parts, ", ")
# a simplified variant of MacroTools.postwalk
postwalk(f, x) = f(x)
postwalk(f, x::Expr) = f(Expr(x.head, postwalk.(f, x.args)...))

function rewrite(top_pkg::Module, pkgs)
Base.Fix1(postwalk, block -> rewrite_block(block, top_pkg, pkgs))
end

function rewrite_line(line, pkgs)
pat = r"^(\s*(using|import)\s+)([^;:#$]*[^;:#$\s])(.*)$"
m = match(pat, line)
if m === nothing
line
function rewrite_block(block, top_pkg::Module, pkgs)
if Meta.isexpr(block, :call) && length(block.args) == 2 && block.args[1] == :include
# inner include, rewrite it recursively
dhanak marked this conversation as resolved.
Show resolved Hide resolved
local_mod = Expr(:macrocall, Symbol("@__MODULE__"), @__LINE__)
Expr(:call, _include, rewrite(top_pkg, pkgs), local_mod, block.args[2])
elseif Meta.isexpr(block, [:using, :import])
# using or import block, replace references to pkgs
imports = map(block.args) do use
Meta.isexpr(use, [:(:), :as]) ?
Expr(use.head,
rewrite_use(use.args[1], top_pkg, pkgs),
use.args[2:end]...) :
rewrite_use(use, top_pkg, pkgs)
end
Expr(block.head, imports...)
else
string(m.captures[1], rewrite_imports(m.captures[3], pkgs), m.captures[4])
# leave everything else alone
block
end
end

function rewrite(srcfile, trgfile, pkgs)
lines = readlines(srcfile)
lines = map(line -> rewrite_line(line, pkgs), lines)
code = join(lines, "\n")
write(trgfile, code)
function rewrite_use(use::Expr, top_pkg::Module, pkgs)::Expr
@assert Meta.isexpr(use, :.)
if string(use.args[1]) ∈ pkgs
# rewrite `using/import WeakDep` as `using/import TopPkg.WeakDep`
Expr(:., nameof(top_pkg), use.args...)
else
# leave every other package import alone
use
end
end

macro require_extensions()
Expand All @@ -56,26 +82,21 @@ const HAS_NATIVE_EXTENSIONS = isdefined(Base, :get_extension)
extensions = get(toml, "extensions", [])
isempty(extensions) && error("no extensions defined in $tomlpath")
exprs = []
rm(joinpath(rootdir, "ext_compat"), force=true, recursive=true)
for (name, pkgs) in extensions
if pkgs isa String
pkgs = [pkgs]
end
extpath = nothing
for path in [joinpath(rootdir, "ext", "$name.jl"), joinpath(rootdir, "ext", "$name", "$name.jl")]
for path in [joinpath(rootdir, "ext", "$name.jl"),
joinpath(rootdir, "ext", "$name", "$name.jl")]
if isfile(path)
extpath = path
end
end
extpath === nothing && error("Expecting ext/$name.jl or ext/$name/$name.jl in $rootdir for extension $name.")
# rewrite the extension code
# TODO: there may be other files to copy/rewrite
__module__.include_dependency(extpath)
extpath2 = joinpath(rootdir, "ext_compat", relpath(extpath, joinpath(rootdir, "ext")))
mkpath(dirname(extpath2))
rewrite(extpath, extpath2, pkgs)
# include the extension code
expr = :($(__module__.include)($extpath2))
# include and rewrite the extension code
expr = :($(_include)($(rewrite(__module__, pkgs)), $__module__, $extpath))
for pkg in pkgs
uuid = get(get(Dict, toml, "weakdeps"), pkg, nothing)
uuid === nothing && error("Expecting a weakdep for $pkg in $tomlpath.")
Expand Down
39 changes: 36 additions & 3 deletions test/runtests.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
using Test, Pkg, Random, UUIDs, PackageExtensionCompat

function make_package(dir; name=nothing, uuid=nothing, src="", deps=[], weakdeps=[], extensions=[])
function make_package(dir; name=nothing, uuid=nothing, src="", deps=[], weakdeps=[], extensions=[], includes=[])
if name === nothing
name = "TestPackage_$(randstring())"
end
Expand Down Expand Up @@ -36,7 +36,7 @@ function make_package(dir; name=nothing, uuid=nothing, src="", deps=[], weakdeps
end
extpath = joinpath(rootpath, "ext")
mkpath(extpath)
for ext in extensions
for ext in [extensions; includes]
open(joinpath(extpath, "$(ext.name).jl"), "w") do io
print(io, """
module $(ext.name)
Expand All @@ -48,7 +48,7 @@ function make_package(dir; name=nothing, uuid=nothing, src="", deps=[], weakdeps
(name=name, uuid=uuid, path=rootpath)
end

function test_extension(; dir, slug, extsrc)
function test_extension(; dir, slug, extsrc, incsrc = nothing)
# a secret value embedded into the extension which can only be recovered if the
# extension is loaded correctly
secret = rand(Int)
Expand Down Expand Up @@ -79,6 +79,12 @@ function test_extension(; dir, slug, extsrc)
deps = [pkg1.name],
src = replace(replace(extsrc, "PKG1NAME" => pkg1.name), "PKG2NAME" => "PKGNAME"),
)
],
includes = incsrc === nothing ? [] : [
(
name = "TestExtSub",
src = replace(replace(incsrc, "PKG1NAME" => pkg1.name), "PKG2NAME" => "PKGNAME"),
)
]
)
# add these packages to the project
Expand Down Expand Up @@ -223,6 +229,33 @@ end
)
end

@testset "submodule" begin
test_extension(
dir = dir,
slug = "9",
extsrc = """
module SubMod
using PKG2NAME, PKG1NAME
PKG2NAME.secret() = PKG1NAME.SECRET
end
"""
)
end

@testset "inner-include" begin
test_extension(
dir = dir,
slug = "10",
extsrc = """
include("TestExtSub.jl")
@assert TestExtSub isa Module
""",
incsrc = """
using PKG2NAME, PKG1NAME
PKG2NAME.secret() = PKG1NAME.SECRET
"""
)
end
end

end