added first working MAPPO implementation

This commit is contained in:
Robert Müller
2022-01-28 11:07:25 +01:00
parent ffc47752a7
commit b09c461754
11 changed files with 194 additions and 61 deletions

View File

@ -13,14 +13,16 @@ class ActorCriticMemory(object):
self.__actions = []
self.__rewards = []
self.__dones = []
self.__hiddens_actor = []
self.__hiddens_actor = []
self.__hiddens_critic = []
self.__logits = []
self.__values = []
def __len__(self):
return len(self.__states)
@property
def observation(self):
def observation(self): # add time dimension through stacking
return torch.stack(self.__states, 0).unsqueeze(0) # 1 x timesteps x hidden dim
@property
@ -47,6 +49,14 @@ class ActorCriticMemory(object):
def done(self):
return torch.tensor(self.__dones).float().unsqueeze(0) # 1 x timesteps
@property
def logits(self): # assumes a trailing 1 for time dimension - common when using output from NN
return torch.cat(self.__logits, 0).unsqueeze(0) # 1 x timesteps x actions
@property
def values(self):
return torch.cat(self.__values, 0).unsqueeze(0) # 1 x timesteps x actions
def add_observation(self, state: Union[Tensor, np.ndarray]):
self.__states.append(state if isinstance(state, Tensor) else torch.from_numpy(state))
@ -69,6 +79,12 @@ class ActorCriticMemory(object):
def add_done(self, done: bool):
self.__dones.append(done)
def add_logits(self, logits: Tensor):
self.__logits.append(logits)
def add_values(self, logits: Tensor):
self.__values.append(logits)
def add(self, **kwargs):
for k, v in kwargs.items():
func = getattr(ActorCriticMemory, f'add_{k}')
@ -129,3 +145,14 @@ class MARLActorCriticMemory(object):
all_hc = [mem.hidden_critic for mem in self.memories]
return torch.cat(all_hc, 0) # agents x layers x timesteps x hidden dim
@property
def logits(self):
all_lgts = [mem.logits for mem in self.memories]
return torch.cat(all_lgts, 0) # agents x layers x timesteps x hidden dim
@property
def values(self):
all_vals = [mem.values for mem in self.memories]
return torch.cat(all_vals, 0) # agents x layers x timesteps x hidden dim