added own dqn

This commit is contained in:
romue 2021-06-17 16:26:07 +02:00
parent 7893e1131e
commit 813c9d2c91

View File

@ -183,7 +183,7 @@ class BaseQlearner:
target_q = experience.reward + (1 - experience.done) * self.gamma * target_q_raw target_q = experience.reward + (1 - experience.done) * self.gamma * target_q_raw
loss = torch.mean(self.reg_weight * pred_q + torch.pow(pred_q - target_q, 2)) loss = torch.mean(self.reg_weight * pred_q + torch.pow(pred_q - target_q, 2))
print(target_q) #print(target_q)
# log loss # log loss
self.running_loss.append(loss.item()) self.running_loss.append(loss.item())
@ -210,7 +210,6 @@ if __name__ == '__main__':
allow_no_op=False) allow_no_op=False)
env = SimpleFactory(dirt_properties=dirt_props, movement_properties=move_props, n_agents=N_AGENTS, pomdp_radius=2, max_steps=400, omit_agent_slice_in_obs=False) env = SimpleFactory(dirt_properties=dirt_props, movement_properties=move_props, n_agents=N_AGENTS, pomdp_radius=2, max_steps=400, omit_agent_slice_in_obs=False)
# env = DummyVecEnv([lambda: env]) # env = DummyVecEnv([lambda: env])
print(env)
from stable_baselines3.dqn import DQN from stable_baselines3.dqn import DQN
#dqn = RegDQN('MlpPolicy', env, verbose=True, buffer_size = 50000, learning_starts = 64, batch_size = 64, #dqn = RegDQN('MlpPolicy', env, verbose=True, buffer_size = 50000, learning_starts = 64, batch_size = 64,