forked from yusuketomoto/chainer-fast-neuralstyle
-
Notifications
You must be signed in to change notification settings - Fork 0
/
create_chainer_model.py
executable file
·38 lines (35 loc) · 1.29 KB
/
create_chainer_model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
from chainer import link
from chainer.links.caffe import CaffeFunction
from chainer import serializers
from net import *
def copy_model(src, dst):
assert isinstance(src, link.Chain)
assert isinstance(dst, link.Chain)
for child in src.children():
if child.name not in dst.__dict__: continue
dst_child = dst[child.name]
if type(child) != type(dst_child): continue
if isinstance(child, link.Chain):
copy_model(child, dst_child)
if isinstance(child, link.Link):
match = True
for a, b in zip(child.namedparams(), dst_child.namedparams()):
if a[0] != b[0]:
match = False
break
if a[1].data.shape != b[1].data.shape:
match = False
break
if not match:
print 'Ignore %s because of parameter mismatch' % child.name
continue
for a, b in zip(child.namedparams(), dst_child.namedparams()):
b[1].data = a[1].data
print 'Copy %s' % child.name
print 'load VGG16 caffemodel'
ref = CaffeFunction('VGG_ILSVRC_16_layers.caffemodel')
vgg = VGG()
print 'copy weights'
copy_model(ref, vgg)
print 'save "vgg16.model"'
serializers.save_npz('vgg16.model', vgg)