import numpy as np from collections import deque import torch from typing import Union from torch import Tensor from torch.utils.data import Dataset, ConcatDataset import random class ActorCriticMemory(object): def __init__(self, capacity=10): self.capacity = capacity self.reset() def reset(self): self.__actions = LazyTensorFiFoQueue(maxlen=self.capacity+1) self.__hidden_actor = LazyTensorFiFoQueue(maxlen=self.capacity+1) self.__hidden_critic = LazyTensorFiFoQueue(maxlen=self.capacity+1) self.__states = LazyTensorFiFoQueue(maxlen=self.capacity+1) self.__rewards = LazyTensorFiFoQueue(maxlen=self.capacity+1) self.__dones = LazyTensorFiFoQueue(maxlen=self.capacity+1) self.__logits = LazyTensorFiFoQueue(maxlen=self.capacity+1) self.__values = LazyTensorFiFoQueue(maxlen=self.capacity+1) def __len__(self): return len(self.__rewards) - 1 @property def observation(self, sls=slice(0, None)): # add time dimension through stacking return self.__states[sls].unsqueeze(0) # 1 x time x hidden dim @property def hidden_actor(self, sls=slice(0, None)): # 1 x n_layers x dim return self.__hidden_actor[sls].unsqueeze(0) # 1 x time x n_layers x dim @property def hidden_critic(self, sls=slice(0, None)): # 1 x n_layers x dim return self.__hidden_critic[sls].unsqueeze(0) # 1 x time x n_layers x dim @property def reward(self, sls=slice(0, None)): return self.__rewards[sls].squeeze().unsqueeze(0) # 1 x time @property def action(self, sls=slice(0, None)): return self.__actions[sls].long().squeeze().unsqueeze(0) # 1 x time @property def done(self, sls=slice(0, None)): return self.__dones[sls].float().squeeze().unsqueeze(0) # 1 x time @property def logits(self, sls=slice(0, None)): # assumes a trailing 1 for time dimension - common when using output from NN return self.__logits[sls].squeeze().unsqueeze(0) # 1 x time x actions @property def values(self, sls=slice(0, None)): return self.__values[sls].squeeze().unsqueeze(0) # 1 x time x actions def add_observation(self, state: Union[Tensor, np.ndarray]): self.__states.append(state if isinstance(state, Tensor) else torch.from_numpy(state)) def add_hidden_actor(self, hidden: Tensor): # layers x hidden dim self.__hidden_actor.append(hidden) def add_hidden_critic(self, hidden: Tensor): # layers x hidden dim self.__hidden_critic.append(hidden) def add_action(self, action: Union[int, Tensor]): if not isinstance(action, Tensor): action = torch.tensor(action) self.__actions.append(action) def add_reward(self, reward: Union[float, Tensor]): if not isinstance(reward, Tensor): reward = torch.tensor(reward) self.__rewards.append(reward) def add_done(self, done: bool): if not isinstance(done, Tensor): done = torch.tensor(done) self.__dones.append(done) def add_logits(self, logits: Tensor): self.__logits.append(logits) def add_values(self, values: Tensor): self.__values.append(values) def add(self, **kwargs): for k, v in kwargs.items(): func = getattr(ActorCriticMemory, f'add_{k}') func(self, v) class MARLActorCriticMemory(object): def __init__(self, n_agents, capacity): self.n_agents = n_agents self.memories = [ ActorCriticMemory(capacity) for _ in range(n_agents) ] def __call__(self, agent_i): return self.memories[agent_i] def __len__(self): return len(self.memories[0]) # todo add assertion check! def reset(self): for mem in self.memories: mem.reset() def add(self, **kwargs): for agent_i in range(self.n_agents): for k, v in kwargs.items(): func = getattr(ActorCriticMemory, f'add_{k}') func(self.memories[agent_i], v[agent_i]) def __getattr__(self, attr): all_attrs = [getattr(mem, attr) for mem in self.memories] return torch.cat(all_attrs, 0) # agents x time ... def chunk_dataloader(self, chunk_len, k): datasets = [ExperienceChunks(mem, chunk_len, k) for mem in self.memories] dataset = ConcatDataset(datasets) data = [dataset[i] for i in range(len(dataset))] data = custom_collate_fn(data) return data def custom_collate_fn(batch): elem = batch[0] return {key: torch.cat([d[key] for d in batch], dim=0) for key in elem} class ExperienceChunks(Dataset): def __init__(self, memory, chunk_len, k): assert chunk_len <= len(memory), 'chunk_len cannot be longer than the size of the memory' self.memory = memory self.chunk_len = chunk_len self.k = k @property def whitelist(self): whitelist = torch.ones(len(self.memory) - self.chunk_len) for d in self.memory.done.squeeze().nonzero().flatten(): whitelist[max((0, d-self.chunk_len-1)):d+2] = 0 whitelist[0] = 0 return whitelist.tolist() def sample(self, start=1): cl = self.chunk_len sample = dict(observation=self.memory.observation[:, start:start+cl+1], action=self.memory.action[:, start-1:start+cl], hidden_actor=self.memory.hidden_actor[:, start-1], hidden_critic=self.memory.hidden_critic[:, start-1], reward=self.memory.reward[:, start:start + cl], done=self.memory.done[:, start:start + cl], logits=self.memory.logits[:, start:start + cl], values=self.memory.values[:, start:start + cl]) return sample def __len__(self): return self.k def __getitem__(self, i): idx = random.choices(range(0, len(self.memory) - self.chunk_len), weights=self.whitelist, k=1) return self.sample(idx[0]) class LazyTensorFiFoQueue: def __init__(self, maxlen): self.maxlen = maxlen self.reset() def reset(self): self.__lazy_queue = deque(maxlen=self.maxlen) self.shape = None self.queue = None def shape_init(self, tensor: Tensor): self.shape = torch.Size([self.maxlen, *tensor.shape]) def build_tensor_queue(self): if len(self.__lazy_queue) > 0: block = torch.stack(list(self.__lazy_queue), dim=0) l = block.shape[0] if self.queue is None: self.queue = block elif self.true_len() <= self.maxlen: self.queue = torch.cat((self.queue, block), dim=0) else: self.queue = torch.cat((self.queue[l:], block), dim=0) self.__lazy_queue.clear() def append(self, data): if self.shape is None: self.shape_init(data) self.__lazy_queue.append(data) if len(self.__lazy_queue) >= self.maxlen: self.build_tensor_queue() def true_len(self): return len(self.__lazy_queue) + (0 if self.queue is None else self.queue.shape[0]) def __len__(self): return min((self.true_len(), self.maxlen)) def __str__(self): return f'LazyTensorFiFoQueue\tmaxlen: {self.maxlen}, shape: {self.shape}, ' \ f'len: {len(self)}, true_len: {self.true_len()}, elements in lazy queue: {len(self.__lazy_queue)}' def __getitem__(self, item_or_slice): self.build_tensor_queue() return self.queue[item_or_slice]