refactoring and init.py

This commit is contained in:
Steffen Illium
2023-06-20 18:21:43 +02:00
parent 1332cee7e1
commit c7d77acbbe
138 changed files with 328 additions and 320 deletions

0
mfg_package/__init__.py Normal file
View File

View File

@ -0,0 +1 @@
import os, sys; sys.path.append(os.path.dirname(os.path.realpath(__file__)))

View File

@ -0,0 +1 @@
from mfg_package.algorithms.marl.memory import MARLActorCriticMemory

View File

@ -0,0 +1,221 @@
import torch
from typing import Union, List, Dict
import numpy as np
from torch.distributions import Categorical
from mfg_package.algorithms.marl.memory import MARLActorCriticMemory
from mfg_package.algorithms.utils import add_env_props, instantiate_class
from pathlib import Path
import pandas as pd
from collections import deque
class Names:
REWARD = 'reward'
DONE = 'done'
ACTION = 'action'
OBSERVATION = 'observation'
LOGITS = 'logits'
HIDDEN_ACTOR = 'hidden_actor'
HIDDEN_CRITIC = 'hidden_critic'
AGENT = 'agent'
ENV = 'environment'
N_AGENTS = 'n_agents'
ALGORITHM = 'algorithm'
MAX_STEPS = 'max_steps'
N_STEPS = 'n_steps'
BUFFER_SIZE = 'buffer_size'
CRITIC = 'critic'
BATCH_SIZE = 'bnatch_size'
N_ACTIONS = 'n_actions'
nms = Names
ListOrTensor = Union[List, torch.Tensor]
class BaseActorCritic:
def __init__(self, cfg):
add_env_props(cfg)
self.__training = True
self.cfg = cfg
self.n_agents = cfg[nms.ENV][nms.N_AGENTS]
self.reset_memory_after_epoch = True
self.setup()
def setup(self):
self.net = instantiate_class(self.cfg[nms.AGENT])
self.optimizer = torch.optim.RMSprop(self.net.parameters(), lr=3e-4, eps=1e-5)
@classmethod
def _as_torch(cls, x):
if isinstance(x, np.ndarray):
return torch.from_numpy(x)
elif isinstance(x, List):
return torch.tensor(x)
elif isinstance(x, (int, float)):
return torch.tensor([x])
return x
def train(self):
self.__training = False
networks = [self.net] if not isinstance(self.net, List) else self.net
for net in networks:
net.train()
def eval(self):
self.__training = False
networks = [self.net] if not isinstance(self.net, List) else self.net
for net in networks:
net.eval()
def load_state_dict(self, path: Path):
pass
def get_actions(self, out) -> ListOrTensor:
actions = [Categorical(logits=logits).sample().item() for logits in out[nms.LOGITS]]
return actions
def init_hidden(self) -> Dict[str, ListOrTensor]:
pass
def forward(self,
observations: ListOrTensor,
actions: ListOrTensor,
hidden_actor: ListOrTensor,
hidden_critic: ListOrTensor
) -> Dict[str, ListOrTensor]:
pass
@torch.no_grad()
def train_loop(self, checkpointer=None):
env = instantiate_class(self.cfg[nms.ENV])
n_steps, max_steps = [self.cfg[nms.ALGORITHM][k] for k in [nms.N_STEPS, nms.MAX_STEPS]]
tm = MARLActorCriticMemory(self.n_agents, self.cfg[nms.ALGORITHM].get(nms.BUFFER_SIZE, n_steps))
global_steps, episode, df_results = 0, 0, []
reward_queue = deque(maxlen=2000)
while global_steps < max_steps:
obs = env.reset()
last_hiddens = self.init_hidden()
last_action, reward = [-1] * self.n_agents, [0.] * self.n_agents
done, rew_log = [False] * self.n_agents, 0
if self.reset_memory_after_epoch:
tm.reset()
tm.add(observation=obs, action=last_action,
logits=torch.zeros(self.n_agents, 1, self.cfg[nms.AGENT][nms.N_ACTIONS]),
values=torch.zeros(self.n_agents, 1), reward=reward, done=done, **last_hiddens)
while not all(done):
out = self.forward(obs, last_action, **last_hiddens)
action = self.get_actions(out)
next_obs, reward, done, info = env.step(action)
done = [done] * self.n_agents if isinstance(done, bool) else done
last_hiddens = dict(hidden_actor =out[nms.HIDDEN_ACTOR],
hidden_critic=out[nms.HIDDEN_CRITIC])
tm.add(observation=obs, action=action, reward=reward, done=done,
logits=out.get(nms.LOGITS, None), values=out.get(nms.CRITIC, None),
**last_hiddens)
obs = next_obs
last_action = action
if (global_steps+1) % n_steps == 0 or all(done):
with torch.inference_mode(False):
self.learn(tm)
global_steps += 1
rew_log += sum(reward)
reward_queue.extend(reward)
if checkpointer is not None:
checkpointer.step([
(f'agent#{i}', agent)
for i, agent in enumerate([self.net] if not isinstance(self.net, List) else self.net)
])
if global_steps >= max_steps:
break
print(f'reward at episode: {episode} = {rew_log}')
episode += 1
df_results.append([episode, rew_log, *reward])
df_results = pd.DataFrame(df_results, columns=['steps', 'reward', *[f'agent#{i}' for i in range(self.n_agents)]])
if checkpointer is not None:
df_results.to_csv(checkpointer.path / 'results.csv', index=False)
return df_results
@torch.inference_mode(True)
def eval_loop(self, n_episodes, render=False):
env = instantiate_class(self.cfg[nms.ENV])
episode, results = 0, []
while episode < n_episodes:
obs = env.reset()
last_hiddens = self.init_hidden()
last_action, reward = [-1] * self.n_agents, [0.] * self.n_agents
done, rew_log, eps_rew = [False] * self.n_agents, 0, torch.zeros(self.n_agents)
while not all(done):
if render: env.render()
out = self.forward(obs, last_action, **last_hiddens)
action = self.get_actions(out)
next_obs, reward, done, info = env.step(action)
if isinstance(done, bool): done = [done] * obs.shape[0]
obs = next_obs
last_action = action
last_hiddens = dict(hidden_actor=out.get(nms.HIDDEN_ACTOR, None),
hidden_critic=out.get(nms.HIDDEN_CRITIC, None)
)
eps_rew += torch.tensor(reward)
results.append(eps_rew.tolist() + [sum(eps_rew).item()] + [episode])
episode += 1
agent_columns = [f'agent#{i}' for i in range(self.cfg['environment']['n_agents'])]
results = pd.DataFrame(results, columns=agent_columns + ['sum', 'episode'])
results = pd.melt(results, id_vars=['episode'], value_vars=agent_columns + ['sum'], value_name='reward', var_name='agent')
return results
@staticmethod
def compute_advantages(critic, reward, done, gamma, gae_coef=0.0):
tds = (reward + gamma * (1.0 - done) * critic[:, 1:].detach()) - critic[:, :-1]
if gae_coef <= 0:
return tds
gae = torch.zeros_like(tds[:, -1])
gaes = []
for t in range(tds.shape[1]-1, -1, -1):
gae = tds[:, t] + gamma * gae_coef * (1.0 - done[:, t]) * gae
gaes.insert(0, gae)
gaes = torch.stack(gaes, dim=1)
return gaes
def actor_critic(self, tm, network, gamma, entropy_coef, vf_coef, gae_coef=0.0, **kwargs):
obs, actions, done, reward = tm.observation, tm.action, tm.done[:, 1:], tm.reward[:, 1:]
out = network(obs, actions, tm.hidden_actor[:, 0], tm.hidden_critic[:, 0])
logits = out[nms.LOGITS][:, :-1] # last one only needed for v_{t+1}
critic = out[nms.CRITIC]
entropy_loss = Categorical(logits=logits).entropy().mean(-1)
advantages = self.compute_advantages(critic, reward, done, gamma, gae_coef)
value_loss = advantages.pow(2).mean(-1) # n_agent
# policy loss
log_ap = torch.log_softmax(logits, -1)
log_ap = torch.gather(log_ap, dim=-1, index=actions[:, 1:].unsqueeze(-1)).squeeze()
a2c_loss = -(advantages.detach() * log_ap).mean(-1)
# weighted loss
loss = a2c_loss + vf_coef*value_loss - entropy_coef * entropy_loss
return loss.mean()
def learn(self, tm: MARLActorCriticMemory, **kwargs):
loss = self.actor_critic(tm, self.net, **self.cfg[nms.ALGORITHM], **kwargs)
# remove next_obs, will be added in next iter
self.optimizer.zero_grad()
loss.backward()
torch.nn.utils.clip_grad_norm_(self.net.parameters(), 0.5)
self.optimizer.step()

View File

@ -0,0 +1,24 @@
agent:
classname: algorithms.marl.networks.RecurrentAC
n_agents: 2
obs_emb_size: 96
action_emb_size: 16
hidden_size_actor: 64
hidden_size_critic: 64
use_agent_embedding: False
env:
classname: environments.factory.make
env_name: "DirtyFactory-v0"
n_agents: 2
max_steps: 250
pomdp_r: 2
stack_n_frames: 0
individual_rewards: True
method: algorithms.marl.LoopSEAC
algorithm:
gamma: 0.99
entropy_coef: 0.01
vf_coef: 0.5
n_steps: 5
max_steps: 1000000

View File

@ -0,0 +1,57 @@
import torch
from mfg_package.algorithms.marl.base_ac import BaseActorCritic, nms
from mfg_package.algorithms.utils import instantiate_class
from pathlib import Path
from natsort import natsorted
from mfg_package.algorithms.marl.memory import MARLActorCriticMemory
class LoopIAC(BaseActorCritic):
def __init__(self, cfg):
super(LoopIAC, self).__init__(cfg)
def setup(self):
self.net = [
instantiate_class(self.cfg[nms.AGENT]) for _ in range(self.n_agents)
]
self.optimizer = [
torch.optim.RMSprop(self.net[ag_i].parameters(), lr=3e-4, eps=1e-5) for ag_i in range(self.n_agents)
]
def load_state_dict(self, path: Path):
paths = natsorted(list(path.glob('*.pt')))
for path, net in zip(paths, self.net):
net.load_state_dict(torch.load(path))
@staticmethod
def merge_dicts(ds): # todo could be recursive for more than 1 hierarchy
d = {}
for k in ds[0].keys():
d[k] = [d[k] for d in ds]
return d
def init_hidden(self):
ha = [net.init_hidden_actor() for net in self.net]
hc = [net.init_hidden_critic() for net in self.net]
return dict(hidden_actor=ha, hidden_critic=hc)
def forward(self, observations, actions, hidden_actor, hidden_critic):
outputs = [
net(
self._as_torch(observations[ag_i]).unsqueeze(0).unsqueeze(0), # agent x time
self._as_torch(actions[ag_i]).unsqueeze(0),
hidden_actor[ag_i],
hidden_critic[ag_i]
) for ag_i, net in enumerate(self.net)
]
return self.merge_dicts(outputs)
def learn(self, tms: MARLActorCriticMemory, **kwargs):
for ag_i in range(self.n_agents):
tm, net = tms(ag_i), self.net[ag_i]
loss = self.actor_critic(tm, net, **self.cfg[nms.ALGORITHM], **kwargs)
self.optimizer[ag_i].zero_grad()
loss.backward()
torch.nn.utils.clip_grad_norm_(net.parameters(), 0.5)
self.optimizer[ag_i].step()

View File

@ -0,0 +1,66 @@
from mfg_package.algorithms.marl.base_ac import Names as nms
from mfg_package.algorithms.marl.snac import LoopSNAC
from mfg_package.algorithms.marl.memory import MARLActorCriticMemory
import torch
from torch.distributions import Categorical
from mfg_package.algorithms.utils import instantiate_class
class LoopMAPPO(LoopSNAC):
def __init__(self, *args, **kwargs):
super(LoopMAPPO, self).__init__(*args, **kwargs)
self.reset_memory_after_epoch = False
def setup(self):
self.net = instantiate_class(self.cfg[nms.AGENT])
self.optimizer = torch.optim.Adam(self.net.parameters(), lr=3e-4, eps=1e-5)
def learn(self, tm: MARLActorCriticMemory, **kwargs):
if len(tm) >= self.cfg['algorithm']['buffer_size']:
# only learn when buffer is full
for batch_i in range(self.cfg['algorithm']['n_updates']):
batch = tm.chunk_dataloader(chunk_len=self.cfg['algorithm']['n_steps'],
k=self.cfg['algorithm']['batch_size'])
loss = self.mappo(batch, self.net, **self.cfg[nms.ALGORITHM], **kwargs)
self.optimizer.zero_grad()
loss.backward()
torch.nn.utils.clip_grad_norm_(self.net.parameters(), 0.5)
self.optimizer.step()
def monte_carlo_returns(self, rewards, done, gamma):
rewards_ = []
discounted_reward = torch.zeros_like(rewards[:, -1])
for t in range(rewards.shape[1]-1, -1, -1):
discounted_reward = rewards[:, t] + (gamma * (1.0 - done[:, t]) * discounted_reward)
rewards_.insert(0, discounted_reward)
rewards_ = torch.stack(rewards_, dim=1)
return rewards_
def mappo(self, batch, network, gamma, entropy_coef, vf_coef, clip_range, **kwargs):
out = network(batch[nms.OBSERVATION], batch[nms.ACTION], batch[nms.HIDDEN_ACTOR], batch[nms.HIDDEN_CRITIC])
logits = out[nms.LOGITS][:, :-1] # last one only needed for v_{t+1}
old_log_probs = torch.log_softmax(batch[nms.LOGITS], -1)
old_log_probs = torch.gather(old_log_probs, index=batch[nms.ACTION][:, 1:].unsqueeze(-1), dim=-1).squeeze()
# monte carlo returns
mc_returns = self.monte_carlo_returns(batch[nms.REWARD], batch[nms.DONE], gamma)
mc_returns = (mc_returns - mc_returns.mean()) / (mc_returns.std() + 1e-8) #todo: norm across agent ok?
advantages = mc_returns - out[nms.CRITIC][:, :-1]
# policy loss
log_ap = torch.log_softmax(logits, -1)
log_ap = torch.gather(log_ap, dim=-1, index=batch[nms.ACTION][:, 1:].unsqueeze(-1)).squeeze()
ratio = (log_ap - old_log_probs).exp()
surr1 = ratio * advantages.detach()
surr2 = torch.clamp(ratio, 1 - clip_range, 1 + clip_range) * advantages.detach()
policy_loss = -torch.min(surr1, surr2).mean(-1)
# entropy & value loss
entropy_loss = Categorical(logits=logits).entropy().mean(-1)
value_loss = advantages.pow(2).mean(-1) # n_agent
# weighted loss
loss = policy_loss + vf_coef*value_loss - entropy_coef * entropy_loss
return loss.mean()

View File

