cleanup algos + adjusted renderer to support "ray casting"
This commit is contained in:
@ -1,5 +1,5 @@
|
||||
from typing import NamedTuple, Union
|
||||
from collections import deque, OrderedDict
|
||||
from collections import deque, OrderedDict, defaultdict
|
||||
import numpy as np
|
||||
import random
|
||||
import torch
|
||||
@ -18,12 +18,13 @@ class Experience(NamedTuple):
|
||||
|
||||
|
||||
class BaseLearner:
|
||||
def __init__(self, env, n_agents=1, train_every=('step', 4), n_grad_steps=1):
|
||||
def __init__(self, env, n_agents=1, train_every=('step', 4), n_grad_steps=1, stack_n_frames=1):
|
||||
assert train_every[0] in ['step', 'episode'], 'train_every[0] must be one of ["step", "episode"]'
|
||||
self.env = env
|
||||
self.n_agents = n_agents
|
||||
self.n_grad_steps = n_grad_steps
|
||||
self.train_every = train_every
|
||||
self.stack_n_frames = deque(stack_n_frames)
|
||||
self.device = 'cpu'
|
||||
self.n_updates = 0
|
||||
self.step = 0
|
||||
@ -102,8 +103,8 @@ class BaseBuffer:
|
||||
def __len__(self):
|
||||
return len(self.experience)
|
||||
|
||||
def add(self, experience):
|
||||
self.experience.append(experience)
|
||||
def add(self, exp: Experience):
|
||||
self.experience.append(exp)
|
||||
|
||||
def sample(self, k, cer=4):
|
||||
sample = random.choices(self.experience, k=k-cer)
|
||||
@ -113,9 +114,22 @@ class BaseBuffer:
|
||||
actions = torch.tensor([e.action for e in sample]).long()
|
||||
rewards = torch.tensor([e.reward for e in sample]).float().view(-1, 1)
|
||||
dones = torch.tensor([e.done for e in sample]).float().view(-1, 1)
|
||||
#print(observations.shape, next_observations.shape, actions.shape, rewards.shape, dones.shape)
|
||||
return Experience(observations, next_observations, actions, rewards, dones)
|
||||
|
||||
|
||||
class TrajectoryBuffer(BaseBuffer):
|
||||
def __init__(self, size):
|
||||
super(TrajectoryBuffer, self).__init__(size)
|
||||
self.experience = defaultdict(list)
|
||||
|
||||
def add(self, exp: Experience):
|
||||
self.experience[exp.episode].append(exp)
|
||||
if len(self.experience) > self.size:
|
||||
oldest_traj_key = list(sorted(self.experience.keys()))[0]
|
||||
del self.experience[oldest_traj_key]
|
||||
|
||||
|
||||
def soft_update(local_model, target_model, tau):
|
||||
# taken from https://github.com/BY571/Munchausen-RL/blob/master/M-DQN.ipynb
|
||||
for target_param, local_param in zip(target_model.parameters(), local_model.parameters()):
|
||||
@ -152,9 +166,10 @@ class BaseDDQN(BaseDQN):
|
||||
def __init__(self,
|
||||
backbone_dims=[3*5*5, 64, 64],
|
||||
value_dims=[64, 1],
|
||||
advantage_dims=[64, 9]):
|
||||
advantage_dims=[64, 9],
|
||||
activation='elu'):
|
||||
super(BaseDDQN, self).__init__(backbone_dims)
|
||||
self.net = mlp_maker(backbone_dims, flatten=True)
|
||||
self.net = mlp_maker(backbone_dims, activation=activation, flatten=True)
|
||||
self.value_head = mlp_maker(value_dims)
|
||||
self.advantage_head = mlp_maker(advantage_dims)
|
||||
|
||||
|
Reference in New Issue
Block a user