moved renderer.py to base, added initial salina experiments
This commit is contained in:
@ -1,4 +1,4 @@
|
|||||||
def make(env_str, n_agents=1, pomdp_r=2, max_steps=400):
|
def make(env_str, n_agents=1, pomdp_r=2, max_steps=400, stack_n_frames=3):
|
||||||
import yaml
|
import yaml
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from environments.factory.combined_factories import DirtItemFactory
|
from environments.factory.combined_factories import DirtItemFactory
|
||||||
@ -9,7 +9,8 @@ def make(env_str, n_agents=1, pomdp_r=2, max_steps=400):
|
|||||||
with (Path(__file__).parent / 'levels' / 'parameters' / f'{env_str}.yaml').open('r') as stream:
|
with (Path(__file__).parent / 'levels' / 'parameters' / f'{env_str}.yaml').open('r') as stream:
|
||||||
dictionary = yaml.load(stream, Loader=yaml.FullLoader)
|
dictionary = yaml.load(stream, Loader=yaml.FullLoader)
|
||||||
|
|
||||||
obs_props = ObservationProperties(render_agents=AgentRenderOptions.COMBINED, frames_to_stack=0, pomdp_r=pomdp_r)
|
obs_props = ObservationProperties(render_agents=AgentRenderOptions.COMBINED,
|
||||||
|
frames_to_stack=stack_n_frames, pomdp_r=pomdp_r)
|
||||||
|
|
||||||
factory_kwargs = dict(n_agents=n_agents, max_steps=max_steps, obs_prop=obs_props,
|
factory_kwargs = dict(n_agents=n_agents, max_steps=max_steps, obs_prop=obs_props,
|
||||||
mv_prop=MovementProperties(**dictionary['movement_props']),
|
mv_prop=MovementProperties(**dictionary['movement_props']),
|
||||||
@ -17,4 +18,4 @@ def make(env_str, n_agents=1, pomdp_r=2, max_steps=400):
|
|||||||
record_episodes=False, verbose=False, **dictionary['factory_props']
|
record_episodes=False, verbose=False, **dictionary['factory_props']
|
||||||
)
|
)
|
||||||
|
|
||||||
return DirtFactory(**factory_kwargs)
|
return DirtFactory(**factory_kwargs).__enter__()
|
||||||
|
@ -544,7 +544,7 @@ class BaseFactory(gym.Env):
|
|||||||
|
|
||||||
def render(self, mode='human'):
|
def render(self, mode='human'):
|
||||||
if not self._renderer: # lazy init
|
if not self._renderer: # lazy init
|
||||||
from environments.factory.renderer import Renderer, RenderEntity
|
from environments.factory.base.renderer import Renderer, RenderEntity
|
||||||
global Renderer, RenderEntity
|
global Renderer, RenderEntity
|
||||||
height, width = self._obs_cube.shape[1:]
|
height, width = self._obs_cube.shape[1:]
|
||||||
self._renderer = Renderer(width, height, view_radius=self._pomdp_r, fps=5)
|
self._renderer = Renderer(width, height, view_radius=self._pomdp_r, fps=5)
|
||||||
@ -562,7 +562,7 @@ class BaseFactory(gym.Env):
|
|||||||
doors.append(RenderEntity(name, door.pos, 1, 'none', state, i + 1))
|
doors.append(RenderEntity(name, door.pos, 1, 'none', state, i + 1))
|
||||||
additional_assets = self.render_additional_assets()
|
additional_assets = self.render_additional_assets()
|
||||||
|
|
||||||
self._renderer.render(walls + doors + additional_assets + agents)
|
return self._renderer.render(walls + doors + additional_assets + agents)
|
||||||
|
|
||||||
def save_params(self, filepath: Path):
|
def save_params(self, filepath: Path):
|
||||||
# noinspection PyProtectedMember
|
# noinspection PyProtectedMember
|
||||||
|
@ -7,6 +7,8 @@ import pygame
|
|||||||
from typing import NamedTuple, Any
|
from typing import NamedTuple, Any
|
||||||
import time
|
import time
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
|
||||||
class RenderEntity(NamedTuple):
|
class RenderEntity(NamedTuple):
|
||||||
name: str
|
name: str
|
||||||
@ -22,7 +24,7 @@ class Renderer:
|
|||||||
BG_COLOR = (178, 190, 195) # (99, 110, 114)
|
BG_COLOR = (178, 190, 195) # (99, 110, 114)
|
||||||
WHITE = (223, 230, 233) # (200, 200, 200)
|
WHITE = (223, 230, 233) # (200, 200, 200)
|
||||||
AGENT_VIEW_COLOR = (9, 132, 227)
|
AGENT_VIEW_COLOR = (9, 132, 227)
|
||||||
ASSETS = Path(__file__).parent / 'assets'
|
ASSETS = Path(__file__).parent.parent / 'assets'
|
||||||
|
|
||||||
def __init__(self, grid_w=16, grid_h=16, cell_size=40, fps=7, grid_lines=True, view_radius=2):
|
def __init__(self, grid_w=16, grid_h=16, cell_size=40, fps=7, grid_lines=True, view_radius=2):
|
||||||
self.grid_h = grid_h
|
self.grid_h = grid_h
|
||||||
@ -121,6 +123,8 @@ class Renderer:
|
|||||||
|
|
||||||
pygame.display.flip()
|
pygame.display.flip()
|
||||||
self.clock.tick(self.fps)
|
self.clock.tick(self.fps)
|
||||||
|
rgb_obs = pygame.surfarray.array3d(self.screen)
|
||||||
|
return torch.from_numpy(rgb_obs).permute(2, 0, 1)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
@ -1,11 +1,11 @@
|
|||||||
from typing import Union, NamedTuple, Dict
|
from typing import Union, NamedTuple
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from environments.factory.base.base_factory import BaseFactory
|
from environments.factory.base.base_factory import BaseFactory
|
||||||
from environments.factory.base.objects import Agent, Action, Entity
|
from environments.factory.base.objects import Agent, Action, Entity
|
||||||
from environments.factory.base.registers import EntityObjectRegister, ObjectRegister
|
from environments.factory.base.registers import EntityObjectRegister, ObjectRegister
|
||||||
from environments.factory.renderer import RenderEntity
|
from environments.factory.base.renderer import RenderEntity
|
||||||
from environments.helpers import Constants as c
|
from environments.helpers import Constants as c
|
||||||
|
|
||||||
from environments import helpers as h
|
from environments import helpers as h
|
||||||
|
@ -1,6 +1,5 @@
|
|||||||
import time
|
import time
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from pathlib import Path
|
|
||||||
from typing import List, Union, NamedTuple, Dict
|
from typing import List, Union, NamedTuple, Dict
|
||||||
import random
|
import random
|
||||||
|
|
||||||
@ -12,8 +11,7 @@ from environments.factory.base.base_factory import BaseFactory
|
|||||||
from environments.factory.base.objects import Agent, Action, Entity, Tile
|
from environments.factory.base.objects import Agent, Action, Entity, Tile
|
||||||
from environments.factory.base.registers import Entities, MovingEntityObjectRegister
|
from environments.factory.base.registers import Entities, MovingEntityObjectRegister
|
||||||
|
|
||||||
from environments.factory.renderer import RenderEntity
|
from environments.factory.base.renderer import RenderEntity
|
||||||
from environments.logging.recorder import RecorderCallback
|
|
||||||
from environments.utility_classes import ObservationProperties
|
from environments.utility_classes import ObservationProperties
|
||||||
|
|
||||||
CLEAN_UP_ACTION = h.EnvActions.CLEAN_UP
|
CLEAN_UP_ACTION = h.EnvActions.CLEAN_UP
|
||||||
|
@ -10,9 +10,9 @@ from environments.helpers import Constants as c
|
|||||||
from environments import helpers as h
|
from environments import helpers as h
|
||||||
from environments.factory.base.objects import Agent, Entity, Action, Tile, MoveableEntity
|
from environments.factory.base.objects import Agent, Entity, Action, Tile, MoveableEntity
|
||||||
from environments.factory.base.registers import Entities, EntityObjectRegister, ObjectRegister, \
|
from environments.factory.base.registers import Entities, EntityObjectRegister, ObjectRegister, \
|
||||||
MovingEntityObjectRegister, Register
|
MovingEntityObjectRegister
|
||||||
|
|
||||||
from environments.factory.renderer import RenderEntity
|
from environments.factory.base.renderer import RenderEntity
|
||||||
|
|
||||||
|
|
||||||
NO_ITEM = 0
|
NO_ITEM = 0
|
||||||
|
@ -1,29 +1,100 @@
|
|||||||
from environments.factory import make
|
from environments.factory import make
|
||||||
import salina
|
from salina import Workspace, TAgent
|
||||||
|
from salina.agents.gyma import AutoResetGymAgent, GymAgent
|
||||||
|
from salina.agents import Agents, TemporalAgent
|
||||||
|
from salina.rl.functional import _index
|
||||||
import torch
|
import torch
|
||||||
from gym.wrappers import FrameStack
|
import torch.nn as nn
|
||||||
|
from torch.nn.utils import spectral_norm
|
||||||
|
import torch.optim as optim
|
||||||
|
from torch.distributions import Categorical
|
||||||
|
|
||||||
|
|
||||||
class MyAgent(salina.TAgent):
|
class A2CAgent(TAgent):
|
||||||
def __init__(self):
|
def __init__(self, observation_size, hidden_size, n_actions):
|
||||||
super(MyAgent, self).__init__()
|
super().__init__()
|
||||||
|
self.model = nn.Sequential(
|
||||||
|
nn.Flatten(),
|
||||||
|
nn.Linear(observation_size, hidden_size),
|
||||||
|
nn.ELU(),
|
||||||
|
nn.Linear(hidden_size, hidden_size),
|
||||||
|
nn.ELU(),
|
||||||
|
nn.Linear(hidden_size, n_actions),
|
||||||
|
)
|
||||||
|
self.critic_model = nn.Sequential(
|
||||||
|
nn.Flatten(),
|
||||||
|
nn.Linear(observation_size, hidden_size),
|
||||||
|
nn.ELU(),
|
||||||
|
spectral_norm(nn.Linear(hidden_size, 1)),
|
||||||
|
)
|
||||||
|
|
||||||
def forward(self, t, **kwargs):
|
def forward(self, t, stochastic, **kwargs):
|
||||||
self.set(('timer', t), torch.tensor([t]))
|
observation = self.get(("env/env_obs", t))
|
||||||
|
scores = self.model(observation)
|
||||||
|
probs = torch.softmax(scores, dim=-1)
|
||||||
|
critic = self.critic_model(observation).squeeze(-1)
|
||||||
|
if stochastic:
|
||||||
|
action = torch.distributions.Categorical(probs).sample()
|
||||||
|
else:
|
||||||
|
action = probs.argmax(1)
|
||||||
|
|
||||||
|
self.set(("action", t), action)
|
||||||
|
self.set(("action_probs", t), probs)
|
||||||
|
self.set(("critic", t), critic)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
n_agents = 1
|
# Setup agents and workspace
|
||||||
env = make('DirtyFactory-v0', n_agents=n_agents)
|
env_agent = AutoResetGymAgent(make, dict(env_str='DirtyFactory-v0'), n_envs=1)
|
||||||
env = FrameStack(env, num_stack=3)
|
a2c_agent = A2CAgent(3*4*5*5, 96, 10)
|
||||||
env.reset()
|
workspace = Workspace()
|
||||||
agent = MyAgent()
|
|
||||||
workspace = salina.Workspace()
|
|
||||||
agent(workspace, t=0, n_steps=10)
|
|
||||||
|
|
||||||
print(workspace)
|
eval_agent = Agents(GymAgent(make, dict(env_str='DirtyFactory-v0'), n_envs=1), a2c_agent)
|
||||||
|
for i in range(100):
|
||||||
|
eval_agent(workspace, t=i, save_render=True, stochastic=True)
|
||||||
|
|
||||||
|
assert False
|
||||||
|
# combine agents
|
||||||
|
acquisition_agent = TemporalAgent(Agents(env_agent, a2c_agent))
|
||||||
|
acquisition_agent.seed(0)
|
||||||
|
|
||||||
for i in range(1000):
|
# optimizers & other parameters
|
||||||
state, *_ = env.step([env.unwrapped.action_space.sample() for _ in range(n_agents)])
|
optimizer = optim.Adam(a2c_agent.parameters(), lr=1e-3)
|
||||||
#env.render()
|
n_timesteps = 10
|
||||||
|
|
||||||
|
# Decision making loop
|
||||||
|
for epoch in range(200000):
|
||||||
|
workspace.zero_grad()
|
||||||
|
if epoch > 0:
|
||||||
|
workspace.copy_n_last_steps(1)
|
||||||
|
acquisition_agent(workspace, t=1, n_steps=n_timesteps-1, stochastic=True)
|
||||||
|
else:
|
||||||
|
acquisition_agent(workspace, t=0, n_steps=n_timesteps, stochastic=True)
|
||||||
|
#for k in workspace.keys():
|
||||||
|
# print(f'{k} ==> {workspace[k].size()}')
|
||||||
|
critic, done, action_probs, reward, action = workspace[
|
||||||
|
"critic", "env/done", "action_probs", "env/reward", "action"
|
||||||
|
]
|
||||||
|
|
||||||
|
target = reward[1:] + 0.99 * critic[1:].detach() * (1 - done[1:].float())
|
||||||
|
td = target - critic[:-1]
|
||||||
|
td_error = td ** 2
|
||||||
|
critic_loss = td_error.mean()
|
||||||
|
entropy_loss = Categorical(action_probs).entropy().mean()
|
||||||
|
action_logp = _index(action_probs, action).log()
|
||||||
|
a2c_loss = action_logp[:-1] * td.detach()
|
||||||
|
a2c_loss = a2c_loss.mean()
|
||||||
|
loss = (
|
||||||
|
-0.001 * entropy_loss
|
||||||
|
+ 1.0 * critic_loss
|
||||||
|
- 0.1 * a2c_loss
|
||||||
|
)
|
||||||
|
optimizer.zero_grad()
|
||||||
|
loss.backward()
|
||||||
|
optimizer.step()
|
||||||
|
|
||||||
|
# Compute the cumulated reward on final_state
|
||||||
|
creward = workspace["env/cumulated_reward"]
|
||||||
|
creward = creward[done]
|
||||||
|
if creward.size()[0] > 0:
|
||||||
|
print(f"Cumulative reward at A2C step #{(1+epoch)*n_timesteps}: {creward.mean().item()}")
|
Reference in New Issue
Block a user