2626env = env .unwrapped
2727N_ACTIONS = env .action_space .n
2828N_STATES = env .observation_space .shape [0 ]
29+ ENV_A_SHAPE = 0 if isinstance (env .action_space .sample (), int ) else env .action_space .sample ().shape # to confirm the shape
2930
3031
3132class Net (nn .Module ):
3233 def __init__ (self , ):
3334 super (Net , self ).__init__ ()
34- self .fc1 = nn .Linear (N_STATES , 10 )
35+ self .fc1 = nn .Linear (N_STATES , 50 )
3536 self .fc1 .weight .data .normal_ (0 , 0.1 ) # initialization
36- self .out = nn .Linear (10 , N_ACTIONS )
37+ self .out = nn .Linear (50 , N_ACTIONS )
3738 self .out .weight .data .normal_ (0 , 0.1 ) # initialization
3839
3940 def forward (self , x ):
@@ -58,9 +59,11 @@ def choose_action(self, x):
5859 # input only one sample
5960 if np .random .uniform () < EPSILON : # greedy
6061 actions_value = self .eval_net .forward (x )
61- action = torch .max (actions_value , 1 )[1 ].data .numpy ()[0 , 0 ] # return the argmax
62+ action = torch .max (actions_value , 1 )[1 ].data .numpy ()
63+ action = action [0 , 0 ] if ENV_A_SHAPE == 0 else action .reshape (ENV_A_SHAPE ) # return the argmax index
6264 else : # random
6365 action = np .random .randint (0 , N_ACTIONS )
66+ action = action if ENV_A_SHAPE == 0 else action .reshape (ENV_A_SHAPE )
6467 return action
6568
6669 def store_transition (self , s , a , r , s_ ):
0 commit comments