mirror of
https://github.com/illiumst/marl-factory-grid.git
synced 2025-06-23 12:01:36 +02:00
New Szenario "Two_Rooms_One_Door"
This commit is contained in:
@ -62,11 +62,14 @@ class Object:
|
||||
|
||||
def add_observer(self, observer):
|
||||
self.observers.append(observer)
|
||||
observer.notify_change_pos(self)
|
||||
observer.notify_add_entity(self)
|
||||
|
||||
def del_observer(self, observer):
|
||||
self.observers.remove(observer)
|
||||
|
||||
def summarize_state(self):
|
||||
return dict()
|
||||
|
||||
|
||||
class EnvObject(Object):
|
||||
|
||||
@ -128,3 +131,6 @@ class EnvObject(Object):
|
||||
self._collection.delete_env_object(self)
|
||||
self._collection = other_collection
|
||||
return self._collection == other_collection
|
||||
|
||||
def summarize_state(self):
|
||||
return dict(name=str(self.name))
|
||||
|
@ -16,8 +16,6 @@ import marl_factory_grid.environment.constants as c
|
||||
|
||||
from marl_factory_grid.utils.states import Gamestate
|
||||
|
||||
REC_TAC = 'rec_'
|
||||
|
||||
|
||||
class Factory(gym.Env):
|
||||
|
||||
@ -44,11 +42,6 @@ class Factory(gym.Env):
|
||||
config_dict = yaml.safe_load(config_path.open())
|
||||
return config_dict
|
||||
|
||||
@property
|
||||
def summarize_header(self):
|
||||
summary_dict = self._summarize_state(stateless_entities=True)
|
||||
return summary_dict
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
self.close()
|
||||
|
||||
@ -125,9 +118,6 @@ class Factory(gym.Env):
|
||||
info = reward_info
|
||||
|
||||
info.update(step_reward=sum(reward), step=self.state.curr_step)
|
||||
# TODO:
|
||||
# if self._record_episodes:
|
||||
# info.update(self._summarize_state())
|
||||
|
||||
obs, reset_info = self.obs_builder.refresh_and_build_for_all(self.state)
|
||||
info.update(reset_info)
|
||||
@ -171,14 +161,6 @@ class Factory(gym.Env):
|
||||
self.state.print(f"reward is {reward}")
|
||||
return reward, combined_info_dict, done
|
||||
|
||||
def start_recording(self):
|
||||
self.conf.do_record = True
|
||||
return self.conf.do_record
|
||||
|
||||
def stop_recording(self):
|
||||
self.conf.do_record = False
|
||||
return not self.conf.do_record
|
||||
|
||||
# noinspection PyGlobalUndefined
|
||||
def render(self, mode='human'):
|
||||
if not self._renderer: # lazy init
|
||||
@ -193,12 +175,23 @@ class Factory(gym.Env):
|
||||
render_entity.aux = self.obs_builder.curr_lightmaps[render_entity.real_name]
|
||||
return self._renderer.render(render_entities)
|
||||
|
||||
def _summarize_state(self, stateless_entities=False):
|
||||
summary = {f'{REC_TAC}step': self.state.curr_step}
|
||||
def summarize_header(self):
|
||||
header = {'rec_step': self.state.curr_step}
|
||||
for entity_group in (x for x in self.state if x.name in ['Walls', 'Floors', 'DropOffLocations', 'ChargePods']):
|
||||
header.update({f'rec{entity_group.name}': entity_group.summarize_states()})
|
||||
return header
|
||||
|
||||
for entity_group in self.state:
|
||||
if entity_group.is_stateless == stateless_entities:
|
||||
summary.update({f'{REC_TAC}{entity_group.name}': entity_group.summarize_states()})
|
||||
def summarize_state(self):
|
||||
summary = {'step': self.state.curr_step}
|
||||
|
||||
# Todo: Protobuff Compatibility Section #######
|
||||
# for entity_group in (x for x in self.state if x.name not in [c.WALLS, c.FLOORS]):
|
||||
for entity_group in (x for x in self.state if x.name not in [c.FLOORS]):
|
||||
summary.update({entity_group.name.lower(): entity_group.summarize_states()})
|
||||
# TODO Section End ########
|
||||
for key in list(summary.keys()):
|
||||
if key not in ['step', 'walls', 'doors', 'agents', 'items', 'dirtPiles', 'batteries']:
|
||||
del summary[key]
|
||||
return summary
|
||||
|
||||
def print(self, string):
|
||||
|
@ -23,9 +23,6 @@ class EnvObjects(Objects):
|
||||
super(EnvObjects, self).add_item(item)
|
||||
return self
|
||||
|
||||
def summarize_states(self):
|
||||
return [entity.summarize_state() for entity in self.values()]
|
||||
|
||||
def delete_env_object(self, env_object: EnvObject):
|
||||
del self[env_object.name]
|
||||
|
||||
|
@ -45,7 +45,8 @@ class PositionMixin:
|
||||
def by_pos(self, pos: (int, int)):
|
||||
pos = tuple(pos)
|
||||
try:
|
||||
return next(e for e in self if e.pos == pos)
|
||||
return self.pos_dict[pos]
|
||||
# return next(e for e in self if e.pos == pos)
|
||||
except StopIteration:
|
||||
pass
|
||||
except ValueError:
|
||||
|
@ -144,7 +144,13 @@ class Objects:
|
||||
|
||||
def notify_add_entity(self, entity: Object):
|
||||
try:
|
||||
entity.add_observer(self)
|
||||
if self not in entity.observers:
|
||||
entity.add_observer(self)
|
||||
self.pos_dict[entity.pos].append(entity)
|
||||
except (ValueError, AttributeError):
|
||||
pass
|
||||
|
||||
def summarize_states(self):
|
||||
# FIXME PROTOBUFF
|
||||
# return [e.summarize_state() for e in self]
|
||||
return [e.summarize_state() for e in self]
|
||||
|
@ -43,38 +43,3 @@ class GlobalPositions(HasBoundMixin, EnvObjects):
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super(GlobalPositions, self).__init__(*args, **kwargs)
|
||||
|
||||
|
||||
class ZonesOLD(Objects):
|
||||
|
||||
_entity = Zone
|
||||
|
||||
@property
|
||||
def accounting_zones(self):
|
||||
return [self[idx] for idx, name in self.items() if name != c.DANGER_ZONE]
|
||||
|
||||
def __init__(self, parsed_level):
|
||||
raise NotImplementedError('This needs a Rework')
|
||||
super(Zones, self).__init__()
|
||||
slices = list()
|
||||
self._accounting_zones = list()
|
||||
self._danger_zones = list()
|
||||
for symbol in np.unique(parsed_level):
|
||||
if symbol == c.VALUE_OCCUPIED_CELL:
|
||||
continue
|
||||
elif symbol == c.DANGER_ZONE:
|
||||
self + symbol
|
||||
slices.append(h.one_hot_level(parsed_level, symbol))
|
||||
self._danger_zones.append(symbol)
|
||||
else:
|
||||
self + symbol
|
||||
slices.append(h.one_hot_level(parsed_level, symbol))
|
||||
self._accounting_zones.append(symbol)
|
||||
|
||||
self._zone_slices = np.stack(slices)
|
||||
|
||||
def __getitem__(self, item):
|
||||
return self._zone_slices[item]
|
||||
|
||||
def add_items(self, other: Union[str, List[str]]):
|
||||
raise AttributeError('You are not allowed to add additional Zones in runtime.')
|
||||
|
@ -26,6 +26,12 @@ class Walls(PositionMixin, EnvObjects):
|
||||
def from_tiles(cls, tiles, *args, **kwargs):
|
||||
raise RuntimeError()
|
||||
|
||||
def by_pos(self, pos: (int, int)):
|
||||
try:
|
||||
return super().by_pos(pos)[0]
|
||||
except IndexError:
|
||||
return None
|
||||
|
||||
|
||||
class Floors(Walls):
|
||||
_entity = Floor
|
||||
|
Reference in New Issue
Block a user