@ -0,0 +1,221 @@
import numpy as np
from collections import deque
import torch
from typing import Union
from torch import Tensor
from torch.utils.data import Dataset, ConcatDataset
import random
class ActorCriticMemory(object):
def __init__(self, capacity=10):
self.capacity = capacity
self.reset()
def reset(self):
self.__actions = LazyTensorFiFoQueue(maxlen=self.capacity+1)
self.__hidden_actor = LazyTensorFiFoQueue(maxlen=self.capacity+1)
self.__hidden_critic = LazyTensorFiFoQueue(maxlen=self.capacity+1)
self.__states = LazyTensorFiFoQueue(maxlen=self.capacity+1)
self.__rewards = LazyTensorFiFoQueue(maxlen=self.capacity+1)
self.__dones = LazyTensorFiFoQueue(maxlen=self.capacity+1)
self.__logits = LazyTensorFiFoQueue(maxlen=self.capacity+1)
self.__values = LazyTensorFiFoQueue(maxlen=self.capacity+1)
def __len__(self):
return len(self.__rewards) - 1
@property
def observation(self, sls=slice(0, None)): # add time dimension through stacking
return self.__states[sls].unsqueeze(0) # 1 x time x hidden dim
@property
def hidden_actor(self, sls=slice(0, None)): # 1 x n_layers x dim
return self.__hidden_actor[sls].unsqueeze(0) # 1 x time x n_layers x dim
@property
def hidden_critic(self, sls=slice(0, None)): # 1 x n_layers x dim
return self.__hidden_critic[sls].unsqueeze(0) # 1 x time x n_layers x dim
@property
def reward(self, sls=slice(0, None)):
return self.__rewards[sls].squeeze().unsqueeze(0) # 1 x time
@property
def action(self, sls=slice(0, None)):
return self.__actions[sls].long().squeeze().unsqueeze(0) # 1 x time
@property
def done(self, sls=slice(0, None)):
return self.__dones[sls].float().squeeze().unsqueeze(0) # 1 x time
@property
def logits(self, sls=slice(0, None)): # assumes a trailing 1 for time dimension - common when using output from NN
return self.__logits[sls].squeeze().unsqueeze(0) # 1 x time x actions
@property
def values(self, sls=slice(0, None)):
return self.__values[sls].squeeze().unsqueeze(0) # 1 x time x actions
def add_observation(self, state: Union[Tensor, np.ndarray]):
self.__states.append(state if isinstance(state, Tensor) else torch.from_numpy(state))
def add_hidden_actor(self, hidden: Tensor):
# layers x hidden dim
self.__hidden_actor.append(hidden)
def add_hidden_critic(self, hidden: Tensor):
# layers x hidden dim
self.__hidden_critic.append(hidden)
def add_action(self, action: Union[int, Tensor]):
if not isinstance(action, Tensor):
action = torch.tensor(action)
self.__actions.append(action)
def add_reward(self, reward: Union[float, Tensor]):
if not isinstance(reward, Tensor):
reward = torch.tensor(reward)
self.__rewards.append(reward)
def add_done(self, done: bool):
if not isinstance(done, Tensor):
done = torch.tensor(done)
self.__dones.append(done)
def add_logits(self, logits: Tensor):
self.__logits.append(logits)
def add_values(self, values: Tensor):
self.__values.append(values)
def add(self, **kwargs):
for k, v in kwargs.items():
func = getattr(ActorCriticMemory, f'add_{k}')
func(self, v)
class MARLActorCriticMemory(object):
def __init__(self, n_agents, capacity):
self.n_agents = n_agents
self.memories = [
ActorCriticMemory(capacity) for _ in range(n_agents)
]
def __call__(self, agent_i):
return self.memories[agent_i]
def __len__(self):
return len(self.memories[0]) # todo add assertion check!
def reset(self):
for mem in self.memories:
mem.reset()
def add(self, **kwargs):
for agent_i in range(self.n_agents):
for k, v in kwargs.items():
func = getattr(ActorCriticMemory, f'add_{k}')
func(self.memories[agent_i], v[agent_i])
def __getattr__(self, attr):
all_attrs = [getattr(mem, attr) for mem in self.memories]
return torch.cat(all_attrs, 0) # agent x time ...
def chunk_dataloader(self, chunk_len, k):
datasets = [ExperienceChunks(mem, chunk_len, k) for mem in self.memories]
dataset = ConcatDataset(datasets)
data = [dataset[i] for i in range(len(dataset))]
data = custom_collate_fn(data)
return data
def custom_collate_fn(batch):
elem = batch[0]
return {key: torch.cat([d[key] for d in batch], dim=0) for key in elem}
class ExperienceChunks(Dataset):
def __init__(self, memory, chunk_len, k):
assert chunk_len <= len(memory), 'chunk_len cannot be longer than the size of the memory'
self.memory = memory
self.chunk_len = chunk_len
self.k = k
@property
def whitelist(self):
whitelist = torch.ones(len(self.memory) - self.chunk_len)
for d in self.memory.done.squeeze().nonzero().flatten():
whitelist[max((0, d-self.chunk_len-1)):d+2] = 0
whitelist[0] = 0
return whitelist.tolist()
def sample(self, start=1):
cl = self.chunk_len
sample = dict(observation=self.memory.observation[:, start:start+cl+1],
action=self.memory.action[:, start-1:start+cl],
hidden_actor=self.memory.hidden_actor[:, start-1],
hidden_critic=self.memory.hidden_critic[:, start-1],
reward=self.memory.reward[:, start:start + cl],
done=self.memory.done[:, start:start + cl],
logits=self.memory.logits[:, start:start + cl],
values=self.memory.values[:, start:start + cl])
return sample
def __len__(self):
return self.k
def __getitem__(self, i):
idx = random.choices(range(0, len(self.memory) - self.chunk_len), weights=self.whitelist, k=1)
return self.sample(idx[0])
class LazyTensorFiFoQueue:
def __init__(self, maxlen):
self.maxlen = maxlen
self.reset()
def reset(self):
self.__lazy_queue = deque(maxlen=self.maxlen)
self.shape = None
self.queue = None
def shape_init(self, tensor: Tensor):
self.shape = torch.Size([self.maxlen, *tensor.shape])
def build_tensor_queue(self):
if len(self.__lazy_queue) > 0:
block = torch.stack(list(self.__lazy_queue), dim=0)
l = block.shape[0]
if self.queue is None:
self.queue = block
elif self.true_len() <= self.maxlen:
self.queue = torch.cat((self.queue, block), dim=0)
else:
self.queue = torch.cat((self.queue[l:], block), dim=0)
self.__lazy_queue.clear()
def append(self, data):
if self.shape is None:
self.shape_init(data)
self.__lazy_queue.append(data)
if len(self.__lazy_queue) >= self.maxlen:
self.build_tensor_queue()
def true_len(self):
return len(self.__lazy_queue) + (0 if self.queue is None else self.queue.shape[0])
def __len__(self):
return min((self.true_len(), self.maxlen))
def __str__(self):
return f'LazyTensorFiFoQueue\tmaxlen: {self.maxlen}, shape: {self.shape}, ' \
f'len: {len(self)}, true_len: {self.true_len()}, elements in lazy queue: {len(self.__lazy_queue)}'
def __getitem__(self, item_or_slice):
self.build_tensor_queue()
return self.queue[item_or_slice]

View File

@ -0,0 +1,104 @@
import torch
import torch.nn as nn
import numpy as np
import torch.nn.functional as F
from torch.nn.utils import spectral_norm
class RecurrentAC(nn.Module):
def __init__(self, observation_size, n_actions, obs_emb_size,
action_emb_size, hidden_size_actor, hidden_size_critic,
n_agents, use_agent_embedding=True):
super(RecurrentAC, self).__init__()
observation_size = np.prod(observation_size)
self.n_layers = 1
self.n_actions = n_actions
self.use_agent_embedding = use_agent_embedding
self.hidden_size_actor = hidden_size_actor
self.hidden_size_critic = hidden_size_critic
self.action_emb_size = action_emb_size
self.obs_proj = nn.Linear(observation_size, obs_emb_size)
self.action_emb = nn.Embedding(n_actions+1, action_emb_size, padding_idx=0)
self.agent_emb = nn.Embedding(n_agents, action_emb_size)
mix_in_size = obs_emb_size+action_emb_size if not use_agent_embedding else obs_emb_size+n_agents*action_emb_size
self.mix = nn.Sequential(nn.Tanh(),
nn.Linear(mix_in_size, obs_emb_size),
nn.Tanh(),
nn.Linear(obs_emb_size, obs_emb_size)
)
self.gru_actor = nn.GRU(obs_emb_size, hidden_size_actor, batch_first=True, num_layers=self.n_layers)
self.gru_critic = nn.GRU(obs_emb_size, hidden_size_critic, batch_first=True, num_layers=self.n_layers)
self.action_head = nn.Sequential(
nn.Linear(hidden_size_actor, hidden_size_actor),
nn.Tanh(),
nn.Linear(hidden_size_actor, n_actions)
)
# spectral_norm(nn.Linear(hidden_size_actor, hidden_size_actor)),
self.critic_head = nn.Sequential(
nn.Linear(hidden_size_critic, hidden_size_critic),
nn.Tanh(),
nn.Linear(hidden_size_critic, 1)
)
#self.action_head[-1].weight.data.uniform_(-3e-3, 3e-3)
#self.action_head[-1].bias.data.uniform_(-3e-3, 3e-3)
def init_hidden_actor(self):
return torch.zeros(1, self.n_layers, self.hidden_size_actor)
def init_hidden_critic(self):
return torch.zeros(1, self.n_layers, self.hidden_size_critic)
def forward(self, observations, actions, hidden_actor=None, hidden_critic=None):
n_agents, t, *_ = observations.shape
obs_emb = self.obs_proj(observations.view(n_agents, t, -1).float())
action_emb = self.action_emb(actions+1) # shift by one due to padding idx
if not self.use_agent_embedding:
x_t = torch.cat((obs_emb, action_emb), -1)
else:
agent_emb = self.agent_emb(
torch.cat([torch.arange(0, n_agents, 1).view(-1, 1)] * t, 1)
)
x_t = torch.cat((obs_emb, agent_emb, action_emb), -1)
mixed_x_t = self.mix(x_t)
output_p, _ = self.gru_actor(input=mixed_x_t, hx=hidden_actor.swapaxes(1, 0))
output_c, _ = self.gru_critic(input=mixed_x_t, hx=hidden_critic.swapaxes(1, 0))
logits = self.action_head(output_p)
critic = self.critic_head(output_c).squeeze(-1)
return dict(logits=logits, critic=critic, hidden_actor=output_p, hidden_critic=output_c)
class RecurrentACL2(RecurrentAC):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.action_head = nn.Sequential(
nn.Linear(self.hidden_size_actor, self.hidden_size_actor),
nn.Tanh(),
NormalizedLinear(self.hidden_size_actor, self.n_actions, trainable_magnitude=True)
)
class NormalizedLinear(nn.Linear):
def __init__(self, in_features: int, out_features: int,
device=None, dtype=None, trainable_magnitude=False):
super(NormalizedLinear, self).__init__(in_features, out_features, False, device, dtype)
self.d_sqrt = in_features**0.5
self.trainable_magnitude = trainable_magnitude
self.scale = nn.Parameter(torch.tensor([1.]), requires_grad=trainable_magnitude)
def forward(self, input):
normalized_input = F.normalize(input, dim=-1, p=2, eps=1e-5)
normalized_weight = F.normalize(self.weight, dim=-1, p=2, eps=1e-5)
return F.linear(normalized_input, normalized_weight) * self.d_sqrt * self.scale
class L2Norm(nn.Module):
def __init__(self, in_features, trainable_magnitude=False):
super(L2Norm, self).__init__()
self.d_sqrt = in_features**0.5
self.scale = nn.Parameter(torch.tensor([1.]), requires_grad=trainable_magnitude)
def forward(self, x):
return F.normalize(x, dim=-1, p=2, eps=1e-5) * self.d_sqrt * self.scale

View File

@ -0,0 +1,56 @@
import torch
from torch.distributions import Categorical
from mfg_package.algorithms.marl.iac import LoopIAC
from mfg_package.algorithms.marl.base_ac import nms
from mfg_package.algorithms.marl.memory import MARLActorCriticMemory
class LoopSEAC(LoopIAC):
def __init__(self, cfg):
super(LoopSEAC, self).__init__(cfg)
def actor_critic(self, tm, networks, gamma, entropy_coef, vf_coef, gae_coef=0.0, **kwargs):
obs, actions, done, reward = tm.observation, tm.action, tm.done[:, 1:], tm.reward[:, 1:]
outputs = [net(obs, actions, tm.hidden_actor[:, 0], tm.hidden_critic[:, 0]) for net in networks]
with torch.inference_mode(True):
true_action_logp = torch.stack([
torch.log_softmax(out[nms.LOGITS][ag_i, :-1], -1)
.gather(index=actions[ag_i, 1:, None], dim=-1)
for ag_i, out in enumerate(outputs)
], 0).squeeze()
losses = []
for ag_i, out in enumerate(outputs):
logits = out[nms.LOGITS][:, :-1] # last one only needed for v_{t+1}
critic = out[nms.CRITIC]
entropy_loss = Categorical(logits=logits[ag_i]).entropy().mean()
advantages = self.compute_advantages(critic, reward, done, gamma, gae_coef)
# policy loss
log_ap = torch.log_softmax(logits, -1)
log_ap = torch.gather(log_ap, dim=-1, index=actions[:, 1:].unsqueeze(-1)).squeeze()
# importance weights
iw = (log_ap - true_action_logp).exp().detach() # importance_weights
a2c_loss = (-iw*log_ap * advantages.detach()).mean(-1)
value_loss = (iw*advantages.pow(2)).mean(-1) # n_agent
# weighted loss
loss = (a2c_loss + vf_coef*value_loss - entropy_coef * entropy_loss).mean()
losses.append(loss)
return losses
def learn(self, tms: MARLActorCriticMemory, **kwargs):
losses = self.actor_critic(tms, self.net, **self.cfg[nms.ALGORITHM], **kwargs)
for ag_i, loss in enumerate(losses):
self.optimizer[ag_i].zero_grad()
loss.backward()
torch.nn.utils.clip_grad_norm_(self.net[ag_i].parameters(), 0.5)
self.optimizer[ag_i].step()

View File

@ -0,0 +1,33 @@
from mfg_package.algorithms.marl.base_ac import BaseActorCritic
from mfg_package.algorithms.marl.base_ac import nms
import torch
from torch.distributions import Categorical
from pathlib import Path
class LoopSNAC(BaseActorCritic):
def __init__(self, cfg):
super().__init__(cfg)
def load_state_dict(self, path: Path):
path2weights = list(path.glob('*.pt'))
assert len(path2weights) == 1, f'Expected a single set of weights but got {len(path2weights)}'
self.net.load_state_dict(torch.load(path2weights[0]))
def init_hidden(self):
hidden_actor = self.net.init_hidden_actor()
hidden_critic = self.net.init_hidden_critic()
return dict(hidden_actor=torch.cat([hidden_actor] * self.n_agents, 0),
hidden_critic=torch.cat([hidden_critic] * self.n_agents, 0)
)
def get_actions(self, out):
actions = Categorical(logits=out[nms.LOGITS]).sample().squeeze()
return actions
def forward(self, observations, actions, hidden_actor, hidden_critic):
out = self.net(self._as_torch(observations).unsqueeze(1),
self._as_torch(actions).unsqueeze(1),
hidden_actor, hidden_critic
)
return out

View File

