Major Bug resolved

This commit is contained in:
Steffen Illium
2021-07-16 08:07:02 +02:00
parent e5dd49f0f0
commit de821ebc0c
6 changed files with 54 additions and 42 deletions

View File

@ -27,19 +27,19 @@ class BaseFactory(gym.Env):
@property @property
def observation_space(self): def observation_space(self):
if self.combin_agent_slices_in_obs: if self.combin_agent_slices_in_obs:
agent_slice = 1 n_agent_slices = 1
else: # not self.combin_agent_slices_in_obs: else: # not self.combin_agent_slices_in_obs:
if self.omit_agent_slice_in_obs: if self.omit_agent_slice_in_obs:
agent_slice = self.n_agents - 1 n_agent_slices = self.n_agents - 1
else: # not self.omit_agent_slice_in_obs: else: # not self.omit_agent_slice_in_obs:
agent_slice = self.n_agents n_agent_slices = self.n_agents
if self.pomdp_radius: if self.pomdp_radius:
shape = (self._obs_cube.shape[0] - agent_slice, self.pomdp_radius * 2 + 1, self.pomdp_radius * 2 + 1) shape = (self._slices.n - n_agent_slices, self.pomdp_radius * 2 + 1, self.pomdp_radius * 2 + 1)
space = spaces.Box(low=0, high=1, shape=shape, dtype=np.float32) space = spaces.Box(low=0, high=1, shape=shape, dtype=np.float32)
return space return space
else: else:
shape = [x-agent_slice if idx == 0 else x for idx, x in enumerate(self._obs_cube.shape)] shape = [x-n_agent_slices if idx == 0 else x for idx, x in enumerate(self._level_shape)]
space = spaces.Box(low=0, high=1, shape=shape, dtype=np.float32) space = spaces.Box(low=0, high=1, shape=shape, dtype=np.float32)
return space return space
@ -133,7 +133,7 @@ class BaseFactory(gym.Env):
# Agents # Agents
agents = [] agents = []
for i in range(self.n_agents): for i in range(self.n_agents):
agents.append(Slice(f'{c.AGENT.name}#{i}', np.zeros_like(level[0].slice))) agents.append(Slice(f'{c.AGENT.name}#{i}', np.zeros_like(level[0].slice, dtype=np.float32)))
state_slices.register_additional_items(level+doors+agents) state_slices.register_additional_items(level+doors+agents)
# Additional Slices from SubDomains # Additional Slices from SubDomains
@ -143,12 +143,14 @@ class BaseFactory(gym.Env):
def _init_obs_cube(self) -> np.ndarray: def _init_obs_cube(self) -> np.ndarray:
x, y = self._slices.by_enum(c.LEVEL).shape x, y = self._slices.by_enum(c.LEVEL).shape
state = np.zeros((len(self._slices), x, y)) state = np.zeros((len(self._slices), x, y), dtype=np.float32)
state[0] = self._slices.by_enum(c.LEVEL).slice state[0] = self._slices.by_enum(c.LEVEL).slice
if r := self.pomdp_radius: if r := self.pomdp_radius:
self._padded_obs_cube = np.full((len(self._slices), x + r*2, y + r*2), c.FREE_CELL.value) self._padded_obs_cube = np.full((len(self._slices), x + r*2, y + r*2), c.FREE_CELL.value, dtype=np.float32)
self._padded_obs_cube[0] = c.OCCUPIED_CELL.value self._padded_obs_cube[0] = c.OCCUPIED_CELL.value
self._padded_obs_cube[:, r:r+x, r:r+y] = state self._padded_obs_cube[:, r:r+x, r:r+y] = state
if self.combin_agent_slices_in_obs and self.n_agents > 1:
self._combined_obs_cube = np.zeros(self.observation_space.shape, dtype=np.float32)
return state return state
def _init_entities(self): def _init_entities(self):
@ -177,18 +179,25 @@ class BaseFactory(gym.Env):
self._slices = self._init_state_slices() self._slices = self._init_state_slices()
self._obs_cube = self._init_obs_cube() self._obs_cube = self._init_obs_cube()
self._entitites = self._init_entities() self._entitites = self._init_entities()
self.do_additional_reset()
self._flush_state() self._flush_state()
self._steps = 0 self._steps = 0
info = self._summarize_state() if self.record_episodes else {} obs = self._get_observations()
return None, None, None, info return obs
def pre_step(self) -> None: def pre_step(self) -> None:
pass pass
def post_step(self) -> dict: def do_additional_reset(self) -> None:
pass pass
def do_additional_step(self) -> dict:
return {}
def post_step(self) -> dict:
return {}
def step(self, actions): def step(self, actions):
actions = [actions] if isinstance(actions, int) or np.isscalar(actions) else actions actions = [actions] if isinstance(actions, int) or np.isscalar(actions) else actions
assert isinstance(actions, Iterable), f'"actions" has to be in [{int, list}]' assert isinstance(actions, Iterable), f'"actions" has to be in [{int, list}]'
@ -219,6 +228,10 @@ class BaseFactory(gym.Env):
agent.temp_action = action agent.temp_action = action
agent.temp_valid = valid agent.temp_valid = valid
# In-between step Hook for later use
info = self.do_additional_step()
# Write to observation cube
self._flush_state() self._flush_state()
tiles_with_collisions = self.get_all_tiles_with_collisions() tiles_with_collisions = self.get_all_tiles_with_collisions()
@ -237,7 +250,8 @@ class BaseFactory(gym.Env):
self._doors.tick_doors() self._doors.tick_doors()
# Finalize # Finalize
reward, info = self.calculate_reward() reward, reward_info = self.calculate_reward()
info.update(reward_info)
if self._steps >= self.max_steps: if self._steps >= self.max_steps:
done = True done = True
info.update(step_reward=reward, step=self._steps) info.update(step_reward=reward, step=self._steps)
@ -255,10 +269,10 @@ class BaseFactory(gym.Env):
self._obs_cube[np.arange(len(self._slices)) != self._slices.get_idx(c.LEVEL)] = c.FREE_CELL.value self._obs_cube[np.arange(len(self._slices)) != self._slices.get_idx(c.LEVEL)] = c.FREE_CELL.value
if self.parse_doors: if self.parse_doors:
for door in self._doors: for door in self._doors:
if door.is_open: if door.is_open and self._obs_cube[self._slices.get_idx(c.DOORS)][door.pos] != c.OPEN_DOOR.value:
self._obs_cube[self._slices.get_idx(c.DOORS)][door.pos] = c.IS_OPEN_DOOR.value self._obs_cube[self._slices.get_idx(c.DOORS)][door.pos] = c.OPEN_DOOR.value
else: elif door.is_closed and self._obs_cube[self._slices.get_idx(c.DOORS)][door.pos] != c.CLOSED_DOOR.value:
self._obs_cube[self._slices.get_idx(c.DOORS)][door.pos] = c.IS_CLOSED_DOOR.value self._obs_cube[self._slices.get_idx(c.DOORS)][door.pos] = c.CLOSED_DOOR.value
for agent in self._agents: for agent in self._agents:
self._obs_cube[self._slices.get_idx_by_name(agent.name)][agent.pos] = c.OCCUPIED_CELL.value self._obs_cube[self._slices.get_idx_by_name(agent.name)][agent.pos] = c.OCCUPIED_CELL.value
if agent.last_pos != h.NO_POS: if agent.last_pos != h.NO_POS:

