mirror of
https://github.com/illiumst/marl-factory-grid.git
synced 2025-06-26 05:01:36 +02:00
Machines
This commit is contained in:
@ -1,6 +1,6 @@
|
||||
from marl_factory_grid.environment.entity.agent import Agent
|
||||
from marl_factory_grid.environment.groups.env_objects import EnvObjects
|
||||
from marl_factory_grid.environment.groups.mixins import PositionMixin
|
||||
from marl_factory_grid.environment.entity.agent import Agent
|
||||
|
||||
|
||||
class Agents(PositionMixin, EnvObjects):
|
||||
|
@ -5,10 +5,10 @@ from marl_factory_grid.environment.entity.object import EnvObject
|
||||
class EnvObjects(Objects):
|
||||
|
||||
_entity = EnvObject
|
||||
is_blocking_light: bool = False
|
||||
can_collide: bool = False
|
||||
has_position: bool = False
|
||||
can_move: bool = False
|
||||
var_is_blocking_light: bool = False
|
||||
var_can_collide: bool = False
|
||||
var_has_position: bool = False
|
||||
var_can_move: bool = False
|
||||
|
||||
@property
|
||||
def encodings(self):
|
||||
@ -19,7 +19,7 @@ class EnvObjects(Objects):
|
||||
self.size = size
|
||||
|
||||
def add_item(self, item: EnvObject):
|
||||
assert self.has_position or (len(self) <= self.size)
|
||||
assert self.var_has_position or (len(self) <= self.size)
|
||||
super(EnvObjects, self).add_item(item)
|
||||
return self
|
||||
|
||||
|
@ -1,15 +1,19 @@
|
||||
from typing import List
|
||||
|
||||
from marl_factory_grid.environment import constants as c
|
||||
|
||||
from marl_factory_grid.environment.entity.entity import Entity
|
||||
from marl_factory_grid.environment.entity.wall_floor import Floor
|
||||
|
||||
|
||||
# noinspection PyUnresolvedReferences,PyTypeChecker,PyArgumentList
|
||||
class PositionMixin:
|
||||
|
||||
_entity = Entity
|
||||
is_blocking_light: bool = True
|
||||
can_collide: bool = True
|
||||
has_position: bool = True
|
||||
var_is_blocking_light: bool = True
|
||||
var_can_collide: bool = True
|
||||
var_has_position: bool = True
|
||||
|
||||
def spawn(self, tiles: List[Floor]):
|
||||
self.add_items([self._entity(tile) for tile in tiles])
|
||||
|
||||
def render(self):
|
||||
return [y for y in [x.render() for x in self] if y is not None]
|
||||
@ -81,8 +85,8 @@ class IsBoundMixin:
|
||||
class HasBoundedMixin:
|
||||
|
||||
@property
|
||||
def obs_names(self):
|
||||
return [x.name for x in self]
|
||||
def obs_pairs(self):
|
||||
return [(x.name, x) for x in self]
|
||||
|
||||
def by_entity(self, entity):
|
||||
try:
|
||||
|
@ -4,6 +4,7 @@ from typing import List
|
||||
import numpy as np
|
||||
|
||||
from marl_factory_grid.environment.entity.object import Object
|
||||
import marl_factory_grid.environment.constants as c
|
||||
|
||||
|
||||
class Objects:
|
||||
@ -116,12 +117,21 @@ class Objects:
|
||||
def __repr__(self):
|
||||
return f'{self.__class__.__name__}[{dict(self._data)}]'
|
||||
|
||||
def spawn(self, n: int):
|
||||
self.add_items([self._entity() for _ in range(n)])
|
||||
return c.VALID
|
||||
|
||||
def despawn(self, items: List[Object]):
|
||||
items = [items] if isinstance(items, Object) else items
|
||||
for item in items:
|
||||
del self[item]
|
||||
|
||||
def notify_change_pos(self, entity: object):
|
||||
try:
|
||||
self.pos_dict[entity.last_pos].remove(entity)
|
||||
except (ValueError, AttributeError):
|
||||
pass
|
||||
if entity.has_position:
|
||||
if entity.var_has_position:
|
||||
try:
|
||||
self.pos_dict[entity.pos].append(entity)
|
||||
except (ValueError, AttributeError):
|
||||
|
@ -2,10 +2,11 @@ from typing import List, Union
|
||||
|
||||
import numpy as np
|
||||
|
||||
from marl_factory_grid.environment.groups.env_objects import EnvObjects
|
||||
from marl_factory_grid.environment.groups.objects import Objects
|
||||
from marl_factory_grid.environment.groups.mixins import HasBoundedMixin, PositionMixin
|
||||
from marl_factory_grid.environment.entity.util import GlobalPosition
|
||||
from marl_factory_grid.environment.groups.env_objects import EnvObjects
|
||||
from marl_factory_grid.environment.groups.mixins import PositionMixin, HasBoundedMixin
|
||||
from marl_factory_grid.environment.groups.objects import Objects
|
||||
from marl_factory_grid.modules.zones import Zone
|
||||
from marl_factory_grid.utils import helpers as h
|
||||
from marl_factory_grid.environment import constants as c
|
||||
|
||||
@ -44,7 +45,9 @@ class GlobalPositions(HasBoundedMixin, EnvObjects):
|
||||
super(GlobalPositions, self).__init__(*args, **kwargs)
|
||||
|
||||
|
||||
class Zones(Objects):
|
||||
class ZonesOLD(Objects):
|
||||
|
||||
_entity = Zone
|
||||
|
||||
@property
|
||||
def accounting_zones(self):
|
||||
|
@ -30,8 +30,8 @@ class Walls(PositionMixin, EnvObjects):
|
||||
class Floors(Walls):
|
||||
_entity = Floor
|
||||
symbol = c.SYMBOL_FLOOR
|
||||
is_blocking_light: bool = False
|
||||
can_collide: bool = False
|
||||
var_is_blocking_light: bool = False
|
||||
var_can_collide: bool = False
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super(Floors, self).__init__(*args, **kwargs)
|
||||
|
Reference in New Issue
Block a user