mirror of
https://github.com/illiumst/marl-factory-grid.git
synced 2025-12-06 15:40:37 +01:00
added changes from code submission branch and coin entity
This commit is contained in:
@@ -168,14 +168,25 @@ class SpawnEntity(Rule):
|
||||
return results
|
||||
|
||||
|
||||
def _get_position(spawn_rule, positions, empty_positions, positions_pointer):
|
||||
"""
|
||||
Internal usage, selects positions based on rule.
|
||||
"""
|
||||
if spawn_rule and spawn_rule == "random":
|
||||
position = random.choice(([x for x in positions if x in empty_positions]))
|
||||
elif spawn_rule and spawn_rule == "order":
|
||||
position = ([x for x in positions if x in empty_positions])[positions_pointer]
|
||||
else:
|
||||
position = h.get_first([x for x in positions if x in empty_positions])
|
||||
return position
|
||||
|
||||
|
||||
class SpawnAgents(Rule):
|
||||
|
||||
def __init__(self):
|
||||
"""
|
||||
TODO
|
||||
|
||||
|
||||
:return:
|
||||
Finds suitable spawn positions according to the given spawn rule, creates agents with these positions and adds
|
||||
them to state.agents.
|
||||
"""
|
||||
super().__init__()
|
||||
pass
|
||||
@@ -183,8 +194,9 @@ class SpawnAgents(Rule):
|
||||
def on_reset(self, state):
|
||||
spawn_rule = None
|
||||
for rule in state.rules.rules:
|
||||
if isinstance(rule, marl_factory_grid.environment.rules.AgentSpawnRule):
|
||||
if isinstance(rule, AgentSpawnRule):
|
||||
spawn_rule = rule.spawn_rule
|
||||
break
|
||||
|
||||
if not hasattr(state, 'agent_spawn_positions'):
|
||||
state.agent_spawn_positions = []
|
||||
@@ -200,7 +212,7 @@ class SpawnAgents(Rule):
|
||||
other = agent_conf['other'].copy()
|
||||
positions_pointer = agent_conf['pos_pointer']
|
||||
|
||||
if position := self._get_position(spawn_rule, positions, empty_positions, positions_pointer):
|
||||
if position := _get_position(spawn_rule, positions, empty_positions, positions_pointer):
|
||||
assert state.check_pos_validity(position), 'smth went wrong....'
|
||||
agents.add_item(Agent(actions, observations, position, str_ident=agent_name, **other))
|
||||
state.agent_spawn_positions.append(position)
|
||||
@@ -213,21 +225,13 @@ class SpawnAgents(Rule):
|
||||
state.agent_spawn_positions.append(chosen_position)
|
||||
return []
|
||||
|
||||
def _get_position(self, spawn_rule, positions, empty_positions, positions_pointer):
|
||||
if spawn_rule and spawn_rule == "random":
|
||||
position = random.choice(([x for x in positions if x in empty_positions]))
|
||||
elif spawn_rule and spawn_rule == "order":
|
||||
position = ([x for x in positions if x in empty_positions])[positions_pointer]
|
||||
else:
|
||||
position = h.get_first([x for x in positions if x in empty_positions])
|
||||
|
||||
return position
|
||||
|
||||
class AgentSpawnRule(Rule):
|
||||
def __init__(self, spawn_rule):
|
||||
self.spawn_rule = spawn_rule
|
||||
super().__init__()
|
||||
|
||||
|
||||
class DoneAtMaxStepsReached(Rule):
|
||||
|
||||
def __init__(self, max_steps: int = 500):
|
||||
|
||||
Reference in New Issue
Block a user