Merge remote-tracking branch 'origin/main'

This commit is contained in:
steffen-illium 2021-05-18 14:44:09 +02:00
commit 7aa31b8f76
11 changed files with 17 additions and 10 deletions

View File

Before

Width:  |  Height:  |  Size: 3.3 KiB

After

Width:  |  Height:  |  Size: 3.3 KiB

View File

Before

Width:  |  Height:  |  Size: 6.0 KiB

After

Width:  |  Height:  |  Size: 6.0 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 6.6 KiB

View File

Before

Width:  |  Height:  |  Size: 6.5 KiB

After

Width:  |  Height:  |  Size: 6.5 KiB

View File

Before

Width:  |  Height:  |  Size: 6.5 KiB

After

Width:  |  Height:  |  Size: 6.5 KiB

View File

Before

Width:  |  Height:  |  Size: 6.3 KiB

After

Width:  |  Height:  |  Size: 6.3 KiB

View File

Before

Width:  |  Height:  |  Size: 6.4 KiB

After

Width:  |  Height:  |  Size: 6.4 KiB

View File

Before

Width:  |  Height:  |  Size: 6.6 KiB

After

Width:  |  Height:  |  Size: 6.6 KiB

View File

@ -68,6 +68,7 @@ class BaseFactory(gym.Env):
self.steps = 0 self.steps = 0
self.cumulative_reward = 0 self.cumulative_reward = 0
self.monitor = FactoryMonitor(self) self.monitor = FactoryMonitor(self)
self.agent_states = []
# Agent placement ... # Agent placement ...
agents = np.zeros((self.n_agents, *self.level.shape), dtype=np.int8) agents = np.zeros((self.n_agents, *self.level.shape), dtype=np.int8)
floor_tiles = np.argwhere(self.level == h.IS_FREE_CELL) floor_tiles = np.argwhere(self.level == h.IS_FREE_CELL)
@ -75,9 +76,13 @@ class BaseFactory(gym.Env):
np.random.shuffle(floor_tiles) np.random.shuffle(floor_tiles)
for i, (x, y) in enumerate(floor_tiles[:self.n_agents]): for i, (x, y) in enumerate(floor_tiles[:self.n_agents]):
agents[i, x, y] = h.IS_OCCUPIED_CELL agents[i, x, y] = h.IS_OCCUPIED_CELL
agent_state = AgentState(i, -1)
agent_state.update(pos=[x, y])
self.agent_states.append(agent_state)
# state.shape = level, agent 1,..., agent n, # state.shape = level, agent 1,..., agent n,
self.state = np.concatenate((np.expand_dims(self.level, axis=0), agents), axis=0) self.state = np.concatenate((np.expand_dims(self.level, axis=0), agents), axis=0)
# Returns State, Reward, Done, Info # Returns State, Reward, Done, Info
return self.state, 0, self.done, {} return self.state, 0, self.done, {}
def additional_actions(self, agent_i: int, action: int) -> ((int, int), bool): def additional_actions(self, agent_i: int, action: int) -> ((int, int), bool):
@ -103,6 +108,7 @@ class BaseFactory(gym.Env):
for i, collision_vec in enumerate(self.check_all_collisions(states, self.state.shape[0])): for i, collision_vec in enumerate(self.check_all_collisions(states, self.state.shape[0])):
states[i].update(collision_vector=collision_vec) states[i].update(collision_vector=collision_vec)
self.agent_states = states
reward, info = self.calculate_reward(states) reward, info = self.calculate_reward(states)
self.cumulative_reward += reward self.cumulative_reward += reward

View File

@ -28,7 +28,7 @@ class Renderer:
self.screen_size = (grid_h*cell_size, grid_w*cell_size) self.screen_size = (grid_h*cell_size, grid_w*cell_size)
self.screen = pygame.display.set_mode(self.screen_size) self.screen = pygame.display.set_mode(self.screen_size)
self.clock = pygame.time.Clock() self.clock = pygame.time.Clock()
assets = list((Path(__file__).parent / 'assets').glob('*.png')) assets = list((Path(__file__).parent / 'assets').rglob('*.png'))
self.assets = {path.stem: self.load_asset(str(path), 0.97) for path in assets} self.assets = {path.stem: self.load_asset(str(path), 0.97) for path in assets}
self.fill_bg() self.fill_bg()
@ -41,9 +41,9 @@ class Renderer:
rect = pygame.Rect(x, y, self.cell_size, self.cell_size) rect = pygame.Rect(x, y, self.cell_size, self.cell_size)
pygame.draw.rect(self.screen, Renderer.WHITE, rect, 1) pygame.draw.rect(self.screen, Renderer.WHITE, rect, 1)
def blit_params(self, entity, name): def blit_params(self, entity):
r, c = entity.pos r, c = entity.pos
img = self.assets[name] img = self.assets[entity.name]
if entity.value_operation == 'opacity': if entity.value_operation == 'opacity':
img.set_alpha(255*entity.value) img.set_alpha(255*entity.value)
elif entity.value_operation == 'scale': elif entity.value_operation == 'scale':
@ -71,7 +71,7 @@ class Renderer:
self.fill_bg() self.fill_bg()
for asset, entities in pos_dict.items(): for asset, entities in pos_dict.items():
for entity in entities: for entity in entities:
bp = self.blit_params(entity, asset) bp = self.blit_params(entity)
if 'agent' in asset and self.view_radius > 0: if 'agent' in asset and self.view_radius > 0:
visibility_rect = bp['dest'].inflate((self.view_radius*2)*self.cell_size, (self.view_radius*2)*self.cell_size) visibility_rect = bp['dest'].inflate((self.view_radius*2)*self.cell_size, (self.view_radius*2)*self.cell_size)
shape_surf = pygame.Surface(visibility_rect.size, pygame.SRCALPHA) shape_surf = pygame.Surface(visibility_rect.size, pygame.SRCALPHA)

View File

@ -41,10 +41,11 @@ class GettingDirty(BaseFactory):
dirt = [Entity('dirt', [x, y], min(1.1*self.state[DIRT_INDEX, x, y], 1), 'opacity') dirt = [Entity('dirt', [x, y], min(1.1*self.state[DIRT_INDEX, x, y], 1), 'opacity')
for x, y in np.argwhere(self.state[DIRT_INDEX] > h.IS_FREE_CELL)] for x, y in np.argwhere(self.state[DIRT_INDEX] > h.IS_FREE_CELL)]
walls = [Entity('dirt', pos) for pos in np.argwhere(self.state[h.LEVEL_IDX] > h.IS_FREE_CELL)] walls = [Entity('wall', pos) for pos in np.argwhere(self.state[h.LEVEL_IDX] > h.IS_FREE_CELL)]
agents = [Entity('agent1', pos) for pos in np.argwhere(self.state[h.AGENT_START_IDX] > h.IS_FREE_CELL)] violation = lambda agent: agent.action_valid and agent.collision_vector[h.LEVEL_IDX] <= 0
agents = {f'agent{i+1}': [Entity(f'agent{i+1}' if violation(agent) else f'agent{i+1}violation', agent.pos)]
self.renderer.render(OrderedDict(dirt=dirt, wall=walls, agent1=agents)) for i, agent in enumerate(self.agent_states)}
self.renderer.render(OrderedDict(dirt=dirt, wall=walls, **agents))
def spawn_dirt(self) -> None: def spawn_dirt(self) -> None:
free_for_dirt = self.free_cells(excluded_slices=DIRT_INDEX) free_for_dirt = self.free_cells(excluded_slices=DIRT_INDEX)