diff --git a/onnxscript/function_libs/torch_lib/ops/nn.py b/onnxscript/function_libs/torch_lib/ops/nn.py index 9a105482c..731d12370 100644 --- a/onnxscript/function_libs/torch_lib/ops/nn.py +++ b/onnxscript/function_libs/torch_lib/ops/nn.py @@ -2197,85 +2197,99 @@ def aten_unflatten_dense_tensors( raise NotImplementedError() -@torch_op(("aten::upsample_bicubic2d", "aten::upsample_bicubic2d.vec"), trace_only=True) -def aten_upsample_bicubic2d( - self: TReal, - output_size: INT64, - align_corners: bool, - scale_factors: Optional[TFloat] = None, -) -> TReal: - """upsample_bicubic2d.vec(Tensor input, SymInt[]? output_size, bool align_corners, float[]? scale_factors) -> Tensor - upsample_bicubic2d(Tensor self, SymInt[2] output_size, bool align_corners, float? scales_h=None, float? scales_w=None) -> Tensor - """ +def _get_upsample_align_corners_mode(align_corners: bool) -> str: + return "align_corners" if align_corners else "pytorch_half_pixel" - if output_size is not None: - result = _aten_upsample_output_size(self, output_size, align_corners, "cubic") - else: - result = _aten_upsample_scales(self, scale_factors, align_corners, "cubic") - return result - -@torch_op("aten::upsample_bicubic2d", private=True) +@torch_op(("aten::upsample_bicubic2d", "aten::upsample_bilinear2d"), private=True) def _aten_upsample_output_size( self: TReal, output_size: INT64, - align_corners: bool, - str_mode: str, + mode: str, + coordinate_transformation_mode: str, ) -> TReal: self_shape = op.Shape(self) starts = op.Constant(value_ints=[0]) ends = op.Constant(value_ints=[2]) batch_channel = op.Slice(self_shape, starts, ends) output_size = op.Concat(batch_channel, output_size, axis=0) - if align_corners: - result = op.Resize( - self, - None, - None, - output_size, - mode=str_mode, - coordinate_transformation_mode="align_corners", - ) - else: - result = op.Resize( - self, - None, - None, - output_size, - mode=str_mode, - coordinate_transformation_mode="pytorch_half_pixel", - ) - - return result + return op.Resize( + self, + None, + None, + output_size, + mode=mode, + coordinate_transformation_mode=coordinate_transformation_mode, + nearest_mode="floor", + ) -@torch_op("aten::upsample_bicubic2d", private=True) +@torch_op(("aten::upsample_bicubic2d", "aten::upsample_bilinear2d"), private=True) def _aten_upsample_scales( self: TReal, scale_factors: TFloat, - align_corners: bool, - str_mode: str, + mode: str, + coordinate_transformation_mode: str, ) -> TReal: scale_factors = op.Cast(scale_factors, to=FLOAT.dtype) scale_factors = op.Concat(op.Constant(value_floats=[1.0, 1.0]), scale_factors, axis=0) - if align_corners: - result = op.Resize( + return op.Resize( + self, + None, + scale_factors, # format should be: [1.0, 1.0, scale_h, scale_w] + None, + mode=mode, + coordinate_transformation_mode=coordinate_transformation_mode, + nearest_mode="floor", + ) + + +@torch_op("aten::upsample_bicubic2d", trace_only=True) +def aten_upsample_bicubic2d( + self: TReal, + output_size: INT64, + align_corners: bool, + scales_h: Optional[float] = None, + scales_w: Optional[float] = None, +) -> TReal: + """upsample_bicubic2d(Tensor self, SymInt[2] output_size, bool align_corners, float? scales_h=None, float? scales_w=None) -> Tensor""" + + # NOTE: Based on experimentation, scales_h and scales_w are always ignored in PyTorch, + # unless when align_corners is True, in which case we do not know what is going on. + coordinate_transformation_mode = _get_upsample_align_corners_mode(align_corners) + return _aten_upsample_output_size( + self, + output_size, + mode="cubic", + coordinate_transformation_mode=coordinate_transformation_mode, + ) + + +@torch_op("aten::upsample_bicubic2d.vec", trace_only=True) +def aten_upsample_bicubic2d_vec( + self: TReal, + output_size: INT64, + align_corners: bool, + scale_factors: Optional[Sequence[float]], +) -> TReal: + """upsample_bicubic2d.vec(Tensor input, SymInt[]? output_size, bool align_corners, float[]? scale_factors) -> Tensor""" + + coordinate_transformation_mode = _get_upsample_align_corners_mode(align_corners) + if scale_factors is not None: + result = _aten_upsample_scales( self, - None, - scale_factors, # format should be: [1.0, 1.0, scale_h, scale_w] - None, - mode=str_mode, - coordinate_transformation_mode="align_corners", + op.Constant(value_floats=scale_factors), + mode="cubic", + coordinate_transformation_mode=coordinate_transformation_mode, ) else: - result = op.Resize( + result = _aten_upsample_output_size( self, - None, - scale_factors, # format should be: [1.0, 1.0, scale_h, scale_w] - None, - mode=str_mode, - coordinate_transformation_mode="pytorch_half_pixel", + output_size, + mode="cubic", + coordinate_transformation_mode=coordinate_transformation_mode, ) + return result @@ -2302,18 +2316,15 @@ def aten_upsample_bilinear2d( ) -> TReal: """upsample_bilinear2d(Tensor self, SymInt[2] output_size, bool align_corners, float? scales_h=None, float? scales_w=None) -> Tensor""" - coordinate_transformation_mode = "align_corners" if align_corners else "pytorch_half_pixel" - if output_size is not None: - result = _aten_upsample_bilinear2d_output_size( - self, output_size, coordinate_transformation_mode - ) - else: - assert scales_h is not None - assert scales_h == scales_w, f"scale_h({scales_h}) != scale_w({scales_w})" - result = _aten_upsample_bilinear2d_scales( - self, scales_h, scales_w, coordinate_transformation_mode - ) - return result + # NOTE: Based on experimentation, scales_h and scales_w are always ignored in PyTorch, + # unless when align_corners is True, in which case we do not know what is going on. + coordinate_transformation_mode = _get_upsample_align_corners_mode(align_corners) + return _aten_upsample_output_size( + self, + output_size, + coordinate_transformation_mode=coordinate_transformation_mode, + mode="linear", + ) @torch_op("aten::upsample_bilinear2d.vec", trace_only=True) @@ -2324,60 +2335,24 @@ def aten_upsample_bilinear2d_vec( scale_factors: Optional[Sequence[float]], ) -> TReal: """upsample_bilinear2d.vec(Tensor input, SymInt[]? output_size, bool align_corners, float[]? scale_factors) -> Tensor""" - scales_h = scale_factors[0] if scale_factors is not None else None - scales_w = scale_factors[1] if scale_factors is not None else None - return aten_upsample_bilinear2d(self, output_size, align_corners, scales_h, scales_w) - -@torch_op("aten::upsample_bilinear2d", private=True) -def _aten_upsample_bilinear2d_output_size( - self: TReal, - output_size: INT64, - coordinate_transformation_mode: str, -) -> TReal: - """upsample_bilinear2d(Tensor self, SymInt[2] output_size, bool align_corners, float? scales_h=None, float? scales_w=None) -> Tensor""" - - self_shape = op.Shape(self) - starts = op.Constant(value_ints=[0]) - ends = op.Constant(value_ints=[2]) - batch_channel = op.Slice(self_shape, starts, ends) - output_size = op.Concat(batch_channel, output_size, axis=0) - return op.Resize( - self, - None, - None, - output_size, - mode="linear", - coordinate_transformation_mode=coordinate_transformation_mode, - nearest_mode="floor", - ) - - -@torch_op("aten::upsample_bilinear2d", private=True) -def _aten_upsample_bilinear2d_scales( - self: TReal, - scales_h: float, - scales_w: float, - coordinate_transformation_mode: str, -) -> TReal: - """upsample_bilinear2d(Tensor self, SymInt[2] output_size, bool align_corners, float? scales_h=None, float? scales_w=None) -> Tensor""" + coordinate_transformation_mode = _get_upsample_align_corners_mode(align_corners) + if scale_factors is not None: + result = _aten_upsample_scales( + self, + op.Constant(value_floats=scale_factors), + mode="linear", + coordinate_transformation_mode=coordinate_transformation_mode, + ) + else: + result = _aten_upsample_output_size( + self, + output_size, + mode="linear", + coordinate_transformation_mode=coordinate_transformation_mode, + ) - neg_1 = op.Constant(value_ints=[-1]) - scales = op.Concat( - op.Constant(value_floats=[1.0, 1.0]), - op.Reshape(op.Constant(value_float=scales_h), neg_1), - op.Reshape(op.Constant(value_float=scales_w), neg_1), - axis=0, - ) - return op.Resize( - self, - None, - scales, # format should be: [1.0, 1.0, scale_h, scale_w] - None, - mode="linear", - coordinate_transformation_mode=coordinate_transformation_mode, - nearest_mode="floor", - ) + return result def aten_upsample_bilinear2d_backward( diff --git a/onnxscript/tests/function_libs/torch_lib/extra_opinfo.py b/onnxscript/tests/function_libs/torch_lib/extra_opinfo.py index a2243cb08..26920953a 100644 --- a/onnxscript/tests/function_libs/torch_lib/extra_opinfo.py +++ b/onnxscript/tests/function_libs/torch_lib/extra_opinfo.py @@ -1431,7 +1431,7 @@ def sample_inputs_unfold(op_info, device, dtype, requires_grad, **kwargs): yield opinfo_core.SampleInput(t, args=(dimension, size, step)) -def sample_inputs_upsample_bicubic2d(op_info, device, dtype, requires_grad, **kwargs): +def sample_inputs_upsample_2d(op_info, device, dtype, requires_grad, **kwargs): del op_info del kwargs @@ -1470,15 +1470,77 @@ def shape(size, rank, with_batch_channel=True): ) yield opinfo_core.SampleInput( make_arg(shape(D, rank)), - None, # output_size - align_corners, - (1.7, 1.7), # scaler + args=(shape(L, rank, False), align_corners), + kwargs=dict(scales_h=0.6, scales_w=4.2), + ) + yield opinfo_core.SampleInput( + make_arg(shape(D, rank)), + args=(shape(L, rank, False), align_corners), + kwargs=dict(scales_h=4.2, scales_w=0.6), + ) + + +def sample_inputs_upsample_2d_vec(op_info, device, dtype, requires_grad, **kwargs): + del op_info + del kwargs + + N, C = 2, 3 + D = 4 + SS = 3 + L = 5 + + align_corners_options = (True, False) + rank = 2 + + def shape(size, rank, with_batch_channel=True): + if with_batch_channel: + return tuple([N, C] + ([size] * rank)) + return tuple([size] * rank) + + make_arg = functools.partial( + torch_testing.make_tensor, + device=device, + dtype=dtype, + requires_grad=requires_grad, + low=-1, + high=1, + ) + + yield opinfo_core.SampleInput(make_arg(shape(D, rank)), shape(SS, rank, False), True, None) + + for align_corners in align_corners_options: + yield opinfo_core.SampleInput( + make_arg(shape(D, rank)), shape(S, rank, False), align_corners, None ) yield opinfo_core.SampleInput( make_arg(shape(D, rank)), - None, # if this is None, the scalar must be list + shape(L, rank, False), align_corners, - (0.6, 0.6), + None, + ) + yield opinfo_core.SampleInput( + make_arg(shape(D, rank)), + args=( + None, # output_size + align_corners, + ), + kwargs=dict(scale_factors=(1.7, 1.7)), + ) + yield opinfo_core.SampleInput( + make_arg(shape(D, rank)), + args=( + None, # if this is None, the scalar must be list + align_corners, + ), + kwargs=dict(scale_factors=(0.6, 0.6)), + ) + yield opinfo_core.SampleInput( + make_arg(shape(D, rank)), + args=( + None, # if this is None, the scalar must be list + align_corners, + ), + kwargs=dict(scale_factors=(0.6, 4.2)), ) @@ -1948,10 +2010,31 @@ def __init__(self): supports_out=False, ), opinfo_core.OpInfo( - "ops.aten.upsample_bicubic2d", + "ops.aten.upsample_bicubic2d.default", aten_name="upsample_bicubic2d", dtypes=common_dtype.floating_types_and(torch.bfloat16), - sample_inputs_func=sample_inputs_upsample_bicubic2d, + sample_inputs_func=sample_inputs_upsample_2d, + supports_out=False, + ), + opinfo_core.OpInfo( + "ops.aten.upsample_bicubic2d.vec", + aten_name="upsample_bicubic2d.vec", + dtypes=common_dtype.floating_types_and(torch.bfloat16), + sample_inputs_func=sample_inputs_upsample_2d_vec, + supports_out=False, + ), + opinfo_core.OpInfo( + "ops.aten.upsample_bilinear2d.default", + aten_name="upsample_bilinear2d", + dtypes=common_dtype.floating_types_and(torch.bfloat16), + sample_inputs_func=sample_inputs_upsample_2d, + supports_out=False, + ), + opinfo_core.OpInfo( + "ops.aten.upsample_bilinear2d.vec", + aten_name="upsample_bilinear2d.vec", + dtypes=common_dtype.floating_types_and(torch.bfloat16), + sample_inputs_func=sample_inputs_upsample_2d_vec, supports_out=False, ), opinfo_core.OpInfo( diff --git a/onnxscript/tests/function_libs/torch_lib/ops_test_data.py b/onnxscript/tests/function_libs/torch_lib/ops_test_data.py index 67efe8047..368a05170 100644 --- a/onnxscript/tests/function_libs/torch_lib/ops_test_data.py +++ b/onnxscript/tests/function_libs/torch_lib/ops_test_data.py @@ -415,57 +415,6 @@ def _sum_input_wrangler( return args, kwargs -def _upsample_bilinear2d_input_wrangler( - args: list[Any], kwargs: dict[str, Any] -) -> tuple[list[Any], dict[str, Any]]: - # Wrangler for the signature difference between - # 'nn.functional.upsample_bilinear' - # and - # 'aten::upsample_bilinear2d' - # https://pytorch.org/docs/stable/generated/torch.nn.functional.upsample_bilinear.html - if "size" in kwargs: - args.append(np.array(kwargs["size"], dtype=np.int64)) - del kwargs["size"] # promote tensor type kwargs to args - else: - args.append(None) - if "align_corners" in kwargs: - args.append(kwargs["align_corners"]) - del kwargs["align_corners"] - else: - args.append(True) # Fill in the default value - if "scale_factor" in kwargs: - kwargs["scales_h"] = kwargs["scale_factor"] - kwargs["scales_w"] = kwargs["scale_factor"] - del kwargs["scale_factor"] # adapt the function signature - return args, kwargs - - -def _upsample_bilinear2d_vec_input_wrangler( - args: list[Any], kwargs: dict[str, Any] -) -> tuple[list[Any], dict[str, Any]]: - # Wrangler for the signature difference between - # 'nn.functional.upsample_bilinear' - # and - # 'aten::upsample_bilinear2d.vec' - # https://pytorch.org/docs/stable/generated/torch.nn.functional.upsample_bilinear.html - if "size" in kwargs: - args.append(np.array(kwargs["size"], dtype=np.int64)) - del kwargs["size"] # promote tensor type kwargs to args - else: - args.append(None) - if "align_corners" in kwargs: - args.append(kwargs["align_corners"]) - del kwargs["align_corners"] - else: - args.append(True) # Fill in the default value - if "scale_factor" in kwargs: - args.append([kwargs["scale_factor"]] * 2) - del kwargs["scale_factor"] # promote tensor type kwargs to args - else: - args.append(None) - return args, kwargs - - def _upsample_input_wrangler( args: list[Any], kwargs: dict[str, Any] ) -> tuple[list[Any], dict[str, Any]]: @@ -2156,21 +2105,32 @@ def _where_input_wrangler( test_class_name="TestOutputConsistencyEager", ), TorchLibOpInfo( - "nn.functional.upsample_bilinear2d", + "ops.aten.upsample_bilinear2d.default", nn_ops.aten_upsample_bilinear2d, - input_wrangler=_upsample_bilinear2d_input_wrangler, trace_only=True, + ).xfail( + matcher=lambda sample: sample.args[1] is False + and sample.kwargs.get("scales_h") is not None, + reason="fixme: align_corners=False output mismatch when scales are provided", ), TorchLibOpInfo( - "nn.functional.upsample_bilinear2d", + "ops.aten.upsample_bilinear2d.vec", nn_ops.aten_upsample_bilinear2d_vec, - input_wrangler=_upsample_bilinear2d_vec_input_wrangler, trace_only=True, ), TorchLibOpInfo( - "ops.aten.upsample_bicubic2d", + "ops.aten.upsample_bicubic2d.default", nn_ops.aten_upsample_bicubic2d, trace_only=True, + ).xfail( + matcher=lambda sample: sample.args[1] is False + and sample.kwargs.get("scales_h") is not None, + reason="fixme: align_corners=False output mismatch when scales are provided", + ), + TorchLibOpInfo( + "ops.aten.upsample_bicubic2d.vec", + nn_ops.aten_upsample_bicubic2d_vec, + trace_only=True, ), TorchLibOpInfo( "nn.functional.upsample_nearest2d", @@ -2403,11 +2363,6 @@ def _where_input_wrangler( "nn.functional.celu", ("nn.functional.celu_type_promoted",), ) -ops_test_common.duplicate_opinfo( - OPS_DB, - "nn.functional.upsample_bilinear", - ("nn.functional.upsample_bilinear2d",), -) ops_test_common.duplicate_opinfo( OPS_DB, "nn.functional.upsample_nearest",