mirror of
https://github.com/illiumst/marl-factory-grid.git
synced 2025-09-15 23:37:14 +02:00
Machines
This commit is contained in:
@@ -0,0 +1,7 @@
|
||||
from .batteries import *
|
||||
from .clean_up import *
|
||||
from .destinations import *
|
||||
from .doors import *
|
||||
from .items import *
|
||||
from .machines import *
|
||||
from .maintenance import *
|
||||
|
@@ -8,4 +8,4 @@ WEST = 'west'
|
||||
NORTHEAST = 'north_east'
|
||||
SOUTHEAST = 'south_east'
|
||||
SOUTHWEST = 'south_west'
|
||||
NORTHWEST = 'north_west'
|
||||
NORTHWEST = 'north_west'
|
||||
|
@@ -8,7 +8,7 @@ class TemplateRule(Rule):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super(TemplateRule, self).__init__(*args, **kwargs)
|
||||
|
||||
def on_init(self, state):
|
||||
def on_init(self, state, lvl_map):
|
||||
pass
|
||||
|
||||
def tick_pre_step(self, state) -> List[TickResult]:
|
||||
|
@@ -0,0 +1,4 @@
|
||||
from .actions import BtryCharge
|
||||
from .entitites import ChargePod, Battery
|
||||
from .groups import ChargePods, Batteries
|
||||
from .rules import BtryDoneAtDischarge, Btry
|
||||
|
@@ -13,18 +13,13 @@ class Batteries(HasBoundedMixin, EnvObjects):
|
||||
def obs_tag(self):
|
||||
return self.__class__.__name__
|
||||
|
||||
@property
|
||||
def obs_pairs(self):
|
||||
return [(x.name, x) for x in self]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super(Batteries, self).__init__(*args, **kwargs)
|
||||
|
||||
def spawn_batteries(self, agents, initial_charge_level):
|
||||
def spawn(self, agents, initial_charge_level):
|
||||
batteries = [self._entity(initial_charge_level, agent) for _, agent in enumerate(agents)]
|
||||
self.add_items(batteries)
|
||||
|
||||
|
||||
class ChargePods(PositionMixin, EnvObjects):
|
||||
|
||||
_entity = ChargePod
|
||||
|
@@ -13,8 +13,8 @@ class Btry(Rule):
|
||||
self.per_action_costs = per_action_costs
|
||||
self.initial_charge = initial_charge
|
||||
|
||||
def on_init(self, state):
|
||||
state[b.BATTERIES].spawn_batteries(state[c.AGENT], self.initial_charge)
|
||||
def on_init(self, state, lvl_map):
|
||||
state[b.BATTERIES].spawn(state[c.AGENT], self.initial_charge)
|
||||
|
||||
def tick_pre_step(self, state) -> List[TickResult]:
|
||||
pass
|
||||
|
@@ -0,0 +1,6 @@
|
||||
from .actions import CleanUp
|
||||
from .entitites import DirtPile
|
||||
from .groups import DirtPiles
|
||||
from .rule_respawn import DirtRespawnRule
|
||||
from .rule_smear_on_move import DirtSmearOnMove
|
||||
from .rule_done_on_all_clean import DirtAllCleanDone
|
||||
|
@@ -7,6 +7,22 @@ from marl_factory_grid.modules.clean_up import constants as d
|
||||
|
||||
class DirtPile(Entity):
|
||||
|
||||
@property
|
||||
def var_can_collide(self):
|
||||
return False
|
||||
|
||||
@property
|
||||
def var_can_move(self):
|
||||
return False
|
||||
|
||||
@property
|
||||
def var_is_blocking_light(self):
|
||||
return False
|
||||
|
||||
@property
|
||||
def var_has_position(self):
|
||||
return True
|
||||
|
||||
@property
|
||||
def amount(self):
|
||||
return self._amount
|
||||
|
@@ -31,7 +31,7 @@ class DirtPiles(PositionMixin, EnvObjects):
|
||||
self.max_global_amount = max_global_amount
|
||||
self.max_local_amount = max_local_amount
|
||||
|
||||
def spawn_dirt(self, then_dirty_tiles, amount) -> bool:
|
||||
def spawn(self, then_dirty_tiles, amount) -> bool:
|
||||
if isinstance(then_dirty_tiles, Floor):
|
||||
then_dirty_tiles = [then_dirty_tiles]
|
||||
for tile in then_dirty_tiles:
|
||||
@@ -57,7 +57,7 @@ class DirtPiles(PositionMixin, EnvObjects):
|
||||
var = self.dirt_spawn_r_var
|
||||
new_spawn = abs(self.initial_dirt_ratio + (state.rng.uniform(-var, var) if initial_spawn else 0))
|
||||
n_dirt_tiles = max(0, int(new_spawn * len(free_for_dirt)))
|
||||
return self.spawn_dirt(free_for_dirt[:n_dirt_tiles], self.initial_amount)
|
||||
return self.spawn(free_for_dirt[:n_dirt_tiles], self.initial_amount)
|
||||
|
||||
def __repr__(self):
|
||||
s = super(DirtPiles, self).__repr__()
|
||||
|
@@ -11,7 +11,7 @@ class DirtRespawnRule(Rule):
|
||||
self.spawn_freq = spawn_freq
|
||||
self._next_dirt_spawn = spawn_freq
|
||||
|
||||
def on_init(self, state) -> str:
|
||||
def on_init(self, state, lvl_map) -> str:
|
||||
state[d.DIRT].trigger_dirt_spawn(state, initial_spawn=True)
|
||||
return f'Initial Dirt was spawned on: {[x.pos for x in state[d.DIRT]]}'
|
||||
|
||||
|
@@ -18,7 +18,7 @@ class DirtSmearOnMove(Rule):
|
||||
if is_move(entity.state.identifier) and entity.state.validity == c.VALID:
|
||||
if old_pos_dirt := state[d.DIRT].by_pos(entity.last_pos):
|
||||
if smeared_dirt := round(old_pos_dirt.amount * self.smear_amount, 2):
|
||||
if state[d.DIRT].spawn_dirt(entity.tile, amount=smeared_dirt):
|
||||
if state[d.DIRT].spawn(entity.tile, amount=smeared_dirt):
|
||||
results.append(TickResult(identifier=self.name, entity=entity,
|
||||
reward=0, validity=c.VALID))
|
||||
return results
|
||||
|
@@ -0,0 +1,4 @@
|
||||
from .actions import DestAction
|
||||
from .entitites import Destination
|
||||
from .groups import ReachedDestinations, Destinations
|
||||
from .rules import DestinationDone, DestinationReach, DestinationSpawn
|
||||
|
@@ -62,7 +62,7 @@ class DestinationSpawn(Rule):
|
||||
self.n_dests = n_dests
|
||||
self.spawn_mode = spawn_mode
|
||||
|
||||
def on_init(self, state):
|
||||
def on_init(self, state, lvl_map):
|
||||
# noinspection PyAttributeOutsideInit
|
||||
self._dest_spawn_timer = self.spawn_frequency
|
||||
self.trigger_destination_spawn(self.n_dests, state)
|
||||
|
@@ -0,0 +1,4 @@
|
||||
from .actions import DoorUse
|
||||
from .entitites import Door, DoorIndicator
|
||||
from .groups import Doors
|
||||
from .rule_door_auto_close import DoorAutoClose
|
||||
|
@@ -1,10 +1,9 @@
|
||||
from typing import Union
|
||||
|
||||
from marl_factory_grid.environment.actions import Action
|
||||
from marl_factory_grid.utils.results import ActionResult
|
||||
|
||||
from marl_factory_grid.modules.doors import constants as d, rewards as r
|
||||
from marl_factory_grid.environment import constants as c
|
||||
from marl_factory_grid.utils.results import ActionResult
|
||||
|
||||
|
||||
class DoorUse(Action):
|
||||
|
@@ -22,15 +22,15 @@ class DoorIndicator(Entity):
|
||||
class Door(Entity):
|
||||
|
||||
@property
|
||||
def is_blocking_pos(self):
|
||||
def var_is_blocking_pos(self):
|
||||
return False if self.is_open else True
|
||||
|
||||
@property
|
||||
def is_blocking_light(self):
|
||||
def var_is_blocking_light(self):
|
||||
return False if self.is_open else True
|
||||
|
||||
@property
|
||||
def can_collide(self):
|
||||
def var_can_collide(self):
|
||||
return False if self.is_open else True
|
||||
|
||||
@property
|
||||
@@ -42,12 +42,14 @@ class Door(Entity):
|
||||
return 'open' if self.is_open else 'closed'
|
||||
|
||||
def __init__(self, *args, closed_on_init=True, auto_close_interval=10, indicate_area=False, **kwargs):
|
||||
self._state = d.STATE_CLOSED
|
||||
self._status = d.STATE_CLOSED
|
||||
super(Door, self).__init__(*args, **kwargs)
|
||||
self.auto_close_interval = auto_close_interval
|
||||
self.time_to_close = 0
|
||||
if not closed_on_init:
|
||||
self._open()
|
||||
else:
|
||||
self._close()
|
||||
if indicate_area:
|
||||
self._collection.add_items([DoorIndicator(x) for x in self.tile.neighboring_floor])
|
||||
|
||||
@@ -58,22 +60,22 @@ class Door(Entity):
|
||||
|
||||
@property
|
||||
def is_closed(self):
|
||||
return self._state == d.STATE_CLOSED
|
||||
return self._status == d.STATE_CLOSED
|
||||
|
||||
@property
|
||||
def is_open(self):
|
||||
return self._state == d.STATE_OPEN
|
||||
return self._status == d.STATE_OPEN
|
||||
|
||||
@property
|
||||
def status(self):
|
||||
return self._state
|
||||
return self._status
|
||||
|
||||
def render(self):
|
||||
name, state = 'door_open' if self.is_open else 'door_closed', 'blank'
|
||||
return RenderEntity(name, self.pos, 1, 'none', state, self.identifier_int + 1)
|
||||
|
||||
def use(self):
|
||||
if self._state == d.STATE_OPEN:
|
||||
if self._status == d.STATE_OPEN:
|
||||
self._close()
|
||||
else:
|
||||
self._open()
|
||||
@@ -90,8 +92,8 @@ class Door(Entity):
|
||||
return c.NOT_VALID
|
||||
|
||||
def _open(self):
|
||||
self._state = d.STATE_OPEN
|
||||
self._status = d.STATE_OPEN
|
||||
self.time_to_close = self.auto_close_interval
|
||||
|
||||
def _close(self):
|
||||
self._state = d.STATE_CLOSED
|
||||
self._status = d.STATE_CLOSED
|
||||
|
0
marl_factory_grid/modules/factory/__init__.py
Normal file
0
marl_factory_grid/modules/factory/__init__.py
Normal file
32
marl_factory_grid/modules/factory/rules.py
Normal file
32
marl_factory_grid/modules/factory/rules.py
Normal file
@@ -0,0 +1,32 @@
|
||||
import random
|
||||
from typing import List, Union
|
||||
|
||||
from marl_factory_grid.environment.rules import Rule
|
||||
from marl_factory_grid.environment import constants as c
|
||||
from marl_factory_grid.utils.results import TickResult
|
||||
|
||||
|
||||
class AgentSingleZonePlacementBeta(Rule):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
def on_init(self, state, lvl_map):
|
||||
zones = state[c.ZONES]
|
||||
n_zones = state[c.ZONES]
|
||||
agents = state[c.AGENT]
|
||||
if len(self.coordinates) == len(agents):
|
||||
coordinates = self.coordinates
|
||||
elif len(self.coordinates) > len(agents):
|
||||
coordinates = random.choices(self.coordinates, k=len(agents))
|
||||
else:
|
||||
raise ValueError
|
||||
tiles = [state[c.FLOOR].by_pos(pos) for pos in coordinates]
|
||||
for agent, tile in zip(agents, tiles):
|
||||
agent.move(tile)
|
||||
|
||||
def tick_step(self, state):
|
||||
return []
|
||||
|
||||
def tick_post_step(self, state) -> List[TickResult]:
|
||||
return []
|
@@ -0,0 +1,4 @@
|
||||
from .actions import ItemAction
|
||||
from .entitites import Item, DropOffLocation
|
||||
from .groups import DropOffLocations, Items, Inventory, Inventories
|
||||
from .rules import ItemRules
|
||||
|
@@ -8,6 +8,8 @@ from marl_factory_grid.modules.items import constants as i
|
||||
|
||||
class Item(Entity):
|
||||
|
||||
var_can_collide = False
|
||||
|
||||
def render(self):
|
||||
return RenderEntity(i.ITEM, self.tile.pos) if self.pos != c.VALUE_NO_POS else None
|
||||
|
||||
@@ -38,6 +40,22 @@ class Item(Entity):
|
||||
|
||||
class DropOffLocation(Entity):
|
||||
|
||||
@property
|
||||
def var_can_collide(self):
|
||||
return False
|
||||
|
||||
@property
|
||||
def var_can_move(self):
|
||||
return False
|
||||
|
||||
@property
|
||||
def var_is_blocking_light(self):
|
||||
return False
|
||||
|
||||
@property
|
||||
def var_has_position(self):
|
||||
return True
|
||||
|
||||
def render(self):
|
||||
return RenderEntity(i.DROP_OFF, self.tile.pos)
|
||||
|
||||
|
@@ -17,15 +17,6 @@ class Items(PositionMixin, EnvObjects):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
def spawn_items(self, tiles: List[Floor]):
|
||||
items = [self._entity(tile) for tile in tiles]
|
||||
self.add_items(items)
|
||||
|
||||
def despawn_items(self, items: List[Item]):
|
||||
items = [items] if isinstance(items, Item) else items
|
||||
for item in items:
|
||||
del self[item]
|
||||
|
||||
|
||||
class Inventory(IsBoundMixin, EnvObjects):
|
||||
|
||||
@@ -58,11 +49,7 @@ class Inventory(IsBoundMixin, EnvObjects):
|
||||
class Inventories(HasBoundedMixin, Objects):
|
||||
|
||||
_entity = Inventory
|
||||
can_move = False
|
||||
|
||||
@property
|
||||
def obs_pairs(self):
|
||||
return [(x.name, x) for x in self]
|
||||
var_can_move = False
|
||||
|
||||
def __init__(self, size, *args, **kwargs):
|
||||
super(Inventories, self).__init__(*args, **kwargs)
|
||||
@@ -70,7 +57,7 @@ class Inventories(HasBoundedMixin, Objects):
|
||||
self._obs = None
|
||||
self._lazy_eval_transforms = []
|
||||
|
||||
def spawn_inventories(self, agents):
|
||||
def spawn(self, agents):
|
||||
inventories = [self._entity(agent, self.size,)
|
||||
for _, agent in enumerate(agents)]
|
||||
self.add_items(inventories)
|
||||
|
@@ -18,7 +18,7 @@ class ItemRules(Rule):
|
||||
self.max_dropoff_storage_size = max_dropoff_storage_size
|
||||
self.n_locations = n_locations
|
||||
|
||||
def on_init(self, state):
|
||||
def on_init(self, state, lvl_map):
|
||||
self.trigger_drop_off_location_spawn(state)
|
||||
self._next_item_spawn = self.spawn_frequency
|
||||
self.trigger_inventory_spawn(state)
|
||||
@@ -42,7 +42,7 @@ class ItemRules(Rule):
|
||||
def trigger_item_spawn(self, state):
|
||||
if item_to_spawns := max(0, (self.n_items - len(state[i.ITEM]))):
|
||||
empty_tiles = state[c.FLOOR].empty_tiles[:item_to_spawns]
|
||||
state[i.ITEM].spawn_items(empty_tiles)
|
||||
state[i.ITEM].spawn(empty_tiles)
|
||||
self._next_item_spawn = self.spawn_frequency
|
||||
state.print(f'{item_to_spawns} new items have been spawned; next spawn in {self._next_item_spawn}')
|
||||
return len(empty_tiles)
|
||||
@@ -52,7 +52,7 @@ class ItemRules(Rule):
|
||||
|
||||
@staticmethod
|
||||
def trigger_inventory_spawn(state):
|
||||
state[i.INVENTORY].spawn_inventories(state[c.AGENT])
|
||||
state[i.INVENTORY].spawn(state[c.AGENT])
|
||||
|
||||
def tick_post_step(self, state) -> List[TickResult]:
|
||||
for item in list(state[i.ITEM].values()):
|
||||
|
@@ -0,0 +1,3 @@
|
||||
from .entitites import Machine
|
||||
from .groups import Machines
|
||||
from .rules import MachineRule
|
||||
|
25
marl_factory_grid/modules/machines/actions.py
Normal file
25
marl_factory_grid/modules/machines/actions.py
Normal file
@@ -0,0 +1,25 @@
|
||||
from typing import Union
|
||||
|
||||
from marl_factory_grid.environment.actions import Action
|
||||
from marl_factory_grid.utils.results import ActionResult
|
||||
|
||||
from marl_factory_grid.modules.machines import constants as m, rewards as r
|
||||
from marl_factory_grid.environment import constants as c
|
||||
|
||||
|
||||
class MachineAction(Action):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(m.MACHINE_ACTION)
|
||||
|
||||
def do(self, entity, state) -> Union[None, ActionResult]:
|
||||
if machine := state[m.MACHINES].by_pos(entity.pos):
|
||||
if valid := machine.maintain():
|
||||
return ActionResult(entity=entity, identifier=self._identifier, validity=valid, reward=r.MAINTAIN_VALID)
|
||||
else:
|
||||
return ActionResult(entity=entity, identifier=self._identifier, validity=valid, reward=r.MAINTAIN_FAIL)
|
||||
else:
|
||||
return ActionResult(entity=entity, identifier=self._identifier, validity=c.NOT_VALID, reward=r.MAINTAIN_FAIL)
|
||||
|
||||
|
||||
|
@@ -2,6 +2,8 @@
|
||||
MACHINES = 'Machines'
|
||||
MACHINE = 'Machine'
|
||||
|
||||
MACHINE_ACTION = 'Maintain'
|
||||
|
||||
STATE_WORK = 'working'
|
||||
STATE_IDLE = 'idling'
|
||||
STATE_MAINTAIN = 'maintenance'
|
||||
|
@@ -2,27 +2,43 @@ from marl_factory_grid.environment.entity.entity import Entity
|
||||
from marl_factory_grid.utils.render import RenderEntity
|
||||
from marl_factory_grid.environment import constants as c
|
||||
from marl_factory_grid.utils.results import TickResult
|
||||
from marl_factory_grid.modules.machines import constants as m, rewards as r
|
||||
|
||||
from . import constants as m
|
||||
|
||||
|
||||
class Machine(Entity):
|
||||
|
||||
@property
|
||||
def var_can_collide(self):
|
||||
return False
|
||||
|
||||
@property
|
||||
def var_can_move(self):
|
||||
return False
|
||||
|
||||
@property
|
||||
def var_is_blocking_light(self):
|
||||
return False
|
||||
|
||||
@property
|
||||
def var_has_position(self):
|
||||
return True
|
||||
|
||||
@property
|
||||
def encoding(self):
|
||||
return self._encodings[self.state]
|
||||
return self._encodings[self.status]
|
||||
|
||||
def __init__(self, *args, work_interval: int = 10, pause_interval: int = 15, **kwargs):
|
||||
super(Machine, self).__init__(*args, **kwargs)
|
||||
self._intervals = dict({m.STATE_IDLE: pause_interval, m.STATE_WORK: work_interval})
|
||||
self._encodings = dict({m.STATE_IDLE: pause_interval, m.STATE_WORK: work_interval})
|
||||
|
||||
self.state = m.STATE_IDLE
|
||||
self.status = m.STATE_IDLE
|
||||
self.health = 100
|
||||
self._counter = 0
|
||||
self.__delattr__('move')
|
||||
|
||||
def maintain(self):
|
||||
if self.state == m.STATE_WORK:
|
||||
if self.status == m.STATE_WORK:
|
||||
return c.NOT_VALID
|
||||
if self.health <= 98:
|
||||
self.health = 100
|
||||
@@ -31,10 +47,10 @@ class Machine(Entity):
|
||||
return c.NOT_VALID
|
||||
|
||||
def tick(self):
|
||||
if self.state == m.STATE_MAINTAIN and any([c.AGENT in x.name for x in self.tile.guests]):
|
||||
return TickResult(self.name, c.VALID, r.NONE, self)
|
||||
elif self.state == m.STATE_MAINTAIN and not any([c.AGENT in x.name for x in self.tile.guests]):
|
||||
self.state = m.STATE_WORK
|
||||
if self.status == m.STATE_MAINTAIN and any([c.AGENT in x.name for x in self.tile.guests]):
|
||||
return TickResult(identifier=self.name, validity=c.VALID, reward=0, entity=self)
|
||||
elif self.status == m.STATE_MAINTAIN and not any([c.AGENT in x.name for x in self.tile.guests]):
|
||||
self.status = m.STATE_WORK
|
||||
self.reset_counter()
|
||||
return None
|
||||
elif self._counter:
|
||||
@@ -42,12 +58,12 @@ class Machine(Entity):
|
||||
self.health -= 1
|
||||
return None
|
||||
else:
|
||||
self.state = m.STATE_WORK if self.state == m.STATE_IDLE else m.STATE_IDLE
|
||||
self.status = m.STATE_WORK if self.status == m.STATE_IDLE else m.STATE_IDLE
|
||||
self.reset_counter()
|
||||
return None
|
||||
|
||||
def reset_counter(self):
|
||||
self._counter = self._intervals[self.state]
|
||||
self._counter = self._intervals[self.status]
|
||||
|
||||
def render(self):
|
||||
return RenderEntity(m.MACHINE, self.pos)
|
||||
|
@@ -1,6 +1,7 @@
|
||||
from marl_factory_grid.environment.groups.env_objects import EnvObjects
|
||||
from marl_factory_grid.environment.groups.mixins import PositionMixin
|
||||
from marl_factory_grid.modules.machines.entitites import Machine
|
||||
|
||||
from .entitites import Machine
|
||||
|
||||
|
||||
class Machines(PositionMixin, EnvObjects):
|
||||
|
BIN
marl_factory_grid/modules/machines/machine.png
Normal file
BIN
marl_factory_grid/modules/machines/machine.png
Normal file
Binary file not shown.
After Width: | Height: | Size: 8.5 KiB |
@@ -12,7 +12,7 @@ class MachineRule(Rule):
|
||||
super(MachineRule, self).__init__()
|
||||
self.n_machines = n_machines
|
||||
|
||||
def on_init(self, state):
|
||||
def on_init(self, state, lvl_map):
|
||||
empty_tiles = state[c.FLOOR].empty_tiles[:self.n_machines]
|
||||
state[m.MACHINES].add_items(Machine(tile) for tile in empty_tiles)
|
||||
|
||||
@@ -27,3 +27,9 @@ class MachineRule(Rule):
|
||||
|
||||
def on_check_done(self, state) -> List[DoneResult]:
|
||||
pass
|
||||
|
||||
|
||||
class DoneOnBreakRule(Rule):
|
||||
|
||||
def on_check_done(self, state) -> List[DoneResult]:
|
||||
pass
|
2
marl_factory_grid/modules/maintenance/__init__.py
Normal file
2
marl_factory_grid/modules/maintenance/__init__.py
Normal file
@@ -0,0 +1,2 @@
|
||||
from .entities import Maintainer
|
||||
from .groups import Maintainers
|
3
marl_factory_grid/modules/maintenance/constants.py
Normal file
3
marl_factory_grid/modules/maintenance/constants.py
Normal file
@@ -0,0 +1,3 @@
|
||||
MAINTAINER = 'Maintainer' # TEMPLATE _identifier. Define your own!
|
||||
MAINTAINERS = 'Maintainers' # TEMPLATE _identifier. Define your own!
|
||||
|
102
marl_factory_grid/modules/maintenance/entities.py
Normal file
102
marl_factory_grid/modules/maintenance/entities.py
Normal file
@@ -0,0 +1,102 @@
|
||||
import networkx as nx
|
||||
import numpy as np
|
||||
|
||||
from ...algorithms.static.utils import points_to_graph
|
||||
from ...environment import constants as c
|
||||
from ...environment.actions import Action, ALL_BASEACTIONS
|
||||
from ...environment.entity.entity import Entity
|
||||
from ..doors import constants as do
|
||||
from ..maintenance import constants as mi
|
||||
from ...utils.helpers import MOVEMAP
|
||||
from ...utils.render import RenderEntity
|
||||
from ...utils.states import Gamestate
|
||||
|
||||
|
||||
class Maintainer(Entity):
|
||||
|
||||
@property
|
||||
def var_can_collide(self):
|
||||
return True
|
||||
|
||||
@property
|
||||
def var_can_move(self):
|
||||
return False
|
||||
|
||||
@property
|
||||
def var_is_blocking_light(self):
|
||||
return False
|
||||
|
||||
@property
|
||||
def var_has_position(self):
|
||||
return True
|
||||
|
||||
def __init__(self, state: Gamestate, objective: str, action: Action, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.action = action
|
||||
self.actions = [x() for x in ALL_BASEACTIONS]
|
||||
self.objective = objective
|
||||
self._path = None
|
||||
self._next = []
|
||||
self._last = []
|
||||
self._last_serviced = 'None'
|
||||
self._floortile_graph = points_to_graph(state[c.FLOOR].positions)
|
||||
|
||||
def tick(self, state):
|
||||
if found_objective := state[self.objective].by_pos(self.pos):
|
||||
if found_objective.name != self._last_serviced:
|
||||
self.action.do(self, state)
|
||||
self._last_serviced = found_objective.name
|
||||
else:
|
||||
action = self.get_move_action(state)
|
||||
return action.do(self, state)
|
||||
else:
|
||||
action = self.get_move_action(state)
|
||||
return action.do(self, state)
|
||||
|
||||
def get_move_action(self, state) -> Action:
|
||||
if self._path is None or not self._path:
|
||||
if not self._next:
|
||||
self._next = list(state[self.objective].values())
|
||||
self._last = []
|
||||
self._last.append(self._next.pop())
|
||||
self._path = self.calculate_route(self._last[-1])
|
||||
|
||||
if door := self._door_is_close():
|
||||
if door.is_closed:
|
||||
# Translate the action_object to an integer to have the same output as any other model
|
||||
action = do.ACTION_DOOR_USE
|
||||
else:
|
||||
action = self._predict_move(state)
|
||||
else:
|
||||
action = self._predict_move(state)
|
||||
# Translate the action_object to an integer to have the same output as any other model
|
||||
try:
|
||||
action_obj = next(x for x in self.actions if x.name == action)
|
||||
except (StopIteration, UnboundLocalError):
|
||||
print('Will not happen')
|
||||
raise EnvironmentError
|
||||
return action_obj
|
||||
|
||||
def calculate_route(self, entity):
|
||||
route = nx.shortest_path(self._floortile_graph, self.pos, entity.pos)
|
||||
return route[1:]
|
||||
|
||||
def _door_is_close(self):
|
||||
try:
|
||||
return next(y for x in self.tile.neighboring_floor for y in x.guests if do.DOOR in y.name)
|
||||
except StopIteration:
|
||||
return None
|
||||
|
||||
def _predict_move(self, state):
|
||||
next_pos = self._path[0]
|
||||
if len(state[c.FLOOR].by_pos(next_pos).guests_that_can_collide) > 0:
|
||||
action = c.NOOP
|
||||
else:
|
||||
next_pos = self._path.pop(0)
|
||||
diff = np.subtract(next_pos, self.pos)
|
||||
# Retrieve action based on the pos dif (like in: What do I have to do to get there?)
|
||||
action = next(action for action, pos_diff in MOVEMAP.items() if np.all(diff == pos_diff))
|
||||
return action
|
||||
|
||||
def render(self):
|
||||
return RenderEntity(mi.MAINTAINER, self.pos)
|
27
marl_factory_grid/modules/maintenance/groups.py
Normal file
27
marl_factory_grid/modules/maintenance/groups.py
Normal file
@@ -0,0 +1,27 @@
|
||||
from typing import List
|
||||
|
||||
from .entities import Maintainer
|
||||
from marl_factory_grid.environment.entity.wall_floor import Floor
|
||||
from marl_factory_grid.environment.groups.env_objects import EnvObjects
|
||||
from marl_factory_grid.environment.groups.mixins import PositionMixin
|
||||
from ..machines.actions import MachineAction
|
||||
from ...utils.render import RenderEntity
|
||||
from ...utils.states import Gamestate
|
||||
|
||||
from ..machines import constants as mc
|
||||
from . import constants as mi
|
||||
|
||||
|
||||
class Maintainers(PositionMixin, EnvObjects):
|
||||
|
||||
_entity = Maintainer
|
||||
var_can_collide = True
|
||||
var_can_move = True
|
||||
var_is_blocking_light = False
|
||||
var_has_position = True
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
def spawn(self, tiles: List[Floor], state: Gamestate):
|
||||
self.add_items([self._entity(state, mc.MACHINES, MachineAction(), tile) for tile in tiles])
|
BIN
marl_factory_grid/modules/maintenance/maintainer.png
Normal file
BIN
marl_factory_grid/modules/maintenance/maintainer.png
Normal file
Binary file not shown.
After Width: | Height: | Size: 22 KiB |
1
marl_factory_grid/modules/maintenance/rewards.py
Normal file
1
marl_factory_grid/modules/maintenance/rewards.py
Normal file
@@ -0,0 +1 @@
|
||||
MAINTAINER_COLLISION_REWARD = -5
|
39
marl_factory_grid/modules/maintenance/rules.py
Normal file
39
marl_factory_grid/modules/maintenance/rules.py
Normal file
@@ -0,0 +1,39 @@
|
||||
from typing import List
|
||||
from marl_factory_grid.environment.rules import Rule
|
||||
from marl_factory_grid.utils.results import TickResult, DoneResult
|
||||
from marl_factory_grid.environment import constants as c
|
||||
from . import rewards as r
|
||||
from . import constants as M
|
||||
from marl_factory_grid.utils.states import Gamestate
|
||||
|
||||
|
||||
class MaintenanceRule(Rule):
|
||||
|
||||
def __init__(self, n_maintainer: int = 1, *args, **kwargs):
|
||||
super(MaintenanceRule, self).__init__(*args, **kwargs)
|
||||
self.n_maintainer = n_maintainer
|
||||
|
||||
def on_init(self, state: Gamestate, lvl_map):
|
||||
state[M.MAINTAINERS].spawn(state[c.FLOOR].empty_tiles[:self.n_maintainer], state)
|
||||
pass
|
||||
|
||||
def tick_pre_step(self, state) -> List[TickResult]:
|
||||
pass
|
||||
|
||||
def tick_step(self, state) -> List[TickResult]:
|
||||
for maintainer in state[M.MAINTAINERS]:
|
||||
maintainer.tick(state)
|
||||
return []
|
||||
|
||||
def tick_post_step(self, state) -> List[TickResult]:
|
||||
pass
|
||||
|
||||
def on_check_done(self, state) -> List[DoneResult]:
|
||||
agents = list(state[c.AGENT].values())
|
||||
m_pos = state[M.MAINTAINERS].positions
|
||||
done_results = []
|
||||
for agent in agents:
|
||||
if agent.pos in m_pos:
|
||||
done_results.append(DoneResult(entity=agent, validity=c.VALID, identifier=self.name,
|
||||
reward=r.MAINTAINER_COLLISION_REWARD))
|
||||
return done_results
|
3
marl_factory_grid/modules/zones/__init__.py
Normal file
3
marl_factory_grid/modules/zones/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
||||
from .entitites import Zone
|
||||
from .groups import Zones
|
||||
from .rules import AgentSingleZonePlacement
|
4
marl_factory_grid/modules/zones/constants.py
Normal file
4
marl_factory_grid/modules/zones/constants.py
Normal file
@@ -0,0 +1,4 @@
|
||||
# Names / Identifiers
|
||||
|
||||
ZONES = 'Zones' # Identifier of Zone-objects and groups (groups).
|
||||
ZONE = 'Zone' # -||-
|
21
marl_factory_grid/modules/zones/entitites.py
Normal file
21
marl_factory_grid/modules/zones/entitites.py
Normal file
@@ -0,0 +1,21 @@
|
||||
import random
|
||||
from typing import List
|
||||
|
||||
from marl_factory_grid.environment.entity.entity import Entity
|
||||
from marl_factory_grid.environment.entity.object import Object
|
||||
from marl_factory_grid.environment.entity.wall_floor import Floor
|
||||
from marl_factory_grid.utils.render import RenderEntity
|
||||
from marl_factory_grid.environment import constants as c
|
||||
|
||||
from marl_factory_grid.modules.doors import constants as d
|
||||
|
||||
|
||||
class Zone(Object):
|
||||
|
||||
def __init__(self, tiles: List[Floor], *args, **kwargs):
|
||||
super(Zone, self).__init__(*args, **kwargs)
|
||||
self.tiles = tiles
|
||||
|
||||
@property
|
||||
def random_tile(self):
|
||||
return random.choice(self.tiles)
|
12
marl_factory_grid/modules/zones/groups.py
Normal file
12
marl_factory_grid/modules/zones/groups.py
Normal file
@@ -0,0 +1,12 @@
|
||||
from marl_factory_grid.environment.groups.objects import Objects
|
||||
from marl_factory_grid.modules.zones import Zone
|
||||
|
||||
|
||||
class Zones(Objects):
|
||||
|
||||
symbol = None
|
||||
_entity = Zone
|
||||
var_can_move = False
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super(Zones, self).__init__(*args, can_collide=True, **kwargs)
|
33
marl_factory_grid/modules/zones/rules.py
Normal file
33
marl_factory_grid/modules/zones/rules.py
Normal file
@@ -0,0 +1,33 @@
|
||||
from random import choices
|
||||
|
||||
from marl_factory_grid.environment.rules import Rule
|
||||
from marl_factory_grid.environment import constants as c
|
||||
from marl_factory_grid.modules.zones import Zone
|
||||
from . import constants as z
|
||||
|
||||
|
||||
class AgentSingleZonePlacement(Rule):
|
||||
|
||||
def __init__(self, n_zones=3):
|
||||
super().__init__()
|
||||
self.n_zones = n_zones
|
||||
|
||||
def on_init(self, state, lvl_map):
|
||||
zones = []
|
||||
|
||||
for z_idx in range(1, self.n_zones):
|
||||
zone_positions = lvl_map.get_coordinates_for_symbol(z_idx)
|
||||
assert len(zone_positions)
|
||||
zones.append(Zone([state[c.FLOOR].by_pos(pos) for pos in zone_positions]))
|
||||
state[z.ZONES].add_items(zones)
|
||||
|
||||
n_agents = len(state[c.AGENT])
|
||||
assert len(state[z.ZONES]) >= n_agents
|
||||
|
||||
z_idxs = choices(list(range(len(state[z.ZONES]))), k=n_agents)
|
||||
for agent in state[c.AGENT]:
|
||||
agent.move(state[z.ZONES][z_idxs.pop()].random_tile)
|
||||
return []
|
||||
|
||||
def tick_step(self, state):
|
||||
return []
|
Reference in New Issue
Block a user