diff --git a/onnxscript/function_libs/torch_lib/ops/nn.py b/onnxscript/function_libs/torch_lib/ops/nn.py index ea16f4c37..46205f296 100644 --- a/onnxscript/function_libs/torch_lib/ops/nn.py +++ b/onnxscript/function_libs/torch_lib/ops/nn.py @@ -2209,6 +2209,7 @@ def _get_upsample_align_corners_mode(align_corners: bool) -> str: "aten::upsample_nearest1d", "aten::upsample_nearest2d", "aten::upsample_nearest3d", + "aten::upsample_trilinear3d", ), private=True, ) @@ -2528,6 +2529,33 @@ def aten_upsample_trilinear3d( ) +@torch_op("aten::upsample_trilinear3d.vec", trace_only=True) +def aten_upsample_trilinear3d_vec( + self: TReal, + output_size: INT64, + align_corners: bool, + scale_factors: Optional[Sequence[float]], +) -> TReal: + """upsample_trilinear3d.vec(Tensor input, SymInt[]? output_size, bool align_corners, float[]? scale_factors) -> Tensor""" + + coordinate_transformation_mode = _get_upsample_align_corners_mode(align_corners) + if scale_factors is not None: + result = _aten_upsample_scales( + self, + op.Constant(value_floats=scale_factors), + mode="linear", + coordinate_transformation_mode=coordinate_transformation_mode, + ) + else: + result = _aten_upsample_output_size( + self, + output_size, + mode="linear", + coordinate_transformation_mode=coordinate_transformation_mode, + ) + return result + + def aten_upsample_trilinear3d_backward( grad_output: TensorType, output_size: INT64, diff --git a/tests/function_libs/torch_lib/extra_opinfo.py b/tests/function_libs/torch_lib/extra_opinfo.py index de67909e2..d61803e30 100644 --- a/tests/function_libs/torch_lib/extra_opinfo.py +++ b/tests/function_libs/torch_lib/extra_opinfo.py @@ -1828,6 +1828,58 @@ def shape(size, rank, with_batch_channel=True): ) +def sample_inputs_upsample_trilinear3d_vec(op_info, device, dtype, requires_grad, **kwargs): + del op_info + del kwargs + + N, C = 2, 3 + D = 4 + SS = 3 + L = 5 + + align_corners_options = (True, False) + rank = 3 + + def shape(size, rank, with_batch_channel=True): + if with_batch_channel: + return tuple([N, C] + ([size] * rank)) + return tuple([size] * rank) + + make_arg = functools.partial( + torch_testing.make_tensor, + device=device, + dtype=dtype, + requires_grad=requires_grad, + low=-1, + high=1, + ) + + yield opinfo_core.SampleInput(make_arg(shape(D, rank)), shape(SS, rank, False), True, None) + + for align_corners in align_corners_options: + yield opinfo_core.SampleInput( + make_arg(shape(D, rank)), shape(S, rank, False), align_corners, None + ) + yield opinfo_core.SampleInput( + make_arg(shape(D, rank)), shape(L, rank, False), align_corners, None + ) + yield opinfo_core.SampleInput( + make_arg(shape(D, rank)), + args=(None, align_corners), + kwargs=dict(scale_factors=(1.7, 1.7, 1.7)), + ) + yield opinfo_core.SampleInput( + make_arg(shape(D, rank)), + args=(None, align_corners), + kwargs=dict(scale_factors=(0.6, 0.6, 0.6)), + ) + yield opinfo_core.SampleInput( + make_arg(shape(D, rank)), + args=(None, align_corners), + kwargs=dict(scale_factors=(0.6, 1.7, 4.2)), + ) + + class _TestParamsMaxPoolEmptyStrideBase: # Adapted from https://github.com/pytorch/pytorch/blob/d6d55f8590eab05d2536756fb4efcfb2d07eb81a/torch/testing/_internal/common_methods_invocations.py#L3203 def __init__(self): @@ -2373,12 +2425,19 @@ def __init__(self): supports_out=False, ), opinfo_core.OpInfo( - "ops.aten.upsample_trilinear3d", + "ops.aten.upsample_trilinear3d.default", aten_name="upsample_trilinear3d", dtypes=common_dtype.floating_types_and(torch.bfloat16), sample_inputs_func=sample_inputs_upsample_trilinear3d, supports_out=False, ), + opinfo_core.OpInfo( + "ops.aten.upsample_trilinear3d.vec", + aten_name="upsample_trilinear3d.vec", + dtypes=common_dtype.floating_types_and(torch.bfloat16), + sample_inputs_func=sample_inputs_upsample_trilinear3d_vec, + supports_out=False, + ), opinfo_core.OpInfo( "nn.functional.max_pool1d_with_indices", aten_name="max_pool1d_with_indices", diff --git a/tests/function_libs/torch_lib/ops_test_data.py b/tests/function_libs/torch_lib/ops_test_data.py index 5aa78cc11..1fac7dd42 100644 --- a/tests/function_libs/torch_lib/ops_test_data.py +++ b/tests/function_libs/torch_lib/ops_test_data.py @@ -2244,10 +2244,15 @@ def _where_input_wrangler( trace_only=True, ), TorchLibOpInfo( - "ops.aten.upsample_trilinear3d", + "ops.aten.upsample_trilinear3d.default", nn_ops.aten_upsample_trilinear3d, trace_only=True, ), + TorchLibOpInfo( + "ops.aten.upsample_trilinear3d.vec", + nn_ops.aten_upsample_trilinear3d_vec, + trace_only=True, + ), TorchLibOpInfo("ones_like", core_ops.aten_ones_like, trace_only=True), TorchLibOpInfo( "roll",