mirror of
https://github.com/illiumst/marl-factory-grid.git
synced 2025-07-05 17:11:35 +02:00
fixed globalpos and 2rooms1door destination spawning
This commit is contained in:
@ -39,7 +39,15 @@ Agents:
|
|||||||
- Doors
|
- Doors
|
||||||
|
|
||||||
Entities:
|
Entities:
|
||||||
Destinations: { }
|
Destinations:
|
||||||
|
spawnrule:
|
||||||
|
SpawnDestinationsPerAgent:
|
||||||
|
coords_or_quantity:
|
||||||
|
Wolfgang:
|
||||||
|
- (6,12)
|
||||||
|
Sigmund:
|
||||||
|
- (6, 2)
|
||||||
|
|
||||||
Doors: { }
|
Doors: { }
|
||||||
GlobalPositions: { }
|
GlobalPositions: { }
|
||||||
|
|
||||||
@ -57,5 +65,5 @@ Rules:
|
|||||||
AssignGlobalPositions: { }
|
AssignGlobalPositions: { }
|
||||||
|
|
||||||
# Done Conditions
|
# Done Conditions
|
||||||
MaxStepsReached:
|
DoneAtMaxStepsReached:
|
||||||
max_steps: 10
|
max_steps: 10
|
||||||
|
@ -234,8 +234,12 @@ class AssignGlobalPositions(Rule):
|
|||||||
:return: None
|
:return: None
|
||||||
"""
|
"""
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
self.level_shape = None
|
||||||
|
|
||||||
def on_reset(self, state, lvl_map):
|
def on_init(self, state, lvl_map):
|
||||||
|
self.level_shape = lvl_map.level_shape
|
||||||
|
|
||||||
|
def on_reset(self, state):
|
||||||
"""
|
"""
|
||||||
Assign global positions to agents when the environment is reset.
|
Assign global positions to agents when the environment is reset.
|
||||||
|
|
||||||
@ -248,7 +252,7 @@ class AssignGlobalPositions(Rule):
|
|||||||
"""
|
"""
|
||||||
from marl_factory_grid.environment.entity.util import GlobalPosition
|
from marl_factory_grid.environment.entity.util import GlobalPosition
|
||||||
for agent in state[c.AGENT]:
|
for agent in state[c.AGENT]:
|
||||||
gp = GlobalPosition(agent, lvl_map.level_shape)
|
gp = GlobalPosition(agent, self.level_shape)
|
||||||
state[c.GLOBALPOSITIONS].add_item(gp)
|
state[c.GLOBALPOSITIONS].add_item(gp)
|
||||||
return []
|
return []
|
||||||
|
|
||||||
|
@ -32,7 +32,7 @@ if __name__ == '__main__':
|
|||||||
ce.save_all(run_path / 'all_available_configs.yaml')
|
ce.save_all(run_path / 'all_available_configs.yaml')
|
||||||
|
|
||||||
# Path to config File
|
# Path to config File
|
||||||
path = Path('marl_factory_grid/configs/default_config.yaml')
|
path = Path('marl_factory_grid/configs/two_rooms_one_door.yaml')
|
||||||
|
|
||||||
# Env Init
|
# Env Init
|
||||||
factory = Factory(path)
|
factory = Factory(path)
|
||||||
@ -50,7 +50,7 @@ if __name__ == '__main__':
|
|||||||
if render:
|
if render:
|
||||||
factory.render()
|
factory.render()
|
||||||
action_spaces = factory.action_space
|
action_spaces = factory.action_space
|
||||||
agents = [TSPDirtAgent(factory, 0), TSPDirtAgent(factory, 1), TSPDirtAgent(factory, 2)]
|
# agents = [TSPDirtAgent(factory, 0), TSPDirtAgent(factory, 1), TSPDirtAgent(factory, 2)]
|
||||||
while not done:
|
while not done:
|
||||||
a = [randint(0, x.n - 1) for x in action_spaces]
|
a = [randint(0, x.n - 1) for x in action_spaces]
|
||||||
obs_type, _, reward, done, info = factory.step(a)
|
obs_type, _, reward, done, info = factory.step(a)
|
||||||
|
@ -9,10 +9,10 @@ from marl_factory_grid.environment.factory import Factory
|
|||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
# Render at each step?
|
# Render at each step?
|
||||||
render = False
|
render = True
|
||||||
|
|
||||||
# Path to config File
|
# Path to config File
|
||||||
path = Path('marl_factory_grid/configs/test_config.yaml')
|
path = Path('marl_factory_grid/configs/two_rooms_one_door.yaml')
|
||||||
|
|
||||||
# Env Init
|
# Env Init
|
||||||
factory = Factory(path)
|
factory = Factory(path)
|
||||||
@ -23,7 +23,8 @@ if __name__ == '__main__':
|
|||||||
if render:
|
if render:
|
||||||
factory.render()
|
factory.render()
|
||||||
action_spaces = factory.action_space
|
action_spaces = factory.action_space
|
||||||
agents = [TSPDirtAgent(factory, 0), TSPItemAgent(factory, 1), TSPTargetAgent(factory, 2)]
|
# agents = [TSPDirtAgent(factory, 0), TSPItemAgent(factory, 1), TSPTargetAgent(factory, 2)]
|
||||||
|
agents = [TSPTargetAgent(factory, 0), TSPTargetAgent(factory, 1)]
|
||||||
while not done:
|
while not done:
|
||||||
a = [x.predict() for x in agents]
|
a = [x.predict() for x in agents]
|
||||||
obs_type, _, _, done, info = factory.step(a)
|
obs_type, _, _, done, info = factory.step(a)
|
||||||
|
Reference in New Issue
Block a user