@ -0,0 +1,130 @@
import itertools
from random import choice
import numpy as np
import networkx as nx
from networkx.algorithms.approximation import traveling_salesman as tsp
from mfg_package.modules.doors import constants as do
from mfg_package.environment import constants as c
from mfg_package.utils.helpers import MOVEMAP
from abc import abstractmethod, ABC
future_planning = 7
def points_to_graph(coordiniates_or_tiles, allow_euclidean_connections=True, allow_manhattan_connections=True):
"""
Given a set of coordinates, this function contructs a non-directed graph, by conncting adjected points.
There are three combinations of settings:
Allow all neigbors: Distance(a, b) <= sqrt(2)
Allow only manhattan: Distance(a, b) == 1
Allow only euclidean: Distance(a, b) == sqrt(2)
:param coordiniates_or_tiles: A set of coordinates.
:type coordiniates_or_tiles: Tiles
:param allow_euclidean_connections: Whether to regard diagonal adjected cells as neighbors
:type: bool
:param allow_manhattan_connections: Whether to regard directly adjected cells as neighbors
:type: bool
:return: A graph with nodes that are conneceted as specified by the parameters.
:rtype: nx.Graph
"""
assert allow_euclidean_connections or allow_manhattan_connections
if hasattr(coordiniates_or_tiles, 'positions'):
coordiniates_or_tiles = coordiniates_or_tiles.positions
possible_connections = itertools.combinations(coordiniates_or_tiles, 2)
graph = nx.Graph()
for a, b in possible_connections:
diff = np.linalg.norm(np.asarray(a)-np.asarray(b))
if allow_manhattan_connections and allow_euclidean_connections and diff <= np.sqrt(2):
graph.add_edge(a, b)
elif not allow_manhattan_connections and allow_euclidean_connections and diff == np.sqrt(2):
graph.add_edge(a, b)
elif allow_manhattan_connections and not allow_euclidean_connections and diff == 1:
graph.add_edge(a, b)
return graph
class TSPBaseAgent(ABC):
def __init__(self, state, agent_i, static_problem: bool = True):
self.static_problem = static_problem
self.local_optimization = True
self._env = state
self.state = self._env.state[c.AGENT][agent_i]
self._floortile_graph = points_to_graph(self._env[c.FLOOR].positions)
self._static_route = None
@abstractmethod
def predict(self, *_, **__) -> int:
return 0
def _use_door_or_move(self, door, target):
if door.is_closed:
# Translate the action_object to an integer to have the same output as any other model
action = do.ACTION_DOOR_USE
else:
action = self._predict_move(target)
return action
def calculate_tsp_route(self, target_identifier):
positions = [x for x in self._env.state[target_identifier].positions if x != c.VALUE_NO_POS]
if self.local_optimization:
nodes = \
[self.state.pos] + \
[x for x in positions if max(abs(np.subtract(x, self.state.pos))) < 3]
try:
while len(nodes) < 7:
nodes += [next(x for x in positions if x not in nodes)]
except StopIteration:
nodes = [self.state.pos] + positions
else:
nodes = [self.state.pos] + positions
route = tsp.traveling_salesman_problem(self._floortile_graph,
nodes=nodes, cycle=True, method=tsp.greedy_tsp)
return route
def _door_is_close(self):
try:
return next(y for x in self.state.tile.neighboring_floor for y in x.guests if do.DOOR in y.name)
except StopIteration:
return None
def _has_targets(self, target_identifier):
return bool(len([x for x in self._env.state[target_identifier] if x.pos != c.VALUE_NO_POS]) >= 1)
def _predict_move(self, target_identifier):
if self._has_targets(target_identifier):
if self.static_problem:
if not self._static_route:
self._static_route = self.calculate_tsp_route(target_identifier)
else:
pass
next_pos = self._static_route.pop(0)
while next_pos == self.state.pos:
next_pos = self._static_route.pop(0)
else:
if not self._static_route:
self._static_route = self.calculate_tsp_route(target_identifier)[:7]
next_pos = self._static_route.pop(0)
while next_pos == self.state.pos:
next_pos = self._static_route.pop(0)
diff = np.subtract(next_pos, self.state.pos)
# Retrieve action based on the pos dif (like in: What do I have to do to get there?)
try:
action = next(action for action, pos_diff in MOVEMAP.items() if np.all(diff == pos_diff))
except StopIteration:
print(f'diff: {diff}')
print('This Should not happen!')
action = choice(self.state.actions).name
else:
action = choice(self.state.actions).name
# noinspection PyUnboundLocalVariable
return action

View File

@ -0,0 +1,27 @@
from mfg_package.algorithms.static.TSP_base_agent import TSPBaseAgent
from mfg_package.modules.clean_up import constants as di
future_planning = 7
class TSPDirtAgent(TSPBaseAgent):
def __init__(self, *args, **kwargs):
super(TSPDirtAgent, self).__init__(*args, **kwargs)
def predict(self, *_, **__):
if self._env.state[di.DIRT].by_pos(self.state.pos) is not None:
# Translate the action_object to an integer to have the same output as any other model
action = di.CLEAN_UP
elif door := self._door_is_close():
action = self._use_door_or_move(door, di.DIRT)
else:
action = self._predict_move(di.DIRT)
# Translate the action_object to an integer to have the same output as any other model
try:
action_obj = next(action_i for action_i, a in enumerate(self.state.actions) if a.name == action)
except (StopIteration, UnboundLocalError):
print('Will not happen')
raise EnvironmentError
return action_obj

View File

@ -0,0 +1,59 @@
import numpy as np
from mfg_package.algorithms.static.TSP_base_agent import TSPBaseAgent
from mfg_package.modules.items import constants as i
future_planning = 7
inventory_size = 3
MODE_GET = 'Mode_Get'
MODE_BRING = 'Mode_Bring'
class TSPItemAgent(TSPBaseAgent):
def __init__(self, *args, mode=MODE_GET, **kwargs):
super(TSPItemAgent, self).__init__(*args, **kwargs)
self.mode = mode
def predict(self, *_, **__):
if self._env.state[i.ITEM].by_pos(self.state.pos) is not None:
# Translate the action_object to an integer to have the same output as any other model
action = i.ITEM_ACTION
elif self._env.state[i.DROP_OFF].by_pos(self.state.pos) is not None:
# Translate the action_object to an integer to have the same output as any other model
action = i.ITEM_ACTION
elif door := self._door_is_close():
action = self._use_door_or_move(door, i.DROP_OFF if self.mode == MODE_BRING else i.ITEM)
else:
action = self._choose()
# Translate the action_object to an integer to have the same output as any other model
try:
action_obj = next(action_i for action_i, a in enumerate(self.state.actions) if a.name == action)
except (StopIteration, UnboundLocalError):
print('Will not happen')
raise EnvironmentError
# noinspection PyUnboundLocalVariable
if self.mode == MODE_BRING and len(self._env[i.INVENTORY].by_entity(self.state)):
pass
elif self.mode == MODE_BRING and not len(self._env[i.INVENTORY].by_entity(self.state)):
self.mode = MODE_GET
elif self.mode == MODE_GET and len(self._env[i.INVENTORY].by_entity(self.state)) > inventory_size:
self.mode = MODE_BRING
else:
pass
return action_obj
def _choose(self):
target = i.DROP_OFF if self.mode == MODE_BRING else i.ITEM
if len(self._env.state[i.ITEM]) >= 1:
action = self._predict_move(target)
elif len(self._env[i.INVENTORY].by_entity(self.state)):
self.mode = MODE_BRING
action = self._predict_move(target)
else:
action = int(np.random.randint(self._env.action_space.n))
# noinspection PyUnboundLocalVariable
return action

View File

@ -0,0 +1,32 @@
from mfg_package.algorithms.static.TSP_base_agent import TSPBaseAgent
from mfg_package.modules.destinations import constants as d
from mfg_package.modules.doors import constants as do
future_planning = 7
class TSPTargetAgent(TSPBaseAgent):
def __init__(self, *args, **kwargs):
super(TSPTargetAgent, self).__init__(*args, **kwargs)
def _handle_doors(self):
try:
return next(y for x in self.state.tile.neighboring_floor for y in x.guests if do.DOOR in y.name)
except StopIteration:
return None
def predict(self, *_, **__):
if door := self._door_is_close():
action = self._use_door_or_move(door, d.DESTINATION)
else:
action = self._predict_move(d.DESTINATION)
# Translate the action_object to an integer to have the same output as any other model
try:
action_obj = next(action_i for action_i, a in enumerate(self.state.actions) if a.name == action)
except (StopIteration, UnboundLocalError):
print('Will not happen')
return action_obj

View File

@ -0,0 +1,15 @@
from random import randint
from mfg_package.algorithms.static.TSP_base_agent import TSPBaseAgent
future_planning = 7
class TSPRandomAgent(TSPBaseAgent):
def __init__(self, n_actions, *args, **kwargs):
super(TSPRandomAgent, self).__init__(*args, **kwargs)
self.n_action = n_actions
def predict(self, *_, **__):
return randint(0, self.n_action - 1)

View File

@ -0,0 +1,85 @@
import torch
import numpy as np
import yaml
from pathlib import Path
def load_class(classname):
from importlib import import_module
module_path, class_name = classname.rsplit(".", 1)
module = import_module(module_path)
c = getattr(module, class_name)
return c
def instantiate_class(arguments):
from importlib import import_module
d = dict(arguments)
classname = d["classname"]
del d["classname"]
module_path, class_name = classname.rsplit(".", 1)
module = import_module(module_path)
c = getattr(module, class_name)
return c(**d)
def get_class(arguments):
from importlib import import_module
if isinstance(arguments, dict):
classname = arguments["classname"]
module_path, class_name = classname.rsplit(".", 1)
module = import_module(module_path)
c = getattr(module, class_name)
return c
else:
classname = arguments.classname
module_path, class_name = classname.rsplit(".", 1)
module = import_module(module_path)
c = getattr(module, class_name)
return c
def get_arguments(arguments):
from importlib import import_module
d = dict(arguments)
if "classname" in d:
del d["classname"]
return d
def load_yaml_file(path: Path):
with path.open() as stream:
cfg = yaml.load(stream, Loader=yaml.FullLoader)
return cfg
def add_env_props(cfg):
env = instantiate_class(cfg['environment'].copy())
cfg['agent'].update(dict(observation_size=list(env.observation_space.shape),
n_actions=env.action_space.n))
class Checkpointer(object):
def __init__(self, experiment_name, root, config, total_steps, n_checkpoints):
self.path = root / experiment_name
self.checkpoint_indices = list(np.linspace(1, total_steps, n_checkpoints, dtype=int) - 1)
self.__current_checkpoint = 0
self.__current_step = 0
self.path.mkdir(exist_ok=True, parents=True)
with (self.path / 'config.yaml').open('w') as outfile:
yaml.dump(config, outfile, default_flow_style=False)
def save_experiment(self, name: str, model):
cpt_path = self.path / f'checkpoint_{self.__current_checkpoint}'
cpt_path.mkdir(exist_ok=True, parents=True)
torch.save(model.state_dict(), cpt_path / f'{name}.pt')
def step(self, to_save):
if self.__current_step in self.checkpoint_indices:
print(f'Checkpointing #{self.__current_checkpoint}')
for name, model in to_save:
self.save_experiment(name, model)
self.__current_checkpoint += 1
self.__current_step += 1

View File

View File

@ -0,0 +1,100 @@
import abc
from typing import Union
from mfg_package.environment import rewards as r, constants as c
from mfg_package.utils.helpers import MOVEMAP
from mfg_package.utils.results import ActionResult
class Action(abc.ABC):
@property
def name(self):
return self._identifier
@abc.abstractmethod
def __init__(self, identifier: str):
self._identifier = identifier
@abc.abstractmethod
def do(self, entity, state) -> Union[None, ActionResult]:
return
def __repr__(self):
return f'Action[{self._identifier}]'
class Noop(Action):
def __init__(self):
super().__init__(c.NOOP)
def do(self, entity, *_) -> Union[None, ActionResult]:
return ActionResult(identifier=self._identifier, validity=c.VALID,
reward=r.NOOP, entity=entity)
class Move(Action, abc.ABC):
@abc.abstractmethod
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
def do(self, entity, env):
new_pos = self._calc_new_pos(entity.pos)
if next_tile := env[c.FLOOR].by_pos(new_pos):
# noinspection PyUnresolvedReferences
valid = entity.move(next_tile)
else:
valid = c.NOT_VALID
reward = r.MOVEMENTS_VALID if valid else r.MOVEMENTS_FAIL
return ActionResult(entity=entity, identifier=self._identifier, validity=valid, reward=reward)
def _calc_new_pos(self, pos):
x_diff, y_diff = MOVEMAP[self._identifier]
return pos[0] + x_diff, pos[1] + y_diff
class North(Move):
def __init__(self, *args, **kwargs):
super().__init__(c.NORTH, *args, **kwargs)
class NorthEast(Move):
def __init__(self, *args, **kwargs):
super().__init__(c.NORTHEAST, *args, **kwargs)
class East(Move):
def __init__(self, *args, **kwargs):
super().__init__(c.EAST, *args, **kwargs)
class SouthEast(Move):
def __init__(self, *args, **kwargs):
super().__init__(c.SOUTHEAST, *args, **kwargs)
class South(Move):
def __init__(self, *args, **kwargs):
super().__init__(c.SOUTH, *args, **kwargs)
class SouthWest(Move):
def __init__(self, *args, **kwargs):
super().__init__(c.SOUTHWEST, *args, **kwargs)
class West(Move):
def __init__(self, *args, **kwargs):
super().__init__(c.WEST, *args, **kwargs)
class NorthWest(Move):
def __init__(self, *args, **kwargs):
super().__init__(c.NORTHWEST, *args, **kwargs)
Move4 = [North, East, South, West]
# noinspection PyTypeChecker
Move8 = Move4 + [NorthEast, SouthEast, SouthWest, NorthWest]

Binary file not shown.

After

Width:  |  Height:  |  Size: 8.3 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 3.3 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 18 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 1.6 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 1.6 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 5.8 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 5.6 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 1.4 KiB

View File

@ -0,0 +1,60 @@
# Names
DANGER_ZONE = 'x' # Dange Zone tile _identifier for resolving the string based map files.
DEFAULTS = 'Defaults'
SELF = 'Self'
PLACEHOLDER = 'Placeholder'
FLOOR = 'Floor' # Identifier of Floor-objects and groups (groups).
FLOORS = 'Floors' # Identifier of Floor-objects and groups (groups).
WALL = 'Wall' # Identifier of Wall-objects and groups (groups).
WALLS = 'Walls' # Identifier of Wall-objects and groups (groups).
LEVEL = 'Level' # Identifier of Level-objects and groups (groups).
AGENT = 'Agent' # Identifier of Agent-objects and groups (groups).
AGENTS = 'Agents' # Identifier of Agent-objects and groups (groups).
OTHERS = 'Other'
COMBINED = 'Combined'
GLOBAL_POSITION = 'GLOBAL_POSITION' # Identifier of the global position slice
# Attributes
IS_BLOCKING_LIGHT = 'is_blocking_light'
HAS_POSITION = 'has_position'
HAS_NO_POSITION = 'has_no_position'
ALL = 'All'
# Symbols (Read from map-files)
SYMBOL_WALL = '#'
SYMBOL_FLOOR = '-'
VALID = True # Identifier to rename boolean values in the context of actions.
NOT_VALID = False # Identifier to rename boolean values in the context of actions.
VALUE_FREE_CELL = 0 # Free-Cell value used in observation
VALUE_OCCUPIED_CELL = 1 # Occupied-Cell value used in observation
VALUE_NO_POS = (-9999, -9999) # Invalid Position value used in the environment (smth. is off-grid)
ACTION = 'action' # Identifier of Action-objects and groups (groups).
COLLISION = 'Collision' # Identifier to use in the context of collitions.
LAST_POS = 'LAST_POS' # Identifiert for retrieving an enitites last pos.
VALIDITY = 'VALIDITY' # Identifiert for retrieving the Validity of Action, Tick, etc. ...
# Actions
# Movements
NORTH = 'north'
EAST = 'east'
SOUTH = 'south'
WEST = 'west'
NORTHEAST = 'north_east'
SOUTHEAST = 'south_east'
SOUTHWEST = 'south_west'
NORTHWEST = 'north_west'
# Move Groups
MOVE8 = 'Move8'
MOVE4 = 'Move4'
# No-Action / Wait
NOOP = 'Noop'
# Result Identifier
MOVEMENTS_VALID = 'motion_valid'
MOVEMENTS_FAIL = 'motion_not_valid'

View File

@ -0,0 +1,76 @@
from typing import List, Union
from mfg_package.environment import constants as c
from mfg_package.environment.actions import Action
from mfg_package.environment.entity.entity import Entity
from mfg_package.utils.render import RenderEntity
from mfg_package.utils import renderer
from mfg_package.utils.helpers import is_move
from mfg_package.utils.results import ActionResult, Result
class Agent(Entity):
@property
def obs_tag(self):
return self.name
@property
def actions(self):
return self._actions
@property
def observations(self):
return self._observations
@property
def can_collide(self):
return True
def step_result(self):
pass
@property
def collection(self):
return self._collection
@property
def state(self):
return self._state or ActionResult(entity=self, identifier=c.NOOP, validity=c.VALID, reward=0)
def __init__(self, actions: List[Action], observations: List[str], *args, **kwargs):
super(Agent, self).__init__(*args, **kwargs)
self.step_result = dict()
self._actions = actions
self._observations = observations
self._state: Union[Result, None] = None
# noinspection PyAttributeOutsideInit
def clear_temp_state(self):
self._state = None
return self
def summarize_state(self):
state_dict = super().summarize_state()
state_dict.update(valid=bool(self.state.validity), action=str(self.state.identifier))
return state_dict
def set_state(self, action_result):
self._state = action_result
def render(self):
i = next(idx for idx, x in enumerate(self._collection) if x.name == self.name)
curr_state = self.state
if curr_state.identifier == c.COLLISION:
render_state = renderer.STATE_COLLISION
elif curr_state.validity:
if curr_state.identifier == c.NOOP:
render_state = renderer.STATE_IDLE
elif is_move(curr_state.identifier):
render_state = renderer.STATE_MOVE
else:
render_state = renderer.STATE_VALID
else:
render_state = renderer.STATE_INVALID
return RenderEntity(c.AGENT, self.pos, 1, 'none', render_state, i + 1, real_name=self.name)

View File

@ -0,0 +1,79 @@
import abc
from mfg_package.environment import constants as c
from mfg_package.environment.entity.object import EnvObject
from mfg_package.utils.render import RenderEntity
class Entity(EnvObject, abc.ABC):
"""Full Env Entity that lives on the environment Grid. Doors, Items, DirtPile etc..."""
@property
def has_position(self):
return self.pos != c.VALUE_NO_POS
@property
def x(self):
return self.pos[0]
@property
def y(self):
return self.pos[1]
@property
def pos(self):
return self._tile.pos
@property
def tile(self):
return self._tile
@property
def last_tile(self):
try:
return self._last_tile
except AttributeError:
# noinspection PyAttributeOutsideInit
self._last_tile = None
return self._last_tile
@property
def last_pos(self):
try:
return self.last_tile.pos
except AttributeError:
return c.VALUE_NO_POS
@property
def direction_of_view(self):
last_x, last_y = self.last_pos
curr_x, curr_y = self.pos
return last_x - curr_x, last_y - curr_y
def move(self, next_tile):
curr_tile = self.tile
if not_same_tile := curr_tile != next_tile:
if valid := next_tile.enter(self):
curr_tile.leave(self)
self._tile = next_tile
self._last_tile = curr_tile
for observer in self.observers:
observer.notify_change_pos(self)
return valid
return not_same_tile
def __init__(self, tile, **kwargs):
super().__init__(**kwargs)
self._tile = tile
tile.enter(self)
def summarize_state(self) -> dict:
return dict(name=str(self.name), x=int(self.x), y=int(self.y),
tile=str(self.tile.name), can_collide=bool(self.can_collide))
@abc.abstractmethod
def render(self):
return RenderEntity(self.__class__.__name__.lower(), self.pos)
def __repr__(self):
return super(Entity, self).__repr__() + f'(@{self.pos})'

View File

@ -0,0 +1,18 @@
# noinspection PyAttributeOutsideInit
class BoundEntityMixin:
@property
def bound_entity(self):
return self._bound_entity
@property
def name(self):
return f'{self.__class__.__name__}({self._bound_entity.name})'
def belongs_to_entity(self, entity):
return entity == self.bound_entity
def bind_to(self, entity):
self._bound_entity = entity

View File

@ -0,0 +1,126 @@
from collections import defaultdict
from typing import Union
from mfg_package.environment import constants as c
class Object:
"""Generell Objects for Organisation and Maintanance such as Actions etc..."""
_u_idx = defaultdict(lambda: 0)
def __bool__(self):
return True
@property
def observers(self):
return self._observers
@property
def name(self):
if self._str_ident is not None:
return f'{self.__class__.__name__}[{self._str_ident}]'
return f'{self.__class__.__name__}#{self.identifier_int}'
@property
def identifier(self):
if self._str_ident is not None:
return self._str_ident
else:
return self.name
def __init__(self, str_ident: Union[str, None] = None, **kwargs):
self._observers = []
self._str_ident = str_ident
self.identifier_int = self._identify_and_count_up()
self._collection = None
if kwargs:
print(f'Following kwargs were passed, but ignored: {kwargs}')
def __repr__(self):
return f'{self.name}'
def __eq__(self, other) -> bool:
return other == self.identifier
def __hash__(self):
return hash(self.identifier)
def _identify_and_count_up(self):
idx = Object._u_idx[self.__class__.__name__]
Object._u_idx[self.__class__.__name__] += 1
return idx
def set_collection(self, collection):
self._collection = collection
def add_observer(self, observer):
self.observers.append(observer)
observer.notify_change_pos(self)
def del_observer(self, observer):
self.observers.remove(observer)
class EnvObject(Object):
"""Objects that hold Information that are observable, but have no position on the environment grid. Inventories etc..."""
_u_idx = defaultdict(lambda: 0)
@property
def obs_tag(self):
try:
return self._collection.name or self.name
except AttributeError:
return self.name
@property
def is_blocking_light(self):
try:
return self._collection.is_blocking_light or False
except AttributeError:
return False
@property
def can_move(self):
try:
return self._collection.can_move or False
except AttributeError:
return False
@property
def is_blocking_pos(self):
try:
return self._collection.is_blocking_pos or False
except AttributeError:
return False
@property
def has_position(self):
try:
return self._collection.has_position or False
except AttributeError:
return False
@property
def can_collide(self):
try:
return self._collection.can_collide or False
except AttributeError:
return False
@property
def encoding(self):
return c.VALUE_OCCUPIED_CELL
def __init__(self, **kwargs):
super(EnvObject, self).__init__(**kwargs)
def change_parent_collection(self, other_collection):
other_collection.add_item(self)
self._collection.delete_env_object(self)
self._collection = other_collection
return self._collection == other_collection

View File

@ -0,0 +1,45 @@
import math
import numpy as np
from mfg_package.environment.entity.mixin import BoundEntityMixin
from mfg_package.environment.entity.object import Object, EnvObject
##########################################################################
# ####################### Objects and Entitys ########################## #
##########################################################################
class PlaceHolder(Object):
def __init__(self, *args, fill_value=0, **kwargs):
super().__init__(*args, **kwargs)
self._fill_value = fill_value
@property
def can_collide(self):
return False
@property
def encoding(self):
return self._fill_value
@property
def name(self):
return "PlaceHolder"
class GlobalPosition(BoundEntityMixin, EnvObject):
@property
def encoding(self):
if self._normalized:
return tuple(np.divide(self._bound_entity.pos, self._level_shape))
else:
return self.bound_entity.pos
def __init__(self, *args, normalized: bool = True, **kwargs):
super(GlobalPosition, self).__init__(*args, **kwargs)
self._level_shape = math.sqrt(self.size)
self._normalized = normalized

View File

@ -0,0 +1,131 @@
from typing import List
import numpy as np
from mfg_package.environment import constants as c
from mfg_package.environment.entity.object import EnvObject
from mfg_package.utils.render import RenderEntity
from mfg_package.utils import helpers as h
class Floor(EnvObject):
@property
def has_position(self):
return True
@property
def can_collide(self):
return False
@property
def can_move(self):
return False
@property
def is_blocking_pos(self):
return False
@property
def is_blocking_light(self):
return False
@property
def neighboring_floor_pos(self):
return [x.pos for x in self.neighboring_floor]
@property
def neighboring_floor(self):
if self._neighboring_floor:
pass
else:
self._neighboring_floor = [x for x in [self._collection.by_pos(np.add(self.pos, pos))
for pos in h.POS_MASK.reshape(-1, 2)
if not np.all(pos == [0, 0])]
if x]
return self._neighboring_floor
@property
def encoding(self):
return c.VALUE_OCCUPIED_CELL
@property
def guests_that_can_collide(self):
return [x for x in self.guests if x.can_collide]
@property
def guests(self):
return self._guests.values()
@property
def x(self):
return self.pos[0]
@property
def y(self):
return self.pos[1]
@property
def is_blocked(self):
return any([x.is_blocking_pos for x in self.guests])
def __init__(self, pos, **kwargs):
super(Floor, self).__init__(**kwargs)
self._guests = dict()
self.pos = tuple(pos)
self._neighboring_floor: List[Floor] = list()
self._blocked_by = None
def __len__(self):
return len(self._guests)
def is_empty(self):
return not len(self._guests)
def is_occupied(self):
return bool(len(self._guests))
def enter(self, guest):
if (guest.name not in self._guests and not self.is_blocked) and not (guest.is_blocking_pos and self.is_occupied()):
self._guests.update({guest.name: guest})
return c.VALID
else:
return c.NOT_VALID
def leave(self, guest):
try:
del self._guests[guest.name]
except (ValueError, KeyError):
return c.NOT_VALID
return c.VALID
def __repr__(self):
return f'{self.name}(@{self.pos})'
def summarize_state(self, **_):
return dict(name=self.name, x=int(self.x), y=int(self.y))
def render(self):
return None
class Wall(Floor):
@property
def can_collide(self):
return True
@property
def encoding(self):
return c.VALUE_OCCUPIED_CELL
def render(self):
return RenderEntity(c.WALL, self.pos)
@property
def is_blocking_pos(self):
return True
@property
def is_blocking_light(self):
return True

View File

@ -0,0 +1,201 @@
import shutil
from collections import defaultdict
from itertools import chain
from os import PathLike
from pathlib import Path
from typing import Union
import gymnasium as gym
from mfg_package.utils.level_parser import LevelParser
from mfg_package.utils.observation_builder import OBSBuilder
from mfg_package.utils.config_parser import FactoryConfigParser
from mfg_package.utils import helpers as h
import mfg_package.environment.constants as c
from mfg_package.utils.states import Gamestate
REC_TAC = 'rec_'
class BaseFactory(gym.Env):
@property
def action_space(self):
return self.state[c.AGENT].action_space
@property
def named_action_space(self):
return self.state[c.AGENT].named_action_space
@property
def observation_space(self):
return self.obs_builder.observation_space(self.state)
@property
def named_observation_space(self):
return self.obs_builder.named_observation_space(self.state)
@property
def params(self) -> dict:
import yaml
config_path = Path(self._config_file)
config_dict = yaml.safe_load(config_path.open())
return config_dict
@property
def summarize_header(self):
summary_dict = self._summarize_state(stateless_entities=True)
return summary_dict
def __exit__(self, exc_type, exc_val, exc_tb):
self.close()
def __init__(self, config_file: Union[str, PathLike]):
self._config_file = config_file
self.conf = FactoryConfigParser(self._config_file)
# Attribute Assignment
self.level_filepath = Path(__file__).parent.parent / h.LEVELS_DIR / f'{self.conf.level_name}.txt'
self._renderer = None # expensive - don't use; unless required !
parsed_entities = self.conf.load_entities()
self.map = LevelParser(self.level_filepath, parsed_entities, self.conf.pomdp_r)
# Init for later usage:
self.state: Gamestate
self.map: LevelParser
self.obs_builder: OBSBuilder
# TODO: Reset ---> document this
self.reset()
def __getitem__(self, item):
return self.state.entities[item]
def reset(self) -> (dict, dict):
self.state = None
# Init entity:
entities = self.map.do_init()
# Grab all rules:
rules = self.conf.load_rules()
# Agents
# noinspection PyAttributeOutsideInit
self.state = Gamestate(entities, rules, self.conf.env_seed)
agents = self.conf.load_agents(self.map.size, self[c.FLOOR].empty_tiles)
self.state.entities.add_item({c.AGENT: agents})
# All is set up, trigger additional init (after agent entity spawn etc)
self.state.rules.do_all_init(self.state)
# Observations
# noinspection PyAttributeOutsideInit
self.obs_builder = OBSBuilder(self.map.level_shape, self.state, self.map.pomdp_r)
return self.obs_builder.refresh_and_build_for_all(self.state)
def step(self, actions):
if not isinstance(actions, list):
actions = [int(actions)]
# Apply rules, do actions, tick the state, etc...
tick_result = self.state.tick(actions)
# Check Done Conditions
done_results = self.state.check_done()
# Finalize
reward, reward_info, done = self.summarize_step_results(tick_result, done_results)
info = reward_info
info.update(step_reward=sum(reward), step=self.state.curr_step)
# TODO:
# if self._record_episodes:
# info.update(self._summarize_state())
obs, reset_info = self.obs_builder.refresh_and_build_for_all(self.state)
info.update(reset_info)
return None, [x for x in obs.values()], reward, done, info
def summarize_step_results(self, tick_results: list, done_check_results: list) -> (int, dict, bool):
# Returns: Reward, Info
rewards = defaultdict(lambda: 0.0)
# Gather per agent environment rewards and
# Combine Info dicts into a global one
combined_info_dict = defaultdict(lambda: 0.0)
for result in chain(tick_results, done_check_results):
if result.reward is not None:
try:
rewards[result.entity.name] += result.reward
except AttributeError:
rewards['global'] += result.reward
infos = result.get_infos()
for info in infos:
assert isinstance(info.value, (float, int))
combined_info_dict[info.identifier] += info.value
# Check Done Rule Results
try:
done_reason = next(x for x in done_check_results if x.validity)
done = True
self.state.print(f'Env done, Reason: {done_reason.name}.')
except StopIteration:
done = False
if self.conf.individual_rewards:
global_rewards = rewards['global']
del rewards['global']
reward = [rewards[agent.name] for agent in self.state[c.AGENT]]
reward = [x + global_rewards for x in reward]
self.state.print(f"rewards are {rewards}")
return reward, combined_info_dict, done
else:
reward = sum(rewards.values())
self.state.print(f"reward is {reward}")
return reward, combined_info_dict, done
def start_recording(self):
self.conf.do_record = True
return self.conf.do_record
def stop_recording(self):
self.conf.do_record = False
return not self.conf.do_record
# noinspection PyGlobalUndefined
def render(self, mode='human'):
if not self._renderer: # lazy init
from mfg_package.utils.renderer import Renderer
global Renderer
self._renderer = Renderer(self.map.level_shape, view_radius=self.conf.pomdp_r, fps=10)
render_entities = self.state.entities.render()
if self.conf.pomdp_r:
for render_entity in render_entities:
if render_entity.name == c.AGENT:
render_entity.aux = self.obs_builder.curr_lightmaps[render_entity.real_name]
return self._renderer.render(render_entities)
def _summarize_state(self, stateless_entities=False):
summary = {f'{REC_TAC}step': self.state.curr_step}
for entity_group in self.state:
if entity_group.is_stateless == stateless_entities:
summary.update({f'{REC_TAC}{entity_group.name}': entity_group.summarize_states()})
return summary
def print(self, string):
if self.conf.verbose:
print(string)
def save_params(self, filepath: Path):
# noinspection PyProtectedMember
filepath = Path(filepath)
filepath.parent.mkdir(parents=True, exist_ok=True)
shutil.copyfile(self._config_file, filepath)

View File

@ -0,0 +1,29 @@
from mfg_package.environment.groups.env_objects import EnvObjects
from mfg_package.environment.groups.mixins import PositionMixin
from mfg_package.environment.entity.agent import Agent
class Agents(PositionMixin, EnvObjects):
_entity = Agent
is_blocking_light = False
can_move = True
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
@property
def obs_pairs(self):
return [(a.name, a) for a in self]
@property
def action_space(self):
from gymnasium import spaces
space = spaces.Tuple([spaces.Discrete(len(x.actions)) for x in self])
return space
@property
def named_action_space(self):
named_space = dict()
for agent in self:
named_space[agent.name] = {action.name: idx for idx, action in enumerate(agent.actions)}
return named_space

View File

@ -0,0 +1,33 @@
from mfg_package.environment.groups.objects import Objects
from mfg_package.environment.entity.object import EnvObject
class EnvObjects(Objects):
_entity = EnvObject
is_blocking_light: bool = False
can_collide: bool = False
has_position: bool = False
can_move: bool = False
@property
def encodings(self):
return [x.encoding for x in self]
def __init__(self, size, *args, **kwargs):
super(EnvObjects, self).__init__(*args, **kwargs)
self.size = size
def add_item(self, item: EnvObject):
assert self.has_position or (len(self) <= self.size)
super(EnvObjects, self).add_item(item)
return self
def summarize_states(self):
return [entity.summarize_state() for entity in self.values()]
def delete_env_object(self, env_object: EnvObject):
del self[env_object.name]
def delete_env_object_by_name(self, name):
del self[name]

View File

@ -0,0 +1,63 @@
from collections import defaultdict
from operator import itemgetter
from typing import Dict
from mfg_package.environment.groups.objects import Objects
from mfg_package.utils.helpers import POS_MASK
class Entities(Objects):
_entity = Objects
@staticmethod
def neighboring_positions(pos):
return (POS_MASK + pos).reshape(-1, 2)
def get_near_pos(self, pos):
return [y for x in itemgetter(*(tuple(x) for x in self.neighboring_positions(pos)))(self.pos_dict) for y in x]
def render(self):
return [y for x in self for y in x.render() if x is not None]
@property
def names(self):
return list(self._data.keys())
def __init__(self):
self.pos_dict = defaultdict(list)
super().__init__()
def iter_entities(self):
return iter((x for sublist in self.values() for x in sublist))
def add_items(self, items: Dict):
return self.add_item(items)
def add_item(self, item: dict):
assert_str = 'This group of entity has already been added!'
assert not any([key for key in item.keys() if key in self.keys()]), assert_str
self._data.update(item)
for val in item.values():
val.add_observer(self)
return self
def __delitem__(self, name):
assert_str = 'This group of entity does not exist in this collection!'
assert any([key for key in name.keys() if key in self.keys()]), assert_str
self[name]._observers.delete(self)
for entity in self[name]:
entity.del_observer(self)
return super(Entities, self).__delitem__(name)
@property
def obs_pairs(self):
return [y for x in self for y in x.obs_pairs]
def by_pos(self, pos: (int, int)):
return self.pos_dict[pos]
# found_entities = [y for y in (x.by_pos(pos) for x in self.values() if hasattr(x, 'by_pos')) if y is not None]
# return found_entities
@property
def positions(self):
return [k for k, v in self.pos_dict.items() for _ in v]

View File

@ -0,0 +1,97 @@
from mfg_package.environment import constants as c
from mfg_package.environment.entity.entity import Entity
# noinspection PyUnresolvedReferences,PyTypeChecker,PyArgumentList
class PositionMixin:
_entity = Entity
is_blocking_light: bool = True
can_collide: bool = True
has_position: bool = True
def render(self):
return [y for y in [x.render() for x in self] if y is not None]
@classmethod
def from_tiles(cls, tiles, *args, entity_kwargs=None, **kwargs):
collection = cls(*args, **kwargs)
entities = [cls._entity(tile, str_ident=i,
**entity_kwargs if entity_kwargs is not None else {})
for i, tile in enumerate(tiles)]
collection.add_items(entities)
return collection
@classmethod
def from_coordinates(cls, positions: [(int, int)], tiles, *args, entity_kwargs=None, **kwargs, ):
return cls.from_tiles([tiles.by_pos(position) for position in positions], tiles.size, *args,
entity_kwargs=entity_kwargs,
**kwargs)
@property
def tiles(self):
return [entity.tile for entity in self]
def __delitem__(self, name):
idx, obj = next((i, obj) for i, obj in enumerate(self) if obj.name == name)
obj.tile.leave(obj)
super().__delitem__(name)
def by_pos(self, pos: (int, int)):
pos = tuple(pos)
try:
return next(e for e in self if e.pos == pos)
except StopIteration:
pass
except ValueError:
print()
@property
def positions(self):
return [e.pos for e in self]
def notify_del_entity(self, entity: Entity):
try:
self.pos_dict[entity.pos].remove(entity)
except (ValueError, AttributeError):
pass
# noinspection PyUnresolvedReferences,PyTypeChecker
class IsBoundMixin:
@property
def name(self):
return f'{self.__class__.__name__}({self._bound_entity.name})'
def __repr__(self):
return f'{self.__class__.__name__}#{self._bound_entity.name}({self._data})'
def bind(self, entity):
# noinspection PyAttributeOutsideInit
self._bound_entity = entity
return c.VALID
def belongs_to_entity(self, entity):
return self._bound_entity == entity
# noinspection PyUnresolvedReferences,PyTypeChecker
class HasBoundedMixin:
@property
def obs_names(self):
return [x.name for x in self]
def by_entity(self, entity):
try:
return next((x for x in self if x.belongs_to_entity(entity)))
except StopIteration:
return None
def idx_by_entity(self, entity):
try:
return next((idx for idx, x in enumerate(self) if x.belongs_to_entity(entity)))
except StopIteration:
return None

View File

@ -0,0 +1,141 @@
from collections import defaultdict
from typing import List
import numpy as np
from mfg_package.environment.entity.object import Object
class Objects:
_entity = Object
@property
def observers(self):
return self._observers
@property
def obs_tag(self):
return self.__class__.__name__
@staticmethod
def render():
return []
@property
def obs_pairs(self):
return [(self.name, self)]
@property
def names(self):
# noinspection PyUnresolvedReferences
return [x.name for x in self]
@property
def name(self):
return f'{self.__class__.__name__}'
def __init__(self, *args, **kwargs):
self._data = defaultdict(lambda: None)
self._observers = list()
self.pos_dict = defaultdict(list)
def __len__(self):
return len(self._data)
def __iter__(self):
return iter(self.values())
def add_item(self, item: _entity):
assert_str = f'All item names have to be of type {self._entity}, but were {item.__class__}.,'
assert isinstance(item, self._entity), assert_str
assert self._data[item.name] is None, f'{item.name} allready exists!!!'
self._data.update({item.name: item})
item.set_collection(self)
for observer in self.observers:
observer.notify_add_entity(item)
return self
# noinspection PyUnresolvedReferences
def del_observer(self, observer):
self.observers.remove(observer)
for entity in self:
if observer in entity.observers:
entity.del_observer(observer)
# noinspection PyUnresolvedReferences
def add_observer(self, observer):
self.observers.append(observer)
for entity in self:
if observer not in entity.observers:
entity.add_observer(observer)
def __delitem__(self, name):
for observer in self.observers:
observer.notify_del_entity(name)
# noinspection PyTypeChecker
del self._data[name]
def add_items(self, items: List[_entity]):
for item in items:
self.add_item(item)
return self
def keys(self):
return self._data.keys()
def values(self):
return self._data.values()
def items(self):
return self._data.items()
def _get_index(self, item):
try:
return next(i for i, v in enumerate(self._data.values()) if v == item)
except StopIteration:
return None
def __getitem__(self, item):
if isinstance(item, (int, np.int64, np.int32)):
if item < 0:
item = len(self._data) - abs(item)
try:
return next(v for i, v in enumerate(self._data.values()) if i == item)
except StopIteration:
return None
try:
return self._data[item]
except KeyError:
return None
except TypeError:
print('Ups')
raise TypeError
def __repr__(self):
return f'{self.__class__.__name__}[{dict(self._data)}]'
def notify_change_pos(self, entity: object):
try:
self.pos_dict[entity.last_pos].remove(entity)
except (ValueError, AttributeError):
pass
if entity.has_position:
try:
self.pos_dict[entity.pos].append(entity)
except (ValueError, AttributeError):
pass
def notify_del_entity(self, entity: Object):
try:
self.pos_dict[entity.pos].remove(entity)
except (ValueError, AttributeError):
pass
def notify_add_entity(self, entity: Object):
try:
entity.add_observer(self)
self.pos_dict[entity.pos].append(entity)
except (ValueError, AttributeError):
pass

View File

@ -0,0 +1,77 @@
from typing import List, Union
import numpy as np
from mfg_package.environment.groups.env_objects import EnvObjects
from mfg_package.environment.groups.objects import Objects
from mfg_package.environment.groups.mixins import HasBoundedMixin, PositionMixin
from mfg_package.environment.entity.util import GlobalPosition
from mfg_package.utils import helpers as h
from mfg_package.environment import constants as c
class Combined(PositionMixin, EnvObjects):
@property
def name(self):
return f'{super().name}({self._ident or self._names})'
@property
def names(self):
return self._names
def __init__(self, names: List[str], *args, identifier: Union[None, str] = None, **kwargs):
super().__init__(*args, **kwargs)
self._ident = identifier
self._names = names or list()
@property
def obs_tag(self):
return self.name
@property
def obs_pairs(self):
return [(name, None) for name in self.names]
class GlobalPositions(HasBoundedMixin, EnvObjects):
_entity = GlobalPosition
is_blocking_light = False,
can_collide = False
def __init__(self, *args, **kwargs):
super(GlobalPositions, self).__init__(*args, **kwargs)
class Zones(Objects):
@property
def accounting_zones(self):
return [self[idx] for idx, name in self.items() if name != c.DANGER_ZONE]
def __init__(self, parsed_level):
raise NotImplementedError('This needs a Rework')
super(Zones, self).__init__()
slices = list()
self._accounting_zones = list()
self._danger_zones = list()
for symbol in np.unique(parsed_level):
if symbol == c.VALUE_OCCUPIED_CELL:
continue
elif symbol == c.DANGER_ZONE:
self + symbol
slices.append(h.one_hot_level(parsed_level, symbol))
self._danger_zones.append(symbol)
else:
self + symbol
slices.append(h.one_hot_level(parsed_level, symbol))
self._accounting_zones.append(symbol)
self._zone_slices = np.stack(slices)
def __getitem__(self, item):
return self._zone_slices[item]
def add_items(self, other: Union[str, List[str]]):
raise AttributeError('You are not allowed to add additional Zones in runtime.')

View File

@ -0,0 +1,54 @@
import random
from typing import List
from mfg_package.environment import constants as c
from mfg_package.environment.groups.env_objects import EnvObjects
from mfg_package.environment.groups.mixins import PositionMixin
from mfg_package.environment.entity.wall_floor import Wall, Floor
class Walls(PositionMixin, EnvObjects):
_entity = Wall
symbol = c.SYMBOL_WALL
def __init__(self, *args, **kwargs):
super(Walls, self).__init__(*args, **kwargs)
self._value = c.VALUE_OCCUPIED_CELL
@classmethod
def from_coordinates(cls, argwhere_coordinates, *args, **kwargs):
tiles = cls(*args, **kwargs)
# noinspection PyTypeChecker
tiles.add_items([cls._entity(pos) for pos in argwhere_coordinates])
return tiles
@classmethod
def from_tiles(cls, tiles, *args, **kwargs):
raise RuntimeError()
class Floors(Walls):
_entity = Floor
symbol = c.SYMBOL_FLOOR
is_blocking_light: bool = False
can_collide: bool = False
def __init__(self, *args, **kwargs):
super(Floors, self).__init__(*args, **kwargs)
self._value = c.VALUE_FREE_CELL
@property
def occupied_tiles(self):
tiles = [tile for tile in self if tile.is_occupied()]
random.shuffle(tiles)
return tiles
@property
def empty_tiles(self) -> List[Floor]:
tiles = [tile for tile in self if tile.is_empty()]
random.shuffle(tiles)
return tiles
@classmethod
def from_tiles(cls, tiles, *args, **kwargs):
raise RuntimeError()

View File

@ -0,0 +1,4 @@
MOVEMENTS_VALID: float = -0.001
MOVEMENTS_FAIL: float = -0.05
NOOP: float = -0.01
COLLISION: float = -0.5

View File

@ -0,0 +1,82 @@
import abc
from typing import List
from mfg_package.utils.results import TickResult, DoneResult
from mfg_package.environment import rewards as r, constants as c
class Rule(abc.ABC):
@property
def name(self):
return self.__class__.__name__
def __init__(self):
pass
def __repr__(self):
return f'{self.name}'
def on_init(self, state):
return []
def on_reset(self):
return []
def tick_pre_step(self, state) -> List[TickResult]:
return []
def tick_step(self, state) -> List[TickResult]:
return []
def tick_post_step(self, state) -> List[TickResult]:
return []
def on_check_done(self, state) -> List[DoneResult]:
return []
class MaxStepsReached(Rule):
def __init__(self, max_steps: int = 500):
super().__init__()
self.max_steps = max_steps
def on_init(self, state):
pass
def on_check_done(self, state):
if self.max_steps <= state.curr_step:
return [DoneResult(validity=c.VALID, identifier=self.name, reward=0)]
return [DoneResult(validity=c.NOT_VALID, identifier=self.name, reward=0)]
class Collision(Rule):
def __init__(self, done_at_collisions: bool = False):
super().__init__()
self.done_at_collisions = done_at_collisions
self.curr_done = False
def tick_post_step(self, state) -> List[TickResult]:
self.curr_done = False
tiles_with_collisions = state.get_all_tiles_with_collisions()
results = list()
for tile in tiles_with_collisions:
guests = tile.guests_that_can_collide
if len(guests) >= 2:
for i, guest in enumerate(guests):
try:
guest.set_state(TickResult(identifier=c.COLLISION, reward=r.COLLISION,
validity=c.NOT_VALID, entity=self))
except AttributeError:
pass
results.append(TickResult(entity=guest, identifier=c.COLLISION,
reward=r.COLLISION, validity=c.VALID))
self.curr_done = True
return results
def on_check_done(self, state) -> List[DoneResult]:
if self.curr_done and self.done_at_collisions:
return [DoneResult(validity=c.VALID, identifier=c.COLLISION, reward=r.COLLISION)]
return [DoneResult(validity=c.NOT_VALID, identifier=self.name, reward=0)]

