Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support try/catch on the happy (nothrow) path #1474

Merged
merged 9 commits into from
Sep 24, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ FillArrays = "0.8, 0.9, 0.10, 0.11, 0.12, 0.13, 1"
ForwardDiff = "0.10"
GPUArrays = "8.4.2, 9"
GPUArraysCore = "0.1.1"
IRTools = "0.4.11"
IRTools = "0.4.12"
LogExpFunctions = "0.3.1"
MacroTools = "0.5"
NaNMath = "0.3, 1"
Expand Down
35 changes: 19 additions & 16 deletions docs/src/limitations.md
Original file line number Diff line number Diff line change
Expand Up @@ -82,27 +82,30 @@ julia> gradient(rand(3)) do y

## Try-catch statements

Any expressions involving `try`/`catch` statements is not supported.
```julia
function tryme(x)
try
2 * x
catch e
throw(e)
end
end
Code containting try-catch blocks can be differentiated as long as no exception is actually thrown.

julia> gradient(rand(3)) do x
sum(tryme(x))
```julia
julia> function safe_sqrt(x)
try
sqrt(x)
catch
0.
end
end
ERROR: Compiling Tuple{typeof(tryme), Vector{Float64}}: try/catch is not supported.
Refer to the Zygote documentation for fixes.
https://fluxml.ai/Zygote.jl/latest/limitations
safe_sqrt (generic function with 1 method)

julia> gradient(safe_sqrt, 4.)
(0.25,)

julia> val, pull = pullback(safe_sqrt, -1.)
(0.0, Zygote.var"#76#77"{Zygote.Pullback{Tuple{typeof(safe_sqrt), Float64}, Any}}(∂(safe_sqrt)))

