pygame import only when needed

This commit is contained in:
romue 2021-11-11 17:52:48 +01:00
parent 62d72e0712
commit 8960cf622b
2 changed files with 26 additions and 8 deletions

View File

@ -11,7 +11,6 @@ from gym import spaces
from gym.wrappers import FrameStack
from environments.factory.base.shadow_casting import Map
from environments.factory.renderer import Renderer, RenderEntity
from environments.helpers import Constants as c, Constants
from environments import helpers as h
from environments.factory.base.objects import Agent, Tile, Action
@ -545,6 +544,8 @@ class BaseFactory(gym.Env):
def render(self, mode='human'):
if not self._renderer: # lazy init
from environments.factory.renderer import Renderer, RenderEntity
global Renderer, RenderEntity
height, width = self._obs_cube.shape[1:]
self._renderer = Renderer(width, height, view_radius=self._pomdp_r, fps=5)

View File

@ -1,12 +1,29 @@
from environments.factory import make
import salina
import torch
from gym.wrappers import FrameStack
n_agents = 4
env = make('DirtyFactory-v0', n_agents=n_agents)
env = FrameStack(env, num_stack=3)
state, *_ = env.reset()
for i in range(1000):
state, *_ = env.step([env.unwrapped.action_space.sample() for _ in range(n_agents)])
env.render()
class MyAgent(salina.TAgent):
def __init__(self):
super(MyAgent, self).__init__()
def forward(self, t, **kwargs):
self.set(('timer', t), torch.tensor([t]))
if __name__ == '__main__':
n_agents = 1
env = make('DirtyFactory-v0', n_agents=n_agents)
env = FrameStack(env, num_stack=3)
env.reset()
agent = MyAgent()
workspace = salina.Workspace()
agent(workspace, t=0, n_steps=10)
print(workspace)
for i in range(1000):
state, *_ = env.step([env.unwrapped.action_space.sample() for _ in range(n_agents)])
#env.render()