mirror of
https://github.com/illiumst/marl-factory-grid.git
synced 2025-05-23 15:26:43 +02:00
new register objects
state slices are now registers
def __get__(int)
def by_name(str)
✌️
This commit is contained in:
parent
efedce579e
commit
7b4e60b0aa
@ -32,12 +32,40 @@ class AgentState:
|
||||
raise AttributeError(f'"{key}" cannot be updated, this attr is not a part of {self.__class__.__name__}')
|
||||
|
||||
|
||||
class Actions:
|
||||
class Register:
|
||||
|
||||
@property
|
||||
def n(self):
|
||||
return len(self)
|
||||
|
||||
def __init__(self):
|
||||
self._register = dict()
|
||||
|
||||
def __len__(self):
|
||||
return len(self._register)
|
||||
|
||||
def __add__(self, other: Union[str, List[str]]):
|
||||
other = other if isinstance(other, list) else [other]
|
||||
assert all([isinstance(x, str) for x in other]), f'All item names have to be of type {str}.'
|
||||
self._register.update({key+len(self._register): value for key, value in enumerate(other)})
|
||||
return self
|
||||
|
||||
def register_additional_items(self, other: Union[str, List[str]]):
|
||||
self_with_additional_items = self + other
|
||||
return self_with_additional_items
|
||||
|
||||
def __getitem__(self, item):
|
||||
return self._register[item]
|
||||
|
||||
def by_name(self, item):
|
||||
return list(self._register.keys())[list(self._register.values()).index(item)]
|
||||
|
||||
def __repr__(self):
|
||||
return f'{self.__class__.__name__}({self._register})'
|
||||
|
||||
|
||||
class Actions(Register):
|
||||
|
||||
@property
|
||||
def movement_actions(self):
|
||||
return self._movement_actions
|
||||
@ -45,33 +73,25 @@ class Actions:
|
||||
def __init__(self, allow_square_movement=False, allow_diagonal_movement=False, allow_no_op=False):
|
||||
# FIXME: There is a bug in helpers because there actions are ints. and the order matters.
|
||||
assert not(allow_square_movement is False and allow_diagonal_movement is True), "There is a bug in helpers!!!"
|
||||
super(Actions, self).__init__()
|
||||
self.allow_no_op = allow_no_op
|
||||
self.allow_diagonal_movement = allow_diagonal_movement
|
||||
self.allow_square_movement = allow_square_movement
|
||||
self._registerd_actions = dict()
|
||||
if allow_square_movement:
|
||||
self + ['north', 'east', 'south', 'west']
|
||||
if allow_diagonal_movement:
|
||||
self + ['north-east', 'south-east', 'south-west', 'north-west']
|
||||
self._movement_actions = self._registerd_actions.copy()
|
||||
self._movement_actions = self._register.copy()
|
||||
if self.allow_no_op:
|
||||
self + 'no-op'
|
||||
|
||||
def __len__(self):
|
||||
return len(self._registerd_actions)
|
||||
|
||||
def __add__(self, other: Union[str, List[str]]):
|
||||
other = other if isinstance(other, list) else [other]
|
||||
assert all([isinstance(x, str) for x in other]), f'All action names have to be of type {str}.'
|
||||
self._registerd_actions.update({key+len(self._registerd_actions): value for key, value in enumerate(other)})
|
||||
return self
|
||||
class StateSlice(Register):
|
||||
|
||||
def register_additional_actions(self, other: Union[str, List[str]]):
|
||||
self_with_additional_actions = self + other
|
||||
return self_with_additional_actions
|
||||
|
||||
def __getitem__(self, item):
|
||||
return self._registerd_actions[item]
|
||||
def __init__(self, n_agents: int):
|
||||
super(StateSlice, self).__init__()
|
||||
offset = 1
|
||||
self.register_additional_items(['level', *[f'agent#{i}' for i in range(offset, n_agents+offset)]])
|
||||
|
||||
|
||||
class BaseFactory(gym.Env):
|
||||
@ -88,9 +108,6 @@ class BaseFactory(gym.Env):
|
||||
def movement_actions(self):
|
||||
return self._actions.movement_actions
|
||||
|
||||
@property
|
||||
def string_slices(self):
|
||||
return {value: key for key, value in self.slice_strings.items()}
|
||||
|
||||
def __init__(self, level='simple', n_agents=1, max_steps=int(2e2), **kwargs):
|
||||
self.n_agents = n_agents
|
||||
@ -104,7 +121,7 @@ class BaseFactory(gym.Env):
|
||||
self.level = h.one_hot_level(
|
||||
h.parse_level(Path(__file__).parent / h.LEVELS_DIR / f'{level}.txt')
|
||||
)
|
||||
self.slice_strings = {0: 'level', **{i: f'agent#{i}' for i in range(1, self.n_agents+1)}}
|
||||
self.state_slices = StateSlice(n_agents)
|
||||
self.reset()
|
||||
|
||||
@property
|
||||
|
@ -38,8 +38,8 @@ class SimpleFactory(BaseFactory):
|
||||
self.verbose = verbose
|
||||
self.max_dirt = 20
|
||||
super(SimpleFactory, self).__init__(*args, **kwargs)
|
||||
self.slice_strings.update({self.state.shape[0]-1: 'dirt'})
|
||||
self.renderer = None # expensive - dont use it when not required !
|
||||
self.state_slices.register_additional_items('dirt')
|
||||
self.renderer = None # expensive - don't use it when not required !
|
||||
|
||||
def render(self):
|
||||
|
||||
@ -52,7 +52,9 @@ class SimpleFactory(BaseFactory):
|
||||
walls = [Entity('wall', pos) for pos in np.argwhere(self.state[h.LEVEL_IDX] > h.IS_FREE_CELL)]
|
||||
|
||||
def asset_str(agent):
|
||||
cols = ' '.join([self.slice_strings[j] for j in agent.collisions])
|
||||
if any([x is None for x in [self.state_slices[j] for j in agent.collisions]]):
|
||||
print('error')
|
||||
cols = ' '.join([self.state_slices[j] for j in agent.collisions])
|
||||
if 'agent' in cols:
|
||||
return 'agent_collision'
|
||||
elif not agent.action_valid or 'level' in cols or 'agent' in cols:
|
||||
@ -131,8 +133,12 @@ class SimpleFactory(BaseFactory):
|
||||
|
||||
for agent_state in agent_states:
|
||||
cols = agent_state.collisions
|
||||
|
||||
list_of_collisions = [self.state_slices[entity] for entity in cols
|
||||
if entity != self.state_slices.by_name("dirt")]
|
||||
|
||||
self.print(f't = {self.steps}\tAgent {agent_state.i} has collisions with '
|
||||
f'{[self.slice_strings[entity] for entity in cols if entity != self.string_slices["dirt"]]}')
|
||||
f'{list_of_collisions}')
|
||||
if self._is_clean_up_action(agent_state.action):
|
||||
if agent_state.action_valid:
|
||||
reward += 1
|
||||
@ -155,8 +161,8 @@ class SimpleFactory(BaseFactory):
|
||||
reward -= 0.25
|
||||
|
||||
for entity in cols:
|
||||
if entity != self.string_slices["dirt"]:
|
||||
self.monitor.set(f'agent_{agent_state.i}_vs_{self.slice_strings[entity]}', 1)
|
||||
if entity != self.state_slices.by_name("dirt"):
|
||||
self.monitor.set(f'agent_{agent_state.i}_vs_{self.state_slices[entity]}', 1)
|
||||
|
||||
self.monitor.set('dirt_amount', current_dirt_amount)
|
||||
self.monitor.set('dirty_tiles', dirty_tiles)
|
||||
|
@ -104,6 +104,9 @@ class MonitorCallback(BaseCallback):
|
||||
df = pd.DataFrame(columns=monitor.columns)
|
||||
for _, row in monitor.iterrows():
|
||||
df.loc[df.shape[0]] = row
|
||||
if df is None: # The env exited premature, we catch it.
|
||||
self.closed = True
|
||||
return
|
||||
for column in list(df.columns):
|
||||
if column != 'episode':
|
||||
df[f'{column}_roll'] = df[column].rolling(window=50).mean()
|
||||
|
@ -20,13 +20,13 @@ PALETTE = 10 * (
|
||||
)
|
||||
|
||||
|
||||
def plot(filepath, ext='png', tag='monitor', **kwargs):
|
||||
def plot(filepath, ext='png', **kwargs):
|
||||
plt.rcParams.update(kwargs)
|
||||
|
||||
plt.tight_layout()
|
||||
figure = plt.gcf()
|
||||
plt.show()
|
||||
figure.savefig(str(filepath.parent / f'{filepath.stem}_{tag}_measures.{ext}'), format=ext)
|
||||
figure.savefig(str(filepath), format=ext)
|
||||
|
||||
|
||||
def prepare_plot(filepath, results_df, ext='png', tag=''):
|
||||
|
6
main.py
6
main.py
@ -61,8 +61,8 @@ def combine_runs(run_path: Union[str, PathLike]):
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
combine_runs('debug_out/PPO_1622120377')
|
||||
exit()
|
||||
# combine_runs('debug_out/PPO_1622120377')
|
||||
# exit()
|
||||
|
||||
from stable_baselines3 import PPO # DQN
|
||||
dirt_props = DirtProperties()
|
||||
@ -72,7 +72,7 @@ if __name__ == '__main__':
|
||||
|
||||
for seed in range(5):
|
||||
|
||||
env = SimpleFactory(n_agents=1, dirt_properties=dirt_props, allow_diagonal_movement=False, allow_no_op=False)
|
||||
env = SimpleFactory(n_agents=1, dirt_properties=dirt_props, allow_diagonal_movement=True, allow_no_op=False)
|
||||
|
||||
model = PPO("MlpPolicy", env, verbose=1, ent_coef=0.0, seed=seed, device='cpu')
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user