Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

besseli (and friends?) #945

Closed
willtebbutt opened this issue Apr 14, 2021 · 5 comments
Closed

besseli (and friends?) #945

willtebbutt opened this issue Apr 14, 2021 · 5 comments

Comments

@willtebbutt
Copy link
Member

willtebbutt commented Apr 14, 2021

using SpecialFunctions
using Zygote

Zygote.gradient(besseli, 4, 0.3)

yields

ERROR: not implemented
Stacktrace:
  [1] error(s::String)
    @ Base ./error.jl:33
  [2] (::SpecialFunctions.var"#27#30")()
    @ SpecialFunctions ~/.julia/packages/SpecialFunctions/mFAQ4/src/chainrules.jl:41
  [3] (::ChainRulesCore.Thunk{SpecialFunctions.var"#27#30"})()
    @ ChainRulesCore ~/.julia/packages/ChainRulesCore/D0go7/src/differentials/thunks.jl:98
  [4] unthunk
    @ ~/.julia/packages/ChainRulesCore/D0go7/src/differentials/thunks.jl:99 [inlined]
  [5] (::ChainRulesCore.var"#11#12"{ChainRulesCore.Thunk{SpecialFunctions.var"#27#30"}})()
    @ ChainRulesCore ~/.julia/packages/ChainRulesCore/D0go7/src/differentials/thunks.jl:40
  [6] (::ChainRulesCore.Thunk{ChainRulesCore.var"#11#12"{ChainRulesCore.Thunk{SpecialFunctions.var"#27#30"}}})()
    @ ChainRulesCore ~/.julia/packages/ChainRulesCore/D0go7/src/differentials/thunks.jl:98
  [7] unthunk
    @ ~/.julia/packages/ChainRulesCore/D0go7/src/differentials/thunks.jl:99 [inlined]
  [8] *(a::ChainRulesCore.Thunk{ChainRulesCore.var"#11#12"{ChainRulesCore.Thunk{SpecialFunctions.var"#27#30"}}}, b::Float64)
    @ ChainRulesCore ~/.julia/packages/ChainRulesCore/D0go7/src/differential_arithmetic.jl:111
  [9] besseli_pullback
    @ ~/.julia/packages/SpecialFunctions/mFAQ4/src/chainrules.jl:38 [inlined]
 [10] ZBack
    @ ~/.julia/packages/Zygote/RxTZu/src/compiler/chainrules.jl:77 [inlined]
 [11] (::Zygote.var"#41#42"{Zygote.ZBack{SpecialFunctions.var"#besseli_pullback#29"{Int64, Float64}}})(Δ::Float64)
    @ Zygote ~/.julia/packages/Zygote/RxTZu/src/compiler/interface.jl:41
 [12] gradient(::Function, ::Int64, ::Vararg{Any, N} where N)
    @ Zygote ~/.julia/packages/Zygote/RxTZu/src/compiler/interface.jl:59
 [13] top-level scope
    @ REPL[4]:1

Perhaps @oxinabox or @mzgubic could confirm whether the following explanation of the problem is accurate:

Zygote doesn't do thunks, so when converting between Zygote and ChainRules types we strip out any thunks by unthunk`ing them.

At the same time, there are some rules in ChainRules where we don't have cotangents implemented w.r.t. certain arguments. For such arguments, we return a thunk, which throws the above error if it's ever actually used. besseli is one such function.

The thinking is that if you never actually need the gradient w.r.t. the argument whose gradient we don't know, we can avoid ever having to worry about not knowing how to compute it. Alas, unthunking to get a type that Zygote can handle means we always evaluate the thing.

Any thoughts on good resolutions?

This is going to cause problems for some TemporalGPs.jl work in the next couple of days. I can add a workaround for now, but that's obviously not much use for anyone else who runs into this.

edit: cross-referenced relevant TemporalGPs issue.

@willtebbutt willtebbutt changed the title besseli besseli (and friends?) Apr 14, 2021
@mzgubic
Copy link
Collaborator

mzgubic commented Apr 15, 2021

The description sounds right to me. Probably the best way to solve this is to support ChainRules types internally in Zygote, as described in #603. I guess an alternative is to add an @adjoint for the function, but that moves us in the opposite direction that we want to go (i.e. more rrules, less adjoints)

@DhairyaLGandhi
Copy link
Member

Hey I like adjoints :p

@oxinabox
Copy link
Member

oxinabox commented Apr 15, 2021

@devmotion proposed a most short term solution yesterday.
Introduce a type specifically for NotImplemented.
JuliaDiff/ChainRulesCore.jl#334

we probably want that even once we do #603

@devmotion
Copy link
Collaborator

The Zygote issue is also discussed in #873.

@willtebbutt
Copy link
Member Author

Closing in favour of #873

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

5 participants