mirror of
https://github.com/illiumst/marl-factory-grid.git
synced 2025-05-23 15:26:43 +02:00
actions are now objects :P
This commit is contained in:
parent
36fe59c95c
commit
604ffc3b60
@ -1,4 +1,4 @@
|
|||||||
from typing import List, Union, Iterable, TypedDict
|
from typing import List, Union, Iterable
|
||||||
|
|
||||||
import gym
|
import gym
|
||||||
from gym import spaces
|
from gym import spaces
|
||||||
@ -34,40 +34,51 @@ class AgentState:
|
|||||||
|
|
||||||
class Actions:
|
class Actions:
|
||||||
|
|
||||||
def __init__(self, allow_square_movement=True, allow_diagonal_movement=True, allow_no_OP=True):
|
@property
|
||||||
self.allow_no_OP = allow_no_OP
|
def n(self):
|
||||||
|
return len(self)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def movement_actions(self):
|
||||||
|
return self._movement_actions
|
||||||
|
|
||||||
|
def __init__(self, allow_square_movement=False, allow_diagonal_movement=False, allow_no_op=False):
|
||||||
|
# FIXME: There is a bug in helpers because there actions are ints. and the order matters.
|
||||||
|
assert not(allow_square_movement is False and allow_diagonal_movement is True), "There is a bug in helpers!!!"
|
||||||
|
self.allow_no_op = allow_no_op
|
||||||
self.allow_diagonal_movement = allow_diagonal_movement
|
self.allow_diagonal_movement = allow_diagonal_movement
|
||||||
self.allow_square_movement = allow_square_movement
|
self.allow_square_movement = allow_square_movement
|
||||||
self._registerd_actions = dict()
|
self._registerd_actions = dict()
|
||||||
if allow_square_movement:
|
if allow_square_movement:
|
||||||
self + {key: val for key, val in enumerate(['north', 'east', 'south', 'west'])}
|
self + ['north', 'east', 'south', 'west']
|
||||||
if allow_diagonal_movement:
|
if allow_diagonal_movement:
|
||||||
self + {key: val for key, val in enumerate(['north-east', 'south-east', 'south-west', 'north-west'])}
|
self + ['north-east', 'south-east', 'south-west', 'north-west']
|
||||||
|
|
||||||
self._movement_actions = self._registerd_actions.copy()
|
self._movement_actions = self._registerd_actions.copy()
|
||||||
if self.allow_no_OP:
|
if self.allow_no_op:
|
||||||
self + {0:'no-op'}
|
self + 'no-op'
|
||||||
|
|
||||||
|
|
||||||
def __len__(self):
|
def __len__(self):
|
||||||
return len(self._registerd_actions)
|
return len(self._registerd_actions)
|
||||||
|
|
||||||
def __add__(self, other: dict):
|
def __add__(self, other: Union[str, List[str]]):
|
||||||
assert all([isinstance(x, int) for x in other.keys()]), f'All action keys have to be of type {int}.'
|
other = other if isinstance(other, list) else [other]
|
||||||
assert all([isinstance(x, str) for x in other.values()]), f'All action values have to be of type {str}.'
|
assert all([isinstance(x, str) for x in other]), f'All action names have to be of type {str}.'
|
||||||
self._registerd_actions.update({key+len(self._registerd_actions): value for key,value in other.items()})
|
self._registerd_actions.update({key+len(self._registerd_actions): value for key, value in enumerate(other)})
|
||||||
return self
|
return self
|
||||||
|
|
||||||
def register_additional_actions(self, other:dict):
|
def register_additional_actions(self, other: Union[str, List[str]]):
|
||||||
self_with_additional_actions = self + other
|
self_with_additional_actions = self + other
|
||||||
return self_with_additional_actions
|
return self_with_additional_actions
|
||||||
|
|
||||||
|
def __getitem__(self, item):
|
||||||
|
return self._registerd_actions[item]
|
||||||
|
|
||||||
|
|
||||||
class BaseFactory(gym.Env):
|
class BaseFactory(gym.Env):
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def action_space(self):
|
def action_space(self):
|
||||||
return spaces.Discrete(self._registered_actions)
|
return spaces.Discrete(self._actions.n)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def observation_space(self):
|
def observation_space(self):
|
||||||
@ -75,27 +86,20 @@ class BaseFactory(gym.Env):
|
|||||||
|
|
||||||
@property
|
@property
|
||||||
def movement_actions(self):
|
def movement_actions(self):
|
||||||
if self._movement_actions is None:
|
return self._actions.movement_actions
|
||||||
self._movement_actions = dict()
|
|
||||||
if self.allow_square_movement:
|
|
||||||
self._movement_actions.update(
|
|
||||||
)
|
|
||||||
if self.allow_diagonal_movement:
|
|
||||||
self.{key: val for key, val in zip(range(4), ['ne', 'ne', 'nw', 'nw'])}
|
|
||||||
|
|
||||||
return self._movement_actions
|
|
||||||
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def string_slices(self):
|
def string_slices(self):
|
||||||
return {value: key for key, value in self.slice_strings.items()}
|
return {value: key for key, value in self.slice_strings.items()}
|
||||||
|
|
||||||
def __init__(self, level='simple', n_agents=1, max_steps=int(2e2)):
|
def __init__(self, level='simple', n_agents=1, max_steps=int(2e2), **kwargs):
|
||||||
self.n_agents = n_agents
|
self.n_agents = n_agents
|
||||||
self.max_steps = max_steps
|
self.max_steps = max_steps
|
||||||
self.done_at_collision = False
|
self.done_at_collision = False
|
||||||
self._actions = Actions(allow_square_movement=True, allow_diagonal_movement=True, allow_no_OP=False)
|
_actions = Actions(allow_square_movement=kwargs.get('allow_square_movement', True),
|
||||||
|
allow_diagonal_movement=kwargs.get('allow_diagonal_movement', True),
|
||||||
|
allow_no_op=kwargs.get('allow_no_op', True))
|
||||||
|
self._actions = _actions + self.additional_actions
|
||||||
|
|
||||||
self.level = h.one_hot_level(
|
self.level = h.one_hot_level(
|
||||||
h.parse_level(Path(__file__).parent / h.LEVELS_DIR / f'{level}.txt')
|
h.parse_level(Path(__file__).parent / h.LEVELS_DIR / f'{level}.txt')
|
||||||
@ -103,7 +107,16 @@ class BaseFactory(gym.Env):
|
|||||||
self.slice_strings = {0: 'level', **{i: f'agent#{i}' for i in range(1, self.n_agents+1)}}
|
self.slice_strings = {0: 'level', **{i: f'agent#{i}' for i in range(1, self.n_agents+1)}}
|
||||||
self.reset()
|
self.reset()
|
||||||
|
|
||||||
def register_additional_actions(self) -> dict:
|
@property
|
||||||
|
def additional_actions(self) -> Union[str, List[str]]:
|
||||||
|
"""
|
||||||
|
When heriting from this Base Class, you musst implement this methode!!!
|
||||||
|
Please return a dict with the given types -> {int: str}.
|
||||||
|
The int should start at 0.
|
||||||
|
|
||||||
|
:return: An Actions-object holding all actions with keys in range 0-n.
|
||||||
|
:rtype: Actions
|
||||||
|
"""
|
||||||
raise NotImplementedError('Please register additional actions ')
|
raise NotImplementedError('Please register additional actions ')
|
||||||
|
|
||||||
def reset(self) -> (np.ndarray, int, bool, dict):
|
def reset(self) -> (np.ndarray, int, bool, dict):
|
||||||
@ -125,7 +138,7 @@ class BaseFactory(gym.Env):
|
|||||||
# Returns State
|
# Returns State
|
||||||
return self.state
|
return self.state
|
||||||
|
|
||||||
def additional_actions(self, agent_i: int, action: int) -> ((int, int), bool):
|
def do_additional_actions(self, agent_i: int, action: int) -> ((int, int), bool):
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
def step(self, actions):
|
def step(self, actions):
|
||||||
@ -143,7 +156,7 @@ class BaseFactory(gym.Env):
|
|||||||
elif self._is_no_op(action):
|
elif self._is_no_op(action):
|
||||||
pos, valid = self.agent_i_position(agent_i), True
|
pos, valid = self.agent_i_position(agent_i), True
|
||||||
else:
|
else:
|
||||||
pos, valid = self.additional_actions(agent_i, action)
|
pos, valid = self.do_additional_actions(agent_i, action)
|
||||||
# Update state accordingly
|
# Update state accordingly
|
||||||
agent_i_state.update(pos=pos, action_valid=valid)
|
agent_i_state.update(pos=pos, action_valid=valid)
|
||||||
agent_states.append(agent_i_state)
|
agent_states.append(agent_i_state)
|
||||||
@ -162,10 +175,10 @@ class BaseFactory(gym.Env):
|
|||||||
return self.state, reward, done, info
|
return self.state, reward, done, info
|
||||||
|
|
||||||
def _is_moving_action(self, action):
|
def _is_moving_action(self, action):
|
||||||
return self._registered_actions[action] in self.movement_actions
|
return action in self._actions.movement_actions
|
||||||
|
|
||||||
def _is_no_op(self, action):
|
def _is_no_op(self, action):
|
||||||
return self._registered_actions[action] == 'no-op'
|
return self._actions[action] == 'no-op'
|
||||||
|
|
||||||
def check_all_collisions(self, agent_states: List[AgentState], collisions: int) -> np.ndarray:
|
def check_all_collisions(self, agent_states: List[AgentState], collisions: int) -> np.ndarray:
|
||||||
collision_vecs = np.zeros((len(agent_states), collisions)) # n_agents x n_slices
|
collision_vecs = np.zeros((len(agent_states), collisions)) # n_agents x n_slices
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
from collections import OrderedDict
|
from collections import OrderedDict
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import List
|
from typing import List, Union
|
||||||
import random
|
import random
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
@ -12,7 +12,7 @@ from environments.logging.monitor import MonitorCallback
|
|||||||
from environments.factory.renderer import Renderer, Entity
|
from environments.factory.renderer import Renderer, Entity
|
||||||
|
|
||||||
DIRT_INDEX = -1
|
DIRT_INDEX = -1
|
||||||
|
CLEAN_UP_ACTION = 'clean_up'
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class DirtProperties:
|
class DirtProperties:
|
||||||
@ -26,13 +26,14 @@ class DirtProperties:
|
|||||||
|
|
||||||
class SimpleFactory(BaseFactory):
|
class SimpleFactory(BaseFactory):
|
||||||
|
|
||||||
def register_additional_actions(self):
|
@property
|
||||||
return 1
|
def additional_actions(self) -> Union[str, List[str]]:
|
||||||
|
return CLEAN_UP_ACTION
|
||||||
|
|
||||||
def _is_clean_up_action(self, action):
|
def _is_clean_up_action(self, action):
|
||||||
return self.action_space.n - 1 == action
|
return self._actions[action] == CLEAN_UP_ACTION
|
||||||
|
|
||||||
def __init__(self, *args, dirt_properties: DirtProperties, verbose=False, force_skip_render=False, **kwargs):
|
def __init__(self, *args, dirt_properties: DirtProperties, verbose=False, **kwargs):
|
||||||
self._dirt_properties = dirt_properties
|
self._dirt_properties = dirt_properties
|
||||||
self.verbose = verbose
|
self.verbose = verbose
|
||||||
self.max_dirt = 20
|
self.max_dirt = 20
|
||||||
@ -98,7 +99,7 @@ class SimpleFactory(BaseFactory):
|
|||||||
self.next_dirt_spawn -= 1
|
self.next_dirt_spawn -= 1
|
||||||
return self.state, r, done, info
|
return self.state, r, done, info
|
||||||
|
|
||||||
def additional_actions(self, agent_i: int, action: int) -> ((int, int), bool):
|
def do_additional_actions(self, agent_i: int, action: int) -> ((int, int), bool):
|
||||||
if action != self._is_moving_action(action):
|
if action != self._is_moving_action(action):
|
||||||
if self._is_clean_up_action(action):
|
if self._is_clean_up_action(action):
|
||||||
agent_i_pos = self.agent_i_position(agent_i)
|
agent_i_pos = self.agent_i_position(agent_i)
|
||||||
@ -175,9 +176,10 @@ if __name__ == '__main__':
|
|||||||
|
|
||||||
dirt_props = DirtProperties()
|
dirt_props = DirtProperties()
|
||||||
factory = SimpleFactory(n_agents=2, dirt_properties=dirt_props)
|
factory = SimpleFactory(n_agents=2, dirt_properties=dirt_props)
|
||||||
|
n_actions = factory.action_space.n - 1
|
||||||
with MonitorCallback(factory):
|
with MonitorCallback(factory):
|
||||||
for epoch in range(100):
|
for epoch in range(100):
|
||||||
random_actions = [(random.randint(0, 8), random.randint(0, 8)) for _ in range(200)]
|
random_actions = [(random.randint(0, n_actions), random.randint(0, n_actions)) for _ in range(200)]
|
||||||
env_state, this_reward, done_bool, _ = factory.reset()
|
env_state, this_reward, done_bool, _ = factory.reset()
|
||||||
for agent_i_action in random_actions:
|
for agent_i_action in random_actions:
|
||||||
env_state, reward, done_bool, info_obj = factory.step(agent_i_action)
|
env_state, reward, done_bool, info_obj = factory.step(agent_i_action)
|
||||||
|
8
main.py
8
main.py
@ -45,7 +45,7 @@ def combine_runs(run_path: Union[str, PathLike]):
|
|||||||
df = pd.concat(df_list, ignore_index=True)
|
df = pd.concat(df_list, ignore_index=True)
|
||||||
df = df.fillna(0).rename(columns={'episode': 'Episode', 'run': 'Run'})
|
df = df.fillna(0).rename(columns={'episode': 'Episode', 'run': 'Run'})
|
||||||
|
|
||||||
df_group = df.groupby(['Episode', 'Run']).aggregate({col: 'mean' if col in ['dirt_amount',
|
df_group = df.groupby(['Episode', 'Run']).aggregate({col: 'sum' if col in ['dirt_amount',
|
||||||
'dirty_tiles'] else 'sum'
|
'dirty_tiles'] else 'sum'
|
||||||
for col in df.columns if
|
for col in df.columns if
|
||||||
col not in ['Episode', 'Run', 'train_step']
|
col not in ['Episode', 'Run', 'train_step']
|
||||||
@ -66,8 +66,8 @@ def combine_runs(run_path: Union[str, PathLike]):
|
|||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
|
|
||||||
combine_runs('debug_out/PPO_1622128912')
|
# combine_runs('debug_out/PPO_1622128912')
|
||||||
exit()
|
# exit()
|
||||||
|
|
||||||
from stable_baselines3 import DQN, PPO
|
from stable_baselines3 import DQN, PPO
|
||||||
|
|
||||||
@ -78,7 +78,7 @@ if __name__ == '__main__':
|
|||||||
|
|
||||||
for seed in range(5):
|
for seed in range(5):
|
||||||
|
|
||||||
env = SimpleFactory(n_agents=1, dirt_properties=dirt_props, force_skip_render=True)
|
env = SimpleFactory(n_agents=1, dirt_properties=dirt_props, allow_diagonal_movement=False, allow_no_op=False)
|
||||||
|
|
||||||
model = PPO("MlpPolicy", env, verbose=1, ent_coef=0.0, seed=seed, device='cpu')
|
model = PPO("MlpPolicy", env, verbose=1, ent_coef=0.0, seed=seed, device='cpu')
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user