Skip to content

Error message to explain that non differentiable arguments should be marked as DoesNotExist(), not Zero() #112

@mzgubic

Description

@mzgubic

Currently, if someone mistakenly uses Zero() instead of DoesNotExist() inside the rrule

function partly_dne(a, i)
    return a[i]
end

function ChainRulesCore.rrule(::typeof(partly_dne), a, i)
    y = partly_dne(a, i)
    function partly_dne_pullback(ȳ)
        grad = zeros(size(a))
        grad[i] = ȳ
        return (NO_FIELDS, grad, Zero())
    end
    return y, partly_dne_pullback
end

The error message isn't very clear

julia> rrule_test(partly_dne, rand(), (rand(4), rand(4)), (1, nothing))
Test Failed at /Users/mzgubic/Projects/ChainRules.jl/dev/ChainRulesTestUtils/src/testers.jl:295
  Expression: x̄_ad isa DoesNotExist
   Evaluated: Zero() isa DoesNotExist
ERROR: There was an error during testing

Metadata

Metadata

Assignees

No one assigned

    Labels

    documentationImprovements or additions to documentationgood first issueGood for newcomers

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions