diff --git a/src/codegen/forward.jl b/src/codegen/forward.jl index 21c4f8ae..a2169466 100644 --- a/src/codegen/forward.jl +++ b/src/codegen/forward.jl @@ -1,17 +1,23 @@ -function transform_fwd!(ci, meth, nargs, sparams, N) +function fwd_transform(ci, args...) + newci = copy(ci) + fwd_transform!(newci, args...) + return newci +end + +function fwd_transform!(ci, mi, nargs, N) new_code = Any[] new_codelocs = Any[] ssa_mapping = Int[] loc_mapping = Int[] - function emit!(stmt) + function emit!(@nospecialize stmt) (isexpr(stmt, :call) || isexpr(stmt, :(=)) || isexpr(stmt, :new)) || return stmt push!(new_code, stmt) push!(new_codelocs, isempty(new_codelocs) ? 0 : new_codelocs[end]) - SSAValue(length(new_code)) + return SSAValue(length(new_code)) end - function mapstmt!(stmt) + function mapstmt!(@nospecialize stmt) if isexpr(stmt, :(=)) return Expr(stmt.head, emit!(mapstmt!(stmt.args[1])), emit!(mapstmt!(stmt.args[2]))) elseif isexpr(stmt, :call) @@ -44,7 +50,7 @@ function transform_fwd!(ci, meth, nargs, sparams, N) elseif isa(stmt, GotoIfNot) return GotoIfNot(emit!(Expr(:call, primal, emit!(mapstmt!(stmt.cond)))), stmt.dest) elseif isexpr(stmt, :static_parameter) - return ZeroBundle{N}(sparams[stmt.args[1]]) + return ZeroBundle{N}(mi.sparam_vals[stmt.args[1]::Int]) elseif isexpr(stmt, :foreigncall) return Expr(:call, error, "Attempted to AD a foreigncall. Missing rule?") elseif isexpr(stmt, :meta) || isexpr(stmt, :inbounds) @@ -56,9 +62,11 @@ function transform_fwd!(ci, meth, nargs, sparams, N) end end - for i = 1:meth.nargs - if meth.isva && i == meth.nargs - args = map(i:(nargs+1)) do j + meth = mi.def::Method + nargs = Int(meth.nargs) + for i = 1:nargs + if meth.isva && i == nargs + args = map(i:(nargs+1)) do j::Int emit!(Expr(:call, getfield, SlotNumber(2), j)) end emit!(Expr(:(=), SlotNumber(2 + i), Expr(:call, ∂vararg{N}(), args...))) @@ -83,7 +91,15 @@ function transform_fwd!(ci, meth, nargs, sparams, N) end end + ci.slotnames = Symbol[Symbol("#self#"), :args, ci.slotnames...] + ci.slotflags = UInt8[0x00, 0x00, ci.slotflags...] + ci.slottypes = ci.slottypes === nothing ? nothing : Any[Any, Any, ci.slottypes...] ci.code = new_code ci.codelocs = new_codelocs - ci + ci.ssavaluetypes = length(new_code) + ci.ssaflags = UInt8[0 for i=1:length(new_code)] + ci.method_for_inference_limit_heuristics = meth + ci.edges = MethodInstance[mi] + + return ci end diff --git a/src/stage1/generated.jl b/src/stage1/generated.jl index bb127c6e..5ccd1a81 100644 --- a/src/stage1/generated.jl +++ b/src/stage1/generated.jl @@ -23,12 +23,7 @@ function perform_optic_transform(world::UInt, source::LineNumberNode, mi = Core.Compiler.specialize_method(match) ci = Core.Compiler.retrieve_code_info(mi, world) - ci′ = copy(ci) - ci′.edges = MethodInstance[mi] - - ci′ = diffract_transform!(ci′, mi.def, length(args) - 1, match.sparams, N) - - return ci′ + return optic_transform(ci, mi, length(args)-1, N) end # This relies on PartialStruct to infer well diff --git a/src/stage1/recurse.jl b/src/stage1/recurse.jl index 478fc174..08571289 100644 --- a/src/stage1/recurse.jl +++ b/src/stage1/recurse.jl @@ -256,8 +256,16 @@ function sptypes(sparams) end end -function diffract_transform!(ci, meth, nargs, sparams, N) +function optic_transform(ci, args...) + newci = copy(ci) + optic_transform!(newci, args...) + return newci +end + +function optic_transform!(ci, mi, nargs, N) code = ci.code + sparams = mi.sparam_vals + cfg = compute_basic_blocks(code) ci.slotnames = Symbol[Symbol("#self#"), :args, ci.slotnames...] ci.slotflags = UInt8[0x00, 0x00, ci.slotflags...] @@ -270,13 +278,15 @@ function diffract_transform!(ci, meth, nargs, sparams, N) Any[Any for i = 1:2], meta, sptypes(sparams)) # SSA conversion + meth = mi.def::Method domtree = construct_domtree(ir.cfg.blocks) defuse_insts = scan_slot_def_use(Int(meth.nargs), ci, ir.stmts.inst) ci.ssavaluetypes = Any[Any for i = 1:ci.ssavaluetypes] ir = construct_ssa!(ci, ir, domtree, defuse_insts, ci.slottypes, Core.Compiler.OptimizerLattice()) ir = compact!(ir) - nfixedargs = meth.isva ? meth.nargs - 1 : meth.nargs + nfixedargs = Int(meth.nargs) + meth.isva && (nfixedargs -= 1) meth.isva || @assert nfixedargs == nargs+1 ir = diffract_ir!(ir, ci, meth, sparams, nargs, N) @@ -286,6 +296,7 @@ function diffract_transform!(ci, meth, nargs, sparams, N) ci.ssavaluetypes = length(ci.code) ci.ssaflags = UInt8[0x00 for i=1:length(ci.code)] ci.method_for_inference_limit_heuristics = meth + ci.edges = MethodInstance[mi] return ci end diff --git a/src/stage1/recurse_fwd.jl b/src/stage1/recurse_fwd.jl index c9c7172b..58572e6c 100644 --- a/src/stage1/recurse_fwd.jl +++ b/src/stage1/recurse_fwd.jl @@ -47,22 +47,7 @@ function perform_fwd_transform(world::UInt, source::LineNumberNode, mi = Core.Compiler.specialize_method(match) ci = Core.Compiler.retrieve_code_info(mi, world) - ci′ = copy(ci) - ci′.edges = MethodInstance[mi] - - transform_fwd!(ci′, mi.def, length(args) - 1, match.sparams, N) - - ci′.ssavaluetypes = length(ci′.code) - ci′.ssaflags = UInt8[0 for i=1:length(ci′.code)] - ci′.method_for_inference_limit_heuristics = match.method - slotnames = Symbol[Symbol("#self#"), :args, ci.slotnames...] - slotflags = UInt8[(0x00 for i = 1:2)..., ci.slotflags...] - slottypes = ci.slottypes === nothing ? nothing : Any[(Any for i = 1:2)..., ci.slottypes...] - ci′.slotnames = slotnames - ci′.slotflags = slotflags - ci′.slottypes = slottypes - - return ci′ + return fwd_transform(ci, mi, length(args)-1, N) end let ex = :(function (ff::∂☆recurse)(args...)