Item debugging and New Entities

This commit is contained in:
Steffen Illium
2021-08-23 18:39:50 +02:00
parent d5e4d44823
commit c3d4925653
4 changed files with 43 additions and 25 deletions

View File

@ -157,7 +157,7 @@ class BaseFactory(gym.Env):
entities.register_additional_items([self._doors]) entities.register_additional_items([self._doors])
if additional_entities := self.additional_entities: if additional_entities := self.additional_entities:
entities.register_additional_items([additional_entities]) entities.register_additional_items(additional_entities)
return entities return entities

View File

@ -93,6 +93,10 @@ class Register:
class EntityRegister(Register): class EntityRegister(Register):
@property
def positions(self):
return [agent.pos for agent in self]
def __init__(self): def __init__(self):
super(EntityRegister, self).__init__() super(EntityRegister, self).__init__()
self._tiles = dict() self._tiles = dict()
@ -150,7 +154,7 @@ class FloorTiles(EntityRegister):
return tiles return tiles
class Agents(Register): class Agents(EntityRegister):
_accepted_objects = Agent _accepted_objects = Agent

View File

@ -8,7 +8,7 @@ from environments.factory.simple_factory import SimpleFactory
from environments.helpers import Constants as c from environments.helpers import Constants as c
from environments import helpers as h from environments import helpers as h
from environments.factory.base.objects import Agent, Slice, Entity, Action from environments.factory.base.objects import Agent, Slice, Entity, Action
from environments.factory.base.registers import Entities from environments.factory.base.registers import Entities, Register, EntityRegister
from environments.factory.renderer import RenderEntity from environments.factory.renderer import RenderEntity
@ -29,22 +29,31 @@ def inventory_slice_name(agent_i):
class DropOffLocation(Entity): class DropOffLocation(Entity):
def __init__(self, *args, storage_size_until_full: int = 5, **kwargs): def __init__(self, *args, storage_size_until_full: int = 5, **kwargs):
super(DropOffLocation, self).__init__(DROP_OFF, *args, **kwargs) super(DropOffLocation, self).__init__(*args, **kwargs)
self.storage = deque(maxlen=storage_size_until_full) self.storage = deque(maxlen=storage_size_until_full)
def place_item(self, item): def place_item(self, item):
self.storage.append(item) if self.is_full:
return True raise RuntimeWarning("There is currently no way to clear the storage or make it unfull.")
return False
else:
self.storage.append(item)
return True
@property @property
def is_full(self): def is_full(self):
return self.storage.maxlen == len(self.storage) return False if not self.storage.maxlen else self.storage.maxlen == len(self.storage)
class DropOffLocations(EntityRegister):
_accepted_objects = DropOffLocation
class ItemProperties(NamedTuple): class ItemProperties(NamedTuple):
n_items: int = 1 # How many items are there at the same time n_items: int = 5 # How many items are there at the same time
spawn_frequency: int = 5 # Spawn Frequency in Steps spawn_frequency: int = 5 # Spawn Frequency in Steps
max_dropoff_storage_size: int = 5 # How many items are needed until the drop off is full n_drop_off_locations: int = 5 # How many DropOff locations are there at the same time
max_dropoff_storage_size: int = 0 # How many items are needed until the drop off is full
max_agent_storage_size: int = 5 # How many items are needed until the agent inventory is full max_agent_storage_size: int = 5 # How many items are needed until the agent inventory is full
agent_can_interact: bool = True # Whether agents have the possibility to interact with the domain items agent_can_interact: bool = True # Whether agents have the possibility to interact with the domain items
@ -69,7 +78,8 @@ class DoubleTaskFactory(SimpleFactory):
@property @property
def additional_entities(self) -> Union[Entities, List[Entities]]: def additional_entities(self) -> Union[Entities, List[Entities]]:
super_entities = super(self._super, self).additional_entities super_entities = super(self._super, self).additional_entities
return super_entities self._drop_offs = self.spawn_drop_off_location()
return super_entities + [self._drop_offs]
@property @property
def additional_slices(self) -> Union[Slice, List[Slice]]: def additional_slices(self) -> Union[Slice, List[Slice]]:
@ -89,6 +99,7 @@ class DoubleTaskFactory(SimpleFactory):
# Flush per agent inventory state # Flush per agent inventory state
for agent in self._agents: for agent in self._agents:
agent_slice_idx = self._slices.get_idx_by_name(inventory_slice_name(agent.name)) agent_slice_idx = self._slices.get_idx_by_name(inventory_slice_name(agent.name))
# Hard reset the Inventory Stat in OBS cube
self._slices[agent_slice_idx].slice[:] = 0 self._slices[agent_slice_idx].slice[:] = 0
if len(agent.inventory) > 0: if len(agent.inventory) > 0:
max_x = self.pomdp_r if self.pomdp_r else self._level_shape[0] max_x = self.pomdp_r if self.pomdp_r else self._level_shape[0]
@ -111,7 +122,8 @@ class DoubleTaskFactory(SimpleFactory):
if item := item_slice[agent.pos]: if item := item_slice[agent.pos]:
if item == ITEM_DROP_OFF: if item == ITEM_DROP_OFF:
if agent.inventory: if agent.inventory:
valid = self._item_drop_off.place_item(agent.inventory.pop(0)) drop_off = self._drop_offs.by_pos(agent.pos)
valid = drop_off.place_item(agent.inventory.pop(0))
return valid return valid
else: else:
return c.NOT_VALID return c.NOT_VALID
@ -142,7 +154,6 @@ class DoubleTaskFactory(SimpleFactory):
def do_additional_reset(self) -> None: def do_additional_reset(self) -> None:
super(self._super, self).do_additional_reset() super(self._super, self).do_additional_reset()
self.spawn_drop_off_location()
self.spawn_items(self.item_properties.n_items) self.spawn_items(self.item_properties.n_items)
self._next_item_spawn = self.item_properties.spawn_frequency self._next_item_spawn = self.item_properties.spawn_frequency
for agent in self._agents: for agent in self._agents:
@ -151,9 +162,9 @@ class DoubleTaskFactory(SimpleFactory):
def do_additional_step(self) -> dict: def do_additional_step(self) -> dict:
info_dict = super(self._super, self).do_additional_step() info_dict = super(self._super, self).do_additional_step()
if not self._next_item_spawn: if not self._next_item_spawn:
if item_to_spawn := (self.item_properties.n_items - if item_to_spawns := max(0, (self.item_properties.n_items -
(np.sum(self._slices.by_enum(c.ITEM).slice.astype(bool)) - 1)): (np.sum(self._slices.by_enum(c.ITEM).slice.astype(bool)) - 1))):
self.spawn_items(item_to_spawn) self.spawn_items(item_to_spawns)
self._next_item_spawn = self.item_properties.spawn_frequency self._next_item_spawn = self.item_properties.spawn_frequency
else: else:
self.print('No Items are spawning, limit is reached.') self.print('No Items are spawning, limit is reached.')
@ -162,17 +173,18 @@ class DoubleTaskFactory(SimpleFactory):
return info_dict return info_dict
def spawn_drop_off_location(self): def spawn_drop_off_location(self):
single_empty_tile = self._tiles.empty_tiles[0] empty_tiles = self._tiles.empty_tiles[:self.item_properties.n_drop_off_locations]
self._item_drop_off = DropOffLocation(single_empty_tile, drop_offs = DropOffLocations.from_tiles(empty_tiles,
storage_size_until_full=self.item_properties.max_dropoff_storage_size) storage_size_until_full=self.item_properties.max_dropoff_storage_size)
single_empty_tile.enter(self._item_drop_off) xs, ys = zip(*[drop_off.pos for drop_off in drop_offs])
self._slices.by_enum(c.ITEM).slice[single_empty_tile.pos] = ITEM_DROP_OFF self._slices.by_enum(c.ITEM).slice[xs, ys] = ITEM_DROP_OFF
return drop_offs
def calculate_additional_reward(self, agent: Agent) -> (int, dict): def calculate_additional_reward(self, agent: Agent) -> (int, dict):
reward, info_dict = super(self._super, self).calculate_additional_reward(agent) reward, info_dict = super(self._super, self).calculate_additional_reward(agent)
if self._is_item_action(agent.temp_action): if self._is_item_action(agent.temp_action):
if agent.temp_valid: if agent.temp_valid:
if agent.pos == self._item_drop_off.pos: if agent.pos in self._drop_offs.positions:
info_dict.update({f'{agent.name}_item_dropoff': 1}) info_dict.update({f'{agent.name}_item_dropoff': 1})
reward += 1 reward += 1
@ -195,8 +207,9 @@ class DoubleTaskFactory(SimpleFactory):
def spawn_items(self, n_items): def spawn_items(self, n_items):
tiles = self._tiles.empty_tiles[:n_items] tiles = self._tiles.empty_tiles[:n_items]
item_slice = self._slices.by_enum(c.ITEM).slice item_slice = self._slices.by_enum(c.ITEM).slice
for idx, tile in enumerate(tiles, start=1): # when all items should be 1
item_slice[tile.pos] = idx xs, ys = zip(*[tile.pos for tile in tiles])
item_slice[xs, ys] = 1
pass pass

