Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adjoint for parent for LowerTriangular and UpperTriangular #1444

Merged
merged 2 commits into from
Aug 8, 2023

Conversation

torfjelde
Copy link
Contributor

LowerTriangular and UpperTriangular is missing adjoints for parent which can lead to downstream issues: TuringLang/Bijectors.jl#280 (comment)

This PR adds those, similar to the existing for Adjoint and Transposed.

It's unclear to me whether this should actually go into ChainRules.jl or not 😕

I'll add tests in a bit.

@CarloLucibello
Copy link
Member

Ideally all of them should be moved to ChainRules, but that can be done in a future PR. Can you add a test?

@torfjelde
Copy link
Contributor Author

Added tests 👍

@torfjelde
Copy link
Contributor Author

Are the test failures related to this PR?

@ToucheSir
Copy link
Member

They do not appear to be, but we should fix them. Unfortunately I think that requires some coordination on the AbstractFFTs side and I have no idea what's going on there or with the FFT rules here. Perhaps @gaurav-arya or someone else in the know could help?

@ToucheSir ToucheSir merged commit dffa378 into FluxML:master Aug 8, 2023
9 of 12 checks passed
@gaurav-arya
Copy link

Looks to be erroring on the FFT Chain Rules that have existed for quite a while in AbstractFFTs.jl, when an object of type Zygote.OneElement is provided to the pullback. The error seems to occur in the convert line added in JuliaMath/AbstractFFTs.jl#105.

Are objects of type Zygote.OneElement meant to be passed into a pullback, and how are they meant to be handled? I just tried to play a bit locally, and the object didn't seem to support many operations, e.g. broadcasting or copying or scaling.

@gaurav-arya
Copy link

OK, I see that Zygote.OneElement does indeed support those things, I just messed up by constructing its axes with unit ranges rather than Base.OneTo's.

So the error is indeed due to the convert line introduced in JuliaMath/AbstractFFTs.jl#105.

@torfjelde
Copy link
Contributor Author

Btw what's the release schedule like for Zygote? Would it be possible to make a new release somewhat soonish?:)

@ToucheSir
Copy link
Member

Was hoping to cut a patch release after merging this PR, but I need to figure out whether the current failures on CI (ref. #1446) require a fix on the Zygote side to ensure we're not breaking anyone's code. However I am not the best person to investigate this since I have little knowledge of the interaction between the Zygote and AbstractFFT rules, so if anyone has more knowledge of this please chime in.

@torfjelde
Copy link
Contributor Author

Gotcha; thanks @ToucheSir 🙏

@devmotion
Copy link
Collaborator

The errors are due to the CUDA fix in JuliaMath/AbstractFFTs.jl#105 and similar conversions for adjoint plans. These changes in AbstractFFTs broke compatibility of the rules with Zygote.OneElement. I don't think you can do anything in Zygote about it - unless you want to re-add the incorrect Zygote FFT rules. It has to be fixed downstream (there are open issues and PRs).

@ToucheSir
Copy link
Member

@torfjelde
Copy link
Contributor Author

Awesome; thanks!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants