mirror of
https://github.com/illiumst/marl-factory-grid.git
synced 2025-06-22 11:41:34 +02:00
moved renderer.py to base, added initial salina experiments
This commit is contained in:
@ -1,4 +1,4 @@
|
||||
def make(env_str, n_agents=1, pomdp_r=2, max_steps=400):
|
||||
def make(env_str, n_agents=1, pomdp_r=2, max_steps=400, stack_n_frames=3):
|
||||
import yaml
|
||||
from pathlib import Path
|
||||
from environments.factory.combined_factories import DirtItemFactory
|
||||
@ -9,7 +9,8 @@ def make(env_str, n_agents=1, pomdp_r=2, max_steps=400):
|
||||
with (Path(__file__).parent / 'levels' / 'parameters' / f'{env_str}.yaml').open('r') as stream:
|
||||
dictionary = yaml.load(stream, Loader=yaml.FullLoader)
|
||||
|
||||
obs_props = ObservationProperties(render_agents=AgentRenderOptions.COMBINED, frames_to_stack=0, pomdp_r=pomdp_r)
|
||||
obs_props = ObservationProperties(render_agents=AgentRenderOptions.COMBINED,
|
||||
frames_to_stack=stack_n_frames, pomdp_r=pomdp_r)
|
||||
|
||||
factory_kwargs = dict(n_agents=n_agents, max_steps=max_steps, obs_prop=obs_props,
|
||||
mv_prop=MovementProperties(**dictionary['movement_props']),
|
||||
@ -17,4 +18,4 @@ def make(env_str, n_agents=1, pomdp_r=2, max_steps=400):
|
||||
record_episodes=False, verbose=False, **dictionary['factory_props']
|
||||
)
|
||||
|
||||
return DirtFactory(**factory_kwargs)
|
||||
return DirtFactory(**factory_kwargs).__enter__()
|
||||
|
@ -544,7 +544,7 @@ class BaseFactory(gym.Env):
|
||||
|
||||
def render(self, mode='human'):
|
||||
if not self._renderer: # lazy init
|
||||
from environments.factory.renderer import Renderer, RenderEntity
|
||||
from environments.factory.base.renderer import Renderer, RenderEntity
|
||||
global Renderer, RenderEntity
|
||||
height, width = self._obs_cube.shape[1:]
|
||||
self._renderer = Renderer(width, height, view_radius=self._pomdp_r, fps=5)
|
||||
@ -562,7 +562,7 @@ class BaseFactory(gym.Env):
|
||||
doors.append(RenderEntity(name, door.pos, 1, 'none', state, i + 1))
|
||||
additional_assets = self.render_additional_assets()
|
||||
|
||||
self._renderer.render(walls + doors + additional_assets + agents)
|
||||
return self._renderer.render(walls + doors + additional_assets + agents)
|
||||
|
||||
def save_params(self, filepath: Path):
|
||||
# noinspection PyProtectedMember
|
||||
|
@ -7,6 +7,8 @@ import pygame
|
||||
from typing import NamedTuple, Any
|
||||
import time
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
class RenderEntity(NamedTuple):
|
||||
name: str
|
||||
@ -22,7 +24,7 @@ class Renderer:
|
||||
BG_COLOR = (178, 190, 195) # (99, 110, 114)
|
||||
WHITE = (223, 230, 233) # (200, 200, 200)
|
||||
AGENT_VIEW_COLOR = (9, 132, 227)
|
||||
ASSETS = Path(__file__).parent / 'assets'
|
||||
ASSETS = Path(__file__).parent.parent / 'assets'
|
||||
|
||||
def __init__(self, grid_w=16, grid_h=16, cell_size=40, fps=7, grid_lines=True, view_radius=2):
|
||||
self.grid_h = grid_h
|
||||
@ -121,6 +123,8 @@ class Renderer:
|
||||
|
||||
pygame.display.flip()
|
||||
self.clock.tick(self.fps)
|
||||
rgb_obs = pygame.surfarray.array3d(self.screen)
|
||||
return torch.from_numpy(rgb_obs).permute(2, 0, 1)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
@ -1,11 +1,11 @@
|
||||
from typing import Union, NamedTuple, Dict
|
||||
from typing import Union, NamedTuple
|
||||
|
||||
import numpy as np
|
||||
|
||||
from environments.factory.base.base_factory import BaseFactory
|
||||
from environments.factory.base.objects import Agent, Action, Entity
|
||||
from environments.factory.base.registers import EntityObjectRegister, ObjectRegister
|
||||
from environments.factory.renderer import RenderEntity
|
||||
from environments.factory.base.renderer import RenderEntity
|
||||
from environments.helpers import Constants as c
|
||||
|
||||
from environments import helpers as h
|
||||
|
@ -1,6 +1,5 @@
|
||||
import time
|
||||
from enum import Enum
|
||||
from pathlib import Path
|
||||
from typing import List, Union, NamedTuple, Dict
|
||||
import random
|
||||
|
||||
@ -12,8 +11,7 @@ from environments.factory.base.base_factory import BaseFactory
|
||||
from environments.factory.base.objects import Agent, Action, Entity, Tile
|
||||
from environments.factory.base.registers import Entities, MovingEntityObjectRegister
|
||||
|
||||
from environments.factory.renderer import RenderEntity
|
||||
from environments.logging.recorder import RecorderCallback
|
||||
from environments.factory.base.renderer import RenderEntity
|
||||
from environments.utility_classes import ObservationProperties
|
||||
|
||||
CLEAN_UP_ACTION = h.EnvActions.CLEAN_UP
|
||||
|
@ -10,9 +10,9 @@ from environments.helpers import Constants as c
|
||||
from environments import helpers as h
|
||||
from environments.factory.base.objects import Agent, Entity, Action, Tile, MoveableEntity
|
||||
from environments.factory.base.registers import Entities, EntityObjectRegister, ObjectRegister, \
|
||||
MovingEntityObjectRegister, Register
|
||||
MovingEntityObjectRegister
|
||||
|
||||
from environments.factory.renderer import RenderEntity
|
||||
from environments.factory.base.renderer import RenderEntity
|
||||
|
||||
|
||||
NO_ITEM = 0
|
||||
|
Reference in New Issue
Block a user