diff --git a/src/host/gpunumber.jl b/src/host/gpunumber.jl index e1ab585e..5dde4e7b 100644 --- a/src/host/gpunumber.jl +++ b/src/host/gpunumber.jl @@ -15,9 +15,16 @@ AN.number(g::GPUNumber) = @allowscalar g.val[] maybe_number(g::GPUNumber) = AN.number(g) maybe_number(g) = g +number_type(::GPUNumber{T}) where T = eltype(T) + # When operations involve other `::Number` types, # do not convert back to `GPUNumber`. AN.like(::Type{<: GPUNumber}, x) = x # When broadcasting, just pass the array itself. Base.broadcastable(g::GPUNumber) = g.val + +# Overload to avoid copies. +Base.one(g::GPUNumber) = one(number_type(g)) +Base.zero(g::GPUNumber) = zero(number_type(g)) +Base.identity(g::GPUNumber) = g