Skip to content

Commit

Permalink
Added input_transform args to pipeline and image logger
Browse files Browse the repository at this point in the history
  • Loading branch information
SimonLarsen committed Feb 29, 2024
1 parent e7be377 commit df50772
Show file tree
Hide file tree
Showing 3 changed files with 52 additions and 9 deletions.
34 changes: 25 additions & 9 deletions frogbox/callbacks/image_logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,13 +65,13 @@ def create_image_logger(
interpolation: InterpolationMode = InterpolationMode.NEAREST,
antialias: bool = True,
num_cols: Optional[int] = None,
normalize_mean: Sequence[float] = (0.0, 0.0, 0.0),
normalize_std: Sequence[float] = (1.0, 1.0, 1.0),
denormalize_input: bool = False,
denormalize_target: bool = False,
normalize_mean: Sequence[float] = (0.0, 0.0, 0.0),
normalize_std: Sequence[float] = (1.0, 1.0, 1.0),
progress: bool = False,
prepare_batch: Callable = _prepare_batch,
input_transform: Callable[[Any], Any] = lambda x: x,
input_transform: Callable[[Any, Any], Any] = lambda x, y: (x, y),
model_transform: Callable[[Any], Any] = lambda output: output,
output_transform: Callable[[Any, Any, Any], Any] = lambda x, y, y_pred: (
x,
Expand All @@ -96,15 +96,31 @@ def create_image_logger(
If `true` antialiasing is used when resizing images.
num_cols : int
Number of columns in image grid.
Defaults to number of elements in returned tuple.
denormalize_input : bool
If `true` input images (x) a denormalized after inference.
denormalize_target : bool
If `true` target images (y and y_pred) are denormalized after inference.
normalize_mean : (float, float, float)
RGB mean values used in image normalization.
normalize_std : (float, float, float)
RGB std.dev. values used in image normalization.
denormalize_input : bool
If `true` input images a denormalized before logging.
denormalize_target : bool
If `true` target images (y and y_pred) are denormalized before logging.
"""
progress : bool
Show progress bar.
prepare_batch : Callable
Function that receives `batch`, `device`, `non_blocking` and
outputs tuple of tensors `(batch_x, batch_y)`.
input_transform : Callable
Function that receives tensors `y` and `y` and outputs tuple of
tensors `(x, y)`.
model_transform : Callable
Function that receives the output from the model during evaluation
and converts it into the predictions:
`y_pred = model_transform(model(x))`.
output_transform : Callable
Function that receives `x`, `y`, `y_pred` and returns tensors to be
logged as images. Default is returning `(x, y_pred, y)`.
""" # noqa: E501
denormalize = Denormalize(
torch.as_tensor(normalize_mean),
torch.as_tensor(normalize_std),
Expand Down Expand Up @@ -132,7 +148,7 @@ def _callback(pipeline: SupervisedPipeline):
images = []
for batch in data_iter:
x, y = prepare_batch(batch, device, non_blocking=False)
x = input_transform(x)
x, y = input_transform(x, y)
with torch.inference_mode():
with torch.autocast(
device_type=device.type, enabled=config.amp
Expand Down
11 changes: 11 additions & 0 deletions frogbox/engines/supervised.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ def create_supervised_trainer(
device: Union[str, torch.device] = "cpu",
non_blocking: bool = False,
prepare_batch: Callable = _prepare_batch,
input_transform: Callable[[Any, Any], Any] = lambda x, y: (x, y),
model_transform: Callable[[Any], Any] = lambda output: output,
output_transform: Callable[
[Any, Any, Any, torch.Tensor], Any
Expand Down Expand Up @@ -46,6 +47,9 @@ def create_supervised_trainer(
prepare_batch : Callable
Function that receives `batch`, `device`, `non_blocking`
and outputs tuple of tensors `(batch_x, batch_y)`.
input_transform : Callable
Function that receives tensors `y` and `y` and outputs tuple of
tensors `(x, y)`.
model_transform : Callable
Function that receives the output from the model and
convert it into the form as required by the loss function.
Expand Down Expand Up @@ -83,6 +87,8 @@ def _update(
model.train()

x, y = prepare_batch(batch, device=device, non_blocking=non_blocking)
x, y = input_transform(x, y)

with torch.autocast(device_type=device.type, enabled=amp):
output = model(x)
y_pred = model_transform(output)
Expand Down Expand Up @@ -135,6 +141,7 @@ def create_supervised_evaluator(
device: Union[str, torch.device] = "cpu",
non_blocking: bool = False,
prepare_batch: Callable = _prepare_batch,
input_transform: Callable[[Any, Any], Any] = lambda x, y: (x, y),
model_transform: Callable[[Any], Any] = lambda output: output,
output_transform: Callable[[Any, Any, Any], Any] = lambda x, y, y_pred: (
y_pred,
Expand Down Expand Up @@ -163,6 +170,9 @@ def create_supervised_evaluator(
prepare_batch : Callable
Function that receives `batch`, `device`, `non_blocking`
and outputs tuple of tensors `(batch_x, batch_y)`.
input_transform : Callable
Function that receives tensors `y` and `y` and outputs tuple of
tensors `(x, y)`.
model_transform : Callable
Function that receives the output from the model and convert it into
the predictions: `y_pred = model_transform(model(x))`.
Expand All @@ -185,6 +195,7 @@ def _step(
x, y = prepare_batch(
batch, device=device, non_blocking=non_blocking
)
x, y = input_transform(x, y)

with torch.autocast(device_type=device.type, enabled=config.amp):
output = model(x)
Expand Down
16 changes: 16 additions & 0 deletions frogbox/pipelines/supervised.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,10 +47,18 @@ def __init__(
tags: Optional[Sequence[str]] = None,
group: Optional[str] = None,
prepare_batch: Callable = _prepare_batch,
trainer_input_transform: Callable[[Any, Any], Any] = lambda x, y: (
x,
y,
),
trainer_model_transform: Callable[[Any], Any] = lambda output: output,
trainer_output_transform: Callable[
[Any, Any, Any, torch.Tensor], Any
] = lambda x, y, y_pred, loss: loss.item(),
evaluator_input_transform: Callable[[Any, Any], Any] = lambda x, y: (
x,
y,
),
evaluator_model_transform: Callable[
[Any], Any
] = lambda output: output,
Expand Down Expand Up @@ -85,13 +93,19 @@ def __init__(
prepare_batch : Callable
Function that receives `batch`, `device`, `non_blocking` and
outputs tuple of tensors `(batch_x, batch_y)`.
trainer_input_transform : Callable
Function that receives tensors `y` and `y` and outputs tuple of
tensors `(x, y)`.
trainer_model_transform : Callable
Function that receives the output from the model during training
and converts it into the form as required by the loss function.
trainer_output_transform : Callable
Function that receives `x`, `y`, `y_pred`, `loss` and returns value
to be assigned to trainer's `state.output` after each iteration.
Default is returning `loss.item()`.
evaluator_input_transform : Callable
Function that receives tensors `y` and `y` and outputs tuple of
tensors `(x, y)`.
evaluator_model_transform : Callable
Function that receives the output from the model during evaluation
and converts it into the predictions:
Expand Down Expand Up @@ -135,6 +149,7 @@ def __init__(
loss_fn=self.loss_fn,
device=device,
prepare_batch=prepare_batch,
input_transform=trainer_input_transform,
model_transform=trainer_model_transform,
output_transform=trainer_output_transform,
)
Expand Down Expand Up @@ -164,6 +179,7 @@ def __init__(
metrics=metrics,
device=device,
prepare_batch=prepare_batch,
input_transform=evaluator_input_transform,
model_transform=evaluator_model_transform,
output_transform=evaluator_output_transform,
)
Expand Down

0 comments on commit df50772

Please sign in to comment.