-
Notifications
You must be signed in to change notification settings - Fork 1
/
do_hpo.py
104 lines (93 loc) · 3.76 KB
/
do_hpo.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
from ConfigSpace import Configuration, ConfigurationSpace
from ConfigSpace.read_and_write import json as config_space_json_r_w
from smac import HyperparameterOptimizationFacade, Scenario
import json
import os
import time
import argparse
import utils
train_paths = "configs/data_splits/default_split/train_paths.json"
val_paths = "configs/data_splits/default_split/val_paths.json"
test_paths = "configs/data_splits/default_split/test_paths.json"
class HPOModel:
def __init__(
self,
model: str,
dataset_root: str,
log_dir: str,
data_config_path: str,
model_config_path: str,
seed: int = 0,
device=None,
metric=None
):
self.dataset_root = dataset_root
self.data_config = json.load(open(data_config_path, "r"))
self.model_configspace = config_space_json_r_w.read(open(model_config_path, 'r').read())# utils.get_model_configspace(model)
self.device = device
self.metric = metric
self.log_dir = log_dir
def get_model_configspace(self) -> ConfigurationSpace:
return self.model_configspace
def do_hpo(self, model_config: Configuration, seed: int = 0) -> float:
model_config = model_config.get_dictionary()
surrogate_model = utils.model_dict[args.model](
data_root=self.dataset_root,
log_dir=self.log_dir,
seed=seed,
model_config=model_config,
data_config=self.data_config,
device=self.device,
metric=self.metric
)
# surrogate_model.train_paths = json.load(open(train_paths, 'r'))
# surrogate_model.val_paths = json.load(open(val_paths, 'r'))
# surrogate_model.test_paths = json.load(open(test_paths, 'r'))
valid_metrics = surrogate_model.train()
# print("=*=*=*=" * 10)
# print(valid_metrics["rmse"])
return valid_metrics["rmse"]
if __name__ == "__main__":
parser = argparse.ArgumentParser("Surrogate HPO")
parser.add_argument("--dataset_root", type=str, help="Path to dataset root dir")
parser.add_argument("--model", type=str, help="Surrogate model")
parser.add_argument("--data_config_path", type=str, help="Data config path")
parser.add_argument("--model_config_path", type=str, help="Model config path")
parser.add_argument("--device", default=None, help="Device, None if acc")
parser.add_argument("--metric", default=None, help="Metric, None if acc")
parser.add_argument(
"--log_dir",
default="experiments/hpo_surrogate",
type=str,
help="Log directory",
)
parser.add_argument("--seed", type=int, default=6, help="Seed")
args = parser.parse_args()
if args.device is not None:
assert args.metric is not None, f"Missing metric for {args.device}"
args.log_dir = os.path.join(args.log_dir, args.model, args.device, args.metric)
else:
args.log_dir = os.path.join(args.log_dir, args.model)
args.log_dir = os.path.join(
args.log_dir, "{}-{}".format(time.strftime("%Y%m%d-%H%M%S"), args.seed)
)
os.makedirs(args.log_dir)
hpo_model = HPOModel(
args.model,
args.dataset_root,
args.log_dir,
args.data_config_path,
args.model_config_path,
seed=args.seed,
device=args.device,
metric=args.metric,
)
scenario = Scenario(
hpo_model.get_model_configspace(), deterministic=True, walltime_limit=600# n_trials=200
)
smac = HyperparameterOptimizationFacade(scenario, hpo_model.do_hpo)
incumbent = smac.optimize()
print('Final Incumbent:')
print(incumbent)
save_file = os.path.join(args.log_dir, f'{args.model}_{args.device}_{args.metric}_config.json')
json.dump(incumbent.get_dictionary(), open(save_file, 'w'))