Accounting Zones for basefactory, mark them in level.txt by int, x for danger zones
This commit is contained in:
@ -7,20 +7,15 @@ import numpy as np
|
|||||||
from gym import spaces
|
from gym import spaces
|
||||||
|
|
||||||
import yaml
|
import yaml
|
||||||
|
from gym.wrappers import FrameStack
|
||||||
|
|
||||||
from environments import helpers as h
|
from environments import helpers as h
|
||||||
from environments.utility_classes import Actions, StateSlice, AgentState, MovementProperties
|
from environments.utility_classes import Actions, StateSlice, AgentState, MovementProperties, Zones
|
||||||
|
|
||||||
|
|
||||||
# noinspection PyAttributeOutsideInit
|
# noinspection PyAttributeOutsideInit
|
||||||
class BaseFactory(gym.Env):
|
class BaseFactory(gym.Env):
|
||||||
|
|
||||||
def __setattr__(self, key, value):
|
|
||||||
if isinstance(value, dict):
|
|
||||||
super(BaseFactory, self).__setattr__(key, Namespace(**value))
|
|
||||||
else:
|
|
||||||
super(BaseFactory, self).__setattr__(key, value)
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def action_space(self):
|
def action_space(self):
|
||||||
return spaces.Discrete(self._actions.n)
|
return spaces.Discrete(self._actions.n)
|
||||||
@ -40,11 +35,20 @@ class BaseFactory(gym.Env):
|
|||||||
def movement_actions(self):
|
def movement_actions(self):
|
||||||
return self._actions.movement_actions
|
return self._actions.movement_actions
|
||||||
|
|
||||||
|
def __enter__(self):
|
||||||
|
return self if self.frames_to_stack == 0 else FrameStack(self, self.frames_to_stack)
|
||||||
|
|
||||||
|
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||||
|
self.close()
|
||||||
|
|
||||||
def __init__(self, level_name='simple', n_agents=1, max_steps=int(5e2), pomdp_radius: Union[None, int] = 0,
|
def __init__(self, level_name='simple', n_agents=1, max_steps=int(5e2), pomdp_radius: Union[None, int] = 0,
|
||||||
movement_properties: MovementProperties = MovementProperties(),
|
movement_properties: MovementProperties = MovementProperties(),
|
||||||
combin_agent_slices_in_obs: bool = False,
|
combin_agent_slices_in_obs: bool = False, frames_to_stack=0,
|
||||||
omit_agent_slice_in_obs=False, **kwargs):
|
omit_agent_slice_in_obs=False, **kwargs):
|
||||||
assert combin_agent_slices_in_obs != omit_agent_slice_in_obs, 'Both options are exclusive'
|
assert (combin_agent_slices_in_obs != omit_agent_slice_in_obs) or \
|
||||||
|
(not combin_agent_slices_in_obs and not omit_agent_slice_in_obs), \
|
||||||
|
'Both options are exclusive'
|
||||||
|
assert frames_to_stack != 1 and frames_to_stack >= 0, "'frames_to_stack' cannot be negative or 1."
|
||||||
|
|
||||||
self.movement_properties = movement_properties
|
self.movement_properties = movement_properties
|
||||||
self.level_name = level_name
|
self.level_name = level_name
|
||||||
@ -54,17 +58,19 @@ class BaseFactory(gym.Env):
|
|||||||
self.pomdp_radius = pomdp_radius
|
self.pomdp_radius = pomdp_radius
|
||||||
self.combin_agent_slices_in_obs = combin_agent_slices_in_obs
|
self.combin_agent_slices_in_obs = combin_agent_slices_in_obs
|
||||||
self.omit_agent_slice_in_obs = omit_agent_slice_in_obs
|
self.omit_agent_slice_in_obs = omit_agent_slice_in_obs
|
||||||
|
self.frames_to_stack = frames_to_stack
|
||||||
|
|
||||||
self.done_at_collision = False
|
self.done_at_collision = False
|
||||||
_actions = Actions(self.movement_properties)
|
_actions = Actions(self.movement_properties)
|
||||||
self._actions = _actions + self.additional_actions
|
self._actions = _actions + self.additional_actions
|
||||||
|
|
||||||
self._level = h.one_hot_level(
|
level_filepath = Path(__file__).parent / h.LEVELS_DIR / f'{self.level_name}.txt'
|
||||||
h.parse_level(Path(__file__).parent / h.LEVELS_DIR / f'{self.level_name}.txt')
|
parsed_level = h.parse_level(level_filepath)
|
||||||
)
|
self._level = h.one_hot_level(parsed_level)
|
||||||
self._state_slices = StateSlice(n_agents)
|
self._state_slices = StateSlice(n_agents)
|
||||||
if 'additional_slices' in kwargs:
|
if 'additional_slices' in kwargs:
|
||||||
self._state_slices.register_additional_items(kwargs.get('additional_slices'))
|
self._state_slices.register_additional_items(kwargs.get('additional_slices'))
|
||||||
|
self._zones = Zones(parsed_level)
|
||||||
self.reset()
|
self.reset()
|
||||||
|
|
||||||
@property
|
@property
|
||||||
@ -259,7 +265,7 @@ class BaseFactory(gym.Env):
|
|||||||
# d = {key: val._asdict() if hasattr(val, '_asdict') else val for key, val in self.__dict__.items()
|
# d = {key: val._asdict() if hasattr(val, '_asdict') else val for key, val in self.__dict__.items()
|
||||||
d = {key: val for key, val in self.__dict__.items() if not key.startswith('_') and not key.startswith('__')}
|
d = {key: val for key, val in self.__dict__.items() if not key.startswith('_') and not key.startswith('__')}
|
||||||
filepath.parent.mkdir(parents=True, exist_ok=True)
|
filepath.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
super(BaseFactory, self).save_params()
|
||||||
with filepath.open('w') as f:
|
with filepath.open('w') as f:
|
||||||
yaml.dump(d, f)
|
yaml.dump(d, f)
|
||||||
# pickle.dump(d, f, protocol=pickle.HIGHEST_PROTOCOL)
|
# pickle.dump(d, f, protocol=pickle.HIGHEST_PROTOCOL)
|
||||||
|
@ -1,13 +1,13 @@
|
|||||||
###############
|
###############
|
||||||
#------#------#
|
#333x33#444444#
|
||||||
#---#--#------#
|
#333#33#444444#
|
||||||
#--------#----#
|
#333333xx#4444#
|
||||||
#------#------#
|
#333333#444444#
|
||||||
#------#------#
|
#333333#444444#
|
||||||
###-#######-###
|
###x#######x###
|
||||||
#----##-------#
|
#1111##2222222#
|
||||||
#-----#----#--#
|
#11111#2222#22#
|
||||||
#-------------#
|
#11111x2222222#
|
||||||
#-----#-------#
|
#11111#2222222#
|
||||||
#-----#-------#
|
#11111#2222222#
|
||||||
###############
|
###############
|
@ -8,7 +8,7 @@ from environments.factory.base_factory import BaseFactory
|
|||||||
from environments import helpers as h
|
from environments import helpers as h
|
||||||
|
|
||||||
from environments.factory.renderer import Renderer, Entity
|
from environments.factory.renderer import Renderer, Entity
|
||||||
from environments.utility_classes import AgentState, MovementProperties
|
from environments.utility_classes import AgentState, MovementProperties, Register
|
||||||
|
|
||||||
DIRT_INDEX = -1
|
DIRT_INDEX = -1
|
||||||
CLEAN_UP_ACTION = 'clean_up'
|
CLEAN_UP_ACTION = 'clean_up'
|
||||||
@ -39,8 +39,8 @@ class SimpleFactory(BaseFactory):
|
|||||||
self.dirt_properties = dirt_properties
|
self.dirt_properties = dirt_properties
|
||||||
self.verbose = verbose
|
self.verbose = verbose
|
||||||
self.max_dirt = 20
|
self.max_dirt = 20
|
||||||
super(SimpleFactory, self).__init__(*args, additional_slices='dirt', **kwargs)
|
|
||||||
self._renderer = None # expensive - don't use it when not required !
|
self._renderer = None # expensive - don't use it when not required !
|
||||||
|
super(SimpleFactory, self).__init__(*args, additional_slices='dirt', **kwargs)
|
||||||
|
|
||||||
def render(self):
|
def render(self):
|
||||||
|
|
||||||
@ -79,7 +79,6 @@ class SimpleFactory(BaseFactory):
|
|||||||
for x, y in free_for_dirt[:n_dirt_tiles]:
|
for x, y in free_for_dirt[:n_dirt_tiles]:
|
||||||
new_value = self._state[DIRT_INDEX, x, y] + self.dirt_properties.gain_amount
|
new_value = self._state[DIRT_INDEX, x, y] + self.dirt_properties.gain_amount
|
||||||
self._state[DIRT_INDEX, x, y] = max(new_value, self.dirt_properties.max_local_amount)
|
self._state[DIRT_INDEX, x, y] = max(new_value, self.dirt_properties.max_local_amount)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@ -126,10 +125,11 @@ class SimpleFactory(BaseFactory):
|
|||||||
return obs
|
return obs
|
||||||
|
|
||||||
def calculate_reward(self, agent_states: List[AgentState]) -> (int, dict):
|
def calculate_reward(self, agent_states: List[AgentState]) -> (int, dict):
|
||||||
# TODO: What reward to use?
|
info_dict = dict()
|
||||||
current_dirt_amount = self._state[DIRT_INDEX].sum()
|
current_dirt_amount = self._state[DIRT_INDEX].sum()
|
||||||
dirty_tiles = np.argwhere(self._state[DIRT_INDEX] != h.IS_FREE_CELL).shape[0]
|
dirty_tiles = np.argwhere(self._state[DIRT_INDEX] != h.IS_FREE_CELL).shape[0]
|
||||||
info_dict = dict()
|
info_dict.update(dirt_amount=current_dirt_amount)
|
||||||
|
info_dict.update(dirty_tile_count=dirty_tiles)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# penalty = current_dirt_amount
|
# penalty = current_dirt_amount
|
||||||
@ -156,14 +156,13 @@ class SimpleFactory(BaseFactory):
|
|||||||
reward -= 0.01
|
reward -= 0.01
|
||||||
self.print(f'Agent {agent_state.i} just tried to clean up some dirt '
|
self.print(f'Agent {agent_state.i} just tried to clean up some dirt '
|
||||||
f'at {agent_state.pos}, but was unsucsessfull.')
|
f'at {agent_state.pos}, but was unsucsessfull.')
|
||||||
info_dict.update(failed_cleanup_attempt=1)
|
info_dict.update({f'agent_{agent_state.i}_failed_action': 1})
|
||||||
|
|
||||||
elif self._actions.is_moving_action(agent_state.action):
|
elif self._actions.is_moving_action(agent_state.action):
|
||||||
if agent_state.action_valid:
|
if agent_state.action_valid:
|
||||||
# info_dict.update(movement=1)
|
# info_dict.update(movement=1)
|
||||||
reward -= 0.00
|
reward -= 0.00
|
||||||
else:
|
else:
|
||||||
# info_dict.update(collision=1)
|
|
||||||
# self.print('collision')
|
# self.print('collision')
|
||||||
reward -= 0.01
|
reward -= 0.01
|
||||||
|
|
||||||
@ -172,10 +171,9 @@ class SimpleFactory(BaseFactory):
|
|||||||
reward -= 0.00
|
reward -= 0.00
|
||||||
|
|
||||||
for entity in list_of_collisions:
|
for entity in list_of_collisions:
|
||||||
|
entity = 'agent' if 'agent' in entity else entity
|
||||||
info_dict.update({f'agent_{agent_state.i}_vs_{entity}': 1})
|
info_dict.update({f'agent_{agent_state.i}_vs_{entity}': 1})
|
||||||
|
|
||||||
info_dict.update(dirt_amount=current_dirt_amount)
|
|
||||||
info_dict.update(dirty_tile_count=dirty_tiles)
|
|
||||||
self.print(f"reward is {reward}")
|
self.print(f"reward is {reward}")
|
||||||
# Potential based rewards ->
|
# Potential based rewards ->
|
||||||
# track the last reward , minus the current reward = potential
|
# track the last reward , minus the current reward = potential
|
||||||
@ -191,8 +189,8 @@ if __name__ == '__main__':
|
|||||||
|
|
||||||
move_props = MovementProperties(allow_diagonal_movement=True, allow_square_movement=True)
|
move_props = MovementProperties(allow_diagonal_movement=True, allow_square_movement=True)
|
||||||
dirt_props = DirtProperties()
|
dirt_props = DirtProperties()
|
||||||
factory = SimpleFactory(movement_properties=move_props, dirt_properties=dirt_props, n_agents=2,
|
factory = SimpleFactory(movement_properties=move_props, dirt_properties=dirt_props, n_agents=10,
|
||||||
combin_agent_slices_in_obs=True, omit_agent_slice_in_obs=False)
|
combin_agent_slices_in_obs=True, omit_agent_slice_in_obs=False, level_name='rooms')
|
||||||
|
|
||||||
# dirt_props = DirtProperties()
|
# dirt_props = DirtProperties()
|
||||||
# move_props = MovementProperties(allow_diagonal_movement=False, allow_no_op=False)
|
# move_props = MovementProperties(allow_diagonal_movement=False, allow_no_op=False)
|
||||||
|
@ -6,6 +6,7 @@ from pathlib import Path
|
|||||||
|
|
||||||
# Constants
|
# Constants
|
||||||
WALL = '#'
|
WALL = '#'
|
||||||
|
DANGER_ZONE = 'x'
|
||||||
LEVELS_DIR = 'levels'
|
LEVELS_DIR = 'levels'
|
||||||
LEVEL_IDX = 0
|
LEVEL_IDX = 0
|
||||||
AGENT_START_IDX = 1
|
AGENT_START_IDX = 1
|
||||||
|
@ -1,6 +1,8 @@
|
|||||||
from typing import Union, List, NamedTuple
|
from typing import Union, List, NamedTuple
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
|
from environments import helpers as h
|
||||||
|
|
||||||
|
|
||||||
class MovementProperties(NamedTuple):
|
class MovementProperties(NamedTuple):
|
||||||
allow_square_movement: bool = True
|
allow_square_movement: bool = True
|
||||||
@ -123,5 +125,38 @@ class StateSlice(Register):
|
|||||||
|
|
||||||
def __init__(self, n_agents: int):
|
def __init__(self, n_agents: int):
|
||||||
super(StateSlice, self).__init__()
|
super(StateSlice, self).__init__()
|
||||||
offset = 1
|
offset = 1 # AGENT_START_IDX
|
||||||
self.register_additional_items(['level', *[f'agent#{i}' for i in range(offset, n_agents+offset)]])
|
self.register_additional_items(['level', *[f'agent#{i}' for i in range(offset, n_agents+offset)]])
|
||||||
|
|
||||||
|
|
||||||
|
class Zones(Register):
|
||||||
|
|
||||||
|
@property
|
||||||
|
def danger_zone(self):
|
||||||
|
return self._zone_slices[self.by_name(h.DANGER_ZONE)]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def accounting_zones(self):
|
||||||
|
return [self[idx] for idx, name in self.items() if name != h.DANGER_ZONE]
|
||||||
|
|
||||||
|
def __init__(self, parsed_level):
|
||||||
|
super(Zones, self).__init__()
|
||||||
|
slices = list()
|
||||||
|
self._accounting_zones = list()
|
||||||
|
self._danger_zones = list()
|
||||||
|
for symbol in np.unique(parsed_level):
|
||||||
|
if symbol == h.WALL:
|
||||||
|
continue
|
||||||
|
elif symbol == h.DANGER_ZONE:
|
||||||
|
self + symbol
|
||||||
|
slices.append(h.one_hot_level(parsed_level, symbol))
|
||||||
|
self._danger_zones.append(symbol)
|
||||||
|
else:
|
||||||
|
self + symbol
|
||||||
|
slices.append(h.one_hot_level(parsed_level, symbol))
|
||||||
|
self._accounting_zones.append(symbol)
|
||||||
|
|
||||||
|
self._zone_slices = np.stack(slices)
|
||||||
|
|
||||||
|
def __getitem__(self, item):
|
||||||
|
return self._zone_slices[item]
|
||||||
|
38
main.py
38
main.py
@ -93,7 +93,7 @@ if __name__ == '__main__':
|
|||||||
# from sb3_contrib import QRDQN
|
# from sb3_contrib import QRDQN
|
||||||
|
|
||||||
dirt_props = DirtProperties(clean_amount=3, gain_amount=0.2, max_global_amount=30,
|
dirt_props = DirtProperties(clean_amount=3, gain_amount=0.2, max_global_amount=30,
|
||||||
max_local_amount=5, spawn_frequency=3)
|
max_local_amount=5, spawn_frequency=1, max_spawn_ratio=0.05)
|
||||||
move_props = MovementProperties(allow_diagonal_movement=True,
|
move_props = MovementProperties(allow_diagonal_movement=True,
|
||||||
allow_square_movement=True,
|
allow_square_movement=True,
|
||||||
allow_no_op=False)
|
allow_no_op=False)
|
||||||
@ -104,31 +104,29 @@ if __name__ == '__main__':
|
|||||||
for modeL_type in [PPO, A2C]: # , RegDQN, DQN]:
|
for modeL_type in [PPO, A2C]: # , RegDQN, DQN]:
|
||||||
for seed in range(3):
|
for seed in range(3):
|
||||||
|
|
||||||
env = SimpleFactory(n_agents=1, dirt_properties=dirt_props, pomdp_radius=3, max_steps=400,
|
with SimpleFactory(n_agents=1, dirt_properties=dirt_props, pomdp_radius=2, max_steps=400,
|
||||||
movement_properties=move_props, level_name='rooms',
|
movement_properties=move_props, level_name='rooms', frames_to_stack=4,
|
||||||
omit_agent_slice_in_obs=True)
|
omit_agent_slice_in_obs=False, combin_agent_slices_in_obs=True) as env:
|
||||||
|
|
||||||
# env = FrameStack(env, 4)
|
kwargs = dict(ent_coef=0.01) if isinstance(modeL_type, (PPO, A2C)) else {}
|
||||||
|
model = modeL_type("MlpPolicy", env, verbose=1, seed=seed, device='cpu', **kwargs)
|
||||||
|
|
||||||
kwargs = dict(ent_coef=0.01) if isinstance(modeL_type, (PPO, A2C)) else {}
|
out_path = Path('debug_out') / f'{model.__class__.__name__}_{time_stamp}'
|
||||||
model = modeL_type("MlpPolicy", env, verbose=1, seed=seed, device='cpu', **kwargs)
|
|
||||||
|
|
||||||
out_path = Path('debug_out') / f'{model.__class__.__name__}_{time_stamp}'
|
# identifier = f'{seed}_{model.__class__.__name__}_{time_stamp}'
|
||||||
|
identifier = f'{seed}_{model.__class__.__name__}_{time_stamp}'
|
||||||
|
out_path /= identifier
|
||||||
|
|
||||||
# identifier = f'{seed}_{model.__class__.__name__}_{time_stamp}'
|
callbacks = CallbackList(
|
||||||
identifier = f'{seed}_{model.__class__.__name__}_{time_stamp}'
|
[MonitorCallback(filepath=out_path / f'monitor_{identifier}.pick', plotting=False)]
|
||||||
out_path /= identifier
|
)
|
||||||
|
|
||||||
callbacks = CallbackList(
|
model.learn(total_timesteps=int(1e5), callback=callbacks)
|
||||||
[MonitorCallback(filepath=out_path / f'monitor_{identifier}.pick', plotting=False)]
|
|
||||||
)
|
|
||||||
|
|
||||||
model.learn(total_timesteps=int(1e5), callback=callbacks)
|
save_path = out_path / f'model_{identifier}.zip'
|
||||||
|
save_path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
save_path = out_path / f'model_{identifier}.zip'
|
model.save(save_path)
|
||||||
save_path.parent.mkdir(parents=True, exist_ok=True)
|
env.save_params(out_path.parent / f'env_{model.__class__.__name__}_{time_stamp}.yaml')
|
||||||
model.save(save_path)
|
|
||||||
env.save_params(out_path.parent / f'env_{model.__class__.__name__}_{time_stamp}.yaml')
|
|
||||||
|
|
||||||
if out_path:
|
if out_path:
|
||||||
combine_runs(out_path.parent)
|
combine_runs(out_path.parent)
|
||||||
|
23
main_test.py
23
main_test.py
@ -3,13 +3,14 @@ import warnings
|
|||||||
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
import yaml
|
import yaml
|
||||||
|
from gym.wrappers import FrameStack
|
||||||
from natsort import natsorted
|
from natsort import natsorted
|
||||||
|
|
||||||
from stable_baselines3.common.callbacks import CallbackList
|
from stable_baselines3.common.callbacks import CallbackList
|
||||||
from stable_baselines3 import PPO, DQN, A2C
|
from stable_baselines3 import PPO, DQN, A2C
|
||||||
|
|
||||||
# our imports
|
# our imports
|
||||||
from environments.factory.simple_factory import SimpleFactory
|
from environments.factory.simple_factory import SimpleFactory, DirtProperties
|
||||||
from environments.logging.monitor import MonitorCallback
|
from environments.logging.monitor import MonitorCallback
|
||||||
from algorithms.reg_dqn import RegDQN
|
from algorithms.reg_dqn import RegDQN
|
||||||
from main import compare_runs, combine_runs
|
from main import compare_runs, combine_runs
|
||||||
@ -28,7 +29,7 @@ if __name__ == '__main__':
|
|||||||
# rewards += [total reward]
|
# rewards += [total reward]
|
||||||
# boxplot total rewards
|
# boxplot total rewards
|
||||||
|
|
||||||
run_id = '1623078961'
|
run_id = '1623241962'
|
||||||
model_name = 'PPO'
|
model_name = 'PPO'
|
||||||
|
|
||||||
# -----------------------
|
# -----------------------
|
||||||
@ -45,9 +46,13 @@ if __name__ == '__main__':
|
|||||||
for seed in range(3):
|
for seed in range(3):
|
||||||
with (model_path / f'env_{model_path.name}.yaml').open('r') as f:
|
with (model_path / f'env_{model_path.name}.yaml').open('r') as f:
|
||||||
env_kwargs = yaml.load(f, Loader=yaml.FullLoader)
|
env_kwargs = yaml.load(f, Loader=yaml.FullLoader)
|
||||||
env_kwargs.update(n_agents=2)
|
dirt_props = DirtProperties(clean_amount=3, gain_amount=0.2, max_global_amount=30,
|
||||||
|
max_local_amount=3, spawn_frequency=1, max_spawn_ratio=0.05)
|
||||||
|
env_kwargs.update(n_agents=1, dirt_properties=dirt_props)
|
||||||
env = SimpleFactory(**env_kwargs)
|
env = SimpleFactory(**env_kwargs)
|
||||||
|
|
||||||
|
env = FrameStack(env, 4)
|
||||||
|
|
||||||
exp_out_path = model_path / 'exp'
|
exp_out_path = model_path / 'exp'
|
||||||
callbacks = CallbackList(
|
callbacks = CallbackList(
|
||||||
[MonitorCallback(filepath=exp_out_path / f'future_exp_name', plotting=True)]
|
[MonitorCallback(filepath=exp_out_path / f'future_exp_name', plotting=True)]
|
||||||
@ -58,13 +63,19 @@ if __name__ == '__main__':
|
|||||||
for epoch in range(100):
|
for epoch in range(100):
|
||||||
observations = env.reset()
|
observations = env.reset()
|
||||||
if render:
|
if render:
|
||||||
env.render()
|
if isinstance(env, FrameStack):
|
||||||
|
env.env.render()
|
||||||
|
else:
|
||||||
|
env.render()
|
||||||
done_bool = False
|
done_bool = False
|
||||||
r = 0
|
r = 0
|
||||||
while not done_bool:
|
while not done_bool:
|
||||||
actions = [model.predict(obs, deterministic=False)[0] for obs in observations]
|
if env.n_agents > 1:
|
||||||
|
actions = [model.predict(obs, deterministic=False)[0] for obs in observations]
|
||||||
|
else:
|
||||||
|
actions = model.predict(observations, deterministic=False)[0]
|
||||||
|
|
||||||
obs, r, done_bool, info_obj = env.step(actions)
|
observations, r, done_bool, info_obj = env.step(actions)
|
||||||
if render:
|
if render:
|
||||||
env.render()
|
env.render()
|
||||||
if done_bool:
|
if done_bool:
|
||||||
|
@ -21,14 +21,14 @@ if __name__ == '__main__':
|
|||||||
|
|
||||||
with (model_path / f'env_{model_name}.yaml').open('r') as f:
|
with (model_path / f'env_{model_name}.yaml').open('r') as f:
|
||||||
env_kwargs = yaml.load(f, Loader=yaml.FullLoader)
|
env_kwargs = yaml.load(f, Loader=yaml.FullLoader)
|
||||||
env = SimpleFactory(level_name='rooms', **env_kwargs)
|
with SimpleFactory(level_name='rooms', **env_kwargs) as env:
|
||||||
|
|
||||||
# Edit THIS:
|
# Edit THIS:
|
||||||
model_files = list(natsorted((model_path / f'{run_id}_{model_name}').rglob('model_*.zip')))
|
model_files = list(natsorted((model_path / f'{run_id}_{model_name}').rglob('model_*.zip')))
|
||||||
this_model = model_files[0]
|
this_model = model_files[0]
|
||||||
|
|
||||||
|
model = PPO.load(this_model)
|
||||||
|
evaluation_result = evaluate_policy(model, env, n_eval_episodes=100, deterministic=False, render=True)
|
||||||
|
print(evaluation_result)
|
||||||
|
|
||||||
model = PPO.load(this_model)
|
|
||||||
evaluation_result = evaluate_policy(model, env, n_eval_episodes=100, deterministic=False, render=True)
|
|
||||||
print(evaluation_result)
|
|
||||||
|
|
||||||
env.close()
|
|
||||||
|
Reference in New Issue
Block a user