Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

always unthunk results #79

Merged
merged 2 commits into from
Jul 25, 2022
Merged

Conversation

oscardssmith
Copy link
Member

in theory this shouldn't be necessary, but it is a good fallback to make sure we don't return thunks to the user. Also, it's 0 overhead as long as the preceding calculation is correctly inferred.

@mcabbott
Copy link
Member

This is intended to only affect the outermost level, the thing returned to the user who types f'(x), right?

Should the same be done to gradient(f, xs...)?

@oscardssmith
Copy link
Member Author

this should be done to gradient. diffractor would actually prefer if thunks didn't exist in the first place, so this is just a fallback to prevent the user from seeing them. our ideal scenario would be a way to get derivatives from chainrules that don't have thunks in the first place, but that is a lot harder than this pr.

@mcabbott
Copy link
Member

mcabbott commented Jul 17, 2022

Some way of not computing things which will be discarded does seem desirable. With the example of JuliaDiff/ChainRulesCore.jl#558 (but picturing a,b huge arrays):

julia> Diffractor.PrimeDerivativeBack(x -> f(x, b) + 10f(2x, b))(a)
rrule is called
rrule is called
∇a is called
∇a is called
42.0f0  # with this PR, seems ideal? ∇b never run despite accumulation

julia> gradient(x -> f(x, b), a)  # should unthunk answer, but not run ∇b
rrule is called
(Thunk(var"#28#31"{Float32, Float32, Float32}(1.0f0, 1.0f0, 2.0f0)),)

Surely a sufficiently smart compiler could notice and eliminate the ∇b branch without @thunk, but is this likely to happen soon?

@oscardssmith
Copy link
Member Author

The goal for diffractor is to use escape analysis to remove the computation entirely which is made easier with simpler types. We aren't there yet, but once stage 2 is integrated, we'll be closeish.

@mcabbott
Copy link
Member

This would be great. If it works, it's possible all thunks could be stripped out of ChainRules, since (IIRC) nothing else uses them anyway?

Although, besides delayed calculation, they are intended one day to save memory too... xref #69 I guess.

@oscardssmith
Copy link
Member Author

as I understand it, ReverseDiff.jl likes thunks, but I'm not sure.

@mcabbott
Copy link
Member

Yes I suppose, although you have to opt-in. More directly I forgot that Yota doesn't un-thunk internally, only the final result:

julia> using Yota

julia> grad(x -> f(x, b), a)  # unthunks final result
rrule is called
∇a is called
(2.0f0, (ZeroTangent(), 2.0f0))

julia> using Zygote

julia> Zygote.gradient(x -> f(x, b), a)  # unthunks all
rrule is called
∇a is called
∇b is called
(2.0f0,)

julia> using ReverseDiff

julia> ReverseDiff.@grad_from_chainrules f(x::TrackedReal, y::Real)

julia> ReverseDiff.gradient(x -> f(x[1], b), [a])  # maybe this relies on thunks inside?
rrule is called
∇a is called
1-element Vector{Float32}:
 2.0

src/interface.jl Outdated Show resolved Hide resolved
@codecov-commenter
Copy link

codecov-commenter commented Jul 19, 2022

Codecov Report

Merging #79 (9356c2d) into main (82096ee) will decrease coverage by 1.20%.
The diff coverage is 100.00%.

@@            Coverage Diff             @@
##             main      #79      +/-   ##
==========================================
- Coverage   52.62%   51.41%   -1.21%     
==========================================
  Files          21       21              
  Lines        2172     2118      -54     
==========================================
- Hits         1143     1089      -54     
  Misses       1029     1029              
Impacted Files Coverage Δ
src/interface.jl 70.90% <100.00%> (ø)
src/stage1/forward.jl 69.38% <0.00%> (-7.49%) ⬇️
src/tangent.jl 32.97% <0.00%> (-1.07%) ⬇️
src/stage1/recurse.jl 91.48% <0.00%> (-0.85%) ⬇️
src/jet.jl 40.00% <0.00%> (-0.50%) ⬇️
src/stage1/recurse_fwd.jl 94.11% <0.00%> (-0.17%) ⬇️
src/stage1/generated.jl 73.30% <0.00%> (-0.13%) ⬇️
src/stage2/interpreter.jl 0.00% <0.00%> (ø)
src/stage2/abstractinterpret.jl 0.00% <0.00%> (ø)

Continue to review full report at Codecov.

Legend - Click here to learn more
Δ = absolute <relative> (impact), ø = not affected, ? = missing data
Powered by Codecov. Last update 82096ee...9356c2d. Read the comment docs.

@oscardssmith
Copy link
Member Author

I think this is good to merge. (jacobians have been extricated).

@mcabbott
Copy link
Member

Can this branch unthunk apply to gradient too, and perhaps within Tangent? Examples from above:

julia> Diffractor.PrimeDerivativeBack(x -> sum(sum(x.a) .+ x.b))((a=[1,2], b=[3,4], c=[5,6]))  # maybe this should unthunk within Tangent? 
Tangent{NamedTuple{(:a, :b, :c), Tuple{Vector{Int64}, Vector{Int64}, Vector{Int64}}}}(b = [1.0, 1.0], a = InplaceableThunk(ChainRules.var"#..., Thunk(ChainRules.var"#...)))

julia> Diffractor.gradient(x -> sum(sum(x.a) .+ x.b), (a=[1,2], b=[3,4], c=[5,6]))[1]  # ditto?
Tangent{NamedTuple{(:a, :b, :c), Tuple{Vector{Int64}, Vector{Int64}, Vector{Int64}}}}(b = [1.0, 1.0], a = InplaceableThunk(ChainRules.var"#..., Thunk(ChainRules.var"#...)))

@oscardssmith
Copy link
Member Author

I think this should be merged as is. I haven't found a great way to unthunk within Tangent, and these are more of a fallback anyway until we find a better way to get rrules that don't have Thunks in the first place.

Copy link
Member

@mcabbott mcabbott left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems ok by me. Would be nice if tests passed, though...

For disabling thunks, xref JuliaDiff/ChainRulesCore.jl#568 (inspired by difficulties with Zygote over Zygote).

@mcabbott
Copy link
Member

Master has the same errors now, on Julia nightly: ((((sin')')')')(1.0) == sin(1.0) MethodError: no method matching fieldnames(::Nothing), so it's not this PR's fault.

@mcabbott mcabbott merged commit a2ea087 into JuliaDiff:main Jul 25, 2022
@oscardssmith oscardssmith deleted the always-unthunk branch July 25, 2022 22:25
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants