Skip to content

Commit

Permalink
Diffractor: Bump to julia 1.10, prepare for stage 2 work (#94)
Browse files Browse the repository at this point in the history
This:
- Bumps the diffractor version to 0.2.0 (in preparation, will be
tagged once new things are in).
- Bumps the julia dependency version to 1.10
- Gets rid of conditional code for older julia code.

This is in preparation for the stage2 code that will depend on lattice
extensions and other 1.10-only features. If there is interest, we
can branch off 2e571fe as release-0.1,
but if nobody complains, I'd rather save the effort to focus on current
development.
  • Loading branch information
Keno committed Dec 27, 2022
1 parent 2e571fe commit d7648f1
Show file tree
Hide file tree
Showing 5 changed files with 15 additions and 34 deletions.
4 changes: 0 additions & 4 deletions .github/workflows/CI.yml
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,6 @@ jobs:
fail-fast: false
matrix:
version:
- '1.7' # Lowest claimed support in Project.toml
# - '1' # Latest Release # Testing on 1.8 gives this message:
# ┌ Warning: ir verification broken. Either use 1.9 or 1.7
# └ @ Diffractor ~/work/Diffractor.jl/Diffractor.jl/src/stage1/recurse.jl:889
- 'nightly'
os:
- ubuntu-latest
Expand Down
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -19,4 +19,4 @@ Combinatorics = "1"
StaticArrays = "1"
StatsBase = "0.33"
StructArrays = "0.6"
julia = "1.7"
julia = "1.10"
2 changes: 1 addition & 1 deletion src/stage1/hacks.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ using Core.Compiler: count_added_node!, NewSSAValue, add_pending!,
StmtRange, BasicBlock

# Re-named in https://github.com/JuliaLang/julia/pull/47051
const add! = VERSION < v"1.9-" ? Core.Compiler.add! : Core.Compiler.add_inst!
const add! = Core.Compiler.add_inst!

Base.length(c::Core.Compiler.NewNodeStream) = Core.Compiler.length(c)
Base.setindex!(i::Instruction, args...) = Core.Compiler.setindex!(i, args...)
Expand Down
29 changes: 7 additions & 22 deletions src/stage1/recurse.jl
Original file line number Diff line number Diff line change
Expand Up @@ -238,13 +238,8 @@ function split_critical_edges!(ir)
end

function make_opaque_closure(typ, name, meth_nargs, isva, lno, cis, revs...)
if VERSION >= v"1.8.0-DEV.1563"
Expr(:new_opaque_closure, typ, Union{}, Any,
Expr(:opaque_closure_method, name, meth_nargs, isva, lno, cis), revs...)
else
Expr(:new_opaque_closure, typ, isva, Union{}, Any,
Expr(:opaque_closure_method, name, meth_nargs, lno, cis), revs...)
end
Expr(:new_opaque_closure, typ, Union{}, Any,
Expr(:opaque_closure_method, name, meth_nargs, isva, lno, cis), revs...)
end

Base.iterate(c::IncrementalCompact, args...) = Core.Compiler.iterate(c, args...)
Expand All @@ -265,21 +260,17 @@ function transform!(ci, meth, nargs, sparams, N)
slotflags = UInt8[(0x00 for i = 1:2)..., ci.slotflags...]
slottypes = UInt8[(0x00 for i = 1:2)..., ci.slotflags...]

meta = VERSION < v"1.9.0-DEV.472" ? Any[] : Expr[]
meta = Expr[]
ir = IRCode(Core.Compiler.InstructionStream(code, Any[],
Any[nothing for i = 1:length(code)],
ci.codelocs, UInt8[0 for i = 1:length(code)]), cfg, Core.LineInfoNode[ci.linetable...],
Any[Any for i = 1:2], meta, Any[sparams...])

# SSA conversion
domtree = construct_domtree(ir.cfg.blocks)
defuse_insts = scan_slot_def_use(VERSION >= v"1.8.0-DEV.267" ? Int(meth.nargs) : meth.nargs-1, ci, ir.stmts.inst)
defuse_insts = scan_slot_def_use(Int(meth.nargs), ci, ir.stmts.inst)
ci.ssavaluetypes = Any[Any for i = 1:ci.ssavaluetypes]
if VERSION >= v"1.8.0-DEV.267"
ir = construct_ssa!(ci, ir, domtree, defuse_insts, Any[Any for i = 1:length(slotnames)])
else
ir = construct_ssa!(ci, ir, domtree, defuse_insts, nargs, Any[Any for i = 1:length(slotnames)])
end
ir = construct_ssa!(ci, ir, domtree, defuse_insts, Any[Any for i = 1:length(slotnames)], Core.Compiler.OptimizerLattice())
ir = compact!(ir)
cfg = ir.cfg

Expand Down Expand Up @@ -842,7 +833,7 @@ function transform!(ci, meth, nargs, sparams, N)
override = false
if has_terminator[active_bb]
terminator = compact[SSAValue(idx)]
terminator = VERSION < v"1.9.0-DEV.739" ? terminator : terminator.inst
terminator = terminator.inst
compact[SSAValue(idx)] = nothing
override = true
end
Expand Down Expand Up @@ -881,13 +872,7 @@ function transform!(ci, meth, nargs, sparams, N)
ir = complete(compact)
#@show ir
ir = compact!(ir)
if VERSION < v"1.8"
Core.Compiler.verify_ir(ir, true)
elseif VERSION >= v"1.9.0-DEV.854"
Core.Compiler.verify_ir(ir, true, true)
else
@warn "ir verification broken. Either use 1.9 or 1.7"
end
Core.Compiler.verify_ir(ir, true, true)

Core.Compiler.replace_code_newstyle!(ci, ir, nargs+1)
ci.ssavaluetypes = length(ci.code)
Expand Down
12 changes: 6 additions & 6 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -93,8 +93,8 @@ let var"'" = Diffractor.PrimeDerivativeBack
@test @inferred(sin''(1.0)) == -sin(1.0)
@test sin'''(1.0) == -cos(1.0)
@test sin''''(1.0) == sin(1.0)
@test sin'''''(1.0) == cos(1.0) # broken = VERSION >= v"1.8"
@test sin''''''(1.0) == -sin(1.0) # broken = VERSION >= v"1.8"
@test sin'''''(1.0) == cos(1.0)
@test sin''''''(1.0) == -sin(1.0)

f_getfield(x) = getfield((x,), 1)
@test f_getfield'(1) == 1
Expand Down Expand Up @@ -229,7 +229,7 @@ z45, delta45 = frule_via_ad(DiffractorRuleConfig(), (0,1), x -> log(exp(x)), 2)

@test gradient(x -> sum((explog).(x)), [1,2,3]) == ([1,1,1],) # frule_via_ad
exp_log(x) = exp(log(x))
@test gradient(x -> sum(exp_log.(x)), [1,2,3]) == ([1,1,1],)
@test gradient(x -> sum(exp_log.(x)), [1,2,3]) == ([1,1,1],)
@test gradient((x,y) -> sum(x ./ y), [1 2; 3 4], [1,2]) == ([1 1; 0.5 0.5], [-3, -1.75])
@test gradient((x,y) -> sum(x ./ y), [1 2; 3 4], 5) == ([0.2 0.2; 0.2 0.2], -0.4)
@test gradient(x -> sum((y -> y/x).([1,2,3])), 4) == (-0.375,) # closure
Expand Down Expand Up @@ -260,7 +260,7 @@ z45, delta45 = frule_via_ad(DiffractorRuleConfig(), (0,1), x -> log(exp(x)), 2)
@test tup_adj[2] [0.6666666666666666 0.5 0.4]
@test tup_adj[2] isa Transpose
@test gradient(x -> sum(atan.(x, (1,2,3))), Diagonal([4,5,6]))[1] isa Diagonal

@test gradient(x -> sum((y -> (x*y)).([1,2,3])), 4.0) == (6.0,) # closure
end

Expand All @@ -272,12 +272,12 @@ end
@test_broken gradient(x -> sum(gradient(x -> sum(x .^ 2 .+ x'), x)[1]), [1,2,3.0])[1] == [6,6,6] # BoundsError: attempt to access 18-element Vector{Core.Compiler.BasicBlock} at index [0]
@test_broken gradient(x -> sum(gradient(x -> sum((x .+ 1) .* x .- x), x)[1]), [1,2,3.0])[1] == [2,2,2]
@test_broken gradient(x -> sum(gradient(x -> sum(x .* x ./ 2), x)[1]), [1,2,3.0])[1] == [1,1,1]

@test_broken gradient(x -> sum(gradient(x -> sum(exp.(x)), x)[1]), [1,2,3])[1] exp.(1:3) # MethodError: no method matching copy(::Nothing)
@test_broken gradient(x -> sum(gradient(x -> sum(atan.(x, x')), x)[1]), [1,2,3.0])[1] [0,0,0]
@test_broken gradient(x -> sum(gradient(x -> sum(transpose(x) .* x), x)[1]), [1,2,3]) == ([6,6,6],) # accum(a::Transpose{Float64, Vector{Float64}}, b::ChainRulesCore.Tangent{Transpose{Int64, Vector{Int64}}, NamedTuple{(:parent,), Tuple{ChainRulesCore.NoTangent}}})
@test_broken gradient(x -> sum(gradient(x -> sum(transpose(x) ./ x.^2), x)[1]), [1,2,3])[1] [27.675925925925927, -0.824074074074074, -2.1018518518518516]

@test_broken gradient(z -> gradient(x -> sum((y -> (x^2*y)).([1,2,3])), z)[1], 5.0) == (12.0,)
end

Expand Down

0 comments on commit d7648f1

Please sign in to comment.