multi_agent observation when n_agent more then 1

This commit is contained in:
steffen-illium 2021-06-09 13:12:49 +02:00
parent 62c141aa1c
commit cf2378a734
6 changed files with 271 additions and 159 deletions

View File

@ -1,6 +1,6 @@
from argparse import Namespace from argparse import Namespace
from pathlib import Path from pathlib import Path
from typing import List, Union, Iterable, NamedTuple from typing import List, Union, Iterable
import gym import gym
import numpy as np import numpy as np
@ -9,115 +9,7 @@ from gym import spaces
import yaml import yaml
from environments import helpers as h from environments import helpers as h
from environments.utility_classes import Actions, StateSlice, AgentState, MovementProperties
class MovementProperties(NamedTuple):
allow_square_movement: bool = True
allow_diagonal_movement: bool = False
allow_no_op: bool = False
class Entity():
@property
def pos(self):
return self._pos
def __init__(self, pos):
self._pos = pos
class AgentState:
def __init__(self, i: int, action: int):
self.i = i
self.action = action
self.collision_vector = None
self.action_valid = None
self.pos = None
self.info = {}
@property
def collisions(self):
return np.argwhere(self.collision_vector != 0).flatten()
def update(self, **kwargs): # is this hacky?? o.0
for key, value in kwargs.items():
if hasattr(self, key):
self.__setattr__(key, value)
else:
raise AttributeError(f'"{key}" cannot be updated, this attr is not a part of {self.__class__.__name__}')
class Register:
@property
def n(self):
return len(self)
def __init__(self):
self._register = dict()
def __len__(self):
return len(self._register)
def __add__(self, other: Union[str, List[str]]):
other = other if isinstance(other, list) else [other]
assert all([isinstance(x, str) for x in other]), f'All item names have to be of type {str}.'
self._register.update({key+len(self._register): value for key, value in enumerate(other)})
return self
def register_additional_items(self, other: Union[str, List[str]]):
self_with_additional_items = self + other
return self_with_additional_items
def keys(self):
return self._register.keys()
def items(self):
return self._register.items()
def __getitem__(self, item):
return self._register[item]
def by_name(self, item):
return list(self._register.keys())[list(self._register.values()).index(item)]
def __repr__(self):
return f'{self.__class__.__name__}({self._register})'
class Actions(Register):
@property
def movement_actions(self):
return self._movement_actions
def __init__(self, movement_properties: MovementProperties):
self.allow_no_op = movement_properties.allow_no_op
self.allow_diagonal_movement = movement_properties.allow_diagonal_movement
self.allow_square_movement = movement_properties.allow_square_movement
# FIXME: There is a bug in helpers because there actions are ints. and the order matters.
assert not(self.allow_square_movement is False and self.allow_diagonal_movement is True), \
"There is a bug in helpers!!!"
super(Actions, self).__init__()
if self.allow_square_movement:
self + ['north', 'east', 'south', 'west']
if self.allow_diagonal_movement:
self + ['north-east', 'south-east', 'south-west', 'north-west']
self._movement_actions = self._register.copy()
if self.allow_no_op:
self + 'no-op'
class StateSlice(Register):
def __init__(self, n_agents: int):
super(StateSlice, self).__init__()
offset = 1
self.register_additional_items(['level', *[f'agent#{i}' for i in range(offset, n_agents+offset)]])
# noinspection PyAttributeOutsideInit # noinspection PyAttributeOutsideInit
@ -148,9 +40,11 @@ class BaseFactory(gym.Env):
def movement_actions(self): def movement_actions(self):
return self._actions.movement_actions return self._actions.movement_actions
def __init__(self, level_name='simple', n_agents=1, max_steps=int(5e2), pomdp_radius: Union[None, int] = None, def __init__(self, level_name='simple', n_agents=1, max_steps=int(5e2), pomdp_radius: Union[None, int] = 0,
movement_properties: MovementProperties = MovementProperties(), movement_properties: MovementProperties = MovementProperties(),
combin_agent_slices_in_obs: bool = False,
omit_agent_slice_in_obs=False, **kwargs): omit_agent_slice_in_obs=False, **kwargs):
assert combin_agent_slices_in_obs != omit_agent_slice_in_obs, 'Both options are exclusive'
self.movement_properties = movement_properties self.movement_properties = movement_properties
self.level_name = level_name self.level_name = level_name
@ -158,6 +52,7 @@ class BaseFactory(gym.Env):
self.n_agents = n_agents self.n_agents = n_agents
self.max_steps = max_steps self.max_steps = max_steps
self.pomdp_radius = pomdp_radius self.pomdp_radius = pomdp_radius
self.combin_agent_slices_in_obs = combin_agent_slices_in_obs
self.omit_agent_slice_in_obs = omit_agent_slice_in_obs self.omit_agent_slice_in_obs = omit_agent_slice_in_obs
self.done_at_collision = False self.done_at_collision = False
@ -185,7 +80,7 @@ class BaseFactory(gym.Env):
raise NotImplementedError('Please register additional actions ') raise NotImplementedError('Please register additional actions ')
def reset(self) -> (np.ndarray, int, bool, dict): def reset(self) -> (np.ndarray, int, bool, dict):
self.steps = 0 self._steps = 0
self._agent_states = [] self._agent_states = []
# Agent placement ... # Agent placement ...
agents = np.zeros((self.n_agents, *self._level.shape), dtype=np.int8) agents = np.zeros((self.n_agents, *self._level.shape), dtype=np.int8)
@ -202,17 +97,25 @@ class BaseFactory(gym.Env):
# Returns State # Returns State
return None return None
def _return_state(self): def _get_observations(self) -> np.ndarray:
if self.n_agents == 1:
obs = self._build_per_agent_obs(0)
elif self.n_agents >= 2:
obs = np.stack([self._build_per_agent_obs(agent_i) for agent_i in range(self.n_agents)])
return obs
def _build_per_agent_obs(self, agent_i: int) -> np.ndarray:
if self.pomdp_radius: if self.pomdp_radius:
pos = self._agent_states[0].pos global_pos = self._agent_states[agent_i].pos
# pos = [agent_state.pos for agent_state in self.agent_states] x0, x1 = max(0, global_pos[0] - self.pomdp_radius), global_pos[0] + self.pomdp_radius + 1
# obs = [] ... list comprehension... pos per agent y0, y1 = max(0, global_pos[1] - self.pomdp_radius), global_pos[1] + self.pomdp_radius + 1
x0, x1 = max(0, pos[0] - self.pomdp_radius), pos[0] + self.pomdp_radius + 1
y0, y1 = max(0, pos[1] - self.pomdp_radius), pos[1] + self.pomdp_radius + 1
obs = self._state[:, x0:x1, y0:y1] obs = self._state[:, x0:x1, y0:y1]
if obs.shape[1] != self.pomdp_radius * 2 + 1 or obs.shape[2] != self.pomdp_radius * 2 + 1: if obs.shape[1] != self.pomdp_radius * 2 + 1 or obs.shape[2] != self.pomdp_radius * 2 + 1:
obs_padded = np.full((obs.shape[0], self.pomdp_radius * 2 + 1, self.pomdp_radius * 2 + 1), 1) obs_padded = np.full((obs.shape[0], self.pomdp_radius * 2 + 1, self.pomdp_radius * 2 + 1), 1)
a_pos = np.argwhere(obs[h.AGENT_START_IDX] == h.IS_OCCUPIED_CELL)[0] try:
a_pos = np.argwhere(obs[h.AGENT_START_IDX + agent_i] == h.IS_OCCUPIED_CELL)[0]
except IndexError:
print('NO')
obs_padded[:, obs_padded[:,
abs(a_pos[0]-self.pomdp_radius):abs(a_pos[0]-self.pomdp_radius)+obs.shape[1], abs(a_pos[0]-self.pomdp_radius):abs(a_pos[0]-self.pomdp_radius)+obs.shape[1],
abs(a_pos[1]-self.pomdp_radius):abs(a_pos[1]-self.pomdp_radius)+obs.shape[2]] = obs abs(a_pos[1]-self.pomdp_radius):abs(a_pos[1]-self.pomdp_radius)+obs.shape[2]] = obs
@ -223,7 +126,13 @@ class BaseFactory(gym.Env):
obs_new = obs[[key for key, val in self._state_slices.items() if 'agent' not in val]] obs_new = obs[[key for key, val in self._state_slices.items() if 'agent' not in val]]
return obs_new return obs_new
else: else:
return obs if self.combin_agent_slices_in_obs:
agent_obs = np.sum(obs[[key for key, val in self._state_slices.items() if 'agent' in val]],
axis=0, keepdims=True)
obs = np.concatenate((obs[:h.AGENT_START_IDX], agent_obs, obs[h.AGENT_START_IDX+self.n_agents:]))
return obs
else:
return obs
def do_additional_actions(self, agent_i: int, action: int) -> ((int, int), bool): def do_additional_actions(self, agent_i: int, action: int) -> ((int, int), bool):
raise NotImplementedError raise NotImplementedError
@ -231,16 +140,16 @@ class BaseFactory(gym.Env):
def step(self, actions): def step(self, actions):
actions = [actions] if isinstance(actions, int) or np.isscalar(actions) else actions actions = [actions] if isinstance(actions, int) or np.isscalar(actions) else actions
assert isinstance(actions, Iterable), f'"actions" has to be in [{int, list}]' assert isinstance(actions, Iterable), f'"actions" has to be in [{int, list}]'
self.steps += 1 self._steps += 1
done = False done = False
# Move this in a seperate function? # Move this in a seperate function?
agent_states = list() agent_states = list()
for agent_i, action in enumerate(actions): for agent_i, action in enumerate(actions):
agent_i_state = AgentState(agent_i, action) agent_i_state = AgentState(agent_i, action)
if self._is_moving_action(action): if self._actions.is_moving_action(action):
pos, valid = self.move_or_colide(agent_i, action) pos, valid = self.move_or_colide(agent_i, action)
elif self._is_no_op(action): elif self._actions.is_no_op(action):
pos, valid = self.agent_i_position(agent_i), True pos, valid = self.agent_i_position(agent_i), True
else: else:
pos, valid = self.do_additional_actions(agent_i, action) pos, valid = self.do_additional_actions(agent_i, action)
@ -256,24 +165,18 @@ class BaseFactory(gym.Env):
self._agent_states = agent_states self._agent_states = agent_states
reward, info = self.calculate_reward(agent_states) reward, info = self.calculate_reward(agent_states)
if self.steps >= self.max_steps: if self._steps >= self.max_steps:
done = True done = True
info.update(step_reward=reward, step=self.steps) info.update(step_reward=reward, step=self._steps)
return None, reward, done, info return None, reward, done, info
def _is_moving_action(self, action):
return action in self._actions.movement_actions
def _is_no_op(self, action):
return self._actions[action] == 'no-op'
def check_all_collisions(self, agent_states: List[AgentState], collisions: int) -> np.ndarray: def check_all_collisions(self, agent_states: List[AgentState], collisions: int) -> np.ndarray:
collision_vecs = np.zeros((len(agent_states), collisions)) # n_agents x n_slices collision_vecs = np.zeros((len(agent_states), collisions)) # n_agents x n_slices
for agent_state in agent_states: for agent_state in agent_states:
# Register only collisions of moving agents # Register only collisions of moving agents
if self._is_moving_action(agent_state.action): if self._actions.is_moving_action(agent_state.action):
collision_vecs[agent_state.i] = self.check_collisions(agent_state) collision_vecs[agent_state.i] = self.check_collisions(agent_state)
return collision_vecs return collision_vecs

