mirror of
https://github.com/illiumst/marl-factory-grid.git
synced 2025-05-23 15:26:43 +02:00
Merge branch 'main' into unit_testing
# Conflicts: # marl_factory_grid/environment/factory.py # marl_factory_grid/utils/states.py
This commit is contained in:
commit
3a7592b285
@ -174,7 +174,7 @@ class BaseActorCritic:
|
|||||||
hidden_critic=out.get(nms.HIDDEN_CRITIC, None)
|
hidden_critic=out.get(nms.HIDDEN_CRITIC, None)
|
||||||
)
|
)
|
||||||
eps_rew += torch.tensor(reward)
|
eps_rew += torch.tensor(reward)
|
||||||
results.append(eps_rew.tolist() + [np.sum(eps_rew).item()] + [episode])
|
results.append(eps_rew.tolist() + [sum(eps_rew).item()] + [episode])
|
||||||
episode += 1
|
episode += 1
|
||||||
agent_columns = [f'agent#{i}' for i in range(self.cfg['environment']['n_agents'])]
|
agent_columns = [f'agent#{i}' for i in range(self.cfg['environment']['n_agents'])]
|
||||||
results = pd.DataFrame(results, columns=agent_columns + ['sum', 'episode'])
|
results = pd.DataFrame(results, columns=agent_columns + ['sum', 'episode'])
|
||||||
|
@ -18,6 +18,7 @@ class Action(abc.ABC):
|
|||||||
|
|
||||||
@abc.abstractmethod
|
@abc.abstractmethod
|
||||||
def do(self, entity, state) -> Union[None, ActionResult]:
|
def do(self, entity, state) -> Union[None, ActionResult]:
|
||||||
|
print()
|
||||||
return
|
return
|
||||||
|
|
||||||
def __repr__(self):
|
def __repr__(self):
|
||||||
|
@ -41,7 +41,7 @@ class Object:
|
|||||||
|
|
||||||
def __init__(self, str_ident: Union[str, None] = None, **kwargs):
|
def __init__(self, str_ident: Union[str, None] = None, **kwargs):
|
||||||
self._bound_entity = None
|
self._bound_entity = None
|
||||||
self._observers = []
|
self._observers = set()
|
||||||
self._str_ident = str_ident
|
self._str_ident = str_ident
|
||||||
self.u_int = self._identify_and_count_up()
|
self.u_int = self._identify_and_count_up()
|
||||||
self._collection = None
|
self._collection = None
|
||||||
@ -75,7 +75,7 @@ class Object:
|
|||||||
self._collection = collection
|
self._collection = collection
|
||||||
|
|
||||||
def add_observer(self, observer):
|
def add_observer(self, observer):
|
||||||
self.observers.append(observer)
|
self.observers.add(observer)
|
||||||
observer.notify_add_entity(self)
|
observer.notify_add_entity(self)
|
||||||
|
|
||||||
def del_observer(self, observer):
|
def del_observer(self, observer):
|
||||||
|
@ -69,23 +69,6 @@ class Factory(gym.Env):
|
|||||||
# expensive - don't use; unless required !
|
# expensive - don't use; unless required !
|
||||||
self._renderer = None
|
self._renderer = None
|
||||||
|
|
||||||
# reset env to initial state, preparing env for new episode.
|
|
||||||
# returns tuple where the first dict contains initial observation for each agent in the env
|
|
||||||
self.reset()
|
|
||||||
|
|
||||||
def __getitem__(self, item):
|
|
||||||
return self.state.entities[item]
|
|
||||||
|
|
||||||
def reset(self) -> (dict, dict):
|
|
||||||
if self.state is not None:
|
|
||||||
for entity_group in self.state.entities:
|
|
||||||
try:
|
|
||||||
entity_group[0].reset_uid()
|
|
||||||
except (AttributeError, TypeError):
|
|
||||||
pass
|
|
||||||
|
|
||||||
self.state = None
|
|
||||||
|
|
||||||
# Init entities
|
# Init entities
|
||||||
entities = self.map.do_init()
|
entities = self.map.do_init()
|
||||||
|
|
||||||
@ -101,7 +84,6 @@ class Factory(gym.Env):
|
|||||||
self.state = Gamestate(entities, parsed_agents_conf, env_rules, env_tests, self.map.level_shape,
|
self.state = Gamestate(entities, parsed_agents_conf, env_rules, env_tests, self.map.level_shape,
|
||||||
self.conf.env_seed, self.conf.verbose)
|
self.conf.env_seed, self.conf.verbose)
|
||||||
|
|
||||||
# All is set up, trigger entity init with variable pos
|
|
||||||
# All is set up, trigger additional init (after agent entity spawn etc)
|
# All is set up, trigger additional init (after agent entity spawn etc)
|
||||||
self.state.rules.do_all_init(self.state, self.map)
|
self.state.rules.do_all_init(self.state, self.map)
|
||||||
|
|
||||||
@ -110,6 +92,17 @@ class Factory(gym.Env):
|
|||||||
# Build initial observations for all agents
|
# Build initial observations for all agents
|
||||||
# noinspection PyAttributeOutsideInit
|
# noinspection PyAttributeOutsideInit
|
||||||
self.obs_builder = OBSBuilder(self.map.level_shape, self.state, self.map.pomdp_r)
|
self.obs_builder = OBSBuilder(self.map.level_shape, self.state, self.map.pomdp_r)
|
||||||
|
|
||||||
|
def __getitem__(self, item):
|
||||||
|
return self.state.entities[item]
|
||||||
|
|
||||||
|
def reset(self) -> (dict, dict):
|
||||||
|
self.state.entities.reset()
|
||||||
|
|
||||||
|
# All is set up, trigger entity spawn with variable pos
|
||||||
|
self.state.rules.do_all_reset(self.state)
|
||||||
|
|
||||||
|
# Build initial observations for all agents
|
||||||
return self.obs_builder.refresh_and_build_for_all(self.state)
|
return self.obs_builder.refresh_and_build_for_all(self.state)
|
||||||
|
|
||||||
def manual_step_init(self) -> List[Result]:
|
def manual_step_init(self) -> List[Result]:
|
||||||
|
@ -2,7 +2,6 @@ from typing import List, Tuple, Union, Dict
|
|||||||
|
|
||||||
from marl_factory_grid.environment.entity.entity import Entity
|
from marl_factory_grid.environment.entity.entity import Entity
|
||||||
from marl_factory_grid.environment.groups.objects import Objects
|
from marl_factory_grid.environment.groups.objects import Objects
|
||||||
# noinspection PyProtectedMember
|
|
||||||
from marl_factory_grid.environment.entity.object import Object
|
from marl_factory_grid.environment.entity.object import Object
|
||||||
import marl_factory_grid.environment.constants as c
|
import marl_factory_grid.environment.constants as c
|
||||||
from marl_factory_grid.utils.results import Result
|
from marl_factory_grid.utils.results import Result
|
||||||
|
@ -31,9 +31,12 @@ class Entities(Objects):
|
|||||||
|
|
||||||
def __init__(self, floor_positions):
|
def __init__(self, floor_positions):
|
||||||
self._floor_positions = floor_positions
|
self._floor_positions = floor_positions
|
||||||
self.pos_dict = defaultdict(list)
|
self.pos_dict = None
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
return f'{self.__class__.__name__}{[x for x in self]}'
|
||||||
|
|
||||||
def guests_that_can_collide(self, pos):
|
def guests_that_can_collide(self, pos):
|
||||||
return [x for val in self.pos_dict[pos] for x in val if x.var_can_collide]
|
return [x for val in self.pos_dict[pos] for x in val if x.var_can_collide]
|
||||||
|
|
||||||
@ -108,3 +111,12 @@ class Entities(Objects):
|
|||||||
|
|
||||||
def is_occupied(self, pos):
|
def is_occupied(self, pos):
|
||||||
return len([x for x in self.pos_dict[pos] if x.var_can_collide or x.var_is_blocking_pos]) >= 1
|
return len([x for x in self.pos_dict[pos] if x.var_can_collide or x.var_is_blocking_pos]) >= 1
|
||||||
|
|
||||||
|
def reset(self):
|
||||||
|
self._observers = set(self)
|
||||||
|
self.pos_dict = defaultdict(list)
|
||||||
|
for entity_group in self:
|
||||||
|
entity_group.reset()
|
||||||
|
|
||||||
|
if hasattr(entity_group, "var_has_position") and entity_group.var_has_position:
|
||||||
|
entity_group.add_observer(self)
|
||||||
|
@ -44,7 +44,7 @@ class Objects:
|
|||||||
|
|
||||||
def __init__(self, *args, **kwargs):
|
def __init__(self, *args, **kwargs):
|
||||||
self._data = defaultdict(lambda: None)
|
self._data = defaultdict(lambda: None)
|
||||||
self._observers = [self]
|
self._observers = set(self)
|
||||||
self.pos_dict = defaultdict(list)
|
self.pos_dict = defaultdict(list)
|
||||||
|
|
||||||
def __len__(self):
|
def __len__(self):
|
||||||
@ -59,6 +59,8 @@ class Objects:
|
|||||||
assert self._data[item.name] is None, f'{item.name} allready exists!!!'
|
assert self._data[item.name] is None, f'{item.name} allready exists!!!'
|
||||||
self._data.update({item.name: item})
|
self._data.update({item.name: item})
|
||||||
item.set_collection(self)
|
item.set_collection(self)
|
||||||
|
if hasattr(self, "var_has_position") and self.var_has_position:
|
||||||
|
item.add_observer(self)
|
||||||
for observer in self.observers:
|
for observer in self.observers:
|
||||||
observer.notify_add_entity(item)
|
observer.notify_add_entity(item)
|
||||||
return self
|
return self
|
||||||
@ -82,9 +84,8 @@ class Objects:
|
|||||||
|
|
||||||
# noinspection PyUnresolvedReferences
|
# noinspection PyUnresolvedReferences
|
||||||
def add_observer(self, observer):
|
def add_observer(self, observer):
|
||||||
self.observers.append(observer)
|
self.observers.add(observer)
|
||||||
for entity in self:
|
for entity in self:
|
||||||
if observer not in entity.observers:
|
|
||||||
entity.add_observer(observer)
|
entity.add_observer(observer)
|
||||||
|
|
||||||
def add_items(self, items: List[_entity]):
|
def add_items(self, items: List[_entity]):
|
||||||
@ -127,8 +128,7 @@ class Objects:
|
|||||||
raise TypeError
|
raise TypeError
|
||||||
|
|
||||||
def __repr__(self):
|
def __repr__(self):
|
||||||
repr_dict = {key: val for key, val in self._data.items() if key not in [c.WALLS]}
|
return f'{self.__class__.__name__}[{len(self)}]'
|
||||||
return f'{self.__class__.__name__}[{repr_dict}]'
|
|
||||||
|
|
||||||
def notify_del_entity(self, entity: Object):
|
def notify_del_entity(self, entity: Object):
|
||||||
try:
|
try:
|
||||||
@ -163,3 +163,9 @@ class Objects:
|
|||||||
return h.get_first_index(self, filter_by=lambda x: x.belongs_to_entity(entity))
|
return h.get_first_index(self, filter_by=lambda x: x.belongs_to_entity(entity))
|
||||||
except (StopIteration, AttributeError):
|
except (StopIteration, AttributeError):
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
def reset(self):
|
||||||
|
self._data = defaultdict(lambda: None)
|
||||||
|
self._observers = set(self)
|
||||||
|
self.pos_dict = defaultdict(list)
|
||||||
|
|
||||||
|
@ -23,3 +23,7 @@ class Walls(Collection):
|
|||||||
return super().by_pos(pos)[0]
|
return super().by_pos(pos)[0]
|
||||||
except IndexError:
|
except IndexError:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
def reset(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
@ -23,7 +23,7 @@ class Rule(abc.ABC):
|
|||||||
def on_init(self, state, lvl_map):
|
def on_init(self, state, lvl_map):
|
||||||
return []
|
return []
|
||||||
|
|
||||||
def on_reset(self):
|
def on_reset(self, state) -> List[TickResult]:
|
||||||
return []
|
return []
|
||||||
|
|
||||||
def tick_pre_step(self, state) -> List[TickResult]:
|
def tick_pre_step(self, state) -> List[TickResult]:
|
||||||
@ -55,7 +55,7 @@ class SpawnEntity(Rule):
|
|||||||
self.collection = collection
|
self.collection = collection
|
||||||
self.ignore_blocking = ignore_blocking
|
self.ignore_blocking = ignore_blocking
|
||||||
|
|
||||||
def on_init(self, state, lvl_map) -> [TickResult]:
|
def on_reset(self, state) -> [TickResult]:
|
||||||
results = self.collection.trigger_spawn(state, ignore_blocking=self.ignore_blocking)
|
results = self.collection.trigger_spawn(state, ignore_blocking=self.ignore_blocking)
|
||||||
pos_str = f' on: {[x.pos for x in self.collection]}' if self.collection.var_has_position else ''
|
pos_str = f' on: {[x.pos for x in self.collection]}' if self.collection.var_has_position else ''
|
||||||
state.print(f'Initial {self.collection.__class__.__name__} were spawned{pos_str}')
|
state.print(f'Initial {self.collection.__class__.__name__} were spawned{pos_str}')
|
||||||
@ -68,8 +68,7 @@ class SpawnAgents(Rule):
|
|||||||
super().__init__()
|
super().__init__()
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def on_init(self, state, lvl_map):
|
def on_reset(self, state):
|
||||||
# agents = Agents(lvl_map.size)
|
|
||||||
agents = state[c.AGENT]
|
agents = state[c.AGENT]
|
||||||
empty_positions = state.entities.empty_positions[:len(state.agents_conf)]
|
empty_positions = state.entities.empty_positions[:len(state.agents_conf)]
|
||||||
for agent_name, agent_conf in state.agents_conf.items():
|
for agent_name, agent_conf in state.agents_conf.items():
|
||||||
@ -101,9 +100,6 @@ class DoneAtMaxStepsReached(Rule):
|
|||||||
super().__init__()
|
super().__init__()
|
||||||
self.max_steps = max_steps
|
self.max_steps = max_steps
|
||||||
|
|
||||||
def on_init(self, state, lvl_map):
|
|
||||||
pass
|
|
||||||
|
|
||||||
def on_check_done(self, state):
|
def on_check_done(self, state):
|
||||||
if self.max_steps <= state.curr_step:
|
if self.max_steps <= state.curr_step:
|
||||||
return [DoneResult(validity=c.VALID, identifier=self.name)]
|
return [DoneResult(validity=c.VALID, identifier=self.name)]
|
||||||
@ -115,7 +111,7 @@ class AssignGlobalPositions(Rule):
|
|||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
def on_init(self, state, lvl_map):
|
def on_reset(self, state, lvl_map):
|
||||||
from marl_factory_grid.environment.entity.util import GlobalPosition
|
from marl_factory_grid.environment.entity.util import GlobalPosition
|
||||||
for agent in state[c.AGENT]:
|
for agent in state[c.AGENT]:
|
||||||
gp = GlobalPosition(agent, lvl_map.level_shape)
|
gp = GlobalPosition(agent, lvl_map.level_shape)
|
||||||
|
@ -127,30 +127,3 @@ class DoneAtBatteryDischarge(BatteryDecharge):
|
|||||||
return [DoneResult(self.name, validity=c.VALID, reward=self.reward_discharge_done)]
|
return [DoneResult(self.name, validity=c.VALID, reward=self.reward_discharge_done)]
|
||||||
else:
|
else:
|
||||||
return [DoneResult(self.name, validity=c.NOT_VALID)]
|
return [DoneResult(self.name, validity=c.NOT_VALID)]
|
||||||
|
|
||||||
|
|
||||||
class SpawnChargePods(Rule):
|
|
||||||
|
|
||||||
def __init__(self, n_pods: int, charge_rate: float = 0.4, multi_charge: bool = False):
|
|
||||||
"""
|
|
||||||
Spawn Chargepods in accordance to the given parameters.
|
|
||||||
|
|
||||||
:type n_pods: int
|
|
||||||
:param n_pods: How many charge pods are there?
|
|
||||||
:type charge_rate: float
|
|
||||||
:param charge_rate: How much juice does each use of the charge action top up?
|
|
||||||
:type multi_charge: bool
|
|
||||||
:param multi_charge: Whether multiple agents are able to charge at the same time.
|
|
||||||
"""
|
|
||||||
super().__init__()
|
|
||||||
self.multi_charge = multi_charge
|
|
||||||
self.charge_rate = charge_rate
|
|
||||||
self.n_pods = n_pods
|
|
||||||
|
|
||||||
def on_init(self, state, lvl_map):
|
|
||||||
pod_collection = state[b.CHARGE_PODS]
|
|
||||||
empty_positions = state.entities.empty_positions
|
|
||||||
pods = pod_collection.from_coordinates(empty_positions, entity_kwargs=dict(
|
|
||||||
multi_charge=self.multi_charge, charge_rate=self.charge_rate)
|
|
||||||
)
|
|
||||||
pod_collection.add_items(pods)
|
|
||||||
|
@ -34,7 +34,12 @@ class DirtPiles(Collection):
|
|||||||
self.coords_or_quantity = coords_or_quantity
|
self.coords_or_quantity = coords_or_quantity
|
||||||
self.initial_amount = initial_amount
|
self.initial_amount = initial_amount
|
||||||
|
|
||||||
def trigger_spawn(self, state, coords_or_quantity=0, amount=0) -> [Result]:
|
def trigger_spawn(self, state, coords_or_quantity=0, amount=0, ignore_blocking=False) -> [Result]:
|
||||||
|
if ignore_blocking:
|
||||||
|
print("##########################################")
|
||||||
|
print("Blocking should not be ignored for this Entity")
|
||||||
|
print("Exiting....")
|
||||||
|
exit()
|
||||||
coords_or_quantity = coords_or_quantity if coords_or_quantity else self.coords_or_quantity
|
coords_or_quantity = coords_or_quantity if coords_or_quantity else self.coords_or_quantity
|
||||||
n_new = int(abs(coords_or_quantity + (state.rng.uniform(-self.n_var, self.n_var))))
|
n_new = int(abs(coords_or_quantity + (state.rng.uniform(-self.n_var, self.n_var))))
|
||||||
n_new = state.get_n_random_free_positions(n_new)
|
n_new = state.get_n_random_free_positions(n_new)
|
||||||
|
@ -106,7 +106,7 @@ class SpawnDestinationsPerAgent(Rule):
|
|||||||
super(Rule, self).__init__()
|
super(Rule, self).__init__()
|
||||||
self.per_agent_positions = {key: [ast.literal_eval(x) for x in val] for key, val in coords_or_quantity.items()}
|
self.per_agent_positions = {key: [ast.literal_eval(x) for x in val] for key, val in coords_or_quantity.items()}
|
||||||
|
|
||||||
def on_init(self, state, lvl_map):
|
def on_reset(self, state, lvl_map):
|
||||||
for (agent_name, position_list) in self.per_agent_positions.items():
|
for (agent_name, position_list) in self.per_agent_positions.items():
|
||||||
agent = h.get_first(state[c.AGENT], lambda x: agent_name in x.name)
|
agent = h.get_first(state[c.AGENT], lambda x: agent_name in x.name)
|
||||||
assert agent
|
assert agent
|
||||||
|
@ -15,7 +15,7 @@ class DoorUse(Action):
|
|||||||
# Check if agent really is standing on a door:
|
# Check if agent really is standing on a door:
|
||||||
e = state.entities.get_entities_near_pos(entity.pos)
|
e = state.entities.get_entities_near_pos(entity.pos)
|
||||||
try:
|
try:
|
||||||
# Only one door opens TODO introcude loop
|
# Only one door opens TODO introduce loop
|
||||||
door = next(x for x in e if x.name.startswith(d.DOOR))
|
door = next(x for x in e if x.name.startswith(d.DOOR))
|
||||||
valid = door.use()
|
valid = door.use()
|
||||||
state.print(f'{entity.name} just used a {door.name} at {door.pos}')
|
state.print(f'{entity.name} just used a {door.name} at {door.pos}')
|
||||||
|
@ -117,3 +117,7 @@ class Door(Entity):
|
|||||||
def _reset_timer(self):
|
def _reset_timer(self):
|
||||||
self._time_to_close = self._auto_close_interval
|
self._time_to_close = self._auto_close_interval
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
def reset(self):
|
||||||
|
self._close()
|
||||||
|
self._reset_timer()
|
||||||
|
@ -23,3 +23,7 @@ class Doors(Collection):
|
|||||||
results.append(tick_result)
|
results.append(tick_result)
|
||||||
# TODO: Should return a Result object, not a random dict.
|
# TODO: Should return a Result object, not a random dict.
|
||||||
return results
|
return results
|
||||||
|
|
||||||
|
def reset(self):
|
||||||
|
for door in self:
|
||||||
|
door.reset()
|
||||||
|
@ -40,6 +40,6 @@ class IndicateDoorAreaInObservation(Rule):
|
|||||||
# Could then be combined with the "Combine"-approach.
|
# Could then be combined with the "Combine"-approach.
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
def on_init(self, state, lvl_map):
|
def on_reset(self, state, lvl_map):
|
||||||
for door in state[d.DOORS]:
|
for door in state[d.DOORS]:
|
||||||
state[d.DOORS].add_items([DoorIndicator(x) for x in state.entities.neighboring_positions(door.pos)])
|
state[d.DOORS].add_items([DoorIndicator(x) for x in state.entities.neighboring_positions(door.pos)])
|
||||||
|
@ -1,32 +0,0 @@
|
|||||||
import random
|
|
||||||
from typing import List
|
|
||||||
|
|
||||||
from marl_factory_grid.environment import constants as c
|
|
||||||
from marl_factory_grid.environment.rules import Rule
|
|
||||||
from marl_factory_grid.utils.results import TickResult
|
|
||||||
|
|
||||||
|
|
||||||
class AgentSingleZonePlacementBeta(Rule):
|
|
||||||
|
|
||||||
def __init__(self):
|
|
||||||
raise NotImplementedError()
|
|
||||||
# TODO!!!! Is this concept needed any more?
|
|
||||||
super().__init__()
|
|
||||||
|
|
||||||
def on_init(self, state, lvl_map):
|
|
||||||
agents = state[c.AGENT]
|
|
||||||
if len(self.coordinates) == len(agents):
|
|
||||||
coordinates = self.coordinates
|
|
||||||
elif len(self.coordinates) > len(agents):
|
|
||||||
coordinates = random.choices(self.coordinates, k=len(agents))
|
|
||||||
else:
|
|
||||||
raise ValueError
|
|
||||||
|
|
||||||
for agent, pos in zip(agents, coordinates):
|
|
||||||
agent.move(pos, state)
|
|
||||||
|
|
||||||
def tick_step(self, state):
|
|
||||||
return []
|
|
||||||
|
|
||||||
def tick_post_step(self, state) -> List[TickResult]:
|
|
||||||
return []
|
|
@ -3,8 +3,6 @@ from random import shuffle
|
|||||||
import networkx as nx
|
import networkx as nx
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
|
|
||||||
from ...algorithms.static.utils import points_to_graph
|
|
||||||
from ...environment import constants as c
|
from ...environment import constants as c
|
||||||
from ...environment.actions import Action, ALL_BASEACTIONS
|
from ...environment.actions import Action, ALL_BASEACTIONS
|
||||||
from ...environment.entity.entity import Entity
|
from ...environment.entity.entity import Entity
|
||||||
@ -26,7 +24,6 @@ class Maintainer(Entity):
|
|||||||
self._next = []
|
self._next = []
|
||||||
self._last = []
|
self._last = []
|
||||||
self._last_serviced = 'None'
|
self._last_serviced = 'None'
|
||||||
self._floortile_graph = None
|
|
||||||
|
|
||||||
def tick(self, state):
|
def tick(self, state):
|
||||||
if found_objective := h.get_first(state[self.objective].by_pos(self.pos)):
|
if found_objective := h.get_first(state[self.objective].by_pos(self.pos)):
|
||||||
@ -41,17 +38,18 @@ class Maintainer(Entity):
|
|||||||
return action.do(self, state)
|
return action.do(self, state)
|
||||||
|
|
||||||
def get_move_action(self, state) -> Action:
|
def get_move_action(self, state) -> Action:
|
||||||
if not self._floortile_graph:
|
if self._path is None or not len(self._path):
|
||||||
state.print("Generating Floorgraph....")
|
|
||||||
self._floortile_graph = points_to_graph(state.entities.floorlist)
|
|
||||||
if self._path is None or not self._path:
|
|
||||||
if not self._next:
|
if not self._next:
|
||||||
self._next = list(state[self.objective].values()) + [Floor(*state.random_free_position)]
|
self._next = list(state[self.objective].values()) + [Floor(*state.random_free_position)]
|
||||||
shuffle(self._next)
|
shuffle(self._next)
|
||||||
self._last = []
|
self._last = []
|
||||||
self._last.append(self._next.pop())
|
self._last.append(self._next.pop())
|
||||||
state.print("Calculating shortest path....")
|
state.print("Calculating shortest path....")
|
||||||
self._path = self.calculate_route(self._last[-1])
|
self._path = self.calculate_route(self._last[-1], state.floortile_graph)
|
||||||
|
if not self._path:
|
||||||
|
self._last.append(self._next.pop())
|
||||||
|
state.print("Calculating shortest path.... Again....")
|
||||||
|
self._path = self.calculate_route(self._last[-1], state.floortile_graph)
|
||||||
|
|
||||||
if door := self._closed_door_in_path(state):
|
if door := self._closed_door_in_path(state):
|
||||||
state.print(f"{self} found {door} that is closed. Attempt to open.")
|
state.print(f"{self} found {door} that is closed. Attempt to open.")
|
||||||
@ -67,8 +65,8 @@ class Maintainer(Entity):
|
|||||||
raise EnvironmentError
|
raise EnvironmentError
|
||||||
return action_obj
|
return action_obj
|
||||||
|
|
||||||
def calculate_route(self, entity):
|
def calculate_route(self, entity, floortile_graph):
|
||||||
route = nx.shortest_path(self._floortile_graph, self.pos, entity.pos)
|
route = nx.shortest_path(floortile_graph, self.pos, entity.pos)
|
||||||
return route[1:]
|
return route[1:]
|
||||||
|
|
||||||
def _closed_door_in_path(self, state):
|
def _closed_door_in_path(self, state):
|
||||||
|
@ -14,14 +14,8 @@ class Maintainers(Collection):
|
|||||||
var_is_blocking_light = False
|
var_is_blocking_light = False
|
||||||
var_has_position = True
|
var_has_position = True
|
||||||
|
|
||||||
def __init__(self, size, *args, coords_or_quantity: int = None,
|
def __init__(self, *args, **kwargs):
|
||||||
spawnrule: Union[None, Dict[str, dict]] = None,
|
super().__init__(*args, **kwargs)
|
||||||
**kwargs):
|
|
||||||
super(Collection, self).__init__(*args, **kwargs)
|
|
||||||
self._coords_or_quantity = coords_or_quantity
|
|
||||||
self.size = size
|
|
||||||
self._spawnrule = spawnrule
|
|
||||||
|
|
||||||
|
|
||||||
def spawn(self, coords_or_quantity: Union[int, List[Tuple[(int, int)]]], *entity_args):
|
def spawn(self, coords_or_quantity: Union[int, List[Tuple[(int, int)]]], *entity_args):
|
||||||
self.add_items([self._entity(mc.MACHINES, MachineAction(), pos) for pos in coords_or_quantity])
|
self.add_items([self._entity(mc.MACHINES, MachineAction(), pos) for pos in coords_or_quantity])
|
||||||
|
@ -11,19 +11,21 @@ class ZoneInit(Rule):
|
|||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
self._zones = list()
|
||||||
|
|
||||||
def on_init(self, state, lvl_map):
|
def on_init(self, state, lvl_map):
|
||||||
zones = []
|
|
||||||
z_idx = 1
|
z_idx = 1
|
||||||
|
|
||||||
while z_idx:
|
while z_idx:
|
||||||
zone_positions = lvl_map.get_coordinates_for_symbol(z_idx)
|
zone_positions = lvl_map.get_coordinates_for_symbol(z_idx)
|
||||||
if len(zone_positions):
|
if len(zone_positions):
|
||||||
zones.append(Zone(zone_positions))
|
self._zones.append(Zone(zone_positions))
|
||||||
z_idx += 1
|
z_idx += 1
|
||||||
else:
|
else:
|
||||||
z_idx = 0
|
z_idx = 0
|
||||||
state[z.ZONES].add_items(zones)
|
|
||||||
|
def on_reset(self, state):
|
||||||
|
state[z.ZONES].add_items(self._zones)
|
||||||
return []
|
return []
|
||||||
|
|
||||||
|
|
||||||
@ -32,7 +34,7 @@ class AgentSingleZonePlacement(Rule):
|
|||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
def on_init(self, state, lvl_map):
|
def on_reset(self, state):
|
||||||
n_agents = len(state[c.AGENT])
|
n_agents = len(state[c.AGENT])
|
||||||
assert len(state[z.ZONES]) >= n_agents
|
assert len(state[z.ZONES]) >= n_agents
|
||||||
|
|
||||||
@ -48,19 +50,16 @@ class AgentSingleZonePlacement(Rule):
|
|||||||
class IndividualDestinationZonePlacement(Rule):
|
class IndividualDestinationZonePlacement(Rule):
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
|
raise NotImplementedError("This is rpetty new, and needs to be debugged, after the zones")
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
def on_init(self, state, lvl_map):
|
def on_reset(self, state):
|
||||||
for agent in state[c.AGENT]:
|
for agent in state[c.AGENT]:
|
||||||
self.trigger_destination_spawn(agent, state)
|
self.trigger_spawn(agent, state)
|
||||||
pass
|
|
||||||
return []
|
|
||||||
|
|
||||||
def tick_step(self, state):
|
|
||||||
return []
|
return []
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def trigger_destination_spawn(agent, state):
|
def trigger_spawn(agent, state):
|
||||||
agent_zones = state[z.ZONES].by_pos(agent.pos)
|
agent_zones = state[z.ZONES].by_pos(agent.pos)
|
||||||
other_zones = [x for x in state[z.ZONES] if x not in agent_zones]
|
other_zones = [x for x in state[z.ZONES] if x not in agent_zones]
|
||||||
already_has_destination = True
|
already_has_destination = True
|
||||||
|
@ -3,6 +3,7 @@ from typing import List, Tuple
|
|||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
|
from marl_factory_grid.algorithms.static.utils import points_to_graph
|
||||||
from marl_factory_grid.environment import constants as c
|
from marl_factory_grid.environment import constants as c
|
||||||
from marl_factory_grid.environment.entity.entity import Entity
|
from marl_factory_grid.environment.entity.entity import Entity
|
||||||
from marl_factory_grid.environment.rules import Rule
|
from marl_factory_grid.environment.rules import Rule
|
||||||
@ -29,6 +30,12 @@ class StepRules:
|
|||||||
self.rules.append(item)
|
self.rules.append(item)
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
def do_all_reset(self, state):
|
||||||
|
for rule in self.rules:
|
||||||
|
if rule_reset_printline := rule.on_reset(state):
|
||||||
|
state.print(rule_reset_printline)
|
||||||
|
return c.VALID
|
||||||
|
|
||||||
def do_all_init(self, state, lvl_map):
|
def do_all_init(self, state, lvl_map):
|
||||||
for rule in self.rules:
|
for rule in self.rules:
|
||||||
if rule_init_printline := rule.on_init(state, lvl_map):
|
if rule_init_printline := rule.on_init(state, lvl_map):
|
||||||
@ -59,6 +66,13 @@ class StepRules:
|
|||||||
|
|
||||||
class Gamestate(object):
|
class Gamestate(object):
|
||||||
|
|
||||||
|
@property
|
||||||
|
def floortile_graph(self):
|
||||||
|
if not self._floortile_graph:
|
||||||
|
self.print("Generating Floorgraph....")
|
||||||
|
self._floortile_graph = points_to_graph(self.entities.floorlist)
|
||||||
|
return self._floortile_graph
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def moving_entites(self):
|
def moving_entites(self):
|
||||||
return [y for x in self.entities for y in x if x.var_can_move]
|
return [y for x in self.entities for y in x if x.var_can_move]
|
||||||
@ -72,6 +86,7 @@ class Gamestate(object):
|
|||||||
self.verbose = verbose
|
self.verbose = verbose
|
||||||
self.rng = np.random.default_rng(env_seed)
|
self.rng = np.random.default_rng(env_seed)
|
||||||
self.rules = StepRules(*rules)
|
self.rules = StepRules(*rules)
|
||||||
|
self._floortile_graph = None
|
||||||
self.tests = StepTests(*tests)
|
self.tests = StepTests(*tests)
|
||||||
|
|
||||||
def __getitem__(self, item):
|
def __getitem__(self, item):
|
||||||
|
Loading…
x
Reference in New Issue
Block a user