added first working MAPPO implementation
This commit is contained in:
@ -13,14 +13,16 @@ class ActorCriticMemory(object):
|
||||
self.__actions = []
|
||||
self.__rewards = []
|
||||
self.__dones = []
|
||||
self.__hiddens_actor = []
|
||||
self.__hiddens_actor = []
|
||||
self.__hiddens_critic = []
|
||||
self.__logits = []
|
||||
self.__values = []
|
||||
|
||||
def __len__(self):
|
||||
return len(self.__states)
|
||||
|
||||
@property
|
||||
def observation(self):
|
||||
def observation(self): # add time dimension through stacking
|
||||
return torch.stack(self.__states, 0).unsqueeze(0) # 1 x timesteps x hidden dim
|
||||
|
||||
@property
|
||||
@ -47,6 +49,14 @@ class ActorCriticMemory(object):
|
||||
def done(self):
|
||||
return torch.tensor(self.__dones).float().unsqueeze(0) # 1 x timesteps
|
||||
|
||||
@property
|
||||
def logits(self): # assumes a trailing 1 for time dimension - common when using output from NN
|
||||
return torch.cat(self.__logits, 0).unsqueeze(0) # 1 x timesteps x actions
|
||||
|
||||
@property
|
||||
def values(self):
|
||||
return torch.cat(self.__values, 0).unsqueeze(0) # 1 x timesteps x actions
|
||||
|
||||
def add_observation(self, state: Union[Tensor, np.ndarray]):
|
||||
self.__states.append(state if isinstance(state, Tensor) else torch.from_numpy(state))
|
||||
|
||||
@ -69,6 +79,12 @@ class ActorCriticMemory(object):
|
||||
def add_done(self, done: bool):
|
||||
self.__dones.append(done)
|
||||
|
||||
def add_logits(self, logits: Tensor):
|
||||
self.__logits.append(logits)
|
||||
|
||||
def add_values(self, logits: Tensor):
|
||||
self.__values.append(logits)
|
||||
|
||||
def add(self, **kwargs):
|
||||
for k, v in kwargs.items():
|
||||
func = getattr(ActorCriticMemory, f'add_{k}')
|
||||
@ -129,3 +145,14 @@ class MARLActorCriticMemory(object):
|
||||
all_hc = [mem.hidden_critic for mem in self.memories]
|
||||
return torch.cat(all_hc, 0) # agents x layers x timesteps x hidden dim
|
||||
|
||||
@property
|
||||
def logits(self):
|
||||
all_lgts = [mem.logits for mem in self.memories]
|
||||
return torch.cat(all_lgts, 0) # agents x layers x timesteps x hidden dim
|
||||
|
||||
@property
|
||||
def values(self):
|
||||
all_vals = [mem.values for mem in self.memories]
|
||||
return torch.cat(all_vals, 0) # agents x layers x timesteps x hidden dim
|
||||
|
||||
|
||||
|
Reference in New Issue
Block a user