diff --git a/environments/factory/base/base_factory.py b/environments/factory/base/base_factory.py index f788ebe..df05fa7 100644 --- a/environments/factory/base/base_factory.py +++ b/environments/factory/base/base_factory.py @@ -11,7 +11,6 @@ from gym import spaces from gym.wrappers import FrameStack from environments.factory.base.shadow_casting import Map -from environments.factory.renderer import Renderer, RenderEntity from environments.helpers import Constants as c, Constants from environments import helpers as h from environments.factory.base.objects import Agent, Tile, Action @@ -545,6 +544,8 @@ class BaseFactory(gym.Env): def render(self, mode='human'): if not self._renderer: # lazy init + from environments.factory.renderer import Renderer, RenderEntity + global Renderer, RenderEntity height, width = self._obs_cube.shape[1:] self._renderer = Renderer(width, height, view_radius=self._pomdp_r, fps=5) diff --git a/studies/sat_mad.py b/studies/sat_mad.py index 000e0bc..6b3b1f6 100644 --- a/studies/sat_mad.py +++ b/studies/sat_mad.py @@ -1,12 +1,29 @@ from environments.factory import make import salina +import torch from gym.wrappers import FrameStack -n_agents = 4 -env = make('DirtyFactory-v0', n_agents=n_agents) -env = FrameStack(env, num_stack=3) -state, *_ = env.reset() -for i in range(1000): - state, *_ = env.step([env.unwrapped.action_space.sample() for _ in range(n_agents)]) - env.render() \ No newline at end of file +class MyAgent(salina.TAgent): + def __init__(self): + super(MyAgent, self).__init__() + + def forward(self, t, **kwargs): + self.set(('timer', t), torch.tensor([t])) + + +if __name__ == '__main__': + n_agents = 1 + env = make('DirtyFactory-v0', n_agents=n_agents) + env = FrameStack(env, num_stack=3) + env.reset() + agent = MyAgent() + workspace = salina.Workspace() + agent(workspace, t=0, n_steps=10) + + print(workspace) + + + for i in range(1000): + state, *_ = env.step([env.unwrapped.action_space.sample() for _ in range(n_agents)]) + #env.render() \ No newline at end of file