-
-
Notifications
You must be signed in to change notification settings - Fork 70
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
ODEAdjointProblem doesn't work with DynamicalODEProblem #340
Comments
I think this is another instance of the same Zygote issue @DhairyaLGandhi . import DiffEqBase: DynamicalODEProblem
import DiffEqSensitivity:
solve,
ODEAdjointProblem,
InterpolatingAdjoint
import DifferentialEquations
sol = solve(
DynamicalODEProblem(
(v, x, p, t) -> x,
# ERROR: LoadError: type Nothing has no field x
# (v, x, p, t) -> [0.0, 0.0],
# ERROR: LoadError: MethodError: no method matching ndims(::Type{Nothing})
(v, x, p, t) -> v,
[0.0, 0.0],
[0.0, 0.0],
(0.0, 1.0),
)
)
solve(
ODEAdjointProblem(
sol,
InterpolatingAdjoint(),
(out, x, p, t, i) -> (out .= 0),
[sol.t[end]],
)
) This works. The difference is just that I made sure the first ODE function had Isolating this a little bit more, here's the MWE. using RecursiveArrayTools
function f(ap)
sum(ArrayPartition(ap.x[1],ap.x[2]))
end
ap = ArrayPartition([0.0, 0.0],[0.0, 0.0])
Zygote.gradient(f,ap) # ((x = ([1.0, 1.0], [1.0, 1.0]),),) Notice this one works fine! But... function f(ap)
ap.x[2][1] + ap.x[2][2]
end
ap = ArrayPartition([0.0, 0.0],[0.0, 0.0])
Zygote.gradient(f,ap) # ((x = (nothing, [1.0, 1.0]),),) This confuses This is the same issue as #339 (comment) , but even simpler. Let me escalate this to Zygote.jl itself. |
Yeah it seems like conflating I just realized that a potential workaround here is to "use" the naughty input parts: import RecursiveArrayTools: ArrayPartition
import Zygote
function f(ap)
ap.x[2][1] + ap.x[2][2] + 0.0 * ap.x[1][1]
end
ap = ArrayPartition([0.0, 0.0],[0.0, 0.0])
Zygote.gradient(f,ap) # ((x = ([0.0, 0.0], [1.0, 1.0]),),) Not satisfying, but hopefully it unblocks others as well! |
SciML/RecursiveArrayTools.jl#115 works around this for ArrayPartition, so this issue will be closed when that is merged and tagged. But more generally, Zygote's fallback here requires that we do this same trick for every single type with this property, so it's a bad fallback and yeah we'll need to get a better one. |
Possibly related to #339. However in this case I'm not nesting
ArrayPartition
s in any interesting way AFAIU.This is another 2-for-1:
gives
The text was updated successfully, but these errors were encountered: