mirror of
https://github.com/illiumst/marl-factory-grid.git
synced 2025-05-22 14:56:43 +02:00
No more Monitor,
env hparams pickeling, pomdp, now training and learning
This commit is contained in:
parent
55b409c72f
commit
ff9846eb54
@ -1,3 +1,4 @@
|
||||
import pickle
|
||||
from pathlib import Path
|
||||
from typing import List, Union, Iterable
|
||||
|
||||
@ -101,28 +102,21 @@ class BaseFactory(gym.Env):
|
||||
|
||||
@property
|
||||
def observation_space(self):
|
||||
if self.pomdp_size:
|
||||
return spaces.Box(low=0, high=1, shape=(self.state.shape[0], self.pomdp_size,
|
||||
self.pomdp_size), dtype=np.float32)
|
||||
if self.pomdp_radius:
|
||||
return spaces.Box(low=0, high=1, shape=(self._state.shape[0], self.pomdp_radius * 2 + 1,
|
||||
self.pomdp_radius * 2 + 1), dtype=np.float32)
|
||||
else:
|
||||
space = spaces.Box(low=0, high=1, shape=self.state.shape, dtype=np.float32)
|
||||
# space = spaces.MultiBinary(np.prod(self.state.shape))
|
||||
# space = spaces.Dict({
|
||||
# 'level': spaces.MultiBinary(np.prod(self.state[0].shape)),
|
||||
# 'agent_n': spaces.Discrete(np.prod(self.state[1].shape)),
|
||||
# 'dirt': spaces.Box(low=0, high=1, shape=self.state[2].shape, dtype=np.float32)
|
||||
# })
|
||||
space = spaces.Box(low=0, high=1, shape=self._state.shape, dtype=np.float32)
|
||||
return space
|
||||
|
||||
@property
|
||||
def movement_actions(self):
|
||||
return self._actions.movement_actions
|
||||
|
||||
def __init__(self, level='simple', n_agents=1, max_steps=int(5e2), pomdp_size: Union[None, int] = None, **kwargs):
|
||||
def __init__(self, level='simple', n_agents=1, max_steps=int(5e2), pomdp_radius: Union[None, int] = None, **kwargs):
|
||||
self.n_agents = n_agents
|
||||
self.max_steps = max_steps
|
||||
assert pomdp_size is None or (pomdp_size is not None and pomdp_size % 2 == 1)
|
||||
self.pomdp_size = pomdp_size
|
||||
self.pomdp_radius = pomdp_radius
|
||||
|
||||
self.done_at_collision = False
|
||||
_actions = Actions(allow_square_movement=kwargs.get('allow_square_movement', True),
|
||||
@ -130,10 +124,10 @@ class BaseFactory(gym.Env):
|
||||
allow_no_op=kwargs.get('allow_no_op', True))
|
||||
self._actions = _actions + self.additional_actions
|
||||
|
||||
self.level = h.one_hot_level(
|
||||
self._level = h.one_hot_level(
|
||||
h.parse_level(Path(__file__).parent / h.LEVELS_DIR / f'{level}.txt')
|
||||
)
|
||||
self.state_slices = StateSlice(n_agents)
|
||||
self._state_slices = StateSlice(n_agents)
|
||||
self.reset()
|
||||
|
||||
@property
|
||||
@ -150,48 +144,35 @@ class BaseFactory(gym.Env):
|
||||
|
||||
def reset(self) -> (np.ndarray, int, bool, dict):
|
||||
self.steps = 0
|
||||
self.agent_states = []
|
||||
self._agent_states = []
|
||||
# Agent placement ...
|
||||
agents = np.zeros((self.n_agents, *self.level.shape), dtype=np.int8)
|
||||
floor_tiles = np.argwhere(self.level == h.IS_FREE_CELL)
|
||||
agents = np.zeros((self.n_agents, *self._level.shape), dtype=np.int8)
|
||||
floor_tiles = np.argwhere(self._level == h.IS_FREE_CELL)
|
||||
# ... on random positions
|
||||
np.random.shuffle(floor_tiles)
|
||||
for i, (x, y) in enumerate(floor_tiles[:self.n_agents]):
|
||||
agents[i, x, y] = h.IS_OCCUPIED_CELL
|
||||
agent_state = AgentState(i, -1)
|
||||
agent_state.update(pos=[x, y])
|
||||
self.agent_states.append(agent_state)
|
||||
self._agent_states.append(agent_state)
|
||||
# state.shape = level, agent 1,..., agent n,
|
||||
self.state = np.concatenate((np.expand_dims(self.level, axis=0), agents), axis=0)
|
||||
self._state = np.concatenate((np.expand_dims(self._level, axis=0), agents), axis=0)
|
||||
# Returns State
|
||||
return self._return_state()
|
||||
|
||||
def _return_state(self):
|
||||
if self.pomdp_size:
|
||||
pos = self.agent_states[0].pos
|
||||
if self.pomdp_radius:
|
||||
pos = self._agent_states[0].pos
|
||||
# pos = [agent_state.pos for agent_state in self.agent_states]
|
||||
# obs = [] ... list comprehension... pos per agent
|
||||
offset = self.pomdp_size // 2
|
||||
x0, x1 = max(0, pos[0] - offset), pos[0] + offset + 1
|
||||
y0, y1 = max(0, pos[1] - offset), pos[1] + offset + 1
|
||||
obs = self.state[:, x0:x1, y0:y1]
|
||||
if obs.shape[1] != self.pomdp_size or obs.shape[2] != self.pomdp_size:
|
||||
obs_padded = np.zeros((obs.shape[0], self.pomdp_size, self.pomdp_size))
|
||||
try:
|
||||
a_pos = np.argwhere(obs[h.AGENT_START_IDX] == h.IS_OCCUPIED_CELL)[0]
|
||||
except IndexError:
|
||||
print('Shiiiiiit')
|
||||
try:
|
||||
obs_padded[:, abs(a_pos[0]-offset):abs(a_pos[0]-offset)+obs.shape[1], abs(a_pos[1]-offset):abs(a_pos[1]-offset)+obs.shape[2]] = obs
|
||||
except ValueError:
|
||||
print('Shiiiiiit')
|
||||
assert all(np.argwhere(obs_padded[h.AGENT_START_IDX] == h.IS_OCCUPIED_CELL)[0] == (3,3))
|
||||
|
||||
|
||||
|
||||
obs = obs_padded
|
||||
npad = [(0, 0)] + [(self.pomdp_radius, self.pomdp_radius)] * (self._state.ndim - 1)
|
||||
x_roll = self.pomdp_radius-pos[0]
|
||||
y_roll = self.pomdp_radius-pos[1]
|
||||
padded_state = np.pad(self._state, pad_width=npad, mode='constant', constant_values=0)
|
||||
padded_state = np.roll(np.roll(padded_state, x_roll, axis=1), y_roll, axis=2)
|
||||
obs = padded_state[:, :self.pomdp_radius * 2 + 1, :self.pomdp_radius * 2 + 1]
|
||||
else:
|
||||
obs = self.state
|
||||
obs = self._state
|
||||
return obs
|
||||
|
||||
def do_additional_actions(self, agent_i: int, action: int) -> ((int, int), bool):
|
||||
@ -217,12 +198,12 @@ class BaseFactory(gym.Env):
|
||||
agent_i_state.update(pos=pos, action_valid=valid)
|
||||
agent_states.append(agent_i_state)
|
||||
|
||||
for i, collision_vec in enumerate(self.check_all_collisions(agent_states, self.state.shape[0])):
|
||||
for i, collision_vec in enumerate(self.check_all_collisions(agent_states, self._state.shape[0])):
|
||||
agent_states[i].update(collision_vector=collision_vec)
|
||||
if self.done_at_collision and collision_vec.any():
|
||||
done = True
|
||||
|
||||
self.agent_states = agent_states
|
||||
self._agent_states = agent_states
|
||||
reward, info = self.calculate_reward(agent_states)
|
||||
|
||||
if self.steps >= self.max_steps:
|
||||
@ -250,7 +231,7 @@ class BaseFactory(gym.Env):
|
||||
def check_collisions(self, agent_state: AgentState) -> np.ndarray:
|
||||
pos_x, pos_y = agent_state.pos
|
||||
# FixMe: We need to find a way to spare out some dimensions, eg. an info dimension etc... a[?,]
|
||||
collisions_vec = self.state[:, pos_x, pos_y].copy() # "vertical fiber" at position of agent i
|
||||
collisions_vec = self._state[:, pos_x, pos_y].copy() # "vertical fiber" at position of agent i
|
||||
collisions_vec[h.AGENT_START_IDX + agent_state.i] = h.IS_FREE_CELL # no self-collisions
|
||||
if agent_state.action_valid:
|
||||
# ToDo: Place a function hook here
|
||||
@ -262,11 +243,11 @@ class BaseFactory(gym.Env):
|
||||
|
||||
def do_move(self, agent_i: int, old_pos: (int, int), new_pos: (int, int)) -> None:
|
||||
(x, y), (x_new, y_new) = old_pos, new_pos
|
||||
self.state[agent_i + h.AGENT_START_IDX, x, y] = h.IS_FREE_CELL
|
||||
self.state[agent_i + h.AGENT_START_IDX, x_new, y_new] = h.IS_OCCUPIED_CELL
|
||||
self._state[agent_i + h.AGENT_START_IDX, x, y] = h.IS_FREE_CELL
|
||||
self._state[agent_i + h.AGENT_START_IDX, x_new, y_new] = h.IS_OCCUPIED_CELL
|
||||
|
||||
def move_or_colide(self, agent_i: int, action: int) -> ((int, int), bool):
|
||||
old_pos, new_pos, valid = h.check_agent_move(state=self.state,
|
||||
old_pos, new_pos, valid = h.check_agent_move(state=self._state,
|
||||
dim=agent_i + h.AGENT_START_IDX,
|
||||
action=action)
|
||||
if valid:
|
||||
@ -278,7 +259,7 @@ class BaseFactory(gym.Env):
|
||||
return old_pos, valid
|
||||
|
||||
def agent_i_position(self, agent_i: int) -> (int, int):
|
||||
positions = np.argwhere(self.state[h.AGENT_START_IDX+agent_i] == h.IS_OCCUPIED_CELL)
|
||||
positions = np.argwhere(self._state[h.AGENT_START_IDX + agent_i] == h.IS_OCCUPIED_CELL)
|
||||
assert positions.shape[0] == 1
|
||||
pos_x, pos_y = positions[0] # a.flatten()
|
||||
return pos_x, pos_y
|
||||
@ -288,13 +269,13 @@ class BaseFactory(gym.Env):
|
||||
assert isinstance(excluded_slices, (int, list))
|
||||
excluded_slices = excluded_slices if isinstance(excluded_slices, list) else [excluded_slices]
|
||||
|
||||
state = self.state
|
||||
state = self._state
|
||||
|
||||
if excluded_slices:
|
||||
# Todo: Is there a cleaner way?
|
||||
inds = list(range(self.state.shape[0]))
|
||||
inds = list(range(self._state.shape[0]))
|
||||
excluded_slices = [inds[x] if x < 0 else x for x in excluded_slices]
|
||||
state = self.state[[x for x in inds if x not in excluded_slices]]
|
||||
state = self._state[[x for x in inds if x not in excluded_slices]]
|
||||
|
||||
free_cells = np.argwhere(state.sum(0) == h.IS_FREE_CELL)
|
||||
np.random.shuffle(free_cells)
|
||||
@ -306,3 +287,9 @@ class BaseFactory(gym.Env):
|
||||
|
||||
def render(self):
|
||||
raise NotImplementedError
|
||||
|
||||
def save_params(self, filepath: Path):
|
||||
d = {key: val for key, val in self.__dict__.items() if not key.startswith('_') or not key.startswith('__')}
|
||||
filepath.parent.mkdir(parents=True, exist_ok=True)
|
||||
with filepath.open('wb') as f:
|
||||
pickle.dump(d, f, protocol=pickle.HIGHEST_PROTOCOL)
|
||||
|
@ -34,27 +34,27 @@ class SimpleFactory(BaseFactory):
|
||||
return self._actions[action] == CLEAN_UP_ACTION
|
||||
|
||||
def __init__(self, *args, dirt_properties: DirtProperties, verbose=False, **kwargs):
|
||||
self._dirt_properties = dirt_properties
|
||||
self.dirt_properties = dirt_properties
|
||||
self.verbose = verbose
|
||||
self.max_dirt = 20
|
||||
super(SimpleFactory, self).__init__(*args, **kwargs)
|
||||
self.state_slices.register_additional_items('dirt')
|
||||
self.renderer = None # expensive - don't use it when not required !
|
||||
self._state_slices.register_additional_items('dirt')
|
||||
self._renderer = None # expensive - don't use it when not required !
|
||||
|
||||
def render(self):
|
||||
|
||||
if not self.renderer: # lazy init
|
||||
height, width = self.state.shape[1:]
|
||||
self.renderer = Renderer(width, height, view_radius=2)
|
||||
if not self._renderer: # lazy init
|
||||
height, width = self._state.shape[1:]
|
||||
self._renderer = Renderer(width, height, view_radius=self.pomdp_radius)
|
||||
|
||||
dirt = [Entity('dirt', [x, y], min(0.15+self.state[DIRT_INDEX, x, y], 1.5), 'scale')
|
||||
for x, y in np.argwhere(self.state[DIRT_INDEX] > h.IS_FREE_CELL)]
|
||||
walls = [Entity('wall', pos) for pos in np.argwhere(self.state[h.LEVEL_IDX] > h.IS_FREE_CELL)]
|
||||
dirt = [Entity('dirt', [x, y], min(0.15 + self._state[DIRT_INDEX, x, y], 1.5), 'scale')
|
||||
for x, y in np.argwhere(self._state[DIRT_INDEX] > h.IS_FREE_CELL)]
|
||||
walls = [Entity('wall', pos) for pos in np.argwhere(self._state[h.LEVEL_IDX] > h.IS_FREE_CELL)]
|
||||
|
||||
def asset_str(agent):
|
||||
if any([x is None for x in [self.state_slices[j] for j in agent.collisions]]):
|
||||
if any([x is None for x in [self._state_slices[j] for j in agent.collisions]]):
|
||||
print('error')
|
||||
cols = ' '.join([self.state_slices[j] for j in agent.collisions])
|
||||
cols = ' '.join([self._state_slices[j] for j in agent.collisions])
|
||||
if 'agent' in cols:
|
||||
return 'agent_collision'
|
||||
elif not agent.action_valid or 'level' in cols or 'agent' in cols:
|
||||
@ -65,38 +65,38 @@ class SimpleFactory(BaseFactory):
|
||||
return f'agent{agent.i + 1}'
|
||||
|
||||
agents = {f'agent{i+1}': [Entity(asset_str(agent), agent.pos)]
|
||||
for i, agent in enumerate(self.agent_states)}
|
||||
self.renderer.render(OrderedDict(dirt=dirt, wall=walls, **agents))
|
||||
for i, agent in enumerate(self._agent_states)}
|
||||
self._renderer.render(OrderedDict(dirt=dirt, wall=walls, **agents))
|
||||
|
||||
def spawn_dirt(self) -> None:
|
||||
if not np.argwhere(self.state[DIRT_INDEX] != h.IS_FREE_CELL).shape[0] > self._dirt_properties.max_global_amount:
|
||||
if not np.argwhere(self._state[DIRT_INDEX] != h.IS_FREE_CELL).shape[0] > self.dirt_properties.max_global_amount:
|
||||
free_for_dirt = self.free_cells(excluded_slices=DIRT_INDEX)
|
||||
|
||||
# randomly distribute dirt across the grid
|
||||
n_dirt_tiles = int(random.uniform(0, self._dirt_properties.max_spawn_ratio) * len(free_for_dirt))
|
||||
n_dirt_tiles = int(random.uniform(0, self.dirt_properties.max_spawn_ratio) * len(free_for_dirt))
|
||||
for x, y in free_for_dirt[:n_dirt_tiles]:
|
||||
new_value = self.state[DIRT_INDEX, x, y] + self._dirt_properties.gain_amount
|
||||
self.state[DIRT_INDEX, x, y] = max(new_value, self._dirt_properties.max_local_amount)
|
||||
new_value = self._state[DIRT_INDEX, x, y] + self.dirt_properties.gain_amount
|
||||
self._state[DIRT_INDEX, x, y] = max(new_value, self.dirt_properties.max_local_amount)
|
||||
|
||||
else:
|
||||
pass
|
||||
|
||||
def clean_up(self, pos: (int, int)) -> ((int, int), bool):
|
||||
new_dirt_amount = self.state[DIRT_INDEX][pos] - self._dirt_properties.clean_amount
|
||||
new_dirt_amount = self._state[DIRT_INDEX][pos] - self.dirt_properties.clean_amount
|
||||
cleanup_was_sucessfull: bool
|
||||
if self.state[DIRT_INDEX][pos] == h.IS_FREE_CELL:
|
||||
if self._state[DIRT_INDEX][pos] == h.IS_FREE_CELL:
|
||||
cleanup_was_sucessfull = False
|
||||
return pos, cleanup_was_sucessfull
|
||||
else:
|
||||
cleanup_was_sucessfull = True
|
||||
self.state[DIRT_INDEX][pos] = max(new_dirt_amount, h.IS_FREE_CELL)
|
||||
self._state[DIRT_INDEX][pos] = max(new_dirt_amount, h.IS_FREE_CELL)
|
||||
return pos, cleanup_was_sucessfull
|
||||
|
||||
def step(self, actions):
|
||||
_, r, done, info = super(SimpleFactory, self).step(actions)
|
||||
if not self.next_dirt_spawn:
|
||||
self.spawn_dirt()
|
||||
self.next_dirt_spawn = self._dirt_properties.spawn_frequency
|
||||
self.next_dirt_spawn = self.dirt_properties.spawn_frequency
|
||||
else:
|
||||
self.next_dirt_spawn -= 1
|
||||
obs = self._return_state()
|
||||
@ -115,17 +115,17 @@ class SimpleFactory(BaseFactory):
|
||||
|
||||
def reset(self) -> (np.ndarray, int, bool, dict):
|
||||
_ = super().reset() # state, reward, done, info ... =
|
||||
dirt_slice = np.zeros((1, *self.state.shape[1:]))
|
||||
self.state = np.concatenate((self.state, dirt_slice)) # dirt is now the last slice
|
||||
dirt_slice = np.zeros((1, *self._state.shape[1:]))
|
||||
self._state = np.concatenate((self._state, dirt_slice)) # dirt is now the last slice
|
||||
self.spawn_dirt()
|
||||
self.next_dirt_spawn = self._dirt_properties.spawn_frequency
|
||||
self.next_dirt_spawn = self.dirt_properties.spawn_frequency
|
||||
obs = self._return_state()
|
||||
return obs
|
||||
|
||||
def calculate_reward(self, agent_states: List[AgentState]) -> (int, dict):
|
||||
# TODO: What reward to use?
|
||||
current_dirt_amount = self.state[DIRT_INDEX].sum()
|
||||
dirty_tiles = np.argwhere(self.state[DIRT_INDEX] != h.IS_FREE_CELL).shape[0]
|
||||
current_dirt_amount = self._state[DIRT_INDEX].sum()
|
||||
dirty_tiles = np.argwhere(self._state[DIRT_INDEX] != h.IS_FREE_CELL).shape[0]
|
||||
info_dict = dict()
|
||||
|
||||
try:
|
||||
@ -137,11 +137,13 @@ class SimpleFactory(BaseFactory):
|
||||
for agent_state in agent_states:
|
||||
cols = agent_state.collisions
|
||||
|
||||
list_of_collisions = [self.state_slices[entity] for entity in cols
|
||||
if entity != self.state_slices.by_name("dirt")]
|
||||
list_of_collisions = [self._state_slices[entity] for entity in cols
|
||||
if entity != self._state_slices.by_name("dirt")]
|
||||
|
||||
if list_of_collisions:
|
||||
self.print(f't = {self.steps}\tAgent {agent_state.i} has collisions with '
|
||||
f'{list_of_collisions}')
|
||||
|
||||
self.print(f't = {self.steps}\tAgent {agent_state.i} has collisions with '
|
||||
f'{list_of_collisions}')
|
||||
if self._is_clean_up_action(agent_state.action):
|
||||
if agent_state.action_valid:
|
||||
reward += 1
|
||||
@ -155,19 +157,19 @@ class SimpleFactory(BaseFactory):
|
||||
|
||||
elif self._is_moving_action(agent_state.action):
|
||||
if agent_state.action_valid:
|
||||
info_dict.update(movement=1)
|
||||
# info_dict.update(movement=1)
|
||||
reward -= 0.00
|
||||
else:
|
||||
info_dict.update(collision=1)
|
||||
# info_dict.update(collision=1)
|
||||
# self.print('collision')
|
||||
reward -= 0.00
|
||||
|
||||
else:
|
||||
info_dict.update(collision=1)
|
||||
info_dict.update(no_op=1)
|
||||
reward -= 0.00
|
||||
|
||||
for entity in cols:
|
||||
if entity != self.state_slices.by_name("dirt"):
|
||||
info_dict.update({f'agent_{agent_state.i}_vs_{self.state_slices[entity]}': 1})
|
||||
for entity in list_of_collisions:
|
||||
info_dict.update({f'agent_{agent_state.i}_vs_{self._state_slices.by_name(entity)}': 1})
|
||||
|
||||
info_dict.update(dirt_amount=current_dirt_amount)
|
||||
info_dict.update(dirty_tile_count=dirty_tiles)
|
||||
|
@ -9,7 +9,7 @@ AGENT_START_IDX = 1
|
||||
IS_FREE_CELL = 0
|
||||
IS_OCCUPIED_CELL = 1
|
||||
TO_BE_AVERAGED = ['dirt_amount', 'dirty_tiles']
|
||||
IGNORED_DF_COLUMNS = ['Episode', 'Run', 'train_step', 'step', 'index']
|
||||
IGNORED_DF_COLUMNS = ['Episode', 'Run', 'train_step', 'step', 'index', 'dirt_amount', 'dirty_tile_count']
|
||||
|
||||
|
||||
# Utility functions
|
||||
|
@ -32,9 +32,6 @@ def plot(filepath, ext='png', **kwargs):
|
||||
|
||||
def prepare_plot(filepath, results_df, ext='png'):
|
||||
|
||||
_ = sns.lineplot(data=results_df, x='Episode', y='Score', hue='Measurement', ci='sd')
|
||||
|
||||
# %%
|
||||
sns.set_theme(palette=PALETTE, style='whitegrid')
|
||||
font_size = 16
|
||||
tex_fonts = {
|
||||
@ -50,6 +47,8 @@ def prepare_plot(filepath, results_df, ext='png'):
|
||||
"ytick.labelsize": font_size - 2
|
||||
}
|
||||
|
||||
sns.lineplot(data=results_df, x='Episode', y='Score', hue='Measurement', ci='sd')
|
||||
|
||||
try:
|
||||
plot(filepath, ext=ext, **tex_fonts)
|
||||
except (FileNotFoundError, RuntimeError):
|
||||
|
22
main.py
22
main.py
@ -26,17 +26,9 @@ def combine_runs(run_path: Union[str, PathLike]):
|
||||
monitor_df = pickle.load(f)
|
||||
|
||||
monitor_df['run'] = run
|
||||
|
||||
monitor_df = monitor_df.fillna(0)
|
||||
|
||||
#for column in list(df.columns):
|
||||
# if column not in ['episode', 'run', 'step', 'train_step']:
|
||||
# if 'clean' in column or '_vs_' in column:
|
||||
# df[f'{column}_sum_roll'] = df[column].rolling(window=50, min_periods=1).sum()
|
||||
# else:
|
||||
# df[f'{column}_mean_roll'] = df[column].rolling(window=50, min_periods=1).mean()
|
||||
|
||||
df_list.append(monitor_df)
|
||||
|
||||
df = pd.concat(df_list, ignore_index=True)
|
||||
df = df.fillna(0).rename(columns={'episode': 'Episode', 'run': 'Run'})
|
||||
columns = [col for col in df.columns if col not in IGNORED_DF_COLUMNS]
|
||||
@ -53,20 +45,17 @@ def combine_runs(run_path: Union[str, PathLike]):
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
# combine_runs('debug_out/PPO_1622399010')
|
||||
# exit()
|
||||
|
||||
from stable_baselines3 import PPO, DQN, A2C
|
||||
dirt_props = DirtProperties()
|
||||
time_stamp = int(time.time())
|
||||
|
||||
out_path = None
|
||||
|
||||
for modeL_type in [A2C, PPO]:
|
||||
for modeL_type in [A2C, PPO, DQN]:
|
||||
for seed in range(5):
|
||||
|
||||
env = SimpleFactory(n_agents=1, dirt_properties=dirt_props, pomdp_size=7,
|
||||
allow_diagonal_movement=False, allow_no_op=False)
|
||||
env = SimpleFactory(n_agents=1, dirt_properties=dirt_props, pomdp_radius=2, max_steps=400,
|
||||
allow_diagonal_movement=False, allow_no_op=False, verbose=True)
|
||||
|
||||
model = modeL_type("MlpPolicy", env, verbose=1, seed=seed, device='cpu')
|
||||
|
||||
@ -80,11 +69,12 @@ if __name__ == '__main__':
|
||||
MonitorCallback(env, filepath=out_path / f'monitor_{identifier}.pick', plotting=False)]
|
||||
)
|
||||
|
||||
model.learn(total_timesteps=int(5e5), callback=callbacks)
|
||||
model.learn(total_timesteps=int(2e5), callback=callbacks)
|
||||
|
||||
save_path = out_path / f'model_{identifier}.zip'
|
||||
save_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
model.save(save_path)
|
||||
env.save_params(out_path.parent / f'env_{model.__class__.__name__}_{time_stamp}.pick')
|
||||
|
||||
if out_path:
|
||||
combine_runs(out_path.parent)
|
||||
|
@ -1,3 +1,4 @@
|
||||
import pickle
|
||||
import warnings
|
||||
from pathlib import Path
|
||||
|
||||
@ -12,10 +13,15 @@ warnings.filterwarnings('ignore', category=UserWarning)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
dirt_props = DirtProperties()
|
||||
env = SimpleFactory(n_agents=1, dirt_properties=dirt_props)
|
||||
|
||||
out_path = Path(r'C:\Users\steff\projects\f_iks\debug_out\PPO_1622485791\1_PPO_1622485791')
|
||||
out_path = Path(r'C:\Users\steff\projects\f_iks\debug_out\A2C_1622557712')
|
||||
with (out_path / f'env_{out_path.name}.pick').open('rb') as f:
|
||||
env_kwargs = pickle.load(f)
|
||||
env = SimpleFactory(**env_kwargs)
|
||||
|
||||
# Edit THIS:
|
||||
model_path = out_path
|
||||
|
||||
model_files = list(natsorted(out_path.rglob('*.zip')))
|
||||
this_model = model_files[0]
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user