Skip to content

Commit

Permalink
Add utility to remove unused values/nodes in IR (#1617)
Browse files Browse the repository at this point in the history
Mostly a re-implementation of the existing proto-based optimization to
remove unused-values/nodes to use the IR.
  • Loading branch information
gramalingam authored Jun 15, 2024
1 parent 4a9b04e commit dc31a6e
Show file tree
Hide file tree
Showing 5 changed files with 297 additions and 156 deletions.
141 changes: 8 additions & 133 deletions onnxscript/optimizer/remove_unused.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,140 +2,15 @@
# Licensed under the MIT License.
from __future__ import annotations

import logging
from typing import Sequence

import onnx
from google.protobuf.internal.containers import ( # type: ignore
RepeatedCompositeFieldContainer,
)

logger = logging.getLogger(__name__)


def remove_unused_optional_outputs(
n: onnx.NodeProto, used: set, opset_import: Sequence[onnx.OperatorSetIdProto]
) -> None:
try:
if n.domain not in {"", "onnx.ai"}:
return
onnx_opset_version = 1
for opset in opset_import:
if opset.domain == n.domain:
onnx_opset_version = opset.version
op_schema = onnx.defs.get_schema(n.op_type, onnx_opset_version, domain=n.domain)
except Exception:
return

if n.op_type == "BatchNormalization":
# BatchNormalization op has 3 outputs: Y, running_mean, running_var
# If running_mean and running_var are not used, remove them, and the training_mode attribute
def is_used_output(i: int) -> bool:
if i < len(n.output):
return n.output[i] in used
return False

if is_used_output(1) or is_used_output(2):
return
del n.output[1:]
for j, attr in enumerate(n.attribute):
if attr.name == "training_mode":
del n.attribute[j]
break

optional_info = []
for o in op_schema.outputs:
# Current ops do not have optional outputs if they have variable number of outputs
if o.option == onnx.defs.OpSchema.FormalParameterOption.Variadic:
return
optional_info.append(o.option == onnx.defs.OpSchema.FormalParameterOption.Optional)
# If no optional outputs in spec, skip delete operations
if len([o == 1 for o in optional_info]) == 0:
return

for i, out in enumerate(n.output):
if out not in used and optional_info[i] is True:
n.output[i] = ""
# Only delete trailing unused optional outputs
for o in n.output[::-1]: # type: ignore[assignment]
if o == "":
n.output.pop()
else:
return


def compute_used_in_node(n: onnx.NodeProto) -> set[str]:
used = {n for n in n.input if n != ""}
for attr in n.attribute:
if attr.HasField("g"):
used |= compute_used_in_graph(attr.g)
elif len(attr.graphs) > 0:
for graph in attr.graphs:
used |= compute_used_in_graph(graph)
return used


def compute_used_in_graph(g: onnx.GraphProto) -> set[str]:
used = set()
for n in g.node:
used |= compute_used_in_node(n)
return used


def process_nodes(
nodes: RepeatedCompositeFieldContainer[onnx.NodeProto],
used: set,
opset_import: Sequence[onnx.OperatorSetIdProto],
) -> int:
count = 0
i = len(nodes) - 1
while i >= 0:
node = nodes[i]
remove_unused_optional_outputs(node, used, opset_import)
used_outputs = [x for x in node.output if x in used]
if not used_outputs:
del nodes[i]
count += 1
i -= 1
continue
for attr in node.attribute:
if attr.HasField("g"):
process_graph(attr.g, opset_import)
elif len(attr.graphs) > 0:
for graph in attr.graphs:
process_graph(graph, opset_import)
used |= compute_used_in_node(node)
i -= 1
return count


def process_graph(
graph: onnx.GraphProto, opset_import: Sequence[onnx.OperatorSetIdProto]
) -> int:
used = {output.name for output in graph.output}

count = process_nodes(graph.node, used, opset_import)

for i in range(len(graph.initializer) - 1, -1, -1):
if graph.initializer[i].name not in used:
del graph.initializer[i]
count += 1

return count


def process_function(
function: onnx.FunctionProto, opset_import: Sequence[onnx.OperatorSetIdProto]
) -> int:
used = set(function.output)

return process_nodes(function.node, used, opset_import)

import onnxscript.optimizer.remove_unused_ir
import onnxscript.optimizer.remove_unused_proto
from onnxscript import ir

def remove_unused_nodes(model: onnx.ModelProto) -> None:
"""Removes unused nodes from the model."""
count = process_graph(model.graph, model.opset_import)
for function in model.functions:
count += process_function(function, model.opset_import)

logger.info("Removed %s unused nodes", count)
def remove_unused_nodes(model: ir.Model | onnx.ModelProto) -> None:
if isinstance(model, ir.Model):
onnxscript.optimizer.remove_unused_ir.remove_unused_nodes(model)
else:
onnxscript.optimizer.remove_unused_proto.remove_unused_nodes(model)
93 changes: 93 additions & 0 deletions onnxscript/optimizer/remove_unused_ir.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
from __future__ import annotations

import logging

import onnx

from onnxscript import ir

logger = logging.getLogger(__name__)


def remove_unused_optional_outputs(
node: ir.Node, graph_outputs: frozenset[ir.Value], onnx_opset_version: int
) -> None:
try:
if node.domain not in {"", "onnx.ai"}:
return
op_schema = onnx.defs.get_schema(node.op_type, onnx_opset_version, domain=node.domain)
except Exception:
return

if node.op_type == "BatchNormalization":
# BatchNormalization op has 3 outputs: Y, running_mean, running_var
# If running_mean and running_var are not used, remove them, and the training_mode attribute
def is_used_output(i: int) -> bool:
if i < len(node.outputs):
val = node.outputs[i]
return val in graph_outputs or bool(val.uses())
return False

if is_used_output(1) or is_used_output(2):
return
node.outputs[1].name = ""
node.outputs[2].name = ""
node.attributes.pop("training_mode", None)
return

optional_info = []
for o in op_schema.outputs:
# Current ops do not have optional outputs if they have variable number of outputs
if o.option == onnx.defs.OpSchema.FormalParameterOption.Variadic:
return
optional_info.append(o.option == onnx.defs.OpSchema.FormalParameterOption.Optional)
# If no optional outputs in spec, skip delete operations
if len([o == 1 for o in optional_info]) == 0:
return

for i, out in enumerate(node.outputs):
if out not in graph_outputs and (not out.uses()) and optional_info[i] is True:
out.name = ""


def process_function_or_graph(function_or_graph: ir.Function | ir.Graph) -> int:
graph_outputs = frozenset(function_or_graph.outputs)
onnx_opset_version = function_or_graph.opset_imports.get("", None)
count = 0
for node in reversed(function_or_graph):
removable = True
for output in node.outputs:
if output in graph_outputs or output.uses():
removable = False
break
if removable:
function_or_graph.remove(node, safe=True)
count += 1
else:
if onnx_opset_version is not None:
remove_unused_optional_outputs(node, graph_outputs, onnx_opset_version)
for attr in node.attributes.values():
if isinstance(attr, ir.AttrGraph):
count += process_function_or_graph(attr.value)
elif isinstance(attr, ir.AttrGraphs):
for graph in attr.value:
count += process_function_or_graph(graph)
return count


def remove_unused_nodes(model: ir.Model) -> None:
"""Removes unused nodes from the model."""
count = process_function_or_graph(model.graph)
graph_outputs = frozenset(model.graph.outputs)
initializers = model.graph.initializers
for init in list(initializers.values()):
if not (init in graph_outputs or init.uses()):
del initializers[init.name] # type: ignore[arg-type]
count += 1

for function in model.functions.values():
count += process_function_or_graph(function)

logger.info("Removed %s unused nodes", count)
141 changes: 141 additions & 0 deletions onnxscript/optimizer/remove_unused_proto.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,141 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
from __future__ import annotations

import logging
from typing import Sequence

import onnx
from google.protobuf.internal.containers import ( # type: ignore
RepeatedCompositeFieldContainer,
)

logger = logging.getLogger(__name__)


def remove_unused_optional_outputs(
n: onnx.NodeProto, used: set, opset_import: Sequence[onnx.OperatorSetIdProto]
) -> None:
try:
if n.domain not in {"", "onnx.ai"}:
return
onnx_opset_version = 1
for opset in opset_import:
if opset.domain == n.domain:
onnx_opset_version = opset.version
op_schema = onnx.defs.get_schema(n.op_type, onnx_opset_version, domain=n.domain)
except Exception:
return

if n.op_type == "BatchNormalization":
# BatchNormalization op has 3 outputs: Y, running_mean, running_var
# If running_mean and running_var are not used, remove them, and the training_mode attribute
def is_used_output(i: int) -> bool:
if i < len(n.output):
return n.output[i] in used
return False

if is_used_output(1) or is_used_output(2):
return
del n.output[1:]
for j, attr in enumerate(n.attribute):
if attr.name == "training_mode":
del n.attribute[j]
break

optional_info = []
for o in op_schema.outputs:
# Current ops do not have optional outputs if they have variable number of outputs
if o.option == onnx.defs.OpSchema.FormalParameterOption.Variadic:
return
optional_info.append(o.option == onnx.defs.OpSchema.FormalParameterOption.Optional)
# If no optional outputs in spec, skip delete operations
if len([o == 1 for o in optional_info]) == 0:
return

for i, out in enumerate(n.output):
if out not in used and optional_info[i] is True:
n.output[i] = ""
# Only delete trailing unused optional outputs
for o in n.output[::-1]: # type: ignore[assignment]
if o == "":
n.output.pop()
else:
return


def compute_used_in_node(n: onnx.NodeProto) -> set[str]:
used = {n for n in n.input if n != ""}
for attr in n.attribute:
if attr.HasField("g"):
used |= compute_used_in_graph(attr.g)
elif len(attr.graphs) > 0:
for graph in attr.graphs:
used |= compute_used_in_graph(graph)
return used


def compute_used_in_graph(g: onnx.GraphProto) -> set[str]:
used = set()
for n in g.node:
used |= compute_used_in_node(n)
return used


def process_nodes(
nodes: RepeatedCompositeFieldContainer[onnx.NodeProto],
used: set,
opset_import: Sequence[onnx.OperatorSetIdProto],
) -> int:
count = 0
i = len(nodes) - 1
while i >= 0:
node = nodes[i]
remove_unused_optional_outputs(node, used, opset_import)
used_outputs = [x for x in node.output if x in used]
if not used_outputs:
del nodes[i]
count += 1
i -= 1
continue
for attr in node.attribute:
if attr.HasField("g"):
process_graph(attr.g, opset_import)
elif len(attr.graphs) > 0:
for graph in attr.graphs:
process_graph(graph, opset_import)
used |= compute_used_in_node(node)
i -= 1
return count


def process_graph(
graph: onnx.GraphProto, opset_import: Sequence[onnx.OperatorSetIdProto]
) -> int:
used = {output.name for output in graph.output}

count = process_nodes(graph.node, used, opset_import)

for i in range(len(graph.initializer) - 1, -1, -1):
if graph.initializer[i].name not in used:
del graph.initializer[i]
count += 1

return count


def process_function(
function: onnx.FunctionProto, opset_import: Sequence[onnx.OperatorSetIdProto]
) -> int:
used = set(function.output)

return process_nodes(function.node, used, opset_import)


def remove_unused_nodes(model: onnx.ModelProto) -> None:
"""Removes unused nodes from the model."""
count = process_graph(model.graph, model.opset_import)
for function in model.functions:
count += process_function(function, model.opset_import)

logger.info("Removed %s unused nodes", count)
Loading

0 comments on commit dc31a6e

Please sign in to comment.