Verified Collision Checks and Rendering

This commit is contained in:
Steffen Illium
2023-11-23 11:43:04 +01:00
parent 142bbb2c0c
commit 2f29ef703c
6 changed files with 14 additions and 6 deletions

View File

@@ -10,6 +10,7 @@ Agents:
- Doors - Doors
- Other - Other
- DirtPiles - DirtPiles
Clones: 8
Juergen: Juergen:
Actions: Actions:

View File

@@ -135,8 +135,10 @@ class Agent(Entity):
i = self.collection.idx_by_entity(self) i = self.collection.idx_by_entity(self)
assert i is not None assert i is not None
curr_state = self.state curr_state = self.state
name = c.AGENT
if curr_state.identifier == c.COLLISION: if curr_state.identifier == c.COLLISION:
render_state = renderer.STATE_COLLISION name = renderer.STATE_COLLISION
render_state=None
elif curr_state.validity: elif curr_state.validity:
if curr_state.identifier == c.NOOP: if curr_state.identifier == c.NOOP:
render_state = renderer.STATE_IDLE render_state = renderer.STATE_IDLE
@@ -147,4 +149,4 @@ class Agent(Entity):
else: else:
render_state = renderer.STATE_INVALID render_state = renderer.STATE_INVALID
return RenderEntity(c.AGENT, self.pos, 1, 'none', render_state, i + 1, real_name=self.name) return RenderEntity(name, self.pos, 1, 'none', render_state, i + 1, real_name=self.name)

View File

@@ -27,6 +27,10 @@ class Agents(Collection):
def var_has_position(self): def var_has_position(self):
return True return True
@property
def var_can_collide(self):
return True
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)

View File

@@ -198,14 +198,14 @@ class WatchCollisions(Rule):
def tick_post_step(self, state) -> List[TickResult]: def tick_post_step(self, state) -> List[TickResult]:
self.curr_done = False self.curr_done = False
pos_with_collisions = state.get_collision_positions()
results = list() results = list()
for agent in state[c.AGENT]: for agent in state[c.AGENT]:
a_s = agent.state a_s = agent.state
if h.is_move(a_s.identifier) and a_s.action_introduced_collision: if h.is_move(a_s.identifier) and a_s.action_introduced_collision:
results.append(TickResult(entity=agent, identifier=c.COLLISION, results.append(TickResult(entity=agent, identifier=c.COLLISION,
reward=self.reward, validity=c.VALID)) reward=self.reward, validity=c.VALID))
for pos in pos_with_collisions:
for pos in state.get_collision_positions():
guests = [x for x in state.entities.pos_dict[pos] if x.var_can_collide] guests = [x for x in state.entities.pos_dict[pos] if x.var_can_collide]
if len(guests) >= 2: if len(guests) >= 2:
for i, guest in enumerate(guests): for i, guest in enumerate(guests):

View File

@@ -2,6 +2,7 @@ from marl_factory_grid.environment import constants as c
from marl_factory_grid.environment.groups.collection import Collection from marl_factory_grid.environment.groups.collection import Collection
from marl_factory_grid.modules.clean_up.entitites import DirtPile from marl_factory_grid.modules.clean_up.entitites import DirtPile
from marl_factory_grid.utils.results import Result from marl_factory_grid.utils.results import Result
from marl_factory_grid.utils import helpers as h
class DirtPiles(Collection): class DirtPiles(Collection):
@@ -82,7 +83,7 @@ class DirtPiles(Collection):
for idx, (pos, a) in enumerate(zip(n_new, amounts)): for idx, (pos, a) in enumerate(zip(n_new, amounts)):
if not self.global_amount > self.max_global_amount: if not self.global_amount > self.max_global_amount:
if dirt := self.by_pos(pos): if dirt := self.by_pos(pos):
dirt = next(dirt.iter()) dirt = h.get_first(dirt)
new_value = dirt.amount + a new_value = dirt.amount + a
dirt.set_new_amount(new_value) dirt.set_new_amount(new_value)
else: else:

View File

@@ -36,7 +36,7 @@ class ItemAction(Action):
:rtype: ActionResult :rtype: ActionResult
""" """
reward = self.valid_drop_off_reward if validity else self.failed_drop_off_reward reward = self.valid_drop_off_reward if validity else self.failed_drop_off_reward
return ActionResult(self.__name__, validity, reward=reward, entity=entity) return ActionResult(self.__class__.__name__, validity, reward=reward, entity=entity)
def do(self, entity, state) -> Union[None, ActionResult]: def do(self, entity, state) -> Union[None, ActionResult]:
inventory = state[i.INVENTORY].by_entity(entity) inventory = state[i.INVENTORY].by_entity(entity)