mirror of
				https://github.com/illiumst/marl-factory-grid.git
				synced 2025-10-31 12:37:27 +01:00 
			
		
		
		
	
		
			
				
	
	
		
			222 lines
		
	
	
		
			7.6 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			222 lines
		
	
	
		
			7.6 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
| import numpy as np
 | |
| from collections import deque
 | |
| import torch
 | |
| from typing import Union
 | |
| from torch import Tensor
 | |
| from torch.utils.data import Dataset, ConcatDataset
 | |
| import random
 | |
| 
 | |
| 
 | |
| class ActorCriticMemory(object):
 | |
|     def __init__(self, capacity=10):
 | |
|         self.capacity = capacity
 | |
|         self.reset()
 | |
| 
 | |
|     def reset(self):
 | |
|         self.__actions        = LazyTensorFiFoQueue(maxlen=self.capacity+1)
 | |
|         self.__hidden_actor   = LazyTensorFiFoQueue(maxlen=self.capacity+1)
 | |
|         self.__hidden_critic  = LazyTensorFiFoQueue(maxlen=self.capacity+1)
 | |
|         self.__states         = LazyTensorFiFoQueue(maxlen=self.capacity+1)
 | |
|         self.__rewards        = LazyTensorFiFoQueue(maxlen=self.capacity+1)
 | |
|         self.__dones          = LazyTensorFiFoQueue(maxlen=self.capacity+1)
 | |
|         self.__logits         = LazyTensorFiFoQueue(maxlen=self.capacity+1)
 | |
|         self.__values         = LazyTensorFiFoQueue(maxlen=self.capacity+1)
 | |
| 
 | |
|     def __len__(self):
 | |
|         return len(self.__rewards) - 1
 | |
| 
 | |
|     @property
 | |
|     def observation(self, sls=slice(0, None)):  # add time dimension through stacking
 | |
|         return self.__states[sls].unsqueeze(0)      # 1 x time x hidden dim
 | |
| 
 | |
|     @property
 | |
|     def hidden_actor(self,  sls=slice(0, None)):  # 1 x n_layers x dim
 | |
|         return self.__hidden_actor[sls].unsqueeze(0)    # 1 x time x n_layers x dim
 | |
| 
 | |
|     @property
 | |
|     def hidden_critic(self, sls=slice(0, None)):  # 1 x n_layers x dim
 | |
|         return self.__hidden_critic[sls].unsqueeze(0)    # 1 x time x n_layers x dim
 | |
| 
 | |
|     @property
 | |
|     def reward(self, sls=slice(0, None)):
 | |
|         return self.__rewards[sls].squeeze().unsqueeze(0)  # 1 x time
 | |
| 
 | |
|     @property
 | |
|     def action(self, sls=slice(0, None)):
 | |
|         return self.__actions[sls].long().squeeze().unsqueeze(0)  # 1 x time
 | |
| 
 | |
|     @property
 | |
|     def done(self, sls=slice(0, None)):
 | |
|         return self.__dones[sls].float().squeeze().unsqueeze(0)  # 1 x time
 | |
| 
 | |
|     @property
 | |
|     def logits(self, sls=slice(0, None)):  # assumes a trailing 1 for time dimension - common when using output from NN
 | |
|         return self.__logits[sls].squeeze().unsqueeze(0)  # 1 x time x actions
 | |
| 
 | |
|     @property
 | |
|     def values(self, sls=slice(0, None)):
 | |
|         return self.__values[sls].squeeze().unsqueeze(0)  # 1 x time x actions
 | |
| 
 | |
|     def add_observation(self, state:  Union[Tensor, np.ndarray]):
 | |
|         self.__states.append(state    if isinstance(state, Tensor) else torch.from_numpy(state))
 | |
| 
 | |
|     def add_hidden_actor(self, hidden: Tensor):
 | |
|         # layers x hidden dim
 | |
|         self.__hidden_actor.append(hidden)
 | |
| 
 | |
|     def add_hidden_critic(self, hidden: Tensor):
 | |
|         # layers x hidden dim
 | |
|         self.__hidden_critic.append(hidden)
 | |
| 
 | |
|     def add_action(self, action: Union[int, Tensor]):
 | |
|         if not isinstance(action, Tensor):
 | |
|             action = torch.tensor(action)
 | |
|         self.__actions.append(action)
 | |
| 
 | |
|     def add_reward(self, reward: Union[float, Tensor]):
 | |
|         if not isinstance(reward, Tensor):
 | |
|             reward = torch.tensor(reward)
 | |
|         self.__rewards.append(reward)
 | |
| 
 | |
|     def add_done(self, done:   bool):
 | |
|         if not isinstance(done, Tensor):
 | |
|             done = torch.tensor(done)
 | |
|         self.__dones.append(done)
 | |
| 
 | |
|     def add_logits(self, logits: Tensor):
 | |
|         self.__logits.append(logits)
 | |
| 
 | |
|     def add_values(self, values: Tensor):
 | |
|         self.__values.append(values)
 | |
| 
 | |
|     def add(self, **kwargs):
 | |
|         for k, v in kwargs.items():
 | |
|             func = getattr(ActorCriticMemory, f'add_{k}')
 | |
|             func(self, v)
 | |
| 
 | |
| 
 | |
| class MARLActorCriticMemory(object):
 | |
|     def __init__(self, n_agents, capacity):
 | |
|         self.n_agents = n_agents
 | |
|         self.memories = [
 | |
|             ActorCriticMemory(capacity) for _ in range(n_agents)
 | |
|         ]
 | |
| 
 | |
|     def __call__(self, agent_i):
 | |
