Smaller fixes, now running.
This commit is contained in:
parent
84eb381307
commit
d0e1175ff1
@ -3,23 +3,24 @@ import time
|
|||||||
from enum import Enum
|
from enum import Enum
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import List, Union, Iterable, Dict
|
from typing import List, Union, Iterable, Dict
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
import gym
|
import gym
|
||||||
import numpy as np
|
|
||||||
from gym import spaces
|
from gym import spaces
|
||||||
|
|
||||||
import yaml
|
|
||||||
from gym.wrappers import FrameStack
|
from gym.wrappers import FrameStack
|
||||||
|
|
||||||
from environments.factory.base.shadow_casting import Map
|
from environments.factory.base.shadow_casting import Map
|
||||||
from environments.factory.renderer import Renderer, RenderEntity
|
from environments.factory.renderer import Renderer, RenderEntity
|
||||||
from environments.helpers import Constants as c, Constants
|
from environments.helpers import Constants as c, Constants
|
||||||
from environments import helpers as h
|
from environments import helpers as h
|
||||||
from environments.factory.base.objects import Agent, Tile, Action
|
from environments.factory.base.objects import Agent, Tile, Action, Wall
|
||||||
from environments.factory.base.registers import Actions, Entities, Agents, Doors, FloorTiles, WallTiles
|
from environments.factory.base.registers import Actions, Entities, Agents, Doors, FloorTiles, WallTiles
|
||||||
from environments.utility_classes import MovementProperties
|
from environments.utility_classes import MovementProperties
|
||||||
|
|
||||||
REC_TAC = 'rec'
|
import simplejson
|
||||||
|
|
||||||
|
|
||||||
|
REC_TAC = 'rec_'
|
||||||
|
|
||||||
|
|
||||||
# noinspection PyAttributeOutsideInit
|
# noinspection PyAttributeOutsideInit
|
||||||
@ -67,6 +68,8 @@ class BaseFactory(gym.Env):
|
|||||||
self.env_seed = env_seed
|
self.env_seed = env_seed
|
||||||
self.seed(env_seed)
|
self.seed(env_seed)
|
||||||
self._base_rng = np.random.default_rng(self.env_seed)
|
self._base_rng = np.random.default_rng(self.env_seed)
|
||||||
|
if isinstance(movement_properties, dict):
|
||||||
|
movement_properties = MovementProperties(**movement_properties)
|
||||||
self.movement_properties = movement_properties
|
self.movement_properties = movement_properties
|
||||||
self.level_name = level_name
|
self.level_name = level_name
|
||||||
self._level_shape = None
|
self._level_shape = None
|
||||||
@ -118,7 +121,7 @@ class BaseFactory(gym.Env):
|
|||||||
entities.update({c.FLOOR: floor})
|
entities.update({c.FLOOR: floor})
|
||||||
|
|
||||||
# NOPOS
|
# NOPOS
|
||||||
self.NO_POS_TILE = Tile(c.NO_POS.value)
|
self._NO_POS_TILE = Tile(c.NO_POS.value)
|
||||||
|
|
||||||
# Doors
|
# Doors
|
||||||
if self.parse_doors:
|
if self.parse_doors:
|
||||||
@ -175,7 +178,7 @@ class BaseFactory(gym.Env):
|
|||||||
|
|
||||||
def step(self, actions):
|
def step(self, actions):
|
||||||
|
|
||||||
if self.n_agents == 1:
|
if self.n_agents == 1 and not isinstance(actions, list):
|
||||||
actions = [int(actions)]
|
actions = [int(actions)]
|
||||||
|
|
||||||
assert isinstance(actions, Iterable), f'"actions" has to be in [{int, list}]'
|
assert isinstance(actions, Iterable), f'"actions" has to be in [{int, list}]'
|
||||||
@ -470,16 +473,16 @@ class BaseFactory(gym.Env):
|
|||||||
d = {key: val for key, val in self.__dict__.items() if not key.startswith('_') and not key.startswith('__')}
|
d = {key: val for key, val in self.__dict__.items() if not key.startswith('_') and not key.startswith('__')}
|
||||||
filepath.parent.mkdir(parents=True, exist_ok=True)
|
filepath.parent.mkdir(parents=True, exist_ok=True)
|
||||||
with filepath.open('w') as f:
|
with filepath.open('w') as f:
|
||||||
yaml.dump(d, f)
|
simplejson.dump(d, f, indent=4, namedtuple_as_object=True)
|
||||||
# pickle.dump(d, f, protocol=pickle.HIGHEST_PROTOCOL)
|
|
||||||
|
|
||||||
def _summarize_state(self):
|
def _summarize_state(self):
|
||||||
summary = {f'{REC_TAC}_step': self._steps}
|
summary = {f'{REC_TAC}step': self._steps}
|
||||||
|
|
||||||
self[c.WALLS].summarize_state()
|
if self._steps == 0:
|
||||||
for entity in self._entities:
|
summary.update({f'{REC_TAC}{self[c.WALLS].name}': {self[c.WALLS].summarize_states()}})
|
||||||
if hasattr(entity, 'summarize_state'):
|
for entity_group in self._entities:
|
||||||
summary.update({f'{REC_TAC}_{entity.name}': entity.summarize_state()})
|
if not isinstance(entity_group, WallTiles):
|
||||||
|
summary.update({f'{REC_TAC}{entity_group.name}': entity_group.summarize_states()})
|
||||||
return summary
|
return summary
|
||||||
|
|
||||||
def print(self, string):
|
def print(self, string):
|
||||||
|
@ -124,6 +124,9 @@ class Tile(Object):
|
|||||||
def __repr__(self):
|
def __repr__(self):
|
||||||
return f'{self.name}(@{self.pos})'
|
return f'{self.name}(@{self.pos})'
|
||||||
|
|
||||||
|
def summarize_state(self):
|
||||||
|
return dict(name=self.name, x=self.x, y=self.y)
|
||||||
|
|
||||||
|
|
||||||
class Wall(Tile):
|
class Wall(Tile):
|
||||||
pass
|
pass
|
||||||
@ -160,8 +163,9 @@ class Entity(Object):
|
|||||||
self._tile = tile
|
self._tile = tile
|
||||||
tile.enter(self)
|
tile.enter(self)
|
||||||
|
|
||||||
def summarize_state(self):
|
def summarize_state(self) -> dict:
|
||||||
return self.__dict__.copy()
|
return dict(name=str(self.name), x=int(self.x), y=int(self.y),
|
||||||
|
tile=str(self.tile.name), can_collide=bool(self.can_collide))
|
||||||
|
|
||||||
def __repr__(self):
|
def __repr__(self):
|
||||||
return f'{self.name}(@{self.pos})'
|
return f'{self.name}(@{self.pos})'
|
||||||
@ -180,6 +184,10 @@ class Door(Entity):
|
|||||||
def encoding(self):
|
def encoding(self):
|
||||||
return 1 if self.is_closed else 0.5
|
return 1 if self.is_closed else 0.5
|
||||||
|
|
||||||
|
@property
|
||||||
|
def str_state(self):
|
||||||
|
return 'open' if self.is_open else 'closed'
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def access_area(self):
|
def access_area(self):
|
||||||
return [node for node in self.connectivity.nodes
|
return [node for node in self.connectivity.nodes
|
||||||
@ -206,6 +214,11 @@ class Door(Entity):
|
|||||||
if not closed_on_init:
|
if not closed_on_init:
|
||||||
self._open()
|
self._open()
|
||||||
|
|
||||||
|
def summarize_state(self):
|
||||||
|
state_dict = super().summarize_state()
|
||||||
|
state_dict.update(state=str(self.str_state), time_to_close=int(self.time_to_close))
|
||||||
|
return state_dict
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def is_closed(self):
|
def is_closed(self):
|
||||||
return self._state == c.CLOSED_DOOR
|
return self._state == c.CLOSED_DOOR
|
||||||
@ -296,3 +309,8 @@ class Agent(MoveableEntity):
|
|||||||
self.temp_valid = None
|
self.temp_valid = None
|
||||||
self.temp_action = None
|
self.temp_action = None
|
||||||
self.temp_light_map = None
|
self.temp_light_map = None
|
||||||
|
|
||||||
|
def summarize_state(self):
|
||||||
|
state_dict = super().summarize_state()
|
||||||
|
state_dict.update(valid=bool(self.temp_valid), action=str(self.temp_action))
|
||||||
|
return state_dict
|
||||||
|
@ -15,7 +15,7 @@ class Register:
|
|||||||
|
|
||||||
@property
|
@property
|
||||||
def name(self):
|
def name(self):
|
||||||
return self.__class__.__name__
|
return f'{self.__class__.__name__}'
|
||||||
|
|
||||||
def __init__(self, *args, **kwargs):
|
def __init__(self, *args, **kwargs):
|
||||||
self._register = dict()
|
self._register = dict()
|
||||||
@ -78,6 +78,9 @@ class ObjectRegister(Register):
|
|||||||
if self.individual_slices:
|
if self.individual_slices:
|
||||||
self._array = np.concatenate((self._array, np.zeros((1, *self._level_shape))))
|
self._array = np.concatenate((self._array, np.zeros((1, *self._level_shape))))
|
||||||
|
|
||||||
|
def summarize_states(self):
|
||||||
|
return [val.summarize_state() for val in self.values()]
|
||||||
|
|
||||||
|
|
||||||
class EntityObjectRegister(ObjectRegister, ABC):
|
class EntityObjectRegister(ObjectRegister, ABC):
|
||||||
|
|
||||||
@ -154,8 +157,8 @@ class Entities(Register):
|
|||||||
def __init__(self):
|
def __init__(self):
|
||||||
super(Entities, self).__init__()
|
super(Entities, self).__init__()
|
||||||
|
|
||||||
def __iter__(self):
|
def iter_individual_entitites(self):
|
||||||
return iter([x for sublist in self.values() for x in sublist])
|
return iter((x for sublist in self.values() for x in sublist))
|
||||||
|
|
||||||
def register_item(self, other: dict):
|
def register_item(self, other: dict):
|
||||||
assert not any([key for key in other.keys() if key in self.keys()]), \
|
assert not any([key for key in other.keys() if key in self.keys()]), \
|
||||||
|
@ -167,6 +167,8 @@ class ItemProperties(NamedTuple):
|
|||||||
class DoubleTaskFactory(SimpleFactory):
|
class DoubleTaskFactory(SimpleFactory):
|
||||||
# noinspection PyMissingConstructor
|
# noinspection PyMissingConstructor
|
||||||
def __init__(self, item_properties: ItemProperties, *args, env_seed=time.time_ns(), **kwargs):
|
def __init__(self, item_properties: ItemProperties, *args, env_seed=time.time_ns(), **kwargs):
|
||||||
|
if isinstance(item_properties, dict):
|
||||||
|
item_properties = ItemProperties(**item_properties)
|
||||||
self.item_properties = item_properties
|
self.item_properties = item_properties
|
||||||
kwargs.update(env_seed=env_seed)
|
kwargs.update(env_seed=env_seed)
|
||||||
self._item_rng = np.random.default_rng(env_seed)
|
self._item_rng = np.random.default_rng(env_seed)
|
||||||
@ -210,7 +212,7 @@ class DoubleTaskFactory(SimpleFactory):
|
|||||||
elif item := self[c.ITEM].by_pos(agent.pos):
|
elif item := self[c.ITEM].by_pos(agent.pos):
|
||||||
try:
|
try:
|
||||||
inventory.append(item)
|
inventory.append(item)
|
||||||
item.move(self.NO_POS_TILE)
|
item.move(self._NO_POS_TILE)
|
||||||
return c.VALID
|
return c.VALID
|
||||||
except RuntimeError:
|
except RuntimeError:
|
||||||
return c.NOT_VALID
|
return c.NOT_VALID
|
||||||
|
@ -1,5 +1,6 @@
|
|||||||
import time
|
import time
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
|
from pathlib import Path
|
||||||
from typing import List, Union, NamedTuple, Dict
|
from typing import List, Union, NamedTuple, Dict
|
||||||
import random
|
import random
|
||||||
|
|
||||||
@ -12,7 +13,7 @@ from environments.factory.base.objects import Agent, Action, Entity, Tile
|
|||||||
from environments.factory.base.registers import Entities, MovingEntityObjectRegister
|
from environments.factory.base.registers import Entities, MovingEntityObjectRegister
|
||||||
|
|
||||||
from environments.factory.renderer import RenderEntity
|
from environments.factory.renderer import RenderEntity
|
||||||
from environments.utility_classes import MovementProperties
|
from environments.logging.recorder import RecorderCallback
|
||||||
|
|
||||||
|
|
||||||
CLEAN_UP_ACTION = h.EnvActions.CLEAN_UP
|
CLEAN_UP_ACTION = h.EnvActions.CLEAN_UP
|
||||||
@ -50,6 +51,11 @@ class Dirt(Entity):
|
|||||||
def set_new_amount(self, amount):
|
def set_new_amount(self, amount):
|
||||||
self._amount = amount
|
self._amount = amount
|
||||||
|
|
||||||
|
def summarize_state(self):
|
||||||
|
state_dict = super().summarize_state()
|
||||||
|
state_dict.update(amount=float(self.amount))
|
||||||
|
return state_dict
|
||||||
|
|
||||||
|
|
||||||
class DirtRegister(MovingEntityObjectRegister):
|
class DirtRegister(MovingEntityObjectRegister):
|
||||||
|
|
||||||
@ -127,6 +133,8 @@ class SimpleFactory(BaseFactory):
|
|||||||
return super_entities
|
return super_entities
|
||||||
|
|
||||||
def __init__(self, *args, dirt_properties: DirtProperties = DirtProperties(), env_seed=time.time_ns(), **kwargs):
|
def __init__(self, *args, dirt_properties: DirtProperties = DirtProperties(), env_seed=time.time_ns(), **kwargs):
|
||||||
|
if isinstance(dirt_properties, dict):
|
||||||
|
dirt_properties = DirtProperties(**dirt_properties)
|
||||||
self.dirt_properties = dirt_properties
|
self.dirt_properties = dirt_properties
|
||||||
self._dirt_rng = np.random.default_rng(env_seed)
|
self._dirt_rng = np.random.default_rng(env_seed)
|
||||||
self._dirt: DirtRegister
|
self._dirt: DirtRegister
|
||||||
@ -235,30 +243,41 @@ class SimpleFactory(BaseFactory):
|
|||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
render = True
|
render = False
|
||||||
|
|
||||||
dirt_props = DirtProperties(1, 0.05, 0.1, 3, 1, 20, 0.0)
|
dirt_props = DirtProperties(1, 0.05, 0.1, 3, 1, 20, 0.0)
|
||||||
move_props = MovementProperties(True, True, False)
|
move_props = {'allow_square_movement': True,
|
||||||
|
'allow_diagonal_movement': False,
|
||||||
|
'allow_no_op': False} #MovementProperties(True, True, False)
|
||||||
|
|
||||||
|
with RecorderCallback(filepath=Path('debug_out') / f'recorder_xxxx.json', occupation_map=False,
|
||||||
|
trajectory_map=False) as recorder:
|
||||||
|
|
||||||
factory = SimpleFactory(n_agents=1, done_at_collision=False, frames_to_stack=0,
|
factory = SimpleFactory(n_agents=1, done_at_collision=False, frames_to_stack=0,
|
||||||
level_name='rooms', max_steps=400, combin_agent_obs=True,
|
level_name='rooms', max_steps=400, combin_agent_obs=True,
|
||||||
omit_agent_in_obs=True, parse_doors=False, pomdp_r=2,
|
omit_agent_in_obs=True, parse_doors=True, pomdp_r=3,
|
||||||
record_episodes=False, verbose=True, cast_shadows=False
|
record_episodes=True, verbose=True, cast_shadows=True,
|
||||||
|
movement_properties=move_props, dirt_properties=dirt_props
|
||||||
)
|
)
|
||||||
|
|
||||||
# noinspection DuplicatedCode
|
# noinspection DuplicatedCode
|
||||||
n_actions = factory.action_space.n - 1
|
n_actions = factory.action_space.n - 1
|
||||||
_ = factory.observation_space
|
_ = factory.observation_space
|
||||||
|
|
||||||
for epoch in range(100):
|
for epoch in range(4):
|
||||||
random_actions = [[random.randint(0, n_actions) for _ in range(factory.n_agents)] for _ in range(200)]
|
random_actions = [[random.randint(0, n_actions) for _
|
||||||
|
in range(factory.n_agents)] for _
|
||||||
|
in range(factory.max_steps+1)]
|
||||||
env_state = factory.reset()
|
env_state = factory.reset()
|
||||||
r = 0
|
r = 0
|
||||||
for agent_i_action in random_actions:
|
for agent_i_action in random_actions:
|
||||||
env_state, step_r, done_bool, info_obj = factory.step(agent_i_action)
|
env_state, step_r, done_bool, info_obj = factory.step(agent_i_action)
|
||||||
|
recorder.read_info(0, info_obj)
|
||||||
r += step_r
|
r += step_r
|
||||||
if render:
|
if render:
|
||||||
factory.render()
|
factory.render()
|
||||||
if done_bool:
|
if done_bool:
|
||||||
|
recorder.read_done(0, done_bool)
|
||||||
break
|
break
|
||||||
print(f'Factory run {epoch} done, reward is:\n {r}')
|
print(f'Factory run {epoch} done, reward is:\n {r}')
|
||||||
|
pass
|
||||||
|
@ -1,4 +1,5 @@
|
|||||||
import json
|
import json
|
||||||
|
from collections import defaultdict
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Union
|
from typing import Union
|
||||||
|
|
||||||
@ -9,6 +10,7 @@ from environments.factory.base.base_factory import REC_TAC
|
|||||||
from environments.helpers import IGNORED_DF_COLUMNS
|
from environments.helpers import IGNORED_DF_COLUMNS
|
||||||
|
|
||||||
|
|
||||||
|
# noinspection PyAttributeOutsideInit
|
||||||
class RecorderCallback(BaseCallback):
|
class RecorderCallback(BaseCallback):
|
||||||
|
|
||||||
def __init__(self, filepath: Union[str, Path], occupation_map: bool = False, trajectory_map: bool = False):
|
def __init__(self, filepath: Union[str, Path], occupation_map: bool = False, trajectory_map: bool = False):
|
||||||
@ -16,66 +18,41 @@ class RecorderCallback(BaseCallback):
|
|||||||
self.trajectory_map = trajectory_map
|
self.trajectory_map = trajectory_map
|
||||||
self.occupation_map = occupation_map
|
self.occupation_map = occupation_map
|
||||||
self.filepath = Path(filepath)
|
self.filepath = Path(filepath)
|
||||||
self._recorder_dict = dict()
|
self._recorder_dict = defaultdict(dict)
|
||||||
self._recorder_df = pd.DataFrame()
|
self._recorder_json_list = list()
|
||||||
self.do_record: bool
|
self.do_record: bool
|
||||||
self.started = False
|
self.started = False
|
||||||
self.closed = False
|
self.closed = False
|
||||||
|
|
||||||
def _on_step(self) -> bool:
|
def read_info(self, env_idx, info: dict):
|
||||||
if self.do_record and self.started:
|
if info_dict := {key.replace(REC_TAC, ''): val for key, val in info.items() if key.startswith(f'{REC_TAC}')}:
|
||||||
for _, info in enumerate(self.locals.get('infos', [])):
|
info_dict.update(episode=(self.num_timesteps + env_idx))
|
||||||
self._recorder_dict[self.num_timesteps] = {key: val for key, val in info.items()
|
self._recorder_dict[env_idx][len(self._recorder_dict[env_idx])] = info_dict
|
||||||
if not key.startswith(f'{REC_TAC}_')}
|
else:
|
||||||
|
pass
|
||||||
|
return
|
||||||
|
|
||||||
for env_idx, done in list(enumerate(self.locals.get('dones', []))) + \
|
def read_done(self, env_idx, done):
|
||||||
list(enumerate(self.locals.get('done', []))):
|
|
||||||
if done:
|
if done:
|
||||||
env_monitor_df = pd.DataFrame.from_dict(self._recorder_dict, orient='index')
|
self._recorder_json_list.append(json.dumps(self._recorder_dict[env_idx]))
|
||||||
self._recorder_dict = dict()
|
self._recorder_dict[env_idx] = dict()
|
||||||
columns = [col for col in env_monitor_df.columns if col not in IGNORED_DF_COLUMNS]
|
|
||||||
env_monitor_df = env_monitor_df.aggregate(
|
|
||||||
{col: 'mean' if col.endswith('ount') else 'sum' for col in columns}
|
|
||||||
)
|
|
||||||
env_monitor_df['episode'] = len(self._recorder_df)
|
|
||||||
self._recorder_df = self._recorder_df.append([env_monitor_df])
|
|
||||||
else:
|
else:
|
||||||
pass
|
pass
|
||||||
else:
|
|
||||||
pass
|
|
||||||
return True
|
|
||||||
|
|
||||||
def __enter__(self):
|
def start(self, force=False):
|
||||||
self._on_training_start()
|
if (hasattr(self.training_env, 'record_episodes') and self.training_env.record_episodes) or force:
|
||||||
|
|
||||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
|
||||||
self._on_training_end()
|
|
||||||
|
|
||||||
def _on_training_start(self) -> None:
|
|
||||||
if self.started:
|
|
||||||
pass
|
|
||||||
else:
|
|
||||||
if hasattr(self.training_env, 'record_episodes'):
|
|
||||||
if self.training_env.record_episodes:
|
|
||||||
self.do_record = True
|
self.do_record = True
|
||||||
self.filepath.parent.mkdir(exist_ok=True, parents=True)
|
self.filepath.parent.mkdir(exist_ok=True, parents=True)
|
||||||
self.started = True
|
self.started = True
|
||||||
else:
|
else:
|
||||||
self.do_record = False
|
self.do_record = False
|
||||||
else:
|
|
||||||
self.do_record = False
|
|
||||||
pass
|
|
||||||
|
|
||||||
def _on_training_end(self) -> None:
|
def stop(self):
|
||||||
if self.closed:
|
|
||||||
pass
|
|
||||||
else:
|
|
||||||
if self.do_record and self.started:
|
if self.do_record and self.started:
|
||||||
# self.out_file.unlink(missing_ok=True)
|
# self.out_file.unlink(missing_ok=True)
|
||||||
with self.filepath.open('w') as f:
|
with self.filepath.open('w') as f:
|
||||||
json_df = self._recorder_df.to_json(orient="table")
|
json_list = self._recorder_json_list
|
||||||
parsed = json.loads(json_df)
|
json.dump(json_list, f, indent=4)
|
||||||
json.dump(parsed, f, indent=4)
|
|
||||||
|
|
||||||
if self.occupation_map:
|
if self.occupation_map:
|
||||||
print('Recorder files were dumped to disk, now plotting the occupation map...')
|
print('Recorder files were dumped to disk, now plotting the occupation map...')
|
||||||
@ -87,3 +64,36 @@ class RecorderCallback(BaseCallback):
|
|||||||
self.started = False
|
self.started = False
|
||||||
else:
|
else:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
def _on_step(self) -> bool:
|
||||||
|
if self.do_record and self.started:
|
||||||
|
for env_idx, info in enumerate(self.locals.get('infos', [])):
|
||||||
|
self.read_info(env_idx, info)
|
||||||
|
|
||||||
|
for env_idx, done in list(
|
||||||
|
enumerate(self.locals.get('dones', []))) + list(
|
||||||
|
enumerate(self.locals.get('done', []))):
|
||||||
|
self.read_done(env_idx, done)
|
||||||
|
else:
|
||||||
|
pass
|
||||||
|
return True
|
||||||
|
|
||||||
|
def __enter__(self):
|
||||||
|
self.start(force=True)
|
||||||
|
return self
|
||||||
|
|
||||||
|
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||||
|
self.stop()
|
||||||
|
|
||||||
|
def _on_training_start(self) -> None:
|
||||||
|
if self.started:
|
||||||
|
pass
|
||||||
|
else:
|
||||||
|
self.start()
|
||||||
|
pass
|
||||||
|
|
||||||
|
def _on_training_end(self) -> None:
|
||||||
|
if self.closed:
|
||||||
|
pass
|
||||||
|
else:
|
||||||
|
self.stop()
|
||||||
|
19
main.py
19
main.py
@ -45,8 +45,8 @@ def combine_runs(run_path: Union[str, PathLike]):
|
|||||||
value_vars=columns, var_name="Measurement",
|
value_vars=columns, var_name="Measurement",
|
||||||
value_name="Score")
|
value_name="Score")
|
||||||
|
|
||||||
if df_melted['Episode'].max() > 100:
|
if df_melted['Episode'].max() > 80:
|
||||||
skip_n = round(df_melted['Episode'].max() * 0.01)
|
skip_n = round(df_melted['Episode'].max() * 0.02, 2)
|
||||||
df_melted = df_melted[df_melted['Episode'] % skip_n == 0]
|
df_melted = df_melted[df_melted['Episode'] % skip_n == 0]
|
||||||
|
|
||||||
prepare_plot(run_path / f'{run_path.name}_monitor_lineplot.png', df_melted)
|
prepare_plot(run_path / f'{run_path.name}_monitor_lineplot.png', df_melted)
|
||||||
@ -72,15 +72,18 @@ def compare_runs(run_path: Path, run_identifier: int, parameter: Union[str, List
|
|||||||
df = df.fillna(0).rename(columns={'episode': 'Episode', 'run': 'Run', 'model': 'Model'})
|
df = df.fillna(0).rename(columns={'episode': 'Episode', 'run': 'Run', 'model': 'Model'})
|
||||||
columns = [col for col in df.columns if col in parameter]
|
columns = [col for col in df.columns if col in parameter]
|
||||||
|
|
||||||
roll_n = 40
|
last_episode_to_report = min(df.groupby(['Model'])['Episode'].max())
|
||||||
|
df = df[df['Episode'] < last_episode_to_report]
|
||||||
|
|
||||||
|
roll_n = 40
|
||||||
non_overlapp_window = df.groupby(['Model', 'Run', 'Episode']).rolling(roll_n, min_periods=1).mean()
|
non_overlapp_window = df.groupby(['Model', 'Run', 'Episode']).rolling(roll_n, min_periods=1).mean()
|
||||||
|
|
||||||
df_melted = non_overlapp_window[columns].reset_index().melt(id_vars=['Episode', 'Run', 'Model'],
|
df_melted = non_overlapp_window[columns].reset_index().melt(id_vars=['Episode', 'Run', 'Model'],
|
||||||
value_vars=columns, var_name="Measurement",
|
value_vars=columns, var_name="Measurement",
|
||||||
value_name="Score")
|
value_name="Score")
|
||||||
|
|
||||||
if df_melted['Episode'].max() > 100:
|
if df_melted['Episode'].max() > 100:
|
||||||
skip_n = round(df_melted['Episode'].max() * 0.01)
|
skip_n = round(df_melted['Episode'].max() * 0.02)
|
||||||
df_melted = df_melted[df_melted['Episode'] % skip_n == 0]
|
df_melted = df_melted[df_melted['Episode'] % skip_n == 0]
|
||||||
|
|
||||||
style = 'Measurement' if len(columns) > 1 else None
|
style = 'Measurement' if len(columns) > 1 else None
|
||||||
@ -113,10 +116,10 @@ if __name__ == '__main__':
|
|||||||
max_local_amount=1, spawn_frequency=3, max_spawn_ratio=0.05,
|
max_local_amount=1, spawn_frequency=3, max_spawn_ratio=0.05,
|
||||||
dirt_smear_amount=0.0, agent_can_interact=True)
|
dirt_smear_amount=0.0, agent_can_interact=True)
|
||||||
item_props = ItemProperties(n_items=5, agent_can_interact=True)
|
item_props = ItemProperties(n_items=5, agent_can_interact=True)
|
||||||
move_props = MovementProperties(allow_diagonal_movement=False,
|
move_props = MovementProperties(allow_diagonal_movement=True,
|
||||||
allow_square_movement=True,
|
allow_square_movement=True,
|
||||||
allow_no_op=False)
|
allow_no_op=False)
|
||||||
train_steps = 1e5
|
train_steps = 8e5
|
||||||
time_stamp = int(time.time())
|
time_stamp = int(time.time())
|
||||||
|
|
||||||
out_path = None
|
out_path = None
|
||||||
@ -129,14 +132,14 @@ if __name__ == '__main__':
|
|||||||
dirt_properties=dirt_props,
|
dirt_properties=dirt_props,
|
||||||
movement_properties=move_props,
|
movement_properties=move_props,
|
||||||
pomdp_r=2, max_steps=400, parse_doors=True,
|
pomdp_r=2, max_steps=400, parse_doors=True,
|
||||||
level_name='simple', frames_to_stack=6,
|
level_name='rooms', frames_to_stack=3,
|
||||||
omit_agent_in_obs=True, combin_agent_obs=True, record_episodes=False,
|
omit_agent_in_obs=True, combin_agent_obs=True, record_episodes=False,
|
||||||
cast_shadows=True, doors_have_area=False, env_seed=seed, verbose=False,
|
cast_shadows=True, doors_have_area=False, env_seed=seed, verbose=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
if modeL_type.__name__ in ["PPO", "A2C"]:
|
if modeL_type.__name__ in ["PPO", "A2C"]:
|
||||||
kwargs = dict(ent_coef=0.01)
|
kwargs = dict(ent_coef=0.01)
|
||||||
env = SubprocVecEnv([make_env(env_kwargs) for _ in range(6)], start_method="spawn")
|
env = SubprocVecEnv([make_env(env_kwargs) for _ in range(10)], start_method="spawn")
|
||||||
elif modeL_type.__name__ in ["RegDQN", "DQN", "QRDQN"]:
|
elif modeL_type.__name__ in ["RegDQN", "DQN", "QRDQN"]:
|
||||||
env = make_env(env_kwargs)()
|
env = make_env(env_kwargs)()
|
||||||
kwargs = dict(buffer_size=50000,
|
kwargs = dict(buffer_size=50000,
|
||||||
|
Loading…
x
Reference in New Issue
Block a user