Merge branch 'main' into unit_testing

# Conflicts:
#	marl_factory_grid/environment/factory.py
#	marl_factory_grid/utils/states.py
This commit is contained in:
Chanumask
2023-11-10 14:26:45 +01:00
22 changed files with 100 additions and 129 deletions

View File

@@ -18,6 +18,7 @@ class Action(abc.ABC):
@abc.abstractmethod
def do(self, entity, state) -> Union[None, ActionResult]:
print()
return
def __repr__(self):

View File

@@ -41,7 +41,7 @@ class Object:
def __init__(self, str_ident: Union[str, None] = None, **kwargs):
self._bound_entity = None
self._observers = []
self._observers = set()
self._str_ident = str_ident
self.u_int = self._identify_and_count_up()
self._collection = None
@@ -75,7 +75,7 @@ class Object:
self._collection = collection
def add_observer(self, observer):
self.observers.append(observer)
self.observers.add(observer)
observer.notify_add_entity(self)
def del_observer(self, observer):

View File

@@ -69,23 +69,6 @@ class Factory(gym.Env):
# expensive - don't use; unless required !
self._renderer = None
# reset env to initial state, preparing env for new episode.
# returns tuple where the first dict contains initial observation for each agent in the env
self.reset()
def __getitem__(self, item):
return self.state.entities[item]
def reset(self) -> (dict, dict):
if self.state is not None:
for entity_group in self.state.entities:
try:
entity_group[0].reset_uid()
except (AttributeError, TypeError):
pass
self.state = None
# Init entities
entities = self.map.do_init()
@@ -101,7 +84,6 @@ class Factory(gym.Env):
self.state = Gamestate(entities, parsed_agents_conf, env_rules, env_tests, self.map.level_shape,
self.conf.env_seed, self.conf.verbose)
# All is set up, trigger entity init with variable pos
# All is set up, trigger additional init (after agent entity spawn etc)
self.state.rules.do_all_init(self.state, self.map)
@@ -110,6 +92,17 @@ class Factory(gym.Env):
# Build initial observations for all agents
# noinspection PyAttributeOutsideInit
self.obs_builder = OBSBuilder(self.map.level_shape, self.state, self.map.pomdp_r)
def __getitem__(self, item):
return self.state.entities[item]
def reset(self) -> (dict, dict):
self.state.entities.reset()
# All is set up, trigger entity spawn with variable pos
self.state.rules.do_all_reset(self.state)
# Build initial observations for all agents
return self.obs_builder.refresh_and_build_for_all(self.state)
def manual_step_init(self) -> List[Result]:

View File

@@ -2,7 +2,6 @@ from typing import List, Tuple, Union, Dict
from marl_factory_grid.environment.entity.entity import Entity
from marl_factory_grid.environment.groups.objects import Objects
# noinspection PyProtectedMember
from marl_factory_grid.environment.entity.object import Object
import marl_factory_grid.environment.constants as c
from marl_factory_grid.utils.results import Result

View File

@@ -31,9 +31,12 @@ class Entities(Objects):
def __init__(self, floor_positions):
self._floor_positions = floor_positions
self.pos_dict = defaultdict(list)
self.pos_dict = None
super().__init__()
def __repr__(self):
return f'{self.__class__.__name__}{[x for x in self]}'
def guests_that_can_collide(self, pos):
return [x for val in self.pos_dict[pos] for x in val if x.var_can_collide]
@@ -108,3 +111,12 @@ class Entities(Objects):
def is_occupied(self, pos):
return len([x for x in self.pos_dict[pos] if x.var_can_collide or x.var_is_blocking_pos]) >= 1
def reset(self):
self._observers = set(self)
self.pos_dict = defaultdict(list)
for entity_group in self:
entity_group.reset()
if hasattr(entity_group, "var_has_position") and entity_group.var_has_position:
entity_group.add_observer(self)

View File

@@ -44,7 +44,7 @@ class Objects:
def __init__(self, *args, **kwargs):
self._data = defaultdict(lambda: None)
self._observers = [self]
self._observers = set(self)
self.pos_dict = defaultdict(list)
def __len__(self):
@@ -59,6 +59,8 @@ class Objects:
assert self._data[item.name] is None, f'{item.name} allready exists!!!'
self._data.update({item.name: item})
item.set_collection(self)
if hasattr(self, "var_has_position") and self.var_has_position:
item.add_observer(self)
for observer in self.observers:
observer.notify_add_entity(item)
return self
@@ -82,10 +84,9 @@ class Objects:
# noinspection PyUnresolvedReferences
def add_observer(self, observer):
self.observers.append(observer)
self.observers.add(observer)
for entity in self:
if observer not in entity.observers:
entity.add_observer(observer)
entity.add_observer(observer)
def add_items(self, items: List[_entity]):
for item in items:
@@ -127,8 +128,7 @@ class Objects:
raise TypeError
def __repr__(self):
repr_dict = {key: val for key, val in self._data.items() if key not in [c.WALLS]}
return f'{self.__class__.__name__}[{repr_dict}]'
return f'{self.__class__.__name__}[{len(self)}]'
def notify_del_entity(self, entity: Object):
try:
@@ -163,3 +163,9 @@ class Objects:
return h.get_first_index(self, filter_by=lambda x: x.belongs_to_entity(entity))
except (StopIteration, AttributeError):
return None
def reset(self):
self._data = defaultdict(lambda: None)
self._observers = set(self)
self.pos_dict = defaultdict(list)

View File

@@ -23,3 +23,7 @@ class Walls(Collection):
return super().by_pos(pos)[0]
except IndexError:
return None
def reset(self):
pass

View File

@@ -23,7 +23,7 @@ class Rule(abc.ABC):
def on_init(self, state, lvl_map):
return []
def on_reset(self):
def on_reset(self, state) -> List[TickResult]:
return []
def tick_pre_step(self, state) -> List[TickResult]:
@@ -55,7 +55,7 @@ class SpawnEntity(Rule):
self.collection = collection
self.ignore_blocking = ignore_blocking
def on_init(self, state, lvl_map) -> [TickResult]:
def on_reset(self, state) -> [TickResult]:
results = self.collection.trigger_spawn(state, ignore_blocking=self.ignore_blocking)
pos_str = f' on: {[x.pos for x in self.collection]}' if self.collection.var_has_position else ''
state.print(f'Initial {self.collection.__class__.__name__} were spawned{pos_str}')
@@ -68,8 +68,7 @@ class SpawnAgents(Rule):
super().__init__()
pass
def on_init(self, state, lvl_map):
# agents = Agents(lvl_map.size)
def on_reset(self, state):
agents = state[c.AGENT]
empty_positions = state.entities.empty_positions[:len(state.agents_conf)]
for agent_name, agent_conf in state.agents_conf.items():
@@ -101,9 +100,6 @@ class DoneAtMaxStepsReached(Rule):
super().__init__()
self.max_steps = max_steps
def on_init(self, state, lvl_map):
pass
def on_check_done(self, state):
if self.max_steps <= state.curr_step:
return [DoneResult(validity=c.VALID, identifier=self.name)]
@@ -115,7 +111,7 @@ class AssignGlobalPositions(Rule):
def __init__(self):
super().__init__()
def on_init(self, state, lvl_map):
def on_reset(self, state, lvl_map):
from marl_factory_grid.environment.entity.util import GlobalPosition
for agent in state[c.AGENT]:
gp = GlobalPosition(agent, lvl_map.level_shape)