julia> pull(1.)
ERROR: Can't differentiate function execution in catch block at #= REPL[2]:3 =#.
Stacktrace:
...
```
Here `tryme` uses a `try`/`catch` statement, and Zygote throws an error when trying to differentiate it as expected. `try`/`catch` expressions are used for error handling, but they are less common in Julia compared to some other languages.

Here, the `safe_sqrt` function catches DomainError from the sqrt call when the input is out of domain and safely returns 0. Zygote is able to differentiate the function when no error is thrown by the sqrt call, but fails to differentiate when the control flow goes through the catch block.

## Foreign call expressions

Expand Down
14 changes: 7 additions & 7 deletions src/compiler/emit.jl
Original file line number Diff line number Diff line change
Expand Up @@ -36,22 +36,23 @@ concrete(T::DataType) = T
concrete(::Type{Type{T}}) where T = typeof(T)
concrete(T) = Any

runonce(b) = b.id in (1, length(b.ir.blocks))
runonce(b) = b.id in (1, length(b.ir.blocks)) &&
!any(((_,stmt),) -> isexpr(stmt.expr, :catch), b)

function forward_stacks!(adj, F)
stks, recs = [], []
pr = adj.primal
for b in blocks(pr), α in alphauses(block(adj.adjoint, b.id))
if runonce(b)
not_stack = runonce(b)
if not_stack
push!(recs, Variable(α))
else
stk = pushfirst!(pr, xstack(Any))
push!(recs, stk)
push!(b, xcall(Zygote, :_push!, stk, Variable(α)))
end
push!(stks, (b.id, alpha(α)))
push!(stks, (b.id, alpha(α), not_stack))
end
args = arguments(pr)[3:end]
rec = push!(pr, xtuple(recs...))
P = length(pr.blocks) == 1 ? Pullback{F} : Pullback{F,Any}
# P = Pullback{F,Any} # reduce specialisation
Expand All @@ -68,11 +69,10 @@ function reverse_stacks!(adj, stks)
self = argument!(entry, at = 1)
t = pushfirst!(blocks(ir)[end], xcall(:getfield, self, QuoteNode(:t)))
repl = Dict()
runonce(b) = b.id in (1, length(ir.blocks))
for b in blocks(ir)
for (i, (b′, α)) in enumerate(stks)
for (i, (b′, α, not_stack)) in enumerate(stks)
b.id == b′ || continue
if runonce(b)
if not_stack
val = insertafter!(ir, t, xcall(:getindex, t, i))
else
stk = push!(entry, xcall(:getindex, t, i))
Expand Down
25 changes: 18 additions & 7 deletions src/compiler/reverse.jl
Original file line number Diff line number Diff line change
Expand Up @@ -124,11 +124,6 @@ function instrument(ir::IR)
ex = st.expr
if isexpr(ex, :foreigncall, :isdefined)
continue
elseif isexpr(ex, :enter, :leave)
error("""try/catch is not supported.
Refer to the Zygote documentation for fixes.
https://fluxml.ai/Zygote.jl/latest/limitations
""")
Comment on lines -127 to -131
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I believer we still need this on versions of julia that do not have enter and leave?
so maybe we keep this and add a && VERSION<v"1.10" (is that the right bounds?)

elseif isexpr(ex, :(=))
@assert ex.args[1] isa GlobalRef
pr[v] = xcall(Zygote, :global_set, QuoteNode(ex.args[1]), ex.args[2])
Expand Down Expand Up @@ -258,7 +253,7 @@ function adjointcfg(pr::Primal)
end
if isempty(preds) || (!isempty(branches(b)) && branches(b)[end] == IRTools.unreachable)
# If `b` is unreachable, then no context produced by the primal should end up branching to `rb`
push!(rb, xcall(Core, :throw, "unreachable")) # `throw` is necessary for inference not to hit the `unreachable`
push!(rb, xcall(Base, :error, "unreachable")) # `throw` is necessary for inference not to hit the `unreachable`
branch!(rb, 0)
end
end
Expand All @@ -279,7 +274,7 @@ xaccum(ir, xs...) = push!(ir, xcall(Zygote, :accum, xs...))

function passthrough_expr(ex::Expr)
# Metadata we want to preserve
isexpr(ex, GlobalRef, :call, :isdefined, :inbounds, :meta, :loopinfo) && return true
isexpr(ex, GlobalRef, :call, :isdefined, :inbounds, :meta, :loopinfo, :enter, :leave, :catch) && return true
oxinabox marked this conversation as resolved.
Show resolved Hide resolved
# ccalls and more that are safe to preserve/required for proper operation:
# - jl_set_task_threadpoolid: added in 1.9 for @spawn
isexpr(ex, :foreigncall) && unwrapquote(ex.args[1]) in (:jl_set_task_threadpoolid,) && return true
Expand All @@ -297,9 +292,14 @@ function adjoint(pr::Primal)
for i = 1:length(sigs[b.id])
grad(sigs[b.id][i], arguments(rb)[i])
end

has_leave = false

# Backprop through statements
for v in reverse(keys(b))
ex = b[v].expr
has_leave |= isexpr(ex, :leave)

if haskey(pr.pullbacks, v)
g = push!(rb, stmt(Expr(:call, alpha(pr.pullbacks[v]), grad(v)),
line = b[v].line))
Expand All @@ -321,6 +321,17 @@ function adjoint(pr::Primal)
continue
end
end

# This is corresponds to a catch blocks which technically
# has predecessors but they are not modelled in the IRTools CFG.
# We put an error message at the beginning of said block.
if has_leave && isempty(predecessors(b)) && b.id != 1
_, f_stmt = first(b)
li = pr.ir.lines[f_stmt.line]
pushfirst!(rb, stmt(xcall(Base, :error,
"Can't differentiate function execution in catch block at $(li.file):$(li.line).")))
end

if b.id > 1 # Backprop through (predecessor) branch arguments
gs = grad.(arguments(b))
for br in branches(rb)
Expand Down
97 changes: 97 additions & 0 deletions test/compiler.jl
Original file line number Diff line number Diff line change
Expand Up @@ -245,3 +245,100 @@ end
@test_nowarn g = back(1.)
@test only(g) ∈ (1., 2.)
end

function throws_and_catches_if_x_negative(x,y)
z = x + y
try
if x < 0.
throw(DomainError("x is negative"))
end
z = 2z + x + y
catch err
@error "something went wrong" exception=(err,catch_backtrace())
end
return 3z
end

function try_catch_finally(cond, x)

try
x = 2x
cond && throw(DomainError())
catch
x = 2x
finally
x = 3x
end

x
end

if VERSION >= v"1.8"
# try/catch/else is invalid syntax prior to v1.8
eval(Meta.parse("""
function try_catch_else(cond, x)
x = 2x

try
x = 2x
cond && throw(nothing)
catch
x = 3x
else
x = 2x
end

x
end
"""))
end

@testset "try/catch" begin
@testset "happy path (nothrow)" begin
res, (dx,dy) = withgradient(throws_and_catches_if_x_negative, 1., 2.)
@test res == 3 * (2 * (1. + 2.) + 1. + 2.)
@test dx == 3. * (2. + 1.)
@test dy == 3. * (2. + 1.)
end

@testset "try/catch/finally" begin
res, (_, dx,) = withgradient(try_catch_finally, false, 1.)
@test res == 6.
@test dx == 6.

res, pull = pullback(try_catch_finally, true, 1.)
@test res == 12.
@test_throws ErrorException pull(1.)
err = try pull(1.) catch ex; ex end
@test occursin("Can't differentiate function execution in catch block",
string(err))
end

if VERSION >= v"1.8"
@testset "try/catch/else" begin
@test Zygote.gradient(try_catch_else, false, 1.0) == (nothing, 8.0)
@test_throws "Can't differentiate function execution in catch block" Zygote.gradient(try_catch_else, true, 1.0)
end
end

function foo_try(f)
y = 1
try
y = f()
catch
y
end
y
end

g, = gradient(x -> foo_try(() -> x), 1) # 1
@test g == 1.

vy, pull = pullback(foo_try, () -> 0//0) # bypass because of expr
@test vy === 1
@test_throws ErrorException pull(1.)

err = try pull(1.) catch ex; ex end
@test occursin("Can't differentiate function execution in catch block",
string(err))
end
Pangoraw marked this conversation as resolved.
Show resolved Hide resolved
3 changes: 1 addition & 2 deletions test/features.jl
Original file line number Diff line number Diff line change
Expand Up @@ -416,8 +416,7 @@ function pow_try(x)
end
end

@test_broken gradient(pow_try, 1) == (2,)
@test_throws Zygote.CompileError gradient(pow_try, 1)
@test gradient(pow_try, 1) == (2,)

function pow_simd(x, n)
r = 1
Expand Down
Loading