Skip to content

Commit 3b23f6b

Browse files
MorvanZhouMorvan Zhou
authored andcommitted
update for torch 0.2
1 parent f31c88b commit 3b23f6b

1 file changed

Lines changed: 2 additions & 2 deletions

File tree

tutorial-contents/405_DQN_Reinforcement_learning.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ def choose_action(self, x):
5858
# input only one sample
5959
if np.random.uniform() < EPSILON: # greedy
6060
actions_value = self.eval_net.forward(x)
61-
action = torch.max(actions_value, 1)[1].data.numpy()[0, 0] # return the argmax
61+
action = torch.max(actions_value, 1)[1].data.numpy()[0] # return the argmax
6262
else: # random
6363
action = np.random.randint(0, N_ACTIONS)
6464
return action
@@ -87,7 +87,7 @@ def learn(self):
8787
# q_eval w.r.t the action in experience
8888
q_eval = self.eval_net(b_s).gather(1, b_a) # shape (batch, 1)
8989
q_next = self.target_net(b_s_).detach() # detach from graph, don't backpropagate
90-
q_target = b_r + GAMMA * q_next.max(1)[0] # shape (batch, 1)
90+
q_target = b_r + GAMMA * q_next.max(1)[0].view(BATCH_SIZE, 1) # shape (batch, 1)
9191
loss = self.loss_func(q_eval, q_target)
9292

9393
self.optimizer.zero_grad()

0 commit comments

Comments
 (0)