View File

@ -7,6 +7,7 @@ from stable_baselines3 import PPO
from stable_baselines3.common.evaluation import evaluate_policy from stable_baselines3.common.evaluation import evaluate_policy
from environments.factory.simple_factory import DirtProperties, SimpleFactory from environments.factory.simple_factory import DirtProperties, SimpleFactory
from environments.factory.double_task_factory import ItemProperties, DoubleTaskFactory
warnings.filterwarnings('ignore', category=FutureWarning) warnings.filterwarnings('ignore', category=FutureWarning)
warnings.filterwarnings('ignore', category=UserWarning) warnings.filterwarnings('ignore', category=UserWarning)
@ -14,7 +15,7 @@ warnings.filterwarnings('ignore', category=UserWarning)
if __name__ == '__main__': if __name__ == '__main__':
model_name = 'A2C_1627491061' model_name = 'A2C_1629467677'
run_id = 0 run_id = 0
out_path = Path(__file__).parent / 'debug_out' out_path = Path(__file__).parent / 'debug_out'
model_path = out_path / model_name model_path = out_path / model_name
@ -26,7 +27,7 @@ if __name__ == '__main__':
max_local_amount=1, spawn_frequency=5, max_spawn_ratio=0.05, max_local_amount=1, spawn_frequency=5, max_spawn_ratio=0.05,
dirt_smear_amount=0.5), dirt_smear_amount=0.5),
combin_agent_slices_in_obs=True, omit_agent_slice_in_obs=True) combin_agent_slices_in_obs=True, omit_agent_slice_in_obs=True)
with SimpleFactory(**env_kwargs) as env: with DoubleTaskFactory(**env_kwargs) as env:
# Edit THIS: # Edit THIS:
model_files = list(natsorted((model_path / f'{run_id}_{model_name}').rglob('model_*.zip'))) model_files = list(natsorted((model_path / f'{run_id}_{model_name}').rglob('model_*.zip')))