new register objects

state slices are now registers
def __get__(int)
def by_name(str)

✌️
This commit is contained in:
steffen-illium 2021-05-29 10:49:39 +02:00
parent efedce579e
commit 7b4e60b0aa
5 changed files with 57 additions and 31 deletions

View File

@ -32,12 +32,40 @@ class AgentState:
raise AttributeError(f'"{key}" cannot be updated, this attr is not a part of {self.__class__.__name__}') raise AttributeError(f'"{key}" cannot be updated, this attr is not a part of {self.__class__.__name__}')
class Actions: class Register:
@property @property
def n(self): def n(self):
return len(self) return len(self)
def __init__(self):
self._register = dict()
def __len__(self):
return len(self._register)
def __add__(self, other: Union[str, List[str]]):
other = other if isinstance(other, list) else [other]
assert all([isinstance(x, str) for x in other]), f'All item names have to be of type {str}.'
self._register.update({key+len(self._register): value for key, value in enumerate(other)})
return self
def register_additional_items(self, other: Union[str, List[str]]):
self_with_additional_items = self + other
return self_with_additional_items
def __getitem__(self, item):
return self._register[item]
def by_name(self, item):
return list(self._register.keys())[list(self._register.values()).index(item)]
def __repr__(self):
return f'{self.__class__.__name__}({self._register})'
class Actions(Register):
@property @property
def movement_actions(self): def movement_actions(self):
return self._movement_actions return self._movement_actions
@ -45,33 +73,25 @@ class Actions:
def __init__(self, allow_square_movement=False, allow_diagonal_movement=False, allow_no_op=False): def __init__(self, allow_square_movement=False, allow_diagonal_movement=False, allow_no_op=False):
# FIXME: There is a bug in helpers because there actions are ints. and the order matters. # FIXME: There is a bug in helpers because there actions are ints. and the order matters.
assert not(allow_square_movement is False and allow_diagonal_movement is True), "There is a bug in helpers!!!" assert not(allow_square_movement is False and allow_diagonal_movement is True), "There is a bug in helpers!!!"
super(Actions, self).__init__()
self.allow_no_op = allow_no_op self.allow_no_op = allow_no_op
self.allow_diagonal_movement = allow_diagonal_movement self.allow_diagonal_movement = allow_diagonal_movement
self.allow_square_movement = allow_square_movement self.allow_square_movement = allow_square_movement
self._registerd_actions = dict()
if allow_square_movement: if allow_square_movement:
self + ['north', 'east', 'south', 'west'] self + ['north', 'east', 'south', 'west']
if allow_diagonal_movement: if allow_diagonal_movement:
self + ['north-east', 'south-east', 'south-west', 'north-west'] self + ['north-east', 'south-east', 'south-west', 'north-west']
self._movement_actions = self._registerd_actions.copy() self._movement_actions = self._register.copy()
if self.allow_no_op: if self.allow_no_op:
self + 'no-op' self + 'no-op'
def __len__(self):
return len(self._registerd_actions)
def __add__(self, other: Union[str, List[str]]): class StateSlice(Register):
other = other if isinstance(other, list) else [other]
assert all([isinstance(x, str) for x in other]), f'All action names have to be of type {str}.'
self._registerd_actions.update({key+len(self._registerd_actions): value for key, value in enumerate(other)})
return self
def register_additional_actions(self, other: Union[str, List[str]]): def __init__(self, n_agents: int):
self_with_additional_actions = self + other super(StateSlice, self).__init__()
return self_with_additional_actions offset = 1
self.register_additional_items(['level', *[f'agent#{i}' for i in range(offset, n_agents+offset)]])
def __getitem__(self, item):
return self._registerd_actions[item]
class BaseFactory(gym.Env): class BaseFactory(gym.Env):
@ -88,9 +108,6 @@ class BaseFactory(gym.Env):
def movement_actions(self): def movement_actions(self):
return self._actions.movement_actions return self._actions.movement_actions
@property
def string_slices(self):
return {value: key for key, value in self.slice_strings.items()}
def __init__(self, level='simple', n_agents=1, max_steps=int(2e2), **kwargs): def __init__(self, level='simple', n_agents=1, max_steps=int(2e2), **kwargs):
self.n_agents = n_agents self.n_agents = n_agents
@ -104,7 +121,7 @@ class BaseFactory(gym.Env):
self.level = h.one_hot_level( self.level = h.one_hot_level(
h.parse_level(Path(__file__).parent / h.LEVELS_DIR / f'{level}.txt') h.parse_level(Path(__file__).parent / h.LEVELS_DIR / f'{level}.txt')
) )
self.slice_strings = {0: 'level', **{i: f'agent#{i}' for i in range(1, self.n_agents+1)}} self.state_slices = StateSlice(n_agents)
self.reset() self.reset()
@property @property

View File

