Skip to content

Commit

Permalink
Fix test failures (#136)
Browse files Browse the repository at this point in the history
  • Loading branch information
NeelayS authored May 18, 2022
1 parent 5f95543 commit 9b29607
Show file tree
Hide file tree
Showing 5 changed files with 64 additions and 62 deletions.
9 changes: 6 additions & 3 deletions KD_Lib/KD/vision/attention/loss_metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,17 +21,20 @@ def forward(self, teacher_output, student_output):
:param student_output (torch.FloatTensor): Prediction made by the student model
"""

A_t = teacher_output[1:]
A_s = student_output[1:]
A_t = teacher_output # [1:]
A_s = student_output # [1:]

loss = 0.0
for (layerT, layerS) in zip(A_t, A_s):

xT = self.single_at_loss(layerT)
xS = self.single_at_loss(layerS)
loss += (xS - xT).pow(self.p).mean()

return loss

def single_at_loss(self, activation):
"""
Function for calculating single attention loss
"""
return F.normalize(activation.pow(self.p).mean(1).view(activation.size(0), -1))
return F.normalize(activation.pow(self.p).view(activation.size(0), -1))
2 changes: 1 addition & 1 deletion KD_Lib/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,4 +4,4 @@

__author__ = """Het Shah"""
__email__ = "divhet163@gmail.com"
__version__ = "__version__ = '0.0.31'"
__version__ = "__version__ = '0.0.32'"
2 changes: 1 addition & 1 deletion setup.cfg
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
[bumpversion]
current_version = 0.0.31
current_version = 0.0.32
commit = True
tag = True

Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,6 @@
test_suite="tests",
tests_require=test_requirements,
url="https://github.com/SforAiDL/KD_Lib",
version="0.0.31",
version="0.0.32",
zip_safe=False,
)
111 changes: 55 additions & 56 deletions tests/test_kd.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
RCO,
TAKD,
Attention,
BaseClass,
LabelSmoothReg,
MeanTeacher,
MessyCollab,
Expand All @@ -29,7 +28,7 @@
img_size = (32, 32)
img_channels = 3
n_classes = 10
len_dataset = 4
len_dataset = 8
batch_size = 2

train_loader = test_loader = DataLoader(
Expand Down Expand Up @@ -86,7 +85,7 @@ def test_TAKD():

assistant_train_order = [[-1], [-1, 0]]

distil = TAKD(
distiller = TAKD(
teacher,
assistants,
student,
Expand All @@ -98,32 +97,32 @@ def test_TAKD():
student_optimizer,
)

distil.train_teacher(epochs=1, plot_losses=False, save_model=False)
distil.train_assistants(epochs=1, plot_losses=False, save_model=False)
distil.train_student(epochs=1, plot_losses=False, save_model=False)
distil.get_parameters()
distiller.train_teacher(epochs=1, plot_losses=False, save_model=False)
distiller.train_assistants(epochs=1, plot_losses=False, save_model=False)
distiller.train_student(epochs=1, plot_losses=False, save_model=False)
distiller.get_parameters()


# def test_attention():
def test_Attention():

# att = Attention(
# teacher,
# student,
# train_loader,
# test_loader,
# t_optimizer,
# s_optimizer,
# )
distiller = Attention(
teacher,
student,
train_loader,
test_loader,
t_optimizer,
s_optimizer,
)

# att.train_teacher(epochs=1, plot_losses=False, save_model=False)
# att.train_student(epochs=1, plot_losses=False, save_model=False)
# att.evaluate(teacher=False)
# att.get_parameters()
distiller.train_teacher(epochs=1, plot_losses=False, save_model=False)
distiller.train_student(epochs=1, plot_losses=False, save_model=False)
distiller.evaluate(teacher=False)
distiller.get_parameters()


def test_NoisyTeacher():

experiment = NoisyTeacher(
distiller = NoisyTeacher(
teacher,
student,
train_loader,
Expand All @@ -135,10 +134,10 @@ def test_NoisyTeacher():
device="cpu",
)

experiment.train_teacher(epochs=1, plot_losses=False, save_model=False)
experiment.train_student(epochs=1, plot_losses=False, save_model=False)
experiment.evaluate(teacher=False)
experiment.get_parameters()
distiller.train_teacher(epochs=1, plot_losses=False, save_model=False)
distiller.train_student(epochs=1, plot_losses=False, save_model=False)
distiller.evaluate(teacher=False)
distiller.get_parameters()


def test_VirtualTeacher():
Expand All @@ -158,21 +157,21 @@ def test_SelfTraining():
distiller.get_parameters()


# def test_mean_teacher():
# def test_MeanTeacher():

# mt = MeanTeacher(
# teacher_model,
# student_model,
# distiller = MeanTeacher(
# teacher,
# student,
# train_loader,
# test_loader,
# t_optimizer,
# s_optimizer,
# )

# mt.train_teacher(epochs=1, plot_losses=False, save_model=False)
# mt.train_student(epochs=1, plot_losses=False, save_model=False)
# mt.evaluate()
# mt.get_parameters()
# distiller.train_teacher(epochs=1, plot_losses=False, save_model=False)
# distiller.train_student(epochs=1, plot_losses=False, save_model=False)
# distiller.evaluate()
# distiller.get_parameters()


def test_RCO():
Expand All @@ -192,15 +191,15 @@ def test_RCO():
distiller.get_parameters()


# def test_BANN():
def test_BANN():

# model = deepcopy(mock_vision_model)
# optimizer = optim.SGD(model.parameters(), 0.01)
model = deepcopy(mock_vision_model)
optimizer = optim.SGD(model.parameters(), 0.01)

# distiller = BANN(model, train_loader, test_loader, optimizer, num_gen=2)
distiller = BANN(model, train_loader, test_loader, optimizer, num_gen=2)

# distiller.train_student(epochs=1, plot_losses=False, save_model=False)
# distiller.evaluate()
distiller.train_student(epochs=1, plot_losses=False, save_model=False)
# distiller.evaluate()


def test_PS():
Expand Down Expand Up @@ -237,7 +236,7 @@ def test_LSR():
distiller.get_parameters()


def test_soft_random():
def test_SoftRandom():

distiller = SoftRandom(
teacher,
Expand All @@ -254,7 +253,7 @@ def test_soft_random():
distiller.get_parameters()


def test_messy_collab():
def test_MessyCollab():

distiller = MessyCollab(
teacher,
Expand All @@ -271,21 +270,6 @@ def test_messy_collab():
distiller.get_parameters()


# def test_bert2lstm():
# student_model = LSTMNet(
# input_dim=len(text_field.vocab), num_classes=2, dropout_prob=0.5
# )
# optimizer = optim.Adam(student_model.parameters())
#
# experiment = BERT2LSTM(
# student_model, bert2lstm_train_loader, bert2lstm_train_loader, optimizer, train_df, val_df
# )
# # experiment.train_teacher(epochs=1, plot_losses=False, save_model=False)
# experiment.train_student(epochs=1, plot_losses=False, save_model=False)
# experiment.evaluate_student()
# experiment.evaluate_teacher()


def test_DML():

student_1 = deepcopy(mock_vision_model)
Expand All @@ -312,3 +296,18 @@ def test_DML():
)
distiller.evaluate()
distiller.get_parameters()


# def test_BERT2LSTM():
# student_model = LSTMNet(
# input_dim=len(text_field.vocab), num_classes=2, dropout_prob=0.5
# )
# optimizer = optim.Adam(student_model.parameters())
#
# distiller = BERT2LSTM(
# student_model, bert2lstm_train_loader, bert2lstm_train_loader, optimizer, train_df, val_df
# )
# # distiller.train_teacher(epochs=1, plot_losses=False, save_model=False)
# distiller.train_student(epochs=1, plot_losses=False, save_model=False)
# distiller.evaluate_student()
# distiller.evaluate_teacher()

0 comments on commit 9b29607

Please sign in to comment.