View File

@ -1,16 +1,14 @@
from collections import OrderedDict from collections import OrderedDict
from dataclasses import dataclass
from pathlib import Path
from typing import List, Union, NamedTuple from typing import List, Union, NamedTuple
import random import random
import numpy as np import numpy as np
from environments.factory.base_factory import BaseFactory, AgentState, MovementProperties from environments.factory.base_factory import BaseFactory
from environments import helpers as h from environments import helpers as h
from environments.logging.monitor import MonitorCallback
from environments.factory.renderer import Renderer, Entity from environments.factory.renderer import Renderer, Entity
from environments.utility_classes import AgentState, MovementProperties
DIRT_INDEX = -1 DIRT_INDEX = -1
CLEAN_UP_ACTION = 'clean_up' CLEAN_UP_ACTION = 'clean_up'
@ -25,13 +23,16 @@ class DirtProperties(NamedTuple):
max_global_amount: int = 20 # Max dirt amount in the whole environment. max_global_amount: int = 20 # Max dirt amount in the whole environment.
# noinspection PyAttributeOutsideInit
class SimpleFactory(BaseFactory): class SimpleFactory(BaseFactory):
@property @property
def additional_actions(self) -> Union[str, List[str]]: def additional_actions(self) -> Union[str, List[str]]:
return CLEAN_UP_ACTION return CLEAN_UP_ACTION
def _is_clean_up_action(self, action): def _is_clean_up_action(self, action: Union[str, int]):
if isinstance(action, str):
action = self._actions.by_name(action)
return self._actions[action] == CLEAN_UP_ACTION return self._actions[action] == CLEAN_UP_ACTION
def __init__(self, *args, dirt_properties: DirtProperties, verbose=False, **kwargs): def __init__(self, *args, dirt_properties: DirtProperties, verbose=False, **kwargs):
@ -47,9 +48,9 @@ class SimpleFactory(BaseFactory):
height, width = self._state.shape[1:] height, width = self._state.shape[1:]
self._renderer = Renderer(width, height, view_radius=self.pomdp_radius) self._renderer = Renderer(width, height, view_radius=self.pomdp_radius)
dirt = [Entity('dirt', [x, y], min(0.15 + self._state[DIRT_INDEX, x, y], 1.5), 'scale') dirt = [Entity('dirt', [x, y], min(0.15 + self._state[DIRT_INDEX, x, y], 1.5), 'scale')
for x, y in np.argwhere(self._state[DIRT_INDEX] > h.IS_FREE_CELL)] for x, y in np.argwhere(self._state[DIRT_INDEX] > h.IS_FREE_CELL)]
walls = [Entity('wall', pos) for pos in np.argwhere(self._state[h.LEVEL_IDX] > h.IS_FREE_CELL)] walls = [Entity('wall', pos) for pos in np.argwhere(self._state[h.LEVEL_IDX] > h.IS_FREE_CELL)]
def asset_str(agent): def asset_str(agent):
if any([x is None for x in [self._state_slices[j] for j in agent.collisions]]): if any([x is None for x in [self._state_slices[j] for j in agent.collisions]]):
@ -93,17 +94,18 @@ class SimpleFactory(BaseFactory):
return pos, cleanup_was_sucessfull return pos, cleanup_was_sucessfull
def step(self, actions): def step(self, actions):
_, r, done, info = super(SimpleFactory, self).step(actions) _, reward, done, info = super(SimpleFactory, self).step(actions)
if not self._next_dirt_spawn: if not self._next_dirt_spawn:
self.spawn_dirt() self.spawn_dirt()
self._next_dirt_spawn = self.dirt_properties.spawn_frequency self._next_dirt_spawn = self.dirt_properties.spawn_frequency
else: else:
self._next_dirt_spawn -= 1 self._next_dirt_spawn -= 1
obs = self._return_state()
return obs, r, done, info obs = self._get_observations()
return obs, reward, done, info
def do_additional_actions(self, agent_i: int, action: int) -> ((int, int), bool): def do_additional_actions(self, agent_i: int, action: int) -> ((int, int), bool):
if action != self._is_moving_action(action): if action != self._actions.is_moving_action(action):
if self._is_clean_up_action(action): if self._is_clean_up_action(action):
agent_i_pos = self.agent_i_position(agent_i) agent_i_pos = self.agent_i_position(agent_i)
_, valid = self.clean_up(agent_i_pos) _, valid = self.clean_up(agent_i_pos)
@ -119,7 +121,7 @@ class SimpleFactory(BaseFactory):
self._state = np.concatenate((self._state, dirt_slice)) # dirt is now the last slice self._state = np.concatenate((self._state, dirt_slice)) # dirt is now the last slice
self.spawn_dirt() self.spawn_dirt()
self._next_dirt_spawn = self.dirt_properties.spawn_frequency self._next_dirt_spawn = self.dirt_properties.spawn_frequency
obs = self._return_state() obs = self._get_observations()
return obs return obs
def calculate_reward(self, agent_states: List[AgentState]) -> (int, dict): def calculate_reward(self, agent_states: List[AgentState]) -> (int, dict):
@ -141,7 +143,7 @@ class SimpleFactory(BaseFactory):
if entity != self._state_slices.by_name("dirt")] if entity != self._state_slices.by_name("dirt")]
if list_of_collisions: if list_of_collisions:
self.print(f't = {self.steps}\tAgent {agent_state.i} has collisions with ' self.print(f't = {self._steps}\tAgent {agent_state.i} has collisions with '
f'{list_of_collisions}') f'{list_of_collisions}')
if self._is_clean_up_action(agent_state.action): if self._is_clean_up_action(agent_state.action):
@ -155,7 +157,7 @@ class SimpleFactory(BaseFactory):
f'at {agent_state.pos}, but was unsucsessfull.') f'at {agent_state.pos}, but was unsucsessfull.')
info_dict.update(failed_cleanup_attempt=1) info_dict.update(failed_cleanup_attempt=1)
elif self._is_moving_action(agent_state.action): elif self._actions.is_moving_action(agent_state.action):
if agent_state.action_valid: if agent_state.action_valid:
# info_dict.update(movement=1) # info_dict.update(movement=1)
reward -= 0.00 reward -= 0.00
@ -185,10 +187,11 @@ class SimpleFactory(BaseFactory):
if __name__ == '__main__': if __name__ == '__main__':
render = True render = True
import yaml
with Path(r'C:\Users\steff\projects\f_iks\debug_out\yaml.txt').open('r') as f: move_props = MovementProperties(allow_diagonal_movement=True, allow_square_movement=True)
env_kwargs = yaml.load(f) dirt_props = DirtProperties()
factory = SimpleFactory(**env_kwargs) factory = SimpleFactory(movement_properties=move_props, dirt_properties=dirt_props, n_agents=2,
combin_agent_slices_in_obs=True, omit_agent_slice_in_obs=False)
# dirt_props = DirtProperties() # dirt_props = DirtProperties()
# move_props = MovementProperties(allow_diagonal_movement=False, allow_no_op=False) # move_props = MovementProperties(allow_diagonal_movement=False, allow_no_op=False)
@ -200,10 +203,12 @@ if __name__ == '__main__':
for epoch in range(100): for epoch in range(100):
random_actions = [[random.randint(0, n_actions) for _ in range(factory.n_agents)] for _ in range(200)] random_actions = [[random.randint(0, n_actions) for _ in range(factory.n_agents)] for _ in range(200)]
env_state = factory.reset() env_state = factory.reset()
r = 0
for agent_i_action in random_actions: for agent_i_action in random_actions:
env_state, reward, done_bool, info_obj = factory.step(agent_i_action) env_state, step_r, done_bool, info_obj = factory.step(agent_i_action)
r += step_r
if render: if render:
factory.render() factory.render()
if done_bool: if done_bool:
break break
print(f'Factory run {epoch} done, reward is:\n {reward}') print(f'Factory run {epoch} done, reward is:\n {r}')

