Item debugging and New Entities
This commit is contained in:
@ -157,7 +157,7 @@ class BaseFactory(gym.Env):
|
|||||||
entities.register_additional_items([self._doors])
|
entities.register_additional_items([self._doors])
|
||||||
|
|
||||||
if additional_entities := self.additional_entities:
|
if additional_entities := self.additional_entities:
|
||||||
entities.register_additional_items([additional_entities])
|
entities.register_additional_items(additional_entities)
|
||||||
|
|
||||||
return entities
|
return entities
|
||||||
|
|
||||||
|
@ -93,6 +93,10 @@ class Register:
|
|||||||
|
|
||||||
class EntityRegister(Register):
|
class EntityRegister(Register):
|
||||||
|
|
||||||
|
@property
|
||||||
|
def positions(self):
|
||||||
|
return [agent.pos for agent in self]
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super(EntityRegister, self).__init__()
|
super(EntityRegister, self).__init__()
|
||||||
self._tiles = dict()
|
self._tiles = dict()
|
||||||
@ -150,7 +154,7 @@ class FloorTiles(EntityRegister):
|
|||||||
return tiles
|
return tiles
|
||||||
|
|
||||||
|
|
||||||
class Agents(Register):
|
class Agents(EntityRegister):
|
||||||
|
|
||||||
_accepted_objects = Agent
|
_accepted_objects = Agent
|
||||||
|
|
||||||
|
@ -8,7 +8,7 @@ from environments.factory.simple_factory import SimpleFactory
|
|||||||
from environments.helpers import Constants as c
|
from environments.helpers import Constants as c
|
||||||
from environments import helpers as h
|
from environments import helpers as h
|
||||||
from environments.factory.base.objects import Agent, Slice, Entity, Action
|
from environments.factory.base.objects import Agent, Slice, Entity, Action
|
||||||
from environments.factory.base.registers import Entities
|
from environments.factory.base.registers import Entities, Register, EntityRegister
|
||||||
|
|
||||||
from environments.factory.renderer import RenderEntity
|
from environments.factory.renderer import RenderEntity
|
||||||
|
|
||||||
@ -29,22 +29,31 @@ def inventory_slice_name(agent_i):
|
|||||||
class DropOffLocation(Entity):
|
class DropOffLocation(Entity):
|
||||||
|
|
||||||
def __init__(self, *args, storage_size_until_full: int = 5, **kwargs):
|
def __init__(self, *args, storage_size_until_full: int = 5, **kwargs):
|
||||||
super(DropOffLocation, self).__init__(DROP_OFF, *args, **kwargs)
|
super(DropOffLocation, self).__init__(*args, **kwargs)
|
||||||
self.storage = deque(maxlen=storage_size_until_full)
|
self.storage = deque(maxlen=storage_size_until_full)
|
||||||
|
|
||||||
def place_item(self, item):
|
def place_item(self, item):
|
||||||
|
if self.is_full:
|
||||||
|
raise RuntimeWarning("There is currently no way to clear the storage or make it unfull.")
|
||||||
|
return False
|
||||||
|
else:
|
||||||
self.storage.append(item)
|
self.storage.append(item)
|
||||||
return True
|
return True
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def is_full(self):
|
def is_full(self):
|
||||||
return self.storage.maxlen == len(self.storage)
|
return False if not self.storage.maxlen else self.storage.maxlen == len(self.storage)
|
||||||
|
|
||||||
|
|
||||||
|
class DropOffLocations(EntityRegister):
|
||||||
|
_accepted_objects = DropOffLocation
|
||||||
|
|
||||||
|
|
||||||
class ItemProperties(NamedTuple):
|
class ItemProperties(NamedTuple):
|
||||||
n_items: int = 1 # How many items are there at the same time
|
n_items: int = 5 # How many items are there at the same time
|
||||||
spawn_frequency: int = 5 # Spawn Frequency in Steps
|
spawn_frequency: int = 5 # Spawn Frequency in Steps
|
||||||
max_dropoff_storage_size: int = 5 # How many items are needed until the drop off is full
|
n_drop_off_locations: int = 5 # How many DropOff locations are there at the same time
|
||||||
|
max_dropoff_storage_size: int = 0 # How many items are needed until the drop off is full
|
||||||
max_agent_storage_size: int = 5 # How many items are needed until the agent inventory is full
|
max_agent_storage_size: int = 5 # How many items are needed until the agent inventory is full
|
||||||
agent_can_interact: bool = True # Whether agents have the possibility to interact with the domain items
|
agent_can_interact: bool = True # Whether agents have the possibility to interact with the domain items
|
||||||
|
|
||||||
@ -69,7 +78,8 @@ class DoubleTaskFactory(SimpleFactory):
|
|||||||
@property
|
@property
|
||||||
def additional_entities(self) -> Union[Entities, List[Entities]]:
|
def additional_entities(self) -> Union[Entities, List[Entities]]:
|
||||||
super_entities = super(self._super, self).additional_entities
|
super_entities = super(self._super, self).additional_entities
|
||||||
return super_entities
|
self._drop_offs = self.spawn_drop_off_location()
|
||||||
|
return super_entities + [self._drop_offs]
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def additional_slices(self) -> Union[Slice, List[Slice]]:
|
def additional_slices(self) -> Union[Slice, List[Slice]]:
|
||||||
@ -89,6 +99,7 @@ class DoubleTaskFactory(SimpleFactory):
|
|||||||
# Flush per agent inventory state
|
# Flush per agent inventory state
|
||||||
for agent in self._agents:
|
for agent in self._agents:
|
||||||
agent_slice_idx = self._slices.get_idx_by_name(inventory_slice_name(agent.name))
|
agent_slice_idx = self._slices.get_idx_by_name(inventory_slice_name(agent.name))
|
||||||
|
# Hard reset the Inventory Stat in OBS cube
|
||||||
self._slices[agent_slice_idx].slice[:] = 0
|
self._slices[agent_slice_idx].slice[:] = 0
|
||||||
if len(agent.inventory) > 0:
|
if len(agent.inventory) > 0:
|
||||||
max_x = self.pomdp_r if self.pomdp_r else self._level_shape[0]
|
max_x = self.pomdp_r if self.pomdp_r else self._level_shape[0]
|
||||||
@ -111,7 +122,8 @@ class DoubleTaskFactory(SimpleFactory):
|
|||||||
if item := item_slice[agent.pos]:
|
if item := item_slice[agent.pos]:
|
||||||
if item == ITEM_DROP_OFF:
|
if item == ITEM_DROP_OFF:
|
||||||
if agent.inventory:
|
if agent.inventory:
|
||||||
valid = self._item_drop_off.place_item(agent.inventory.pop(0))
|
drop_off = self._drop_offs.by_pos(agent.pos)
|
||||||
|
valid = drop_off.place_item(agent.inventory.pop(0))
|
||||||
return valid
|
return valid
|
||||||
else:
|
else:
|
||||||
return c.NOT_VALID
|
return c.NOT_VALID
|
||||||
@ -142,7 +154,6 @@ class DoubleTaskFactory(SimpleFactory):
|
|||||||
|
|
||||||
def do_additional_reset(self) -> None:
|
def do_additional_reset(self) -> None:
|
||||||
super(self._super, self).do_additional_reset()
|
super(self._super, self).do_additional_reset()
|
||||||
self.spawn_drop_off_location()
|
|
||||||
self.spawn_items(self.item_properties.n_items)
|
self.spawn_items(self.item_properties.n_items)
|
||||||
self._next_item_spawn = self.item_properties.spawn_frequency
|
self._next_item_spawn = self.item_properties.spawn_frequency
|
||||||
for agent in self._agents:
|
for agent in self._agents:
|
||||||
@ -151,9 +162,9 @@ class DoubleTaskFactory(SimpleFactory):
|
|||||||
def do_additional_step(self) -> dict:
|
def do_additional_step(self) -> dict:
|
||||||
info_dict = super(self._super, self).do_additional_step()
|
info_dict = super(self._super, self).do_additional_step()
|
||||||
if not self._next_item_spawn:
|
if not self._next_item_spawn:
|
||||||
if item_to_spawn := (self.item_properties.n_items -
|
if item_to_spawns := max(0, (self.item_properties.n_items -
|
||||||
(np.sum(self._slices.by_enum(c.ITEM).slice.astype(bool)) - 1)):
|
(np.sum(self._slices.by_enum(c.ITEM).slice.astype(bool)) - 1))):
|
||||||
self.spawn_items(item_to_spawn)
|
self.spawn_items(item_to_spawns)
|
||||||
self._next_item_spawn = self.item_properties.spawn_frequency
|
self._next_item_spawn = self.item_properties.spawn_frequency
|
||||||
else:
|
else:
|
||||||
self.print('No Items are spawning, limit is reached.')
|
self.print('No Items are spawning, limit is reached.')
|
||||||
@ -162,17 +173,18 @@ class DoubleTaskFactory(SimpleFactory):
|
|||||||
return info_dict
|
return info_dict
|
||||||
|
|
||||||
def spawn_drop_off_location(self):
|
def spawn_drop_off_location(self):
|
||||||
single_empty_tile = self._tiles.empty_tiles[0]
|
empty_tiles = self._tiles.empty_tiles[:self.item_properties.n_drop_off_locations]
|
||||||
self._item_drop_off = DropOffLocation(single_empty_tile,
|
drop_offs = DropOffLocations.from_tiles(empty_tiles,
|
||||||
storage_size_until_full=self.item_properties.max_dropoff_storage_size)
|
storage_size_until_full=self.item_properties.max_dropoff_storage_size)
|
||||||
single_empty_tile.enter(self._item_drop_off)
|
xs, ys = zip(*[drop_off.pos for drop_off in drop_offs])
|
||||||
self._slices.by_enum(c.ITEM).slice[single_empty_tile.pos] = ITEM_DROP_OFF
|
self._slices.by_enum(c.ITEM).slice[xs, ys] = ITEM_DROP_OFF
|
||||||
|
return drop_offs
|
||||||
|
|
||||||
def calculate_additional_reward(self, agent: Agent) -> (int, dict):
|
def calculate_additional_reward(self, agent: Agent) -> (int, dict):
|
||||||
reward, info_dict = super(self._super, self).calculate_additional_reward(agent)
|
reward, info_dict = super(self._super, self).calculate_additional_reward(agent)
|
||||||
if self._is_item_action(agent.temp_action):
|
if self._is_item_action(agent.temp_action):
|
||||||
if agent.temp_valid:
|
if agent.temp_valid:
|
||||||
if agent.pos == self._item_drop_off.pos:
|
if agent.pos in self._drop_offs.positions:
|
||||||
info_dict.update({f'{agent.name}_item_dropoff': 1})
|
info_dict.update({f'{agent.name}_item_dropoff': 1})
|
||||||
|
|
||||||
reward += 1
|
reward += 1
|
||||||
@ -195,8 +207,9 @@ class DoubleTaskFactory(SimpleFactory):
|
|||||||
def spawn_items(self, n_items):
|
def spawn_items(self, n_items):
|
||||||
tiles = self._tiles.empty_tiles[:n_items]
|
tiles = self._tiles.empty_tiles[:n_items]
|
||||||
item_slice = self._slices.by_enum(c.ITEM).slice
|
item_slice = self._slices.by_enum(c.ITEM).slice
|
||||||
for idx, tile in enumerate(tiles, start=1):
|
# when all items should be 1
|
||||||
item_slice[tile.pos] = idx
|
xs, ys = zip(*[tile.pos for tile in tiles])
|
||||||
|
item_slice[xs, ys] = 1
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
@ -7,6 +7,7 @@ from stable_baselines3 import PPO
|
|||||||
from stable_baselines3.common.evaluation import evaluate_policy
|
from stable_baselines3.common.evaluation import evaluate_policy
|
||||||
|
|
||||||
from environments.factory.simple_factory import DirtProperties, SimpleFactory
|
from environments.factory.simple_factory import DirtProperties, SimpleFactory
|
||||||
|
from environments.factory.double_task_factory import ItemProperties, DoubleTaskFactory
|
||||||
|
|
||||||
warnings.filterwarnings('ignore', category=FutureWarning)
|
warnings.filterwarnings('ignore', category=FutureWarning)
|
||||||
warnings.filterwarnings('ignore', category=UserWarning)
|
warnings.filterwarnings('ignore', category=UserWarning)
|
||||||
@ -14,7 +15,7 @@ warnings.filterwarnings('ignore', category=UserWarning)
|
|||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
|
|
||||||
model_name = 'A2C_1627491061'
|
model_name = 'A2C_1629467677'
|
||||||
run_id = 0
|
run_id = 0
|
||||||
out_path = Path(__file__).parent / 'debug_out'
|
out_path = Path(__file__).parent / 'debug_out'
|
||||||
model_path = out_path / model_name
|
model_path = out_path / model_name
|
||||||
@ -26,7 +27,7 @@ if __name__ == '__main__':
|
|||||||
max_local_amount=1, spawn_frequency=5, max_spawn_ratio=0.05,
|
max_local_amount=1, spawn_frequency=5, max_spawn_ratio=0.05,
|
||||||
dirt_smear_amount=0.5),
|
dirt_smear_amount=0.5),
|
||||||
combin_agent_slices_in_obs=True, omit_agent_slice_in_obs=True)
|
combin_agent_slices_in_obs=True, omit_agent_slice_in_obs=True)
|
||||||
with SimpleFactory(**env_kwargs) as env:
|
with DoubleTaskFactory(**env_kwargs) as env:
|
||||||
|
|
||||||
# Edit THIS:
|
# Edit THIS:
|
||||||
model_files = list(natsorted((model_path / f'{run_id}_{model_name}').rglob('model_*.zip')))
|
model_files = list(natsorted((model_path / f'{run_id}_{model_name}').rglob('model_*.zip')))
|
||||||
|
Reference in New Issue
Block a user