mirror of
				https://github.com/illiumst/marl-factory-grid.git
				synced 2025-10-31 12:37:27 +01:00 
			
		
		
		
	add Munchhausen DQN refactoring
This commit is contained in:
		| @@ -1,22 +1,20 @@ | ||||
| from typing import NamedTuple, Union, Iterable | ||||
| from collections import namedtuple, deque | ||||
| from typing import NamedTuple, Union | ||||
| from collections import deque | ||||
| import numpy as np | ||||
| import random | ||||
| import torch | ||||
| import torch.nn as nn | ||||
| import torch.nn.functional as F | ||||
| from stable_baselines3.common.utils import polyak_update | ||||
| from stable_baselines3.common.buffers import ReplayBuffer | ||||
| import copy | ||||
|  | ||||
|  | ||||
| class Experience(NamedTuple): | ||||
|     # can be use for a single (s_t, a, r s_{t+1}) tuple | ||||
|     # or for a batch of tuples | ||||
|     observation:      np.ndarray | ||||
|     next_observation: np.ndarray | ||||
|     action:           np.ndarray | ||||
|     reward:           Union[float, np.ndarray] | ||||
|     done  :           Union[bool, np.ndarray] | ||||
|     priority:         np.ndarray = 1 | ||||
|  | ||||
|  | ||||
| class BaseBuffer: | ||||
| @@ -65,9 +63,6 @@ class BaseDQN(nn.Module): | ||||
|         values = self.value_head(features) | ||||
|         return values + (advantages - advantages.mean()) | ||||
|  | ||||
|     def random_action(self): | ||||
|         return random.randrange(0, 5) | ||||
|  | ||||
|  | ||||
| def soft_update(local_model, target_model, tau): | ||||
|     # taken from https://github.com/BY571/Munchausen-RL/blob/master/M-DQN.ipynb | ||||
| @@ -165,6 +160,14 @@ class BaseQlearner: | ||||
|         next_q_values_raw = self.target_q_net(next_obs).max(dim=-1)[0].reshape(-1, 1).detach() | ||||
|         return current_q_values, next_q_values_raw | ||||
|  | ||||
|     def _backprop_loss(self, loss): | ||||
|         # log loss | ||||
|         self.running_loss.append(loss.item()) | ||||
|         # Optimize the model | ||||
|         self.optimizer.zero_grad() | ||||
|         loss.backward() | ||||
|         torch.nn.utils.clip_grad_norm_(self.q_net.parameters(), self.max_grad_norm) | ||||
|         self.optimizer.step() | ||||
|  | ||||
|     def train(self): | ||||
|         if len(self.buffer) < self.batch_size: return | ||||
| @@ -188,14 +191,7 @@ class BaseQlearner: | ||||
|                     target_q_raw += next_q_values_raw | ||||
|             target_q = experience.reward  + (1 - experience.done) * self.gamma * target_q_raw | ||||
|             loss = torch.mean(self.reg_weight * pred_q + torch.pow(pred_q - target_q, 2)) | ||||
|  | ||||
|             # log loss | ||||
|             self.running_loss.append(loss.item()) | ||||
|             # Optimize the model | ||||
|             self.optimizer.zero_grad() | ||||
|             loss.backward() | ||||
|             torch.nn.utils.clip_grad_norm_(self.q_net.parameters(), self.max_grad_norm) | ||||
|             self.optimizer.step() | ||||
|             self._backprop_loss(loss) | ||||
|  | ||||
|  | ||||
| class MDQN(BaseQlearner): | ||||
| @@ -207,7 +203,8 @@ class MDQN(BaseQlearner): | ||||
|         self.clip0 = clip_l0 | ||||
|  | ||||
|     def tau_ln_pi(self, qs): | ||||
|         # Custom log-sum-exp trick from page 18 to compute the e log-policy terms | ||||
|         # computes log(softmax(qs/temperature)) | ||||
|         # Custom log-sum-exp trick from page 18 to compute the log-policy terms | ||||
|         v_k = qs.max(-1)[0].unsqueeze(-1) | ||||
|         advantage = qs - v_k | ||||
|         logsum = torch.logsumexp(advantage / self.temperature, -1).unsqueeze(-1) | ||||
| @@ -242,21 +239,11 @@ class MDQN(BaseQlearner): | ||||
|  | ||||
|             # Compute loss | ||||
|             loss = torch.mean(self.reg_weight * pred_q + torch.pow(pred_q - m_q_target, 2)) | ||||
|  | ||||
|             # log loss | ||||
|             self.running_loss.append(loss.item()) | ||||
|             # Optimize the model | ||||
|             self.optimizer.zero_grad() | ||||
|             loss.backward() | ||||
|             torch.nn.utils.clip_grad_norm_(self.q_net.parameters(), self.max_grad_norm) | ||||
|             self.optimizer.step() | ||||
|  | ||||
|             self._backprop_loss(loss) | ||||
|  | ||||
|  | ||||
| if __name__ == '__main__': | ||||
|     from environments.factory.simple_factory import SimpleFactory, DirtProperties, MovementProperties | ||||
|     from algorithms.reg_dqn import RegDQN | ||||
|     from stable_baselines3.common.vec_env import DummyVecEnv | ||||
|  | ||||
|     N_AGENTS = 1 | ||||
|  | ||||
| @@ -266,16 +253,8 @@ if __name__ == '__main__': | ||||
|                                     allow_square_movement=True, | ||||
|                                     allow_no_op=False) | ||||
|     env = SimpleFactory(dirt_properties=dirt_props, movement_properties=move_props, n_agents=N_AGENTS, pomdp_radius=2,  max_steps=400, omit_agent_slice_in_obs=False, combin_agent_slices_in_obs=True) | ||||
|     #env = DummyVecEnv([lambda: env]) | ||||
|     from stable_baselines3.dqn import DQN | ||||
|  | ||||
|     #dqn = RegDQN('MlpPolicy', env, verbose=True, buffer_size = 40000, learning_starts = 0, batch_size = 64,learning_rate=0.0008, | ||||
|     #             target_update_interval = 3500, exploration_fraction = 0.25, exploration_final_eps = 0.05, | ||||
|     #             train_freq=4, gradient_steps=1, reg_weight=0.05, seed=69) | ||||
|     #dqn.learn(100000) | ||||
|  | ||||
|  | ||||
|     dqn, target_dqn = BaseDQN(), BaseDQN() | ||||
|     learner = MDQN(dqn, target_dqn, env, BaseBuffer(40000), target_update=3500, lr=0.0008, gamma=0.99, n_agents=N_AGENTS, tau=0.95, max_grad_norm=10, | ||||
|     learner = MDQN(dqn, target_dqn, env, BaseBuffer(40000), target_update=3500, lr=0.0007, gamma=0.99, n_agents=N_AGENTS, tau=0.95, max_grad_norm=10, | ||||
|                    train_every_n_steps=4, eps_end=0.025, n_grad_steps=1, reg_weight=0.1, exploration_fraction=0.25, batch_size=64) | ||||
|     learner.learn(100000) | ||||
|   | ||||
		Reference in New Issue
	
	Block a user
	 romue
					romue