View File

@ -32,13 +32,15 @@ def prepare_plot(filepath, results_df, ext='png', hue='Measurement', style=None)
hue_order = sorted(list(df[hue].unique())) hue_order = sorted(list(df[hue].unique()))
try: try:
sns.set(rc={'text.usetex': True}, style='whitegrid') sns.set(rc={'text.usetex': True}, style='whitegrid')
_ = sns.lineplot(data=df, x='Episode', y='Score', ci=95, palette=PALETTE, lineplot = sns.lineplot(data=df, x='Episode', y='Score', ci=95, palette=PALETTE,
hue_order=hue_order, hue=hue, style=style) hue_order=hue_order, hue=hue, style=style)
lineplot.set_title(f'{sorted(list(df["Measurement"].unique()))}')
plot(filepath, ext=ext) # plot raises errors not lineplot! plot(filepath, ext=ext) # plot raises errors not lineplot!
except (FileNotFoundError, RuntimeError): except (FileNotFoundError, RuntimeError):
print('Struggling to plot Figure using LaTeX - going back to normal.') print('Struggling to plot Figure using LaTeX - going back to normal.')
plt.close('all') plt.close('all')
sns.set(rc={'text.usetex': False}, style='whitegrid') sns.set(rc={'text.usetex': False}, style='whitegrid')
sns.lineplot(data=df, x='Episode', y='Score', hue=hue, style=style, lineplot = sns.lineplot(data=df, x='Episode', y='Score', hue=hue, style=style,
ci=95, palette=PALETTE, hue_order=hue_order) ci=95, palette=PALETTE, hue_order=hue_order)
lineplot.set_title(f'{sorted(list(df["Measurement"].unique()))}')
plot(filepath, ext=ext) plot(filepath, ext=ext)

