diff --git a/environments/factory/base/base_factory.py b/environments/factory/base/base_factory.py index c2b92cd..633a02b 100644 --- a/environments/factory/base/base_factory.py +++ b/environments/factory/base/base_factory.py @@ -157,7 +157,7 @@ class BaseFactory(gym.Env): entities.register_additional_items([self._doors]) if additional_entities := self.additional_entities: - entities.register_additional_items([additional_entities]) + entities.register_additional_items(additional_entities) return entities diff --git a/environments/factory/base/registers.py b/environments/factory/base/registers.py index 6167dbe..18e884d 100644 --- a/environments/factory/base/registers.py +++ b/environments/factory/base/registers.py @@ -93,6 +93,10 @@ class Register: class EntityRegister(Register): + @property + def positions(self): + return [agent.pos for agent in self] + def __init__(self): super(EntityRegister, self).__init__() self._tiles = dict() @@ -150,7 +154,7 @@ class FloorTiles(EntityRegister): return tiles -class Agents(Register): +class Agents(EntityRegister): _accepted_objects = Agent diff --git a/environments/factory/double_task_factory.py b/environments/factory/double_task_factory.py index d5d8dc6..af8dc9c 100644 --- a/environments/factory/double_task_factory.py +++ b/environments/factory/double_task_factory.py @@ -8,7 +8,7 @@ from environments.factory.simple_factory import SimpleFactory from environments.helpers import Constants as c from environments import helpers as h from environments.factory.base.objects import Agent, Slice, Entity, Action -from environments.factory.base.registers import Entities +from environments.factory.base.registers import Entities, Register, EntityRegister from environments.factory.renderer import RenderEntity @@ -29,22 +29,31 @@ def inventory_slice_name(agent_i): class DropOffLocation(Entity): def __init__(self, *args, storage_size_until_full: int = 5, **kwargs): - super(DropOffLocation, self).__init__(DROP_OFF, *args, **kwargs) + super(DropOffLocation, self).__init__(*args, **kwargs) self.storage = deque(maxlen=storage_size_until_full) def place_item(self, item): - self.storage.append(item) - return True + if self.is_full: + raise RuntimeWarning("There is currently no way to clear the storage or make it unfull.") + return False + else: + self.storage.append(item) + return True @property def is_full(self): - return self.storage.maxlen == len(self.storage) + return False if not self.storage.maxlen else self.storage.maxlen == len(self.storage) + + +class DropOffLocations(EntityRegister): + _accepted_objects = DropOffLocation class ItemProperties(NamedTuple): - n_items: int = 1 # How many items are there at the same time + n_items: int = 5 # How many items are there at the same time spawn_frequency: int = 5 # Spawn Frequency in Steps - max_dropoff_storage_size: int = 5 # How many items are needed until the drop off is full + n_drop_off_locations: int = 5 # How many DropOff locations are there at the same time + max_dropoff_storage_size: int = 0 # How many items are needed until the drop off is full max_agent_storage_size: int = 5 # How many items are needed until the agent inventory is full agent_can_interact: bool = True # Whether agents have the possibility to interact with the domain items @@ -69,7 +78,8 @@ class DoubleTaskFactory(SimpleFactory): @property def additional_entities(self) -> Union[Entities, List[Entities]]: super_entities = super(self._super, self).additional_entities - return super_entities + self._drop_offs = self.spawn_drop_off_location() + return super_entities + [self._drop_offs] @property def additional_slices(self) -> Union[Slice, List[Slice]]: @@ -89,6 +99,7 @@ class DoubleTaskFactory(SimpleFactory): # Flush per agent inventory state for agent in self._agents: agent_slice_idx = self._slices.get_idx_by_name(inventory_slice_name(agent.name)) + # Hard reset the Inventory Stat in OBS cube self._slices[agent_slice_idx].slice[:] = 0 if len(agent.inventory) > 0: max_x = self.pomdp_r if self.pomdp_r else self._level_shape[0] @@ -111,7 +122,8 @@ class DoubleTaskFactory(SimpleFactory): if item := item_slice[agent.pos]: if item == ITEM_DROP_OFF: if agent.inventory: - valid = self._item_drop_off.place_item(agent.inventory.pop(0)) + drop_off = self._drop_offs.by_pos(agent.pos) + valid = drop_off.place_item(agent.inventory.pop(0)) return valid else: return c.NOT_VALID @@ -142,7 +154,6 @@ class DoubleTaskFactory(SimpleFactory): def do_additional_reset(self) -> None: super(self._super, self).do_additional_reset() - self.spawn_drop_off_location() self.spawn_items(self.item_properties.n_items) self._next_item_spawn = self.item_properties.spawn_frequency for agent in self._agents: @@ -151,9 +162,9 @@ class DoubleTaskFactory(SimpleFactory): def do_additional_step(self) -> dict: info_dict = super(self._super, self).do_additional_step() if not self._next_item_spawn: - if item_to_spawn := (self.item_properties.n_items - - (np.sum(self._slices.by_enum(c.ITEM).slice.astype(bool)) - 1)): - self.spawn_items(item_to_spawn) + if item_to_spawns := max(0, (self.item_properties.n_items - + (np.sum(self._slices.by_enum(c.ITEM).slice.astype(bool)) - 1))): + self.spawn_items(item_to_spawns) self._next_item_spawn = self.item_properties.spawn_frequency else: self.print('No Items are spawning, limit is reached.') @@ -162,17 +173,18 @@ class DoubleTaskFactory(SimpleFactory): return info_dict def spawn_drop_off_location(self): - single_empty_tile = self._tiles.empty_tiles[0] - self._item_drop_off = DropOffLocation(single_empty_tile, - storage_size_until_full=self.item_properties.max_dropoff_storage_size) - single_empty_tile.enter(self._item_drop_off) - self._slices.by_enum(c.ITEM).slice[single_empty_tile.pos] = ITEM_DROP_OFF + empty_tiles = self._tiles.empty_tiles[:self.item_properties.n_drop_off_locations] + drop_offs = DropOffLocations.from_tiles(empty_tiles, + storage_size_until_full=self.item_properties.max_dropoff_storage_size) + xs, ys = zip(*[drop_off.pos for drop_off in drop_offs]) + self._slices.by_enum(c.ITEM).slice[xs, ys] = ITEM_DROP_OFF + return drop_offs def calculate_additional_reward(self, agent: Agent) -> (int, dict): reward, info_dict = super(self._super, self).calculate_additional_reward(agent) if self._is_item_action(agent.temp_action): if agent.temp_valid: - if agent.pos == self._item_drop_off.pos: + if agent.pos in self._drop_offs.positions: info_dict.update({f'{agent.name}_item_dropoff': 1}) reward += 1 @@ -195,8 +207,9 @@ class DoubleTaskFactory(SimpleFactory): def spawn_items(self, n_items): tiles = self._tiles.empty_tiles[:n_items] item_slice = self._slices.by_enum(c.ITEM).slice - for idx, tile in enumerate(tiles, start=1): - item_slice[tile.pos] = idx + # when all items should be 1 + xs, ys = zip(*[tile.pos for tile in tiles]) + item_slice[xs, ys] = 1 pass diff --git a/reload_agent.py b/reload_agent.py index 41f3cfa..74f68d5 100644 --- a/reload_agent.py +++ b/reload_agent.py @@ -7,6 +7,7 @@ from stable_baselines3 import PPO from stable_baselines3.common.evaluation import evaluate_policy from environments.factory.simple_factory import DirtProperties, SimpleFactory +from environments.factory.double_task_factory import ItemProperties, DoubleTaskFactory warnings.filterwarnings('ignore', category=FutureWarning) warnings.filterwarnings('ignore', category=UserWarning) @@ -14,7 +15,7 @@ warnings.filterwarnings('ignore', category=UserWarning) if __name__ == '__main__': - model_name = 'A2C_1627491061' + model_name = 'A2C_1629467677' run_id = 0 out_path = Path(__file__).parent / 'debug_out' model_path = out_path / model_name @@ -26,7 +27,7 @@ if __name__ == '__main__': max_local_amount=1, spawn_frequency=5, max_spawn_ratio=0.05, dirt_smear_amount=0.5), combin_agent_slices_in_obs=True, omit_agent_slice_in_obs=True) - with SimpleFactory(**env_kwargs) as env: + with DoubleTaskFactory(**env_kwargs) as env: # Edit THIS: model_files = list(natsorted((model_path / f'{run_id}_{model_name}').rglob('model_*.zip')))