View File

@ -199,7 +199,7 @@ class Door(Entity):
def __init__(self, *args, context, closed_on_init=True, auto_close_interval=10): def __init__(self, *args, context, closed_on_init=True, auto_close_interval=10):
super(Door, self).__init__(*args) super(Door, self).__init__(*args)
self._state = c.IS_CLOSED_DOOR self._state = c.CLOSED_DOOR
self.auto_close_interval = auto_close_interval self.auto_close_interval = auto_close_interval
self.time_to_close = -1 self.time_to_close = -1
neighbor_pos = list(itertools.product([-1, 1, 0], repeat=2))[:-1] neighbor_pos = list(itertools.product([-1, 1, 0], repeat=2))[:-1]
@ -215,18 +215,18 @@ class Door(Entity):
@property @property
def is_closed(self): def is_closed(self):
return self._state == c.IS_CLOSED_DOOR return self._state == c.CLOSED_DOOR
@property @property
def is_open(self): def is_open(self):
return self._state == c.IS_OPEN_DOOR return self._state == c.OPEN_DOOR
@property @property
def status(self): def status(self):
return self._state return self._state
def use(self): def use(self):
if self._state == c.IS_OPEN_DOOR: if self._state == c.OPEN_DOOR:
self._close() self._close()
else: else:
self._open() self._open()
@ -239,12 +239,12 @@ class Door(Entity):
def _open(self): def _open(self):
self.connectivity.add_edges_from([(self.pos, x) for x in self.connectivity.nodes]) self.connectivity.add_edges_from([(self.pos, x) for x in self.connectivity.nodes])
self._state = c.IS_OPEN_DOOR self._state = c.OPEN_DOOR
self.time_to_close = self.auto_close_interval self.time_to_close = self.auto_close_interval
def _close(self): def _close(self):
self.connectivity.remove_node(self.pos) self.connectivity.remove_node(self.pos)
self._state = c.IS_CLOSED_DOOR self._state = c.CLOSED_DOOR
def is_linked(self, old_pos, new_pos): def is_linked(self, old_pos, new_pos):
try: try:

