Skip to content

Commit

Permalink
Remove marktype (#2233)
Browse files Browse the repository at this point in the history
  • Loading branch information
wsmoses authored Dec 28, 2024
1 parent 378dc21 commit f46e44d
Show file tree
Hide file tree
Showing 3 changed files with 2 additions and 100 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "Enzyme"
uuid = "7da242da-08ed-463a-9acd-ee780be4f1d9"
authors = ["William Moses <wmoses@mit.edu>", "Valentin Churavy <vchuravy@mit.edu>"]
version = "0.13.25"
version = "0.13.26"

[deps]
CEnum = "fa961155-64e5-5f13-b03f-caf6b980ea82"
Expand Down
74 changes: 1 addition & 73 deletions src/Enzyme.jl
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ export autodiff,
make_zero!

export jacobian, gradient, gradient!, hvp, hvp!, hvp_and_gradient!
export markType, batch_size, onehot, chunkedonehot
export batch_size, onehot, chunkedonehot

using LinearAlgebra
import SparseArrays
Expand Down Expand Up @@ -1446,78 +1446,6 @@ result, ∂v, ∂A
aug_thunk, adj_thunk
end

# White lie, should be `Core.LLVMPtr{Cvoid, 0}` but that's not supported by ccallable
Base.@ccallable function __enzyme_float(x::Ptr{Cvoid})::Cvoid
return nothing
end

Base.@ccallable function __enzyme_double(x::Ptr{Cvoid})::Cvoid
return nothing
end

@inline function markType(::Type{T}, ptr::Ptr{Cvoid}) where {T}
markType(Base.unsafe_convert(Ptr{T}, ptr))
end

@inline function markType(data::Array{T}) where {T}
GC.@preserve data markType(pointer(data))
end

# TODO(WM): We record the type of a single index here, we could give it a range
@inline function markType(data::SubArray)
GC.@preserve data markType(pointer(data))
end

@inline function markType(data::Ptr{Float32})
@static if sizeof(Int) == sizeof(Int64)
Base.llvmcall(
(
"declare void @__enzyme_float(i8* nocapture) nounwind define void @c(i64 %q) nounwind alwaysinline { %p = inttoptr i64 %q to i8* call void @__enzyme_float(i8* %p) ret void }",
"c",
),
Cvoid,
Tuple{Ptr{Float32}},
data,
)
else
Base.llvmcall(
(
"declare void @__enzyme_float(i8* nocapture) nounwind define void @c(i32 %q) nounwind alwaysinline { %p = inttoptr i32 %q to i8* call void @__enzyme_float(i8* %p) ret void }",
"c",
),
Cvoid,
Tuple{Ptr{Float32}},
data,
)
end
nothing
end

@inline function markType(data::Ptr{Float64})
@static if sizeof(Int) == sizeof(Int64)
Base.llvmcall(
(
"declare void @__enzyme_double(i8* nocapture) nounwind define void @c(i64 %q) nounwind alwaysinline { %p = inttoptr i64 %q to i8* call void @__enzyme_double(i8* %p) ret void }",
"c",
),
Cvoid,
Tuple{Ptr{Float64}},
data,
)
else
Base.llvmcall(
(
"declare void @__enzyme_double(i8* nocapture) nounwind define void @c(i32 %q) nounwind alwaysinline { %p = inttoptr i32 %q to i8* call void @__enzyme_double(i8* %p) ret void }",
"c",
),
Cvoid,
Tuple{Ptr{Float64}},
data,
)
end
nothing
end

include("sugar.jl")

function _import_frule end # defined in EnzymeChainRulesCoreExt extension
Expand Down
26 changes: 0 additions & 26 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2374,32 +2374,6 @@ end
@test Enzyme.autodiff(Forward, timsteploop, Duplicated(2.0, 1.0))[1] 1.0
end

@testset "Type" begin
function foo(in::Ptr{Cvoid}, out::Ptr{Cvoid})
markType(Float64, in)
ccall(:memcpy,Cvoid, (Ptr{Cvoid}, Ptr{Cvoid}, Csize_t), out, in, 8)
end

x = [2.0]
y = [3.0]
dx = [5.0]
dy = [7.0]

@test markType(x) === nothing
@test markType(zeros(Float32, 64)) === nothing
@test markType(view(zeros(64), 16:32)) === nothing

GC.@preserve x y begin
foo(Base.unsafe_convert(Ptr{Cvoid}, x), Base.unsafe_convert(Ptr{Cvoid}, y))
end

GC.@preserve x y dx dy begin
autodiff(Reverse, foo,
Duplicated(Base.unsafe_convert(Ptr{Cvoid}, x), Base.unsafe_convert(Ptr{Cvoid}, dx)),
Duplicated(Base.unsafe_convert(Ptr{Cvoid}, y), Base.unsafe_convert(Ptr{Cvoid}, dy)))
end
end

function bc0_test_function(ps)
z = view(ps, 26:30)
C = Matrix{Float64}(undef, 5, 1)
Expand Down

0 comments on commit f46e44d

Please sign in to comment.