View File

View File

@ -0,0 +1,64 @@
import pickle
from os import PathLike
from pathlib import Path
from typing import Union
from gymnasium import Wrapper
from mfg_package.utils.helpers import IGNORED_DF_COLUMNS
from mfg_package.environment.factory import REC_TAC
import pandas as pd
from mfg_package.plotting.compare_runs import plot_single_run
class EnvMonitor(Wrapper):
ext = 'png'
def __init__(self, env, filepath: Union[str, PathLike] = None):
super(EnvMonitor, self).__init__(env)
self._filepath = filepath
self._monitor_df = pd.DataFrame()
self._monitor_dict = dict()
def __getattr__(self, item):
return getattr(self.unwrapped, item)
def step(self, action):
obs_type, obs, reward, done, info = self.env.step(action)
self._read_info(info)
self._read_done(done)
return obs_type, obs, reward, done, info
def reset(self):
return self.unwrapped.reset()
def _read_info(self, info: dict):
self._monitor_dict[len(self._monitor_dict)] = {
key: val for key, val in info.items() if
key not in ['terminal_observation', 'episode'] and not key.startswith(REC_TAC)}
return
def _read_done(self, done):
if done:
env_monitor_df = pd.DataFrame.from_dict(self._monitor_dict, orient='index')
self._monitor_dict = dict()
columns = [col for col in env_monitor_df.columns if col not in IGNORED_DF_COLUMNS]
env_monitor_df = env_monitor_df.aggregate(
{col: 'mean' if col.endswith('ount') else 'sum' for col in columns}
)
env_monitor_df['episode'] = len(self._monitor_df)
self._monitor_df = self._monitor_df.append([env_monitor_df])
else:
pass
return
def save_run(self, filepath: Union[Path, str, None] = None, auto_plotting_keys=None):
filepath = Path(filepath or self._filepath)
filepath.parent.mkdir(exist_ok=True, parents=True)
with filepath.open('wb') as f:
pickle.dump(self._monitor_df.reset_index(), f, protocol=pickle.HIGHEST_PROTOCOL)
if auto_plotting_keys:
plot_single_run(filepath, column_keys=auto_plotting_keys)

View File

@ -0,0 +1,152 @@
import warnings
from collections import defaultdict
from os import PathLike
from pathlib import Path
from typing import Union
import yaml
from gymnasium import Wrapper
import numpy as np
import pandas as pd
from mfg_package.environment.factory import REC_TAC
class EnvRecorder(Wrapper):
def __init__(self, env, entities: str = 'all', filepath: Union[str, PathLike] = None, freq: int = 0):
super(EnvRecorder, self).__init__(env)
self.filepath = filepath
self.freq = freq
self._recorder_dict = defaultdict(list)
self._recorder_out_list = list()
self._episode_counter = 1
self._do_record_dict = defaultdict(lambda: False)
if isinstance(entities, str):
if entities.lower() == 'all':
self._entities = None
else:
self._entities = [entities]
else:
self._entities = entities
def __getattr__(self, item):
return getattr(self.unwrapped, item)
def reset(self):
self._on_training_start()
return self.unwrapped.reset()
def _on_training_start(self) -> None:
assert self.start_recording()
def _read_info(self, env_idx, info: dict):
if info_dict := {key.replace(REC_TAC, ''): val for key, val in info.items() if key.startswith(f'{REC_TAC}')}:
if self._entities:
info_dict = {k: v for k, v in info_dict.items() if k in self._entities}
self._recorder_dict[env_idx].append(info_dict)
else:
pass
return True
def _read_done(self, env_idx, done):
if done:
self._recorder_out_list.append({'steps': self._recorder_dict[env_idx],
'episode': self._episode_counter})
self._recorder_dict[env_idx] = list()
else:
pass
def step(self, actions):
step_result = self.unwrapped.step(actions)
if self.do_record_episode(0):
info = step_result[-1]
self._read_info(0, info)
if self._do_record_dict[0]:
self._read_done(0, step_result[-2])
return step_result
def finalize(self):
self._on_training_end()
return True
def save_records(self, filepath: Union[Path, str, None] = None,
only_deltas=True,
save_occupation_map=False,
save_trajectory_map=False,
):
filepath = Path(filepath or self.filepath)
filepath.parent.mkdir(exist_ok=True, parents=True)
# cls.out_file.unlink(missing_ok=True)
with filepath.open('w') as f:
if only_deltas:
from deepdiff import DeepDiff
diff_dict = [DeepDiff(t1,t2, ignore_order=True)
for t1, t2 in zip(self._recorder_out_list, self._recorder_out_list[1:])
]
out_dict = {'episodes': diff_dict}
else:
out_dict = {'episodes': self._recorder_out_list}
out_dict.update(
{'n_episodes': self._episode_counter,
'env_params': self.env.params,
'header': self.env.summarize_header
})
try:
yaml.dump(out_dict, f, indent=4)
except TypeError:
print('Shit')
if save_occupation_map:
a = np.zeros((15, 15))
# noinspection PyTypeChecker
for episode in out_dict['episodes']:
df = pd.DataFrame([y for x in episode['steps'] for y in x['Agents']])
b = list(df[['x', 'y']].to_records(index=False))
np.add.at(a, tuple(zip(*b)), 1)
# a = np.rot90(a)
import seaborn as sns
from matplotlib import pyplot as plt
hm = sns.heatmap(data=a)
hm.set_title('Very Nice Heatmap')
plt.show()
if save_trajectory_map:
raise NotImplementedError('This has not yet been implemented.')
def do_record_episode(self, env_idx):
if not self._recorder_dict[env_idx]:
if self.freq:
self._do_record_dict[env_idx] = (self.freq == -1) or (self._episode_counter % self.freq) == 0
else:
self._do_record_dict[env_idx] = False
warnings.warn('You did wrap your Environment with a recorder, but set the freq to zero\n'
'Nothing will be recorded')
self._episode_counter += 1
else:
pass
return self._do_record_dict[env_idx]
def _on_step(self) -> bool:
for env_idx, info in enumerate(self.locals.get('infos', [])):
if self._do_record_dict[env_idx]:
self._read_info(env_idx, info)
dones = list(enumerate(self.locals.get('dones', [])))
dones.extend(list(enumerate(self.locals.get('done', []))))
for env_idx, done in dones:
if self._do_record_dict[env_idx]:
self._read_done(env_idx, done)
return True
def _on_training_end(self) -> None:
for env_idx in range(len(self._recorder_dict)):
if self._recorder_dict[env_idx]:
self._recorder_out_list.append({'steps': self._recorder_dict[env_idx],
'episode': self._episode_counter})
pass

View File

View File

@ -0,0 +1,11 @@
TEMPLATE = '#' # TEMPLATE _identifier. Define your own!
# Movements
NORTH = 'north'
EAST = 'east'
SOUTH = 'south'
WEST = 'west'
NORTHEAST = 'north_east'
SOUTHEAST = 'south_east'
SOUTHWEST = 'south_west'
NORTHWEST = 'north_west'

View File

@ -0,0 +1,24 @@
from typing import List
from mfg_package.environment.rules import Rule
from mfg_package.utils.results import TickResult, DoneResult
class TemplateRule(Rule):
def __init__(self, *args, **kwargs):
super(TemplateRule, self).__init__(*args, **kwargs)
def on_init(self, state):
pass
def tick_pre_step(self, state) -> List[TickResult]:
pass
def tick_step(self, state) -> List[TickResult]:
pass
def tick_post_step(self, state) -> List[TickResult]:
pass
def on_check_done(self, state) -> List[DoneResult]:
pass

View File

@ -0,0 +1,26 @@
from typing import Union
from mfg_package.environment.actions import Action
from mfg_package.utils.results import ActionResult
from mfg_package.modules.batteries import constants as b, rewards as r
from mfg_package.environment import constants as c
class BtryCharge(Action):
def __init__(self):
super().__init__(b.CHARGE)
def do(self, entity, state) -> Union[None, ActionResult]:
if charge_pod := state[b.CHARGE_PODS].by_pos(entity.pos):
valid = charge_pod.charge_battery(state[b.BATTERIES].by_entity(entity))
if valid:
state.print(f'{entity.name} just charged batteries at {charge_pod.name}.')
else:
state.print(f'{entity.name} failed to charged batteries at {charge_pod.name}.')
else:
valid = c.NOT_VALID
state.print(f'{entity.name} failed to charged batteries at {entity.pos}.')
return ActionResult(entity=entity, identifier=self._identifier, validity=valid,
reward=r.CHARGE_VALID if valid else r.CHARGE_FAIL)

View File

@ -0,0 +1,19 @@
from typing import NamedTuple, Union
# Battery Env
CHARGE_PODS = 'ChargePods'
BATTERIES = 'Batteries'
BATTERY_DISCHARGED = 'DISCHARGED'
CHARGE_POD_SYMBOL = 1
CHARGE = 'do_charge_action'
class BatteryProperties(NamedTuple):
initial_charge: float = 0.8 #
charge_rate: float = 0.4 #
charge_locations: int = 20 #
per_action_costs: Union[dict, float] = 0.02
done_when_discharged: bool = False
multi_charge: bool = False

View File

@ -0,0 +1,75 @@
from mfg_package.environment.entity.mixin import BoundEntityMixin
from mfg_package.environment.entity.object import EnvObject
from mfg_package.environment.entity.entity import Entity
from mfg_package.environment import constants as c
from mfg_package.utils.render import RenderEntity
from mfg_package.modules.batteries import constants as b
class Battery(BoundEntityMixin, EnvObject):
@property
def is_discharged(self):
return self.charge_level == 0
@property
def obs_tag(self):
return self.name
@property
def encoding(self):
return self.charge_level
def __init__(self, initial_charge_level: float, owner: Entity, *args, **kwargs):
super(Battery, self).__init__(*args, **kwargs)
self.charge_level = initial_charge_level
self.bind_to(owner)
def do_charge_action(self, amount):
if self.charge_level < 1:
# noinspection PyTypeChecker
self.charge_level = min(1, amount + self.charge_level)
return c.VALID
else:
return c.NOT_VALID
def decharge(self, amount) -> float:
if self.charge_level != 0:
# noinspection PyTypeChecker
self.charge_level = max(0, amount + self.charge_level)
return c.VALID
else:
return c.NOT_VALID
def summarize_state(self, **_):
attr_dict = {key: str(val) for key, val in self.__dict__.items() if not key.startswith('_') and key != 'data'}
attr_dict.update(dict(name=self.name, belongs_to=self._bound_entity.name))
return attr_dict
def render(self):
return None
class ChargePod(Entity):
@property
def encoding(self):
return b.CHARGE_POD_SYMBOL
def __init__(self, *args, charge_rate: float = 0.4,
multi_charge: bool = False, **kwargs):
super(ChargePod, self).__init__(*args, **kwargs)
self.charge_rate = charge_rate
self.multi_charge = multi_charge
def charge_battery(self, battery: Battery):
if battery.charge_level == 1.0:
return c.NOT_VALID
if sum(guest for guest in self.tile.guests if 'agent' in guest.name.lower()) > 1:
return c.NOT_VALID
valid = battery.do_charge_action(self.charge_rate)
return valid
def render(self):
return RenderEntity(b.CHARGE_PODS, self.pos)

View File

@ -0,0 +1,36 @@
from mfg_package.environment.groups.env_objects import EnvObjects
from mfg_package.environment.groups.mixins import PositionMixin, HasBoundedMixin
from mfg_package.modules.batteries.entitites import ChargePod, Battery
class Batteries(HasBoundedMixin, EnvObjects):
_entity = Battery
is_blocking_light: bool = False
can_collide: bool = False
@property
def obs_tag(self):
return self.__class__.__name__
@property
def obs_pairs(self):
return [(x.name, x) for x in self]
def __init__(self, *args, **kwargs):
super(Batteries, self).__init__(*args, **kwargs)
def spawn_batteries(self, agents, initial_charge_level):
batteries = [self._entity(initial_charge_level, agent) for _, agent in enumerate(agents)]
self.add_items(batteries)
class ChargePods(PositionMixin, EnvObjects):
_entity = ChargePod
def __init__(self, *args, **kwargs):
super(ChargePods, self).__init__(*args, **kwargs)
def __repr__(self):
return super(ChargePods, self).__repr__()

View File

@ -0,0 +1,3 @@
CHARGE_VALID: float = 0.1
CHARGE_FAIL: float = -0.1
BATTERY_DISCHARGED: float = -1.0

View File

@ -0,0 +1,61 @@
from typing import List, Union
from mfg_package.environment.rules import Rule
from mfg_package.utils.results import TickResult, DoneResult
from mfg_package.environment import constants as c
from mfg_package.modules.batteries import constants as b, rewards as r
class Btry(Rule):
def __init__(self, initial_charge: float = 0.8, per_action_costs: Union[dict, float] = 0.02):
super().__init__()
self.per_action_costs = per_action_costs
self.initial_charge = initial_charge
def on_init(self, state):
state[b.BATTERIES].spawn_batteries(state[c.AGENT], self.initial_charge)
def tick_pre_step(self, state) -> List[TickResult]:
pass
def tick_step(self, state) -> List[TickResult]:
# Decharge
batteries = state[b.BATTERIES]
results = []
for agent in state[c.AGENT]:
if isinstance(self.per_action_costs, dict):
energy_consumption = self.per_action_costs[agent.step_result()['action']]
else:
energy_consumption = self.per_action_costs
batteries.by_entity(agent).decharge(energy_consumption)
results.append(TickResult(self.name, reward=0, entity=agent, validity=c.VALID))
return results
def tick_post_step(self, state) -> List[TickResult]:
results = []
for btry in state[b.BATTERIES]:
if btry.is_discharged:
state.print(f'Battery of {btry.bound_entity.name} is discharged!')
results.append(
TickResult(self.name, entity=btry.bound_entity, reward=r.BATTERY_DISCHARGED, validity=c.VALID))
else:
pass
return results
class BtryDoneAtDischarge(Rule):
def __init__(self):
super().__init__()
def on_check_done(self, state) -> List[DoneResult]:
if btry_done := any(battery.is_discharged for battery in state[b.BATTERIES]):
return [DoneResult(self.name, validity=c.VALID, reward=r.BATTERY_DISCHARGED)]
else:
return [DoneResult(self.name, validity=c.NOT_VALID, reward=0)]

View File

View File

