mirror of
https://github.com/illiumst/marl-factory-grid.git
synced 2025-05-23 07:16:44 +02:00
multi_agent observation when n_agent more then 1
This commit is contained in:
parent
62c141aa1c
commit
cf2378a734
@ -1,6 +1,6 @@
|
|||||||
from argparse import Namespace
|
from argparse import Namespace
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import List, Union, Iterable, NamedTuple
|
from typing import List, Union, Iterable
|
||||||
|
|
||||||
import gym
|
import gym
|
||||||
import numpy as np
|
import numpy as np
|
||||||
@ -9,115 +9,7 @@ from gym import spaces
|
|||||||
import yaml
|
import yaml
|
||||||
|
|
||||||
from environments import helpers as h
|
from environments import helpers as h
|
||||||
|
from environments.utility_classes import Actions, StateSlice, AgentState, MovementProperties
|
||||||
|
|
||||||
class MovementProperties(NamedTuple):
|
|
||||||
allow_square_movement: bool = True
|
|
||||||
allow_diagonal_movement: bool = False
|
|
||||||
allow_no_op: bool = False
|
|
||||||
|
|
||||||
|
|
||||||
class Entity():
|
|
||||||
|
|
||||||
@property
|
|
||||||
def pos(self):
|
|
||||||
return self._pos
|
|
||||||
|
|
||||||
def __init__(self, pos):
|
|
||||||
self._pos = pos
|
|
||||||
|
|
||||||
|
|
||||||
class AgentState:
|
|
||||||
|
|
||||||
def __init__(self, i: int, action: int):
|
|
||||||
self.i = i
|
|
||||||
self.action = action
|
|
||||||
|
|
||||||
self.collision_vector = None
|
|
||||||
self.action_valid = None
|
|
||||||
self.pos = None
|
|
||||||
self.info = {}
|
|
||||||
|
|
||||||
@property
|
|
||||||
def collisions(self):
|
|
||||||
return np.argwhere(self.collision_vector != 0).flatten()
|
|
||||||
|
|
||||||
def update(self, **kwargs): # is this hacky?? o.0
|
|
||||||
for key, value in kwargs.items():
|
|
||||||
if hasattr(self, key):
|
|
||||||
self.__setattr__(key, value)
|
|
||||||
else:
|
|
||||||
raise AttributeError(f'"{key}" cannot be updated, this attr is not a part of {self.__class__.__name__}')
|
|
||||||
|
|
||||||
|
|
||||||
class Register:
|
|
||||||
|
|
||||||
@property
|
|
||||||
def n(self):
|
|
||||||
return len(self)
|
|
||||||
|
|
||||||
def __init__(self):
|
|
||||||
self._register = dict()
|
|
||||||
|
|
||||||
def __len__(self):
|
|
||||||
return len(self._register)
|
|
||||||
|
|
||||||
def __add__(self, other: Union[str, List[str]]):
|
|
||||||
other = other if isinstance(other, list) else [other]
|
|
||||||
assert all([isinstance(x, str) for x in other]), f'All item names have to be of type {str}.'
|
|
||||||
self._register.update({key+len(self._register): value for key, value in enumerate(other)})
|
|
||||||
return self
|
|
||||||
|
|
||||||
def register_additional_items(self, other: Union[str, List[str]]):
|
|
||||||
self_with_additional_items = self + other
|
|
||||||
return self_with_additional_items
|
|
||||||
|
|
||||||
def keys(self):
|
|
||||||
return self._register.keys()
|
|
||||||
|
|
||||||
def items(self):
|
|
||||||
return self._register.items()
|
|
||||||
|
|
||||||
def __getitem__(self, item):
|
|
||||||
return self._register[item]
|
|
||||||
|
|
||||||
def by_name(self, item):
|
|
||||||
return list(self._register.keys())[list(self._register.values()).index(item)]
|
|
||||||
|
|
||||||
def __repr__(self):
|
|
||||||
return f'{self.__class__.__name__}({self._register})'
|
|
||||||
|
|
||||||
|
|
||||||
class Actions(Register):
|
|
||||||
|
|
||||||
@property
|
|
||||||
def movement_actions(self):
|
|
||||||
return self._movement_actions
|
|
||||||
|
|
||||||
def __init__(self, movement_properties: MovementProperties):
|
|
||||||
self.allow_no_op = movement_properties.allow_no_op
|
|
||||||
self.allow_diagonal_movement = movement_properties.allow_diagonal_movement
|
|
||||||
self.allow_square_movement = movement_properties.allow_square_movement
|
|
||||||
# FIXME: There is a bug in helpers because there actions are ints. and the order matters.
|
|
||||||
assert not(self.allow_square_movement is False and self.allow_diagonal_movement is True), \
|
|
||||||
"There is a bug in helpers!!!"
|
|
||||||
super(Actions, self).__init__()
|
|
||||||
|
|
||||||
if self.allow_square_movement:
|
|
||||||
self + ['north', 'east', 'south', 'west']
|
|
||||||
if self.allow_diagonal_movement:
|
|
||||||
self + ['north-east', 'south-east', 'south-west', 'north-west']
|
|
||||||
self._movement_actions = self._register.copy()
|
|
||||||
if self.allow_no_op:
|
|
||||||
self + 'no-op'
|
|
||||||
|
|
||||||
|
|
||||||
class StateSlice(Register):
|
|
||||||
|
|
||||||
def __init__(self, n_agents: int):
|
|
||||||
super(StateSlice, self).__init__()
|
|
||||||
offset = 1
|
|
||||||
self.register_additional_items(['level', *[f'agent#{i}' for i in range(offset, n_agents+offset)]])
|
|
||||||
|
|
||||||
|
|
||||||
# noinspection PyAttributeOutsideInit
|
# noinspection PyAttributeOutsideInit
|
||||||
@ -148,9 +40,11 @@ class BaseFactory(gym.Env):
|
|||||||
def movement_actions(self):
|
def movement_actions(self):
|
||||||
return self._actions.movement_actions
|
return self._actions.movement_actions
|
||||||
|
|
||||||
def __init__(self, level_name='simple', n_agents=1, max_steps=int(5e2), pomdp_radius: Union[None, int] = None,
|
def __init__(self, level_name='simple', n_agents=1, max_steps=int(5e2), pomdp_radius: Union[None, int] = 0,
|
||||||
movement_properties: MovementProperties = MovementProperties(),
|
movement_properties: MovementProperties = MovementProperties(),
|
||||||
|
combin_agent_slices_in_obs: bool = False,
|
||||||
omit_agent_slice_in_obs=False, **kwargs):
|
omit_agent_slice_in_obs=False, **kwargs):
|
||||||
|
assert combin_agent_slices_in_obs != omit_agent_slice_in_obs, 'Both options are exclusive'
|
||||||
|
|
||||||
self.movement_properties = movement_properties
|
self.movement_properties = movement_properties
|
||||||
self.level_name = level_name
|
self.level_name = level_name
|
||||||
@ -158,6 +52,7 @@ class BaseFactory(gym.Env):
|
|||||||
self.n_agents = n_agents
|
self.n_agents = n_agents
|
||||||
self.max_steps = max_steps
|
self.max_steps = max_steps
|
||||||
self.pomdp_radius = pomdp_radius
|
self.pomdp_radius = pomdp_radius
|
||||||
|
self.combin_agent_slices_in_obs = combin_agent_slices_in_obs
|
||||||
self.omit_agent_slice_in_obs = omit_agent_slice_in_obs
|
self.omit_agent_slice_in_obs = omit_agent_slice_in_obs
|
||||||
|
|
||||||
self.done_at_collision = False
|
self.done_at_collision = False
|
||||||
@ -185,7 +80,7 @@ class BaseFactory(gym.Env):
|
|||||||
raise NotImplementedError('Please register additional actions ')
|
raise NotImplementedError('Please register additional actions ')
|
||||||
|
|
||||||
def reset(self) -> (np.ndarray, int, bool, dict):
|
def reset(self) -> (np.ndarray, int, bool, dict):
|
||||||
self.steps = 0
|
self._steps = 0
|
||||||
self._agent_states = []
|
self._agent_states = []
|
||||||
# Agent placement ...
|
# Agent placement ...
|
||||||
agents = np.zeros((self.n_agents, *self._level.shape), dtype=np.int8)
|
agents = np.zeros((self.n_agents, *self._level.shape), dtype=np.int8)
|
||||||
@ -202,17 +97,25 @@ class BaseFactory(gym.Env):
|
|||||||
# Returns State
|
# Returns State
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def _return_state(self):
|
def _get_observations(self) -> np.ndarray:
|
||||||
|
if self.n_agents == 1:
|
||||||
|
obs = self._build_per_agent_obs(0)
|
||||||
|
elif self.n_agents >= 2:
|
||||||
|
obs = np.stack([self._build_per_agent_obs(agent_i) for agent_i in range(self.n_agents)])
|
||||||
|
return obs
|
||||||
|
|
||||||
|
def _build_per_agent_obs(self, agent_i: int) -> np.ndarray:
|
||||||
if self.pomdp_radius:
|
if self.pomdp_radius:
|
||||||
pos = self._agent_states[0].pos
|
global_pos = self._agent_states[agent_i].pos
|
||||||
# pos = [agent_state.pos for agent_state in self.agent_states]
|
x0, x1 = max(0, global_pos[0] - self.pomdp_radius), global_pos[0] + self.pomdp_radius + 1
|
||||||
# obs = [] ... list comprehension... pos per agent
|
y0, y1 = max(0, global_pos[1] - self.pomdp_radius), global_pos[1] + self.pomdp_radius + 1
|
||||||
x0, x1 = max(0, pos[0] - self.pomdp_radius), pos[0] + self.pomdp_radius + 1
|
|
||||||
y0, y1 = max(0, pos[1] - self.pomdp_radius), pos[1] + self.pomdp_radius + 1
|
|
||||||
obs = self._state[:, x0:x1, y0:y1]
|
obs = self._state[:, x0:x1, y0:y1]
|
||||||
if obs.shape[1] != self.pomdp_radius * 2 + 1 or obs.shape[2] != self.pomdp_radius * 2 + 1:
|
if obs.shape[1] != self.pomdp_radius * 2 + 1 or obs.shape[2] != self.pomdp_radius * 2 + 1:
|
||||||
obs_padded = np.full((obs.shape[0], self.pomdp_radius * 2 + 1, self.pomdp_radius * 2 + 1), 1)
|
obs_padded = np.full((obs.shape[0], self.pomdp_radius * 2 + 1, self.pomdp_radius * 2 + 1), 1)
|
||||||
a_pos = np.argwhere(obs[h.AGENT_START_IDX] == h.IS_OCCUPIED_CELL)[0]
|
try:
|
||||||
|
a_pos = np.argwhere(obs[h.AGENT_START_IDX + agent_i] == h.IS_OCCUPIED_CELL)[0]
|
||||||
|
except IndexError:
|
||||||
|
print('NO')
|
||||||
obs_padded[:,
|
obs_padded[:,
|
||||||
abs(a_pos[0]-self.pomdp_radius):abs(a_pos[0]-self.pomdp_radius)+obs.shape[1],
|
abs(a_pos[0]-self.pomdp_radius):abs(a_pos[0]-self.pomdp_radius)+obs.shape[1],
|
||||||
abs(a_pos[1]-self.pomdp_radius):abs(a_pos[1]-self.pomdp_radius)+obs.shape[2]] = obs
|
abs(a_pos[1]-self.pomdp_radius):abs(a_pos[1]-self.pomdp_radius)+obs.shape[2]] = obs
|
||||||
@ -223,7 +126,13 @@ class BaseFactory(gym.Env):
|
|||||||
obs_new = obs[[key for key, val in self._state_slices.items() if 'agent' not in val]]
|
obs_new = obs[[key for key, val in self._state_slices.items() if 'agent' not in val]]
|
||||||
return obs_new
|
return obs_new
|
||||||
else:
|
else:
|
||||||
return obs
|
if self.combin_agent_slices_in_obs:
|
||||||
|
agent_obs = np.sum(obs[[key for key, val in self._state_slices.items() if 'agent' in val]],
|
||||||
|
axis=0, keepdims=True)
|
||||||
|
obs = np.concatenate((obs[:h.AGENT_START_IDX], agent_obs, obs[h.AGENT_START_IDX+self.n_agents:]))
|
||||||
|
return obs
|
||||||
|
else:
|
||||||
|
return obs
|
||||||
|
|
||||||
def do_additional_actions(self, agent_i: int, action: int) -> ((int, int), bool):
|
def do_additional_actions(self, agent_i: int, action: int) -> ((int, int), bool):
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
@ -231,16 +140,16 @@ class BaseFactory(gym.Env):
|
|||||||
def step(self, actions):
|
def step(self, actions):
|
||||||
actions = [actions] if isinstance(actions, int) or np.isscalar(actions) else actions
|
actions = [actions] if isinstance(actions, int) or np.isscalar(actions) else actions
|
||||||
assert isinstance(actions, Iterable), f'"actions" has to be in [{int, list}]'
|
assert isinstance(actions, Iterable), f'"actions" has to be in [{int, list}]'
|
||||||
self.steps += 1
|
self._steps += 1
|
||||||
done = False
|
done = False
|
||||||
|
|
||||||
# Move this in a seperate function?
|
# Move this in a seperate function?
|
||||||
agent_states = list()
|
agent_states = list()
|
||||||
for agent_i, action in enumerate(actions):
|
for agent_i, action in enumerate(actions):
|
||||||
agent_i_state = AgentState(agent_i, action)
|
agent_i_state = AgentState(agent_i, action)
|
||||||
if self._is_moving_action(action):
|
if self._actions.is_moving_action(action):
|
||||||
pos, valid = self.move_or_colide(agent_i, action)
|
pos, valid = self.move_or_colide(agent_i, action)
|
||||||
elif self._is_no_op(action):
|
elif self._actions.is_no_op(action):
|
||||||
pos, valid = self.agent_i_position(agent_i), True
|
pos, valid = self.agent_i_position(agent_i), True
|
||||||
else:
|
else:
|
||||||
pos, valid = self.do_additional_actions(agent_i, action)
|
pos, valid = self.do_additional_actions(agent_i, action)
|
||||||
@ -256,24 +165,18 @@ class BaseFactory(gym.Env):
|
|||||||
self._agent_states = agent_states
|
self._agent_states = agent_states
|
||||||
reward, info = self.calculate_reward(agent_states)
|
reward, info = self.calculate_reward(agent_states)
|
||||||
|
|
||||||
if self.steps >= self.max_steps:
|
if self._steps >= self.max_steps:
|
||||||
done = True
|
done = True
|
||||||
|
|
||||||
info.update(step_reward=reward, step=self.steps)
|
info.update(step_reward=reward, step=self._steps)
|
||||||
|
|
||||||
return None, reward, done, info
|
return None, reward, done, info
|
||||||
|
|
||||||
def _is_moving_action(self, action):
|
|
||||||
return action in self._actions.movement_actions
|
|
||||||
|
|
||||||
def _is_no_op(self, action):
|
|
||||||
return self._actions[action] == 'no-op'
|
|
||||||
|
|
||||||
def check_all_collisions(self, agent_states: List[AgentState], collisions: int) -> np.ndarray:
|
def check_all_collisions(self, agent_states: List[AgentState], collisions: int) -> np.ndarray:
|
||||||
collision_vecs = np.zeros((len(agent_states), collisions)) # n_agents x n_slices
|
collision_vecs = np.zeros((len(agent_states), collisions)) # n_agents x n_slices
|
||||||
for agent_state in agent_states:
|
for agent_state in agent_states:
|
||||||
# Register only collisions of moving agents
|
# Register only collisions of moving agents
|
||||||
if self._is_moving_action(agent_state.action):
|
if self._actions.is_moving_action(agent_state.action):
|
||||||
collision_vecs[agent_state.i] = self.check_collisions(agent_state)
|
collision_vecs[agent_state.i] = self.check_collisions(agent_state)
|
||||||
return collision_vecs
|
return collision_vecs
|
||||||
|
|
||||||
|
@ -1,16 +1,14 @@
|
|||||||
from collections import OrderedDict
|
from collections import OrderedDict
|
||||||
from dataclasses import dataclass
|
|
||||||
from pathlib import Path
|
|
||||||
from typing import List, Union, NamedTuple
|
from typing import List, Union, NamedTuple
|
||||||
import random
|
import random
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from environments.factory.base_factory import BaseFactory, AgentState, MovementProperties
|
from environments.factory.base_factory import BaseFactory
|
||||||
from environments import helpers as h
|
from environments import helpers as h
|
||||||
|
|
||||||
from environments.logging.monitor import MonitorCallback
|
|
||||||
from environments.factory.renderer import Renderer, Entity
|
from environments.factory.renderer import Renderer, Entity
|
||||||
|
from environments.utility_classes import AgentState, MovementProperties
|
||||||
|
|
||||||
DIRT_INDEX = -1
|
DIRT_INDEX = -1
|
||||||
CLEAN_UP_ACTION = 'clean_up'
|
CLEAN_UP_ACTION = 'clean_up'
|
||||||
@ -25,13 +23,16 @@ class DirtProperties(NamedTuple):
|
|||||||
max_global_amount: int = 20 # Max dirt amount in the whole environment.
|
max_global_amount: int = 20 # Max dirt amount in the whole environment.
|
||||||
|
|
||||||
|
|
||||||
|
# noinspection PyAttributeOutsideInit
|
||||||
class SimpleFactory(BaseFactory):
|
class SimpleFactory(BaseFactory):
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def additional_actions(self) -> Union[str, List[str]]:
|
def additional_actions(self) -> Union[str, List[str]]:
|
||||||
return CLEAN_UP_ACTION
|
return CLEAN_UP_ACTION
|
||||||
|
|
||||||
def _is_clean_up_action(self, action):
|
def _is_clean_up_action(self, action: Union[str, int]):
|
||||||
|
if isinstance(action, str):
|
||||||
|
action = self._actions.by_name(action)
|
||||||
return self._actions[action] == CLEAN_UP_ACTION
|
return self._actions[action] == CLEAN_UP_ACTION
|
||||||
|
|
||||||
def __init__(self, *args, dirt_properties: DirtProperties, verbose=False, **kwargs):
|
def __init__(self, *args, dirt_properties: DirtProperties, verbose=False, **kwargs):
|
||||||
@ -47,9 +48,9 @@ class SimpleFactory(BaseFactory):
|
|||||||
height, width = self._state.shape[1:]
|
height, width = self._state.shape[1:]
|
||||||
self._renderer = Renderer(width, height, view_radius=self.pomdp_radius)
|
self._renderer = Renderer(width, height, view_radius=self.pomdp_radius)
|
||||||
|
|
||||||
dirt = [Entity('dirt', [x, y], min(0.15 + self._state[DIRT_INDEX, x, y], 1.5), 'scale')
|
dirt = [Entity('dirt', [x, y], min(0.15 + self._state[DIRT_INDEX, x, y], 1.5), 'scale')
|
||||||
for x, y in np.argwhere(self._state[DIRT_INDEX] > h.IS_FREE_CELL)]
|
for x, y in np.argwhere(self._state[DIRT_INDEX] > h.IS_FREE_CELL)]
|
||||||
walls = [Entity('wall', pos) for pos in np.argwhere(self._state[h.LEVEL_IDX] > h.IS_FREE_CELL)]
|
walls = [Entity('wall', pos) for pos in np.argwhere(self._state[h.LEVEL_IDX] > h.IS_FREE_CELL)]
|
||||||
|
|
||||||
def asset_str(agent):
|
def asset_str(agent):
|
||||||
if any([x is None for x in [self._state_slices[j] for j in agent.collisions]]):
|
if any([x is None for x in [self._state_slices[j] for j in agent.collisions]]):
|
||||||
@ -93,17 +94,18 @@ class SimpleFactory(BaseFactory):
|
|||||||
return pos, cleanup_was_sucessfull
|
return pos, cleanup_was_sucessfull
|
||||||
|
|
||||||
def step(self, actions):
|
def step(self, actions):
|
||||||
_, r, done, info = super(SimpleFactory, self).step(actions)
|
_, reward, done, info = super(SimpleFactory, self).step(actions)
|
||||||
if not self._next_dirt_spawn:
|
if not self._next_dirt_spawn:
|
||||||
self.spawn_dirt()
|
self.spawn_dirt()
|
||||||
self._next_dirt_spawn = self.dirt_properties.spawn_frequency
|
self._next_dirt_spawn = self.dirt_properties.spawn_frequency
|
||||||
else:
|
else:
|
||||||
self._next_dirt_spawn -= 1
|
self._next_dirt_spawn -= 1
|
||||||
obs = self._return_state()
|
|
||||||
return obs, r, done, info
|
obs = self._get_observations()
|
||||||
|
return obs, reward, done, info
|
||||||
|
|
||||||
def do_additional_actions(self, agent_i: int, action: int) -> ((int, int), bool):
|
def do_additional_actions(self, agent_i: int, action: int) -> ((int, int), bool):
|
||||||
if action != self._is_moving_action(action):
|
if action != self._actions.is_moving_action(action):
|
||||||
if self._is_clean_up_action(action):
|
if self._is_clean_up_action(action):
|
||||||
agent_i_pos = self.agent_i_position(agent_i)
|
agent_i_pos = self.agent_i_position(agent_i)
|
||||||
_, valid = self.clean_up(agent_i_pos)
|
_, valid = self.clean_up(agent_i_pos)
|
||||||
@ -119,7 +121,7 @@ class SimpleFactory(BaseFactory):
|
|||||||
self._state = np.concatenate((self._state, dirt_slice)) # dirt is now the last slice
|
self._state = np.concatenate((self._state, dirt_slice)) # dirt is now the last slice
|
||||||
self.spawn_dirt()
|
self.spawn_dirt()
|
||||||
self._next_dirt_spawn = self.dirt_properties.spawn_frequency
|
self._next_dirt_spawn = self.dirt_properties.spawn_frequency
|
||||||
obs = self._return_state()
|
obs = self._get_observations()
|
||||||
return obs
|
return obs
|
||||||
|
|
||||||
def calculate_reward(self, agent_states: List[AgentState]) -> (int, dict):
|
def calculate_reward(self, agent_states: List[AgentState]) -> (int, dict):
|
||||||
@ -141,7 +143,7 @@ class SimpleFactory(BaseFactory):
|
|||||||
if entity != self._state_slices.by_name("dirt")]
|
if entity != self._state_slices.by_name("dirt")]
|
||||||
|
|
||||||
if list_of_collisions:
|
if list_of_collisions:
|
||||||
self.print(f't = {self.steps}\tAgent {agent_state.i} has collisions with '
|
self.print(f't = {self._steps}\tAgent {agent_state.i} has collisions with '
|
||||||
f'{list_of_collisions}')
|
f'{list_of_collisions}')
|
||||||
|
|
||||||
if self._is_clean_up_action(agent_state.action):
|
if self._is_clean_up_action(agent_state.action):
|
||||||
@ -155,7 +157,7 @@ class SimpleFactory(BaseFactory):
|
|||||||
f'at {agent_state.pos}, but was unsucsessfull.')
|
f'at {agent_state.pos}, but was unsucsessfull.')
|
||||||
info_dict.update(failed_cleanup_attempt=1)
|
info_dict.update(failed_cleanup_attempt=1)
|
||||||
|
|
||||||
elif self._is_moving_action(agent_state.action):
|
elif self._actions.is_moving_action(agent_state.action):
|
||||||
if agent_state.action_valid:
|
if agent_state.action_valid:
|
||||||
# info_dict.update(movement=1)
|
# info_dict.update(movement=1)
|
||||||
reward -= 0.00
|
reward -= 0.00
|
||||||
@ -185,10 +187,11 @@ class SimpleFactory(BaseFactory):
|
|||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
render = True
|
render = True
|
||||||
import yaml
|
|
||||||
with Path(r'C:\Users\steff\projects\f_iks\debug_out\yaml.txt').open('r') as f:
|
move_props = MovementProperties(allow_diagonal_movement=True, allow_square_movement=True)
|
||||||
env_kwargs = yaml.load(f)
|
dirt_props = DirtProperties()
|
||||||
factory = SimpleFactory(**env_kwargs)
|
factory = SimpleFactory(movement_properties=move_props, dirt_properties=dirt_props, n_agents=2,
|
||||||
|
combin_agent_slices_in_obs=True, omit_agent_slice_in_obs=False)
|
||||||
|
|
||||||
# dirt_props = DirtProperties()
|
# dirt_props = DirtProperties()
|
||||||
# move_props = MovementProperties(allow_diagonal_movement=False, allow_no_op=False)
|
# move_props = MovementProperties(allow_diagonal_movement=False, allow_no_op=False)
|
||||||
@ -200,10 +203,12 @@ if __name__ == '__main__':
|
|||||||
for epoch in range(100):
|
for epoch in range(100):
|
||||||
random_actions = [[random.randint(0, n_actions) for _ in range(factory.n_agents)] for _ in range(200)]
|
random_actions = [[random.randint(0, n_actions) for _ in range(factory.n_agents)] for _ in range(200)]
|
||||||
env_state = factory.reset()
|
env_state = factory.reset()
|
||||||
|
r = 0
|
||||||
for agent_i_action in random_actions:
|
for agent_i_action in random_actions:
|
||||||
env_state, reward, done_bool, info_obj = factory.step(agent_i_action)
|
env_state, step_r, done_bool, info_obj = factory.step(agent_i_action)
|
||||||
|
r += step_r
|
||||||
if render:
|
if render:
|
||||||
factory.render()
|
factory.render()
|
||||||
if done_bool:
|
if done_bool:
|
||||||
break
|
break
|
||||||
print(f'Factory run {epoch} done, reward is:\n {reward}')
|
print(f'Factory run {epoch} done, reward is:\n {r}')
|
||||||
|
@ -32,13 +32,15 @@ def prepare_plot(filepath, results_df, ext='png', hue='Measurement', style=None)
|
|||||||
hue_order = sorted(list(df[hue].unique()))
|
hue_order = sorted(list(df[hue].unique()))
|
||||||
try:
|
try:
|
||||||
sns.set(rc={'text.usetex': True}, style='whitegrid')
|
sns.set(rc={'text.usetex': True}, style='whitegrid')
|
||||||
_ = sns.lineplot(data=df, x='Episode', y='Score', ci=95, palette=PALETTE,
|
lineplot = sns.lineplot(data=df, x='Episode', y='Score', ci=95, palette=PALETTE,
|
||||||
hue_order=hue_order, hue=hue, style=style)
|
hue_order=hue_order, hue=hue, style=style)
|
||||||
|
lineplot.set_title(f'{sorted(list(df["Measurement"].unique()))}')
|
||||||
plot(filepath, ext=ext) # plot raises errors not lineplot!
|
plot(filepath, ext=ext) # plot raises errors not lineplot!
|
||||||
except (FileNotFoundError, RuntimeError):
|
except (FileNotFoundError, RuntimeError):
|
||||||
print('Struggling to plot Figure using LaTeX - going back to normal.')
|
print('Struggling to plot Figure using LaTeX - going back to normal.')
|
||||||
plt.close('all')
|
plt.close('all')
|
||||||
sns.set(rc={'text.usetex': False}, style='whitegrid')
|
sns.set(rc={'text.usetex': False}, style='whitegrid')
|
||||||
sns.lineplot(data=df, x='Episode', y='Score', hue=hue, style=style,
|
lineplot = sns.lineplot(data=df, x='Episode', y='Score', hue=hue, style=style,
|
||||||
ci=95, palette=PALETTE, hue_order=hue_order)
|
ci=95, palette=PALETTE, hue_order=hue_order)
|
||||||
|
lineplot.set_title(f'{sorted(list(df["Measurement"].unique()))}')
|
||||||
plot(filepath, ext=ext)
|
plot(filepath, ext=ext)
|
||||||
|
127
environments/utility_classes.py
Normal file
127
environments/utility_classes.py
Normal file
@ -0,0 +1,127 @@
|
|||||||
|
from typing import Union, List, NamedTuple
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
|
||||||
|
class MovementProperties(NamedTuple):
|
||||||
|
allow_square_movement: bool = True
|
||||||
|
allow_diagonal_movement: bool = False
|
||||||
|
allow_no_op: bool = False
|
||||||
|
|
||||||
|
# Preperations for Entities (not used yet)
|
||||||
|
class Entity:
|
||||||
|
|
||||||
|
@property
|
||||||
|
def pos(self):
|
||||||
|
return self._pos
|
||||||
|
|
||||||
|
@property
|
||||||
|
def identifier(self):
|
||||||
|
return self._identifier
|
||||||
|
|
||||||
|
def __init__(self, identifier, pos):
|
||||||
|
self._pos = pos
|
||||||
|
self._identifier = identifier
|
||||||
|
|
||||||
|
|
||||||
|
class AgentState:
|
||||||
|
|
||||||
|
def __init__(self, i: int, action: int):
|
||||||
|
self.i = i
|
||||||
|
self.action = action
|
||||||
|
|
||||||
|
self.collision_vector = None
|
||||||
|
self.action_valid = None
|
||||||
|
self.pos = None
|
||||||
|
self.info = {}
|
||||||
|
|
||||||
|
@property
|
||||||
|
def collisions(self):
|
||||||
|
return np.argwhere(self.collision_vector != 0).flatten()
|
||||||
|
|
||||||
|
def update(self, **kwargs): # is this hacky?? o.0
|
||||||
|
for key, value in kwargs.items():
|
||||||
|
if hasattr(self, key):
|
||||||
|
self.__setattr__(key, value)
|
||||||
|
else:
|
||||||
|
raise AttributeError(f'"{key}" cannot be updated, this attr is not a part of {self.__class__.__name__}')
|
||||||
|
|
||||||
|
|
||||||
|
class Register:
|
||||||
|
|
||||||
|
@property
|
||||||
|
def n(self):
|
||||||
|
return len(self)
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self._register = dict()
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return len(self._register)
|
||||||
|
|
||||||
|
def __add__(self, other: Union[str, List[str]]):
|
||||||
|
other = other if isinstance(other, list) else [other]
|
||||||
|
assert all([isinstance(x, str) for x in other]), f'All item names have to be of type {str}.'
|
||||||
|
self._register.update({key+len(self._register): value for key, value in enumerate(other)})
|
||||||
|
return self
|
||||||
|
|
||||||
|
def register_additional_items(self, other: Union[str, List[str]]):
|
||||||
|
self_with_additional_items = self + other
|
||||||
|
return self_with_additional_items
|
||||||
|
|
||||||
|
def keys(self):
|
||||||
|
return self._register.keys()
|
||||||
|
|
||||||
|
def items(self):
|
||||||
|
return self._register.items()
|
||||||
|
|
||||||
|
def __getitem__(self, item):
|
||||||
|
return self._register[item]
|
||||||
|
|
||||||
|
def by_name(self, item):
|
||||||
|
return list(self._register.keys())[list(self._register.values()).index(item)]
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
return f'{self.__class__.__name__}({self._register})'
|
||||||
|
|
||||||
|
|
||||||
|
class Actions(Register):
|
||||||
|
|
||||||
|
@property
|
||||||
|
def movement_actions(self):
|
||||||
|
return self._movement_actions
|
||||||
|
|
||||||
|
def __init__(self, movement_properties: MovementProperties):
|
||||||
|
self.allow_no_op = movement_properties.allow_no_op
|
||||||
|
self.allow_diagonal_movement = movement_properties.allow_diagonal_movement
|
||||||
|
self.allow_square_movement = movement_properties.allow_square_movement
|
||||||
|
# FIXME: There is a bug in helpers because there actions are ints. and the order matters.
|
||||||
|
# assert not(self.allow_square_movement is False and self.allow_diagonal_movement is True), \
|
||||||
|
# "There is a bug in helpers!!!"
|
||||||
|
super(Actions, self).__init__()
|
||||||
|
|
||||||
|
if self.allow_square_movement:
|
||||||
|
self + ['north', 'east', 'south', 'west']
|
||||||
|
if self.allow_diagonal_movement:
|
||||||
|
self + ['north_east', 'south_east', 'south_west', 'north_west']
|
||||||
|
self._movement_actions = self._register.copy()
|
||||||
|
if self.allow_no_op:
|
||||||
|
self + 'no-op'
|
||||||
|
|
||||||
|
def is_moving_action(self, action: Union[str, int]):
|
||||||
|
if isinstance(action, str):
|
||||||
|
return action in self.movement_actions.values()
|
||||||
|
else:
|
||||||
|
return self[action] in self.movement_actions.values()
|
||||||
|
|
||||||
|
def is_no_op(self, action: Union[str, int]):
|
||||||
|
if isinstance(action, str):
|
||||||
|
action = self.by_name(action)
|
||||||
|
return self[action] == 'no-op'
|
||||||
|
|
||||||
|
|
||||||
|
class StateSlice(Register):
|
||||||
|
|
||||||
|
def __init__(self, n_agents: int):
|
||||||
|
super(StateSlice, self).__init__()
|
||||||
|
offset = 1
|
||||||
|
self.register_additional_items(['level', *[f'agent#{i}' for i in range(offset, n_agents+offset)]])
|
4
main.py
4
main.py
@ -94,7 +94,7 @@ if __name__ == '__main__':
|
|||||||
|
|
||||||
dirt_props = DirtProperties(clean_amount=3, gain_amount=0.2, max_global_amount=30,
|
dirt_props = DirtProperties(clean_amount=3, gain_amount=0.2, max_global_amount=30,
|
||||||
max_local_amount=5, spawn_frequency=3)
|
max_local_amount=5, spawn_frequency=3)
|
||||||
move_props = MovementProperties(allow_diagonal_movement=False,
|
move_props = MovementProperties(allow_diagonal_movement=True,
|
||||||
allow_square_movement=True,
|
allow_square_movement=True,
|
||||||
allow_no_op=False)
|
allow_no_op=False)
|
||||||
time_stamp = int(time.time())
|
time_stamp = int(time.time())
|
||||||
@ -123,7 +123,7 @@ if __name__ == '__main__':
|
|||||||
[MonitorCallback(filepath=out_path / f'monitor_{identifier}.pick', plotting=False)]
|
[MonitorCallback(filepath=out_path / f'monitor_{identifier}.pick', plotting=False)]
|
||||||
)
|
)
|
||||||
|
|
||||||
model.learn(total_timesteps=int(5e5), callback=callbacks)
|
model.learn(total_timesteps=int(1e5), callback=callbacks)
|
||||||
|
|
||||||
save_path = out_path / f'model_{identifier}.zip'
|
save_path = out_path / f'model_{identifier}.zip'
|
||||||
save_path.parent.mkdir(parents=True, exist_ok=True)
|
save_path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
75
main_test.py
Normal file
75
main_test.py
Normal file
@ -0,0 +1,75 @@
|
|||||||
|
# foreign imports
|
||||||
|
import warnings
|
||||||
|
|
||||||
|
from pathlib import Path
|
||||||
|
import yaml
|
||||||
|
from natsort import natsorted
|
||||||
|
|
||||||
|
from stable_baselines3.common.callbacks import CallbackList
|
||||||
|
from stable_baselines3 import PPO, DQN, A2C
|
||||||
|
|
||||||
|
# our imports
|
||||||
|
from environments.factory.simple_factory import SimpleFactory
|
||||||
|
from environments.logging.monitor import MonitorCallback
|
||||||
|
from algorithms.reg_dqn import RegDQN
|
||||||
|
from main import compare_runs, combine_runs
|
||||||
|
|
||||||
|
warnings.filterwarnings('ignore', category=FutureWarning)
|
||||||
|
warnings.filterwarnings('ignore', category=UserWarning)
|
||||||
|
model_mapping = dict(A2C=A2C, PPO=PPO, DQN=DQN, RegDQN=RegDQN)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
|
||||||
|
# get n policies pi_1, ..., pi_n trained in single agent setting
|
||||||
|
# rewards = []
|
||||||
|
# repeat for x eval runs
|
||||||
|
# total reward = rollout game for y steps with n policies in multi-agent setting
|
||||||
|
# rewards += [total reward]
|
||||||
|
# boxplot total rewards
|
||||||
|
|
||||||
|
run_id = '1623078961'
|
||||||
|
model_name = 'PPO'
|
||||||
|
|
||||||
|
# -----------------------
|
||||||
|
out_path = Path(__file__).parent / 'debug_out'
|
||||||
|
|
||||||
|
# from sb3_contrib import QRDQN
|
||||||
|
model_path = out_path / f'{model_name}_{run_id}'
|
||||||
|
model_files = list(natsorted(model_path.rglob('model_*.zip')))
|
||||||
|
this_model = model_files[0]
|
||||||
|
render = True
|
||||||
|
|
||||||
|
model = model_mapping[model_name].load(this_model)
|
||||||
|
|
||||||
|
for seed in range(3):
|
||||||
|
with (model_path / f'env_{model_path.name}.yaml').open('r') as f:
|
||||||
|
env_kwargs = yaml.load(f, Loader=yaml.FullLoader)
|
||||||
|
env_kwargs.update(n_agents=2)
|
||||||
|
env = SimpleFactory(**env_kwargs)
|
||||||
|
|
||||||
|
exp_out_path = model_path / 'exp'
|
||||||
|
callbacks = CallbackList(
|
||||||
|
[MonitorCallback(filepath=exp_out_path / f'future_exp_name', plotting=True)]
|
||||||
|
)
|
||||||
|
|
||||||
|
n_actions = env.action_space.n
|
||||||
|
|
||||||
|
for epoch in range(100):
|
||||||
|
observations = env.reset()
|
||||||
|
if render:
|
||||||
|
env.render()
|
||||||
|
done_bool = False
|
||||||
|
r = 0
|
||||||
|
while not done_bool:
|
||||||
|
actions = [model.predict(obs, deterministic=False)[0] for obs in observations]
|
||||||
|
|
||||||
|
obs, r, done_bool, info_obj = env.step(actions)
|
||||||
|
if render:
|
||||||
|
env.render()
|
||||||
|
if done_bool:
|
||||||
|
break
|
||||||
|
print(f'Factory run {epoch} done, reward is:\n {r}')
|
||||||
|
|
||||||
|
if out_path:
|
||||||
|
combine_runs(out_path.parent)
|
Loading…
x
Reference in New Issue
Block a user