moved renderer.py to base, added initial salina experiments

This commit is contained in:
romue
2021-11-12 13:47:53 +01:00
parent f625b9d8a5
commit b6bda84033
7 changed files with 105 additions and 31 deletions

View File

@ -1,4 +1,4 @@
def make(env_str, n_agents=1, pomdp_r=2, max_steps=400): def make(env_str, n_agents=1, pomdp_r=2, max_steps=400, stack_n_frames=3):
import yaml import yaml
from pathlib import Path from pathlib import Path
from environments.factory.combined_factories import DirtItemFactory from environments.factory.combined_factories import DirtItemFactory
@ -9,7 +9,8 @@ def make(env_str, n_agents=1, pomdp_r=2, max_steps=400):
with (Path(__file__).parent / 'levels' / 'parameters' / f'{env_str}.yaml').open('r') as stream: with (Path(__file__).parent / 'levels' / 'parameters' / f'{env_str}.yaml').open('r') as stream:
dictionary = yaml.load(stream, Loader=yaml.FullLoader) dictionary = yaml.load(stream, Loader=yaml.FullLoader)
obs_props = ObservationProperties(render_agents=AgentRenderOptions.COMBINED, frames_to_stack=0, pomdp_r=pomdp_r) obs_props = ObservationProperties(render_agents=AgentRenderOptions.COMBINED,
frames_to_stack=stack_n_frames, pomdp_r=pomdp_r)
factory_kwargs = dict(n_agents=n_agents, max_steps=max_steps, obs_prop=obs_props, factory_kwargs = dict(n_agents=n_agents, max_steps=max_steps, obs_prop=obs_props,
mv_prop=MovementProperties(**dictionary['movement_props']), mv_prop=MovementProperties(**dictionary['movement_props']),
@ -17,4 +18,4 @@ def make(env_str, n_agents=1, pomdp_r=2, max_steps=400):
record_episodes=False, verbose=False, **dictionary['factory_props'] record_episodes=False, verbose=False, **dictionary['factory_props']
) )
return DirtFactory(**factory_kwargs) return DirtFactory(**factory_kwargs).__enter__()

View File

@ -544,7 +544,7 @@ class BaseFactory(gym.Env):
def render(self, mode='human'): def render(self, mode='human'):
if not self._renderer: # lazy init if not self._renderer: # lazy init
from environments.factory.renderer import Renderer, RenderEntity from environments.factory.base.renderer import Renderer, RenderEntity
global Renderer, RenderEntity global Renderer, RenderEntity
height, width = self._obs_cube.shape[1:] height, width = self._obs_cube.shape[1:]
self._renderer = Renderer(width, height, view_radius=self._pomdp_r, fps=5) self._renderer = Renderer(width, height, view_radius=self._pomdp_r, fps=5)
@ -562,7 +562,7 @@ class BaseFactory(gym.Env):
doors.append(RenderEntity(name, door.pos, 1, 'none', state, i + 1)) doors.append(RenderEntity(name, door.pos, 1, 'none', state, i + 1))
additional_assets = self.render_additional_assets() additional_assets = self.render_additional_assets()
self._renderer.render(walls + doors + additional_assets + agents) return self._renderer.render(walls + doors + additional_assets + agents)
def save_params(self, filepath: Path): def save_params(self, filepath: Path):
# noinspection PyProtectedMember # noinspection PyProtectedMember

View File

@ -7,6 +7,8 @@ import pygame
from typing import NamedTuple, Any from typing import NamedTuple, Any
import time import time
import torch
class RenderEntity(NamedTuple): class RenderEntity(NamedTuple):
name: str name: str
@ -22,7 +24,7 @@ class Renderer:
BG_COLOR = (178, 190, 195) # (99, 110, 114) BG_COLOR = (178, 190, 195) # (99, 110, 114)
WHITE = (223, 230, 233) # (200, 200, 200) WHITE = (223, 230, 233) # (200, 200, 200)
AGENT_VIEW_COLOR = (9, 132, 227) AGENT_VIEW_COLOR = (9, 132, 227)
ASSETS = Path(__file__).parent / 'assets' ASSETS = Path(__file__).parent.parent / 'assets'
def __init__(self, grid_w=16, grid_h=16, cell_size=40, fps=7, grid_lines=True, view_radius=2): def __init__(self, grid_w=16, grid_h=16, cell_size=40, fps=7, grid_lines=True, view_radius=2):
self.grid_h = grid_h self.grid_h = grid_h
@ -121,6 +123,8 @@ class Renderer:
pygame.display.flip() pygame.display.flip()
self.clock.tick(self.fps) self.clock.tick(self.fps)
rgb_obs = pygame.surfarray.array3d(self.screen)
return torch.from_numpy(rgb_obs).permute(2, 0, 1)
if __name__ == '__main__': if __name__ == '__main__':

View File

@ -1,11 +1,11 @@
from typing import Union, NamedTuple, Dict from typing import Union, NamedTuple
import numpy as np import numpy as np
from environments.factory.base.base_factory import BaseFactory from environments.factory.base.base_factory import BaseFactory
from environments.factory.base.objects import Agent, Action, Entity from environments.factory.base.objects import Agent, Action, Entity
from environments.factory.base.registers import EntityObjectRegister, ObjectRegister from environments.factory.base.registers import EntityObjectRegister, ObjectRegister
from environments.factory.renderer import RenderEntity from environments.factory.base.renderer import RenderEntity
from environments.helpers import Constants as c from environments.helpers import Constants as c
from environments import helpers as h from environments import helpers as h

View File

