new observation properties for testing of technical limitations
This commit is contained in:
@ -3,6 +3,7 @@ from collections import deque, UserList
|
||||
from enum import Enum
|
||||
from typing import List, Union, NamedTuple, Dict
|
||||
import numpy as np
|
||||
import random
|
||||
|
||||
from environments.factory.base.base_factory import BaseFactory
|
||||
from environments.helpers import Constants as c
|
||||
@ -18,13 +19,6 @@ NO_ITEM = 0
|
||||
ITEM_DROP_OFF = 1
|
||||
|
||||
|
||||
def inventory_slice_name(agent_i):
|
||||
if isinstance(agent_i, int):
|
||||
return f'{c.INVENTORY.name}_{c.AGENT.value}#{agent_i}'
|
||||
else:
|
||||
return f'{c.INVENTORY.name}_{agent_i}'
|
||||
|
||||
|
||||
class Item(MoveableEntity):
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
@ -77,7 +71,7 @@ class Inventory(UserList):
|
||||
|
||||
@property
|
||||
def name(self):
|
||||
return self.agent.name
|
||||
return f'{self.__class__.__name__}({self.agent.name})'
|
||||
|
||||
def __init__(self, pomdp_r: int, level_shape: (int, int), agent: Agent, capacity: int):
|
||||
super(Inventory, self).__init__()
|
||||
@ -111,7 +105,8 @@ class Inventory(UserList):
|
||||
|
||||
def summarize_state(self, **kwargs):
|
||||
attr_dict = {key: str(val) for key, val in self.__dict__.items() if not key.startswith('_') and key != 'data'}
|
||||
attr_dict.update({val.name: val.summarize_state(**kwargs) for val in self})
|
||||
attr_dict.update(dict(items={val.name: val.summarize_state(**kwargs) for val in self}))
|
||||
attr_dict.update(dict(name=self.name))
|
||||
return attr_dict
|
||||
|
||||
|
||||
@ -149,6 +144,11 @@ class Inventories(ObjectRegister):
|
||||
except StopIteration:
|
||||
return None
|
||||
|
||||
def summarize_states(self, n_steps=None):
|
||||
# as dict with additional nesting
|
||||
# return dict(items=super(Inventories, self).summarize_states())
|
||||
return super(Inventories, self).summarize_states(n_steps=n_steps)
|
||||
|
||||
|
||||
class DropOffLocation(Entity):
|
||||
|
||||
@ -194,6 +194,9 @@ class DropOffLocations(EntityObjectRegister):
|
||||
self._array[0, item.x, item.y] = item.encoding
|
||||
return self._array
|
||||
|
||||
def __repr__(self):
|
||||
super(DropOffLocations, self).__repr__()
|
||||
|
||||
|
||||
class ItemProperties(NamedTuple):
|
||||
n_items: int = 5 # How many items are there at the same time
|
||||
@ -207,13 +210,13 @@ class ItemProperties(NamedTuple):
|
||||
# noinspection PyAttributeOutsideInit, PyAbstractClass
|
||||
class ItemFactory(BaseFactory):
|
||||
# noinspection PyMissingConstructor
|
||||
def __init__(self, *args, item_properties: ItemProperties = ItemProperties(), env_seed=time.time_ns(), **kwargs):
|
||||
if isinstance(item_properties, dict):
|
||||
item_properties = ItemProperties(**item_properties)
|
||||
self.item_properties = item_properties
|
||||
def __init__(self, *args, item_prop: ItemProperties = ItemProperties(), env_seed=time.time_ns(), **kwargs):
|
||||
if isinstance(item_prop, dict):
|
||||
item_prop = ItemProperties(**item_prop)
|
||||
self.item_prop = item_prop
|
||||
kwargs.update(env_seed=env_seed)
|
||||
self._item_rng = np.random.default_rng(env_seed)
|
||||
assert (item_properties.n_items <= ((1 + kwargs.get('pomdp_r', 0) * 2) ** 2)) or not kwargs.get('pomdp_r', 0)
|
||||
assert (item_prop.n_items <= ((1 + kwargs.get('_pomdp_r', 0) * 2) ** 2)) or not kwargs.get('_pomdp_r', 0)
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
@property
|
||||
@ -228,16 +231,19 @@ class ItemFactory(BaseFactory):
|
||||
# noinspection PyUnresolvedReferences
|
||||
super_entities = super().additional_entities
|
||||
|
||||
empty_tiles = self[c.FLOOR].empty_tiles[:self.item_properties.n_drop_off_locations]
|
||||
drop_offs = DropOffLocations.from_tiles(empty_tiles, self._level_shape,
|
||||
storage_size_until_full=self.item_properties.max_dropoff_storage_size)
|
||||
empty_tiles = self[c.FLOOR].empty_tiles[:self.item_prop.n_drop_off_locations]
|
||||
drop_offs = DropOffLocations.from_tiles(
|
||||
empty_tiles, self._level_shape,
|
||||
entity_kwargs=dict(
|
||||
storage_size_until_full=self.item_prop.max_dropoff_storage_size)
|
||||
)
|
||||
item_register = ItemRegister(self._level_shape)
|
||||
empty_tiles = self[c.FLOOR].empty_tiles[:self.item_properties.n_items]
|
||||
empty_tiles = self[c.FLOOR].empty_tiles[:self.item_prop.n_items]
|
||||
item_register.spawn_items(empty_tiles)
|
||||
|
||||
inventories = Inventories(self._level_shape if not self.pomdp_r else ((self.pomdp_diameter,) * 2))
|
||||
inventories.spawn_inventories(self[c.AGENT], self.pomdp_r,
|
||||
self.item_properties.max_agent_inventory_capacity)
|
||||
inventories = Inventories(self._level_shape if not self._pomdp_r else ((self.pomdp_diameter,) * 2))
|
||||
inventories.spawn_inventories(self[c.AGENT], self._pomdp_r,
|
||||
self.item_prop.max_agent_inventory_capacity)
|
||||
|
||||
super_entities.update({c.DROP_OFF: drop_offs, c.ITEM: item_register, c.INVENTORY: inventories})
|
||||
return super_entities
|
||||
@ -270,7 +276,7 @@ class ItemFactory(BaseFactory):
|
||||
valid = super().do_additional_actions(agent, action)
|
||||
if valid is None:
|
||||
if action == h.EnvActions.ITEM_ACTION:
|
||||
if self.item_properties.agent_can_interact:
|
||||
if self.item_prop.agent_can_interact:
|
||||
valid = self.do_item_action(agent)
|
||||
return valid
|
||||
else:
|
||||
@ -283,14 +289,14 @@ class ItemFactory(BaseFactory):
|
||||
def do_additional_reset(self) -> None:
|
||||
# noinspection PyUnresolvedReferences
|
||||
super().do_additional_reset()
|
||||
self._next_item_spawn = self.item_properties.spawn_frequency
|
||||
self._next_item_spawn = self.item_prop.spawn_frequency
|
||||
self.trigger_item_spawn()
|
||||
|
||||
def trigger_item_spawn(self):
|
||||
if item_to_spawns := max(0, (self.item_properties.n_items - len(self[c.ITEM]))):
|
||||
if item_to_spawns := max(0, (self.item_prop.n_items - len(self[c.ITEM]))):
|
||||
empty_tiles = self[c.FLOOR].empty_tiles[:item_to_spawns]
|
||||
self[c.ITEM].spawn_items(empty_tiles)
|
||||
self._next_item_spawn = self.item_properties.spawn_frequency
|
||||
self._next_item_spawn = self.item_prop.spawn_frequency
|
||||
self.print(f'{item_to_spawns} new items have been spawned; next spawn in {self._next_item_spawn}')
|
||||
else:
|
||||
self.print('No Items are spawning, limit is reached.')
|
||||
@ -351,30 +357,41 @@ class ItemFactory(BaseFactory):
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
import random
|
||||
from environments.utility_classes import AgentRenderOptions as ARO, ObservationProperties
|
||||
|
||||
render = True
|
||||
|
||||
item_props = ItemProperties()
|
||||
item_probs = ItemProperties()
|
||||
|
||||
factory = ItemFactory(item_properties=item_props, n_agents=3, done_at_collision=False, frames_to_stack=0,
|
||||
level_name='rooms', max_steps=4000,
|
||||
omit_agent_in_obs=True, parse_doors=True, pomdp_r=3,
|
||||
record_episodes=False, verbose=False
|
||||
obs_props = ObservationProperties(render_agents=ARO.LEVEL, omit_agent_self=True, pomdp_r=2)
|
||||
|
||||
move_props = {'allow_square_movement': True,
|
||||
'allow_diagonal_movement': False,
|
||||
'allow_no_op': False}
|
||||
|
||||
factory = ItemFactory(n_agents=3, done_at_collision=False,
|
||||
level_name='rooms', max_steps=400,
|
||||
obs_prop=obs_props, parse_doors=True,
|
||||
record_episodes=True, verbose=True,
|
||||
mv_prop=move_props, item_prop=item_probs
|
||||
)
|
||||
|
||||
# noinspection DuplicatedCode
|
||||
n_actions = factory.action_space.n - 1
|
||||
_ = factory.observation_space
|
||||
|
||||
for epoch in range(100):
|
||||
random_actions = [[random.randint(0, n_actions) for _ in range(factory.n_agents)] for _ in range(200)]
|
||||
for epoch in range(4):
|
||||
random_actions = [[random.randint(0, n_actions) for _
|
||||
in range(factory.n_agents)] for _
|
||||
in range(factory.max_steps + 1)]
|
||||
env_state = factory.reset()
|
||||
rew = 0
|
||||
r = 0
|
||||
for agent_i_action in random_actions:
|
||||
env_state, step_r, done_bool, info_obj = factory.step(agent_i_action)
|
||||
rew += step_r
|
||||
r += step_r
|
||||
if render:
|
||||
factory.render()
|
||||
if done_bool:
|
||||
break
|
||||
print(f'Factory run {epoch} done, reward is:\n {rew}')
|
||||
print(f'Factory run {epoch} done, reward is:\n {r}')
|
||||
pass
|
||||
|
Reference in New Issue
Block a user