mirror of
				https://github.com/illiumst/marl-factory-grid.git
				synced 2025-10-31 12:37:27 +01:00 
			
		
		
		
	added mlpmaker
This commit is contained in:
		| @@ -47,19 +47,22 @@ def soft_update(local_model, target_model, tau): | ||||
|         target_param.data.copy_(tau*local_param.data + (1.-tau)*target_param.data) | ||||
|  | ||||
|  | ||||
| def mlp_maker(dims): | ||||
|     layers = [('Flatten', nn.Flatten())] | ||||
| def mlp_maker(dims, flatten=False, activation='elu', activation_last='identity'): | ||||
|     activations = {'elu': nn.ELU, 'relu': nn.ReLU, | ||||
|                   'leaky_relu': nn.LeakyReLU, 'tanh': nn.Tanh, | ||||
|                   'gelu': nn.GELU, 'identity': nn.Identity} | ||||
|     layers = [('Flatten', nn.Flatten())] if flatten else [] | ||||
|     for i in range(1, len(dims)): | ||||
|         layers.append((f'Linear#{i - 1}', nn.Linear(dims[i - 1], dims[i]))) | ||||
|         if i != len(dims) - 1: | ||||
|             layers.append(('ELU', nn.ELU())) | ||||
|         layers.append((f'Layer #{i - 1}: Linear', nn.Linear(dims[i - 1], dims[i]))) | ||||
|         activation_str = activation if i != len(dims)-1 else activation_last | ||||
|         layers.append((f'Layer #{i - 1}: {activation_str.capitalize()}', activations[activation_str]())) | ||||
|     return nn.Sequential(OrderedDict(layers)) | ||||
|  | ||||
|  | ||||
| class BaseDQN(nn.Module): | ||||
|     def __init__(self, dims=[3*5*5, 64, 64, 9]): | ||||
|         super(BaseDQN, self).__init__() | ||||
|         self.net = mlp_maker(dims) | ||||
|         self.net = mlp_maker(dims, flatten=True) | ||||
|  | ||||
|     def act(self, x) -> np.ndarray: | ||||
|         with torch.no_grad(): | ||||
| @@ -76,6 +79,7 @@ class BaseDDQN(BaseDQN): | ||||
|                  value_dims=[64,1], | ||||
|                  advantage_dims=[64,9]): | ||||
|         super(BaseDDQN, self).__init__(backbone_dims) | ||||
|         self.net = mlp_maker(backbone_dims, flatten=True) | ||||
|         self.value_head         =  mlp_maker(value_dims) | ||||
|         self.advantage_head     =  mlp_maker(advantage_dims) | ||||
|  | ||||
| @@ -86,13 +90,11 @@ class BaseDDQN(BaseDQN): | ||||
|         return values + (advantages - advantages.mean()) | ||||
|  | ||||
|  | ||||
|  | ||||
| class BaseQlearner: | ||||
|     def __init__(self, q_net, target_q_net, env, buffer_size, target_update, eps_end, n_agents=1, | ||||
|                  gamma=0.99, train_every_n_steps=4, n_grad_steps=1, tau=1.0, max_grad_norm=10, | ||||
|                  exploration_fraction=0.2, batch_size=64, lr=1e-4, reg_weight=0.0): | ||||
|         self.q_net = q_net | ||||
|         print(self.q_net) | ||||
|         self.target_q_net = target_q_net | ||||
|         self.target_q_net.eval() | ||||
|         soft_update(self.q_net, self.target_q_net, tau=1.0) | ||||
| @@ -205,14 +207,13 @@ class BaseQlearner: | ||||
|                     pred_q += q_values | ||||
|                     target_q_raw += next_q_values_raw | ||||
|             target_q = experience.reward  + (1 - experience.done) * self.gamma * target_q_raw | ||||
|             #print(pred_q[0], target_q_raw[0], target_q[0], experience.reward[0]) | ||||
|             loss = torch.mean(self.reg_weight * pred_q + torch.pow(pred_q - target_q, 2)) | ||||
|             self._backprop_loss(loss) | ||||
|  | ||||
|  | ||||
| class MDQN(BaseQlearner): | ||||
| class MunchhausenQLearner(BaseQlearner): | ||||
|     def __init__(self, *args, temperature=0.03, alpha=0.9, clip_l0=-1.0, **kwargs): | ||||
|         super(MDQN, self).__init__(*args, **kwargs) | ||||
|         super(MunchhausenQLearner, self).__init__(*args, **kwargs) | ||||
|         assert self.n_agents == 1, 'M-DQN currently only supports single agent training' | ||||
|         self.temperature = temperature | ||||
|         self.alpha = alpha | ||||
| @@ -260,7 +261,6 @@ class MDQN(BaseQlearner): | ||||
|  | ||||
| if __name__ == '__main__': | ||||
|     from environments.factory.simple_factory import SimpleFactory, DirtProperties, MovementProperties | ||||
|     from gym.wrappers import FrameStack | ||||
|  | ||||
|     N_AGENTS = 1 | ||||
|  | ||||
| @@ -272,6 +272,6 @@ if __name__ == '__main__': | ||||
|     env = SimpleFactory(dirt_properties=dirt_props, movement_properties=move_props, n_agents=N_AGENTS, pomdp_radius=2,  max_steps=400, omit_agent_slice_in_obs=False, combin_agent_slices_in_obs=True) | ||||
|  | ||||
|     dqn, target_dqn = BaseDDQN(), BaseDDQN() | ||||
|     learner = MDQN(dqn, target_dqn, env, 40000, target_update=3500, lr=0.0007, gamma=0.99, n_agents=N_AGENTS, tau=0.95, max_grad_norm=10, | ||||
|                    train_every_n_steps=4, eps_end=0.025, n_grad_steps=1, reg_weight=0.1, exploration_fraction=0.25, batch_size=64) | ||||
|     learner = MunchhausenQLearner(dqn, target_dqn, env, 40000, target_update=3500, lr=0.0007, gamma=0.99, n_agents=N_AGENTS, tau=0.95, max_grad_norm=10, | ||||
|                            train_every_n_steps=4, eps_end=0.025, n_grad_steps=1, reg_weight=0.1, exploration_fraction=0.25, batch_size=64) | ||||
|     learner.learn(100000) | ||||
|   | ||||
		Reference in New Issue
	
	Block a user
	 romue
					romue