Skip to content

Commit

Permalink
Fix 361 (#367)
Browse files Browse the repository at this point in the history
* check if N==0, and handle this case

* change position of comment

* Address comments from shuhei
  • Loading branch information
ravinkohli authored Jan 24, 2022
1 parent f612f46 commit c0fb82e
Show file tree
Hide file tree
Showing 4 changed files with 89 additions and 1 deletion.
10 changes: 10 additions & 0 deletions autoPyTorch/pipeline/components/training/trainer/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -293,6 +293,13 @@ def _fit(self, X: Dict[str, Any], y: Any = None, **kwargs: Any) -> 'TrainerChoic
writer=writer,
)

# its fine if train_loss is None due to `is_max_time_reached()`
if train_loss is None:
if self.budget_tracker.is_max_time_reached():
break
else:
raise RuntimeError("Got an unexpected None in `train_loss`.")

val_loss, val_metrics, test_loss, test_metrics = None, {}, None, {}
if self.eval_valid_each_epoch(X):
val_loss, val_metrics = self.choice.evaluate(X['val_data_loader'], epoch, writer)
Expand Down Expand Up @@ -334,6 +341,9 @@ def _fit(self, X: Dict[str, Any], y: Any = None, **kwargs: Any) -> 'TrainerChoic
if 'cuda' in X['device']:
torch.cuda.empty_cache()

if self.run_summary.is_empty():
raise RuntimeError("Budget exhausted without finishing an epoch.")

# wrap up -- add score if not evaluating every epoch
if not self.eval_valid_each_epoch(X):
val_loss, val_metrics = self.choice.evaluate(X['val_data_loader'], epoch, writer)
Expand Down
15 changes: 14 additions & 1 deletion autoPyTorch/pipeline/components/training/trainer/base_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,16 @@ def repr_last_epoch(self) -> str:
string += '=' * 40
return string

def is_empty(self) -> bool:
"""
Checks if the object is empty or not
Returns:
bool
"""
# if train_loss is empty, we can be sure that RunSummary is empty.
return not bool(self.performance_tracker['train_loss'])


class BaseTrainerComponent(autoPyTorchTrainingComponent):

Expand Down Expand Up @@ -277,7 +287,7 @@ def _scheduler_step(

def train_epoch(self, train_loader: torch.utils.data.DataLoader, epoch: int,
writer: Optional[SummaryWriter],
) -> Tuple[float, Dict[str, float]]:
) -> Tuple[Optional[float], Dict[str, float]]:
"""
Train the model for a single epoch.
Expand Down Expand Up @@ -317,6 +327,9 @@ def train_epoch(self, train_loader: torch.utils.data.DataLoader, epoch: int,
epoch * len(train_loader) + step,
)

if N == 0:
return None, {}

self._scheduler_step(step_interval=StepIntervalUnit.epoch, loss=loss_sum / N)

if self.metrics_during_training:
Expand Down
37 changes: 37 additions & 0 deletions test/test_pipeline/components/training/test_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,6 +236,43 @@ def test_train_step(self):
lr = optimizer.param_groups[0]['lr']
assert lr == target_lr

def test_train_epoch_no_step(self):
"""
This test checks if max runtime is reached
for an epoch before any train_step has been
completed. In this case we would like to
return None for train_loss and an empty
dictionary for the metrics.
"""
device = torch.device('cpu')
model = torch.nn.Linear(1, 1).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=1)
data_loader = unittest.mock.MagicMock(spec=torch.utils.data.DataLoader)
ms = [3, 5, 6]
params = {
'metrics': [],
'device': device,
'task_type': constants.TABULAR_REGRESSION,
'labels': torch.Tensor([]),
'metrics_during_training': False,
'budget_tracker': BudgetTracker(budget_type='runtime', max_runtime=0),
'criterion': torch.nn.MSELoss,
'optimizer': optimizer,
'scheduler': torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=ms, gamma=2),
'model': model,
'step_interval': StepIntervalUnit.epoch
}
trainer = StandardTrainer()
trainer.prepare(**params)

loss, metrics = trainer.train_epoch(
train_loader=data_loader,
epoch=0,
writer=None
)
assert loss is None
assert metrics == {}


class TestStandardTrainer(BaseTraining):
def test_regression_epoch_training(self, n_samples):
Expand Down
28 changes: 28 additions & 0 deletions test/test_pipeline/test_tabular_classification.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import os
import re
import unittest
import unittest.mock

from ConfigSpace.hyperparameters import (
CategoricalHyperparameter,
Expand Down Expand Up @@ -491,3 +492,30 @@ def test_train_pipeline_with_runtime(fit_dictionary_tabular_dummy):

# More than 200 epochs would have pass in 5 seconds for this dataset
assert len(run_summary.performance_tracker['start_time']) > 100


@pytest.mark.parametrize("fit_dictionary_tabular_dummy", ["classification"], indirect=True)
def test_train_pipeline_with_runtime_max_reached(fit_dictionary_tabular_dummy):
"""
This test makes sure that the pipeline raises an
error in case no epoch has finished successfully
due to max runtime reached
"""

# Convert the training to runtime
fit_dictionary_tabular_dummy.pop('epochs', None)
fit_dictionary_tabular_dummy['budget_type'] = 'runtime'
fit_dictionary_tabular_dummy['runtime'] = 5
fit_dictionary_tabular_dummy['early_stopping'] = -1

pipeline = TabularClassificationPipeline(
dataset_properties=fit_dictionary_tabular_dummy['dataset_properties'])

cs = pipeline.get_hyperparameter_search_space()
config = cs.get_default_configuration()
pipeline.set_hyperparameters(config)

with unittest.mock.patch('autoPyTorch.pipeline.components.training.trainer.BudgetTracker') as patch:
patch.is_max_time_reached.return_value = True
with pytest.raises(RuntimeError):
pipeline.fit(fit_dictionary_tabular_dummy)

0 comments on commit c0fb82e

Please sign in to comment.