2022-01-28 11:07:25 +01:00

159 lines
4.9 KiB
Python

import torch
from typing import Union, List
from torch import Tensor
import numpy as np
class ActorCriticMemory(object):
def __init__(self):
self.reset()
def reset(self):
self.__states = []
self.__actions = []
self.__rewards = []
self.__dones = []
self.__hiddens_actor = []
self.__hiddens_critic = []
self.__logits = []
self.__values = []
def __len__(self):
return len(self.__states)
@property
def observation(self): # add time dimension through stacking
return torch.stack(self.__states, 0).unsqueeze(0) # 1 x timesteps x hidden dim
@property
def hidden_actor(self):
if len(self.__hiddens_actor) == 1:
return self.__hiddens_actor[0]
return torch.stack(self.__hiddens_actor, 0) # layers x timesteps x hidden dim
@property
def hidden_critic(self):
if len(self.__hiddens_critic) == 1:
return self.__hiddens_critic[0]
return torch.stack(self.__hiddens_critic, 0) # layers x timesteps x hidden dim
@property
def reward(self):
return torch.tensor(self.__rewards).float().unsqueeze(0) # 1 x timesteps
@property
def action(self):
return torch.tensor(self.__actions).long().unsqueeze(0) # 1 x timesteps+1
@property
def done(self):
return torch.tensor(self.__dones).float().unsqueeze(0) # 1 x timesteps
@property
def logits(self): # assumes a trailing 1 for time dimension - common when using output from NN
return torch.cat(self.__logits, 0).unsqueeze(0) # 1 x timesteps x actions
@property
def values(self):
return torch.cat(self.__values, 0).unsqueeze(0) # 1 x timesteps x actions
def add_observation(self, state: Union[Tensor, np.ndarray]):
self.__states.append(state if isinstance(state, Tensor) else torch.from_numpy(state))
def add_hidden_actor(self, hidden: Tensor):
# 1x layers x hidden dim
if len(hidden.shape) < 3: hidden = hidden.unsqueeze(0)
self.__hiddens_actor.append(hidden)
def add_hidden_critic(self, hidden: Tensor):
# 1x layers x hidden dim
if len(hidden.shape) < 3: hidden = hidden.unsqueeze(0)
self.__hiddens_critic.append(hidden)
def add_action(self, action: int):
self.__actions.append(action)
def add_reward(self, reward: float):
self.__rewards.append(reward)
def add_done(self, done: bool):
self.__dones.append(done)
def add_logits(self, logits: Tensor):
self.__logits.append(logits)
def add_values(self, logits: Tensor):
self.__values.append(logits)
def add(self, **kwargs):
for k, v in kwargs.items():
func = getattr(ActorCriticMemory, f'add_{k}')
func(self, v)
class MARLActorCriticMemory(object):
def __init__(self, n_agents):
self.n_agents = n_agents
self.memories = [
ActorCriticMemory() for _ in range(n_agents)
]
def __call__(self, agent_i):
return self.memories[agent_i]
def __len__(self):
return len(self.memories[0]) # todo add assertion check!
def reset(self):
for mem in self.memories:
mem.reset()
def add(self, **kwargs):
# todo try catch - print all possible functions
for agent_i in range(self.n_agents):
for k, v in kwargs.items():
func = getattr(ActorCriticMemory, f'add_{k}')
func(self.memories[agent_i], v[agent_i])
@property
def observation(self):
all_obs = [mem.observation for mem in self.memories]
return torch.cat(all_obs, 0) # agents x timesteps+1 x ...
@property
def action(self):
all_actions = [mem.action for mem in self.memories]
return torch.cat(all_actions, 0) # agents x timesteps+1 x ...
@property
def done(self):
all_dones = [mem.done for mem in self.memories]
return torch.cat(all_dones, 0).float() # agents x timesteps x ...
@property
def reward(self):
all_rewards = [mem.reward for mem in self.memories]
return torch.cat(all_rewards, 0).float() # agents x timesteps x ...
@property
def hidden_actor(self):
all_ha = [mem.hidden_actor for mem in self.memories]
return torch.cat(all_ha, 0) # agents x layers x x timesteps x hidden dim
@property
def hidden_critic(self):
all_hc = [mem.hidden_critic for mem in self.memories]
return torch.cat(all_hc, 0) # agents x layers x timesteps x hidden dim
@property
def logits(self):
all_lgts = [mem.logits for mem in self.memories]
return torch.cat(all_lgts, 0) # agents x layers x timesteps x hidden dim
@property
def values(self):
all_vals = [mem.values for mem in self.memories]
return torch.cat(all_vals, 0) # agents x layers x timesteps x hidden dim