@ -0,0 +1,36 @@
from typing import Union
from mfg_package.environment.actions import Action
from mfg_package.utils.results import ActionResult
from mfg_package.modules.clean_up import constants as d, rewards as r
from mfg_package.environment import constants as c
class CleanUp(Action):
def __init__(self):
super().__init__(d.CLEAN_UP)
def do(self, entity, state) -> Union[None, ActionResult]:
if dirt := state[d.DIRT].by_pos(entity.pos):
new_dirt_amount = dirt.amount - state[d.DIRT].clean_amount
if new_dirt_amount <= 0:
state[d.DIRT].delete_env_object(dirt)
else:
dirt.set_new_amount(max(new_dirt_amount, c.VALUE_FREE_CELL))
valid = c.VALID
print_str = f'{entity.name} did just clean up some dirt at {entity.pos}.'
state.print(print_str)
reward = r.CLEAN_UP_VALID
identifier = d.CLEAN_UP
else:
valid = c.NOT_VALID
print_str = f'{entity.name} just tried to clean up some dirt at {entity.pos}, but failed.'
state.print(print_str)
reward = r.CLEAN_UP_FAIL
identifier = d.CLEAN_UP_FAIL
return ActionResult(identifier=identifier, validity=valid, reward=reward, entity=entity)

View File

@ -0,0 +1,7 @@
DIRT = 'DirtPiles'
CLEAN_UP = 'do_cleanup_action'
CLEAN_UP_VALID = 'clean_up_valid'
CLEAN_UP_FAIL = 'clean_up_fail'
CLEAN_UP_ALL = 'all_cleaned_up'

Binary file not shown.

After

Width:  |  Height:  |  Size: 38 KiB

View File

@ -0,0 +1,35 @@
from numpy import random
from mfg_package.environment.entity.entity import Entity
from mfg_package.utils.render import RenderEntity
from mfg_package.modules.clean_up import constants as d
class DirtPile(Entity):
@property
def amount(self):
return self._amount
@property
def encoding(self):
# Edit this if you want items to be drawn in the ops differntly
return self._amount
def __init__(self, *args, max_local_amount=5, initial_amount=2, spawn_variation=0.05, **kwargs):
super(DirtPile, self).__init__(*args, **kwargs)
self._amount = abs(initial_amount + (
random.normal(loc=0, scale=spawn_variation, size=1).item() * initial_amount)
)
self.max_local_amount = max_local_amount
def set_new_amount(self, amount):
self._amount = min(amount, self.max_local_amount)
def summarize_state(self):
state_dict = super().summarize_state()
state_dict.update(amount=float(self.amount))
return state_dict
def render(self):
return RenderEntity(d.DIRT, self.tile.pos, min(0.15 + self.amount, 1.5), 'scale')

View File

@ -0,0 +1,64 @@
from mfg_package.environment.groups.env_objects import EnvObjects
from mfg_package.environment.groups.mixins import PositionMixin
from mfg_package.environment.entity.wall_floor import Floor
from mfg_package.modules.clean_up.entitites import DirtPile
from mfg_package.environment import constants as c
class DirtPiles(PositionMixin, EnvObjects):
_entity = DirtPile
is_blocking_light: bool = False
can_collide: bool = False
@property
def amount(self):
return sum([dirt.amount for dirt in self])
def __init__(self, *args,
initial_amount=2,
initial_dirt_ratio=0.05,
dirt_spawn_r_var=0.1,
max_local_amount=5,
clean_amount=1,
max_global_amount: int = 20, **kwargs):
super(DirtPiles, self).__init__(*args, **kwargs)
self.clean_amount = clean_amount
self.initial_amount = initial_amount
self.initial_dirt_ratio = initial_dirt_ratio
self.dirt_spawn_r_var = dirt_spawn_r_var
self.max_global_amount = max_global_amount
self.max_local_amount = max_local_amount
def spawn_dirt(self, then_dirty_tiles, amount) -> bool:
if isinstance(then_dirty_tiles, Floor):
then_dirty_tiles = [then_dirty_tiles]
for tile in then_dirty_tiles:
if not self.amount > self.max_global_amount:
if dirt := self.by_pos(tile.pos):
new_value = dirt.amount + amount
dirt.set_new_amount(new_value)
else:
dirt = DirtPile(tile, initial_amount=amount, spawn_variation=self.dirt_spawn_r_var)
self.add_item(dirt)
else:
return c.NOT_VALID
return c.VALID
def trigger_dirt_spawn(self, state, initial_spawn=False) -> bool:
free_for_dirt = [x for x in state[c.FLOOR]
if len(x.guests) == 0 or (
len(x.guests) == 1 and
isinstance(next(y for y in x.guests), DirtPile))
]
state.rng.shuffle(free_for_dirt)
var = self.dirt_spawn_r_var
new_spawn = abs(self.initial_dirt_ratio + (state.rng.uniform(-var, var) if initial_spawn else 0))
n_dirt_tiles = max(0, int(new_spawn * len(free_for_dirt)))
return self.spawn_dirt(free_for_dirt[:n_dirt_tiles], self.initial_amount)
def __repr__(self):
s = super(DirtPiles, self).__repr__()
return f'{s[:-1]}, {self.amount})'

View File

@ -0,0 +1,3 @@
CLEAN_UP_VALID: float = 0.5
CLEAN_UP_FAIL: float = -0.1
CLEAN_UP_ALL: float = 4.5

View File

@ -0,0 +1,15 @@
from mfg_package.environment import constants as c
from mfg_package.environment.rules import Rule
from mfg_package.utils.results import DoneResult
from mfg_package.modules.clean_up import constants as d, rewards as r
class DirtAllCleanDone(Rule):
def __init__(self):
super().__init__()
def on_check_done(self, state) -> [DoneResult]:
if len(state[d.DIRT]) == 0 and state.curr_step:
return [DoneResult(validity=c.VALID, identifier=self.name, reward=r.CLEAN_UP_ALL)]
return [DoneResult(validity=c.NOT_VALID, identifier=self.name, reward=0)]

View File

@ -0,0 +1,28 @@
from mfg_package.environment.rules import Rule
from mfg_package.utils.results import TickResult
from mfg_package.modules.clean_up import constants as d
class DirtRespawnRule(Rule):
def __init__(self, spawn_freq=15):
super().__init__()
self.spawn_freq = spawn_freq
self._next_dirt_spawn = spawn_freq
def on_init(self, state) -> str:
state[d.DIRT].trigger_dirt_spawn(state, initial_spawn=True)
return f'Initial Dirt was spawned on: {[x.pos for x in state[d.DIRT]]}'
def tick_step(self, state):
if self._next_dirt_spawn < 0:
pass # No DirtPile Spawn
elif not self._next_dirt_spawn:
validity = state[d.DIRT].trigger_dirt_spawn(state)
return [TickResult(entity=None, validity=validity, identifier=self.name, reward=0)]
self._next_dirt_spawn = self.spawn_freq
else:
self._next_dirt_spawn -= 1
return []

View File

@ -0,0 +1,24 @@
from mfg_package.environment.rules import Rule
from mfg_package.utils.helpers import is_move
from mfg_package.utils.results import TickResult
from mfg_package.environment import constants as c
from mfg_package.modules.clean_up import constants as d
class DirtSmearOnMove(Rule):
def __init__(self, smear_amount: float = 0.2):
super().__init__()
self.smear_amount = smear_amount
def tick_post_step(self, state):
results = list()
for entity in state.moving_entites:
if is_move(entity.state.identifier) and entity.state.validity == c.VALID:
if old_pos_dirt := state[d.DIRT].by_pos(entity.last_pos):
if smeared_dirt := round(old_pos_dirt.amount * self.smear_amount, 2):
if state[d.DIRT].spawn_dirt(entity.tile, amount=smeared_dirt):
results.append(TickResult(identifier=self.name, entity=entity,
reward=0, validity=c.VALID))
return results

View File

@ -0,0 +1,23 @@
from typing import Union
from mfg_package.environment.actions import Action
from mfg_package.utils.results import ActionResult
from mfg_package.modules.destinations import constants as d, rewards as r
from mfg_package.environment import constants as c
class DestAction(Action):
def __init__(self):
super().__init__(d.DESTINATION)
def do(self, entity, state) -> Union[None, ActionResult]:
if destination := state[d.DESTINATION].by_pos(entity.pos):
valid = destination.do_wait_action(entity)
state.print(f'{entity.name} just waited at {entity.pos}')
else:
valid = c.NOT_VALID
state.print(f'{entity.name} just tried to do_wait_action do_wait_action at {entity.pos} but failed')
return ActionResult(entity=entity, identifier=self._identifier, validity=valid,
reward=r.WAIT_VALID if valid else r.WAIT_FAIL)

View File

@ -0,0 +1,14 @@
# Destination Env
DESTINATION = 'Destinations'
DEST_SYMBOL = 1
DEST_REACHED_REWARD = 0.5
DEST_REACHED = 'ReachedDestinations'
WAIT_ON_DEST = 'WAIT'
MODE_SINGLE = 'SINGLE'
MODE_GROUPED = 'GROUPED'
DONE_ALL = 'DONE_ALL'
DONE_SINGLE = 'DONE_SINGLE'

Binary file not shown.

After

Width:  |  Height:  |  Size: 6.9 KiB

View File

@ -0,0 +1,51 @@
from collections import defaultdict
from mfg_package.environment.entity.agent import Agent
from mfg_package.environment.entity.entity import Entity
from mfg_package.environment import constants as c
from mfg_package.utils.render import RenderEntity
from mfg_package.modules.destinations import constants as d
class Destination(Entity):
@property
def any_agent_has_dwelled(self):
return bool(len(self._per_agent_times))
@property
def currently_dwelling_names(self):
return list(self._per_agent_times.keys())
@property
def encoding(self):
return d.DEST_SYMBOL
def __init__(self, *args, dwell_time: int = 0, **kwargs):
super(Destination, self).__init__(*args, **kwargs)
self.dwell_time = dwell_time
self._per_agent_times = defaultdict(lambda: dwell_time)
def do_wait_action(self, agent: Agent):
self._per_agent_times[agent.name] -= 1
return c.VALID
def leave(self, agent: Agent):
del self._per_agent_times[agent.name]
@property
def is_considered_reached(self):
agent_at_position = any(c.AGENT.lower() in x.name.lower() for x in self.tile.guests_that_can_collide)
return (agent_at_position and not self.dwell_time) or any(x == 0 for x in self._per_agent_times.values())
def agent_is_dwelling(self, agent: Agent):
return self._per_agent_times[agent.name] < self.dwell_time
def summarize_state(self) -> dict:
state_summary = super().summarize_state()
state_summary.update(per_agent_times=[
dict(belongs_to=key, time=val) for key, val in self._per_agent_times.keys()], dwell_time=self.dwell_time)
return state_summary
def render(self):
return RenderEntity(d.DESTINATION, self.pos)

View File

@ -0,0 +1,28 @@
from mfg_package.environment.groups.env_objects import EnvObjects
from mfg_package.environment.groups.mixins import PositionMixin
from mfg_package.modules.destinations.entitites import Destination
class Destinations(PositionMixin, EnvObjects):
_entity = Destination
is_blocking_light: bool = False
can_collide: bool = False
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
def __repr__(self):
return super(Destinations, self).__repr__()
class ReachedDestinations(Destinations):
_entity = Destination
is_blocking_light = False
can_collide = False
def __init__(self, *args, **kwargs):
super(ReachedDestinations, self).__init__(*args, **kwargs)
def __repr__(self):
return super(ReachedDestinations, self).__repr__()

View File

@ -0,0 +1,3 @@
WAIT_VALID: float = 0.1
WAIT_FAIL: float = -0.1
DEST_REACHED: float = 5.0

View File

@ -0,0 +1,89 @@
from typing import List, Union
from mfg_package.environment.rules import Rule
from mfg_package.utils.results import TickResult, DoneResult
from mfg_package.environment import constants as c
from mfg_package.modules.destinations import constants as d, rewards as r
from mfg_package.modules.destinations.entitites import Destination
class DestinationReach(Rule):
def __init__(self, n_dests: int = 1, tiles: Union[List, None] = None):
super(DestinationReach, self).__init__()
self.n_dests = n_dests or len(tiles)
self._tiles = tiles
def tick_step(self, state) -> List[TickResult]:
for dest in list(state[d.DESTINATION].values()):
if dest.is_considered_reached:
dest.change_parent_collection(state[d.DEST_REACHED])
state.print(f'{dest.name} is reached now, removing...')
else:
for agent_name in dest.currently_dwelling_names:
agent = state[c.AGENT][agent_name]
if agent.pos == dest.pos:
state.print(f'{agent.name} is still waiting.')
pass
else:
dest.leave(agent)
state.print(f'{agent.name} left the destination early.')
return [TickResult(self.name, validity=c.VALID, reward=0, entity=None)]
def tick_post_step(self, state) -> List[TickResult]:
results = list()
for reached_dest in state[d.DEST_REACHED]:
for guest in reached_dest.tile.guests:
if guest in state[c.AGENT]:
state.print(f'{guest.name} just reached destination at {guest.pos}')
state[d.DEST_REACHED].delete_env_object(reached_dest)
results.append(TickResult(self.name, validity=c.VALID, reward=r.DEST_REACHED, entity=guest))
return results
class DestinationDone(Rule):
def __init__(self):
super(DestinationDone, self).__init__()
def on_check_done(self, state) -> List[DoneResult]:
if not len(state[d.DESTINATION]):
return [DoneResult(self.name, validity=c.VALID, reward=r.DEST_REACHED)]
return []
class DestinationSpawn(Rule):
def __init__(self, spawn_frequency: int = 5, n_dests: int = 1,
spawn_mode: str = d.MODE_GROUPED):
super(DestinationSpawn, self).__init__()
self.spawn_frequency = spawn_frequency
self.n_dests = n_dests
self.spawn_mode = spawn_mode
def on_init(self, state):
# noinspection PyAttributeOutsideInit
self._dest_spawn_timer = self.spawn_frequency
self.trigger_destination_spawn(self.n_dests, state)
pass
def tick_pre_step(self, state) -> List[TickResult]:
pass
def tick_step(self, state) -> List[TickResult]:
if n_dest_spawn := max(0, self.n_dests - len(state[d.DESTINATION])):
if self.spawn_mode == d.MODE_GROUPED and n_dest_spawn == self.n_dests:
validity = state.rules['DestinationReach'].trigger_destination_spawn(n_dest_spawn, state)
return [TickResult(self.name, validity=validity, entity=None, value=n_dest_spawn)]
@staticmethod
def trigger_destination_spawn(n_dests, state, tiles=None):
tiles = tiles or state[c.FLOOR].empty_tiles[:n_dests]
if destinations := [Destination(tile) for tile in tiles]:
state[d.DESTINATION].add_items(destinations)
state.print(f'{n_dests} new destinations have been spawned')
return c.VALID
else:
state.print('No Destiantions are spawning, limit is reached.')
return c.NOT_VALID

View File

View File

