Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Inplace getindex rrule #240

Merged
merged 9 commits into from
Oct 19, 2020
Merged

Inplace getindex rrule #240

merged 9 commits into from
Oct 19, 2020

Conversation

oxinabox
Copy link
Member

@oxinabox oxinabox commented Jul 17, 2020

Something @willtebbutt pointed out to me the other day
We can use ChainRules's ability to define an inplace addition operation (the InplaceThunk) when defining the gradient of getindex.
Which would be much more efficient than e.g retuning a onehot dense matrix to be summed, and marginally more efficient than returning a onehot sparse matrix to be summed.
It would also mean maybe we would not need to say that the differential for a primal type depends on the primal type and the operation.
Though we still might want to.

Right now Zygote won't use ChainRules's inplace accumulation stuff, but idk how hard it would be to enable it.
I think it might be fine or it might (only) break nesting Zygote.
It would definitely require the deeper change over to ChainRules's types.

I am tempted to leave this WIP for a while and improve our abstractions, while trying to get Zygote to actually use this. (Though that is a bigger project as need Zygote to use ChainRules's trypes)

  • Test single indexes
  • Make it so can test this with ChainRulesTestUtils (improve abstractions first?)

@oxinabox oxinabox changed the title Inplace getindex rrule WIP: Inplace getindex rrule Jul 17, 2020
@oxinabox
Copy link
Member Author

For Zygote to benifit it needs FluxML/Zygote.jl#603
and indeed this is itself another good motivation to do that.

@oxinabox oxinabox changed the title WIP: Inplace getindex rrule Inplace getindex rrule Oct 15, 2020
@oxinabox oxinabox closed this Oct 16, 2020
@oxinabox oxinabox reopened this Oct 16, 2020
test/rulesets/Base/array.jl Outdated Show resolved Hide resolved
test/rulesets/Base/array.jl Outdated Show resolved Hide resolved
src/rulesets/Base/array.jl Outdated Show resolved Hide resolved
@oxinabox
Copy link
Member Author

Ok, seems like i need to do a bunch more methods of getindex to make this work for Nabla so this PR is going to get bigger.

@willtebbutt
Copy link
Member

This broadly LGTM. It would be good to add some tests involving arrays with an element type that's not a number -- perhaps an array of arrays?

@oxinabox
Copy link
Member Author

This broadly LGTM. It would be good to add some tests involving arrays with an element type that's not a number -- perhaps an array of arrays?

Good Idea
Probably don't need to many of them since most of the tests are about indexing behavour.
Just some to make sure it goes ok

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants