-
Notifications
You must be signed in to change notification settings - Fork 0
/
trainer.py
412 lines (318 loc) · 15.3 KB
/
trainer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
import wandb
import sys
import numpy as np
import torch
from tensor_dataloader import TensorDataLoader
from dataset import *
from param import *
import time
from torch.optim.lr_scheduler import StepLR
import time
from loss import *
from sklearn.metrics import ndcg_score
def evaluate_mse(prediction, truth):
pred = prediction.detach().cpu()
truth_np = truth.detach().cpu().numpy()
mse = (np.square(pred - truth_np)).mean()
mae = (np.absolute(pred - truth_np)).mean()
return mse, mae
def evaluate_ndcg(model, hr_map, num_entity):
tester = Tester(model, num_entity)
mean_linear_ndcg, mean_exp_ndcg = tester.mean_ndcg(hr_map)
return mean_linear_ndcg, mean_exp_ndcg
def test_func(test_data, device, model, params, threshold=None, neg_mse_also=False, ndcg_also=False):
# return mse, mae, neg_mse, ndcg
# neg also: to test negative samples separately (used for validation)
data = TensorDataLoader(test_data, batch_size=test_data.length)
for ids, cls in data:
ids, cls = ids.to(device), cls.to(device)
# special_subset_indices = (ids[:, 1] == special_rel_index).nonzero().squeeze(1)
with torch.no_grad():
prediction, truth = model(ids, cls)
score = prediction
label = truth
mse, mae = evaluate_mse(torch.exp(score), label)
ndcg = None
if ndcg_also:
ndcg = evaluate_ndcg(model, params.hr_map, params.VOCAB_SIZE)
if not neg_mse_also:
return mse, mae, None, ndcg
# test for negative samples
negative_samples, neg_probs = model.random_negative_sampling(ids, cls, neg_per_pos=1)
neg_prediction, _ = model(negative_samples, neg_probs)
neg_mse, neg_mae = evaluate_mse(torch.exp(neg_prediction), neg_probs)
combined_mae = (mae+neg_mae)/(1+params.NEG_RATIO) # for validation
return mse, combined_mae, neg_mse, ndcg
def train_func(train_data, train_test_data, dev_data, test_data,
best_metric, optimizer, rule_configs, device, model, params,
verbose=True):
# Train the model
train_loss = 0
batch_size = params.BATCH_SIZE
data = TensorDataLoader(train_data, batch_size=batch_size, shuffle=True)
model.true_head, model.true_tail = train_data.true_head, train_data.true_tail # for negative sampling
step = 0
for ids, cls in data:
model.train()
ids, cls = ids.to(device), cls.to(device)
loss, pos_loss, neg_loss, logic_loss = my_loss(model, ids, cls, params)
train_loss += loss.item()
loss.backward()
optimizer.step()
step += 1
if step % 20 == 0 or (batch_size > 4096 and step % 10 ==0):
wandb.log({'Train Loss': loss})
# test with neg MSE
model.eval()
# test
test_MSE, test_MAE, _, _ = test_func(test_data, device, model, params, ndcg_also=False)
wandb.log({'Test MSE': test_MSE})
wandb.log({'Test MAE': test_MAE})
# validation
valid_pos_mse, valid_mae, valid_neg_mse, _ = test_func(dev_data, device, model, params, neg_mse_also=True)
valid_mse = (valid_pos_mse + params.NEG_RATIO * valid_neg_mse) / (1 + params.NEG_RATIO)
wandb.log({'Valid MSE': valid_mse})
wandb.log({'Valid pos MSE': valid_pos_mse})
wandb.log({'Valid neg MSE': valid_neg_mse})
wandb.log({'Valid MAE': valid_mae})
if verbose:
print(
f'\tLoss: {loss:.10f}(train)\t|pos_loss: {pos_loss:.3f}\t|neg_loss: {neg_loss:.3f}\t|logic_loss:{logic_loss:.3f}')
if params.early_stop == 'valid_mse':
print(f'\t\tTest (with neg) MSE: {test_MSE:.3f}')
print(f'\t\tValid MSE: {valid_mse:.3f}\t|Valid pos: {valid_pos_mse:.3f}|Valid neg: {valid_neg_mse:.3f}')
if test_MSE < best_metric['test_mse']:
best_metric['test_mse'] = test_MSE
wandb.log({'Best Test MSE': best_metric['test_mse']})
if test_MAE < best_metric['test_mae']:
best_metric['test_mae'] = test_MAE
wandb.log({'Best Test MAE': best_metric['test_mae']})
if valid_mse < best_metric['valid_mse']:
best_metric['valid_mse'] = valid_mse
wandb.log({'Best Valid MSE': best_metric['valid_mse']})
if valid_mae < best_metric['valid_mae']:
best_metric['valid_mae'] = valid_mae
wandb.log({'Best Valid MAE': best_metric['valid_mae']})
# test at the end of epoch
test_MSE, test_MAE, _, _ = test_func(test_data, device, model, params, ndcg_also=False)
if params.early_stop == 'valid_mse':
print(f'Test MSE: {test_MSE}')
print(f'Test MAE: {test_MAE}')
return train_loss, best_metric
def run_train(
model, run, train_dataset, train_test_dataset, dev_dataset, test_dataset,
optimizer, params, verbose=True):
best_metric = {
'test_mse': 1,
'valid_mse': 1,
'ndcg': 0,
'test_mae': 100,
'valid_mae': 100
}
last_best_metric = best_metric.copy()
last_best_ndcg = 0
last_best_epoch = 0 # for early stopping
start_time = time.time()
for epoch in range(params.EPOCH):
loss, best_metric = train_func(train_dataset, train_test_dataset, dev_dataset, test_dataset,
best_metric, optimizer, params.RULE_CONFIGS, params.device, model, params=params, verbose=verbose)
secs = int(time.time() - start_time)
mins = secs / 60
secs = secs % 60
print('Epoch: %d' % (epoch + 1), " | time in %d minutes, %d seconds" % (mins, secs))
if params.early_stop == 'ndcg' and epoch % 10 == 0:
print('####NDCG####')
linear_ndcg, exp_ndcg = evaluate_ndcg(model, params.hr_map, params.VOCAB_SIZE)
print(f'Test ndcg (linear, exp): {linear_ndcg:.3f}, {exp_ndcg:.3f}')
wandb.log({'ndcg': linear_ndcg})
wandb.log({'exp_ndcg': exp_ndcg})
if linear_ndcg > last_best_ndcg:
last_best_ndcg = linear_ndcg
last_best_epoch = epoch
wandb.log({'best_ndcg': linear_ndcg})
wandb.log({'best_exp_ndcg': exp_ndcg})
wandb.log({'epoch': last_best_epoch})
torch.save(model, join(params.model_dir, f'{params.whichmodel}-{wandb.run.id}.pt'))
else:
if epoch >= 1 and epoch-last_best_epoch >= 200:
print('***best epoch:***', last_best_epoch)
wandb.log({'best_ndcg': linear_ndcg})
wandb.log({'best_exp_ndcg': exp_ndcg})
# print('***best metric:***', last_best_metric)
wandb.log({'epoch': last_best_epoch})
break # early stop
# early stopping
# print('best', best_metric['valid_mse'], 'last best', last_best_metric['valid_mse'])
if params.early_stop == 'valid_mse':
if epoch >= 1 and best_metric['valid_mse'] >= last_best_metric['valid_mse']: # no improvement or already overfit
print('epoch', epoch, 'last_best_epoch', last_best_epoch)
if epoch - last_best_epoch >= 50: # patience
print('***best epoch:***', last_best_epoch)
print('***best metric:***', last_best_metric)
wandb.log({'epoch': last_best_epoch})
# run.finish() # end wandb watch
break
else:
last_best_metric = best_metric.copy()
last_best_epoch = epoch
torch.save(model, join(params.model_dir, f'{params.whichmodel}-{wandb.run.id}.pt'))
if params.early_stop == 'valid_mae':
if epoch >= 1 and best_metric['valid_mae'] >= last_best_metric['valid_mae']: # no improvement or already overfit
print('epoch', epoch, 'last_best_epoch', last_best_epoch)
if epoch - last_best_epoch >= 50: # patience
print('***best epoch:***', last_best_epoch)
print('***best metric:***', last_best_metric)
wandb.log({'epoch': last_best_epoch})
# run.finish() # end wandb watch
break
else:
last_best_metric = best_metric.copy()
last_best_epoch = epoch
torch.save(model, join(params.model_dir, f'{params.whichmodel}-{wandb.run.id}.pt'))
class NDCGRankingTestDataset(TensorDataset):
def __init__(self, h, r, num_entities):
self.h, self.r = h, r
self.num_entities = num_entities
self.length = num_entities
# make candidate list for ranking task
self.candidate_triples = self.get_all_candidate_triples()
def get_all_candidate_triples(self):
# candidate triples:
# (h, r, 0), (h, r, 1), (h, r, 2) ...
candidates = torch.zeros((self.num_entities, 3), dtype=torch.long)
candidates[:, 0] = self.h
candidates[:, 1] = self.r
candidates[:, 2] = torch.arange(0, self.num_entities)
return candidates
def __getitem__(self, index):
return self.candidate_triples[index, :]
def __len__(self):
return self.length
class Tester:
class IndexScore:
"""
The score of a tail when h and r is given.
It's used in the ranking task to facilitate comparison and sorting.
Print w as 3 digit precision float.
"""
def __init__(self, index, score):
self.index = index
self.score = score
def __lt__(self, other):
return self.score < other.score
def __repr__(self):
# return "(index: %d, w:%.3f)" % (self.index, self.score)
return "(%d, %.3f)" % (self.index, self.score)
def __str__(self):
return "(index: %d, w:%.3f)" % (self.index, self.score)
def __init__(self, model, num_entity):
"""
:type test_dataset: ShirleyTripleDataset
"""
self.model = model
self.num_entity = num_entity
def get_score(self, h, r, i):
ids = torch.LongTensor([[h, r, i]])
cls = torch.Tensor([0]) # dummy
log_score, _ = self.model(ids, cls)
return torch.exp(log_score).detach().cpu().numpy()[0]
def get_t_ranks(self, h, r, ts):
"""
Given some t index, return the ranks for each t
:return:
"""
ranking_dataset = NDCGRankingTestDataset(
h, r, self.num_entity
) # for one hr
candidates_data = TensorDataLoader(
ranking_dataset,
batch_size=ranking_dataset.length,
shuffle=False
)
with torch.no_grad():
for ids in candidates_data: # only one batch
ids = ids # [[h,r,0],[h,r,1]...]
cls = torch.zeros(ids.shape[0])
log_scores, _ = self.model(ids, cls)
scores = log_scores
grt_scores = scores[ts]
ranks = np.array([(scores >= s).sum().detach().cpu().numpy() for s in grt_scores])
# print('ranks', ranks)
break
return ranks
def ndcg0(self, h, r, tw_truth):
"""
Compute nDCG(normalized discounted cummulative gain)
sum(score_ground_truth / log2(rank+1)) / max_possible_dcg
:param tw_truth: [IndexScore1, IndexScore2, ...], soreted by IndexScore.score descending
:return:
"""
# prediction
ts = [tw.index for tw in tw_truth]
ranks = self.get_t_ranks(h, r, ts)
# linear gain
gains = np.array([tw.score for tw in tw_truth])
discounts = np.log2(ranks + 2) # avoid division by 0
discounted_gains = gains / discounts
dcg = np.sum(discounted_gains) # discounted cumulative gain
# normalize
best_possible_ranks = np.array([(gains >= g).sum() for g in gains]) # gains [0.9, 0.8, 0.8, 0.7] -> [1,3,3,4]
max_possible_dcg = np.sum(gains / np.log2(best_possible_ranks + 1))
# max_possible_dcg = np.sum(gains / np.log2(np.arange(len(gains)) + 2))
ndcg = dcg / max_possible_dcg # normalized discounted cumulative gain
# exponential gain
exp_gains = np.array([2 ** tw.score - 1 for tw in tw_truth])
exp_discounted_gains = exp_gains / discounts
exp_dcg = np.sum(exp_discounted_gains)
# normalize
exp_best_possible_ranks = np.array([(exp_gains >= g).sum() for g in exp_gains])
exp_max_possible_dcg = np.sum(exp_gains / np.log2(exp_best_possible_ranks + 1))
# exp_max_possible_dcg = np.sum(exp_gains / np.log2(np.arange(len(gains)) + 2))
exp_ndcg = exp_dcg / exp_max_possible_dcg
return ndcg, exp_ndcg, ranks
def ndcg(self, h, r, tw_truth):
with torch.no_grad():
gains = torch.zeros(self.num_entity)
indices = torch.LongTensor([tw.index for tw in tw_truth])
weights = torch.FloatTensor([tw.score for tw in tw_truth])
gains[indices] = weights
# exp_gains = torch.exp2(gains) - 1
ranking_dataset = NDCGRankingTestDataset(
h, r, self.num_entity
) # for one hr
candidates_data = TensorDataLoader(
ranking_dataset,
batch_size=ranking_dataset.length,
shuffle=False
)
for ids in candidates_data: # only one batch
ids = ids # [[h,r,0],[h,r,1]...]
cls = torch.zeros(ids.shape[0])
log_scores, _ = self.model(ids, cls)
scores = torch.exp(log_scores)
linear_ndcg = ndcg_score(gains.unsqueeze(0).detach().cpu().numpy(), scores.unsqueeze(0).detach().cpu().numpy())
# exp_ndcg = ndcg_score(exp_gains.unsqueeze(0).detach().cpu().numpy(), scores.unsqueeze(0).detach().cpu().numpy())
return linear_ndcg, linear_ndcg, None
def mean_ndcg(self, hr_map):
"""
:param hr_map: {h:{r:{t:w}}}
:return:
"""
ndcg_sum = 0 # nDCG with linear gain
exp_ndcg_sum = 0
count = 0
t0 = time.time()
# debug ndcg
res = [] # [(h,r,tw_truth, ndcg)]
for h in hr_map:
for r in hr_map[h]:
tw_dict = hr_map[h][r] # {t:w}
tw_truth = [self.IndexScore(t, w) for t, w in tw_dict.items()]
tw_truth.sort(reverse=True) # descending on w
ndcg, exp_ndcg, ranks = self.ndcg(h, r, tw_truth) # nDCG with linear gain and exponential gain
ndcg_sum += ndcg
exp_ndcg_sum += exp_ndcg
count += 1
# debug
res.append((h, r, tw_truth, ndcg, ranks))
return ndcg_sum / count, exp_ndcg_sum / count