Merge remote-tracking branch 'origin/main'

This commit is contained in:
Steffen Illium 2021-11-24 17:39:42 +01:00
commit 3833a0a064
6 changed files with 144 additions and 44 deletions

View File

@ -1,9 +1,16 @@
import re import re
import torch import torch
import numpy as np
import yaml import yaml
from pathlib import Path from pathlib import Path
from salina import instantiate_class from salina import instantiate_class
from salina import TAgent from salina import TAgent
from salina.agents.gyma import (
AutoResetGymAgent,
_torch_type,
_format_frame,
_torch_cat_dict
)
def load_yaml_file(path: Path): def load_yaml_file(path: Path):
@ -18,13 +25,86 @@ def add_env_props(cfg):
n_actions=env.action_space.n)) n_actions=env.action_space.n))
AGENT_PREFIX = 'agent#'
REWARD = 'reward'
CUMU_REWARD = 'cumulated_reward'
OBS = 'env_obs'
SEP = '_'
ACTION = 'action'
def access_str(agent_i, name, prefix=''):
return f'{prefix}{AGENT_PREFIX}{agent_i}{SEP}{name}'
class AutoResetGymMultiAgent(AutoResetGymAgent):
def __init__(self, *args, **kwargs):
super(AutoResetGymMultiAgent, self).__init__(*args, **kwargs)
def per_agent_values(self, name, values):
return {access_str(agent_i, name): value
for agent_i, value in zip(range(self.n_agents), values)}
def _initialize_envs(self, n):
super()._initialize_envs(n)
n_agents_list = [self.envs[i].unwrapped.n_agents for i in range(n)]
assert all(n_agents == n_agents_list[0] for n_agents in n_agents_list), \
'All envs must have the same number of agents.'
self.n_agents = n_agents_list[0]
def _reset(self, k, save_render):
ret = super()._reset(k, save_render)
obs = ret['env_obs'].squeeze()
self.cumulated_reward[k] = [0.0]*self.n_agents
obs = self.per_agent_values(OBS, [_format_frame(obs[i]) for i in range(self.n_agents)])
cumu_rew = self.per_agent_values(CUMU_REWARD, torch.zeros(self.n_agents, 1).float().unbind())
rewards = self.per_agent_values(REWARD, torch.zeros(self.n_agents, 1).float().unbind())
ret.update(cumu_rew)
ret.update(rewards)
ret.update(obs)
for remove in ['env_obs', 'cumulated_reward', 'reward']:
del ret[remove]
return ret
def _step(self, k, action, save_render):
self.timestep[k] += 1
env = self.envs[k]
if len(action.size()) == 0:
action = action.item()
assert isinstance(action, int)
else:
action = np.array(action.tolist())
o, r, d, _ = env.step(action)
self.cumulated_reward[k] = [x+y for x, y in zip(r, self.cumulated_reward[k])]
observation = self.per_agent_values(OBS, [_format_frame(o[i]) for i in range(self.n_agents)])
if d:
self.is_running[k] = False
if save_render:
image = env.render(mode="image").unsqueeze(0)
observation["rendering"] = image
rewards = self.per_agent_values(REWARD, torch.tensor(r).float().view(-1, 1).unbind())
cumulated_rewards = self.per_agent_values(CUMU_REWARD, torch.tensor(self.cumulated_reward[k]).float().view(-1, 1).unbind())
ret = {
**observation,
**rewards,
**cumulated_rewards,
"done": torch.tensor([d]),
"initial_state": torch.tensor([False]),
"timestep": torch.tensor([self.timestep[k]])
}
return _torch_type(ret)
class CombineActionsAgent(TAgent): class CombineActionsAgent(TAgent):
def __init__(self, pattern=r'^agent\d_action$'): def __init__(self):
super().__init__() super().__init__()
self.pattern = pattern self.pattern = fr'^{AGENT_PREFIX}\d{SEP}{ACTION}$'
def forward(self, t, **kwargs): def forward(self, t, **kwargs):
keys = list(self.workspace.keys()) keys = list(self.workspace.keys())
action_keys = sorted([k for k in keys if bool(re.match(self.pattern, k))]) action_keys = sorted([k for k in keys if bool(re.match(self.pattern, k))])
actions = torch.cat([self.get((k, t)) for k in action_keys], 0) actions = torch.cat([self.get((k, t)) for k in action_keys], 0)
self.set((f'action', t), actions.unsqueeze(0)) actions = actions if len(action_keys) <= 1 else actions.unsqueeze(0)
self.set((f'action', t), actions)

View File

