mirror of
https://github.com/illiumst/marl-factory-grid.git
synced 2025-09-15 23:37:14 +02:00
Machines
This commit is contained in:
@@ -98,3 +98,5 @@ class NorthWest(Move):
|
||||
Move4 = [North, East, South, West]
|
||||
# noinspection PyTypeChecker
|
||||
Move8 = Move4 + [NorthEast, SouthEast, SouthWest, NorthWest]
|
||||
|
||||
ALL_BASEACTIONS = Move8 + [Noop]
|
||||
|
@@ -9,15 +9,13 @@ WALL = 'Wall' # Identifier of Wall-objects and
|
||||
WALLS = 'Walls' # Identifier of Wall-objects and groups (groups).
|
||||
LEVEL = 'Level' # Identifier of Level-objects and groups (groups).
|
||||
AGENT = 'Agent' # Identifier of Agent-objects and groups (groups).
|
||||
AGENTS = 'Agents' # Identifier of Agent-objects and groups (groups).
|
||||
OTHERS = 'Other'
|
||||
COMBINED = 'Combined'
|
||||
GLOBAL_POSITION = 'GLOBAL_POSITION' # Identifier of the global position slice
|
||||
|
||||
GLOBALPOSITIONS = 'GlobalPositions' # Identifier of the global position slice
|
||||
|
||||
# Attributes
|
||||
IS_BLOCKING_LIGHT = 'is_blocking_light'
|
||||
HAS_POSITION = 'has_position'
|
||||
IS_BLOCKING_LIGHT = 'var_is_blocking_light'
|
||||
HAS_POSITION = 'var_has_position'
|
||||
HAS_NO_POSITION = 'has_no_position'
|
||||
ALL = 'All'
|
||||
|
||||
|
@@ -1,6 +1,5 @@
|
||||
from typing import List, Union
|
||||
|
||||
from marl_factory_grid.environment import constants as c
|
||||
from marl_factory_grid.environment.actions import Action
|
||||
from marl_factory_grid.environment.entity.entity import Entity
|
||||
from marl_factory_grid.utils.render import RenderEntity
|
||||
@@ -8,6 +7,8 @@ from marl_factory_grid.utils import renderer
|
||||
from marl_factory_grid.utils.helpers import is_move
|
||||
from marl_factory_grid.utils.results import ActionResult, Result
|
||||
|
||||
from marl_factory_grid.environment import constants as c
|
||||
|
||||
|
||||
class Agent(Entity):
|
||||
|
||||
@@ -24,7 +25,7 @@ class Agent(Entity):
|
||||
return self._observations
|
||||
|
||||
@property
|
||||
def can_collide(self):
|
||||
def var_can_collide(self):
|
||||
return True
|
||||
|
||||
def step_result(self):
|
||||
|
@@ -1,15 +1,20 @@
|
||||
import abc
|
||||
|
||||
from marl_factory_grid.environment import constants as c
|
||||
from marl_factory_grid.environment.entity.object import EnvObject
|
||||
from marl_factory_grid.utils.render import RenderEntity
|
||||
from .. import constants as c
|
||||
from .object import EnvObject
|
||||
from ...utils.render import RenderEntity
|
||||
from ...utils.results import ActionResult
|
||||
|
||||
|
||||
class Entity(EnvObject, abc.ABC):
|
||||
"""Full Env Entity that lives on the environment Grid. Doors, Items, DirtPile etc..."""
|
||||
|
||||
@property
|
||||
def has_position(self):
|
||||
def state(self):
|
||||
return self._status or ActionResult(entity=self, identifier=c.NOOP, validity=c.VALID, reward=0)
|
||||
|
||||
@property
|
||||
def var_has_position(self):
|
||||
return self.pos != c.VALUE_NO_POS
|
||||
|
||||
@property
|
||||
@@ -64,12 +69,13 @@ class Entity(EnvObject, abc.ABC):
|
||||
|
||||
def __init__(self, tile, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self._status = None
|
||||
self._tile = tile
|
||||
tile.enter(self)
|
||||
|
||||
def summarize_state(self) -> dict:
|
||||
return dict(name=str(self.name), x=int(self.x), y=int(self.y),
|
||||
tile=str(self.tile.name), can_collide=bool(self.can_collide))
|
||||
tile=str(self.tile.name), can_collide=bool(self.var_can_collide))
|
||||
|
||||
@abc.abstractmethod
|
||||
def render(self):
|
||||
|
@@ -78,37 +78,37 @@ class EnvObject(Object):
|
||||
return self.name
|
||||
|
||||
@property
|
||||
def is_blocking_light(self):
|
||||
def var_is_blocking_light(self):
|
||||
try:
|
||||
return self._collection.is_blocking_light or False
|
||||
return self._collection.var_is_blocking_light or False
|
||||
except AttributeError:
|
||||
return False
|
||||
|
||||
@property
|
||||
def can_move(self):
|
||||
def var_can_move(self):
|
||||
try:
|
||||
return self._collection.can_move or False
|
||||
return self._collection.var_can_move or False
|
||||
except AttributeError:
|
||||
return False
|
||||
|
||||
@property
|
||||
def is_blocking_pos(self):
|
||||
def var_is_blocking_pos(self):
|
||||
try:
|
||||
return self._collection.is_blocking_pos or False
|
||||
return self._collection.var_is_blocking_pos or False
|
||||
except AttributeError:
|
||||
return False
|
||||
|
||||
@property
|
||||
def has_position(self):
|
||||
def var_has_position(self):
|
||||
try:
|
||||
return self._collection.has_position or False
|
||||
return self._collection.var_has_position or False
|
||||
except AttributeError:
|
||||
return False
|
||||
|
||||
@property
|
||||
def can_collide(self):
|
||||
def var_can_collide(self):
|
||||
try:
|
||||
return self._collection.can_collide or False
|
||||
return self._collection.var_can_collide or False
|
||||
except AttributeError:
|
||||
return False
|
||||
|
||||
|
@@ -35,11 +35,11 @@ class GlobalPosition(BoundEntityMixin, EnvObject):
|
||||
@property
|
||||
def encoding(self):
|
||||
if self._normalized:
|
||||
return tuple(np.divide(self._bound_entity.pos, self._level_shape))
|
||||
return tuple(np.divide(self._bound_entity.pos, self._shape))
|
||||
else:
|
||||
return self.bound_entity.pos
|
||||
|
||||
def __init__(self, *args, normalized: bool = True, **kwargs):
|
||||
def __init__(self, level_shape, *args, normalized: bool = True, **kwargs):
|
||||
super(GlobalPosition, self).__init__(*args, **kwargs)
|
||||
self._level_shape = math.sqrt(self.size)
|
||||
self._normalized = normalized
|
||||
self._shape = level_shape
|
||||
|
@@ -11,23 +11,23 @@ from marl_factory_grid.utils import helpers as h
|
||||
class Floor(EnvObject):
|
||||
|
||||
@property
|
||||
def has_position(self):
|
||||
def var_has_position(self):
|
||||
return True
|
||||
|
||||
@property
|
||||
def can_collide(self):
|
||||
def var_can_collide(self):
|
||||
return False
|
||||
|
||||
@property
|
||||
def can_move(self):
|
||||
def var_can_move(self):
|
||||
return False
|
||||
|
||||
@property
|
||||
def is_blocking_pos(self):
|
||||
def var_is_blocking_pos(self):
|
||||
return False
|
||||
|
||||
@property
|
||||
def is_blocking_light(self):
|
||||
def var_is_blocking_light(self):
|
||||
return False
|
||||
|
||||
@property
|
||||
@@ -51,7 +51,7 @@ class Floor(EnvObject):
|
||||
|
||||
@property
|
||||
def guests_that_can_collide(self):
|
||||
return [x for x in self.guests if x.can_collide]
|
||||
return [x for x in self.guests if x.var_can_collide]
|
||||
|
||||
@property
|
||||
def guests(self):
|
||||
@@ -67,7 +67,7 @@ class Floor(EnvObject):
|
||||
|
||||
@property
|
||||
def is_blocked(self):
|
||||
return any([x.is_blocking_pos for x in self.guests])
|
||||
return any([x.var_is_blocking_pos for x in self.guests])
|
||||
|
||||
def __init__(self, pos, **kwargs):
|
||||
super(Floor, self).__init__(**kwargs)
|
||||
@@ -86,7 +86,7 @@ class Floor(EnvObject):
|
||||
return bool(len(self._guests))
|
||||
|
||||
def enter(self, guest):
|
||||
if (guest.name not in self._guests and not self.is_blocked) and not (guest.is_blocking_pos and self.is_occupied()):
|
||||
if (guest.name not in self._guests and not self.is_blocked) and not (guest.var_is_blocking_pos and self.is_occupied()):
|
||||
self._guests.update({guest.name: guest})
|
||||
return c.VALID
|
||||
else:
|
||||
@@ -112,7 +112,7 @@ class Floor(EnvObject):
|
||||
class Wall(Floor):
|
||||
|
||||
@property
|
||||
def can_collide(self):
|
||||
def var_can_collide(self):
|
||||
return True
|
||||
|
||||
@property
|
||||
@@ -123,9 +123,9 @@ class Wall(Floor):
|
||||
return RenderEntity(c.WALL, self.pos)
|
||||
|
||||
@property
|
||||
def is_blocking_pos(self):
|
||||
def var_is_blocking_pos(self):
|
||||
return True
|
||||
|
||||
@property
|
||||
def is_blocking_light(self):
|
||||
def var_is_blocking_light(self):
|
||||
return True
|
||||
|
@@ -19,7 +19,7 @@ from marl_factory_grid.utils.states import Gamestate
|
||||
REC_TAC = 'rec_'
|
||||
|
||||
|
||||
class BaseFactory(gym.Env):
|
||||
class Factory(gym.Env):
|
||||
|
||||
@property
|
||||
def action_space(self):
|
||||
@@ -52,11 +52,15 @@ class BaseFactory(gym.Env):
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
self.close()
|
||||
|
||||
def __init__(self, config_file: Union[str, PathLike]):
|
||||
def __init__(self, config_file: Union[str, PathLike], custom_modules_path: Union[None, PathLike] = None,
|
||||
custom_level_path: Union[None, PathLike] = None):
|
||||
self._config_file = config_file
|
||||
self.conf = FactoryConfigParser(self._config_file)
|
||||
self.conf = FactoryConfigParser(self._config_file, custom_modules_path)
|
||||
# Attribute Assignment
|
||||
self.level_filepath = Path(__file__).parent.parent / h.LEVELS_DIR / f'{self.conf.level_name}.txt'
|
||||
if custom_level_path is not None:
|
||||
self.level_filepath = Path(custom_level_path)
|
||||
else:
|
||||
self.level_filepath = Path(__file__).parent.parent / h.LEVELS_DIR / f'{self.conf.level_name}.txt'
|
||||
self._renderer = None # expensive - don't use; unless required !
|
||||
|
||||
parsed_entities = self.conf.load_entities()
|
||||
@@ -90,7 +94,7 @@ class BaseFactory(gym.Env):
|
||||
self.state.entities.add_item({c.AGENT: agents})
|
||||
|
||||
# All is set up, trigger additional init (after agent entity spawn etc)
|
||||
self.state.rules.do_all_init(self.state)
|
||||
self.state.rules.do_all_init(self.state, self.map)
|
||||
|
||||
# Observations
|
||||
# noinspection PyAttributeOutsideInit
|
||||
@@ -144,7 +148,7 @@ class BaseFactory(gym.Env):
|
||||
try:
|
||||
done_reason = next(x for x in done_check_results if x.validity)
|
||||
done = True
|
||||
self.state.print(f'Env done, Reason: {done_reason.name}.')
|
||||
self.state.print(f'Env done, Reason: {done_reason.identifier}.')
|
||||
except StopIteration:
|
||||
done = False
|
||||
|
||||
|
@@ -1,6 +1,6 @@
|
||||
from marl_factory_grid.environment.entity.agent import Agent
|
||||
from marl_factory_grid.environment.groups.env_objects import EnvObjects
|
||||
from marl_factory_grid.environment.groups.mixins import PositionMixin
|
||||
from marl_factory_grid.environment.entity.agent import Agent
|
||||
|
||||
|
||||
class Agents(PositionMixin, EnvObjects):
|
||||
|
@@ -5,10 +5,10 @@ from marl_factory_grid.environment.entity.object import EnvObject
|
||||
class EnvObjects(Objects):
|
||||
|
||||
_entity = EnvObject
|
||||
is_blocking_light: bool = False
|
||||
can_collide: bool = False
|
||||
has_position: bool = False
|
||||
can_move: bool = False
|
||||
var_is_blocking_light: bool = False
|
||||
var_can_collide: bool = False
|
||||
var_has_position: bool = False
|
||||
var_can_move: bool = False
|
||||
|
||||
@property
|
||||
def encodings(self):
|
||||
@@ -19,7 +19,7 @@ class EnvObjects(Objects):
|
||||
self.size = size
|
||||
|
||||
def add_item(self, item: EnvObject):
|
||||
assert self.has_position or (len(self) <= self.size)
|
||||
assert self.var_has_position or (len(self) <= self.size)
|
||||
super(EnvObjects, self).add_item(item)
|
||||
return self
|
||||
|
||||
|
@@ -1,15 +1,19 @@
|
||||
from typing import List
|
||||
|
||||
from marl_factory_grid.environment import constants as c
|
||||
|
||||
from marl_factory_grid.environment.entity.entity import Entity
|
||||
from marl_factory_grid.environment.entity.wall_floor import Floor
|
||||
|
||||
|
||||
# noinspection PyUnresolvedReferences,PyTypeChecker,PyArgumentList
|
||||
class PositionMixin:
|
||||
|
||||
_entity = Entity
|
||||
is_blocking_light: bool = True
|
||||
can_collide: bool = True
|
||||
has_position: bool = True
|
||||
var_is_blocking_light: bool = True
|
||||
var_can_collide: bool = True
|
||||
var_has_position: bool = True
|
||||
|
||||
def spawn(self, tiles: List[Floor]):
|
||||
self.add_items([self._entity(tile) for tile in tiles])
|
||||
|
||||
def render(self):
|
||||
return [y for y in [x.render() for x in self] if y is not None]
|
||||
@@ -81,8 +85,8 @@ class IsBoundMixin:
|
||||
class HasBoundedMixin:
|
||||
|
||||
@property
|
||||
def obs_names(self):
|
||||
return [x.name for x in self]
|
||||
def obs_pairs(self):
|
||||
return [(x.name, x) for x in self]
|
||||
|
||||
def by_entity(self, entity):
|
||||
try:
|
||||
|
@@ -4,6 +4,7 @@ from typing import List
|
||||
import numpy as np
|
||||
|
||||
from marl_factory_grid.environment.entity.object import Object
|
||||
import marl_factory_grid.environment.constants as c
|
||||
|
||||
|
||||
class Objects:
|
||||
@@ -116,12 +117,21 @@ class Objects:
|
||||
def __repr__(self):
|
||||
return f'{self.__class__.__name__}[{dict(self._data)}]'
|
||||
|
||||
def spawn(self, n: int):
|
||||
self.add_items([self._entity() for _ in range(n)])
|
||||
return c.VALID
|
||||
|
||||
def despawn(self, items: List[Object]):
|
||||
items = [items] if isinstance(items, Object) else items
|
||||
for item in items:
|
||||
del self[item]
|
||||
|
||||
def notify_change_pos(self, entity: object):
|
||||
try:
|
||||
self.pos_dict[entity.last_pos].remove(entity)
|
||||
except (ValueError, AttributeError):
|
||||
pass
|
||||
if entity.has_position:
|
||||
if entity.var_has_position:
|
||||
try:
|
||||
self.pos_dict[entity.pos].append(entity)
|
||||
except (ValueError, AttributeError):
|
||||
|
@@ -2,10 +2,11 @@ from typing import List, Union
|
||||
|
||||
import numpy as np
|
||||
|
||||
from marl_factory_grid.environment.groups.env_objects import EnvObjects
|
||||
from marl_factory_grid.environment.groups.objects import Objects
|
||||
from marl_factory_grid.environment.groups.mixins import HasBoundedMixin, PositionMixin
|
||||
from marl_factory_grid.environment.entity.util import GlobalPosition
|
||||
from marl_factory_grid.environment.groups.env_objects import EnvObjects
|
||||
from marl_factory_grid.environment.groups.mixins import PositionMixin, HasBoundedMixin
|
||||
from marl_factory_grid.environment.groups.objects import Objects
|
||||
from marl_factory_grid.modules.zones import Zone
|
||||
from marl_factory_grid.utils import helpers as h
|
||||
from marl_factory_grid.environment import constants as c
|
||||
|
||||
@@ -44,7 +45,9 @@ class GlobalPositions(HasBoundedMixin, EnvObjects):
|
||||
super(GlobalPositions, self).__init__(*args, **kwargs)
|
||||
|
||||
|
||||
class Zones(Objects):
|
||||
class ZonesOLD(Objects):
|
||||
|
||||
_entity = Zone
|
||||
|
||||
@property
|
||||
def accounting_zones(self):
|
||||
|
@@ -30,8 +30,8 @@ class Walls(PositionMixin, EnvObjects):
|
||||
class Floors(Walls):
|
||||
_entity = Floor
|
||||
symbol = c.SYMBOL_FLOOR
|
||||
is_blocking_light: bool = False
|
||||
can_collide: bool = False
|
||||
var_is_blocking_light: bool = False
|
||||
var_can_collide: bool = False
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super(Floors, self).__init__(*args, **kwargs)
|
||||
|
@@ -17,7 +17,7 @@ class Rule(abc.ABC):
|
||||
def __repr__(self):
|
||||
return f'{self.name}'
|
||||
|
||||
def on_init(self, state):
|
||||
def on_init(self, state, lvl_map):
|
||||
return []
|
||||
|
||||
def on_reset(self):
|
||||
@@ -42,7 +42,7 @@ class MaxStepsReached(Rule):
|
||||
super().__init__()
|
||||
self.max_steps = max_steps
|
||||
|
||||
def on_init(self, state):
|
||||
def on_init(self, state, lvl_map):
|
||||
pass
|
||||
|
||||
def on_check_done(self, state):
|
||||
@@ -51,6 +51,20 @@ class MaxStepsReached(Rule):
|
||||
return [DoneResult(validity=c.NOT_VALID, identifier=self.name, reward=0)]
|
||||
|
||||
|
||||
class AssignGlobalPositions(Rule):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
def on_init(self, state, lvl_map):
|
||||
from marl_factory_grid.environment.entity.util import GlobalPosition
|
||||
for agent in state[c.AGENT]:
|
||||
gp = GlobalPosition(lvl_map.level_shape)
|
||||
gp.bind_to(agent)
|
||||
state[c.GLOBALPOSITIONS].add_item(gp)
|
||||
return []
|
||||
|
||||
|
||||
class Collision(Rule):
|
||||
|
||||
def __init__(self, done_at_collisions: bool = False):
|
||||
|
Reference in New Issue
Block a user