mirror of
https://github.com/illiumst/marl-factory-grid.git
synced 2025-06-23 12:01:36 +02:00
no more tiles no more floor
This commit is contained in:
@ -42,12 +42,12 @@ class Move(Action, abc.ABC):
|
||||
|
||||
def do(self, entity, state):
|
||||
new_pos = self._calc_new_pos(entity.pos)
|
||||
if state.check_move_validity(entity, new_pos): # next_tile := state[c.FLOOR].by_pos(new_pos):
|
||||
if state.check_move_validity(entity, new_pos):
|
||||
# noinspection PyUnresolvedReferences
|
||||
move_validity = entity.move(new_pos, state)
|
||||
reward = r.MOVEMENTS_VALID if move_validity else r.MOVEMENTS_FAIL
|
||||
return ActionResult(entity=entity, identifier=self._identifier, validity=move_validity, reward=reward)
|
||||
else: # There is no floor, propably collision
|
||||
else: # There is no place to go, propably collision
|
||||
# This is currently handeld by the Collision rule, so that it can be switched on and off by conf.yml
|
||||
# return ActionResult(entity=entity, identifier=self._identifier, validity=c.NOT_VALID, reward=r.COLLISION)
|
||||
return ActionResult(entity=entity, identifier=self._identifier, validity=c.NOT_VALID, reward=0)
|
||||
|
@ -3,15 +3,13 @@ DANGER_ZONE = 'x' # Dange Zone tile _identifier fo
|
||||
DEFAULTS = 'Defaults'
|
||||
SELF = 'Self'
|
||||
PLACEHOLDER = 'Placeholder'
|
||||
FLOOR = 'Floor' # Identifier of Floor-objects and groups (groups).
|
||||
FLOORS = 'Floors' # Identifier of Floor-objects and groups (groups).
|
||||
WALL = 'Wall' # Identifier of Wall-objects and groups (groups).
|
||||
WALLS = 'Walls' # Identifier of Wall-objects and groups (groups).
|
||||
LEVEL = 'Level' # Identifier of Level-objects and groups (groups).
|
||||
AGENT = 'Agent' # Identifier of Agent-objects and groups (groups).
|
||||
OTHERS = 'Other'
|
||||
COMBINED = 'Combined'
|
||||
GLOBALPOSITIONS = 'GlobalPositions' # Identifier of the global position slice
|
||||
GLOBALPOSITIONS = 'GlobalPositions' # Identifier of the global position slice
|
||||
|
||||
# Attributes
|
||||
IS_BLOCKING_LIGHT = 'var_is_blocking_light'
|
||||
@ -32,7 +30,7 @@ VALUE_NO_POS = (-9999, -9999) # Invalid Position value used in the e
|
||||
|
||||
ACTION = 'action' # Identifier of Action-objects and groups (groups).
|
||||
COLLISION = 'Collision' # Identifier to use in the context of collitions.
|
||||
LAST_POS = 'LAST_POS' # Identifiert for retrieving an enitites last pos.
|
||||
# LAST_POS = 'LAST_POS' # Identifiert for retrieving an enitites last pos.
|
||||
VALIDITY = 'VALIDITY' # Identifiert for retrieving the Validity of Action, Tick, etc. ...
|
||||
|
||||
# Actions
|
||||
|
@ -2,7 +2,7 @@ from typing import List, Union
|
||||
|
||||
from marl_factory_grid.environment.actions import Action
|
||||
from marl_factory_grid.environment.entity.entity import Entity
|
||||
from marl_factory_grid.utils.render import RenderEntity
|
||||
from marl_factory_grid.utils.utility_classes import RenderEntity
|
||||
from marl_factory_grid.utils import renderer
|
||||
from marl_factory_grid.utils.helpers import is_move
|
||||
from marl_factory_grid.utils.results import ActionResult, Result
|
||||
|
@ -1,8 +1,10 @@
|
||||
import abc
|
||||
|
||||
import numpy as np
|
||||
|
||||
from .. import constants as c
|
||||
from .object import EnvObject
|
||||
from ...utils.render import RenderEntity
|
||||
from ...utils.utility_classes import RenderEntity
|
||||
from ...utils.results import ActionResult
|
||||
|
||||
|
||||
@ -30,33 +32,32 @@ class Entity(EnvObject, abc.ABC):
|
||||
return self._pos
|
||||
|
||||
@property
|
||||
def tile(self):
|
||||
return self._tile # wall_n_floors funktionalität
|
||||
|
||||
# @property
|
||||
# def last_tile(self):
|
||||
# try:
|
||||
# return self._last_tile
|
||||
# except AttributeError:
|
||||
# # noinspection PyAttributeOutsideInit
|
||||
# self._last_tile = None
|
||||
# return self._last_tile
|
||||
def last_pos(self):
|
||||
try:
|
||||
return self._last_pos
|
||||
except AttributeError:
|
||||
# noinspection PyAttributeOutsideInit
|
||||
self._last_pos = c.VALUE_NO_POS
|
||||
return self._last_pos
|
||||
|
||||
@property
|
||||
def direction_of_view(self):
|
||||
last_x, last_y = self._last_pos
|
||||
curr_x, curr_y = self.pos
|
||||
return last_x - curr_x, last_y - curr_y
|
||||
if self._last_pos != c.VALUE_NO_POS:
|
||||
return 0, 0
|
||||
else:
|
||||
return np.subtract(self._last_pos, self.pos)
|
||||
|
||||
def move(self, next_pos, state):
|
||||
next_pos = next_pos
|
||||
curr_pos = self._pos
|
||||
if not_same_pos := curr_pos != next_pos:
|
||||
if valid := state.check_move_validity(self, next_pos):
|
||||
self._pos = next_pos
|
||||
self._last_pos = curr_pos
|
||||
for observer in self.observers:
|
||||
observer.notify_change_pos(self)
|
||||
observer.notify_del_entity(self)
|
||||
self._view_directory = curr_pos[0]-next_pos[0], curr_pos[1]-next_pos[1]
|
||||
self._pos = next_pos
|
||||
for observer in self.observers:
|
||||
observer.notify_add_entity(self)
|
||||
return valid
|
||||
return not_same_pos
|
||||
|
||||
@ -64,6 +65,7 @@ class Entity(EnvObject, abc.ABC):
|
||||
super().__init__(**kwargs)
|
||||
self._status = None
|
||||
self._pos = pos
|
||||
self._last_pos = pos
|
||||
if bind_to:
|
||||
try:
|
||||
self.bind_to(bind_to)
|
||||
|
@ -4,7 +4,7 @@ import numpy as np
|
||||
|
||||
from marl_factory_grid.environment import constants as c
|
||||
from marl_factory_grid.environment.entity.object import EnvObject
|
||||
from marl_factory_grid.utils.render import RenderEntity
|
||||
from marl_factory_grid.utils.utility_classes import RenderEntity
|
||||
from marl_factory_grid.utils import helpers as h
|
||||
|
||||
|
||||
@ -30,17 +30,6 @@ class Floor(EnvObject):
|
||||
def var_is_blocking_light(self):
|
||||
return False
|
||||
|
||||
@property
|
||||
def neighboring_floor(self):
|
||||
if self._neighboring_floor:
|
||||
pass
|
||||
else:
|
||||
self._neighboring_floor = [x for x in [self._collection.by_pos(np.add(self.pos, pos))
|
||||
for pos in h.POS_MASK.reshape(-1, 2)
|
||||
if not np.all(pos == [0, 0])]
|
||||
if x]
|
||||
return self._neighboring_floor
|
||||
|
||||
@property
|
||||
def encoding(self):
|
||||
return c.VALUE_OCCUPIED_CELL
|
||||
|
@ -197,7 +197,7 @@ class Factory(gym.Env):
|
||||
del rewards['global']
|
||||
reward = [rewards[agent.name] for agent in self.state[c.AGENT]]
|
||||
reward = [x + global_rewards for x in reward]
|
||||
self.state.print(f"rewards are {rewards}")
|
||||
self.state.print(f"Individual rewards are {dict(rewards)}")
|
||||
return reward, combined_info_dict, done
|
||||
else:
|
||||
reward = sum(rewards.values())
|
||||
@ -220,7 +220,7 @@ class Factory(gym.Env):
|
||||
|
||||
def summarize_header(self):
|
||||
header = {'rec_step': self.state.curr_step}
|
||||
for entity_group in (x for x in self.state if x.name in ['Walls', 'Floors', 'DropOffLocations', 'ChargePods']):
|
||||
for entity_group in (x for x in self.state if x.name in ['Walls', 'DropOffLocations', 'ChargePods']):
|
||||
header.update({f'rec{entity_group.name}': entity_group.summarize_states()})
|
||||
return header
|
||||
|
||||
@ -229,7 +229,7 @@ class Factory(gym.Env):
|
||||
|
||||
# Todo: Protobuff Compatibility Section #######
|
||||
# for entity_group in (x for x in self.state if x.name not in [c.WALLS, c.FLOORS]):
|
||||
for entity_group in (x for x in self.state if x.name not in [c.FLOORS]):
|
||||
for entity_group in self.state:
|
||||
summary.update({entity_group.name.lower(): entity_group.summarize_states()})
|
||||
# TODO Section End ########
|
||||
for key in list(summary.keys()):
|
||||
|
@ -1,5 +1,6 @@
|
||||
from collections import defaultdict
|
||||
from operator import itemgetter
|
||||
from random import shuffle
|
||||
from typing import Dict
|
||||
|
||||
from marl_factory_grid.environment.groups.objects import Objects
|
||||
@ -13,7 +14,7 @@ class Entities(Objects):
|
||||
def neighboring_positions(pos):
|
||||
return (POS_MASK + pos).reshape(-1, 2)
|
||||
|
||||
def get_near_pos(self, pos):
|
||||
def get_entities_near_pos(self, pos):
|
||||
return [y for x in itemgetter(*(tuple(x) for x in self.neighboring_positions(pos)))(self.pos_dict) for y in x]
|
||||
|
||||
def render(self):
|
||||
@ -38,11 +39,17 @@ class Entities(Objects):
|
||||
def guests_that_can_collide(self, pos):
|
||||
return[x for val in self.pos_dict[pos] for x in val if x.var_can_collide]
|
||||
|
||||
def empty_tiles(self):
|
||||
return[key for key in self.floorlist if not any(self.pos_dict[key])]
|
||||
@property
|
||||
def empty_positions(self):
|
||||
empty_positions= [key for key in self.floorlist if self.pos_dict[key]]
|
||||
shuffle(empty_positions)
|
||||
return empty_positions
|
||||
|
||||
def occupied_tiles(self): # positions that are not empty
|
||||
return[key for key in self.floorlist if any(self.pos_dict[key])]
|
||||
@property
|
||||
def occupied_positions(self): # positions that are not empty
|
||||
empty_positions = [key for key in self.floorlist if self.pos_dict[key]]
|
||||
shuffle(empty_positions)
|
||||
return empty_positions
|
||||
|
||||
def is_blocked(self):
|
||||
return[key for key, val in self.pos_dict.items() if any([x.var_is_blocking_pos for x in val])]
|
||||
|
@ -37,7 +37,11 @@ class PositionMixin:
|
||||
|
||||
def __delitem__(self, name):
|
||||
idx, obj = next((i, obj) for i, obj in enumerate(self) if obj.name == name)
|
||||
obj.tile.leave(obj) # observer notify?
|
||||
try:
|
||||
for observer in obj.observers:
|
||||
observer.notify_del_entity(obj)
|
||||
except AttributeError:
|
||||
pass
|
||||
super().__delitem__(name)
|
||||
|
||||
def by_pos(self, pos: (int, int)):
|
||||
|
@ -103,6 +103,9 @@ class Objects:
|
||||
except StopIteration:
|
||||
return None
|
||||
|
||||
def by_name(self, name):
|
||||
return next(x for x in self if x.name == name)
|
||||
|
||||
def __getitem__(self, item):
|
||||
if isinstance(item, (int, np.int64, np.int32)):
|
||||
if item < 0:
|
||||
@ -120,7 +123,7 @@ class Objects:
|
||||
raise TypeError
|
||||
|
||||
def __repr__(self):
|
||||
repr_dict = { key: val for key, val in self._data.items() if key not in [c.WALLS, c.FLOORS]}
|
||||
repr_dict = { key: val for key, val in self._data.items() if key not in [c.WALLS]}
|
||||
return f'{self.__class__.__name__}[{repr_dict}]'
|
||||
|
||||
def spawn(self, n: int):
|
||||
@ -132,22 +135,25 @@ class Objects:
|
||||
for item in items:
|
||||
del self[item]
|
||||
|
||||
def notify_change_pos(self, entity: object):
|
||||
try:
|
||||
self.pos_dict[entity.last_pos].remove(entity)
|
||||
except (ValueError, AttributeError):
|
||||
pass
|
||||
if entity.var_has_position:
|
||||
try:
|
||||
self.pos_dict[entity.pos].append(entity)
|
||||
except (ValueError, AttributeError):
|
||||
pass
|
||||
# def notify_change_pos(self, entity: object):
|
||||
# try:
|
||||
# self.pos_dict[entity.last_pos].remove(entity)
|
||||
# except (ValueError, AttributeError):
|
||||
# pass
|
||||
# if entity.var_has_position:
|
||||
# try:
|
||||
# self.pos_dict[entity.pos].append(entity)
|
||||
# except (ValueError, AttributeError):
|
||||
# pass
|
||||
|
||||
def notify_del_entity(self, entity: Object):
|
||||
try:
|
||||
entity.del_observer(self)
|
||||
except AttributeError:
|
||||
pass
|
||||
try:
|
||||
self.pos_dict[entity.pos].remove(entity)
|
||||
except (ValueError, AttributeError):
|
||||
except (AttributeError, ValueError, IndexError):
|
||||
pass
|
||||
|
||||
def notify_add_entity(self, entity: Object):
|
||||
|
@ -15,6 +15,7 @@ class Walls(PositionMixin, EnvObjects):
|
||||
super(Walls, self).__init__(*args, **kwargs)
|
||||
self._value = c.VALUE_OCCUPIED_CELL
|
||||
|
||||
#ToDo: Do we need this? Move to spawn methode?
|
||||
# @classmethod
|
||||
# def from_coordinates(cls, argwhere_coordinates, *args, **kwargs):
|
||||
# tiles = cls(*args, **kwargs)
|
@ -49,7 +49,7 @@ class SpawnAgents(Rule):
|
||||
agent_conf = state.agents_conf
|
||||
# agents = Agents(lvl_map.size)
|
||||
agents = state[c.AGENT]
|
||||
empty_tiles = state[c.FLOORS].empty_tiles[:len(agent_conf)]
|
||||
empty_positions = state.entities.empty_positions[:len(agent_conf)]
|
||||
for agent_name in agent_conf:
|
||||
actions = agent_conf[agent_name]['actions'].copy()
|
||||
observations = agent_conf[agent_name]['observations'].copy()
|
||||
@ -58,18 +58,17 @@ class SpawnAgents(Rule):
|
||||
shuffle(positions)
|
||||
while True:
|
||||
try:
|
||||
tile = state[c.FLOORS].by_pos(positions.pop())
|
||||
pos = positions.pop()
|
||||
except IndexError as e:
|
||||
raise ValueError(f'It was not possible to spawn an Agent on the available position: '
|
||||
f'\n{agent_name[agent_name]["positions"].copy()}')
|
||||
try:
|
||||
agents.add_item(Agent(actions, observations, tile, str_ident=agent_name))
|
||||
except AssertionError:
|
||||
state.print(f'No valid pos:{tile.pos} for {agent_name}')
|
||||
if agents.by_pos(pos) and state.check_pos_validity(pos):
|
||||
continue
|
||||
else:
|
||||
agents.add_item(Agent(actions, observations, pos, str_ident=agent_name))
|
||||
break
|
||||
else:
|
||||
agents.add_item(Agent(actions, observations, empty_tiles.pop(), str_ident=agent_name))
|
||||
agents.add_item(Agent(actions, observations, empty_positions.pop(), str_ident=agent_name))
|
||||
pass
|
||||
|
||||
|
||||
|
Reference in New Issue
Block a user