mirror of
https://github.com/illiumst/marl-factory-grid.git
synced 2025-06-24 04:11:36 +02:00
Merge branch 'main' into refactor_rename
# Conflicts: # marl_factory_grid/environment/entity/entity.py # marl_factory_grid/modules/destinations/entitites.py # marl_factory_grid/modules/doors/entitites.py # marl_factory_grid/modules/items/groups.py
This commit is contained in:
@ -4,7 +4,7 @@ from marl_factory_grid.environment.entity.agent import Agent
|
||||
from marl_factory_grid.environment.entity.entity import Entity
|
||||
from marl_factory_grid.environment import constants as c
|
||||
from marl_factory_grid.environment.entity.mixin import BoundEntityMixin
|
||||
from marl_factory_grid.utils.render import RenderEntity
|
||||
from marl_factory_grid.utils.utility_classes import RenderEntity
|
||||
from marl_factory_grid.modules.destinations import constants as d
|
||||
|
||||
|
||||
@ -34,7 +34,6 @@ class Destination(Entity):
|
||||
def var_can_be_bound(self):
|
||||
return True
|
||||
|
||||
@property
|
||||
def was_reached(self):
|
||||
return self._was_reached
|
||||
|
||||
@ -52,12 +51,10 @@ class Destination(Entity):
|
||||
self._per_agent_actions[agent.name] += 1
|
||||
return c.VALID
|
||||
|
||||
@property
|
||||
def has_just_been_reached(self):
|
||||
if self.was_reached:
|
||||
def has_just_been_reached(self, state):
|
||||
if self.was_reached():
|
||||
return False
|
||||
agent_at_position = any(
|
||||
c.AGENT.lower() in x.name.lower() for x in state.entities.pos_dict[self.pos] if x.var_can_collide)
|
||||
agent_at_position = any(state[c.AGENT].by_pos(self.pos))
|
||||
|
||||
if self.bound_entity:
|
||||
return ((agent_at_position and not self.action_counts)
|
||||
@ -75,7 +72,7 @@ class Destination(Entity):
|
||||
return state_summary
|
||||
|
||||
def render(self):
|
||||
if self.was_reached:
|
||||
if self.was_reached():
|
||||
return None
|
||||
else:
|
||||
return RenderEntity(d.DESTINATION, self.pos)
|
||||
|
@ -16,28 +16,29 @@ class DestinationReachAll(Rule):
|
||||
|
||||
def tick_step(self, state) -> List[TickResult]:
|
||||
results = []
|
||||
reached = False
|
||||
for dest in state[d.DESTINATION]:
|
||||
if dest.has_just_been_reached and not dest.was_reached:
|
||||
# Dest has just been reached, some agent needs to stand here, grab any first.
|
||||
if dest.has_just_been_reached(state) and not dest.was_reached():
|
||||
# Dest has just been reached, some agent needs to stand here
|
||||
for agent in state[c.AGENT].by_pos(dest.pos):
|
||||
if dest.bound_entity:
|
||||
if dest.bound_entity == agent:
|
||||
results.append(TickResult(self.name, validity=c.VALID, reward=r.DEST_REACHED, entity=agent))
|
||||
reached = True
|
||||
else:
|
||||
pass
|
||||
else:
|
||||
results.append(TickResult(self.name, validity=c.VALID, reward=r.DEST_REACHED, entity=agent))
|
||||
state.print(f'{dest.name} is reached now, mark as reached...')
|
||||
dest.mark_as_reached()
|
||||
reached = True
|
||||
else:
|
||||
pass
|
||||
if reached:
|
||||
state.print(f'{dest.name} is reached now, mark as reached...')
|
||||
dest.mark_as_reached()
|
||||
results.append(TickResult(self.name, validity=c.VALID, reward=r.DEST_REACHED, entity=agent))
|
||||
return results
|
||||
|
||||
def tick_post_step(self, state) -> List[TickResult]:
|
||||
return []
|
||||
|
||||
def on_check_done(self, state) -> List[DoneResult]:
|
||||
if all(x.was_reached for x in state[d.DESTINATION]):
|
||||
if all(x.was_reached() for x in state[d.DESTINATION]):
|
||||
return [DoneResult(self.name, validity=c.VALID, reward=r.DEST_REACHED)]
|
||||
return [DoneResult(self.name, validity=c.NOT_VALID, reward=0)]
|
||||
|
||||
@ -48,7 +49,7 @@ class DestinationReachAny(DestinationReachAll):
|
||||
super(DestinationReachAny, self).__init__()
|
||||
|
||||
def on_check_done(self, state) -> List[DoneResult]:
|
||||
if any(x.was_reached for x in state[d.DESTINATION]):
|
||||
if any(x.was_reached() for x in state[d.DESTINATION]):
|
||||
return [DoneResult(self.name, validity=c.VALID, reward=r.DEST_REACHED)]
|
||||
return []
|
||||
|
||||
@ -63,7 +64,7 @@ class DestinationSpawn(Rule):
|
||||
|
||||
def on_init(self, state, lvl_map):
|
||||
# noinspection PyAttributeOutsideInit
|
||||
self.trigger_destination_spawn(self.n_dests, state)
|
||||
state[d.DESTINATION].trigger_destination_spawn(self.n_dests, state)
|
||||
pass
|
||||
|
||||
def tick_pre_step(self, state) -> List[TickResult]:
|
||||
@ -72,24 +73,14 @@ class DestinationSpawn(Rule):
|
||||
def tick_step(self, state) -> List[TickResult]:
|
||||
if n_dest_spawn := max(0, self.n_dests - len(state[d.DESTINATION])):
|
||||
if self.spawn_mode == d.MODE_GROUPED and n_dest_spawn == self.n_dests:
|
||||
validity = self.trigger_destination_spawn(n_dest_spawn, state)
|
||||
validity = state[d.DESTINATION].trigger_destination_spawn(n_dest_spawn, state)
|
||||
return [TickResult(self.name, validity=validity, entity=None, value=n_dest_spawn)]
|
||||
elif self.spawn_mode == d.MODE_SINGLE and n_dest_spawn:
|
||||
validity = self.trigger_destination_spawn(n_dest_spawn, state)
|
||||
validity = state[d.DESTINATION].trigger_destination_spawn(n_dest_spawn, state)
|
||||
return [TickResult(self.name, validity=validity, entity=None, value=n_dest_spawn)]
|
||||
else:
|
||||
pass
|
||||
|
||||
def trigger_destination_spawn(self, n_dests, state):
|
||||
empty_positions = state[c.FLOORS].empty_tiles[:n_dests]
|
||||
if destinations := [Destination(pos) for pos in empty_positions]:
|
||||
state[d.DESTINATION].add_items(destinations)
|
||||
state.print(f'{n_dests} new destinations have been spawned')
|
||||
return c.VALID
|
||||
else:
|
||||
state.print('No Destiantions are spawning, limit is reached.')
|
||||
return c.NOT_VALID
|
||||
|
||||
|
||||
class FixedDestinationSpawn(Rule):
|
||||
def __init__(self, per_agent_positions: Dict[str, List[Tuple[int, int]]]):
|
||||
@ -99,11 +90,17 @@ class FixedDestinationSpawn(Rule):
|
||||
def on_init(self, state, lvl_map):
|
||||
for (agent_name, position_list) in self.per_agent_positions.items():
|
||||
agent = next(x for x in state[c.AGENT] if agent_name in x.name) # Fixme: Ugly AF
|
||||
position_list = position_list.copy()
|
||||
shuffle(position_list)
|
||||
while True:
|
||||
pos = position_list.pop()
|
||||
if pos != agent.pos and not state[d.DESTINATION].by_pos(pos):
|
||||
destination = Destination(state[c.FLOORS].by_pos(pos), bind_to=agent)
|
||||
try:
|
||||
pos = position_list.pop()
|
||||
except IndexError:
|
||||
print(f"Could not spawn Destinations at: {self.per_agent_positions[agent_name]}")
|
||||
print(f'Check your agent palcement: {state[c.AGENT]} ... Exit ...')
|
||||
exit(9999)
|
||||
if (not pos == agent.pos) and (not state[d.DESTINATION].by_pos(pos)):
|
||||
destination = Destination(pos, bind_to=agent)
|
||||
break
|
||||
else:
|
||||
continue
|
||||
|
Reference in New Issue
Block a user