View File

@ -0,0 +1,127 @@
from typing import Union, List, NamedTuple
import numpy as np
class MovementProperties(NamedTuple):
allow_square_movement: bool = True
allow_diagonal_movement: bool = False
allow_no_op: bool = False
# Preperations for Entities (not used yet)
class Entity:
@property
def pos(self):
return self._pos
@property
def identifier(self):
return self._identifier
def __init__(self, identifier, pos):
self._pos = pos
self._identifier = identifier
class AgentState:
def __init__(self, i: int, action: int):
self.i = i
self.action = action
self.collision_vector = None
self.action_valid = None
self.pos = None
self.info = {}
@property
def collisions(self):
return np.argwhere(self.collision_vector != 0).flatten()
def update(self, **kwargs): # is this hacky?? o.0
for key, value in kwargs.items():
if hasattr(self, key):
self.__setattr__(key, value)
else:
raise AttributeError(f'"{key}" cannot be updated, this attr is not a part of {self.__class__.__name__}')
class Register:
@property
def n(self):
return len(self)
def __init__(self):
self._register = dict()
def __len__(self):
return len(self._register)
def __add__(self, other: Union[str, List[str]]):
other = other if isinstance(other, list) else [other]
assert all([isinstance(x, str) for x in other]), f'All item names have to be of type {str}.'
self._register.update({key+len(self._register): value for key, value in enumerate(other)})
return self
def register_additional_items(self, other: Union[str, List[str]]):
self_with_additional_items = self + other
return self_with_additional_items
def keys(self):
return self._register.keys()
def items(self):
return self._register.items()
def __getitem__(self, item):
return self._register[item]
def by_name(self, item):
return list(self._register.keys())[list(self._register.values()).index(item)]
def __repr__(self):
return f'{self.__class__.__name__}({self._register})'
class Actions(Register):
@property
def movement_actions(self):
return self._movement_actions
def __init__(self, movement_properties: MovementProperties):
self.allow_no_op = movement_properties.allow_no_op
self.allow_diagonal_movement = movement_properties.allow_diagonal_movement
self.allow_square_movement = movement_properties.allow_square_movement
# FIXME: There is a bug in helpers because there actions are ints. and the order matters.
# assert not(self.allow_square_movement is False and self.allow_diagonal_movement is True), \
# "There is a bug in helpers!!!"
super(Actions, self).__init__()
if self.allow_square_movement:
self + ['north', 'east', 'south', 'west']
if self.allow_diagonal_movement:
self + ['north_east', 'south_east', 'south_west', 'north_west']
self._movement_actions = self._register.copy()
if self.allow_no_op:
self + 'no-op'
def is_moving_action(self, action: Union[str, int]):
if isinstance(action, str):
return action in self.movement_actions.values()
else:
return self[action] in self.movement_actions.values()
def is_no_op(self, action: Union[str, int]):
if isinstance(action, str):
action = self.by_name(action)
return self[action] == 'no-op'
class StateSlice(Register):
def __init__(self, n_agents: int):
super(StateSlice, self).__init__()
offset = 1
self.register_additional_items(['level', *[f'agent#{i}' for i in range(offset, n_agents+offset)]])

