mirror of
https://github.com/illiumst/marl-factory-grid.git
synced 2025-09-17 16:12:00 +02:00
Remove BoundDestination Object
New Variable 'var_can_be_bound' Observations adjusted accordingly
This commit is contained in:
@@ -8,13 +8,18 @@ from marl_factory_grid.utils.render import RenderEntity
|
||||
from marl_factory_grid.modules.destinations import constants as d
|
||||
|
||||
|
||||
class Destination(Entity):
|
||||
class Destination(BoundEntityMixin, Entity):
|
||||
|
||||
var_can_move = False
|
||||
var_can_collide = False
|
||||
var_has_position = True
|
||||
var_is_blocking_pos = False
|
||||
var_is_blocking_light = False
|
||||
var_can_be_bound = True # Introduce this globally!
|
||||
|
||||
@property
|
||||
def was_reached(self):
|
||||
return self._was_reached
|
||||
|
||||
@property
|
||||
def encoding(self):
|
||||
@@ -22,6 +27,7 @@ class Destination(Entity):
|
||||
|
||||
def __init__(self, *args, action_counts=0, **kwargs):
|
||||
super(Destination, self).__init__(*args, **kwargs)
|
||||
self._was_reached = False
|
||||
self.action_counts = action_counts
|
||||
self._per_agent_actions = defaultdict(lambda: 0)
|
||||
|
||||
@@ -30,9 +36,15 @@ class Destination(Entity):
|
||||
return c.VALID
|
||||
|
||||
@property
|
||||
def is_considered_reached(self):
|
||||
agent_at_position = any(c.AGENT.lower() in x.name.lower() for x in self.tile.guests_that_can_collide)
|
||||
return agent_at_position or any(x >= self.action_counts for x in self._per_agent_actions.values())
|
||||
def has_just_been_reached(self):
|
||||
if self.was_reached:
|
||||
return False
|
||||
agent_at_position = any(self.bound_entity == x for x in self.tile.guests_that_can_collide)
|
||||
if self.bound_entity:
|
||||
return ((agent_at_position and not self.action_counts)
|
||||
or self._per_agent_actions[self.bound_entity.name] >= self.action_counts >= 1)
|
||||
else:
|
||||
return agent_at_position or any(x >= self.action_counts for x in self._per_agent_actions.values())
|
||||
|
||||
def agent_did_action(self, agent: Agent):
|
||||
return self._per_agent_actions[agent.name] >= self.action_counts
|
||||
@@ -44,21 +56,10 @@ class Destination(Entity):
|
||||
return state_summary
|
||||
|
||||
def render(self):
|
||||
return RenderEntity(d.DESTINATION, self.pos)
|
||||
if self.was_reached:
|
||||
return None
|
||||
else:
|
||||
return RenderEntity(d.DESTINATION, self.pos)
|
||||
|
||||
|
||||
class BoundDestination(BoundEntityMixin, Destination):
|
||||
|
||||
@property
|
||||
def encoding(self):
|
||||
return d.DEST_SYMBOL
|
||||
|
||||
def __init__(self, entity, *args, **kwargs):
|
||||
self.bind_to(entity)
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
@property
|
||||
def is_considered_reached(self):
|
||||
agent_at_position = any(self.bound_entity == x for x in self.tile.guests_that_can_collide)
|
||||
return ((agent_at_position and not self.action_counts)
|
||||
or self._per_agent_actions[self.bound_entity.name] >= self.action_counts >= 1)
|
||||
def mark_as_reached(self):
|
||||
self._was_reached = True
|
||||
|
Reference in New Issue
Block a user