Skip to content

Commit

Permalink
Update handling of batch-norm in DCE optimization (#1591)
Browse files Browse the repository at this point in the history
Addresses Issue #1338
  • Loading branch information
gramalingam authored Jun 7, 2024
1 parent 1c154c9 commit 87618e8
Show file tree
Hide file tree
Showing 2 changed files with 50 additions and 4 deletions.
20 changes: 16 additions & 4 deletions onnxscript/optimizer/remove_unused.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,11 +26,23 @@ def remove_unused_optional_outputs(
op_schema = onnx.defs.get_schema(n.op_type, onnx_opset_version, domain=n.domain)
except Exception:
return
# TODO: If current node is a BatchNormalization node,
# based on training_mode atrribute, number of optional outputs and
# how they are handled varies, handle both training_modes

if n.op_type == "BatchNormalization":
return
# BatchNormalization op has 3 outputs: Y, running_mean, running_var
# If running_mean and running_var are not used, remove them, and the training_mode attribute
def is_used_output(i: int) -> bool:
if i < len(n.output):
return n.output[i] in used
return False

if is_used_output(1) or is_used_output(2):
return
del n.output[1:]
for j, attr in enumerate(n.attribute):
if attr.name == "training_mode":
del n.attribute[j]
break

optional_info = []
for o in op_schema.outputs:
# Current ops do not have optional outputs if they have variable number of outputs
Expand Down
34 changes: 34 additions & 0 deletions onnxscript/optimizer/remove_unused_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,40 @@ def test_avoid_remove_non_trailing_unused_optional_outputs_layernorm(self):
self.assertEqual(model.graph.node[2].op_type, "LayerNormalization")
self.assertEqual(len(model.graph.node[2].output), 3)

def test_remove_trailing_unused_optional_outputs_batchnorm(self):
model = onnx.parser.parse_model(
"""
<ir_version: 7, opset_import: [ "" : 17]>
agraph (float[1, 3, 5, 5] x, float[3] scale, float[3] B) => (float[1, 3, 5, 5] z) {
z, mean_out, var_out = BatchNormalization <training_mode=1> (x, scale, B, mean, var)
}
"""
)
self.assertEqual(len(model.graph.node[0].attribute), 1)
optimizer.remove_unused_nodes(model)
self.assertEqual(len(model.graph.node), 1)
self.assertEqual(model.graph.node[0].op_type, "BatchNormalization")
# Check that both the mean/var outputs are removed, and training_mode attribute is removed.
self.assertEqual(len(model.graph.node[0].output), 1)
self.assertEqual(len(model.graph.node[0].attribute), 0)

def test_avoid_remove_used_optional_outputs_batchnorm(self):
model = onnx.parser.parse_model(
"""
<ir_version: 7, opset_import: [ "" : 17]>
agraph (float[1, 3, 5, 5] x, float[3] scale, float[3] B) => (float[1, 3, 5, 5] z, float[3] mean_out) {
z, mean_out, var_out = BatchNormalization <training_mode=1> (x, scale, B, mean, var)
}
"""
)
self.assertEqual(len(model.graph.node[0].attribute), 1)
optimizer.remove_unused_nodes(model)
self.assertEqual(len(model.graph.node), 1)
self.assertEqual(model.graph.node[0].op_type, "BatchNormalization")
# Check that the mean/var outputs are NOT removed, and training_mode attribute is NOT removed.
self.assertEqual(len(model.graph.node[0].output), 3)
self.assertEqual(len(model.graph.node[0].attribute), 1)


if __name__ == "__main__":
unittest.main()

0 comments on commit 87618e8

Please sign in to comment.