View File

@ -94,7 +94,7 @@ if __name__ == '__main__':
dirt_props = DirtProperties(clean_amount=3, gain_amount=0.2, max_global_amount=30, dirt_props = DirtProperties(clean_amount=3, gain_amount=0.2, max_global_amount=30,
max_local_amount=5, spawn_frequency=3) max_local_amount=5, spawn_frequency=3)
move_props = MovementProperties(allow_diagonal_movement=False, move_props = MovementProperties(allow_diagonal_movement=True,
allow_square_movement=True, allow_square_movement=True,
allow_no_op=False) allow_no_op=False)
time_stamp = int(time.time()) time_stamp = int(time.time())
@ -123,7 +123,7 @@ if __name__ == '__main__':
[MonitorCallback(filepath=out_path / f'monitor_{identifier}.pick', plotting=False)] [MonitorCallback(filepath=out_path / f'monitor_{identifier}.pick', plotting=False)]
) )
model.learn(total_timesteps=int(5e5), callback=callbacks) model.learn(total_timesteps=int(1e5), callback=callbacks)
save_path = out_path / f'model_{identifier}.zip' save_path = out_path / f'model_{identifier}.zip'
save_path.parent.mkdir(parents=True, exist_ok=True) save_path.parent.mkdir(parents=True, exist_ok=True)

