Skip to content

Commit

Permalink
FixOp (argmax, argmin) | feat(torchlib) (#1594)
Browse files Browse the repository at this point in the history
aten_argmax and aten_argmin have "dim=None" as default. 


https://github.com/pytorch/pytorch/blob/2369c719d485af0787d95668947125a5605bed88/aten/src/ATen/native/native_functions.yaml#L810

Previous to "trace all traceable functions" PR, scripted function
manages to handle unamtched attributes if they are None, but in traced
function, this becomes errors of unrecognized arguments to the function.
  • Loading branch information
titaiwangms authored Jun 9, 2024
1 parent 87618e8 commit 4c3a6be
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 48 deletions.
42 changes: 34 additions & 8 deletions onnxscript/function_libs/torch_lib/ops/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -693,8 +693,21 @@ def aten_arctanh(self: TensorType) -> TensorType:
raise NotImplementedError()


@torch_op("aten::argmax", traceable=True)
def aten_argmax(self: Union[RealType, UINT8], keepdim: bool = False) -> INT64:
@torch_op("aten::argmax", trace_only=True)
def aten_argmax(
self: Union[RealType, UINT8], dim: Optional[int] = None, keepdim: bool = False
) -> INT64:
"""argmax(Tensor self, int? dim=None, bool keepdim=False) -> Tensor"""

if dim is None:
result = _aten_argmax(self, keepdim)
else:
result = _aten_argmax_dim(self, dim, keepdim)
return result


@torch_op("aten::argmax", private=True, traceable=True)
def _aten_argmax(self: Union[RealType, UINT8], keepdim: bool = False) -> INT64:
"""argmax(Tensor self, int? dim=None, bool keepdim=False) -> Tensor"""

self_is_scaler = IsScalar(self)
Expand All @@ -706,8 +719,8 @@ def aten_argmax(self: Union[RealType, UINT8], keepdim: bool = False) -> INT64:
return result


@torch_op("aten::argmax", traceable=True)
def aten_argmax_dim(self: Union[RealType, UINT8], dim: int, keepdim: bool = False) -> INT64:
@torch_op("aten::argmax", private=True, traceable=True)
def _aten_argmax_dim(self: Union[RealType, UINT8], dim: int, keepdim: bool = False) -> INT64:
"""argmax(Tensor self, int? dim=None, bool keepdim=False) -> Tensor"""

self_is_scaler = IsScalar(self)
Expand All @@ -721,8 +734,21 @@ def aten_argmax_dim(self: Union[RealType, UINT8], dim: int, keepdim: bool = Fals
return result


@torch_op("aten::argmin", traceable=True)
def aten_argmin(self: Union[RealType, UINT8], keepdim: bool = False) -> INT64:
@torch_op("aten::argmin", trace_only=True)
def aten_argmin(
self: Union[RealType, UINT8], dim: Optional[int] = None, keepdim: bool = False
) -> INT64:
"""argmax(Tensor self, int? dim=None, bool keepdim=False) -> Tensor"""

if dim is None:
result = _aten_argmin(self, keepdim)
else:
result = _aten_argmin_dim(self, dim, keepdim)
return result


@torch_op("aten::argmin", private=True, traceable=True)
def _aten_argmin(self: Union[RealType, UINT8], keepdim: bool = False) -> INT64:
"""argmin(Tensor self, int? dim=None, bool keepdim=False) -> Tensor"""

self_is_scaler = IsScalar(self)
Expand All @@ -734,8 +760,8 @@ def aten_argmin(self: Union[RealType, UINT8], keepdim: bool = False) -> INT64:
return result


@torch_op("aten::argmin", traceable=True)
def aten_argmin_dim(self: Union[RealType, UINT8], dim: int, keepdim: bool = False) -> INT64:
@torch_op("aten::argmin", private=True, traceable=True)
def _aten_argmin_dim(self: Union[RealType, UINT8], dim: int, keepdim: bool = False) -> INT64:
"""argmin(Tensor self, int? dim=None, bool keepdim=False) -> Tensor"""

self_is_scaler = IsScalar(self)
Expand Down
42 changes: 2 additions & 40 deletions tests/function_libs/torch_lib/ops_test_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -1688,25 +1688,7 @@ def _where_input_wrangler(
matcher=lambda sample: sample.kwargs.get("end") is not None,
reason="arange overload does not support positional 'end' argument",
),
TorchLibOpInfo("argmax", core_ops.aten_argmax)
.skip(
matcher=lambda sample: "dim" in sample.kwargs,
reason="this overload does not support the 'dim' attribute by design",
)
.skip(
matcher=lambda sample: len(sample.input.shape) == 0,
enabled_if=version_utils.onnxruntime_older_than("1.16"),
reason="fixme (core dump): ORT aborts on scalar inputs to Reduce*-18. https://github.com/microsoft/onnxruntime/issues/16492",
)
.xfail(
dtypes=(torch.int64,),
reason="fixme: ORT did not implement ArgMax for int64. https://github.com/microsoft/onnxruntime/issues/16654",
),
TorchLibOpInfo("argmax_dim", core_ops.aten_argmax_dim)
.xfail(
matcher=lambda sample: "dim" not in sample.kwargs,
reason="this overload requires the 'dim' attribute by design",
)
TorchLibOpInfo("argmax", core_ops.aten_argmax, trace_only=True)
.skip(
matcher=lambda sample: len(sample.input.shape) == 0,
enabled_if=version_utils.onnxruntime_older_than("1.16"),
Expand All @@ -1716,25 +1698,7 @@ def _where_input_wrangler(
dtypes=(torch.int64,),
reason="fixme: ORT did not implement ArgMax for int64. https://github.com/microsoft/onnxruntime/issues/16654",
),
TorchLibOpInfo("argmin", core_ops.aten_argmin)
.skip(
matcher=lambda sample: "dim" in sample.kwargs,
reason="this overload does not support the 'dim' attribute by design",
)
.skip(
matcher=lambda sample: len(sample.input.shape) == 0,
enabled_if=version_utils.onnxruntime_older_than("1.16"),
reason="fixme (core dump): ORT aborts on scalar inputs to Reduce*-18. https://github.com/microsoft/onnxruntime/issues/16492",
)
.xfail(
dtypes=(torch.int64,),
reason="fixme: ORT did not implement ArgMin for int64. https://github.com/microsoft/onnxruntime/issues/16654",
),
TorchLibOpInfo("argmin_dim", core_ops.aten_argmin_dim)
.xfail(
matcher=lambda sample: "dim" not in sample.kwargs,
reason="this overload requires the 'dim' attribute by design",
)
TorchLibOpInfo("argmin", core_ops.aten_argmin, trace_only=True)
.skip(
matcher=lambda sample: len(sample.input.shape) == 0,
enabled_if=version_utils.onnxruntime_older_than("1.16"),
Expand Down Expand Up @@ -2399,8 +2363,6 @@ def _where_input_wrangler(
ops_test_common.duplicate_opinfo(OPS_DB, "all", ("all_dim", "all_dims"))
ops_test_common.duplicate_opinfo(OPS_DB, "any", ("any_dim", "any_dims"))
ops_test_common.duplicate_opinfo(OPS_DB, "arange", ("arange_start", "arange_start_step"))
ops_test_common.duplicate_opinfo(OPS_DB, "argmax", ("argmax_dim",))
ops_test_common.duplicate_opinfo(OPS_DB, "argmin", ("argmin_dim",))
ops_test_common.duplicate_opinfo(OPS_DB, "atleast_1d", ("atleast_1d_Sequence",))
ops_test_common.duplicate_opinfo(OPS_DB, "atleast_2d", ("atleast_2d_Sequence",))
ops_test_common.duplicate_opinfo(OPS_DB, "atleast_3d", ("atleast_3d_Sequence",))
Expand Down

0 comments on commit 4c3a6be

Please sign in to comment.