moved renderer.py to base, added initial salina experiments

This commit is contained in:
romue
2021-11-12 13:47:53 +01:00
parent f625b9d8a5
commit b6bda84033
7 changed files with 105 additions and 31 deletions

View File

@ -1,4 +1,4 @@
def make(env_str, n_agents=1, pomdp_r=2, max_steps=400):
def make(env_str, n_agents=1, pomdp_r=2, max_steps=400, stack_n_frames=3):
import yaml
from pathlib import Path
from environments.factory.combined_factories import DirtItemFactory
@ -9,7 +9,8 @@ def make(env_str, n_agents=1, pomdp_r=2, max_steps=400):
with (Path(__file__).parent / 'levels' / 'parameters' / f'{env_str}.yaml').open('r') as stream:
dictionary = yaml.load(stream, Loader=yaml.FullLoader)
obs_props = ObservationProperties(render_agents=AgentRenderOptions.COMBINED, frames_to_stack=0, pomdp_r=pomdp_r)
obs_props = ObservationProperties(render_agents=AgentRenderOptions.COMBINED,
frames_to_stack=stack_n_frames, pomdp_r=pomdp_r)
factory_kwargs = dict(n_agents=n_agents, max_steps=max_steps, obs_prop=obs_props,
mv_prop=MovementProperties(**dictionary['movement_props']),
@ -17,4 +18,4 @@ def make(env_str, n_agents=1, pomdp_r=2, max_steps=400):
record_episodes=False, verbose=False, **dictionary['factory_props']
)
return DirtFactory(**factory_kwargs)
return DirtFactory(**factory_kwargs).__enter__()

View File

@ -544,7 +544,7 @@ class BaseFactory(gym.Env):
def render(self, mode='human'):
if not self._renderer: # lazy init
from environments.factory.renderer import Renderer, RenderEntity
from environments.factory.base.renderer import Renderer, RenderEntity
global Renderer, RenderEntity
height, width = self._obs_cube.shape[1:]
self._renderer = Renderer(width, height, view_radius=self._pomdp_r, fps=5)
@ -562,7 +562,7 @@ class BaseFactory(gym.Env):
doors.append(RenderEntity(name, door.pos, 1, 'none', state, i + 1))
additional_assets = self.render_additional_assets()
self._renderer.render(walls + doors + additional_assets + agents)
return self._renderer.render(walls + doors + additional_assets + agents)
def save_params(self, filepath: Path):
# noinspection PyProtectedMember

View File

@ -7,6 +7,8 @@ import pygame
from typing import NamedTuple, Any
import time
import torch
class RenderEntity(NamedTuple):
name: str
@ -22,7 +24,7 @@ class Renderer:
BG_COLOR = (178, 190, 195) # (99, 110, 114)
WHITE = (223, 230, 233) # (200, 200, 200)
AGENT_VIEW_COLOR = (9, 132, 227)
ASSETS = Path(__file__).parent / 'assets'
ASSETS = Path(__file__).parent.parent / 'assets'
def __init__(self, grid_w=16, grid_h=16, cell_size=40, fps=7, grid_lines=True, view_radius=2):
self.grid_h = grid_h
@ -121,6 +123,8 @@ class Renderer:
pygame.display.flip()
self.clock.tick(self.fps)
rgb_obs = pygame.surfarray.array3d(self.screen)
return torch.from_numpy(rgb_obs).permute(2, 0, 1)
if __name__ == '__main__':

View File

@ -1,11 +1,11 @@
from typing import Union, NamedTuple, Dict
from typing import Union, NamedTuple
import numpy as np
from environments.factory.base.base_factory import BaseFactory
from environments.factory.base.objects import Agent, Action, Entity
from environments.factory.base.registers import EntityObjectRegister, ObjectRegister
from environments.factory.renderer import RenderEntity
from environments.factory.base.renderer import RenderEntity
from environments.helpers import Constants as c
from environments import helpers as h

View File

@ -1,6 +1,5 @@
import time
from enum import Enum
from pathlib import Path
from typing import List, Union, NamedTuple, Dict
import random
@ -12,8 +11,7 @@ from environments.factory.base.base_factory import BaseFactory
from environments.factory.base.objects import Agent, Action, Entity, Tile
from environments.factory.base.registers import Entities, MovingEntityObjectRegister
from environments.factory.renderer import RenderEntity
from environments.logging.recorder import RecorderCallback
from environments.factory.base.renderer import RenderEntity
from environments.utility_classes import ObservationProperties
CLEAN_UP_ACTION = h.EnvActions.CLEAN_UP

View File

@ -10,9 +10,9 @@ from environments.helpers import Constants as c
from environments import helpers as h
from environments.factory.base.objects import Agent, Entity, Action, Tile, MoveableEntity
from environments.factory.base.registers import Entities, EntityObjectRegister, ObjectRegister, \
MovingEntityObjectRegister, Register
MovingEntityObjectRegister
from environments.factory.renderer import RenderEntity
from environments.factory.base.renderer import RenderEntity
NO_ITEM = 0