diff --git a/autoPyTorch/pipeline/components/setup/network_backbone/ConvNetImageBackbone.py b/autoPyTorch/pipeline/components/setup/network_backbone/ConvNetImageBackbone.py new file mode 100644 index 000000000..a9d1855c8 --- /dev/null +++ b/autoPyTorch/pipeline/components/setup/network_backbone/ConvNetImageBackbone.py @@ -0,0 +1,122 @@ +from typing import Any, Dict, List, Optional, Tuple + +import ConfigSpace as CS +from ConfigSpace.configuration_space import ConfigurationSpace +from ConfigSpace.hyperparameters import ( + CategoricalHyperparameter, + UniformIntegerHyperparameter +) + +from torch import nn + +from autoPyTorch.pipeline.components.setup.network_backbone.base_network_backbone import NetworkBackboneComponent +from autoPyTorch.pipeline.components.setup.network_backbone.utils import _activations + + +class ConvNetImageBackbone(NetworkBackboneComponent): + """ + Standard Convolutional Neural Network backbone for images + """ + + def __init__(self, **kwargs: Any): + super().__init__(**kwargs) + self.bn_args = {"eps": 1e-5, "momentum": 0.1} + + def _get_layer_size(self, w: int, h: int) -> Tuple[int, int]: + cw = ((w - self.config["conv_kernel_size"] + 2 * self.config["conv_kernel_padding"]) + // self.config["conv_kernel_stride"]) + 1 + ch = ((h - self.config["conv_kernel_size"] + 2 * self.config["conv_kernel_padding"]) + // self.config["conv_kernel_stride"]) + 1 + cw, ch = cw // self.config["pool_size"], ch // self.config["pool_size"] + return cw, ch + + def _add_layer(self, layers: List[nn.Module], in_filters: int, out_filters: int) -> None: + layers.append(nn.Conv2d(in_filters, out_filters, + kernel_size=self.config["conv_kernel_size"], + stride=self.config["conv_kernel_stride"], + padding=self.config["conv_kernel_padding"])) + layers.append(nn.BatchNorm2d(out_filters, **self.bn_args)) + layers.append(_activations[self.config["activation"]]()) + layers.append(nn.MaxPool2d(kernel_size=self.config["pool_size"], stride=self.config["pool_size"])) + + def build_backbone(self, input_shape: Tuple[int, ...]) -> nn.Module: + channels, iw, ih = input_shape + layers: List[nn.Module] = [] + init_filter = self.config["conv_init_filters"] + self._add_layer(layers, channels, init_filter) + + cw, ch = self._get_layer_size(iw, ih) + for i in range(2, self.config["num_layers"] + 1): + cw, ch = self._get_layer_size(cw, ch) + if cw == 0 or ch == 0: + break + self._add_layer(layers, init_filter, init_filter * 2) + init_filter *= 2 + backbone = nn.Sequential(*layers) + self.backbone = backbone + return backbone + + @staticmethod + def get_properties(dataset_properties: Optional[Dict[str, str]] = None) -> Dict[str, Any]: + return { + 'shortname': 'ConvNetImageBackbone', + 'name': 'ConvNetImageBackbone', + 'handles_tabular': False, + 'handles_image': True, + 'handles_time_series': False, + } + + @staticmethod + def get_hyperparameter_search_space(dataset_properties: Optional[Dict] = None, + num_layers: Tuple[Tuple, int] = ((2, 8), 4), + num_init_filters: Tuple[Tuple, int] = ((16, 64), 32), + activation: Tuple[Tuple, str] = (tuple(_activations.keys()), + list(_activations.keys())[0]), + kernel_size: Tuple[Tuple, int] = ((3, 5), 3), + stride: Tuple[Tuple, int] = ((1, 3), 1), + padding: Tuple[Tuple, int] = ((2, 3), 2), + pool_size: Tuple[Tuple, int] = ((2, 3), 2) + ) -> ConfigurationSpace: + cs = CS.ConfigurationSpace() + + min_num_layers, max_num_layers = num_layers[0] + cs.add_hyperparameter(UniformIntegerHyperparameter('num_layers', + lower=min_num_layers, + upper=max_num_layers, + default_value=num_layers[1])) + + cs.add_hyperparameter(CategoricalHyperparameter('activation', + choices=activation[0], + default_value=activation[1])) + + min_init_filters, max_init_filters = num_init_filters[0] + cs.add_hyperparameter(UniformIntegerHyperparameter('conv_init_filters', + lower=min_init_filters, + upper=max_init_filters, + default_value=num_init_filters[1])) + + min_kernel_size, max_kernel_size = kernel_size[0] + cs.add_hyperparameter(UniformIntegerHyperparameter('conv_kernel_size', + lower=min_kernel_size, + upper=max_kernel_size, + default_value=kernel_size[1])) + + min_stride, max_stride = stride[0] + cs.add_hyperparameter(UniformIntegerHyperparameter('conv_kernel_stride', + lower=min_stride, + upper=max_stride, + default_value=stride[1])) + + min_padding, max_padding = padding[0] + cs.add_hyperparameter(UniformIntegerHyperparameter('conv_kernel_padding', + lower=min_padding, + upper=max_padding, + default_value=padding[1])) + + min_pool_size, max_pool_size = pool_size[0] + cs.add_hyperparameter(UniformIntegerHyperparameter('pool_size', + lower=min_pool_size, + upper=max_pool_size, + default_value=pool_size[1])) + + return cs diff --git a/autoPyTorch/pipeline/components/setup/network_backbone/image.py b/autoPyTorch/pipeline/components/setup/network_backbone/DenseNetImageBackone.py similarity index 56% rename from autoPyTorch/pipeline/components/setup/network_backbone/image.py rename to autoPyTorch/pipeline/components/setup/network_backbone/DenseNetImageBackone.py index bdf6acb68..98e0eb9b8 100644 --- a/autoPyTorch/pipeline/components/setup/network_backbone/image.py +++ b/autoPyTorch/pipeline/components/setup/network_backbone/DenseNetImageBackone.py @@ -1,7 +1,6 @@ -import logging import math from collections import OrderedDict -from typing import Any, Dict, List, Optional, Tuple +from typing import Any, Dict, Optional, Tuple import ConfigSpace as CS from ConfigSpace.configuration_space import ConfigurationSpace @@ -16,103 +15,7 @@ from torch.nn import functional as F from autoPyTorch.pipeline.components.setup.network_backbone.base_network_backbone import NetworkBackboneComponent - -_activations: Dict[str, nn.Module] = { - "relu": nn.ReLU, - "tanh": nn.Tanh, - "sigmoid": nn.Sigmoid -} - - -class ConvNetImageBackbone(NetworkBackboneComponent): - supported_tasks = {"image_classification", "image_regression"} - - def __init__(self, **kwargs: Any): - super().__init__(**kwargs) - self.bn_args = {"eps": 1e-5, "momentum": 0.1} - - def _get_layer_size(self, w: int, h: int) -> Tuple[int, int]: - cw = ((w - self.config["conv_kernel_size"] + 2 * self.config["conv_kernel_padding"]) - // self.config["conv_kernel_stride"]) + 1 - ch = ((h - self.config["conv_kernel_size"] + 2 * self.config["conv_kernel_padding"]) - // self.config["conv_kernel_stride"]) + 1 - cw, ch = cw // self.config["pool_size"], ch // self.config["pool_size"] - return cw, ch - - def _add_layer(self, layers: List[nn.Module], in_filters: int, out_filters: int) -> None: - layers.append(nn.Conv2d(in_filters, out_filters, - kernel_size=self.config["conv_kernel_size"], - stride=self.config["conv_kernel_stride"], - padding=self.config["conv_kernel_padding"])) - layers.append(nn.BatchNorm2d(out_filters, **self.bn_args)) - layers.append(_activations[self.config["activation"]]()) - layers.append(nn.MaxPool2d(kernel_size=self.config["pool_size"], stride=self.config["pool_size"])) - - def build_backbone(self, input_shape: Tuple[int, ...]) -> nn.Module: - channels, iw, ih = input_shape - layers: List[nn.Module] = [] - init_filter = self.config["conv_init_filters"] - self._add_layer(layers, channels, init_filter) - - cw, ch = self._get_layer_size(iw, ih) - for i in range(2, self.config["num_layers"] + 1): - cw, ch = self._get_layer_size(cw, ch) - if cw == 0 or ch == 0: - logging.info("> reduce network size due to too small layers.") - break - self._add_layer(layers, init_filter, init_filter * 2) - init_filter *= 2 - backbone = nn.Sequential(*layers) - self.backbone = backbone - return backbone - - @staticmethod - def get_properties(dataset_properties: Optional[Dict[str, str]] = None) -> Dict[str, Any]: - return { - 'shortname': 'ConvNetImageBackbone', - 'name': 'ConvNetImageBackbone', - 'handles_tabular': False, - 'handles_image': True, - 'handles_time_series': False, - } - - @staticmethod - def get_hyperparameter_search_space(dataset_properties: Optional[Dict[str, str]] = None, - min_num_layers: int = 2, - max_num_layers: int = 5, - min_init_filters: int = 16, - max_init_filters: int = 64, - min_kernel_size: int = 2, - max_kernel_size: int = 5, - min_stride: int = 1, - max_stride: int = 3, - min_padding: int = 2, - max_padding: int = 3, - min_pool_size: int = 2, - max_pool_size: int = 3) -> ConfigurationSpace: - cs = CS.ConfigurationSpace() - - cs.add_hyperparameter(UniformIntegerHyperparameter('num_layers', - lower=min_num_layers, - upper=max_num_layers)) - cs.add_hyperparameter(CategoricalHyperparameter('activation', - choices=list(_activations.keys()))) - cs.add_hyperparameter(UniformIntegerHyperparameter('conv_init_filters', - lower=min_init_filters, - upper=max_init_filters)) - cs.add_hyperparameter(UniformIntegerHyperparameter('conv_kernel_size', - lower=min_kernel_size, - upper=max_kernel_size)) - cs.add_hyperparameter(UniformIntegerHyperparameter('conv_kernel_stride', - lower=min_stride, - upper=max_stride)) - cs.add_hyperparameter(UniformIntegerHyperparameter('conv_kernel_padding', - lower=min_padding, - upper=max_padding)) - cs.add_hyperparameter(UniformIntegerHyperparameter('pool_size', - lower=min_pool_size, - upper=max_pool_size)) - return cs +from autoPyTorch.pipeline.components.setup.network_backbone.utils import _activations class _DenseLayer(nn.Sequential): @@ -177,7 +80,9 @@ def __init__(self, class DenseNetBackbone(NetworkBackboneComponent): - supported_tasks = {"image_classification", "image_regression"} + """ + Dense Net Backbone for images (see https://arxiv.org/pdf/1608.06993.pdf) + """ def __init__(self, **kwargs: Any): super().__init__(**kwargs) @@ -247,39 +152,55 @@ def get_properties(dataset_properties: Optional[Dict[str, str]] = None) -> Dict[ } @staticmethod - def get_hyperparameter_search_space(dataset_properties: Optional[Dict[str, str]] = None, - min_growth_rate: int = 12, - max_growth_rate: int = 40, - min_num_blocks: int = 3, - max_num_blocks: int = 4, - min_num_layers: int = 4, - max_num_layers: int = 64) -> ConfigurationSpace: + def get_hyperparameter_search_space(dataset_properties: Optional[Dict] = None, + num_blocks: Tuple[Tuple, int] = ((3, 4), 3), + num_layers: Tuple[Tuple, int] = ((4, 64), 16), + growth_rate: Tuple[Tuple, int] = ((12, 40), 20), + activation: Tuple[Tuple, str] = (tuple(_activations.keys()), + list(_activations.keys())[0]), + use_dropout: Tuple[Tuple, bool] = ((True, False), False), + dropout: Tuple[Tuple, float] = ((0, 0.5), 0.2) + ) -> ConfigurationSpace: cs = CS.ConfigurationSpace() + + min_growth_rate, max_growth_rate = growth_rate[0] growth_rate_hp = UniformIntegerHyperparameter('growth_rate', lower=min_growth_rate, - upper=max_growth_rate) + upper=max_growth_rate, + default_value=growth_rate[1]) cs.add_hyperparameter(growth_rate_hp) + min_num_blocks, max_num_blocks = num_blocks[0] blocks_hp = UniformIntegerHyperparameter('blocks', lower=min_num_blocks, - upper=max_num_blocks) + upper=max_num_blocks, + default_value=num_blocks[1]) cs.add_hyperparameter(blocks_hp) activation_hp = CategoricalHyperparameter('activation', - choices=list(_activations.keys())) + choices=activation[0], + default_value=activation[1]) cs.add_hyperparameter(activation_hp) - use_dropout = CategoricalHyperparameter('use_dropout', choices=[True, False]) + use_dropout = CategoricalHyperparameter('use_dropout', + choices=use_dropout[0], + default_value=use_dropout[1]) + + min_dropout, max_dropout = dropout[0] dropout = UniformFloatHyperparameter('dropout', - lower=0.0, - upper=1.0) + lower=min_dropout, + upper=max_dropout, + default_value=dropout[1]) + cs.add_hyperparameters([use_dropout, dropout]) cs.add_condition(CS.EqualsCondition(dropout, use_dropout, True)) for i in range(1, max_num_blocks + 1): + min_num_layers, max_num_layers = num_layers[0] layer_hp = UniformIntegerHyperparameter('layer_in_block_%d' % i, lower=min_num_layers, - upper=max_num_layers) + upper=max_num_layers, + default_value=num_layers[1]) cs.add_hyperparameter(layer_hp) if i > min_num_blocks: diff --git a/autoPyTorch/pipeline/components/setup/network_backbone/InceptionTimeBackbone.py b/autoPyTorch/pipeline/components/setup/network_backbone/InceptionTimeBackbone.py new file mode 100644 index 000000000..4bf5c8842 --- /dev/null +++ b/autoPyTorch/pipeline/components/setup/network_backbone/InceptionTimeBackbone.py @@ -0,0 +1,180 @@ +from typing import Any, Dict, Optional, Tuple + +from ConfigSpace.configuration_space import ConfigurationSpace +from ConfigSpace.hyperparameters import ( + UniformIntegerHyperparameter +) + +import torch +from torch import nn + +from autoPyTorch.pipeline.components.setup.network_backbone.base_network_backbone import NetworkBackboneComponent + + +# Code inspired by https://github.com/hfawaz/InceptionTime +# Paper: https://arxiv.org/pdf/1909.04939.pdf +class _InceptionBlock(nn.Module): + def __init__(self, + n_inputs: int, + n_filters: int, + kernel_size: int, + bottleneck: int = None): + super(_InceptionBlock, self).__init__() + self.n_filters = n_filters + self.bottleneck = None \ + if bottleneck is None \ + else nn.Conv1d(n_inputs, bottleneck, kernel_size=1) + + kernel_sizes = [kernel_size // (2 ** i) for i in range(3)] + n_inputs = n_inputs if bottleneck is None else bottleneck + + # create 3 conv layers with different kernel sizes which are applied in parallel + self.pad1 = nn.ConstantPad1d( + padding=self._padding(kernel_sizes[0]), value=0) + self.conv1 = nn.Conv1d(n_inputs, n_filters, kernel_sizes[0]) + + self.pad2 = nn.ConstantPad1d( + padding=self._padding(kernel_sizes[1]), value=0) + self.conv2 = nn.Conv1d(n_inputs, n_filters, kernel_sizes[1]) + + self.pad3 = nn.ConstantPad1d( + padding=self._padding(kernel_sizes[2]), value=0) + self.conv3 = nn.Conv1d(n_inputs, n_filters, kernel_sizes[2]) + + # create 1 maxpool and conv layer which are also applied in parallel + self.maxpool = nn.MaxPool1d(kernel_size=3, stride=1, padding=1) + self.convpool = nn.Conv1d(n_inputs, n_filters, 1) + + self.bn = nn.BatchNorm1d(4 * n_filters) + + def _padding(self, kernel_size: int) -> Tuple[int, int]: + if kernel_size % 2 == 0: + return kernel_size // 2, kernel_size // 2 - 1 + else: + return kernel_size // 2, kernel_size // 2 + + def get_n_outputs(self) -> int: + return 4 * self.n_filters + + def forward(self, x: torch.Tensor) -> torch.Tensor: + if self.bottleneck is not None: + x = self.bottleneck(x) + x1 = self.conv1(self.pad1(x)) + x2 = self.conv2(self.pad2(x)) + x3 = self.conv3(self.pad3(x)) + x4 = self.convpool(self.maxpool(x)) + x = torch.cat([x1, x2, x3, x4], dim=1) + x = self.bn(x) + return torch.relu(x) + + +class _ResidualBlock(nn.Module): + def __init__(self, n_res_inputs: int, n_outputs: int): + super(_ResidualBlock, self).__init__() + self.shortcut = nn.Conv1d(n_res_inputs, n_outputs, 1, bias=False) + self.bn = nn.BatchNorm1d(n_outputs) + + def forward(self, x: torch.Tensor, res: torch.Tensor) -> torch.Tensor: + shortcut = self.shortcut(res) + shortcut = self.bn(shortcut) + x += shortcut + return torch.relu(x) + + +class _InceptionTime(nn.Module): + def __init__(self, + in_features: int, + config: Dict[str, Any]) -> None: + super().__init__() + self.config = config + n_inputs = in_features + n_filters = self.config["num_filters"] + bottleneck_size = self.config["bottleneck_size"] + kernel_size = self.config["kernel_size"] + n_res_inputs = in_features + for i in range(self.config["num_blocks"]): + block = _InceptionBlock(n_inputs=n_inputs, + n_filters=n_filters, + bottleneck=bottleneck_size, + kernel_size=kernel_size) + self.__setattr__(f"inception_block_{i}", block) + + # add a residual block after every 3 inception blocks + if i % 3 == 2: + n_res_outputs = block.get_n_outputs() + self.__setattr__(f"residual_block_{i}", _ResidualBlock(n_res_inputs=n_res_inputs, + n_outputs=n_res_outputs)) + n_res_inputs = n_res_outputs + n_inputs = block.get_n_outputs() + + def forward(self, x: torch.Tensor) -> torch.Tensor: + # swap sequence and feature dimensions for use with convolutional nets + x = x.transpose(1, 2).contiguous() + res = x + for i in range(self.config["num_blocks"]): + x = self.__getattr__(f"inception_block_{i}")(x) + if i % 3 == 2: + x = self.__getattr__(f"residual_block_{i}")(x, res) + res = x + x = x.transpose(1, 2).contiguous() + return x + + +class InceptionTimeBackbone(NetworkBackboneComponent): + """ + InceptionTime backbone for time series data (see https://arxiv.org/pdf/1909.04939.pdf). + """ + + def build_backbone(self, input_shape: Tuple[int, ...]) -> nn.Module: + backbone = _InceptionTime(in_features=input_shape[-1], + config=self.config) + self.backbone = backbone + return backbone + + @staticmethod + def get_properties(dataset_properties: Optional[Dict[str, str]] = None) -> Dict[str, Any]: + return { + 'shortname': 'InceptionTimeBackbone', + 'name': 'InceptionTimeBackbone', + 'handles_tabular': False, + 'handles_image': False, + 'handles_time_series': True, + } + + @staticmethod + def get_hyperparameter_search_space(dataset_properties: Optional[Dict] = None, + num_blocks: Tuple[Tuple, int] = ((1, 10), 5), + num_filters: Tuple[Tuple, int] = ((4, 64), 32), + kernel_size: Tuple[Tuple, int] = ((4, 64), 32), + bottleneck_size: Tuple[Tuple, int] = ((16, 64), 32) + ) -> ConfigurationSpace: + cs = ConfigurationSpace() + + min_num_blocks, max_num_blocks = num_blocks[0] + num_blocks_hp = UniformIntegerHyperparameter("num_blocks", + lower=min_num_blocks, + upper=max_num_blocks, + default_value=num_blocks[1]) + cs.add_hyperparameter(num_blocks_hp) + + min_num_filters, max_num_filters = num_filters[0] + num_filters_hp = UniformIntegerHyperparameter("num_filters", + lower=min_num_filters, + upper=max_num_filters, + default_value=num_filters[1]) + cs.add_hyperparameter(num_filters_hp) + + min_bottleneck_size, max_bottleneck_size = bottleneck_size[0] + bottleneck_size_hp = UniformIntegerHyperparameter("bottleneck_size", + lower=min_bottleneck_size, + upper=max_bottleneck_size, + default_value=bottleneck_size[1]) + cs.add_hyperparameter(bottleneck_size_hp) + + min_kernel_size, max_kernel_size = kernel_size[0] + kernel_size_hp = UniformIntegerHyperparameter("kernel_size", + lower=min_kernel_size, + upper=max_kernel_size, + default_value=kernel_size[1]) + cs.add_hyperparameter(kernel_size_hp) + return cs diff --git a/autoPyTorch/pipeline/components/setup/network_backbone/MLPBackbone.py b/autoPyTorch/pipeline/components/setup/network_backbone/MLPBackbone.py index 6bf7ec36e..230ddfe96 100644 --- a/autoPyTorch/pipeline/components/setup/network_backbone/MLPBackbone.py +++ b/autoPyTorch/pipeline/components/setup/network_backbone/MLPBackbone.py @@ -10,12 +10,8 @@ from torch import nn -from autoPyTorch.pipeline.components.setup.network_backbone.base_network_backbone import ( - NetworkBackboneComponent, -) -from autoPyTorch.pipeline.components.setup.network_backbone.utils import ( - _activations, -) +from autoPyTorch.pipeline.components.setup.network_backbone.base_network_backbone import NetworkBackboneComponent +from autoPyTorch.pipeline.components.setup.network_backbone.utils import _activations class MLPBackbone(NetworkBackboneComponent): @@ -28,7 +24,6 @@ class MLPBackbone(NetworkBackboneComponent): - Using or not dropout - Specifying the number of units per layers """ - supported_tasks = {"tabular_classification", "tabular_regression"} def build_backbone(self, input_shape: Tuple[int, ...]) -> nn.Module: layers = list() # type: List[nn.Module] diff --git a/autoPyTorch/pipeline/components/setup/network_backbone/ResNetBackbone.py b/autoPyTorch/pipeline/components/setup/network_backbone/ResNetBackbone.py index 634aabee0..4433f540c 100644 --- a/autoPyTorch/pipeline/components/setup/network_backbone/ResNetBackbone.py +++ b/autoPyTorch/pipeline/components/setup/network_backbone/ResNetBackbone.py @@ -11,9 +11,7 @@ import torch from torch import nn -from autoPyTorch.pipeline.components.setup.network_backbone.base_network_backbone import ( - NetworkBackboneComponent, -) +from autoPyTorch.pipeline.components.setup.network_backbone.base_network_backbone import NetworkBackboneComponent from autoPyTorch.pipeline.components.setup.network_backbone.utils import ( _activations, shake_drop, @@ -26,9 +24,7 @@ class ResNetBackbone(NetworkBackboneComponent): """ Implementation of a Residual Network backbone - """ - supported_tasks = {"tabular_classification", "tabular_regression"} def build_backbone(self, input_shape: Tuple[int, ...]) -> None: layers = list() # type: List[nn.Module] diff --git a/autoPyTorch/pipeline/components/setup/network_backbone/ShapedMLPBackbone.py b/autoPyTorch/pipeline/components/setup/network_backbone/ShapedMLPBackbone.py index 607823430..3e8be6b70 100644 --- a/autoPyTorch/pipeline/components/setup/network_backbone/ShapedMLPBackbone.py +++ b/autoPyTorch/pipeline/components/setup/network_backbone/ShapedMLPBackbone.py @@ -10,9 +10,7 @@ from torch import nn -from autoPyTorch.pipeline.components.setup.network_backbone.base_network_backbone import ( - NetworkBackboneComponent, -) +from autoPyTorch.pipeline.components.setup.network_backbone.base_network_backbone import NetworkBackboneComponent from autoPyTorch.pipeline.components.setup.network_backbone.utils import ( _activations, get_shaped_neuron_counts, @@ -21,10 +19,9 @@ class ShapedMLPBackbone(NetworkBackboneComponent): """ - Implementation of a Shaped MLP -- an MLP with the number of units - arranged so that a given shape is honored + Implementation of a Shaped MLP -- an MLP with the number of units + arranged so that a given shape is honored """ - supported_tasks = {"tabular_classification", "tabular_regression"} def build_backbone(self, input_shape: Tuple[int, ...]) -> nn.Module: layers = list() # type: List[nn.Module] diff --git a/autoPyTorch/pipeline/components/setup/network_backbone/ShapedResNetBackbone.py b/autoPyTorch/pipeline/components/setup/network_backbone/ShapedResNetBackbone.py index b3efc7bb1..fbfae28ad 100644 --- a/autoPyTorch/pipeline/components/setup/network_backbone/ShapedResNetBackbone.py +++ b/autoPyTorch/pipeline/components/setup/network_backbone/ShapedResNetBackbone.py @@ -21,7 +21,6 @@ class ShapedResNetBackbone(ResNetBackbone): """ Implementation of a Residual Network builder with support for shaped number of units per group. - """ def build_backbone(self, input_shape: Tuple[int, ...]) -> None: diff --git a/autoPyTorch/pipeline/components/setup/network_backbone/TCNBackbone.py b/autoPyTorch/pipeline/components/setup/network_backbone/TCNBackbone.py new file mode 100644 index 000000000..c9768153f --- /dev/null +++ b/autoPyTorch/pipeline/components/setup/network_backbone/TCNBackbone.py @@ -0,0 +1,172 @@ +from typing import Any, Dict, List, Optional, Tuple + +import ConfigSpace as CS +from ConfigSpace.configuration_space import ConfigurationSpace +from ConfigSpace.hyperparameters import ( + CategoricalHyperparameter, + UniformFloatHyperparameter, + UniformIntegerHyperparameter +) + +import torch +from torch import nn +from torch.nn.utils import weight_norm + +from autoPyTorch.pipeline.components.setup.network_backbone.base_network_backbone import NetworkBackboneComponent + + +# _Chomp1d, _TemporalBlock and _TemporalConvNet copied from +# https://github.com/locuslab/TCN/blob/master/TCN/tcn.py, Carnegie Mellon University Locus Labs +# Paper: https://arxiv.org/pdf/1803.01271.pdf +class _Chomp1d(nn.Module): + def __init__(self, chomp_size: int): + super(_Chomp1d, self).__init__() + self.chomp_size = chomp_size + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return x[:, :, :-self.chomp_size].contiguous() + + +class _TemporalBlock(nn.Module): + def __init__(self, + n_inputs: int, + n_outputs: int, + kernel_size: int, + stride: int, + dilation: int, + padding: int, + dropout: float = 0.2): + super(_TemporalBlock, self).__init__() + self.conv1 = weight_norm(nn.Conv1d(n_inputs, n_outputs, kernel_size, + stride=stride, padding=padding, dilation=dilation)) + self.chomp1 = _Chomp1d(padding) + self.relu1 = nn.ReLU() + self.dropout1 = nn.Dropout(dropout) + + self.conv2 = weight_norm(nn.Conv1d(n_outputs, n_outputs, kernel_size, + stride=stride, padding=padding, dilation=dilation)) + self.chomp2 = _Chomp1d(padding) + self.relu2 = nn.ReLU() + self.dropout2 = nn.Dropout(dropout) + + self.net = nn.Sequential(self.conv1, self.chomp1, self.relu1, self.dropout1, + self.conv2, self.chomp2, self.relu2, self.dropout2) + self.downsample = nn.Conv1d( + n_inputs, n_outputs, 1) if n_inputs != n_outputs else None + self.relu = nn.ReLU() + # self.init_weights() + + def init_weights(self) -> None: + self.conv1.weight.data.normal_(0, 0.01) + self.conv2.weight.data.normal_(0, 0.01) + if self.downsample is not None: + self.downsample.weight.data.normal_(0, 0.01) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + out = self.net(x) + res = x if self.downsample is None else self.downsample(x) + return self.relu(out + res) + + +class _TemporalConvNet(nn.Module): + def __init__(self, num_inputs: int, num_channels: List[int], kernel_size: int = 2, dropout: float = 0.2): + super(_TemporalConvNet, self).__init__() + layers: List[Any] = [] + num_levels = len(num_channels) + for i in range(num_levels): + dilation_size = 2 ** i + in_channels = num_inputs if i == 0 else num_channels[i - 1] + out_channels = num_channels[i] + layers += [_TemporalBlock(in_channels, + out_channels, + kernel_size, + stride=1, + dilation=dilation_size, + padding=(kernel_size - 1) * dilation_size, + dropout=dropout)] + self.network = nn.Sequential(*layers) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + # swap sequence and feature dimensions for use with convolutional nets + x = x.transpose(1, 2).contiguous() + x = self.network(x) + x = x.transpose(1, 2).contiguous() + return x + + +class TCNBackbone(NetworkBackboneComponent): + """ + Temporal Convolutional Network backbone for time series data (see https://arxiv.org/pdf/1803.01271.pdf). + """ + + def build_backbone(self, input_shape: Tuple[int, ...]) -> nn.Module: + num_channels = [self.config["num_filters_0"]] + for i in range(1, self.config["num_blocks"]): + num_channels.append(self.config[f"num_filters_{i}"]) + backbone = _TemporalConvNet(input_shape[-1], + num_channels, + kernel_size=self.config["kernel_size"], + dropout=self.config["dropout"] if self.config["use_dropout"] else 0.0 + ) + self.backbone = backbone + return backbone + + @staticmethod + def get_properties(dataset_properties: Optional[Dict[str, str]] = None) -> Dict[str, Any]: + return { + "shortname": "TCNBackbone", + "name": "TCNBackbone", + 'handles_tabular': False, + 'handles_image': False, + 'handles_time_series': True, + } + + @staticmethod + def get_hyperparameter_search_space(dataset_properties: Optional[Dict] = None, + num_blocks: Tuple[Tuple, int] = ((1, 10), 5), + num_filters: Tuple[Tuple, int] = ((4, 64), 32), + kernel_size: Tuple[Tuple, int] = ((4, 64), 32), + use_dropout: Tuple[Tuple, bool] = ((True, False), False), + dropout: Tuple[Tuple, float] = ((0.0, 0.5), 0.1) + ) -> ConfigurationSpace: + cs = ConfigurationSpace() + + min_num_blocks, max_num_blocks = num_blocks[0] + num_blocks_hp = UniformIntegerHyperparameter("num_blocks", + lower=min_num_blocks, + upper=max_num_blocks, + default_value=num_blocks[1]) + cs.add_hyperparameter(num_blocks_hp) + + min_kernel_size, max_kernel_size = kernel_size[0] + kernel_size_hp = UniformIntegerHyperparameter("kernel_size", + lower=min_kernel_size, + upper=max_kernel_size, + default_value=kernel_size[1]) + cs.add_hyperparameter(kernel_size_hp) + + use_dropout_hp = CategoricalHyperparameter("use_dropout", + choices=use_dropout[0], + default_value=use_dropout[1]) + cs.add_hyperparameter(use_dropout_hp) + + min_dropout, max_dropout = dropout[0] + dropout_hp = UniformFloatHyperparameter("dropout", + lower=min_dropout, + upper=max_dropout, + default_value=dropout[1]) + cs.add_hyperparameter(dropout_hp) + cs.add_condition(CS.EqualsCondition(dropout_hp, use_dropout_hp, True)) + + for i in range(0, max_num_blocks): + min_num_filters, max_num_filters = num_filters[0] + num_filters_hp = UniformIntegerHyperparameter(f"num_filters_{i}", + lower=min_num_filters, + upper=max_num_filters, + default_value=num_filters[1]) + cs.add_hyperparameter(num_filters_hp) + if i >= min_num_blocks: + cs.add_condition(CS.GreaterThanCondition( + num_filters_hp, num_blocks_hp, i)) + + return cs diff --git a/autoPyTorch/pipeline/components/setup/network_backbone/base_network_backbone.py b/autoPyTorch/pipeline/components/setup/network_backbone/base_network_backbone.py index 639975c1d..d355005e8 100644 --- a/autoPyTorch/pipeline/components/setup/network_backbone/base_network_backbone.py +++ b/autoPyTorch/pipeline/components/setup/network_backbone/base_network_backbone.py @@ -1,5 +1,5 @@ from abc import abstractmethod -from typing import Any, Dict, Set, Tuple +from typing import Any, Dict, Tuple import torch from torch import nn @@ -12,9 +12,9 @@ class NetworkBackboneComponent(autoPyTorchComponent): """ - Backbone base class + Base class for network backbones. Holds the backbone module and the config which was used to create it. """ - supported_tasks: Set = set() + _required_properties = ["name", "shortname", "handles_tabular", "handles_image", "handles_time_series"] def __init__(self, **kwargs: Any): @@ -24,8 +24,15 @@ def __init__(self, def fit(self, X: Dict[str, Any], y: Any = None) -> BaseEstimator: """ - Not used. Just for API compatibility. + Builds the backbone component and assigns it to self.backbone + + Args: + X (X: Dict[str, Any]): Dependencies needed by current component to perform fit + y (Any): not used. To comply with sklearn API + Returns: + Self """ + input_shape = X['X_train'].shape[1:] self.backbone = self.build_backbone( @@ -35,7 +42,8 @@ def fit(self, X: Dict[str, Any], y: Any = None) -> BaseEstimator: def transform(self, X: Dict[str, Any]) -> Dict[str, Any]: """ - Adds the scheduler into the fit dictionary 'X' and returns it. + Adds the network head into the fit dictionary 'X' and returns it. + Args: X (Dict[str, Any]): 'X' dictionary Returns: @@ -47,11 +55,13 @@ def transform(self, X: Dict[str, Any]) -> Dict[str, Any]: @abstractmethod def build_backbone(self, input_shape: Tuple[int, ...]) -> nn.Module: """ + Builds the backbone module and returns it - Builds the backbone module and assigns it to self.backbone + Args: + input_shape (Tuple[int, ...]): shape of the input to the backbone - :param input_shape: shape of the input - :return: the backbone module + Returns: + nn.Module: backbone module """ raise NotImplementedError() @@ -61,8 +71,11 @@ def get_output_shape(self, input_shape: Tuple[int, ...]) -> Tuple[int, ...]: Can and should be overridden by subclasses that know the output shape without running a dummy forward pass. - :param input_shape: shape of the input - :return: output_shape + Args: + input_shape (Tuple[int, ...]): shape of the input + + Returns: + output_shape (Tuple[int, ...]): shape of the backbone output """ placeholder = torch.randn((2, *input_shape), dtype=torch.float) with torch.no_grad(): @@ -73,6 +86,11 @@ def get_output_shape(self, input_shape: Tuple[int, ...]) -> Tuple[int, ...]: def get_name(cls) -> str: """ Get the name of the backbone - :return: name of the backbone + + Args: + None + + Returns: + str: Name of the backbone """ return cls.get_properties()["shortname"] diff --git a/autoPyTorch/pipeline/components/setup/network_backbone/time_series.py b/autoPyTorch/pipeline/components/setup/network_backbone/time_series.py deleted file mode 100644 index 6663a3565..000000000 --- a/autoPyTorch/pipeline/components/setup/network_backbone/time_series.py +++ /dev/null @@ -1,329 +0,0 @@ -from typing import Any, Dict, List, Optional, Tuple - -import ConfigSpace as CS -from ConfigSpace.configuration_space import ConfigurationSpace -from ConfigSpace.hyperparameters import ( - CategoricalHyperparameter, - UniformFloatHyperparameter, - UniformIntegerHyperparameter -) - -import torch -from torch import nn -from torch.nn.utils import weight_norm - -from autoPyTorch.pipeline.components.setup.network_backbone.base_network_backbone import ( - NetworkBackboneComponent, -) - - -# Code inspired by https://github.com/hfawaz/InceptionTime -# Paper: https://arxiv.org/pdf/1909.04939.pdf -class _InceptionBlock(nn.Module): - def __init__(self, - n_inputs: int, - n_filters: int, - kernel_size: int, - bottleneck: int = None): - super(_InceptionBlock, self).__init__() - self.n_filters = n_filters - self.bottleneck = None \ - if bottleneck is None \ - else nn.Conv1d(n_inputs, bottleneck, kernel_size=1) - - kernel_sizes = [kernel_size // (2 ** i) for i in range(3)] - n_inputs = n_inputs if bottleneck is None else bottleneck - - # create 3 conv layers with different kernel sizes which are applied in parallel - self.pad1 = nn.ConstantPad1d( - padding=self._padding(kernel_sizes[0]), value=0) - self.conv1 = nn.Conv1d(n_inputs, n_filters, kernel_sizes[0]) - - self.pad2 = nn.ConstantPad1d( - padding=self._padding(kernel_sizes[1]), value=0) - self.conv2 = nn.Conv1d(n_inputs, n_filters, kernel_sizes[1]) - - self.pad3 = nn.ConstantPad1d( - padding=self._padding(kernel_sizes[2]), value=0) - self.conv3 = nn.Conv1d(n_inputs, n_filters, kernel_sizes[2]) - - # create 1 maxpool and conv layer which are also applied in parallel - self.maxpool = nn.MaxPool1d(kernel_size=3, stride=1, padding=1) - self.convpool = nn.Conv1d(n_inputs, n_filters, 1) - - self.bn = nn.BatchNorm1d(4 * n_filters) - - def _padding(self, kernel_size: int) -> Tuple[int, int]: - if kernel_size % 2 == 0: - return kernel_size // 2, kernel_size // 2 - 1 - else: - return kernel_size // 2, kernel_size // 2 - - def get_n_outputs(self) -> int: - return 4 * self.n_filters - - def forward(self, x: torch.Tensor) -> torch.Tensor: - if self.bottleneck is not None: - x = self.bottleneck(x) - x1 = self.conv1(self.pad1(x)) - x2 = self.conv2(self.pad2(x)) - x3 = self.conv3(self.pad3(x)) - x4 = self.convpool(self.maxpool(x)) - x = torch.cat([x1, x2, x3, x4], dim=1) - x = self.bn(x) - return torch.relu(x) - - -class _ResidualBlock(nn.Module): - def __init__(self, n_res_inputs: int, n_outputs: int): - super(_ResidualBlock, self).__init__() - self.shortcut = nn.Conv1d(n_res_inputs, n_outputs, 1, bias=False) - self.bn = nn.BatchNorm1d(n_outputs) - - def forward(self, x: torch.Tensor, res: torch.Tensor) -> torch.Tensor: - shortcut = self.shortcut(res) - shortcut = self.bn(shortcut) - x += shortcut - return torch.relu(x) - - -class _InceptionTime(nn.Module): - def __init__(self, - in_features: int, - config: Dict[str, Any]) -> None: - super().__init__() - self.config = config - n_inputs = in_features - n_filters = self.config["num_filters"] - bottleneck_size = self.config["bottleneck_size"] - kernel_size = self.config["kernel_size"] - n_res_inputs = in_features - for i in range(self.config["num_blocks"]): - block = _InceptionBlock(n_inputs=n_inputs, - n_filters=n_filters, - bottleneck=bottleneck_size, - kernel_size=kernel_size) - self.__setattr__(f"inception_block_{i}", block) - - # add a residual block after every 3 inception blocks - if i % 3 == 2: - n_res_outputs = block.get_n_outputs() - self.__setattr__(f"residual_block_{i}", _ResidualBlock(n_res_inputs=n_res_inputs, - n_outputs=n_res_outputs)) - n_res_inputs = n_res_outputs - n_inputs = block.get_n_outputs() - - def forward(self, x: torch.Tensor) -> torch.Tensor: - # swap sequence and feature dimensions for use with convolutional nets - x = x.transpose(1, 2).contiguous() - res = x - for i in range(self.config["num_blocks"]): - x = self.__getattr__(f"inception_block_{i}")(x) - if i % 3 == 2: - x = self.__getattr__(f"residual_block_{i}")(x, res) - res = x - x = x.transpose(1, 2).contiguous() - return x - - -class InceptionTimeBackbone(NetworkBackboneComponent): - supported_tasks = {"time_series_classification", "time_series_regression"} - - def build_backbone(self, input_shape: Tuple[int, ...]) -> nn.Module: - backbone = _InceptionTime(in_features=input_shape[-1], - config=self.config) - self.backbone = backbone - return backbone - - @staticmethod - def get_properties(dataset_properties: Optional[Dict[str, str]] = None) -> Dict[str, Any]: - return { - 'shortname': 'InceptionTimeBackbone', - 'name': 'InceptionTimeBackbone', - 'handles_tabular': False, - 'handles_image': False, - 'handles_time_series': True, - } - - @staticmethod - def get_hyperparameter_search_space(dataset_properties: Optional[Dict[str, str]] = None, - min_num_blocks: int = 1, - max_num_blocks: int = 10, - min_num_filters: int = 16, - max_num_filters: int = 64, - min_kernel_size: int = 32, - max_kernel_size: int = 64, - min_bottleneck_size: int = 16, - max_bottleneck_size: int = 64, - ) -> ConfigurationSpace: - cs = ConfigurationSpace() - - num_blocks_hp = UniformIntegerHyperparameter("num_blocks", - lower=min_num_blocks, - upper=max_num_blocks) - cs.add_hyperparameter(num_blocks_hp) - - num_filters_hp = UniformIntegerHyperparameter("num_filters", - lower=min_num_filters, - upper=max_num_filters) - cs.add_hyperparameter(num_filters_hp) - - bottleneck_size_hp = UniformIntegerHyperparameter("bottleneck_size", - lower=min_bottleneck_size, - upper=max_bottleneck_size) - cs.add_hyperparameter(bottleneck_size_hp) - - kernel_size_hp = UniformIntegerHyperparameter("kernel_size", - lower=min_kernel_size, - upper=max_kernel_size) - cs.add_hyperparameter(kernel_size_hp) - return cs - - -# Chomp1d, TemporalBlock and TemporalConvNet copied from -# https://github.com/locuslab/TCN/blob/master/TCN/tcn.py, Carnegie Mellon University Locus Labs -# Paper: https://arxiv.org/pdf/1803.01271.pdf -class _Chomp1d(nn.Module): - def __init__(self, chomp_size: int): - super(_Chomp1d, self).__init__() - self.chomp_size = chomp_size - - def forward(self, x: torch.Tensor) -> torch.Tensor: - return x[:, :, :-self.chomp_size].contiguous() - - -class _TemporalBlock(nn.Module): - def __init__(self, - n_inputs: int, - n_outputs: int, - kernel_size: int, - stride: int, - dilation: int, - padding: int, - dropout: float = 0.2): - super(_TemporalBlock, self).__init__() - self.conv1 = weight_norm(nn.Conv1d(n_inputs, n_outputs, kernel_size, - stride=stride, padding=padding, dilation=dilation)) - self.chomp1 = _Chomp1d(padding) - self.relu1 = nn.ReLU() - self.dropout1 = nn.Dropout(dropout) - - self.conv2 = weight_norm(nn.Conv1d(n_outputs, n_outputs, kernel_size, - stride=stride, padding=padding, dilation=dilation)) - self.chomp2 = _Chomp1d(padding) - self.relu2 = nn.ReLU() - self.dropout2 = nn.Dropout(dropout) - - self.net = nn.Sequential(self.conv1, self.chomp1, self.relu1, self.dropout1, - self.conv2, self.chomp2, self.relu2, self.dropout2) - self.downsample = nn.Conv1d( - n_inputs, n_outputs, 1) if n_inputs != n_outputs else None - self.relu = nn.ReLU() - # self.init_weights() - - def init_weights(self) -> None: - self.conv1.weight.data.normal_(0, 0.01) - self.conv2.weight.data.normal_(0, 0.01) - if self.downsample is not None: - self.downsample.weight.data.normal_(0, 0.01) - - def forward(self, x: torch.Tensor) -> torch.Tensor: - out = self.net(x) - res = x if self.downsample is None else self.downsample(x) - return self.relu(out + res) - - -class _TemporalConvNet(nn.Module): - def __init__(self, num_inputs: int, num_channels: List[int], kernel_size: int = 2, dropout: float = 0.2): - super(_TemporalConvNet, self).__init__() - layers: List[Any] = [] - num_levels = len(num_channels) - for i in range(num_levels): - dilation_size = 2 ** i - in_channels = num_inputs if i == 0 else num_channels[i - 1] - out_channels = num_channels[i] - layers += [_TemporalBlock(in_channels, - out_channels, - kernel_size, - stride=1, - dilation=dilation_size, - padding=(kernel_size - 1) * dilation_size, - dropout=dropout)] - self.network = nn.Sequential(*layers) - - def forward(self, x: torch.Tensor) -> torch.Tensor: - # swap sequence and feature dimensions for use with convolutional nets - x = x.transpose(1, 2).contiguous() - x = self.network(x) - x = x.transpose(1, 2).contiguous() - return x - - -class TCNBackbone(NetworkBackboneComponent): - supported_tasks = {"time_series_classification", "time_series_regression"} - - def build_backbone(self, input_shape: Tuple[int, ...]) -> nn.Module: - num_channels = [self.config["num_filters_0"]] - for i in range(1, self.config["num_blocks"]): - num_channels.append(self.config[f"num_filters_{i}"]) - backbone = _TemporalConvNet(input_shape[-1], - num_channels, - kernel_size=self.config["kernel_size"], - dropout=self.config["dropout"] if self.config["use_dropout"] else 0.0 - ) - self.backbone = backbone - return backbone - - @staticmethod - def get_properties(dataset_properties: Optional[Dict[str, str]] = None) -> Dict[str, Any]: - return { - "shortname": "TCNBackbone", - "name": "TCNBackbone", - 'handles_tabular': False, - 'handles_image': False, - 'handles_time_series': True, - } - - @staticmethod - def get_hyperparameter_search_space(dataset_properties: Optional[Dict[str, str]] = None, - min_num_blocks: int = 1, - max_num_blocks: int = 10, - min_num_filters: int = 4, - max_num_filters: int = 64, - min_kernel_size: int = 4, - max_kernel_size: int = 64, - min_dropout: float = 0.0, - max_dropout: float = 0.5 - ) -> ConfigurationSpace: - cs = ConfigurationSpace() - - num_blocks_hp = UniformIntegerHyperparameter("num_blocks", - lower=min_num_blocks, - upper=max_num_blocks) - cs.add_hyperparameter(num_blocks_hp) - - kernel_size_hp = UniformIntegerHyperparameter("kernel_size", - lower=min_kernel_size, - upper=max_kernel_size) - cs.add_hyperparameter(kernel_size_hp) - - use_dropout_hp = CategoricalHyperparameter("use_dropout", - choices=[True, False]) - cs.add_hyperparameter(use_dropout_hp) - - dropout_hp = UniformFloatHyperparameter("dropout", - lower=min_dropout, - upper=max_dropout) - cs.add_hyperparameter(dropout_hp) - cs.add_condition(CS.EqualsCondition(dropout_hp, use_dropout_hp, True)) - - for i in range(0, max_num_blocks): - num_filters_hp = UniformIntegerHyperparameter(f"num_filters_{i}", - lower=min_num_filters, - upper=max_num_filters) - cs.add_hyperparameter(num_filters_hp) - if i >= min_num_blocks: - cs.add_condition(CS.GreaterThanCondition( - num_filters_hp, num_blocks_hp, i)) - - return cs diff --git a/autoPyTorch/pipeline/components/setup/network_head/base_network_head.py b/autoPyTorch/pipeline/components/setup/network_head/base_network_head.py index 72a34fefe..be2a9c7dc 100644 --- a/autoPyTorch/pipeline/components/setup/network_head/base_network_head.py +++ b/autoPyTorch/pipeline/components/setup/network_head/base_network_head.py @@ -1,5 +1,5 @@ from abc import abstractmethod -from typing import Any, Dict, Set, Tuple +from typing import Any, Dict, Tuple import torch.nn as nn @@ -10,9 +10,9 @@ class NetworkHeadComponent(autoPyTorchComponent): """ - Head base class + Base class for network heads. Holds the head module and the config which was used to create it. """ - supported_tasks: Set = set() + _required_properties = ["name", "shortname", "handles_tabular", "handles_image", "handles_time_series"] def __init__(self, **kwargs: Any): @@ -22,7 +22,13 @@ def __init__(self, def fit(self, X: Dict[str, Any], y: Any = None) -> BaseEstimator: """ - Not used. Just for API compatibility. + Builds the head component and assigns it to self.head + + Args: + X (X: Dict[str, Any]): Dependencies needed by current component to perform fit + y (Any): not used. To comply with sklearn API + Returns: + Self """ input_shape = X['X_train'].shape[1:] output_shape = (X['dataset_properties']['num_classes'],) if \ @@ -37,7 +43,8 @@ def fit(self, X: Dict[str, Any], y: Any = None) -> BaseEstimator: def transform(self, X: Dict[str, Any]) -> Dict[str, Any]: """ - Adds the scheduler into the fit dictionary 'X' and returns it. + Adds the network head into the fit dictionary 'X' and returns it. + Args: X (Dict[str, Any]): 'X' dictionary Returns: @@ -49,12 +56,14 @@ def transform(self, X: Dict[str, Any]) -> Dict[str, Any]: @abstractmethod def build_head(self, input_shape: Tuple[int, ...], output_shape: Tuple[int, ...]) -> nn.Module: """ + Builds the head module and returns it - Builds the head module and assigns it to self.head + Args: + input_shape (Tuple[int, ...]): shape of the input to the head (usually the shape of the backbone output) + output_shape (Tuple[int, ...]): shape of the output of the head - :param input_shape: shape of the input (usually the shape of the backbone output) - :param output_shape: shape of the output - :return: the head module + Returns: + nn.Module: head module """ raise NotImplementedError() @@ -62,6 +71,11 @@ def build_head(self, input_shape: Tuple[int, ...], output_shape: Tuple[int, ...] def get_name(cls) -> str: """ Get the name of the head - :return: name of the head + + Args: + None + + Returns: + str: Name of the head """ return cls.get_properties()["shortname"] diff --git a/autoPyTorch/pipeline/components/setup/network_head/fully_connected.py b/autoPyTorch/pipeline/components/setup/network_head/fully_connected.py index bd555e03a..f01839234 100644 --- a/autoPyTorch/pipeline/components/setup/network_head/fully_connected.py +++ b/autoPyTorch/pipeline/components/setup/network_head/fully_connected.py @@ -8,25 +8,15 @@ from torch import nn -from autoPyTorch.pipeline.components.setup.network_head.base_network_head import ( - NetworkHeadComponent, -) - -_activations: Dict[str, nn.Module] = { - "relu": nn.ReLU, - "tanh": nn.Tanh, - "sigmoid": nn.Sigmoid -} +from autoPyTorch.pipeline.components.setup.network_head.base_network_head import NetworkHeadComponent +from autoPyTorch.pipeline.components.setup.network_head.utils import _activations class FullyConnectedHead(NetworkHeadComponent): """ - Standard head consisting of a number of fully connected layers. + Head consisting of a number of fully connected layers. Flattens any input in a array of shape [B, prod(input_shape)]. """ - supported_tasks = {"tabular_classification", "tabular_regression", - "image_classification", "image_regression", - "time_series_classification", "time_series_regression"} def build_head(self, input_shape: Tuple[int, ...], output_shape: Tuple[int, ...]) -> nn.Module: layers = [nn.Flatten()] @@ -47,8 +37,8 @@ def get_properties(dataset_properties: Optional[Dict[str, Any]] = None) -> Dict[ 'shortname': 'FullyConnectedHead', 'name': 'FullyConnectedHead', 'handles_tabular': True, - 'handles_image': False, - 'handles_time_series': False, + 'handles_image': True, + 'handles_time_series': True, } @staticmethod diff --git a/autoPyTorch/pipeline/components/setup/network_head/fully_convolutional.py b/autoPyTorch/pipeline/components/setup/network_head/fully_convolutional.py index ed83fc32e..21ae3eb71 100644 --- a/autoPyTorch/pipeline/components/setup/network_head/fully_convolutional.py +++ b/autoPyTorch/pipeline/components/setup/network_head/fully_convolutional.py @@ -7,15 +7,8 @@ import torch from torch import nn -from autoPyTorch.pipeline.components.setup.network_head.base_network_head import ( - NetworkHeadComponent, -) - -_activations: Dict[str, nn.Module] = { - "relu": nn.ReLU, - "tanh": nn.Tanh, - "sigmoid": nn.Sigmoid -} +from autoPyTorch.pipeline.components.setup.network_head.base_network_head import NetworkHeadComponent +from autoPyTorch.pipeline.components.setup.network_head.utils import _activations class _FullyConvolutional2DHead(nn.Module): @@ -56,7 +49,6 @@ class FullyConvolutional2DHead(NetworkHeadComponent): Head consisting of a number of 2d convolutional connected layers. Applies a global pooling operation in the end. """ - supported_tasks = {"image_classification", "image_regression"} def build_head(self, input_shape: Tuple[int, ...], output_shape: Tuple[int, ...]) -> nn.Module: return _FullyConvolutional2DHead(input_shape=input_shape, @@ -70,8 +62,8 @@ def build_head(self, input_shape: Tuple[int, ...], output_shape: Tuple[int, ...] @staticmethod def get_properties(dataset_properties: Optional[Dict[str, Any]] = None) -> Dict[str, Union[str, bool]]: return { - 'shortname': 'FullyConvolutionalHead', - 'name': 'FullyConvolutionalHead', + 'shortname': 'FullyConvolutional2DHead', + 'name': 'FullyConvolutional2DHead', 'handles_tabular': False, 'handles_image': True, 'handles_time_series': False, diff --git a/autoPyTorch/pipeline/components/setup/network_head/utils.py b/autoPyTorch/pipeline/components/setup/network_head/utils.py new file mode 100644 index 000000000..21e037395 --- /dev/null +++ b/autoPyTorch/pipeline/components/setup/network_head/utils.py @@ -0,0 +1,7 @@ +import torch + +_activations = { + "relu": torch.nn.ReLU, + "tanh": torch.nn.Tanh, + "sigmoid": torch.nn.Sigmoid +} diff --git a/test/test_pipeline/components/test_setup.py b/test/test_pipeline/components/test_setup.py index 9ab961ddc..07e2f2f03 100644 --- a/test/test_pipeline/components/test_setup.py +++ b/test/test_pipeline/components/test_setup.py @@ -1,21 +1,30 @@ import copy import unittest.mock +from typing import Any, Dict, Optional, Tuple from ConfigSpace.configuration_space import ConfigurationSpace from sklearn.base import clone +import torch +from torch import nn + import autoPyTorch.pipeline.components.setup.lr_scheduler.base_scheduler_choice as lr_components import \ autoPyTorch.pipeline.components.setup.network_initializer.base_network_init_choice as network_initializer_components # noqa: E501 import autoPyTorch.pipeline.components.setup.optimizer.base_optimizer_choice as optimizer_components +from autoPyTorch import constants +from autoPyTorch.pipeline.components.base_component import ThirdPartyComponents from autoPyTorch.pipeline.components.setup.lr_scheduler.base_scheduler_choice import ( BaseLRComponent, SchedulerChoice ) -from autoPyTorch.pipeline.components.setup.network_head.base_network_head_choice import ( - NetworkHeadChoice, -) +from autoPyTorch.pipeline.components.setup.network_backbone import base_network_backbone_choice +from autoPyTorch.pipeline.components.setup.network_backbone.base_network_backbone import NetworkBackboneComponent +from autoPyTorch.pipeline.components.setup.network_backbone.base_network_backbone_choice import NetworkBackboneChoice +from autoPyTorch.pipeline.components.setup.network_head import base_network_head_choice +from autoPyTorch.pipeline.components.setup.network_head.base_network_head import NetworkHeadComponent +from autoPyTorch.pipeline.components.setup.network_head.base_network_head_choice import NetworkHeadChoice from autoPyTorch.pipeline.components.setup.network_initializer.base_network_init_choice import ( BaseNetworkInitializerComponent, NetworkInitializerChoice @@ -74,6 +83,40 @@ def get_properties(dataset_properties=None): } +class DummyBackbone(NetworkBackboneComponent): + @staticmethod + def get_properties(dataset_properties: Optional[Dict[str, str]] = None) -> Dict[str, Any]: + return {"name": "DummyBackbone", + "shortname": "DummyBackbone", + "handles_tabular": True, + "handles_image": True, + "handles_time_series": True} + + def build_backbone(self, input_shape: Tuple[int, ...]) -> nn.Module: + return nn.Identity() + + @staticmethod + def get_hyperparameter_search_space(dataset_properties: Optional[Dict[str, str]] = None) -> ConfigurationSpace: + return ConfigurationSpace() + + +class DummyHead(NetworkHeadComponent): + @staticmethod + def get_properties(dataset_properties: Optional[Dict[str, str]] = None) -> Dict[str, Any]: + return {"name": "DummyHead", + "shortname": "DummyHead", + "handles_tabular": True, + "handles_image": True, + "handles_time_series": True} + + def build_head(self, input_shape: Tuple[int, ...], output_shape: Tuple[int, ...]) -> nn.Module: + return nn.Identity() + + @staticmethod + def get_hyperparameter_search_space(dataset_properties: Optional[Dict[str, str]] = None) -> ConfigurationSpace: + return ConfigurationSpace() + + class SchedulerTest(unittest.TestCase): def test_every_scheduler_is_valid(self): """ @@ -246,8 +289,160 @@ def test_optimizer_add(self): self.assertIn('DummyOptimizer', str(cs)) +class NetworkBackboneTest(unittest.TestCase): + def test_all_backbones_available(self): + backbone_choice = NetworkBackboneChoice(dataset_properties={}) + + self.assertEqual(len(backbone_choice.get_components().keys()), 8) + + def test_dummy_forward_backward_pass(self): + network_backbone_choice = NetworkBackboneChoice(dataset_properties={}) + + task_types = {constants.IMAGE_CLASSIFICATION: (3, 64, 64), + constants.IMAGE_REGRESSION: (3, 64, 64), + constants.TIMESERIES_CLASSIFICATION: (32, 6), + constants.TIMESERIES_REGRESSION: (32, 6), + constants.TABULAR_CLASSIFICATION: (100,), + constants.TABULAR_REGRESSION: (100,)} + + device = torch.device("cpu") + + for task_type, input_shape in task_types.items(): + dataset_properties = {"task_type": constants.TASK_TYPES_TO_STRING[task_type]} + + cs = network_backbone_choice.get_hyperparameter_search_space(dataset_properties=dataset_properties) + + # test 10 random configurations + for i in range(10): + config = cs.sample_configuration() + network_backbone_choice.set_hyperparameters(config) + backbone = network_backbone_choice.choice.build_backbone(input_shape=input_shape) + self.assertNotEqual(backbone, None) + backbone = backbone.to(device) + dummy_input = torch.randn((2, *input_shape), dtype=torch.float) + output = backbone(dummy_input) + self.assertNotEqual(output.shape[1:], output) + loss = output.sum() + loss.backward() + + def test_every_backbone_is_valid(self): + backbone_choice = NetworkBackboneChoice(dataset_properties={}) + + self.assertEqual(len(backbone_choice.get_components().keys()), 8) + + for name, backbone in backbone_choice.get_components().items(): + config = backbone.get_hyperparameter_search_space().sample_configuration() + estimator = backbone(**config) + estimator_clone = clone(estimator) + estimator_clone_params = estimator_clone.get_params() + + # Make sure all keys are copied properly + for k, v in estimator.get_params().items(): + self.assertIn(k, estimator_clone_params) + + # Make sure the params getter of estimator are honored + klass = estimator.__class__ + new_object_params = estimator.get_params(deep=False) + for name, param in new_object_params.items(): + new_object_params[name] = clone(param, safe=False) + new_object = klass(**new_object_params) + params_set = new_object.get_params(deep=False) + + for name in new_object_params: + param1 = new_object_params[name] + param2 = params_set[name] + self.assertEqual(param1, param2) + + def test_get_set_config_space(self): + """ + Make sure that we can setup a valid choice in the network backbone choice + """ + network_backbone_choice = NetworkBackboneChoice(dataset_properties={}) + for task_type in constants.TASK_TYPES: + dataset_properties = {"task_type": constants.TASK_TYPES_TO_STRING[task_type]} + cs = network_backbone_choice.get_hyperparameter_search_space(dataset_properties) + + # Make sure we can properly set some random configs + # Whereas just one iteration will make sure the algorithm works, + # doing five iterations increase the confidence. We will be able to + # catch component specific crashes + for i in range(5): + config = cs.sample_configuration() + config_dict = copy.deepcopy(config.get_dictionary()) + network_backbone_choice.set_hyperparameters(config) + + self.assertEqual(network_backbone_choice.choice.__class__, + network_backbone_choice.get_components()[config_dict['__choice__']]) + + # Then check the choice configuration + selected_choice = config_dict.pop('__choice__', None) + self.assertNotEqual(selected_choice, None) + for key, value in config_dict.items(): + # Remove the selected_choice string from the parameter + # so we can query in the object for it + key = key.replace(selected_choice + ':', '') + # parameters are dynamic, so they exist in config + parameters = vars(network_backbone_choice.choice) + parameters.update(vars(network_backbone_choice.choice)['config']) + self.assertIn(key, parameters) + self.assertEqual(value, parameters[key]) + + def test_add_network_backbone(self): + """Makes sure that a component can be added to the CS""" + # No third party components to start with + self.assertEqual(len(base_network_backbone_choice._addons.components), 0) + + # Then make sure the backbone can be added + base_network_backbone_choice.add_backbone(DummyBackbone) + self.assertEqual(len(base_network_backbone_choice._addons.components), 1) + + cs = NetworkBackboneChoice(dataset_properties={}). \ + get_hyperparameter_search_space(dataset_properties={"task_type": "tabular_classification"}) + self.assertIn("DummyBackbone", str(cs)) + + # clear addons + base_network_backbone_choice._addons = ThirdPartyComponents(NetworkBackboneComponent) + + class NetworkHeadTest(unittest.TestCase): - def test_every_networkHead_is_valid(self): + def test_all_heads_available(self): + network_head_choice = NetworkHeadChoice(dataset_properties={}) + + self.assertEqual(len(network_head_choice.get_components().keys()), 2) + + def test_dummy_forward_backward_pass(self): + network_head_choice = NetworkHeadChoice(dataset_properties={}) + + task_types = {constants.IMAGE_CLASSIFICATION: ((3, 64, 64), (5,)), + constants.IMAGE_REGRESSION: ((3, 64, 64), (1,)), + constants.TIMESERIES_CLASSIFICATION: ((32, 6), (5,)), + constants.TIMESERIES_REGRESSION: ((32, 6), (1,)), + constants.TABULAR_CLASSIFICATION: ((100,), (5,)), + constants.TABULAR_REGRESSION: ((100,), (1,))} + + device = torch.device("cpu") + + for task_type, (input_shape, output_shape) in task_types.items(): + dataset_properties = {"task_type": constants.TASK_TYPES_TO_STRING[task_type]} + if task_type in constants.CLASSIFICATION_TASKS: + dataset_properties["num_classes"] = output_shape[0] + + cs = network_head_choice.get_hyperparameter_search_space(dataset_properties=dataset_properties) + # test 10 random configurations + for i in range(10): + config = cs.sample_configuration() + network_head_choice.set_hyperparameters(config) + head = network_head_choice.choice.build_head(input_shape=input_shape, + output_shape=output_shape) + self.assertNotEqual(head, None) + head = head.to(device) + dummy_input = torch.randn((2, *input_shape), dtype=torch.float) + output = head(dummy_input) + self.assertEqual(output.shape[1:], output_shape) + loss = output.sum() + loss.backward() + + def test_every_head_is_valid(self): """ Makes sure that every network is a valid estimator. That is, we can fully create an object via get/set params. @@ -255,18 +450,15 @@ def test_every_networkHead_is_valid(self): This also test that we can properly initialize each one of them """ - networkHead_choice = NetworkHeadChoice(dataset_properties={'task_type': 'tabular_classification'}) - - # Make sure all components are returned - self.assertEqual(len(networkHead_choice.get_components().keys()), 2) + network_head_choice = NetworkHeadChoice(dataset_properties={'task_type': 'tabular_classification'}) # For every network in the components, make sure # that it complies with the scikit learn estimator. # This is important because usually components are forked to workers, # so the set/get params methods should recreate the same object - for name, networkHead in networkHead_choice.get_components().items(): - config = networkHead.get_hyperparameter_search_space().sample_configuration() - estimator = networkHead(**config) + for name, network_head in network_head_choice.get_components().items(): + config = network_head.get_hyperparameter_search_space().sample_configuration() + estimator = network_head(**config) estimator_clone = clone(estimator) estimator_clone_params = estimator_clone.get_params() @@ -288,43 +480,54 @@ def test_every_networkHead_is_valid(self): self.assertEqual(param1, param2) def test_get_set_config_space(self): - """Make sure that we can setup a valid choice in the networkHead - choice""" - networkHead_choice = NetworkHeadChoice(dataset_properties={'task_type': 'tabular_classification'}) - cs = networkHead_choice.get_hyperparameter_search_space( - dataset_properties={"task_type": 'tabular_classification'}) - - # Make sure that all hyperparameters are part of the search space - self.assertListEqual( - sorted(cs.get_hyperparameter('__choice__').choices), - ['fully_connected'] - ) - - # Make sure we can properly set some random configs - # Whereas just one iteration will make sure the algorithm works, - # doing five iterations increase the confidence. We will be able to - # catch component specific crashes - for i in range(5): - config = cs.sample_configuration() - config_dict = copy.deepcopy(config.get_dictionary()) - networkHead_choice.set_hyperparameters(config) + """ + Make sure that we can setup a valid choice in the network head choice + """ + network_head_choice = NetworkHeadChoice(dataset_properties={}) + for task_type in constants.TASK_TYPES: + dataset_properties = {"task_type": constants.TASK_TYPES_TO_STRING[task_type]} + cs = network_head_choice.get_hyperparameter_search_space(dataset_properties) + + # Make sure we can properly set some random configs + # Whereas just one iteration will make sure the algorithm works, + # doing five iterations increase the confidence. We will be able to + # catch component specific crashes + for i in range(5): + config = cs.sample_configuration() + config_dict = copy.deepcopy(config.get_dictionary()) + network_head_choice.set_hyperparameters(config) + + self.assertEqual(network_head_choice.choice.__class__, + network_head_choice.get_components()[config_dict['__choice__']]) + + # Then check the choice configuration + selected_choice = config_dict.pop('__choice__', None) + self.assertNotEqual(selected_choice, None) + for key, value in config_dict.items(): + # Remove the selected_choice string from the parameter + # so we can query in the object for it + key = key.replace(selected_choice + ':', '') + # parameters are dynamic, so they exist in config + parameters = vars(network_head_choice.choice) + parameters.update(vars(network_head_choice.choice)['config']) + self.assertIn(key, parameters) + self.assertEqual(value, parameters[key]) + + def test_add_network_head(self): + """Makes sure that a component can be added to the CS""" + # No third party components to start with + self.assertEqual(len(base_network_head_choice._addons.components), 0) - self.assertEqual(networkHead_choice.choice.__class__, - networkHead_choice.get_components()[config_dict['__choice__']]) + # Then make sure the head can be added + base_network_head_choice.add_head(DummyHead) + self.assertEqual(len(base_network_head_choice._addons.components), 1) - # Then check the choice configuration - selected_choice = config_dict.pop('__choice__', None) - self.assertNotEqual(selected_choice, None) - for key, value in config_dict.items(): - # Remove the selected_choice string from the parameter - # so we can query in the object for it + cs = NetworkHeadChoice(dataset_properties={}). \ + get_hyperparameter_search_space(dataset_properties={"task_type": "tabular_classification"}) + self.assertIn("DummyHead", str(cs)) - key = key.replace(selected_choice + ':', '') - # In the case of MLP, parameters are dynamic, so they exist in config - parameters = vars(networkHead_choice.choice) - parameters.update(vars(networkHead_choice.choice)['config']) - self.assertIn(key, parameters) - self.assertEqual(value, parameters[key]) + # clear addons + base_network_head_choice._addons = ThirdPartyComponents(NetworkHeadComponent) class NetworkInitializerTest(unittest.TestCase):