recoder adaption
This commit is contained in:
parent
4c21a0af7c
commit
696e520862
1
.gitignore
vendored
1
.gitignore
vendored
@ -701,3 +701,4 @@ $RECYCLE.BIN/
|
|||||||
*.lnk
|
*.lnk
|
||||||
|
|
||||||
# End of https://www.toptal.com/developers/gitignore/api/linux,unity,macos,python,windows,pycharm,notepadpp,visualstudiocode,latex
|
# End of https://www.toptal.com/developers/gitignore/api/linux,unity,macos,python,windows,pycharm,notepadpp,visualstudiocode,latex
|
||||||
|
/studies/e_1/
|
||||||
|
0
__init__.py
Normal file
0
__init__.py
Normal file
@ -0,0 +1 @@
|
|||||||
|
import os, sys; sys.path.append(os.path.dirname(os.path.realpath(__file__)))
|
@ -13,8 +13,8 @@ from environments.factory.base.shadow_casting import Map
|
|||||||
from environments.factory.renderer import Renderer, RenderEntity
|
from environments.factory.renderer import Renderer, RenderEntity
|
||||||
from environments.helpers import Constants as c, Constants
|
from environments.helpers import Constants as c, Constants
|
||||||
from environments import helpers as h
|
from environments import helpers as h
|
||||||
from environments.factory.base.objects import Agent, Tile, Action, Wall
|
from environments.factory.base.objects import Agent, Tile, Action
|
||||||
from environments.factory.base.registers import Actions, Entities, Agents, Doors, FloorTiles, WallTiles
|
from environments.factory.base.registers import Actions, Entities, Agents, Doors, FloorTiles, WallTiles, PlaceHolders
|
||||||
from environments.utility_classes import MovementProperties
|
from environments.utility_classes import MovementProperties
|
||||||
|
|
||||||
import simplejson
|
import simplejson
|
||||||
@ -58,7 +58,7 @@ class BaseFactory(gym.Env):
|
|||||||
def __init__(self, level_name='simple', n_agents=1, max_steps=int(5e2), pomdp_r: Union[None, int] = 0,
|
def __init__(self, level_name='simple', n_agents=1, max_steps=int(5e2), pomdp_r: Union[None, int] = 0,
|
||||||
movement_properties: MovementProperties = MovementProperties(), parse_doors=False,
|
movement_properties: MovementProperties = MovementProperties(), parse_doors=False,
|
||||||
combin_agent_obs: bool = False, frames_to_stack=0, record_episodes=False,
|
combin_agent_obs: bool = False, frames_to_stack=0, record_episodes=False,
|
||||||
omit_agent_in_obs=False, done_at_collision=False, cast_shadows=True,
|
omit_agent_in_obs=False, done_at_collision=False, cast_shadows=True, additional_agent_placeholder=None,
|
||||||
verbose=False, doors_have_area=True, env_seed=time.time_ns(), **kwargs):
|
verbose=False, doors_have_area=True, env_seed=time.time_ns(), **kwargs):
|
||||||
assert frames_to_stack != 1 and frames_to_stack >= 0, "'frames_to_stack' cannot be negative or 1."
|
assert frames_to_stack != 1 and frames_to_stack >= 0, "'frames_to_stack' cannot be negative or 1."
|
||||||
if kwargs:
|
if kwargs:
|
||||||
@ -74,6 +74,7 @@ class BaseFactory(gym.Env):
|
|||||||
self.level_name = level_name
|
self.level_name = level_name
|
||||||
self._level_shape = None
|
self._level_shape = None
|
||||||
self.verbose = verbose
|
self.verbose = verbose
|
||||||
|
self.additional_agent_placeholder = additional_agent_placeholder
|
||||||
self._renderer = None # expensive - don't use it when not required !
|
self._renderer = None # expensive - don't use it when not required !
|
||||||
self._entities = Entities()
|
self._entities = Entities()
|
||||||
|
|
||||||
@ -141,6 +142,14 @@ class BaseFactory(gym.Env):
|
|||||||
individual_slices=not self.combin_agent_obs)
|
individual_slices=not self.combin_agent_obs)
|
||||||
entities.update({c.AGENT: agents})
|
entities.update({c.AGENT: agents})
|
||||||
|
|
||||||
|
if self.additional_agent_placeholder is not None:
|
||||||
|
|
||||||
|
# Empty Observations with either [0, 1, N(0, 1)]
|
||||||
|
placeholder = PlaceHolders.from_tiles([self._NO_POS_TILE], self._level_shape,
|
||||||
|
fill_value=self.additional_agent_placeholder)
|
||||||
|
|
||||||
|
entities.update({c.AGENT_PLACEHOLDER: placeholder})
|
||||||
|
|
||||||
# All entities
|
# All entities
|
||||||
self._entities = Entities()
|
self._entities = Entities()
|
||||||
self._entities.register_additional_items(entities)
|
self._entities.register_additional_items(entities)
|
||||||
@ -155,10 +164,12 @@ class BaseFactory(gym.Env):
|
|||||||
def _init_obs_cube(self):
|
def _init_obs_cube(self):
|
||||||
arrays = self._entities.observable_arrays
|
arrays = self._entities.observable_arrays
|
||||||
|
|
||||||
|
# FIXME: Move logic to Register
|
||||||
if self.omit_agent_in_obs and self.n_agents == 1:
|
if self.omit_agent_in_obs and self.n_agents == 1:
|
||||||
del arrays[c.AGENT]
|
del arrays[c.AGENT]
|
||||||
elif self.omit_agent_in_obs:
|
# This does not seem to be necesarry, because this case is allready handled by the Agent Register Class
|
||||||
arrays[c.AGENT] = np.delete(arrays[c.AGENT], 0, axis=0)
|
# elif self.omit_agent_in_obs:
|
||||||
|
# arrays[c.AGENT] = np.delete(arrays[c.AGENT], 0, axis=0)
|
||||||
obs_cube_z = sum([a.shape[0] if not self[key].is_per_agent else 1 for key, a in arrays.items()])
|
obs_cube_z = sum([a.shape[0] if not self[key].is_per_agent else 1 for key, a in arrays.items()])
|
||||||
self._obs_cube = np.zeros((obs_cube_z, *self._level_shape), dtype=np.float32)
|
self._obs_cube = np.zeros((obs_cube_z, *self._level_shape), dtype=np.float32)
|
||||||
|
|
||||||
@ -273,6 +284,7 @@ class BaseFactory(gym.Env):
|
|||||||
agent_pos_is_omitted = False
|
agent_pos_is_omitted = False
|
||||||
agent_omit_idx = None
|
agent_omit_idx = None
|
||||||
if self.omit_agent_in_obs and self.n_agents == 1:
|
if self.omit_agent_in_obs and self.n_agents == 1:
|
||||||
|
# There is only a single agent and we want to omit the agent obs, so just remove the array.
|
||||||
del state_array_dict[c.AGENT]
|
del state_array_dict[c.AGENT]
|
||||||
elif self.omit_agent_in_obs and self.combin_agent_obs and self.n_agents > 1:
|
elif self.omit_agent_in_obs and self.combin_agent_obs and self.n_agents > 1:
|
||||||
state_array_dict[c.AGENT][0, agent.x, agent.y] -= agent.encoding
|
state_array_dict[c.AGENT][0, agent.x, agent.y] -= agent.encoding
|
||||||
@ -295,6 +307,9 @@ class BaseFactory(gym.Env):
|
|||||||
for array_idx in range(array.shape[0]):
|
for array_idx in range(array.shape[0]):
|
||||||
self._obs_cube[running_idx: running_idx+z] = array[[x for x in range(array.shape[0])
|
self._obs_cube[running_idx: running_idx+z] = array[[x for x in range(array.shape[0])
|
||||||
if x != agent_omit_idx]]
|
if x != agent_omit_idx]]
|
||||||
|
elif key == c.AGENT and self.omit_agent_in_obs and self.combin_agent_obs:
|
||||||
|
z = 1
|
||||||
|
self._obs_cube[running_idx: running_idx + z] = array
|
||||||
else:
|
else:
|
||||||
z = array.shape[0]
|
z = array.shape[0]
|
||||||
self._obs_cube[running_idx: running_idx+z] = array
|
self._obs_cube[running_idx: running_idx+z] = array
|
||||||
@ -499,12 +514,8 @@ class BaseFactory(gym.Env):
|
|||||||
def _summarize_state(self):
|
def _summarize_state(self):
|
||||||
summary = {f'{REC_TAC}step': self._steps}
|
summary = {f'{REC_TAC}step': self._steps}
|
||||||
|
|
||||||
if self._steps == 0:
|
|
||||||
summary.update({f'{REC_TAC}{self[c.WALLS].name}': {self[c.WALLS].summarize_states()},
|
|
||||||
'FactoryName': self.__class__.__name__})
|
|
||||||
for entity_group in self._entities:
|
for entity_group in self._entities:
|
||||||
if not isinstance(entity_group, WallTiles):
|
summary.update({f'{REC_TAC}{entity_group.name}': entity_group.summarize_states(n_steps=self._steps)})
|
||||||
summary.update({f'{REC_TAC}{entity_group.name}': entity_group.summarize_states()})
|
|
||||||
return summary
|
return summary
|
||||||
|
|
||||||
def print(self, string):
|
def print(self, string):
|
||||||
|
@ -93,11 +93,11 @@ class Entity(Object):
|
|||||||
return self._tile
|
return self._tile
|
||||||
|
|
||||||
def __init__(self, tile, **kwargs):
|
def __init__(self, tile, **kwargs):
|
||||||
super(Entity, self).__init__(**kwargs)
|
super().__init__(**kwargs)
|
||||||
self._tile = tile
|
self._tile = tile
|
||||||
tile.enter(self)
|
tile.enter(self)
|
||||||
|
|
||||||
def summarize_state(self) -> dict:
|
def summarize_state(self, **_) -> dict:
|
||||||
return dict(name=str(self.name), x=int(self.x), y=int(self.y),
|
return dict(name=str(self.name), x=int(self.x), y=int(self.y),
|
||||||
tile=str(self.tile.name), can_collide=bool(self.can_collide))
|
tile=str(self.tile.name), can_collide=bool(self.can_collide))
|
||||||
|
|
||||||
@ -125,7 +125,7 @@ class MoveableEntity(Entity):
|
|||||||
return last_x-curr_x, last_y-curr_y
|
return last_x-curr_x, last_y-curr_y
|
||||||
|
|
||||||
def __init__(self, *args, **kwargs):
|
def __init__(self, *args, **kwargs):
|
||||||
super(MoveableEntity, self).__init__(*args, **kwargs)
|
super().__init__(*args, **kwargs)
|
||||||
self._last_tile = None
|
self._last_tile = None
|
||||||
|
|
||||||
def move(self, next_tile):
|
def move(self, next_tile):
|
||||||
@ -143,11 +143,34 @@ class MoveableEntity(Entity):
|
|||||||
class Action(Object):
|
class Action(Object):
|
||||||
|
|
||||||
def __init__(self, *args, **kwargs):
|
def __init__(self, *args, **kwargs):
|
||||||
super(Action, self).__init__(*args, **kwargs)
|
super().__init__(*args, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
class PlaceHolder(MoveableEntity):
|
class PlaceHolder(MoveableEntity):
|
||||||
pass
|
|
||||||
|
def __init__(self, *args, fill_value=0, **kwargs):
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
self._fill_value = fill_value
|
||||||
|
|
||||||
|
@property
|
||||||
|
def last_tile(self):
|
||||||
|
return self.tile
|
||||||
|
|
||||||
|
@property
|
||||||
|
def direction_of_view(self):
|
||||||
|
return self.pos
|
||||||
|
|
||||||
|
@property
|
||||||
|
def can_collide(self):
|
||||||
|
return False
|
||||||
|
|
||||||
|
@property
|
||||||
|
def encoding(self):
|
||||||
|
return c.NO_POS.value[0]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def name(self):
|
||||||
|
return "PlaceHolder"
|
||||||
|
|
||||||
|
|
||||||
class Tile(Object):
|
class Tile(Object):
|
||||||
@ -203,8 +226,8 @@ class Tile(Object):
|
|||||||
def __repr__(self):
|
def __repr__(self):
|
||||||
return f'{self.name}(@{self.pos})'
|
return f'{self.name}(@{self.pos})'
|
||||||
|
|
||||||
def summarize_state(self):
|
def summarize_state(self, **_):
|
||||||
return dict(name=self.name, x=self.x, y=self.y)
|
return dict(name=self.name, x=int(self.x), y=int(self.y))
|
||||||
|
|
||||||
|
|
||||||
class Wall(Tile):
|
class Wall(Tile):
|
||||||
@ -254,8 +277,8 @@ class Door(Entity):
|
|||||||
if not closed_on_init:
|
if not closed_on_init:
|
||||||
self._open()
|
self._open()
|
||||||
|
|
||||||
def summarize_state(self):
|
def summarize_state(self, **kwargs):
|
||||||
state_dict = super().summarize_state()
|
state_dict = super().summarize_state(**kwargs)
|
||||||
state_dict.update(state=str(self.str_state), time_to_close=int(self.time_to_close))
|
state_dict.update(state=str(self.str_state), time_to_close=int(self.time_to_close))
|
||||||
return state_dict
|
return state_dict
|
||||||
|
|
||||||
@ -315,7 +338,7 @@ class Agent(MoveableEntity):
|
|||||||
self.temp_action = None
|
self.temp_action = None
|
||||||
self.temp_light_map = None
|
self.temp_light_map = None
|
||||||
|
|
||||||
def summarize_state(self):
|
def summarize_state(self, **kwargs):
|
||||||
state_dict = super().summarize_state()
|
state_dict = super().summarize_state(**kwargs)
|
||||||
state_dict.update(valid=bool(self.temp_valid), action=str(self.temp_action))
|
state_dict.update(valid=bool(self.temp_valid), action=str(self.temp_action))
|
||||||
return state_dict
|
return state_dict
|
||||||
|
@ -81,8 +81,8 @@ class ObjectRegister(Register):
|
|||||||
if self.individual_slices:
|
if self.individual_slices:
|
||||||
self._array = np.concatenate((self._array, np.zeros((1, *self._array.shape[1:]))))
|
self._array = np.concatenate((self._array, np.zeros((1, *self._array.shape[1:]))))
|
||||||
|
|
||||||
def summarize_states(self):
|
def summarize_states(self, n_steps=None):
|
||||||
return [val.summarize_state() for val in self.values()]
|
return [val.summarize_state(n_steps=n_steps) for val in self.values()]
|
||||||
|
|
||||||
|
|
||||||
class EntityObjectRegister(ObjectRegister, ABC):
|
class EntityObjectRegister(ObjectRegister, ABC):
|
||||||
@ -156,23 +156,25 @@ class MovingEntityObjectRegister(EntityObjectRegister, ABC):
|
|||||||
del self[name]
|
del self[name]
|
||||||
|
|
||||||
|
|
||||||
class PlaceHolderRegister(MovingEntityObjectRegister):
|
class PlaceHolders(MovingEntityObjectRegister):
|
||||||
|
|
||||||
_accepted_objects = PlaceHolder
|
_accepted_objects = PlaceHolder
|
||||||
|
|
||||||
|
def __init__(self, *args, fill_value: Union[str, int] = 0, **kwargs):
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
self.fill_value = fill_value
|
||||||
|
|
||||||
# noinspection DuplicatedCode
|
# noinspection DuplicatedCode
|
||||||
def as_array(self):
|
def as_array(self):
|
||||||
self._array[:] = c.FREE_CELL.value
|
if isinstance(self.fill_value, int):
|
||||||
# noinspection PyTupleAssignmentBalance
|
self._array[:] = self.fill_value
|
||||||
for z, x, y, v in zip(range(len(self)), *zip(*[x.pos for x in self]), [x.encoding for x in self]):
|
elif self.fill_value == "normal":
|
||||||
if self.individual_slices:
|
self._array = np.random.normal(size=self._array.shape)
|
||||||
self._array[z, x, y] += v
|
|
||||||
else:
|
|
||||||
self._array[0, x, y] += v
|
|
||||||
if self.individual_slices:
|
if self.individual_slices:
|
||||||
return self._array
|
return self._array
|
||||||
else:
|
else:
|
||||||
return self._array.sum(axis=0, keepdims=True)
|
return self._array[None, 0]
|
||||||
|
|
||||||
|
|
||||||
class Entities(Register):
|
class Entities(Register):
|
||||||
@ -243,6 +245,12 @@ class WallTiles(EntityObjectRegister):
|
|||||||
def from_tiles(cls, tiles, *args, **kwargs):
|
def from_tiles(cls, tiles, *args, **kwargs):
|
||||||
raise RuntimeError()
|
raise RuntimeError()
|
||||||
|
|
||||||
|
def summarize_states(self, n_steps=None):
|
||||||
|
if n_steps == h.STEPS_START:
|
||||||
|
return super(WallTiles, self).summarize_states(n_steps=n_steps)
|
||||||
|
else:
|
||||||
|
return {}
|
||||||
|
|
||||||
|
|
||||||
class FloorTiles(WallTiles):
|
class FloorTiles(WallTiles):
|
||||||
|
|
||||||
@ -272,6 +280,10 @@ class FloorTiles(WallTiles):
|
|||||||
def from_tiles(cls, tiles, *args, **kwargs):
|
def from_tiles(cls, tiles, *args, **kwargs):
|
||||||
raise RuntimeError()
|
raise RuntimeError()
|
||||||
|
|
||||||
|
def summarize_states(self, n_steps=None):
|
||||||
|
# Do not summarize
|
||||||
|
return {}
|
||||||
|
|
||||||
|
|
||||||
class Agents(MovingEntityObjectRegister):
|
class Agents(MovingEntityObjectRegister):
|
||||||
|
|
||||||
|
29
environments/factory/env_item_default_param.json
Normal file
29
environments/factory/env_item_default_param.json
Normal file
@ -0,0 +1,29 @@
|
|||||||
|
{
|
||||||
|
"item_properties": {
|
||||||
|
"n_items": 5,
|
||||||
|
"spawn_frequency": 10,
|
||||||
|
"n_drop_off_locations": 5,
|
||||||
|
"max_dropoff_storage_size": 0,
|
||||||
|
"max_agent_inventory_capacity": 5,
|
||||||
|
"agent_can_interact": true
|
||||||
|
},
|
||||||
|
"env_seed": 2,
|
||||||
|
"movement_properties": {
|
||||||
|
"allow_square_movement": true,
|
||||||
|
"allow_diagonal_movement": true,
|
||||||
|
"allow_no_op": false
|
||||||
|
},
|
||||||
|
"level_name": "rooms",
|
||||||
|
"verbose": false,
|
||||||
|
"n_agents": 1,
|
||||||
|
"max_steps": 400,
|
||||||
|
"pomdp_r": 2,
|
||||||
|
"combin_agent_obs": true,
|
||||||
|
"omit_agent_in_obs": true,
|
||||||
|
"cast_shadows": true,
|
||||||
|
"frames_to_stack": 3,
|
||||||
|
"done_at_collision": false,
|
||||||
|
"record_episodes": false,
|
||||||
|
"parse_doors": false,
|
||||||
|
"doors_have_area": false
|
||||||
|
}
|
@ -51,8 +51,8 @@ class Dirt(Entity):
|
|||||||
def set_new_amount(self, amount):
|
def set_new_amount(self, amount):
|
||||||
self._amount = amount
|
self._amount = amount
|
||||||
|
|
||||||
def summarize_state(self):
|
def summarize_state(self, **kwargs):
|
||||||
state_dict = super().summarize_state()
|
state_dict = super().summarize_state(**kwargs)
|
||||||
state_dict.update(amount=float(self.amount))
|
state_dict.update(amount=float(self.amount))
|
||||||
return state_dict
|
return state_dict
|
||||||
|
|
||||||
|
@ -1,7 +1,9 @@
|
|||||||
import random
|
import random
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
from environments.factory.factory_dirt import DirtFactory, DirtProperties
|
from environments.factory.factory_dirt import DirtFactory, DirtProperties
|
||||||
from environments.factory.factory_item import ItemFactory, ItemProperties
|
from environments.factory.factory_item import ItemFactory, ItemProperties
|
||||||
|
from environments.logging.recorder import RecorderCallback
|
||||||
from environments.utility_classes import MovementProperties
|
from environments.utility_classes import MovementProperties
|
||||||
|
|
||||||
|
|
||||||
@ -12,40 +14,44 @@ class DirtItemFactory(ItemFactory, DirtFactory):
|
|||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
|
with RecorderCallback(filepath=Path('debug_out') / f'recorder_xxxx.json', occupation_map=False,
|
||||||
|
trajectory_map=False) as recorder:
|
||||||
|
|
||||||
dirt_props = DirtProperties(clean_amount=2, gain_amount=0.1, max_global_amount=20,
|
dirt_props = DirtProperties(clean_amount=2, gain_amount=0.1, max_global_amount=20,
|
||||||
max_local_amount=1, spawn_frequency=3, max_spawn_ratio=0.05,
|
max_local_amount=1, spawn_frequency=3, max_spawn_ratio=0.05,
|
||||||
dirt_smear_amount=0.0, agent_can_interact=True)
|
dirt_smear_amount=0.0, agent_can_interact=True)
|
||||||
item_props = ItemProperties(n_items=5, agent_can_interact=True)
|
item_props = ItemProperties(n_items=5, agent_can_interact=True)
|
||||||
move_props = MovementProperties(allow_diagonal_movement=True,
|
move_props = MovementProperties(allow_diagonal_movement=True,
|
||||||
allow_square_movement=True,
|
allow_square_movement=True,
|
||||||
allow_no_op=False)
|
allow_no_op=False)
|
||||||
|
|
||||||
render = True
|
render = False
|
||||||
|
|
||||||
factory = DirtItemFactory(n_agents=1, done_at_collision=False, frames_to_stack=0,
|
factory = DirtItemFactory(n_agents=1, done_at_collision=False, frames_to_stack=0,
|
||||||
level_name='rooms', max_steps=400, combin_agent_obs=True,
|
level_name='rooms', max_steps=200, combin_agent_obs=True,
|
||||||
omit_agent_in_obs=True, parse_doors=True, pomdp_r=3,
|
omit_agent_in_obs=True, parse_doors=True, pomdp_r=3,
|
||||||
record_episodes=True, verbose=True, cast_shadows=True,
|
record_episodes=True, verbose=False, cast_shadows=True,
|
||||||
movement_properties=move_props, dirt_properties=dirt_props
|
movement_properties=move_props, dirt_properties=dirt_props
|
||||||
)
|
)
|
||||||
|
|
||||||
# noinspection DuplicatedCode
|
# noinspection DuplicatedCode
|
||||||
n_actions = factory.action_space.n - 1
|
n_actions = factory.action_space.n - 1
|
||||||
_ = factory.observation_space
|
_ = factory.observation_space
|
||||||
|
|
||||||
for epoch in range(4):
|
for epoch in range(4):
|
||||||
random_actions = [[random.randint(0, n_actions) for _
|
random_actions = [[random.randint(0, n_actions) for _
|
||||||
in range(factory.n_agents)] for _
|
in range(factory.n_agents)] for _
|
||||||
in range(factory.max_steps + 1)]
|
in range(factory.max_steps + 1)]
|
||||||
env_state = factory.reset()
|
env_state = factory.reset()
|
||||||
r = 0
|
r = 0
|
||||||
for agent_i_action in random_actions:
|
for agent_i_action in random_actions:
|
||||||
env_state, step_r, done_bool, info_obj = factory.step(agent_i_action)
|
env_state, step_r, done_bool, info_obj = factory.step(agent_i_action)
|
||||||
r += step_r
|
recorder.read_info(0, info_obj)
|
||||||
if render:
|
r += step_r
|
||||||
factory.render()
|
if render:
|
||||||
if done_bool:
|
factory.render()
|
||||||
break
|
if done_bool:
|
||||||
print(f'Factory run {epoch} done, reward is:\n {r}')
|
recorder.read_done(0, done_bool)
|
||||||
pass
|
break
|
||||||
|
print(f'Factory run {epoch} done, reward is:\n {r}')
|
||||||
|
pass
|
||||||
|
@ -109,8 +109,10 @@ class Inventory(UserList):
|
|||||||
def belongs_to_entity(self, entity):
|
def belongs_to_entity(self, entity):
|
||||||
return self.agent == entity
|
return self.agent == entity
|
||||||
|
|
||||||
def summarize_state(self):
|
def summarize_state(self, **kwargs):
|
||||||
return {val.name: val.summarize_state() for val in self}
|
attr_dict = {key: str(val) for key, val in self.__dict__.items() if not key.startswith('_') and key != 'data'}
|
||||||
|
attr_dict.update({val.name: val.summarize_state(**kwargs) for val in self})
|
||||||
|
return attr_dict
|
||||||
|
|
||||||
|
|
||||||
class Inventories(ObjectRegister):
|
class Inventories(ObjectRegister):
|
||||||
@ -176,6 +178,10 @@ class DropOffLocation(Entity):
|
|||||||
def is_full(self):
|
def is_full(self):
|
||||||
return False if not self.storage.maxlen else self.storage.maxlen == len(self.storage)
|
return False if not self.storage.maxlen else self.storage.maxlen == len(self.storage)
|
||||||
|
|
||||||
|
def summarize_state(self, n_steps=None) -> dict:
|
||||||
|
if n_steps == h.STEPS_START:
|
||||||
|
return super().summarize_state(n_steps=n_steps)
|
||||||
|
|
||||||
|
|
||||||
class DropOffLocations(EntityObjectRegister):
|
class DropOffLocations(EntityObjectRegister):
|
||||||
|
|
||||||
|
@ -7,7 +7,10 @@ from pathlib import Path
|
|||||||
|
|
||||||
from stable_baselines3 import PPO, DQN, A2C
|
from stable_baselines3 import PPO, DQN, A2C
|
||||||
|
|
||||||
|
MODEL_MAP = dict(PPO=PPO, DQN=DQN, A2C=A2C)
|
||||||
|
|
||||||
LEVELS_DIR = 'levels'
|
LEVELS_DIR = 'levels'
|
||||||
|
STEPS_START = 1
|
||||||
|
|
||||||
TO_BE_AVERAGED = ['dirt_amount', 'dirty_tiles']
|
TO_BE_AVERAGED = ['dirt_amount', 'dirty_tiles']
|
||||||
IGNORED_DF_COLUMNS = ['Episode', 'Run', 'train_step', 'step', 'index', 'dirt_amount',
|
IGNORED_DF_COLUMNS = ['Episode', 'Run', 'train_step', 'step', 'index', 'dirt_amount',
|
||||||
@ -16,34 +19,35 @@ IGNORED_DF_COLUMNS = ['Episode', 'Run', 'train_step', 'step', 'index', 'dirt_amo
|
|||||||
|
|
||||||
# Constants
|
# Constants
|
||||||
class Constants(Enum):
|
class Constants(Enum):
|
||||||
WALL = '#'
|
WALL = '#'
|
||||||
WALLS = 'Walls'
|
WALLS = 'Walls'
|
||||||
FLOOR = 'Floor'
|
FLOOR = 'Floor'
|
||||||
DOOR = 'D'
|
DOOR = 'D'
|
||||||
DANGER_ZONE = 'x'
|
DANGER_ZONE = 'x'
|
||||||
LEVEL = 'Level'
|
LEVEL = 'Level'
|
||||||
AGENT = 'Agent'
|
AGENT = 'Agent'
|
||||||
FREE_CELL = 0
|
AGENT_PLACEHOLDER = 'AGENT_PLACEHOLDER'
|
||||||
OCCUPIED_CELL = 1
|
FREE_CELL = 0
|
||||||
SHADOWED_CELL = -1
|
OCCUPIED_CELL = 1
|
||||||
NO_POS = (-9999, -9999)
|
SHADOWED_CELL = -1
|
||||||
|
NO_POS = (-9999, -9999)
|
||||||
|
|
||||||
DOORS = 'Doors'
|
DOORS = 'Doors'
|
||||||
CLOSED_DOOR = 'closed'
|
CLOSED_DOOR = 'closed'
|
||||||
OPEN_DOOR = 'open'
|
OPEN_DOOR = 'open'
|
||||||
|
|
||||||
ACTION = 'action'
|
ACTION = 'action'
|
||||||
COLLISIONS = 'collision'
|
COLLISIONS = 'collision'
|
||||||
VALID = 'valid'
|
VALID = 'valid'
|
||||||
NOT_VALID = 'not_valid'
|
NOT_VALID = 'not_valid'
|
||||||
|
|
||||||
# Dirt Env
|
# Dirt Env
|
||||||
DIRT = 'Dirt'
|
DIRT = 'Dirt'
|
||||||
|
|
||||||
# Item Env
|
# Item Env
|
||||||
ITEM = 'Item'
|
ITEM = 'Item'
|
||||||
INVENTORY = 'Inventory'
|
INVENTORY = 'Inventory'
|
||||||
DROP_OFF = 'Drop_Off'
|
DROP_OFF = 'Drop_Off'
|
||||||
|
|
||||||
def __bool__(self):
|
def __bool__(self):
|
||||||
if 'not_' in self.value:
|
if 'not_' in self.value:
|
||||||
@ -144,8 +148,6 @@ def asset_str(agent):
|
|||||||
return c.AGENT.value, 'idle'
|
return c.AGENT.value, 'idle'
|
||||||
|
|
||||||
|
|
||||||
model_map = dict(PPO=PPO, DQN=DQN, A2C=A2C)
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
parsed_level = parse_level(Path(__file__).parent / 'factory' / 'levels' / 'simple.txt')
|
parsed_level = parse_level(Path(__file__).parent / 'factory' / 'levels' / 'simple.txt')
|
||||||
y = one_hot_level(parsed_level)
|
y = one_hot_level(parsed_level)
|
||||||
|
@ -6,7 +6,7 @@ from typing import List, Dict
|
|||||||
from stable_baselines3.common.callbacks import BaseCallback
|
from stable_baselines3.common.callbacks import BaseCallback
|
||||||
|
|
||||||
from environments.helpers import IGNORED_DF_COLUMNS
|
from environments.helpers import IGNORED_DF_COLUMNS
|
||||||
from environments.logging.plotting import prepare_plot
|
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
|
|
||||||
|
|
||||||
@ -14,85 +14,76 @@ class MonitorCallback(BaseCallback):
|
|||||||
|
|
||||||
ext = 'png'
|
ext = 'png'
|
||||||
|
|
||||||
def __init__(self, filepath=Path('debug_out/monitor.pick'), plotting=True):
|
def __init__(self, filepath=Path('debug_out/monitor.pick')):
|
||||||
super(MonitorCallback, self).__init__()
|
super(MonitorCallback, self).__init__()
|
||||||
self.filepath = Path(filepath)
|
self.filepath = Path(filepath)
|
||||||
self._monitor_df = pd.DataFrame()
|
self._monitor_df = pd.DataFrame()
|
||||||
self._monitor_dicts = defaultdict(dict)
|
self._monitor_dicts = defaultdict(dict)
|
||||||
self.plotting = plotting
|
|
||||||
self.started = False
|
self.started = False
|
||||||
self.closed = False
|
self.closed = False
|
||||||
|
|
||||||
def __enter__(self):
|
def __enter__(self):
|
||||||
self._on_training_start()
|
self.start()
|
||||||
|
return self
|
||||||
|
|
||||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||||
self._on_training_end()
|
self.stop()
|
||||||
|
|
||||||
def _on_training_start(self) -> None:
|
def _on_training_start(self) -> None:
|
||||||
if self.started:
|
if self.started:
|
||||||
pass
|
pass
|
||||||
else:
|
else:
|
||||||
self.filepath.parent.mkdir(exist_ok=True, parents=True)
|
self.start()
|
||||||
self.started = True
|
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def _on_training_end(self) -> None:
|
def _on_training_end(self) -> None:
|
||||||
if self.closed:
|
if self.closed:
|
||||||
pass
|
pass
|
||||||
else:
|
else:
|
||||||
# self.out_file.unlink(missing_ok=True)
|
self.stop()
|
||||||
with self.filepath.open('wb') as f:
|
|
||||||
pickle.dump(self._monitor_df.reset_index(), f, protocol=pickle.HIGHEST_PROTOCOL)
|
|
||||||
if self.plotting:
|
|
||||||
print('Monitor files were dumped to disk, now plotting....')
|
|
||||||
|
|
||||||
# %% Load MonitorList from Disk
|
|
||||||
with self.filepath.open('rb') as f:
|
|
||||||
monitor_list = pickle.load(f)
|
|
||||||
df = None
|
|
||||||
for m_idx, monitor in enumerate(monitor_list):
|
|
||||||
monitor['episode'] = m_idx
|
|
||||||
if df is None:
|
|
||||||
df = pd.DataFrame(columns=monitor.columns)
|
|
||||||
for _, row in monitor.iterrows():
|
|
||||||
df.loc[df.shape[0]] = row
|
|
||||||
if df is None: # The env exited premature, we catch it.
|
|
||||||
self.closed = True
|
|
||||||
return
|
|
||||||
for column in list(df.columns):
|
|
||||||
if column != 'episode':
|
|
||||||
df[f'{column}_roll'] = df[column].rolling(window=50).mean()
|
|
||||||
# result.tail()
|
|
||||||
prepare_plot(filepath=self.filepath, results_df=df.filter(regex=(".+_roll")))
|
|
||||||
print('Plotting done.')
|
|
||||||
self.closed = True
|
|
||||||
|
|
||||||
def _on_step(self, alt_infos: List[Dict] = None, alt_dones: List[bool] = None) -> bool:
|
def _on_step(self, alt_infos: List[Dict] = None, alt_dones: List[bool] = None) -> bool:
|
||||||
infos = alt_infos or self.locals.get('infos', [])
|
if self.started:
|
||||||
if alt_dones is not None:
|
for env_idx, info in enumerate(self.locals.get('infos', [])):
|
||||||
dones = alt_dones
|
self.read_info(env_idx, info)
|
||||||
elif self.locals.get('dones', None) is not None:
|
|
||||||
dones =self.locals.get('dones', None)
|
|
||||||
elif self.locals.get('done', None) is not None:
|
|
||||||
dones = self.locals.get('done', [None])
|
|
||||||
else:
|
|
||||||
dones = []
|
|
||||||
|
|
||||||
for env_idx, (info, done) in enumerate(zip(infos, dones)):
|
for env_idx, done in list(
|
||||||
self._monitor_dicts[env_idx][len(self._monitor_dicts[env_idx])] = {key: val for key, val in info.items()
|
enumerate(self.locals.get('dones', []))) + list(enumerate(self.locals.get('done', []))):
|
||||||
if key not in ['terminal_observation', 'episode']
|
self.read_done(env_idx, done)
|
||||||
and not key.startswith('rec_')}
|
else:
|
||||||
if done:
|
pass
|
||||||
env_monitor_df = pd.DataFrame.from_dict(self._monitor_dicts[env_idx], orient='index')
|
|
||||||
self._monitor_dicts[env_idx] = dict()
|
|
||||||
columns = [col for col in env_monitor_df.columns if col not in IGNORED_DF_COLUMNS]
|
|
||||||
env_monitor_df = env_monitor_df.aggregate(
|
|
||||||
{col: 'mean' if col.endswith('ount') else 'sum' for col in columns}
|
|
||||||
)
|
|
||||||
env_monitor_df['episode'] = len(self._monitor_df)
|
|
||||||
self._monitor_df = self._monitor_df.append([env_monitor_df])
|
|
||||||
else:
|
|
||||||
pass
|
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
def read_info(self, env_idx, info: dict):
|
||||||
|
self._monitor_dicts[env_idx][len(self._monitor_dicts[env_idx])] = {
|
||||||
|
key: val for key, val in info.items() if
|
||||||
|
key not in ['terminal_observation', 'episode'] and not key.startswith('rec_')}
|
||||||
|
return
|
||||||
|
|
||||||
|
def read_done(self, env_idx, done):
|
||||||
|
if done:
|
||||||
|
env_monitor_df = pd.DataFrame.from_dict(self._monitor_dicts[env_idx], orient='index')
|
||||||
|
self._monitor_dicts[env_idx] = dict()
|
||||||
|
columns = [col for col in env_monitor_df.columns if col not in IGNORED_DF_COLUMNS]
|
||||||
|
env_monitor_df = env_monitor_df.aggregate(
|
||||||
|
{col: 'mean' if col.endswith('ount') else 'sum' for col in columns}
|
||||||
|
)
|
||||||
|
env_monitor_df['episode'] = len(self._monitor_df)
|
||||||
|
self._monitor_df = self._monitor_df.append([env_monitor_df])
|
||||||
|
else:
|
||||||
|
pass
|
||||||
|
return
|
||||||
|
|
||||||
|
def stop(self):
|
||||||
|
# self.out_file.unlink(missing_ok=True)
|
||||||
|
with self.filepath.open('wb') as f:
|
||||||
|
pickle.dump(self._monitor_df.reset_index(), f, protocol=pickle.HIGHEST_PROTOCOL)
|
||||||
|
self.closed = True
|
||||||
|
|
||||||
|
def start(self):
|
||||||
|
if self.started:
|
||||||
|
pass
|
||||||
|
else:
|
||||||
|
self.filepath.parent.mkdir(exist_ok=True, parents=True)
|
||||||
|
self.started = True
|
||||||
|
pass
|
||||||
|
@ -3,11 +3,10 @@ from collections import defaultdict
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Union
|
from typing import Union
|
||||||
|
|
||||||
import pandas as pd
|
import simplejson
|
||||||
from stable_baselines3.common.callbacks import BaseCallback
|
from stable_baselines3.common.callbacks import BaseCallback
|
||||||
|
|
||||||
from environments.factory.base.base_factory import REC_TAC
|
from environments.factory.base.base_factory import REC_TAC
|
||||||
from environments.helpers import IGNORED_DF_COLUMNS
|
|
||||||
|
|
||||||
|
|
||||||
# noinspection PyAttributeOutsideInit
|
# noinspection PyAttributeOutsideInit
|
||||||
@ -18,8 +17,8 @@ class RecorderCallback(BaseCallback):
|
|||||||
self.trajectory_map = trajectory_map
|
self.trajectory_map = trajectory_map
|
||||||
self.occupation_map = occupation_map
|
self.occupation_map = occupation_map
|
||||||
self.filepath = Path(filepath)
|
self.filepath = Path(filepath)
|
||||||
self._recorder_dict = defaultdict(dict)
|
self._recorder_dict = defaultdict(list)
|
||||||
self._recorder_json_list = list()
|
self._recorder_out_list = list()
|
||||||
self.do_record: bool
|
self.do_record: bool
|
||||||
self.started = False
|
self.started = False
|
||||||
self.closed = False
|
self.closed = False
|
||||||
@ -27,15 +26,15 @@ class RecorderCallback(BaseCallback):
|
|||||||
def read_info(self, env_idx, info: dict):
|
def read_info(self, env_idx, info: dict):
|
||||||
if info_dict := {key.replace(REC_TAC, ''): val for key, val in info.items() if key.startswith(f'{REC_TAC}')}:
|
if info_dict := {key.replace(REC_TAC, ''): val for key, val in info.items() if key.startswith(f'{REC_TAC}')}:
|
||||||
info_dict.update(episode=(self.num_timesteps + env_idx))
|
info_dict.update(episode=(self.num_timesteps + env_idx))
|
||||||
self._recorder_dict[env_idx][len(self._recorder_dict[env_idx])] = info_dict
|
self._recorder_dict[env_idx].append(info_dict)
|
||||||
else:
|
else:
|
||||||
pass
|
pass
|
||||||
return
|
return
|
||||||
|
|
||||||
def read_done(self, env_idx, done):
|
def read_done(self, env_idx, done):
|
||||||
if done:
|
if done:
|
||||||
self._recorder_json_list.append(json.dumps(self._recorder_dict[env_idx]))
|
self._recorder_out_list.append({'steps': self._recorder_dict[env_idx]})
|
||||||
self._recorder_dict[env_idx] = dict()
|
self._recorder_dict[env_idx] = list()
|
||||||
else:
|
else:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@ -51,8 +50,11 @@ class RecorderCallback(BaseCallback):
|
|||||||
if self.do_record and self.started:
|
if self.do_record and self.started:
|
||||||
# self.out_file.unlink(missing_ok=True)
|
# self.out_file.unlink(missing_ok=True)
|
||||||
with self.filepath.open('w') as f:
|
with self.filepath.open('w') as f:
|
||||||
json_list = self._recorder_json_list
|
out_dict = {'episodes': self._recorder_out_list}
|
||||||
json.dump(json_list, f, indent=4)
|
try:
|
||||||
|
simplejson.dump(out_dict, f, indent=4)
|
||||||
|
except TypeError:
|
||||||
|
print('Shit')
|
||||||
|
|
||||||
if self.occupation_map:
|
if self.occupation_map:
|
||||||
print('Recorder files were dumped to disk, now plotting the occupation map...')
|
print('Recorder files were dumped to disk, now plotting the occupation map...')
|
||||||
|
108
main.py
108
main.py
@ -1,101 +1,27 @@
|
|||||||
import pickle
|
|
||||||
import warnings
|
import warnings
|
||||||
from typing import Union, List
|
|
||||||
from os import PathLike
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
import time
|
import time
|
||||||
|
|
||||||
import pandas as pd
|
|
||||||
|
|
||||||
from stable_baselines3.common.callbacks import CallbackList
|
from stable_baselines3.common.callbacks import CallbackList
|
||||||
from stable_baselines3.common.vec_env import SubprocVecEnv
|
from stable_baselines3.common.vec_env import SubprocVecEnv
|
||||||
|
|
||||||
from environments.factory.factory_dirt_item import DirtItemFactory
|
from environments.factory.factory_dirt_item import DirtItemFactory
|
||||||
from environments.factory.factory_item import ItemFactory, ItemProperties
|
from environments.factory.factory_item import ItemFactory, ItemProperties
|
||||||
from environments.factory.factory_dirt import DirtProperties, DirtFactory
|
from environments.factory.factory_dirt import DirtProperties, DirtFactory
|
||||||
from environments.helpers import IGNORED_DF_COLUMNS
|
|
||||||
from environments.logging.monitor import MonitorCallback
|
from environments.logging.monitor import MonitorCallback
|
||||||
from environments.logging.plotting import prepare_plot
|
|
||||||
from environments.logging.recorder import RecorderCallback
|
from environments.logging.recorder import RecorderCallback
|
||||||
from environments.utility_classes import MovementProperties
|
from environments.utility_classes import MovementProperties
|
||||||
|
from plotting.compare_runs import compare_seed_runs, compare_model_runs
|
||||||
|
|
||||||
warnings.filterwarnings('ignore', category=FutureWarning)
|
warnings.filterwarnings('ignore', category=FutureWarning)
|
||||||
warnings.filterwarnings('ignore', category=UserWarning)
|
warnings.filterwarnings('ignore', category=UserWarning)
|
||||||
|
|
||||||
|
|
||||||
def combine_runs(run_path: Union[str, PathLike]):
|
|
||||||
run_path = Path(run_path)
|
|
||||||
df_list = list()
|
|
||||||
for run, monitor_file in enumerate(run_path.rglob('monitor_*.pick')):
|
|
||||||
with monitor_file.open('rb') as f:
|
|
||||||
monitor_df = pickle.load(f)
|
|
||||||
|
|
||||||
monitor_df['run'] = run
|
|
||||||
monitor_df = monitor_df.fillna(0)
|
|
||||||
df_list.append(monitor_df)
|
|
||||||
|
|
||||||
df = pd.concat(df_list, ignore_index=True)
|
|
||||||
df = df.fillna(0).rename(columns={'episode': 'Episode', 'run': 'Run'}).sort_values(['Run', 'Episode'])
|
|
||||||
columns = [col for col in df.columns if col not in IGNORED_DF_COLUMNS]
|
|
||||||
|
|
||||||
roll_n = 50
|
|
||||||
|
|
||||||
non_overlapp_window = df.groupby(['Run', 'Episode']).rolling(roll_n, min_periods=1).mean()
|
|
||||||
|
|
||||||
df_melted = non_overlapp_window[columns].reset_index().melt(id_vars=['Episode', 'Run'],
|
|
||||||
value_vars=columns, var_name="Measurement",
|
|
||||||
value_name="Score")
|
|
||||||
|
|
||||||
if df_melted['Episode'].max() > 800:
|
|
||||||
skip_n = round(df_melted['Episode'].max() * 0.02)
|
|
||||||
df_melted = df_melted[df_melted['Episode'] % skip_n == 0]
|
|
||||||
|
|
||||||
prepare_plot(run_path / f'{run_path.name}_monitor_lineplot.png', df_melted)
|
|
||||||
print('Plotting done.')
|
|
||||||
|
|
||||||
|
|
||||||
def compare_runs(run_path: Path, run_identifier: int, parameter: Union[str, List[str]]):
|
|
||||||
run_path = Path(run_path)
|
|
||||||
df_list = list()
|
|
||||||
parameter = [parameter] if isinstance(parameter, str) else parameter
|
|
||||||
for path in run_path.iterdir():
|
|
||||||
if path.is_dir() and str(run_identifier) in path.name:
|
|
||||||
for run, monitor_file in enumerate(path.rglob('monitor_*.pick')):
|
|
||||||
with monitor_file.open('rb') as f:
|
|
||||||
monitor_df = pickle.load(f)
|
|
||||||
|
|
||||||
monitor_df['run'] = run
|
|
||||||
monitor_df['model'] = path.name.split('_')[0]
|
|
||||||
monitor_df = monitor_df.fillna(0)
|
|
||||||
df_list.append(monitor_df)
|
|
||||||
|
|
||||||
df = pd.concat(df_list, ignore_index=True)
|
|
||||||
df = df.fillna(0).rename(columns={'episode': 'Episode', 'run': 'Run', 'model': 'Model'})
|
|
||||||
columns = [col for col in df.columns if col in parameter]
|
|
||||||
|
|
||||||
last_episode_to_report = min(df.groupby(['Model'])['Episode'].max())
|
|
||||||
df = df[df['Episode'] < last_episode_to_report]
|
|
||||||
|
|
||||||
roll_n = 40
|
|
||||||
non_overlapp_window = df.groupby(['Model', 'Run', 'Episode']).rolling(roll_n, min_periods=1).mean()
|
|
||||||
|
|
||||||
df_melted = non_overlapp_window[columns].reset_index().melt(id_vars=['Episode', 'Run', 'Model'],
|
|
||||||
value_vars=columns, var_name="Measurement",
|
|
||||||
value_name="Score")
|
|
||||||
|
|
||||||
if df_melted['Episode'].max() > 100:
|
|
||||||
skip_n = round(df_melted['Episode'].max() * 0.02)
|
|
||||||
df_melted = df_melted[df_melted['Episode'] % skip_n == 0]
|
|
||||||
|
|
||||||
style = 'Measurement' if len(columns) > 1 else None
|
|
||||||
prepare_plot(run_path / f'{run_identifier}_compare_{parameter}.png', df_melted, hue='Model', style=style)
|
|
||||||
print('Plotting done.')
|
|
||||||
|
|
||||||
|
|
||||||
def make_env(env_kwargs_dict):
|
def make_env(env_kwargs_dict):
|
||||||
|
|
||||||
def _init():
|
def _init():
|
||||||
with ItemFactory(**env_kwargs_dict) as init_env:
|
with DirtFactory(**env_kwargs_dict) as init_env:
|
||||||
return init_env
|
return init_env
|
||||||
|
|
||||||
return _init
|
return _init
|
||||||
@ -110,17 +36,19 @@ if __name__ == '__main__':
|
|||||||
# exit()
|
# exit()
|
||||||
|
|
||||||
from stable_baselines3 import PPO, DQN, A2C
|
from stable_baselines3 import PPO, DQN, A2C
|
||||||
from algorithms.reg_dqn import RegDQN
|
# from algorithms.reg_dqn import RegDQN
|
||||||
# from sb3_contrib import QRDQN
|
# from sb3_contrib import QRDQN
|
||||||
|
|
||||||
dirt_props = DirtProperties(clean_amount=2, gain_amount=0.1, max_global_amount=20,
|
dirt_props = DirtProperties(clean_amount=2, gain_amount=0.1, max_global_amount=20,
|
||||||
max_local_amount=1, spawn_frequency=3, max_spawn_ratio=0.05,
|
max_local_amount=1, spawn_frequency=16, max_spawn_ratio=0.05,
|
||||||
dirt_smear_amount=0.0, agent_can_interact=True)
|
dirt_smear_amount=0.0, agent_can_interact=True)
|
||||||
item_props = ItemProperties(n_items=5, agent_can_interact=True)
|
item_props = ItemProperties(n_items=10, agent_can_interact=True,
|
||||||
|
spawn_frequency=30, n_drop_off_locations=2,
|
||||||
|
max_agent_inventory_capacity=15)
|
||||||
move_props = MovementProperties(allow_diagonal_movement=True,
|
move_props = MovementProperties(allow_diagonal_movement=True,
|
||||||
allow_square_movement=True,
|
allow_square_movement=True,
|
||||||
allow_no_op=False)
|
allow_no_op=False)
|
||||||
train_steps = 8e5
|
train_steps = 5e6
|
||||||
time_stamp = int(time.time())
|
time_stamp = int(time.time())
|
||||||
|
|
||||||
out_path = None
|
out_path = None
|
||||||
@ -128,18 +56,18 @@ if __name__ == '__main__':
|
|||||||
for modeL_type in [A2C, PPO, DQN]: # ,RegDQN, QRDQN]:
|
for modeL_type in [A2C, PPO, DQN]: # ,RegDQN, QRDQN]:
|
||||||
for seed in range(3):
|
for seed in range(3):
|
||||||
env_kwargs = dict(n_agents=1,
|
env_kwargs = dict(n_agents=1,
|
||||||
item_properties=item_props,
|
# item_properties=item_props,
|
||||||
# dirt_properties=dirt_props,
|
dirt_properties=dirt_props,
|
||||||
movement_properties=move_props,
|
movement_properties=move_props,
|
||||||
pomdp_r=2, max_steps=400, parse_doors=False,
|
pomdp_r=2, max_steps=1000, parse_doors=False,
|
||||||
level_name='rooms', frames_to_stack=3,
|
level_name='rooms', frames_to_stack=4,
|
||||||
omit_agent_in_obs=True, combin_agent_obs=True, record_episodes=False,
|
omit_agent_in_obs=True, combin_agent_obs=True, record_episodes=False,
|
||||||
cast_shadows=True, doors_have_area=False, env_seed=seed, verbose=False,
|
cast_shadows=True, doors_have_area=False, env_seed=seed, verbose=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
if modeL_type.__name__ in ["PPO", "A2C"]:
|
if modeL_type.__name__ in ["PPO", "A2C"]:
|
||||||
kwargs = dict(ent_coef=0.01)
|
kwargs = dict(ent_coef=0.01)
|
||||||
env = SubprocVecEnv([make_env(env_kwargs) for _ in range(1)], start_method="spawn")
|
env = SubprocVecEnv([make_env(env_kwargs) for _ in range(10)], start_method="spawn")
|
||||||
elif modeL_type.__name__ in ["RegDQN", "DQN", "QRDQN"]:
|
elif modeL_type.__name__ in ["RegDQN", "DQN", "QRDQN"]:
|
||||||
env = make_env(env_kwargs)()
|
env = make_env(env_kwargs)()
|
||||||
kwargs = dict(buffer_size=50000,
|
kwargs = dict(buffer_size=50000,
|
||||||
@ -161,7 +89,7 @@ if __name__ == '__main__':
|
|||||||
out_path /= identifier
|
out_path /= identifier
|
||||||
|
|
||||||
callbacks = CallbackList(
|
callbacks = CallbackList(
|
||||||
[MonitorCallback(filepath=out_path / f'monitor_{identifier}.pick', plotting=False),
|
[MonitorCallback(filepath=out_path / f'monitor_{identifier}.pick'),
|
||||||
RecorderCallback(filepath=out_path / f'recorder_{identifier}.json', occupation_map=False,
|
RecorderCallback(filepath=out_path / f'recorder_{identifier}.json', occupation_map=False,
|
||||||
trajectory_map=False
|
trajectory_map=False
|
||||||
)]
|
)]
|
||||||
@ -172,7 +100,7 @@ if __name__ == '__main__':
|
|||||||
save_path = out_path / f'model_{identifier}.zip'
|
save_path = out_path / f'model_{identifier}.zip'
|
||||||
save_path.parent.mkdir(parents=True, exist_ok=True)
|
save_path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
model.save(save_path)
|
model.save(save_path)
|
||||||
param_path = out_path.parent / f'env_{model.__class__.__name__}_{time_stamp}.yaml'
|
param_path = out_path.parent / f'env_{model.__class__.__name__}_{time_stamp}.json'
|
||||||
try:
|
try:
|
||||||
env.env_method('save_params', param_path)
|
env.env_method('save_params', param_path)
|
||||||
except AttributeError:
|
except AttributeError:
|
||||||
@ -181,7 +109,7 @@ if __name__ == '__main__':
|
|||||||
print("Model Group Done.. Plotting...")
|
print("Model Group Done.. Plotting...")
|
||||||
|
|
||||||
if out_path:
|
if out_path:
|
||||||
combine_runs(out_path.parent)
|
compare_seed_runs(out_path.parent)
|
||||||
print("All Models Done... Evaluating")
|
print("All Models Done... Evaluating")
|
||||||
if out_path:
|
if out_path:
|
||||||
compare_runs(Path('debug_out'), time_stamp, 'step_reward')
|
compare_model_runs(Path('debug_out'), time_stamp, 'step_reward')
|
||||||
|
@ -13,7 +13,7 @@ from stable_baselines3 import PPO, DQN, A2C
|
|||||||
from environments.factory.factory_dirt import DirtFactory, DirtProperties
|
from environments.factory.factory_dirt import DirtFactory, DirtProperties
|
||||||
from environments.logging.monitor import MonitorCallback
|
from environments.logging.monitor import MonitorCallback
|
||||||
from algorithms.reg_dqn import RegDQN
|
from algorithms.reg_dqn import RegDQN
|
||||||
from main import compare_runs, combine_runs
|
from main import compare_model_runs, compare_seed_runs
|
||||||
|
|
||||||
warnings.filterwarnings('ignore', category=FutureWarning)
|
warnings.filterwarnings('ignore', category=FutureWarning)
|
||||||
warnings.filterwarnings('ignore', category=UserWarning)
|
warnings.filterwarnings('ignore', category=UserWarning)
|
||||||
@ -55,7 +55,7 @@ if __name__ == '__main__':
|
|||||||
|
|
||||||
exp_out_path = model_path / 'exp'
|
exp_out_path = model_path / 'exp'
|
||||||
callbacks = CallbackList(
|
callbacks = CallbackList(
|
||||||
[MonitorCallback(filepath=exp_out_path / f'future_exp_name', plotting=True)]
|
[MonitorCallback(filepath=exp_out_path / f'future_exp_name')]
|
||||||
)
|
)
|
||||||
|
|
||||||
n_actions = env.action_space.n
|
n_actions = env.action_space.n
|
||||||
@ -83,4 +83,4 @@ if __name__ == '__main__':
|
|||||||
print(f'Factory run {epoch} done, reward is:\n {r}')
|
print(f'Factory run {epoch} done, reward is:\n {r}')
|
||||||
|
|
||||||
if out_path:
|
if out_path:
|
||||||
combine_runs(out_path.parent)
|
compare_seed_runs(out_path.parent)
|
||||||
|
0
plotting/__init__.py
Normal file
0
plotting/__init__.py
Normal file
155
plotting/compare_runs.py
Normal file
155
plotting/compare_runs.py
Normal file
@ -0,0 +1,155 @@
|
|||||||
|
import pickle
|
||||||
|
import re
|
||||||
|
from os import PathLike
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Union, List
|
||||||
|
|
||||||
|
import pandas as pd
|
||||||
|
|
||||||
|
from environments.helpers import IGNORED_DF_COLUMNS, MODEL_MAP
|
||||||
|
from plotting.plotting import prepare_plot
|
||||||
|
|
||||||
|
|
||||||
|
def compare_seed_runs(run_path: Union[str, PathLike]):
|
||||||
|
run_path = Path(run_path)
|
||||||
|
df_list = list()
|
||||||
|
for run, monitor_file in enumerate(run_path.rglob('monitor*.pick')):
|
||||||
|
with monitor_file.open('rb') as f:
|
||||||
|
monitor_df = pickle.load(f)
|
||||||
|
|
||||||
|
monitor_df['run'] = run
|
||||||
|
monitor_df = monitor_df.fillna(0)
|
||||||
|
df_list.append(monitor_df)
|
||||||
|
|
||||||
|
df = pd.concat(df_list, ignore_index=True)
|
||||||
|
df = df.fillna(0).rename(columns={'episode': 'Episode', 'run': 'Run'}).sort_values(['Run', 'Episode'])
|
||||||
|
columns = [col for col in df.columns if col not in IGNORED_DF_COLUMNS]
|
||||||
|
|
||||||
|
roll_n = 50
|
||||||
|
|
||||||
|
non_overlapp_window = df.groupby(['Run', 'Episode']).rolling(roll_n, min_periods=1).mean()
|
||||||
|
|
||||||
|
df_melted = non_overlapp_window[columns].reset_index().melt(id_vars=['Episode', 'Run'],
|
||||||
|
value_vars=columns, var_name="Measurement",
|
||||||
|
value_name="Score")
|
||||||
|
|
||||||
|
if df_melted['Episode'].max() > 800:
|
||||||
|
skip_n = round(df_melted['Episode'].max() * 0.02)
|
||||||
|
df_melted = df_melted[df_melted['Episode'] % skip_n == 0]
|
||||||
|
|
||||||
|
prepare_plot(run_path / f'{run_path.name}_monitor_lineplot.png', df_melted)
|
||||||
|
print('Plotting done.')
|
||||||
|
|
||||||
|
|
||||||
|
def compare_model_runs(run_path: Path, run_identifier: Union[str, int], parameter: Union[str, List[str]]):
|
||||||
|
run_path = Path(run_path)
|
||||||
|
df_list = list()
|
||||||
|
parameter = [parameter] if isinstance(parameter, str) else parameter
|
||||||
|
for path in run_path.iterdir():
|
||||||
|
if path.is_dir() and str(run_identifier) in path.name:
|
||||||
|
for run, monitor_file in enumerate(path.rglob('monitor*.pick')):
|
||||||
|
with monitor_file.open('rb') as f:
|
||||||
|
monitor_df = pickle.load(f)
|
||||||
|
|
||||||
|
monitor_df['run'] = run
|
||||||
|
monitor_df['model'] = next((x for x in path.name.split('_') if x in MODEL_MAP.keys()))
|
||||||
|
monitor_df = monitor_df.fillna(0)
|
||||||
|
df_list.append(monitor_df)
|
||||||
|
|
||||||
|
df = pd.concat(df_list, ignore_index=True)
|
||||||
|
df = df.fillna(0).rename(columns={'episode': 'Episode', 'run': 'Run', 'model': 'Model'})
|
||||||
|
columns = [col for col in df.columns if col in parameter]
|
||||||
|
|
||||||
|
last_episode_to_report = min(df.groupby(['Model'])['Episode'].max())
|
||||||
|
df = df[df['Episode'] < last_episode_to_report]
|
||||||
|
|
||||||
|
roll_n = 40
|
||||||
|
non_overlapp_window = df.groupby(['Model', 'Run', 'Episode']).rolling(roll_n, min_periods=1).mean()
|
||||||
|
|
||||||
|
df_melted = non_overlapp_window[columns].reset_index().melt(id_vars=['Episode', 'Run', 'Model'],
|
||||||
|
value_vars=columns, var_name="Measurement",
|
||||||
|
value_name="Score")
|
||||||
|
|
||||||
|
if df_melted['Episode'].max() > 80:
|
||||||
|
skip_n = round(df_melted['Episode'].max() * 0.02)
|
||||||
|
df_melted = df_melted[df_melted['Episode'] % skip_n == 0]
|
||||||
|
|
||||||
|
style = 'Measurement' if len(columns) > 1 else None
|
||||||
|
prepare_plot(run_path / f'{run_identifier}_compare_{parameter}.png', df_melted, hue='Model', style=style)
|
||||||
|
print('Plotting done.')
|
||||||
|
|
||||||
|
|
||||||
|
def compare_all_parameter_runs(run_root_path: Path, parameter: Union[str, List[str]],
|
||||||
|
param_names: Union[List[str], None] = None, str_to_ignore=''):
|
||||||
|
run_root_path = Path(run_root_path)
|
||||||
|
df_list = list()
|
||||||
|
parameter = [parameter] if isinstance(parameter, str) else parameter
|
||||||
|
for monitor_idx, monitor_file in enumerate(run_root_path.rglob('monitor*.pick')):
|
||||||
|
with monitor_file.open('rb') as f:
|
||||||
|
monitor_df = pickle.load(f)
|
||||||
|
|
||||||
|
parameters = [x.name for x in monitor_file.parents if x.parent not in run_root_path.parents]
|
||||||
|
if str_to_ignore:
|
||||||
|
parameters = [re.sub(f'_*({str_to_ignore})', '', param) for param in parameters]
|
||||||
|
|
||||||
|
if monitor_idx == 0:
|
||||||
|
if param_names is not None:
|
||||||
|
if len(param_names) < len(parameters):
|
||||||
|
# FIXME: Missing Seed Detection, see below @111
|
||||||
|
param_names = [next(param_names) if param not in MODEL_MAP.keys() else 'Model' for param in parameters]
|
||||||
|
elif len(param_names) == len(parameters):
|
||||||
|
pass
|
||||||
|
else:
|
||||||
|
raise ValueError
|
||||||
|
else:
|
||||||
|
param_names = []
|
||||||
|
for param_idx, param in enumerate(parameters):
|
||||||
|
dtype = None
|
||||||
|
if param in MODEL_MAP.keys():
|
||||||
|
param_name = 'Model'
|
||||||
|
elif '_' in param:
|
||||||
|
param_split = param.split('_')
|
||||||
|
if len(param_split) == 2 and any(split in MODEL_MAP.keys() for split in param_split):
|
||||||
|
# Extract the seed
|
||||||
|
param = int(next(x for x in param_split if x not in MODEL_MAP))
|
||||||
|
param_name = 'Seed'
|
||||||
|
dtype = int
|
||||||
|
else:
|
||||||
|
param_name = f'param_{param_idx}'
|
||||||
|
else:
|
||||||
|
param_name = f'param_{param_idx}'
|
||||||
|
dtype = dtype if dtype is not None else str
|
||||||
|
monitor_df[param_name] = str(param)
|
||||||
|
monitor_df[param_name] = monitor_df[param_name].astype(dtype)
|
||||||
|
if monitor_idx == 0:
|
||||||
|
param_names.append(param_name)
|
||||||
|
|
||||||
|
monitor_df = monitor_df.fillna(0)
|
||||||
|
df_list.append(monitor_df)
|
||||||
|
|
||||||
|
df = pd.concat(df_list, ignore_index=True)
|
||||||
|
df = df.fillna(0).rename(columns={'episode': 'Episode'}).sort_values(['Episode'])
|
||||||
|
|
||||||
|
for param_name in param_names:
|
||||||
|
df[param_name] = df[param_name].astype(str)
|
||||||
|
columns = [col for col in df.columns if col in parameter]
|
||||||
|
|
||||||
|
last_episode_to_report = min(df.groupby(['Model'])['Episode'].max())
|
||||||
|
df = df[df['Episode'] < last_episode_to_report]
|
||||||
|
|
||||||
|
if df['Episode'].max() > 80:
|
||||||
|
skip_n = round(df['Episode'].max() * 0.02)
|
||||||
|
df = df[df['Episode'] % skip_n == 0]
|
||||||
|
combinations = [x for x in param_names if x not in ['Model', 'Seed']]
|
||||||
|
df['Parameter Combination'] = df[combinations].apply(lambda row: '_'.join(row.values.astype(str)), axis=1)
|
||||||
|
df.drop(columns=combinations, inplace=True)
|
||||||
|
|
||||||
|
# non_overlapp_window = df.groupby(param_names).sum()
|
||||||
|
|
||||||
|
df_melted = df.reset_index().melt(id_vars=['Parameter Combination', 'Episode'],
|
||||||
|
value_vars=columns, var_name="Measurement",
|
||||||
|
value_name="Score")
|
||||||
|
|
||||||
|
style = 'Measurement' if len(columns) > 1 else None
|
||||||
|
prepare_plot(run_root_path / f'compare_{parameter}.png', df_melted, hue='Parameter Combination', style=style)
|
||||||
|
print('Plotting done.')
|
@ -3,10 +3,10 @@ from pathlib import Path
|
|||||||
|
|
||||||
import yaml
|
import yaml
|
||||||
from natsort import natsorted
|
from natsort import natsorted
|
||||||
from stable_baselines3.common.evaluation import evaluate_policy
|
from environments import helpers as h
|
||||||
|
|
||||||
from environments.factory.factory_dirt import DirtProperties, DirtFactory
|
from environments.factory.factory_dirt_item import DirtItemFactory
|
||||||
from environments.factory.factory_item import ItemProperties, ItemFactory
|
from environments.logging.recorder import RecorderCallback
|
||||||
|
|
||||||
warnings.filterwarnings('ignore', category=FutureWarning)
|
warnings.filterwarnings('ignore', category=FutureWarning)
|
||||||
warnings.filterwarnings('ignore', category=UserWarning)
|
warnings.filterwarnings('ignore', category=UserWarning)
|
||||||
@ -14,27 +14,35 @@ warnings.filterwarnings('ignore', category=UserWarning)
|
|||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
|
|
||||||
model_name = 'DQN_1631092016'
|
model_name = 'PPO_1631187073'
|
||||||
run_id = 0
|
run_id = 0
|
||||||
seed = 69
|
seed = 69
|
||||||
out_path = Path(__file__).parent / 'debug_out'
|
out_path = Path(__file__).parent / 'study_out' / 'e_1_1631709932'/ 'no_obs' / 'itemdirt'/'A2C_1631709932' / '0_A2C_1631709932'
|
||||||
model_path = out_path / model_name
|
model_path = out_path / model_name
|
||||||
|
|
||||||
with (model_path / f'env_{model_name}.yaml').open('r') as f:
|
with (out_path / f'env_params.json').open('r') as f:
|
||||||
env_kwargs = yaml.load(f, Loader=yaml.FullLoader)
|
env_kwargs = yaml.load(f, Loader=yaml.FullLoader)
|
||||||
env_kwargs.update(verbose=True, env_seed=seed)
|
env_kwargs.update(verbose=False, env_seed=seed, record_episodes=True)
|
||||||
if False:
|
|
||||||
env_kwargs.update(dirt_properties=DirtProperties(clean_amount=1, gain_amount=0.1, max_global_amount=20,
|
|
||||||
max_local_amount=1, spawn_frequency=5, max_spawn_ratio=0.05,
|
|
||||||
dirt_smear_amount=0.5),
|
|
||||||
combin_agent_slices_in_obs=True, omit_agent_slice_in_obs=True)
|
|
||||||
with ItemFactory(**env_kwargs) as env:
|
|
||||||
|
|
||||||
# Edit THIS:
|
this_model = out_path / 'model.zip'
|
||||||
env.seed(seed)
|
|
||||||
model_files = list(natsorted((model_path / f'{run_id}_{model_name}').rglob('model_*.zip')))
|
model_cls = next(val for key, val in h.MODEL_MAP.items() if key in model_name)
|
||||||
this_model = model_files[0]
|
model = model_cls.load(this_model)
|
||||||
model_cls = next(val for key, val in model_map.items() if key in model_name)
|
|
||||||
model = model_cls.load(this_model)
|
with RecorderCallback(filepath=Path() / 'recorder_out.json') as recorder:
|
||||||
evaluation_result = evaluate_policy(model, env, n_eval_episodes=100, deterministic=False, render=True)
|
# Init Env
|
||||||
print(evaluation_result)
|
with DirtItemFactory(**env_kwargs) as env:
|
||||||
|
# Evaluation Loop for i in range(n Episodes)
|
||||||
|
for episode in range(5):
|
||||||
|
obs = env.reset()
|
||||||
|
rew, done_bool = 0, False
|
||||||
|
while not done_bool:
|
||||||
|
action = model.predict(obs, deterministic=False)[0]
|
||||||
|
env_state, step_r, done_bool, info_obj = env.step(action[0])
|
||||||
|
recorder.read_info(0, info_obj)
|
||||||
|
rew += step_r
|
||||||
|
if done_bool:
|
||||||
|
recorder.read_done(0, done_bool)
|
||||||
|
break
|
||||||
|
print(f'Factory run {episode} done, reward is:\n {rew}')
|
||||||
|
print('all done')
|
||||||
|
8
setup.py
8
setup.py
@ -1 +1,7 @@
|
|||||||
# TODO
|
# setup.py
|
||||||
|
from setuptools import setup, find_packages
|
||||||
|
|
||||||
|
setup(
|
||||||
|
name='F_IKS',
|
||||||
|
packages=find_packages()
|
||||||
|
)
|
||||||
|
304
studies/e_1.py
304
studies/e_1.py
@ -1,130 +1,234 @@
|
|||||||
import itertools
|
import sys
|
||||||
import random
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
|
try:
|
||||||
|
# noinspection PyUnboundLocalVariable
|
||||||
|
if __package__ is None:
|
||||||
|
DIR = Path(__file__).resolve().parent
|
||||||
|
sys.path.insert(0, str(DIR.parent))
|
||||||
|
__package__ = DIR.name
|
||||||
|
else:
|
||||||
|
DIR = None
|
||||||
|
except NameError:
|
||||||
|
DIR = None
|
||||||
|
pass
|
||||||
|
|
||||||
|
import time
|
||||||
|
|
||||||
|
|
||||||
import simplejson
|
import simplejson
|
||||||
from stable_baselines3 import DQN, PPO, A2C
|
from stable_baselines3.common.vec_env import SubprocVecEnv
|
||||||
|
|
||||||
|
from environments import helpers as h
|
||||||
from environments.factory.factory_dirt import DirtProperties, DirtFactory
|
from environments.factory.factory_dirt import DirtProperties, DirtFactory
|
||||||
|
from environments.factory.factory_dirt_item import DirtItemFactory
|
||||||
from environments.factory.factory_item import ItemProperties, ItemFactory
|
from environments.factory.factory_item import ItemProperties, ItemFactory
|
||||||
|
from environments.logging.monitor import MonitorCallback
|
||||||
|
from environments.utility_classes import MovementProperties
|
||||||
|
from plotting.compare_runs import compare_seed_runs, compare_model_runs, compare_all_parameter_runs
|
||||||
|
|
||||||
if __name__ == '__main__':
|
# Define a global studi save path
|
||||||
"""
|
start_time = 1631709932 # int(time.time())
|
||||||
In this studie, we want to explore the macro behaviour of multi agents which are trained on the same task,
|
study_root_path = (Path('..') if not DIR else Path()) / 'study_out' / f'{Path(__file__).stem}_{start_time}'
|
||||||
but never saw each other in training.
|
|
||||||
Those agents learned
|
"""
|
||||||
|
In this studie, we want to explore the macro behaviour of multi agents which are trained on the same task,
|
||||||
|
but never saw each other in training.
|
||||||
|
Those agents learned
|
||||||
|
|
||||||
|
|
||||||
We start with training a single policy on a single task (dirt cleanup / item pickup).
|
We start with training a single policy on a single task (dirt cleanup / item pickup).
|
||||||
Then multiple agent equipped with the same policy are deployed in the same environment.
|
Then multiple agent equipped with the same policy are deployed in the same environment.
|
||||||
|
|
||||||
There are further distinctions to be made:
|
There are further distinctions to be made:
|
||||||
|
|
||||||
1. No Observation - ['no_obs']:
|
1. No Observation - ['no_obs']:
|
||||||
- Agent do not see each other but their consequences of their combined actions
|
- Agent do not see each other but their consequences of their combined actions
|
||||||
- Agents can collide
|
- Agents can collide
|
||||||
|
|
||||||
2. Observation in seperate slice - [['seperate_0'], ['seperate_1'], ['seperate_N']]:
|
2. Observation in seperate slice - [['seperate_0'], ['seperate_1'], ['seperate_N']]:
|
||||||
- Agents see other entitys on a seperate slice
|
- Agents see other entitys on a seperate slice
|
||||||
- This slice has been filled with $0 | 1 | \mathbb{N}(0, 1)$
|
- This slice has been filled with $0 | 1 | \mathbb{N}(0, 1)$
|
||||||
-- Depending ob the fill value, agents will react diffently
|
-- Depending ob the fill value, agents will react diffently
|
||||||
-> TODO: Test this!
|
-> TODO: Test this!
|
||||||
|
|
||||||
3. Observation in level slice - ['in_lvl_obs']:
|
3. Observation in level slice - ['in_lvl_obs']:
|
||||||
- This tells the agent to treat other agents as obstacle.
|
- This tells the agent to treat other agents as obstacle.
|
||||||
- However, the state space is altered since moving obstacles are not part the original agent observation.
|
- However, the state space is altered since moving obstacles are not part the original agent observation.
|
||||||
- We are out of distribution.
|
- We are out of distribution.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
def bundle_model(model_class):
|
def policy_model_kwargs():
|
||||||
if model_class.__class__.__name__ in ["PPO", "A2C"]:
|
return dict(ent_coef=0.01)
|
||||||
kwargs = dict(ent_coef=0.01)
|
|
||||||
elif model_class.__class__.__name__ in ["RegDQN", "DQN", "QRDQN"]:
|
|
||||||
kwargs = dict(buffer_size=50000,
|
def dqn_model_kwargs():
|
||||||
learning_starts=64,
|
return dict(buffer_size=50000,
|
||||||
batch_size=64,
|
learning_starts=64,
|
||||||
target_update_interval=5000,
|
batch_size=64,
|
||||||
exploration_fraction=0.25,
|
target_update_interval=5000,
|
||||||
exploration_final_eps=0.025
|
exploration_fraction=0.25,
|
||||||
)
|
exploration_final_eps=0.025
|
||||||
return lambda: model_class(kwargs)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def encapsule_env_factory(env_fctry, env_kwrgs):
|
||||||
|
|
||||||
|
def _init():
|
||||||
|
with env_fctry(**env_kwrgs) as init_env:
|
||||||
|
return init_env
|
||||||
|
|
||||||
|
return _init
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
# Define a global studi save path
|
train_steps = 5e5
|
||||||
study_root_path = Path(Path(__file__).stem) / 'out'
|
|
||||||
|
|
||||||
# TODO: Define Global Env Parameters
|
# Define Global Env Parameters
|
||||||
factory_kwargs = {
|
# Define properties object parameters
|
||||||
|
move_props = MovementProperties(allow_diagonal_movement=True,
|
||||||
|
allow_square_movement=True,
|
||||||
}
|
allow_no_op=False)
|
||||||
|
dirt_props = DirtProperties(clean_amount=2, gain_amount=0.1, max_global_amount=20,
|
||||||
# TODO: Define global model parameters
|
max_local_amount=1, spawn_frequency=15, max_spawn_ratio=0.05,
|
||||||
|
dirt_smear_amount=0.0, agent_can_interact=True)
|
||||||
|
item_props = ItemProperties(n_items=10, agent_can_interact=True,
|
||||||
# TODO: Define parameters for both envs
|
spawn_frequency=30, n_drop_off_locations=2,
|
||||||
dirt_props = DirtProperties()
|
max_agent_inventory_capacity=15)
|
||||||
item_props = ItemProperties()
|
factory_kwargs = dict(n_agents=1,
|
||||||
|
pomdp_r=2, max_steps=400, parse_doors=False,
|
||||||
|
level_name='rooms', frames_to_stack=3,
|
||||||
|
omit_agent_in_obs=True, combin_agent_obs=True, record_episodes=False,
|
||||||
|
cast_shadows=True, doors_have_area=False, verbose=False,
|
||||||
|
movement_properties=move_props
|
||||||
|
)
|
||||||
|
|
||||||
# Bundle both environments with global kwargs and parameters
|
# Bundle both environments with global kwargs and parameters
|
||||||
env_bundles = [lambda: ('dirt', DirtFactory(factory_kwargs, dirt_properties=dirt_props)),
|
env_map = {'dirt': (DirtFactory, dict(dirt_properties=dirt_props, **factory_kwargs)),
|
||||||
lambda: ('item', ItemFactory(factory_kwargs, item_properties=item_props))]
|
'item': (ItemFactory, dict(item_properties=item_props, **factory_kwargs)),
|
||||||
|
'itemdirt': (DirtItemFactory, dict(dirt_properties=dirt_props, item_properties=item_props,
|
||||||
|
**factory_kwargs))}
|
||||||
|
env_names = list(env_map.keys())
|
||||||
|
|
||||||
# Define parameter versions according with #1,2[1,0,N],3
|
# Define parameter versions according with #1,2[1,0,N],3
|
||||||
observation_modes = ['no_obs', 'seperate_0', 'seperate_1', 'seperate_N', 'in_lvl_obs']
|
observation_modes = {
|
||||||
|
# Fill-value = 0
|
||||||
# Define RL-Models
|
'seperate_0': dict(additional_env_kwargs=dict(additional_agent_placeholder=0)),
|
||||||
model_bundles = [bundle_model(model) for model in [A2C, PPO, DQN]]
|
# Fill-value = 1
|
||||||
|
'seperate_1': dict(additional_env_kwargs=dict(additional_agent_placeholder=1)),
|
||||||
# Zip parameters, parameter versions, Env Classes and models
|
# Fill-value = N(0, 1)
|
||||||
combinations = itertools.product(model_bundles, env_bundles)
|
'seperate_N': dict(additional_env_kwargs=dict(additional_agent_placeholder='N')),
|
||||||
|
# Further Adjustments are done post-training
|
||||||
|
'in_lvl_obs': dict(post_training_kwargs=dict(other_agent_obs='in_lvl')),
|
||||||
|
# No further adjustment needed
|
||||||
|
'no_obs': None
|
||||||
|
}
|
||||||
|
|
||||||
# Train starts here ############################################################
|
# Train starts here ############################################################
|
||||||
# Build Major Loop
|
# Build Major Loop parameters, parameter versions, Env Classes and models
|
||||||
for model, (env_identifier, env_bundle) in combinations:
|
if False:
|
||||||
for observation_mode in observation_modes:
|
for observation_mode in observation_modes.keys():
|
||||||
# TODO: Create an identifier, which is unique for every combination and easy to read in filesystem
|
for env_name in env_names:
|
||||||
identifier = f'{model.name}_{observation_mode}_{env_identifier}'
|
for model_cls in h.MODEL_MAP.values():
|
||||||
# Train each combination per seed
|
# Create an identifier, which is unique for every combination and easy to read in filesystem
|
||||||
for seed in range(3):
|
identifier = f'{model_cls.__name__}_{start_time}'
|
||||||
# TODO: Output folder
|
# Train each combination per seed
|
||||||
# TODO: Monitor Init
|
combination_path = study_root_path / observation_mode / env_name / identifier
|
||||||
# TODO: Env Init
|
env_class, env_kwargs = env_map[env_name]
|
||||||
# TODO: Model Init
|
# Retrieve and set the observation mode specific env parameters
|
||||||
# TODO: Model train
|
if observation_mode_kwargs := observation_modes.get(observation_mode, None):
|
||||||
# TODO: Model save
|
if additional_env_kwargs := observation_mode_kwargs.get("additional_env_kwargs", None):
|
||||||
pass
|
env_kwargs.update(additional_env_kwargs)
|
||||||
# TODO: Seed Compare Plot
|
for seed in range(5):
|
||||||
|
env_kwargs.update(env_seed=seed)
|
||||||
|
# Output folder
|
||||||
|
seed_path = combination_path / f'{str(seed)}_{identifier}'
|
||||||
|
seed_path.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
# Monitor Init
|
||||||
|
callbacks = [MonitorCallback(seed_path / 'monitor.pick')]
|
||||||
|
|
||||||
|
# Env Init & Model kwargs definition
|
||||||
|
if model_cls.__name__ in ["PPO", "A2C"]:
|
||||||
|
env = env_class(**env_kwargs)
|
||||||
|
|
||||||
|
# env = SubprocVecEnv([encapsule_env_factory(env_class, env_kwargs) for _ in range(1)],
|
||||||
|
# start_method="spawn")
|
||||||
|
model_kwargs = policy_model_kwargs()
|
||||||
|
|
||||||
|
elif model_cls.__name__ in ["RegDQN", "DQN", "QRDQN"]:
|
||||||
|
env = env_class(**env_kwargs)
|
||||||
|
model_kwargs = dqn_model_kwargs()
|
||||||
|
|
||||||
|
else:
|
||||||
|
raise NameError(f'The model "{model_cls.__name__}" has the wrong name.')
|
||||||
|
|
||||||
|
param_path = seed_path / f'env_params.json'
|
||||||
|
try:
|
||||||
|
env.env_method('save_params', param_path)
|
||||||
|
except AttributeError:
|
||||||
|
env.save_params(param_path)
|
||||||
|
|
||||||
|
# Model Init
|
||||||
|
model = model_cls("MlpPolicy", env, verbose=1, seed=seed, device='cpu', **model_kwargs)
|
||||||
|
|
||||||
|
# Model train
|
||||||
|
model.learn(total_timesteps=int(train_steps), callback=callbacks)
|
||||||
|
|
||||||
|
# Model save
|
||||||
|
save_path = seed_path / f'model.zip'
|
||||||
|
model.save(save_path)
|
||||||
|
pass
|
||||||
|
# Compare perfoormance runs, for each seed within a model
|
||||||
|
compare_seed_runs(combination_path)
|
||||||
|
# Compare performance runs, for each model
|
||||||
|
# FIXME: Check THIS!!!!
|
||||||
|
compare_model_runs(study_root_path / observation_mode / env_name, f'{start_time}', 'step_reward')
|
||||||
# Train ends here ############################################################
|
# Train ends here ############################################################
|
||||||
|
|
||||||
# Evaluation starts here #####################################################
|
# Evaluation starts here #####################################################
|
||||||
# Iterate Observation Modes
|
# Iterate Observation Modes
|
||||||
for observation_mode in observation_modes:
|
|
||||||
# TODO: For trained policy in study_root_path / identifier
|
|
||||||
for policy_group in (x for x in study_root_path.iterdir() if x.is_dir()):
|
|
||||||
# TODO: Pick random seed or iterate over available seeds
|
|
||||||
policy_seed = next((y for y in study_root_path.iterdir() if y.is_dir()))
|
|
||||||
# TODO: retrieve model class
|
|
||||||
# TODO: Load both agents
|
|
||||||
models = []
|
|
||||||
# TODO: Evaluation Loop for i in range(100) Episodes
|
|
||||||
for episode in range(100):
|
|
||||||
with next(policy_seed.glob('*.yaml')).open('r') as f:
|
|
||||||
env_kwargs = simplejson.load(f)
|
|
||||||
# TODO: Monitor Init
|
|
||||||
env = None # TODO: Init Env
|
|
||||||
for step in range(400):
|
|
||||||
random_actions = [[random.randint(0, env.n_actions) for _ in range(len(models))] for _ in range(200)]
|
|
||||||
env_state = env.reset()
|
|
||||||
rew = 0
|
|
||||||
for agent_i_action in random_actions:
|
|
||||||
env_state, step_r, done_bool, info_obj = env.step(agent_i_action)
|
|
||||||
rew += step_r
|
|
||||||
if done_bool:
|
|
||||||
break
|
|
||||||
print(f'Factory run {episode} done, reward is:\n {rew}')
|
|
||||||
# TODO: Plotting
|
|
||||||
|
|
||||||
pass
|
for observation_mode in observation_modes:
|
||||||
|
obs_mode_path = next(x for x in study_root_path.iterdir() if x.is_dir() and x.name == observation_mode)
|
||||||
|
# For trained policy in study_root_path / identifier
|
||||||
|
for env_path in [x for x in obs_mode_path.iterdir() if x.is_dir()]:
|
||||||
|
for policy_path in [x for x in env_path.iterdir() if x. is_dir()]:
|
||||||
|
# TODO: Pick random seed or iterate over available seeds
|
||||||
|
# First seed path version
|
||||||
|
# seed_path = next((y for y in policy_path.iterdir() if y.is_dir()))
|
||||||
|
# Iteration
|
||||||
|
for seed_path in (y for y in policy_path.iterdir() if y.is_dir()):
|
||||||
|
# retrieve model class
|
||||||
|
for model_cls in (val for key, val in h.MODEL_MAP.items() if key in policy_path.name):
|
||||||
|
# Load both agents
|
||||||
|
models = [model_cls.load(seed_path / 'model.zip') for _ in range(2)]
|
||||||
|
# Load old env kwargs
|
||||||
|
with next(seed_path.glob('*.json')).open('r') as f:
|
||||||
|
env_kwargs = simplejson.load(f)
|
||||||
|
env_kwargs.update(n_agents=2, additional_agent_placeholder=None,
|
||||||
|
**observation_modes[observation_mode].get('post_training_env_kwargs', {}))
|
||||||
|
|
||||||
|
# Monitor Init
|
||||||
|
with MonitorCallback(filepath=seed_path / f'e_1_monitor.pick') as monitor:
|
||||||
|
# Init Env
|
||||||
|
env = env_map[env_path.name][0](**env_kwargs)
|
||||||
|
# Evaluation Loop for i in range(n Episodes)
|
||||||
|
for episode in range(50):
|
||||||
|
obs = env.reset()
|
||||||
|
rew, done_bool = 0, False
|
||||||
|
while not done_bool:
|
||||||
|
actions = [model.predict(obs[i], deterministic=False)[0]
|
||||||
|
for i, model in enumerate(models)]
|
||||||
|
env_state, step_r, done_bool, info_obj = env.step(actions)
|
||||||
|
monitor.read_info(0, info_obj)
|
||||||
|
rew += step_r
|
||||||
|
if done_bool:
|
||||||
|
monitor.read_done(0, done_bool)
|
||||||
|
break
|
||||||
|
print(f'Factory run {episode} done, reward is:\n {rew}')
|
||||||
|
# Eval monitor outputs are automatically stored by the monitor object
|
||||||
|
|
||||||
|
# TODO: Plotting
|
||||||
|
pass
|
||||||
|
Loading…
x
Reference in New Issue
Block a user