mirror of
https://github.com/illiumst/marl-factory-grid.git
synced 2025-12-20 05:56:07 +01:00
Changed Collision Check
This commit is contained in:
@@ -1,7 +1,6 @@
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
|
||||||
import yaml
|
import yaml
|
||||||
|
|
||||||
|
|
||||||
@@ -74,6 +73,7 @@ class Checkpointer(object):
|
|||||||
def save_experiment(self, name: str, model):
|
def save_experiment(self, name: str, model):
|
||||||
cpt_path = self.path / f'checkpoint_{self.__current_checkpoint}'
|
cpt_path = self.path / f'checkpoint_{self.__current_checkpoint}'
|
||||||
cpt_path.mkdir(exist_ok=True, parents=True)
|
cpt_path.mkdir(exist_ok=True, parents=True)
|
||||||
|
import torch
|
||||||
torch.save(model.state_dict(), cpt_path / f'{name}.pt')
|
torch.save(model.state_dict(), cpt_path / f'{name}.pt')
|
||||||
|
|
||||||
def step(self, to_save):
|
def step(self, to_save):
|
||||||
|
|||||||
@@ -28,9 +28,10 @@ class Action(abc.ABC):
|
|||||||
def __repr__(self):
|
def __repr__(self):
|
||||||
return f'Action[{self._identifier}]'
|
return f'Action[{self._identifier}]'
|
||||||
|
|
||||||
def get_result(self, validity, entity):
|
def get_result(self, validity, entity, action_introduced_collision=False):
|
||||||
reward = self.valid_reward if validity else self.fail_reward
|
reward = self.valid_reward if validity else self.fail_reward
|
||||||
return ActionResult(self.__class__.__name__, validity, reward=reward, entity=entity)
|
return ActionResult(self.__class__.__name__, validity, reward=reward, entity=entity,
|
||||||
|
action_introduced_collision=action_introduced_collision)
|
||||||
|
|
||||||
|
|
||||||
class Noop(Action):
|
class Noop(Action):
|
||||||
@@ -50,24 +51,24 @@ class Move(Action, abc.ABC):
|
|||||||
|
|
||||||
def do(self, entity, state):
|
def do(self, entity, state):
|
||||||
new_pos = self._calc_new_pos(entity.pos)
|
new_pos = self._calc_new_pos(entity.pos)
|
||||||
|
collision = False
|
||||||
if state.check_move_validity(entity, new_pos):
|
if state.check_move_validity(entity, new_pos):
|
||||||
valid = entity.move(new_pos, state)
|
valid = entity.move(new_pos, state)
|
||||||
# Aftermath Collision Check
|
# Aftermath Collision Check
|
||||||
if len([x for x in state.entities.by_pos(entity.pos) if x.var_can_collide]) > 1:
|
if len([x for x in state.entities.by_pos(entity.pos) if x.var_can_collide]) > 1:
|
||||||
# The entity did move, but there was something to collide with...
|
# The entity did move, but there was something to collide with...
|
||||||
# Is then reported as a non-valid move, which did work.
|
collision = True
|
||||||
valid = False
|
|
||||||
|
|
||||||
else:
|
else:
|
||||||
# There is no place to go, propably collision
|
# There is no place to go, propably collision
|
||||||
# This is currently handeld by the WatchCollisions rule, so that it can be switched on and off by conf.yml
|
# This is currently handeld by the WatchCollisions rule, so that it can be switched on and off by conf.yml
|
||||||
# return ActionResult(entity=entity, identifier=self._identifier, validity=c.NOT_VALID, reward=r.COLLISION)
|
|
||||||
valid = c.NOT_VALID
|
valid = c.NOT_VALID
|
||||||
|
collision = True
|
||||||
if valid:
|
if valid:
|
||||||
state.print(f'{entity.name} just moved to {entity.pos}.')
|
state.print(f'{entity.name} just moved to {entity.pos}.')
|
||||||
else:
|
else:
|
||||||
state.print(f'{entity.name} just tried to move to {new_pos} but either failed or hat a Collision.')
|
state.print(f'{entity.name} just tried to move to {new_pos} but either failed or hat a Collision.')
|
||||||
return self.get_result(valid, entity)
|
return self.get_result(valid, entity, action_introduced_collision=collision)
|
||||||
|
|
||||||
def _calc_new_pos(self, pos):
|
def _calc_new_pos(self, pos):
|
||||||
x_diff, y_diff = MOVEMAP[self._identifier]
|
x_diff, y_diff = MOVEMAP[self._identifier]
|
||||||
|
|||||||
@@ -200,6 +200,11 @@ class WatchCollisions(Rule):
|
|||||||
self.curr_done = False
|
self.curr_done = False
|
||||||
pos_with_collisions = state.get_collision_positions()
|
pos_with_collisions = state.get_collision_positions()
|
||||||
results = list()
|
results = list()
|
||||||
|
for agent in state[c.AGENT]:
|
||||||
|
a_s = agent.state
|
||||||
|
if h.is_move(a_s.identifier) and a_s.action_introduced_collision:
|
||||||
|
results.append(TickResult(entity=agent, identifier=c.COLLISION,
|
||||||
|
reward=self.reward, validity=c.VALID))
|
||||||
for pos in pos_with_collisions:
|
for pos in pos_with_collisions:
|
||||||
guests = [x for x in state.entities.pos_dict[pos] if x.var_can_collide]
|
guests = [x for x in state.entities.pos_dict[pos] if x.var_can_collide]
|
||||||
if len(guests) >= 2:
|
if len(guests) >= 2:
|
||||||
@@ -210,6 +215,7 @@ class WatchCollisions(Rule):
|
|||||||
)
|
)
|
||||||
except AttributeError:
|
except AttributeError:
|
||||||
pass
|
pass
|
||||||
|
if not any([x.entity == guest for x in results]):
|
||||||
results.append(TickResult(entity=guest, identifier=c.COLLISION,
|
results.append(TickResult(entity=guest, identifier=c.COLLISION,
|
||||||
reward=self.reward, validity=c.VALID))
|
reward=self.reward, validity=c.VALID))
|
||||||
self.curr_done = True if self.done_at_collisions else False
|
self.curr_done = True if self.done_at_collisions else False
|
||||||
@@ -218,7 +224,9 @@ class WatchCollisions(Rule):
|
|||||||
def on_check_done(self, state) -> List[DoneResult]:
|
def on_check_done(self, state) -> List[DoneResult]:
|
||||||
if self.done_at_collisions:
|
if self.done_at_collisions:
|
||||||
inter_entity_collision_detected = self.curr_done
|
inter_entity_collision_detected = self.curr_done
|
||||||
move_failed = any(h.is_move(x.state.identifier) and not x.state.validity for x in state[c.AGENT])
|
collision_in_step = any(h.is_move(x.state.identifier) and x.state.action_introduced_collision
|
||||||
if inter_entity_collision_detected or move_failed:
|
for x in state[c.AGENT]
|
||||||
|
)
|
||||||
|
if inter_entity_collision_detected or collision_in_step:
|
||||||
return [DoneResult(validity=c.VALID, identifier=c.COLLISION, reward=self.reward_at_done)]
|
return [DoneResult(validity=c.VALID, identifier=c.COLLISION, reward=self.reward_at_done)]
|
||||||
return []
|
return []
|
||||||
|
|||||||
@@ -58,9 +58,16 @@ class Result:
|
|||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class ActionResult(Result):
|
class ActionResult(Result):
|
||||||
|
def __init__(self, *args, action_introduced_collision: bool = False, **kwargs):
|
||||||
"""
|
"""
|
||||||
A specific Result class representing outcomes of actions.
|
A specific Result class representing outcomes of actions.
|
||||||
|
|
||||||
|
:param action_introduced_collision: Wether the action did introduce a colision between agents or other entities.
|
||||||
|
These need to be able to collide.
|
||||||
"""
|
"""
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
self.action_introduced_collision = action_introduced_collision
|
||||||
|
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user