mirror of
https://github.com/illiumst/marl-factory-grid.git
synced 2025-05-23 15:26:43 +02:00
Restructuring
This commit is contained in:
parent
d9d8784338
commit
c8883a9c0d
@ -36,6 +36,10 @@ class BaseFactory(gym.Env):
|
|||||||
def movement_actions(self):
|
def movement_actions(self):
|
||||||
return self._actions.movement_actions
|
return self._actions.movement_actions
|
||||||
|
|
||||||
|
@property
|
||||||
|
def has_doors(self):
|
||||||
|
return hasattr(self, '_doors')
|
||||||
|
|
||||||
def __enter__(self):
|
def __enter__(self):
|
||||||
return self if self.frames_to_stack == 0 else FrameStack(self, self.frames_to_stack)
|
return self if self.frames_to_stack == 0 else FrameStack(self, self.frames_to_stack)
|
||||||
|
|
||||||
@ -43,7 +47,7 @@ class BaseFactory(gym.Env):
|
|||||||
self.close()
|
self.close()
|
||||||
|
|
||||||
def __init__(self, level_name='simple', n_agents=1, max_steps=int(5e2), pomdp_radius: Union[None, int] = 0,
|
def __init__(self, level_name='simple', n_agents=1, max_steps=int(5e2), pomdp_radius: Union[None, int] = 0,
|
||||||
movement_properties: MovementProperties = MovementProperties(),
|
movement_properties: MovementProperties = MovementProperties(), parse_doors=False,
|
||||||
combin_agent_slices_in_obs: bool = False, frames_to_stack=0,
|
combin_agent_slices_in_obs: bool = False, frames_to_stack=0,
|
||||||
omit_agent_slice_in_obs=False, **kwargs):
|
omit_agent_slice_in_obs=False, **kwargs):
|
||||||
assert (combin_agent_slices_in_obs != omit_agent_slice_in_obs) or \
|
assert (combin_agent_slices_in_obs != omit_agent_slice_in_obs) or \
|
||||||
@ -64,25 +68,31 @@ class BaseFactory(gym.Env):
|
|||||||
self.done_at_collision = False
|
self.done_at_collision = False
|
||||||
|
|
||||||
self._state_slices = StateSlices()
|
self._state_slices = StateSlices()
|
||||||
|
|
||||||
|
# Level
|
||||||
level_filepath = Path(__file__).parent / h.LEVELS_DIR / f'{self.level_name}.txt'
|
level_filepath = Path(__file__).parent / h.LEVELS_DIR / f'{self.level_name}.txt'
|
||||||
parsed_level = h.parse_level(level_filepath)
|
parsed_level = h.parse_level(level_filepath)
|
||||||
self._level = h.one_hot_level(parsed_level)
|
self._level = h.one_hot_level(parsed_level)
|
||||||
parsed_doors = h.one_hot_level(parsed_level, h.DOOR)
|
level_slices = [h.LEVEL]
|
||||||
if parsed_doors.any():
|
|
||||||
self._doors = parsed_doors
|
# Doors
|
||||||
level_slices = ['level', 'doors']
|
if parse_doors:
|
||||||
can_use_doors = True
|
parsed_doors = h.one_hot_level(parsed_level, h.DOOR)
|
||||||
else:
|
if parsed_doors.any():
|
||||||
level_slices = ['level']
|
self._doors = parsed_doors
|
||||||
can_use_doors = False
|
level_slices.append(h.DOORS)
|
||||||
|
|
||||||
|
# Agents
|
||||||
offset = len(level_slices)
|
offset = len(level_slices)
|
||||||
self._state_slices.register_additional_items([*level_slices,
|
self._state_slices.register_additional_items([*level_slices,
|
||||||
*[f'agent#{i}' for i in range(offset, n_agents + offset)]])
|
*[f'agent#{i}' for i in range(offset, n_agents + offset)]])
|
||||||
|
|
||||||
|
# Additional Slices from SubDomains
|
||||||
if 'additional_slices' in kwargs:
|
if 'additional_slices' in kwargs:
|
||||||
self._state_slices.register_additional_items(kwargs.get('additional_slices'))
|
self._state_slices.register_additional_items(kwargs.get('additional_slices'))
|
||||||
self._zones = Zones(parsed_level)
|
self._zones = Zones(parsed_level)
|
||||||
|
|
||||||
self._actions = Actions(self.movement_properties, can_use_doors=can_use_doors)
|
self._actions = Actions(self.movement_properties, can_use_doors=self.has_doors)
|
||||||
self._actions.register_additional_items(self.additional_actions)
|
self._actions.register_additional_items(self.additional_actions)
|
||||||
self.reset()
|
self.reset()
|
||||||
|
|
||||||
@ -99,30 +109,29 @@ class BaseFactory(gym.Env):
|
|||||||
raise NotImplementedError('Please register additional actions ')
|
raise NotImplementedError('Please register additional actions ')
|
||||||
|
|
||||||
def reset(self) -> (np.ndarray, int, bool, dict):
|
def reset(self) -> (np.ndarray, int, bool, dict):
|
||||||
|
slices = [np.expand_dims(self._level, 0)]
|
||||||
self._steps = 0
|
self._steps = 0
|
||||||
self._agent_states = list()
|
self._agent_states = list()
|
||||||
|
|
||||||
|
# Door Init
|
||||||
|
if self.has_doors:
|
||||||
|
self._door_states = [DoorState(i, tuple(pos)) for i, pos
|
||||||
|
in enumerate(np.argwhere(self._doors == h.IS_OCCUPIED_CELL))]
|
||||||
|
slices.append(np.expand_dims(self._doors, 0))
|
||||||
|
|
||||||
# Agent placement ...
|
# Agent placement ...
|
||||||
agents = np.zeros((self.n_agents, *self._level.shape), dtype=np.int8)
|
|
||||||
floor_tiles = np.argwhere(self._level == h.IS_FREE_CELL)
|
floor_tiles = np.argwhere(self._level == h.IS_FREE_CELL)
|
||||||
# ... on random positions
|
# ... on random positions
|
||||||
np.random.shuffle(floor_tiles)
|
np.random.shuffle(floor_tiles)
|
||||||
|
agents = np.zeros((self.n_agents, *self._level.shape), dtype=np.int8)
|
||||||
for i, (x, y) in enumerate(floor_tiles[:self.n_agents]):
|
for i, (x, y) in enumerate(floor_tiles[:self.n_agents]):
|
||||||
agents[i, x, y] = h.IS_OCCUPIED_CELL
|
agents[i, x, y] = h.IS_OCCUPIED_CELL
|
||||||
agent_state = AgentState(i, -1)
|
agent_state = AgentState(i, -1, pos=(x, y))
|
||||||
agent_state.update(pos=(x, y))
|
|
||||||
self._agent_states.append(agent_state)
|
self._agent_states.append(agent_state)
|
||||||
# state.shape = level, agent 1,..., agent n,
|
slices.append(agents)
|
||||||
if 'doors' in self._state_slices.values():
|
|
||||||
self._door_states = [DoorState(i, tuple(pos)) for i, pos
|
|
||||||
in enumerate(np.argwhere(self._doors == h.IS_OCCUPIED_CELL))]
|
|
||||||
self._state = np.concatenate((np.expand_dims(self._level, axis=0),
|
|
||||||
np.expand_dims(self._doors, axis=0),
|
|
||||||
agents), axis=0)
|
|
||||||
|
|
||||||
else:
|
# GLOBAL STATE
|
||||||
self._state = np.concatenate((np.expand_dims(self._level, axis=0), agents), axis=0)
|
self._state = np.concatenate(slices, axis=0)
|
||||||
# Returns State
|
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def _get_observations(self) -> np.ndarray:
|
def _get_observations(self) -> np.ndarray:
|
||||||
@ -138,21 +147,22 @@ class BaseFactory(gym.Env):
|
|||||||
first_agent_slice = self._state_slices.AGENTSTARTIDX
|
first_agent_slice = self._state_slices.AGENTSTARTIDX
|
||||||
# Todo: make this more efficient!
|
# Todo: make this more efficient!
|
||||||
if self.pomdp_radius:
|
if self.pomdp_radius:
|
||||||
global_pos = self._agent_states[agent_i].pos
|
pomdp_diameter = self.pomdp_radius * 2 + 1
|
||||||
x0, x1 = max(0, global_pos[0] - self.pomdp_radius), global_pos[0] + self.pomdp_radius + 1
|
global_x, global_y = self._agent_states[agent_i].pos
|
||||||
y0, y1 = max(0, global_pos[1] - self.pomdp_radius), global_pos[1] + self.pomdp_radius + 1
|
x0, x1 = max(0, global_x - self.pomdp_radius), global_x + self.pomdp_radius + 1
|
||||||
|
y0, y1 = max(0, global_y - self.pomdp_radius), global_y + self.pomdp_radius + 1
|
||||||
obs = self._state[:, x0:x1, y0:y1]
|
obs = self._state[:, x0:x1, y0:y1]
|
||||||
if obs.shape[1] != self.pomdp_radius * 2 + 1 or obs.shape[2] != self.pomdp_radius * 2 + 1:
|
if obs.shape[1] != pomdp_diameter or obs.shape[2] != pomdp_diameter:
|
||||||
obs_padded = np.full((obs.shape[0], self.pomdp_radius * 2 + 1, self.pomdp_radius * 2 + 1), 1)
|
obs_padded = np.full((obs.shape[0], pomdp_diameter, pomdp_diameter), h.IS_OCCUPIED_CELL)
|
||||||
a_pos = np.argwhere(obs[first_agent_slice + agent_i] == h.IS_OCCUPIED_CELL)[0]
|
local_x, local_y = np.argwhere(obs[first_agent_slice + agent_i] == h.IS_OCCUPIED_CELL)[0]
|
||||||
obs_padded[:,
|
obs_padded[:,
|
||||||
abs(a_pos[0]-self.pomdp_radius):abs(a_pos[0]-self.pomdp_radius)+obs.shape[1],
|
abs(local_x-self.pomdp_radius):abs(local_x-self.pomdp_radius)+obs.shape[1],
|
||||||
abs(a_pos[1]-self.pomdp_radius):abs(a_pos[1]-self.pomdp_radius)+obs.shape[2]] = obs
|
abs(local_y-self.pomdp_radius):abs(local_y-self.pomdp_radius)+obs.shape[2]] = obs
|
||||||
obs = obs_padded
|
obs = obs_padded
|
||||||
else:
|
else:
|
||||||
obs = self._state
|
obs = self._state
|
||||||
if self.omit_agent_slice_in_obs:
|
if self.omit_agent_slice_in_obs:
|
||||||
obs_new = obs[[key for key, val in self._state_slices.items() if 'agent' not in val]]
|
obs_new = obs[[key for key, val in self._state_slices.items() if h.AGENT not in val]]
|
||||||
return obs_new
|
return obs_new
|
||||||
else:
|
else:
|
||||||
if self.combin_agent_slices_in_obs:
|
if self.combin_agent_slices_in_obs:
|
||||||
@ -174,16 +184,19 @@ class BaseFactory(gym.Env):
|
|||||||
|
|
||||||
# Move this in a seperate function?
|
# Move this in a seperate function?
|
||||||
for agent_i, action in enumerate(actions):
|
for agent_i, action in enumerate(actions):
|
||||||
|
agent = self._agent_states[agent_i]
|
||||||
if self._actions.is_moving_action(action):
|
if self._actions.is_moving_action(action):
|
||||||
pos, valid = self.move_or_colide(agent_i, action)
|
pos, valid = self.move_or_colide(agent_i, action)
|
||||||
elif self._actions.is_no_op(action):
|
elif self._actions.is_no_op(action):
|
||||||
pos, valid = self._agent_states[agent_i].pos, h.VALID
|
pos, valid = agent.pos, h.VALID
|
||||||
elif self._actions.is_door_usage(action):
|
elif self._actions.is_door_usage(action):
|
||||||
try:
|
# Check if agent raly stands on a door:
|
||||||
|
if self._state[self._state_slices.by_name(h.DOORS)][agent.pos] in [h.IS_OCCUPIED_CELL, ]:
|
||||||
door = [door for door in self._door_states if door.pos == self._agent_states[agent_i].pos][0]
|
door = [door for door in self._door_states if door.pos == self._agent_states[agent_i].pos][0]
|
||||||
door.use()
|
door.use()
|
||||||
pos, valid = self._agent_states[agent_i].pos, h.VALID
|
pos, valid = self._agent_states[agent_i].pos, h.VALID
|
||||||
except IndexError:
|
# When he doesn't...
|
||||||
|
else:
|
||||||
pos, valid = self._agent_states[agent_i].pos, h.NOT_VALID
|
pos, valid = self._agent_states[agent_i].pos, h.NOT_VALID
|
||||||
else:
|
else:
|
||||||
pos, valid = self.do_additional_actions(agent_i, action)
|
pos, valid = self.do_additional_actions(agent_i, action)
|
||||||
@ -202,6 +215,7 @@ class BaseFactory(gym.Env):
|
|||||||
door.time_to_close -= 1
|
door.time_to_close -= 1
|
||||||
elif door.is_open and not door.time_to_close and door.pos not in agents_pos:
|
elif door.is_open and not door.time_to_close and door.pos not in agents_pos:
|
||||||
door.use()
|
door.use()
|
||||||
|
self._state[self._state_slices.by_name(h.DOORS)] = 1 if door.is_closed else -1
|
||||||
|
|
||||||
reward, info = self.calculate_reward(self._agent_states)
|
reward, info = self.calculate_reward(self._agent_states)
|
||||||
|
|
||||||
@ -230,11 +244,12 @@ class BaseFactory(gym.Env):
|
|||||||
collisions_vec[self._state_slices.by_name('doors')] = h.IS_FREE_CELL # no door-collisions
|
collisions_vec[self._state_slices.by_name('doors')] = h.IS_FREE_CELL # no door-collisions
|
||||||
|
|
||||||
if agent_state.action_valid:
|
if agent_state.action_valid:
|
||||||
# ToDo: Place a function hook here
|
# All well, no collision.
|
||||||
|
# Place a function hook here if needed.
|
||||||
pass
|
pass
|
||||||
else:
|
else:
|
||||||
# Place a marker to indicate a collision with the level boundrys
|
# Place a marker to indicate a collision with the level boundrys
|
||||||
collisions_vec[h.LEVEL_IDX] = h.IS_OCCUPIED_CELL
|
collisions_vec[self._state_slices.by_name(h.LEVEL)] = h.IS_OCCUPIED_CELL
|
||||||
return collisions_vec
|
return collisions_vec
|
||||||
|
|
||||||
def do_move(self, agent_i: int, old_pos: (int, int), new_pos: (int, int)) -> None:
|
def do_move(self, agent_i: int, old_pos: (int, int), new_pos: (int, int)) -> None:
|
||||||
@ -265,7 +280,7 @@ class BaseFactory(gym.Env):
|
|||||||
x_new = x + x_diff
|
x_new = x + x_diff
|
||||||
y_new = y + y_diff
|
y_new = y + y_diff
|
||||||
|
|
||||||
if h.DOORS in self._state_slices.values() and self._agent_states[agent_i]._last_pos != (-1, -1):
|
if self.has_doors and self._agent_states[agent_i]._last_pos != (-1, -1):
|
||||||
door = [door for door in self._door_states if door.pos == (x, y)]
|
door = [door for door in self._door_states if door.pos == (x, y)]
|
||||||
if door:
|
if door:
|
||||||
door = door[0]
|
door = door[0]
|
||||||
@ -298,7 +313,7 @@ class BaseFactory(gym.Env):
|
|||||||
else:
|
else:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
valid = h.check_position(self._state[h.LEVEL_IDX], (x_new, y_new))
|
valid = h.check_position(self._state[self._state_slices.by_name(h.LEVEL)], (x_new, y_new))
|
||||||
|
|
||||||
return (x, y), (x_new, y_new), valid
|
return (x, y), (x_new, y_new), valid
|
||||||
|
|
||||||
|
@ -4,10 +4,10 @@
|
|||||||
#333333xx#4444#
|
#333333xx#4444#
|
||||||
#333333#444444#
|
#333333#444444#
|
||||||
#333333#444444#
|
#333333#444444#
|
||||||
###x#######x###
|
###x#######D###
|
||||||
#1111##2222222#
|
#1111##2222222#
|
||||||
#11111#2222#22#
|
#11111#2222#22#
|
||||||
#11111x2222222#
|
#11111D2222222#
|
||||||
#11111#2222222#
|
#11111#2222222#
|
||||||
#11111#2222222#
|
#11111#2222222#
|
||||||
###############
|
###############
|
@ -49,28 +49,30 @@ class SimpleFactory(BaseFactory):
|
|||||||
|
|
||||||
dirt = [Entity('dirt', [x, y], min(0.15 + self._state[DIRT_INDEX, x, y], 1.5), 'scale')
|
dirt = [Entity('dirt', [x, y], min(0.15 + self._state[DIRT_INDEX, x, y], 1.5), 'scale')
|
||||||
for x, y in np.argwhere(self._state[DIRT_INDEX] > h.IS_FREE_CELL)]
|
for x, y in np.argwhere(self._state[DIRT_INDEX] > h.IS_FREE_CELL)]
|
||||||
walls = [Entity('wall', pos) for pos in np.argwhere(self._state[h.LEVEL_IDX] > h.IS_FREE_CELL)]
|
walls = [Entity('wall', pos)
|
||||||
|
for pos in np.argwhere(self._state[self._state_slices.by_name(h.LEVEL)] > h.IS_FREE_CELL)]
|
||||||
|
|
||||||
def asset_str(agent):
|
def asset_str(agent):
|
||||||
if any([x is None for x in [self._state_slices[j] for j in agent.collisions]]):
|
if any([x is None for x in [self._state_slices[j] for j in agent.collisions]]):
|
||||||
print('error')
|
print('error')
|
||||||
cols = ' '.join([self._state_slices[j] for j in agent.collisions])
|
cols = ' '.join([self._state_slices[j] for j in agent.collisions])
|
||||||
if 'agent' in cols:
|
if h.AGENT in cols:
|
||||||
return 'agent_collision', 'blank'
|
return 'agent_collision', 'blank'
|
||||||
elif not agent.action_valid or 'level' in cols or 'agent' in cols:
|
elif not agent.action_valid or 'level' in cols or h.AGENT in cols:
|
||||||
return 'agent', 'invalid'
|
return h.AGENT, 'invalid'
|
||||||
elif self._is_clean_up_action(agent.action):
|
elif self._is_clean_up_action(agent.action):
|
||||||
return 'agent', 'valid'
|
return h.AGENT, 'valid'
|
||||||
else:
|
else:
|
||||||
return 'agent', 'idle'
|
return h.AGENT, 'idle'
|
||||||
agents = []
|
agents = []
|
||||||
for i, agent in enumerate(self._agent_states):
|
for i, agent in enumerate(self._agent_states):
|
||||||
name, state = asset_str(agent)
|
name, state = asset_str(agent)
|
||||||
agents.append(Entity(name, agent.pos, 1, 'none', state, i+1))
|
agents.append(Entity(name, agent.pos, 1, 'none', state, i+1))
|
||||||
doors = []
|
doors = []
|
||||||
for i, door in enumerate(self._door_states):
|
if self.has_doors:
|
||||||
name, state = 'door_open' if door.is_open else 'door_closed', 'blank'
|
for i, door in enumerate(self._door_states):
|
||||||
agents.append(Entity(name, door.pos, 1, 'none', state, i+1))
|
name, state = 'door_open' if door.is_open else 'door_closed', 'blank'
|
||||||
|
agents.append(Entity(name, door.pos, 1, 'none', state, i+1))
|
||||||
self._renderer.render(dirt+walls+agents+doors)
|
self._renderer.render(dirt+walls+agents+doors)
|
||||||
|
|
||||||
def spawn_dirt(self) -> None:
|
def spawn_dirt(self) -> None:
|
||||||
@ -141,26 +143,25 @@ class SimpleFactory(BaseFactory):
|
|||||||
reward = 0
|
reward = 0
|
||||||
|
|
||||||
for agent_state in agent_states:
|
for agent_state in agent_states:
|
||||||
|
agent_name = f'{h.AGENT.capitalize()} {agent_state.i}'
|
||||||
cols = agent_state.collisions
|
cols = agent_state.collisions
|
||||||
|
|
||||||
list_of_collisions = [self._state_slices[entity] for entity in cols
|
list_of_collisions = [self._state_slices[entity] for entity in cols
|
||||||
if entity != self._state_slices.by_name("dirt")]
|
if entity != self._state_slices.by_name('dirt')]
|
||||||
|
|
||||||
if list_of_collisions:
|
if list_of_collisions:
|
||||||
self.print(f't = {self._steps}\tAgent {agent_state.i} has collisions with '
|
self.print(f't = {self._steps}\t{agent_name} has collisions with {list_of_collisions}')
|
||||||
f'{list_of_collisions}')
|
|
||||||
|
|
||||||
if self._is_clean_up_action(agent_state.action):
|
if self._is_clean_up_action(agent_state.action):
|
||||||
if agent_state.action_valid:
|
if agent_state.action_valid:
|
||||||
reward += 1
|
reward += 1
|
||||||
self.print(f'Agent {agent_state.i} did just clean up some dirt at {agent_state.pos}.')
|
self.print(f'{agent_name} did just clean up some dirt at {agent_state.pos}.')
|
||||||
info_dict.update(dirt_cleaned=1)
|
info_dict.update(dirt_cleaned=1)
|
||||||
else:
|
else:
|
||||||
reward -= 0.01
|
reward -= 0.01
|
||||||
self.print(f'Agent {agent_state.i} just tried to clean up some dirt '
|
self.print(f'{agent_name} just tried to clean up some dirt at {agent_state.pos}, but failed.')
|
||||||
f'at {agent_state.pos}, but was unsucsessfull.')
|
info_dict.update({f'{h.AGENT}_{agent_state.i}_failed_action': 1})
|
||||||
info_dict.update({f'agent_{agent_state.i}_failed_action': 1})
|
info_dict.update({f'{h.AGENT}_{agent_state.i}_failed_dirt_cleanup': 1})
|
||||||
info_dict.update({f'agent_{agent_state.i}_failed_dirt_cleanup': 1})
|
|
||||||
|
|
||||||
elif self._actions.is_moving_action(agent_state.action):
|
elif self._actions.is_moving_action(agent_state.action):
|
||||||
if agent_state.action_valid:
|
if agent_state.action_valid:
|
||||||
@ -173,21 +174,20 @@ class SimpleFactory(BaseFactory):
|
|||||||
elif self._actions.is_door_usage(agent_state.action):
|
elif self._actions.is_door_usage(agent_state.action):
|
||||||
if agent_state.action_valid:
|
if agent_state.action_valid:
|
||||||
reward += 0.1
|
reward += 0.1
|
||||||
self.print(f'Agent {agent_state.i} did just use the door at {agent_state.pos}.')
|
self.print(f'{agent_name} did just use the door at {agent_state.pos}.')
|
||||||
info_dict.update(door_used=1)
|
info_dict.update(door_used=1)
|
||||||
else:
|
else:
|
||||||
self.print(f'Agent {agent_state.i} just tried to use a door '
|
self.print(f'{agent_name} just tried to use a door at {agent_state.pos}, but failed.')
|
||||||
f'at {agent_state.pos}, but was unsucsessfull.')
|
info_dict.update({f'{h.AGENT}_{agent_state.i}_failed_action': 1})
|
||||||
info_dict.update({f'agent_{agent_state.i}_failed_action': 1})
|
info_dict.update({f'{h.AGENT}_{agent_state.i}_failed_door_open': 1})
|
||||||
info_dict.update({f'agent_{agent_state.i}_failed_door_open': 1})
|
|
||||||
|
|
||||||
else:
|
else:
|
||||||
info_dict.update(no_op=1)
|
info_dict.update(no_op=1)
|
||||||
reward -= 0.00
|
reward -= 0.00
|
||||||
|
|
||||||
for entity in list_of_collisions:
|
for entity in list_of_collisions:
|
||||||
entity = 'agent' if 'agent' in entity else entity
|
entity = h.AGENT if h.AGENT in entity else entity
|
||||||
info_dict.update({f'agent_{agent_state.i}_vs_{entity}': 1})
|
info_dict.update({f'{h.AGENT}_{agent_state.i}_vs_{entity}': 1})
|
||||||
|
|
||||||
self.print(f"reward is {reward}")
|
self.print(f"reward is {reward}")
|
||||||
# Potential based rewards ->
|
# Potential based rewards ->
|
||||||
|
@ -10,10 +10,16 @@ DOOR = 'D'
|
|||||||
DANGER_ZONE = 'x'
|
DANGER_ZONE = 'x'
|
||||||
LEVELS_DIR = 'levels'
|
LEVELS_DIR = 'levels'
|
||||||
LEVEL = 'level'
|
LEVEL = 'level'
|
||||||
DOORS = 'doors'
|
AGENT = 'agent'
|
||||||
LEVEL_IDX = 0
|
|
||||||
IS_FREE_CELL = 0
|
IS_FREE_CELL = 0
|
||||||
IS_OCCUPIED_CELL = 1
|
IS_OCCUPIED_CELL = 1
|
||||||
|
|
||||||
|
DOORS = 'doors'
|
||||||
|
IS_CLOSED_DOOR = IS_OCCUPIED_CELL
|
||||||
|
IS_OPEN_DOOR = -1
|
||||||
|
|
||||||
|
LEVEL_IDX = 0
|
||||||
|
|
||||||
TO_BE_AVERAGED = ['dirt_amount', 'dirty_tiles']
|
TO_BE_AVERAGED = ['dirt_amount', 'dirty_tiles']
|
||||||
IGNORED_DF_COLUMNS = ['Episode', 'Run', 'train_step', 'step', 'index', 'dirt_amount', 'dirty_tile_count']
|
IGNORED_DF_COLUMNS = ['Episode', 'Run', 'train_step', 'step', 'index', 'dirt_amount', 'dirty_tile_count']
|
||||||
|
|
||||||
|
@ -83,13 +83,13 @@ class AgentState:
|
|||||||
curr_x, curr_y = self.pos
|
curr_x, curr_y = self.pos
|
||||||
return last_x-curr_x, last_y-curr_y
|
return last_x-curr_x, last_y-curr_y
|
||||||
|
|
||||||
def __init__(self, i: int, action: int):
|
def __init__(self, i: int, action: int, pos=None):
|
||||||
self.i = i
|
self.i = i
|
||||||
self.action = action
|
self.action = action
|
||||||
|
|
||||||
self.collision_vector = None
|
self.collision_vector = None
|
||||||
self.action_valid = None
|
self.action_valid = None
|
||||||
self.pos = None
|
self.pos = pos
|
||||||
self._last_pos = (-1, -1)
|
self._last_pos = (-1, -1)
|
||||||
|
|
||||||
def update(self, **kwargs): # is this hacky?? o.0
|
def update(self, **kwargs): # is this hacky?? o.0
|
||||||
@ -248,7 +248,7 @@ class StateSlices(Register):
|
|||||||
if self._agent_start_idx:
|
if self._agent_start_idx:
|
||||||
return self._agent_start_idx
|
return self._agent_start_idx
|
||||||
else:
|
else:
|
||||||
self._agent_start_idx = min([idx for idx, x in self.items() if 'agent' in x])
|
self._agent_start_idx = min([idx for idx, x in self.items() if h.AGENT in x])
|
||||||
return self._agent_start_idx
|
return self._agent_start_idx
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
|
@ -29,8 +29,8 @@ if __name__ == '__main__':
|
|||||||
# rewards += [total reward]
|
# rewards += [total reward]
|
||||||
# boxplot total rewards
|
# boxplot total rewards
|
||||||
|
|
||||||
run_id = '1623241962'
|
run_id = '1623923982'
|
||||||
model_name = 'PPO'
|
model_name = 'A2C'
|
||||||
|
|
||||||
# -----------------------
|
# -----------------------
|
||||||
out_path = Path(__file__).parent / 'debug_out'
|
out_path = Path(__file__).parent / 'debug_out'
|
||||||
@ -48,7 +48,7 @@ if __name__ == '__main__':
|
|||||||
env_kwargs = yaml.load(f, Loader=yaml.FullLoader)
|
env_kwargs = yaml.load(f, Loader=yaml.FullLoader)
|
||||||
dirt_props = DirtProperties(clean_amount=3, gain_amount=0.2, max_global_amount=30,
|
dirt_props = DirtProperties(clean_amount=3, gain_amount=0.2, max_global_amount=30,
|
||||||
max_local_amount=3, spawn_frequency=1, max_spawn_ratio=0.05)
|
max_local_amount=3, spawn_frequency=1, max_spawn_ratio=0.05)
|
||||||
env_kwargs.update(n_agents=1, dirt_properties=dirt_props)
|
# env_kwargs.update(n_agents=1, dirt_properties=dirt_props)
|
||||||
env = SimpleFactory(**env_kwargs)
|
env = SimpleFactory(**env_kwargs)
|
||||||
|
|
||||||
env = FrameStack(env, 4)
|
env = FrameStack(env, 4)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user