mirror of
https://github.com/illiumst/marl-factory-grid.git
synced 2025-05-23 07:16:44 +02:00
added own dqn
This commit is contained in:
parent
7893e1131e
commit
813c9d2c91
@ -183,7 +183,7 @@ class BaseQlearner:
|
||||
|
||||
target_q = experience.reward + (1 - experience.done) * self.gamma * target_q_raw
|
||||
loss = torch.mean(self.reg_weight * pred_q + torch.pow(pred_q - target_q, 2))
|
||||
print(target_q)
|
||||
#print(target_q)
|
||||
|
||||
# log loss
|
||||
self.running_loss.append(loss.item())
|
||||
@ -210,7 +210,6 @@ if __name__ == '__main__':
|
||||
allow_no_op=False)
|
||||
env = SimpleFactory(dirt_properties=dirt_props, movement_properties=move_props, n_agents=N_AGENTS, pomdp_radius=2, max_steps=400, omit_agent_slice_in_obs=False)
|
||||
# env = DummyVecEnv([lambda: env])
|
||||
print(env)
|
||||
from stable_baselines3.dqn import DQN
|
||||
|
||||
#dqn = RegDQN('MlpPolicy', env, verbose=True, buffer_size = 50000, learning_starts = 64, batch_size = 64,
|
||||
|
Loading…
x
Reference in New Issue
Block a user