Major Bug resolved
This commit is contained in:
@ -27,19 +27,19 @@ class BaseFactory(gym.Env):
|
|||||||
@property
|
@property
|
||||||
def observation_space(self):
|
def observation_space(self):
|
||||||
if self.combin_agent_slices_in_obs:
|
if self.combin_agent_slices_in_obs:
|
||||||
agent_slice = 1
|
n_agent_slices = 1
|
||||||
else: # not self.combin_agent_slices_in_obs:
|
else: # not self.combin_agent_slices_in_obs:
|
||||||
if self.omit_agent_slice_in_obs:
|
if self.omit_agent_slice_in_obs:
|
||||||
agent_slice = self.n_agents - 1
|
n_agent_slices = self.n_agents - 1
|
||||||
else: # not self.omit_agent_slice_in_obs:
|
else: # not self.omit_agent_slice_in_obs:
|
||||||
agent_slice = self.n_agents
|
n_agent_slices = self.n_agents
|
||||||
|
|
||||||
if self.pomdp_radius:
|
if self.pomdp_radius:
|
||||||
shape = (self._obs_cube.shape[0] - agent_slice, self.pomdp_radius * 2 + 1, self.pomdp_radius * 2 + 1)
|
shape = (self._slices.n - n_agent_slices, self.pomdp_radius * 2 + 1, self.pomdp_radius * 2 + 1)
|
||||||
space = spaces.Box(low=0, high=1, shape=shape, dtype=np.float32)
|
space = spaces.Box(low=0, high=1, shape=shape, dtype=np.float32)
|
||||||
return space
|
return space
|
||||||
else:
|
else:
|
||||||
shape = [x-agent_slice if idx == 0 else x for idx, x in enumerate(self._obs_cube.shape)]
|
shape = [x-n_agent_slices if idx == 0 else x for idx, x in enumerate(self._level_shape)]
|
||||||
space = spaces.Box(low=0, high=1, shape=shape, dtype=np.float32)
|
space = spaces.Box(low=0, high=1, shape=shape, dtype=np.float32)
|
||||||
return space
|
return space
|
||||||
|
|
||||||
@ -133,7 +133,7 @@ class BaseFactory(gym.Env):
|
|||||||
# Agents
|
# Agents
|
||||||
agents = []
|
agents = []
|
||||||
for i in range(self.n_agents):
|
for i in range(self.n_agents):
|
||||||
agents.append(Slice(f'{c.AGENT.name}#{i}', np.zeros_like(level[0].slice)))
|
agents.append(Slice(f'{c.AGENT.name}#{i}', np.zeros_like(level[0].slice, dtype=np.float32)))
|
||||||
state_slices.register_additional_items(level+doors+agents)
|
state_slices.register_additional_items(level+doors+agents)
|
||||||
|
|
||||||
# Additional Slices from SubDomains
|
# Additional Slices from SubDomains
|
||||||
@ -143,12 +143,14 @@ class BaseFactory(gym.Env):
|
|||||||
|
|
||||||
def _init_obs_cube(self) -> np.ndarray:
|
def _init_obs_cube(self) -> np.ndarray:
|
||||||
x, y = self._slices.by_enum(c.LEVEL).shape
|
x, y = self._slices.by_enum(c.LEVEL).shape
|
||||||
state = np.zeros((len(self._slices), x, y))
|
state = np.zeros((len(self._slices), x, y), dtype=np.float32)
|
||||||
state[0] = self._slices.by_enum(c.LEVEL).slice
|
state[0] = self._slices.by_enum(c.LEVEL).slice
|
||||||
if r := self.pomdp_radius:
|
if r := self.pomdp_radius:
|
||||||
self._padded_obs_cube = np.full((len(self._slices), x + r*2, y + r*2), c.FREE_CELL.value)
|
self._padded_obs_cube = np.full((len(self._slices), x + r*2, y + r*2), c.FREE_CELL.value, dtype=np.float32)
|
||||||
self._padded_obs_cube[0] = c.OCCUPIED_CELL.value
|
self._padded_obs_cube[0] = c.OCCUPIED_CELL.value
|
||||||
self._padded_obs_cube[:, r:r+x, r:r+y] = state
|
self._padded_obs_cube[:, r:r+x, r:r+y] = state
|
||||||
|
if self.combin_agent_slices_in_obs and self.n_agents > 1:
|
||||||
|
self._combined_obs_cube = np.zeros(self.observation_space.shape, dtype=np.float32)
|
||||||
return state
|
return state
|
||||||
|
|
||||||
def _init_entities(self):
|
def _init_entities(self):
|
||||||
@ -177,18 +179,25 @@ class BaseFactory(gym.Env):
|
|||||||
self._slices = self._init_state_slices()
|
self._slices = self._init_state_slices()
|
||||||
self._obs_cube = self._init_obs_cube()
|
self._obs_cube = self._init_obs_cube()
|
||||||
self._entitites = self._init_entities()
|
self._entitites = self._init_entities()
|
||||||
|
self.do_additional_reset()
|
||||||
self._flush_state()
|
self._flush_state()
|
||||||
self._steps = 0
|
self._steps = 0
|
||||||
|
|
||||||
info = self._summarize_state() if self.record_episodes else {}
|
obs = self._get_observations()
|
||||||
return None, None, None, info
|
return obs
|
||||||
|
|
||||||
def pre_step(self) -> None:
|
def pre_step(self) -> None:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def post_step(self) -> dict:
|
def do_additional_reset(self) -> None:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
def do_additional_step(self) -> dict:
|
||||||
|
return {}
|
||||||
|
|
||||||
|
def post_step(self) -> dict:
|
||||||
|
return {}
|
||||||
|
|
||||||
def step(self, actions):
|
def step(self, actions):
|
||||||
actions = [actions] if isinstance(actions, int) or np.isscalar(actions) else actions
|
actions = [actions] if isinstance(actions, int) or np.isscalar(actions) else actions
|
||||||
assert isinstance(actions, Iterable), f'"actions" has to be in [{int, list}]'
|
assert isinstance(actions, Iterable), f'"actions" has to be in [{int, list}]'
|
||||||
@ -219,6 +228,10 @@ class BaseFactory(gym.Env):
|
|||||||
agent.temp_action = action
|
agent.temp_action = action
|
||||||
agent.temp_valid = valid
|
agent.temp_valid = valid
|
||||||
|
|
||||||
|
# In-between step Hook for later use
|
||||||
|
info = self.do_additional_step()
|
||||||
|
|
||||||
|
# Write to observation cube
|
||||||
self._flush_state()
|
self._flush_state()
|
||||||
|
|
||||||
tiles_with_collisions = self.get_all_tiles_with_collisions()
|
tiles_with_collisions = self.get_all_tiles_with_collisions()
|
||||||
@ -237,7 +250,8 @@ class BaseFactory(gym.Env):
|
|||||||
self._doors.tick_doors()
|
self._doors.tick_doors()
|
||||||
|
|
||||||
# Finalize
|
# Finalize
|
||||||
reward, info = self.calculate_reward()
|
reward, reward_info = self.calculate_reward()
|
||||||
|
info.update(reward_info)
|
||||||
if self._steps >= self.max_steps:
|
if self._steps >= self.max_steps:
|
||||||
done = True
|
done = True
|
||||||
info.update(step_reward=reward, step=self._steps)
|
info.update(step_reward=reward, step=self._steps)
|
||||||
@ -255,10 +269,10 @@ class BaseFactory(gym.Env):
|
|||||||
self._obs_cube[np.arange(len(self._slices)) != self._slices.get_idx(c.LEVEL)] = c.FREE_CELL.value
|
self._obs_cube[np.arange(len(self._slices)) != self._slices.get_idx(c.LEVEL)] = c.FREE_CELL.value
|
||||||
if self.parse_doors:
|
if self.parse_doors:
|
||||||
for door in self._doors:
|
for door in self._doors:
|
||||||
if door.is_open:
|
if door.is_open and self._obs_cube[self._slices.get_idx(c.DOORS)][door.pos] != c.OPEN_DOOR.value:
|
||||||
self._obs_cube[self._slices.get_idx(c.DOORS)][door.pos] = c.IS_OPEN_DOOR.value
|
self._obs_cube[self._slices.get_idx(c.DOORS)][door.pos] = c.OPEN_DOOR.value
|
||||||
else:
|
elif door.is_closed and self._obs_cube[self._slices.get_idx(c.DOORS)][door.pos] != c.CLOSED_DOOR.value:
|
||||||
self._obs_cube[self._slices.get_idx(c.DOORS)][door.pos] = c.IS_CLOSED_DOOR.value
|
self._obs_cube[self._slices.get_idx(c.DOORS)][door.pos] = c.CLOSED_DOOR.value
|
||||||
for agent in self._agents:
|
for agent in self._agents:
|
||||||
self._obs_cube[self._slices.get_idx_by_name(agent.name)][agent.pos] = c.OCCUPIED_CELL.value
|
self._obs_cube[self._slices.get_idx_by_name(agent.name)][agent.pos] = c.OCCUPIED_CELL.value
|
||||||
if agent.last_pos != h.NO_POS:
|
if agent.last_pos != h.NO_POS:
|
||||||
|
@ -199,7 +199,7 @@ class Door(Entity):
|
|||||||
|
|
||||||
def __init__(self, *args, context, closed_on_init=True, auto_close_interval=10):
|
def __init__(self, *args, context, closed_on_init=True, auto_close_interval=10):
|
||||||
super(Door, self).__init__(*args)
|
super(Door, self).__init__(*args)
|
||||||
self._state = c.IS_CLOSED_DOOR
|
self._state = c.CLOSED_DOOR
|
||||||
self.auto_close_interval = auto_close_interval
|
self.auto_close_interval = auto_close_interval
|
||||||
self.time_to_close = -1
|
self.time_to_close = -1
|
||||||
neighbor_pos = list(itertools.product([-1, 1, 0], repeat=2))[:-1]
|
neighbor_pos = list(itertools.product([-1, 1, 0], repeat=2))[:-1]
|
||||||
@ -215,18 +215,18 @@ class Door(Entity):
|
|||||||
|
|
||||||
@property
|
@property
|
||||||
def is_closed(self):
|
def is_closed(self):
|
||||||
return self._state == c.IS_CLOSED_DOOR
|
return self._state == c.CLOSED_DOOR
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def is_open(self):
|
def is_open(self):
|
||||||
return self._state == c.IS_OPEN_DOOR
|
return self._state == c.OPEN_DOOR
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def status(self):
|
def status(self):
|
||||||
return self._state
|
return self._state
|
||||||
|
|
||||||
def use(self):
|
def use(self):
|
||||||
if self._state == c.IS_OPEN_DOOR:
|
if self._state == c.OPEN_DOOR:
|
||||||
self._close()
|
self._close()
|
||||||
else:
|
else:
|
||||||
self._open()
|
self._open()
|
||||||
@ -239,12 +239,12 @@ class Door(Entity):
|
|||||||
|
|
||||||
def _open(self):
|
def _open(self):
|
||||||
self.connectivity.add_edges_from([(self.pos, x) for x in self.connectivity.nodes])
|
self.connectivity.add_edges_from([(self.pos, x) for x in self.connectivity.nodes])
|
||||||
self._state = c.IS_OPEN_DOOR
|
self._state = c.OPEN_DOOR
|
||||||
self.time_to_close = self.auto_close_interval
|
self.time_to_close = self.auto_close_interval
|
||||||
|
|
||||||
def _close(self):
|
def _close(self):
|
||||||
self.connectivity.remove_node(self.pos)
|
self.connectivity.remove_node(self.pos)
|
||||||
self._state = c.IS_CLOSED_DOOR
|
self._state = c.CLOSED_DOOR
|
||||||
|
|
||||||
def is_linked(self, old_pos, new_pos):
|
def is_linked(self, old_pos, new_pos):
|
||||||
try:
|
try:
|
||||||
|
@ -99,7 +99,7 @@ class SimpleFactory(BaseFactory):
|
|||||||
free_for_dirt = self._tiles.empty_tiles
|
free_for_dirt = self._tiles.empty_tiles
|
||||||
|
|
||||||
# randomly distribute dirt across the grid
|
# randomly distribute dirt across the grid
|
||||||
n_dirt_tiles = int(random.uniform(0, self.dirt_properties.max_spawn_ratio) * len(free_for_dirt))
|
n_dirt_tiles = max(0, int(random.uniform(0, self.dirt_properties.max_spawn_ratio) * len(free_for_dirt)))
|
||||||
for tile in free_for_dirt[:n_dirt_tiles]:
|
for tile in free_for_dirt[:n_dirt_tiles]:
|
||||||
new_value = dirt_slice[tile.pos] + self.dirt_properties.gain_amount
|
new_value = dirt_slice[tile.pos] + self.dirt_properties.gain_amount
|
||||||
dirt_slice[tile.pos] = min(new_value, self.dirt_properties.max_local_amount)
|
dirt_slice[tile.pos] = min(new_value, self.dirt_properties.max_local_amount)
|
||||||
@ -115,7 +115,7 @@ class SimpleFactory(BaseFactory):
|
|||||||
else:
|
else:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
def post_step(self) -> dict:
|
def do_additional_step(self) -> dict:
|
||||||
if smear_amount := self.dirt_properties.dirt_smear_amount:
|
if smear_amount := self.dirt_properties.dirt_smear_amount:
|
||||||
dirt_slice = self._slices.by_name(DIRT).slice
|
dirt_slice = self._slices.by_name(DIRT).slice
|
||||||
for agent in self._agents:
|
for agent in self._agents:
|
||||||
@ -144,12 +144,9 @@ class SimpleFactory(BaseFactory):
|
|||||||
else:
|
else:
|
||||||
raise RuntimeError('This should not happen!!!')
|
raise RuntimeError('This should not happen!!!')
|
||||||
|
|
||||||
def reset(self) -> (np.ndarray, int, bool, dict):
|
def do_additional_reset(self) -> None:
|
||||||
_ = super().reset() # state, reward, done, info ... =
|
|
||||||
self.spawn_dirt()
|
self.spawn_dirt()
|
||||||
self._next_dirt_spawn = self.dirt_properties.spawn_frequency
|
self._next_dirt_spawn = self.dirt_properties.spawn_frequency
|
||||||
obs = self._get_observations()
|
|
||||||
return obs
|
|
||||||
|
|
||||||
def calculate_reward(self) -> (int, dict):
|
def calculate_reward(self) -> (int, dict):
|
||||||
info_dict = dict()
|
info_dict = dict()
|
||||||
@ -174,7 +171,7 @@ class SimpleFactory(BaseFactory):
|
|||||||
|
|
||||||
if self._is_clean_up_action(agent.temp_action):
|
if self._is_clean_up_action(agent.temp_action):
|
||||||
if agent.temp_valid:
|
if agent.temp_valid:
|
||||||
reward += 1
|
reward += 0.5
|
||||||
self.print(f'{agent.name} did just clean up some dirt at {agent.pos}.')
|
self.print(f'{agent.name} did just clean up some dirt at {agent.pos}.')
|
||||||
info_dict.update(dirt_cleaned=1)
|
info_dict.update(dirt_cleaned=1)
|
||||||
else:
|
else:
|
||||||
|
@ -17,8 +17,8 @@ class Constants(Enum):
|
|||||||
OCCUPIED_CELL = 1
|
OCCUPIED_CELL = 1
|
||||||
|
|
||||||
DOORS = 'doors'
|
DOORS = 'doors'
|
||||||
IS_CLOSED_DOOR = 1
|
CLOSED_DOOR = 1
|
||||||
IS_OPEN_DOOR = -1
|
OPEN_DOOR = -1
|
||||||
|
|
||||||
LEVEL_IDX = 0
|
LEVEL_IDX = 0
|
||||||
|
|
||||||
|
10
main.py
10
main.py
@ -92,8 +92,8 @@ if __name__ == '__main__':
|
|||||||
from algorithms.reg_dqn import RegDQN
|
from algorithms.reg_dqn import RegDQN
|
||||||
# from sb3_contrib import QRDQN
|
# from sb3_contrib import QRDQN
|
||||||
|
|
||||||
dirt_props = DirtProperties(clean_amount=6, gain_amount=1, max_global_amount=30,
|
dirt_props = DirtProperties(clean_amount=1, gain_amount=0.1, max_global_amount=20,
|
||||||
max_local_amount=5, spawn_frequency=5, max_spawn_ratio=0.05,
|
max_local_amount=1, spawn_frequency=5, max_spawn_ratio=0.05,
|
||||||
dirt_smear_amount=0.0)
|
dirt_smear_amount=0.0)
|
||||||
move_props = MovementProperties(allow_diagonal_movement=True,
|
move_props = MovementProperties(allow_diagonal_movement=True,
|
||||||
allow_square_movement=True,
|
allow_square_movement=True,
|
||||||
@ -102,11 +102,11 @@ if __name__ == '__main__':
|
|||||||
|
|
||||||
out_path = None
|
out_path = None
|
||||||
|
|
||||||
for modeL_type in [A2C, PPO, RegDQN, DQN]: # , QRDQN]:
|
for modeL_type in [A2C]: # , PPO, RegDQN, DQN]: # , QRDQN]:
|
||||||
for seed in range(3):
|
for seed in range(3):
|
||||||
|
|
||||||
with SimpleFactory(n_agents=1, dirt_properties=dirt_props, pomdp_radius=2, max_steps=400, parse_doors=False,
|
with SimpleFactory(n_agents=1, dirt_properties=dirt_props, pomdp_radius=2, max_steps=400, parse_doors=True,
|
||||||
movement_properties=move_props, level_name='rooms', frames_to_stack=4,
|
movement_properties=move_props, level_name='rooms', frames_to_stack=0,
|
||||||
omit_agent_slice_in_obs=True, combin_agent_slices_in_obs=True, record_episodes=False
|
omit_agent_slice_in_obs=True, combin_agent_slices_in_obs=True, record_episodes=False
|
||||||
) as env:
|
) as env:
|
||||||
|
|
||||||
|
@ -14,17 +14,18 @@ warnings.filterwarnings('ignore', category=UserWarning)
|
|||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
|
|
||||||
model_name = 'A2C_1626103200'
|
model_name = 'PPO_1626384768'
|
||||||
run_id = 0
|
run_id = 0
|
||||||
out_path = Path(__file__).parent / 'debug_out'
|
out_path = Path(__file__).parent / 'debug_out'
|
||||||
model_path = out_path / model_name
|
model_path = out_path / model_name
|
||||||
|
|
||||||
with (model_path / f'env_{model_name}.yaml').open('r') as f:
|
with (model_path / f'env_{model_name}.yaml').open('r') as f:
|
||||||
env_kwargs = yaml.load(f, Loader=yaml.FullLoader)
|
env_kwargs = yaml.load(f, Loader=yaml.FullLoader)
|
||||||
env_kwargs.update(dirt_properties=DirtProperties(clean_amount=1, gain_amount=0.3, max_global_amount=20,
|
if False:
|
||||||
max_local_amount=2, spawn_frequency=5, max_spawn_ratio=0.05,
|
env_kwargs.update(dirt_properties=DirtProperties(clean_amount=1, gain_amount=0.1, max_global_amount=20,
|
||||||
dirt_smear_amount=0.2),
|
max_local_amount=1, spawn_frequency=5, max_spawn_ratio=0.05,
|
||||||
combin_agent_slices_in_obs=True, omit_agent_slice_in_obs=True)
|
dirt_smear_amount=0.5),
|
||||||
|
combin_agent_slices_in_obs=True, omit_agent_slice_in_obs=True)
|
||||||
with SimpleFactory(**env_kwargs) as env:
|
with SimpleFactory(**env_kwargs) as env:
|
||||||
|
|
||||||
# Edit THIS:
|
# Edit THIS:
|
||||||
@ -32,5 +33,5 @@ if __name__ == '__main__':
|
|||||||
this_model = model_files[0]
|
this_model = model_files[0]
|
||||||
|
|
||||||
model = PPO.load(this_model)
|
model = PPO.load(this_model)
|
||||||
evaluation_result = evaluate_policy(model, env, n_eval_episodes=100, deterministic=True, render=True)
|
evaluation_result = evaluate_policy(model, env, n_eval_episodes=100, deterministic=False, render=True)
|
||||||
print(evaluation_result)
|
print(evaluation_result)
|
||||||
|
Reference in New Issue
Block a user