@ -0,0 +1,28 @@
from typing import Union
from mfg_package.environment.actions import Action
from mfg_package.utils.results import ActionResult
from mfg_package.modules.doors import constants as d, rewards as r
from mfg_package.environment import constants as c
class DoorUse(Action):
def __init__(self):
super().__init__(d.ACTION_DOOR_USE)
def do(self, entity, state) -> Union[None, ActionResult]:
# Check if agent really is standing on a door:
e = state.entities.get_near_pos(entity.pos)
try:
door = next(x for x in e if x.name.startswith(d.DOOR))
valid = door.use()
state.print(f'{entity.name} just used a {door.name} at {door.pos}')
return ActionResult(entity=entity, identifier=self._identifier, validity=valid, reward=r.USE_DOOR_VALID)
except StopIteration:
# When he doesn't...
state.print(f'{entity.name} just tried to use a door at {entity.pos}, but there is none.')
return ActionResult(entity=entity, identifier=self._identifier,
validity=c.NOT_VALID, reward=r.USE_DOOR_FAIL)

View File

@ -0,0 +1,18 @@
# Names / Identifiers
DOOR = 'Door' # Identifier of Single-Door Entities.
DOORS = 'Doors' # Identifier of Door-objects and groups (groups).
# Symbols (in map)
SYMBOL_DOOR = 'D' # Door _identifier for resolving the string based map files.
# Values
VALUE_ACCESS_INDICATOR = 1 / 3 # Access-door-Cell value used in observation
VALUE_OPEN_DOOR = 2 / 3 # Open-door-Cell value used in observation
VALUE_CLOSED_DOOR = 3 / 3 # Closed-door-Cell value used in observation
# States
STATE_CLOSED = 'closed' # Identifier to compare door-is-closed state
STATE_OPEN = 'open' # Identifier to compare door-is-open state
# Actions
ACTION_DOOR_USE = 'use_door'

Binary file not shown.

After

Width:  |  Height:  |  Size: 699 B

Binary file not shown.

After

Width:  |  Height:  |  Size: 4.1 KiB

View File

@ -0,0 +1,97 @@
from mfg_package.environment.entity.entity import Entity
from mfg_package.utils.render import RenderEntity
from mfg_package.environment import constants as c
from mfg_package.modules.doors import constants as d
class DoorIndicator(Entity):
@property
def encoding(self):
return d.VALUE_ACCESS_INDICATOR
def render(self):
return None
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.__delattr__('move')
class Door(Entity):
@property
def is_blocking_pos(self):
return False if self.is_open else True
@property
def is_blocking_light(self):
return False if self.is_open else True
@property
def can_collide(self):
return False if self.is_open else True
@property
def encoding(self):
return d.VALUE_CLOSED_DOOR if self.is_closed else d.VALUE_OPEN_DOOR
@property
def str_state(self):
return 'open' if self.is_open else 'closed'
def __init__(self, *args, closed_on_init=True, auto_close_interval=10, indicate_area=False, **kwargs):
self._state = d.STATE_CLOSED
super(Door, self).__init__(*args, **kwargs)
self.auto_close_interval = auto_close_interval
self.time_to_close = 0
if not closed_on_init:
self._open()
if indicate_area:
self._collection.add_items([DoorIndicator(x) for x in self.tile.neighboring_floor])
def summarize_state(self):
state_dict = super().summarize_state()
state_dict.update(state=str(self.str_state), time_to_close=int(self.time_to_close))
return state_dict
@property
def is_closed(self):
return self._state == d.STATE_CLOSED
@property
def is_open(self):
return self._state == d.STATE_OPEN
@property
def status(self):
return self._state
def render(self):
name, state = 'door_open' if self.is_open else 'door_closed', 'blank'
return RenderEntity(name, self.pos, 1, 'none', state, self.identifier_int + 1)
def use(self):
if self._state == d.STATE_OPEN:
self._close()
else:
self._open()
return c.VALID
def tick(self):
if self.is_open and len(self.tile) == 1 and self.time_to_close:
self.time_to_close -= 1
return c.NOT_VALID
elif self.is_open and not self.time_to_close and len(self.tile) == 1:
self.use()
return c.VALID
else:
return c.NOT_VALID
def _open(self):
self._state = d.STATE_OPEN
self.time_to_close = self.auto_close_interval
def _close(self):
self._state = d.STATE_CLOSED

View File

@ -0,0 +1,28 @@
from typing import Union
from mfg_package.environment.groups.env_objects import EnvObjects
from mfg_package.environment.groups.mixins import PositionMixin
from mfg_package.modules.doors import constants as d
from mfg_package.modules.doors.entitites import Door
class Doors(PositionMixin, EnvObjects):
symbol = d.SYMBOL_DOOR
_entity = Door
def __init__(self, *args, **kwargs):
super(Doors, self).__init__(*args, can_collide=True, **kwargs)
def get_near_position(self, position: (int, int)) -> Union[None, Door]:
try:
return next(door for door in self if position in door.tile.neighboring_floor_pos)
except StopIteration:
return None
def tick_doors(self):
result_dict = dict()
for door in self:
did_tick = door.tick()
result_dict.update({door.name: did_tick})
return result_dict

View File

@ -0,0 +1,2 @@
USE_DOOR_VALID: float = -0.00
USE_DOOR_FAIL: float = -0.01

View File

@ -0,0 +1,21 @@
from mfg_package.environment.rules import Rule
from mfg_package.environment import constants as c
from mfg_package.utils.results import TickResult
from mfg_package.modules.doors import constants as d
class DoorAutoClose(Rule):
def __init__(self, close_frequency: int = 10):
super().__init__()
self.close_frequency = close_frequency
def tick_step(self, state):
if doors := state[d.DOORS]:
doors_tick_result = doors.tick_doors()
doors_that_ticked = [key for key, val in doors_tick_result.items() if val]
state.print(f'{doors_that_ticked} were auto-closed'
if doors_that_ticked else 'No Doors were auto-closed')
return [TickResult(self.name, validity=c.VALID, value=0)]
state.print('There are no doors, but you loaded the corresponding Module')
return []

View File

View File

@ -0,0 +1,37 @@
from typing import Union
from mfg_package.environment.actions import Action
from mfg_package.utils.results import ActionResult
from mfg_package.modules.items import constants as i, rewards as r
from mfg_package.environment import constants as c
class ItemAction(Action):
def __init__(self):
super().__init__(i.ITEM_ACTION)
def do(self, entity, state) -> Union[None, ActionResult]:
inventory = state[i.INVENTORY].by_entity(entity)
if drop_off := state[i.DROP_OFF].by_pos(entity.pos):
if inventory:
valid = drop_off.place_item(inventory.pop())
else:
valid = c.NOT_VALID
if valid:
state.print(f'{entity.name} just dropped of an item at {drop_off.pos}.')
else:
state.print(f'{entity.name} just tried to drop off at {entity.pos}, but failed.')
reward = r.DROP_OFF_VALID if valid else r.DROP_OFF_FAIL
return ActionResult(entity=entity, identifier=self._identifier, validity=valid, reward=reward)
elif item := state[i.ITEM].by_pos(entity.pos):
item.change_parent_collection(inventory)
item.set_tile_to(state.NO_POS_TILE)
state.print(f'{entity.name} just picked up an item at {entity.pos}')
return ActionResult(entity=entity, identifier=self._identifier, validity=c.VALID, reward=r.PICK_UP_VALID)
else:
state.print(f'{entity.name} just tried to pick up an item at {entity.pos}, but failed.')
return ActionResult(entity=entity, identifier=self._identifier, validity=c.NOT_VALID, reward=r.PICK_UP_FAIL)

Binary file not shown.

After

Width:  |  Height:  |  Size: 6.5 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 2.3 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 3.0 KiB

View File

@ -0,0 +1,11 @@
from typing import NamedTuple
SYMBOL_NO_ITEM = 0
SYMBOL_DROP_OFF = 1
# Item Env
ITEM = 'Items'
INVENTORY = 'Inventories'
DROP_OFF = 'DropOffLocations'
ITEM_ACTION = 'ITEMACTION'

View File

@ -0,0 +1,64 @@
from collections import deque
from mfg_package.environment.entity.entity import Entity
from mfg_package.environment import constants as c
from mfg_package.utils.render import RenderEntity
from mfg_package.modules.items import constants as i
class Item(Entity):
def render(self):
return RenderEntity(i.ITEM, self.tile.pos) if self.pos != c.VALUE_NO_POS else None
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self._auto_despawn = -1
@property
def auto_despawn(self):
return self._auto_despawn
@property
def encoding(self):
# Edit this if you want items to be drawn in the ops differently
return 1
def set_auto_despawn(self, auto_despawn):
self._auto_despawn = auto_despawn
def set_tile_to(self, no_pos_tile):
self._tile = no_pos_tile
def summarize_state(self) -> dict:
super_summarization = super(Item, self).summarize_state()
super_summarization.update(dict(auto_despawn=self.auto_despawn))
return super_summarization
class DropOffLocation(Entity):
def render(self):
return RenderEntity(i.DROP_OFF, self.tile.pos)
@property
def encoding(self):
return i.SYMBOL_DROP_OFF
def __init__(self, *args, storage_size_until_full: int = 5, auto_item_despawn_interval: int = 5, **kwargs):
super(DropOffLocation, self).__init__(*args, **kwargs)
self.auto_item_despawn_interval = auto_item_despawn_interval
self.storage = deque(maxlen=storage_size_until_full or None)
def place_item(self, item: Item):
if self.is_full:
raise RuntimeWarning("There is currently no way to clear the storage or make it unfull.")
return bc.NOT_VALID
else:
self.storage.append(item)
item.set_auto_despawn(self.auto_item_despawn_interval)
return c.VALID
@property
def is_full(self):
return False if not self.storage.maxlen else self.storage.maxlen == len(self.storage)

View File

@ -0,0 +1,101 @@
from typing import List
from mfg_package.environment.groups.env_objects import EnvObjects
from mfg_package.environment.groups.objects import Objects
from mfg_package.environment.groups.mixins import PositionMixin, IsBoundMixin, HasBoundedMixin
from mfg_package.environment.entity.wall_floor import Floor
from mfg_package.environment.entity.agent import Agent
from mfg_package.modules.items.entitites import Item, DropOffLocation
class Items(PositionMixin, EnvObjects):
_entity = Item
is_blocking_light: bool = False
can_collide: bool = False
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
def spawn_items(self, tiles: List[Floor]):
items = [self._entity(tile) for tile in tiles]
self.add_items(items)
def despawn_items(self, items: List[Item]):
items = [items] if isinstance(items, Item) else items
for item in items:
del self[item]
class Inventory(IsBoundMixin, EnvObjects):
_accepted_objects = Item
@property
def obs_tag(self):
return self.name
def __init__(self, agent: Agent, *args, **kwargs):
super(Inventory, self).__init__(*args, **kwargs)
self._collection = None
self.bind(agent)
def summarize_states(self, **kwargs):
attr_dict = {key: val for key, val in self.__dict__.items() if not key.startswith('_') and key != 'data'}
attr_dict.update(dict(items=[val.summarize_state(**kwargs) for key, val in self.items()]))
attr_dict.update(dict(name=self.name, belongs_to=self._bound_entity.name))
return attr_dict
def pop(self):
item_to_pop = self[0]
self.delete_env_object(item_to_pop)
return item_to_pop
def set_collection(self, collection):
self._collection = collection
class Inventories(HasBoundedMixin, Objects):
_entity = Inventory
can_move = False
@property
def obs_pairs(self):
return [(x.name, x) for x in self]
def __init__(self, size, *args, **kwargs):
super(Inventories, self).__init__(*args, **kwargs)
self.size = size
self._obs = None
self._lazy_eval_transforms = []
def spawn_inventories(self, agents):
inventories = [self._entity(agent, self.size,)
for _, agent in enumerate(agents)]
self.add_items(inventories)
def idx_by_entity(self, entity):
try:
return next((idx for idx, inv in enumerate(self) if inv.belongs_to_entity(entity)))
except StopIteration:
return None
def by_entity(self, entity):
try:
return next((inv for inv in self if inv.belongs_to_entity(entity)))
except StopIteration:
return None
def summarize_states(self, **kwargs):
return [val.summarize_states(**kwargs) for key, val in self.items()]
class DropOffLocations(PositionMixin, EnvObjects):
_entity = DropOffLocation
is_blocking_light: bool = False
can_collide: bool = False
def __init__(self, *args, **kwargs):
super(DropOffLocations, self).__init__(*args, **kwargs)

View File

@ -0,0 +1,4 @@
DROP_OFF_VALID: float = 0.1
DROP_OFF_FAIL: float = -0.1
PICK_UP_FAIL: float = -0.1
PICK_UP_VALID: float = 0.1

View File

@ -0,0 +1,79 @@
from typing import List
from mfg_package.environment.rules import Rule
from mfg_package.environment import constants as c
from mfg_package.utils.results import TickResult
from mfg_package.modules.items import constants as i
from mfg_package.modules.items.entitites import DropOffLocation
class ItemRules(Rule):
def __init__(self, n_items: int = 5, spawn_frequency: int = 15,
n_locations: int = 5, max_dropoff_storage_size: int = 0):
super().__init__()
self.spawn_frequency = spawn_frequency
self._next_item_spawn = spawn_frequency
self.n_items = n_items
self.max_dropoff_storage_size = max_dropoff_storage_size
self.n_locations = n_locations
def on_init(self, state):
self.trigger_drop_off_location_spawn(state)
self._next_item_spawn = self.spawn_frequency
self.trigger_inventory_spawn(state)
self.trigger_item_spawn(state)
def tick_step(self, state):
for item in list(state[i.ITEM].values()):
if item.auto_despawn >= 1:
item.set_auto_despawn(item.auto_despawn - 1)
elif not item.auto_despawn:
state[i.ITEM].delete_env_object(item)
else:
pass
if not self._next_item_spawn:
self.trigger_item_spawn(state)
else:
self._next_item_spawn = max(0, self._next_item_spawn - 1)
return []
def trigger_item_spawn(self, state):
if item_to_spawns := max(0, (self.n_items - len(state[i.ITEM]))):
empty_tiles = state[c.FLOOR].empty_tiles[:item_to_spawns]
state[i.ITEM].spawn_items(empty_tiles)
self._next_item_spawn = self.spawn_frequency
state.print(f'{item_to_spawns} new items have been spawned; next spawn in {self._next_item_spawn}')
return len(empty_tiles)
else:
state.print('No Items are spawning, limit is reached.')
return 0
@staticmethod
def trigger_inventory_spawn(state):
state[i.INVENTORY].spawn_inventories(state[c.AGENT])
def tick_post_step(self, state) -> List[TickResult]:
for item in list(state[i.ITEM].values()):
if item.auto_despawn >= 1:
item.set_auto_despawn(item.auto_despawn-1)
elif not item.auto_despawn:
state[i.ITEM].delete_env_object(item)
else:
pass
if not self._next_item_spawn:
if spawned_items := self.trigger_item_spawn(state):
return [TickResult(self.name, validity=c.VALID, value=spawned_items, entity=None)]
else:
return [TickResult(self.name, validity=c.NOT_VALID, value=0, entity=None)]
else:
self._next_item_spawn = max(0, self._next_item_spawn-1)
return []
def trigger_drop_off_location_spawn(self, state):
empty_tiles = state[c.FLOOR].empty_tiles[:self.n_locations]
do_entites = state[i.DROP_OFF]
drop_offs = [DropOffLocation(tile) for tile in empty_tiles]
do_entites.add_items(drop_offs)

Some files were not shown because too many files have changed in this diff Show More