diff --git a/autoPyTorch/api/base_task.py b/autoPyTorch/api/base_task.py index de4d49ac0..16f498e43 100644 --- a/autoPyTorch/api/base_task.py +++ b/autoPyTorch/api/base_task.py @@ -46,6 +46,7 @@ from autoPyTorch.pipeline.components.training.metrics.utils import calculate_score, get_metrics from autoPyTorch.utils.backend import Backend, create from autoPyTorch.utils.common import FitRequirement, replace_string_bool_to_bool +from autoPyTorch.utils.hyperparameter_search_space_update import HyperparameterSearchSpaceUpdates from autoPyTorch.utils.logging_ import ( PicklableClientLogger, get_named_client_logger, @@ -135,6 +136,7 @@ def __init__( include_components: Optional[Dict] = None, exclude_components: Optional[Dict] = None, backend: Optional[Backend] = None, + search_space_updates: Optional[HyperparameterSearchSpaceUpdates] = None ) -> None: self.seed = seed self.n_jobs = n_jobs @@ -178,6 +180,13 @@ def __init__( self.stop_logging_server = None # type: Optional[multiprocessing.synchronize.Event] + self.search_space_updates = search_space_updates + if search_space_updates is not None: + if not isinstance(self.search_space_updates, + HyperparameterSearchSpaceUpdates): + raise ValueError("Expected search space updates to be of instance" + " HyperparameterSearchSpaceUpdates got {}".format(type(self.search_space_updates))) + @abstractmethod def _get_required_dataset_properties(self, dataset: BaseDataset) -> Dict[str, Any]: """ @@ -252,7 +261,8 @@ def get_search_space(self, dataset: BaseDataset = None) -> ConfigurationSpace: info=self._get_required_dataset_properties(dataset)) return get_configuration_space(info=dataset.get_dataset_properties(dataset_requirements), include=self.include_components, - exclude=self.exclude_components) + exclude=self.exclude_components, + search_space_updates=self.search_space_updates) raise Exception("No search space initialised and no dataset passed. " "Can't create default search space without the dataset") @@ -816,7 +826,8 @@ def search( pipeline_config={**self.pipeline_options, **budget_config}, ensemble_callback=proc_ensemble, logger_port=self._logger_port, - start_num_run=num_run + start_num_run=num_run, + search_space_updates=self.search_space_updates ) try: self.run_history, self.trajectory, budget_type = \ diff --git a/autoPyTorch/api/tabular_classification.py b/autoPyTorch/api/tabular_classification.py index 70eac1c2a..165ca98e7 100644 --- a/autoPyTorch/api/tabular_classification.py +++ b/autoPyTorch/api/tabular_classification.py @@ -9,6 +9,7 @@ from autoPyTorch.datasets.tabular_dataset import TabularDataset from autoPyTorch.pipeline.tabular_classification import TabularClassificationPipeline from autoPyTorch.utils.backend import Backend +from autoPyTorch.utils.hyperparameter_search_space_update import HyperparameterSearchSpaceUpdates class TabularClassificationTask(BaseTask): @@ -52,6 +53,7 @@ def __init__( include_components: Optional[Dict] = None, exclude_components: Optional[Dict] = None, backend: Optional[Backend] = None, + search_space_updates: Optional[HyperparameterSearchSpaceUpdates] = None ): super().__init__( seed=seed, @@ -67,6 +69,7 @@ def __init__( include_components=include_components, exclude_components=exclude_components, backend=backend, + search_space_updates=search_space_updates ) self.task_type = TASK_TYPES_TO_STRING[TABULAR_CLASSIFICATION] diff --git a/autoPyTorch/evaluation/abstract_evaluator.py b/autoPyTorch/evaluation/abstract_evaluator.py index 69619bd2e..65d252852 100644 --- a/autoPyTorch/evaluation/abstract_evaluator.py +++ b/autoPyTorch/evaluation/abstract_evaluator.py @@ -42,6 +42,7 @@ get_metrics, ) from autoPyTorch.utils.backend import Backend +from autoPyTorch.utils.hyperparameter_search_space_update import HyperparameterSearchSpaceUpdates from autoPyTorch.utils.logging_ import PicklableClientLogger, get_named_client_logger from autoPyTorch.utils.pipeline import get_dataset_requirements @@ -200,7 +201,9 @@ def __init__(self, backend: Backend, disable_file_output: Union[bool, List[str]] = False, init_params: Optional[Dict[str, Any]] = None, logger_port: Optional[int] = None, - all_supported_metrics: bool = True) -> None: + all_supported_metrics: bool = True, + search_space_updates: Optional[HyperparameterSearchSpaceUpdates] = None + ) -> None: self.starttime = time.time() @@ -218,6 +221,7 @@ def __init__(self, backend: Backend, self.include = include self.exclude = exclude + self.search_space_updates = search_space_updates self.X_train, self.y_train = self.datamanager.train_tensors @@ -324,6 +328,7 @@ def __init__(self, backend: Backend, self.pipelines: Optional[List[BaseEstimator]] = None self.pipeline: Optional[BaseEstimator] = None self.logger.debug("Fit dictionary in Abstract evaluator: {}".format(self.fit_dictionary)) + self.logger.debug("Search space updates :{}".format(self.search_space_updates)) def _get_pipeline(self) -> BaseEstimator: assert self.pipeline_class is not None, "Can't return pipeline, pipeline_class not initialised" @@ -337,7 +342,8 @@ def _get_pipeline(self) -> BaseEstimator: random_state=np.random.RandomState(self.seed), include=self.include, exclude=self.exclude, - init_params=self._init_params) + init_params=self._init_params, + search_space_updates=self.search_space_updates) elif isinstance(self.configuration, str): pipeline = self.pipeline_class(config=self.configuration, dataset_properties=self.dataset_properties, diff --git a/autoPyTorch/evaluation/tae.py b/autoPyTorch/evaluation/tae.py index 9562d8051..770625c53 100644 --- a/autoPyTorch/evaluation/tae.py +++ b/autoPyTorch/evaluation/tae.py @@ -25,6 +25,7 @@ from autoPyTorch.evaluation.utils import empty_queue, extract_learning_curve, read_queue from autoPyTorch.pipeline.components.training.metrics.base import autoPyTorchMetric from autoPyTorch.utils.backend import Backend +from autoPyTorch.utils.hyperparameter_search_space_update import HyperparameterSearchSpaceUpdates from autoPyTorch.utils.logging_ import PicklableClientLogger, get_named_client_logger @@ -111,6 +112,7 @@ def __init__( ta: typing.Optional[typing.Callable] = None, logger_port: int = None, all_supported_metrics: bool = True, + search_space_updates: typing.Optional[HyperparameterSearchSpaceUpdates] = None ): eval_function = autoPyTorch.evaluation.train_evaluator.eval_function @@ -164,6 +166,8 @@ def __init__( self.resampling_strategy = dm.resampling_strategy self.resampling_strategy_args = dm.resampling_strategy_args + self.search_space_updates = search_space_updates + def run_wrapper( self, run_info: RunInfo, @@ -250,6 +254,7 @@ def run( else: num_run = config.config_id + self.initial_num_run + self.logger.debug("Search space updates: {}".format(self.search_space_updates)) obj_kwargs = dict( queue=queue, config=config, @@ -267,7 +272,8 @@ def run( budget_type=self.budget_type, pipeline_config=self.pipeline_config, logger_port=self.logger_port, - all_supported_metrics=self.all_supported_metrics + all_supported_metrics=self.all_supported_metrics, + search_space_updates=self.search_space_updates ) info: typing.Optional[typing.List[RunValue]] diff --git a/autoPyTorch/evaluation/train_evaluator.py b/autoPyTorch/evaluation/train_evaluator.py index 5e175df9b..f74c2ac19 100644 --- a/autoPyTorch/evaluation/train_evaluator.py +++ b/autoPyTorch/evaluation/train_evaluator.py @@ -20,6 +20,7 @@ from autoPyTorch.evaluation.utils import subsampler from autoPyTorch.pipeline.components.training.metrics.base import autoPyTorchMetric from autoPyTorch.utils.backend import Backend +from autoPyTorch.utils.hyperparameter_search_space_update import HyperparameterSearchSpaceUpdates __all__ = ['TrainEvaluator', 'eval_function'] @@ -48,7 +49,8 @@ def __init__(self, backend: Backend, queue: Queue, init_params: Optional[Dict[str, Any]] = None, logger_port: Optional[int] = None, keep_models: Optional[bool] = None, - all_supported_metrics: bool = True) -> None: + all_supported_metrics: bool = True, + search_space_updates: Optional[HyperparameterSearchSpaceUpdates] = None) -> None: super().__init__( backend=backend, queue=queue, @@ -65,7 +67,8 @@ def __init__(self, backend: Backend, queue: Queue, budget_type=budget_type, logger_port=logger_port, all_supported_metrics=all_supported_metrics, - pipeline_config=pipeline_config + pipeline_config=pipeline_config, + search_space_updates=search_space_updates ) self.splits = self.datamanager.splits @@ -77,6 +80,7 @@ def __init__(self, backend: Backend, queue: Queue, self.pipelines: List[Optional[BaseEstimator]] = [None] * self.num_folds self.indices: List[Optional[Tuple[Union[np.ndarray, List], Union[np.ndarray, List]]]] = [None] * self.num_folds + self.logger.debug("Search space updates :{}".format(self.search_space_updates)) self.keep_models = keep_models def fit_predict_and_loss(self) -> None: @@ -320,6 +324,7 @@ def eval_function( init_params: Optional[Dict[str, Any]] = None, logger_port: Optional[int] = None, all_supported_metrics: bool = True, + search_space_updates: Optional[HyperparameterSearchSpaceUpdates] = None, instance: str = None, ) -> None: evaluator = TrainEvaluator( @@ -338,6 +343,7 @@ def eval_function( budget_type=budget_type, logger_port=logger_port, all_supported_metrics=all_supported_metrics, - pipeline_config=pipeline_config + pipeline_config=pipeline_config, + search_space_updates=search_space_updates ) evaluator.fit_predict_and_loss() diff --git a/autoPyTorch/optimizer/smbo.py b/autoPyTorch/optimizer/smbo.py index efc54a516..9a464a1bb 100644 --- a/autoPyTorch/optimizer/smbo.py +++ b/autoPyTorch/optimizer/smbo.py @@ -16,17 +16,17 @@ from smac.tae.serial_runner import SerialRunner from smac.utils.io.traj_logging import TrajEntry -# TODO: Enable when merged Ensemble -# from autoPyTorch.ensemble.ensemble_builder import EnsembleBuilderManager from autoPyTorch.datasets.base_dataset import BaseDataset from autoPyTorch.datasets.resampling_strategy import ( CrossValTypes, DEFAULT_RESAMPLING_PARAMETERS, HoldoutValTypes, ) +from autoPyTorch.ensemble.ensemble_builder import EnsembleBuilderManager from autoPyTorch.evaluation.tae import ExecuteTaFuncWithQueue, get_cost_of_crash from autoPyTorch.pipeline.components.training.metrics.base import autoPyTorchMetric from autoPyTorch.utils.backend import Backend +from autoPyTorch.utils.hyperparameter_search_space_update import HyperparameterSearchSpaceUpdates from autoPyTorch.utils.logging_ import get_named_client_logger from autoPyTorch.utils.stopwatch import StopWatch @@ -101,10 +101,9 @@ def __init__(self, smac_scenario_args: typing.Optional[typing.Dict[str, typing.Any]] = None, get_smac_object_callback: typing.Optional[typing.Callable] = None, all_supported_metrics: bool = True, - # TODO: Re-enable when ensemble merged - # ensemble_callback: typing.Optional[EnsembleBuilderManager] = None, - ensemble_callback: typing.Any = None, + ensemble_callback: typing.Optional[EnsembleBuilderManager] = None, logger_port: typing.Optional[int] = None, + search_space_updates: typing.Optional[HyperparameterSearchSpaceUpdates] = None ): """ Interface to SMAC. This method calls the SMAC optimize method, and allows @@ -194,6 +193,8 @@ def __init__(self, self.ensemble_callback = ensemble_callback + self.search_space_updates = search_space_updates + dataset_name_ = "" if dataset_name is None else dataset_name if logger_port is None: self.logger_port = logging.handlers.DEFAULT_TCP_LOGGING_PORT @@ -254,7 +255,8 @@ def run_smbo(self, func: typing.Optional[typing.Callable] = None ta=func, logger_port=self.logger_port, all_supported_metrics=self.all_supported_metrics, - pipeline_config=self.pipeline_config + pipeline_config=self.pipeline_config, + search_space_updates=self.search_space_updates ) ta = ExecuteTaFuncWithQueue self.logger.info("Created TA") diff --git a/autoPyTorch/pipeline/base_pipeline.py b/autoPyTorch/pipeline/base_pipeline.py index 2d7449c03..bf42dd364 100644 --- a/autoPyTorch/pipeline/base_pipeline.py +++ b/autoPyTorch/pipeline/base_pipeline.py @@ -21,6 +21,7 @@ get_match_array ) from autoPyTorch.utils.common import FitRequirement +from autoPyTorch.utils.hyperparameter_search_space_update import HyperparameterSearchSpaceUpdates class BasePipeline(Pipeline): @@ -52,14 +53,15 @@ class BasePipeline(Pipeline): __metaclass__ = ABCMeta def __init__( - self, - config: Optional[Configuration] = None, - steps: Optional[List[Tuple[str, autoPyTorchChoice]]] = None, - dataset_properties: Optional[Dict[str, Any]] = None, - include: Optional[Dict[str, Any]] = None, - exclude: Optional[Dict[str, Any]] = None, - random_state: Optional[np.random.RandomState] = None, - init_params: Optional[Dict[str, Any]] = None + self, + config: Optional[Configuration] = None, + steps: Optional[List[Tuple[str, autoPyTorchChoice]]] = None, + dataset_properties: Optional[Dict[str, Any]] = None, + include: Optional[Dict[str, Any]] = None, + exclude: Optional[Dict[str, Any]] = None, + random_state: Optional[np.random.RandomState] = None, + init_params: Optional[Dict[str, Any]] = None, + search_space_updates: Optional[HyperparameterSearchSpaceUpdates] = None ): self.init_params = init_params if init_params is not None else {} @@ -67,6 +69,7 @@ def __init__( dataset_properties is not None else {} self.include = include if include is not None else {} self.exclude = exclude if exclude is not None else {} + self.search_space_updates = search_space_updates if steps is None: self.steps = self._get_pipeline_steps(dataset_properties) @@ -138,9 +141,9 @@ def predict(self, X: np.ndarray, batch_size: Optional[int] = None return self.named_steps['network'].predict(loader) def set_hyperparameters( - self, - configuration: Configuration, - init_params: Optional[Dict] = None + self, + configuration: Configuration, + init_params: Optional[Dict] = None ) -> 'Pipeline': """Method to set the hyperparameter configuration of the pipeline. @@ -158,7 +161,12 @@ def set_hyperparameters( for node_idx, n_ in enumerate(self.steps): node_name, node = n_ - sub_configuration_space = node.get_hyperparameter_search_space(self.dataset_properties) + updates: Dict[str, Any] = {} + if not isinstance(node, autoPyTorchChoice): + updates = node._get_search_space_updates() + + sub_configuration_space = node.get_hyperparameter_search_space(self.dataset_properties, + **updates) sub_config_dict = {} for param in configuration: if param.startswith('%s:' % node_name): @@ -250,12 +258,12 @@ def __repr__(self) -> str: return string def _get_base_search_space( - self, - cs: ConfigurationSpace, - dataset_properties: Dict[str, Any], - include: Optional[Dict[str, Any]], - exclude: Optional[Dict[str, Any]], - pipeline: List[Tuple[str, autoPyTorchChoice]] + self, + cs: ConfigurationSpace, + dataset_properties: Dict[str, Any], + include: Optional[Dict[str, Any]], + exclude: Optional[Dict[str, Any]], + pipeline: List[Tuple[str, autoPyTorchChoice]] ) -> ConfigurationSpace: if include is None: if self.include is None: @@ -281,6 +289,11 @@ def _get_base_search_space( raise ValueError('Invalid key in exclude: %s; should be one ' 'of %s' % (key, keys)) + if self.search_space_updates is not None: + self._check_search_space_updates(include=include, + exclude=exclude) + self.search_space_updates.apply(pipeline=pipeline) + matches = get_match_array( pipeline, dataset_properties, include=include, exclude=exclude) @@ -302,9 +315,12 @@ def _get_base_search_space( # if the node isn't a choice we can add it immediately because it # must be active (if it wasn't, np.sum(matches) would be zero if not is_choice: + # for mypy + assert not isinstance(node, autoPyTorchChoice) cs.add_configuration_space( node_name, - node.get_hyperparameter_search_space(dataset_properties), + node.get_hyperparameter_search_space(dataset_properties, # type: ignore[arg-type] + **node._get_search_space_updates()), ) # If the node is a choice, we have to figure out which of its # choices are actually legal choices @@ -329,6 +345,65 @@ def _get_base_search_space( return cs + def _check_search_space_updates(self, include: Optional[Dict[str, Any]], + exclude: Optional[Dict[str, Any]]) -> None: + assert self.search_space_updates is not None + for update in self.search_space_updates.updates: + if update.node_name not in self.named_steps.keys(): + raise ValueError("Unknown node name. Expected update node name to be in {} " + "got {}".format(self.named_steps.keys(), update.node_name)) + node = self.named_steps[update.node_name] + if hasattr(node, 'get_components'): + split_hyperparameter = update.hyperparameter.split(':') + + if include is not None and update.node_name in include.keys(): + if split_hyperparameter[0] not in include[update.node_name]: + raise ValueError("Not found {} in include".format(split_hyperparameter[0])) + + if exclude is not None and update.node_name in exclude.keys(): + if split_hyperparameter[0] in exclude[update.node_name]: + raise ValueError("Found {} in exclude".format(split_hyperparameter[0])) + + components = node.get_components() + if split_hyperparameter[0] not in components.keys(): + raise ValueError("Unknown hyperparameter for choice {}. " + "Expected update hyperparameter " + "to be in {} got {}".format(node.__class__.__name__, + components.keys(), split_hyperparameter[0])) + else: + component = components[split_hyperparameter[0]] + if split_hyperparameter[1] not in component. \ + get_hyperparameter_search_space(dataset_properties=self.dataset_properties): + # Check if update hyperparameter is in names of + # hyperparameters of the search space + # Example 'num_units' in 'num_units_1', 'num_units_2' + if any([split_hyperparameter[1] in name for name in + component.get_hyperparameter_search_space( + dataset_properties=self.dataset_properties).get_hyperparameter_names()]): + continue + raise ValueError("Unknown hyperparameter for component {}. " + "Expected update hyperparameter " + "to be in {} got {}".format(node.__class__.__name__, + component. + get_hyperparameter_search_space( + dataset_properties=self.dataset_properties). + get_hyperparameter_names(), + split_hyperparameter[1])) + else: + if update.hyperparameter not in node.get_hyperparameter_search_space( + dataset_properties=self.dataset_properties): + if any([update.hyperparameter in name for name in + node.get_hyperparameter_search_space( + dataset_properties=self.dataset_properties).get_hyperparameter_names()]): + continue + raise ValueError("Unknown hyperparameter for component {}. " + "Expected update hyperparameter " + "to be in {} got {}".format(node.__class__.__name__, + node. + get_hyperparameter_search_space( + dataset_properties=self.dataset_properties). + get_hyperparameter_names(), update.hyperparameter)) + def _get_pipeline_steps(self, dataset_properties: Optional[Dict[str, Any]] ) -> List[Tuple[str, autoPyTorchChoice]]: """ diff --git a/autoPyTorch/pipeline/components/base_choice.py b/autoPyTorch/pipeline/components/base_choice.py index 0cf61a12f..b6fffbaf2 100644 --- a/autoPyTorch/pipeline/components/base_choice.py +++ b/autoPyTorch/pipeline/components/base_choice.py @@ -1,6 +1,6 @@ import warnings from collections import OrderedDict -from typing import Any, Dict, List, Optional +from typing import Any, Dict, List, Optional, Tuple, Union from ConfigSpace.configuration_space import Configuration, ConfigurationSpace @@ -49,6 +49,8 @@ def __init__(self, # self.set_hyperparameters(self.configuration) self.choice: Optional[autoPyTorchComponent] = None + self._cs_updates: Dict[str, Tuple] = dict() + def get_fit_requirements(self) -> Optional[List[FitRequirement]]: if self.choice is not None: return self.choice.get_fit_requirements() @@ -244,3 +246,36 @@ def _check_dataset_properties(self, dataset_properties: Dict[str, Any]) -> None: """ assert isinstance(dataset_properties, dict), "dataset_properties must be a dictionary" + + def _apply_search_space_update(self, name: str, new_value_range: Union[List, Tuple], + default_value: Union[int, float, str], log: bool = False) -> None: + """Allows the user to update a hyperparameter + + Arguments: + name {string} -- name of hyperparameter + new_value_range {List[?] -- value range can be either lower, upper or a list of possible conditionals + log {bool} -- is hyperparameter logscale + """ + + if len(new_value_range) == 0: + raise ValueError("The new value range needs at least one value") + self._cs_updates[name] = tuple([new_value_range, default_value, log]) + + def _get_search_space_updates(self, prefix: Optional[str] = None) -> Dict[str, Tuple]: + """Get the search space updates with the given prefix + + Keyword Arguments: + prefix {str} -- Only return search space updates with given prefix (default: {None}) + + Returns: + dict -- Mapping of search space updates. Keys don't contain the prefix. + """ + if prefix is None: + return self._cs_updates + result: Dict[str, Tuple] = dict() + + # iterate over all search space updates of this node and filter the ones out, that have the given prefix + for key in self._cs_updates.keys(): + if key.startswith(prefix): + result[key[len(prefix) + 1:]] = self._cs_updates[key] + return result diff --git a/autoPyTorch/pipeline/components/base_component.py b/autoPyTorch/pipeline/components/base_component.py index 7918bca2f..80fb1ddd0 100644 --- a/autoPyTorch/pipeline/components/base_component.py +++ b/autoPyTorch/pipeline/components/base_component.py @@ -4,7 +4,7 @@ import sys import warnings from collections import OrderedDict -from typing import Any, Dict, List, Optional +from typing import Any, Dict, List, Optional, Tuple, Union from ConfigSpace.configuration_space import Configuration, ConfigurationSpace @@ -91,12 +91,12 @@ def add_component(self, obj: BaseEstimator) -> None: class autoPyTorchComponent(BaseEstimator): - _required_properties: Optional[List[str]] = None def __init__(self) -> None: super().__init__() self._fit_requirements: List[FitRequirement] = list() + self._cs_updates: Dict[str, Tuple] = dict() @classmethod def get_required_properties(cls) -> Optional[List[str]]: @@ -252,3 +252,37 @@ def __str__(self) -> str: """Representation of the current Component""" name = self.get_properties()['name'] return "autoPyTorch.pipeline %s" % name + + def _apply_search_space_update(self, name: str, new_value_range: Union[List, Tuple], + default_value: Union[int, float, str], log: bool = False) -> None: + """Allows the user to update a hyperparameter + + Arguments: + name {string} -- name of hyperparameter + new_value_range {List[?] -- value range can be either lower, upper or a list of possible conditionals + log {bool} -- is hyperparameter logscale + """ + + if len(new_value_range) == 0: + raise ValueError("The new value range needs at least one value") + self._cs_updates[name] = tuple([new_value_range, default_value, log]) + + def _get_search_space_updates(self, prefix: Optional[str] = None) -> Dict[str, Tuple]: + """Get the search space updates with the given prefix + + Keyword Arguments: + prefix {str} -- Only return search space updates with given prefix (default: {None}) + + Returns: + dict -- Mapping of search space updates. Keys don't contain the prefix. + """ + if prefix is None: + return self._cs_updates + result: Dict[str, Tuple] = dict() + + # iterate over all search space updates of this node and keep the ones that have the given prefix + for key in self._cs_updates.keys(): + if key.startswith(prefix): + # different for autopytorch component as the hyperparameter + result[key[len(prefix):]] = self._cs_updates[key] + return result diff --git a/autoPyTorch/pipeline/components/preprocessing/tabular_preprocessing/imputation/SimpleImputer.py b/autoPyTorch/pipeline/components/preprocessing/tabular_preprocessing/imputation/SimpleImputer.py index 4ae9d8d40..95d89726c 100644 --- a/autoPyTorch/pipeline/components/preprocessing/tabular_preprocessing/imputation/SimpleImputer.py +++ b/autoPyTorch/pipeline/components/preprocessing/tabular_preprocessing/imputation/SimpleImputer.py @@ -1,4 +1,4 @@ -from typing import Any, Dict, Optional, Union +from typing import Any, Dict, Optional, Tuple, Union from ConfigSpace.configuration_space import ConfigurationSpace from ConfigSpace.hyperparameters import ( @@ -57,20 +57,27 @@ def fit(self, X: Dict[str, Any], y: Any = None) -> BaseImputer: return self @staticmethod - def get_hyperparameter_search_space(dataset_properties: Optional[Dict[str, Any]] = None) -> ConfigurationSpace: + def get_hyperparameter_search_space(dataset_properties: Optional[Dict[str, Any]] = None, + numerical_strategy: Tuple[Tuple, str] = (("mean", "median", + "most_frequent", "constant_zero"), + "mean"), + categorical_strategy: Tuple[Tuple, str] = (("most_frequent", + "constant_!missing!"), + "most_frequent") + ) -> ConfigurationSpace: cs = ConfigurationSpace() assert dataset_properties is not None, "To create hyperparameter search space" \ ", dataset_properties should not be None" if len(dataset_properties['numerical_columns']) != 0: numerical_strategy = CategoricalHyperparameter("numerical_strategy", - ["mean", "median", "most_frequent", "constant_zero"], - default_value="mean") + numerical_strategy[0], + default_value=numerical_strategy[1]) cs.add_hyperparameter(numerical_strategy) if len(dataset_properties['categorical_columns']) != 0: categorical_strategy = CategoricalHyperparameter("categorical_strategy", - ["most_frequent", "constant_!missing!"], - default_value="most_frequent") + categorical_strategy[0], + default_value=categorical_strategy[1]) cs.add_hyperparameter(categorical_strategy) return cs diff --git a/autoPyTorch/pipeline/components/preprocessing/tabular_preprocessing/scaling/Normalizer.py b/autoPyTorch/pipeline/components/preprocessing/tabular_preprocessing/scaling/Normalizer.py index a28791b2c..f0e6dc0ff 100644 --- a/autoPyTorch/pipeline/components/preprocessing/tabular_preprocessing/scaling/Normalizer.py +++ b/autoPyTorch/pipeline/components/preprocessing/tabular_preprocessing/scaling/Normalizer.py @@ -1,4 +1,4 @@ -from typing import Any, Dict, Optional, Union +from typing import Any, Dict, Optional, Tuple, Union from ConfigSpace.configuration_space import ConfigurationSpace from ConfigSpace.hyperparameters import ( @@ -29,7 +29,6 @@ def __init__(self, random_state: Optional[Union[np.random.RandomState, int]] = N self.norm = norm def fit(self, X: Dict[str, Any], y: Any = None) -> BaseScaler: - self.check_requirements(X, y) map_norm = dict({"mean_abs": "l1", "mean_squared": "l2", "max": "max"}) @@ -37,9 +36,12 @@ def fit(self, X: Dict[str, Any], y: Any = None) -> BaseScaler: return self @staticmethod - def get_hyperparameter_search_space(dataset_properties: Optional[Dict[str, Any]] = None) -> ConfigurationSpace: + def get_hyperparameter_search_space(dataset_properties: Optional[Dict[str, Any]] = None, + norm: Tuple[Tuple, str] = (("mean_abs", "mean_squared", "max"), + "mean_squared") + ) -> ConfigurationSpace: cs = ConfigurationSpace() - norm = CategoricalHyperparameter("norm", ["mean_abs", "mean_squared", "max"], default_value="mean_squared") + norm = CategoricalHyperparameter("norm", norm[0], default_value=norm[1]) cs.add_hyperparameter(norm) return cs diff --git a/autoPyTorch/pipeline/components/preprocessing/tabular_preprocessing/scaling/base_scaler_choice.py b/autoPyTorch/pipeline/components/preprocessing/tabular_preprocessing/scaling/base_scaler_choice.py index 0c3357026..718c80d39 100644 --- a/autoPyTorch/pipeline/components/preprocessing/tabular_preprocessing/scaling/base_scaler_choice.py +++ b/autoPyTorch/pipeline/components/preprocessing/tabular_preprocessing/scaling/base_scaler_choice.py @@ -57,17 +57,17 @@ def get_hyperparameter_search_space(self, dataset_properties = {**self.dataset_properties, **dataset_properties} - available_preprocessors = self.get_available_components(dataset_properties=dataset_properties, - include=include, - exclude=exclude) + available_scalers = self.get_available_components(dataset_properties=dataset_properties, + include=include, + exclude=exclude) - if len(available_preprocessors) == 0: + if len(available_scalers) == 0: raise ValueError("no scalers found, please add a scaler") if default is None: defaults = ['StandardScaler', 'Normalizer', 'MinMaxScaler', 'NoScaler'] for default_ in defaults: - if default_ in available_preprocessors: + if default_ in available_scalers: default = default_ break @@ -79,16 +79,17 @@ def get_hyperparameter_search_space(self, default_value=default) else: preprocessor = CSH.CategoricalHyperparameter('__choice__', - list(available_preprocessors.keys()), + list(available_scalers.keys()), default_value=default) cs.add_hyperparameter(preprocessor) # add only child hyperparameters of early_preprocessor choices for name in preprocessor.choices: - preprocessor_configuration_space = available_preprocessors[name].\ - get_hyperparameter_search_space(dataset_properties) + updates = self._get_search_space_updates(prefix=name) + config_space = available_scalers[name].get_hyperparameter_search_space(dataset_properties, # type:ignore + **updates) parent_hyperparameter = {'parent': preprocessor, 'value': name} - cs.add_configuration_space(name, preprocessor_configuration_space, + cs.add_configuration_space(name, config_space, parent_hyperparameter=parent_hyperparameter) self.configuration_space = cs @@ -105,5 +106,6 @@ def _check_dataset_properties(self, dataset_properties: Dict[str, Any]) -> None: """ super()._check_dataset_properties(dataset_properties) - assert 'numerical_columns' in dataset_properties.keys() and 'categorical_columns' in dataset_properties.keys(),\ + assert 'numerical_columns' in dataset_properties.keys() and \ + 'categorical_columns' in dataset_properties.keys(), \ "Dataset properties must contain information about the type of columns" diff --git a/autoPyTorch/pipeline/components/setup/augmentation/image/GaussianBlur.py b/autoPyTorch/pipeline/components/setup/augmentation/image/GaussianBlur.py index 6a4ab0c27..e977b0633 100644 --- a/autoPyTorch/pipeline/components/setup/augmentation/image/GaussianBlur.py +++ b/autoPyTorch/pipeline/components/setup/augmentation/image/GaussianBlur.py @@ -1,4 +1,4 @@ -from typing import Any, Dict, Optional, Union +from typing import Any, Dict, Optional, Tuple, Union import ConfigSpace as CS from ConfigSpace.configuration_space import ConfigurationSpace @@ -30,13 +30,19 @@ def fit(self, X: Dict[str, Any], y: Any = None) -> BaseImageAugmenter: @staticmethod def get_hyperparameter_search_space( - dataset_properties: Optional[Dict[str, str]] = None + dataset_properties: Optional[Dict[str, str]] = None, + use_augmenter: Tuple[Tuple, bool] = ((True, False), True), + sigma_min: Tuple[Tuple, int] = ((0, 3), 0), + sigma_offset: Tuple[Tuple, float] = ((0.0, 3.0), 0.5), ) -> ConfigurationSpace: cs = ConfigurationSpace() - use_augmenter = CategoricalHyperparameter('use_augmenter', choices=[True, False], default_value=True) - sigma_min = UniformFloatHyperparameter('sigma_min', lower=0, upper=3, default_value=0) - sigma_offset = UniformFloatHyperparameter('sigma_offset', lower=0, upper=3, default_value=0.5) + use_augmenter = CategoricalHyperparameter('use_augmenter', choices=use_augmenter[0], + default_value=use_augmenter[1]) + sigma_min = UniformFloatHyperparameter('sigma_min', lower=sigma_min[0][0], upper=sigma_min[0][1], + default_value=0) + sigma_offset = UniformFloatHyperparameter('sigma_offset', lower=sigma_offset[0][0], upper=sigma_offset[0][1], + default_value=0.5) cs.add_hyperparameters([use_augmenter, sigma_min, sigma_offset]) # only add hyperparameters to configuration space if we are using the augmenter diff --git a/autoPyTorch/pipeline/components/setup/augmentation/image/GaussianNoise.py b/autoPyTorch/pipeline/components/setup/augmentation/image/GaussianNoise.py index 3f5be3173..064c564cd 100644 --- a/autoPyTorch/pipeline/components/setup/augmentation/image/GaussianNoise.py +++ b/autoPyTorch/pipeline/components/setup/augmentation/image/GaussianNoise.py @@ -1,4 +1,4 @@ -from typing import Any, Dict, Optional, Union +from typing import Any, Dict, Optional, Tuple, Union import ConfigSpace as CS from ConfigSpace.configuration_space import ConfigurationSpace @@ -29,12 +29,16 @@ def fit(self, X: Dict[str, Any], y: Any = None) -> BaseImageAugmenter: @staticmethod def get_hyperparameter_search_space( - dataset_properties: Optional[Dict[str, str]] = None + dataset_properties: Optional[Dict[str, str]] = None, + use_augmenter: Tuple[Tuple[bool, bool], bool] = ((True, False), True), + sigma_offset: Tuple[Tuple[float, float], float] = ((0.0, 3.0), 0.3) ) -> ConfigurationSpace: cs = ConfigurationSpace() - sigma_offset = UniformFloatHyperparameter('sigma_offset', lower=0, upper=3, default_value=0.3) - use_augmenter = CategoricalHyperparameter('use_augmenter', choices=[True, False], default_value=True) + sigma_offset = UniformFloatHyperparameter('sigma_offset', lower=sigma_offset[0][0], upper=sigma_offset[0][1], + default_value=sigma_offset[1]) + use_augmenter = CategoricalHyperparameter('use_augmenter', choices=use_augmenter[0], + default_value=use_augmenter[1]) cs.add_hyperparameters([use_augmenter, sigma_offset]) # only add hyperparameters to configuration space if we are using the augmenter cs.add_condition(CS.EqualsCondition(sigma_offset, use_augmenter, True)) diff --git a/autoPyTorch/pipeline/components/setup/augmentation/image/RandomAffine.py b/autoPyTorch/pipeline/components/setup/augmentation/image/RandomAffine.py index 01d6f16e5..a9797a6e5 100644 --- a/autoPyTorch/pipeline/components/setup/augmentation/image/RandomAffine.py +++ b/autoPyTorch/pipeline/components/setup/augmentation/image/RandomAffine.py @@ -1,4 +1,4 @@ -from typing import Any, Dict, Optional, Union +from typing import Any, Dict, Optional, Tuple, Union import ConfigSpace as CS from ConfigSpace.configuration_space import ConfigurationSpace @@ -37,18 +37,27 @@ def fit(self, X: Dict[str, Any], y: Any = None) -> BaseImageAugmenter: @staticmethod def get_hyperparameter_search_space( - dataset_properties: Optional[Dict[str, str]] = None + dataset_properties: Optional[Dict[str, str]] = None, + use_augmenter: Tuple[Tuple, bool] = ((True, False), True), + scale_offset: Tuple[Tuple, float] = ((0, 0.4), 0.2), + translate_percent_offset: Tuple[Tuple, float] = ((0, 0.4), 0.2), + shear: Tuple[Tuple, int] = ((0, 45), 30), + rotate: Tuple[Tuple, int] = ((0, 360), 45) ) -> ConfigurationSpace: cs = ConfigurationSpace() - scale_offset = UniformFloatHyperparameter('scale_offset', lower=0, upper=0.4, default_value=0.2) - - translate_percent_offset = UniformFloatHyperparameter('translate_percent_offset', lower=0, upper=0.4, - default_value=0.2) - shear = UniformIntegerHyperparameter('shear', lower=0, upper=45, default_value=30) + scale_offset = UniformFloatHyperparameter('scale_offset', lower=scale_offset[0][0], upper=scale_offset[0][1], + default_value=scale_offset[1]) + + translate_percent_offset = UniformFloatHyperparameter('translate_percent_offset', + lower=translate_percent_offset[0][0], + upper=translate_percent_offset[0][1], + default_value=translate_percent_offset[1]) + shear = UniformIntegerHyperparameter('shear', lower=shear[0][0], upper=shear[0][1], default_value=shear[1]) rotate = UniformIntegerHyperparameter('rotate', lower=0, upper=360, default_value=45) - use_augmenter = CategoricalHyperparameter('use_augmenter', choices=[True, False], default_value=True) + use_augmenter = CategoricalHyperparameter('use_augmenter', choices=use_augmenter[0], + default_value=use_augmenter[1]) cs.add_hyperparameters([scale_offset, translate_percent_offset]) cs.add_hyperparameters([shear, rotate, use_augmenter]) diff --git a/autoPyTorch/pipeline/components/setup/augmentation/image/RandomCutout.py b/autoPyTorch/pipeline/components/setup/augmentation/image/RandomCutout.py index 4a12bbdef..332efbb41 100644 --- a/autoPyTorch/pipeline/components/setup/augmentation/image/RandomCutout.py +++ b/autoPyTorch/pipeline/components/setup/augmentation/image/RandomCutout.py @@ -1,4 +1,4 @@ -from typing import Any, Dict, Optional, Union +from typing import Any, Dict, Optional, Tuple, Union import ConfigSpace as CS from ConfigSpace.configuration_space import ConfigurationSpace @@ -31,12 +31,14 @@ def fit(self, X: Dict[str, Any], y: Any = None) -> BaseImageAugmenter: @staticmethod def get_hyperparameter_search_space( - dataset_properties: Optional[Dict[str, str]] = None + dataset_properties: Optional[Dict[str, str]] = None, + use_augmenter: Tuple[Tuple, bool] = ((True, False), True), + p: Tuple[Tuple, float] = ((0.2, 1.0), 0.5) ) -> ConfigurationSpace: - cs = ConfigurationSpace() - p = UniformFloatHyperparameter('p', lower=0.2, upper=1, default_value=0.5) - use_augmenter = CategoricalHyperparameter('use_augmenter', choices=[True, False], default_value=True) + p = UniformFloatHyperparameter('p', lower=p[0][0], upper=p[0][1], default_value=p[1]) + use_augmenter = CategoricalHyperparameter('use_augmenter', choices=use_augmenter[0], + default_value=use_augmenter[1]) cs.add_hyperparameters([p, use_augmenter]) # only add hyperparameters to configuration space if we are using the augmenter diff --git a/autoPyTorch/pipeline/components/setup/augmentation/image/Resize.py b/autoPyTorch/pipeline/components/setup/augmentation/image/Resize.py index 7ee10d8d3..eabafa95b 100644 --- a/autoPyTorch/pipeline/components/setup/augmentation/image/Resize.py +++ b/autoPyTorch/pipeline/components/setup/augmentation/image/Resize.py @@ -1,4 +1,4 @@ -from typing import Any, Dict, Optional, Union +from typing import Any, Dict, Optional, Tuple, Union from ConfigSpace.configuration_space import ConfigurationSpace from ConfigSpace.hyperparameters import ( @@ -35,10 +35,12 @@ def fit(self, X: Dict[str, Any], y: Any = None) -> BaseImageAugmenter: @staticmethod def get_hyperparameter_search_space( - dataset_properties: Optional[Dict[str, str]] = None + dataset_properties: Optional[Dict[str, str]] = None, + use_augmenter: Tuple[Tuple, bool] = ((True, False), True), ) -> ConfigurationSpace: cs = ConfigurationSpace() - use_augmenter = CategoricalHyperparameter('use_augmenter', choices=[True, False], default_value=True) + use_augmenter = CategoricalHyperparameter('use_augmenter', choices=use_augmenter[0], + default_value=use_augmenter[1]) cs.add_hyperparameters([use_augmenter]) return cs diff --git a/autoPyTorch/pipeline/components/setup/augmentation/image/ZeroPadAndCrop.py b/autoPyTorch/pipeline/components/setup/augmentation/image/ZeroPadAndCrop.py index bf8ce63fe..257c0a550 100644 --- a/autoPyTorch/pipeline/components/setup/augmentation/image/ZeroPadAndCrop.py +++ b/autoPyTorch/pipeline/components/setup/augmentation/image/ZeroPadAndCrop.py @@ -1,4 +1,4 @@ -from typing import Any, Dict, Optional, Union +from typing import Any, Dict, Optional, Tuple, Union from ConfigSpace.configuration_space import ConfigurationSpace from ConfigSpace.hyperparameters import ( @@ -41,11 +41,13 @@ def fit(self, X: Dict[str, Any], y: Any = None) -> BaseImageAugmenter: @staticmethod def get_hyperparameter_search_space( - dataset_properties: Optional[Dict[str, str]] = None + dataset_properties: Optional[Dict[str, str]] = None, + percent: Tuple[Tuple, float] = ((0, 0.5), 0.1) ) -> ConfigurationSpace: cs = ConfigurationSpace() - percent = UniformFloatHyperparameter('percent', lower=0, upper=0.5, default_value=0.1) + percent = UniformFloatHyperparameter('percent', lower=percent[0][0], upper=percent[0][1], + default_value=percent[1]) cs.add_hyperparameters([percent]) return cs diff --git a/autoPyTorch/pipeline/components/setup/early_preprocessor/EarlyPreprocessing.py b/autoPyTorch/pipeline/components/setup/early_preprocessor/EarlyPreprocessing.py index cd97a9ba1..5d6def24a 100644 --- a/autoPyTorch/pipeline/components/setup/early_preprocessor/EarlyPreprocessing.py +++ b/autoPyTorch/pipeline/components/setup/early_preprocessor/EarlyPreprocessing.py @@ -47,7 +47,8 @@ def transform(self, X: Dict[str, Any]) -> Dict[str, Any]: @staticmethod def get_hyperparameter_search_space( - dataset_properties: Optional[Dict[str, str]] = None + dataset_properties: Optional[Dict[str, str]] = None, + **kwargs: Any ) -> ConfigurationSpace: return ConfigurationSpace() diff --git a/autoPyTorch/pipeline/components/setup/lr_scheduler/CosineAnnealingLR.py b/autoPyTorch/pipeline/components/setup/lr_scheduler/CosineAnnealingLR.py index 870302747..9cbbaa41d 100644 --- a/autoPyTorch/pipeline/components/setup/lr_scheduler/CosineAnnealingLR.py +++ b/autoPyTorch/pipeline/components/setup/lr_scheduler/CosineAnnealingLR.py @@ -1,4 +1,4 @@ -from typing import Any, Dict, Optional +from typing import Any, Dict, Optional, Tuple from ConfigSpace.configuration_space import ConfigurationSpace from ConfigSpace.hyperparameters import ( @@ -61,10 +61,11 @@ def get_properties(dataset_properties: Optional[Dict[str, Any]] = None) -> Dict[ } @staticmethod - def get_hyperparameter_search_space(dataset_properties: Optional[Dict] = None + def get_hyperparameter_search_space(dataset_properties: Optional[Dict] = None, + T_max: Tuple[Tuple[int, int], int] = ((10, 500), 200) ) -> ConfigurationSpace: T_max = UniformIntegerHyperparameter( - "T_max", 10, 500, default_value=200) + "T_max", T_max[0][0], T_max[0][1], default_value=T_max[1]) cs = ConfigurationSpace() cs.add_hyperparameters([T_max]) return cs diff --git a/autoPyTorch/pipeline/components/setup/lr_scheduler/CosineAnnealingWarmRestarts.py b/autoPyTorch/pipeline/components/setup/lr_scheduler/CosineAnnealingWarmRestarts.py index fea2d30c4..07c9ffc2a 100644 --- a/autoPyTorch/pipeline/components/setup/lr_scheduler/CosineAnnealingWarmRestarts.py +++ b/autoPyTorch/pipeline/components/setup/lr_scheduler/CosineAnnealingWarmRestarts.py @@ -1,4 +1,4 @@ -from typing import Any, Dict, Optional +from typing import Any, Dict, Optional, Tuple from ConfigSpace.configuration_space import ConfigurationSpace from ConfigSpace.hyperparameters import UniformFloatHyperparameter, UniformIntegerHyperparameter @@ -67,12 +67,14 @@ def get_properties(dataset_properties: Optional[Dict[str, Any]] = None) -> Dict[ } @staticmethod - def get_hyperparameter_search_space(dataset_properties: Optional[Dict] = None + def get_hyperparameter_search_space(dataset_properties: Optional[Dict] = None, + T_0: Tuple[Tuple[int, int], int] = ((1, 20), 1), + T_mult: Tuple[Tuple[float, float], float] = ((1.0, 2.0), 1.0) ) -> ConfigurationSpace: T_0 = UniformIntegerHyperparameter( - "T_0", 1, 20, default_value=1) + "T_0", T_0[0][0], T_0[0][1], default_value=T_0[1]) T_mult = UniformFloatHyperparameter( - "T_mult", 1.0, 2.0, default_value=1.0) + "T_mult", T_mult[0][0], T_mult[0][1], default_value=T_mult[1]) cs = ConfigurationSpace() cs.add_hyperparameters([T_0, T_mult]) return cs diff --git a/autoPyTorch/pipeline/components/setup/lr_scheduler/CyclicLR.py b/autoPyTorch/pipeline/components/setup/lr_scheduler/CyclicLR.py index 8bbf5c237..bc6e4e3ff 100644 --- a/autoPyTorch/pipeline/components/setup/lr_scheduler/CyclicLR.py +++ b/autoPyTorch/pipeline/components/setup/lr_scheduler/CyclicLR.py @@ -1,4 +1,4 @@ -from typing import Any, Dict, Optional +from typing import Any, Dict, Optional, Tuple from ConfigSpace.configuration_space import ConfigurationSpace from ConfigSpace.hyperparameters import ( @@ -86,15 +86,20 @@ def get_properties(dataset_properties: Optional[Dict[str, Any]] = None) -> Dict[ } @staticmethod - def get_hyperparameter_search_space(dataset_properties: Optional[Dict] = None + def get_hyperparameter_search_space(dataset_properties: Optional[Dict] = None, + base_lr: Tuple[Tuple, float] = ((1e-6, 1e-1), 0.01), + mode: Tuple[Tuple, str] = (('triangular', 'triangular2', 'exp_range'), + 'triangular'), + step_size_up: Tuple[Tuple, int] = ((1000, 4000), 2000), + max_lr: Tuple[Tuple, float] = ((1e-3, 1e-1), 0.1) ) -> ConfigurationSpace: base_lr = UniformFloatHyperparameter( - "base_lr", 1e-6, 1e-1, default_value=0.01) - mode = CategoricalHyperparameter('mode', ['triangular', 'triangular2', 'exp_range']) + "base_lr", base_lr[0][0], base_lr[0][1], default_value=base_lr[1]) + mode = CategoricalHyperparameter('mode', choices=mode[0], default_value=mode[1]) step_size_up = UniformIntegerHyperparameter( - "step_size_up", 1000, 4000, default_value=2000) + "step_size_up", step_size_up[0][0], step_size_up[0][1], default_value=step_size_up[1]) max_lr = UniformFloatHyperparameter( - "max_lr", 1e-3, 1e-1, default_value=0.1) + "max_lr", max_lr[0][0], max_lr[0][1], default_value=max_lr[1]) cs = ConfigurationSpace() cs.add_hyperparameters([base_lr, mode, step_size_up, max_lr]) return cs diff --git a/autoPyTorch/pipeline/components/setup/lr_scheduler/ExponentialLR.py b/autoPyTorch/pipeline/components/setup/lr_scheduler/ExponentialLR.py index 0e5584da7..d090d710b 100644 --- a/autoPyTorch/pipeline/components/setup/lr_scheduler/ExponentialLR.py +++ b/autoPyTorch/pipeline/components/setup/lr_scheduler/ExponentialLR.py @@ -1,4 +1,4 @@ -from typing import Any, Dict, Optional +from typing import Any, Dict, Optional, Tuple from ConfigSpace.configuration_space import ConfigurationSpace from ConfigSpace.hyperparameters import ( @@ -62,10 +62,11 @@ def get_properties(dataset_properties: Optional[Dict[str, Any]] = None) -> Dict[ } @staticmethod - def get_hyperparameter_search_space(dataset_properties: Optional[Dict] = None + def get_hyperparameter_search_space(dataset_properties: Optional[Dict] = None, + gamma: Tuple[Tuple, float] = ((0.7, 0.9999), 0.9) ) -> ConfigurationSpace: gamma = UniformFloatHyperparameter( - "gamma", 0.7, 0.9999, default_value=0.9) + "gamma", gamma[0][0], gamma[0][1], default_value=gamma[1]) cs = ConfigurationSpace() cs.add_hyperparameters([gamma]) return cs diff --git a/autoPyTorch/pipeline/components/setup/lr_scheduler/ReduceLROnPlateau.py b/autoPyTorch/pipeline/components/setup/lr_scheduler/ReduceLROnPlateau.py index 0eb9dcbff..bd8c9c97a 100644 --- a/autoPyTorch/pipeline/components/setup/lr_scheduler/ReduceLROnPlateau.py +++ b/autoPyTorch/pipeline/components/setup/lr_scheduler/ReduceLROnPlateau.py @@ -1,4 +1,4 @@ -from typing import Any, Dict, Optional +from typing import Any, Dict, Optional, Tuple from ConfigSpace.configuration_space import ConfigurationSpace from ConfigSpace.hyperparameters import ( @@ -77,13 +77,16 @@ def get_properties(dataset_properties: Optional[Dict[str, Any]] = None) -> Dict[ } @staticmethod - def get_hyperparameter_search_space(dataset_properties: Optional[Dict] = None + def get_hyperparameter_search_space(dataset_properties: Optional[Dict] = None, + mode: Tuple[Tuple, str] = (('min', 'max'), 'min'), + patience: Tuple[Tuple, int] = ((5, 20), 10), + factor: Tuple[Tuple[float, float], float] = ((0.01, 0.9), 0.1) ) -> ConfigurationSpace: - mode = CategoricalHyperparameter('mode', ['min', 'max']) + mode = CategoricalHyperparameter('mode', choices=mode[0], default_value=mode[1]) patience = UniformIntegerHyperparameter( - "patience", 5, 20, default_value=10) + "patience", patience[0][0], patience[0][1], default_value=patience[1]) factor = UniformFloatHyperparameter( - "factor", 0.01, 0.9, default_value=0.1) + "factor", factor[0][0], factor[0][1], default_value=factor[1]) cs = ConfigurationSpace() cs.add_hyperparameters([mode, patience, factor]) return cs diff --git a/autoPyTorch/pipeline/components/setup/lr_scheduler/StepLR.py b/autoPyTorch/pipeline/components/setup/lr_scheduler/StepLR.py index 8c94a38b6..035e4e841 100644 --- a/autoPyTorch/pipeline/components/setup/lr_scheduler/StepLR.py +++ b/autoPyTorch/pipeline/components/setup/lr_scheduler/StepLR.py @@ -1,4 +1,4 @@ -from typing import Any, Dict, Optional +from typing import Any, Dict, Optional, Tuple from ConfigSpace.configuration_space import ConfigurationSpace from ConfigSpace.hyperparameters import ( @@ -68,12 +68,14 @@ def get_properties(dataset_properties: Optional[Dict[str, Any]] = None) -> Dict[ } @staticmethod - def get_hyperparameter_search_space(dataset_properties: Optional[Dict] = None + def get_hyperparameter_search_space(dataset_properties: Optional[Dict] = None, + gamma: Tuple[Tuple, float] = ((0.001, 0.9), 0.1), + step_size: Tuple[Tuple, int] = ((1, 10), 5) ) -> ConfigurationSpace: gamma = UniformFloatHyperparameter( - "gamma", 0.001, 0.9, default_value=0.1) + "gamma", gamma[0][0], gamma[0][1], default_value=gamma[1]) step_size = UniformIntegerHyperparameter( - "step_size", 1, 10, default_value=5) + "step_size", step_size[0][0], step_size[0][1], default_value=step_size[1]) cs = ConfigurationSpace() cs.add_hyperparameters([gamma, step_size]) return cs diff --git a/autoPyTorch/pipeline/components/setup/lr_scheduler/base_scheduler_choice.py b/autoPyTorch/pipeline/components/setup/lr_scheduler/base_scheduler_choice.py index a12c8abad..44d1a7880 100644 --- a/autoPyTorch/pipeline/components/setup/lr_scheduler/base_scheduler_choice.py +++ b/autoPyTorch/pipeline/components/setup/lr_scheduler/base_scheduler_choice.py @@ -157,12 +157,13 @@ def get_hyperparameter_search_space( ) cs.add_hyperparameter(scheduler) for name in available_schedulers: - scheduler_configuration_space = available_schedulers[name]. \ - get_hyperparameter_search_space(dataset_properties) + updates = self._get_search_space_updates(prefix=name) + config_space = available_schedulers[name].get_hyperparameter_search_space(dataset_properties, # type:ignore + **updates) parent_hyperparameter = {'parent': scheduler, 'value': name} cs.add_configuration_space( name, - scheduler_configuration_space, + config_space, parent_hyperparameter=parent_hyperparameter ) diff --git a/autoPyTorch/pipeline/components/setup/network/base_network.py b/autoPyTorch/pipeline/components/setup/network/base_network.py index 2b8c27d1c..b40c7e774 100644 --- a/autoPyTorch/pipeline/components/setup/network/base_network.py +++ b/autoPyTorch/pipeline/components/setup/network/base_network.py @@ -123,7 +123,8 @@ def predict(self, loader: torch.utils.data.DataLoader) -> torch.Tensor: return torch.cat(Y_batch_preds, 0).cpu().numpy() @staticmethod - def get_hyperparameter_search_space(dataset_properties: Optional[Dict] = None + def get_hyperparameter_search_space(dataset_properties: Optional[Dict] = None, + **kwargs: Any ) -> ConfigurationSpace: cs = ConfigurationSpace() return cs diff --git a/autoPyTorch/pipeline/components/setup/network_backbone/MLPBackbone.py b/autoPyTorch/pipeline/components/setup/network_backbone/MLPBackbone.py index d261841eb..6bf7ec36e 100644 --- a/autoPyTorch/pipeline/components/setup/network_backbone/MLPBackbone.py +++ b/autoPyTorch/pipeline/components/setup/network_backbone/MLPBackbone.py @@ -74,11 +74,12 @@ def get_properties(dataset_properties: Optional[Dict[str, Any]] = None) -> Dict[ @staticmethod def get_hyperparameter_search_space(dataset_properties: Optional[Dict] = None, - min_mlp_layers: int = 1, - max_mlp_layers: int = 15, - dropout: bool = True, - min_num_units: int = 10, - max_num_units: int = 1024, + num_groups: Tuple[Tuple, int] = ((1, 15), 5), + activation: Tuple[Tuple, str] = (tuple(_activations.keys()), + list(_activations.keys())[0]), + use_dropout: Tuple[Tuple, bool] = ((True, False), False), + num_units: Tuple[Tuple, int] = ((10, 1024), 200), + dropout: Tuple[Tuple, float] = ((0, 0.8), 0.5) ) -> ConfigurationSpace: cs = ConfigurationSpace() @@ -86,26 +87,27 @@ def get_hyperparameter_search_space(dataset_properties: Optional[Dict] = None, # The number of hidden layers the network will have. # Layer blocks are meant to have the same architecture, differing only # by the number of units + min_mlp_layers, max_mlp_layers = num_groups[0] num_groups = UniformIntegerHyperparameter( - "num_groups", min_mlp_layers, max_mlp_layers, default_value=5) + "num_groups", min_mlp_layers, max_mlp_layers, default_value=num_groups[1]) activation = CategoricalHyperparameter( - "activation", choices=list(_activations.keys()) + "activation", choices=activation[0], + default_value=activation[1] ) cs.add_hyperparameters([num_groups, activation]) # We can have dropout in the network for # better generalization - if dropout: - use_dropout = CategoricalHyperparameter( - "use_dropout", choices=[True, False], default_value=False) - cs.add_hyperparameters([use_dropout]) + use_dropout = CategoricalHyperparameter( + "use_dropout", choices=use_dropout[0], default_value=use_dropout[1]) + cs.add_hyperparameters([use_dropout]) for i in range(1, max_mlp_layers + 1): n_units_hp = UniformIntegerHyperparameter("num_units_%d" % i, - lower=min_num_units, - upper=max_num_units, - default_value=200) + lower=num_units[0][0], + upper=num_units[0][1], + default_value=num_units[1]) cs.add_hyperparameter(n_units_hp) if i > min_mlp_layers: @@ -117,20 +119,19 @@ def get_hyperparameter_search_space(dataset_properties: Optional[Dict] = None, ) ) - if dropout: - dropout_hp = UniformFloatHyperparameter( - "dropout_%d" % i, - lower=0.0, - upper=0.8, - default_value=0.5 - ) - cs.add_hyperparameter(dropout_hp) - dropout_condition_1 = CS.EqualsCondition(dropout_hp, use_dropout, True) - - if i > min_mlp_layers: - dropout_condition_2 = CS.GreaterThanCondition(dropout_hp, num_groups, i - 1) - cs.add_condition(CS.AndConjunction(dropout_condition_1, dropout_condition_2)) - else: - cs.add_condition(dropout_condition_1) + dropout_hp = UniformFloatHyperparameter( + "dropout_%d" % i, + lower=dropout[0][0], + upper=dropout[0][1], + default_value=dropout[1] + ) + cs.add_hyperparameter(dropout_hp) + dropout_condition_1 = CS.EqualsCondition(dropout_hp, use_dropout, True) + + if i > min_mlp_layers: + dropout_condition_2 = CS.GreaterThanCondition(dropout_hp, num_groups, i - 1) + cs.add_condition(CS.AndConjunction(dropout_condition_1, dropout_condition_2)) + else: + cs.add_condition(dropout_condition_1) return cs diff --git a/autoPyTorch/pipeline/components/setup/network_backbone/ResNetBackbone.py b/autoPyTorch/pipeline/components/setup/network_backbone/ResNetBackbone.py index 9716f8aab..634aabee0 100644 --- a/autoPyTorch/pipeline/components/setup/network_backbone/ResNetBackbone.py +++ b/autoPyTorch/pipeline/components/setup/network_backbone/ResNetBackbone.py @@ -97,51 +97,65 @@ def get_properties(dataset_properties: Optional[Dict[str, Any]] = None) -> Dict[ @staticmethod def get_hyperparameter_search_space(dataset_properties: Optional[Dict] = None, - min_num_gropus: int = 1, - max_num_groups: int = 9, - min_blocks_per_groups: int = 1, - max_blocks_per_groups: int = 4, - min_num_units: int = 10, - max_num_units: int = 1024, + num_groups: Tuple[Tuple, int] = ((1, 15), 5), + use_dropout: Tuple[Tuple, bool] = ((True, False), False), + num_units: Tuple[Tuple, int] = ((10, 1024), 200), + activation: Tuple[Tuple, str] = (tuple(_activations.keys()), + list(_activations.keys())[0]), + blocks_per_group: Tuple[Tuple, int] = ((1, 4), 2), + dropout: Tuple[Tuple, float] = ((0, 0.8), 0.5), + use_shake_shake: Tuple[Tuple, bool] = ((True, False), True), + use_shake_drop: Tuple[Tuple, bool] = ((True, False), True), + max_shake_drop_probability: Tuple[Tuple, float] = ((0, 1), 0.5) ) -> ConfigurationSpace: cs = ConfigurationSpace() # The number of groups that will compose the resnet. That is, # a group can have N Resblock. The M number of this N resblock # repetitions is num_groups + min_num_gropus, max_num_groups = num_groups[0] num_groups = UniformIntegerHyperparameter( - "num_groups", lower=min_num_gropus, upper=max_num_groups, default_value=5) + "num_groups", lower=min_num_gropus, upper=max_num_groups, default_value=num_groups[1]) activation = CategoricalHyperparameter( - "activation", choices=list(_activations.keys()) + "activation", choices=activation[0], + default_value=activation[1] ) cs.add_hyperparameters([num_groups, activation]) # We can have dropout in the network for # better generalization - use_dropout = CategoricalHyperparameter( - "use_dropout", choices=[True, False]) + use_dropout = CategoricalHyperparameter("use_dropout", choices=use_dropout[0], default_value=use_dropout[1]) cs.add_hyperparameters([use_dropout]) - use_shake_shake = CategoricalHyperparameter("use_shake_shake", choices=[True, False]) - use_shake_drop = CategoricalHyperparameter("use_shake_drop", choices=[True, False]) + use_shake_shake = CategoricalHyperparameter("use_shake_shake", choices=use_shake_shake[0], + default_value=use_shake_shake[1]) + use_shake_drop = CategoricalHyperparameter("use_shake_drop", choices=use_shake_drop[0], + default_value=use_shake_drop[1]) shake_drop_prob = UniformFloatHyperparameter( - "max_shake_drop_probability", lower=0.0, upper=1.0) + "max_shake_drop_probability", + lower=max_shake_drop_probability[0][0], + upper=max_shake_drop_probability[0][1], + default_value=max_shake_drop_probability[1]) cs.add_hyperparameters([use_shake_shake, use_shake_drop, shake_drop_prob]) cs.add_condition(CS.EqualsCondition(shake_drop_prob, use_shake_drop, True)) # It is the upper bound of the nr of groups, # since the configuration will actually be sampled. + (min_blocks_per_group, max_blocks_per_group), default_blocks_per_group = blocks_per_group[:2] for i in range(0, max_num_groups + 1): n_units = UniformIntegerHyperparameter( "num_units_%d" % i, - lower=min_num_units, - upper=max_num_units, + lower=num_units[0][0], + upper=num_units[0][1], + default_value=num_units[1] ) blocks_per_group = UniformIntegerHyperparameter( - "blocks_per_group_%d" % i, lower=min_blocks_per_groups, - upper=max_blocks_per_groups) + "blocks_per_group_%d" % i, + lower=min_blocks_per_group, + upper=max_blocks_per_group, + default_value=default_blocks_per_group) cs.add_hyperparameters([n_units, blocks_per_group]) @@ -150,7 +164,10 @@ def get_hyperparameter_search_space(dataset_properties: Optional[Dict] = None, cs.add_condition(CS.GreaterThanCondition(blocks_per_group, num_groups, i - 1)) this_dropout = UniformFloatHyperparameter( - "dropout_%d" % i, lower=0.0, upper=1.0 + "dropout_%d" % i, + lower=dropout[0][0], + upper=dropout[0][1], + default_value=dropout[1] ) cs.add_hyperparameters([this_dropout]) diff --git a/autoPyTorch/pipeline/components/setup/network_backbone/ShapedMLPBackbone.py b/autoPyTorch/pipeline/components/setup/network_backbone/ShapedMLPBackbone.py index c7cef2fd6..607823430 100644 --- a/autoPyTorch/pipeline/components/setup/network_backbone/ShapedMLPBackbone.py +++ b/autoPyTorch/pipeline/components/setup/network_backbone/ShapedMLPBackbone.py @@ -79,10 +79,16 @@ def get_properties(dataset_properties: Optional[Dict[str, Any]] = None) -> Dict[ @staticmethod def get_hyperparameter_search_space(dataset_properties: Optional[Dict] = None, - min_num_gropus: int = 1, - max_num_groups: int = 15, - min_num_units: int = 10, - max_num_units: int = 1024, + num_groups: Tuple[Tuple, int] = ((1, 15), 5), + max_dropout: Tuple[Tuple, float] = ((0, 1), 0.5), + use_dropout: Tuple[Tuple, bool] = ((True, False), False), + max_units: Tuple[Tuple, int] = ((10, 1024), 200), + output_dim: Tuple[Tuple, int] = ((10, 1024), 200), + mlp_shape: Tuple[Tuple, str] = (('funnel', 'long_funnel', + 'diamond', 'hexagon', + 'brick', 'triangle', 'stairs'), 'funnel'), + activation: Tuple[Tuple, str] = ( + tuple(_activations.keys()), list(_activations.keys())[0]) ) -> ConfigurationSpace: cs = ConfigurationSpace() @@ -91,27 +97,28 @@ def get_hyperparameter_search_space(dataset_properties: Optional[Dict] = None, # a group can have N Resblock. The M number of this N resblock # repetitions is num_groups num_groups = UniformIntegerHyperparameter( - "num_groups", lower=min_num_gropus, upper=max_num_groups, default_value=5) + "num_groups", lower=num_groups[0][0], upper=num_groups[0][1], default_value=num_groups[1]) - mlp_shape = CategoricalHyperparameter('mlp_shape', choices=[ - 'funnel', 'long_funnel', 'diamond', 'hexagon', 'brick', 'triangle', 'stairs' - ]) + mlp_shape = CategoricalHyperparameter('mlp_shape', choices=mlp_shape[0], + default_value=mlp_shape[1]) activation = CategoricalHyperparameter( - "activation", choices=list(_activations.keys()) + "activation", choices=activation[0], + default_value=activation[1] ) - + (min_num_units, max_num_units), default_units = max_units[:2] max_units = UniformIntegerHyperparameter( "max_units", lower=min_num_units, upper=max_num_units, - default_value=200, + default_value=default_units, ) output_dim = UniformIntegerHyperparameter( "output_dim", - lower=min_num_units, - upper=max_num_units + lower=output_dim[0][0], + upper=output_dim[0][1], + default_value=output_dim[1] ) cs.add_hyperparameters([num_groups, activation, mlp_shape, max_units, output_dim]) @@ -119,8 +126,9 @@ def get_hyperparameter_search_space(dataset_properties: Optional[Dict] = None, # We can have dropout in the network for # better generalization use_dropout = CategoricalHyperparameter( - "use_dropout", choices=[True, False]) - max_dropout = UniformFloatHyperparameter("max_dropout", lower=0.0, upper=1.0) + "use_dropout", choices=use_dropout[0], default_value=use_dropout[1]) + max_dropout = UniformFloatHyperparameter("max_dropout", lower=max_dropout[0][0], upper=max_dropout[0][1], + default_value=max_dropout[1]) cs.add_hyperparameters([use_dropout, max_dropout]) cs.add_condition(CS.EqualsCondition(max_dropout, use_dropout, True)) diff --git a/autoPyTorch/pipeline/components/setup/network_backbone/ShapedResNetBackbone.py b/autoPyTorch/pipeline/components/setup/network_backbone/ShapedResNetBackbone.py index 16435d0c9..b3efc7bb1 100644 --- a/autoPyTorch/pipeline/components/setup/network_backbone/ShapedResNetBackbone.py +++ b/autoPyTorch/pipeline/components/setup/network_backbone/ShapedResNetBackbone.py @@ -81,28 +81,29 @@ def get_properties(dataset_properties: Optional[Dict[str, Any]] = None) -> Dict[ } @staticmethod - def get_hyperparameter_search_space(dataset_properties: Optional[Dict] = None, - min_num_gropus: int = 1, - max_num_groups: int = 9, - min_blocks_per_groups: int = 1, - max_blocks_per_groups: int = 4, - min_num_units: int = 10, - max_num_units: int = 1024, + def get_hyperparameter_search_space(dataset_properties: Optional[Dict] = None, # type: ignore[override] + num_groups: Tuple[Tuple, int] = ((1, 15), 5), + use_dropout: Tuple[Tuple, bool] = ((True, False), False), + max_units: Tuple[Tuple, int] = ((10, 1024), 200), + blocks_per_group: Tuple[Tuple, int] = ((1, 4), 2), + max_dropout: Tuple[Tuple, float] = ((0, 0.8), 0.5), + use_shake_shake: Tuple[Tuple, bool] = ((True, False), True), + use_shake_drop: Tuple[Tuple, bool] = ((True, False), True), + max_shake_drop_probability: Tuple[Tuple, float] = ((0, 1), 0.5), + resnet_shape: Tuple[Tuple, str] = (('funnel', 'long_funnel', + 'diamond', 'hexagon', + 'brick', 'triangle', 'stairs'), 'funnel'), + activation: Tuple[Tuple, str] = ( + tuple(_activations.keys()), list(_activations.keys())[0]), + output_dim: Tuple[Tuple, int] = ((10, 1024), 200), ) -> ConfigurationSpace: cs = ConfigurationSpace() # Support for different shapes resnet_shape = CategoricalHyperparameter( 'resnet_shape', - choices=[ - 'funnel', - 'long_funnel', - 'diamond', - 'hexagon', - 'brick', - 'triangle', - 'stairs' - ] + choices=resnet_shape[0], + default_value=resnet_shape[1] ) cs.add_hyperparameter(resnet_shape) @@ -110,33 +111,36 @@ def get_hyperparameter_search_space(dataset_properties: Optional[Dict] = None, # a group can have N Resblock. The M number of this N resblock # repetitions is num_groups num_groups = UniformIntegerHyperparameter( - "num_groups", lower=min_num_gropus, upper=max_num_groups, default_value=5) + "num_groups", lower=num_groups[0][0], upper=num_groups[0][1], default_value=num_groups[1]) blocks_per_group = UniformIntegerHyperparameter( - "blocks_per_group", lower=min_blocks_per_groups, upper=max_blocks_per_groups) + "blocks_per_group", lower=blocks_per_group[0][0], + upper=blocks_per_group[0][1], + default_value=blocks_per_group[1]) activation = CategoricalHyperparameter( - "activation", choices=list(_activations.keys()) + "activation", choices=activation[0], + default_value=activation[1] ) - + (min_num_units, max_num_units), default_units = max_units[:2] output_dim = UniformIntegerHyperparameter( "output_dim", - lower=min_num_units, - upper=max_num_units + lower=output_dim[0][0], + upper=output_dim[0][1], + default_value=output_dim[1] ) cs.add_hyperparameters([num_groups, blocks_per_group, activation, output_dim]) - # We can have dropout in the network for - # better generalization - use_dropout = CategoricalHyperparameter( - "use_dropout", choices=[True, False]) - cs.add_hyperparameters([use_dropout]) - - use_shake_shake = CategoricalHyperparameter("use_shake_shake", choices=[True, False]) - use_shake_drop = CategoricalHyperparameter("use_shake_drop", choices=[True, False]) + use_shake_shake = CategoricalHyperparameter("use_shake_shake", choices=use_shake_shake[0], + default_value=use_shake_shake[1]) + use_shake_drop = CategoricalHyperparameter("use_shake_drop", choices=use_shake_drop[0], + default_value=use_shake_drop[1]) shake_drop_prob = UniformFloatHyperparameter( - "max_shake_drop_probability", lower=0.0, upper=1.0) + "max_shake_drop_probability", + lower=max_shake_drop_probability[0][0], + upper=max_shake_drop_probability[0][1], + default_value=max_shake_drop_probability[1]) cs.add_hyperparameters([use_shake_shake, use_shake_drop, shake_drop_prob]) cs.add_condition(CS.EqualsCondition(shake_drop_prob, use_shake_drop, True)) @@ -144,12 +148,15 @@ def get_hyperparameter_search_space(dataset_properties: Optional[Dict] = None, "max_units", lower=min_num_units, upper=max_num_units, + default_value=default_units ) cs.add_hyperparameters([max_units]) - max_dropout = UniformFloatHyperparameter( - "max_dropout", lower=0.0, upper=1.0 - ) + use_dropout = CategoricalHyperparameter( + "use_dropout", choices=use_dropout[0], default_value=use_dropout[1]) + max_dropout = UniformFloatHyperparameter("max_dropout", lower=max_dropout[0][0], upper=max_dropout[0][1], + default_value=max_dropout[1]) + cs.add_hyperparameters([use_dropout]) cs.add_hyperparameters([max_dropout]) cs.add_condition(CS.EqualsCondition(max_dropout, use_dropout, True)) diff --git a/autoPyTorch/pipeline/components/setup/network_backbone/base_network_backbone_choice.py b/autoPyTorch/pipeline/components/setup/network_backbone/base_network_backbone_choice.py index 278979d60..71d8a63d3 100644 --- a/autoPyTorch/pipeline/components/setup/network_backbone/base_network_backbone_choice.py +++ b/autoPyTorch/pipeline/components/setup/network_backbone/base_network_backbone_choice.py @@ -17,7 +17,6 @@ NetworkBackboneComponent, ) - directory = os.path.split(__file__)[0] _backbones = find_components(__package__, directory, @@ -167,12 +166,13 @@ def get_hyperparameter_search_space( ) cs.add_hyperparameter(backbone) for name in available_backbones: - backbone_configuration_space = available_backbones[name]. \ - get_hyperparameter_search_space(dataset_properties) + updates = self._get_search_space_updates(prefix=name) + config_space = available_backbones[name].get_hyperparameter_search_space(dataset_properties, # type: ignore + **updates) parent_hyperparameter = {'parent': backbone, 'value': name} cs.add_configuration_space( name, - backbone_configuration_space, + config_space, parent_hyperparameter=parent_hyperparameter ) diff --git a/autoPyTorch/pipeline/components/setup/network_head/base_network_head_choice.py b/autoPyTorch/pipeline/components/setup/network_head/base_network_head_choice.py index 116a707f3..f57e60b91 100644 --- a/autoPyTorch/pipeline/components/setup/network_head/base_network_head_choice.py +++ b/autoPyTorch/pipeline/components/setup/network_head/base_network_head_choice.py @@ -165,12 +165,13 @@ def get_hyperparameter_search_space( ) cs.add_hyperparameter(head) for name in available_heads: - head_configuration_space = available_heads[name]. \ - get_hyperparameter_search_space(dataset_properties) + updates = self._get_search_space_updates(prefix=name) + config_space = available_heads[name].get_hyperparameter_search_space(dataset_properties, # type: ignore + **updates) parent_hyperparameter = {'parent': head, 'value': name} cs.add_configuration_space( name, - head_configuration_space, + config_space, parent_hyperparameter=parent_hyperparameter ) diff --git a/autoPyTorch/pipeline/components/setup/network_head/fully_connected.py b/autoPyTorch/pipeline/components/setup/network_head/fully_connected.py index bada209ec..bd555e03a 100644 --- a/autoPyTorch/pipeline/components/setup/network_head/fully_connected.py +++ b/autoPyTorch/pipeline/components/setup/network_head/fully_connected.py @@ -33,9 +33,9 @@ def build_head(self, input_shape: Tuple[int, ...], output_shape: Tuple[int, ...] in_features = np.prod(input_shape).item() for i in range(1, self.config["num_layers"]): layers.append(nn.Linear(in_features=in_features, - out_features=self.config[f"layer_{i}_units"])) + out_features=self.config[f"units_layer_{i}"])) layers.append(_activations[self.config["activation"]]()) - in_features = self.config[f"layer_{i}_units"] + in_features = self.config[f"units_layer_{i}"] out_features = np.prod(output_shape).item() layers.append(nn.Linear(in_features=in_features, out_features=out_features)) @@ -53,27 +53,34 @@ def get_properties(dataset_properties: Optional[Dict[str, Any]] = None) -> Dict[ @staticmethod def get_hyperparameter_search_space(dataset_properties: Optional[Dict[str, str]] = None, - min_num_layers: int = 1, - max_num_layers: int = 4, - min_num_units: int = 64, - max_num_units: int = 512) -> ConfigurationSpace: + num_layers: Tuple[Tuple, int] = ((1, 4), 2), + units_layer: Tuple[Tuple, int] = ((64, 512), 128), + activation: Tuple[Tuple, str] = (tuple(_activations.keys()), + list(_activations.keys())[0]) + ) -> ConfigurationSpace: cs = ConfigurationSpace() + min_num_layers, max_num_layers = num_layers[0] num_layers_hp = UniformIntegerHyperparameter("num_layers", lower=min_num_layers, - upper=max_num_layers) + upper=max_num_layers, + default_value=num_layers[1] + ) - activation_hp = CategoricalHyperparameter("activation", - choices=list(_activations.keys())) + activation_hp = CategoricalHyperparameter( + "activation", choices=activation[0], + default_value=activation[1] + ) cs.add_hyperparameters([num_layers_hp, activation_hp]) cs.add_condition(CS.GreaterThanCondition(activation_hp, num_layers_hp, 1)) for i in range(1, max_num_layers): - num_units_hp = UniformIntegerHyperparameter(f"layer_{i}_units", - lower=min_num_units, - upper=max_num_units) + num_units_hp = UniformIntegerHyperparameter(f"units_layer_{i}", + lower=units_layer[0][0], + upper=units_layer[0][1], + default_value=units_layer[1]) cs.add_hyperparameter(num_units_hp) if i >= min_num_layers: cs.add_condition(CS.GreaterThanCondition(num_units_hp, num_layers_hp, i)) diff --git a/autoPyTorch/pipeline/components/setup/network_initializer/base_network_init_choice.py b/autoPyTorch/pipeline/components/setup/network_initializer/base_network_init_choice.py index 8f4b734d1..ea9b4c1d9 100644 --- a/autoPyTorch/pipeline/components/setup/network_initializer/base_network_init_choice.py +++ b/autoPyTorch/pipeline/components/setup/network_initializer/base_network_init_choice.py @@ -17,7 +17,6 @@ BaseNetworkInitializerComponent ) - directory = os.path.split(__file__)[0] _initializers = find_components(__package__, directory, @@ -47,10 +46,10 @@ def get_components(self) -> Dict[str, autoPyTorchComponent]: return components def get_available_components( - self, - dataset_properties: Optional[Dict[str, str]] = None, - include: List[str] = None, - exclude: List[str] = None, + self, + dataset_properties: Optional[Dict[str, str]] = None, + include: List[str] = None, + exclude: List[str] = None, ) -> Dict[str, autoPyTorchComponent]: """Filters out components based on user provided include/exclude directives, as well as the dataset properties @@ -103,11 +102,11 @@ def get_available_components( return components_dict def get_hyperparameter_search_space( - self, - dataset_properties: Optional[Dict[str, str]] = None, - default: Optional[str] = None, - include: Optional[List[str]] = None, - exclude: Optional[List[str]] = None, + self, + dataset_properties: Optional[Dict[str, str]] = None, + default: Optional[str] = None, + include: Optional[List[str]] = None, + exclude: Optional[List[str]] = None, ) -> ConfigurationSpace: """Returns the configuration space of the current chosen components @@ -128,34 +127,35 @@ def get_hyperparameter_search_space( dataset_properties = {} # Compile a list of legal preprocessors for this problem - available_initializers = self.get_available_components( + initializers = self.get_available_components( dataset_properties=dataset_properties, include=include, exclude=exclude) - if len(available_initializers) == 0: + if len(initializers) == 0: raise ValueError("No initializers found") if default is None: defaults = ['XavierInit', ] for default_ in defaults: - if default_ in available_initializers: + if default_ in initializers: default = default_ break initializer = CSH.CategoricalHyperparameter( '__choice__', - list(available_initializers.keys()), + list(initializers.keys()), default_value=default ) cs.add_hyperparameter(initializer) - for name in available_initializers: - initializer_configuration_space = available_initializers[name]. \ - get_hyperparameter_search_space(dataset_properties) + for name in initializers: + updates = self._get_search_space_updates(prefix=name) + config_space = initializers[name].get_hyperparameter_search_space(dataset_properties, # type:ignore + **updates) parent_hyperparameter = {'parent': initializer, 'value': name} cs.add_configuration_space( name, - initializer_configuration_space, + config_space, parent_hyperparameter=parent_hyperparameter ) diff --git a/autoPyTorch/pipeline/components/setup/network_initializer/base_network_initializer.py b/autoPyTorch/pipeline/components/setup/network_initializer/base_network_initializer.py index 306c798da..93f478f1d 100644 --- a/autoPyTorch/pipeline/components/setup/network_initializer/base_network_initializer.py +++ b/autoPyTorch/pipeline/components/setup/network_initializer/base_network_initializer.py @@ -1,5 +1,5 @@ from abc import abstractmethod -from typing import Any, Callable, Dict, Optional +from typing import Any, Callable, Dict, Optional, Tuple from ConfigSpace.configuration_space import ConfigurationSpace from ConfigSpace.hyperparameters import ( @@ -73,18 +73,14 @@ def transform(self, X: Dict[str, Any]) -> Dict[str, Any]: @staticmethod def get_hyperparameter_search_space(dataset_properties: Optional[Dict] = None, - min_mlp_layers: int = 1, - max_mlp_layers: int = 15, - dropout: bool = True, - min_num_units: int = 10, - max_num_units: int = 1024, + bias_strategy: Tuple[Tuple, str] = (('Zero', 'Normal'), 'Normal') ) -> ConfigurationSpace: cs = ConfigurationSpace() # The strategy for bias initializations bias_strategy = CategoricalHyperparameter( - "bias_strategy", choices=['Zero', 'Normal']) + "bias_strategy", choices=bias_strategy[0], default_value=bias_strategy[1]) cs.add_hyperparameters([bias_strategy]) return cs diff --git a/autoPyTorch/pipeline/components/setup/optimizer/AdamOptimizer.py b/autoPyTorch/pipeline/components/setup/optimizer/AdamOptimizer.py index 1293444ad..aa86e9d0c 100644 --- a/autoPyTorch/pipeline/components/setup/optimizer/AdamOptimizer.py +++ b/autoPyTorch/pipeline/components/setup/optimizer/AdamOptimizer.py @@ -1,4 +1,4 @@ -from typing import Any, Dict, Optional +from typing import Any, Dict, Optional, Tuple from ConfigSpace.configuration_space import ConfigurationSpace from ConfigSpace.hyperparameters import ( @@ -74,22 +74,26 @@ def get_properties(dataset_properties: Optional[Dict[str, Any]] = None) -> Dict[ @staticmethod def get_hyperparameter_search_space(dataset_properties: Optional[Dict] = None, + lr: Tuple[Tuple, float, bool] = ((1e-5, 1e-1), 1e-2, True), + beta1: Tuple[Tuple, float] = ((0.85, 0.999), 0.9), + beta2: Tuple[Tuple, float] = ((0.9, 0.9999), 0.9), + weight_decay: Tuple[Tuple, float] = ((0.0, 0.1), 0.0) ) -> ConfigurationSpace: cs = ConfigurationSpace() # The learning rate for the model - lr = UniformFloatHyperparameter('lr', lower=1e-5, upper=1e-1, - default_value=1e-2, log=True) + lr = UniformFloatHyperparameter('lr', lower=lr[0][0], upper=lr[0][1], + default_value=lr[1], log=lr[2]) - beta1 = UniformFloatHyperparameter('beta1', lower=0.85, upper=0.999, - default_value=0.9) + beta1 = UniformFloatHyperparameter('beta1', lower=beta1[0][0], upper=beta1[0][1], + default_value=beta1[1]) - beta2 = UniformFloatHyperparameter('beta2', lower=0.9, upper=0.9999, - default_value=0.9) + beta2 = UniformFloatHyperparameter('beta2', lower=beta2[0][0], upper=beta2[0][1], + default_value=beta2[1]) - weight_decay = UniformFloatHyperparameter('weight_decay', lower=0.0, upper=0.1, - default_value=0.0) + weight_decay = UniformFloatHyperparameter('weight_decay', lower=weight_decay[0][0], upper=weight_decay[0][1], + default_value=weight_decay[1]) cs.add_hyperparameters([lr, beta1, beta2, weight_decay]) diff --git a/autoPyTorch/pipeline/components/setup/optimizer/AdamWOptimizer.py b/autoPyTorch/pipeline/components/setup/optimizer/AdamWOptimizer.py index 74e0504a7..b525cdef7 100644 --- a/autoPyTorch/pipeline/components/setup/optimizer/AdamWOptimizer.py +++ b/autoPyTorch/pipeline/components/setup/optimizer/AdamWOptimizer.py @@ -1,4 +1,4 @@ -from typing import Any, Dict, Optional +from typing import Any, Dict, Optional, Tuple from ConfigSpace.configuration_space import ConfigurationSpace from ConfigSpace.hyperparameters import ( @@ -74,22 +74,26 @@ def get_properties(dataset_properties: Optional[Dict[str, Any]] = None) -> Dict[ @staticmethod def get_hyperparameter_search_space(dataset_properties: Optional[Dict] = None, + lr: Tuple[Tuple, float, bool] = ((1e-5, 1e-1), 1e-2, True), + beta1: Tuple[Tuple, float] = ((0.85, 0.999), 0.9), + beta2: Tuple[Tuple, float] = ((0.9, 0.9999), 0.9), + weight_decay: Tuple[Tuple, float] = ((0.0, 0.1), 0.0) ) -> ConfigurationSpace: cs = ConfigurationSpace() # The learning rate for the model - lr = UniformFloatHyperparameter('lr', lower=1e-5, upper=1e-1, - default_value=1e-2, log=True) + lr = UniformFloatHyperparameter('lr', lower=lr[0][0], upper=lr[0][1], + default_value=lr[1], log=lr[2]) - beta1 = UniformFloatHyperparameter('beta1', lower=0.85, upper=0.999, - default_value=0.9) + beta1 = UniformFloatHyperparameter('beta1', lower=beta1[0][0], upper=beta1[0][1], + default_value=beta1[1]) - beta2 = UniformFloatHyperparameter('beta2', lower=0.9, upper=0.9999, - default_value=0.9) + beta2 = UniformFloatHyperparameter('beta2', lower=beta2[0][0], upper=beta2[0][1], + default_value=beta2[1]) - weight_decay = UniformFloatHyperparameter('weight_decay', lower=0.0, upper=0.1, - default_value=0.0) + weight_decay = UniformFloatHyperparameter('weight_decay', lower=weight_decay[0][0], upper=weight_decay[0][1], + default_value=weight_decay[1]) cs.add_hyperparameters([lr, beta1, beta2, weight_decay]) diff --git a/autoPyTorch/pipeline/components/setup/optimizer/RMSpropOptimizer.py b/autoPyTorch/pipeline/components/setup/optimizer/RMSpropOptimizer.py index f589584fb..e44bb9495 100644 --- a/autoPyTorch/pipeline/components/setup/optimizer/RMSpropOptimizer.py +++ b/autoPyTorch/pipeline/components/setup/optimizer/RMSpropOptimizer.py @@ -1,4 +1,4 @@ -from typing import Any, Dict, Optional +from typing import Any, Dict, Optional, Tuple from ConfigSpace.configuration_space import ConfigurationSpace from ConfigSpace.hyperparameters import ( @@ -77,22 +77,26 @@ def get_properties(dataset_properties: Optional[Dict[str, Any]] = None) -> Dict[ @staticmethod def get_hyperparameter_search_space(dataset_properties: Optional[Dict] = None, + lr: Tuple[Tuple, float, bool] = ((1e-5, 1e-1), 1e-2, True), + alpha: Tuple[Tuple, float] = ((0.1, 0.99), 0.99), + weight_decay: Tuple[Tuple, float] = ((0.0, 0.1), 0.0), + momentum: Tuple[Tuple, float] = ((0.0, 0.99), 0.0), ) -> ConfigurationSpace: cs = ConfigurationSpace() # The learning rate for the model - lr = UniformFloatHyperparameter('lr', lower=1e-5, upper=1e-1, - default_value=1e-2, log=True) + lr = UniformFloatHyperparameter('lr', lower=lr[0][0], upper=lr[0][1], + default_value=lr[1], log=lr[2]) - alpha = UniformFloatHyperparameter('alpha', lower=0.1, upper=0.99, - default_value=0.99) + alpha = UniformFloatHyperparameter('alpha', lower=alpha[0][0], upper=alpha[0][1], + default_value=alpha[1]) - weight_decay = UniformFloatHyperparameter('weight_decay', lower=0.0, upper=0.1, - default_value=0.0) + weight_decay = UniformFloatHyperparameter('weight_decay', lower=weight_decay[0][0], upper=weight_decay[0][1], + default_value=weight_decay[1]) - momentum = UniformFloatHyperparameter('momentum', lower=0.0, upper=0.99, - default_value=0.0) + momentum = UniformFloatHyperparameter('momentum', lower=momentum[0][0], upper=momentum[0][1], + default_value=momentum[1]) cs.add_hyperparameters([lr, alpha, weight_decay, momentum]) diff --git a/autoPyTorch/pipeline/components/setup/optimizer/SGDOptimizer.py b/autoPyTorch/pipeline/components/setup/optimizer/SGDOptimizer.py index 831419a39..4396cb381 100644 --- a/autoPyTorch/pipeline/components/setup/optimizer/SGDOptimizer.py +++ b/autoPyTorch/pipeline/components/setup/optimizer/SGDOptimizer.py @@ -1,4 +1,4 @@ -from typing import Any, Dict, Optional +from typing import Any, Dict, Optional, Tuple from ConfigSpace.configuration_space import ConfigurationSpace from ConfigSpace.hyperparameters import ( @@ -71,19 +71,22 @@ def get_properties(dataset_properties: Optional[Dict[str, Any]] = None) -> Dict[ @staticmethod def get_hyperparameter_search_space(dataset_properties: Optional[Dict] = None, + lr: Tuple[Tuple, float, bool] = ((1e-5, 1e-1), 1e-2, True), + weight_decay: Tuple[Tuple, float] = ((0.0, 0.1), 0.0), + momentum: Tuple[Tuple, float] = ((0.0, 0.99), 0.0), ) -> ConfigurationSpace: cs = ConfigurationSpace() # The learning rate for the model - lr = UniformFloatHyperparameter('lr', lower=1e-5, upper=1e-1, - default_value=1e-2, log=True) + lr = UniformFloatHyperparameter('lr', lower=lr[0][0], upper=lr[0][1], + default_value=lr[1], log=lr[2]) - weight_decay = UniformFloatHyperparameter('weight_decay', lower=0.0, upper=0.1, - default_value=0.0) + weight_decay = UniformFloatHyperparameter('weight_decay', lower=weight_decay[0][0], upper=weight_decay[0][1], + default_value=weight_decay[1]) - momentum = UniformFloatHyperparameter('momentum', lower=0.0, upper=0.99, - default_value=0.0) + momentum = UniformFloatHyperparameter('momentum', lower=momentum[0][0], upper=momentum[0][1], + default_value=momentum[1]) cs.add_hyperparameters([lr, weight_decay, momentum]) diff --git a/autoPyTorch/pipeline/components/setup/optimizer/base_optimizer_choice.py b/autoPyTorch/pipeline/components/setup/optimizer/base_optimizer_choice.py index 82bacf9d2..5196f0bb7 100644 --- a/autoPyTorch/pipeline/components/setup/optimizer/base_optimizer_choice.py +++ b/autoPyTorch/pipeline/components/setup/optimizer/base_optimizer_choice.py @@ -125,11 +125,11 @@ def get_hyperparameter_search_space( dataset_properties = {} # Compile a list of legal preprocessors for this problem - available_optimizers = self.get_available_components( + available_optimizer = self.get_available_components( dataset_properties=dataset_properties, include=include, exclude=exclude) - if len(available_optimizers) == 0: + if len(available_optimizer) == 0: raise ValueError("No Optimizer found") if default is None: @@ -140,23 +140,24 @@ def get_hyperparameter_search_space( 'RMSpropOptimizer' ] for default_ in defaults: - if default_ in available_optimizers: + if default_ in available_optimizer: default = default_ break optimizer = CSH.CategoricalHyperparameter( '__choice__', - list(available_optimizers.keys()), + list(available_optimizer.keys()), default_value=default ) cs.add_hyperparameter(optimizer) - for name in available_optimizers: - optimizer_configuration_space = available_optimizers[name]. \ - get_hyperparameter_search_space(dataset_properties) + for name in available_optimizer: + updates = self._get_search_space_updates(prefix=name) + config_space = available_optimizer[name].get_hyperparameter_search_space(dataset_properties, # type: ignore + **updates) parent_hyperparameter = {'parent': optimizer, 'value': name} cs.add_configuration_space( name, - optimizer_configuration_space, + config_space, parent_hyperparameter=parent_hyperparameter ) diff --git a/autoPyTorch/pipeline/components/training/data_loader/base_data_loader.py b/autoPyTorch/pipeline/components/training/data_loader/base_data_loader.py index 4dd509e17..95a4fccb2 100644 --- a/autoPyTorch/pipeline/components/training/data_loader/base_data_loader.py +++ b/autoPyTorch/pipeline/components/training/data_loader/base_data_loader.py @@ -1,4 +1,4 @@ -from typing import Any, Dict, Optional +from typing import Any, Dict, Optional, Tuple from ConfigSpace.configuration_space import ConfigurationSpace from ConfigSpace.hyperparameters import ( @@ -248,10 +248,11 @@ def get_torchvision_datasets(self) -> Dict[str, torchvision.datasets.VisionDatas } @staticmethod - def get_hyperparameter_search_space(dataset_properties: Optional[Dict] = None + def get_hyperparameter_search_space(dataset_properties: Optional[Dict] = None, + batch_size: Tuple[Tuple, int] = ((32, 320), 64) ) -> ConfigurationSpace: batch_size = UniformIntegerHyperparameter( - "batch_size", 16, 512, default_value=64) + "batch_size", batch_size[0][0], batch_size[0][1], default_value=batch_size[1]) cs = ConfigurationSpace() cs.add_hyperparameters([batch_size]) return cs diff --git a/autoPyTorch/pipeline/components/training/trainer/MixUpTrainer.py b/autoPyTorch/pipeline/components/training/trainer/MixUpTrainer.py index b391b7d59..a49e17682 100644 --- a/autoPyTorch/pipeline/components/training/trainer/MixUpTrainer.py +++ b/autoPyTorch/pipeline/components/training/trainer/MixUpTrainer.py @@ -61,10 +61,11 @@ def get_properties(dataset_properties: typing.Optional[typing.Dict[str, typing.A } @staticmethod - def get_hyperparameter_search_space(dataset_properties: typing.Optional[typing.Dict] = None + def get_hyperparameter_search_space(dataset_properties: typing.Optional[typing.Dict] = None, + alpha: typing.Tuple[typing.Tuple[float, float], float] = ((0, 1), 0.2) ) -> ConfigurationSpace: alpha = UniformFloatHyperparameter( - "alpha", 0, 1, default_value=0.2) + "alpha", alpha[0][0], alpha[0][1], default_value=alpha[1]) cs = ConfigurationSpace() cs.add_hyperparameters([alpha]) return cs diff --git a/autoPyTorch/pipeline/components/training/trainer/StandardTrainer.py b/autoPyTorch/pipeline/components/training/trainer/StandardTrainer.py index 4509c17f6..454d4c625 100644 --- a/autoPyTorch/pipeline/components/training/trainer/StandardTrainer.py +++ b/autoPyTorch/pipeline/components/training/trainer/StandardTrainer.py @@ -39,7 +39,8 @@ def get_properties(dataset_properties: typing.Optional[typing.Dict[str, typing.A } @staticmethod - def get_hyperparameter_search_space(dataset_properties: typing.Optional[typing.Dict] = None + def get_hyperparameter_search_space(dataset_properties: typing.Optional[typing.Dict] = None, + **kwargs: typing.Any ) -> ConfigurationSpace: cs = ConfigurationSpace() return cs diff --git a/autoPyTorch/pipeline/components/training/trainer/base_trainer_choice.py b/autoPyTorch/pipeline/components/training/trainer/base_trainer_choice.py index 971eed3f5..252efdbaf 100755 --- a/autoPyTorch/pipeline/components/training/trainer/base_trainer_choice.py +++ b/autoPyTorch/pipeline/components/training/trainer/base_trainer_choice.py @@ -142,12 +142,13 @@ def get_hyperparameter_search_space( ) cs.add_hyperparameter(trainer) for name in available_trainers: - trainer_configuration_space = available_trainers[name]. \ - get_hyperparameter_search_space(dataset_properties) + updates = self._get_search_space_updates(prefix=name) + config_space = available_trainers[name].get_hyperparameter_search_space(dataset_properties, # type:ignore + **updates) parent_hyperparameter = {'parent': trainer, 'value': name} cs.add_configuration_space( name, - trainer_configuration_space, + config_space, parent_hyperparameter=parent_hyperparameter ) diff --git a/autoPyTorch/pipeline/image_classification.py b/autoPyTorch/pipeline/image_classification.py index 04e70c3d5..108594c5e 100644 --- a/autoPyTorch/pipeline/image_classification.py +++ b/autoPyTorch/pipeline/image_classification.py @@ -13,6 +13,7 @@ ) from autoPyTorch.pipeline.components.setup.augmentation.image.ImageAugmenter import ImageAugmenter from autoPyTorch.pipeline.components.setup.early_preprocessor.EarlyPreprocessing import EarlyPreprocessing +from autoPyTorch.utils.hyperparameter_search_space_update import HyperparameterSearchSpaceUpdates # from autoPyTorch.pipeline.components.setup.lr_scheduler.base_scheduler_choice import SchedulerChoice # from autoPyTorch.pipeline.components.setup.network.base_network_choice import NetworkChoice # from autoPyTorch.pipeline.components.setup.optimizer.base_optimizer_choice import OptimizerChoice @@ -53,11 +54,12 @@ def __init__( include: Optional[Dict[str, Any]] = None, exclude: Optional[Dict[str, Any]] = None, random_state: Optional[np.random.RandomState] = None, - init_params: Optional[Dict[str, Any]] = None + init_params: Optional[Dict[str, Any]] = None, + search_space_updates: Optional[HyperparameterSearchSpaceUpdates] = None ): super().__init__( config, steps, dataset_properties, include, exclude, - random_state, init_params) + random_state, init_params, search_space_updates) def fit_transformer( self, diff --git a/autoPyTorch/pipeline/tabular_classification.py b/autoPyTorch/pipeline/tabular_classification.py index 4b0a743be..5059f3536 100644 --- a/autoPyTorch/pipeline/tabular_classification.py +++ b/autoPyTorch/pipeline/tabular_classification.py @@ -31,6 +31,7 @@ from autoPyTorch.pipeline.components.training.trainer.base_trainer_choice import ( TrainerChoice ) +from autoPyTorch.utils.hyperparameter_search_space_update import HyperparameterSearchSpaceUpdates class TabularClassificationPipeline(ClassifierMixin, BasePipeline): @@ -65,11 +66,12 @@ def __init__( include: Optional[Dict[str, Any]] = None, exclude: Optional[Dict[str, Any]] = None, random_state: Optional[np.random.RandomState] = None, - init_params: Optional[Dict[str, Any]] = None + init_params: Optional[Dict[str, Any]] = None, + search_space_updates: Optional[HyperparameterSearchSpaceUpdates] = None ): super().__init__( config, steps, dataset_properties, include, exclude, - random_state, init_params) + random_state, init_params, search_space_updates) def fit_transformer( self, diff --git a/autoPyTorch/pipeline/tabular_regression.py b/autoPyTorch/pipeline/tabular_regression.py index 2e3efad85..02a668592 100644 --- a/autoPyTorch/pipeline/tabular_regression.py +++ b/autoPyTorch/pipeline/tabular_regression.py @@ -31,6 +31,7 @@ from autoPyTorch.pipeline.components.training.trainer.base_trainer_choice import ( TrainerChoice ) +from autoPyTorch.utils.hyperparameter_search_space_update import HyperparameterSearchSpaceUpdates class TabularRegressionPipeline(RegressorMixin, BasePipeline): @@ -65,11 +66,12 @@ def __init__( include: Optional[Dict[str, Any]] = None, exclude: Optional[Dict[str, Any]] = None, random_state: Optional[np.random.RandomState] = None, - init_params: Optional[Dict[str, Any]] = None + init_params: Optional[Dict[str, Any]] = None, + search_space_updates: Optional[HyperparameterSearchSpaceUpdates] = None ): super().__init__( config, steps, dataset_properties, include, exclude, - random_state, init_params) + random_state, init_params, search_space_updates) def fit_transformer( self, diff --git a/autoPyTorch/utils/hyperparameter_search_space_update.py b/autoPyTorch/utils/hyperparameter_search_space_update.py new file mode 100644 index 000000000..3f2937686 --- /dev/null +++ b/autoPyTorch/utils/hyperparameter_search_space_update.py @@ -0,0 +1,73 @@ +import ast +import os +from typing import List, Optional, Tuple, Union + +from autoPyTorch.pipeline.components.base_choice import autoPyTorchChoice +from autoPyTorch.pipeline.components.base_component import autoPyTorchComponent + + +class HyperparameterSearchSpaceUpdate(): + def __init__(self, node_name: str, hyperparameter: str, value_range: Union[List, Tuple], + default_value: Union[int, float, str], log: bool = False) -> None: + self.node_name = node_name + self.hyperparameter = hyperparameter + self.value_range = value_range + self.log = log + self.default_value = default_value + + def apply(self, pipeline: List[Tuple[str, Union[autoPyTorchComponent, autoPyTorchChoice]]]) -> None: + [node[1]._apply_search_space_update(name=self.hyperparameter, + new_value_range=self.value_range, + log=self.log, + default_value=self.default_value) + for node in pipeline if node[0] == self.node_name] + + def __str__(self) -> str: + return "{}, {}, {}, {}, {}".format(self.node_name, self.hyperparameter, str(self.value_range), + self.default_value if isinstance(self.default_value, + str) else self.default_value, + (" log" if self.log else "")) + + +class HyperparameterSearchSpaceUpdates(): + def __init__(self, updates: Optional[List[HyperparameterSearchSpaceUpdate]] = None) -> None: + self.updates = updates if updates is not None else [] + + def apply(self, pipeline: List[Tuple[str, Union[autoPyTorchComponent, autoPyTorchChoice]]]) -> None: + for update in self.updates: + update.apply(pipeline) + + def append(self, node_name: str, hyperparameter: str, value_range: Union[List, Tuple], + default_value: Union[int, float, str], log: bool = False) -> None: + self.updates.append(HyperparameterSearchSpaceUpdate(node_name=node_name, + hyperparameter=hyperparameter, + value_range=value_range, + default_value=default_value, + log=log)) + + def save_as_file(self, path: str) -> None: + with open(path, "w") as f: + with open(path, "w") as f: + for update in self.updates: + print(update.node_name, update.hyperparameter, # noqa: T001 + str(update.value_range), "'{}'".format(update.default_value) + if isinstance(update.default_value, str) else update.default_value, + (" log" if update.log else ""), file=f) + + +def parse_hyperparameter_search_space_updates(updates_file: Optional[str] + ) -> Optional[HyperparameterSearchSpaceUpdates]: + if updates_file is None or os.path.basename(updates_file) == "None": + return None + with open(updates_file, "r") as f: + result = [] + for line in f: + if line.strip() == "": + continue + line = line.split() # type: ignore[assignment] + node, hyperparameter, value_range = line[0], line[1], ast.literal_eval(line[2] + line[3]) + default_value = ast.literal_eval(line[4]) + assert isinstance(value_range, (tuple, list)) + log = len(line) == 6 and "log" == line[5] + result.append(HyperparameterSearchSpaceUpdate(node, hyperparameter, value_range, default_value, log)) + return HyperparameterSearchSpaceUpdates(result) diff --git a/autoPyTorch/utils/pipeline.py b/autoPyTorch/utils/pipeline.py index 6820d2702..3cd0d528f 100644 --- a/autoPyTorch/utils/pipeline.py +++ b/autoPyTorch/utils/pipeline.py @@ -14,6 +14,7 @@ from autoPyTorch.pipeline.tabular_classification import TabularClassificationPipeline from autoPyTorch.pipeline.tabular_regression import TabularRegressionPipeline from autoPyTorch.utils.common import FitRequirement +from autoPyTorch.utils.hyperparameter_search_space_update import HyperparameterSearchSpaceUpdates __all__ = [ 'get_dataset_requirements', @@ -99,6 +100,7 @@ def _get_classification_dataset_requirements(info: Dict[str, Any], include: Dict def get_configuration_space(info: Dict[str, Any], include: Optional[Dict] = None, exclude: Optional[Dict] = None, + search_space_updates: Optional[HyperparameterSearchSpaceUpdates] = None ) -> ConfigurationSpace: task_type: int = STRING_TO_TASK_TYPES[info['task_type']] @@ -106,21 +108,26 @@ def get_configuration_space(info: Dict[str, Any], return _get_regression_configuration_space(info, include if include is not None else {}, exclude if exclude is not None else {}, + search_space_updates=search_space_updates ) else: return _get_classification_configuration_space(info, include if include is not None else {}, exclude if exclude is not None else {}, + search_space_updates=search_space_updates ) def _get_regression_configuration_space(info: Dict[str, Any], include: Dict[str, List[str]], - exclude: Dict[str, List[str]]) -> ConfigurationSpace: + exclude: Dict[str, List[str]], + search_space_updates: Optional[HyperparameterSearchSpaceUpdates] = None + ) -> ConfigurationSpace: if STRING_TO_TASK_TYPES[info['task_type']] in TABULAR_TASKS: configuration_space = TabularRegressionPipeline( dataset_properties=info, include=include, - exclude=exclude + exclude=exclude, + search_space_updates=search_space_updates ).get_hyperparameter_search_space() return configuration_space else: @@ -128,15 +135,19 @@ def _get_regression_configuration_space(info: Dict[str, Any], include: Dict[str, def _get_classification_configuration_space(info: Dict[str, Any], include: Dict[str, List[str]], - exclude: Dict[str, List[str]]) -> ConfigurationSpace: + exclude: Dict[str, List[str]], + search_space_updates: Optional[HyperparameterSearchSpaceUpdates] = None + ) -> ConfigurationSpace: if STRING_TO_TASK_TYPES[info['task_type']] in TABULAR_TASKS: pipeline = TabularClassificationPipeline(dataset_properties=info, - include=include, exclude=exclude) + include=include, exclude=exclude, + search_space_updates=search_space_updates) return pipeline.get_hyperparameter_search_space() elif STRING_TO_TASK_TYPES[info['task_type']] in IMAGE_TASKS: return ImageClassificationPipeline( dataset_properties=info, - include=include, exclude=exclude).\ + include=include, exclude=exclude, + search_space_updates=search_space_updates).\ get_hyperparameter_search_space() else: raise ValueError("Task_type not supported") diff --git a/examples/example_tabular_classification.py b/examples/example_tabular_classification.py index 4693b39bb..d87b29e8b 100644 --- a/examples/example_tabular_classification.py +++ b/examples/example_tabular_classification.py @@ -14,6 +14,7 @@ from autoPyTorch.api.tabular_classification import TabularClassificationTask from autoPyTorch.datasets.tabular_dataset import TabularDataset +from autoPyTorch.utils.hyperparameter_search_space_update import HyperparameterSearchSpaceUpdates # Get the training data for tabular classification @@ -35,6 +36,28 @@ def get_data_to_train() -> typing.Tuple[typing.Any, typing.Any, typing.Any, typi return X_train, X_test, y_train, y_test +def get_search_space_updates(): + """ + Search space updates to the task can be added using HyperparameterSearchSpaceUpdates + Returns: + HyperparameterSearchSpaceUpdates + """ + updates = HyperparameterSearchSpaceUpdates() + updates.append(node_name="data_loader", + hyperparameter="batch_size", + value_range=[16, 512], + default_value=32) + updates.append(node_name="lr_scheduler", + hyperparameter="CosineAnnealingLR:T_max", + value_range=[50, 60], + default_value=55) + updates.append(node_name='network_backbone', + hyperparameter='ResNetBackbone:dropout', + value_range=[0, 0.5], + default_value=0.2) + return updates + + if __name__ == '__main__': # Get data to train X_train, X_test, y_train, y_test = get_data_to_train() @@ -44,7 +67,8 @@ def get_data_to_train() -> typing.Tuple[typing.Any, typing.Any, typing.Any, typi X=X_train, Y=y_train, X_test=X_test, Y_test=y_test) - api = TabularClassificationTask(delete_tmp_folder_after_terminate=False,) + api = TabularClassificationTask(delete_tmp_folder_after_terminate=False, + search_space_updates=get_search_space_updates()) api.search(dataset=datamanager, optimize_metric='accuracy', total_walltime_limit=500, func_eval_time_limit=150) print(api.run_history, api.trajectory) y_pred = api.predict(X_test) diff --git a/test/conftest.py b/test/conftest.py index 1e66ed72a..195f51e13 100644 --- a/test/conftest.py +++ b/test/conftest.py @@ -14,6 +14,7 @@ from autoPyTorch.datasets.tabular_dataset import TabularDataset from autoPyTorch.utils.backend import create +from autoPyTorch.utils.hyperparameter_search_space_update import HyperparameterSearchSpaceUpdates from autoPyTorch.utils.pipeline import get_dataset_requirements @@ -23,7 +24,6 @@ def slugify(text): @pytest.fixture(scope="function") def backend(request): - test_dir = os.path.dirname(__file__) tmp = slugify(os.path.join( test_dir, '.tmp__%s__%s' % (request.module.__name__, request.node.name))) @@ -57,7 +57,9 @@ def session_run_at_end(): break except OSError: time.sleep(1) + return session_run_at_end + request.addfinalizer(get_finalizer(tmp, output)) return backend @@ -74,7 +76,6 @@ def output_dir(request): def _dir_fixture(dir_type, request): - test_dir = os.path.dirname(__file__) dir = os.path.join( test_dir, '.%s__%s__%s' % (dir_type, request.module.__name__, request.node.name) @@ -135,7 +136,9 @@ def session_run_at_end(): client.shutdown() client.close() del client + return session_run_at_end + request.addfinalizer(get_finalizer(client.scheduler_info()['address'])) return client @@ -312,3 +315,47 @@ def dataset_traditional_classifier_num_categorical(): y = y.astype(np.int) X, y = X[:200].to_numpy(), y[:200].to_numpy().astype(np.int) return X, y + + +@pytest.fixture +def search_space_updates(): + updates = HyperparameterSearchSpaceUpdates() + updates.append(node_name="imputer", + hyperparameter="numerical_strategy", + value_range=("mean", "most_frequent"), + default_value="mean") + updates.append(node_name="data_loader", + hyperparameter="batch_size", + value_range=[16, 512], + default_value=32) + updates.append(node_name="lr_scheduler", + hyperparameter="CosineAnnealingLR:T_max", + value_range=[50, 60], + default_value=55) + updates.append(node_name='network_backbone', + hyperparameter='ResNetBackbone:dropout', + value_range=[0, 0.5], + default_value=0.2) + return updates + + +@pytest.fixture +def error_search_space_updates(): + updates = HyperparameterSearchSpaceUpdates() + updates.append(node_name="imputer", + hyperparameter="num_str", + value_range=("mean", "most_frequent"), + default_value="mean") + updates.append(node_name="data_loader", + hyperparameter="batch_size", + value_range=[16, 512], + default_value=32) + updates.append(node_name="lr_scheduler", + hyperparameter="CosineAnnealingLR:T_max", + value_range=[50, 60], + default_value=55) + updates.append(node_name='network_backbone', + hyperparameter='ResNetBackbone:dropout', + value_range=[0, 0.5], + default_value=0.2) + return updates diff --git a/test/test_pipeline/test_pipeline.py b/test/test_pipeline/test_pipeline.py index b54083935..668930d57 100644 --- a/test/test_pipeline/test_pipeline.py +++ b/test/test_pipeline/test_pipeline.py @@ -13,6 +13,7 @@ def __init__(self, a=0, b='orange', random_state=None): self.a = a self.b = b self.fitted = False + self._cs_updates = {} def get_hyperparameter_search_space(self, dataset_properties=None): cs = CS.ConfigurationSpace() @@ -65,6 +66,7 @@ def base_pipeline(): ('DummyComponent1', DummyComponent(a=10, b='red')), ('DummyChoice', DummyChoice(base_pipeline.dataset_properties)) ] + base_pipeline.search_space_updates = None return base_pipeline diff --git a/test/test_pipeline/test_tabular_classification.py b/test/test_pipeline/test_tabular_classification.py index 6d19833ad..8a96004e9 100644 --- a/test/test_pipeline/test_tabular_classification.py +++ b/test/test_pipeline/test_tabular_classification.py @@ -1,3 +1,12 @@ +import os +import re + +from ConfigSpace.hyperparameters import ( + CategoricalHyperparameter, + UniformFloatHyperparameter, + UniformIntegerHyperparameter, +) + import numpy as np import pytest @@ -7,12 +16,34 @@ from autoPyTorch.pipeline.components.setup.early_preprocessor.utils import get_preprocess_transforms from autoPyTorch.pipeline.tabular_classification import TabularClassificationPipeline from autoPyTorch.utils.common import FitRequirement +from autoPyTorch.utils.hyperparameter_search_space_update import HyperparameterSearchSpaceUpdates, \ + parse_hyperparameter_search_space_updates @pytest.mark.parametrize("fit_dictionary", ['fit_dictionary_numerical_only', 'fit_dictionary_categorical_only', 'fit_dictionary_num_and_categorical'], indirect=True) class TestTabularClassification: + def _assert_pipeline_search_space(self, pipeline, search_space_updates): + config_space = pipeline.get_hyperparameter_search_space() + for update in search_space_updates.updates: + try: + assert update.node_name + ':' + update.hyperparameter in config_space + hyperparameter = config_space.get_hyperparameter(update.node_name + ':' + update.hyperparameter) + except AssertionError: + assert any(update.node_name + ':' + update.hyperparameter in name + for name in config_space.get_hyperparameter_names()), \ + "Can't find hyperparameter: {}".format(update.hyperparameter) + hyperparameter = config_space.get_hyperparameter(update.node_name + ':' + update.hyperparameter + '_1') + assert update.default_value == hyperparameter.default_value + if isinstance(hyperparameter, (UniformIntegerHyperparameter, UniformFloatHyperparameter)): + assert update.value_range[0] == hyperparameter.lower + assert update.value_range[1] == hyperparameter.upper + if hasattr(update, 'log'): + assert update.log == hyperparameter.log + elif isinstance(hyperparameter, CategoricalHyperparameter): + assert update.value_range == hyperparameter.choices + def test_pipeline_fit(self, fit_dictionary): """This test makes sure that the pipeline is able to fit given random combinations of hyperparameters across the pipeline""" @@ -155,8 +186,8 @@ def test_network_optimizer_lr_handshake(self, fit_dictionary): # Then fitting a optimizer should fail if no network: assert 'optimizer' in pipeline.named_steps.keys() with pytest.raises( - ValueError, - match=r"To fit .+?, expected fit dictionary to have 'network' but got .*" + ValueError, + match=r"To fit .+?, expected fit dictionary to have 'network' but got .*" ): pipeline.named_steps['optimizer'].fit({'dataset_properties': {}}, None) @@ -167,8 +198,8 @@ def test_network_optimizer_lr_handshake(self, fit_dictionary): # Then fitting a optimizer should fail if no network: assert 'lr_scheduler' in pipeline.named_steps.keys() with pytest.raises( - ValueError, - match=r"To fit .+?, expected fit dictionary to have 'optimizer' but got .*" + ValueError, + match=r"To fit .+?, expected fit dictionary to have 'optimizer' but got .*" ): pipeline.named_steps['lr_scheduler'].fit({'dataset_properties': {}}, None) @@ -186,3 +217,70 @@ def test_get_fit_requirements(self, fit_dictionary): assert isinstance(fit_requirements, list) for requirement in fit_requirements: assert isinstance(requirement, FitRequirement) + + def test_apply_search_space_updates(self, fit_dictionary, search_space_updates): + dataset_properties = {'numerical_columns': [1], 'categorical_columns': [2], + 'task_type': 'tabular_classification'} + pipeline = TabularClassificationPipeline(dataset_properties=dataset_properties, + search_space_updates=search_space_updates) + self._assert_pipeline_search_space(pipeline, search_space_updates) + + def test_read_and_update_search_space(self, fit_dictionary, search_space_updates): + import tempfile + path = tempfile.gettempdir() + path = os.path.join(path, 'updates.txt') + # Write to disk + search_space_updates.save_as_file(path=path) + assert os.path.exists(path=path) + + # Read from disk + file_search_space_updates = parse_hyperparameter_search_space_updates(updates_file=path) + assert isinstance(file_search_space_updates, HyperparameterSearchSpaceUpdates) + dataset_properties = {'numerical_columns': [1], 'categorical_columns': [2], + 'task_type': 'tabular_classification'} + pipeline = TabularClassificationPipeline(dataset_properties=dataset_properties, + search_space_updates=file_search_space_updates) + assert file_search_space_updates == pipeline.search_space_updates + + def test_error_search_space_updates(self, fit_dictionary, error_search_space_updates): + dataset_properties = {'numerical_columns': [1], 'categorical_columns': [2], + 'task_type': 'tabular_classification'} + try: + _ = TabularClassificationPipeline(dataset_properties=dataset_properties, + search_space_updates=error_search_space_updates) + except Exception as e: + assert isinstance(e, ValueError) + assert re.match(r'Unknown hyperparameter for component .*?\. Expected update ' + r'hyperparameter to be in \[.*?\] got .+', e.args[0]) + + def test_set_range_search_space_updates(self, fit_dictionary): + dataset_properties = {'numerical_columns': [1], 'categorical_columns': [2], + 'task_type': 'tabular_classification'} + config_dict = TabularClassificationPipeline(dataset_properties=dataset_properties). \ + get_hyperparameter_search_space()._hyperparameters + updates = HyperparameterSearchSpaceUpdates() + for i, (name, hyperparameter) in enumerate(config_dict.items()): + if '__choice__' in name: + continue + name = name.split(':') + hyperparameter_name = ':'.join(name[1:]) + if '_' in hyperparameter_name: + if any(l_.isnumeric() for l_ in hyperparameter_name.split('_')[-1]) and 'network' in name[0]: + hyperparameter_name = '_'.join(hyperparameter_name.split('_')[:-1]) + if isinstance(hyperparameter, CategoricalHyperparameter): + value_range = (hyperparameter.choices[0],) + default_value = hyperparameter.choices[0] + else: + value_range = (0, 1) + default_value = 1 + updates.append(node_name=name[0], hyperparameter=hyperparameter_name, + value_range=value_range, default_value=default_value) + pipeline = TabularClassificationPipeline(dataset_properties=dataset_properties, + search_space_updates=updates) + + try: + self._assert_pipeline_search_space(pipeline, updates) + except AssertionError as e: + # As we are setting num_layers to 1 for fully connected + # head, units_layer does not exist in the configspace + assert 'fully_connected:units_layer' in e.args[0]