View File

@ -99,7 +99,7 @@ class SimpleFactory(BaseFactory):
free_for_dirt = self._tiles.empty_tiles free_for_dirt = self._tiles.empty_tiles
# randomly distribute dirt across the grid # randomly distribute dirt across the grid
n_dirt_tiles = int(random.uniform(0, self.dirt_properties.max_spawn_ratio) * len(free_for_dirt)) n_dirt_tiles = max(0, int(random.uniform(0, self.dirt_properties.max_spawn_ratio) * len(free_for_dirt)))
for tile in free_for_dirt[:n_dirt_tiles]: for tile in free_for_dirt[:n_dirt_tiles]:
new_value = dirt_slice[tile.pos] + self.dirt_properties.gain_amount new_value = dirt_slice[tile.pos] + self.dirt_properties.gain_amount
dirt_slice[tile.pos] = min(new_value, self.dirt_properties.max_local_amount) dirt_slice[tile.pos] = min(new_value, self.dirt_properties.max_local_amount)
@ -115,7 +115,7 @@ class SimpleFactory(BaseFactory):
else: else:
return False return False
def post_step(self) -> dict: def do_additional_step(self) -> dict:
if smear_amount := self.dirt_properties.dirt_smear_amount: if smear_amount := self.dirt_properties.dirt_smear_amount:
dirt_slice = self._slices.by_name(DIRT).slice dirt_slice = self._slices.by_name(DIRT).slice
for agent in self._agents: for agent in self._agents:
@ -144,12 +144,9 @@ class SimpleFactory(BaseFactory):
else: else:
raise RuntimeError('This should not happen!!!') raise RuntimeError('This should not happen!!!')
def reset(self) -> (np.ndarray, int, bool, dict): def do_additional_reset(self) -> None:
_ = super().reset() # state, reward, done, info ... =
self.spawn_dirt() self.spawn_dirt()
self._next_dirt_spawn = self.dirt_properties.spawn_frequency self._next_dirt_spawn = self.dirt_properties.spawn_frequency
obs = self._get_observations()
return obs
def calculate_reward(self) -> (int, dict): def calculate_reward(self) -> (int, dict):
info_dict = dict() info_dict = dict()
@ -174,7 +171,7 @@ class SimpleFactory(BaseFactory):
if self._is_clean_up_action(agent.temp_action): if self._is_clean_up_action(agent.temp_action):
if agent.temp_valid: if agent.temp_valid:
reward += 1 reward += 0.5
self.print(f'{agent.name} did just clean up some dirt at {agent.pos}.') self.print(f'{agent.name} did just clean up some dirt at {agent.pos}.')
info_dict.update(dirt_cleaned=1) info_dict.update(dirt_cleaned=1)
else: else:

View File

