Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[feat] Add an object that realizes the perf over time viz #331

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
36 commits
Select commit Hold shift + click to select a range
aa9d35a
[feat] Add an object that realizes the perf over time viz
nabenabe0928 Nov 16, 2021
031ce44
[fix] Modify TODOs and add comments to avoid complications
nabenabe0928 Nov 17, 2021
b9636a6
[refactor] [feat] Format visualizer API and integrate this feature in…
nabenabe0928 Nov 17, 2021
fbe7cd3
[refactor] Separate a shared raise error process as a function
nabenabe0928 Nov 17, 2021
bc98805
[refactor] Gather params in Dataclass to look smarter
nabenabe0928 Nov 18, 2021
517ccf1
[refactor] Merge extraction from history to the result manager
nabenabe0928 Nov 24, 2021
29b4aa6
[feat] Merge the viz in the latest version
nabenabe0928 Nov 24, 2021
af6f375
[fix] Fix nan --> worst val so that we can always handle by number
nabenabe0928 Nov 24, 2021
a2084e9
[fix] Fix mypy issues
nabenabe0928 Nov 24, 2021
4039a24
[test] Add test for get_start_time
nabenabe0928 Nov 25, 2021
a44cb86
[test] Add test for order by end time
nabenabe0928 Nov 25, 2021
0b15f66
[test] Add tests for ensemble results
nabenabe0928 Nov 25, 2021
9be8a51
[test] Add tests for merging ensemble results and run history
nabenabe0928 Nov 25, 2021
367b3a4
[test] Add the tests in the case of ensemble_results is None
nabenabe0928 Nov 25, 2021
2a2c564
[fix] Alternate datetime to timestamp in tests to pass universally
nabenabe0928 Nov 26, 2021
27dfbff
[fix] Fix status_msg --> status_type because it does not need to be str
nabenabe0928 Nov 26, 2021
8b1dec5
[fix] Change the name for the homogeniety
nabenabe0928 Nov 29, 2021
4fd6245
[fix] Fix based on the file name change
nabenabe0928 Nov 29, 2021
37ebf7e
[test] Add tests for set_plot_args
nabenabe0928 Nov 29, 2021
d050fd8
[test] Add tests for plot_perf_over_time in BaseTask
nabenabe0928 Nov 29, 2021
2777e1b
[refactor] Replace redundant lines by pytest parametrization
nabenabe0928 Nov 29, 2021
020d7fb
[test] Add tests for _get_perf_and_time
nabenabe0928 Nov 29, 2021
f06eb97
[fix] Remove viz attribute based on Ravin's comment
nabenabe0928 Nov 30, 2021
ca36dc1
[fix] Fix doc-string based on Ravin's comments
nabenabe0928 Nov 30, 2021
38055c8
[refactor] Hide color label settings extraction in dataclass
nabenabe0928 Nov 30, 2021
08e9e12
[test] Add tests for color label dicts extraction
nabenabe0928 Nov 30, 2021
362b61f
[test] Add tests for checking if plt.show is called or not
nabenabe0928 Nov 30, 2021
b5da0d6
[refactor] Address Ravin's comments and add TODO for the refactoring
nabenabe0928 Nov 30, 2021
18e8d1e
[refactor] Change KeyError in EnsembleResults to empty
nabenabe0928 Nov 30, 2021
bef13be
[refactor] Prohibit external updates to make objects more robust
nabenabe0928 Nov 30, 2021
6642e29
[fix] Remove a member variable _opt_scores since it is confusing
nabenabe0928 Nov 30, 2021
45fc875
[example] Add an example how to plot performance over time
nabenabe0928 Nov 30, 2021
a5b37a2
[fix] Fix unexpected train loss when using cross validation
nabenabe0928 Nov 30, 2021
9585b5f
[fix] Remove __main__ from example based on the Ravin's comment
nabenabe0928 Dec 1, 2021
d596c03
[fix] Move results_xxx to utils from API
nabenabe0928 Dec 1, 2021
4fe5ae8
[enhance] Change example for the plot over time to save fig
nabenabe0928 Dec 1, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 58 additions & 1 deletion autoPyTorch/api/base_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@

import joblib

import matplotlib.pyplot as plt

import numpy as np

import pandas as pd
Expand All @@ -29,7 +31,7 @@
from smac.stats.stats import Stats
from smac.tae import StatusType

from autoPyTorch.api.results_manager import ResultsManager, SearchResults
from autoPyTorch import metrics
from autoPyTorch.automl_common.common.utils.backend import Backend, create
from autoPyTorch.constants import (
REGRESSION_TASKS,
Expand Down Expand Up @@ -58,6 +60,8 @@
)
from autoPyTorch.utils.parallel import preload_modules
from autoPyTorch.utils.pipeline import get_configuration_space, get_dataset_requirements
from autoPyTorch.utils.results_manager import MetricResults, ResultsManager, SearchResults
from autoPyTorch.utils.results_visualizer import ColorLabelSettings, PlotSettingParams, ResultsVisualizer
from autoPyTorch.utils.single_thread_client import SingleThreadedClient
from autoPyTorch.utils.stopwatch import StopWatch

Expand Down Expand Up @@ -1479,3 +1483,56 @@ def sprint_statistics(self) -> str:
scoring_functions=self._scoring_functions,
metric=self._metric
)

def plot_perf_over_time(
self,
metric_name: str,
ax: Optional[plt.Axes] = None,
plot_setting_params: PlotSettingParams = PlotSettingParams(),
color_label_settings: ColorLabelSettings = ColorLabelSettings(),
*args: Any,
**kwargs: Any
) -> None:
"""
Visualize the performance over time using matplotlib.
The plot related arguments are based on matplotlib.
Please refer to the matplotlib documentation for more details.

Args:
metric_name (str):
The name of metric to visualize.
The names are available in
* autoPyTorch.metrics.CLASSIFICATION_METRICS
* autoPyTorch.metrics.REGRESSION_METRICS
ax (Optional[plt.Axes]):
axis to plot (subplots of matplotlib).
If None, it will be created automatically.
plot_setting_params (PlotSettingParams):
Parameters for the plot.
color_label_settings (ColorLabelSettings):
The settings of a pair of color and label for each plot.
args, kwargs (Any):
Arguments for the ax.plot.
"""

if not hasattr(metrics, metric_name):
raise ValueError(
f'metric_name must be in {list(metrics.CLASSIFICATION_METRICS.keys())} '
f'or {list(metrics.REGRESSION_METRICS.keys())}, but got {metric_name}'
)
if len(self.ensemble_performance_history) == 0:
raise RuntimeError('Visualization is available only after ensembles are evaluated.')

results = MetricResults(
metric=getattr(metrics, metric_name),
run_history=self.run_history,
ensemble_performance_history=self.ensemble_performance_history
)

colors, labels = color_label_settings.extract_dicts(results)

ResultsVisualizer().plot_perf_over_time( # type: ignore
results=results, plot_setting_params=plot_setting_params,
colors=colors, labels=labels, ax=ax,
*args, **kwargs
)
Loading