-
Notifications
You must be signed in to change notification settings - Fork 0
/
training.py
147 lines (129 loc) · 5.77 KB
/
training.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
import numpy as np
import tensorflow as tf
import os
import csv
from gridworld import gameEnv
from Model import Qnetwork, experience_buffer
from helper import *
env = gameEnv(partial=True, size=9)
batch_size = 4
trace_length = 8
update_freq = 5
y = .99
startE = 1
endE = 0.1
anneling_steps = 10000
num_episodes = 10000
pre_train_steps = 10000
load_model = False
path = "./drqn"
h_size = 512
max_epLength = 50
time_per_step = 1
summaryLength = 100
tau = 0.001
tf.reset_default_graph()
# We define the cells for the primary and target q-networks
cell = tf.contrib.rnn.BasicLSTMCell(num_units=h_size, state_is_tuple=True)
cellT = tf.contrib.rnn.BasicLSTMCell(num_units=h_size, state_is_tuple=True)
mainQN = Qnetwork(h_size, cell, 'main')
targetQN = Qnetwork(h_size, cellT, 'target')
init = tf.global_variables_initializer()
saver = tf.train.Saver(max_to_keep=5)
trainables = tf.trainable_variables()
targetOps = updateTargetGraph(trainables, tau)
myBuffer = experience_buffer()
# Set the rate of random action decrease.
e = startE
stepDrop = (startE - endE) / anneling_steps
# create lists to contain total rewards and steps per episode
jList = []
rList = []
total_steps = 0
# Make a path for our model to be saved in.
if not os.path.exists(path):
os.makedirs(path)
# Write the first line of the master log-file for the Control Center
with open('./Center/log.csv', 'w') as myfile:
wr = csv.writer(myfile, quoting=csv.QUOTE_ALL)
wr.writerow(['Episode', 'Length', 'Reward', 'IMG', 'LOG', 'SAL'])
with tf.Session() as sess:
if load_model:
print('Loading Model...')
ckpt = tf.train.get_checkpoint_state(path)
saver.restore(sess, ckpt.model_checkpoint_path)
sess.run(init)
updateTarget(targetOps, sess) # Set the target network to be equal to the primary network.
for i in range(num_episodes):
episodeBuffer = []
# Reset environment and get first new observation
sP = env.reset()
s = processState(sP)
d = False
rAll = 0
j = 0
state = (np.zeros([1, h_size]), np.zeros([1, h_size])) # Reset the recurrent layer's hidden state
# The Q-Network
while j < max_epLength:
j += 1
# Choose an action by greedily (with e chance of random action) from the Q-network
if np.random.rand(1) < e or total_steps < pre_train_steps:
state1 = sess.run(mainQN.rnn_state,
feed_dict={mainQN.scalarInput: [s / 255.0], mainQN.trainLength: 1,
mainQN.state_in: state, mainQN.batch_size: 1})
a = np.random.randint(0, 4)
else:
a, state1 = sess.run([mainQN.predict, mainQN.rnn_state],
feed_dict={mainQN.scalarInput: [s / 255.0], mainQN.trainLength: 1,
mainQN.state_in: state, mainQN.batch_size: 1})
a = a[0]
s1P, r, d = env.step(a)
s1 = processState(s1P)
total_steps += 1
episodeBuffer.append(np.reshape(np.array([s, a, r, s1, d]), [1, 5]))
if total_steps > pre_train_steps:
if e > endE:
e -= stepDrop
if total_steps % (update_freq) == 0:
updateTarget(targetOps, sess)
# Reset the recurrent layer's hidden state
state_train = (np.zeros([batch_size, h_size]), np.zeros([batch_size, h_size]))
trainBatch = myBuffer.sample(batch_size, trace_length) # Get a random batch of experiences.
# Below we perform the Double-DQN update to the target Q-values
Q1 = sess.run(mainQN.predict, feed_dict={
mainQN.scalarInput: np.vstack(trainBatch[:, 3] / 255.0),
mainQN.trainLength: trace_length, mainQN.state_in: state_train, mainQN.batch_size: batch_size})
Q2 = sess.run(targetQN.Qout, feed_dict={
targetQN.scalarInput: np.vstack(trainBatch[:, 3] / 255.0),
targetQN.trainLength: trace_length, targetQN.state_in: state_train,
targetQN.batch_size: batch_size})
end_multiplier = -(trainBatch[:, 4] - 1)
doubleQ = Q2[range(batch_size * trace_length), Q1]
targetQ = trainBatch[:, 2] + (y * doubleQ * end_multiplier)
# Update the network with our target values.
sess.run(mainQN.updateModel,
feed_dict={mainQN.scalarInput: np.vstack(trainBatch[:, 0] / 255.0),
mainQN.targetQ: targetQ,
mainQN.actions: trainBatch[:, 1], mainQN.trainLength: trace_length,
mainQN.state_in: state_train, mainQN.batch_size: batch_size})
rAll += r
s = s1
sP = s1P
state = state1
if d:
break
# Add the episode to the experience buffer
bufferArray = np.array(episodeBuffer)
episodeBuffer = list(zip(bufferArray))
myBuffer.add(episodeBuffer)
jList.append(j)
rList.append(rAll)
# Periodically save the model.
if i % 1000 == 0 and i != 0:
saver.save(sess, path + '/model-' + str(i) + '.cptk')
print("Saved Model")
if len(rList) % summaryLength == 0 and len(rList) != 0:
print(total_steps, np.mean(rList[-summaryLength:]), e)
saveToCenter(i, rList, jList, np.reshape(np.array(episodeBuffer), [len(episodeBuffer), 5]),
summaryLength, h_size, sess, mainQN, time_per_step)
saver.save(sess, path + '/model-' + str(i) + '.cptk')