@ -1,6 +1,5 @@
import time import time
from enum import Enum from enum import Enum
from pathlib import Path
from typing import List, Union, NamedTuple, Dict from typing import List, Union, NamedTuple, Dict
import random import random
@ -12,8 +11,7 @@ from environments.factory.base.base_factory import BaseFactory
from environments.factory.base.objects import Agent, Action, Entity, Tile from environments.factory.base.objects import Agent, Action, Entity, Tile
from environments.factory.base.registers import Entities, MovingEntityObjectRegister from environments.factory.base.registers import Entities, MovingEntityObjectRegister
from environments.factory.renderer import RenderEntity from environments.factory.base.renderer import RenderEntity
from environments.logging.recorder import RecorderCallback
from environments.utility_classes import ObservationProperties from environments.utility_classes import ObservationProperties
CLEAN_UP_ACTION = h.EnvActions.CLEAN_UP CLEAN_UP_ACTION = h.EnvActions.CLEAN_UP

View File

@ -10,9 +10,9 @@ from environments.helpers import Constants as c
from environments import helpers as h from environments import helpers as h
from environments.factory.base.objects import Agent, Entity, Action, Tile, MoveableEntity from environments.factory.base.objects import Agent, Entity, Action, Tile, MoveableEntity
from environments.factory.base.registers import Entities, EntityObjectRegister, ObjectRegister, \ from environments.factory.base.registers import Entities, EntityObjectRegister, ObjectRegister, \
MovingEntityObjectRegister, Register MovingEntityObjectRegister
from environments.factory.renderer import RenderEntity from environments.factory.base.renderer import RenderEntity
NO_ITEM = 0 NO_ITEM = 0

View File

@ -1,29 +1,100 @@
from environments.factory import make from environments.factory import make
import salina from salina import Workspace, TAgent
from salina.agents.gyma import AutoResetGymAgent, GymAgent
from salina.agents import Agents, TemporalAgent
from salina.rl.functional import _index
import torch import torch
from gym.wrappers import FrameStack import torch.nn as nn
from torch.nn.utils import spectral_norm
import torch.optim as optim
from torch.distributions import Categorical
class MyAgent(salina.TAgent): class A2CAgent(TAgent):
def __init__(self): def __init__(self, observation_size, hidden_size, n_actions):
super(MyAgent, self).__init__() super().__init__()
self.model = nn.Sequential(
nn.Flatten(),
nn.Linear(observation_size, hidden_size),
nn.ELU(),
nn.Linear(hidden_size, hidden_size),
nn.ELU(),
nn.Linear(hidden_size, n_actions),
)
self.critic_model = nn.Sequential(
nn.Flatten(),
nn.Linear(observation_size, hidden_size),
nn.ELU(),
spectral_norm(nn.Linear(hidden_size, 1)),
)
def forward(self, t, **kwargs): def forward(self, t, stochastic, **kwargs):
self.set(('timer', t), torch.tensor([t])) observation = self.get(("env/env_obs", t))
scores = self.model(observation)
probs = torch.softmax(scores, dim=-1)
critic = self.critic_model(observation).squeeze(-1)
if stochastic:
action = torch.distributions.Categorical(probs).sample()
else:
action = probs.argmax(1)
self.set(("action", t), action)
self.set(("action_probs", t), probs)
self.set(("critic", t), critic)
if __name__ == '__main__': if __name__ == '__main__':
n_agents = 1 # Setup agents and workspace
env = make('DirtyFactory-v0', n_agents=n_agents) env_agent = AutoResetGymAgent(make, dict(env_str='DirtyFactory-v0'), n_envs=1)
env = FrameStack(env, num_stack=3) a2c_agent = A2CAgent(3*4*5*5, 96, 10)
env.reset() workspace = Workspace()
agent = MyAgent()
workspace = salina.Workspace()
agent(workspace, t=0, n_steps=10)
print(workspace) eval_agent = Agents(GymAgent(make, dict(env_str='DirtyFactory-v0'), n_envs=1), a2c_agent)
for i in range(100):
eval_agent(workspace, t=i, save_render=True, stochastic=True)
assert False
# combine agents
acquisition_agent = TemporalAgent(Agents(env_agent, a2c_agent))
acquisition_agent.seed(0)
for i in range(1000): # optimizers & other parameters
state, *_ = env.step([env.unwrapped.action_space.sample() for _ in range(n_agents)]) optimizer = optim.Adam(a2c_agent.parameters(), lr=1e-3)
#env.render() n_timesteps = 10
# Decision making loop
for epoch in range(200000):
workspace.zero_grad()
if epoch > 0:
workspace.copy_n_last_steps(1)
acquisition_agent(workspace, t=1, n_steps=n_timesteps-1, stochastic=True)
else:
acquisition_agent(workspace, t=0, n_steps=n_timesteps, stochastic=True)
#for k in workspace.keys():
# print(f'{k} ==> {workspace[k].size()}')
critic, done, action_probs, reward, action = workspace[
"critic", "env/done", "action_probs", "env/reward", "action"
]
target = reward[1:] + 0.99 * critic[1:].detach() * (1 - done[1:].float())
td = target - critic[:-1]
td_error = td ** 2
critic_loss = td_error.mean()
entropy_loss = Categorical(action_probs).entropy().mean()
action_logp = _index(action_probs, action).log()
a2c_loss = action_logp[:-1] * td.detach()
a2c_loss = a2c_loss.mean()
loss = (
-0.001 * entropy_loss
+ 1.0 * critic_loss
- 0.1 * a2c_loss
)
optimizer.zero_grad()
loss.backward()
optimizer.step()
# Compute the cumulated reward on final_state
creward = workspace["env/cumulated_reward"]
creward = creward[done]
if creward.size()[0] > 0:
print(f"Cumulative reward at A2C step #{(1+epoch)*n_timesteps}: {creward.mean().item()}")