|         return self.memories[agent_i]
 | |
| 
 | |
|     def __len__(self):
 | |
|         return len(self.memories[0])  # todo add assertion check!
 | |
| 
 | |
|     def reset(self):
 | |
|         for mem in self.memories:
 | |
|             mem.reset()
 | |
| 
 | |
|     def add(self, **kwargs):
 | |
|         for agent_i in range(self.n_agents):
 | |
|             for k, v in kwargs.items():
 | |
|                 func = getattr(ActorCriticMemory, f'add_{k}')
 | |
|                 func(self.memories[agent_i], v[agent_i])
 | |
| 
 | |
|     def __getattr__(self, attr):
 | |
|         all_attrs = [getattr(mem, attr) for mem in self.memories]
 | |
|         return torch.cat(all_attrs, 0)  # agent x time ...
 | |
| 
 | |
|     def chunk_dataloader(self, chunk_len, k):
 | |
|         datasets = [ExperienceChunks(mem, chunk_len, k) for mem in self.memories]
 | |
|         dataset = ConcatDataset(datasets)
 | |
|         data = [dataset[i] for i in range(len(dataset))]
 | |
|         data = custom_collate_fn(data)
 | |
|         return data
 | |
| 
 | |
| 
 | |
| def custom_collate_fn(batch):
 | |
|     elem = batch[0]
 | |
|     return {key: torch.cat([d[key] for d in batch], dim=0) for key in elem}
 | |
| 
 | |
| 
 | |
| class ExperienceChunks(Dataset):
 | |
|     def __init__(self, memory, chunk_len, k):
 | |
|         assert chunk_len <= len(memory), 'chunk_len cannot be longer than the size of the memory'
 | |
|         self.memory = memory
 | |
|         self.chunk_len = chunk_len
 | |
|         self.k = k
 | |
| 
 | |
|     @property
 | |
|     def whitelist(self):
 | |
|         whitelist = torch.ones(len(self.memory) - self.chunk_len)
 | |
|         for d in self.memory.done.squeeze().nonzero().flatten():
 | |
|             whitelist[max((0, d-self.chunk_len-1)):d+2] = 0
 | |
|         whitelist[0] = 0
 | |
|         return whitelist.tolist()
 | |
| 
 | |
|     def sample(self, start=1):
 | |
|         cl = self.chunk_len
 | |
|         sample = dict(observation=self.memory.observation[:, start:start+cl+1],
 | |
|                       action=self.memory.action[:, start-1:start+cl],
 | |
|                       hidden_actor=self.memory.hidden_actor[:, start-1],
 | |
|                       hidden_critic=self.memory.hidden_critic[:, start-1],
 | |
|                       reward=self.memory.reward[:, start:start + cl],
 | |
|                       done=self.memory.done[:, start:start + cl],
 | |
|                       logits=self.memory.logits[:, start:start + cl],
 | |
|                       values=self.memory.values[:, start:start + cl])
 | |
|         return sample
 | |
| 
 | |
|     def __len__(self):
 | |
|         return self.k
 | |
| 
 | |
|     def __getitem__(self, i):
 | |
|         idx = random.choices(range(0, len(self.memory) - self.chunk_len), weights=self.whitelist, k=1)
 | |
|         return self.sample(idx[0])
 | |
| 
 | |
| 
 | |
| class LazyTensorFiFoQueue:
 | |
|     def __init__(self, maxlen):
 | |
|         self.maxlen = maxlen
 | |
|         self.reset()
 | |
| 
 | |
|     def reset(self):
 | |
|         self.__lazy_queue = deque(maxlen=self.maxlen)
 | |
|         self.shape = None
 | |
|         self.queue = None
 | |
| 
 | |
|     def shape_init(self, tensor: Tensor):
 | |
|         self.shape = torch.Size([self.maxlen, *tensor.shape])
 | |
| 
 | |
|     def build_tensor_queue(self):
 | |
|         if len(self.__lazy_queue) > 0:
 | |
|             block = torch.stack(list(self.__lazy_queue), dim=0)
 | |
|             l = block.shape[0]
 | |
|             if self.queue is None:
 | |
|                 self.queue = block
 | |
|             elif self.true_len() <= self.maxlen:
 | |
|                 self.queue = torch.cat((self.queue, block),  dim=0)
 | |
|             else:
 | |
|                 self.queue = torch.cat((self.queue[l:], block),  dim=0)
 | |
|             self.__lazy_queue.clear()
 | |
| 
 | |
|     def append(self, data):
 | |
|         if self.shape is None:
 | |
|             self.shape_init(data)
 | |
|         self.__lazy_queue.append(data)
 | |
|         if len(self.__lazy_queue) >= self.maxlen:
 | |
|             self.build_tensor_queue()
 | |
| 
 | |
|     def true_len(self):
 | |
|         return len(self.__lazy_queue) + (0 if self.queue is None else self.queue.shape[0])
 | |
| 
 | |
|     def __len__(self):
 | |
|         return min((self.true_len(), self.maxlen))
 | |
| 
 | |
|     def __str__(self):
 | |
|         return f'LazyTensorFiFoQueue\tmaxlen: {self.maxlen}, shape: {self.shape}, ' \
 | |
|                f'len: {len(self)}, true_len: {self.true_len()}, elements in lazy queue: {len(self.__lazy_queue)}'
 | |
| 
 | |
|     def __getitem__(self, item_or_slice):
 | |
|         self.build_tensor_queue()
 | |
|         return self.queue[item_or_slice]
 | |
| 
 | |
| 
 | |
| 
 | |
| 
 | 
