Skip to content

Commit

Permalink
Pass ssa instead of stmt to visit_custom!()
Browse files Browse the repository at this point in the history
We need to pass the SSA index, rather than the statement itself so that
`visit_custom!()` can inspect the type of the statement under
visitation.

This is more consistent with the `transform!()` API as well.
  • Loading branch information
staticfloat committed Jun 27, 2023
1 parent b23337a commit 60a99fe
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 5 deletions.
4 changes: 2 additions & 2 deletions src/codegen/forward_demand.jl
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,7 @@ function forward_visit!(ir::IRCode, ssa::SSAValue, order::Int, ssa_orders::Vecto
inst = ir[ssa]
stmt = inst[:inst]
recurse(@nospecialize(val)) = forward_visit!(ir, val, order, ssa_orders, visit_custom!)
if visit_custom!(ir, stmt, order, recurse)
if visit_custom!(ir, ssa, order, recurse)
ssa_orders[ssa.id] = order => true
return
elseif isa(stmt, PiNode)
Expand Down Expand Up @@ -211,7 +211,7 @@ Internal method which generates the code for forward mode diffentiation
- `to_diff`: collection of all SSA values for which the derivative is to be taken,
paired with the order (first deriviative, second derivative etc)
- `visit_custom!(ir::IRCode, stmt, order::Int, recurse::Bool) -> Bool`:
- `visit_custom!(ir::IRCode, ssa, order::Int, recurse::Bool) -> Bool`:
decides if the custom `transform!` should be applied to a `stmt` or not
Default: `false` for all statements
- `transform!(ir::IRCode, ssa::SSAValue, order::Int)` mutates `ir` to do a custom tranformation.
Expand Down
9 changes: 6 additions & 3 deletions src/stage2/forward.jl
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,15 @@ function dontuse_nth_order_forward_stage2(tt::Type, order::Int=1)
end
end

function visit_custom!(ir::IRCode, @nospecialize(stmt), order, recurse)
function visit_custom!(ir::IRCode, ssa::Union{SSAValue,Argument}, order, recurse)
if isa(ssa, Argument)
return true
end

stmt = ir[ssa][:inst]
if isa(stmt, ReturnNode)
recurse(stmt.val)
return true
elseif isa(stmt, Argument)
return true
else
return false
end
Expand Down

0 comments on commit 60a99fe

Please sign in to comment.