@ -38,8 +38,8 @@ class SimpleFactory(BaseFactory):
self.verbose = verbose self.verbose = verbose
self.max_dirt = 20 self.max_dirt = 20
super(SimpleFactory, self).__init__(*args, **kwargs) super(SimpleFactory, self).__init__(*args, **kwargs)
self.slice_strings.update({self.state.shape[0]-1: 'dirt'}) self.state_slices.register_additional_items('dirt')
self.renderer = None # expensive - dont use it when not required ! self.renderer = None # expensive - don't use it when not required !
def render(self): def render(self):
@ -52,7 +52,9 @@ class SimpleFactory(BaseFactory):
walls = [Entity('wall', pos) for pos in np.argwhere(self.state[h.LEVEL_IDX] > h.IS_FREE_CELL)] walls = [Entity('wall', pos) for pos in np.argwhere(self.state[h.LEVEL_IDX] > h.IS_FREE_CELL)]
def asset_str(agent): def asset_str(agent):
cols = ' '.join([self.slice_strings[j] for j in agent.collisions]) if any([x is None for x in [self.state_slices[j] for j in agent.collisions]]):
print('error')
cols = ' '.join([self.state_slices[j] for j in agent.collisions])
if 'agent' in cols: if 'agent' in cols:
return 'agent_collision' return 'agent_collision'
elif not agent.action_valid or 'level' in cols or 'agent' in cols: elif not agent.action_valid or 'level' in cols or 'agent' in cols:
@ -131,8 +133,12 @@ class SimpleFactory(BaseFactory):
for agent_state in agent_states: for agent_state in agent_states:
cols = agent_state.collisions cols = agent_state.collisions
list_of_collisions = [self.state_slices[entity] for entity in cols
if entity != self.state_slices.by_name("dirt")]
self.print(f't = {self.steps}\tAgent {agent_state.i} has collisions with ' self.print(f't = {self.steps}\tAgent {agent_state.i} has collisions with '
f'{[self.slice_strings[entity] for entity in cols if entity != self.string_slices["dirt"]]}') f'{list_of_collisions}')
if self._is_clean_up_action(agent_state.action): if self._is_clean_up_action(agent_state.action):
if agent_state.action_valid: if agent_state.action_valid:
reward += 1 reward += 1
@ -155,8 +161,8 @@ class SimpleFactory(BaseFactory):
reward -= 0.25 reward -= 0.25
for entity in cols: for entity in cols:
if entity != self.string_slices["dirt"]: if entity != self.state_slices.by_name("dirt"):
self.monitor.set(f'agent_{agent_state.i}_vs_{self.slice_strings[entity]}', 1) self.monitor.set(f'agent_{agent_state.i}_vs_{self.state_slices[entity]}', 1)
self.monitor.set('dirt_amount', current_dirt_amount) self.monitor.set('dirt_amount', current_dirt_amount)
self.monitor.set('dirty_tiles', dirty_tiles) self.monitor.set('dirty_tiles', dirty_tiles)

View File

@ -104,6 +104,9 @@ class MonitorCallback(BaseCallback):
df = pd.DataFrame(columns=monitor.columns) df = pd.DataFrame(columns=monitor.columns)
for _, row in monitor.iterrows(): for _, row in monitor.iterrows():
df.loc[df.shape[0]] = row df.loc[df.shape[0]] = row
if df is None: # The env exited premature, we catch it.
self.closed = True
return
for column in list(df.columns): for column in list(df.columns):
if column != 'episode': if column != 'episode':
df[f'{column}_roll'] = df[column].rolling(window=50).mean() df[f'{column}_roll'] = df[column].rolling(window=50).mean()

View File

@ -20,13 +20,13 @@ PALETTE = 10 * (
) )
def plot(filepath, ext='png', tag='monitor', **kwargs): def plot(filepath, ext='png', **kwargs):
plt.rcParams.update(kwargs) plt.rcParams.update(kwargs)
plt.tight_layout() plt.tight_layout()
figure = plt.gcf() figure = plt.gcf()
plt.show() plt.show()
figure.savefig(str(filepath.parent / f'{filepath.stem}_{tag}_measures.{ext}'), format=ext) figure.savefig(str(filepath), format=ext)
def prepare_plot(filepath, results_df, ext='png', tag=''): def prepare_plot(filepath, results_df, ext='png', tag=''):

View File

@ -61,8 +61,8 @@ def combine_runs(run_path: Union[str, PathLike]):
if __name__ == '__main__': if __name__ == '__main__':
combine_runs('debug_out/PPO_1622120377') # combine_runs('debug_out/PPO_1622120377')
exit() # exit()
from stable_baselines3 import PPO # DQN from stable_baselines3 import PPO # DQN
dirt_props = DirtProperties() dirt_props = DirtProperties()
@ -72,7 +72,7 @@ if __name__ == '__main__':
for seed in range(5): for seed in range(5):
env = SimpleFactory(n_agents=1, dirt_properties=dirt_props, allow_diagonal_movement=False, allow_no_op=False) env = SimpleFactory(n_agents=1, dirt_properties=dirt_props, allow_diagonal_movement=True, allow_no_op=False)
model = PPO("MlpPolicy", env, verbose=1, ent_coef=0.0, seed=seed, device='cpu') model = PPO("MlpPolicy", env, verbose=1, ent_coef=0.0, seed=seed, device='cpu')