mirror of
https://github.com/illiumst/marl-factory-grid.git
synced 2025-06-24 04:11:36 +02:00
new rules, new spawn logic, small fixes, default and narrow corridor debugged
This commit is contained in:
@ -1,48 +1,35 @@
|
||||
from random import shuffle
|
||||
|
||||
import networkx as nx
|
||||
import numpy as np
|
||||
|
||||
|
||||
from ...algorithms.static.utils import points_to_graph
|
||||
from ...environment import constants as c
|
||||
from ...environment.actions import Action, ALL_BASEACTIONS
|
||||
from ...environment.entity.entity import Entity
|
||||
from ..doors import constants as do
|
||||
from ..maintenance import constants as mi
|
||||
from ...utils.helpers import MOVEMAP
|
||||
from ...utils.utility_classes import RenderEntity
|
||||
from ...utils.states import Gamestate
|
||||
from ...utils import helpers as h
|
||||
from ...utils.utility_classes import RenderEntity, Floor
|
||||
from ..doors import DoorUse
|
||||
|
||||
|
||||
class Maintainer(Entity):
|
||||
|
||||
@property
|
||||
def var_can_collide(self):
|
||||
return True
|
||||
|
||||
@property
|
||||
def var_can_move(self):
|
||||
return False
|
||||
|
||||
@property
|
||||
def var_is_blocking_light(self):
|
||||
return False
|
||||
|
||||
@property
|
||||
def var_has_position(self):
|
||||
return True
|
||||
|
||||
def __init__(self, state: Gamestate, objective: str, action: Action, *args, **kwargs):
|
||||
def __init__(self, objective: str, action: Action, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.action = action
|
||||
self.actions = [x() for x in ALL_BASEACTIONS]
|
||||
self.actions = [x() for x in ALL_BASEACTIONS] + [DoorUse()]
|
||||
self.objective = objective
|
||||
self._path = None
|
||||
self._next = []
|
||||
self._last = []
|
||||
self._last_serviced = 'None'
|
||||
self._floortile_graph = points_to_graph(state.entities.floorlist)
|
||||
self._floortile_graph = None
|
||||
|
||||
def tick(self, state):
|
||||
if found_objective := state[self.objective].by_pos(self.pos):
|
||||
if found_objective := h.get_first(state[self.objective].by_pos(self.pos)):
|
||||
if found_objective.name != self._last_serviced:
|
||||
self.action.do(self, state)
|
||||
self._last_serviced = found_objective.name
|
||||
@ -54,24 +41,27 @@ class Maintainer(Entity):
|
||||
return action.do(self, state)
|
||||
|
||||
def get_move_action(self, state) -> Action:
|
||||
if not self._floortile_graph:
|
||||
state.print("Generating Floorgraph....")
|
||||
self._floortile_graph = points_to_graph(state.entities.floorlist)
|
||||
if self._path is None or not self._path:
|
||||
if not self._next:
|
||||
self._next = list(state[self.objective].values())
|
||||
self._next = list(state[self.objective].values()) + [Floor(*state.random_free_position)]
|
||||
shuffle(self._next)
|
||||
self._last = []
|
||||
self._last.append(self._next.pop())
|
||||
state.print("Calculating shortest path....")
|
||||
self._path = self.calculate_route(self._last[-1])
|
||||
|
||||
if door := self._door_is_close(state):
|
||||
if door.is_closed:
|
||||
# Translate the action_object to an integer to have the same output as any other model
|
||||
action = do.ACTION_DOOR_USE
|
||||
else:
|
||||
action = self._predict_move(state)
|
||||
if door := self._closed_door_in_path(state):
|
||||
state.print(f"{self} found {door} that is closed. Attempt to open.")
|
||||
# Translate the action_object to an integer to have the same output as any other model
|
||||
action = do.ACTION_DOOR_USE
|
||||
else:
|
||||
action = self._predict_move(state)
|
||||
# Translate the action_object to an integer to have the same output as any other model
|
||||
try:
|
||||
action_obj = next(x for x in self.actions if x.name == action)
|
||||
action_obj = h.get_first(self.actions, lambda x: x.name == action)
|
||||
except (StopIteration, UnboundLocalError):
|
||||
print('Will not happen')
|
||||
raise EnvironmentError
|
||||
@ -81,11 +71,10 @@ class Maintainer(Entity):
|
||||
route = nx.shortest_path(self._floortile_graph, self.pos, entity.pos)
|
||||
return route[1:]
|
||||
|
||||
def _door_is_close(self, state):
|
||||
state.print("Found a door that is close.")
|
||||
try:
|
||||
return next(y for x in state.entities.neighboring_positions(self.state.pos) for y in state.entities.pos_dict[x] if do.DOOR in y.name)
|
||||
except StopIteration:
|
||||
def _closed_door_in_path(self, state):
|
||||
if self._path:
|
||||
return h.get_first(state[do.DOORS].by_pos(self._path[0]), lambda x: x.is_closed)
|
||||
else:
|
||||
return None
|
||||
|
||||
def _predict_move(self, state):
|
||||
@ -96,7 +85,7 @@ class Maintainer(Entity):
|
||||
next_pos = self._path.pop(0)
|
||||
diff = np.subtract(next_pos, self.pos)
|
||||
# Retrieve action based on the pos dif (like in: What do I have to do to get there?)
|
||||
action = next(action for action, pos_diff in MOVEMAP.items() if np.all(diff == pos_diff))
|
||||
action = next(action for action, pos_diff in h.MOVEMAP.items() if np.all(diff == pos_diff))
|
||||
return action
|
||||
|
||||
def render(self):
|
||||
|
@ -1,4 +1,4 @@
|
||||
from typing import Union, List, Tuple
|
||||
from typing import Union, List, Tuple, Dict
|
||||
|
||||
from marl_factory_grid.environment.groups.collection import Collection
|
||||
from .entities import Maintainer
|
||||
@ -10,25 +10,21 @@ from ...utils.states import Gamestate
|
||||
class Maintainers(Collection):
|
||||
_entity = Maintainer
|
||||
|
||||
@property
|
||||
def var_can_collide(self):
|
||||
return True
|
||||
var_can_collide = True
|
||||
var_can_move = True
|
||||
var_is_blocking_light = False
|
||||
var_has_position = True
|
||||
|
||||
@property
|
||||
def var_can_move(self):
|
||||
return True
|
||||
|
||||
@property
|
||||
def var_is_blocking_light(self):
|
||||
return False
|
||||
|
||||
@property
|
||||
def var_has_position(self):
|
||||
return True
|
||||
def __init__(self, size, *args, coords_or_quantity: int = None,
|
||||
spawnrule: Union[None, Dict[str, dict]] = None,
|
||||
**kwargs):
|
||||
super(Collection, self).__init__(*args, **kwargs)
|
||||
self._coords_or_quantity = coords_or_quantity
|
||||
self.size = size
|
||||
self._spawnrule = spawnrule
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
def spawn(self, coords_or_quantity: Union[int, List[Tuple[(int, int)]]], *entity_args):
|
||||
state = entity_args[0]
|
||||
self.add_items([self._entity(state, mc.MACHINES, MachineAction(), pos) for pos in coords_or_quantity])
|
||||
self.add_items([self._entity(mc.MACHINES, MachineAction(), pos) for pos in coords_or_quantity])
|
||||
|
@ -4,29 +4,24 @@ from marl_factory_grid.utils.results import TickResult, DoneResult
|
||||
from marl_factory_grid.environment import constants as c
|
||||
from . import rewards as r
|
||||
from . import constants as M
|
||||
from marl_factory_grid.utils.states import Gamestate
|
||||
|
||||
|
||||
class MaintenanceRule(Rule):
|
||||
class MoveMaintainers(Rule):
|
||||
|
||||
def __init__(self, n_maintainer: int = 1, *args, **kwargs):
|
||||
super(MaintenanceRule, self).__init__(*args, **kwargs)
|
||||
self.n_maintainer = n_maintainer
|
||||
|
||||
def on_init(self, state: Gamestate, lvl_map):
|
||||
state[M.MAINTAINERS].spawn(state.entities.empty_positions[:self.n_maintainer], state)
|
||||
pass
|
||||
|
||||
def tick_pre_step(self, state) -> List[TickResult]:
|
||||
pass
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
def tick_step(self, state) -> List[TickResult]:
|
||||
for maintainer in state[M.MAINTAINERS]:
|
||||
maintainer.tick(state)
|
||||
# Todo: Return a Result Object.
|
||||
return []
|
||||
|
||||
def tick_post_step(self, state) -> List[TickResult]:
|
||||
pass
|
||||
|
||||
class DoneAtMaintainerCollision(Rule):
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
def on_check_done(self, state) -> List[DoneResult]:
|
||||
agents = list(state[c.AGENT].values())
|
||||
|
Reference in New Issue
Block a user