first icm steps

This commit is contained in:
romue
2021-07-30 14:01:09 +02:00
parent aebe3b2f60
commit ebf49cadea
6 changed files with 61 additions and 35 deletions

View File

@ -24,7 +24,7 @@ class BaseLearner:
self.n_agents = n_agents
self.n_grad_steps = n_grad_steps
self.train_every = train_every
self.stack_n_frames = deque(stack_n_frames)
self.stack_n_frames = deque(maxlen=stack_n_frames)
self.device = 'cpu'
self.n_updates = 0
self.step = 0
@ -51,6 +51,9 @@ class BaseLearner:
def on_episode_end(self, n_steps):
pass
def on_all_done(self):
pass
def train(self):
pass
@ -93,6 +96,7 @@ class BaseLearner:
f' eps: {self.eps:.4f}\tRunning loss: {sum(list(self.running_loss)) / len(self.running_loss):.4f}\tUpdates:{self.n_updates}')
except Exception as e:
pass
self.on_all_done()
class BaseBuffer:
@ -180,18 +184,17 @@ class BaseDDQN(BaseDQN):
return values + (advantages - advantages.mean())
class QTRANtestNet(nn.Module):
def __init__(self, backbone_dims=[3*5*5, 64, 64], q_head=[64, 9]):
super(QTRANtestNet, self).__init__()
self.backbone = mlp_maker(backbone_dims, flatten=True, activation_last='elu')
self.q_head = mlp_maker(q_head)
class BaseICM(nn.Module):
def __init__(self, backbone_dims=[2*3*5*5, 64, 64], head_dims=[2*64, 64, 9]):
super(BaseICM, self).__init__()
self.backbone = mlp_maker(backbone_dims, flatten=True)
self.icm = mlp_maker(head_dims)
self.ce = nn.CrossEntropyLoss()
def forward(self, x):
features = self.backbone(x)
qs = self.q_head(features)
return qs, features
@torch.no_grad()
def act(self, x) -> np.ndarray:
action = self.forward(x)[0].max(-1)[1].numpy()
return action
def forward(self, s0, s1, a):
phi_s0 = self.backbone(s0)
phi_s1 = self.backbone(s1)
cat = torch.cat((phi_s0, phi_s1), dim=1)
a_prime = torch.softmax(self.icm(cat), dim=-1)
ce = self.ce(a_prime, a)
return dict(prediction=a_prime, loss=ce)