@ -17,8 +17,8 @@ class Constants(Enum):
OCCUPIED_CELL = 1 OCCUPIED_CELL = 1
DOORS = 'doors' DOORS = 'doors'
IS_CLOSED_DOOR = 1 CLOSED_DOOR = 1
IS_OPEN_DOOR = -1 OPEN_DOOR = -1
LEVEL_IDX = 0 LEVEL_IDX = 0

10
main.py
View File

@ -92,8 +92,8 @@ if __name__ == '__main__':
from algorithms.reg_dqn import RegDQN from algorithms.reg_dqn import RegDQN
# from sb3_contrib import QRDQN # from sb3_contrib import QRDQN
dirt_props = DirtProperties(clean_amount=6, gain_amount=1, max_global_amount=30, dirt_props = DirtProperties(clean_amount=1, gain_amount=0.1, max_global_amount=20,
max_local_amount=5, spawn_frequency=5, max_spawn_ratio=0.05, max_local_amount=1, spawn_frequency=5, max_spawn_ratio=0.05,
dirt_smear_amount=0.0) dirt_smear_amount=0.0)
move_props = MovementProperties(allow_diagonal_movement=True, move_props = MovementProperties(allow_diagonal_movement=True,
allow_square_movement=True, allow_square_movement=True,
@ -102,11 +102,11 @@ if __name__ == '__main__':
out_path = None out_path = None
for modeL_type in [A2C, PPO, RegDQN, DQN]: # , QRDQN]: for modeL_type in [A2C]: # , PPO, RegDQN, DQN]: # , QRDQN]:
for seed in range(3): for seed in range(3):
with SimpleFactory(n_agents=1, dirt_properties=dirt_props, pomdp_radius=2, max_steps=400, parse_doors=False, with SimpleFactory(n_agents=1, dirt_properties=dirt_props, pomdp_radius=2, max_steps=400, parse_doors=True,
movement_properties=move_props, level_name='rooms', frames_to_stack=4, movement_properties=move_props, level_name='rooms', frames_to_stack=0,
omit_agent_slice_in_obs=True, combin_agent_slices_in_obs=True, record_episodes=False omit_agent_slice_in_obs=True, combin_agent_slices_in_obs=True, record_episodes=False
) as env: ) as env:

View File

@ -14,17 +14,18 @@ warnings.filterwarnings('ignore', category=UserWarning)
if __name__ == '__main__': if __name__ == '__main__':
model_name = 'A2C_1626103200' model_name = 'PPO_1626384768'
run_id = 0 run_id = 0
out_path = Path(__file__).parent / 'debug_out' out_path = Path(__file__).parent / 'debug_out'
model_path = out_path / model_name model_path = out_path / model_name
with (model_path / f'env_{model_name}.yaml').open('r') as f: with (model_path / f'env_{model_name}.yaml').open('r') as f:
env_kwargs = yaml.load(f, Loader=yaml.FullLoader) env_kwargs = yaml.load(f, Loader=yaml.FullLoader)
env_kwargs.update(dirt_properties=DirtProperties(clean_amount=1, gain_amount=0.3, max_global_amount=20, if False:
max_local_amount=2, spawn_frequency=5, max_spawn_ratio=0.05, env_kwargs.update(dirt_properties=DirtProperties(clean_amount=1, gain_amount=0.1, max_global_amount=20,
dirt_smear_amount=0.2), max_local_amount=1, spawn_frequency=5, max_spawn_ratio=0.05,
combin_agent_slices_in_obs=True, omit_agent_slice_in_obs=True) dirt_smear_amount=0.5),
combin_agent_slices_in_obs=True, omit_agent_slice_in_obs=True)
with SimpleFactory(**env_kwargs) as env: with SimpleFactory(**env_kwargs) as env:
# Edit THIS: # Edit THIS:
@ -32,5 +33,5 @@ if __name__ == '__main__':
this_model = model_files[0] this_model = model_files[0]
model = PPO.load(this_model) model = PPO.load(this_model)
evaluation_result = evaluate_policy(model, env, n_eval_episodes=100, deterministic=True, render=True) evaluation_result = evaluate_policy(model, env, n_eval_episodes=100, deterministic=False, render=True)
print(evaluation_result) print(evaluation_result)