Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

NFC cosmetic changes #128

Merged
merged 1 commit into from
Mar 28, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 25 additions & 9 deletions src/codegen/forward.jl
Original file line number Diff line number Diff line change
@@ -1,17 +1,23 @@
function transform_fwd!(ci, meth, nargs, sparams, N)
function fwd_transform(ci, args...)
newci = copy(ci)
fwd_transform!(newci, args...)
return newci
end

function fwd_transform!(ci, mi, nargs, N)
new_code = Any[]
new_codelocs = Any[]
ssa_mapping = Int[]
loc_mapping = Int[]

function emit!(stmt)
function emit!(@nospecialize stmt)
(isexpr(stmt, :call) || isexpr(stmt, :(=)) || isexpr(stmt, :new)) || return stmt
push!(new_code, stmt)
push!(new_codelocs, isempty(new_codelocs) ? 0 : new_codelocs[end])
SSAValue(length(new_code))
return SSAValue(length(new_code))
end

function mapstmt!(stmt)
function mapstmt!(@nospecialize stmt)
if isexpr(stmt, :(=))
return Expr(stmt.head, emit!(mapstmt!(stmt.args[1])), emit!(mapstmt!(stmt.args[2])))
elseif isexpr(stmt, :call)
Expand Down Expand Up @@ -44,7 +50,7 @@ function transform_fwd!(ci, meth, nargs, sparams, N)
elseif isa(stmt, GotoIfNot)
return GotoIfNot(emit!(Expr(:call, primal, emit!(mapstmt!(stmt.cond)))), stmt.dest)
elseif isexpr(stmt, :static_parameter)
return ZeroBundle{N}(sparams[stmt.args[1]])
return ZeroBundle{N}(mi.sparam_vals[stmt.args[1]::Int])
elseif isexpr(stmt, :foreigncall)
return Expr(:call, error, "Attempted to AD a foreigncall. Missing rule?")
elseif isexpr(stmt, :meta) || isexpr(stmt, :inbounds)
Expand All @@ -56,9 +62,11 @@ function transform_fwd!(ci, meth, nargs, sparams, N)
end
end

for i = 1:meth.nargs
if meth.isva && i == meth.nargs
args = map(i:(nargs+1)) do j
meth = mi.def::Method
nargs = Int(meth.nargs)
for i = 1:nargs
if meth.isva && i == nargs
args = map(i:(nargs+1)) do j::Int
emit!(Expr(:call, getfield, SlotNumber(2), j))
end
emit!(Expr(:(=), SlotNumber(2 + i), Expr(:call, ∂vararg{N}(), args...)))
Expand All @@ -83,7 +91,15 @@ function transform_fwd!(ci, meth, nargs, sparams, N)
end
end

ci.slotnames = Symbol[Symbol("#self#"), :args, ci.slotnames...]
ci.slotflags = UInt8[0x00, 0x00, ci.slotflags...]
ci.slottypes = ci.slottypes === nothing ? nothing : Any[Any, Any, ci.slottypes...]
ci.code = new_code
ci.codelocs = new_codelocs
ci
ci.ssavaluetypes = length(new_code)
ci.ssaflags = UInt8[0 for i=1:length(new_code)]
ci.method_for_inference_limit_heuristics = meth
ci.edges = MethodInstance[mi]

return ci
end
7 changes: 1 addition & 6 deletions src/stage1/generated.jl
Original file line number Diff line number Diff line change
Expand Up @@ -23,12 +23,7 @@ function perform_optic_transform(world::UInt, source::LineNumberNode,
mi = Core.Compiler.specialize_method(match)
ci = Core.Compiler.retrieve_code_info(mi, world)

ci′ = copy(ci)
ci′.edges = MethodInstance[mi]

ci′ = diffract_transform!(ci′, mi.def, length(args) - 1, match.sparams, N)

return ci′
return optic_transform(ci, mi, length(args)-1, N)
end

# This relies on PartialStruct to infer well
Expand Down
15 changes: 13 additions & 2 deletions src/stage1/recurse.jl
Original file line number Diff line number Diff line change
Expand Up @@ -256,8 +256,16 @@ function sptypes(sparams)
end
end

function diffract_transform!(ci, meth, nargs, sparams, N)
function optic_transform(ci, args...)
newci = copy(ci)
optic_transform!(newci, args...)
return newci
end

function optic_transform!(ci, mi, nargs, N)
code = ci.code
sparams = mi.sparam_vals

cfg = compute_basic_blocks(code)
ci.slotnames = Symbol[Symbol("#self#"), :args, ci.slotnames...]
ci.slotflags = UInt8[0x00, 0x00, ci.slotflags...]
Expand All @@ -270,13 +278,15 @@ function diffract_transform!(ci, meth, nargs, sparams, N)
Any[Any for i = 1:2], meta, sptypes(sparams))

# SSA conversion
meth = mi.def::Method
domtree = construct_domtree(ir.cfg.blocks)
defuse_insts = scan_slot_def_use(Int(meth.nargs), ci, ir.stmts.inst)
ci.ssavaluetypes = Any[Any for i = 1:ci.ssavaluetypes]
ir = construct_ssa!(ci, ir, domtree, defuse_insts, ci.slottypes, Core.Compiler.OptimizerLattice())
ir = compact!(ir)

nfixedargs = meth.isva ? meth.nargs - 1 : meth.nargs
nfixedargs = Int(meth.nargs)
meth.isva && (nfixedargs -= 1)
meth.isva || @assert nfixedargs == nargs+1

ir = diffract_ir!(ir, ci, meth, sparams, nargs, N)
Expand All @@ -286,6 +296,7 @@ function diffract_transform!(ci, meth, nargs, sparams, N)
ci.ssavaluetypes = length(ci.code)
ci.ssaflags = UInt8[0x00 for i=1:length(ci.code)]
ci.method_for_inference_limit_heuristics = meth
ci.edges = MethodInstance[mi]

return ci
end
17 changes: 1 addition & 16 deletions src/stage1/recurse_fwd.jl
Original file line number Diff line number Diff line change
Expand Up @@ -47,22 +47,7 @@ function perform_fwd_transform(world::UInt, source::LineNumberNode,
mi = Core.Compiler.specialize_method(match)
ci = Core.Compiler.retrieve_code_info(mi, world)

ci′ = copy(ci)
ci′.edges = MethodInstance[mi]

transform_fwd!(ci′, mi.def, length(args) - 1, match.sparams, N)

ci′.ssavaluetypes = length(ci′.code)
ci′.ssaflags = UInt8[0 for i=1:length(ci′.code)]
ci′.method_for_inference_limit_heuristics = match.method
slotnames = Symbol[Symbol("#self#"), :args, ci.slotnames...]
slotflags = UInt8[(0x00 for i = 1:2)..., ci.slotflags...]
slottypes = ci.slottypes === nothing ? nothing : Any[(Any for i = 1:2)..., ci.slottypes...]
ci′.slotnames = slotnames
ci′.slotflags = slotflags
ci′.slottypes = slottypes

return ci′
return fwd_transform(ci, mi, length(args)-1, N)
end

let ex = :(function (ff::∂☆recurse)(args...)
Expand Down