mirror of
				https://github.com/illiumst/marl-factory-grid.git
				synced 2025-10-31 04:37:25 +01:00 
			
		
		
		
	Merge remote-tracking branch 'origin/main'
This commit is contained in:
		| @@ -1,9 +1,16 @@ | ||||
| import re | ||||
| import torch | ||||
| import numpy as np | ||||
| import yaml | ||||
| from pathlib import Path | ||||
| from salina import instantiate_class | ||||
| from salina import TAgent | ||||
| from salina.agents.gyma import ( | ||||
|     AutoResetGymAgent, | ||||
|     _torch_type, | ||||
|     _format_frame, | ||||
|     _torch_cat_dict | ||||
| ) | ||||
|  | ||||
|  | ||||
| def load_yaml_file(path: Path): | ||||
| @@ -18,13 +25,86 @@ def add_env_props(cfg): | ||||
|                              n_actions=env.action_space.n)) | ||||
|  | ||||
|  | ||||
|  | ||||
|  | ||||
| AGENT_PREFIX = 'agent#' | ||||
| REWARD       =  'reward' | ||||
| CUMU_REWARD  = 'cumulated_reward' | ||||
| OBS          = 'env_obs' | ||||
| SEP          = '_' | ||||
| ACTION       = 'action' | ||||
|  | ||||
|  | ||||
| def access_str(agent_i, name, prefix=''): | ||||
|     return f'{prefix}{AGENT_PREFIX}{agent_i}{SEP}{name}' | ||||
|  | ||||
|  | ||||
| class AutoResetGymMultiAgent(AutoResetGymAgent): | ||||
|     def __init__(self, *args, **kwargs): | ||||
|         super(AutoResetGymMultiAgent, self).__init__(*args, **kwargs) | ||||
|  | ||||
|     def per_agent_values(self, name, values): | ||||
|         return {access_str(agent_i, name): value | ||||
|                 for agent_i, value in zip(range(self.n_agents), values)} | ||||
|  | ||||
|     def _initialize_envs(self, n): | ||||
|         super()._initialize_envs(n) | ||||
|         n_agents_list = [self.envs[i].unwrapped.n_agents for i in range(n)] | ||||
|         assert all(n_agents == n_agents_list[0] for n_agents in n_agents_list), \ | ||||
|             'All envs must have the same number of agents.' | ||||
|         self.n_agents = n_agents_list[0] | ||||
|  | ||||
|     def _reset(self, k, save_render): | ||||
|         ret = super()._reset(k, save_render) | ||||
|         obs = ret['env_obs'].squeeze() | ||||
|         self.cumulated_reward[k] = [0.0]*self.n_agents | ||||
|         obs      = self.per_agent_values(OBS,  [_format_frame(obs[i]) for i in range(self.n_agents)]) | ||||
|         cumu_rew = self.per_agent_values(CUMU_REWARD, torch.zeros(self.n_agents, 1).float().unbind()) | ||||
|         rewards  = self.per_agent_values(REWARD,      torch.zeros(self.n_agents, 1).float().unbind()) | ||||
|         ret.update(cumu_rew) | ||||
|         ret.update(rewards) | ||||
|         ret.update(obs) | ||||
|         for remove in ['env_obs', 'cumulated_reward', 'reward']: | ||||
|             del ret[remove] | ||||
|         return ret | ||||
|  | ||||
|     def _step(self, k, action, save_render): | ||||
|         self.timestep[k] += 1 | ||||
|         env = self.envs[k] | ||||
|         if len(action.size()) == 0: | ||||
|             action = action.item() | ||||
|             assert isinstance(action, int) | ||||
|         else: | ||||
|             action = np.array(action.tolist()) | ||||
|         o, r, d, _ = env.step(action) | ||||
|         self.cumulated_reward[k] = [x+y for x, y in zip(r, self.cumulated_reward[k])] | ||||
|         observation = self.per_agent_values(OBS, [_format_frame(o[i]) for i in range(self.n_agents)]) | ||||
|         if d: | ||||
|             self.is_running[k] = False | ||||
|         if save_render: | ||||
|             image = env.render(mode="image").unsqueeze(0) | ||||
|             observation["rendering"] = image | ||||
|         rewards           = self.per_agent_values(REWARD, torch.tensor(r).float().view(-1, 1).unbind()) | ||||
|         cumulated_rewards = self.per_agent_values(CUMU_REWARD, torch.tensor(self.cumulated_reward[k]).float().view(-1, 1).unbind()) | ||||
|         ret = { | ||||
|             **observation, | ||||
|             **rewards, | ||||
|             **cumulated_rewards, | ||||
|             "done": torch.tensor([d]), | ||||
|             "initial_state": torch.tensor([False]), | ||||
|             "timestep": torch.tensor([self.timestep[k]]) | ||||
|         } | ||||
|         return _torch_type(ret) | ||||
|  | ||||
|  | ||||
| class CombineActionsAgent(TAgent): | ||||
|     def __init__(self, pattern=r'^agent\d_action$'): | ||||
|     def __init__(self): | ||||
|         super().__init__() | ||||
|         self.pattern = pattern | ||||
|         self.pattern = fr'^{AGENT_PREFIX}\d{SEP}{ACTION}$' | ||||
|  | ||||
|     def forward(self, t, **kwargs): | ||||
|         keys = list(self.workspace.keys()) | ||||
|         action_keys = sorted([k for k in keys if bool(re.match(self.pattern, k))]) | ||||
|         actions = torch.cat([self.get((k, t)) for k in action_keys], 0) | ||||
|         self.set((f'action', t), actions.unsqueeze(0)) | ||||
|         actions = actions if len(action_keys) <= 1 else actions.unsqueeze(0) | ||||
|         self.set((f'action', t), actions) | ||||
|   | ||||
| @@ -1,4 +1,4 @@ | ||||
| def make(env_name, n_agents=1, pomdp_r=2, max_steps=400, stack_n_frames=3): | ||||
| def make(env_name, pomdp_r=2, max_steps=400, stack_n_frames=3, n_agents=1,  individual_rewards=False): | ||||
|     import yaml | ||||
|     from pathlib import Path | ||||
|     from environments.factory.combined_factories import DirtItemFactory | ||||
| @@ -12,7 +12,8 @@ def make(env_name, n_agents=1, pomdp_r=2, max_steps=400, stack_n_frames=3): | ||||
|     obs_props = ObservationProperties(render_agents=AgentRenderOptions.COMBINED, | ||||
|                                       frames_to_stack=stack_n_frames, pomdp_r=pomdp_r) | ||||
|  | ||||
|     factory_kwargs = dict(n_agents=n_agents, max_steps=max_steps, obs_prop=obs_props, | ||||
|     factory_kwargs = dict(n_agents=n_agents, individual_rewards=individual_rewards, | ||||
|                           max_steps=max_steps, obs_prop=obs_props, | ||||
|                           mv_prop=MovementProperties(**dictionary['movement_props']), | ||||
|                           dirt_prop=DirtProperties(**dictionary['dirt_props']), | ||||
|                           record_episodes=False, verbose=False, **dictionary['factory_props'] | ||||
|   | ||||
| @@ -15,12 +15,11 @@ from environments.helpers import Constants as c, Constants | ||||
| from environments import helpers as h | ||||
| from environments.factory.base.objects import Agent, Tile, Action | ||||
| from environments.factory.base.registers import Actions, Entities, Agents, Doors, FloorTiles, WallTiles, PlaceHolders | ||||
| from environments.utility_classes import MovementProperties, ObservationProperties | ||||
| from environments.utility_classes import MovementProperties, ObservationProperties, MarlFrameStack | ||||
| from environments.utility_classes import AgentRenderOptions as a_obs | ||||
|  | ||||
| import simplejson | ||||
|  | ||||
|  | ||||
| REC_TAC = 'rec_' | ||||
|  | ||||
|  | ||||
| @@ -57,7 +56,7 @@ class BaseFactory(gym.Env): | ||||
|  | ||||
|     def __enter__(self): | ||||
|         return self if self.obs_prop.frames_to_stack == 0 else \ | ||||
|             FrameStack(self, self.obs_prop.frames_to_stack) | ||||
|             MarlFrameStack(FrameStack(self, self.obs_prop.frames_to_stack)) | ||||
|  | ||||
|     def __exit__(self, exc_type, exc_val, exc_tb): | ||||
|         self.close() | ||||
|   | ||||
| @@ -1,4 +1,6 @@ | ||||
| from typing import NamedTuple, Union | ||||
| import gym | ||||
| from gym.wrappers.frame_stack import FrameStack | ||||
|  | ||||
|  | ||||
| class AgentRenderOptions(object): | ||||
| @@ -21,3 +23,14 @@ class ObservationProperties(NamedTuple): | ||||
|     cast_shadows = True | ||||
|     frames_to_stack: int = 0 | ||||
|     pomdp_r: int = 0 | ||||
|  | ||||
|  | ||||
| class MarlFrameStack(gym.ObservationWrapper): | ||||
|     def __init__(self, env): | ||||
|         super().__init__(env) | ||||
|  | ||||
|     def observation(self, observation): | ||||
|         if isinstance(self.env, FrameStack) and self.env.unwrapped.n_agents > 1: | ||||
|             return observation[0:].swapaxes(0, 1) | ||||
|         return observation | ||||
|  | ||||
|   | ||||
| @@ -9,15 +9,22 @@ from pathlib import Path | ||||
| import numpy as np | ||||
| from tqdm import tqdm | ||||
| import time | ||||
| from algorithms.utils import add_env_props, load_yaml_file, CombineActionsAgent | ||||
| from algorithms.utils import ( | ||||
|     add_env_props, | ||||
|     load_yaml_file, | ||||
|     CombineActionsAgent, | ||||
|     AutoResetGymMultiAgent, | ||||
|     access_str, | ||||
|     AGENT_PREFIX, REWARD, CUMU_REWARD, OBS, SEP | ||||
| ) | ||||
|  | ||||
|  | ||||
| class A2CAgent(TAgent): | ||||
|     def __init__(self, observation_size, hidden_size, n_actions, agent_id=-1, marl=False): | ||||
|     def __init__(self, observation_size, hidden_size, n_actions, agent_id): | ||||
|         super().__init__() | ||||
|         observation_size = np.prod(observation_size) | ||||
|         print(observation_size) | ||||
|         self.agent_id = agent_id | ||||
|         self.marl = marl | ||||
|         self.model = nn.Sequential( | ||||
|             nn.Flatten(), | ||||
|             nn.Linear(observation_size, hidden_size), | ||||
| @@ -31,10 +38,7 @@ class A2CAgent(TAgent): | ||||
|         self.critic_head = nn.Linear(hidden_size, 1) | ||||
|  | ||||
|     def get_obs(self, t): | ||||
|         observation = self.get(("env/env_obs", t)) | ||||
|         if self.marl: | ||||
|             observation = observation.permute(2, 0, 1, 3, 4, 5) | ||||
|             observation = observation[self.agent_id] | ||||
|         observation = self.get((f'env/{access_str(self.agent_id, OBS)}', t)) | ||||
|         return observation | ||||
|  | ||||
|     def forward(self, t, stochastic, **kwargs): | ||||
| @@ -47,17 +51,16 @@ class A2CAgent(TAgent): | ||||
|             action = torch.distributions.Categorical(probs).sample() | ||||
|         else: | ||||
|             action = probs.argmax(1) | ||||
|         agent_str = f'agent{self.agent_id}_' | ||||
|         self.set((f'{agent_str}action', t), action) | ||||
|         self.set((f'{agent_str}action_probs', t), probs) | ||||
|         self.set((f'{agent_str}critic', t), critic) | ||||
|         self.set((f'{access_str(self.agent_id, "action")}', t), action) | ||||
|         self.set((f'{access_str(self.agent_id, "action_probs")}', t), probs) | ||||
|         self.set((f'{access_str(self.agent_id, "critic")}', t), critic) | ||||
|  | ||||
|  | ||||
| if __name__ == '__main__': | ||||
|     # Setup workspace | ||||
|     uid = time.time() | ||||
|     workspace = Workspace() | ||||
|     n_agents = 1 | ||||
|     n_agents = 2 | ||||
|  | ||||
|     # load config | ||||
|     cfg = load_yaml_file(Path(__file__).parent / 'sat_mad.yaml') | ||||
| @@ -65,15 +68,14 @@ if __name__ == '__main__': | ||||
|     cfg['env'].update({'n_agents': n_agents}) | ||||
|  | ||||
|     # instantiate agent and env | ||||
|     env_agent = AutoResetGymAgent( | ||||
|     env_agent = AutoResetGymMultiAgent( | ||||
|         get_class(cfg['env']), | ||||
|         get_arguments(cfg['env']), | ||||
|         n_envs=1 | ||||
|     ) | ||||
|  | ||||
|     a2c_agents = [instantiate_class({**cfg['agent'], | ||||
|                                      'agent_id': agent_id, | ||||
|                                      'marl':     n_agents > 1}) | ||||
|                                      'agent_id': agent_id}) | ||||
|                   for agent_id in range(n_agents)] | ||||
|  | ||||
|     # combine agents | ||||
| @@ -99,11 +101,13 @@ if __name__ == '__main__': | ||||
|  | ||||
|             for agent_id in range(n_agents): | ||||
|                 critic, done, action_probs, reward, action = workspace[ | ||||
|                     f"agent{agent_id}_critic", "env/done", | ||||
|                     f'agent{agent_id}_action_probs', "env/reward", | ||||
|                     f"agent{agent_id}_action" | ||||
|                     access_str(agent_id, 'critic'), | ||||
|                     "env/done", | ||||
|                     access_str(agent_id, 'action_probs'), | ||||
|                     access_str(agent_id, 'reward', 'env/'), | ||||
|                     access_str(agent_id, 'action') | ||||
|                 ] | ||||
|                 td = gae(critic, reward, done, 0.99, 0.3) | ||||
|                 td = gae(critic, reward, done, 0.98, 0.25) | ||||
|                 td_error = td ** 2 | ||||
|                 critic_loss = td_error.mean() | ||||
|                 entropy_loss = Categorical(action_probs).entropy().mean() | ||||
| @@ -118,16 +122,18 @@ if __name__ == '__main__': | ||||
|                 optimizer = optimizers[agent_id] | ||||
|                 optimizer.zero_grad() | ||||
|                 loss.backward() | ||||
|                 #torch.nn.utils.clip_grad_norm_(a2c_agents[agent_id].parameters(), 2) | ||||
|                 #torch.nn.utils.clip_grad_norm_(a2c_agents[agent_id].parameters(), .5) | ||||
|                 optimizer.step() | ||||
|  | ||||
|                 # Compute the cumulated reward on final_state | ||||
|                 creward = workspace["env/cumulated_reward"] | ||||
|                 creward = creward[done] | ||||
|                 if creward.size()[0] > 0: | ||||
|                     cum_r = creward.mean().item() | ||||
|                     if cum_r > best: | ||||
|                     #    torch.save(a2c_agent.state_dict(), Path(__file__).parent / f'agent_{uid}.pt') | ||||
|                         best = cum_r | ||||
|                     pbar.set_description(f"Cum. r: {cum_r:.2f}, Best r. so far: {best:.2f}", refresh=True) | ||||
|                 rews = '' | ||||
|                 for agent_i in range(n_agents): | ||||
|                     creward = workspace['env/'+access_str(agent_i, CUMU_REWARD)] | ||||
|                     creward = creward[done] | ||||
|                     if creward.size()[0] > 0: | ||||
|                         rews += f'{AGENT_PREFIX}{agent_i}: {creward.mean().item():.2f}  |  ' | ||||
|                         """if cum_r > best: | ||||
|                             torch.save(a2c_agent.state_dict(), Path(__file__).parent / f'agent_{uid}.pt') | ||||
|                             best = cum_r""" | ||||
|                         pbar.set_description(rews, refresh=True) | ||||
|  | ||||
|   | ||||
| @@ -5,21 +5,22 @@ agent: | ||||
|   n_actions:        10 | ||||
|  | ||||
| env: | ||||
|   classname:      environments.factory.make | ||||
|   env_name:       "DirtyFactory-v0" | ||||
|   n_agents:       1 | ||||
|   pomdp_r:        2 | ||||
|   max_steps:      400 | ||||
|   stack_n_frames: 3 | ||||
|   classname:          environments.factory.make | ||||
|   env_name:           "DirtyFactory-v0" | ||||
|   n_agents:           1 | ||||
|   pomdp_r:            2 | ||||
|   max_steps:          400 | ||||
|   stack_n_frames:     3 | ||||
|   individual_rewards: True | ||||
|  | ||||
| algorithm: | ||||
|   max_epochs:             1000000 | ||||
|   n_envs:                 1 | ||||
|   n_timesteps:            16 | ||||
|   n_timesteps:            10 | ||||
|   discount_factor:        0.99 | ||||
|   entropy_coef:           0.01 | ||||
|   critic_coef:            1.0 | ||||
|   gae:                    0.3 | ||||
|   gae:                    0.25 | ||||
|   optimizer: | ||||
|     classname:            torch.optim.Adam | ||||
|     lr:                   0.0003 | ||||
|   | ||||
		Reference in New Issue
	
	Block a user
	 Steffen Illium
					Steffen Illium