Skip to content

Commit

Permalink
another hack
Browse files Browse the repository at this point in the history
  • Loading branch information
willow-ahrens committed Jan 9, 2025
1 parent bb0195d commit 1145e57
Show file tree
Hide file tree
Showing 13 changed files with 47 additions and 55 deletions.
4 changes: 2 additions & 2 deletions ext/SparseArraysExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -268,7 +268,7 @@ function Finch.unfurl(ctx, tns::VirtualSparseMatrixCSCColumn, ext, mode, ::Union
end
$dirty = false
end,
body = (ctx) -> Finch.instantiate(ctx, Finch.VirtualSparseScalar(nothing, arr.Tv, zero(arr.Tv), fgensym(), :($(arr.val)[$(ctx(qos))]), dirty), mode),
body = (ctx) -> Finch.instantiate(ctx, Finch.VirtualSparseScalar(nothing, arr.Tv, zero(arr.Tv), gensym(), :($(arr.val)[$(ctx(qos))]), dirty), mode),
epilogue = quote
if $dirty
$(arr.idx)[$qos] = $(ctx(idx))
Expand Down Expand Up @@ -463,7 +463,7 @@ function Finch.unfurl(ctx, arr::VirtualSparseVector, ext, mode::Updater, ::Union
end
$dirty = false
end,
body = (ctx) -> Finch.instantiate(ctx, Finch.VirtualSparseScalar(nothing, arr.Tv, zero(arr.Tv), fgensym(), :($(arr.val)[$(ctx(qos))]), dirty), mode),
body = (ctx) -> Finch.instantiate(ctx, Finch.VirtualSparseScalar(nothing, arr.Tv, zero(arr.Tv), gensym(), :($(arr.val)[$(ctx(qos))]), dirty), mode),
epilogue = quote
if $dirty
$(arr.idx)[$qos] = $(ctx(idx))
Expand Down
2 changes: 1 addition & 1 deletion src/execute.jl
Original file line number Diff line number Diff line change
Expand Up @@ -228,7 +228,7 @@ See also: [`@finch`](@ref)
"""
function finch_kernel(fname, args, prgm; algebra = DefaultAlgebra(), mode = :safe, ctx = FinchCompiler(algebra=algebra, mode=mode))
maybe_typeof(x) = x isa Type ? x : typeof(x)
unreachable = fgensym(:unreachable)
unreachable = gensym(:unreachable)
code = contain(ctx) do ctx_2
foreach(args) do (key, val)
set_binding!(ctx_2, variable(key), finch_leaf(virtualize(ctx_2.code, key, maybe_typeof(val), key)))
Expand Down
8 changes: 4 additions & 4 deletions src/interface/abstract_arrays.jl
Original file line number Diff line number Diff line change
Expand Up @@ -53,11 +53,11 @@ function unfurl(ctx, tns::VirtualAbstractArraySlice, ext, mode, proto)
preamble = quote
$val = $(arr.ex)[$(map(ctx, idx_2)...)]
end,
body = (ctx) -> instantiate(ctx, VirtualScalar(nothing, arr.eltype, nothing#=We don't know what init is, but it won't be used here =#, fgensym(), val), mode)
body = (ctx) -> instantiate(ctx, VirtualScalar(nothing, arr.eltype, nothing#=We don't know what init is, but it won't be used here =#, gensym(), val), mode)
)
else
Thunk(
body = (ctx,) -> instantiate(ctx, VirtualScalar(nothing, arr.eltype, nothing#=We don't know what init is, but it won't be used here=#, fgensym(), :($(arr.ex)[$(map(ctx, idx_2)...)])), mode)
body = (ctx,) -> instantiate(ctx, VirtualScalar(nothing, arr.eltype, nothing#=We don't know what init is, but it won't be used here=#, gensym(), :($(arr.ex)[$(map(ctx, idx_2)...)])), mode)
)
end
else
Expand All @@ -81,11 +81,11 @@ function instantiate(ctx::AbstractCompiler, arr::VirtualAbstractArray, mode)
preamble = quote
$val = $(arr.ex)[]
end,
body = (ctx) -> instantiate(ctx, VirtualScalar(nothing, arr.eltype, nothing#=We don't know what init is, but it won't be used here =#, fgensym(), val), mode)
body = (ctx) -> instantiate(ctx, VirtualScalar(nothing, arr.eltype, nothing#=We don't know what init is, but it won't be used here =#, gensym(), val), mode)
)
else
Thunk(
body = (ctx,) -> instantiate(ctx, VirtualScalar(nothing, arr.eltype, nothing#=We don't know what init is, but it won't be used here=#, fgensym(), :($(arr.ex)[])), mode)
body = (ctx,) -> instantiate(ctx, VirtualScalar(nothing, arr.eltype, nothing#=We don't know what init is, but it won't be used here=#, gensym(), :($(arr.ex)[])), mode)
)
end
else
Expand Down
30 changes: 15 additions & 15 deletions src/interface/lazy.jl
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,8 @@ function expanddims(arr::LazyTensor{T}, dims) where {T}
@assert allunique(dims)
@assert issubset(dims,1:ndims(arr) + length(dims))
antidims = setdiff(1:ndims(arr) + length(dims), dims)
idxs_1 = [field(fgensym(:i)) for _ in 1:ndims(arr)]
idxs_2 = [field(fgensym(:i)) for _ in 1:ndims(arr) + length(dims)]
idxs_1 = [field(gensym(:i)) for _ in 1:ndims(arr)]
idxs_2 = [field(gensym(:i)) for _ in 1:ndims(arr) + length(dims)]
idxs_2[antidims] .= idxs_1
data_2 = reorder(relabel(arr.data, idxs_1...), idxs_2...)
extrude_2 = [false for _ in 1:ndims(arr) + length(dims)]
Expand All @@ -48,25 +48,25 @@ function expanddims(arr::LazyTensor{T}, dims) where {T}
end

function identify(data)
lhs = alias(fgensym(:A))
lhs = alias(gensym(:A))
subquery(lhs, data)
end

LazyTensor(data::Number) = LazyTensor{typeof(data), 0}(immediate(data), (), data)
LazyTensor{T}(data::Number) where {T} = LazyTensor{T, 0}(immediate(data), (), data)
LazyTensor(arr::Base.AbstractArrayOrBroadcasted) = LazyTensor{eltype(arr)}(arr)
function LazyTensor{T}(arr::Base.AbstractArrayOrBroadcasted) where {T}
name = alias(fgensym(:A))
idxs = [field(fgensym(:i)) for _ in 1:ndims(arr)]
name = alias(gensym(:A))
idxs = [field(gensym(:i)) for _ in 1:ndims(arr)]
extrude = ntuple(n -> size(arr, n) == 1, ndims(arr))
tns = subquery(name, table(immediate(arr), idxs...))
LazyTensor{eltype(arr), ndims(arr)}(tns, extrude, fill_value(arr))
end
LazyTensor(arr::AbstractTensor) = LazyTensor{eltype(arr)}(arr)
LazyTensor(swizzle_arr::SwizzleArray{dims, <:Tensor}) where {dims} = permutedims(LazyTensor(swizzle_arr.body), dims)
function LazyTensor{T}(arr::AbstractTensor) where {T}
name = alias(fgensym(:A))
idxs = [field(fgensym(:i)) for _ in 1:ndims(arr)]
name = alias(gensym(:A))
idxs = [field(gensym(:i)) for _ in 1:ndims(arr)]
extrude = ntuple(n -> size(arr)[n] == 1, ndims(arr))
tns = subquery(name, table(immediate(arr), idxs...))
LazyTensor{eltype(arr), ndims(arr)}(tns, extrude, fill_value(arr))
Expand All @@ -90,7 +90,7 @@ end
function Base.map(f, src::LazyTensor, args...)
largs = map(LazyTensor, (src, args...))
extrude = largs[something(findfirst(arg -> length(arg.extrude) > 0, largs), 1)].extrude
idxs = [field(fgensym(:i)) for _ in src.extrude]
idxs = [field(gensym(:i)) for _ in src.extrude]
ldatas = map(largs) do larg
if larg.extrude == extrude
return relabel(larg.data, idxs...)
Expand Down Expand Up @@ -135,7 +135,7 @@ end
function Base.reduce(op, arg::LazyTensor{T, N}; dims=:, init = initial_value(op, T)) where {T, N}
dims = dims == Colon() ? (1:N) : collect(dims)
extrude = ((arg.extrude[n] for n in 1:N if !(n in dims))...,)
fields = [field(fgensym(:i)) for _ in 1:N]
fields = [field(gensym(:i)) for _ in 1:N]
S = fixpoint_type(op, init, eltype(arg))
data = aggregate(immediate(op), immediate(init), relabel(arg.data, fields), fields[dims]...)
LazyTensor{S}(identify(data), extrude, init)
Expand All @@ -160,8 +160,8 @@ function tensordot(A::LazyTensor{T1, N1}, B::LazyTensor{T2, N2}, idxs; mult_op=*

extrude = ((A.extrude[n] for n in 1:N1 if !(n in A_idxs))...,
(B.extrude[n] for n in 1:N2 if !(n in B_idxs))...,)
A_fields = [field(fgensym(:i)) for _ in 1:N1]
B_fields = [field(fgensym(:i)) for _ in 1:N2]
A_fields = [field(gensym(:i)) for _ in 1:N1]
B_fields = [field(gensym(:i)) for _ in 1:N2]
reduce_fields = []
for i in eachindex(A_idxs)
B_fields[B_idxs[i]] = A_fields[A_idxs[i]]
Expand Down Expand Up @@ -197,7 +197,7 @@ function broadcast_to_query(bc::Broadcast.Broadcasted, idxs)
end

function broadcast_to_query(tns::LazyTensor{T, N}, idxs) where {T, N}
idxs_2 = [tns.extrude[i] ? field(fgensym(:idx)) : idxs[i] for i in 1:N]
idxs_2 = [tns.extrude[i] ? field(gensym(:idx)) : idxs[i] for i in 1:N]
data_2 = relabel(tns.data, idxs_2...)
reorder(data_2, idxs[findall(!, tns.extrude)]...)
end
Expand Down Expand Up @@ -232,7 +232,7 @@ Base.copyto!(out, bc::Broadcasted{LazyStyle{N}}) where {N} = copyto!(out, copy(b

function Base.copy(bc::Broadcasted{LazyStyle{N}}) where {N}
bc_lgc = broadcast_to_logic(bc)
idxs = [field(fgensym(:i)) for _ in 1:N]
idxs = [field(gensym(:i)) for _ in 1:N]
data = reorder(broadcast_to_query(bc_lgc, idxs), idxs)
extrude = ntuple(n -> broadcast_to_extrude(bc_lgc, n), N)
def = broadcast_to_default(bc_lgc)
Expand All @@ -252,7 +252,7 @@ function Base.permutedims(arg::LazyTensor{T, N}, perm) where {T, N}
length(perm) == N || throw(ArgumentError("permutedims given wrong number of dimensions"))
isperm(perm) || throw(ArgumentError("permutedims given invalid permutation"))
perm = collect(perm)
idxs = [field(fgensym(:i)) for _ in 1:N]
idxs = [field(gensym(:i)) for _ in 1:N]
return LazyTensor{T, N}(reorder(relabel(arg.data, idxs...), idxs[perm]...), arg.extrude[perm], arg.fill_value)
end
Base.permutedims(arr::SwizzleArray, perm) = swizzle(arr, perm...)
Expand Down Expand Up @@ -536,7 +536,7 @@ compute(arg; ctx=get_scheduler(), kwargs...) = compute_parse(set_options(ctx; kw
compute(args::Tuple; ctx=get_scheduler(), kwargs...) = compute_parse(set_options(ctx; kwargs...), map(lazy, args))
function compute_parse(ctx, args::Tuple)
args = collect(args)
vars = map(arg -> alias(fgensym(:A)), args)
vars = map(arg -> alias(gensym(:A)), args)
bodies = map((arg, var) -> query(var, arg.data), args, vars)
prgm = plan(bodies, produces(vars))

Expand Down
6 changes: 3 additions & 3 deletions src/interface/morgue.jl
Original file line number Diff line number Diff line change
Expand Up @@ -27,13 +27,13 @@ end
#=
root = Rewrite(Fixpoint(Postwalk(Chain([
(@rule plan(~a1..., query(~b, relabel(~c, ~i...)), ~a2...) => begin
d = alias(fgensym(:A))
d = alias(gensym(:A))
bindings[d] = c
rw = Rewrite(Postwalk(@rule b => relabel(d, i...)))
plan(a1..., query(d, c), map(rw, a2)...)
end),
(@rule plan(~a1..., query(~b, reorder(~c, ~i...)), ~a2...) => begin
d = alias(fgensym(:A))
d = alias(gensym(:A))
bindings[d] = c
rw = Rewrite(Postwalk(@rule b => reorder(d, i...)))
plan(a1..., query(d, c), map(rw, a2)...)
Expand All @@ -54,7 +54,7 @@ end
function push_reorders(root, bindings)
Rewrite(Fixpoint(Postwalk(Chain([
(@rule plan(~a1..., query(~b, reorder(~c, ~i...)), ~a2...) => begin
d = alias(fgensym(:A))
d = alias(gensym(:A))
bindings[d] = c
rw = Rewrite(Postwalk(@rule b => reorder(d, i...)))
plan(a1..., query(d, c), map(rw, a2)...)
Expand Down
6 changes: 3 additions & 3 deletions src/interface/traits.jl
Original file line number Diff line number Diff line change
Expand Up @@ -214,7 +214,7 @@ function map_rep_def(::MapRepHollowStyle, f, args)
end
for (n, arg) in enumerate(args)
if arg isa HollowData
args_2 = map(arg -> value(fgensym(), eltype(arg)), collect(args))
args_2 = map(arg -> value(gensym(), eltype(arg)), collect(args))
args_2[n] = literal(fill_value(arg))
if finch_leaf(simplify(FinchCompiler(), call(f, args_2...))) == literal(fill_value(lvl))
return HollowData(lvl)
Expand All @@ -231,7 +231,7 @@ function map_rep_def(::MapRepSparseStyle, f, args)
end
for (n, arg) in enumerate(args)
if arg isa SparseData
args_2 = map(arg -> value(fgensym(), eltype(arg)), collect(args))
args_2 = map(arg -> value(gensym(), eltype(arg)), collect(args))
args_2[n] = literal(fill_value(arg))
if finch_leaf(simplify(FinchCompiler(), call(f, args_2...))) == literal(fill_value(lvl))
return SparseData(lvl)
Expand All @@ -248,7 +248,7 @@ function map_rep_def(::MapRepRepeatStyle, f, args)
end
for (n, arg) in enumerate(args)
if arg isa RepeatData
args_2 = map(arg -> value(fgensym(), eltype(arg)), collect(args))
args_2 = map(arg -> value(gensym(), eltype(arg)), collect(args))
args_2[n] = literal(fill_value(arg))
if finch_leaf(simplify(FinchCompiler(), call(f, args_2...))) == literal(fill_value(lvl))
return RepeatData(lvl)
Expand Down
2 changes: 1 addition & 1 deletion src/scheduler/LogicExecutor.jl
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ function logic_executor_code(ctx, prgm)
ctx(prgm)
end
code = pretty(code)
fname = fgensym(:compute)
fname = gensym(Symbol(:compute, hash(get_structure(prgm)))) #The fact that we need this hash is worrisome
return :(function $fname(prgm)
$code
end) |> striplines
Expand Down
12 changes: 6 additions & 6 deletions src/scheduler/optimize.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,21 +5,21 @@ flatten_plans = Rewrite(Postwalk(Fixpoint(Chain([

isolate_aggregates = Rewrite(Postwalk(
@rule aggregate(~op, ~init, ~arg, ~idxs...) => begin
name = alias(fgensym(:A))
name = alias(gensym(:A))
subquery(name, aggregate(~op, ~init, ~arg, ~idxs...))
end
))

isolate_reformats = Rewrite(Postwalk(
@rule reformat(~tns, ~arg) => begin
name = alias(fgensym(:A))
name = alias(gensym(:A))
subquery(name, reformat(tns, arg))
end
))

isolate_tables = Rewrite(Postwalk(
@rule table(~tns, ~idxs...) => begin
name = alias(fgensym(:A))
name = alias(gensym(:A))
subquery(name, table(tns, idxs...))
end
))
Expand Down Expand Up @@ -193,7 +193,7 @@ function materialize_squeeze_expand_productions(root)
preamble = []
args_2 = map(args) do arg
if (@capture arg reorder(relabel(~tns::isalias, ~idxs_1...), ~idxs_2...)) && Set(idxs_1) != Set(idxs_2)
tns_2 = alias(fgensym(:A))
tns_2 = alias(gensym(:A))
idxs_3 = withsubsequence(intersect(idxs_1, idxs_2), idxs_2)
push!(preamble, query(tns_2, reorder(relabel(tns, idxs_1), idxs_3)))
if idxs_3 == idxs_2
Expand Down Expand Up @@ -541,7 +541,7 @@ function set_loop_order(node, perms = Dict(), reps = Dict())
reps[lhs] = SuitableRep(reps)(rhs_2)
query(lhs, reformat(tns, rhs_2))
elseif @capture node query(~lhs, reformat(~tns, ~rhs))
arg = alias(fgensym(:A))
arg = alias(gensym(:A))
set_loop_order(plan(
query(A, rhs),
query(lhs, reformat(tns, A))
Expand Down Expand Up @@ -593,7 +593,7 @@ function optimize(prgm)
prgm = isolate_tables(prgm)
prgm = lift_subqueries(prgm)

#I shouldn't use fgensym but I do, so this cleans up the names
#I shouldn't use gensym but I do, so this cleans up the names
prgm = pretty_labels(prgm)

#These steps fuse copy, permutation, and mapjoin statements
Expand Down
2 changes: 1 addition & 1 deletion src/tensors/levels/atomic_element_levels.jl
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,7 @@ function instantiate(ctx, fbr::VirtualSubFiber{VirtualAtomicElementLevel}, mode:
preamble = quote
$val = $(lvl.val)[$(ctx(pos))]
end,
body = (ctx) -> VirtualScalar(nothing, lvl.Tv, lvl.Vf, fgensym(), val)
body = (ctx) -> VirtualScalar(nothing, lvl.Tv, lvl.Vf, gensym(), val)
)
end

Expand Down
10 changes: 5 additions & 5 deletions src/tensors/levels/element_levels.jl
Original file line number Diff line number Diff line change
Expand Up @@ -170,26 +170,26 @@ function instantiate(ctx, fbr::VirtualSubFiber{VirtualElementLevel}, mode::Reade
preamble = quote
$val = $(lvl.val)[$(ctx(pos))]
end,
body = (ctx) -> VirtualScalar(nothing, lvl.Tv, lvl.Vf, fgensym(), val)
body = (ctx) -> VirtualScalar(nothing, lvl.Tv, lvl.Vf, gensym(), val)
)
end

function instantiate(ctx, fbr::VirtualSubFiber{VirtualElementLevel}, mode::Updater)
(lvl, pos) = (fbr.lvl, fbr.pos)
VirtualScalar(nothing, lvl.Tv, lvl.Vf, fgensym(), :($(lvl.val)[$(ctx(pos))]))
VirtualScalar(nothing, lvl.Tv, lvl.Vf, gensym(), :($(lvl.val)[$(ctx(pos))]))
end

function instantiate(ctx, fbr::VirtualHollowSubFiber{VirtualElementLevel}, mode::Updater)
(lvl, pos) = (fbr.lvl, fbr.pos)
VirtualSparseScalar(nothing, lvl.Tv, lvl.Vf, fgensym(), :($(lvl.val)[$(ctx(pos))]), fbr.dirty)
VirtualSparseScalar(nothing, lvl.Tv, lvl.Vf, gensym(), :($(lvl.val)[$(ctx(pos))]), fbr.dirty)
end

function lower_assign(ctx, fbr::VirtualHollowSubFiber{VirtualElementLevel}, mode::Updater, op, rhs)
(lvl, pos) = (fbr.lvl, fbr.pos)
lower_assign(ctx, VirtualSparseScalar(nothing, lvl.Tv, lvl.Vf, fgensym(), :($(lvl.val)[$(ctx(pos))]), fbr.dirty), mode, op, rhs)
lower_assign(ctx, VirtualSparseScalar(nothing, lvl.Tv, lvl.Vf, gensym(), :($(lvl.val)[$(ctx(pos))]), fbr.dirty), mode, op, rhs)
end

function lower_assign(ctx, fbr::VirtualSubFiber{VirtualElementLevel}, mode::Updater, op, rhs)
(lvl, pos) = (fbr.lvl, fbr.pos)
lower_assign(ctx, VirtualScalar(nothing, lvl.Tv, lvl.Vf, fgensym(), :($(lvl.val)[$(ctx(pos))])), mode, op, rhs)
lower_assign(ctx, VirtualScalar(nothing, lvl.Tv, lvl.Vf, gensym(), :($(lvl.val)[$(ctx(pos))])), mode, op, rhs)
end
4 changes: 2 additions & 2 deletions src/tensors/levels/pattern_levels.jl
Original file line number Diff line number Diff line change
Expand Up @@ -128,9 +128,9 @@ instantiate(ctx, ::VirtualSubFiber{VirtualPatternLevel}, mode::Reader) = FillLea
function instantiate(ctx, fbr::VirtualSubFiber{VirtualPatternLevel}, mode::Updater)
val = freshen(ctx, :null)
push_preamble!(ctx, :($val = false))
VirtualScalar(nothing, Bool, false, fgensym(), val)
VirtualScalar(nothing, Bool, false, gensym(), val)
end

function instantiate(ctx, fbr::VirtualHollowSubFiber{VirtualPatternLevel}, mode::Updater)
VirtualScalar(nothing, Bool, false, fgensym(), fbr.dirty)
VirtualScalar(nothing, Bool, false, gensym(), fbr.dirty)
end
2 changes: 1 addition & 1 deletion src/util/shims.jl
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ ensuring that the variables in `ex` are not mutated by the arguments.
"""
macro barrier(args_ex...)
(args, ex) = args_ex[1:end-1], args_ex[end]
f = fgensym()
f = gensym()
esc(quote
$f = Finch.@closure ($(args...),) -> $ex
$f()
Expand Down
14 changes: 3 additions & 11 deletions src/util/staging.jl
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,9 @@ This macro does not support type parameters, varargs, or keyword arguments.
macro staged(def)
(@capture def :function(:call(~name, ~args...), ~body)) || throw(ArgumentError("unrecognized function definition in @staged"))

name_generator = fgensym(Symbol(name, :_generator))
name_invokelatest = fgensym(Symbol(name, :_invokelatest))
name_eval_invokelatest = fgensym(Symbol(name, :_eval_invokelatest))
name_generator = gensym(Symbol(name, :_generator))
name_invokelatest = gensym(Symbol(name, :_invokelatest))
name_eval_invokelatest = gensym(Symbol(name, :_eval_invokelatest))

def = quote
function $name_generator($(args...))
Expand Down Expand Up @@ -84,11 +84,3 @@ function refresh()
@eval $def
end
end

"""
fgensym([tag])
Generate a new fgensym symbol with the given name, for use in Finch.
"""
fgensym(tag) = eval(Finch, :(gensym($(QuoteNode(tag)))))
fgensym() = eval(Finch, :(gensym()))

0 comments on commit 1145e57

Please sign in to comment.