mirror of
https://github.com/illiumst/marl-factory-grid.git
synced 2025-05-23 15:26:43 +02:00
Smaller adjustments
This commit is contained in:
parent
0161197cd8
commit
e5dd49f0f0
@ -26,8 +26,14 @@ class BaseFactory(gym.Env):
|
||||
|
||||
@property
|
||||
def observation_space(self):
|
||||
agent_slice = self.n_agents if self.omit_agent_slice_in_obs else 0
|
||||
agent_slice = (self.n_agents - 1) if self.combin_agent_slices_in_obs else agent_slice
|
||||
if self.combin_agent_slices_in_obs:
|
||||
agent_slice = 1
|
||||
else: # not self.combin_agent_slices_in_obs:
|
||||
if self.omit_agent_slice_in_obs:
|
||||
agent_slice = self.n_agents - 1
|
||||
else: # not self.omit_agent_slice_in_obs:
|
||||
agent_slice = self.n_agents
|
||||
|
||||
if self.pomdp_radius:
|
||||
shape = (self._obs_cube.shape[0] - agent_slice, self.pomdp_radius * 2 + 1, self.pomdp_radius * 2 + 1)
|
||||
space = spaces.Box(low=0, high=1, shape=shape, dtype=np.float32)
|
||||
@ -289,7 +295,7 @@ class BaseFactory(gym.Env):
|
||||
return obs
|
||||
else:
|
||||
if self.omit_agent_slice_in_obs:
|
||||
obs_new = obs[[key for key, val in self._slices.items() if c.AGENT.value not in val.name]]
|
||||
obs_new = obs[[key for key, val in self._slices.items() if val.name != agent.name]]
|
||||
return obs_new
|
||||
else:
|
||||
return obs
|
||||
|
@ -146,6 +146,7 @@ class Entity(Object):
|
||||
def __init__(self, identifier, tile: Tile, **kwargs):
|
||||
super(Entity, self).__init__(identifier, **kwargs)
|
||||
self._tile = tile
|
||||
tile.enter(self)
|
||||
|
||||
def summarize_state(self):
|
||||
return self.__dict__.copy()
|
||||
|
@ -190,7 +190,7 @@ class SimpleFactory(BaseFactory):
|
||||
reward -= 0.00
|
||||
else:
|
||||
# self.print('collision')
|
||||
reward -= 0.05
|
||||
reward -= 0.01
|
||||
self.print(f'{agent.name} just hit the wall at {agent.pos}.')
|
||||
info_dict.update({f'{agent.name}_vs_LEVEL': 1})
|
||||
|
||||
@ -224,8 +224,12 @@ class SimpleFactory(BaseFactory):
|
||||
if __name__ == '__main__':
|
||||
render = True
|
||||
|
||||
move_props = MovementProperties(allow_diagonal_movement=True, allow_square_movement=True)
|
||||
dirt_props = DirtProperties(dirt_smear_amount=0.2)
|
||||
dirt_props = DirtProperties(clean_amount=3, gain_amount=0.3, max_global_amount=20,
|
||||
max_local_amount=2, spawn_frequency=3, max_spawn_ratio=0.05,
|
||||
dirt_smear_amount=0.0)
|
||||
move_props = MovementProperties(allow_diagonal_movement=True,
|
||||
allow_square_movement=True,
|
||||
allow_no_op=False)
|
||||
factory = SimpleFactory(movement_properties=move_props, dirt_properties=dirt_props, n_agents=1,
|
||||
combin_agent_slices_in_obs=False, level_name='rooms', parse_doors=True,
|
||||
pomdp_radius=3)
|
||||
|
8
main.py
8
main.py
@ -92,9 +92,9 @@ if __name__ == '__main__':
|
||||
from algorithms.reg_dqn import RegDQN
|
||||
# from sb3_contrib import QRDQN
|
||||
|
||||
dirt_props = DirtProperties(clean_amount=1, gain_amount=0.3, max_global_amount=20,
|
||||
max_local_amount=2, spawn_frequency=3, max_spawn_ratio=0.05,
|
||||
dirt_smear_amount=0.2)
|
||||
dirt_props = DirtProperties(clean_amount=6, gain_amount=1, max_global_amount=30,
|
||||
max_local_amount=5, spawn_frequency=5, max_spawn_ratio=0.05,
|
||||
dirt_smear_amount=0.0)
|
||||
move_props = MovementProperties(allow_diagonal_movement=True,
|
||||
allow_square_movement=True,
|
||||
allow_no_op=False)
|
||||
@ -136,7 +136,7 @@ if __name__ == '__main__':
|
||||
)]
|
||||
)
|
||||
|
||||
model.learn(total_timesteps=int(5e5), callback=callbacks)
|
||||
model.learn(total_timesteps=int(2e5), callback=callbacks)
|
||||
|
||||
save_path = out_path / f'model_{identifier}.zip'
|
||||
save_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
Loading…
x
Reference in New Issue
Block a user