@ -1,4 +1,4 @@
def make(env_name, n_agents=1, pomdp_r=2, max_steps=400, stack_n_frames=3): def make(env_name, pomdp_r=2, max_steps=400, stack_n_frames=3, n_agents=1, individual_rewards=False):
import yaml import yaml
from pathlib import Path from pathlib import Path
from environments.factory.combined_factories import DirtItemFactory from environments.factory.combined_factories import DirtItemFactory
@ -12,7 +12,8 @@ def make(env_name, n_agents=1, pomdp_r=2, max_steps=400, stack_n_frames=3):
obs_props = ObservationProperties(render_agents=AgentRenderOptions.COMBINED, obs_props = ObservationProperties(render_agents=AgentRenderOptions.COMBINED,
frames_to_stack=stack_n_frames, pomdp_r=pomdp_r) frames_to_stack=stack_n_frames, pomdp_r=pomdp_r)
factory_kwargs = dict(n_agents=n_agents, max_steps=max_steps, obs_prop=obs_props, factory_kwargs = dict(n_agents=n_agents, individual_rewards=individual_rewards,
max_steps=max_steps, obs_prop=obs_props,
mv_prop=MovementProperties(**dictionary['movement_props']), mv_prop=MovementProperties(**dictionary['movement_props']),
dirt_prop=DirtProperties(**dictionary['dirt_props']), dirt_prop=DirtProperties(**dictionary['dirt_props']),
record_episodes=False, verbose=False, **dictionary['factory_props'] record_episodes=False, verbose=False, **dictionary['factory_props']

View File

@ -15,12 +15,11 @@ from environments.helpers import Constants as c, Constants
from environments import helpers as h from environments import helpers as h
from environments.factory.base.objects import Agent, Tile, Action from environments.factory.base.objects import Agent, Tile, Action
from environments.factory.base.registers import Actions, Entities, Agents, Doors, FloorTiles, WallTiles, PlaceHolders from environments.factory.base.registers import Actions, Entities, Agents, Doors, FloorTiles, WallTiles, PlaceHolders
from environments.utility_classes import MovementProperties, ObservationProperties from environments.utility_classes import MovementProperties, ObservationProperties, MarlFrameStack
from environments.utility_classes import AgentRenderOptions as a_obs from environments.utility_classes import AgentRenderOptions as a_obs
import simplejson import simplejson
REC_TAC = 'rec_' REC_TAC = 'rec_'
@ -57,7 +56,7 @@ class BaseFactory(gym.Env):
def __enter__(self): def __enter__(self):
return self if self.obs_prop.frames_to_stack == 0 else \ return self if self.obs_prop.frames_to_stack == 0 else \
FrameStack(self, self.obs_prop.frames_to_stack) MarlFrameStack(FrameStack(self, self.obs_prop.frames_to_stack))
def __exit__(self, exc_type, exc_val, exc_tb): def __exit__(self, exc_type, exc_val, exc_tb):
self.close() self.close()

View File

@ -1,4 +1,6 @@
from typing import NamedTuple, Union from typing import NamedTuple, Union
import gym
from gym.wrappers.frame_stack import FrameStack
class AgentRenderOptions(object): class AgentRenderOptions(object):
@ -21,3 +23,14 @@ class ObservationProperties(NamedTuple):
cast_shadows = True cast_shadows = True
frames_to_stack: int = 0 frames_to_stack: int = 0
pomdp_r: int = 0 pomdp_r: int = 0
class MarlFrameStack(gym.ObservationWrapper):
def __init__(self, env):
super().__init__(env)
def observation(self, observation):
if isinstance(self.env, FrameStack) and self.env.unwrapped.n_agents > 1:
return observation[0:].swapaxes(0, 1)
return observation

View File

@ -9,15 +9,22 @@ from pathlib import Path
import numpy as np import numpy as np
from tqdm import tqdm from tqdm import tqdm
import time import time
from algorithms.utils import add_env_props, load_yaml_file, CombineActionsAgent from algorithms.utils import (
add_env_props,
load_yaml_file,
CombineActionsAgent,
AutoResetGymMultiAgent,
access_str,
AGENT_PREFIX, REWARD, CUMU_REWARD, OBS, SEP
)
class A2CAgent(TAgent): class A2CAgent(TAgent):
def __init__(self, observation_size, hidden_size, n_actions, agent_id=-1, marl=False): def __init__(self, observation_size, hidden_size, n_actions, agent_id):
super().__init__() super().__init__()
observation_size = np.prod(observation_size) observation_size = np.prod(observation_size)
print(observation_size)
self.agent_id = agent_id self.agent_id = agent_id
self.marl = marl
self.model = nn.Sequential( self.model = nn.Sequential(
nn.Flatten(), nn.Flatten(),
nn.Linear(observation_size, hidden_size), nn.Linear(observation_size, hidden_size),
@ -31,10 +38,7 @@ class A2CAgent(TAgent):
self.critic_head = nn.Linear(hidden_size, 1) self.critic_head = nn.Linear(hidden_size, 1)
def get_obs(self, t): def get_obs(self, t):
observation = self.get(("env/env_obs", t)) observation = self.get((f'env/{access_str(self.agent_id, OBS)}', t))
if self.marl:
observation = observation.permute(2, 0, 1, 3, 4, 5)
observation = observation[self.agent_id]
return observation return observation
def forward(self, t, stochastic, **kwargs): def forward(self, t, stochastic, **kwargs):
@ -47,17 +51,16 @@ class A2CAgent(TAgent):
action = torch.distributions.Categorical(probs).sample() action = torch.distributions.Categorical(probs).sample()
else: else:
action = probs.argmax(1) action = probs.argmax(1)
agent_str = f'agent{self.agent_id}_' self.set((f'{access_str(self.agent_id, "action")}', t), action)
self.set((f'{agent_str}action', t), action) self.set((f'{access_str(self.agent_id, "action_probs")}', t), probs)
self.set((f'{agent_str}action_probs', t), probs) self.set((f'{access_str(self.agent_id, "critic")}', t), critic)
self.set((f'{agent_str}critic', t), critic)
if __name__ == '__main__': if __name__ == '__main__':
# Setup workspace # Setup workspace
uid = time.time() uid = time.time()
workspace = Workspace() workspace = Workspace()
n_agents = 1 n_agents = 2
# load config # load config
cfg = load_yaml_file(Path(__file__).parent / 'sat_mad.yaml') cfg = load_yaml_file(Path(__file__).parent / 'sat_mad.yaml')
@ -65,15 +68,14 @@ if __name__ == '__main__':
cfg['env'].update({'n_agents': n_agents}) cfg['env'].update({'n_agents': n_agents})
# instantiate agent and env # instantiate agent and env
env_agent = AutoResetGymAgent( env_agent = AutoResetGymMultiAgent(
get_class(cfg['env']), get_class(cfg['env']),
get_arguments(cfg['env']), get_arguments(cfg['env']),
n_envs=1 n_envs=1
) )
a2c_agents = [instantiate_class({**cfg['agent'], a2c_agents = [instantiate_class({**cfg['agent'],
'agent_id': agent_id, 'agent_id': agent_id})
'marl': n_agents > 1})
for agent_id in range(n_agents)] for agent_id in range(n_agents)]
# combine agents # combine agents
@ -99,11 +101,13 @@ if __name__ == '__main__':
for agent_id in range(n_agents): for agent_id in range(n_agents):
critic, done, action_probs, reward, action = workspace[ critic, done, action_probs, reward, action = workspace[
f"agent{agent_id}_critic", "env/done", access_str(agent_id, 'critic'),
f'agent{agent_id}_action_probs', "env/reward", "env/done",
f"agent{agent_id}_action" access_str(agent_id, 'action_probs'),
access_str(agent_id, 'reward', 'env/'),
access_str(agent_id, 'action')
] ]
td = gae(critic, reward, done, 0.99, 0.3) td = gae(critic, reward, done, 0.98, 0.25)
td_error = td ** 2 td_error = td ** 2
critic_loss = td_error.mean() critic_loss = td_error.mean()
entropy_loss = Categorical(action_probs).entropy().mean() entropy_loss = Categorical(action_probs).entropy().mean()
@ -118,16 +122,18 @@ if __name__ == '__main__':
optimizer = optimizers[agent_id] optimizer = optimizers[agent_id]
optimizer.zero_grad() optimizer.zero_grad()
loss.backward() loss.backward()
#torch.nn.utils.clip_grad_norm_(a2c_agents[agent_id].parameters(), 2) #torch.nn.utils.clip_grad_norm_(a2c_agents[agent_id].parameters(), .5)
optimizer.step() optimizer.step()
# Compute the cumulated reward on final_state # Compute the cumulated reward on final_state
creward = workspace["env/cumulated_reward"] rews = ''
creward = creward[done] for agent_i in range(n_agents):
if creward.size()[0] > 0: creward = workspace['env/'+access_str(agent_i, CUMU_REWARD)]
cum_r = creward.mean().item() creward = creward[done]
if cum_r > best: if creward.size()[0] > 0:
# torch.save(a2c_agent.state_dict(), Path(__file__).parent / f'agent_{uid}.pt') rews += f'{AGENT_PREFIX}{agent_i}: {creward.mean().item():.2f} | '
best = cum_r """if cum_r > best:
pbar.set_description(f"Cum. r: {cum_r:.2f}, Best r. so far: {best:.2f}", refresh=True) torch.save(a2c_agent.state_dict(), Path(__file__).parent / f'agent_{uid}.pt')
best = cum_r"""
pbar.set_description(rews, refresh=True)

View File

@ -5,21 +5,22 @@ agent:
n_actions: 10 n_actions: 10
env: env:
classname: environments.factory.make classname: environments.factory.make
env_name: "DirtyFactory-v0" env_name: "DirtyFactory-v0"
n_agents: 1 n_agents: 1
pomdp_r: 2 pomdp_r: 2
max_steps: 400 max_steps: 400
stack_n_frames: 3 stack_n_frames: 3
individual_rewards: True
algorithm: algorithm:
max_epochs: 1000000 max_epochs: 1000000
n_envs: 1 n_envs: 1
n_timesteps: 16 n_timesteps: 10
discount_factor: 0.99 discount_factor: 0.99
entropy_coef: 0.01 entropy_coef: 0.01
critic_coef: 1.0 critic_coef: 1.0
gae: 0.3 gae: 0.25
optimizer: optimizer:
classname: torch.optim.Adam classname: torch.optim.Adam
lr: 0.0003 lr: 0.0003