100 lines
3.5 KiB
Python
100 lines
3.5 KiB
Python
from environments.factory import make
|
|
from salina import Workspace, TAgent
|
|
from salina.agents.gyma import AutoResetGymAgent, GymAgent
|
|
from salina.agents import Agents, TemporalAgent
|
|
from salina.rl.functional import _index
|
|
import torch
|
|
import torch.nn as nn
|
|
from torch.nn.utils import spectral_norm
|
|
import torch.optim as optim
|
|
from torch.distributions import Categorical
|
|
|
|
|
|
class A2CAgent(TAgent):
|
|
def __init__(self, observation_size, hidden_size, n_actions):
|
|
super().__init__()
|
|
self.model = nn.Sequential(
|
|
nn.Flatten(),
|
|
nn.Linear(observation_size, hidden_size),
|
|
nn.ELU(),
|
|
nn.Linear(hidden_size, hidden_size),
|
|
nn.ELU(),
|
|
nn.Linear(hidden_size, n_actions),
|
|
)
|
|
self.critic_model = nn.Sequential(
|
|
nn.Flatten(),
|
|
nn.Linear(observation_size, hidden_size),
|
|
nn.ELU(),
|
|
spectral_norm(nn.Linear(hidden_size, 1)),
|
|
)
|
|
|
|
def forward(self, t, stochastic, **kwargs):
|
|
observation = self.get(("env/env_obs", t))
|
|
scores = self.model(observation)
|
|
probs = torch.softmax(scores, dim=-1)
|
|
critic = self.critic_model(observation).squeeze(-1)
|
|
if stochastic:
|
|
action = torch.distributions.Categorical(probs).sample()
|
|
else:
|
|
action = probs.argmax(1)
|
|
|
|
self.set(("action", t), action)
|
|
self.set(("action_probs", t), probs)
|
|
self.set(("critic", t), critic)
|
|
|
|
|
|
if __name__ == '__main__':
|
|
# Setup agents and workspace
|
|
env_agent = AutoResetGymAgent(make, dict(env_str='DirtyFactory-v0'), n_envs=1)
|
|
a2c_agent = A2CAgent(3*4*5*5, 96, 10)
|
|
workspace = Workspace()
|
|
|
|
eval_agent = Agents(GymAgent(make, dict(env_str='DirtyFactory-v0'), n_envs=1), a2c_agent)
|
|
for i in range(100):
|
|
eval_agent(workspace, t=i, save_render=True, stochastic=True)
|
|
|
|
assert False
|
|
# combine agents
|
|
acquisition_agent = TemporalAgent(Agents(env_agent, a2c_agent))
|
|
acquisition_agent.seed(0)
|
|
|
|
# optimizers & other parameters
|
|
optimizer = optim.Adam(a2c_agent.parameters(), lr=1e-3)
|
|
n_timesteps = 10
|
|
|
|
# Decision making loop
|
|
for epoch in range(200000):
|
|
workspace.zero_grad()
|
|
if epoch > 0:
|
|
workspace.copy_n_last_steps(1)
|
|
acquisition_agent(workspace, t=1, n_steps=n_timesteps-1, stochastic=True)
|
|
else:
|
|
acquisition_agent(workspace, t=0, n_steps=n_timesteps, stochastic=True)
|
|
#for k in workspace.keys():
|
|
# print(f'{k} ==> {workspace[k].size()}')
|
|
critic, done, action_probs, reward, action = workspace[
|
|
"critic", "env/done", "action_probs", "env/reward", "action"
|
|
]
|
|
|
|
target = reward[1:] + 0.99 * critic[1:].detach() * (1 - done[1:].float())
|
|
td = target - critic[:-1]
|
|
td_error = td ** 2
|
|
critic_loss = td_error.mean()
|
|
entropy_loss = Categorical(action_probs).entropy().mean()
|
|
action_logp = _index(action_probs, action).log()
|
|
a2c_loss = action_logp[:-1] * td.detach()
|
|
a2c_loss = a2c_loss.mean()
|
|
loss = (
|
|
-0.001 * entropy_loss
|
|
+ 1.0 * critic_loss
|
|
- 0.1 * a2c_loss
|
|
)
|
|
optimizer.zero_grad()
|
|
loss.backward()
|
|
optimizer.step()
|
|
|
|
# Compute the cumulated reward on final_state
|
|
creward = workspace["env/cumulated_reward"]
|
|
creward = creward[done]
|
|
if creward.size()[0] > 0:
|
|
print(f"Cumulative reward at A2C step #{(1+epoch)*n_timesteps}: {creward.mean().item()}") |