Redone the spawn procedute and destination objects

This commit is contained in:
Steffen Illium
2023-10-11 16:36:48 +02:00
parent e64fa84ef1
commit e326a95bf4
32 changed files with 266 additions and 146 deletions

@ -1,6 +1,9 @@
import abc
from random import shuffle
from typing import List
from marl_factory_grid.environment.entity.agent import Agent
from marl_factory_grid.utils import helpers as h
from marl_factory_grid.utils.results import TickResult, DoneResult
from marl_factory_grid.environment import rewards as r, constants as c
@ -36,6 +39,40 @@ class Rule(abc.ABC):
return []
class SpawnAgents(Rule):
def __init__(self):
super().__init__()
pass
def on_init(self, state, lvl_map):
agent_conf = state.agents_conf
# agents = Agents(lvl_map.size)
agents = state[c.AGENT]
empty_tiles = state[c.FLOORS].empty_tiles[:len(agent_conf)]
for agent_name in agent_conf:
actions = agent_conf[agent_name]['actions'].copy()
observations = agent_conf[agent_name]['observations'].copy()
positions = agent_conf[agent_name]['positions'].copy()
if positions:
shuffle(positions)
while True:
try:
tile = state[c.FLOORS].by_pos(positions.pop())
except IndexError as e:
raise ValueError(f'It was not possible to spawn an Agent on the available position: '
f'\n{agent_name[agent_name]["positions"].copy()}')
try:
agents.add_item(Agent(actions, observations, tile, str_ident=agent_name))
except AssertionError:
state.print(f'No valid pos:{tile.pos} for {agent_name}')
continue
break
else:
agents.add_item(Agent(actions, observations, empty_tiles.pop(), str_ident=agent_name))
pass
class MaxStepsReached(Rule):
def __init__(self, max_steps: int = 500):
@ -91,6 +128,8 @@ class Collision(Rule):
return results
def on_check_done(self, state) -> List[DoneResult]:
if self.curr_done and self.done_at_collisions:
inter_entity_collision_detected = self.curr_done and self.done_at_collisions
move_failed = any(h.is_move(x.state.identifier) and not x.state.validity for x in state[c.AGENT])
if inter_entity_collision_detected or move_failed:
return [DoneResult(validity=c.VALID, identifier=c.COLLISION, reward=r.COLLISION)]
return [DoneResult(validity=c.NOT_VALID, identifier=self.name, reward=0)]