From 444ffe3f37d37483c498878d5db81d1c36780ad6 Mon Sep 17 00:00:00 2001 From: Steffen Illium Date: Tue, 7 Sep 2021 16:04:04 +0200 Subject: [PATCH] Item und dirt funktionieren nun zusammen names refactored --- algorithms/q_learner.py | 4 +- environments/factory/base/base_factory.py | 2 +- environments/factory/base/objects.py | 2 +- .../{simple_factory.py => factory_dirt.py} | 30 ++++----- environments/factory/factory_dirt_item.py | 51 ++++++++++++++++ ...double_task_factory.py => factory_item.py} | 61 ++++++++++++------- main.py | 11 ++-- main_test.py | 4 +- reload_agent.py | 6 +- 9 files changed, 121 insertions(+), 50 deletions(-) rename environments/factory/{simple_factory.py => factory_dirt.py} (90%) create mode 100644 environments/factory/factory_dirt_item.py rename environments/factory/{double_task_factory.py => factory_item.py} (84%) diff --git a/algorithms/q_learner.py b/algorithms/q_learner.py index 93ea949..53e891e 100644 --- a/algorithms/q_learner.py +++ b/algorithms/q_learner.py @@ -99,7 +99,7 @@ class QLearner(BaseLearner): if __name__ == '__main__': - from environments.factory.simple_factory import SimpleFactory, DirtProperties, MovementProperties + from environments.factory.factory_dirt import DirtFactory, DirtProperties, MovementProperties from algorithms.common import BaseDDQN, BaseICM from algorithms.m_q_learner import MQLearner, MQICMLearner from algorithms.vdn_learner import VDNLearner @@ -109,7 +109,7 @@ if __name__ == '__main__': with (Path(f'../environments/factory/env_default_param.yaml')).open('r') as f: env_kwargs = yaml.load(f, Loader=yaml.FullLoader) - env = SimpleFactory(**env_kwargs) + env = DirtFactory(**env_kwargs) obs_shape = np.prod(env.observation_space.shape) n_actions = env.action_space.n diff --git a/environments/factory/base/base_factory.py b/environments/factory/base/base_factory.py index 16a4c3b..8930f79 100644 --- a/environments/factory/base/base_factory.py +++ b/environments/factory/base/base_factory.py @@ -279,7 +279,7 @@ class BaseFactory(gym.Env): for key, array in state_array_dict.items(): # Flush state array object representation to obs cube if self[key].is_per_agent: - per_agent_idx = self[key].get_idx_by_name(agent.name) + per_agent_idx = self[key].idx_by_entity(agent) z = 1 self._obs_cube[running_idx: running_idx+z] = array[per_agent_idx] else: diff --git a/environments/factory/base/objects.py b/environments/factory/base/objects.py index ece397c..01d782f 100644 --- a/environments/factory/base/objects.py +++ b/environments/factory/base/objects.py @@ -182,7 +182,7 @@ class Door(Entity): @property def encoding(self): - return 1 if self.is_closed else 0.5 + return 1 if self.is_closed else 2 @property def str_state(self): diff --git a/environments/factory/simple_factory.py b/environments/factory/factory_dirt.py similarity index 90% rename from environments/factory/simple_factory.py rename to environments/factory/factory_dirt.py index 20fd10e..acce0e3 100644 --- a/environments/factory/simple_factory.py +++ b/environments/factory/factory_dirt.py @@ -116,18 +116,18 @@ def entropy(x): # noinspection PyAttributeOutsideInit, PyAbstractClass -class SimpleFactory(BaseFactory): +class DirtFactory(BaseFactory): @property def additional_actions(self) -> Union[Action, List[Action]]: - super_actions = super(SimpleFactory, self).additional_actions + super_actions = super().additional_actions if self.dirt_properties.agent_can_interact: super_actions.append(Action(enum_ident=CLEAN_UP_ACTION)) return super_actions @property def additional_entities(self) -> Dict[(Enum, Entities)]: - super_entities = super(SimpleFactory, self).additional_entities + super_entities = super().additional_entities dirt_register = DirtRegister(self.dirt_properties, self._level_shape) super_entities.update(({c.DIRT: dirt_register})) return super_entities @@ -139,10 +139,10 @@ class SimpleFactory(BaseFactory): self._dirt_rng = np.random.default_rng(env_seed) self._dirt: DirtRegister kwargs.update(env_seed=env_seed) - super(SimpleFactory, self).__init__(*args, **kwargs) + super().__init__(*args, **kwargs) def render_additional_assets(self, mode='human'): - additional_assets = super(SimpleFactory, self).render_additional_assets() + additional_assets = super().render_additional_assets() dirt = [RenderEntity('dirt', dirt.tile.pos, min(0.15 + dirt.amount, 1.5), 'scale') for dirt in self[c.DIRT]] additional_assets.extend(dirt) @@ -170,7 +170,7 @@ class SimpleFactory(BaseFactory): self[c.DIRT].spawn_dirt(free_for_dirt[:n_dirt_tiles]) def do_additional_step(self) -> dict: - info_dict = super(SimpleFactory, self).do_additional_step() + info_dict = super().do_additional_step() if smear_amount := self.dirt_properties.dirt_smear_amount: for agent in self[c.AGENT]: if agent.temp_valid and agent.last_pos != c.NO_POS: @@ -193,7 +193,7 @@ class SimpleFactory(BaseFactory): return info_dict def do_additional_actions(self, agent: Agent, action: Action) -> Union[None, c]: - valid = super(SimpleFactory, self).do_additional_actions(agent, action) + valid = super().do_additional_actions(agent, action) if valid is None: if action == CLEAN_UP_ACTION: if self.dirt_properties.agent_can_interact: @@ -207,12 +207,12 @@ class SimpleFactory(BaseFactory): return valid def do_additional_reset(self) -> None: - super(SimpleFactory, self).do_additional_reset() + super().do_additional_reset() self.trigger_dirt_spawn() self._next_dirt_spawn = self.dirt_properties.spawn_frequency def calculate_additional_reward(self, agent: Agent) -> (int, dict): - reward, info_dict = super(SimpleFactory, self).calculate_additional_reward(agent) + reward, info_dict = super().calculate_additional_reward(agent) dirt = [dirt.amount for dirt in self[c.DIRT]] current_dirt_amount = sum(dirt) dirty_tile_count = len(dirt) @@ -253,12 +253,12 @@ if __name__ == '__main__': with RecorderCallback(filepath=Path('debug_out') / f'recorder_xxxx.json', occupation_map=False, trajectory_map=False) as recorder: - factory = SimpleFactory(n_agents=1, done_at_collision=False, frames_to_stack=0, - level_name='rooms', max_steps=400, combin_agent_obs=True, - omit_agent_in_obs=True, parse_doors=True, pomdp_r=3, - record_episodes=True, verbose=True, cast_shadows=True, - movement_properties=move_props, dirt_properties=dirt_props - ) + factory = DirtFactory(n_agents=1, done_at_collision=False, frames_to_stack=0, + level_name='rooms', max_steps=400, combin_agent_obs=True, + omit_agent_in_obs=True, parse_doors=True, pomdp_r=3, + record_episodes=True, verbose=True, cast_shadows=True, + movement_properties=move_props, dirt_properties=dirt_props + ) # noinspection DuplicatedCode n_actions = factory.action_space.n - 1 diff --git a/environments/factory/factory_dirt_item.py b/environments/factory/factory_dirt_item.py new file mode 100644 index 0000000..7ef10d0 --- /dev/null +++ b/environments/factory/factory_dirt_item.py @@ -0,0 +1,51 @@ +import random + +from environments.factory.factory_dirt import DirtFactory, DirtProperties +from environments.factory.factory_item import ItemFactory, ItemProperties +from environments.utility_classes import MovementProperties + + +class DirtItemFactory(ItemFactory, DirtFactory): + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + +if __name__ == '__main__': + + dirt_props = DirtProperties(clean_amount=2, gain_amount=0.1, max_global_amount=20, + max_local_amount=1, spawn_frequency=3, max_spawn_ratio=0.05, + dirt_smear_amount=0.0, agent_can_interact=True) + item_props = ItemProperties(n_items=5, agent_can_interact=True) + move_props = MovementProperties(allow_diagonal_movement=True, + allow_square_movement=True, + allow_no_op=False) + + render = True + + factory = DirtItemFactory(n_agents=1, done_at_collision=False, frames_to_stack=0, + level_name='rooms', max_steps=400, combin_agent_obs=True, + omit_agent_in_obs=True, parse_doors=True, pomdp_r=3, + record_episodes=True, verbose=True, cast_shadows=True, + movement_properties=move_props, dirt_properties=dirt_props + ) + + # noinspection DuplicatedCode + n_actions = factory.action_space.n - 1 + _ = factory.observation_space + + for epoch in range(4): + random_actions = [[random.randint(0, n_actions) for _ + in range(factory.n_agents)] for _ + in range(factory.max_steps + 1)] + env_state = factory.reset() + r = 0 + for agent_i_action in random_actions: + env_state, step_r, done_bool, info_obj = factory.step(agent_i_action) + r += step_r + if render: + factory.render() + if done_bool: + break + print(f'Factory run {epoch} done, reward is:\n {r}') + pass diff --git a/environments/factory/double_task_factory.py b/environments/factory/factory_item.py similarity index 84% rename from environments/factory/double_task_factory.py rename to environments/factory/factory_item.py index 215cfd1..e97b8e7 100644 --- a/environments/factory/double_task_factory.py +++ b/environments/factory/factory_item.py @@ -4,7 +4,7 @@ from enum import Enum from typing import List, Union, NamedTuple, Dict import numpy as np -from environments.factory.simple_factory import SimpleFactory +from environments.factory.base.base_factory import BaseFactory from environments.helpers import Constants as c from environments import helpers as h from environments.factory.base.objects import Agent, Entity, Action, Tile, MoveableEntity @@ -31,6 +31,7 @@ class Item(MoveableEntity): def can_collide(self): return False + @property def encoding(self): # Edit this if you want items to be drawn in the ops differntly return 1 @@ -42,13 +43,13 @@ class ItemRegister(MovingEntityObjectRegister): self._array[:] = c.FREE_CELL.value for item in self: if item.pos != c.NO_POS.value: - self._array[0, item.x, item.y] = item.encoding() + self._array[0, item.x, item.y] = item.encoding return self._array _accepted_objects = Item def spawn_items(self, tiles: List[Tile]): - items = [Item(idx, tile) for idx, tile in enumerate(tiles)] + items = [Item(tile) for tile in tiles] self.register_additional_items(items) @@ -80,7 +81,7 @@ class Inventory(UserList): for item_idx, item in enumerate(self): x_diff, y_diff = divmod(item_idx, max_x) - self._array[0].slice[int(x + x_diff), int(y + y_diff)] = item.encoding + self._array[0, int(x + x_diff), int(y + y_diff)] = item.encoding return self._array def __repr__(self): @@ -92,6 +93,12 @@ class Inventory(UserList): else: raise RuntimeError('Inventory is full') + def belongs_to_entity(self, entity): + return self.agent == entity + + def summarize_state(self): + return {val.name: val.summarize_state() for val in self} + class Inventories(ObjectRegister): @@ -114,6 +121,18 @@ class Inventories(ObjectRegister): for _, agent in enumerate(agents)] self.register_additional_items(inventories) + def idx_by_entity(self, entity): + try: + return next((idx for idx, inv in enumerate(self) if inv.belongs_to_entity(entity))) + except StopIteration: + return None + + def by_entity(self, entity): + try: + return next((inv for inv in self if inv.belongs_to_entity(entity))) + except StopIteration: + return None + class DropOffLocation(Entity): @@ -164,28 +183,28 @@ class ItemProperties(NamedTuple): # noinspection PyAttributeOutsideInit, PyAbstractClass -class DoubleTaskFactory(SimpleFactory): +class ItemFactory(BaseFactory): # noinspection PyMissingConstructor - def __init__(self, item_properties: ItemProperties, *args, env_seed=time.time_ns(), **kwargs): + def __init__(self, *args, item_properties: ItemProperties = ItemProperties(), env_seed=time.time_ns(), **kwargs): if isinstance(item_properties, dict): item_properties = ItemProperties(**item_properties) self.item_properties = item_properties kwargs.update(env_seed=env_seed) self._item_rng = np.random.default_rng(env_seed) - assert item_properties.n_items < kwargs.get('pomdp_r', 0) ** 2 or not kwargs.get('pomdp_r', 0) - super(DoubleTaskFactory, self).__init__(*args, **kwargs) + assert (item_properties.n_items <= ((1 + kwargs.get('pomdp_r', 0) * 2) ** 2)) or not kwargs.get('pomdp_r', 0) + super().__init__(*args, **kwargs) @property def additional_actions(self) -> Union[Action, List[Action]]: # noinspection PyUnresolvedReferences - super_actions = super(DoubleTaskFactory, self).additional_actions + super_actions = super().additional_actions super_actions.append(Action(enum_ident=h.EnvActions.ITEM_ACTION)) return super_actions @property def additional_entities(self) -> Dict[(Enum, Entities)]: # noinspection PyUnresolvedReferences - super_entities = super(DoubleTaskFactory, self).additional_entities + super_entities = super().additional_entities empty_tiles = self[c.FLOOR].empty_tiles[:self.item_properties.n_drop_off_locations] drop_offs = DropOffLocations.from_tiles(empty_tiles, self._level_shape, @@ -202,7 +221,7 @@ class DoubleTaskFactory(SimpleFactory): return super_entities def do_item_action(self, agent: Agent): - inventory = self[c.INVENTORY].by_name(agent.name) + inventory = self[c.INVENTORY].by_entity(agent) if drop_off := self[c.DROP_OFF].by_pos(agent.pos): if inventory: valid = drop_off.place_item(inventory.pop(0)) @@ -221,7 +240,7 @@ class DoubleTaskFactory(SimpleFactory): def do_additional_actions(self, agent: Agent, action: Action) -> Union[None, c]: # noinspection PyUnresolvedReferences - valid = super(DoubleTaskFactory, self).do_additional_actions(agent, action) + valid = super().do_additional_actions(agent, action) if valid is None: if action == h.EnvActions.ITEM_ACTION: if self.item_properties.agent_can_interact: @@ -236,7 +255,7 @@ class DoubleTaskFactory(SimpleFactory): def do_additional_reset(self) -> None: # noinspection PyUnresolvedReferences - super(DoubleTaskFactory, self).do_additional_reset() + super().do_additional_reset() self._next_item_spawn = self.item_properties.spawn_frequency self.trigger_item_spawn() @@ -251,7 +270,7 @@ class DoubleTaskFactory(SimpleFactory): def do_additional_step(self) -> dict: # noinspection PyUnresolvedReferences - info_dict = super(DoubleTaskFactory, self).do_additional_step() + info_dict = super().do_additional_step() if not self._next_item_spawn: self.trigger_item_spawn() else: @@ -260,7 +279,7 @@ class DoubleTaskFactory(SimpleFactory): def calculate_additional_reward(self, agent: Agent) -> (int, dict): # noinspection PyUnresolvedReferences - reward, info_dict = super(DoubleTaskFactory, self).calculate_additional_reward(agent) + reward, info_dict = super().calculate_additional_reward(agent) if h.EnvActions.ITEM_ACTION == agent.temp_action: if agent.temp_valid: if self[c.DROP_OFF].by_pos(agent.pos): @@ -277,7 +296,7 @@ class DoubleTaskFactory(SimpleFactory): def render_additional_assets(self, mode='human'): # noinspection PyUnresolvedReferences - additional_assets = super(DoubleTaskFactory, self).render_additional_assets() + additional_assets = super().render_additional_assets() items = [RenderEntity(c.ITEM.value, item.tile.pos) for item in self[c.ITEM]] additional_assets.extend(items) drop_offs = [RenderEntity(c.DROP_OFF.value, drop_off.tile.pos) for drop_off in self[c.DROP_OFF]] @@ -291,11 +310,11 @@ if __name__ == '__main__': item_props = ItemProperties() - factory = DoubleTaskFactory(item_props, n_agents=3, done_at_collision=False, frames_to_stack=0, - level_name='rooms', max_steps=4000, - omit_agent_in_obs=True, parse_doors=True, pomdp_r=3, - record_episodes=False, verbose=False - ) + factory = ItemFactory(item_properties=item_props, n_agents=3, done_at_collision=False, frames_to_stack=0, + level_name='rooms', max_steps=4000, + omit_agent_in_obs=True, parse_doors=True, pomdp_r=3, + record_episodes=False, verbose=False + ) # noinspection DuplicatedCode n_actions = factory.action_space.n - 1 diff --git a/main.py b/main.py index 140909e..07d1212 100644 --- a/main.py +++ b/main.py @@ -10,8 +10,9 @@ import pandas as pd from stable_baselines3.common.callbacks import CallbackList from stable_baselines3.common.vec_env import SubprocVecEnv -from environments.factory.double_task_factory import DoubleTaskFactory, ItemProperties -from environments.factory.simple_factory import DirtProperties, SimpleFactory +from environments.factory.factory_dirt_item import DirtItemFactory +from environments.factory.factory_item import ItemFactory, ItemProperties +from environments.factory.factory_dirt import DirtProperties, DirtFactory from environments.helpers import IGNORED_DF_COLUMNS from environments.logging.monitor import MonitorCallback from environments.logging.plotting import prepare_plot @@ -94,7 +95,7 @@ def compare_runs(run_path: Path, run_identifier: int, parameter: Union[str, List def make_env(env_kwargs_dict): def _init(): - with SimpleFactory(**env_kwargs_dict) as init_env: + with DirtItemFactory(**env_kwargs_dict) as init_env: return init_env return _init @@ -128,7 +129,7 @@ if __name__ == '__main__': for seed in range(3): env_kwargs = dict(n_agents=1, # with_dirt=True, - # item_properties=item_props, + item_properties=item_props, dirt_properties=dirt_props, movement_properties=move_props, pomdp_r=2, max_steps=400, parse_doors=True, @@ -139,7 +140,7 @@ if __name__ == '__main__': if modeL_type.__name__ in ["PPO", "A2C"]: kwargs = dict(ent_coef=0.01) - env = SubprocVecEnv([make_env(env_kwargs) for _ in range(10)], start_method="spawn") + env = SubprocVecEnv([make_env(env_kwargs) for _ in range(1)], start_method="spawn") elif modeL_type.__name__ in ["RegDQN", "DQN", "QRDQN"]: env = make_env(env_kwargs)() kwargs = dict(buffer_size=50000, diff --git a/main_test.py b/main_test.py index 7648316..0940285 100644 --- a/main_test.py +++ b/main_test.py @@ -10,7 +10,7 @@ from stable_baselines3.common.callbacks import CallbackList from stable_baselines3 import PPO, DQN, A2C # our imports -from environments.factory.simple_factory import SimpleFactory, DirtProperties +from environments.factory.factory_dirt import DirtFactory, DirtProperties from environments.logging.monitor import MonitorCallback from algorithms.reg_dqn import RegDQN from main import compare_runs, combine_runs @@ -49,7 +49,7 @@ if __name__ == '__main__': dirt_props = DirtProperties(clean_amount=3, gain_amount=0.2, max_global_amount=30, max_local_amount=3, spawn_frequency=1, max_spawn_ratio=0.05) # env_kwargs.update(n_agents=1, dirt_properties=dirt_props) - env = SimpleFactory(**env_kwargs) + env = DirtFactory(**env_kwargs) env = FrameStack(env, 4) diff --git a/reload_agent.py b/reload_agent.py index 80b5e49..4a9ae79 100644 --- a/reload_agent.py +++ b/reload_agent.py @@ -6,8 +6,8 @@ from natsort import natsorted from stable_baselines3 import PPO, DQN, A2C from stable_baselines3.common.evaluation import evaluate_policy -from environments.factory.simple_factory import DirtProperties, SimpleFactory -from environments.factory.double_task_factory import ItemProperties, DoubleTaskFactory +from environments.factory.factory_dirt import DirtProperties, DirtFactory +from environments.factory.factory_item import ItemProperties, ItemFactory warnings.filterwarnings('ignore', category=FutureWarning) warnings.filterwarnings('ignore', category=UserWarning) @@ -30,7 +30,7 @@ if __name__ == '__main__': max_local_amount=1, spawn_frequency=5, max_spawn_ratio=0.05, dirt_smear_amount=0.5), combin_agent_slices_in_obs=True, omit_agent_slice_in_obs=True) - with SimpleFactory(**env_kwargs) as env: + with DirtFactory(**env_kwargs) as env: # Edit THIS: env.seed(seed)