added mlpmaker

This commit is contained in:
romue 2021-06-22 16:23:39 +02:00
parent c5d677e9ba
commit b5d729e597

View File

@ -1,5 +1,5 @@
from typing import NamedTuple, Union from typing import NamedTuple, Union
from collections import deque from collections import deque, OrderedDict
import numpy as np import numpy as np
import random import random
import torch import torch
@ -39,42 +39,27 @@ class BaseBuffer:
return Experience(observations, next_observations, actions, rewards, dones) return Experience(observations, next_observations, actions, rewards, dones)
class BaseDDQN(nn.Module):
def __init__(self):
super(BaseDDQN, self).__init__()
self.net = nn.Sequential(
nn.Flatten(),
nn.Linear(3*5*5, 64),
nn.ELU(),
nn.Linear(64, 64),
nn.ELU()
)
self.value_head = nn.Linear(64, 1)
self.advantage_head = nn.Linear(64, 9)
def act(self, x) -> np.ndarray:
with torch.no_grad():
action = self.forward(x).max(-1)[1].numpy()
return action
def forward(self, x): def soft_update(local_model, target_model, tau):
features = self.net(x) # taken from https://github.com/BY571/Munchausen-RL/blob/master/M-DQN.ipynb
advantages = self.advantage_head(features) for target_param, local_param in zip(target_model.parameters(), local_model.parameters()):
values = self.value_head(features) target_param.data.copy_(tau*local_param.data + (1.-tau)*target_param.data)
return values + (advantages - advantages.mean())
def mlp_maker(dims):
layers = [('Flatten', nn.Flatten())]
for i in range(1, len(dims)):
layers.append((f'Linear#{i - 1}', nn.Linear(dims[i - 1], dims[i])))
if i != len(dims) - 1:
layers.append(('ELU', nn.ELU()))
return nn.Sequential(OrderedDict(layers))
class BaseDQN(nn.Module): class BaseDQN(nn.Module):
def __init__(self): def __init__(self, dims=[3*5*5, 64, 64, 9]):
super(BaseDQN, self).__init__() super(BaseDQN, self).__init__()
self.net = nn.Sequential( self.net = mlp_maker(dims)
nn.Flatten(),
nn.Linear(3*5*5, 64),
nn.ELU(),
nn.Linear(64, 64),
nn.ELU(),
nn.Linear(64, 9)
)
def act(self, x) -> np.ndarray: def act(self, x) -> np.ndarray:
with torch.no_grad(): with torch.no_grad():
@ -85,23 +70,34 @@ class BaseDQN(nn.Module):
return self.net(x) return self.net(x)
def soft_update(local_model, target_model, tau): class BaseDDQN(BaseDQN):
# taken from https://github.com/BY571/Munchausen-RL/blob/master/M-DQN.ipynb def __init__(self,
for target_param, local_param in zip(target_model.parameters(), local_model.parameters()): backbone_dims=[3*5*5, 64, 64],
target_param.data.copy_(tau*local_param.data + (1.-tau)*target_param.data) value_dims=[64,1],
advantage_dims=[64,9]):
super(BaseDDQN, self).__init__(backbone_dims)
self.value_head = mlp_maker(value_dims)
self.advantage_head = mlp_maker(advantage_dims)
def forward(self, x):
features = self.net(x)
advantages = self.advantage_head(features)
values = self.value_head(features)
return values + (advantages - advantages.mean())
class BaseQlearner: class BaseQlearner:
def __init__(self, q_net, target_q_net, env, buffer, target_update, eps_end, n_agents=1, def __init__(self, q_net, target_q_net, env, buffer_size, target_update, eps_end, n_agents=1,
gamma=0.99, train_every_n_steps=4, n_grad_steps=1, tau=1.0, max_grad_norm=10, gamma=0.99, train_every_n_steps=4, n_grad_steps=1, tau=1.0, max_grad_norm=10,
exploration_fraction=0.2, batch_size=64, lr=1e-4, reg_weight=0.0): exploration_fraction=0.2, batch_size=64, lr=1e-4, reg_weight=0.0):
self.q_net = q_net self.q_net = q_net
print(self.q_net)
self.target_q_net = target_q_net self.target_q_net = target_q_net
#self.q_net.apply(self.weights_init)
self.target_q_net.eval() self.target_q_net.eval()
soft_update(self.q_net, self.target_q_net, tau=1.0) soft_update(self.q_net, self.target_q_net, tau=1.0)
self.env = env self.env = env
self.buffer = buffer self.buffer = BaseBuffer(buffer_size)
self.target_update = target_update self.target_update = target_update
self.eps = 1. self.eps = 1.
self.eps_end = eps_end self.eps_end = eps_end
@ -128,7 +124,7 @@ class BaseQlearner:
@staticmethod @staticmethod
def weights_init(module, activation='leaky_relu'): def weights_init(module, activation='leaky_relu'):
if isinstance(module, (nn.Linear, nn.Conv2d)): if isinstance(module, (nn.Linear, nn.Conv2d)):
nn.init.xavier_normal_(module.weight, gain=torch.nn.init.calculate_gain(activation)) nn.init.orthogonal_(module.weight, gain=torch.nn.init.calculate_gain(activation))
if module.bias is not None: if module.bias is not None:
module.bias.data.fill_(0.0) module.bias.data.fill_(0.0)
@ -179,7 +175,6 @@ class BaseQlearner:
current_q_values = self.q_net(obs) current_q_values = self.q_net(obs)
current_q_values = torch.gather(current_q_values, dim=-1, index=action) current_q_values = torch.gather(current_q_values, dim=-1, index=action)
next_q_values_raw = self.target_q_net(next_obs).max(dim=-1)[0].reshape(-1, 1).detach() next_q_values_raw = self.target_q_net(next_obs).max(dim=-1)[0].reshape(-1, 1).detach()
#print(current_q_values.shape, next_q_values_raw.shape)
return current_q_values, next_q_values_raw return current_q_values, next_q_values_raw
def _backprop_loss(self, loss): def _backprop_loss(self, loss):
@ -265,8 +260,9 @@ class MDQN(BaseQlearner):
if __name__ == '__main__': if __name__ == '__main__':
from environments.factory.simple_factory import SimpleFactory, DirtProperties, MovementProperties from environments.factory.simple_factory import SimpleFactory, DirtProperties, MovementProperties
from gym.wrappers import FrameStack
N_AGENTS = 2 N_AGENTS = 1
dirt_props = DirtProperties(clean_amount=3, gain_amount=0.2, max_global_amount=30, dirt_props = DirtProperties(clean_amount=3, gain_amount=0.2, max_global_amount=30,
max_local_amount=5, spawn_frequency=1, max_spawn_ratio=0.05) max_local_amount=5, spawn_frequency=1, max_spawn_ratio=0.05)
@ -275,7 +271,7 @@ if __name__ == '__main__':
allow_no_op=False) allow_no_op=False)
env = SimpleFactory(dirt_properties=dirt_props, movement_properties=move_props, n_agents=N_AGENTS, pomdp_radius=2, max_steps=400, omit_agent_slice_in_obs=False, combin_agent_slices_in_obs=True) env = SimpleFactory(dirt_properties=dirt_props, movement_properties=move_props, n_agents=N_AGENTS, pomdp_radius=2, max_steps=400, omit_agent_slice_in_obs=False, combin_agent_slices_in_obs=True)
dqn, target_dqn = BaseDQN(), BaseDQN() dqn, target_dqn = BaseDDQN(), BaseDDQN()
learner = BaseQlearner(dqn, target_dqn, env, BaseBuffer(40000), target_update=3500, lr=0.0007, gamma=0.99, n_agents=N_AGENTS, tau=0.95, max_grad_norm=10, learner = MDQN(dqn, target_dqn, env, 40000, target_update=3500, lr=0.0007, gamma=0.99, n_agents=N_AGENTS, tau=0.95, max_grad_norm=10,
train_every_n_steps=4, eps_end=0.025, n_grad_steps=1, reg_weight=0.1, exploration_fraction=0.25, batch_size=64) train_every_n_steps=4, eps_end=0.025, n_grad_steps=1, reg_weight=0.1, exploration_fraction=0.25, batch_size=64)
learner.learn(100000) learner.learn(100000)