diff --git a/frogbox/callbacks/image_logger.py b/frogbox/callbacks/image_logger.py index d294171..fb0984a 100644 --- a/frogbox/callbacks/image_logger.py +++ b/frogbox/callbacks/image_logger.py @@ -65,13 +65,13 @@ def create_image_logger( interpolation: InterpolationMode = InterpolationMode.NEAREST, antialias: bool = True, num_cols: Optional[int] = None, - normalize_mean: Sequence[float] = (0.0, 0.0, 0.0), - normalize_std: Sequence[float] = (1.0, 1.0, 1.0), denormalize_input: bool = False, denormalize_target: bool = False, + normalize_mean: Sequence[float] = (0.0, 0.0, 0.0), + normalize_std: Sequence[float] = (1.0, 1.0, 1.0), progress: bool = False, prepare_batch: Callable = _prepare_batch, - input_transform: Callable[[Any], Any] = lambda x: x, + input_transform: Callable[[Any, Any], Any] = lambda x, y: (x, y), model_transform: Callable[[Any], Any] = lambda output: output, output_transform: Callable[[Any, Any, Any], Any] = lambda x, y, y_pred: ( x, @@ -96,15 +96,31 @@ def create_image_logger( If `true` antialiasing is used when resizing images. num_cols : int Number of columns in image grid. + Defaults to number of elements in returned tuple. + denormalize_input : bool + If `true` input images (x) a denormalized after inference. + denormalize_target : bool + If `true` target images (y and y_pred) are denormalized after inference. normalize_mean : (float, float, float) RGB mean values used in image normalization. normalize_std : (float, float, float) RGB std.dev. values used in image normalization. - denormalize_input : bool - If `true` input images a denormalized before logging. - denormalize_target : bool - If `true` target images (y and y_pred) are denormalized before logging. - """ + progress : bool + Show progress bar. + prepare_batch : Callable + Function that receives `batch`, `device`, `non_blocking` and + outputs tuple of tensors `(batch_x, batch_y)`. + input_transform : Callable + Function that receives tensors `y` and `y` and outputs tuple of + tensors `(x, y)`. + model_transform : Callable + Function that receives the output from the model during evaluation + and converts it into the predictions: + `y_pred = model_transform(model(x))`. + output_transform : Callable + Function that receives `x`, `y`, `y_pred` and returns tensors to be + logged as images. Default is returning `(x, y_pred, y)`. + """ # noqa: E501 denormalize = Denormalize( torch.as_tensor(normalize_mean), torch.as_tensor(normalize_std), @@ -132,7 +148,7 @@ def _callback(pipeline: SupervisedPipeline): images = [] for batch in data_iter: x, y = prepare_batch(batch, device, non_blocking=False) - x = input_transform(x) + x, y = input_transform(x, y) with torch.inference_mode(): with torch.autocast( device_type=device.type, enabled=config.amp diff --git a/frogbox/engines/supervised.py b/frogbox/engines/supervised.py index 8ef9e35..0c78173 100644 --- a/frogbox/engines/supervised.py +++ b/frogbox/engines/supervised.py @@ -15,6 +15,7 @@ def create_supervised_trainer( device: Union[str, torch.device] = "cpu", non_blocking: bool = False, prepare_batch: Callable = _prepare_batch, + input_transform: Callable[[Any, Any], Any] = lambda x, y: (x, y), model_transform: Callable[[Any], Any] = lambda output: output, output_transform: Callable[ [Any, Any, Any, torch.Tensor], Any @@ -46,6 +47,9 @@ def create_supervised_trainer( prepare_batch : Callable Function that receives `batch`, `device`, `non_blocking` and outputs tuple of tensors `(batch_x, batch_y)`. + input_transform : Callable + Function that receives tensors `y` and `y` and outputs tuple of + tensors `(x, y)`. model_transform : Callable Function that receives the output from the model and convert it into the form as required by the loss function. @@ -83,6 +87,8 @@ def _update( model.train() x, y = prepare_batch(batch, device=device, non_blocking=non_blocking) + x, y = input_transform(x, y) + with torch.autocast(device_type=device.type, enabled=amp): output = model(x) y_pred = model_transform(output) @@ -135,6 +141,7 @@ def create_supervised_evaluator( device: Union[str, torch.device] = "cpu", non_blocking: bool = False, prepare_batch: Callable = _prepare_batch, + input_transform: Callable[[Any, Any], Any] = lambda x, y: (x, y), model_transform: Callable[[Any], Any] = lambda output: output, output_transform: Callable[[Any, Any, Any], Any] = lambda x, y, y_pred: ( y_pred, @@ -163,6 +170,9 @@ def create_supervised_evaluator( prepare_batch : Callable Function that receives `batch`, `device`, `non_blocking` and outputs tuple of tensors `(batch_x, batch_y)`. + input_transform : Callable + Function that receives tensors `y` and `y` and outputs tuple of + tensors `(x, y)`. model_transform : Callable Function that receives the output from the model and convert it into the predictions: `y_pred = model_transform(model(x))`. @@ -185,6 +195,7 @@ def _step( x, y = prepare_batch( batch, device=device, non_blocking=non_blocking ) + x, y = input_transform(x, y) with torch.autocast(device_type=device.type, enabled=config.amp): output = model(x) diff --git a/frogbox/pipelines/supervised.py b/frogbox/pipelines/supervised.py index 14f8c65..87e855a 100644 --- a/frogbox/pipelines/supervised.py +++ b/frogbox/pipelines/supervised.py @@ -47,10 +47,18 @@ def __init__( tags: Optional[Sequence[str]] = None, group: Optional[str] = None, prepare_batch: Callable = _prepare_batch, + trainer_input_transform: Callable[[Any, Any], Any] = lambda x, y: ( + x, + y, + ), trainer_model_transform: Callable[[Any], Any] = lambda output: output, trainer_output_transform: Callable[ [Any, Any, Any, torch.Tensor], Any ] = lambda x, y, y_pred, loss: loss.item(), + evaluator_input_transform: Callable[[Any, Any], Any] = lambda x, y: ( + x, + y, + ), evaluator_model_transform: Callable[ [Any], Any ] = lambda output: output, @@ -85,6 +93,9 @@ def __init__( prepare_batch : Callable Function that receives `batch`, `device`, `non_blocking` and outputs tuple of tensors `(batch_x, batch_y)`. + trainer_input_transform : Callable + Function that receives tensors `y` and `y` and outputs tuple of + tensors `(x, y)`. trainer_model_transform : Callable Function that receives the output from the model during training and converts it into the form as required by the loss function. @@ -92,6 +103,9 @@ def __init__( Function that receives `x`, `y`, `y_pred`, `loss` and returns value to be assigned to trainer's `state.output` after each iteration. Default is returning `loss.item()`. + evaluator_input_transform : Callable + Function that receives tensors `y` and `y` and outputs tuple of + tensors `(x, y)`. evaluator_model_transform : Callable Function that receives the output from the model during evaluation and converts it into the predictions: @@ -135,6 +149,7 @@ def __init__( loss_fn=self.loss_fn, device=device, prepare_batch=prepare_batch, + input_transform=trainer_input_transform, model_transform=trainer_model_transform, output_transform=trainer_output_transform, ) @@ -164,6 +179,7 @@ def __init__( metrics=metrics, device=device, prepare_batch=prepare_batch, + input_transform=evaluator_input_transform, model_transform=evaluator_model_transform, output_transform=evaluator_output_transform, )