-
Notifications
You must be signed in to change notification settings - Fork 5
/
resnet.py
102 lines (87 loc) · 3.42 KB
/
resnet.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
from collections import namedtuple
from math import sqrt
from sklearn import metrics
import tensorflow as tf
from tensorflow.contrib import skflow
from data_utils import load_CIFAR100
def weight_variable(shape,name=None):
initial = tf.truncated_normal(shape, stddev=0.1)
return tf.Variable(initial,name)
def bias_variable(shape,name=None):
initial = tf.constant(0.1, shape=shape)
return tf.Variable(initial)
def res_net(x, y, activation=tf.nn.elu):
with tf.variable_scope('conv_layer1'):
net = skflow.ops.conv2d(x, 16, [3, 3], batch_norm=True,
activation=activation, bias=False,
padding='SAME')
for block in range(3):
nfilters=16<<block
for layer in range(2):
net_copy=net
name = 'block_%d/layer_%d' % (block, layer)
for i in range(2):
with tf.variable_scope(name+'/'+str(i)):
if layer==0 and block!=0 and i==0:
up=1
else:
up=0
net = skflow.ops.conv2d(net,
nfilters,
[3, 3], [1, 1+up,1+up, 1],
padding='SAME',
activation=activation,
batch_norm=True,
bias=False)
# residual function (identity shortcut)
if net_copy.get_shape().as_list()[1]!=net.get_shape().as_list()[1]:
net_copy=tf.nn.avg_pool(net_copy,[1,2,2,1],
strides=[1,2,2,1],padding='VALID')
net_copy=tf.pad(net_copy,[[0,0],[0,0],[0,0],[int(nfilters/4),int(nfilters/4)]])
net = net + net_copy
#Global avg pooling
net_shape = net.get_shape().as_list()
net = tf.nn.avg_pool(net,
ksize=[1, net_shape[1], net_shape[2], 1],
strides=[1, 1, 1, 1], padding='VALID')
net_shape = net.get_shape().as_list()
net = tf.reshape(net, [-1, net_shape[1] * net_shape[2] * net_shape[3]])
w=tf.get_variable('w',[net.get_shape()[1],y.get_shape()[-1]])
b=tf.get_variable('b',[y.get_shape()[-1]])
logits = tf.nn.xw_plus_b(net, w, b)
h = tf.nn.softmax_cross_entropy_with_logits(logits, y, name='h_raw')
loss = tf.reduce_mean(h, name='cross_entropy')
predictions = tf.nn.softmax(logits, name=name)
#return skflow.models.logistic_regression(net,y)
return predictions, loss
path='./dataset/cifar-100-python'
Xtr, Ytr, Xte, Yte=load_CIFAR100(path)
nclass=20
batch_size=256
steps=int(Xtr.shape[0]/batch_size)
w=weight_variable([64,nclass],'w')
b=bias_variable([nclass,1],'b')
classifier = skflow.TensorFlowEstimator(
model_fn=res_net,
n_classes=nclass, batch_size=batch_size,
steps=steps,
learning_rate=0.1, continue_training=True,
optimizer="Adam",
verbose=1)
import time
t=time.time()
while True:
classifier.fit(Xtr, Ytr, logdir="models/resnet/")
# Calculate accuracy.
score = metrics.accuracy_score(
Yte, classifier.predict(Xte, batch_size=64))
now=int((time.time()-t)/60.0)
print('Accuracy: {0:f}'.format(score)+' time'+str(now))
# Save model graph and checkpoints.
classifier.save("models/resnet/")
if now > 170:
break