diff --git a/algorithms/utils.py b/algorithms/utils.py index bf3e2cb..d72046a 100644 --- a/algorithms/utils.py +++ b/algorithms/utils.py @@ -1,9 +1,16 @@ import re import torch +import numpy as np import yaml from pathlib import Path from salina import instantiate_class from salina import TAgent +from salina.agents.gyma import ( + AutoResetGymAgent, + _torch_type, + _format_frame, + _torch_cat_dict +) def load_yaml_file(path: Path): @@ -18,13 +25,86 @@ def add_env_props(cfg): n_actions=env.action_space.n)) + + +AGENT_PREFIX = 'agent#' +REWARD = 'reward' +CUMU_REWARD = 'cumulated_reward' +OBS = 'env_obs' +SEP = '_' +ACTION = 'action' + + +def access_str(agent_i, name, prefix=''): + return f'{prefix}{AGENT_PREFIX}{agent_i}{SEP}{name}' + + +class AutoResetGymMultiAgent(AutoResetGymAgent): + def __init__(self, *args, **kwargs): + super(AutoResetGymMultiAgent, self).__init__(*args, **kwargs) + + def per_agent_values(self, name, values): + return {access_str(agent_i, name): value + for agent_i, value in zip(range(self.n_agents), values)} + + def _initialize_envs(self, n): + super()._initialize_envs(n) + n_agents_list = [self.envs[i].unwrapped.n_agents for i in range(n)] + assert all(n_agents == n_agents_list[0] for n_agents in n_agents_list), \ + 'All envs must have the same number of agents.' + self.n_agents = n_agents_list[0] + + def _reset(self, k, save_render): + ret = super()._reset(k, save_render) + obs = ret['env_obs'].squeeze() + self.cumulated_reward[k] = [0.0]*self.n_agents + obs = self.per_agent_values(OBS, [_format_frame(obs[i]) for i in range(self.n_agents)]) + cumu_rew = self.per_agent_values(CUMU_REWARD, torch.zeros(self.n_agents, 1).float().unbind()) + rewards = self.per_agent_values(REWARD, torch.zeros(self.n_agents, 1).float().unbind()) + ret.update(cumu_rew) + ret.update(rewards) + ret.update(obs) + for remove in ['env_obs', 'cumulated_reward', 'reward']: + del ret[remove] + return ret + + def _step(self, k, action, save_render): + self.timestep[k] += 1 + env = self.envs[k] + if len(action.size()) == 0: + action = action.item() + assert isinstance(action, int) + else: + action = np.array(action.tolist()) + o, r, d, _ = env.step(action) + self.cumulated_reward[k] = [x+y for x, y in zip(r, self.cumulated_reward[k])] + observation = self.per_agent_values(OBS, [_format_frame(o[i]) for i in range(self.n_agents)]) + if d: + self.is_running[k] = False + if save_render: + image = env.render(mode="image").unsqueeze(0) + observation["rendering"] = image + rewards = self.per_agent_values(REWARD, torch.tensor(r).float().view(-1, 1).unbind()) + cumulated_rewards = self.per_agent_values(CUMU_REWARD, torch.tensor(self.cumulated_reward[k]).float().view(-1, 1).unbind()) + ret = { + **observation, + **rewards, + **cumulated_rewards, + "done": torch.tensor([d]), + "initial_state": torch.tensor([False]), + "timestep": torch.tensor([self.timestep[k]]) + } + return _torch_type(ret) + + class CombineActionsAgent(TAgent): - def __init__(self, pattern=r'^agent\d_action$'): + def __init__(self): super().__init__() - self.pattern = pattern + self.pattern = fr'^{AGENT_PREFIX}\d{SEP}{ACTION}$' def forward(self, t, **kwargs): keys = list(self.workspace.keys()) action_keys = sorted([k for k in keys if bool(re.match(self.pattern, k))]) actions = torch.cat([self.get((k, t)) for k in action_keys], 0) - self.set((f'action', t), actions.unsqueeze(0)) + actions = actions if len(action_keys) <= 1 else actions.unsqueeze(0) + self.set((f'action', t), actions) diff --git a/environments/factory/__init__.py b/environments/factory/__init__.py index dca0135..23346f9 100644 --- a/environments/factory/__init__.py +++ b/environments/factory/__init__.py @@ -1,4 +1,4 @@ -def make(env_name, n_agents=1, pomdp_r=2, max_steps=400, stack_n_frames=3): +def make(env_name, pomdp_r=2, max_steps=400, stack_n_frames=3, n_agents=1, individual_rewards=False): import yaml from pathlib import Path from environments.factory.combined_factories import DirtItemFactory @@ -12,7 +12,8 @@ def make(env_name, n_agents=1, pomdp_r=2, max_steps=400, stack_n_frames=3): obs_props = ObservationProperties(render_agents=AgentRenderOptions.COMBINED, frames_to_stack=stack_n_frames, pomdp_r=pomdp_r) - factory_kwargs = dict(n_agents=n_agents, max_steps=max_steps, obs_prop=obs_props, + factory_kwargs = dict(n_agents=n_agents, individual_rewards=individual_rewards, + max_steps=max_steps, obs_prop=obs_props, mv_prop=MovementProperties(**dictionary['movement_props']), dirt_prop=DirtProperties(**dictionary['dirt_props']), record_episodes=False, verbose=False, **dictionary['factory_props'] diff --git a/environments/factory/base/base_factory.py b/environments/factory/base/base_factory.py index 889b949..db08573 100644 --- a/environments/factory/base/base_factory.py +++ b/environments/factory/base/base_factory.py @@ -15,12 +15,11 @@ from environments.helpers import Constants as c, Constants from environments import helpers as h from environments.factory.base.objects import Agent, Tile, Action from environments.factory.base.registers import Actions, Entities, Agents, Doors, FloorTiles, WallTiles, PlaceHolders -from environments.utility_classes import MovementProperties, ObservationProperties +from environments.utility_classes import MovementProperties, ObservationProperties, MarlFrameStack from environments.utility_classes import AgentRenderOptions as a_obs import simplejson - REC_TAC = 'rec_' @@ -57,7 +56,7 @@ class BaseFactory(gym.Env): def __enter__(self): return self if self.obs_prop.frames_to_stack == 0 else \ - FrameStack(self, self.obs_prop.frames_to_stack) + MarlFrameStack(FrameStack(self, self.obs_prop.frames_to_stack)) def __exit__(self, exc_type, exc_val, exc_tb): self.close() diff --git a/environments/utility_classes.py b/environments/utility_classes.py index 5ca2caf..cdcfd56 100644 --- a/environments/utility_classes.py +++ b/environments/utility_classes.py @@ -1,4 +1,6 @@ from typing import NamedTuple, Union +import gym +from gym.wrappers.frame_stack import FrameStack class AgentRenderOptions(object): @@ -21,3 +23,14 @@ class ObservationProperties(NamedTuple): cast_shadows = True frames_to_stack: int = 0 pomdp_r: int = 0 + + +class MarlFrameStack(gym.ObservationWrapper): + def __init__(self, env): + super().__init__(env) + + def observation(self, observation): + if isinstance(self.env, FrameStack) and self.env.unwrapped.n_agents > 1: + return observation[0:].swapaxes(0, 1) + return observation + diff --git a/studies/sat_mad.py b/studies/sat_mad.py index 35ddda5..6bf3e22 100644 --- a/studies/sat_mad.py +++ b/studies/sat_mad.py @@ -9,15 +9,22 @@ from pathlib import Path import numpy as np from tqdm import tqdm import time -from algorithms.utils import add_env_props, load_yaml_file, CombineActionsAgent +from algorithms.utils import ( + add_env_props, + load_yaml_file, + CombineActionsAgent, + AutoResetGymMultiAgent, + access_str, + AGENT_PREFIX, REWARD, CUMU_REWARD, OBS, SEP +) class A2CAgent(TAgent): - def __init__(self, observation_size, hidden_size, n_actions, agent_id=-1, marl=False): + def __init__(self, observation_size, hidden_size, n_actions, agent_id): super().__init__() observation_size = np.prod(observation_size) + print(observation_size) self.agent_id = agent_id - self.marl = marl self.model = nn.Sequential( nn.Flatten(), nn.Linear(observation_size, hidden_size), @@ -31,10 +38,7 @@ class A2CAgent(TAgent): self.critic_head = nn.Linear(hidden_size, 1) def get_obs(self, t): - observation = self.get(("env/env_obs", t)) - if self.marl: - observation = observation.permute(2, 0, 1, 3, 4, 5) - observation = observation[self.agent_id] + observation = self.get((f'env/{access_str(self.agent_id, OBS)}', t)) return observation def forward(self, t, stochastic, **kwargs): @@ -47,17 +51,16 @@ class A2CAgent(TAgent): action = torch.distributions.Categorical(probs).sample() else: action = probs.argmax(1) - agent_str = f'agent{self.agent_id}_' - self.set((f'{agent_str}action', t), action) - self.set((f'{agent_str}action_probs', t), probs) - self.set((f'{agent_str}critic', t), critic) + self.set((f'{access_str(self.agent_id, "action")}', t), action) + self.set((f'{access_str(self.agent_id, "action_probs")}', t), probs) + self.set((f'{access_str(self.agent_id, "critic")}', t), critic) if __name__ == '__main__': # Setup workspace uid = time.time() workspace = Workspace() - n_agents = 1 + n_agents = 2 # load config cfg = load_yaml_file(Path(__file__).parent / 'sat_mad.yaml') @@ -65,15 +68,14 @@ if __name__ == '__main__': cfg['env'].update({'n_agents': n_agents}) # instantiate agent and env - env_agent = AutoResetGymAgent( + env_agent = AutoResetGymMultiAgent( get_class(cfg['env']), get_arguments(cfg['env']), n_envs=1 ) a2c_agents = [instantiate_class({**cfg['agent'], - 'agent_id': agent_id, - 'marl': n_agents > 1}) + 'agent_id': agent_id}) for agent_id in range(n_agents)] # combine agents @@ -99,11 +101,13 @@ if __name__ == '__main__': for agent_id in range(n_agents): critic, done, action_probs, reward, action = workspace[ - f"agent{agent_id}_critic", "env/done", - f'agent{agent_id}_action_probs', "env/reward", - f"agent{agent_id}_action" + access_str(agent_id, 'critic'), + "env/done", + access_str(agent_id, 'action_probs'), + access_str(agent_id, 'reward', 'env/'), + access_str(agent_id, 'action') ] - td = gae(critic, reward, done, 0.99, 0.3) + td = gae(critic, reward, done, 0.98, 0.25) td_error = td ** 2 critic_loss = td_error.mean() entropy_loss = Categorical(action_probs).entropy().mean() @@ -118,16 +122,18 @@ if __name__ == '__main__': optimizer = optimizers[agent_id] optimizer.zero_grad() loss.backward() - #torch.nn.utils.clip_grad_norm_(a2c_agents[agent_id].parameters(), 2) + #torch.nn.utils.clip_grad_norm_(a2c_agents[agent_id].parameters(), .5) optimizer.step() # Compute the cumulated reward on final_state - creward = workspace["env/cumulated_reward"] - creward = creward[done] - if creward.size()[0] > 0: - cum_r = creward.mean().item() - if cum_r > best: - # torch.save(a2c_agent.state_dict(), Path(__file__).parent / f'agent_{uid}.pt') - best = cum_r - pbar.set_description(f"Cum. r: {cum_r:.2f}, Best r. so far: {best:.2f}", refresh=True) + rews = '' + for agent_i in range(n_agents): + creward = workspace['env/'+access_str(agent_i, CUMU_REWARD)] + creward = creward[done] + if creward.size()[0] > 0: + rews += f'{AGENT_PREFIX}{agent_i}: {creward.mean().item():.2f} | ' + """if cum_r > best: + torch.save(a2c_agent.state_dict(), Path(__file__).parent / f'agent_{uid}.pt') + best = cum_r""" + pbar.set_description(rews, refresh=True) diff --git a/studies/sat_mad.yaml b/studies/sat_mad.yaml index 5f7ea46..7e7ca71 100644 --- a/studies/sat_mad.yaml +++ b/studies/sat_mad.yaml @@ -5,21 +5,22 @@ agent: n_actions: 10 env: - classname: environments.factory.make - env_name: "DirtyFactory-v0" - n_agents: 1 - pomdp_r: 2 - max_steps: 400 - stack_n_frames: 3 + classname: environments.factory.make + env_name: "DirtyFactory-v0" + n_agents: 1 + pomdp_r: 2 + max_steps: 400 + stack_n_frames: 3 + individual_rewards: True algorithm: max_epochs: 1000000 n_envs: 1 - n_timesteps: 16 + n_timesteps: 10 discount_factor: 0.99 entropy_coef: 0.01 critic_coef: 1.0 - gae: 0.3 + gae: 0.25 optimizer: classname: torch.optim.Adam lr: 0.0003