mirror of
https://github.com/illiumst/marl-factory-grid.git
synced 2025-05-23 07:16:44 +02:00
New Szenario "Two_Rooms_One_Door"
This commit is contained in:
parent
836495a884
commit
9135a69da6
47
marl_factory_grid/configs/two_rooms_one_door.yaml
Normal file
47
marl_factory_grid/configs/two_rooms_one_door.yaml
Normal file
@ -0,0 +1,47 @@
|
|||||||
|
Agents:
|
||||||
|
Wolfgang:
|
||||||
|
Actions:
|
||||||
|
- Move8
|
||||||
|
- Noop
|
||||||
|
- DestAction
|
||||||
|
- DoorUse
|
||||||
|
Observations:
|
||||||
|
- Walls
|
||||||
|
- Other
|
||||||
|
- Doors
|
||||||
|
- BoundDestination
|
||||||
|
Sigmund:
|
||||||
|
Actions:
|
||||||
|
- Move8
|
||||||
|
- Noop
|
||||||
|
- DestAction
|
||||||
|
- DoorUse
|
||||||
|
Observations:
|
||||||
|
- Combined:
|
||||||
|
- Other
|
||||||
|
- Walls
|
||||||
|
- BoundDestination
|
||||||
|
- Doors
|
||||||
|
Entities:
|
||||||
|
BoundDestinations: {}
|
||||||
|
ReachedDestinations: {}
|
||||||
|
Doors: {}
|
||||||
|
GlobalPositions: {}
|
||||||
|
Zones: {}
|
||||||
|
|
||||||
|
General:
|
||||||
|
env_seed: 69
|
||||||
|
individual_rewards: true
|
||||||
|
level_name: two_rooms
|
||||||
|
pomdp_r: 3
|
||||||
|
verbose: false
|
||||||
|
|
||||||
|
Rules:
|
||||||
|
Collision:
|
||||||
|
done_at_collisions: false
|
||||||
|
AssignGlobalPositions: {}
|
||||||
|
DoorAutoClose:
|
||||||
|
close_frequency: 10
|
||||||
|
ZoneInit: {}
|
||||||
|
AgentSingleZonePlacement: {}
|
||||||
|
IndividualDestinationZonePlacement: {}
|
@ -4,7 +4,7 @@ DEFAULTS = 'Defaults'
|
|||||||
SELF = 'Self'
|
SELF = 'Self'
|
||||||
PLACEHOLDER = 'Placeholder'
|
PLACEHOLDER = 'Placeholder'
|
||||||
FLOOR = 'Floor' # Identifier of Floor-objects and groups (groups).
|
FLOOR = 'Floor' # Identifier of Floor-objects and groups (groups).
|
||||||
FLOORS = 'Floors' # Identifier of Floor-objects and groups (groups).
|
FLOORS = 'Floors' # Identifier of Floor-objects and groups (groups).
|
||||||
WALL = 'Wall' # Identifier of Wall-objects and groups (groups).
|
WALL = 'Wall' # Identifier of Wall-objects and groups (groups).
|
||||||
WALLS = 'Walls' # Identifier of Wall-objects and groups (groups).
|
WALLS = 'Walls' # Identifier of Wall-objects and groups (groups).
|
||||||
LEVEL = 'Level' # Identifier of Level-objects and groups (groups).
|
LEVEL = 'Level' # Identifier of Level-objects and groups (groups).
|
||||||
|
@ -12,6 +12,22 @@ from marl_factory_grid.environment import constants as c
|
|||||||
|
|
||||||
class Agent(Entity):
|
class Agent(Entity):
|
||||||
|
|
||||||
|
@property
|
||||||
|
def var_is_blocking_light(self):
|
||||||
|
return False
|
||||||
|
|
||||||
|
@property
|
||||||
|
def var_can_move(self):
|
||||||
|
return True
|
||||||
|
|
||||||
|
@property
|
||||||
|
def var_is_blocking_pos(self):
|
||||||
|
return False
|
||||||
|
|
||||||
|
@property
|
||||||
|
def var_has_position(self):
|
||||||
|
return True
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def obs_tag(self):
|
def obs_tag(self):
|
||||||
return self.name
|
return self.name
|
||||||
|
@ -9,7 +9,7 @@ class BoundEntityMixin:
|
|||||||
|
|
||||||
@property
|
@property
|
||||||
def name(self):
|
def name(self):
|
||||||
return f'{self.__class__.__name__}({self._bound_entity.name})'
|
return f'{self.__class__.__name__}({self.bound_entity.name})'
|
||||||
|
|
||||||
def belongs_to_entity(self, entity):
|
def belongs_to_entity(self, entity):
|
||||||
return entity == self.bound_entity
|
return entity == self.bound_entity
|
||||||
|
@ -21,7 +21,7 @@ class Object:
|
|||||||
def name(self):
|
def name(self):
|
||||||
if self._str_ident is not None:
|
if self._str_ident is not None:
|
||||||
return f'{self.__class__.__name__}[{self._str_ident}]'
|
return f'{self.__class__.__name__}[{self._str_ident}]'
|
||||||
return f'{self.__class__.__name__}#{self.identifier_int}'
|
return f'{self.__class__.__name__}#{self.u_int}'
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def identifier(self):
|
def identifier(self):
|
||||||
@ -30,10 +30,14 @@ class Object:
|
|||||||
else:
|
else:
|
||||||
return self.name
|
return self.name
|
||||||
|
|
||||||
|
def reset_uid(self):
|
||||||
|
self._u_idx = defaultdict(lambda: 0)
|
||||||
|
return True
|
||||||
|
|
||||||
def __init__(self, str_ident: Union[str, None] = None, **kwargs):
|
def __init__(self, str_ident: Union[str, None] = None, **kwargs):
|
||||||
self._observers = []
|
self._observers = []
|
||||||
self._str_ident = str_ident
|
self._str_ident = str_ident
|
||||||
self.identifier_int = self._identify_and_count_up()
|
self.u_int = self._identify_and_count_up()
|
||||||
self._collection = None
|
self._collection = None
|
||||||
|
|
||||||
if kwargs:
|
if kwargs:
|
||||||
|
@ -30,10 +30,6 @@ class Floor(EnvObject):
|
|||||||
def var_is_blocking_light(self):
|
def var_is_blocking_light(self):
|
||||||
return False
|
return False
|
||||||
|
|
||||||
@property
|
|
||||||
def neighboring_floor_pos(self):
|
|
||||||
return [x.pos for x in self.neighboring_floor]
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def neighboring_floor(self):
|
def neighboring_floor(self):
|
||||||
if self._neighboring_floor:
|
if self._neighboring_floor:
|
||||||
|
@ -78,12 +78,19 @@ class Factory(gym.Env):
|
|||||||
return self.state.entities[item]
|
return self.state.entities[item]
|
||||||
|
|
||||||
def reset(self) -> (dict, dict):
|
def reset(self) -> (dict, dict):
|
||||||
|
if hasattr(self, 'state'):
|
||||||
|
for entity_group in self.state.entities:
|
||||||
|
try:
|
||||||
|
entity_group[0].reset_uid()
|
||||||
|
except (AttributeError, TypeError):
|
||||||
|
pass
|
||||||
|
|
||||||
self.state = None
|
self.state = None
|
||||||
|
|
||||||
# Init entity:
|
# Init entity:
|
||||||
entities = self.map.do_init()
|
entities = self.map.do_init()
|
||||||
|
|
||||||
# Grab all rules:
|
# Grab all )rules:
|
||||||
rules = self.conf.load_rules()
|
rules = self.conf.load_rules()
|
||||||
|
|
||||||
# Agents
|
# Agents
|
||||||
|
@ -41,6 +41,9 @@ class Entities(Objects):
|
|||||||
val.add_observer(self)
|
val.add_observer(self)
|
||||||
return self
|
return self
|
||||||
|
|
||||||
|
def __contains__(self, item):
|
||||||
|
return item in self._data
|
||||||
|
|
||||||
def __delitem__(self, name):
|
def __delitem__(self, name):
|
||||||
assert_str = 'This group of entity does not exist in this collection!'
|
assert_str = 'This group of entity does not exist in this collection!'
|
||||||
assert any([key for key in name.keys() if key in self.keys()]), assert_str
|
assert any([key for key in name.keys() if key in self.keys()]), assert_str
|
||||||
@ -51,7 +54,10 @@ class Entities(Objects):
|
|||||||
|
|
||||||
@property
|
@property
|
||||||
def obs_pairs(self):
|
def obs_pairs(self):
|
||||||
return [y for x in self for y in x.obs_pairs]
|
try:
|
||||||
|
return [y for x in self for y in x.obs_pairs]
|
||||||
|
except AttributeError:
|
||||||
|
print('OhOh (debug me)')
|
||||||
|
|
||||||
def by_pos(self, pos: (int, int)):
|
def by_pos(self, pos: (int, int)):
|
||||||
return self.pos_dict[pos]
|
return self.pos_dict[pos]
|
||||||
|
@ -82,7 +82,7 @@ class IsBoundMixin:
|
|||||||
|
|
||||||
|
|
||||||
# noinspection PyUnresolvedReferences,PyTypeChecker
|
# noinspection PyUnresolvedReferences,PyTypeChecker
|
||||||
class HasBoundedMixin:
|
class HasBoundMixin:
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def obs_pairs(self):
|
def obs_pairs(self):
|
||||||
|
@ -37,7 +37,7 @@ class Objects:
|
|||||||
|
|
||||||
def __init__(self, *args, **kwargs):
|
def __init__(self, *args, **kwargs):
|
||||||
self._data = defaultdict(lambda: None)
|
self._data = defaultdict(lambda: None)
|
||||||
self._observers = list()
|
self._observers = [self]
|
||||||
self.pos_dict = defaultdict(list)
|
self.pos_dict = defaultdict(list)
|
||||||
|
|
||||||
def __len__(self):
|
def __len__(self):
|
||||||
@ -52,6 +52,7 @@ class Objects:
|
|||||||
assert self._data[item.name] is None, f'{item.name} allready exists!!!'
|
assert self._data[item.name] is None, f'{item.name} allready exists!!!'
|
||||||
self._data.update({item.name: item})
|
self._data.update({item.name: item})
|
||||||
item.set_collection(self)
|
item.set_collection(self)
|
||||||
|
# self.notify_add_entity(item)
|
||||||
for observer in self.observers:
|
for observer in self.observers:
|
||||||
observer.notify_add_entity(item)
|
observer.notify_add_entity(item)
|
||||||
return self
|
return self
|
||||||
@ -96,8 +97,6 @@ class Objects:
|
|||||||
except StopIteration:
|
except StopIteration:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def __getitem__(self, item):
|
def __getitem__(self, item):
|
||||||
if isinstance(item, (int, np.int64, np.int32)):
|
if isinstance(item, (int, np.int64, np.int32)):
|
||||||
if item < 0:
|
if item < 0:
|
||||||
|
@ -4,7 +4,7 @@ import numpy as np
|
|||||||
|
|
||||||
from marl_factory_grid.environment.entity.util import GlobalPosition
|
from marl_factory_grid.environment.entity.util import GlobalPosition
|
||||||
from marl_factory_grid.environment.groups.env_objects import EnvObjects
|
from marl_factory_grid.environment.groups.env_objects import EnvObjects
|
||||||
from marl_factory_grid.environment.groups.mixins import PositionMixin, HasBoundedMixin
|
from marl_factory_grid.environment.groups.mixins import PositionMixin, HasBoundMixin
|
||||||
from marl_factory_grid.environment.groups.objects import Objects
|
from marl_factory_grid.environment.groups.objects import Objects
|
||||||
from marl_factory_grid.modules.zones import Zone
|
from marl_factory_grid.modules.zones import Zone
|
||||||
from marl_factory_grid.utils import helpers as h
|
from marl_factory_grid.utils import helpers as h
|
||||||
@ -35,7 +35,7 @@ class Combined(PositionMixin, EnvObjects):
|
|||||||
return [(name, None) for name in self.names]
|
return [(name, None) for name in self.names]
|
||||||
|
|
||||||
|
|
||||||
class GlobalPositions(HasBoundedMixin, EnvObjects):
|
class GlobalPositions(HasBoundMixin, EnvObjects):
|
||||||
|
|
||||||
_entity = GlobalPosition
|
_entity = GlobalPosition
|
||||||
is_blocking_light = False,
|
is_blocking_light = False,
|
||||||
|
@ -1,9 +1,9 @@
|
|||||||
from marl_factory_grid.environment.groups.env_objects import EnvObjects
|
from marl_factory_grid.environment.groups.env_objects import EnvObjects
|
||||||
from marl_factory_grid.environment.groups.mixins import PositionMixin, HasBoundedMixin
|
from marl_factory_grid.environment.groups.mixins import PositionMixin, HasBoundMixin
|
||||||
from marl_factory_grid.modules.batteries.entitites import ChargePod, Battery
|
from marl_factory_grid.modules.batteries.entitites import ChargePod, Battery
|
||||||
|
|
||||||
|
|
||||||
class Batteries(HasBoundedMixin, EnvObjects):
|
class Batteries(HasBoundMixin, EnvObjects):
|
||||||
|
|
||||||
_entity = Battery
|
_entity = Battery
|
||||||
is_blocking_light: bool = False
|
is_blocking_light: bool = False
|
||||||
|
@ -13,7 +13,9 @@ class DestAction(Action):
|
|||||||
super().__init__(d.DESTINATION)
|
super().__init__(d.DESTINATION)
|
||||||
|
|
||||||
def do(self, entity, state) -> Union[None, ActionResult]:
|
def do(self, entity, state) -> Union[None, ActionResult]:
|
||||||
if destination := state[d.DESTINATION].by_pos(entity.pos):
|
dest_entities = d.DESTINATION if d.DESTINATION in state else d.BOUNDDESTINATION
|
||||||
|
assert dest_entities
|
||||||
|
if destination := state[dest_entities].by_pos(entity.pos):
|
||||||
valid = destination.do_wait_action(entity)
|
valid = destination.do_wait_action(entity)
|
||||||
state.print(f'{entity.name} just waited at {entity.pos}')
|
state.print(f'{entity.name} just waited at {entity.pos}')
|
||||||
else:
|
else:
|
||||||
|
@ -1,6 +1,7 @@
|
|||||||
|
|
||||||
# Destination Env
|
# Destination Env
|
||||||
DESTINATION = 'Destinations'
|
DESTINATION = 'Destinations'
|
||||||
|
BOUNDDESTINATION = 'BoundDestinations'
|
||||||
DEST_SYMBOL = 1
|
DEST_SYMBOL = 1
|
||||||
DEST_REACHED_REWARD = 0.5
|
DEST_REACHED_REWARD = 0.5
|
||||||
DEST_REACHED = 'ReachedDestinations'
|
DEST_REACHED = 'ReachedDestinations'
|
||||||
|
@ -3,12 +3,19 @@ from collections import defaultdict
|
|||||||
from marl_factory_grid.environment.entity.agent import Agent
|
from marl_factory_grid.environment.entity.agent import Agent
|
||||||
from marl_factory_grid.environment.entity.entity import Entity
|
from marl_factory_grid.environment.entity.entity import Entity
|
||||||
from marl_factory_grid.environment import constants as c
|
from marl_factory_grid.environment import constants as c
|
||||||
|
from marl_factory_grid.environment.entity.mixin import BoundEntityMixin
|
||||||
from marl_factory_grid.utils.render import RenderEntity
|
from marl_factory_grid.utils.render import RenderEntity
|
||||||
from marl_factory_grid.modules.destinations import constants as d
|
from marl_factory_grid.modules.destinations import constants as d
|
||||||
|
|
||||||
|
|
||||||
class Destination(Entity):
|
class Destination(Entity):
|
||||||
|
|
||||||
|
var_can_move = False
|
||||||
|
var_can_collide = False
|
||||||
|
var_has_position = True
|
||||||
|
var_is_blocking_pos = False
|
||||||
|
var_is_blocking_light = False
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def any_agent_has_dwelled(self):
|
def any_agent_has_dwelled(self):
|
||||||
return bool(len(self._per_agent_times))
|
return bool(len(self._per_agent_times))
|
||||||
@ -49,3 +56,21 @@ class Destination(Entity):
|
|||||||
|
|
||||||
def render(self):
|
def render(self):
|
||||||
return RenderEntity(d.DESTINATION, self.pos)
|
return RenderEntity(d.DESTINATION, self.pos)
|
||||||
|
|
||||||
|
|
||||||
|
class BoundDestination(BoundEntityMixin, Destination):
|
||||||
|
|
||||||
|
@property
|
||||||
|
def encoding(self):
|
||||||
|
return d.DEST_SYMBOL
|
||||||
|
|
||||||
|
def __init__(self, entity, *args, **kwargs):
|
||||||
|
self.bind_to(entity)
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
@property
|
||||||
|
def is_considered_reached(self):
|
||||||
|
agent_at_position = any(self.bound_entity == x for x in self.tile.guests_that_can_collide)
|
||||||
|
return (agent_at_position and not self.dwell_time) \
|
||||||
|
or any(x == 0 for x in self._per_agent_times[self.bound_entity.name])
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
from marl_factory_grid.environment.groups.env_objects import EnvObjects
|
from marl_factory_grid.environment.groups.env_objects import EnvObjects
|
||||||
from marl_factory_grid.environment.groups.mixins import PositionMixin
|
from marl_factory_grid.environment.groups.mixins import PositionMixin, HasBoundMixin
|
||||||
from marl_factory_grid.modules.destinations.entitites import Destination
|
from marl_factory_grid.modules.destinations.entitites import Destination, BoundDestination
|
||||||
|
|
||||||
|
|
||||||
class Destinations(PositionMixin, EnvObjects):
|
class Destinations(PositionMixin, EnvObjects):
|
||||||
@ -16,6 +16,14 @@ class Destinations(PositionMixin, EnvObjects):
|
|||||||
return super(Destinations, self).__repr__()
|
return super(Destinations, self).__repr__()
|
||||||
|
|
||||||
|
|
||||||
|
class BoundDestinations(HasBoundMixin, Destinations):
|
||||||
|
|
||||||
|
_entity = BoundDestination
|
||||||
|
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
class ReachedDestinations(Destinations):
|
class ReachedDestinations(Destinations):
|
||||||
_entity = Destination
|
_entity = Destination
|
||||||
is_blocking_light = False
|
is_blocking_light = False
|
||||||
|
@ -72,7 +72,7 @@ class Door(Entity):
|
|||||||
|
|
||||||
def render(self):
|
def render(self):
|
||||||
name, state = 'door_open' if self.is_open else 'door_closed', 'blank'
|
name, state = 'door_open' if self.is_open else 'door_closed', 'blank'
|
||||||
return RenderEntity(name, self.pos, 1, 'none', state, self.identifier_int + 1)
|
return RenderEntity(name, self.pos, 1, 'none', state, self.u_int + 1)
|
||||||
|
|
||||||
def use(self):
|
def use(self):
|
||||||
if self._status == d.STATE_OPEN:
|
if self._status == d.STATE_OPEN:
|
||||||
|
@ -14,12 +14,6 @@ class Doors(PositionMixin, EnvObjects):
|
|||||||
def __init__(self, *args, **kwargs):
|
def __init__(self, *args, **kwargs):
|
||||||
super(Doors, self).__init__(*args, can_collide=True, **kwargs)
|
super(Doors, self).__init__(*args, can_collide=True, **kwargs)
|
||||||
|
|
||||||
def get_near_position(self, position: (int, int)) -> Union[None, Door]:
|
|
||||||
try:
|
|
||||||
return next(door for door in self if position in door.tile.neighboring_floor_pos)
|
|
||||||
except StopIteration:
|
|
||||||
return None
|
|
||||||
|
|
||||||
def tick_doors(self):
|
def tick_doors(self):
|
||||||
result_dict = dict()
|
result_dict = dict()
|
||||||
for door in self:
|
for door in self:
|
||||||
|
@ -2,7 +2,7 @@ from typing import List
|
|||||||
|
|
||||||
from marl_factory_grid.environment.groups.env_objects import EnvObjects
|
from marl_factory_grid.environment.groups.env_objects import EnvObjects
|
||||||
from marl_factory_grid.environment.groups.objects import Objects
|
from marl_factory_grid.environment.groups.objects import Objects
|
||||||
from marl_factory_grid.environment.groups.mixins import PositionMixin, IsBoundMixin, HasBoundedMixin
|
from marl_factory_grid.environment.groups.mixins import PositionMixin, IsBoundMixin, HasBoundMixin
|
||||||
from marl_factory_grid.environment.entity.wall_floor import Floor
|
from marl_factory_grid.environment.entity.wall_floor import Floor
|
||||||
from marl_factory_grid.environment.entity.agent import Agent
|
from marl_factory_grid.environment.entity.agent import Agent
|
||||||
from marl_factory_grid.modules.items.entitites import Item, DropOffLocation
|
from marl_factory_grid.modules.items.entitites import Item, DropOffLocation
|
||||||
@ -46,7 +46,7 @@ class Inventory(IsBoundMixin, EnvObjects):
|
|||||||
self._collection = collection
|
self._collection = collection
|
||||||
|
|
||||||
|
|
||||||
class Inventories(HasBoundedMixin, Objects):
|
class Inventories(HasBoundMixin, Objects):
|
||||||
|
|
||||||
_entity = Inventory
|
_entity = Inventory
|
||||||
var_can_move = False
|
var_can_move = False
|
||||||
|
13
marl_factory_grid/modules/levels/two_rooms.txt
Normal file
13
marl_factory_grid/modules/levels/two_rooms.txt
Normal file
@ -0,0 +1,13 @@
|
|||||||
|
###############
|
||||||
|
#111111#222222#
|
||||||
|
#111111#222222#
|
||||||
|
#111111#222222#
|
||||||
|
#111111#222222#
|
||||||
|
#111111#222222#
|
||||||
|
#111111D222222#
|
||||||
|
#111111#222222#
|
||||||
|
#111111#222222#
|
||||||
|
#111111#222222#
|
||||||
|
#111111#222222#
|
||||||
|
#111111#222222#
|
||||||
|
###############
|
@ -12,6 +12,10 @@ from marl_factory_grid.modules.doors import constants as d
|
|||||||
|
|
||||||
class Zone(Object):
|
class Zone(Object):
|
||||||
|
|
||||||
|
@property
|
||||||
|
def positions(self):
|
||||||
|
return [x.pos for x in self.tiles]
|
||||||
|
|
||||||
def __init__(self, tiles: List[Floor], *args, **kwargs):
|
def __init__(self, tiles: List[Floor], *args, **kwargs):
|
||||||
super(Zone, self).__init__(*args, **kwargs)
|
super(Zone, self).__init__(*args, **kwargs)
|
||||||
self.tiles = tiles
|
self.tiles = tiles
|
||||||
|
@ -10,3 +10,15 @@ class Zones(Objects):
|
|||||||
|
|
||||||
def __init__(self, *args, **kwargs):
|
def __init__(self, *args, **kwargs):
|
||||||
super(Zones, self).__init__(*args, can_collide=True, **kwargs)
|
super(Zones, self).__init__(*args, can_collide=True, **kwargs)
|
||||||
|
|
||||||
|
def by_pos(self, pos):
|
||||||
|
return self.pos_dict[pos]
|
||||||
|
|
||||||
|
def notify_add_entity(self, entity: Zone):
|
||||||
|
self.pos_dict.update({key: [entity] for key in entity.positions})
|
||||||
|
return True
|
||||||
|
|
||||||
|
def notify_del_entity(self, entity: Zone):
|
||||||
|
for pos in entity.positions:
|
||||||
|
self.pos_dict[pos].remove(entity)
|
||||||
|
return True
|
||||||
|
@ -1,26 +1,38 @@
|
|||||||
from random import choices
|
from random import choices, choice
|
||||||
|
|
||||||
from marl_factory_grid.environment.rules import Rule
|
from . import constants as z, Zone
|
||||||
from marl_factory_grid.environment import constants as c
|
from ..destinations import constants as d
|
||||||
from marl_factory_grid.modules.zones import Zone
|
from ..destinations.entitites import BoundDestination
|
||||||
from . import constants as z
|
from ...environment.rules import Rule
|
||||||
|
from ...environment import constants as c
|
||||||
|
|
||||||
|
|
||||||
|
class ZoneInit(Rule):
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
def on_init(self, state, lvl_map):
|
||||||
|
zones = []
|
||||||
|
z_idx = 1
|
||||||
|
|
||||||
|
while z_idx:
|
||||||
|
zone_positions = lvl_map.get_coordinates_for_symbol(z_idx)
|
||||||
|
if len(zone_positions):
|
||||||
|
zones.append(Zone([state[c.FLOOR].by_pos(pos) for pos in zone_positions]))
|
||||||
|
z_idx += 1
|
||||||
|
else:
|
||||||
|
z_idx = 0
|
||||||
|
state[z.ZONES].add_items(zones)
|
||||||
|
return []
|
||||||
|
|
||||||
|
|
||||||
class AgentSingleZonePlacement(Rule):
|
class AgentSingleZonePlacement(Rule):
|
||||||
|
|
||||||
def __init__(self, n_zones=3):
|
def __init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.n_zones = n_zones
|
|
||||||
|
|
||||||
def on_init(self, state, lvl_map):
|
def on_init(self, state, lvl_map):
|
||||||
zones = []
|
|
||||||
|
|
||||||
for z_idx in range(1, self.n_zones):
|
|
||||||
zone_positions = lvl_map.get_coordinates_for_symbol(z_idx)
|
|
||||||
assert len(zone_positions)
|
|
||||||
zones.append(Zone([state[c.FLOOR].by_pos(pos) for pos in zone_positions]))
|
|
||||||
state[z.ZONES].add_items(zones)
|
|
||||||
|
|
||||||
n_agents = len(state[c.AGENT])
|
n_agents = len(state[c.AGENT])
|
||||||
assert len(state[z.ZONES]) >= n_agents
|
assert len(state[z.ZONES]) >= n_agents
|
||||||
|
|
||||||
@ -31,3 +43,32 @@ class AgentSingleZonePlacement(Rule):
|
|||||||
|
|
||||||
def tick_step(self, state):
|
def tick_step(self, state):
|
||||||
return []
|
return []
|
||||||
|
|
||||||
|
|
||||||
|
class IndividualDestinationZonePlacement(Rule):
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
def on_init(self, state, lvl_map):
|
||||||
|
for agent in state[c.AGENT]:
|
||||||
|
self.trigger_destination_spawn(agent, state)
|
||||||
|
pass
|
||||||
|
return []
|
||||||
|
|
||||||
|
def tick_step(self, state):
|
||||||
|
return []
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def trigger_destination_spawn(agent, state):
|
||||||
|
agent_zones = state[z.ZONES].by_pos(agent.pos)
|
||||||
|
other_zones = [x for x in state[z.ZONES] if x not in agent_zones]
|
||||||
|
already_has_destination = True
|
||||||
|
while already_has_destination:
|
||||||
|
tile = choice(other_zones).random_tile
|
||||||
|
if state[d.BOUNDDESTINATION].by_pos(tile.pos) is None:
|
||||||
|
already_has_destination = False
|
||||||
|
destination = BoundDestination(agent, tile)
|
||||||
|
state[d.BOUNDDESTINATION].add_item(destination)
|
||||||
|
continue
|
||||||
|
return c.VALID
|
||||||
|
@ -76,6 +76,9 @@ class Gamestate(object):
|
|||||||
def __iter__(self):
|
def __iter__(self):
|
||||||
return iter(e for e in self.entities.values())
|
return iter(e for e in self.entities.values())
|
||||||
|
|
||||||
|
def __contains__(self, item):
|
||||||
|
return item in self.entities
|
||||||
|
|
||||||
def __repr__(self):
|
def __repr__(self):
|
||||||
return f'{self.__class__.__name__}({len(self.entities)} Entitites @ Step {self.curr_step})'
|
return f'{self.__class__.__name__}({len(self.entities)} Entitites @ Step {self.curr_step})'
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user