Skip to content

Commit

Permalink
[torchlib] Mark a few ops as traceable (#1889)
Browse files Browse the repository at this point in the history
- pow
- sqrt
- rsqrt
- round
  • Loading branch information
justinchuby authored Oct 3, 2024
1 parent 35fdcf5 commit d68d652
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 5 deletions.
9 changes: 5 additions & 4 deletions onnxscript/function_libs/torch_lib/ops/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -6619,7 +6619,8 @@ def aten_positive(self: TensorType) -> TensorType:
"aten::pow.Tensor_Tensor",
"aten::pow.Tensor_Scalar",
"_operator::pow",
)
),
traceable=True,
)
def aten_pow(self: TReal, exponent: TTensor) -> TReal:
"""pow(Tensor self, Tensor exponent) -> Tensor"""
Expand Down Expand Up @@ -7304,7 +7305,7 @@ def aten_rot90(self: TensorType, k: int = 1, dims: Sequence[int] = (0, 1)) -> Te
raise NotImplementedError()


@torch_op("aten::round")
@torch_op("aten::round", traceable=True)
def aten_round(self: TFloat) -> TFloat:
"""round(Tensor self) -> Tensor"""

Expand Down Expand Up @@ -7353,7 +7354,7 @@ def aten_rrelu(
raise NotImplementedError()


@torch_op("aten::rsqrt")
@torch_op("aten::rsqrt", traceable=True)
def aten_rsqrt(self: TFloatOrBFloat16) -> TFloatOrBFloat16:
"""rsqrt(Tensor self) -> Tensor"""

Expand Down Expand Up @@ -7810,7 +7811,7 @@ def aten_split_with_sizes_copy(
raise NotImplementedError()


@torch_op("aten::sqrt")
@torch_op("aten::sqrt", traceable=True)
def aten_sqrt(self: TFloatOrBFloat16) -> TFloatOrBFloat16:
"""sqrt(Tensor self) -> Tensor"""

Expand Down
1 change: 0 additions & 1 deletion tests/function_libs/torch_lib/ops_test_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -1349,7 +1349,6 @@ def _where_input_wrangler(
.xfail(
variant_name="decimals_0",
reason="This variant does not accept decimals",
test_class_name="TestOutputConsistencyEager",
)
.xfail(
variant_name="decimals_3",
Expand Down

0 comments on commit d68d652

Please sign in to comment.