diff --git a/autoPyTorch/api/base_task.py b/autoPyTorch/api/base_task.py index a997c505b..edd505d86 100644 --- a/autoPyTorch/api/base_task.py +++ b/autoPyTorch/api/base_task.py @@ -21,6 +21,8 @@ import joblib +import matplotlib.pyplot as plt + import numpy as np import pandas as pd @@ -29,7 +31,7 @@ from smac.stats.stats import Stats from smac.tae import StatusType -from autoPyTorch.api.results_manager import ResultsManager, SearchResults +from autoPyTorch import metrics from autoPyTorch.automl_common.common.utils.backend import Backend, create from autoPyTorch.constants import ( REGRESSION_TASKS, @@ -58,6 +60,8 @@ ) from autoPyTorch.utils.parallel import preload_modules from autoPyTorch.utils.pipeline import get_configuration_space, get_dataset_requirements +from autoPyTorch.utils.results_manager import MetricResults, ResultsManager, SearchResults +from autoPyTorch.utils.results_visualizer import ColorLabelSettings, PlotSettingParams, ResultsVisualizer from autoPyTorch.utils.single_thread_client import SingleThreadedClient from autoPyTorch.utils.stopwatch import StopWatch @@ -1479,3 +1483,56 @@ def sprint_statistics(self) -> str: scoring_functions=self._scoring_functions, metric=self._metric ) + + def plot_perf_over_time( + self, + metric_name: str, + ax: Optional[plt.Axes] = None, + plot_setting_params: PlotSettingParams = PlotSettingParams(), + color_label_settings: ColorLabelSettings = ColorLabelSettings(), + *args: Any, + **kwargs: Any + ) -> None: + """ + Visualize the performance over time using matplotlib. + The plot related arguments are based on matplotlib. + Please refer to the matplotlib documentation for more details. + + Args: + metric_name (str): + The name of metric to visualize. + The names are available in + * autoPyTorch.metrics.CLASSIFICATION_METRICS + * autoPyTorch.metrics.REGRESSION_METRICS + ax (Optional[plt.Axes]): + axis to plot (subplots of matplotlib). + If None, it will be created automatically. + plot_setting_params (PlotSettingParams): + Parameters for the plot. + color_label_settings (ColorLabelSettings): + The settings of a pair of color and label for each plot. + args, kwargs (Any): + Arguments for the ax.plot. + """ + + if not hasattr(metrics, metric_name): + raise ValueError( + f'metric_name must be in {list(metrics.CLASSIFICATION_METRICS.keys())} ' + f'or {list(metrics.REGRESSION_METRICS.keys())}, but got {metric_name}' + ) + if len(self.ensemble_performance_history) == 0: + raise RuntimeError('Visualization is available only after ensembles are evaluated.') + + results = MetricResults( + metric=getattr(metrics, metric_name), + run_history=self.run_history, + ensemble_performance_history=self.ensemble_performance_history + ) + + colors, labels = color_label_settings.extract_dicts(results) + + ResultsVisualizer().plot_perf_over_time( # type: ignore + results=results, plot_setting_params=plot_setting_params, + colors=colors, labels=labels, ax=ax, + *args, **kwargs + ) diff --git a/autoPyTorch/api/results_manager.py b/autoPyTorch/api/results_manager.py deleted file mode 100644 index e52d21613..000000000 --- a/autoPyTorch/api/results_manager.py +++ /dev/null @@ -1,326 +0,0 @@ -import io -from typing import Any, Dict, List, Optional, Tuple, Union - -from ConfigSpace.configuration_space import Configuration - -import numpy as np - -import scipy - -from smac.runhistory.runhistory import RunHistory, RunValue -from smac.tae import StatusType -from smac.utils.io.traj_logging import TrajEntry - -from autoPyTorch.pipeline.components.training.metrics.base import autoPyTorchMetric - - -# TODO remove StatusType.RUNNING at some point in the future when the new SMAC 0.13.2 -# is the new minimum required version! -STATUS2MSG = { - StatusType.SUCCESS: 'Success', - StatusType.DONOTADVANCE: 'Success (but did not advance to higher budget)', - StatusType.TIMEOUT: 'Timeout', - StatusType.CRASHED: 'Crash', - StatusType.ABORT: 'Abort', - StatusType.MEMOUT: 'Memory out' -} - - -def cost2metric(cost: float, metric: autoPyTorchMetric) -> float: - """ - Revert cost metric evaluated in SMAC to the original metric. - - The conversion is defined in: - autoPyTorch/pipeline/components/training/metrics/utils.py::calculate_loss - cost = metric._optimum - metric._sign * original_metric_value - ==> original_metric_value = metric._sign * (metric._optimum - cost) - """ - return metric._sign * (metric._optimum - cost) - - -def _extract_metrics_info( - run_value: RunValue, - scoring_functions: List[autoPyTorchMetric] -) -> Dict[str, float]: - """ - Extract the metric information given a run_value - and a list of metrics of interest. - - Args: - run_value (RunValue): - The information for each config evaluation. - scoring_functions (List[autoPyTorchMetric]): - The list of metrics to retrieve the info. - """ - - if run_value.status not in (StatusType.SUCCESS, StatusType.DONOTADVANCE): - # Additional info for metrics is not available in this case. - return {metric.name: np.nan for metric in scoring_functions} - - cost_info = run_value.additional_info['opt_loss'] - avail_metrics = cost_info.keys() - - return { - metric.name: cost2metric(cost=cost_info[metric.name], metric=metric) - if metric.name in avail_metrics else np.nan - for metric in scoring_functions - } - - -class SearchResults: - def __init__( - self, - metric: autoPyTorchMetric, - scoring_functions: List[autoPyTorchMetric], - run_history: RunHistory - ): - self.metric_dict: Dict[str, List[float]] = { - metric.name: [] - for metric in scoring_functions - } - self._opt_scores: List[float] = [] - self._fit_times: List[float] = [] - self.configs: List[Configuration] = [] - self.status_types: List[str] = [] - self.budgets: List[float] = [] - self.config_ids: List[int] = [] - self.is_traditionals: List[bool] = [] - self.additional_infos: List[Optional[Dict[str, Any]]] = [] - self.rank_test_scores: np.ndarray = np.array([]) - self._scoring_functions = scoring_functions - self._metric = metric - - self._extract_results_from_run_history(run_history) - - @property - def opt_scores(self) -> np.ndarray: - return np.asarray(self._opt_scores) - - @property - def fit_times(self) -> np.ndarray: - return np.asarray(self._fit_times) - - def update( - self, - config: Configuration, - status: str, - budget: float, - fit_time: float, - config_id: int, - is_traditional: bool, - additional_info: Dict[str, Any], - score: float, - metric_info: Dict[str, float] - ) -> None: - - self.status_types.append(status) - self.configs.append(config) - self.budgets.append(budget) - self.config_ids.append(config_id) - self.is_traditionals.append(is_traditional) - self.additional_infos.append(additional_info) - self._fit_times.append(fit_time) - self._opt_scores.append(score) - - for metric_name, val in metric_info.items(): - self.metric_dict[metric_name].append(val) - - def clear(self) -> None: - self._opt_scores = [] - self._fit_times = [] - self.configs = [] - self.status_types = [] - self.budgets = [] - self.config_ids = [] - self.additional_infos = [] - self.is_traditionals = [] - self.rank_test_scores = np.array([]) - - def _extract_results_from_run_history(self, run_history: RunHistory) -> None: - """ - Extract the information to match this class format. - - Args: - run_history (RunHistory): - The history of config evals from SMAC. - """ - - self.clear() # Delete cache before the extraction - - for run_key, run_value in run_history.data.items(): - config_id = run_key.config_id - config = run_history.ids_config[config_id] - - status_msg = STATUS2MSG.get(run_value.status, None) - if run_value.status in (StatusType.STOP, StatusType.RUNNING): - continue - elif status_msg is None: - raise ValueError(f'Unexpected run status: {run_value.status}') - - is_traditional = False # If run is not successful, unsure ==> not True ==> False - if run_value.additional_info is not None: - is_traditional = run_value.additional_info['configuration_origin'] == 'traditional' - - self.update( - status=status_msg, - config=config, - budget=run_key.budget, - fit_time=run_value.time, - score=cost2metric(cost=run_value.cost, metric=self._metric), - metric_info=_extract_metrics_info(run_value=run_value, scoring_functions=self._scoring_functions), - is_traditional=is_traditional, - additional_info=run_value.additional_info, - config_id=config_id - ) - - self.rank_test_scores = scipy.stats.rankdata( - -1 * self._metric._sign * self.opt_scores, # rank order - method='min' - ) - - -class ResultsManager: - def __init__(self, *args: Any, **kwargs: Any): - """ - Attributes: - run_history (RunHistory): - A `SMAC Runshistory `_ - object that holds information about the runs of the target algorithm made during search - ensemble_performance_history (List[Dict[str, Any]]): - The list of ensemble performance in the optimization. - The list includes the `timestamp`, `result on train set`, and `result on test set` - trajectory (List[TrajEntry]): - A list of all incumbent configurations during search - """ - self.run_history: RunHistory = RunHistory() - self.ensemble_performance_history: List[Dict[str, Any]] = [] - self.trajectory: List[TrajEntry] = [] - - def _check_run_history(self) -> None: - if self.run_history is None: - raise RuntimeError("No Run History found, search has not been called.") - - if self.run_history.empty(): - raise RuntimeError("Run History is empty. Something went wrong, " - "SMAC was not able to fit any model?") - - def get_incumbent_results( - self, - metric: autoPyTorchMetric, - include_traditional: bool = False - ) -> Tuple[Configuration, Dict[str, Union[int, str, float]]]: - """ - Get Incumbent config and the corresponding results - - Args: - metric (autoPyTorchMetric): - A metric that is evaluated when searching with fit AutoPytorch. - include_traditional (bool): - Whether to include results from tradtional pipelines - - Returns: - Configuration (CS.ConfigurationSpace): - The incumbent configuration - Dict[str, Union[int, str, float]]: - Additional information about the run of the incumbent configuration. - """ - self._check_run_history() - - results = SearchResults(metric=metric, scoring_functions=[], run_history=self.run_history) - - if not include_traditional: - non_traditional = ~np.array(results.is_traditionals) - scores = results.opt_scores[non_traditional] - indices = np.arange(len(results.configs))[non_traditional] - else: - scores = results.opt_scores - indices = np.arange(len(results.configs)) - - incumbent_idx = indices[np.nanargmax(metric._sign * scores)] - incumbent_config = results.configs[incumbent_idx] - incumbent_results = results.additional_infos[incumbent_idx] - - assert incumbent_results is not None # mypy check - return incumbent_config, incumbent_results - - def get_search_results( - self, - scoring_functions: List[autoPyTorchMetric], - metric: autoPyTorchMetric - ) -> SearchResults: - """ - This attribute is populated with data from `self.run_history` - and contains information about the configurations, and their - corresponding metric results, status of run, parameters and - the budget - - Args: - scoring_functions (List[autoPyTorchMetric]): - Metrics to show in the results. - metric (autoPyTorchMetric): - A metric that is evaluated when searching with fit AutoPytorch. - - Returns: - SearchResults: - An instance that contains the results from search - """ - self._check_run_history() - return SearchResults(metric=metric, scoring_functions=scoring_functions, run_history=self.run_history) - - def sprint_statistics( - self, - dataset_name: str, - scoring_functions: List[autoPyTorchMetric], - metric: autoPyTorchMetric - ) -> str: - """ - Prints statistics about the SMAC search. - - These statistics include: - - 1. Optimisation Metric - 2. Best Optimisation score achieved by individual pipelines - 3. Total number of target algorithm runs - 4. Total number of successful target algorithm runs - 5. Total number of crashed target algorithm runs - 6. Total number of target algorithm runs that exceeded the time limit - 7. Total number of successful target algorithm runs that exceeded the memory limit - - Args: - dataset_name (str): - The dataset name that was used in the run. - scoring_functions (List[autoPyTorchMetric]): - Metrics to show in the results. - metric (autoPyTorchMetric): - A metric that is evaluated when searching with fit AutoPytorch. - - Returns: - (str): - Formatted string with statistics - """ - search_results = self.get_search_results(scoring_functions, metric) - success_msgs = (STATUS2MSG[StatusType.SUCCESS], STATUS2MSG[StatusType.DONOTADVANCE]) - sio = io.StringIO() - sio.write("autoPyTorch results:\n") - sio.write(f"\tDataset name: {dataset_name}\n") - sio.write(f"\tOptimisation Metric: {metric}\n") - - num_runs = len(search_results.status_types) - num_success = sum([s in success_msgs for s in search_results.status_types]) - num_crash = sum([s == STATUS2MSG[StatusType.CRASHED] for s in search_results.status_types]) - num_timeout = sum([s == STATUS2MSG[StatusType.TIMEOUT] for s in search_results.status_types]) - num_memout = sum([s == STATUS2MSG[StatusType.MEMOUT] for s in search_results.status_types]) - - if num_success > 0: - best_score = metric._sign * np.nanmax(metric._sign * search_results.opt_scores) - sio.write(f"\tBest validation score: {best_score}\n") - - sio.write(f"\tNumber of target algorithm runs: {num_runs}\n") - sio.write(f"\tNumber of successful target algorithm runs: {num_success}\n") - sio.write(f"\tNumber of crashed target algorithm runs: {num_crash}\n") - sio.write(f"\tNumber of target algorithms that exceeded the time " - f"limit: {num_timeout}\n") - sio.write(f"\tNumber of target algorithms that exceeded the memory " - f"limit: {num_memout}\n") - - return sio.getvalue() diff --git a/autoPyTorch/evaluation/train_evaluator.py b/autoPyTorch/evaluation/train_evaluator.py index 010948b55..37926a8c0 100644 --- a/autoPyTorch/evaluation/train_evaluator.py +++ b/autoPyTorch/evaluation/train_evaluator.py @@ -254,10 +254,15 @@ def fit_predict_and_loss(self) -> None: # train_losses is a list of dicts. It is # computed using the target metric (self.metric). - train_loss = np.average([train_losses[i][str(self.metric)] - for i in range(self.num_folds)], - weights=train_fold_weights, - ) + train_loss = {} + for metric in train_losses[0].keys(): + train_loss[metric] = np.average( + [ + train_losses[i][metric] + for i in range(self.num_folds) + ], + weights=train_fold_weights + ) opt_loss = {} # self.logger.debug("OPT LOSSES: {}".format(opt_losses if opt_losses is not None else None)) diff --git a/autoPyTorch/utils/results_manager.py b/autoPyTorch/utils/results_manager.py new file mode 100644 index 000000000..c1860b0f6 --- /dev/null +++ b/autoPyTorch/utils/results_manager.py @@ -0,0 +1,686 @@ +import io +from datetime import datetime +from typing import Any, Dict, List, Tuple, Union + +from ConfigSpace.configuration_space import Configuration + +import numpy as np + +import scipy + +from smac.runhistory.runhistory import RunHistory, RunKey, RunValue +from smac.tae import StatusType +from smac.utils.io.traj_logging import TrajEntry + +from autoPyTorch.pipeline.components.training.metrics.base import autoPyTorchMetric + + +# TODO remove StatusType.RUNNING at some point in the future when the new SMAC 0.13.2 +# is the new minimum required version! +STATUS_TYPES = [ + StatusType.SUCCESS, + # Success (but did not advance to higher budget such as cutoff by hyperband) + StatusType.DONOTADVANCE, + StatusType.TIMEOUT, + StatusType.CRASHED, + StatusType.ABORT, + StatusType.MEMOUT +] + + +def cost2metric(cost: float, metric: autoPyTorchMetric) -> float: + """ + Revert cost metric evaluated in SMAC to the original metric. + + The conversion is defined in: + autoPyTorch/pipeline/components/training/metrics/utils.py::calculate_loss + cost = metric._optimum - metric._sign * original_metric_value + ==> original_metric_value = metric._sign * (metric._optimum - cost) + """ + return metric._sign * (metric._optimum - cost) + + +def get_start_time(run_history: RunHistory) -> float: + """ + Get start time of optimization. + + Args: + run_history (RunHistory): + The history of config evals from SMAC. + + Returns: + starttime (float): + The start time of the first training. + """ + + start_times = [] + for run_value in run_history.data.values(): + if run_value.status in (StatusType.STOP, StatusType.RUNNING): + continue + elif run_value.status not in STATUS_TYPES: + raise ValueError(f'Unexpected run status: {run_value.status}') + + start_times.append(run_value.starttime) + + return float(np.min(start_times)) # mypy redefinition + + +def _extract_metrics_info( + run_value: RunValue, + scoring_functions: List[autoPyTorchMetric], + inference_name: str +) -> Dict[str, float]: + """ + Extract the metric information given a run_value + and a list of metrics of interest. + + Args: + run_value (RunValue): + The information for each config evaluation. + scoring_functions (List[autoPyTorchMetric]): + The list of metrics to retrieve the info. + inference_name (str): + The name of the inference. Either `train`, `opt` or `test`. + + Returns: + metric_info (Dict[str, float]): + The metric values of interest. + Since the metrics in additional_info are `cost`, + we transform them into the original form. + """ + + if run_value.status not in (StatusType.SUCCESS, StatusType.DONOTADVANCE): + # Additional info for metrics is not available in this case. + return {metric.name: metric._worst_possible_result for metric in scoring_functions} + + inference_choices = ['train', 'opt', 'test'] + if inference_name not in inference_choices: + raise ValueError(f'inference_name must be in {inference_choices}, but got {inference_choices}') + + cost_info = run_value.additional_info[f'{inference_name}_loss'] + avail_metrics = cost_info.keys() + + return { + metric.name: cost2metric(cost=cost_info[metric.name], metric=metric) + if metric.name in avail_metrics else metric._worst_possible_result + for metric in scoring_functions + } + + +class EnsembleResults: + def __init__( + self, + metric: autoPyTorchMetric, + ensemble_performance_history: List[Dict[str, Any]], + order_by_endtime: bool = False + ): + """ + The wrapper class for ensemble_performance_history. + This class extracts the information from ensemble_performance_history + and allows other class to easily handle the history. + + Attributes: + train_scores (List[float]): + The ensemble scores on the training dataset. + test_scores (List[float]): + The ensemble scores on the test dataset. + end_times (List[float]): + The end time of the end of each ensemble evaluation. + Each element is a float timestamp. + empty (bool): + Whether the ensemble history about `self.metric` is empty or not. + metric (autoPyTorchMetric): + The information about the metric to contain. + In the case when such a metric does not exist in the record, + This class raises KeyError. + """ + self._test_scores: List[float] = [] + self._train_scores: List[float] = [] + self._end_times: List[float] = [] + self._metric = metric + self._empty = True # Initial state is empty. + self._instantiated = False + + self._extract_results_from_ensemble_performance_history(ensemble_performance_history) + if order_by_endtime: + self._sort_by_endtime() + + self._instantiated = True + + @property + def train_scores(self) -> np.ndarray: + return np.asarray(self._train_scores) + + @property + def test_scores(self) -> np.ndarray: + return np.asarray(self._test_scores) + + @property + def end_times(self) -> np.ndarray: + return np.asarray(self._end_times) + + @property + def metric_name(self) -> str: + return self._metric.name + + def empty(self) -> bool: + """ This is not property to follow coding conventions. """ + return self._empty + + def _update(self, data: Dict[str, Any]) -> None: + if self._instantiated: + raise RuntimeError( + 'EnsembleResults should not be overwritten once instantiated. ' + 'Instantiate new object rather than using update.' + ) + + self._train_scores.append(data[f'train_{self.metric_name}']) + self._test_scores.append(data[f'test_{self.metric_name}']) + self._end_times.append(datetime.timestamp(data['Timestamp'])) + + def _sort_by_endtime(self) -> None: + """ + Since the default order is by start time + and parallel computation might change the order of ending, + this method provides the feature to sort by end time. + Note that this method is destructive. + """ + if self._instantiated: + raise RuntimeError( + 'EnsembleResults should not be overwritten once instantiated. ' + 'Instantiate new object with order_by_endtime=True.' + ) + + order = np.argsort(self._end_times) + + self._train_scores = self.train_scores[order].tolist() + self._test_scores = self.test_scores[order].tolist() + self._end_times = self.end_times[order].tolist() + + def _extract_results_from_ensemble_performance_history( + self, + ensemble_performance_history: List[Dict[str, Any]] + ) -> None: + """ + Extract information to from `ensemble_performance_history` + to match the format of this class format. + + Args: + ensemble_performance_history (List[Dict[str, Any]]): + The history of the ensemble performance from EnsembleBuilder. + Its key must be either `train_xxx`, `test_xxx` or `Timestamp`. + """ + + if ( + len(ensemble_performance_history) == 0 + or f'train_{self.metric_name}' not in ensemble_performance_history[0].keys() + ): + self._empty = True + return + + self._empty = False # We can extract ==> not empty + for data in ensemble_performance_history: + self._update(data) + + +class SearchResults: + def __init__( + self, + metric: autoPyTorchMetric, + scoring_functions: List[autoPyTorchMetric], + run_history: RunHistory, + order_by_endtime: bool = False + ): + """ + The wrapper class for run_history. + This class extracts the information from run_history + and allows other class to easily handle the history. + Note that the data is sorted by starttime by default and + metric_dict has the original form of metric value, i.e. not necessarily cost. + + Attributes: + train_metric_dict (Dict[str, List[float]]): + The extracted train metric information at each evaluation. + Each list keeps the metric information specified by scoring_functions and metric. + opt_metric_dict (Dict[str, List[float]]): + The extracted opt metric information at each evaluation. + Each list keeps the metric information specified by scoring_functions and metric. + test_metric_dict (Dict[str, List[float]]): + The extracted test metric information at each evaluation. + Each list keeps the metric information specified by scoring_functions and metric. + fit_times (List[float]): + The time needed to fit each model. + end_times (List[float]): + The end time of the end of each evaluation. + Each element is a float timestamp. + configs (List[Configuration]): + The configurations at each evaluation. + status_types (List[StatusType]): + The list of status types of each evaluation (e.g. success, crush). + budgets (List[float]): + The budgets used for each evaluation. + Here, budget refers to the definition in Hyperband or Successive halving. + config_ids (List[int]): + The ID of each configuration. Since we use cutoff such as in Hyperband, + we need to store it to know whether each configuration is a suvivor. + is_traditionals (List[bool]): + Whether each configuration is from traditional machine learning methods. + additional_infos (List[Dict[str, float]]): + It usually serves as the source of each metric at each evaluation. + In other words, train or test performance is extracted from this info. + rank_opt_scores (np.ndarray): + The rank of each evaluation among all the evaluations. + metric (autoPyTorchMetric): + The metric of the main interest. + scoring_functions (List[autoPyTorchMetric]): + The list of metrics to contain in the additional_infos. + """ + if metric not in scoring_functions: + scoring_functions.append(metric) + + self.train_metric_dict: Dict[str, List[float]] = {metric.name: [] for metric in scoring_functions} + self.opt_metric_dict: Dict[str, List[float]] = {metric.name: [] for metric in scoring_functions} + self.test_metric_dict: Dict[str, List[float]] = {metric.name: [] for metric in scoring_functions} + + self._fit_times: List[float] = [] + self._end_times: List[float] = [] + self.configs: List[Configuration] = [] + self.status_types: List[StatusType] = [] + self.budgets: List[float] = [] + self.config_ids: List[int] = [] + self.is_traditionals: List[bool] = [] + self.additional_infos: List[Dict[str, float]] = [] + self.rank_opt_scores: np.ndarray = np.array([]) + self._scoring_functions = scoring_functions + self._metric = metric + self._instantiated = False + + self._extract_results_from_run_history(run_history) + if order_by_endtime: + self._sort_by_endtime() + + self._instantiated = True + + @property + def train_scores(self) -> np.ndarray: + """ training metric values at each evaluation """ + return np.asarray(self.train_metric_dict[self.metric_name]) + + @property + def opt_scores(self) -> np.ndarray: + """ validation metric values at each evaluation """ + return np.asarray(self.opt_metric_dict[self.metric_name]) + + @property + def test_scores(self) -> np.ndarray: + """ test metric values at each evaluation """ + return np.asarray(self.test_metric_dict[self.metric_name]) + + @property + def fit_times(self) -> np.ndarray: + return np.asarray(self._fit_times) + + @property + def end_times(self) -> np.ndarray: + return np.asarray(self._end_times) + + @property + def metric_name(self) -> str: + return self._metric.name + + def _update( + self, + config: Configuration, + run_key: RunKey, + run_value: RunValue + ) -> None: + + if self._instantiated: + raise RuntimeError( + 'SearchResults should not be overwritten once instantiated. ' + 'Instantiate new object rather than using update.' + ) + elif run_value.status in (StatusType.STOP, StatusType.RUNNING): + return + elif run_value.status not in STATUS_TYPES: + raise ValueError(f'Unexpected run status: {run_value.status}') + + is_traditional = False # If run is not successful, unsure ==> not True ==> False + if run_value.additional_info is not None: + is_traditional = run_value.additional_info['configuration_origin'] == 'traditional' + + self.status_types.append(run_value.status) + self.configs.append(config) + self.budgets.append(run_key.budget) + self.config_ids.append(run_key.config_id) + self.is_traditionals.append(is_traditional) + self.additional_infos.append(run_value.additional_info) + self._fit_times.append(run_value.time) + self._end_times.append(run_value.endtime) + + for inference_name in ['train', 'opt', 'test']: + metric_info = _extract_metrics_info( + run_value=run_value, + scoring_functions=self._scoring_functions, + inference_name=inference_name + ) + for metric_name, val in metric_info.items(): + getattr(self, f'{inference_name}_metric_dict')[metric_name].append(val) + + def _sort_by_endtime(self) -> None: + """ + Since the default order is by start time + and parallel computation might change the order of ending, + this method provides the feature to sort by end time. + Note that this method is destructive. + """ + if self._instantiated: + raise RuntimeError( + 'SearchResults should not be overwritten once instantiated. ' + 'Instantiate new object with order_by_endtime=True.' + ) + + order = np.argsort(self._end_times) + + self.train_metric_dict = {name: [arr[idx] for idx in order] for name, arr in self.train_metric_dict.items()} + self.opt_metric_dict = {name: [arr[idx] for idx in order] for name, arr in self.opt_metric_dict.items()} + self.test_metric_dict = {name: [arr[idx] for idx in order] for name, arr in self.test_metric_dict.items()} + + self._fit_times = [self._fit_times[idx] for idx in order] + self._end_times = [self._end_times[idx] for idx in order] + self.status_types = [self.status_types[idx] for idx in order] + self.budgets = [self.budgets[idx] for idx in order] + self.config_ids = [self.config_ids[idx] for idx in order] + self.is_traditionals = [self.is_traditionals[idx] for idx in order] + self.additional_infos = [self.additional_infos[idx] for idx in order] + + # Don't use numpy slicing to avoid version dependency (cast config to object might cause issues) + self.configs = [self.configs[idx] for idx in order] + + # Only rank_opt_scores is np.ndarray + self.rank_opt_scores = self.rank_opt_scores[order] + + def _extract_results_from_run_history(self, run_history: RunHistory) -> None: + """ + Extract the information to match this class format. + + Args: + run_history (RunHistory): + The history of config evals from SMAC. + """ + + for run_key, run_value in run_history.data.items(): + config = run_history.ids_config[run_key.config_id] + self._update(config=config, run_key=run_key, run_value=run_value) + + self.rank_opt_scores = scipy.stats.rankdata( + -1 * self._metric._sign * self.opt_scores, # rank order + method='min' + ) + + +class MetricResults: + def __init__( + self, + metric: autoPyTorchMetric, + run_history: RunHistory, + ensemble_performance_history: List[Dict[str, Any]] + ): + """ + The wrapper class for ensemble_performance_history. + This class extracts the information from ensemble_performance_history + and allows other class to easily handle the history. + Note that all the data is sorted by endtime! + + Attributes: + start_time (float): + The timestamp at the very beginning of the optimization. + cum_times (np.ndarray): + The runtime needed to reach the end of each evaluation. + The time unit is second. + metric (autoPyTorchMetric): + The information about the metric to contain. + search_results (SearchResults): + The instance to fetch the metric values of `self.metric` + from run_history. + ensemble_results (EnsembleResults): + The instance to fetch the metric values of `self.metric` + from ensemble_performance_history. + If there is no information available, self.empty() returns True. + data (Dict[str, np.ndarray]): + Keys are `{single, ensemble}::{train, opt, test}::{metric.name}`. + Each array contains the evaluated values for the corresponding category. + """ + self.start_time = get_start_time(run_history) + self.metric = metric + self.search_results = SearchResults( + metric=metric, + run_history=run_history, + scoring_functions=[], + order_by_endtime=True + ) + self.ensemble_results = EnsembleResults( + metric=metric, + ensemble_performance_history=ensemble_performance_history, + order_by_endtime=True + ) + + if ( + not self.ensemble_results.empty() + and self.search_results.end_times[-1] < self.ensemble_results.end_times[-1] + ): + # Augment runtime table with the final available end time + self.cum_times = np.hstack( + [self.search_results.end_times - self.start_time, + [self.ensemble_results.end_times[-1] - self.start_time]] + ) + else: + self.cum_times = self.search_results.end_times - self.start_time + + self.data: Dict[str, np.ndarray] = {} + self._extract_results() + + def _extract_results(self) -> None: + """ Extract metric values of `self.metric` and store them in `self.data`. """ + metric_name = self.metric.name + for inference_name in ['train', 'test', 'opt']: + # TODO: Extract information from self.search_results + data = getattr(self.search_results, f'{inference_name}_metric_dict')[metric_name] + self.data[f'single::{inference_name}::{metric_name}'] = np.array(data) + + if self.ensemble_results.empty() or inference_name == 'opt': + continue + + data = getattr(self.ensemble_results, f'{inference_name}_scores') + self.data[f'ensemble::{inference_name}::{metric_name}'] = np.array(data) + + def get_ensemble_merged_data(self) -> Dict[str, np.ndarray]: + """ + Merge the ensemble performance data to the closest time step + available in the run_history. + One performance metric will be allocated to one time step. + Other time steps will be filled by the worst possible value. + + Returns: + data (Dict[str, np.ndarray]): + Merged data as mentioned above + """ + + data = {k: v.copy() for k, v in self.data.items()} # deep copy + + if self.ensemble_results.empty(): # no ensemble data available + return data + + train_scores, test_scores = self.ensemble_results.train_scores, self.ensemble_results.test_scores + end_times = self.ensemble_results.end_times + cur, timestep_size, sign = 0, self.cum_times.size, self.metric._sign + key_train, key_test = f'ensemble::train::{self.metric.name}', f'ensemble::test::{self.metric.name}' + + train_perfs = np.full_like(self.cum_times, self.metric._worst_possible_result) + test_perfs = np.full_like(self.cum_times, self.metric._worst_possible_result) + + for timestamp, train_score, test_score in zip(end_times, train_scores, test_scores): + avail_time = timestamp - self.start_time + while cur < timestep_size and self.cum_times[cur] < avail_time: + # Guarantee that cum_times[cur] >= avail_time + cur += 1 + + # results[cur] is the closest available checkpoint after or at the avail_time + # ==> Assign this data to that checkpoint + time_index = min(cur, timestep_size - 1) + # If there already exists a previous allocated value, update by a better value + train_perfs[time_index] = sign * max(sign * train_perfs[time_index], sign * train_score) + test_perfs[time_index] = sign * max(sign * test_perfs[time_index], sign * test_score) + + data.update({key_train: train_perfs, key_test: test_perfs}) + return data + + +class ResultsManager: + def __init__(self, *args: Any, **kwargs: Any): + """ + This module is used to gather result information for BaseTask. + In other words, this module is supposed to be wrapped by BaseTask. + + Attributes: + run_history (RunHistory): + A `SMAC Runshistory `_ + object that holds information about the runs of the target algorithm made during search + ensemble_performance_history (List[Dict[str, Any]]): + The history of the ensemble performance from EnsembleBuilder. + Its keys are `train_xxx`, `test_xxx` or `Timestamp`. + trajectory (List[TrajEntry]): + A list of all incumbent configurations during search + """ + self.run_history: RunHistory = RunHistory() + self.ensemble_performance_history: List[Dict[str, Any]] = [] + self.trajectory: List[TrajEntry] = [] + + def _check_run_history(self) -> None: + if self.run_history is None: + raise RuntimeError("No Run History found, search has not been called.") + + if self.run_history.empty(): + raise RuntimeError("Run History is empty. Something went wrong, " + "SMAC was not able to fit any model?") + + def get_incumbent_results( + self, + metric: autoPyTorchMetric, + include_traditional: bool = False + ) -> Tuple[Configuration, Dict[str, Union[int, str, float]]]: + """ + Get Incumbent config and the corresponding results + + Args: + metric (autoPyTorchMetric): + A metric that is evaluated when searching with fit AutoPytorch. + include_traditional (bool): + Whether to include results from tradtional pipelines + + Returns: + Configuration (CS.ConfigurationSpace): + The incumbent configuration + Dict[str, Union[int, str, float]]: + Additional information about the run of the incumbent configuration. + """ + self._check_run_history() + + results = SearchResults(metric=metric, scoring_functions=[], run_history=self.run_history) + + if not include_traditional: + non_traditional = ~np.array(results.is_traditionals) + scores = results.opt_scores[non_traditional] + indices = np.arange(len(results.configs))[non_traditional] + else: + scores = results.opt_scores + indices = np.arange(len(results.configs)) + + incumbent_idx = indices[np.argmax(metric._sign * scores)] + incumbent_config = results.configs[incumbent_idx] + incumbent_results = results.additional_infos[incumbent_idx] + + assert incumbent_results is not None # mypy check + return incumbent_config, incumbent_results + + def get_search_results( + self, + scoring_functions: List[autoPyTorchMetric], + metric: autoPyTorchMetric + ) -> SearchResults: + """ + This attribute is populated with data from `self.run_history` + and contains information about the configurations, and their + corresponding metric results, status of run, parameters and + the budget + + Args: + scoring_functions (List[autoPyTorchMetric]): + Metrics to show in the results. + metric (autoPyTorchMetric): + A metric that is evaluated when searching with fit AutoPytorch. + + Returns: + SearchResults: + An instance that contains the results from search + """ + self._check_run_history() + return SearchResults(metric=metric, scoring_functions=scoring_functions, run_history=self.run_history) + + def sprint_statistics( + self, + dataset_name: str, + scoring_functions: List[autoPyTorchMetric], + metric: autoPyTorchMetric + ) -> str: + """ + Prints statistics about the SMAC search. + + These statistics include: + + 1. Optimisation Metric + 2. Best Optimisation score achieved by individual pipelines + 3. Total number of target algorithm runs + 4. Total number of successful target algorithm runs + 5. Total number of crashed target algorithm runs + 6. Total number of target algorithm runs that exceeded the time limit + 7. Total number of successful target algorithm runs that exceeded the memory limit + + Args: + dataset_name (str): + The dataset name that was used in the run. + scoring_functions (List[autoPyTorchMetric]): + Metrics to show in the results. + metric (autoPyTorchMetric): + A metric that is evaluated when searching with fit AutoPytorch. + + Returns: + (str): + Formatted string with statistics + """ + search_results = self.get_search_results(scoring_functions, metric) + success_status = (StatusType.SUCCESS, StatusType.DONOTADVANCE) + sio = io.StringIO() + sio.write("autoPyTorch results:\n") + sio.write(f"\tDataset name: {dataset_name}\n") + sio.write(f"\tOptimisation Metric: {metric}\n") + + num_runs = len(search_results.status_types) + num_success = sum([s in success_status for s in search_results.status_types]) + num_crash = sum([s == StatusType.CRASHED for s in search_results.status_types]) + num_timeout = sum([s == StatusType.TIMEOUT for s in search_results.status_types]) + num_memout = sum([s == StatusType.MEMOUT for s in search_results.status_types]) + + if num_success > 0: + best_score = metric._sign * np.max(metric._sign * search_results.opt_scores) + sio.write(f"\tBest validation score: {best_score}\n") + + sio.write(f"\tNumber of target algorithm runs: {num_runs}\n") + sio.write(f"\tNumber of successful target algorithm runs: {num_success}\n") + sio.write(f"\tNumber of crashed target algorithm runs: {num_crash}\n") + sio.write(f"\tNumber of target algorithms that exceeded the time " + f"limit: {num_timeout}\n") + sio.write(f"\tNumber of target algorithms that exceeded the memory " + f"limit: {num_memout}\n") + + return sio.getvalue() diff --git a/autoPyTorch/utils/results_visualizer.py b/autoPyTorch/utils/results_visualizer.py new file mode 100644 index 000000000..64c87ba94 --- /dev/null +++ b/autoPyTorch/utils/results_visualizer.py @@ -0,0 +1,310 @@ +from dataclasses import dataclass +from enum import Enum +from typing import Any, Dict, Optional, Tuple + +import matplotlib.pyplot as plt + +import numpy as np + +from autoPyTorch.utils.results_manager import MetricResults + + +plt.rcParams["font.family"] = "Times New Roman" +plt.rcParams["font.size"] = 18 + + +@dataclass(frozen=True) +class ColorLabelSettings: + """ + The settings for each plot. + If None is provided, those plots are omitted. + + Attributes: + single_train (Optional[Tuple[Optional[str], Optional[str]]]): + The setting for the plot of the optimal single train result. + single_opt (Optional[Tuple[Optional[str], Optional[str]]]): + The setting for the plot of the optimal single result used in optimization. + single_test (Optional[Tuple[Optional[str], Optional[str]]]): + The setting for the plot of the optimal single test result. + ensemble_train (Optional[Tuple[Optional[str], Optional[str]]]): + The setting for the plot of the optimal ensemble train result. + ensemble_test (Optional[Tuple[Optional[str], Optional[str]]]): + The setting for the plot of the optimal ensemble test result. + """ + single_train: Optional[Tuple[Optional[str], Optional[str]]] = ('red', None) + single_opt: Optional[Tuple[Optional[str], Optional[str]]] = ('blue', None) + single_test: Optional[Tuple[Optional[str], Optional[str]]] = ('green', None) + ensemble_train: Optional[Tuple[Optional[str], Optional[str]]] = ('brown', None) + ensemble_test: Optional[Tuple[Optional[str], Optional[str]]] = ('purple', None) + + def extract_dicts( + self, + results: MetricResults + ) -> Tuple[Dict[str, Optional[str]], Dict[str, Optional[str]]]: + """ + Args: + results (MetricResults): + The results of the optimization in the base task API. + It determines what keys to include. + + Returns: + colors, labels (Tuple[Dict[str, Optional[str]], Dict[str, Optional[str]]]): + The dicts for colors and labels. + The keys are determined by results and each label and color + are determined by each instantiation. + Note that the keys include the metric name. + """ + + colors, labels = {}, {} + + for key, color_label in vars(self).items(): + if color_label is None: + continue + + prefix = '::'.join(key.split('_')) + try: + new_key = [key for key in results.data.keys() if key.startswith(prefix)][0] + colors[new_key], labels[new_key] = color_label + except IndexError: # ensemble does not always have results + pass + + return colors, labels + + +@dataclass(frozen=True) +class PlotSettingParams: + """ + Parameters for the plot environment. + + Attributes: + n_points (int): + The number of points to plot. + xlabel (Optional[str]): + The label in the x axis. + ylabel (Optional[str]): + The label in the y axis. + xscale (str): + The scale of x axis. + yscale (str): + The scale of y axis. + title (Optional[str]): + The title of the subfigure. + xlim (Tuple[float, float]): + The range of x axis. + ylim (Tuple[float, float]): + The range of y axis. + legend (bool): + Whether to have legend in the figure. + legend_loc (str): + The location of the legend. + show (bool): + Whether to show the plot. + args, kwargs (Any): + Arguments for the ax.plot. + """ + n_points: int = 20 + xscale: str = 'linear' + yscale: str = 'linear' + xlabel: Optional[str] = None + ylabel: Optional[str] = None + title: Optional[str] = None + xlim: Optional[Tuple[float, float]] = None + ylim: Optional[Tuple[float, float]] = None + legend: bool = True + legend_loc: str = 'best' + show: bool = False + figsize: Optional[Tuple[int, int]] = None + + +class ScaleChoices(Enum): + linear = 'linear' + log = 'log' + + +def _get_perf_and_time( + cum_results: np.ndarray, + cum_times: np.ndarray, + plot_setting_params: PlotSettingParams, + worst_val: float +) -> Tuple[np.ndarray, np.ndarray]: + """ + Get the performance and time step to plot. + + Args: + cum_results (np.ndarray): + The cumulated performance per evaluation. + cum_times (np.ndarray): + The cumulated runtime at the end of each evaluation. + plot_setting_params (PlotSettingParams): + Parameters for the plot. + worst_val (float): + The worst possible value given a metric. + + Returns: + check_points (np.ndarray): + The time in second where the plot will happen. + perf_by_time_step (np.ndarray): + The best performance at the corresponding time in second + where the plot will happen. + """ + + scale_choices = [s.name for s in ScaleChoices] + if plot_setting_params.xscale not in scale_choices or plot_setting_params.yscale not in scale_choices: + raise ValueError(f'xscale and yscale must be in {scale_choices}, ' + f'but got xscale={plot_setting_params.xscale}, yscale={plot_setting_params.yscale}') + + n_evals, runtime_lb, runtime_ub = cum_results.size, cum_times[0], cum_times[-1] + + if plot_setting_params.xscale == 'log': + # Take the even time interval in the log scale and revert + check_points = np.exp(np.linspace(np.log(runtime_lb), np.log(runtime_ub), plot_setting_params.n_points)) + else: + check_points = np.linspace(runtime_lb, runtime_ub, plot_setting_params.n_points) + + check_points += 1e-8 # Prevent float error + + # The worst possible value is always at the head + perf_by_time_step = np.full_like(check_points, worst_val) + cur = 0 + + for i, check_point in enumerate(check_points): + while cur < n_evals and cum_times[cur] <= check_point: + # Guarantee that cum_times[cur] > check_point + # ==> cum_times[cur - 1] <= check_point + cur += 1 + if cur: # filter cur - 1 == -1 + # results[cur - 1] was obtained before or at the checkpoint + # ==> The best performance up to this checkpoint + perf_by_time_step[i] = cum_results[cur - 1] + + if plot_setting_params.yscale == 'log' and np.any(perf_by_time_step < 0): + raise ValueError('log scale is not available when performance metric can be negative.') + + return check_points, perf_by_time_step + + +class ResultsVisualizer: + @staticmethod + def _set_plot_args( + ax: plt.Axes, + plot_setting_params: PlotSettingParams + ) -> None: + if plot_setting_params.xlim is not None: + ax.set_xlim(*plot_setting_params.xlim) + if plot_setting_params.ylim is not None: + ax.set_ylim(*plot_setting_params.ylim) + + if plot_setting_params.xlabel is not None: + ax.set_xlabel(plot_setting_params.xlabel) + if plot_setting_params.ylabel is not None: + ax.set_ylabel(plot_setting_params.ylabel) + + ax.set_xscale(plot_setting_params.xscale) + ax.set_yscale(plot_setting_params.yscale) + if plot_setting_params.xscale == 'log' or plot_setting_params.yscale == 'log': + ax.grid(True, which='minor', color='gray', linestyle=':') + + ax.grid(True, which='major', color='black') + + if plot_setting_params.legend: + ax.legend(loc=plot_setting_params.legend_loc) + + if plot_setting_params.title is not None: + ax.set_title(plot_setting_params.title) + if plot_setting_params.show: + plt.show() + + @staticmethod + def _plot_individual_perf_over_time( + ax: plt.Axes, + cum_times: np.ndarray, + cum_results: np.ndarray, + worst_val: float, + plot_setting_params: PlotSettingParams, + label: Optional[str] = None, + color: Optional[str] = None, + *args: Any, + **kwargs: Any + ) -> None: + """ + Plot the incumbent performance of the AutoPytorch over time. + This method is created to make plot_perf_over_time more readable + and it is not supposed to be used only in this class, but not from outside. + + Args: + ax (plt.Axes): + axis to plot (subplots of matplotlib). + cum_times (np.ndarray): + The cumulated time until each end of config evaluation. + results (np.ndarray): + The cumulated performance per evaluation. + worst_val (float): + The worst possible value given a metric. + plot_setting_params (PlotSettingParams): + Parameters for the plot. + label (Optional[str]): + The name of the plot. + color (Optional[str]): + Color of the plot. + args, kwargs (Any): + Arguments for the ax.plot. + """ + check_points, perf_by_time_step = _get_perf_and_time( + cum_results=cum_results, + cum_times=cum_times, + plot_setting_params=plot_setting_params, + worst_val=worst_val + ) + + ax.plot(check_points, perf_by_time_step, color=color, label=label, *args, **kwargs) + + def plot_perf_over_time( + self, + results: MetricResults, + plot_setting_params: PlotSettingParams, + colors: Dict[str, Optional[str]], + labels: Dict[str, Optional[str]], + ax: Optional[plt.Axes] = None, + *args: Any, + **kwargs: Any + ) -> None: + """ + Plot the incumbent performance of the AutoPytorch over time. + + Args: + results (MetricResults): + The module that handles results from various sources. + plot_setting_params (PlotSettingParams): + Parameters for the plot. + labels (Dict[str, Optional[str]]): + The name of the plot. + colors (Dict[str, Optional[str]]): + Color of the plot. + ax (Optional[plt.Axes]): + axis to plot (subplots of matplotlib). + If None, it will be created automatically. + args, kwargs (Any): + Arguments for the ax.plot. + """ + if ax is None: + _, ax = plt.subplots(nrows=1, ncols=1) + + data = results.get_ensemble_merged_data() + cum_times = results.cum_times + minimize = (results.metric._sign == -1) + + for key in data.keys(): + _label, _color, _perfs = labels[key], colors[key], data[key] + # Take the best results over time + _cum_perfs = np.minimum.accumulate(_perfs) if minimize else np.maximum.accumulate(_perfs) + + self._plot_individual_perf_over_time( # type: ignore + ax=ax, cum_results=_cum_perfs, cum_times=cum_times, + plot_setting_params=plot_setting_params, + worst_val=results.metric._worst_possible_result, + label=_label if _label is not None else ' '.join(key.split('::')), + color=_color, + *args, **kwargs + ) + + self._set_plot_args(ax=ax, plot_setting_params=plot_setting_params) diff --git a/examples/40_advanced/example_plot_over_time.py b/examples/40_advanced/example_plot_over_time.py new file mode 100644 index 000000000..9c103452e --- /dev/null +++ b/examples/40_advanced/example_plot_over_time.py @@ -0,0 +1,82 @@ +""" +============================== +Plot the Performance over Time +============================== + +Auto-Pytorch uses SMAC to fit individual machine learning algorithms +and then ensembles them together using `Ensemble Selection +`_. + +The following examples shows how to plot both the performance +of the individual models and their respective ensemble. + +Additionally, as we are compatible with matplotlib, +you can input any args or kwargs that are compatible with ax.plot. +In the case when you would like to create multipanel visualization, +please input plt.Axes obtained from matplotlib.pyplot.subplots. + +""" +import warnings + +import numpy as np +import pandas as pd + +from sklearn import model_selection + +import matplotlib.pyplot as plt + +from autoPyTorch.api.tabular_classification import TabularClassificationTask +from autoPyTorch.utils.results_visualizer import PlotSettingParams + + +warnings.simplefilter(action='ignore', category=UserWarning) +warnings.simplefilter(action='ignore', category=FutureWarning) + + +############################################################################ +# Task Definition +# =============== +n_samples, dim = 100, 2 +X = np.random.random((n_samples, dim)) * 2 - 1 +y = ((X ** 2).sum(axis=-1) < 2 / np.pi).astype(np.int32) +print(y) + +X, y = pd.DataFrame(X), pd.DataFrame(y) +X_train, X_test, y_train, y_test = model_selection.train_test_split(X, y) + +############################################################################ +# API Instantiation and Searching +# =============================== +api = TabularClassificationTask(seed=42) + +api.search(X_train=X_train, y_train=y_train, X_test=X_test, y_test=y_test, + optimize_metric='accuracy', total_walltime_limit=120, func_eval_time_limit_secs=10) + +############################################################################ +# Create Setting Parameters Object +# ================================ +metric_name = 'accuracy' + +params = PlotSettingParams( + xscale='log', + xlabel='Runtime', + ylabel='Accuracy', + title='Toy Example', + show=False # If you would like to show, make it True +) + +############################################################################ +# Plot with the Specified Setting Parameters +# ========================================== +_, ax = plt.subplots() + +api.plot_perf_over_time( + ax=ax, # You do not have to provide. + metric_name=metric_name, + plot_setting_params=params, + marker='*', + markersize=10 +) + +# plt.show() might cause issue depending on environments +plt.savefig('example_plot_over_time.png') diff --git a/test/test_api/test_results_manager.py b/test/test_api/test_results_manager.py deleted file mode 100644 index 4c6e7a7ae..000000000 --- a/test/test_api/test_results_manager.py +++ /dev/null @@ -1,232 +0,0 @@ -import json -import os -from test.test_api.utils import make_dict_run_history_data -from unittest.mock import MagicMock - -import ConfigSpace.hyperparameters as CSH -from ConfigSpace.configuration_space import Configuration, ConfigurationSpace - -import numpy as np - -import pytest - -from smac.runhistory.runhistory import RunHistory, StatusType - -from autoPyTorch.api.base_task import BaseTask -from autoPyTorch.api.results_manager import ResultsManager, STATUS2MSG, SearchResults, cost2metric -from autoPyTorch.metrics import accuracy, balanced_accuracy, log_loss - - -def _check_status(status): - """ Based on runhistory_B.json """ - ans = [ - STATUS2MSG[StatusType.SUCCESS], STATUS2MSG[StatusType.SUCCESS], - STATUS2MSG[StatusType.SUCCESS], STATUS2MSG[StatusType.SUCCESS], - STATUS2MSG[StatusType.SUCCESS], STATUS2MSG[StatusType.SUCCESS], - STATUS2MSG[StatusType.CRASHED], STATUS2MSG[StatusType.SUCCESS], - STATUS2MSG[StatusType.SUCCESS], STATUS2MSG[StatusType.SUCCESS], - STATUS2MSG[StatusType.SUCCESS], STATUS2MSG[StatusType.SUCCESS], - STATUS2MSG[StatusType.SUCCESS], STATUS2MSG[StatusType.SUCCESS], - STATUS2MSG[StatusType.TIMEOUT], STATUS2MSG[StatusType.TIMEOUT], - ] - assert isinstance(status, list) - assert isinstance(status[0], str) - assert status == ans - - -def _check_costs(costs): - """ Based on runhistory_B.json """ - ans = [0.15204678362573099, 0.4444444444444444, 0.5555555555555556, 0.29824561403508776, - 0.4444444444444444, 0.4444444444444444, 1.0, 0.5555555555555556, 0.4444444444444444, - 0.15204678362573099, 0.15204678362573099, 0.4035087719298246, 0.4444444444444444, - 0.4444444444444444, 1.0, 1.0] - assert np.allclose(1 - np.array(costs), ans) - assert isinstance(costs, np.ndarray) - assert costs.dtype is np.dtype(np.float) - - -def _check_fit_times(fit_times): - """ Based on runhistory_B.json """ - ans = [3.154788017272949, 3.2763524055480957, 22.723600149154663, 4.990685224533081, 10.684926509857178, - 9.947429180145264, 11.687273979187012, 8.478890419006348, 5.485020637512207, 11.514830589294434, - 15.370736837387085, 23.846530199050903, 6.757539510726929, 15.061991930007935, 50.010520696640015, - 22.011935234069824] - - assert np.allclose(fit_times, ans) - assert isinstance(fit_times, np.ndarray) - assert fit_times.dtype is np.dtype(np.float) - - -def _check_budgets(budgets): - """ Based on runhistory_B.json """ - ans = [5.555555555555555, 5.555555555555555, 5.555555555555555, 5.555555555555555, - 5.555555555555555, 5.555555555555555, 5.555555555555555, 5.555555555555555, - 5.555555555555555, 16.666666666666664, 50.0, 16.666666666666664, 16.666666666666664, - 16.666666666666664, 50.0, 50.0] - assert np.allclose(budgets, ans) - assert isinstance(budgets, list) - assert isinstance(budgets[0], float) - - -def _check_additional_infos(status_types, additional_infos): - for i, status in enumerate(status_types): - info = additional_infos[i] - if status in (STATUS2MSG[StatusType.SUCCESS], STATUS2MSG[StatusType.DONOTADVANCE]): - metric_info = info.get('opt_loss', None) - assert metric_info is not None - elif info is not None: - metric_info = info.get('opt_loss', None) - assert metric_info is None - - -def _check_metric_dict(metric_dict, status_types): - assert isinstance(metric_dict['accuracy'], list) - assert metric_dict['accuracy'][0] > 0 - assert isinstance(metric_dict['balanced_accuracy'], list) - assert metric_dict['balanced_accuracy'][0] > 0 - - for key, vals in metric_dict.items(): - # ^ is a XOR operator - # True and False / False and True must be fulfilled - assert all([(s == STATUS2MSG[StatusType.SUCCESS]) ^ isnan - for s, isnan in zip(status_types, np.isnan(vals))]) - - -def test_extract_results_from_run_history(): - # test the raise error for the `status_msg is None` - run_history = RunHistory() - cs = ConfigurationSpace() - config = Configuration(cs, {}) - run_history.add( - config=config, - cost=0.0, - time=1.0, - status=StatusType.CAPPED, - ) - with pytest.raises(ValueError) as excinfo: - SearchResults(metric=accuracy, scoring_functions=[], run_history=run_history) - - assert excinfo._excinfo[0] == ValueError - - -def test_search_results_sprint_statistics(): - api = BaseTask() - for method in ['get_search_results', 'sprint_statistics', 'get_incumbent_results']: - with pytest.raises(RuntimeError) as excinfo: - getattr(api, method)() - - assert excinfo._excinfo[0] == RuntimeError - - run_history_data = json.load(open(os.path.join(os.path.dirname(__file__), - '.tmp_api/runhistory_B.json'), - mode='r'))['data'] - api._results_manager.run_history = MagicMock() - api.run_history.empty = MagicMock(return_value=False) - - # The run_history has 16 runs + 1 run interruption ==> 16 runs - api.run_history.data = make_dict_run_history_data(run_history_data) - api._metric = accuracy - api.dataset_name = 'iris' - api._scoring_functions = [accuracy, balanced_accuracy] - api.search_space = MagicMock(spec=ConfigurationSpace) - search_results = api.get_search_results() - - _check_status(search_results.status_types) - _check_costs(search_results.opt_scores) - _check_fit_times(search_results.fit_times) - _check_budgets(search_results.budgets) - _check_metric_dict(search_results.metric_dict, search_results.status_types) - _check_additional_infos(status_types=search_results.status_types, - additional_infos=search_results.additional_infos) - - # config_ids can duplicate because of various budget size - config_ids = [1, 2, 3, 4, 5, 6, 7, 8, 9, 1, 1, 10, 11, 12, 10, 13] - assert config_ids == search_results.config_ids - - # assert that contents of search_results are of expected types - assert isinstance(search_results.rank_test_scores, np.ndarray) - assert search_results.rank_test_scores.dtype is np.dtype(np.int) - assert isinstance(search_results.configs, list) - - n_success, n_timeout, n_memoryout, n_crashed = 13, 2, 0, 1 - msg = ["autoPyTorch results:", f"\tDataset name: {api.dataset_name}", - f"\tOptimisation Metric: {api._metric.name}", - f"\tBest validation score: {max(search_results.opt_scores)}", - "\tNumber of target algorithm runs: 16", f"\tNumber of successful target algorithm runs: {n_success}", - f"\tNumber of crashed target algorithm runs: {n_crashed}", - f"\tNumber of target algorithms that exceeded the time limit: {n_timeout}", - f"\tNumber of target algorithms that exceeded the memory limit: {n_memoryout}"] - - assert isinstance(api.sprint_statistics(), str) - assert all([m1 == m2 for m1, m2 in zip(api.sprint_statistics().split("\n"), msg)]) - - -@pytest.mark.parametrize('run_history', (None, RunHistory())) -def test_check_run_history(run_history): - manager = ResultsManager() - manager.run_history = run_history - - with pytest.raises(RuntimeError) as excinfo: - manager._check_run_history() - - assert excinfo._excinfo[0] == RuntimeError - - -T, NT = 'traditional', 'non-traditional' -SCORES = [0.1 * (i + 1) for i in range(10)] - - -@pytest.mark.parametrize('include_traditional', (True, False)) -@pytest.mark.parametrize('metric', (accuracy, log_loss)) -@pytest.mark.parametrize('origins', ([T] * 5 + [NT] * 5, [T, NT] * 5, [NT] * 5 + [T] * 5)) -@pytest.mark.parametrize('scores', (SCORES, SCORES[::-1])) -def test_get_incumbent_results(include_traditional, metric, origins, scores): - manager = ResultsManager() - cs = ConfigurationSpace() - cs.add_hyperparameter(CSH.UniformFloatHyperparameter('a', lower=0, upper=1)) - - configs = [0.1 * (i + 1) for i in range(len(scores))] - if metric.name == "log_loss": - # This is to detect mis-computation in reversion - metric._optimum = 0.1 - - best_cost, best_idx = np.inf, -1 - for idx, (a, origin, score) in enumerate(zip(configs, origins, scores)): - config = Configuration(cs, {'a': a}) - - # conversion defined in: - # autoPyTorch/pipeline/components/training/metrics/utils.py::calculate_loss - cost = metric._optimum - metric._sign * score - manager.run_history.add( - config=config, - cost=cost, - time=1.0, - status=StatusType.SUCCESS, - additional_info={'opt_loss': {metric.name: score}, - 'configuration_origin': origin} - ) - if cost > best_cost: - continue - - if include_traditional: - best_cost, best_idx = cost, idx - elif origin != T: - best_cost, best_idx = cost, idx - - incumbent_config, incumbent_results = manager.get_incumbent_results( - metric=metric, - include_traditional=include_traditional - ) - - assert isinstance(incumbent_config, Configuration) - assert isinstance(incumbent_results, dict) - best_score, best_a = scores[best_idx], configs[best_idx] - assert np.allclose( - [best_score, best_score, best_a], - [cost2metric(best_cost, metric), - incumbent_results['opt_loss'][metric.name], - incumbent_config['a']] - ) - - if not include_traditional: - assert incumbent_results['configuration_origin'] != T diff --git a/test/test_api/.tmp_api/runhistory_B.json b/test/test_utils/runhistory.json similarity index 100% rename from test/test_api/.tmp_api/runhistory_B.json rename to test/test_utils/runhistory.json diff --git a/test/test_utils/test_results_manager.py b/test/test_utils/test_results_manager.py new file mode 100644 index 000000000..60ee11f42 --- /dev/null +++ b/test/test_utils/test_results_manager.py @@ -0,0 +1,484 @@ +import json +import os +from datetime import datetime +from test.test_api.utils import make_dict_run_history_data +from unittest.mock import MagicMock + +import ConfigSpace.hyperparameters as CSH +from ConfigSpace.configuration_space import Configuration, ConfigurationSpace + +import numpy as np + +import pytest + +from smac.runhistory.runhistory import RunHistory, RunKey, RunValue, StatusType + +from autoPyTorch.api.base_task import BaseTask +from autoPyTorch.metrics import accuracy, balanced_accuracy, log_loss +from autoPyTorch.utils.results_manager import ( + EnsembleResults, + MetricResults, + ResultsManager, + SearchResults, + cost2metric, + get_start_time +) + + +T, NT = 'traditional', 'non-traditional' +SCORES = [0.1 * (i + 1) for i in range(10)] +END_TIMES = [8, 4, 3, 6, 0, 7, 1, 9, 2, 5] + + +def _check_status(status): + """ Based on runhistory.json """ + ans = [ + StatusType.SUCCESS, StatusType.SUCCESS, + StatusType.SUCCESS, StatusType.SUCCESS, + StatusType.SUCCESS, StatusType.SUCCESS, + StatusType.CRASHED, StatusType.SUCCESS, + StatusType.SUCCESS, StatusType.SUCCESS, + StatusType.SUCCESS, StatusType.SUCCESS, + StatusType.SUCCESS, StatusType.SUCCESS, + StatusType.TIMEOUT, StatusType.TIMEOUT, + ] + assert isinstance(status, list) + assert isinstance(status[0], StatusType) + assert status == ans + + +def _check_costs(costs): + """ Based on runhistory.json """ + ans = [0.15204678362573099, 0.4444444444444444, 0.5555555555555556, 0.29824561403508776, + 0.4444444444444444, 0.4444444444444444, 1.0, 0.5555555555555556, 0.4444444444444444, + 0.15204678362573099, 0.15204678362573099, 0.4035087719298246, 0.4444444444444444, + 0.4444444444444444, 1.0, 1.0] + assert np.allclose(1 - np.array(costs), ans) + assert isinstance(costs, np.ndarray) + assert costs.dtype is np.dtype(np.float) + + +def _check_end_times(end_times): + """ Based on runhistory.json """ + ans = [1637342642.7887495, 1637342647.2651122, 1637342675.2555833, 1637342681.334954, + 1637342693.2717755, 1637342704.341065, 1637342726.1866672, 1637342743.3274522, + 1637342749.9442234, 1637342762.5487585, 1637342779.192385, 1637342804.3368232, + 1637342820.8067145, 1637342846.0210106, 1637342897.1205413, 1637342928.7456856] + + assert np.allclose(end_times, ans) + assert isinstance(end_times, np.ndarray) + assert end_times.dtype is np.dtype(np.float) + + +def _check_fit_times(fit_times): + """ Based on runhistory.json """ + ans = [3.154788017272949, 3.2763524055480957, 22.723600149154663, 4.990685224533081, 10.684926509857178, + 9.947429180145264, 11.687273979187012, 8.478890419006348, 5.485020637512207, 11.514830589294434, + 15.370736837387085, 23.846530199050903, 6.757539510726929, 15.061991930007935, 50.010520696640015, + 22.011935234069824] + + assert np.allclose(fit_times, ans) + assert isinstance(fit_times, np.ndarray) + assert fit_times.dtype is np.dtype(np.float) + + +def _check_budgets(budgets): + """ Based on runhistory.json """ + ans = [5.555555555555555, 5.555555555555555, 5.555555555555555, 5.555555555555555, + 5.555555555555555, 5.555555555555555, 5.555555555555555, 5.555555555555555, + 5.555555555555555, 16.666666666666664, 50.0, 16.666666666666664, 16.666666666666664, + 16.666666666666664, 50.0, 50.0] + assert np.allclose(budgets, ans) + assert isinstance(budgets, list) + assert isinstance(budgets[0], float) + + +def _check_additional_infos(status_types, additional_infos): + for i, status in enumerate(status_types): + info = additional_infos[i] + if status in (StatusType.SUCCESS, StatusType.DONOTADVANCE): + metric_info = info.get('opt_loss', None) + assert metric_info is not None + elif info is not None: + metric_info = info.get('opt_loss', None) + assert metric_info is None + + +def _check_metric_dict(metric_dict, status_types, worst_val): + assert isinstance(metric_dict['accuracy'], list) + assert metric_dict['accuracy'][0] > 0 + assert isinstance(metric_dict['balanced_accuracy'], list) + assert metric_dict['balanced_accuracy'][0] > 0 + + for key, vals in metric_dict.items(): + # ^ is a XOR operator + # True and False / False and True must be fulfilled + assert all([(s == StatusType.SUCCESS) ^ np.isclose([val], [worst_val]) + for s, val in zip(status_types, vals)]) + + +def _check_metric_results(scores, metric, run_history, ensemble_performance_history): + if metric.name == 'accuracy': # Check the case when ensemble does not have the metric name + dummy_history = [{'Timestamp': datetime(2000, 1, 1), 'train_log_loss': 1, 'test_log_loss': 1}] + mr = MetricResults(metric, run_history, dummy_history) + # ensemble_results should be None because ensemble evaluated log_loss + assert mr.ensemble_results.empty() + data = mr.get_ensemble_merged_data() + # since ensemble_results is None, merged_data must be identical to the run_history data + assert all(np.allclose(data[key], mr.data[key]) for key in data.keys()) + + mr = MetricResults(metric, run_history, ensemble_performance_history) + perfs = np.array([cost2metric(s, metric) for s in scores]) + modified_scores = scores[::2] + [0] + modified_scores.insert(2, 0) + ens_perfs = np.array([s for s in modified_scores]) + assert np.allclose(mr.data[f'single::train::{metric.name}'], perfs) + assert np.allclose(mr.data[f'single::opt::{metric.name}'], perfs) + assert np.allclose(mr.data[f'single::test::{metric.name}'], perfs) + assert np.allclose(mr.data[f'ensemble::train::{metric.name}'], ens_perfs) + assert np.allclose(mr.data[f'ensemble::test::{metric.name}'], ens_perfs) + + # the end times of synthetic ensemble is [0.25, 0.45, 0.45, 0.65, 0.85, 0.85] + # the end times of synthetic run history is 0.1 * np.arange(1, 9) or 0.1 * np.arange(2, 10) + ensemble_ends_later = mr.search_results.end_times[-1] < mr.ensemble_results.end_times[-1] + indices = [2, 4, 4, 6, 8, 8] if ensemble_ends_later else [1, 3, 3, 5, 7, 7] + + merged_data = mr.get_ensemble_merged_data() + worst_val = metric._worst_possible_result + minimize = metric._sign == -1 + ans = np.full_like(mr.cum_times, worst_val) + for idx, s in zip(indices, mr.ensemble_results.train_scores): + ans[idx] = min(ans[idx], s) if minimize else max(ans[idx], s) + + assert np.allclose(ans, merged_data[f'ensemble::train::{metric.name}']) + assert np.allclose(ans, merged_data[f'ensemble::test::{metric.name}']) + + +def test_extract_results_from_run_history(): + # test the raise error for the `status_msg is None` + run_history = RunHistory() + cs = ConfigurationSpace() + config = Configuration(cs, {}) + run_history.add( + config=config, + cost=0.0, + time=1.0, + status=StatusType.CAPPED, + ) + with pytest.raises(ValueError) as excinfo: + SearchResults(metric=accuracy, scoring_functions=[], run_history=run_history) + + assert excinfo._excinfo[0] == ValueError + + +def test_raise_error_in_update_and_sort_by_time(): + cs = ConfigurationSpace() + cs.add_hyperparameter(CSH.UniformFloatHyperparameter('a', lower=0, upper=1)) + config = Configuration(cs, {'a': 0.1}) + + sr = SearchResults(metric=accuracy, scoring_functions=[], run_history=RunHistory()) + er = EnsembleResults(metric=accuracy, ensemble_performance_history=[]) + + with pytest.raises(RuntimeError) as excinfo: + sr._update( + config=config, + run_key=RunKey(config_id=0, instance_id=0, seed=0), + run_value=RunValue( + cost=0, time=1, status=StatusType.SUCCESS, + starttime=0, endtime=1, additional_info={} + ) + ) + + assert excinfo._excinfo[0] == RuntimeError + + with pytest.raises(RuntimeError) as excinfo: + sr._sort_by_endtime() + + assert excinfo._excinfo[0] == RuntimeError + + with pytest.raises(RuntimeError) as excinfo: + er._update(data={}) + + assert excinfo._excinfo[0] == RuntimeError + + with pytest.raises(RuntimeError) as excinfo: + er._sort_by_endtime() + + +@pytest.mark.parametrize('starttimes', (list(range(10)), list(range(10))[::-1])) +@pytest.mark.parametrize('status_types', ( + [StatusType.SUCCESS] * 9 + [StatusType.STOP], + [StatusType.RUNNING] + [StatusType.SUCCESS] * 9 +)) +def test_get_start_time(starttimes, status_types): + run_history = RunHistory() + cs = ConfigurationSpace() + cs.add_hyperparameter(CSH.UniformFloatHyperparameter('a', lower=0, upper=1)) + endtime = 1e9 + kwargs = dict(cost=1.0, endtime=endtime) + for starttime, status_type in zip(starttimes, status_types): + config = Configuration(cs, {'a': 0.1 * starttime}) + run_history.add( + config=config, + starttime=starttime, + time=endtime - starttime, + status=status_type, + **kwargs + ) + starttime = get_start_time(run_history) + + # this rule is strictly defined on the inputs defined from pytest + ans = min(t for s, t in zip(status_types, starttimes) if s == StatusType.SUCCESS) + assert starttime == ans + + +def test_raise_error_in_get_start_time(): + # test the raise error for the `status_msg is None` + run_history = RunHistory() + cs = ConfigurationSpace() + config = Configuration(cs, {}) + run_history.add( + config=config, + cost=0.0, + time=1.0, + status=StatusType.CAPPED, + ) + + with pytest.raises(ValueError) as excinfo: + get_start_time(run_history) + + assert excinfo._excinfo[0] == ValueError + + +def test_search_results_sort_by_endtime(): + run_history = RunHistory() + n_configs = len(SCORES) + cs = ConfigurationSpace() + cs.add_hyperparameter(CSH.UniformFloatHyperparameter('a', lower=0, upper=1)) + order = np.argsort(END_TIMES) + ans = np.array(SCORES)[order].tolist() + status_types = [StatusType.SUCCESS, StatusType.DONOTADVANCE] * (n_configs // 2) + + for i, (fixed_val, et, status) in enumerate(zip(SCORES, END_TIMES, status_types)): + config = Configuration(cs, {'a': fixed_val}) + run_history.add( + config=config, cost=fixed_val, + status=status, budget=fixed_val, + time=et - fixed_val, starttime=fixed_val, endtime=et, + additional_info={ + 'a': fixed_val, + 'configuration_origin': [T, NT][i % 2], + 'train_loss': {accuracy.name: fixed_val - 0.1}, + 'opt_loss': {accuracy.name: fixed_val}, + 'test_loss': {accuracy.name: fixed_val + 0.1} + } + ) + + sr = SearchResults(accuracy, scoring_functions=[], run_history=run_history, order_by_endtime=True) + assert sr.budgets == ans + assert np.allclose(accuracy._optimum - accuracy._sign * sr.opt_scores, ans) + assert np.allclose(accuracy._optimum - accuracy._sign * sr.train_scores, np.array(ans) - accuracy._sign * 0.1) + assert np.allclose(accuracy._optimum - accuracy._sign * sr.test_scores, np.array(ans) + accuracy._sign * 0.1) + assert np.allclose(1 - sr.opt_scores, ans) + assert sr._end_times == list(range(n_configs)) + assert all(c.get('a') == val for val, c in zip(ans, sr.configs)) + assert all(info['a'] == val for val, info in zip(ans, sr.additional_infos)) + assert np.all(np.array([s for s in status_types])[order] == np.array(sr.status_types)) + assert sr.is_traditionals == np.array([True, False] * 5)[order].tolist() + assert np.allclose(sr.fit_times, np.subtract(np.arange(n_configs), ans)) + + +def test_ensemble_results(): + order = np.argsort(END_TIMES) + end_times = [datetime.timestamp(datetime(2000, et + 1, 1)) for et in END_TIMES] + ensemble_performance_history = [ + {'Timestamp': datetime(2000, et + 1, 1), 'train_accuracy': s1, 'test_accuracy': s2} + for et, s1, s2 in zip(END_TIMES, SCORES, SCORES[::-1]) + ] + + er = EnsembleResults(log_loss, ensemble_performance_history) + assert er.empty() + + er = EnsembleResults(accuracy, ensemble_performance_history) + assert er._train_scores == SCORES + assert np.allclose(er.train_scores, SCORES) + assert er._test_scores == SCORES[::-1] + assert np.allclose(er.test_scores, SCORES[::-1]) + assert np.allclose(er.end_times, end_times) + + er = EnsembleResults(accuracy, ensemble_performance_history, order_by_endtime=True) + assert np.allclose(er.train_scores, np.array(SCORES)[order]) + assert np.allclose(er.test_scores, np.array(SCORES[::-1])[order]) + assert np.allclose(er.end_times, np.array(end_times)[order]) + + +@pytest.mark.parametrize('metric', (accuracy, log_loss)) +@pytest.mark.parametrize('scores', (SCORES[:8], SCORES[:8][::-1])) +@pytest.mark.parametrize('ensemble_ends_later', (True, False)) +def test_metric_results(metric, scores, ensemble_ends_later): + # since datetime --> timestamp variates between machines and float64 might not + # be able to handle time precisely enough, we might need to change t0 in the future. + # Basically, it happens because this test is checking by the precision of milli second + t0, ms_unit = (1970, 1, 1, 9, 0, 0), 100000 + ensemble_performance_history = [ + {'Timestamp': datetime(*t0, ms_unit * 2 * (i + 1) + ms_unit // 2), + f'train_{metric.name}': s, + f'test_{metric.name}': s} + for i, s in enumerate(scores[::2]) + ] + # Add a record with the exact same stamp as the last one + ensemble_performance_history.append( + {'Timestamp': datetime(*t0, ms_unit * 8 + ms_unit // 2), + f'train_{metric.name}': 0, + f'test_{metric.name}': 0} + ) + # Add a record with the exact same stamp as a middle one + ensemble_performance_history.append( + {'Timestamp': datetime(*t0, ms_unit * 4 + ms_unit // 2), + f'train_{metric.name}': 0, + f'test_{metric.name}': 0} + ) + + run_history = RunHistory() + cs = ConfigurationSpace() + cs.add_hyperparameter(CSH.UniformFloatHyperparameter('a', lower=0, upper=1)) + + for i, fixed_val in enumerate(scores): + config = Configuration(cs, {'a': fixed_val}) + st = datetime.timestamp(datetime(*t0, ms_unit * (i + 1 - ensemble_ends_later))) + et = datetime.timestamp(datetime(*t0, ms_unit * (i + 2 - ensemble_ends_later))) + run_history.add( + config=config, cost=1, budget=0, + time=0.1, starttime=st, endtime=et, + status=StatusType.SUCCESS, + additional_info={ + 'configuration_origin': T, + 'train_loss': {f'{metric.name}': fixed_val}, + 'opt_loss': {f'{metric.name}': fixed_val}, + 'test_loss': {f'{metric.name}': fixed_val} + } + ) + _check_metric_results(scores, metric, run_history, ensemble_performance_history) + + +def test_search_results_sprint_statistics(): + api = BaseTask() + for method in ['get_search_results', 'sprint_statistics', 'get_incumbent_results']: + with pytest.raises(RuntimeError) as excinfo: + getattr(api, method)() + + assert excinfo._excinfo[0] == RuntimeError + + run_history_data = json.load(open(os.path.join(os.path.dirname(__file__), + 'runhistory.json'), + mode='r'))['data'] + api._results_manager.run_history = MagicMock() + api.run_history.empty = MagicMock(return_value=False) + + # The run_history has 16 runs + 1 run interruption ==> 16 runs + api.run_history.data = make_dict_run_history_data(run_history_data) + api._metric = accuracy + api.dataset_name = 'iris' + api._scoring_functions = [accuracy, balanced_accuracy] + api.search_space = MagicMock(spec=ConfigurationSpace) + worst_val = api._metric._worst_possible_result + search_results = api.get_search_results() + + _check_status(search_results.status_types) + _check_costs(search_results.opt_scores) + _check_end_times(search_results.end_times) + _check_fit_times(search_results.fit_times) + _check_budgets(search_results.budgets) + _check_metric_dict(search_results.opt_metric_dict, search_results.status_types, worst_val) + _check_additional_infos(status_types=search_results.status_types, + additional_infos=search_results.additional_infos) + + # config_ids can duplicate because of various budget size + config_ids = [1, 2, 3, 4, 5, 6, 7, 8, 9, 1, 1, 10, 11, 12, 10, 13] + assert config_ids == search_results.config_ids + + # assert that contents of search_results are of expected types + assert isinstance(search_results.rank_opt_scores, np.ndarray) + assert search_results.rank_opt_scores.dtype is np.dtype(np.int) + assert isinstance(search_results.configs, list) + + n_success, n_timeout, n_memoryout, n_crashed = 13, 2, 0, 1 + msg = ["autoPyTorch results:", f"\tDataset name: {api.dataset_name}", + f"\tOptimisation Metric: {api._metric.name}", + f"\tBest validation score: {max(search_results.opt_scores)}", + "\tNumber of target algorithm runs: 16", f"\tNumber of successful target algorithm runs: {n_success}", + f"\tNumber of crashed target algorithm runs: {n_crashed}", + f"\tNumber of target algorithms that exceeded the time limit: {n_timeout}", + f"\tNumber of target algorithms that exceeded the memory limit: {n_memoryout}"] + + assert isinstance(api.sprint_statistics(), str) + assert all([m1 == m2 for m1, m2 in zip(api.sprint_statistics().split("\n"), msg)]) + + +@pytest.mark.parametrize('run_history', (None, RunHistory())) +def test_check_run_history(run_history): + manager = ResultsManager() + manager.run_history = run_history + + with pytest.raises(RuntimeError) as excinfo: + manager._check_run_history() + + assert excinfo._excinfo[0] == RuntimeError + + +@pytest.mark.parametrize('include_traditional', (True, False)) +@pytest.mark.parametrize('metric', (accuracy, log_loss)) +@pytest.mark.parametrize('origins', ([T] * 5 + [NT] * 5, [T, NT] * 5, [NT] * 5 + [T] * 5)) +@pytest.mark.parametrize('scores', (SCORES, SCORES[::-1])) +def test_get_incumbent_results(include_traditional, metric, origins, scores): + manager = ResultsManager() + cs = ConfigurationSpace() + cs.add_hyperparameter(CSH.UniformFloatHyperparameter('a', lower=0, upper=1)) + + configs = [0.1 * (i + 1) for i in range(len(scores))] + if metric.name == "log_loss": + # This is to detect mis-computation in reversion + metric._optimum = 0.1 + + best_cost, best_idx = np.inf, -1 + for idx, (a, origin, score) in enumerate(zip(configs, origins, scores)): + config = Configuration(cs, {'a': a}) + + # conversion defined in: + # autoPyTorch/pipeline/components/training/metrics/utils.py::calculate_loss + cost = metric._optimum - metric._sign * score + manager.run_history.add( + config=config, + cost=cost, + time=1.0, + status=StatusType.SUCCESS, + additional_info={'train_loss': {metric.name: cost}, + 'opt_loss': {metric.name: cost}, + 'test_loss': {metric.name: cost}, + 'configuration_origin': origin} + ) + if cost > best_cost: + continue + + if include_traditional: + best_cost, best_idx = cost, idx + elif origin != T: + best_cost, best_idx = cost, idx + + incumbent_config, incumbent_results = manager.get_incumbent_results( + metric=metric, + include_traditional=include_traditional + ) + + assert isinstance(incumbent_config, Configuration) + assert isinstance(incumbent_results, dict) + best_score, best_a = scores[best_idx], configs[best_idx] + assert np.allclose( + [best_score, best_score, best_a], + [cost2metric(best_cost, metric), + cost2metric(incumbent_results['opt_loss'][metric.name], metric), + incumbent_config['a']] + ) + + if not include_traditional: + assert incumbent_results['configuration_origin'] != T diff --git a/test/test_utils/test_results_visualizer.py b/test/test_utils/test_results_visualizer.py new file mode 100644 index 000000000..926d21e6f --- /dev/null +++ b/test/test_utils/test_results_visualizer.py @@ -0,0 +1,274 @@ +import json +import os +from datetime import datetime +from test.test_api.utils import make_dict_run_history_data +from unittest.mock import MagicMock + +from ConfigSpace import ConfigurationSpace + +import matplotlib.pyplot as plt + +import numpy as np + +import pytest + +from autoPyTorch.api.base_task import BaseTask +from autoPyTorch.metrics import accuracy, balanced_accuracy +from autoPyTorch.utils.results_visualizer import ( + ColorLabelSettings, + PlotSettingParams, + ResultsVisualizer, + _get_perf_and_time +) + + +TEST_CL = ('test color', 'test label') + + +@pytest.mark.parametrize('cl_settings', ( + ColorLabelSettings(single_opt=TEST_CL), + ColorLabelSettings(single_opt=TEST_CL, single_test=None, single_train=None) +)) +@pytest.mark.parametrize('with_ensemble', (True, False)) +def test_extract_dicts(cl_settings, with_ensemble): + dummy_keys = [name for name in [ + 'single::train::dummy', + 'single::opt::dummy', + 'single::test::dummy', + 'ensemble::train::dummy', + 'ensemble::test::dummy' + ] if ( + (with_ensemble or not name.startswith('ensemble')) + and getattr(cl_settings, "_".join(name.split('::')[:2])) is not None + ) + ] + + results = MagicMock() + results.data.keys = MagicMock(return_value=dummy_keys) + cd, ld = cl_settings.extract_dicts(results) + assert set(dummy_keys) == set(cd.keys()) + assert set(dummy_keys) == set(ld.keys()) + + opt_key = 'single::opt::dummy' + assert TEST_CL == (cd[opt_key], ld[opt_key]) + + +@pytest.mark.parametrize('params', ( + PlotSettingParams(show=True), + PlotSettingParams(show=False) +)) +def test_plt_show_in_set_plot_args(params): # TODO + plt.show = MagicMock() + _, ax = plt.subplots(nrows=1, ncols=1) + viz = ResultsVisualizer() + + viz._set_plot_args(ax, params) + assert plt.show._mock_called == params.show + plt.close() + + +@pytest.mark.parametrize('params', ( + PlotSettingParams(xscale='none', yscale='none'), + PlotSettingParams(xscale='none', yscale='log'), + PlotSettingParams(xscale='none', yscale='none'), + PlotSettingParams(xscale='none', yscale='log') +)) +def test_raise_value_error_in_set_plot_args(params): # TODO + _, ax = plt.subplots(nrows=1, ncols=1) + viz = ResultsVisualizer() + + with pytest.raises(ValueError) as excinfo: + viz._set_plot_args(ax, params) + + assert excinfo._excinfo[0] == ValueError + plt.close() + + +@pytest.mark.parametrize('params', ( + PlotSettingParams(xlim=(-100, 100), ylim=(-200, 200)), + PlotSettingParams(xlabel='x label', ylabel='y label'), + PlotSettingParams(xscale='log', yscale='log'), + PlotSettingParams(legend=False, title='Title') +)) +def test_set_plot_args(params): # TODO + _, ax = plt.subplots(nrows=1, ncols=1) + viz = ResultsVisualizer() + viz._set_plot_args(ax, params) + + if params.xlim is not None: + assert ax.get_xlim() == params.xlim + if params.ylim is not None: + assert ax.get_ylim() == params.ylim + + assert ax.xaxis.get_label()._text == ('' if params.xlabel is None else params.xlabel) + assert ax.yaxis.get_label()._text == ('' if params.ylabel is None else params.ylabel) + assert ax.get_title() == ('' if params.title is None else params.title) + assert params.xscale == ax.get_xscale() + assert params.yscale == ax.get_yscale() + + if params.legend: + assert ax.get_legend() is not None + else: + assert ax.get_legend() is None + + plt.close() + + +@pytest.mark.parametrize('metric_name', ('unknown', 'accuracy')) +def test_raise_error_in_plot_perf_over_time_in_base_task(metric_name): + api = BaseTask() + + if metric_name == 'unknown': + with pytest.raises(ValueError) as excinfo: + api.plot_perf_over_time(metric_name) + assert excinfo._excinfo[0] == ValueError + else: + with pytest.raises(RuntimeError) as excinfo: + api.plot_perf_over_time(metric_name) + assert excinfo._excinfo[0] == RuntimeError + + +@pytest.mark.parametrize('metric_name', ('balanced_accuracy', 'accuracy')) +def test_plot_perf_over_time(metric_name): # TODO + dummy_history = [{'Timestamp': datetime(2022, 1, 1), 'train_accuracy': 1, 'test_accuracy': 1}] + api = BaseTask() + run_history_data = json.load(open(os.path.join(os.path.dirname(__file__), + 'runhistory.json'), + mode='r'))['data'] + api._results_manager.run_history = MagicMock() + api.run_history.empty = MagicMock(return_value=False) + + # The run_history has 16 runs + 1 run interruption ==> 16 runs + api.run_history.data = make_dict_run_history_data(run_history_data) + api._results_manager.ensemble_performance_history = dummy_history + api._metric = accuracy + api.dataset_name = 'iris' + api._scoring_functions = [accuracy, balanced_accuracy] + api.search_space = MagicMock(spec=ConfigurationSpace) + + api.plot_perf_over_time(metric_name=metric_name) + _, ax = plt.subplots(nrows=1, ncols=1) + api.plot_perf_over_time(metric_name=metric_name, ax=ax) + + # remove ensemble keys if metric name is not for the opt score + ans = set([ + name + for name in [f'single train {metric_name}', + f'single test {metric_name}', + f'single opt {metric_name}', + f'ensemble train {metric_name}', + f'ensemble test {metric_name}'] + if metric_name == api._metric.name or not name.startswith('ensemble') + ]) + legend_set = set([txt._text for txt in ax.get_legend().texts]) + assert ans == legend_set + plt.close() + + +@pytest.mark.parametrize('params', ( + PlotSettingParams(xscale='none', yscale='none'), + PlotSettingParams(xscale='none', yscale='log'), + PlotSettingParams(xscale='log', yscale='none'), + PlotSettingParams(yscale='log') +)) +def test_raise_error_get_perf_and_time(params): + results = np.linspace(-1, 1, 10) + cum_times = np.linspace(0, 1, 10) + + with pytest.raises(ValueError) as excinfo: + _get_perf_and_time( + cum_results=results, + cum_times=cum_times, + plot_setting_params=params, + worst_val=np.inf + ) + + assert excinfo._excinfo[0] == ValueError + + +@pytest.mark.parametrize('params', ( + PlotSettingParams(n_points=20, xscale='linear', yscale='linear'), + PlotSettingParams(n_points=20, xscale='log', yscale='log') +)) +def test_get_perf_and_time(params): + y_min, y_max = 1e-5, 1 + results = np.linspace(y_min, y_max, 10) + cum_times = np.linspace(y_min, y_max, 10) + + check_points, perf_by_time_step = _get_perf_and_time( + cum_results=results, + cum_times=cum_times, + plot_setting_params=params, + worst_val=np.inf + ) + + times_ans = np.linspace( + y_min if params.xscale == 'linear' else np.log(y_min), + y_max if params.xscale == 'linear' else np.log(y_max), + params.n_points + ) + times_ans = times_ans if params.xscale == 'linear' else np.exp(times_ans) + assert np.allclose(check_points, times_ans) + + if params.xscale == 'linear': + """ + each time step to record the result + [1.00000000e-05, 5.26410526e-02, 1.05272105e-01, 1.57903158e-01, + 2.10534211e-01, 2.63165263e-01, 3.15796316e-01, 3.68427368e-01, + 4.21058421e-01, 4.73689474e-01, 5.26320526e-01, 5.78951579e-01, + 6.31582632e-01, 6.84213684e-01, 7.36844737e-01, 7.89475789e-01, + 8.42106842e-01, 8.94737895e-01, 9.47368947e-01, 1.00000000e+00] + + The time steps when each result was recorded + [ + 1.0000e-05, # cover index 0 ~ 2 + 1.1112e-01, # cover index 3, 4 + 2.2223e-01, # cover index 5, 6 + 3.3334e-01, # cover index 7, 8 + 4.4445e-01, # cover index 9, 10 + 5.5556e-01, # cover index 11, 12 + 6.6667e-01, # cover index 13, 14 + 7.7778e-01, # cover index 15, 16 + 8.8889e-01, # cover index 17, 18 + 1.0000e+00 # cover index 19 + ] + Since the sequence is monotonically increasing, + if multiple elements cover the same index, take the best. + """ + results_ans = [r for r in results] + results_ans = [results[0]] + results_ans + results_ans[:-1] + results_ans = np.sort(results_ans) + else: + """ + each time step to record the result + [1.00000000e-05, 1.83298071e-05, 3.35981829e-05, 6.15848211e-05, + 1.12883789e-04, 2.06913808e-04, 3.79269019e-04, 6.95192796e-04, + 1.27427499e-03, 2.33572147e-03, 4.28133240e-03, 7.84759970e-03, + 1.43844989e-02, 2.63665090e-02, 4.83293024e-02, 8.85866790e-02, + 1.62377674e-01, 2.97635144e-01, 5.45559478e-01, 1.00000000e+00] + + The time steps when each result was recorded + [ + 1.0000e-05, # cover index 0 ~ 15 + 1.1112e-01, # cover index 16 + 2.2223e-01, # cover index 17 + 3.3334e-01, # cover index 18 + 4.4445e-01, # cover index 18 + 5.5556e-01, # cover index 19 + 6.6667e-01, # cover index 19 + 7.7778e-01, # cover index 19 + 8.8889e-01, # cover index 19 + 1.0000e+00 # cover index 19 + ] + Since the sequence is monotonically increasing, + if multiple elements cover the same index, take the best. + """ + results_ans = [ + *([results[0]] * 16), + results[1], + results[2], + results[4], + results[-1] + ] + + assert np.allclose(perf_by_time_step, results_ans)