mirror of
https://github.com/illiumst/marl-factory-grid.git
synced 2025-09-13 22:44:00 +02:00
yaml working
This commit is contained in:
@@ -1,4 +1,4 @@
|
||||
import pickle
|
||||
from argparse import Namespace
|
||||
from pathlib import Path
|
||||
from typing import List, Union, Iterable, NamedTuple
|
||||
|
||||
@@ -89,7 +89,8 @@ class Actions(Register):
|
||||
self.allow_diagonal_movement = movement_properties.allow_diagonal_movement
|
||||
self.allow_square_movement = movement_properties.allow_square_movement
|
||||
# FIXME: There is a bug in helpers because there actions are ints. and the order matters.
|
||||
assert not(self.allow_square_movement is False and self.allow_diagonal_movement is True), "There is a bug in helpers!!!"
|
||||
assert not(self.allow_square_movement is False and self.allow_diagonal_movement is True), \
|
||||
"There is a bug in helpers!!!"
|
||||
super(Actions, self).__init__()
|
||||
|
||||
if self.allow_square_movement:
|
||||
@@ -109,11 +110,14 @@ class StateSlice(Register):
|
||||
self.register_additional_items(['level', *[f'agent#{i}' for i in range(offset, n_agents+offset)]])
|
||||
|
||||
|
||||
# noinspection PyAttributeOutsideInit
|
||||
class BaseFactory(gym.Env):
|
||||
|
||||
# def __setattr__(self, key, value):
|
||||
# if isinstance(value, dict):
|
||||
|
||||
def __setattr__(self, key, value):
|
||||
if isinstance(value, dict):
|
||||
super(BaseFactory, self).__setattr__(key, Namespace(**value))
|
||||
else:
|
||||
super(BaseFactory, self).__setattr__(key, value)
|
||||
|
||||
@property
|
||||
def action_space(self):
|
||||
@@ -199,8 +203,8 @@ class BaseFactory(gym.Env):
|
||||
obs_padded = np.full((obs.shape[0], self.pomdp_radius ** 2 + 1, self.pomdp_radius ** 2 + 1), 1)
|
||||
a_pos = np.argwhere(obs[h.AGENT_START_IDX] == h.IS_OCCUPIED_CELL)[0]
|
||||
obs_padded[:,
|
||||
abs(a_pos[0]-self.pomdp_radius):abs(a_pos[0]-self.pomdp_radius)+obs.shape[1],
|
||||
abs(a_pos[1]-self.pomdp_radius):abs(a_pos[1]-self.pomdp_radius)+obs.shape[2]] = obs
|
||||
abs(a_pos[0]-self.pomdp_radius):abs(a_pos[0]-self.pomdp_radius)+obs.shape[1],
|
||||
abs(a_pos[1]-self.pomdp_radius):abs(a_pos[1]-self.pomdp_radius)+obs.shape[2]] = obs
|
||||
obs = obs_padded
|
||||
else:
|
||||
obs = self._state
|
||||
@@ -323,9 +327,11 @@ class BaseFactory(gym.Env):
|
||||
raise NotImplementedError
|
||||
|
||||
def save_params(self, filepath: Path):
|
||||
d = {key: val for key, val in self.__dict__.items() if not key.startswith('_') and not key.startswith('__')}
|
||||
# noinspection PyProtectedMember
|
||||
d = {key: val._asdict() if hasattr(val, '_asdict') else val for key, val in self.__dict__.items()
|
||||
if not key.startswith('_') and not key.startswith('__')}
|
||||
filepath.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
with filepath.open('wb') as f:
|
||||
# yaml.dump(d, f)
|
||||
pickle.dump(d, f, protocol=pickle.HIGHEST_PROTOCOL)
|
||||
with filepath.open('w') as f:
|
||||
yaml.dump(d, f)
|
||||
# pickle.dump(d, f, protocol=pickle.HIGHEST_PROTOCOL)
|
||||
|
@@ -185,17 +185,21 @@ class SimpleFactory(BaseFactory):
|
||||
|
||||
if __name__ == '__main__':
|
||||
render = True
|
||||
import yaml
|
||||
with Path(r'C:\Users\steff\projects\f_iks\debug_out\yaml.txt').open('r') as f:
|
||||
env_kwargs = yaml.load(f)
|
||||
factory = SimpleFactory(**env_kwargs)
|
||||
|
||||
dirt_props = DirtProperties()
|
||||
move_props = MovementProperties(allow_diagonal_movement=False, allow_no_op=False)
|
||||
factory = SimpleFactory(n_agents=2, dirt_properties=dirt_props, movement_properties=move_props, level='rooms',
|
||||
pomdp_radius=2)
|
||||
# dirt_props = DirtProperties()
|
||||
# move_props = MovementProperties(allow_diagonal_movement=False, allow_no_op=False)
|
||||
# factory = SimpleFactory(n_agents=2, dirt_properties=dirt_props, movement_properties=move_props, level='rooms',
|
||||
# pomdp_radius=2)
|
||||
|
||||
n_actions = factory.action_space.n - 1
|
||||
|
||||
for epoch in range(100):
|
||||
random_actions = [(random.randint(0, n_actions), random.randint(0, n_actions)) for _ in range(200)]
|
||||
env_state, this_reward, done_bool, _ = factory.reset()
|
||||
random_actions = [[random.randint(0, n_actions) for _ in range(factory.n_agents)] for _ in range(200)]
|
||||
env_state = factory.reset()
|
||||
for agent_i_action in random_actions:
|
||||
env_state, reward, done_bool, info_obj = factory.step(agent_i_action)
|
||||
if render:
|
||||
|
Reference in New Issue
Block a user