mirror of
https://github.com/illiumst/marl-factory-grid.git
synced 2025-12-06 15:40:37 +01:00
new rules, new spawn logic, small fixes, default and narrow corridor debugged
This commit is contained in:
@@ -1,6 +1,6 @@
|
||||
import abc
|
||||
from random import shuffle
|
||||
from typing import List
|
||||
from typing import List, Collection, Union
|
||||
|
||||
from marl_factory_grid.environment.entity.agent import Agent
|
||||
from marl_factory_grid.utils import helpers as h
|
||||
@@ -39,6 +39,29 @@ class Rule(abc.ABC):
|
||||
return []
|
||||
|
||||
|
||||
class SpawnEntity(Rule):
|
||||
|
||||
@property
|
||||
def _collection(self) -> Collection:
|
||||
return Collection()
|
||||
|
||||
@property
|
||||
def name(self):
|
||||
return f'{self.__class__.__name__}({self.collection.name})'
|
||||
|
||||
def __init__(self, collection, coords_or_quantity, ignore_blocking=False):
|
||||
super().__init__()
|
||||
self.coords_or_quantity = coords_or_quantity
|
||||
self.collection = collection
|
||||
self.ignore_blocking = ignore_blocking
|
||||
|
||||
def on_init(self, state, lvl_map) -> [TickResult]:
|
||||
results = self.collection.trigger_spawn(state, ignore_blocking=self.ignore_blocking)
|
||||
pos_str = f' on: {[x.pos for x in self.collection]}' if self.collection.var_has_position else ''
|
||||
state.print(f'Initial {self.collection.__class__.__name__} were spawned{pos_str}')
|
||||
return results
|
||||
|
||||
|
||||
class SpawnAgents(Rule):
|
||||
|
||||
def __init__(self):
|
||||
@@ -46,14 +69,14 @@ class SpawnAgents(Rule):
|
||||
pass
|
||||
|
||||
def on_init(self, state, lvl_map):
|
||||
agent_conf = state.agents_conf
|
||||
# agents = Agents(lvl_map.size)
|
||||
agents = state[c.AGENT]
|
||||
empty_positions = state.entities.empty_positions()[:len(agent_conf)]
|
||||
for agent_name in agent_conf:
|
||||
actions = agent_conf[agent_name]['actions'].copy()
|
||||
observations = agent_conf[agent_name]['observations'].copy()
|
||||
positions = agent_conf[agent_name]['positions'].copy()
|
||||
empty_positions = state.entities.empty_positions[:len(state.agents_conf)]
|
||||
for agent_name, agent_conf in state.agents_conf.items():
|
||||
actions = agent_conf['actions'].copy()
|
||||
observations = agent_conf['observations'].copy()
|
||||
positions = agent_conf['positions'].copy()
|
||||
other = agent_conf['other'].copy()
|
||||
if positions:
|
||||
shuffle(positions)
|
||||
while True:
|
||||
@@ -61,18 +84,18 @@ class SpawnAgents(Rule):
|
||||
pos = positions.pop()
|
||||
except IndexError:
|
||||
raise ValueError(f'It was not possible to spawn an Agent on the available position: '
|
||||
f'\n{agent_name[agent_name]["positions"].copy()}')
|
||||
if agents.by_pos(pos) and state.check_pos_validity(pos):
|
||||
f'\n{agent_conf["positions"].copy()}')
|
||||
if bool(agents.by_pos(pos)) or not state.check_pos_validity(pos):
|
||||
continue
|
||||
else:
|
||||
agents.add_item(Agent(actions, observations, pos, str_ident=agent_name))
|
||||
agents.add_item(Agent(actions, observations, pos, str_ident=agent_name, **other))
|
||||
break
|
||||
else:
|
||||
agents.add_item(Agent(actions, observations, empty_positions.pop(), str_ident=agent_name))
|
||||
agents.add_item(Agent(actions, observations, empty_positions.pop(), str_ident=agent_name, **other))
|
||||
pass
|
||||
|
||||
|
||||
class MaxStepsReached(Rule):
|
||||
class DoneAtMaxStepsReached(Rule):
|
||||
|
||||
def __init__(self, max_steps: int = 500):
|
||||
super().__init__()
|
||||
@@ -83,8 +106,8 @@ class MaxStepsReached(Rule):
|
||||
|
||||
def on_check_done(self, state):
|
||||
if self.max_steps <= state.curr_step:
|
||||
return [DoneResult(validity=c.VALID, identifier=self.name, reward=0)]
|
||||
return [DoneResult(validity=c.NOT_VALID, identifier=self.name, reward=0)]
|
||||
return [DoneResult(validity=c.VALID, identifier=self.name)]
|
||||
return [DoneResult(validity=c.NOT_VALID, identifier=self.name)]
|
||||
|
||||
|
||||
class AssignGlobalPositions(Rule):
|
||||
@@ -101,7 +124,7 @@ class AssignGlobalPositions(Rule):
|
||||
return []
|
||||
|
||||
|
||||
class Collision(Rule):
|
||||
class WatchCollisions(Rule):
|
||||
|
||||
def __init__(self, done_at_collisions: bool = False):
|
||||
super().__init__()
|
||||
@@ -132,4 +155,4 @@ class Collision(Rule):
|
||||
move_failed = any(h.is_move(x.state.identifier) and not x.state.validity for x in state[c.AGENT])
|
||||
if inter_entity_collision_detected or move_failed:
|
||||
return [DoneResult(validity=c.VALID, identifier=c.COLLISION, reward=r.COLLISION)]
|
||||
return [DoneResult(validity=c.NOT_VALID, identifier=self.name, reward=0)]
|
||||
return [DoneResult(validity=c.NOT_VALID, identifier=self.name)]
|
||||
|
||||
Reference in New Issue
Block a user