Skip to content

Commit

Permalink
Add Op (upsample_trilinear_vec) | feat(torchlib) (#1592)
Browse files Browse the repository at this point in the history
  • Loading branch information
titaiwangms authored Jun 6, 2024
1 parent 87b3006 commit 1c154c9
Show file tree
Hide file tree
Showing 3 changed files with 94 additions and 2 deletions.
28 changes: 28 additions & 0 deletions onnxscript/function_libs/torch_lib/ops/nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -2209,6 +2209,7 @@ def _get_upsample_align_corners_mode(align_corners: bool) -> str:
"aten::upsample_nearest1d",
"aten::upsample_nearest2d",
"aten::upsample_nearest3d",
"aten::upsample_trilinear3d",
),
private=True,
)
Expand Down Expand Up @@ -2528,6 +2529,33 @@ def aten_upsample_trilinear3d(
)


@torch_op("aten::upsample_trilinear3d.vec", trace_only=True)
def aten_upsample_trilinear3d_vec(
self: TReal,
output_size: INT64,
align_corners: bool,
scale_factors: Optional[Sequence[float]],
) -> TReal:
"""upsample_trilinear3d.vec(Tensor input, SymInt[]? output_size, bool align_corners, float[]? scale_factors) -> Tensor"""

coordinate_transformation_mode = _get_upsample_align_corners_mode(align_corners)
if scale_factors is not None:
result = _aten_upsample_scales(
self,
op.Constant(value_floats=scale_factors),
mode="linear",
coordinate_transformation_mode=coordinate_transformation_mode,
)
else:
result = _aten_upsample_output_size(
self,
output_size,
mode="linear",
coordinate_transformation_mode=coordinate_transformation_mode,
)
return result


def aten_upsample_trilinear3d_backward(
grad_output: TensorType,
output_size: INT64,
Expand Down
61 changes: 60 additions & 1 deletion tests/function_libs/torch_lib/extra_opinfo.py
Original file line number Diff line number Diff line change
Expand Up @@ -1828,6 +1828,58 @@ def shape(size, rank, with_batch_channel=True):
)


def sample_inputs_upsample_trilinear3d_vec(op_info, device, dtype, requires_grad, **kwargs):
del op_info
del kwargs

N, C = 2, 3
D = 4
SS = 3
L = 5

align_corners_options = (True, False)
rank = 3

def shape(size, rank, with_batch_channel=True):
if with_batch_channel:
return tuple([N, C] + ([size] * rank))
return tuple([size] * rank)

make_arg = functools.partial(
torch_testing.make_tensor,
device=device,
dtype=dtype,
requires_grad=requires_grad,
low=-1,
high=1,
)

yield opinfo_core.SampleInput(make_arg(shape(D, rank)), shape(SS, rank, False), True, None)

for align_corners in align_corners_options:
yield opinfo_core.SampleInput(
make_arg(shape(D, rank)), shape(S, rank, False), align_corners, None
)
yield opinfo_core.SampleInput(
make_arg(shape(D, rank)), shape(L, rank, False), align_corners, None
)
yield opinfo_core.SampleInput(
make_arg(shape(D, rank)),
args=(None, align_corners),
kwargs=dict(scale_factors=(1.7, 1.7, 1.7)),
)
yield opinfo_core.SampleInput(
make_arg(shape(D, rank)),
args=(None, align_corners),
kwargs=dict(scale_factors=(0.6, 0.6, 0.6)),
)
yield opinfo_core.SampleInput(
make_arg(shape(D, rank)),
args=(None, align_corners),
kwargs=dict(scale_factors=(0.6, 1.7, 4.2)),
)


class _TestParamsMaxPoolEmptyStrideBase:
# Adapted from https://github.com/pytorch/pytorch/blob/d6d55f8590eab05d2536756fb4efcfb2d07eb81a/torch/testing/_internal/common_methods_invocations.py#L3203
def __init__(self):
Expand Down Expand Up @@ -2373,12 +2425,19 @@ def __init__(self):
supports_out=False,
),
opinfo_core.OpInfo(
"ops.aten.upsample_trilinear3d",
"ops.aten.upsample_trilinear3d.default",
aten_name="upsample_trilinear3d",
dtypes=common_dtype.floating_types_and(torch.bfloat16),
sample_inputs_func=sample_inputs_upsample_trilinear3d,
supports_out=False,
),
opinfo_core.OpInfo(
"ops.aten.upsample_trilinear3d.vec",
aten_name="upsample_trilinear3d.vec",
dtypes=common_dtype.floating_types_and(torch.bfloat16),
sample_inputs_func=sample_inputs_upsample_trilinear3d_vec,
supports_out=False,
),
opinfo_core.OpInfo(
"nn.functional.max_pool1d_with_indices",
aten_name="max_pool1d_with_indices",
Expand Down
7 changes: 6 additions & 1 deletion tests/function_libs/torch_lib/ops_test_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -2244,10 +2244,15 @@ def _where_input_wrangler(
trace_only=True,
),
TorchLibOpInfo(
"ops.aten.upsample_trilinear3d",
"ops.aten.upsample_trilinear3d.default",
nn_ops.aten_upsample_trilinear3d,
trace_only=True,
),
TorchLibOpInfo(
"ops.aten.upsample_trilinear3d.vec",
nn_ops.aten_upsample_trilinear3d_vec,
trace_only=True,
),
TorchLibOpInfo("ones_like", core_ops.aten_ones_like, trace_only=True),
TorchLibOpInfo(
"roll",
Expand Down

0 comments on commit 1c154c9

Please sign in to comment.