Skip to content

Commit 9318a69

Browse files
MorvanZhouMorvan Zhou
authored andcommitted
fix action shape problem
1 parent a7b14b8 commit 9318a69

1 file changed

Lines changed: 6 additions & 3 deletions

File tree

tutorial-contents/405_DQN_Reinforcement_learning.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -26,14 +26,15 @@
2626
env = env.unwrapped
2727
N_ACTIONS = env.action_space.n
2828
N_STATES = env.observation_space.shape[0]
29+
ENV_A_SHAPE = 0 if isinstance(env.action_space.sample(), int) else env.action_space.sample().shape # to confirm the shape
2930

3031

3132
class Net(nn.Module):
3233
def __init__(self, ):
3334
super(Net, self).__init__()
34-
self.fc1 = nn.Linear(N_STATES, 10)
35+
self.fc1 = nn.Linear(N_STATES, 50)
3536
self.fc1.weight.data.normal_(0, 0.1) # initialization
36-
self.out = nn.Linear(10, N_ACTIONS)
37+
self.out = nn.Linear(50, N_ACTIONS)
3738
self.out.weight.data.normal_(0, 0.1) # initialization
3839

3940
def forward(self, x):
@@ -58,9 +59,11 @@ def choose_action(self, x):
5859
# input only one sample
5960
if np.random.uniform() < EPSILON: # greedy
6061
actions_value = self.eval_net.forward(x)
61-
action = torch.max(actions_value, 1)[1].data.numpy()[0, 0] # return the argmax
62+
action = torch.max(actions_value, 1)[1].data.numpy()
63+
action = action[0, 0] if ENV_A_SHAPE == 0 else action.reshape(ENV_A_SHAPE) # return the argmax index
6264
else: # random
6365
action = np.random.randint(0, N_ACTIONS)
66+
action = action if ENV_A_SHAPE == 0 else action.reshape(ENV_A_SHAPE)
6467
return action
6568

6669
def store_transition(self, s, a, r, s_):

0 commit comments

Comments
 (0)