75
main_test.py Normal file
View File

@ -0,0 +1,75 @@
# foreign imports
import warnings
from pathlib import Path
import yaml
from natsort import natsorted
from stable_baselines3.common.callbacks import CallbackList
from stable_baselines3 import PPO, DQN, A2C
# our imports
from environments.factory.simple_factory import SimpleFactory
from environments.logging.monitor import MonitorCallback
from algorithms.reg_dqn import RegDQN
from main import compare_runs, combine_runs
warnings.filterwarnings('ignore', category=FutureWarning)
warnings.filterwarnings('ignore', category=UserWarning)
model_mapping = dict(A2C=A2C, PPO=PPO, DQN=DQN, RegDQN=RegDQN)
if __name__ == '__main__':
# get n policies pi_1, ..., pi_n trained in single agent setting
# rewards = []
# repeat for x eval runs
# total reward = rollout game for y steps with n policies in multi-agent setting
# rewards += [total reward]
# boxplot total rewards
run_id = '1623078961'
model_name = 'PPO'
# -----------------------
out_path = Path(__file__).parent / 'debug_out'
# from sb3_contrib import QRDQN
model_path = out_path / f'{model_name}_{run_id}'
model_files = list(natsorted(model_path.rglob('model_*.zip')))
this_model = model_files[0]
render = True
model = model_mapping[model_name].load(this_model)
for seed in range(3):
with (model_path / f'env_{model_path.name}.yaml').open('r') as f:
env_kwargs = yaml.load(f, Loader=yaml.FullLoader)
env_kwargs.update(n_agents=2)
env = SimpleFactory(**env_kwargs)
exp_out_path = model_path / 'exp'
callbacks = CallbackList(
[MonitorCallback(filepath=exp_out_path / f'future_exp_name', plotting=True)]
)
n_actions = env.action_space.n
for epoch in range(100):
observations = env.reset()
if render:
env.render()
done_bool = False
r = 0
while not done_bool:
actions = [model.predict(obs, deterministic=False)[0] for obs in observations]
obs, r, done_bool, info_obj = env.step(actions)
if render:
env.render()
if done_bool:
break
print(f'Factory run {epoch} done, reward is:\n {r}')
if out_path:
combine_runs(out_path.parent)