diff --git a/environments/factory/__init__.py b/environments/factory/__init__.py index 9e57a7c..bb53095 100644 --- a/environments/factory/__init__.py +++ b/environments/factory/__init__.py @@ -1,4 +1,4 @@ -def make(env_str, n_agents=1, pomdp_r=2, max_steps=400): +def make(env_str, n_agents=1, pomdp_r=2, max_steps=400, stack_n_frames=3): import yaml from pathlib import Path from environments.factory.combined_factories import DirtItemFactory @@ -9,7 +9,8 @@ def make(env_str, n_agents=1, pomdp_r=2, max_steps=400): with (Path(__file__).parent / 'levels' / 'parameters' / f'{env_str}.yaml').open('r') as stream: dictionary = yaml.load(stream, Loader=yaml.FullLoader) - obs_props = ObservationProperties(render_agents=AgentRenderOptions.COMBINED, frames_to_stack=0, pomdp_r=pomdp_r) + obs_props = ObservationProperties(render_agents=AgentRenderOptions.COMBINED, + frames_to_stack=stack_n_frames, pomdp_r=pomdp_r) factory_kwargs = dict(n_agents=n_agents, max_steps=max_steps, obs_prop=obs_props, mv_prop=MovementProperties(**dictionary['movement_props']), @@ -17,4 +18,4 @@ def make(env_str, n_agents=1, pomdp_r=2, max_steps=400): record_episodes=False, verbose=False, **dictionary['factory_props'] ) - return DirtFactory(**factory_kwargs) + return DirtFactory(**factory_kwargs).__enter__() diff --git a/environments/factory/base/base_factory.py b/environments/factory/base/base_factory.py index d417190..a9b3c1d 100644 --- a/environments/factory/base/base_factory.py +++ b/environments/factory/base/base_factory.py @@ -544,7 +544,7 @@ class BaseFactory(gym.Env): def render(self, mode='human'): if not self._renderer: # lazy init - from environments.factory.renderer import Renderer, RenderEntity + from environments.factory.base.renderer import Renderer, RenderEntity global Renderer, RenderEntity height, width = self._obs_cube.shape[1:] self._renderer = Renderer(width, height, view_radius=self._pomdp_r, fps=5) @@ -562,7 +562,7 @@ class BaseFactory(gym.Env): doors.append(RenderEntity(name, door.pos, 1, 'none', state, i + 1)) additional_assets = self.render_additional_assets() - self._renderer.render(walls + doors + additional_assets + agents) + return self._renderer.render(walls + doors + additional_assets + agents) def save_params(self, filepath: Path): # noinspection PyProtectedMember diff --git a/environments/factory/renderer.py b/environments/factory/base/renderer.py similarity index 96% rename from environments/factory/renderer.py rename to environments/factory/base/renderer.py index e8f4297..92eefc1 100644 --- a/environments/factory/renderer.py +++ b/environments/factory/base/renderer.py @@ -7,6 +7,8 @@ import pygame from typing import NamedTuple, Any import time +import torch + class RenderEntity(NamedTuple): name: str @@ -22,7 +24,7 @@ class Renderer: BG_COLOR = (178, 190, 195) # (99, 110, 114) WHITE = (223, 230, 233) # (200, 200, 200) AGENT_VIEW_COLOR = (9, 132, 227) - ASSETS = Path(__file__).parent / 'assets' + ASSETS = Path(__file__).parent.parent / 'assets' def __init__(self, grid_w=16, grid_h=16, cell_size=40, fps=7, grid_lines=True, view_radius=2): self.grid_h = grid_h @@ -121,6 +123,8 @@ class Renderer: pygame.display.flip() self.clock.tick(self.fps) + rgb_obs = pygame.surfarray.array3d(self.screen) + return torch.from_numpy(rgb_obs).permute(2, 0, 1) if __name__ == '__main__': diff --git a/environments/factory/factory_battery.py b/environments/factory/factory_battery.py index c24bf99..760963b 100644 --- a/environments/factory/factory_battery.py +++ b/environments/factory/factory_battery.py @@ -1,11 +1,11 @@ -from typing import Union, NamedTuple, Dict +from typing import Union, NamedTuple import numpy as np from environments.factory.base.base_factory import BaseFactory from environments.factory.base.objects import Agent, Action, Entity from environments.factory.base.registers import EntityObjectRegister, ObjectRegister -from environments.factory.renderer import RenderEntity +from environments.factory.base.renderer import RenderEntity from environments.helpers import Constants as c from environments import helpers as h diff --git a/environments/factory/factory_dirt.py b/environments/factory/factory_dirt.py index 67230f4..92d15c6 100644 --- a/environments/factory/factory_dirt.py +++ b/environments/factory/factory_dirt.py @@ -1,6 +1,5 @@ import time from enum import Enum -from pathlib import Path from typing import List, Union, NamedTuple, Dict import random @@ -12,8 +11,7 @@ from environments.factory.base.base_factory import BaseFactory from environments.factory.base.objects import Agent, Action, Entity, Tile from environments.factory.base.registers import Entities, MovingEntityObjectRegister -from environments.factory.renderer import RenderEntity -from environments.logging.recorder import RecorderCallback +from environments.factory.base.renderer import RenderEntity from environments.utility_classes import ObservationProperties CLEAN_UP_ACTION = h.EnvActions.CLEAN_UP diff --git a/environments/factory/factory_item.py b/environments/factory/factory_item.py index c67d59e..5538953 100644 --- a/environments/factory/factory_item.py +++ b/environments/factory/factory_item.py @@ -10,9 +10,9 @@ from environments.helpers import Constants as c from environments import helpers as h from environments.factory.base.objects import Agent, Entity, Action, Tile, MoveableEntity from environments.factory.base.registers import Entities, EntityObjectRegister, ObjectRegister, \ - MovingEntityObjectRegister, Register + MovingEntityObjectRegister -from environments.factory.renderer import RenderEntity +from environments.factory.base.renderer import RenderEntity NO_ITEM = 0 diff --git a/studies/sat_mad.py b/studies/sat_mad.py index 6b3b1f6..06b477d 100644 --- a/studies/sat_mad.py +++ b/studies/sat_mad.py @@ -1,29 +1,100 @@ from environments.factory import make -import salina +from salina import Workspace, TAgent +from salina.agents.gyma import AutoResetGymAgent, GymAgent +from salina.agents import Agents, TemporalAgent +from salina.rl.functional import _index import torch -from gym.wrappers import FrameStack +import torch.nn as nn +from torch.nn.utils import spectral_norm +import torch.optim as optim +from torch.distributions import Categorical -class MyAgent(salina.TAgent): - def __init__(self): - super(MyAgent, self).__init__() +class A2CAgent(TAgent): + def __init__(self, observation_size, hidden_size, n_actions): + super().__init__() + self.model = nn.Sequential( + nn.Flatten(), + nn.Linear(observation_size, hidden_size), + nn.ELU(), + nn.Linear(hidden_size, hidden_size), + nn.ELU(), + nn.Linear(hidden_size, n_actions), + ) + self.critic_model = nn.Sequential( + nn.Flatten(), + nn.Linear(observation_size, hidden_size), + nn.ELU(), + spectral_norm(nn.Linear(hidden_size, 1)), + ) - def forward(self, t, **kwargs): - self.set(('timer', t), torch.tensor([t])) + def forward(self, t, stochastic, **kwargs): + observation = self.get(("env/env_obs", t)) + scores = self.model(observation) + probs = torch.softmax(scores, dim=-1) + critic = self.critic_model(observation).squeeze(-1) + if stochastic: + action = torch.distributions.Categorical(probs).sample() + else: + action = probs.argmax(1) + + self.set(("action", t), action) + self.set(("action_probs", t), probs) + self.set(("critic", t), critic) if __name__ == '__main__': - n_agents = 1 - env = make('DirtyFactory-v0', n_agents=n_agents) - env = FrameStack(env, num_stack=3) - env.reset() - agent = MyAgent() - workspace = salina.Workspace() - agent(workspace, t=0, n_steps=10) + # Setup agents and workspace + env_agent = AutoResetGymAgent(make, dict(env_str='DirtyFactory-v0'), n_envs=1) + a2c_agent = A2CAgent(3*4*5*5, 96, 10) + workspace = Workspace() - print(workspace) + eval_agent = Agents(GymAgent(make, dict(env_str='DirtyFactory-v0'), n_envs=1), a2c_agent) + for i in range(100): + eval_agent(workspace, t=i, save_render=True, stochastic=True) + assert False + # combine agents + acquisition_agent = TemporalAgent(Agents(env_agent, a2c_agent)) + acquisition_agent.seed(0) - for i in range(1000): - state, *_ = env.step([env.unwrapped.action_space.sample() for _ in range(n_agents)]) - #env.render() \ No newline at end of file + # optimizers & other parameters + optimizer = optim.Adam(a2c_agent.parameters(), lr=1e-3) + n_timesteps = 10 + + # Decision making loop + for epoch in range(200000): + workspace.zero_grad() + if epoch > 0: + workspace.copy_n_last_steps(1) + acquisition_agent(workspace, t=1, n_steps=n_timesteps-1, stochastic=True) + else: + acquisition_agent(workspace, t=0, n_steps=n_timesteps, stochastic=True) + #for k in workspace.keys(): + # print(f'{k} ==> {workspace[k].size()}') + critic, done, action_probs, reward, action = workspace[ + "critic", "env/done", "action_probs", "env/reward", "action" + ] + + target = reward[1:] + 0.99 * critic[1:].detach() * (1 - done[1:].float()) + td = target - critic[:-1] + td_error = td ** 2 + critic_loss = td_error.mean() + entropy_loss = Categorical(action_probs).entropy().mean() + action_logp = _index(action_probs, action).log() + a2c_loss = action_logp[:-1] * td.detach() + a2c_loss = a2c_loss.mean() + loss = ( + -0.001 * entropy_loss + + 1.0 * critic_loss + - 0.1 * a2c_loss + ) + optimizer.zero_grad() + loss.backward() + optimizer.step() + + # Compute the cumulated reward on final_state + creward = workspace["env/cumulated_reward"] + creward = creward[done] + if creward.size()[0] > 0: + print(f"Cumulative reward at A2C step #{(1+epoch)*n_timesteps}: {creward.mean().item()}") \ No newline at end of file