mirror of
https://github.com/illiumst/marl-factory-grid.git
synced 2025-05-22 14:56:43 +02:00
Verified Collision Checks and Rendering
This commit is contained in:
parent
142bbb2c0c
commit
2f29ef703c
@ -10,6 +10,7 @@ Agents:
|
||||
- Doors
|
||||
- Other
|
||||
- DirtPiles
|
||||
Clones: 8
|
||||
|
||||
Juergen:
|
||||
Actions:
|
||||
|
@ -135,8 +135,10 @@ class Agent(Entity):
|
||||
i = self.collection.idx_by_entity(self)
|
||||
assert i is not None
|
||||
curr_state = self.state
|
||||
name = c.AGENT
|
||||
if curr_state.identifier == c.COLLISION:
|
||||
render_state = renderer.STATE_COLLISION
|
||||
name = renderer.STATE_COLLISION
|
||||
render_state=None
|
||||
elif curr_state.validity:
|
||||
if curr_state.identifier == c.NOOP:
|
||||
render_state = renderer.STATE_IDLE
|
||||
@ -147,4 +149,4 @@ class Agent(Entity):
|
||||
else:
|
||||
render_state = renderer.STATE_INVALID
|
||||
|
||||
return RenderEntity(c.AGENT, self.pos, 1, 'none', render_state, i + 1, real_name=self.name)
|
||||
return RenderEntity(name, self.pos, 1, 'none', render_state, i + 1, real_name=self.name)
|
||||
|
@ -27,6 +27,10 @@ class Agents(Collection):
|
||||
def var_has_position(self):
|
||||
return True
|
||||
|
||||
@property
|
||||
def var_can_collide(self):
|
||||
return True
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
|
@ -198,14 +198,14 @@ class WatchCollisions(Rule):
|
||||
|
||||
def tick_post_step(self, state) -> List[TickResult]:
|
||||
self.curr_done = False
|
||||
pos_with_collisions = state.get_collision_positions()
|
||||
results = list()
|
||||
for agent in state[c.AGENT]:
|
||||
a_s = agent.state
|
||||
if h.is_move(a_s.identifier) and a_s.action_introduced_collision:
|
||||
results.append(TickResult(entity=agent, identifier=c.COLLISION,
|
||||
reward=self.reward, validity=c.VALID))
|
||||
for pos in pos_with_collisions:
|
||||
|
||||
for pos in state.get_collision_positions():
|
||||
guests = [x for x in state.entities.pos_dict[pos] if x.var_can_collide]
|
||||
if len(guests) >= 2:
|
||||
for i, guest in enumerate(guests):
|
||||
|
@ -2,6 +2,7 @@ from marl_factory_grid.environment import constants as c
|
||||
from marl_factory_grid.environment.groups.collection import Collection
|
||||
from marl_factory_grid.modules.clean_up.entitites import DirtPile
|
||||
from marl_factory_grid.utils.results import Result
|
||||
from marl_factory_grid.utils import helpers as h
|
||||
|
||||
|
||||
class DirtPiles(Collection):
|
||||
@ -82,7 +83,7 @@ class DirtPiles(Collection):
|
||||
for idx, (pos, a) in enumerate(zip(n_new, amounts)):
|
||||
if not self.global_amount > self.max_global_amount:
|
||||
if dirt := self.by_pos(pos):
|
||||
dirt = next(dirt.iter())
|
||||
dirt = h.get_first(dirt)
|
||||
new_value = dirt.amount + a
|
||||
dirt.set_new_amount(new_value)
|
||||
else:
|
||||
|
@ -36,7 +36,7 @@ class ItemAction(Action):
|
||||
:rtype: ActionResult
|
||||
"""
|
||||
reward = self.valid_drop_off_reward if validity else self.failed_drop_off_reward
|
||||
return ActionResult(self.__name__, validity, reward=reward, entity=entity)
|
||||
return ActionResult(self.__class__.__name__, validity, reward=reward, entity=entity)
|
||||
|
||||
def do(self, entity, state) -> Union[None, ActionResult]:
|
||||
inventory = state[i.INVENTORY].by_entity(entity)
|
||||
|
Loading…
x
Reference in New Issue
Block a user