Gradient of primitive and confusion #19514
Unanswered
ericmjonas
asked this question in
General
Replies: 1 comment 4 replies
-
Can I ask why you're defining primitives at all? A big reason that |
Beta Was this translation helpful? Give feedback.
4 replies
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
-
Hello ! I'm filing this as a "discussion" because I'm sure it's an error in my understanding and not actually a bug. I'm trying to fit a function's gradient with machine learning, where the function contains calls to a custom primitive p(x) which is backed by a pile of C++/CUDA.
In pcode I have:
Technically I have two primitives, a fwd and a bwd that I then set up with defvjp. As far as I can tell, this should only ever require the VJP of
my_func
.However, I am getting the error:
Differentiation rule for 'my_bwd_p' not implemented
which is confusing me, as I really don't think we should need anything beyond first derivatives formy_func
.I have constructed a full example for my own primitive that just implements
sin
, below. I've tried everything I can think of, including liberal application ofjax.lax.stop_gradient
. The real code I care about formy_func
'sfwd
andbwd
primitives is incredibly complicated, and the idea of implementing a higher-order gradient (that is, the derivative rule forbwd
) is sort of soul-crushing. But I don't think it should be necessary?I'm attaching a full example of what I'm running into below, and it is also available as a collab notebook here: https://colab.research.google.com/drive/1T0HcQlELiUJVw6OhptNox6Ed_LsaCYOm?usp=sharing
Beta Was this translation helpful? Give feedback.
All reactions