mirror of
https://github.com/illiumst/marl-factory-grid.git
synced 2025-05-23 07:16:44 +02:00
Debugged Item Factory
This commit is contained in:
parent
50c0d90c77
commit
b09055d95d
@ -128,7 +128,7 @@ class BaseFactory(gym.Env):
|
|||||||
parsed_doors = h.one_hot_level(parsed_level, c.DOOR)
|
parsed_doors = h.one_hot_level(parsed_level, c.DOOR)
|
||||||
if np.any(parsed_doors):
|
if np.any(parsed_doors):
|
||||||
door_tiles = [floor.by_pos(pos) for pos in np.argwhere(parsed_doors == c.OCCUPIED_CELL.value)]
|
door_tiles = [floor.by_pos(pos) for pos in np.argwhere(parsed_doors == c.OCCUPIED_CELL.value)]
|
||||||
doors = Doors.from_tiles(door_tiles, self._level_shape, context=floor, is_blocking_light=True)
|
doors = Doors.from_tiles(door_tiles, self._level_shape, context=floor)
|
||||||
entities.update({c.DOORS: doors})
|
entities.update({c.DOORS: doors})
|
||||||
|
|
||||||
# Actions
|
# Actions
|
||||||
@ -137,7 +137,8 @@ class BaseFactory(gym.Env):
|
|||||||
self._actions.register_additional_items(additional_actions)
|
self._actions.register_additional_items(additional_actions)
|
||||||
|
|
||||||
# Agents
|
# Agents
|
||||||
agents = Agents.from_tiles(floor.empty_tiles[:self.n_agents], self._level_shape)
|
agents = Agents.from_tiles(floor.empty_tiles[:self.n_agents], self._level_shape,
|
||||||
|
individual_slices=not self.combin_agent_obs)
|
||||||
entities.update({c.AGENT: agents})
|
entities.update({c.AGENT: agents})
|
||||||
|
|
||||||
# All entities
|
# All entities
|
||||||
@ -152,10 +153,12 @@ class BaseFactory(gym.Env):
|
|||||||
return self._entities
|
return self._entities
|
||||||
|
|
||||||
def _init_obs_cube(self):
|
def _init_obs_cube(self):
|
||||||
arrays = self._entities.arrays
|
arrays = self._entities.observable_arrays
|
||||||
|
|
||||||
if self.omit_agent_in_obs and self.n_agents == 1:
|
if self.omit_agent_in_obs and self.n_agents == 1:
|
||||||
del arrays[c.AGENT]
|
del arrays[c.AGENT]
|
||||||
|
elif self.omit_agent_in_obs:
|
||||||
|
arrays[c.AGENT] = np.delete(arrays[c.AGENT], 0, axis=0)
|
||||||
obs_cube_z = sum([a.shape[0] if not self[key].is_per_agent else 1 for key, a in arrays.items()])
|
obs_cube_z = sum([a.shape[0] if not self[key].is_per_agent else 1 for key, a in arrays.items()])
|
||||||
self._obs_cube = np.zeros((obs_cube_z, *self._level_shape), dtype=np.float32)
|
self._obs_cube = np.zeros((obs_cube_z, *self._level_shape), dtype=np.float32)
|
||||||
|
|
||||||
@ -257,7 +260,7 @@ class BaseFactory(gym.Env):
|
|||||||
return c.NOT_VALID
|
return c.NOT_VALID
|
||||||
|
|
||||||
def _get_observations(self) -> np.ndarray:
|
def _get_observations(self) -> np.ndarray:
|
||||||
state_array_dict = self._entities.arrays
|
state_array_dict = self._entities.obs_arrays
|
||||||
if self.n_agents == 1:
|
if self.n_agents == 1:
|
||||||
obs = self._build_per_agent_obs(self[c.AGENT][0], state_array_dict)
|
obs = self._build_per_agent_obs(self[c.AGENT][0], state_array_dict)
|
||||||
elif self.n_agents >= 2:
|
elif self.n_agents >= 2:
|
||||||
@ -268,11 +271,14 @@ class BaseFactory(gym.Env):
|
|||||||
|
|
||||||
def _build_per_agent_obs(self, agent: Agent, state_array_dict) -> np.ndarray:
|
def _build_per_agent_obs(self, agent: Agent, state_array_dict) -> np.ndarray:
|
||||||
agent_pos_is_omitted = False
|
agent_pos_is_omitted = False
|
||||||
|
agent_omit_idx = None
|
||||||
if self.omit_agent_in_obs and self.n_agents == 1:
|
if self.omit_agent_in_obs and self.n_agents == 1:
|
||||||
del state_array_dict[c.AGENT]
|
del state_array_dict[c.AGENT]
|
||||||
elif self.omit_agent_in_obs and self.combin_agent_obs and self.n_agents > 1:
|
elif self.omit_agent_in_obs and self.combin_agent_obs and self.n_agents > 1:
|
||||||
state_array_dict[c.AGENT][0, agent.x, agent.y] -= agent.encoding
|
state_array_dict[c.AGENT][0, agent.x, agent.y] -= agent.encoding
|
||||||
agent_pos_is_omitted = True
|
agent_pos_is_omitted = True
|
||||||
|
elif self.omit_agent_in_obs and not self.combin_agent_obs and self.n_agents > 1:
|
||||||
|
agent_omit_idx = next((i for i, a in enumerate(self[c.AGENT]) if a == agent))
|
||||||
|
|
||||||
running_idx, shadowing_idxs, can_be_shadowed_idxs = 0, [], []
|
running_idx, shadowing_idxs, can_be_shadowed_idxs = 0, [], []
|
||||||
|
|
||||||
@ -284,8 +290,14 @@ class BaseFactory(gym.Env):
|
|||||||
z = 1
|
z = 1
|
||||||
self._obs_cube[running_idx: running_idx+z] = array[per_agent_idx]
|
self._obs_cube[running_idx: running_idx+z] = array[per_agent_idx]
|
||||||
else:
|
else:
|
||||||
z = array.shape[0]
|
if key == c.AGENT and agent_omit_idx is not None:
|
||||||
self._obs_cube[running_idx: running_idx+z] = array
|
z = array.shape[0] - 1
|
||||||
|
for array_idx in range(array.shape[0]):
|
||||||
|
self._obs_cube[running_idx: running_idx+z] = array[[x for x in range(array.shape[0])
|
||||||
|
if x != agent_omit_idx]]
|
||||||
|
else:
|
||||||
|
z = array.shape[0]
|
||||||
|
self._obs_cube[running_idx: running_idx+z] = array
|
||||||
# Define which OBS SLices cast a Shadow
|
# Define which OBS SLices cast a Shadow
|
||||||
if self[key].is_blocking_light:
|
if self[key].is_blocking_light:
|
||||||
for i in range(z):
|
for i in range(z):
|
||||||
@ -345,9 +357,13 @@ class BaseFactory(gym.Env):
|
|||||||
else:
|
else:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
# Additional Observation:
|
||||||
for additional_obs in self.additional_obs_build():
|
for additional_obs in self.additional_obs_build():
|
||||||
obs[running_idx:running_idx+additional_obs.shape[0]] = additional_obs
|
obs[running_idx:running_idx+additional_obs.shape[0]] = additional_obs
|
||||||
running_idx += additional_obs.shape[0]
|
running_idx += additional_obs.shape[0]
|
||||||
|
for additional_per_agent_obs in self.additional_per_agent_obs_build(agent):
|
||||||
|
obs[running_idx:running_idx + additional_per_agent_obs.shape[0]] = additional_per_agent_obs
|
||||||
|
running_idx += additional_per_agent_obs.shape[0]
|
||||||
|
|
||||||
return obs
|
return obs
|
||||||
|
|
||||||
@ -522,6 +538,10 @@ class BaseFactory(gym.Env):
|
|||||||
def additional_obs_build(self) -> List[np.ndarray]:
|
def additional_obs_build(self) -> List[np.ndarray]:
|
||||||
return []
|
return []
|
||||||
|
|
||||||
|
@abc.abstractmethod
|
||||||
|
def additional_per_agent_obs_build(self, agent) -> List[np.ndarray]:
|
||||||
|
return []
|
||||||
|
|
||||||
@abc.abstractmethod
|
@abc.abstractmethod
|
||||||
def do_additional_reset(self) -> None:
|
def do_additional_reset(self) -> None:
|
||||||
pass
|
pass
|
||||||
|
@ -1,3 +1,4 @@
|
|||||||
|
from collections import defaultdict
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import Union
|
from typing import Union
|
||||||
|
|
||||||
@ -9,7 +10,7 @@ import itertools
|
|||||||
|
|
||||||
class Object:
|
class Object:
|
||||||
|
|
||||||
_u_idx = 0
|
_u_idx = defaultdict(lambda: 0)
|
||||||
|
|
||||||
def __bool__(self):
|
def __bool__(self):
|
||||||
return True
|
return True
|
||||||
@ -40,8 +41,8 @@ class Object:
|
|||||||
elif self._str_ident is not None and self._enum_ident is None:
|
elif self._str_ident is not None and self._enum_ident is None:
|
||||||
self._name = f'{self.__class__.__name__}[{self._str_ident}]'
|
self._name = f'{self.__class__.__name__}[{self._str_ident}]'
|
||||||
elif self._str_ident is None and self._enum_ident is None:
|
elif self._str_ident is None and self._enum_ident is None:
|
||||||
self._name = f'{self.__class__.__name__}#{self._u_idx}'
|
self._name = f'{self.__class__.__name__}#{self._u_idx[self.__class__.__name__]}'
|
||||||
Object._u_idx += 1
|
Object._u_idx[self.__class__.__name__] += 1
|
||||||
else:
|
else:
|
||||||
raise ValueError('Please use either of the idents.')
|
raise ValueError('Please use either of the idents.')
|
||||||
|
|
||||||
|
@ -4,7 +4,7 @@ from typing import List, Union, Dict
|
|||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from environments.factory.base.objects import Entity, Tile, Agent, Door, Action, Wall
|
from environments.factory.base.objects import Entity, Tile, Agent, Door, Action, Wall, Object
|
||||||
from environments.utility_classes import MovementProperties
|
from environments.utility_classes import MovementProperties
|
||||||
from environments import helpers as h
|
from environments import helpers as h
|
||||||
from environments.helpers import Constants as c
|
from environments.helpers import Constants as c
|
||||||
@ -93,9 +93,13 @@ class EntityObjectRegister(ObjectRegister, ABC):
|
|||||||
@classmethod
|
@classmethod
|
||||||
def from_tiles(cls, tiles, *args, **kwargs):
|
def from_tiles(cls, tiles, *args, **kwargs):
|
||||||
# objects_name = cls._accepted_objects.__name__
|
# objects_name = cls._accepted_objects.__name__
|
||||||
|
register_obj = cls(*args, **kwargs)
|
||||||
|
try:
|
||||||
|
del kwargs['individual_slices']
|
||||||
|
except KeyError:
|
||||||
|
pass
|
||||||
entities = [cls._accepted_objects(tile, str_ident=i, **kwargs)
|
entities = [cls._accepted_objects(tile, str_ident=i, **kwargs)
|
||||||
for i, tile in enumerate(tiles)]
|
for i, tile in enumerate(tiles)]
|
||||||
register_obj = cls(*args)
|
|
||||||
register_obj.register_additional_items(entities)
|
register_obj.register_additional_items(entities)
|
||||||
return register_obj
|
return register_obj
|
||||||
|
|
||||||
@ -139,10 +143,17 @@ class MovingEntityObjectRegister(EntityObjectRegister, ABC):
|
|||||||
except StopIteration:
|
except StopIteration:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
def __delitem__(self, name):
|
||||||
|
idx = next(i for i, entity in enumerate(self) if entity.name == name)
|
||||||
|
del self._register[name]
|
||||||
|
if self.individual_slices:
|
||||||
|
self._array = np.delete(self._array, idx, axis=0)
|
||||||
|
|
||||||
def delete_item(self, item):
|
def delete_item(self, item):
|
||||||
if not isinstance(item, str):
|
self.delete_item_by_name(item.name)
|
||||||
item = item.name
|
|
||||||
del self._register[item]
|
def delete_item_by_name(self, name):
|
||||||
|
del self[name]
|
||||||
|
|
||||||
|
|
||||||
class Entities(Register):
|
class Entities(Register):
|
||||||
@ -150,9 +161,13 @@ class Entities(Register):
|
|||||||
_accepted_objects = EntityObjectRegister
|
_accepted_objects = EntityObjectRegister
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def arrays(self):
|
def observable_arrays(self):
|
||||||
return {key: val.as_array() for key, val in self.items() if val.is_observable}
|
return {key: val.as_array() for key, val in self.items() if val.is_observable}
|
||||||
|
|
||||||
|
@property
|
||||||
|
def obs_arrays(self):
|
||||||
|
return {key: val.as_array() for key, val in self.items() if val.is_observable and not val.hide_from_obs_builder}
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def names(self):
|
def names(self):
|
||||||
return list(self._register.keys())
|
return list(self._register.keys())
|
||||||
|
@ -27,15 +27,26 @@ def inventory_slice_name(agent_i):
|
|||||||
|
|
||||||
class Item(MoveableEntity):
|
class Item(MoveableEntity):
|
||||||
|
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
self._auto_despawn = -1
|
||||||
|
|
||||||
|
@property
|
||||||
|
def auto_despawn(self):
|
||||||
|
return self._auto_despawn
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def can_collide(self):
|
def can_collide(self):
|
||||||
return False
|
return False
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def encoding(self):
|
def encoding(self):
|
||||||
# Edit this if you want items to be drawn in the ops differntly
|
# Edit this if you want items to be drawn in the ops differently
|
||||||
return 1
|
return 1
|
||||||
|
|
||||||
|
def set_auto_despawn(self, auto_despawn):
|
||||||
|
self._auto_despawn = auto_despawn
|
||||||
|
|
||||||
|
|
||||||
class ItemRegister(MovingEntityObjectRegister):
|
class ItemRegister(MovingEntityObjectRegister):
|
||||||
|
|
||||||
@ -52,6 +63,11 @@ class ItemRegister(MovingEntityObjectRegister):
|
|||||||
items = [Item(tile) for tile in tiles]
|
items = [Item(tile) for tile in tiles]
|
||||||
self.register_additional_items(items)
|
self.register_additional_items(items)
|
||||||
|
|
||||||
|
def despawn_items(self, items: List[Item]):
|
||||||
|
items = [items] if isinstance(items, Item) else items
|
||||||
|
for item in items:
|
||||||
|
del self[item]
|
||||||
|
|
||||||
|
|
||||||
class Inventory(UserList):
|
class Inventory(UserList):
|
||||||
|
|
||||||
@ -142,16 +158,18 @@ class DropOffLocation(Entity):
|
|||||||
def encoding(self):
|
def encoding(self):
|
||||||
return ITEM_DROP_OFF
|
return ITEM_DROP_OFF
|
||||||
|
|
||||||
def __init__(self, *args, storage_size_until_full: int = 5, **kwargs):
|
def __init__(self, *args, storage_size_until_full: int = 5, auto_item_despawn_interval: int = 5, **kwargs):
|
||||||
super(DropOffLocation, self).__init__(*args, **kwargs)
|
super(DropOffLocation, self).__init__(*args, **kwargs)
|
||||||
|
self.auto_item_despawn_interval = auto_item_despawn_interval
|
||||||
self.storage = deque(maxlen=storage_size_until_full or None)
|
self.storage = deque(maxlen=storage_size_until_full or None)
|
||||||
|
|
||||||
def place_item(self, item):
|
def place_item(self, item: Item):
|
||||||
if self.is_full:
|
if self.is_full:
|
||||||
raise RuntimeWarning("There is currently no way to clear the storage or make it unfull.")
|
raise RuntimeWarning("There is currently no way to clear the storage or make it unfull.")
|
||||||
return c.NOT_VALID
|
return c.NOT_VALID
|
||||||
else:
|
else:
|
||||||
self.storage.append(item)
|
self.storage.append(item)
|
||||||
|
item.set_auto_despawn(self.auto_item_despawn_interval)
|
||||||
return c.VALID
|
return c.VALID
|
||||||
|
|
||||||
@property
|
@property
|
||||||
@ -173,7 +191,7 @@ class DropOffLocations(EntityObjectRegister):
|
|||||||
|
|
||||||
class ItemProperties(NamedTuple):
|
class ItemProperties(NamedTuple):
|
||||||
n_items: int = 5 # How many items are there at the same time
|
n_items: int = 5 # How many items are there at the same time
|
||||||
spawn_frequency: int = 5 # Spawn Frequency in Steps
|
spawn_frequency: int = 10 # Spawn Frequency in Steps
|
||||||
n_drop_off_locations: int = 5 # How many DropOff locations are there at the same time
|
n_drop_off_locations: int = 5 # How many DropOff locations are there at the same time
|
||||||
max_dropoff_storage_size: int = 0 # How many items are needed until the drop off is full
|
max_dropoff_storage_size: int = 0 # How many items are needed until the drop off is full
|
||||||
max_agent_inventory_capacity: int = 5 # How many items are needed until the agent inventory is full
|
max_agent_inventory_capacity: int = 5 # How many items are needed until the agent inventory is full
|
||||||
@ -218,10 +236,10 @@ class ItemFactory(BaseFactory):
|
|||||||
super_entities.update({c.DROP_OFF: drop_offs, c.ITEM: item_register, c.INVENTORY: inventories})
|
super_entities.update({c.DROP_OFF: drop_offs, c.ITEM: item_register, c.INVENTORY: inventories})
|
||||||
return super_entities
|
return super_entities
|
||||||
|
|
||||||
def additional_obs_build(self) -> List[np.ndarray]:
|
def additional_per_agent_obs_build(self, agent) -> List[np.ndarray]:
|
||||||
super_additional_obs_build = super().additional_obs_build()
|
additional_per_agent_obs_build = super().additional_per_agent_obs_build(agent)
|
||||||
super_additional_obs_build.append(self[c.INVENTORY].as_array())
|
additional_per_agent_obs_build.append(self[c.INVENTORY].by_entity(agent).as_array())
|
||||||
return super_additional_obs_build
|
return additional_per_agent_obs_build
|
||||||
|
|
||||||
def do_item_action(self, agent: Agent):
|
def do_item_action(self, agent: Agent):
|
||||||
inventory = self[c.INVENTORY].by_entity(agent)
|
inventory = self[c.INVENTORY].by_entity(agent)
|
||||||
@ -274,10 +292,18 @@ class ItemFactory(BaseFactory):
|
|||||||
def do_additional_step(self) -> dict:
|
def do_additional_step(self) -> dict:
|
||||||
# noinspection PyUnresolvedReferences
|
# noinspection PyUnresolvedReferences
|
||||||
info_dict = super().do_additional_step()
|
info_dict = super().do_additional_step()
|
||||||
|
for item in list(self[c.ITEM].values()):
|
||||||
|
if item.auto_despawn >= 1:
|
||||||
|
item.set_auto_despawn(item.auto_despawn-1)
|
||||||
|
elif not item.auto_despawn:
|
||||||
|
self[c.ITEM].delete_item(item)
|
||||||
|
else:
|
||||||
|
pass
|
||||||
|
|
||||||
if not self._next_item_spawn:
|
if not self._next_item_spawn:
|
||||||
self.trigger_item_spawn()
|
self.trigger_item_spawn()
|
||||||
else:
|
else:
|
||||||
self._next_item_spawn -= 1
|
self._next_item_spawn = max(0, self._next_item_spawn-1)
|
||||||
return info_dict
|
return info_dict
|
||||||
|
|
||||||
def calculate_additional_reward(self, agent: Agent) -> (int, dict):
|
def calculate_additional_reward(self, agent: Agent) -> (int, dict):
|
||||||
@ -309,7 +335,7 @@ class ItemFactory(BaseFactory):
|
|||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
import random
|
import random
|
||||||
render = True
|
render = False
|
||||||
|
|
||||||
item_props = ItemProperties()
|
item_props = ItemProperties()
|
||||||
|
|
||||||
|
4
main.py
4
main.py
@ -46,8 +46,8 @@ def combine_runs(run_path: Union[str, PathLike]):
|
|||||||
value_vars=columns, var_name="Measurement",
|
value_vars=columns, var_name="Measurement",
|
||||||
value_name="Score")
|
value_name="Score")
|
||||||
|
|
||||||
if df_melted['Episode'].max() > 80:
|
if df_melted['Episode'].max() > 800:
|
||||||
skip_n = round(df_melted['Episode'].max() * 0.02, 2)
|
skip_n = round(df_melted['Episode'].max() * 0.02)
|
||||||
df_melted = df_melted[df_melted['Episode'] % skip_n == 0]
|
df_melted = df_melted[df_melted['Episode'] % skip_n == 0]
|
||||||
|
|
||||||
prepare_plot(run_path / f'{run_path.name}_monitor_lineplot.png', df_melted)
|
prepare_plot(run_path / f'{run_path.name}_monitor_lineplot.png', df_melted)
|
||||||
|
@ -16,7 +16,7 @@ model_map = dict(PPO=PPO, DQN=DQN, A2C=A2C)
|
|||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
|
|
||||||
model_name = 'A2C_1630414444'
|
model_name = 'PPO_1631029150'
|
||||||
run_id = 0
|
run_id = 0
|
||||||
seed=69
|
seed=69
|
||||||
out_path = Path(__file__).parent / 'debug_out'
|
out_path = Path(__file__).parent / 'debug_out'
|
||||||
@ -30,7 +30,7 @@ if __name__ == '__main__':
|
|||||||
max_local_amount=1, spawn_frequency=5, max_spawn_ratio=0.05,
|
max_local_amount=1, spawn_frequency=5, max_spawn_ratio=0.05,
|
||||||
dirt_smear_amount=0.5),
|
dirt_smear_amount=0.5),
|
||||||
combin_agent_slices_in_obs=True, omit_agent_slice_in_obs=True)
|
combin_agent_slices_in_obs=True, omit_agent_slice_in_obs=True)
|
||||||
with DirtFactory(**env_kwargs) as env:
|
with ItemFactory(**env_kwargs) as env:
|
||||||
|
|
||||||
# Edit THIS:
|
# Edit THIS:
|
||||||
env.seed(seed)
|
env.seed(seed)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user