mirror of
https://github.com/illiumst/marl-factory-grid.git
synced 2025-05-23 15:26:43 +02:00
new register objects
state slices are now registers
def __get__(int)
def by_name(str)
✌️
This commit is contained in:
parent
efedce579e
commit
7b4e60b0aa
@ -32,12 +32,40 @@ class AgentState:
|
|||||||
raise AttributeError(f'"{key}" cannot be updated, this attr is not a part of {self.__class__.__name__}')
|
raise AttributeError(f'"{key}" cannot be updated, this attr is not a part of {self.__class__.__name__}')
|
||||||
|
|
||||||
|
|
||||||
class Actions:
|
class Register:
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def n(self):
|
def n(self):
|
||||||
return len(self)
|
return len(self)
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self._register = dict()
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return len(self._register)
|
||||||
|
|
||||||
|
def __add__(self, other: Union[str, List[str]]):
|
||||||
|
other = other if isinstance(other, list) else [other]
|
||||||
|
assert all([isinstance(x, str) for x in other]), f'All item names have to be of type {str}.'
|
||||||
|
self._register.update({key+len(self._register): value for key, value in enumerate(other)})
|
||||||
|
return self
|
||||||
|
|
||||||
|
def register_additional_items(self, other: Union[str, List[str]]):
|
||||||
|
self_with_additional_items = self + other
|
||||||
|
return self_with_additional_items
|
||||||
|
|
||||||
|
def __getitem__(self, item):
|
||||||
|
return self._register[item]
|
||||||
|
|
||||||
|
def by_name(self, item):
|
||||||
|
return list(self._register.keys())[list(self._register.values()).index(item)]
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
return f'{self.__class__.__name__}({self._register})'
|
||||||
|
|
||||||
|
|
||||||
|
class Actions(Register):
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def movement_actions(self):
|
def movement_actions(self):
|
||||||
return self._movement_actions
|
return self._movement_actions
|
||||||
@ -45,33 +73,25 @@ class Actions:
|
|||||||
def __init__(self, allow_square_movement=False, allow_diagonal_movement=False, allow_no_op=False):
|
def __init__(self, allow_square_movement=False, allow_diagonal_movement=False, allow_no_op=False):
|
||||||
# FIXME: There is a bug in helpers because there actions are ints. and the order matters.
|
# FIXME: There is a bug in helpers because there actions are ints. and the order matters.
|
||||||
assert not(allow_square_movement is False and allow_diagonal_movement is True), "There is a bug in helpers!!!"
|
assert not(allow_square_movement is False and allow_diagonal_movement is True), "There is a bug in helpers!!!"
|
||||||
|
super(Actions, self).__init__()
|
||||||
self.allow_no_op = allow_no_op
|
self.allow_no_op = allow_no_op
|
||||||
self.allow_diagonal_movement = allow_diagonal_movement
|
self.allow_diagonal_movement = allow_diagonal_movement
|
||||||
self.allow_square_movement = allow_square_movement
|
self.allow_square_movement = allow_square_movement
|
||||||
self._registerd_actions = dict()
|
|
||||||
if allow_square_movement:
|
if allow_square_movement:
|
||||||
self + ['north', 'east', 'south', 'west']
|
self + ['north', 'east', 'south', 'west']
|
||||||
if allow_diagonal_movement:
|
if allow_diagonal_movement:
|
||||||
self + ['north-east', 'south-east', 'south-west', 'north-west']
|
self + ['north-east', 'south-east', 'south-west', 'north-west']
|
||||||
self._movement_actions = self._registerd_actions.copy()
|
self._movement_actions = self._register.copy()
|
||||||
if self.allow_no_op:
|
if self.allow_no_op:
|
||||||
self + 'no-op'
|
self + 'no-op'
|
||||||
|
|
||||||
def __len__(self):
|
|
||||||
return len(self._registerd_actions)
|
|
||||||
|
|
||||||
def __add__(self, other: Union[str, List[str]]):
|
class StateSlice(Register):
|
||||||
other = other if isinstance(other, list) else [other]
|
|
||||||
assert all([isinstance(x, str) for x in other]), f'All action names have to be of type {str}.'
|
|
||||||
self._registerd_actions.update({key+len(self._registerd_actions): value for key, value in enumerate(other)})
|
|
||||||
return self
|
|
||||||
|
|
||||||
def register_additional_actions(self, other: Union[str, List[str]]):
|
def __init__(self, n_agents: int):
|
||||||
self_with_additional_actions = self + other
|
super(StateSlice, self).__init__()
|
||||||
return self_with_additional_actions
|
offset = 1
|
||||||
|
self.register_additional_items(['level', *[f'agent#{i}' for i in range(offset, n_agents+offset)]])
|
||||||
def __getitem__(self, item):
|
|
||||||
return self._registerd_actions[item]
|
|
||||||
|
|
||||||
|
|
||||||
class BaseFactory(gym.Env):
|
class BaseFactory(gym.Env):
|
||||||
@ -88,9 +108,6 @@ class BaseFactory(gym.Env):
|
|||||||
def movement_actions(self):
|
def movement_actions(self):
|
||||||
return self._actions.movement_actions
|
return self._actions.movement_actions
|
||||||
|
|
||||||
@property
|
|
||||||
def string_slices(self):
|
|
||||||
return {value: key for key, value in self.slice_strings.items()}
|
|
||||||
|
|
||||||
def __init__(self, level='simple', n_agents=1, max_steps=int(2e2), **kwargs):
|
def __init__(self, level='simple', n_agents=1, max_steps=int(2e2), **kwargs):
|
||||||
self.n_agents = n_agents
|
self.n_agents = n_agents
|
||||||
@ -104,7 +121,7 @@ class BaseFactory(gym.Env):
|
|||||||
self.level = h.one_hot_level(
|
self.level = h.one_hot_level(
|
||||||
h.parse_level(Path(__file__).parent / h.LEVELS_DIR / f'{level}.txt')
|
h.parse_level(Path(__file__).parent / h.LEVELS_DIR / f'{level}.txt')
|
||||||
)
|
)
|
||||||
self.slice_strings = {0: 'level', **{i: f'agent#{i}' for i in range(1, self.n_agents+1)}}
|
self.state_slices = StateSlice(n_agents)
|
||||||
self.reset()
|
self.reset()
|
||||||
|
|
||||||
@property
|
@property
|
||||||
|
@ -38,8 +38,8 @@ class SimpleFactory(BaseFactory):
|
|||||||
self.verbose = verbose
|
self.verbose = verbose
|
||||||
self.max_dirt = 20
|
self.max_dirt = 20
|
||||||
super(SimpleFactory, self).__init__(*args, **kwargs)
|
super(SimpleFactory, self).__init__(*args, **kwargs)
|
||||||
self.slice_strings.update({self.state.shape[0]-1: 'dirt'})
|
self.state_slices.register_additional_items('dirt')
|
||||||
self.renderer = None # expensive - dont use it when not required !
|
self.renderer = None # expensive - don't use it when not required !
|
||||||
|
|
||||||
def render(self):
|
def render(self):
|
||||||
|
|
||||||
@ -52,7 +52,9 @@ class SimpleFactory(BaseFactory):
|
|||||||
walls = [Entity('wall', pos) for pos in np.argwhere(self.state[h.LEVEL_IDX] > h.IS_FREE_CELL)]
|
walls = [Entity('wall', pos) for pos in np.argwhere(self.state[h.LEVEL_IDX] > h.IS_FREE_CELL)]
|
||||||
|
|
||||||
def asset_str(agent):
|
def asset_str(agent):
|
||||||
cols = ' '.join([self.slice_strings[j] for j in agent.collisions])
|
if any([x is None for x in [self.state_slices[j] for j in agent.collisions]]):
|
||||||
|
print('error')
|
||||||
|
cols = ' '.join([self.state_slices[j] for j in agent.collisions])
|
||||||
if 'agent' in cols:
|
if 'agent' in cols:
|
||||||
return 'agent_collision'
|
return 'agent_collision'
|
||||||
elif not agent.action_valid or 'level' in cols or 'agent' in cols:
|
elif not agent.action_valid or 'level' in cols or 'agent' in cols:
|
||||||
@ -131,8 +133,12 @@ class SimpleFactory(BaseFactory):
|
|||||||
|
|
||||||
for agent_state in agent_states:
|
for agent_state in agent_states:
|
||||||
cols = agent_state.collisions
|
cols = agent_state.collisions
|
||||||
|
|
||||||
|
list_of_collisions = [self.state_slices[entity] for entity in cols
|
||||||
|
if entity != self.state_slices.by_name("dirt")]
|
||||||
|
|
||||||
self.print(f't = {self.steps}\tAgent {agent_state.i} has collisions with '
|
self.print(f't = {self.steps}\tAgent {agent_state.i} has collisions with '
|
||||||
f'{[self.slice_strings[entity] for entity in cols if entity != self.string_slices["dirt"]]}')
|
f'{list_of_collisions}')
|
||||||
if self._is_clean_up_action(agent_state.action):
|
if self._is_clean_up_action(agent_state.action):
|
||||||
if agent_state.action_valid:
|
if agent_state.action_valid:
|
||||||
reward += 1
|
reward += 1
|
||||||
@ -155,8 +161,8 @@ class SimpleFactory(BaseFactory):
|
|||||||
reward -= 0.25
|
reward -= 0.25
|
||||||
|
|
||||||
for entity in cols:
|
for entity in cols:
|
||||||
if entity != self.string_slices["dirt"]:
|
if entity != self.state_slices.by_name("dirt"):
|
||||||
self.monitor.set(f'agent_{agent_state.i}_vs_{self.slice_strings[entity]}', 1)
|
self.monitor.set(f'agent_{agent_state.i}_vs_{self.state_slices[entity]}', 1)
|
||||||
|
|
||||||
self.monitor.set('dirt_amount', current_dirt_amount)
|
self.monitor.set('dirt_amount', current_dirt_amount)
|
||||||
self.monitor.set('dirty_tiles', dirty_tiles)
|
self.monitor.set('dirty_tiles', dirty_tiles)
|
||||||
|
@ -104,6 +104,9 @@ class MonitorCallback(BaseCallback):
|
|||||||
df = pd.DataFrame(columns=monitor.columns)
|
df = pd.DataFrame(columns=monitor.columns)
|
||||||
for _, row in monitor.iterrows():
|
for _, row in monitor.iterrows():
|
||||||
df.loc[df.shape[0]] = row
|
df.loc[df.shape[0]] = row
|
||||||
|
if df is None: # The env exited premature, we catch it.
|
||||||
|
self.closed = True
|
||||||
|
return
|
||||||
for column in list(df.columns):
|
for column in list(df.columns):
|
||||||
if column != 'episode':
|
if column != 'episode':
|
||||||
df[f'{column}_roll'] = df[column].rolling(window=50).mean()
|
df[f'{column}_roll'] = df[column].rolling(window=50).mean()
|
||||||
|
@ -20,13 +20,13 @@ PALETTE = 10 * (
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def plot(filepath, ext='png', tag='monitor', **kwargs):
|
def plot(filepath, ext='png', **kwargs):
|
||||||
plt.rcParams.update(kwargs)
|
plt.rcParams.update(kwargs)
|
||||||
|
|
||||||
plt.tight_layout()
|
plt.tight_layout()
|
||||||
figure = plt.gcf()
|
figure = plt.gcf()
|
||||||
plt.show()
|
plt.show()
|
||||||
figure.savefig(str(filepath.parent / f'{filepath.stem}_{tag}_measures.{ext}'), format=ext)
|
figure.savefig(str(filepath), format=ext)
|
||||||
|
|
||||||
|
|
||||||
def prepare_plot(filepath, results_df, ext='png', tag=''):
|
def prepare_plot(filepath, results_df, ext='png', tag=''):
|
||||||
|
6
main.py
6
main.py
@ -61,8 +61,8 @@ def combine_runs(run_path: Union[str, PathLike]):
|
|||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
|
|
||||||
combine_runs('debug_out/PPO_1622120377')
|
# combine_runs('debug_out/PPO_1622120377')
|
||||||
exit()
|
# exit()
|
||||||
|
|
||||||
from stable_baselines3 import PPO # DQN
|
from stable_baselines3 import PPO # DQN
|
||||||
dirt_props = DirtProperties()
|
dirt_props = DirtProperties()
|
||||||
@ -72,7 +72,7 @@ if __name__ == '__main__':
|
|||||||
|
|
||||||
for seed in range(5):
|
for seed in range(5):
|
||||||
|
|
||||||
env = SimpleFactory(n_agents=1, dirt_properties=dirt_props, allow_diagonal_movement=False, allow_no_op=False)
|
env = SimpleFactory(n_agents=1, dirt_properties=dirt_props, allow_diagonal_movement=True, allow_no_op=False)
|
||||||
|
|
||||||
model = PPO("MlpPolicy", env, verbose=1, ent_coef=0.0, seed=seed, device='cpu')
|
model = PPO("MlpPolicy", env, verbose=1, ent_coef=0.0, seed=seed, device='cpu')
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user