Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use ChainRules types #603

Open
oxinabox opened this issue Apr 20, 2020 · 5 comments
Open

Use ChainRules types #603

oxinabox opened this issue Apr 20, 2020 · 5 comments
Labels
ChainRules adjoint -> rrule, and further integration help wanted Extra attention is needed up for grabs anyone is welcome to contribute with a PR to fix the issue

Comments

@oxinabox
Copy link
Member

This issue is about swapping Zygote over to use ChainRule's types by default.
When #366 is merged rules coming out of ChainRules will use its types like Composite and AbstractZero,
but thing created via Source Code Transform (SCT) will still use NamedTuple and nothing.

This is fine as they are mutually compatible because accum falls back to + and ChainRules overloads +.

The use of Base types in Zygote tends to cause issues as its hard to add methods for NamedTuple and nothing, due to type piracy.
Especially things like defining linear operations on them (e.g. overloading things from LinearAlgebra), as well as defining addition +.
Particular discussion on Composite ans structured differential types is in #462

Related to:
#454 #419 #329

@oxinabox
Copy link
Member Author

oxinabox commented Apr 20, 2020

Discussion @MikeInnes and I had today

Lyndon White:ox:
I am now going through and fixing broken tests
or rather fixing issues that the tests reveal
One of the tests fails unless I define
Base.:*(::Int, ::Nothing) = nothing
which presumable means that somewhere Zygote (on perhaps inside a ZygoteRule?) there was something removing that nothing before it could hit
if I just went full over to ChainRules types this would not be an issue since AbstractZero does define multiplcation.
Does this mean there is some code somewhere that would be able to be deleted?

Mike Innes
By default @adjoint rules ignore nothing entirely
So I guess it’s just that we’re passing nothing to an rrule which tries to use it like a numeric gradient

Lyndon White:ox:
yep

Mike Innes
We can just add that dispatch to the CR rule wrapper function though, right?

Lyndon White:ox:
Yeah we can
What do you do when 1 input is nothing and other is non-nothing?
and so on and so forth for nested tuples of input?

Mike Innes
Since it’s an rrule there’s only one input

Lyndon White:ox:
but that input could be a structure (inc tuple)

Mike Innes
The tuple/struct functions can just carry on using the Zygote definitions for now

Lyndon White:ox:
indeed, I am just curious

Mike Innes
If you have an adjoint that takes a struct, all you usually do with the struct elements is forward them on to other gradient functions

Lyndon White:ox:
That seems like it could mean missing out on some CSE.

Mike Innes
I can imagine that in some cases you’d need to explicitly cast to a reasonable zero type, but that doesn’t seem to have come up too much in practice

Lyndon White:ox:
Yeah, I only have seen a few structure pullbacks, so my instrincts are not great (edited)
I think Will probably has a ton of them in his GP via Kalman filter package though

Mike Innes
It would be useful to know if he runs into issues. I imagine we could implement a generic “cast zero” that supports AbstractZero and nothing

Lyndon White:ox:
I think we do just want to get rid of nothing
there are just too many random issues that show up from nothing not supporting +

Mike Innes
Perhaps; my impression has been that most of those kinds of issues have highlighted things you want to fix anyway

Lyndon White:ox:
That is true
at very least to an exent

Mike Innes
Julia gives you a lot of ease-of-use/performance tradeoffs in cases like this, what with generic code often being slow

Lyndon White:ox:
I don’t think this is one of those cases
Also there are a bunch of things that are simplified if you can define things on them (beyond + and *)
Apparently adding Base.:/(z::AbstractZero, ::Any) = z massively simplified one of the most complicated rules relating to I think a derivative of a structure (Probably SVD or Cholekey) (edited)

Mike Innes
That’s interesting

Lyndon White:ox:
I would like to create a package that adds traits to all linear operators,
so we could easily go and make them propage zeros like that.
(idea not fully formed)

Mike Innes
I’m thinking of cases like getindex, where AbstractZero would propagate ‘for free’, but actually having them in your matrices is terrible for performance compared to casting to Float64 (edited)
Since then you lose blas for the rest of the computation

Lyndon White:ox:
Oh yeah that is true

Mike Innes
Of course we can just remove them; but then it’s equivalent to nothing while being easier to get wrong
That’s just one case of course

Lyndon White:ox:
In a matrix, if you are in that case, you should probably be using a sparse type
You can’t use nothing there either for same reason
I conceptuallize AbstractZero as equiv to the type of a structural zero in a sparse matrix
It acts the same, including in presence of NaN
I don’t see how it is easier to get wrong in that case?

Mike Innes
e.g. getindex has given us a lot of pain because nothing keeps making it throw errors. But it’s actually good that we cast them away where possible. If we had AbstractZero, those cases would have worked fine but been slow
So I think in that case the errors were a good motivator to fix the problem (not that you couldn’t write the same code over AbstractZero)

Lyndon White:ox:
That matrix case might actually be a reason to use AbstractZero
Since can define convert(T, ::AbstractZero) = zero(T)
and then setindex! will do the convert for us.

Mike Innes
That is true

Mike Innes
My other concern is about the type of the expression gradient(f, x)
With x a struct, it seems wrong that this is Union{Number,Struct}
because you can write code that treats this as a number, and it’ll break when x actually has a gradient

Lyndon White:ox:
AbstractZero does not subtype Number

Mike Innes
Really I think the answer here is that the type is Differential where Zero is a differential, but in that case the name is a bit misleading
Not that I have a better idea
This may or may not really be worth worrying about, but I think you appreciate these kinds of issues

Thinking about AbstractZero as ‘the identity differential’ rather than zero(x) for all x does make me feel a bit warmer to it though

@DhairyaLGandhi
Copy link
Member

Fwiw I think the concept of a AbstractGradient type, with ZeroGrad as an identity, sounds reasonable. This was the implicit definition of having nothing as a valid gradient, I feel.

@oxinabox
Copy link
Member Author

One important thing to do as part of this PR is to make sure to have a clean deprecation path,
so we don't break all existing custom rules that assume NamedTuple and nothing work.
Part of this might be adding stuff to ZygoteRules to do conversion.

@oxinabox
Copy link
Member Author

Advantages of this:

  • Clearer more explict types
  • Allowed to overload things like + and linear operators on them
  • Avoid overhead of converting to and from chainrules types (normally compiles away but not always especially when nesting AD)
  • Allow use of Thunk to defer and potentially eleminate computations that are never needed
  • Allow use of InplaceThunk to allow inplace accumulation. Particularly relevent for for getindex, see Inplace getindex rrule JuliaDiff/ChainRules.jl#240. Would need to still disable itself and use non-inplace during nested AD (can be don't via a rrule)

@nickrobinson251
Copy link

Another issue we should make sure is resolved by/when changing to ChainRules types: #802

@ToucheSir ToucheSir added help wanted Extra attention is needed up for grabs anyone is welcome to contribute with a PR to fix the issue ChainRules adjoint -> rrule, and further integration labels Dec 15, 2022
ToucheSir added a commit that referenced this issue Mar 5, 2023
This will give us more flexibility to implement internal changes such as
#603 without changing the user-facing API.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ChainRules adjoint -> rrule, and further integration help wanted Extra attention is needed up for grabs anyone is welcome to contribute with a PR to fix the issue
Projects
None yet
Development

No branches or pull requests

4 participants