Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Zero is not nothing: getproperty adjoint issues #802

Open
ChrisRackauckas opened this issue Sep 28, 2020 · 0 comments
Open

Zero is not nothing: getproperty adjoint issues #802

ChrisRackauckas opened this issue Sep 28, 2020 · 0 comments

Comments

@ChrisRackauckas
Copy link
Member

ChrisRackauckas commented Sep 28, 2020

SciML/SciMLSensitivity.jl#340 (comment) highlights an issue with the current adjoint definition that is used in Zygote. Essentially it's as follows. The ArrayPartition is a type that is a tuple of arrays, and the object acts like a vector which is the concatenation of the arrays.

using RecursiveArrayTools
ap = ArrayPartition([1.0, 2.0],[3.0, 4.0])
ap[3] # 3.0

Simple? Yes, but enough to break this. Zygote's type handling allows it to pull back on some functions:

function f(ap)
    sum(ArrayPartition(ap.x[1],ap.x[2]))
end
ap = ArrayPartition([0.0, 0.0],[0.0, 0.0])
Zygote.gradient(f,ap) # ((x = ([1.0, 1.0], [1.0, 1.0]),),)

Notice that this works like we'd expect: 4 values in, 4 values out. Now let's try something else:

function f(ap)
    ap.x[2][1] + ap.x[2][2]
end
ap = ArrayPartition([0.0, 0.0],[0.0, 0.0])
Zygote.gradient(f,ap) # ((x = (nothing, [1.0, 1.0]),),)

And 💥 . There are multiple issues here. First of all, having nothing in an array partition doesn't work and all operations will fail. But secondly, even if we do manually interpret that to zero to fix it for Zygote, Zygote isn't giving us back an object that is sized, so the resulting derivative only has 3 values! In some sense, the missing value isn't nothing, but it's a Zero(2 dimensional array), and without that information we cannot appropriately interpret and reconstruct the vector to get the correct ArrayPartition!

Now can I define my way out of this? No. Let's say we wanted to define a literal_getproperty:

using ZygoteRules
ZygoteRules.@adjoint function ZygoteRules.literal_getproperty(A::ArrayPartition, ::Val{:x})
  function literal_ArrayPartition_x_adjoint(d)
      @show d
      ArrayPartition(d)
  end
  A.x,literal_ArrayPartition_x_adjoint
end

If we then run the function that worked:

function f(ap)
    sum(ArrayPartition(ap.x[1],ap.x[2]))
end

we'll notice that we get an error. The reason is because the ap.x[2] call causes only the second value in the tuple to be used, and so the @show d shows d = (nothing, [1.0, 1.0]) and it errors because the ArrayPartition is malformed. Again, nothing is giving not enough information for the user to correct this! I can manually correct for it by using the forward pass value:

ZygoteRules.@adjoint function ZygoteRules.literal_getproperty(A::ArrayPartition, ::Val{:x})
  function literal_ArrayPartition_x_adjoint(d)
      (ArrayPartition((isnothing(d[i]) ? zero(A.x[i]) : d[i] for i in 1:length(d))...),)
  end
  A.x,literal_ArrayPartition_x_adjoint
end

and tada it's handled correctly now.

function f(ap)
    ap.x[2][1] + ap.x[2][2]
end
ap = ArrayPartition([0.0, 0.0],[0.0, 0.0])
Zygote.gradient(f,ap) # ArrayPartition([0.0, 0.0],[1.0, 1.0])

So moral of the story, using nothing there is just wrong and if this is fixed, a lot of getproperty fallback definitions will have a much better chance of working. As a reference, here it's an array type but #510 's definition was constrained to only work on single partial dual numbers, and if you try to alleviate the issue you'll see it's this same issue of having to handle nothing in odd ways. Somehow this should really be a zero.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant