refactoring and init.py
0
mfg_package/__init__.py
Normal file
1
mfg_package/algorithms/__init__.py
Normal file
@ -0,0 +1 @@
|
||||
import os, sys; sys.path.append(os.path.dirname(os.path.realpath(__file__)))
|
1
mfg_package/algorithms/marl/__init__.py
Normal file
@ -0,0 +1 @@
|
||||
from mfg_package.algorithms.marl.memory import MARLActorCriticMemory
|
221
mfg_package/algorithms/marl/base_ac.py
Normal file
@ -0,0 +1,221 @@
|
||||
import torch
|
||||
from typing import Union, List, Dict
|
||||
import numpy as np
|
||||
from torch.distributions import Categorical
|
||||
from mfg_package.algorithms.marl.memory import MARLActorCriticMemory
|
||||
from mfg_package.algorithms.utils import add_env_props, instantiate_class
|
||||
from pathlib import Path
|
||||
import pandas as pd
|
||||
from collections import deque
|
||||
|
||||
|
||||
class Names:
|
||||
REWARD = 'reward'
|
||||
DONE = 'done'
|
||||
ACTION = 'action'
|
||||
OBSERVATION = 'observation'
|
||||
LOGITS = 'logits'
|
||||
HIDDEN_ACTOR = 'hidden_actor'
|
||||
HIDDEN_CRITIC = 'hidden_critic'
|
||||
AGENT = 'agent'
|
||||
ENV = 'environment'
|
||||
N_AGENTS = 'n_agents'
|
||||
ALGORITHM = 'algorithm'
|
||||
MAX_STEPS = 'max_steps'
|
||||
N_STEPS = 'n_steps'
|
||||
BUFFER_SIZE = 'buffer_size'
|
||||
CRITIC = 'critic'
|
||||
BATCH_SIZE = 'bnatch_size'
|
||||
N_ACTIONS = 'n_actions'
|
||||
|
||||
nms = Names
|
||||
ListOrTensor = Union[List, torch.Tensor]
|
||||
|
||||
|
||||
class BaseActorCritic:
|
||||
def __init__(self, cfg):
|
||||
add_env_props(cfg)
|
||||
self.__training = True
|
||||
self.cfg = cfg
|
||||
self.n_agents = cfg[nms.ENV][nms.N_AGENTS]
|
||||
self.reset_memory_after_epoch = True
|
||||
self.setup()
|
||||
|
||||
def setup(self):
|
||||
self.net = instantiate_class(self.cfg[nms.AGENT])
|
||||
self.optimizer = torch.optim.RMSprop(self.net.parameters(), lr=3e-4, eps=1e-5)
|
||||
|
||||
@classmethod
|
||||
def _as_torch(cls, x):
|
||||
if isinstance(x, np.ndarray):
|
||||
return torch.from_numpy(x)
|
||||
elif isinstance(x, List):
|
||||
return torch.tensor(x)
|
||||
elif isinstance(x, (int, float)):
|
||||
return torch.tensor([x])
|
||||
return x
|
||||
|
||||
def train(self):
|
||||
self.__training = False
|
||||
networks = [self.net] if not isinstance(self.net, List) else self.net
|
||||
for net in networks:
|
||||
net.train()
|
||||
|
||||
def eval(self):
|
||||
self.__training = False
|
||||
networks = [self.net] if not isinstance(self.net, List) else self.net
|
||||
for net in networks:
|
||||
net.eval()
|
||||
|
||||
def load_state_dict(self, path: Path):
|
||||
pass
|
||||
|
||||
def get_actions(self, out) -> ListOrTensor:
|
||||
actions = [Categorical(logits=logits).sample().item() for logits in out[nms.LOGITS]]
|
||||
return actions
|
||||
|
||||
def init_hidden(self) -> Dict[str, ListOrTensor]:
|
||||
pass
|
||||
|
||||
def forward(self,
|
||||
observations: ListOrTensor,
|
||||
actions: ListOrTensor,
|
||||
hidden_actor: ListOrTensor,
|
||||
hidden_critic: ListOrTensor
|
||||
) -> Dict[str, ListOrTensor]:
|
||||
pass
|
||||
|
||||
@torch.no_grad()
|
||||
def train_loop(self, checkpointer=None):
|
||||
env = instantiate_class(self.cfg[nms.ENV])
|
||||
n_steps, max_steps = [self.cfg[nms.ALGORITHM][k] for k in [nms.N_STEPS, nms.MAX_STEPS]]
|
||||
tm = MARLActorCriticMemory(self.n_agents, self.cfg[nms.ALGORITHM].get(nms.BUFFER_SIZE, n_steps))
|
||||
global_steps, episode, df_results = 0, 0, []
|
||||
reward_queue = deque(maxlen=2000)
|
||||
|
||||
while global_steps < max_steps:
|
||||
obs = env.reset()
|
||||
last_hiddens = self.init_hidden()
|
||||
last_action, reward = [-1] * self.n_agents, [0.] * self.n_agents
|
||||
done, rew_log = [False] * self.n_agents, 0
|
||||
|
||||
if self.reset_memory_after_epoch:
|
||||
tm.reset()
|
||||
|
||||
tm.add(observation=obs, action=last_action,
|
||||
logits=torch.zeros(self.n_agents, 1, self.cfg[nms.AGENT][nms.N_ACTIONS]),
|
||||
values=torch.zeros(self.n_agents, 1), reward=reward, done=done, **last_hiddens)
|
||||
|
||||
while not all(done):
|
||||
out = self.forward(obs, last_action, **last_hiddens)
|
||||
action = self.get_actions(out)
|
||||
next_obs, reward, done, info = env.step(action)
|
||||
done = [done] * self.n_agents if isinstance(done, bool) else done
|
||||
|
||||
last_hiddens = dict(hidden_actor =out[nms.HIDDEN_ACTOR],
|
||||
hidden_critic=out[nms.HIDDEN_CRITIC])
|
||||
|
||||
|
||||
tm.add(observation=obs, action=action, reward=reward, done=done,
|
||||
logits=out.get(nms.LOGITS, None), values=out.get(nms.CRITIC, None),
|
||||
**last_hiddens)
|
||||
|
||||
obs = next_obs
|
||||
last_action = action
|
||||
|
||||
if (global_steps+1) % n_steps == 0 or all(done):
|
||||
with torch.inference_mode(False):
|
||||
self.learn(tm)
|
||||
|
||||
global_steps += 1
|
||||
rew_log += sum(reward)
|
||||
reward_queue.extend(reward)
|
||||
|
||||
if checkpointer is not None:
|
||||
checkpointer.step([
|
||||
(f'agent#{i}', agent)
|
||||
for i, agent in enumerate([self.net] if not isinstance(self.net, List) else self.net)
|
||||
])
|
||||
|
||||
if global_steps >= max_steps:
|
||||
break
|
||||
print(f'reward at episode: {episode} = {rew_log}')
|
||||
episode += 1
|
||||
df_results.append([episode, rew_log, *reward])
|
||||
df_results = pd.DataFrame(df_results, columns=['steps', 'reward', *[f'agent#{i}' for i in range(self.n_agents)]])
|
||||
if checkpointer is not None:
|
||||
df_results.to_csv(checkpointer.path / 'results.csv', index=False)
|
||||
return df_results
|
||||
|
||||
@torch.inference_mode(True)
|
||||
def eval_loop(self, n_episodes, render=False):
|
||||
env = instantiate_class(self.cfg[nms.ENV])
|
||||
episode, results = 0, []
|
||||
while episode < n_episodes:
|
||||
obs = env.reset()
|
||||
last_hiddens = self.init_hidden()
|
||||
last_action, reward = [-1] * self.n_agents, [0.] * self.n_agents
|
||||
done, rew_log, eps_rew = [False] * self.n_agents, 0, torch.zeros(self.n_agents)
|
||||
while not all(done):
|
||||
if render: env.render()
|
||||
|
||||
out = self.forward(obs, last_action, **last_hiddens)
|
||||
action = self.get_actions(out)
|
||||
next_obs, reward, done, info = env.step(action)
|
||||
|
||||
if isinstance(done, bool): done = [done] * obs.shape[0]
|
||||
obs = next_obs
|
||||
last_action = action
|
||||
last_hiddens = dict(hidden_actor=out.get(nms.HIDDEN_ACTOR, None),
|
||||
hidden_critic=out.get(nms.HIDDEN_CRITIC, None)
|
||||
)
|
||||
eps_rew += torch.tensor(reward)
|
||||
results.append(eps_rew.tolist() + [sum(eps_rew).item()] + [episode])
|
||||
episode += 1
|
||||
agent_columns = [f'agent#{i}' for i in range(self.cfg['environment']['n_agents'])]
|
||||
results = pd.DataFrame(results, columns=agent_columns + ['sum', 'episode'])
|
||||
results = pd.melt(results, id_vars=['episode'], value_vars=agent_columns + ['sum'], value_name='reward', var_name='agent')
|
||||
return results
|
||||
|
||||
@staticmethod
|
||||
def compute_advantages(critic, reward, done, gamma, gae_coef=0.0):
|
||||
tds = (reward + gamma * (1.0 - done) * critic[:, 1:].detach()) - critic[:, :-1]
|
||||
|
||||
if gae_coef <= 0:
|
||||
return tds
|
||||
|
||||
gae = torch.zeros_like(tds[:, -1])
|
||||
gaes = []
|
||||
for t in range(tds.shape[1]-1, -1, -1):
|
||||
gae = tds[:, t] + gamma * gae_coef * (1.0 - done[:, t]) * gae
|
||||
gaes.insert(0, gae)
|
||||
gaes = torch.stack(gaes, dim=1)
|
||||
return gaes
|
||||
|
||||
def actor_critic(self, tm, network, gamma, entropy_coef, vf_coef, gae_coef=0.0, **kwargs):
|
||||
obs, actions, done, reward = tm.observation, tm.action, tm.done[:, 1:], tm.reward[:, 1:]
|
||||
|
||||
out = network(obs, actions, tm.hidden_actor[:, 0], tm.hidden_critic[:, 0])
|
||||
logits = out[nms.LOGITS][:, :-1] # last one only needed for v_{t+1}
|
||||
critic = out[nms.CRITIC]
|
||||
|
||||
entropy_loss = Categorical(logits=logits).entropy().mean(-1)
|
||||
advantages = self.compute_advantages(critic, reward, done, gamma, gae_coef)
|
||||
value_loss = advantages.pow(2).mean(-1) # n_agent
|
||||
|
||||
# policy loss
|
||||
log_ap = torch.log_softmax(logits, -1)
|
||||
log_ap = torch.gather(log_ap, dim=-1, index=actions[:, 1:].unsqueeze(-1)).squeeze()
|
||||
a2c_loss = -(advantages.detach() * log_ap).mean(-1)
|
||||
# weighted loss
|
||||
loss = a2c_loss + vf_coef*value_loss - entropy_coef * entropy_loss
|
||||
return loss.mean()
|
||||
|
||||
def learn(self, tm: MARLActorCriticMemory, **kwargs):
|
||||
loss = self.actor_critic(tm, self.net, **self.cfg[nms.ALGORITHM], **kwargs)
|
||||
# remove next_obs, will be added in next iter
|
||||
self.optimizer.zero_grad()
|
||||
loss.backward()
|
||||
torch.nn.utils.clip_grad_norm_(self.net.parameters(), 0.5)
|
||||
self.optimizer.step()
|
||||
|
24
mfg_package/algorithms/marl/example_config.yaml
Normal file
@ -0,0 +1,24 @@
|
||||
agent:
|
||||
classname: algorithms.marl.networks.RecurrentAC
|
||||
n_agents: 2
|
||||
obs_emb_size: 96
|
||||
action_emb_size: 16
|
||||
hidden_size_actor: 64
|
||||
hidden_size_critic: 64
|
||||
use_agent_embedding: False
|
||||
env:
|
||||
classname: environments.factory.make
|
||||
env_name: "DirtyFactory-v0"
|
||||
n_agents: 2
|
||||
max_steps: 250
|
||||
pomdp_r: 2
|
||||
stack_n_frames: 0
|
||||
individual_rewards: True
|
||||
method: algorithms.marl.LoopSEAC
|
||||
algorithm:
|
||||
gamma: 0.99
|
||||
entropy_coef: 0.01
|
||||
vf_coef: 0.5
|
||||
n_steps: 5
|
||||
max_steps: 1000000
|
||||
|
57
mfg_package/algorithms/marl/iac.py
Normal file
@ -0,0 +1,57 @@
|
||||
import torch
|
||||
from mfg_package.algorithms.marl.base_ac import BaseActorCritic, nms
|
||||
from mfg_package.algorithms.utils import instantiate_class
|
||||
from pathlib import Path
|
||||
from natsort import natsorted
|
||||
from mfg_package.algorithms.marl.memory import MARLActorCriticMemory
|
||||
|
||||
|
||||
class LoopIAC(BaseActorCritic):
|
||||
|
||||
def __init__(self, cfg):
|
||||
super(LoopIAC, self).__init__(cfg)
|
||||
|
||||
def setup(self):
|
||||
self.net = [
|
||||
instantiate_class(self.cfg[nms.AGENT]) for _ in range(self.n_agents)
|
||||
]
|
||||
self.optimizer = [
|
||||
torch.optim.RMSprop(self.net[ag_i].parameters(), lr=3e-4, eps=1e-5) for ag_i in range(self.n_agents)
|
||||
]
|
||||
|
||||
def load_state_dict(self, path: Path):
|
||||
paths = natsorted(list(path.glob('*.pt')))
|
||||
for path, net in zip(paths, self.net):
|
||||
net.load_state_dict(torch.load(path))
|
||||
|
||||
@staticmethod
|
||||
def merge_dicts(ds): # todo could be recursive for more than 1 hierarchy
|
||||
d = {}
|
||||
for k in ds[0].keys():
|
||||
d[k] = [d[k] for d in ds]
|
||||
return d
|
||||
|
||||
def init_hidden(self):
|
||||
ha = [net.init_hidden_actor() for net in self.net]
|
||||
hc = [net.init_hidden_critic() for net in self.net]
|
||||
return dict(hidden_actor=ha, hidden_critic=hc)
|
||||
|
||||
def forward(self, observations, actions, hidden_actor, hidden_critic):
|
||||
outputs = [
|
||||
net(
|
||||
self._as_torch(observations[ag_i]).unsqueeze(0).unsqueeze(0), # agent x time
|
||||
self._as_torch(actions[ag_i]).unsqueeze(0),
|
||||
hidden_actor[ag_i],
|
||||
hidden_critic[ag_i]
|
||||
) for ag_i, net in enumerate(self.net)
|
||||
]
|
||||
return self.merge_dicts(outputs)
|
||||
|
||||
def learn(self, tms: MARLActorCriticMemory, **kwargs):
|
||||
for ag_i in range(self.n_agents):
|
||||
tm, net = tms(ag_i), self.net[ag_i]
|
||||
loss = self.actor_critic(tm, net, **self.cfg[nms.ALGORITHM], **kwargs)
|
||||
self.optimizer[ag_i].zero_grad()
|
||||
loss.backward()
|
||||
torch.nn.utils.clip_grad_norm_(net.parameters(), 0.5)
|
||||
self.optimizer[ag_i].step()
|
66
mfg_package/algorithms/marl/mappo.py
Normal file
@ -0,0 +1,66 @@
|
||||
from mfg_package.algorithms.marl.base_ac import Names as nms
|
||||
from mfg_package.algorithms.marl.snac import LoopSNAC
|
||||
from mfg_package.algorithms.marl.memory import MARLActorCriticMemory
|
||||
import torch
|
||||
from torch.distributions import Categorical
|
||||
from mfg_package.algorithms.utils import instantiate_class
|
||||
|
||||
|
||||
class LoopMAPPO(LoopSNAC):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super(LoopMAPPO, self).__init__(*args, **kwargs)
|
||||
self.reset_memory_after_epoch = False
|
||||
|
||||
def setup(self):
|
||||
self.net = instantiate_class(self.cfg[nms.AGENT])
|
||||
self.optimizer = torch.optim.Adam(self.net.parameters(), lr=3e-4, eps=1e-5)
|
||||
|
||||
def learn(self, tm: MARLActorCriticMemory, **kwargs):
|
||||
if len(tm) >= self.cfg['algorithm']['buffer_size']:
|
||||
# only learn when buffer is full
|
||||
for batch_i in range(self.cfg['algorithm']['n_updates']):
|
||||
batch = tm.chunk_dataloader(chunk_len=self.cfg['algorithm']['n_steps'],
|
||||
k=self.cfg['algorithm']['batch_size'])
|
||||
loss = self.mappo(batch, self.net, **self.cfg[nms.ALGORITHM], **kwargs)
|
||||
self.optimizer.zero_grad()
|
||||
loss.backward()
|
||||
torch.nn.utils.clip_grad_norm_(self.net.parameters(), 0.5)
|
||||
self.optimizer.step()
|
||||
|
||||
def monte_carlo_returns(self, rewards, done, gamma):
|
||||
rewards_ = []
|
||||
discounted_reward = torch.zeros_like(rewards[:, -1])
|
||||
for t in range(rewards.shape[1]-1, -1, -1):
|
||||
discounted_reward = rewards[:, t] + (gamma * (1.0 - done[:, t]) * discounted_reward)
|
||||
rewards_.insert(0, discounted_reward)
|
||||
rewards_ = torch.stack(rewards_, dim=1)
|
||||
return rewards_
|
||||
|
||||
def mappo(self, batch, network, gamma, entropy_coef, vf_coef, clip_range, **kwargs):
|
||||
out = network(batch[nms.OBSERVATION], batch[nms.ACTION], batch[nms.HIDDEN_ACTOR], batch[nms.HIDDEN_CRITIC])
|
||||
logits = out[nms.LOGITS][:, :-1] # last one only needed for v_{t+1}
|
||||
|
||||
old_log_probs = torch.log_softmax(batch[nms.LOGITS], -1)
|
||||
old_log_probs = torch.gather(old_log_probs, index=batch[nms.ACTION][:, 1:].unsqueeze(-1), dim=-1).squeeze()
|
||||
|
||||
# monte carlo returns
|
||||
mc_returns = self.monte_carlo_returns(batch[nms.REWARD], batch[nms.DONE], gamma)
|
||||
mc_returns = (mc_returns - mc_returns.mean()) / (mc_returns.std() + 1e-8) #todo: norm across agent ok?
|
||||
advantages = mc_returns - out[nms.CRITIC][:, :-1]
|
||||
|
||||
# policy loss
|
||||
log_ap = torch.log_softmax(logits, -1)
|
||||
log_ap = torch.gather(log_ap, dim=-1, index=batch[nms.ACTION][:, 1:].unsqueeze(-1)).squeeze()
|
||||
ratio = (log_ap - old_log_probs).exp()
|
||||
surr1 = ratio * advantages.detach()
|
||||
surr2 = torch.clamp(ratio, 1 - clip_range, 1 + clip_range) * advantages.detach()
|
||||
policy_loss = -torch.min(surr1, surr2).mean(-1)
|
||||
|
||||
# entropy & value loss
|
||||
entropy_loss = Categorical(logits=logits).entropy().mean(-1)
|
||||
value_loss = advantages.pow(2).mean(-1) # n_agent
|
||||
|
||||
# weighted loss
|
||||
loss = policy_loss + vf_coef*value_loss - entropy_coef * entropy_loss
|
||||
|
||||
return loss.mean()
|
221
mfg_package/algorithms/marl/memory.py
Normal file
@ -0,0 +1,221 @@
|
||||
import numpy as np
|
||||
from collections import deque
|
||||
import torch
|
||||
from typing import Union
|
||||
from torch import Tensor
|
||||
from torch.utils.data import Dataset, ConcatDataset
|
||||
import random
|
||||
|
||||
|
||||
class ActorCriticMemory(object):
|
||||
def __init__(self, capacity=10):
|
||||
self.capacity = capacity
|
||||
self.reset()
|
||||
|
||||
def reset(self):
|
||||
self.__actions = LazyTensorFiFoQueue(maxlen=self.capacity+1)
|
||||
self.__hidden_actor = LazyTensorFiFoQueue(maxlen=self.capacity+1)
|
||||
self.__hidden_critic = LazyTensorFiFoQueue(maxlen=self.capacity+1)
|
||||
self.__states = LazyTensorFiFoQueue(maxlen=self.capacity+1)
|
||||
self.__rewards = LazyTensorFiFoQueue(maxlen=self.capacity+1)
|
||||
self.__dones = LazyTensorFiFoQueue(maxlen=self.capacity+1)
|
||||
self.__logits = LazyTensorFiFoQueue(maxlen=self.capacity+1)
|
||||
self.__values = LazyTensorFiFoQueue(maxlen=self.capacity+1)
|
||||
|
||||
def __len__(self):
|
||||
return len(self.__rewards) - 1
|
||||
|
||||
@property
|
||||
def observation(self, sls=slice(0, None)): # add time dimension through stacking
|
||||
return self.__states[sls].unsqueeze(0) # 1 x time x hidden dim
|
||||
|
||||
@property
|
||||
def hidden_actor(self, sls=slice(0, None)): # 1 x n_layers x dim
|
||||
return self.__hidden_actor[sls].unsqueeze(0) # 1 x time x n_layers x dim
|
||||
|
||||
@property
|
||||
def hidden_critic(self, sls=slice(0, None)): # 1 x n_layers x dim
|
||||
return self.__hidden_critic[sls].unsqueeze(0) # 1 x time x n_layers x dim
|
||||
|
||||
@property
|
||||
def reward(self, sls=slice(0, None)):
|
||||
return self.__rewards[sls].squeeze().unsqueeze(0) # 1 x time
|
||||
|
||||
@property
|
||||
def action(self, sls=slice(0, None)):
|
||||
return self.__actions[sls].long().squeeze().unsqueeze(0) # 1 x time
|
||||
|
||||
@property
|
||||
def done(self, sls=slice(0, None)):
|
||||
return self.__dones[sls].float().squeeze().unsqueeze(0) # 1 x time
|
||||
|
||||
@property
|
||||
def logits(self, sls=slice(0, None)): # assumes a trailing 1 for time dimension - common when using output from NN
|
||||
return self.__logits[sls].squeeze().unsqueeze(0) # 1 x time x actions
|
||||
|
||||
@property
|
||||
def values(self, sls=slice(0, None)):
|
||||
return self.__values[sls].squeeze().unsqueeze(0) # 1 x time x actions
|
||||
|
||||
def add_observation(self, state: Union[Tensor, np.ndarray]):
|
||||
self.__states.append(state if isinstance(state, Tensor) else torch.from_numpy(state))
|
||||
|
||||
def add_hidden_actor(self, hidden: Tensor):
|
||||
# layers x hidden dim
|
||||
self.__hidden_actor.append(hidden)
|
||||
|
||||
def add_hidden_critic(self, hidden: Tensor):
|
||||
# layers x hidden dim
|
||||
self.__hidden_critic.append(hidden)
|
||||
|
||||
def add_action(self, action: Union[int, Tensor]):
|
||||
if not isinstance(action, Tensor):
|
||||
action = torch.tensor(action)
|
||||
self.__actions.append(action)
|
||||
|
||||
def add_reward(self, reward: Union[float, Tensor]):
|
||||
if not isinstance(reward, Tensor):
|
||||
reward = torch.tensor(reward)
|
||||
self.__rewards.append(reward)
|
||||
|
||||
def add_done(self, done: bool):
|
||||
if not isinstance(done, Tensor):
|
||||
done = torch.tensor(done)
|
||||
self.__dones.append(done)
|
||||
|
||||
def add_logits(self, logits: Tensor):
|
||||
self.__logits.append(logits)
|
||||
|
||||
def add_values(self, values: Tensor):
|
||||
self.__values.append(values)
|
||||
|
||||
def add(self, **kwargs):
|
||||
for k, v in kwargs.items():
|
||||
func = getattr(ActorCriticMemory, f'add_{k}')
|
||||
func(self, v)
|
||||
|
||||
|
||||
class MARLActorCriticMemory(object):
|
||||
def __init__(self, n_agents, capacity):
|
||||
self.n_agents = n_agents
|
||||
self.memories = [
|
||||
ActorCriticMemory(capacity) for _ in range(n_agents)
|
||||
]
|
||||
|
||||
def __call__(self, agent_i):
|
||||
return self.memories[agent_i]
|
||||
|
||||
def __len__(self):
|
||||
return len(self.memories[0]) # todo add assertion check!
|
||||
|
||||
def reset(self):
|
||||
for mem in self.memories:
|
||||
mem.reset()
|
||||
|
||||
def add(self, **kwargs):
|
||||
for agent_i in range(self.n_agents):
|
||||
for k, v in kwargs.items():
|
||||
func = getattr(ActorCriticMemory, f'add_{k}')
|
||||
func(self.memories[agent_i], v[agent_i])
|
||||
|
||||
def __getattr__(self, attr):
|
||||
all_attrs = [getattr(mem, attr) for mem in self.memories]
|
||||
return torch.cat(all_attrs, 0) # agent x time ...
|
||||
|
||||
def chunk_dataloader(self, chunk_len, k):
|
||||
datasets = [ExperienceChunks(mem, chunk_len, k) for mem in self.memories]
|
||||
dataset = ConcatDataset(datasets)
|
||||
data = [dataset[i] for i in range(len(dataset))]
|
||||
data = custom_collate_fn(data)
|
||||
return data
|
||||
|
||||
|
||||
def custom_collate_fn(batch):
|
||||
elem = batch[0]
|
||||
return {key: torch.cat([d[key] for d in batch], dim=0) for key in elem}
|
||||
|
||||
|
||||
class ExperienceChunks(Dataset):
|
||||
def __init__(self, memory, chunk_len, k):
|
||||
assert chunk_len <= len(memory), 'chunk_len cannot be longer than the size of the memory'
|
||||
self.memory = memory
|
||||
self.chunk_len = chunk_len
|
||||
self.k = k
|
||||
|
||||
@property
|
||||
def whitelist(self):
|
||||
whitelist = torch.ones(len(self.memory) - self.chunk_len)
|
||||
for d in self.memory.done.squeeze().nonzero().flatten():
|
||||
whitelist[max((0, d-self.chunk_len-1)):d+2] = 0
|
||||
whitelist[0] = 0
|
||||
return whitelist.tolist()
|
||||
|
||||
def sample(self, start=1):
|
||||
cl = self.chunk_len
|
||||
sample = dict(observation=self.memory.observation[:, start:start+cl+1],
|
||||
action=self.memory.action[:, start-1:start+cl],
|
||||
hidden_actor=self.memory.hidden_actor[:, start-1],
|
||||
hidden_critic=self.memory.hidden_critic[:, start-1],
|
||||
reward=self.memory.reward[:, start:start + cl],
|
||||
done=self.memory.done[:, start:start + cl],
|
||||
logits=self.memory.logits[:, start:start + cl],
|
||||
values=self.memory.values[:, start:start + cl])
|
||||
return sample
|
||||
|
||||
def __len__(self):
|
||||
return self.k
|
||||
|
||||
def __getitem__(self, i):
|
||||
idx = random.choices(range(0, len(self.memory) - self.chunk_len), weights=self.whitelist, k=1)
|
||||
return self.sample(idx[0])
|
||||
|
||||
|
||||
class LazyTensorFiFoQueue:
|
||||
def __init__(self, maxlen):
|
||||
self.maxlen = maxlen
|
||||
self.reset()
|
||||
|
||||
def reset(self):
|
||||
self.__lazy_queue = deque(maxlen=self.maxlen)
|
||||
self.shape = None
|
||||
self.queue = None
|
||||
|
||||
def shape_init(self, tensor: Tensor):
|
||||
self.shape = torch.Size([self.maxlen, *tensor.shape])
|
||||
|
||||
def build_tensor_queue(self):
|
||||
if len(self.__lazy_queue) > 0:
|
||||
block = torch.stack(list(self.__lazy_queue), dim=0)
|
||||
l = block.shape[0]
|
||||
if self.queue is None:
|
||||
self.queue = block
|
||||
elif self.true_len() <= self.maxlen:
|
||||
self.queue = torch.cat((self.queue, block), dim=0)
|
||||
else:
|
||||
self.queue = torch.cat((self.queue[l:], block), dim=0)
|
||||
self.__lazy_queue.clear()
|
||||
|
||||
def append(self, data):
|
||||
if self.shape is None:
|
||||
self.shape_init(data)
|
||||
self.__lazy_queue.append(data)
|
||||
if len(self.__lazy_queue) >= self.maxlen:
|
||||
self.build_tensor_queue()
|
||||
|
||||
def true_len(self):
|
||||
return len(self.__lazy_queue) + (0 if self.queue is None else self.queue.shape[0])
|
||||
|
||||
def __len__(self):
|
||||
return min((self.true_len(), self.maxlen))
|
||||
|
||||
def __str__(self):
|
||||
return f'LazyTensorFiFoQueue\tmaxlen: {self.maxlen}, shape: {self.shape}, ' \
|
||||
f'len: {len(self)}, true_len: {self.true_len()}, elements in lazy queue: {len(self.__lazy_queue)}'
|
||||
|
||||
def __getitem__(self, item_or_slice):
|
||||
self.build_tensor_queue()
|
||||
return self.queue[item_or_slice]
|
||||
|
||||
|
||||
|
||||
|
104
mfg_package/algorithms/marl/networks.py
Normal file
@ -0,0 +1,104 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import numpy as np
|
||||
import torch.nn.functional as F
|
||||
from torch.nn.utils import spectral_norm
|
||||
|
||||
|
||||
class RecurrentAC(nn.Module):
|
||||
def __init__(self, observation_size, n_actions, obs_emb_size,
|
||||
action_emb_size, hidden_size_actor, hidden_size_critic,
|
||||
n_agents, use_agent_embedding=True):
|
||||
super(RecurrentAC, self).__init__()
|
||||
observation_size = np.prod(observation_size)
|
||||
self.n_layers = 1
|
||||
self.n_actions = n_actions
|
||||
self.use_agent_embedding = use_agent_embedding
|
||||
self.hidden_size_actor = hidden_size_actor
|
||||
self.hidden_size_critic = hidden_size_critic
|
||||
self.action_emb_size = action_emb_size
|
||||
self.obs_proj = nn.Linear(observation_size, obs_emb_size)
|
||||
self.action_emb = nn.Embedding(n_actions+1, action_emb_size, padding_idx=0)
|
||||
self.agent_emb = nn.Embedding(n_agents, action_emb_size)
|
||||
mix_in_size = obs_emb_size+action_emb_size if not use_agent_embedding else obs_emb_size+n_agents*action_emb_size
|
||||
self.mix = nn.Sequential(nn.Tanh(),
|
||||
nn.Linear(mix_in_size, obs_emb_size),
|
||||
nn.Tanh(),
|
||||
nn.Linear(obs_emb_size, obs_emb_size)
|
||||
)
|
||||
self.gru_actor = nn.GRU(obs_emb_size, hidden_size_actor, batch_first=True, num_layers=self.n_layers)
|
||||
self.gru_critic = nn.GRU(obs_emb_size, hidden_size_critic, batch_first=True, num_layers=self.n_layers)
|
||||
self.action_head = nn.Sequential(
|
||||
nn.Linear(hidden_size_actor, hidden_size_actor),
|
||||
nn.Tanh(),
|
||||
nn.Linear(hidden_size_actor, n_actions)
|
||||
)
|
||||
# spectral_norm(nn.Linear(hidden_size_actor, hidden_size_actor)),
|
||||
self.critic_head = nn.Sequential(
|
||||
nn.Linear(hidden_size_critic, hidden_size_critic),
|
||||
nn.Tanh(),
|
||||
nn.Linear(hidden_size_critic, 1)
|
||||
)
|
||||
#self.action_head[-1].weight.data.uniform_(-3e-3, 3e-3)
|
||||
#self.action_head[-1].bias.data.uniform_(-3e-3, 3e-3)
|
||||
|
||||
def init_hidden_actor(self):
|
||||
return torch.zeros(1, self.n_layers, self.hidden_size_actor)
|
||||
|
||||
def init_hidden_critic(self):
|
||||
return torch.zeros(1, self.n_layers, self.hidden_size_critic)
|
||||
|
||||
def forward(self, observations, actions, hidden_actor=None, hidden_critic=None):
|
||||
n_agents, t, *_ = observations.shape
|
||||
obs_emb = self.obs_proj(observations.view(n_agents, t, -1).float())
|
||||
action_emb = self.action_emb(actions+1) # shift by one due to padding idx
|
||||
|
||||
if not self.use_agent_embedding:
|
||||
x_t = torch.cat((obs_emb, action_emb), -1)
|
||||
else:
|
||||
agent_emb = self.agent_emb(
|
||||
torch.cat([torch.arange(0, n_agents, 1).view(-1, 1)] * t, 1)
|
||||
)
|
||||
x_t = torch.cat((obs_emb, agent_emb, action_emb), -1)
|
||||
|
||||
mixed_x_t = self.mix(x_t)
|
||||
output_p, _ = self.gru_actor(input=mixed_x_t, hx=hidden_actor.swapaxes(1, 0))
|
||||
output_c, _ = self.gru_critic(input=mixed_x_t, hx=hidden_critic.swapaxes(1, 0))
|
||||
|
||||
logits = self.action_head(output_p)
|
||||
critic = self.critic_head(output_c).squeeze(-1)
|
||||
return dict(logits=logits, critic=critic, hidden_actor=output_p, hidden_critic=output_c)
|
||||
|
||||
|
||||
class RecurrentACL2(RecurrentAC):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.action_head = nn.Sequential(
|
||||
nn.Linear(self.hidden_size_actor, self.hidden_size_actor),
|
||||
nn.Tanh(),
|
||||
NormalizedLinear(self.hidden_size_actor, self.n_actions, trainable_magnitude=True)
|
||||
)
|
||||
|
||||
|
||||
class NormalizedLinear(nn.Linear):
|
||||
def __init__(self, in_features: int, out_features: int,
|
||||
device=None, dtype=None, trainable_magnitude=False):
|
||||
super(NormalizedLinear, self).__init__(in_features, out_features, False, device, dtype)
|
||||
self.d_sqrt = in_features**0.5
|
||||
self.trainable_magnitude = trainable_magnitude
|
||||
self.scale = nn.Parameter(torch.tensor([1.]), requires_grad=trainable_magnitude)
|
||||
|
||||
def forward(self, input):
|
||||
normalized_input = F.normalize(input, dim=-1, p=2, eps=1e-5)
|
||||
normalized_weight = F.normalize(self.weight, dim=-1, p=2, eps=1e-5)
|
||||
return F.linear(normalized_input, normalized_weight) * self.d_sqrt * self.scale
|
||||
|
||||
|
||||
class L2Norm(nn.Module):
|
||||
def __init__(self, in_features, trainable_magnitude=False):
|
||||
super(L2Norm, self).__init__()
|
||||
self.d_sqrt = in_features**0.5
|
||||
self.scale = nn.Parameter(torch.tensor([1.]), requires_grad=trainable_magnitude)
|
||||
|
||||
def forward(self, x):
|
||||
return F.normalize(x, dim=-1, p=2, eps=1e-5) * self.d_sqrt * self.scale
|
56
mfg_package/algorithms/marl/seac.py
Normal file
@ -0,0 +1,56 @@
|
||||
import torch
|
||||
from torch.distributions import Categorical
|
||||
from mfg_package.algorithms.marl.iac import LoopIAC
|
||||
from mfg_package.algorithms.marl.base_ac import nms
|
||||
from mfg_package.algorithms.marl.memory import MARLActorCriticMemory
|
||||
|
||||
|
||||
class LoopSEAC(LoopIAC):
|
||||
def __init__(self, cfg):
|
||||
super(LoopSEAC, self).__init__(cfg)
|
||||
|
||||
def actor_critic(self, tm, networks, gamma, entropy_coef, vf_coef, gae_coef=0.0, **kwargs):
|
||||
obs, actions, done, reward = tm.observation, tm.action, tm.done[:, 1:], tm.reward[:, 1:]
|
||||
outputs = [net(obs, actions, tm.hidden_actor[:, 0], tm.hidden_critic[:, 0]) for net in networks]
|
||||
|
||||
with torch.inference_mode(True):
|
||||
true_action_logp = torch.stack([
|
||||
torch.log_softmax(out[nms.LOGITS][ag_i, :-1], -1)
|
||||
.gather(index=actions[ag_i, 1:, None], dim=-1)
|
||||
for ag_i, out in enumerate(outputs)
|
||||
], 0).squeeze()
|
||||
|
||||
losses = []
|
||||
|
||||
for ag_i, out in enumerate(outputs):
|
||||
logits = out[nms.LOGITS][:, :-1] # last one only needed for v_{t+1}
|
||||
critic = out[nms.CRITIC]
|
||||
|
||||
entropy_loss = Categorical(logits=logits[ag_i]).entropy().mean()
|
||||
advantages = self.compute_advantages(critic, reward, done, gamma, gae_coef)
|
||||
|
||||
# policy loss
|
||||
log_ap = torch.log_softmax(logits, -1)
|
||||
log_ap = torch.gather(log_ap, dim=-1, index=actions[:, 1:].unsqueeze(-1)).squeeze()
|
||||
|
||||
# importance weights
|
||||
iw = (log_ap - true_action_logp).exp().detach() # importance_weights
|
||||
|
||||
a2c_loss = (-iw*log_ap * advantages.detach()).mean(-1)
|
||||
|
||||
|
||||
value_loss = (iw*advantages.pow(2)).mean(-1) # n_agent
|
||||
|
||||
# weighted loss
|
||||
loss = (a2c_loss + vf_coef*value_loss - entropy_coef * entropy_loss).mean()
|
||||
losses.append(loss)
|
||||
|
||||
return losses
|
||||
|
||||
def learn(self, tms: MARLActorCriticMemory, **kwargs):
|
||||
losses = self.actor_critic(tms, self.net, **self.cfg[nms.ALGORITHM], **kwargs)
|
||||
for ag_i, loss in enumerate(losses):
|
||||
self.optimizer[ag_i].zero_grad()
|
||||
loss.backward()
|
||||
torch.nn.utils.clip_grad_norm_(self.net[ag_i].parameters(), 0.5)
|
||||
self.optimizer[ag_i].step()
|
33
mfg_package/algorithms/marl/snac.py
Normal file
@ -0,0 +1,33 @@
|
||||
from mfg_package.algorithms.marl.base_ac import BaseActorCritic
|
||||
from mfg_package.algorithms.marl.base_ac import nms
|
||||
import torch
|
||||
from torch.distributions import Categorical
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
class LoopSNAC(BaseActorCritic):
|
||||
def __init__(self, cfg):
|
||||
super().__init__(cfg)
|
||||
|
||||
def load_state_dict(self, path: Path):
|
||||
path2weights = list(path.glob('*.pt'))
|
||||
assert len(path2weights) == 1, f'Expected a single set of weights but got {len(path2weights)}'
|
||||
self.net.load_state_dict(torch.load(path2weights[0]))
|
||||
|
||||
def init_hidden(self):
|
||||
hidden_actor = self.net.init_hidden_actor()
|
||||
hidden_critic = self.net.init_hidden_critic()
|
||||
return dict(hidden_actor=torch.cat([hidden_actor] * self.n_agents, 0),
|
||||
hidden_critic=torch.cat([hidden_critic] * self.n_agents, 0)
|
||||
)
|
||||
|
||||
def get_actions(self, out):
|
||||
actions = Categorical(logits=out[nms.LOGITS]).sample().squeeze()
|
||||
return actions
|
||||
|
||||
def forward(self, observations, actions, hidden_actor, hidden_critic):
|
||||
out = self.net(self._as_torch(observations).unsqueeze(1),
|
||||
self._as_torch(actions).unsqueeze(1),
|
||||
hidden_actor, hidden_critic
|
||||
)
|
||||
return out
|
130
mfg_package/algorithms/static/TSP_base_agent.py
Normal file
@ -0,0 +1,130 @@
|
||||
import itertools
|
||||
from random import choice
|
||||
|
||||
import numpy as np
|
||||
|
||||
import networkx as nx
|
||||
from networkx.algorithms.approximation import traveling_salesman as tsp
|
||||
|
||||
from mfg_package.modules.doors import constants as do
|
||||
from mfg_package.environment import constants as c
|
||||
from mfg_package.utils.helpers import MOVEMAP
|
||||
|
||||
from abc import abstractmethod, ABC
|
||||
|
||||
future_planning = 7
|
||||
|
||||
|
||||
def points_to_graph(coordiniates_or_tiles, allow_euclidean_connections=True, allow_manhattan_connections=True):
|
||||
"""
|
||||
Given a set of coordinates, this function contructs a non-directed graph, by conncting adjected points.
|
||||
There are three combinations of settings:
|
||||
Allow all neigbors: Distance(a, b) <= sqrt(2)
|
||||
Allow only manhattan: Distance(a, b) == 1
|
||||
Allow only euclidean: Distance(a, b) == sqrt(2)
|
||||
|
||||
|
||||
:param coordiniates_or_tiles: A set of coordinates.
|
||||
:type coordiniates_or_tiles: Tiles
|
||||
:param allow_euclidean_connections: Whether to regard diagonal adjected cells as neighbors
|
||||
:type: bool
|
||||
:param allow_manhattan_connections: Whether to regard directly adjected cells as neighbors
|
||||
:type: bool
|
||||
|
||||
:return: A graph with nodes that are conneceted as specified by the parameters.
|
||||
:rtype: nx.Graph
|
||||
"""
|
||||
assert allow_euclidean_connections or allow_manhattan_connections
|
||||
if hasattr(coordiniates_or_tiles, 'positions'):
|
||||
coordiniates_or_tiles = coordiniates_or_tiles.positions
|
||||
possible_connections = itertools.combinations(coordiniates_or_tiles, 2)
|
||||
graph = nx.Graph()
|
||||
for a, b in possible_connections:
|
||||
diff = np.linalg.norm(np.asarray(a)-np.asarray(b))
|
||||
if allow_manhattan_connections and allow_euclidean_connections and diff <= np.sqrt(2):
|
||||
graph.add_edge(a, b)
|
||||
elif not allow_manhattan_connections and allow_euclidean_connections and diff == np.sqrt(2):
|
||||
graph.add_edge(a, b)
|
||||
elif allow_manhattan_connections and not allow_euclidean_connections and diff == 1:
|
||||
graph.add_edge(a, b)
|
||||
return graph
|
||||
|
||||
|
||||
class TSPBaseAgent(ABC):
|
||||
|
||||
def __init__(self, state, agent_i, static_problem: bool = True):
|
||||
self.static_problem = static_problem
|
||||
self.local_optimization = True
|
||||
self._env = state
|
||||
self.state = self._env.state[c.AGENT][agent_i]
|
||||
self._floortile_graph = points_to_graph(self._env[c.FLOOR].positions)
|
||||
self._static_route = None
|
||||
|
||||
@abstractmethod
|
||||
def predict(self, *_, **__) -> int:
|
||||
return 0
|
||||
|
||||
def _use_door_or_move(self, door, target):
|
||||
if door.is_closed:
|
||||
# Translate the action_object to an integer to have the same output as any other model
|
||||
action = do.ACTION_DOOR_USE
|
||||
else:
|
||||
action = self._predict_move(target)
|
||||
return action
|
||||
|
||||
def calculate_tsp_route(self, target_identifier):
|
||||
positions = [x for x in self._env.state[target_identifier].positions if x != c.VALUE_NO_POS]
|
||||
if self.local_optimization:
|
||||
nodes = \
|
||||
[self.state.pos] + \
|
||||
[x for x in positions if max(abs(np.subtract(x, self.state.pos))) < 3]
|
||||
try:
|
||||
while len(nodes) < 7:
|
||||
nodes += [next(x for x in positions if x not in nodes)]
|
||||
except StopIteration:
|
||||
nodes = [self.state.pos] + positions
|
||||
|
||||
else:
|
||||
nodes = [self.state.pos] + positions
|
||||
route = tsp.traveling_salesman_problem(self._floortile_graph,
|
||||
nodes=nodes, cycle=True, method=tsp.greedy_tsp)
|
||||
return route
|
||||
|
||||
def _door_is_close(self):
|
||||
try:
|
||||
return next(y for x in self.state.tile.neighboring_floor for y in x.guests if do.DOOR in y.name)
|
||||
except StopIteration:
|
||||
return None
|
||||
|
||||
def _has_targets(self, target_identifier):
|
||||
return bool(len([x for x in self._env.state[target_identifier] if x.pos != c.VALUE_NO_POS]) >= 1)
|
||||
|
||||
def _predict_move(self, target_identifier):
|
||||
if self._has_targets(target_identifier):
|
||||
if self.static_problem:
|
||||
if not self._static_route:
|
||||
self._static_route = self.calculate_tsp_route(target_identifier)
|
||||
else:
|
||||
pass
|
||||
next_pos = self._static_route.pop(0)
|
||||
while next_pos == self.state.pos:
|
||||
next_pos = self._static_route.pop(0)
|
||||
else:
|
||||
if not self._static_route:
|
||||
self._static_route = self.calculate_tsp_route(target_identifier)[:7]
|
||||
next_pos = self._static_route.pop(0)
|
||||
while next_pos == self.state.pos:
|
||||
next_pos = self._static_route.pop(0)
|
||||
|
||||
diff = np.subtract(next_pos, self.state.pos)
|
||||
# Retrieve action based on the pos dif (like in: What do I have to do to get there?)
|
||||
try:
|
||||
action = next(action for action, pos_diff in MOVEMAP.items() if np.all(diff == pos_diff))
|
||||
except StopIteration:
|
||||
print(f'diff: {diff}')
|
||||
print('This Should not happen!')
|
||||
action = choice(self.state.actions).name
|
||||
else:
|
||||
action = choice(self.state.actions).name
|
||||
# noinspection PyUnboundLocalVariable
|
||||
return action
|
27
mfg_package/algorithms/static/TSP_dirt_agent.py
Normal file
@ -0,0 +1,27 @@
|
||||
from mfg_package.algorithms.static.TSP_base_agent import TSPBaseAgent
|
||||
|
||||
from mfg_package.modules.clean_up import constants as di
|
||||
|
||||
future_planning = 7
|
||||
|
||||
|
||||
class TSPDirtAgent(TSPBaseAgent):
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super(TSPDirtAgent, self).__init__(*args, **kwargs)
|
||||
|
||||
def predict(self, *_, **__):
|
||||
if self._env.state[di.DIRT].by_pos(self.state.pos) is not None:
|
||||
# Translate the action_object to an integer to have the same output as any other model
|
||||
action = di.CLEAN_UP
|
||||
elif door := self._door_is_close():
|
||||
action = self._use_door_or_move(door, di.DIRT)
|
||||
else:
|
||||
action = self._predict_move(di.DIRT)
|
||||
# Translate the action_object to an integer to have the same output as any other model
|
||||
try:
|
||||
action_obj = next(action_i for action_i, a in enumerate(self.state.actions) if a.name == action)
|
||||
except (StopIteration, UnboundLocalError):
|
||||
print('Will not happen')
|
||||
raise EnvironmentError
|
||||
return action_obj
|
59
mfg_package/algorithms/static/TSP_item_agent.py
Normal file
@ -0,0 +1,59 @@
|
||||
import numpy as np
|
||||
|
||||
from mfg_package.algorithms.static.TSP_base_agent import TSPBaseAgent
|
||||
|
||||
from mfg_package.modules.items import constants as i
|
||||
|
||||
future_planning = 7
|
||||
inventory_size = 3
|
||||
|
||||
MODE_GET = 'Mode_Get'
|
||||
MODE_BRING = 'Mode_Bring'
|
||||
|
||||
|
||||
class TSPItemAgent(TSPBaseAgent):
|
||||
|
||||
def __init__(self, *args, mode=MODE_GET, **kwargs):
|
||||
super(TSPItemAgent, self).__init__(*args, **kwargs)
|
||||
self.mode = mode
|
||||
|
||||
def predict(self, *_, **__):
|
||||
if self._env.state[i.ITEM].by_pos(self.state.pos) is not None:
|
||||
# Translate the action_object to an integer to have the same output as any other model
|
||||
action = i.ITEM_ACTION
|
||||
elif self._env.state[i.DROP_OFF].by_pos(self.state.pos) is not None:
|
||||
# Translate the action_object to an integer to have the same output as any other model
|
||||
action = i.ITEM_ACTION
|
||||
elif door := self._door_is_close():
|
||||
action = self._use_door_or_move(door, i.DROP_OFF if self.mode == MODE_BRING else i.ITEM)
|
||||
else:
|
||||
action = self._choose()
|
||||
# Translate the action_object to an integer to have the same output as any other model
|
||||
try:
|
||||
action_obj = next(action_i for action_i, a in enumerate(self.state.actions) if a.name == action)
|
||||
except (StopIteration, UnboundLocalError):
|
||||
print('Will not happen')
|
||||
raise EnvironmentError
|
||||
# noinspection PyUnboundLocalVariable
|
||||
if self.mode == MODE_BRING and len(self._env[i.INVENTORY].by_entity(self.state)):
|
||||
pass
|
||||
elif self.mode == MODE_BRING and not len(self._env[i.INVENTORY].by_entity(self.state)):
|
||||
self.mode = MODE_GET
|
||||
elif self.mode == MODE_GET and len(self._env[i.INVENTORY].by_entity(self.state)) > inventory_size:
|
||||
self.mode = MODE_BRING
|
||||
else:
|
||||
pass
|
||||
return action_obj
|
||||
|
||||
def _choose(self):
|
||||
target = i.DROP_OFF if self.mode == MODE_BRING else i.ITEM
|
||||
if len(self._env.state[i.ITEM]) >= 1:
|
||||
action = self._predict_move(target)
|
||||
|
||||
elif len(self._env[i.INVENTORY].by_entity(self.state)):
|
||||
self.mode = MODE_BRING
|
||||
action = self._predict_move(target)
|
||||
else:
|
||||
action = int(np.random.randint(self._env.action_space.n))
|
||||
# noinspection PyUnboundLocalVariable
|
||||
return action
|
32
mfg_package/algorithms/static/TSP_target_agent.py
Normal file
@ -0,0 +1,32 @@
|
||||
from mfg_package.algorithms.static.TSP_base_agent import TSPBaseAgent
|
||||
|
||||
from mfg_package.modules.destinations import constants as d
|
||||
from mfg_package.modules.doors import constants as do
|
||||
|
||||
future_planning = 7
|
||||
|
||||
|
||||
class TSPTargetAgent(TSPBaseAgent):
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super(TSPTargetAgent, self).__init__(*args, **kwargs)
|
||||
|
||||
def _handle_doors(self):
|
||||
|
||||
try:
|
||||
return next(y for x in self.state.tile.neighboring_floor for y in x.guests if do.DOOR in y.name)
|
||||
except StopIteration:
|
||||
return None
|
||||
|
||||
def predict(self, *_, **__):
|
||||
if door := self._door_is_close():
|
||||
action = self._use_door_or_move(door, d.DESTINATION)
|
||||
else:
|
||||
action = self._predict_move(d.DESTINATION)
|
||||
# Translate the action_object to an integer to have the same output as any other model
|
||||
try:
|
||||
action_obj = next(action_i for action_i, a in enumerate(self.state.actions) if a.name == action)
|
||||
except (StopIteration, UnboundLocalError):
|
||||
print('Will not happen')
|
||||
return action_obj
|
||||
|
0
mfg_package/algorithms/static/__init__.py
Normal file
15
mfg_package/algorithms/static/random_agent.py
Normal file
@ -0,0 +1,15 @@
|
||||
from random import randint
|
||||
|
||||
from mfg_package.algorithms.static.TSP_base_agent import TSPBaseAgent
|
||||
|
||||
future_planning = 7
|
||||
|
||||
|
||||
class TSPRandomAgent(TSPBaseAgent):
|
||||
|
||||
def __init__(self, n_actions, *args, **kwargs):
|
||||
super(TSPRandomAgent, self).__init__(*args, **kwargs)
|
||||
self.n_action = n_actions
|
||||
|
||||
def predict(self, *_, **__):
|
||||
return randint(0, self.n_action - 1)
|
85
mfg_package/algorithms/utils.py
Normal file
@ -0,0 +1,85 @@
|
||||
import torch
|
||||
import numpy as np
|
||||
import yaml
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
def load_class(classname):
|
||||
from importlib import import_module
|
||||
module_path, class_name = classname.rsplit(".", 1)
|
||||
module = import_module(module_path)
|
||||
c = getattr(module, class_name)
|
||||
return c
|
||||
|
||||
|
||||
def instantiate_class(arguments):
|
||||
from importlib import import_module
|
||||
|
||||
d = dict(arguments)
|
||||
classname = d["classname"]
|
||||
del d["classname"]
|
||||
module_path, class_name = classname.rsplit(".", 1)
|
||||
module = import_module(module_path)
|
||||
c = getattr(module, class_name)
|
||||
return c(**d)
|
||||
|
||||
|
||||
def get_class(arguments):
|
||||
from importlib import import_module
|
||||
|
||||
if isinstance(arguments, dict):
|
||||
classname = arguments["classname"]
|
||||
module_path, class_name = classname.rsplit(".", 1)
|
||||
module = import_module(module_path)
|
||||
c = getattr(module, class_name)
|
||||
return c
|
||||
else:
|
||||
classname = arguments.classname
|
||||
module_path, class_name = classname.rsplit(".", 1)
|
||||
module = import_module(module_path)
|
||||
c = getattr(module, class_name)
|
||||
return c
|
||||
|
||||
|
||||
def get_arguments(arguments):
|
||||
from importlib import import_module
|
||||
d = dict(arguments)
|
||||
if "classname" in d:
|
||||
del d["classname"]
|
||||
return d
|
||||
|
||||
|
||||
def load_yaml_file(path: Path):
|
||||
with path.open() as stream:
|
||||
cfg = yaml.load(stream, Loader=yaml.FullLoader)
|
||||
return cfg
|
||||
|
||||
|
||||
def add_env_props(cfg):
|
||||
env = instantiate_class(cfg['environment'].copy())
|
||||
cfg['agent'].update(dict(observation_size=list(env.observation_space.shape),
|
||||
n_actions=env.action_space.n))
|
||||
|
||||
|
||||
class Checkpointer(object):
|
||||
def __init__(self, experiment_name, root, config, total_steps, n_checkpoints):
|
||||
self.path = root / experiment_name
|
||||
self.checkpoint_indices = list(np.linspace(1, total_steps, n_checkpoints, dtype=int) - 1)
|
||||
self.__current_checkpoint = 0
|
||||
self.__current_step = 0
|
||||
self.path.mkdir(exist_ok=True, parents=True)
|
||||
with (self.path / 'config.yaml').open('w') as outfile:
|
||||
yaml.dump(config, outfile, default_flow_style=False)
|
||||
|
||||
def save_experiment(self, name: str, model):
|
||||
cpt_path = self.path / f'checkpoint_{self.__current_checkpoint}'
|
||||
cpt_path.mkdir(exist_ok=True, parents=True)
|
||||
torch.save(model.state_dict(), cpt_path / f'{name}.pt')
|
||||
|
||||
def step(self, to_save):
|
||||
if self.__current_step in self.checkpoint_indices:
|
||||
print(f'Checkpointing #{self.__current_checkpoint}')
|
||||
for name, model in to_save:
|
||||
self.save_experiment(name, model)
|
||||
self.__current_checkpoint += 1
|
||||
self.__current_step += 1
|
0
mfg_package/environment/__init__.py
Normal file
100
mfg_package/environment/actions.py
Normal file
@ -0,0 +1,100 @@
|
||||
import abc
|
||||
from typing import Union
|
||||
|
||||
from mfg_package.environment import rewards as r, constants as c
|
||||
from mfg_package.utils.helpers import MOVEMAP
|
||||
from mfg_package.utils.results import ActionResult
|
||||
|
||||
|
||||
class Action(abc.ABC):
|
||||
|
||||
@property
|
||||
def name(self):
|
||||
return self._identifier
|
||||
|
||||
@abc.abstractmethod
|
||||
def __init__(self, identifier: str):
|
||||
self._identifier = identifier
|
||||
|
||||
@abc.abstractmethod
|
||||
def do(self, entity, state) -> Union[None, ActionResult]:
|
||||
return
|
||||
|
||||
def __repr__(self):
|
||||
return f'Action[{self._identifier}]'
|
||||
|
||||
|
||||
class Noop(Action):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(c.NOOP)
|
||||
|
||||
def do(self, entity, *_) -> Union[None, ActionResult]:
|
||||
return ActionResult(identifier=self._identifier, validity=c.VALID,
|
||||
reward=r.NOOP, entity=entity)
|
||||
|
||||
|
||||
class Move(Action, abc.ABC):
|
||||
|
||||
@abc.abstractmethod
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
def do(self, entity, env):
|
||||
new_pos = self._calc_new_pos(entity.pos)
|
||||
if next_tile := env[c.FLOOR].by_pos(new_pos):
|
||||
# noinspection PyUnresolvedReferences
|
||||
valid = entity.move(next_tile)
|
||||
else:
|
||||
valid = c.NOT_VALID
|
||||
reward = r.MOVEMENTS_VALID if valid else r.MOVEMENTS_FAIL
|
||||
return ActionResult(entity=entity, identifier=self._identifier, validity=valid, reward=reward)
|
||||
|
||||
def _calc_new_pos(self, pos):
|
||||
x_diff, y_diff = MOVEMAP[self._identifier]
|
||||
return pos[0] + x_diff, pos[1] + y_diff
|
||||
|
||||
|
||||
class North(Move):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(c.NORTH, *args, **kwargs)
|
||||
|
||||
|
||||
class NorthEast(Move):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(c.NORTHEAST, *args, **kwargs)
|
||||
|
||||
|
||||
class East(Move):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(c.EAST, *args, **kwargs)
|
||||
|
||||
|
||||
class SouthEast(Move):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(c.SOUTHEAST, *args, **kwargs)
|
||||
|
||||
|
||||
class South(Move):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(c.SOUTH, *args, **kwargs)
|
||||
|
||||
|
||||
class SouthWest(Move):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(c.SOUTHWEST, *args, **kwargs)
|
||||
|
||||
|
||||
class West(Move):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(c.WEST, *args, **kwargs)
|
||||
|
||||
|
||||
class NorthWest(Move):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(c.NORTHWEST, *args, **kwargs)
|
||||
|
||||
|
||||
Move4 = [North, East, South, West]
|
||||
# noinspection PyTypeChecker
|
||||
Move8 = Move4 + [NorthEast, SouthEast, SouthWest, NorthWest]
|
0
mfg_package/environment/assets/__init__.py
Normal file
0
mfg_package/environment/assets/agent/__init__.py
Normal file
BIN
mfg_package/environment/assets/agent/adversary.png
Normal file
After Width: | Height: | Size: 8.3 KiB |
BIN
mfg_package/environment/assets/agent/agent.png
Normal file
After Width: | Height: | Size: 3.3 KiB |
BIN
mfg_package/environment/assets/agent/agent_collision.png
Normal file
After Width: | Height: | Size: 18 KiB |
BIN
mfg_package/environment/assets/agent/idle.png
Normal file
After Width: | Height: | Size: 1.6 KiB |
BIN
mfg_package/environment/assets/agent/invalid.png
Normal file
After Width: | Height: | Size: 1.6 KiB |
BIN
mfg_package/environment/assets/agent/move.png
Normal file
After Width: | Height: | Size: 5.8 KiB |
BIN
mfg_package/environment/assets/agent/valid.png
Normal file
After Width: | Height: | Size: 5.6 KiB |
BIN
mfg_package/environment/assets/wall.png
Normal file
After Width: | Height: | Size: 1.4 KiB |
60
mfg_package/environment/constants.py
Normal file
@ -0,0 +1,60 @@
|
||||
# Names
|
||||
DANGER_ZONE = 'x' # Dange Zone tile _identifier for resolving the string based map files.
|
||||
DEFAULTS = 'Defaults'
|
||||
SELF = 'Self'
|
||||
PLACEHOLDER = 'Placeholder'
|
||||
FLOOR = 'Floor' # Identifier of Floor-objects and groups (groups).
|
||||
FLOORS = 'Floors' # Identifier of Floor-objects and groups (groups).
|
||||
WALL = 'Wall' # Identifier of Wall-objects and groups (groups).
|
||||
WALLS = 'Walls' # Identifier of Wall-objects and groups (groups).
|
||||
LEVEL = 'Level' # Identifier of Level-objects and groups (groups).
|
||||
AGENT = 'Agent' # Identifier of Agent-objects and groups (groups).
|
||||
AGENTS = 'Agents' # Identifier of Agent-objects and groups (groups).
|
||||
OTHERS = 'Other'
|
||||
COMBINED = 'Combined'
|
||||
GLOBAL_POSITION = 'GLOBAL_POSITION' # Identifier of the global position slice
|
||||
|
||||
|
||||
# Attributes
|
||||
IS_BLOCKING_LIGHT = 'is_blocking_light'
|
||||
HAS_POSITION = 'has_position'
|
||||
HAS_NO_POSITION = 'has_no_position'
|
||||
ALL = 'All'
|
||||
|
||||
# Symbols (Read from map-files)
|
||||
SYMBOL_WALL = '#'
|
||||
SYMBOL_FLOOR = '-'
|
||||
|
||||
VALID = True # Identifier to rename boolean values in the context of actions.
|
||||
NOT_VALID = False # Identifier to rename boolean values in the context of actions.
|
||||
VALUE_FREE_CELL = 0 # Free-Cell value used in observation
|
||||
VALUE_OCCUPIED_CELL = 1 # Occupied-Cell value used in observation
|
||||
VALUE_NO_POS = (-9999, -9999) # Invalid Position value used in the environment (smth. is off-grid)
|
||||
|
||||
|
||||
ACTION = 'action' # Identifier of Action-objects and groups (groups).
|
||||
COLLISION = 'Collision' # Identifier to use in the context of collitions.
|
||||
LAST_POS = 'LAST_POS' # Identifiert for retrieving an enitites last pos.
|
||||
VALIDITY = 'VALIDITY' # Identifiert for retrieving the Validity of Action, Tick, etc. ...
|
||||
|
||||
# Actions
|
||||
# Movements
|
||||
NORTH = 'north'
|
||||
EAST = 'east'
|
||||
SOUTH = 'south'
|
||||
WEST = 'west'
|
||||
NORTHEAST = 'north_east'
|
||||
SOUTHEAST = 'south_east'
|
||||
SOUTHWEST = 'south_west'
|
||||
NORTHWEST = 'north_west'
|
||||
|
||||
# Move Groups
|
||||
MOVE8 = 'Move8'
|
||||
MOVE4 = 'Move4'
|
||||
|
||||
# No-Action / Wait
|
||||
NOOP = 'Noop'
|
||||
|
||||
# Result Identifier
|
||||
MOVEMENTS_VALID = 'motion_valid'
|
||||
MOVEMENTS_FAIL = 'motion_not_valid'
|
0
mfg_package/environment/entity/__init__.py
Normal file
76
mfg_package/environment/entity/agent.py
Normal file
@ -0,0 +1,76 @@
|
||||
from typing import List, Union
|
||||
|
||||
from mfg_package.environment import constants as c
|
||||
from mfg_package.environment.actions import Action
|
||||
from mfg_package.environment.entity.entity import Entity
|
||||
from mfg_package.utils.render import RenderEntity
|
||||
from mfg_package.utils import renderer
|
||||
from mfg_package.utils.helpers import is_move
|
||||
from mfg_package.utils.results import ActionResult, Result
|
||||
|
||||
|
||||
class Agent(Entity):
|
||||
|
||||
@property
|
||||
def obs_tag(self):
|
||||
return self.name
|
||||
|
||||
@property
|
||||
def actions(self):
|
||||
return self._actions
|
||||
|
||||
@property
|
||||
def observations(self):
|
||||
return self._observations
|
||||
|
||||
@property
|
||||
def can_collide(self):
|
||||
return True
|
||||
|
||||
def step_result(self):
|
||||
pass
|
||||
|
||||
@property
|
||||
def collection(self):
|
||||
return self._collection
|
||||
|
||||
@property
|
||||
def state(self):
|
||||
return self._state or ActionResult(entity=self, identifier=c.NOOP, validity=c.VALID, reward=0)
|
||||
|
||||
def __init__(self, actions: List[Action], observations: List[str], *args, **kwargs):
|
||||
super(Agent, self).__init__(*args, **kwargs)
|
||||
self.step_result = dict()
|
||||
self._actions = actions
|
||||
self._observations = observations
|
||||
self._state: Union[Result, None] = None
|
||||
|
||||
# noinspection PyAttributeOutsideInit
|
||||
def clear_temp_state(self):
|
||||
self._state = None
|
||||
return self
|
||||
|
||||
def summarize_state(self):
|
||||
state_dict = super().summarize_state()
|
||||
state_dict.update(valid=bool(self.state.validity), action=str(self.state.identifier))
|
||||
return state_dict
|
||||
|
||||
def set_state(self, action_result):
|
||||
self._state = action_result
|
||||
|
||||
def render(self):
|
||||
i = next(idx for idx, x in enumerate(self._collection) if x.name == self.name)
|
||||
curr_state = self.state
|
||||
if curr_state.identifier == c.COLLISION:
|
||||
render_state = renderer.STATE_COLLISION
|
||||
elif curr_state.validity:
|
||||
if curr_state.identifier == c.NOOP:
|
||||
render_state = renderer.STATE_IDLE
|
||||
elif is_move(curr_state.identifier):
|
||||
render_state = renderer.STATE_MOVE
|
||||
else:
|
||||
render_state = renderer.STATE_VALID
|
||||
else:
|
||||
render_state = renderer.STATE_INVALID
|
||||
|
||||
return RenderEntity(c.AGENT, self.pos, 1, 'none', render_state, i + 1, real_name=self.name)
|
79
mfg_package/environment/entity/entity.py
Normal file
@ -0,0 +1,79 @@
|
||||
import abc
|
||||
|
||||
from mfg_package.environment import constants as c
|
||||
from mfg_package.environment.entity.object import EnvObject
|
||||
from mfg_package.utils.render import RenderEntity
|
||||
|
||||
|
||||
class Entity(EnvObject, abc.ABC):
|
||||
"""Full Env Entity that lives on the environment Grid. Doors, Items, DirtPile etc..."""
|
||||
|
||||
@property
|
||||
def has_position(self):
|
||||
return self.pos != c.VALUE_NO_POS
|
||||
|
||||
@property
|
||||
def x(self):
|
||||
return self.pos[0]
|
||||
|
||||
@property
|
||||
def y(self):
|
||||
return self.pos[1]
|
||||
|
||||
@property
|
||||
def pos(self):
|
||||
return self._tile.pos
|
||||
|
||||
@property
|
||||
def tile(self):
|
||||
return self._tile
|
||||
|
||||
@property
|
||||
def last_tile(self):
|
||||
try:
|
||||
return self._last_tile
|
||||
except AttributeError:
|
||||
# noinspection PyAttributeOutsideInit
|
||||
self._last_tile = None
|
||||
return self._last_tile
|
||||
|
||||
@property
|
||||
def last_pos(self):
|
||||
try:
|
||||
return self.last_tile.pos
|
||||
except AttributeError:
|
||||
return c.VALUE_NO_POS
|
||||
|
||||
@property
|
||||
def direction_of_view(self):
|
||||
last_x, last_y = self.last_pos
|
||||
curr_x, curr_y = self.pos
|
||||
return last_x - curr_x, last_y - curr_y
|
||||
|
||||
def move(self, next_tile):
|
||||
curr_tile = self.tile
|
||||
if not_same_tile := curr_tile != next_tile:
|
||||
if valid := next_tile.enter(self):
|
||||
curr_tile.leave(self)
|
||||
self._tile = next_tile
|
||||
self._last_tile = curr_tile
|
||||
for observer in self.observers:
|
||||
observer.notify_change_pos(self)
|
||||
return valid
|
||||
return not_same_tile
|
||||
|
||||
def __init__(self, tile, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self._tile = tile
|
||||
tile.enter(self)
|
||||
|
||||
def summarize_state(self) -> dict:
|
||||
return dict(name=str(self.name), x=int(self.x), y=int(self.y),
|
||||
tile=str(self.tile.name), can_collide=bool(self.can_collide))
|
||||
|
||||
@abc.abstractmethod
|
||||
def render(self):
|
||||
return RenderEntity(self.__class__.__name__.lower(), self.pos)
|
||||
|
||||
def __repr__(self):
|
||||
return super(Entity, self).__repr__() + f'(@{self.pos})'
|
18
mfg_package/environment/entity/mixin.py
Normal file
@ -0,0 +1,18 @@
|
||||
|
||||
|
||||
# noinspection PyAttributeOutsideInit
|
||||
class BoundEntityMixin:
|
||||
|
||||
@property
|
||||
def bound_entity(self):
|
||||
return self._bound_entity
|
||||
|
||||
@property
|
||||
def name(self):
|
||||
return f'{self.__class__.__name__}({self._bound_entity.name})'
|
||||
|
||||
def belongs_to_entity(self, entity):
|
||||
return entity == self.bound_entity
|
||||
|
||||
def bind_to(self, entity):
|
||||
self._bound_entity = entity
|
126
mfg_package/environment/entity/object.py
Normal file
@ -0,0 +1,126 @@
|
||||
from collections import defaultdict
|
||||
from typing import Union
|
||||
|
||||
from mfg_package.environment import constants as c
|
||||
|
||||
|
||||
class Object:
|
||||
|
||||
"""Generell Objects for Organisation and Maintanance such as Actions etc..."""
|
||||
|
||||
_u_idx = defaultdict(lambda: 0)
|
||||
|
||||
def __bool__(self):
|
||||
return True
|
||||
|
||||
@property
|
||||
def observers(self):
|
||||
return self._observers
|
||||
|
||||
@property
|
||||
def name(self):
|
||||
if self._str_ident is not None:
|
||||
return f'{self.__class__.__name__}[{self._str_ident}]'
|
||||
return f'{self.__class__.__name__}#{self.identifier_int}'
|
||||
|
||||
@property
|
||||
def identifier(self):
|
||||
if self._str_ident is not None:
|
||||
return self._str_ident
|
||||
else:
|
||||
return self.name
|
||||
|
||||
def __init__(self, str_ident: Union[str, None] = None, **kwargs):
|
||||
self._observers = []
|
||||
self._str_ident = str_ident
|
||||
self.identifier_int = self._identify_and_count_up()
|
||||
self._collection = None
|
||||
|
||||
if kwargs:
|
||||
print(f'Following kwargs were passed, but ignored: {kwargs}')
|
||||
|
||||
def __repr__(self):
|
||||
return f'{self.name}'
|
||||
|
||||
def __eq__(self, other) -> bool:
|
||||
return other == self.identifier
|
||||
|
||||
def __hash__(self):
|
||||
return hash(self.identifier)
|
||||
|
||||
def _identify_and_count_up(self):
|
||||
idx = Object._u_idx[self.__class__.__name__]
|
||||
Object._u_idx[self.__class__.__name__] += 1
|
||||
return idx
|
||||
|
||||
def set_collection(self, collection):
|
||||
self._collection = collection
|
||||
|
||||
def add_observer(self, observer):
|
||||
self.observers.append(observer)
|
||||
observer.notify_change_pos(self)
|
||||
|
||||
def del_observer(self, observer):
|
||||
self.observers.remove(observer)
|
||||
|
||||
|
||||
class EnvObject(Object):
|
||||
|
||||
"""Objects that hold Information that are observable, but have no position on the environment grid. Inventories etc..."""
|
||||
|
||||
_u_idx = defaultdict(lambda: 0)
|
||||
|
||||
@property
|
||||
def obs_tag(self):
|
||||
try:
|
||||
return self._collection.name or self.name
|
||||
except AttributeError:
|
||||
return self.name
|
||||
|
||||
@property
|
||||
def is_blocking_light(self):
|
||||
try:
|
||||
return self._collection.is_blocking_light or False
|
||||
except AttributeError:
|
||||
return False
|
||||
|
||||
@property
|
||||
def can_move(self):
|
||||
try:
|
||||
return self._collection.can_move or False
|
||||
except AttributeError:
|
||||
return False
|
||||
|
||||
@property
|
||||
def is_blocking_pos(self):
|
||||
try:
|
||||
return self._collection.is_blocking_pos or False
|
||||
except AttributeError:
|
||||
return False
|
||||
|
||||
@property
|
||||
def has_position(self):
|
||||
try:
|
||||
return self._collection.has_position or False
|
||||
except AttributeError:
|
||||
return False
|
||||
|
||||
@property
|
||||
def can_collide(self):
|
||||
try:
|
||||
return self._collection.can_collide or False
|
||||
except AttributeError:
|
||||
return False
|
||||
|
||||
@property
|
||||
def encoding(self):
|
||||
return c.VALUE_OCCUPIED_CELL
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
super(EnvObject, self).__init__(**kwargs)
|
||||
|
||||
def change_parent_collection(self, other_collection):
|
||||
other_collection.add_item(self)
|
||||
self._collection.delete_env_object(self)
|
||||
self._collection = other_collection
|
||||
return self._collection == other_collection
|
45
mfg_package/environment/entity/util.py
Normal file
@ -0,0 +1,45 @@
|
||||
import math
|
||||
|
||||
import numpy as np
|
||||
|
||||
from mfg_package.environment.entity.mixin import BoundEntityMixin
|
||||
from mfg_package.environment.entity.object import Object, EnvObject
|
||||
|
||||
|
||||
##########################################################################
|
||||
# ####################### Objects and Entitys ########################## #
|
||||
##########################################################################
|
||||
|
||||
|
||||
class PlaceHolder(Object):
|
||||
|
||||
def __init__(self, *args, fill_value=0, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self._fill_value = fill_value
|
||||
|
||||
@property
|
||||
def can_collide(self):
|
||||
return False
|
||||
|
||||
@property
|
||||
def encoding(self):
|
||||
return self._fill_value
|
||||
|
||||
@property
|
||||
def name(self):
|
||||
return "PlaceHolder"
|
||||
|
||||
|
||||
class GlobalPosition(BoundEntityMixin, EnvObject):
|
||||
|
||||
@property
|
||||
def encoding(self):
|
||||
if self._normalized:
|
||||
return tuple(np.divide(self._bound_entity.pos, self._level_shape))
|
||||
else:
|
||||
return self.bound_entity.pos
|
||||
|
||||
def __init__(self, *args, normalized: bool = True, **kwargs):
|
||||
super(GlobalPosition, self).__init__(*args, **kwargs)
|
||||
self._level_shape = math.sqrt(self.size)
|
||||
self._normalized = normalized
|
131
mfg_package/environment/entity/wall_floor.py
Normal file
@ -0,0 +1,131 @@
|
||||
from typing import List
|
||||
|
||||
import numpy as np
|
||||
|
||||
from mfg_package.environment import constants as c
|
||||
from mfg_package.environment.entity.object import EnvObject
|
||||
from mfg_package.utils.render import RenderEntity
|
||||
from mfg_package.utils import helpers as h
|
||||
|
||||
|
||||
class Floor(EnvObject):
|
||||
|
||||
@property
|
||||
def has_position(self):
|
||||
return True
|
||||
|
||||
@property
|
||||
def can_collide(self):
|
||||
return False
|
||||
|
||||
@property
|
||||
def can_move(self):
|
||||
return False
|
||||
|
||||
@property
|
||||
def is_blocking_pos(self):
|
||||
return False
|
||||
|
||||
@property
|
||||
def is_blocking_light(self):
|
||||
return False
|
||||
|
||||
@property
|
||||
def neighboring_floor_pos(self):
|
||||
return [x.pos for x in self.neighboring_floor]
|
||||
|
||||
@property
|
||||
def neighboring_floor(self):
|
||||
if self._neighboring_floor:
|
||||
pass
|
||||
else:
|
||||
self._neighboring_floor = [x for x in [self._collection.by_pos(np.add(self.pos, pos))
|
||||
for pos in h.POS_MASK.reshape(-1, 2)
|
||||
if not np.all(pos == [0, 0])]
|
||||
if x]
|
||||
return self._neighboring_floor
|
||||
|
||||
@property
|
||||
def encoding(self):
|
||||
return c.VALUE_OCCUPIED_CELL
|
||||
|
||||
@property
|
||||
def guests_that_can_collide(self):
|
||||
return [x for x in self.guests if x.can_collide]
|
||||
|
||||
@property
|
||||
def guests(self):
|
||||
return self._guests.values()
|
||||
|
||||
@property
|
||||
def x(self):
|
||||
return self.pos[0]
|
||||
|
||||
@property
|
||||
def y(self):
|
||||
return self.pos[1]
|
||||
|
||||
@property
|
||||
def is_blocked(self):
|
||||
return any([x.is_blocking_pos for x in self.guests])
|
||||
|
||||
def __init__(self, pos, **kwargs):
|
||||
super(Floor, self).__init__(**kwargs)
|
||||
self._guests = dict()
|
||||
self.pos = tuple(pos)
|
||||
self._neighboring_floor: List[Floor] = list()
|
||||
self._blocked_by = None
|
||||
|
||||
def __len__(self):
|
||||
return len(self._guests)
|
||||
|
||||
def is_empty(self):
|
||||
return not len(self._guests)
|
||||
|
||||
def is_occupied(self):
|
||||
return bool(len(self._guests))
|
||||
|
||||
def enter(self, guest):
|
||||
if (guest.name not in self._guests and not self.is_blocked) and not (guest.is_blocking_pos and self.is_occupied()):
|
||||
self._guests.update({guest.name: guest})
|
||||
return c.VALID
|
||||
else:
|
||||
return c.NOT_VALID
|
||||
|
||||
def leave(self, guest):
|
||||
try:
|
||||
del self._guests[guest.name]
|
||||
except (ValueError, KeyError):
|
||||
return c.NOT_VALID
|
||||
return c.VALID
|
||||
|
||||
def __repr__(self):
|
||||
return f'{self.name}(@{self.pos})'
|
||||
|
||||
def summarize_state(self, **_):
|
||||
return dict(name=self.name, x=int(self.x), y=int(self.y))
|
||||
|
||||
def render(self):
|
||||
return None
|
||||
|
||||
|
||||
class Wall(Floor):
|
||||
|
||||
@property
|
||||
def can_collide(self):
|
||||
return True
|
||||
|
||||
@property
|
||||
def encoding(self):
|
||||
return c.VALUE_OCCUPIED_CELL
|
||||
|
||||
def render(self):
|
||||
return RenderEntity(c.WALL, self.pos)
|
||||
|
||||
@property
|
||||
def is_blocking_pos(self):
|
||||
return True
|
||||
|
||||
@property
|
||||
def is_blocking_light(self):
|
||||
return True
|
201
mfg_package/environment/factory.py
Normal file
@ -0,0 +1,201 @@
|
||||
import shutil
|
||||
|
||||
from collections import defaultdict
|
||||
from itertools import chain
|
||||
from os import PathLike
|
||||
from pathlib import Path
|
||||
from typing import Union
|
||||
|
||||
import gymnasium as gym
|
||||
|
||||
from mfg_package.utils.level_parser import LevelParser
|
||||
from mfg_package.utils.observation_builder import OBSBuilder
|
||||
from mfg_package.utils.config_parser import FactoryConfigParser
|
||||
from mfg_package.utils import helpers as h
|
||||
import mfg_package.environment.constants as c
|
||||
|
||||
from mfg_package.utils.states import Gamestate
|
||||
|
||||
REC_TAC = 'rec_'
|
||||
|
||||
|
||||
class BaseFactory(gym.Env):
|
||||
|
||||
@property
|
||||
def action_space(self):
|
||||
return self.state[c.AGENT].action_space
|
||||
|
||||
@property
|
||||
def named_action_space(self):
|
||||
return self.state[c.AGENT].named_action_space
|
||||
|
||||
@property
|
||||
def observation_space(self):
|
||||
return self.obs_builder.observation_space(self.state)
|
||||
|
||||
@property
|
||||
def named_observation_space(self):
|
||||
return self.obs_builder.named_observation_space(self.state)
|
||||
|
||||
@property
|
||||
def params(self) -> dict:
|
||||
import yaml
|
||||
config_path = Path(self._config_file)
|
||||
config_dict = yaml.safe_load(config_path.open())
|
||||
return config_dict
|
||||
|
||||
@property
|
||||
def summarize_header(self):
|
||||
summary_dict = self._summarize_state(stateless_entities=True)
|
||||
return summary_dict
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
self.close()
|
||||
|
||||
def __init__(self, config_file: Union[str, PathLike]):
|
||||
self._config_file = config_file
|
||||
self.conf = FactoryConfigParser(self._config_file)
|
||||
# Attribute Assignment
|
||||
self.level_filepath = Path(__file__).parent.parent / h.LEVELS_DIR / f'{self.conf.level_name}.txt'
|
||||
self._renderer = None # expensive - don't use; unless required !
|
||||
|
||||
parsed_entities = self.conf.load_entities()
|
||||
self.map = LevelParser(self.level_filepath, parsed_entities, self.conf.pomdp_r)
|
||||
|
||||
# Init for later usage:
|
||||
self.state: Gamestate
|
||||
self.map: LevelParser
|
||||
self.obs_builder: OBSBuilder
|
||||
|
||||
# TODO: Reset ---> document this
|
||||
self.reset()
|
||||
|
||||
def __getitem__(self, item):
|
||||
return self.state.entities[item]
|
||||
|
||||
def reset(self) -> (dict, dict):
|
||||
self.state = None
|
||||
|
||||
# Init entity:
|
||||
entities = self.map.do_init()
|
||||
|
||||
# Grab all rules:
|
||||
rules = self.conf.load_rules()
|
||||
|
||||
# Agents
|
||||
# noinspection PyAttributeOutsideInit
|
||||
self.state = Gamestate(entities, rules, self.conf.env_seed)
|
||||
|
||||
agents = self.conf.load_agents(self.map.size, self[c.FLOOR].empty_tiles)
|
||||
self.state.entities.add_item({c.AGENT: agents})
|
||||
|
||||
# All is set up, trigger additional init (after agent entity spawn etc)
|
||||
self.state.rules.do_all_init(self.state)
|
||||
|
||||
# Observations
|
||||
# noinspection PyAttributeOutsideInit
|
||||
self.obs_builder = OBSBuilder(self.map.level_shape, self.state, self.map.pomdp_r)
|
||||
return self.obs_builder.refresh_and_build_for_all(self.state)
|
||||
|
||||
def step(self, actions):
|
||||
|
||||
if not isinstance(actions, list):
|
||||
actions = [int(actions)]
|
||||
|
||||
# Apply rules, do actions, tick the state, etc...
|
||||
tick_result = self.state.tick(actions)
|
||||
|
||||
# Check Done Conditions
|
||||
done_results = self.state.check_done()
|
||||
|
||||
# Finalize
|
||||
reward, reward_info, done = self.summarize_step_results(tick_result, done_results)
|
||||
|
||||
info = reward_info
|
||||
|
||||
info.update(step_reward=sum(reward), step=self.state.curr_step)
|
||||
# TODO:
|
||||
# if self._record_episodes:
|
||||
# info.update(self._summarize_state())
|
||||
|
||||
obs, reset_info = self.obs_builder.refresh_and_build_for_all(self.state)
|
||||
info.update(reset_info)
|
||||
return None, [x for x in obs.values()], reward, done, info
|
||||
|
||||
def summarize_step_results(self, tick_results: list, done_check_results: list) -> (int, dict, bool):
|
||||
# Returns: Reward, Info
|
||||
rewards = defaultdict(lambda: 0.0)
|
||||
|
||||
# Gather per agent environment rewards and
|
||||
# Combine Info dicts into a global one
|
||||
combined_info_dict = defaultdict(lambda: 0.0)
|
||||
for result in chain(tick_results, done_check_results):
|
||||
if result.reward is not None:
|
||||
try:
|
||||
rewards[result.entity.name] += result.reward
|
||||
except AttributeError:
|
||||
rewards['global'] += result.reward
|
||||
infos = result.get_infos()
|
||||
for info in infos:
|
||||
assert isinstance(info.value, (float, int))
|
||||
combined_info_dict[info.identifier] += info.value
|
||||
|
||||
# Check Done Rule Results
|
||||
try:
|
||||
done_reason = next(x for x in done_check_results if x.validity)
|
||||
done = True
|
||||
self.state.print(f'Env done, Reason: {done_reason.name}.')
|
||||
except StopIteration:
|
||||
done = False
|
||||
|
||||
if self.conf.individual_rewards:
|
||||
global_rewards = rewards['global']
|
||||
del rewards['global']
|
||||
reward = [rewards[agent.name] for agent in self.state[c.AGENT]]
|
||||
reward = [x + global_rewards for x in reward]
|
||||
self.state.print(f"rewards are {rewards}")
|
||||
return reward, combined_info_dict, done
|
||||
else:
|
||||
reward = sum(rewards.values())
|
||||
self.state.print(f"reward is {reward}")
|
||||
return reward, combined_info_dict, done
|
||||
|
||||
def start_recording(self):
|
||||
self.conf.do_record = True
|
||||
return self.conf.do_record
|
||||
|
||||
def stop_recording(self):
|
||||
self.conf.do_record = False
|
||||
return not self.conf.do_record
|
||||
|
||||
# noinspection PyGlobalUndefined
|
||||
def render(self, mode='human'):
|
||||
if not self._renderer: # lazy init
|
||||
from mfg_package.utils.renderer import Renderer
|
||||
global Renderer
|
||||
self._renderer = Renderer(self.map.level_shape, view_radius=self.conf.pomdp_r, fps=10)
|
||||
|
||||
render_entities = self.state.entities.render()
|
||||
if self.conf.pomdp_r:
|
||||
for render_entity in render_entities:
|
||||
if render_entity.name == c.AGENT:
|
||||
render_entity.aux = self.obs_builder.curr_lightmaps[render_entity.real_name]
|
||||
return self._renderer.render(render_entities)
|
||||
|
||||
def _summarize_state(self, stateless_entities=False):
|
||||
summary = {f'{REC_TAC}step': self.state.curr_step}
|
||||
|
||||
for entity_group in self.state:
|
||||
if entity_group.is_stateless == stateless_entities:
|
||||
summary.update({f'{REC_TAC}{entity_group.name}': entity_group.summarize_states()})
|
||||
return summary
|
||||
|
||||
def print(self, string):
|
||||
if self.conf.verbose:
|
||||
print(string)
|
||||
|
||||
def save_params(self, filepath: Path):
|
||||
# noinspection PyProtectedMember
|
||||
filepath = Path(filepath)
|
||||
filepath.parent.mkdir(parents=True, exist_ok=True)
|
||||
shutil.copyfile(self._config_file, filepath)
|
0
mfg_package/environment/groups/__init__.py
Normal file
29
mfg_package/environment/groups/agents.py
Normal file
@ -0,0 +1,29 @@
|
||||
from mfg_package.environment.groups.env_objects import EnvObjects
|
||||
from mfg_package.environment.groups.mixins import PositionMixin
|
||||
from mfg_package.environment.entity.agent import Agent
|
||||
|
||||
|
||||
class Agents(PositionMixin, EnvObjects):
|
||||
_entity = Agent
|
||||
is_blocking_light = False
|
||||
can_move = True
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
@property
|
||||
def obs_pairs(self):
|
||||
return [(a.name, a) for a in self]
|
||||
|
||||
@property
|
||||
def action_space(self):
|
||||
from gymnasium import spaces
|
||||
space = spaces.Tuple([spaces.Discrete(len(x.actions)) for x in self])
|
||||
return space
|
||||
|
||||
@property
|
||||
def named_action_space(self):
|
||||
named_space = dict()
|
||||
for agent in self:
|
||||
named_space[agent.name] = {action.name: idx for idx, action in enumerate(agent.actions)}
|
||||
return named_space
|
33
mfg_package/environment/groups/env_objects.py
Normal file
@ -0,0 +1,33 @@
|
||||
from mfg_package.environment.groups.objects import Objects
|
||||
from mfg_package.environment.entity.object import EnvObject
|
||||
|
||||
|
||||
class EnvObjects(Objects):
|
||||
|
||||
_entity = EnvObject
|
||||
is_blocking_light: bool = False
|
||||
can_collide: bool = False
|
||||
has_position: bool = False
|
||||
can_move: bool = False
|
||||
|
||||
@property
|
||||
def encodings(self):
|
||||
return [x.encoding for x in self]
|
||||
|
||||
def __init__(self, size, *args, **kwargs):
|
||||
super(EnvObjects, self).__init__(*args, **kwargs)
|
||||
self.size = size
|
||||
|
||||
def add_item(self, item: EnvObject):
|
||||
assert self.has_position or (len(self) <= self.size)
|
||||
super(EnvObjects, self).add_item(item)
|
||||
return self
|
||||
|
||||
def summarize_states(self):
|
||||
return [entity.summarize_state() for entity in self.values()]
|
||||
|
||||
def delete_env_object(self, env_object: EnvObject):
|
||||
del self[env_object.name]
|
||||
|
||||
def delete_env_object_by_name(self, name):
|
||||
del self[name]
|
63
mfg_package/environment/groups/global_entities.py
Normal file
@ -0,0 +1,63 @@
|
||||
from collections import defaultdict
|
||||
from operator import itemgetter
|
||||
from typing import Dict
|
||||
|
||||
from mfg_package.environment.groups.objects import Objects
|
||||
from mfg_package.utils.helpers import POS_MASK
|
||||
|
||||
|
||||
class Entities(Objects):
|
||||
_entity = Objects
|
||||
|
||||
@staticmethod
|
||||
def neighboring_positions(pos):
|
||||
return (POS_MASK + pos).reshape(-1, 2)
|
||||
|
||||
def get_near_pos(self, pos):
|
||||
return [y for x in itemgetter(*(tuple(x) for x in self.neighboring_positions(pos)))(self.pos_dict) for y in x]
|
||||
|
||||
def render(self):
|
||||
return [y for x in self for y in x.render() if x is not None]
|
||||
|
||||
@property
|
||||
def names(self):
|
||||
return list(self._data.keys())
|
||||
|
||||
def __init__(self):
|
||||
self.pos_dict = defaultdict(list)
|
||||
super().__init__()
|
||||
|
||||
def iter_entities(self):
|
||||
return iter((x for sublist in self.values() for x in sublist))
|
||||
|
||||
def add_items(self, items: Dict):
|
||||
return self.add_item(items)
|
||||
|
||||
def add_item(self, item: dict):
|
||||
assert_str = 'This group of entity has already been added!'
|
||||
assert not any([key for key in item.keys() if key in self.keys()]), assert_str
|
||||
self._data.update(item)
|
||||
for val in item.values():
|
||||
val.add_observer(self)
|
||||
return self
|
||||
|
||||
def __delitem__(self, name):
|
||||
assert_str = 'This group of entity does not exist in this collection!'
|
||||
assert any([key for key in name.keys() if key in self.keys()]), assert_str
|
||||
self[name]._observers.delete(self)
|
||||
for entity in self[name]:
|
||||
entity.del_observer(self)
|
||||
return super(Entities, self).__delitem__(name)
|
||||
|
||||
@property
|
||||
def obs_pairs(self):
|
||||
return [y for x in self for y in x.obs_pairs]
|
||||
|
||||
def by_pos(self, pos: (int, int)):
|
||||
return self.pos_dict[pos]
|
||||
# found_entities = [y for y in (x.by_pos(pos) for x in self.values() if hasattr(x, 'by_pos')) if y is not None]
|
||||
# return found_entities
|
||||
|
||||
@property
|
||||
def positions(self):
|
||||
return [k for k, v in self.pos_dict.items() for _ in v]
|
97
mfg_package/environment/groups/mixins.py
Normal file
@ -0,0 +1,97 @@
|
||||
from mfg_package.environment import constants as c
|
||||
|
||||
from mfg_package.environment.entity.entity import Entity
|
||||
|
||||
|
||||
# noinspection PyUnresolvedReferences,PyTypeChecker,PyArgumentList
|
||||
class PositionMixin:
|
||||
|
||||
_entity = Entity
|
||||
is_blocking_light: bool = True
|
||||
can_collide: bool = True
|
||||
has_position: bool = True
|
||||
|
||||
def render(self):
|
||||
return [y for y in [x.render() for x in self] if y is not None]
|
||||
|
||||
@classmethod
|
||||
def from_tiles(cls, tiles, *args, entity_kwargs=None, **kwargs):
|
||||
collection = cls(*args, **kwargs)
|
||||
entities = [cls._entity(tile, str_ident=i,
|
||||
**entity_kwargs if entity_kwargs is not None else {})
|
||||
for i, tile in enumerate(tiles)]
|
||||
collection.add_items(entities)
|
||||
return collection
|
||||
|
||||
@classmethod
|
||||
def from_coordinates(cls, positions: [(int, int)], tiles, *args, entity_kwargs=None, **kwargs, ):
|
||||
return cls.from_tiles([tiles.by_pos(position) for position in positions], tiles.size, *args,
|
||||
entity_kwargs=entity_kwargs,
|
||||
**kwargs)
|
||||
|
||||
@property
|
||||
def tiles(self):
|
||||
return [entity.tile for entity in self]
|
||||
|
||||
def __delitem__(self, name):
|
||||
idx, obj = next((i, obj) for i, obj in enumerate(self) if obj.name == name)
|
||||
obj.tile.leave(obj)
|
||||
super().__delitem__(name)
|
||||
|
||||
def by_pos(self, pos: (int, int)):
|
||||
pos = tuple(pos)
|
||||
try:
|
||||
return next(e for e in self if e.pos == pos)
|
||||
except StopIteration:
|
||||
pass
|
||||
except ValueError:
|
||||
print()
|
||||
|
||||
@property
|
||||
def positions(self):
|
||||
return [e.pos for e in self]
|
||||
|
||||
def notify_del_entity(self, entity: Entity):
|
||||
try:
|
||||
self.pos_dict[entity.pos].remove(entity)
|
||||
except (ValueError, AttributeError):
|
||||
pass
|
||||
|
||||
|
||||
# noinspection PyUnresolvedReferences,PyTypeChecker
|
||||
class IsBoundMixin:
|
||||
|
||||
@property
|
||||
def name(self):
|
||||
return f'{self.__class__.__name__}({self._bound_entity.name})'
|
||||
|
||||
def __repr__(self):
|
||||
return f'{self.__class__.__name__}#{self._bound_entity.name}({self._data})'
|
||||
|
||||
def bind(self, entity):
|
||||
# noinspection PyAttributeOutsideInit
|
||||
self._bound_entity = entity
|
||||
return c.VALID
|
||||
|
||||
def belongs_to_entity(self, entity):
|
||||
return self._bound_entity == entity
|
||||
|
||||
|
||||
# noinspection PyUnresolvedReferences,PyTypeChecker
|
||||
class HasBoundedMixin:
|
||||
|
||||
@property
|
||||
def obs_names(self):
|
||||
return [x.name for x in self]
|
||||
|
||||
def by_entity(self, entity):
|
||||
try:
|
||||
return next((x for x in self if x.belongs_to_entity(entity)))
|
||||
except StopIteration:
|
||||
return None
|
||||
|
||||
def idx_by_entity(self, entity):
|
||||
try:
|
||||
return next((idx for idx, x in enumerate(self) if x.belongs_to_entity(entity)))
|
||||
except StopIteration:
|
||||
return None
|
141
mfg_package/environment/groups/objects.py
Normal file
@ -0,0 +1,141 @@
|
||||
from collections import defaultdict
|
||||
from typing import List
|
||||
|
||||
import numpy as np
|
||||
|
||||
from mfg_package.environment.entity.object import Object
|
||||
|
||||
|
||||
class Objects:
|
||||
_entity = Object
|
||||
|
||||
@property
|
||||
def observers(self):
|
||||
return self._observers
|
||||
|
||||
@property
|
||||
def obs_tag(self):
|
||||
return self.__class__.__name__
|
||||
|
||||
@staticmethod
|
||||
def render():
|
||||
return []
|
||||
|
||||
@property
|
||||
def obs_pairs(self):
|
||||
return [(self.name, self)]
|
||||
|
||||
@property
|
||||
def names(self):
|
||||
# noinspection PyUnresolvedReferences
|
||||
return [x.name for x in self]
|
||||
|
||||
@property
|
||||
def name(self):
|
||||
return f'{self.__class__.__name__}'
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
self._data = defaultdict(lambda: None)
|
||||
self._observers = list()
|
||||
self.pos_dict = defaultdict(list)
|
||||
|
||||
def __len__(self):
|
||||
return len(self._data)
|
||||
|
||||
def __iter__(self):
|
||||
return iter(self.values())
|
||||
|
||||
def add_item(self, item: _entity):
|
||||
assert_str = f'All item names have to be of type {self._entity}, but were {item.__class__}.,'
|
||||
assert isinstance(item, self._entity), assert_str
|
||||
assert self._data[item.name] is None, f'{item.name} allready exists!!!'
|
||||
self._data.update({item.name: item})
|
||||
item.set_collection(self)
|
||||
for observer in self.observers:
|
||||
observer.notify_add_entity(item)
|
||||
return self
|
||||
|
||||
# noinspection PyUnresolvedReferences
|
||||
def del_observer(self, observer):
|
||||
self.observers.remove(observer)
|
||||
for entity in self:
|
||||
if observer in entity.observers:
|
||||
entity.del_observer(observer)
|
||||
|
||||
# noinspection PyUnresolvedReferences
|
||||
def add_observer(self, observer):
|
||||
self.observers.append(observer)
|
||||
for entity in self:
|
||||
if observer not in entity.observers:
|
||||
entity.add_observer(observer)
|
||||
|
||||
def __delitem__(self, name):
|
||||
for observer in self.observers:
|
||||
observer.notify_del_entity(name)
|
||||
# noinspection PyTypeChecker
|
||||
del self._data[name]
|
||||
|
||||
def add_items(self, items: List[_entity]):
|
||||
for item in items:
|
||||
self.add_item(item)
|
||||
return self
|
||||
|
||||
def keys(self):
|
||||
return self._data.keys()
|
||||
|
||||
def values(self):
|
||||
return self._data.values()
|
||||
|
||||
def items(self):
|
||||
return self._data.items()
|
||||
|
||||
def _get_index(self, item):
|
||||
try:
|
||||
return next(i for i, v in enumerate(self._data.values()) if v == item)
|
||||
except StopIteration:
|
||||
return None
|
||||
|
||||
|
||||
|
||||
def __getitem__(self, item):
|
||||
if isinstance(item, (int, np.int64, np.int32)):
|
||||
if item < 0:
|
||||
item = len(self._data) - abs(item)
|
||||
try:
|
||||
return next(v for i, v in enumerate(self._data.values()) if i == item)
|
||||
except StopIteration:
|
||||
return None
|
||||
try:
|
||||
return self._data[item]
|
||||
except KeyError:
|
||||
return None
|
||||
except TypeError:
|
||||
print('Ups')
|
||||
raise TypeError
|
||||
|
||||
def __repr__(self):
|
||||
return f'{self.__class__.__name__}[{dict(self._data)}]'
|
||||
|
||||
def notify_change_pos(self, entity: object):
|
||||
try:
|
||||
self.pos_dict[entity.last_pos].remove(entity)
|
||||
except (ValueError, AttributeError):
|
||||
pass
|
||||
if entity.has_position:
|
||||
try:
|
||||
self.pos_dict[entity.pos].append(entity)
|
||||
except (ValueError, AttributeError):
|
||||
pass
|
||||
|
||||
def notify_del_entity(self, entity: Object):
|
||||
try:
|
||||
self.pos_dict[entity.pos].remove(entity)
|
||||
except (ValueError, AttributeError):
|
||||
pass
|
||||
|
||||
def notify_add_entity(self, entity: Object):
|
||||
try:
|
||||
entity.add_observer(self)
|
||||
self.pos_dict[entity.pos].append(entity)
|
||||
except (ValueError, AttributeError):
|
||||
pass
|
77
mfg_package/environment/groups/utils.py
Normal file
@ -0,0 +1,77 @@
|
||||
from typing import List, Union
|
||||
|
||||
import numpy as np
|
||||
|
||||
from mfg_package.environment.groups.env_objects import EnvObjects
|
||||
from mfg_package.environment.groups.objects import Objects
|
||||
from mfg_package.environment.groups.mixins import HasBoundedMixin, PositionMixin
|
||||
from mfg_package.environment.entity.util import GlobalPosition
|
||||
from mfg_package.utils import helpers as h
|
||||
from mfg_package.environment import constants as c
|
||||
|
||||
|
||||
class Combined(PositionMixin, EnvObjects):
|
||||
|
||||
@property
|
||||
def name(self):
|
||||
return f'{super().name}({self._ident or self._names})'
|
||||
|
||||
@property
|
||||
def names(self):
|
||||
return self._names
|
||||
|
||||
def __init__(self, names: List[str], *args, identifier: Union[None, str] = None, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self._ident = identifier
|
||||
self._names = names or list()
|
||||
|
||||
@property
|
||||
def obs_tag(self):
|
||||
return self.name
|
||||
|
||||
@property
|
||||
def obs_pairs(self):
|
||||
return [(name, None) for name in self.names]
|
||||
|
||||
|
||||
class GlobalPositions(HasBoundedMixin, EnvObjects):
|
||||
|
||||
_entity = GlobalPosition
|
||||
is_blocking_light = False,
|
||||
can_collide = False
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super(GlobalPositions, self).__init__(*args, **kwargs)
|
||||
|
||||
|
||||
class Zones(Objects):
|
||||
|
||||
@property
|
||||
def accounting_zones(self):
|
||||
return [self[idx] for idx, name in self.items() if name != c.DANGER_ZONE]
|
||||
|
||||
def __init__(self, parsed_level):
|
||||
raise NotImplementedError('This needs a Rework')
|
||||
super(Zones, self).__init__()
|
||||
slices = list()
|
||||
self._accounting_zones = list()
|
||||
self._danger_zones = list()
|
||||
for symbol in np.unique(parsed_level):
|
||||
if symbol == c.VALUE_OCCUPIED_CELL:
|
||||
continue
|
||||
elif symbol == c.DANGER_ZONE:
|
||||
self + symbol
|
||||
slices.append(h.one_hot_level(parsed_level, symbol))
|
||||
self._danger_zones.append(symbol)
|
||||
else:
|
||||
self + symbol
|
||||
slices.append(h.one_hot_level(parsed_level, symbol))
|
||||
self._accounting_zones.append(symbol)
|
||||
|
||||
self._zone_slices = np.stack(slices)
|
||||
|
||||
def __getitem__(self, item):
|
||||
return self._zone_slices[item]
|
||||
|
||||
def add_items(self, other: Union[str, List[str]]):
|
||||
raise AttributeError('You are not allowed to add additional Zones in runtime.')
|
54
mfg_package/environment/groups/wall_n_floors.py
Normal file
@ -0,0 +1,54 @@
|
||||
import random
|
||||
from typing import List
|
||||
|
||||
from mfg_package.environment import constants as c
|
||||
from mfg_package.environment.groups.env_objects import EnvObjects
|
||||
from mfg_package.environment.groups.mixins import PositionMixin
|
||||
from mfg_package.environment.entity.wall_floor import Wall, Floor
|
||||
|
||||
|
||||
class Walls(PositionMixin, EnvObjects):
|
||||
_entity = Wall
|
||||
symbol = c.SYMBOL_WALL
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super(Walls, self).__init__(*args, **kwargs)
|
||||
self._value = c.VALUE_OCCUPIED_CELL
|
||||
|
||||
@classmethod
|
||||
def from_coordinates(cls, argwhere_coordinates, *args, **kwargs):
|
||||
tiles = cls(*args, **kwargs)
|
||||
# noinspection PyTypeChecker
|
||||
tiles.add_items([cls._entity(pos) for pos in argwhere_coordinates])
|
||||
return tiles
|
||||
|
||||
@classmethod
|
||||
def from_tiles(cls, tiles, *args, **kwargs):
|
||||
raise RuntimeError()
|
||||
|
||||
|
||||
class Floors(Walls):
|
||||
_entity = Floor
|
||||
symbol = c.SYMBOL_FLOOR
|
||||
is_blocking_light: bool = False
|
||||
can_collide: bool = False
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super(Floors, self).__init__(*args, **kwargs)
|
||||
self._value = c.VALUE_FREE_CELL
|
||||
|
||||
@property
|
||||
def occupied_tiles(self):
|
||||
tiles = [tile for tile in self if tile.is_occupied()]
|
||||
random.shuffle(tiles)
|
||||
return tiles
|
||||
|
||||
@property
|
||||
def empty_tiles(self) -> List[Floor]:
|
||||
tiles = [tile for tile in self if tile.is_empty()]
|
||||
random.shuffle(tiles)
|
||||
return tiles
|
||||
|
||||
@classmethod
|
||||
def from_tiles(cls, tiles, *args, **kwargs):
|
||||
raise RuntimeError()
|
4
mfg_package/environment/rewards.py
Normal file
@ -0,0 +1,4 @@
|
||||
MOVEMENTS_VALID: float = -0.001
|
||||
MOVEMENTS_FAIL: float = -0.05
|
||||
NOOP: float = -0.01
|
||||
COLLISION: float = -0.5
|
82
mfg_package/environment/rules.py
Normal file
@ -0,0 +1,82 @@
|
||||
import abc
|
||||
from typing import List
|
||||
|
||||
from mfg_package.utils.results import TickResult, DoneResult
|
||||
from mfg_package.environment import rewards as r, constants as c
|
||||
|
||||
|
||||
class Rule(abc.ABC):
|
||||
|
||||
@property
|
||||
def name(self):
|
||||
return self.__class__.__name__
|
||||
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
def __repr__(self):
|
||||
return f'{self.name}'
|
||||
|
||||
def on_init(self, state):
|
||||
return []
|
||||
|
||||
def on_reset(self):
|
||||
return []
|
||||
|
||||
def tick_pre_step(self, state) -> List[TickResult]:
|
||||
return []
|
||||
|
||||
def tick_step(self, state) -> List[TickResult]:
|
||||
return []
|
||||
|
||||
def tick_post_step(self, state) -> List[TickResult]:
|
||||
return []
|
||||
|
||||
def on_check_done(self, state) -> List[DoneResult]:
|
||||
return []
|
||||
|
||||
|
||||
class MaxStepsReached(Rule):
|
||||
|
||||
def __init__(self, max_steps: int = 500):
|
||||
super().__init__()
|
||||
self.max_steps = max_steps
|
||||
|
||||
def on_init(self, state):
|
||||
pass
|
||||
|
||||
def on_check_done(self, state):
|
||||
if self.max_steps <= state.curr_step:
|
||||
return [DoneResult(validity=c.VALID, identifier=self.name, reward=0)]
|
||||
return [DoneResult(validity=c.NOT_VALID, identifier=self.name, reward=0)]
|
||||
|
||||
|
||||
class Collision(Rule):
|
||||
|
||||
def __init__(self, done_at_collisions: bool = False):
|
||||
super().__init__()
|
||||
self.done_at_collisions = done_at_collisions
|
||||
self.curr_done = False
|
||||
|
||||
def tick_post_step(self, state) -> List[TickResult]:
|
||||
self.curr_done = False
|
||||
tiles_with_collisions = state.get_all_tiles_with_collisions()
|
||||
results = list()
|
||||
for tile in tiles_with_collisions:
|
||||
guests = tile.guests_that_can_collide
|
||||
if len(guests) >= 2:
|
||||
for i, guest in enumerate(guests):
|
||||
try:
|
||||
guest.set_state(TickResult(identifier=c.COLLISION, reward=r.COLLISION,
|
||||
validity=c.NOT_VALID, entity=self))
|
||||
except AttributeError:
|
||||
pass
|
||||
results.append(TickResult(entity=guest, identifier=c.COLLISION,
|
||||
reward=r.COLLISION, validity=c.VALID))
|
||||
self.curr_done = True
|
||||
return results
|
||||
|
||||
def on_check_done(self, state) -> List[DoneResult]:
|
||||
if self.curr_done and self.done_at_collisions:
|
||||
return [DoneResult(validity=c.VALID, identifier=c.COLLISION, reward=r.COLLISION)]
|
||||
return [DoneResult(validity=c.NOT_VALID, identifier=self.name, reward=0)]
|
0
mfg_package/logging/__init__.py
Normal file
64
mfg_package/logging/envmonitor.py
Normal file
@ -0,0 +1,64 @@
|
||||
import pickle
|
||||
from os import PathLike
|
||||
from pathlib import Path
|
||||
from typing import Union
|
||||
|
||||
from gymnasium import Wrapper
|
||||
|
||||
from mfg_package.utils.helpers import IGNORED_DF_COLUMNS
|
||||
from mfg_package.environment.factory import REC_TAC
|
||||
|
||||
import pandas as pd
|
||||
|
||||
from mfg_package.plotting.compare_runs import plot_single_run
|
||||
|
||||
|
||||
class EnvMonitor(Wrapper):
|
||||
|
||||
ext = 'png'
|
||||
|
||||
def __init__(self, env, filepath: Union[str, PathLike] = None):
|
||||
super(EnvMonitor, self).__init__(env)
|
||||
self._filepath = filepath
|
||||
self._monitor_df = pd.DataFrame()
|
||||
self._monitor_dict = dict()
|
||||
|
||||
def __getattr__(self, item):
|
||||
return getattr(self.unwrapped, item)
|
||||
|
||||
def step(self, action):
|
||||
obs_type, obs, reward, done, info = self.env.step(action)
|
||||
self._read_info(info)
|
||||
self._read_done(done)
|
||||
return obs_type, obs, reward, done, info
|
||||
|
||||
def reset(self):
|
||||
return self.unwrapped.reset()
|
||||
|
||||
def _read_info(self, info: dict):
|
||||
self._monitor_dict[len(self._monitor_dict)] = {
|
||||
key: val for key, val in info.items() if
|
||||
key not in ['terminal_observation', 'episode'] and not key.startswith(REC_TAC)}
|
||||
return
|
||||
|
||||
def _read_done(self, done):
|
||||
if done:
|
||||
env_monitor_df = pd.DataFrame.from_dict(self._monitor_dict, orient='index')
|
||||
self._monitor_dict = dict()
|
||||
columns = [col for col in env_monitor_df.columns if col not in IGNORED_DF_COLUMNS]
|
||||
env_monitor_df = env_monitor_df.aggregate(
|
||||
{col: 'mean' if col.endswith('ount') else 'sum' for col in columns}
|
||||
)
|
||||
env_monitor_df['episode'] = len(self._monitor_df)
|
||||
self._monitor_df = self._monitor_df.append([env_monitor_df])
|
||||
else:
|
||||
pass
|
||||
return
|
||||
|
||||
def save_run(self, filepath: Union[Path, str, None] = None, auto_plotting_keys=None):
|
||||
filepath = Path(filepath or self._filepath)
|
||||
filepath.parent.mkdir(exist_ok=True, parents=True)
|
||||
with filepath.open('wb') as f:
|
||||
pickle.dump(self._monitor_df.reset_index(), f, protocol=pickle.HIGHEST_PROTOCOL)
|
||||
if auto_plotting_keys:
|
||||
plot_single_run(filepath, column_keys=auto_plotting_keys)
|
152
mfg_package/logging/recorder.py
Normal file
@ -0,0 +1,152 @@
|
||||
import warnings
|
||||
from collections import defaultdict
|
||||
from os import PathLike
|
||||
from pathlib import Path
|
||||
from typing import Union
|
||||
|
||||
import yaml
|
||||
from gymnasium import Wrapper
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
|
||||
from mfg_package.environment.factory import REC_TAC
|
||||
|
||||
|
||||
class EnvRecorder(Wrapper):
|
||||
|
||||
def __init__(self, env, entities: str = 'all', filepath: Union[str, PathLike] = None, freq: int = 0):
|
||||
super(EnvRecorder, self).__init__(env)
|
||||
self.filepath = filepath
|
||||
self.freq = freq
|
||||
self._recorder_dict = defaultdict(list)
|
||||
self._recorder_out_list = list()
|
||||
self._episode_counter = 1
|
||||
self._do_record_dict = defaultdict(lambda: False)
|
||||
if isinstance(entities, str):
|
||||
if entities.lower() == 'all':
|
||||
self._entities = None
|
||||
else:
|
||||
self._entities = [entities]
|
||||
else:
|
||||
self._entities = entities
|
||||
|
||||
def __getattr__(self, item):
|
||||
return getattr(self.unwrapped, item)
|
||||
|
||||
def reset(self):
|
||||
self._on_training_start()
|
||||
return self.unwrapped.reset()
|
||||
|
||||
def _on_training_start(self) -> None:
|
||||
assert self.start_recording()
|
||||
|
||||
def _read_info(self, env_idx, info: dict):
|
||||
if info_dict := {key.replace(REC_TAC, ''): val for key, val in info.items() if key.startswith(f'{REC_TAC}')}:
|
||||
if self._entities:
|
||||
info_dict = {k: v for k, v in info_dict.items() if k in self._entities}
|
||||
self._recorder_dict[env_idx].append(info_dict)
|
||||
else:
|
||||
pass
|
||||
return True
|
||||
|
||||
def _read_done(self, env_idx, done):
|
||||
if done:
|
||||
self._recorder_out_list.append({'steps': self._recorder_dict[env_idx],
|
||||
'episode': self._episode_counter})
|
||||
self._recorder_dict[env_idx] = list()
|
||||
else:
|
||||
pass
|
||||
|
||||
def step(self, actions):
|
||||
step_result = self.unwrapped.step(actions)
|
||||
if self.do_record_episode(0):
|
||||
info = step_result[-1]
|
||||
self._read_info(0, info)
|
||||
if self._do_record_dict[0]:
|
||||
self._read_done(0, step_result[-2])
|
||||
return step_result
|
||||
|
||||
def finalize(self):
|
||||
self._on_training_end()
|
||||
return True
|
||||
|
||||
def save_records(self, filepath: Union[Path, str, None] = None,
|
||||
only_deltas=True,
|
||||
save_occupation_map=False,
|
||||
save_trajectory_map=False,
|
||||
):
|
||||
filepath = Path(filepath or self.filepath)
|
||||
filepath.parent.mkdir(exist_ok=True, parents=True)
|
||||
# cls.out_file.unlink(missing_ok=True)
|
||||
with filepath.open('w') as f:
|
||||
if only_deltas:
|
||||
from deepdiff import DeepDiff
|
||||
diff_dict = [DeepDiff(t1,t2, ignore_order=True)
|
||||
for t1, t2 in zip(self._recorder_out_list, self._recorder_out_list[1:])
|
||||
]
|
||||
out_dict = {'episodes': diff_dict}
|
||||
|
||||
else:
|
||||
out_dict = {'episodes': self._recorder_out_list}
|
||||
out_dict.update(
|
||||
{'n_episodes': self._episode_counter,
|
||||
'env_params': self.env.params,
|
||||
'header': self.env.summarize_header
|
||||
})
|
||||
try:
|
||||
yaml.dump(out_dict, f, indent=4)
|
||||
except TypeError:
|
||||
print('Shit')
|
||||
|
||||
if save_occupation_map:
|
||||
a = np.zeros((15, 15))
|
||||
# noinspection PyTypeChecker
|
||||
for episode in out_dict['episodes']:
|
||||
df = pd.DataFrame([y for x in episode['steps'] for y in x['Agents']])
|
||||
|
||||
b = list(df[['x', 'y']].to_records(index=False))
|
||||
|
||||
np.add.at(a, tuple(zip(*b)), 1)
|
||||
|
||||
# a = np.rot90(a)
|
||||
import seaborn as sns
|
||||
from matplotlib import pyplot as plt
|
||||
hm = sns.heatmap(data=a)
|
||||
hm.set_title('Very Nice Heatmap')
|
||||
plt.show()
|
||||
|
||||
if save_trajectory_map:
|
||||
raise NotImplementedError('This has not yet been implemented.')
|
||||
|
||||
def do_record_episode(self, env_idx):
|
||||
if not self._recorder_dict[env_idx]:
|
||||
if self.freq:
|
||||
self._do_record_dict[env_idx] = (self.freq == -1) or (self._episode_counter % self.freq) == 0
|
||||
else:
|
||||
self._do_record_dict[env_idx] = False
|
||||
warnings.warn('You did wrap your Environment with a recorder, but set the freq to zero\n'
|
||||
'Nothing will be recorded')
|
||||
self._episode_counter += 1
|
||||
else:
|
||||
pass
|
||||
return self._do_record_dict[env_idx]
|
||||
|
||||
def _on_step(self) -> bool:
|
||||
for env_idx, info in enumerate(self.locals.get('infos', [])):
|
||||
if self._do_record_dict[env_idx]:
|
||||
self._read_info(env_idx, info)
|
||||
dones = list(enumerate(self.locals.get('dones', [])))
|
||||
dones.extend(list(enumerate(self.locals.get('done', []))))
|
||||
for env_idx, done in dones:
|
||||
if self._do_record_dict[env_idx]:
|
||||
self._read_done(env_idx, done)
|
||||
|
||||
return True
|
||||
|
||||
def _on_training_end(self) -> None:
|
||||
for env_idx in range(len(self._recorder_dict)):
|
||||
if self._recorder_dict[env_idx]:
|
||||
self._recorder_out_list.append({'steps': self._recorder_dict[env_idx],
|
||||
'episode': self._episode_counter})
|
||||
pass
|
0
mfg_package/modules/__init__.py
Normal file
0
mfg_package/modules/_template/__init__.py
Normal file
11
mfg_package/modules/_template/constants.py
Normal file
@ -0,0 +1,11 @@
|
||||
TEMPLATE = '#' # TEMPLATE _identifier. Define your own!
|
||||
|
||||
# Movements
|
||||
NORTH = 'north'
|
||||
EAST = 'east'
|
||||
SOUTH = 'south'
|
||||
WEST = 'west'
|
||||
NORTHEAST = 'north_east'
|
||||
SOUTHEAST = 'south_east'
|
||||
SOUTHWEST = 'south_west'
|
||||
NORTHWEST = 'north_west'
|
24
mfg_package/modules/_template/rules.py
Normal file
@ -0,0 +1,24 @@
|
||||
from typing import List
|
||||
from mfg_package.environment.rules import Rule
|
||||
from mfg_package.utils.results import TickResult, DoneResult
|
||||
|
||||
|
||||
class TemplateRule(Rule):
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super(TemplateRule, self).__init__(*args, **kwargs)
|
||||
|
||||
def on_init(self, state):
|
||||
pass
|
||||
|
||||
def tick_pre_step(self, state) -> List[TickResult]:
|
||||
pass
|
||||
|
||||
def tick_step(self, state) -> List[TickResult]:
|
||||
pass
|
||||
|
||||
def tick_post_step(self, state) -> List[TickResult]:
|
||||
pass
|
||||
|
||||
def on_check_done(self, state) -> List[DoneResult]:
|
||||
pass
|
0
mfg_package/modules/batteries/__init__.py
Normal file
26
mfg_package/modules/batteries/actions.py
Normal file
@ -0,0 +1,26 @@
|
||||
from typing import Union
|
||||
|
||||
from mfg_package.environment.actions import Action
|
||||
from mfg_package.utils.results import ActionResult
|
||||
|
||||
from mfg_package.modules.batteries import constants as b, rewards as r
|
||||
from mfg_package.environment import constants as c
|
||||
|
||||
|
||||
class BtryCharge(Action):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(b.CHARGE)
|
||||
|
||||
def do(self, entity, state) -> Union[None, ActionResult]:
|
||||
if charge_pod := state[b.CHARGE_PODS].by_pos(entity.pos):
|
||||
valid = charge_pod.charge_battery(state[b.BATTERIES].by_entity(entity))
|
||||
if valid:
|
||||
state.print(f'{entity.name} just charged batteries at {charge_pod.name}.')
|
||||
else:
|
||||
state.print(f'{entity.name} failed to charged batteries at {charge_pod.name}.')
|
||||
else:
|
||||
valid = c.NOT_VALID
|
||||
state.print(f'{entity.name} failed to charged batteries at {entity.pos}.')
|
||||
return ActionResult(entity=entity, identifier=self._identifier, validity=valid,
|
||||
reward=r.CHARGE_VALID if valid else r.CHARGE_FAIL)
|
19
mfg_package/modules/batteries/constants.py
Normal file
@ -0,0 +1,19 @@
|
||||
from typing import NamedTuple, Union
|
||||
|
||||
# Battery Env
|
||||
CHARGE_PODS = 'ChargePods'
|
||||
BATTERIES = 'Batteries'
|
||||
BATTERY_DISCHARGED = 'DISCHARGED'
|
||||
CHARGE_POD_SYMBOL = 1
|
||||
|
||||
|
||||
CHARGE = 'do_charge_action'
|
||||
|
||||
|
||||
class BatteryProperties(NamedTuple):
|
||||
initial_charge: float = 0.8 #
|
||||
charge_rate: float = 0.4 #
|
||||
charge_locations: int = 20 #
|
||||
per_action_costs: Union[dict, float] = 0.02
|
||||
done_when_discharged: bool = False
|
||||
multi_charge: bool = False
|
75
mfg_package/modules/batteries/entitites.py
Normal file
@ -0,0 +1,75 @@
|
||||
from mfg_package.environment.entity.mixin import BoundEntityMixin
|
||||
from mfg_package.environment.entity.object import EnvObject
|
||||
from mfg_package.environment.entity.entity import Entity
|
||||
from mfg_package.environment import constants as c
|
||||
from mfg_package.utils.render import RenderEntity
|
||||
|
||||
from mfg_package.modules.batteries import constants as b
|
||||
|
||||
|
||||
class Battery(BoundEntityMixin, EnvObject):
|
||||
|
||||
@property
|
||||
def is_discharged(self):
|
||||
return self.charge_level == 0
|
||||
|
||||
@property
|
||||
def obs_tag(self):
|
||||
return self.name
|
||||
|
||||
@property
|
||||
def encoding(self):
|
||||
return self.charge_level
|
||||
|
||||
def __init__(self, initial_charge_level: float, owner: Entity, *args, **kwargs):
|
||||
super(Battery, self).__init__(*args, **kwargs)
|
||||
self.charge_level = initial_charge_level
|
||||
self.bind_to(owner)
|
||||
|
||||
def do_charge_action(self, amount):
|
||||
if self.charge_level < 1:
|
||||
# noinspection PyTypeChecker
|
||||
self.charge_level = min(1, amount + self.charge_level)
|
||||
return c.VALID
|
||||
else:
|
||||
return c.NOT_VALID
|
||||
|
||||
def decharge(self, amount) -> float:
|
||||
if self.charge_level != 0:
|
||||
# noinspection PyTypeChecker
|
||||
self.charge_level = max(0, amount + self.charge_level)
|
||||
return c.VALID
|
||||
else:
|
||||
return c.NOT_VALID
|
||||
|
||||
def summarize_state(self, **_):
|
||||
attr_dict = {key: str(val) for key, val in self.__dict__.items() if not key.startswith('_') and key != 'data'}
|
||||
attr_dict.update(dict(name=self.name, belongs_to=self._bound_entity.name))
|
||||
return attr_dict
|
||||
|
||||
def render(self):
|
||||
return None
|
||||
|
||||
|
||||
class ChargePod(Entity):
|
||||
|
||||
@property
|
||||
def encoding(self):
|
||||
return b.CHARGE_POD_SYMBOL
|
||||
|
||||
def __init__(self, *args, charge_rate: float = 0.4,
|
||||
multi_charge: bool = False, **kwargs):
|
||||
super(ChargePod, self).__init__(*args, **kwargs)
|
||||
self.charge_rate = charge_rate
|
||||
self.multi_charge = multi_charge
|
||||
|
||||
def charge_battery(self, battery: Battery):
|
||||
if battery.charge_level == 1.0:
|
||||
return c.NOT_VALID
|
||||
if sum(guest for guest in self.tile.guests if 'agent' in guest.name.lower()) > 1:
|
||||
return c.NOT_VALID
|
||||
valid = battery.do_charge_action(self.charge_rate)
|
||||
return valid
|
||||
|
||||
def render(self):
|
||||
return RenderEntity(b.CHARGE_PODS, self.pos)
|
36
mfg_package/modules/batteries/groups.py
Normal file
@ -0,0 +1,36 @@
|
||||
from mfg_package.environment.groups.env_objects import EnvObjects
|
||||
from mfg_package.environment.groups.mixins import PositionMixin, HasBoundedMixin
|
||||
from mfg_package.modules.batteries.entitites import ChargePod, Battery
|
||||
|
||||
|
||||
class Batteries(HasBoundedMixin, EnvObjects):
|
||||
|
||||
_entity = Battery
|
||||
is_blocking_light: bool = False
|
||||
can_collide: bool = False
|
||||
|
||||
@property
|
||||
def obs_tag(self):
|
||||
return self.__class__.__name__
|
||||
|
||||
@property
|
||||
def obs_pairs(self):
|
||||
return [(x.name, x) for x in self]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super(Batteries, self).__init__(*args, **kwargs)
|
||||
|
||||
def spawn_batteries(self, agents, initial_charge_level):
|
||||
batteries = [self._entity(initial_charge_level, agent) for _, agent in enumerate(agents)]
|
||||
self.add_items(batteries)
|
||||
|
||||
|
||||
class ChargePods(PositionMixin, EnvObjects):
|
||||
|
||||
_entity = ChargePod
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super(ChargePods, self).__init__(*args, **kwargs)
|
||||
|
||||
def __repr__(self):
|
||||
return super(ChargePods, self).__repr__()
|
3
mfg_package/modules/batteries/rewards.py
Normal file
@ -0,0 +1,3 @@
|
||||
CHARGE_VALID: float = 0.1
|
||||
CHARGE_FAIL: float = -0.1
|
||||
BATTERY_DISCHARGED: float = -1.0
|
61
mfg_package/modules/batteries/rules.py
Normal file
@ -0,0 +1,61 @@
|
||||
from typing import List, Union
|
||||
from mfg_package.environment.rules import Rule
|
||||
from mfg_package.utils.results import TickResult, DoneResult
|
||||
|
||||
from mfg_package.environment import constants as c
|
||||
from mfg_package.modules.batteries import constants as b, rewards as r
|
||||
|
||||
|
||||
class Btry(Rule):
|
||||
|
||||
def __init__(self, initial_charge: float = 0.8, per_action_costs: Union[dict, float] = 0.02):
|
||||
super().__init__()
|
||||
self.per_action_costs = per_action_costs
|
||||
self.initial_charge = initial_charge
|
||||
|
||||
def on_init(self, state):
|
||||
state[b.BATTERIES].spawn_batteries(state[c.AGENT], self.initial_charge)
|
||||
|
||||
def tick_pre_step(self, state) -> List[TickResult]:
|
||||
pass
|
||||
|
||||
def tick_step(self, state) -> List[TickResult]:
|
||||
# Decharge
|
||||
batteries = state[b.BATTERIES]
|
||||
results = []
|
||||
|
||||
for agent in state[c.AGENT]:
|
||||
if isinstance(self.per_action_costs, dict):
|
||||
energy_consumption = self.per_action_costs[agent.step_result()['action']]
|
||||
else:
|
||||
energy_consumption = self.per_action_costs
|
||||
|
||||
batteries.by_entity(agent).decharge(energy_consumption)
|
||||
|
||||
results.append(TickResult(self.name, reward=0, entity=agent, validity=c.VALID))
|
||||
|
||||
return results
|
||||
|
||||
def tick_post_step(self, state) -> List[TickResult]:
|
||||
results = []
|
||||
for btry in state[b.BATTERIES]:
|
||||
if btry.is_discharged:
|
||||
state.print(f'Battery of {btry.bound_entity.name} is discharged!')
|
||||
results.append(
|
||||
TickResult(self.name, entity=btry.bound_entity, reward=r.BATTERY_DISCHARGED, validity=c.VALID))
|
||||
else:
|
||||
pass
|
||||
return results
|
||||
|
||||
|
||||
class BtryDoneAtDischarge(Rule):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
def on_check_done(self, state) -> List[DoneResult]:
|
||||
if btry_done := any(battery.is_discharged for battery in state[b.BATTERIES]):
|
||||
return [DoneResult(self.name, validity=c.VALID, reward=r.BATTERY_DISCHARGED)]
|
||||
else:
|
||||
return [DoneResult(self.name, validity=c.NOT_VALID, reward=0)]
|
||||
|
0
mfg_package/modules/clean_up/__init__.py
Normal file
36
mfg_package/modules/clean_up/actions.py
Normal file
@ -0,0 +1,36 @@
|
||||
from typing import Union
|
||||
|
||||
from mfg_package.environment.actions import Action
|
||||
from mfg_package.utils.results import ActionResult
|
||||
|
||||
from mfg_package.modules.clean_up import constants as d, rewards as r
|
||||
|
||||
from mfg_package.environment import constants as c
|
||||
|
||||
|
||||
class CleanUp(Action):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(d.CLEAN_UP)
|
||||
|
||||
def do(self, entity, state) -> Union[None, ActionResult]:
|
||||
if dirt := state[d.DIRT].by_pos(entity.pos):
|
||||
new_dirt_amount = dirt.amount - state[d.DIRT].clean_amount
|
||||
|
||||
if new_dirt_amount <= 0:
|
||||
state[d.DIRT].delete_env_object(dirt)
|
||||
else:
|
||||
dirt.set_new_amount(max(new_dirt_amount, c.VALUE_FREE_CELL))
|
||||
valid = c.VALID
|
||||
print_str = f'{entity.name} did just clean up some dirt at {entity.pos}.'
|
||||
state.print(print_str)
|
||||
reward = r.CLEAN_UP_VALID
|
||||
identifier = d.CLEAN_UP
|
||||
else:
|
||||
valid = c.NOT_VALID
|
||||
print_str = f'{entity.name} just tried to clean up some dirt at {entity.pos}, but failed.'
|
||||
state.print(print_str)
|
||||
reward = r.CLEAN_UP_FAIL
|
||||
identifier = d.CLEAN_UP_FAIL
|
||||
|
||||
return ActionResult(identifier=identifier, validity=valid, reward=reward, entity=entity)
|
7
mfg_package/modules/clean_up/constants.py
Normal file
@ -0,0 +1,7 @@
|
||||
DIRT = 'DirtPiles'
|
||||
|
||||
CLEAN_UP = 'do_cleanup_action'
|
||||
|
||||
CLEAN_UP_VALID = 'clean_up_valid'
|
||||
CLEAN_UP_FAIL = 'clean_up_fail'
|
||||
CLEAN_UP_ALL = 'all_cleaned_up'
|
BIN
mfg_package/modules/clean_up/dirtpiles.png
Normal file
After Width: | Height: | Size: 38 KiB |
35
mfg_package/modules/clean_up/entitites.py
Normal file
@ -0,0 +1,35 @@
|
||||
from numpy import random
|
||||
|
||||
from mfg_package.environment.entity.entity import Entity
|
||||
from mfg_package.utils.render import RenderEntity
|
||||
from mfg_package.modules.clean_up import constants as d
|
||||
|
||||
|
||||
class DirtPile(Entity):
|
||||
|
||||
@property
|
||||
def amount(self):
|
||||
return self._amount
|
||||
|
||||
@property
|
||||
def encoding(self):
|
||||
# Edit this if you want items to be drawn in the ops differntly
|
||||
return self._amount
|
||||
|
||||
def __init__(self, *args, max_local_amount=5, initial_amount=2, spawn_variation=0.05, **kwargs):
|
||||
super(DirtPile, self).__init__(*args, **kwargs)
|
||||
self._amount = abs(initial_amount + (
|
||||
random.normal(loc=0, scale=spawn_variation, size=1).item() * initial_amount)
|
||||
)
|
||||
self.max_local_amount = max_local_amount
|
||||
|
||||
def set_new_amount(self, amount):
|
||||
self._amount = min(amount, self.max_local_amount)
|
||||
|
||||
def summarize_state(self):
|
||||
state_dict = super().summarize_state()
|
||||
state_dict.update(amount=float(self.amount))
|
||||
return state_dict
|
||||
|
||||
def render(self):
|
||||
return RenderEntity(d.DIRT, self.tile.pos, min(0.15 + self.amount, 1.5), 'scale')
|
64
mfg_package/modules/clean_up/groups.py
Normal file
@ -0,0 +1,64 @@
|
||||
from mfg_package.environment.groups.env_objects import EnvObjects
|
||||
from mfg_package.environment.groups.mixins import PositionMixin
|
||||
from mfg_package.environment.entity.wall_floor import Floor
|
||||
from mfg_package.modules.clean_up.entitites import DirtPile
|
||||
|
||||
from mfg_package.environment import constants as c
|
||||
|
||||
|
||||
class DirtPiles(PositionMixin, EnvObjects):
|
||||
|
||||
_entity = DirtPile
|
||||
is_blocking_light: bool = False
|
||||
can_collide: bool = False
|
||||
|
||||
@property
|
||||
def amount(self):
|
||||
return sum([dirt.amount for dirt in self])
|
||||
|
||||
def __init__(self, *args,
|
||||
initial_amount=2,
|
||||
initial_dirt_ratio=0.05,
|
||||
dirt_spawn_r_var=0.1,
|
||||
max_local_amount=5,
|
||||
clean_amount=1,
|
||||
max_global_amount: int = 20, **kwargs):
|
||||
super(DirtPiles, self).__init__(*args, **kwargs)
|
||||
self.clean_amount = clean_amount
|
||||
self.initial_amount = initial_amount
|
||||
self.initial_dirt_ratio = initial_dirt_ratio
|
||||
self.dirt_spawn_r_var = dirt_spawn_r_var
|
||||
self.max_global_amount = max_global_amount
|
||||
self.max_local_amount = max_local_amount
|
||||
|
||||
def spawn_dirt(self, then_dirty_tiles, amount) -> bool:
|
||||
if isinstance(then_dirty_tiles, Floor):
|
||||
then_dirty_tiles = [then_dirty_tiles]
|
||||
for tile in then_dirty_tiles:
|
||||
if not self.amount > self.max_global_amount:
|
||||
if dirt := self.by_pos(tile.pos):
|
||||
new_value = dirt.amount + amount
|
||||
dirt.set_new_amount(new_value)
|
||||
else:
|
||||
dirt = DirtPile(tile, initial_amount=amount, spawn_variation=self.dirt_spawn_r_var)
|
||||
self.add_item(dirt)
|
||||
else:
|
||||
return c.NOT_VALID
|
||||
return c.VALID
|
||||
|
||||
def trigger_dirt_spawn(self, state, initial_spawn=False) -> bool:
|
||||
free_for_dirt = [x for x in state[c.FLOOR]
|
||||
if len(x.guests) == 0 or (
|
||||
len(x.guests) == 1 and
|
||||
isinstance(next(y for y in x.guests), DirtPile))
|
||||
]
|
||||
state.rng.shuffle(free_for_dirt)
|
||||
|
||||
var = self.dirt_spawn_r_var
|
||||
new_spawn = abs(self.initial_dirt_ratio + (state.rng.uniform(-var, var) if initial_spawn else 0))
|
||||
n_dirt_tiles = max(0, int(new_spawn * len(free_for_dirt)))
|
||||
return self.spawn_dirt(free_for_dirt[:n_dirt_tiles], self.initial_amount)
|
||||
|
||||
def __repr__(self):
|
||||
s = super(DirtPiles, self).__repr__()
|
||||
return f'{s[:-1]}, {self.amount})'
|
3
mfg_package/modules/clean_up/rewards.py
Normal file
@ -0,0 +1,3 @@
|
||||
CLEAN_UP_VALID: float = 0.5
|
||||
CLEAN_UP_FAIL: float = -0.1
|
||||
CLEAN_UP_ALL: float = 4.5
|
15
mfg_package/modules/clean_up/rule_done_on_all_clean.py
Normal file
@ -0,0 +1,15 @@
|
||||
from mfg_package.environment import constants as c
|
||||
from mfg_package.environment.rules import Rule
|
||||
from mfg_package.utils.results import DoneResult
|
||||
from mfg_package.modules.clean_up import constants as d, rewards as r
|
||||
|
||||
|
||||
class DirtAllCleanDone(Rule):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
def on_check_done(self, state) -> [DoneResult]:
|
||||
if len(state[d.DIRT]) == 0 and state.curr_step:
|
||||
return [DoneResult(validity=c.VALID, identifier=self.name, reward=r.CLEAN_UP_ALL)]
|
||||
return [DoneResult(validity=c.NOT_VALID, identifier=self.name, reward=0)]
|
28
mfg_package/modules/clean_up/rule_respawn.py
Normal file
@ -0,0 +1,28 @@
|
||||
from mfg_package.environment.rules import Rule
|
||||
from mfg_package.utils.results import TickResult
|
||||
|
||||
from mfg_package.modules.clean_up import constants as d
|
||||
|
||||
|
||||
class DirtRespawnRule(Rule):
|
||||
|
||||
def __init__(self, spawn_freq=15):
|
||||
super().__init__()
|
||||
self.spawn_freq = spawn_freq
|
||||
self._next_dirt_spawn = spawn_freq
|
||||
|
||||
def on_init(self, state) -> str:
|
||||
state[d.DIRT].trigger_dirt_spawn(state, initial_spawn=True)
|
||||
return f'Initial Dirt was spawned on: {[x.pos for x in state[d.DIRT]]}'
|
||||
|
||||
def tick_step(self, state):
|
||||
if self._next_dirt_spawn < 0:
|
||||
pass # No DirtPile Spawn
|
||||
elif not self._next_dirt_spawn:
|
||||
validity = state[d.DIRT].trigger_dirt_spawn(state)
|
||||
|
||||
return [TickResult(entity=None, validity=validity, identifier=self.name, reward=0)]
|
||||
self._next_dirt_spawn = self.spawn_freq
|
||||
else:
|
||||
self._next_dirt_spawn -= 1
|
||||
return []
|
24
mfg_package/modules/clean_up/rule_smear_on_move.py
Normal file
@ -0,0 +1,24 @@
|
||||
from mfg_package.environment.rules import Rule
|
||||
from mfg_package.utils.helpers import is_move
|
||||
from mfg_package.utils.results import TickResult
|
||||
|
||||
from mfg_package.environment import constants as c
|
||||
from mfg_package.modules.clean_up import constants as d
|
||||
|
||||
|
||||
class DirtSmearOnMove(Rule):
|
||||
|
||||
def __init__(self, smear_amount: float = 0.2):
|
||||
super().__init__()
|
||||
self.smear_amount = smear_amount
|
||||
|
||||
def tick_post_step(self, state):
|
||||
results = list()
|
||||
for entity in state.moving_entites:
|
||||
if is_move(entity.state.identifier) and entity.state.validity == c.VALID:
|
||||
if old_pos_dirt := state[d.DIRT].by_pos(entity.last_pos):
|
||||
if smeared_dirt := round(old_pos_dirt.amount * self.smear_amount, 2):
|
||||
if state[d.DIRT].spawn_dirt(entity.tile, amount=smeared_dirt):
|
||||
results.append(TickResult(identifier=self.name, entity=entity,
|
||||
reward=0, validity=c.VALID))
|
||||
return results
|
0
mfg_package/modules/destinations/__init__.py
Normal file
23
mfg_package/modules/destinations/actions.py
Normal file
@ -0,0 +1,23 @@
|
||||
from typing import Union
|
||||
|
||||
from mfg_package.environment.actions import Action
|
||||
from mfg_package.utils.results import ActionResult
|
||||
|
||||
from mfg_package.modules.destinations import constants as d, rewards as r
|
||||
from mfg_package.environment import constants as c
|
||||
|
||||
|
||||
class DestAction(Action):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(d.DESTINATION)
|
||||
|
||||
def do(self, entity, state) -> Union[None, ActionResult]:
|
||||
if destination := state[d.DESTINATION].by_pos(entity.pos):
|
||||
valid = destination.do_wait_action(entity)
|
||||
state.print(f'{entity.name} just waited at {entity.pos}')
|
||||
else:
|
||||
valid = c.NOT_VALID
|
||||
state.print(f'{entity.name} just tried to do_wait_action do_wait_action at {entity.pos} but failed')
|
||||
return ActionResult(entity=entity, identifier=self._identifier, validity=valid,
|
||||
reward=r.WAIT_VALID if valid else r.WAIT_FAIL)
|
14
mfg_package/modules/destinations/constants.py
Normal file
@ -0,0 +1,14 @@
|
||||
|
||||
# Destination Env
|
||||
DESTINATION = 'Destinations'
|
||||
DEST_SYMBOL = 1
|
||||
DEST_REACHED_REWARD = 0.5
|
||||
DEST_REACHED = 'ReachedDestinations'
|
||||
|
||||
WAIT_ON_DEST = 'WAIT'
|
||||
|
||||
MODE_SINGLE = 'SINGLE'
|
||||
MODE_GROUPED = 'GROUPED'
|
||||
|
||||
DONE_ALL = 'DONE_ALL'
|
||||
DONE_SINGLE = 'DONE_SINGLE'
|
BIN
mfg_package/modules/destinations/destinations.png
Normal file
After Width: | Height: | Size: 6.9 KiB |
51
mfg_package/modules/destinations/entitites.py
Normal file
@ -0,0 +1,51 @@
|
||||
from collections import defaultdict
|
||||
|
||||
from mfg_package.environment.entity.agent import Agent
|
||||
from mfg_package.environment.entity.entity import Entity
|
||||
from mfg_package.environment import constants as c
|
||||
from mfg_package.utils.render import RenderEntity
|
||||
from mfg_package.modules.destinations import constants as d
|
||||
|
||||
|
||||
class Destination(Entity):
|
||||
|
||||
@property
|
||||
def any_agent_has_dwelled(self):
|
||||
return bool(len(self._per_agent_times))
|
||||
|
||||
@property
|
||||
def currently_dwelling_names(self):
|
||||
return list(self._per_agent_times.keys())
|
||||
|
||||
@property
|
||||
def encoding(self):
|
||||
return d.DEST_SYMBOL
|
||||
|
||||
def __init__(self, *args, dwell_time: int = 0, **kwargs):
|
||||
super(Destination, self).__init__(*args, **kwargs)
|
||||
self.dwell_time = dwell_time
|
||||
self._per_agent_times = defaultdict(lambda: dwell_time)
|
||||
|
||||
def do_wait_action(self, agent: Agent):
|
||||
self._per_agent_times[agent.name] -= 1
|
||||
return c.VALID
|
||||
|
||||
def leave(self, agent: Agent):
|
||||
del self._per_agent_times[agent.name]
|
||||
|
||||
@property
|
||||
def is_considered_reached(self):
|
||||
agent_at_position = any(c.AGENT.lower() in x.name.lower() for x in self.tile.guests_that_can_collide)
|
||||
return (agent_at_position and not self.dwell_time) or any(x == 0 for x in self._per_agent_times.values())
|
||||
|
||||
def agent_is_dwelling(self, agent: Agent):
|
||||
return self._per_agent_times[agent.name] < self.dwell_time
|
||||
|
||||
def summarize_state(self) -> dict:
|
||||
state_summary = super().summarize_state()
|
||||
state_summary.update(per_agent_times=[
|
||||
dict(belongs_to=key, time=val) for key, val in self._per_agent_times.keys()], dwell_time=self.dwell_time)
|
||||
return state_summary
|
||||
|
||||
def render(self):
|
||||
return RenderEntity(d.DESTINATION, self.pos)
|
28
mfg_package/modules/destinations/groups.py
Normal file
@ -0,0 +1,28 @@
|
||||
from mfg_package.environment.groups.env_objects import EnvObjects
|
||||
from mfg_package.environment.groups.mixins import PositionMixin
|
||||
from mfg_package.modules.destinations.entitites import Destination
|
||||
|
||||
|
||||
class Destinations(PositionMixin, EnvObjects):
|
||||
|
||||
_entity = Destination
|
||||
is_blocking_light: bool = False
|
||||
can_collide: bool = False
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
def __repr__(self):
|
||||
return super(Destinations, self).__repr__()
|
||||
|
||||
|
||||
class ReachedDestinations(Destinations):
|
||||
_entity = Destination
|
||||
is_blocking_light = False
|
||||
can_collide = False
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super(ReachedDestinations, self).__init__(*args, **kwargs)
|
||||
|
||||
def __repr__(self):
|
||||
return super(ReachedDestinations, self).__repr__()
|
3
mfg_package/modules/destinations/rewards.py
Normal file
@ -0,0 +1,3 @@
|
||||
WAIT_VALID: float = 0.1
|
||||
WAIT_FAIL: float = -0.1
|
||||
DEST_REACHED: float = 5.0
|
89
mfg_package/modules/destinations/rules.py
Normal file
@ -0,0 +1,89 @@
|
||||
from typing import List, Union
|
||||
from mfg_package.environment.rules import Rule
|
||||
from mfg_package.utils.results import TickResult, DoneResult
|
||||
from mfg_package.environment import constants as c
|
||||
|
||||
from mfg_package.modules.destinations import constants as d, rewards as r
|
||||
from mfg_package.modules.destinations.entitites import Destination
|
||||
|
||||
|
||||
class DestinationReach(Rule):
|
||||
|
||||
def __init__(self, n_dests: int = 1, tiles: Union[List, None] = None):
|
||||
super(DestinationReach, self).__init__()
|
||||
self.n_dests = n_dests or len(tiles)
|
||||
self._tiles = tiles
|
||||
|
||||
def tick_step(self, state) -> List[TickResult]:
|
||||
|
||||
for dest in list(state[d.DESTINATION].values()):
|
||||
if dest.is_considered_reached:
|
||||
dest.change_parent_collection(state[d.DEST_REACHED])
|
||||
state.print(f'{dest.name} is reached now, removing...')
|
||||
else:
|
||||
for agent_name in dest.currently_dwelling_names:
|
||||
agent = state[c.AGENT][agent_name]
|
||||
if agent.pos == dest.pos:
|
||||
state.print(f'{agent.name} is still waiting.')
|
||||
pass
|
||||
else:
|
||||
dest.leave(agent)
|
||||
state.print(f'{agent.name} left the destination early.')
|
||||
return [TickResult(self.name, validity=c.VALID, reward=0, entity=None)]
|
||||
|
||||
def tick_post_step(self, state) -> List[TickResult]:
|
||||
results = list()
|
||||
for reached_dest in state[d.DEST_REACHED]:
|
||||
for guest in reached_dest.tile.guests:
|
||||
if guest in state[c.AGENT]:
|
||||
state.print(f'{guest.name} just reached destination at {guest.pos}')
|
||||
state[d.DEST_REACHED].delete_env_object(reached_dest)
|
||||
results.append(TickResult(self.name, validity=c.VALID, reward=r.DEST_REACHED, entity=guest))
|
||||
return results
|
||||
|
||||
|
||||
class DestinationDone(Rule):
|
||||
|
||||
def __init__(self):
|
||||
super(DestinationDone, self).__init__()
|
||||
|
||||
def on_check_done(self, state) -> List[DoneResult]:
|
||||
if not len(state[d.DESTINATION]):
|
||||
return [DoneResult(self.name, validity=c.VALID, reward=r.DEST_REACHED)]
|
||||
return []
|
||||
|
||||
|
||||
class DestinationSpawn(Rule):
|
||||
|
||||
def __init__(self, spawn_frequency: int = 5, n_dests: int = 1,
|
||||
spawn_mode: str = d.MODE_GROUPED):
|
||||
super(DestinationSpawn, self).__init__()
|
||||
self.spawn_frequency = spawn_frequency
|
||||
self.n_dests = n_dests
|
||||
self.spawn_mode = spawn_mode
|
||||
|
||||
def on_init(self, state):
|
||||
# noinspection PyAttributeOutsideInit
|
||||
self._dest_spawn_timer = self.spawn_frequency
|
||||
self.trigger_destination_spawn(self.n_dests, state)
|
||||
pass
|
||||
|
||||
def tick_pre_step(self, state) -> List[TickResult]:
|
||||
pass
|
||||
|
||||
def tick_step(self, state) -> List[TickResult]:
|
||||
if n_dest_spawn := max(0, self.n_dests - len(state[d.DESTINATION])):
|
||||
if self.spawn_mode == d.MODE_GROUPED and n_dest_spawn == self.n_dests:
|
||||
validity = state.rules['DestinationReach'].trigger_destination_spawn(n_dest_spawn, state)
|
||||
return [TickResult(self.name, validity=validity, entity=None, value=n_dest_spawn)]
|
||||
|
||||
@staticmethod
|
||||
def trigger_destination_spawn(n_dests, state, tiles=None):
|
||||
tiles = tiles or state[c.FLOOR].empty_tiles[:n_dests]
|
||||
if destinations := [Destination(tile) for tile in tiles]:
|
||||
state[d.DESTINATION].add_items(destinations)
|
||||
state.print(f'{n_dests} new destinations have been spawned')
|
||||
return c.VALID
|
||||
else:
|
||||
state.print('No Destiantions are spawning, limit is reached.')
|
||||
return c.NOT_VALID
|
0
mfg_package/modules/doors/__init__.py
Normal file
28
mfg_package/modules/doors/actions.py
Normal file
@ -0,0 +1,28 @@
|
||||
from typing import Union
|
||||
|
||||
from mfg_package.environment.actions import Action
|
||||
from mfg_package.utils.results import ActionResult
|
||||
|
||||
from mfg_package.modules.doors import constants as d, rewards as r
|
||||
from mfg_package.environment import constants as c
|
||||
|
||||
|
||||
class DoorUse(Action):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(d.ACTION_DOOR_USE)
|
||||
|
||||
def do(self, entity, state) -> Union[None, ActionResult]:
|
||||
# Check if agent really is standing on a door:
|
||||
e = state.entities.get_near_pos(entity.pos)
|
||||
try:
|
||||
door = next(x for x in e if x.name.startswith(d.DOOR))
|
||||
valid = door.use()
|
||||
state.print(f'{entity.name} just used a {door.name} at {door.pos}')
|
||||
return ActionResult(entity=entity, identifier=self._identifier, validity=valid, reward=r.USE_DOOR_VALID)
|
||||
|
||||
except StopIteration:
|
||||
# When he doesn't...
|
||||
state.print(f'{entity.name} just tried to use a door at {entity.pos}, but there is none.')
|
||||
return ActionResult(entity=entity, identifier=self._identifier,
|
||||
validity=c.NOT_VALID, reward=r.USE_DOOR_FAIL)
|
18
mfg_package/modules/doors/constants.py
Normal file
@ -0,0 +1,18 @@
|
||||
# Names / Identifiers
|
||||
DOOR = 'Door' # Identifier of Single-Door Entities.
|
||||
DOORS = 'Doors' # Identifier of Door-objects and groups (groups).
|
||||
|
||||
# Symbols (in map)
|
||||
SYMBOL_DOOR = 'D' # Door _identifier for resolving the string based map files.
|
||||
|
||||
# Values
|
||||
VALUE_ACCESS_INDICATOR = 1 / 3 # Access-door-Cell value used in observation
|
||||
VALUE_OPEN_DOOR = 2 / 3 # Open-door-Cell value used in observation
|
||||
VALUE_CLOSED_DOOR = 3 / 3 # Closed-door-Cell value used in observation
|
||||
|
||||
# States
|
||||
STATE_CLOSED = 'closed' # Identifier to compare door-is-closed state
|
||||
STATE_OPEN = 'open' # Identifier to compare door-is-open state
|
||||
|
||||
# Actions
|
||||
ACTION_DOOR_USE = 'use_door'
|
BIN
mfg_package/modules/doors/door_closed.png
Normal file
After Width: | Height: | Size: 699 B |
BIN
mfg_package/modules/doors/door_open.png
Normal file
After Width: | Height: | Size: 4.1 KiB |
97
mfg_package/modules/doors/entitites.py
Normal file
@ -0,0 +1,97 @@
|
||||
from mfg_package.environment.entity.entity import Entity
|
||||
from mfg_package.utils.render import RenderEntity
|
||||
from mfg_package.environment import constants as c
|
||||
|
||||
from mfg_package.modules.doors import constants as d
|
||||
|
||||
|
||||
class DoorIndicator(Entity):
|
||||
|
||||
@property
|
||||
def encoding(self):
|
||||
return d.VALUE_ACCESS_INDICATOR
|
||||
|
||||
def render(self):
|
||||
return None
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.__delattr__('move')
|
||||
|
||||
|
||||
class Door(Entity):
|
||||
|
||||
@property
|
||||
def is_blocking_pos(self):
|
||||
return False if self.is_open else True
|
||||
|
||||
@property
|
||||
def is_blocking_light(self):
|
||||
return False if self.is_open else True
|
||||
|
||||
@property
|
||||
def can_collide(self):
|
||||
return False if self.is_open else True
|
||||
|
||||
@property
|
||||
def encoding(self):
|
||||
return d.VALUE_CLOSED_DOOR if self.is_closed else d.VALUE_OPEN_DOOR
|
||||
|
||||
@property
|
||||
def str_state(self):
|
||||
return 'open' if self.is_open else 'closed'
|
||||
|
||||
def __init__(self, *args, closed_on_init=True, auto_close_interval=10, indicate_area=False, **kwargs):
|
||||
self._state = d.STATE_CLOSED
|
||||
super(Door, self).__init__(*args, **kwargs)
|
||||
self.auto_close_interval = auto_close_interval
|
||||
self.time_to_close = 0
|
||||
if not closed_on_init:
|
||||
self._open()
|
||||
if indicate_area:
|
||||
self._collection.add_items([DoorIndicator(x) for x in self.tile.neighboring_floor])
|
||||
|
||||
def summarize_state(self):
|
||||
state_dict = super().summarize_state()
|
||||
state_dict.update(state=str(self.str_state), time_to_close=int(self.time_to_close))
|
||||
return state_dict
|
||||
|
||||
@property
|
||||
def is_closed(self):
|
||||
return self._state == d.STATE_CLOSED
|
||||
|
||||
@property
|
||||
def is_open(self):
|
||||
return self._state == d.STATE_OPEN
|
||||
|
||||
@property
|
||||
def status(self):
|
||||
return self._state
|
||||
|
||||
def render(self):
|
||||
name, state = 'door_open' if self.is_open else 'door_closed', 'blank'
|
||||
return RenderEntity(name, self.pos, 1, 'none', state, self.identifier_int + 1)
|
||||
|
||||
def use(self):
|
||||
if self._state == d.STATE_OPEN:
|
||||
self._close()
|
||||
else:
|
||||
self._open()
|
||||
return c.VALID
|
||||
|
||||
def tick(self):
|
||||
if self.is_open and len(self.tile) == 1 and self.time_to_close:
|
||||
self.time_to_close -= 1
|
||||
return c.NOT_VALID
|
||||
elif self.is_open and not self.time_to_close and len(self.tile) == 1:
|
||||
self.use()
|
||||
return c.VALID
|
||||
else:
|
||||
return c.NOT_VALID
|
||||
|
||||
def _open(self):
|
||||
self._state = d.STATE_OPEN
|
||||
self.time_to_close = self.auto_close_interval
|
||||
|
||||
def _close(self):
|
||||
self._state = d.STATE_CLOSED
|
28
mfg_package/modules/doors/groups.py
Normal file
@ -0,0 +1,28 @@
|
||||
from typing import Union
|
||||
|
||||
from mfg_package.environment.groups.env_objects import EnvObjects
|
||||
from mfg_package.environment.groups.mixins import PositionMixin
|
||||
from mfg_package.modules.doors import constants as d
|
||||
from mfg_package.modules.doors.entitites import Door
|
||||
|
||||
|
||||
class Doors(PositionMixin, EnvObjects):
|
||||
|
||||
symbol = d.SYMBOL_DOOR
|
||||
_entity = Door
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super(Doors, self).__init__(*args, can_collide=True, **kwargs)
|
||||
|
||||
def get_near_position(self, position: (int, int)) -> Union[None, Door]:
|
||||
try:
|
||||
return next(door for door in self if position in door.tile.neighboring_floor_pos)
|
||||
except StopIteration:
|
||||
return None
|
||||
|
||||
def tick_doors(self):
|
||||
result_dict = dict()
|
||||
for door in self:
|
||||
did_tick = door.tick()
|
||||
result_dict.update({door.name: did_tick})
|
||||
return result_dict
|
2
mfg_package/modules/doors/rewards.py
Normal file
@ -0,0 +1,2 @@
|
||||
USE_DOOR_VALID: float = -0.00
|
||||
USE_DOOR_FAIL: float = -0.01
|
21
mfg_package/modules/doors/rule_door_auto_close.py
Normal file
@ -0,0 +1,21 @@
|
||||
from mfg_package.environment.rules import Rule
|
||||
from mfg_package.environment import constants as c
|
||||
from mfg_package.utils.results import TickResult
|
||||
from mfg_package.modules.doors import constants as d
|
||||
|
||||
|
||||
class DoorAutoClose(Rule):
|
||||
|
||||
def __init__(self, close_frequency: int = 10):
|
||||
super().__init__()
|
||||
self.close_frequency = close_frequency
|
||||
|
||||
def tick_step(self, state):
|
||||
if doors := state[d.DOORS]:
|
||||
doors_tick_result = doors.tick_doors()
|
||||
doors_that_ticked = [key for key, val in doors_tick_result.items() if val]
|
||||
state.print(f'{doors_that_ticked} were auto-closed'
|
||||
if doors_that_ticked else 'No Doors were auto-closed')
|
||||
return [TickResult(self.name, validity=c.VALID, value=0)]
|
||||
state.print('There are no doors, but you loaded the corresponding Module')
|
||||
return []
|
0
mfg_package/modules/items/__init__.py
Normal file
37
mfg_package/modules/items/actions.py
Normal file
@ -0,0 +1,37 @@
|
||||
from typing import Union
|
||||
|
||||
from mfg_package.environment.actions import Action
|
||||
from mfg_package.utils.results import ActionResult
|
||||
|
||||
from mfg_package.modules.items import constants as i, rewards as r
|
||||
from mfg_package.environment import constants as c
|
||||
|
||||
|
||||
class ItemAction(Action):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(i.ITEM_ACTION)
|
||||
|
||||
def do(self, entity, state) -> Union[None, ActionResult]:
|
||||
inventory = state[i.INVENTORY].by_entity(entity)
|
||||
if drop_off := state[i.DROP_OFF].by_pos(entity.pos):
|
||||
if inventory:
|
||||
valid = drop_off.place_item(inventory.pop())
|
||||
else:
|
||||
valid = c.NOT_VALID
|
||||
if valid:
|
||||
state.print(f'{entity.name} just dropped of an item at {drop_off.pos}.')
|
||||
else:
|
||||
state.print(f'{entity.name} just tried to drop off at {entity.pos}, but failed.')
|
||||
reward = r.DROP_OFF_VALID if valid else r.DROP_OFF_FAIL
|
||||
return ActionResult(entity=entity, identifier=self._identifier, validity=valid, reward=reward)
|
||||
|
||||
elif item := state[i.ITEM].by_pos(entity.pos):
|
||||
item.change_parent_collection(inventory)
|
||||
item.set_tile_to(state.NO_POS_TILE)
|
||||
state.print(f'{entity.name} just picked up an item at {entity.pos}')
|
||||
return ActionResult(entity=entity, identifier=self._identifier, validity=c.VALID, reward=r.PICK_UP_VALID)
|
||||
|
||||
else:
|
||||
state.print(f'{entity.name} just tried to pick up an item at {entity.pos}, but failed.')
|
||||
return ActionResult(entity=entity, identifier=self._identifier, validity=c.NOT_VALID, reward=r.PICK_UP_FAIL)
|
BIN
mfg_package/modules/items/assets/charge_pod.png
Normal file
After Width: | Height: | Size: 6.5 KiB |
BIN
mfg_package/modules/items/assets/dropofflocations.png
Normal file
After Width: | Height: | Size: 2.3 KiB |
BIN
mfg_package/modules/items/assets/items.png
Normal file
After Width: | Height: | Size: 3.0 KiB |
11
mfg_package/modules/items/constants.py
Normal file
@ -0,0 +1,11 @@
|
||||
from typing import NamedTuple
|
||||
|
||||
|
||||
SYMBOL_NO_ITEM = 0
|
||||
SYMBOL_DROP_OFF = 1
|
||||
# Item Env
|
||||
ITEM = 'Items'
|
||||
INVENTORY = 'Inventories'
|
||||
DROP_OFF = 'DropOffLocations'
|
||||
|
||||
ITEM_ACTION = 'ITEMACTION'
|
64
mfg_package/modules/items/entitites.py
Normal file
@ -0,0 +1,64 @@
|
||||
from collections import deque
|
||||
|
||||
from mfg_package.environment.entity.entity import Entity
|
||||
from mfg_package.environment import constants as c
|
||||
from mfg_package.utils.render import RenderEntity
|
||||
from mfg_package.modules.items import constants as i
|
||||
|
||||
|
||||
class Item(Entity):
|
||||
|
||||
def render(self):
|
||||
return RenderEntity(i.ITEM, self.tile.pos) if self.pos != c.VALUE_NO_POS else None
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self._auto_despawn = -1
|
||||
|
||||
@property
|
||||
def auto_despawn(self):
|
||||
return self._auto_despawn
|
||||
|
||||
@property
|
||||
def encoding(self):
|
||||
# Edit this if you want items to be drawn in the ops differently
|
||||
return 1
|
||||
|
||||
def set_auto_despawn(self, auto_despawn):
|
||||
self._auto_despawn = auto_despawn
|
||||
|
||||
def set_tile_to(self, no_pos_tile):
|
||||
self._tile = no_pos_tile
|
||||
|
||||
def summarize_state(self) -> dict:
|
||||
super_summarization = super(Item, self).summarize_state()
|
||||
super_summarization.update(dict(auto_despawn=self.auto_despawn))
|
||||
return super_summarization
|
||||
|
||||
|
||||
class DropOffLocation(Entity):
|
||||
|
||||
def render(self):
|
||||
return RenderEntity(i.DROP_OFF, self.tile.pos)
|
||||
|
||||
@property
|
||||
def encoding(self):
|
||||
return i.SYMBOL_DROP_OFF
|
||||
|
||||
def __init__(self, *args, storage_size_until_full: int = 5, auto_item_despawn_interval: int = 5, **kwargs):
|
||||
super(DropOffLocation, self).__init__(*args, **kwargs)
|
||||
self.auto_item_despawn_interval = auto_item_despawn_interval
|
||||
self.storage = deque(maxlen=storage_size_until_full or None)
|
||||
|
||||
def place_item(self, item: Item):
|
||||
if self.is_full:
|
||||
raise RuntimeWarning("There is currently no way to clear the storage or make it unfull.")
|
||||
return bc.NOT_VALID
|
||||
else:
|
||||
self.storage.append(item)
|
||||
item.set_auto_despawn(self.auto_item_despawn_interval)
|
||||
return c.VALID
|
||||
|
||||
@property
|
||||
def is_full(self):
|
||||
return False if not self.storage.maxlen else self.storage.maxlen == len(self.storage)
|
101
mfg_package/modules/items/groups.py
Normal file
@ -0,0 +1,101 @@
|
||||
from typing import List
|
||||
|
||||
from mfg_package.environment.groups.env_objects import EnvObjects
|
||||
from mfg_package.environment.groups.objects import Objects
|
||||
from mfg_package.environment.groups.mixins import PositionMixin, IsBoundMixin, HasBoundedMixin
|
||||
from mfg_package.environment.entity.wall_floor import Floor
|
||||
from mfg_package.environment.entity.agent import Agent
|
||||
from mfg_package.modules.items.entitites import Item, DropOffLocation
|
||||
|
||||
|
||||
class Items(PositionMixin, EnvObjects):
|
||||
|
||||
_entity = Item
|
||||
is_blocking_light: bool = False
|
||||
can_collide: bool = False
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
def spawn_items(self, tiles: List[Floor]):
|
||||
items = [self._entity(tile) for tile in tiles]
|
||||
self.add_items(items)
|
||||
|
||||
def despawn_items(self, items: List[Item]):
|
||||
items = [items] if isinstance(items, Item) else items
|
||||
for item in items:
|
||||
del self[item]
|
||||
|
||||
|
||||
class Inventory(IsBoundMixin, EnvObjects):
|
||||
|
||||
_accepted_objects = Item
|
||||
|
||||
@property
|
||||
def obs_tag(self):
|
||||
return self.name
|
||||
|
||||
def __init__(self, agent: Agent, *args, **kwargs):
|
||||
super(Inventory, self).__init__(*args, **kwargs)
|
||||
self._collection = None
|
||||
self.bind(agent)
|
||||
|
||||
def summarize_states(self, **kwargs):
|
||||
attr_dict = {key: val for key, val in self.__dict__.items() if not key.startswith('_') and key != 'data'}
|
||||
attr_dict.update(dict(items=[val.summarize_state(**kwargs) for key, val in self.items()]))
|
||||
attr_dict.update(dict(name=self.name, belongs_to=self._bound_entity.name))
|
||||
return attr_dict
|
||||
|
||||
def pop(self):
|
||||
item_to_pop = self[0]
|
||||
self.delete_env_object(item_to_pop)
|
||||
return item_to_pop
|
||||
|
||||
def set_collection(self, collection):
|
||||
self._collection = collection
|
||||
|
||||
|
||||
class Inventories(HasBoundedMixin, Objects):
|
||||
|
||||
_entity = Inventory
|
||||
can_move = False
|
||||
|
||||
@property
|
||||
def obs_pairs(self):
|
||||
return [(x.name, x) for x in self]
|
||||
|
||||
def __init__(self, size, *args, **kwargs):
|
||||
super(Inventories, self).__init__(*args, **kwargs)
|
||||
self.size = size
|
||||
self._obs = None
|
||||
self._lazy_eval_transforms = []
|
||||
|
||||
def spawn_inventories(self, agents):
|
||||
inventories = [self._entity(agent, self.size,)
|
||||
for _, agent in enumerate(agents)]
|
||||
self.add_items(inventories)
|
||||
|
||||
def idx_by_entity(self, entity):
|
||||
try:
|
||||
return next((idx for idx, inv in enumerate(self) if inv.belongs_to_entity(entity)))
|
||||
except StopIteration:
|
||||
return None
|
||||
|
||||
def by_entity(self, entity):
|
||||
try:
|
||||
return next((inv for inv in self if inv.belongs_to_entity(entity)))
|
||||
except StopIteration:
|
||||
return None
|
||||
|
||||
def summarize_states(self, **kwargs):
|
||||
return [val.summarize_states(**kwargs) for key, val in self.items()]
|
||||
|
||||
|
||||
class DropOffLocations(PositionMixin, EnvObjects):
|
||||
|
||||
_entity = DropOffLocation
|
||||
is_blocking_light: bool = False
|
||||
can_collide: bool = False
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super(DropOffLocations, self).__init__(*args, **kwargs)
|
4
mfg_package/modules/items/rewards.py
Normal file
@ -0,0 +1,4 @@
|
||||
DROP_OFF_VALID: float = 0.1
|
||||
DROP_OFF_FAIL: float = -0.1
|
||||
PICK_UP_FAIL: float = -0.1
|
||||
PICK_UP_VALID: float = 0.1
|
79
mfg_package/modules/items/rules.py
Normal file
@ -0,0 +1,79 @@
|
||||
from typing import List
|
||||
|
||||
from mfg_package.environment.rules import Rule
|
||||
from mfg_package.environment import constants as c
|
||||
from mfg_package.utils.results import TickResult
|
||||
from mfg_package.modules.items import constants as i
|
||||
from mfg_package.modules.items.entitites import DropOffLocation
|
||||
|
||||
|
||||
class ItemRules(Rule):
|
||||
|
||||
def __init__(self, n_items: int = 5, spawn_frequency: int = 15,
|
||||
n_locations: int = 5, max_dropoff_storage_size: int = 0):
|
||||
super().__init__()
|
||||
self.spawn_frequency = spawn_frequency
|
||||
self._next_item_spawn = spawn_frequency
|
||||
self.n_items = n_items
|
||||
self.max_dropoff_storage_size = max_dropoff_storage_size
|
||||
self.n_locations = n_locations
|
||||
|
||||
def on_init(self, state):
|
||||
self.trigger_drop_off_location_spawn(state)
|
||||
self._next_item_spawn = self.spawn_frequency
|
||||
self.trigger_inventory_spawn(state)
|
||||
self.trigger_item_spawn(state)
|
||||
|
||||
def tick_step(self, state):
|
||||
for item in list(state[i.ITEM].values()):
|
||||
if item.auto_despawn >= 1:
|
||||
item.set_auto_despawn(item.auto_despawn - 1)
|
||||
elif not item.auto_despawn:
|
||||
state[i.ITEM].delete_env_object(item)
|
||||
else:
|
||||
pass
|
||||
|
||||
if not self._next_item_spawn:
|
||||
self.trigger_item_spawn(state)
|
||||
else:
|
||||
self._next_item_spawn = max(0, self._next_item_spawn - 1)
|
||||
return []
|
||||
|
||||
def trigger_item_spawn(self, state):
|
||||
if item_to_spawns := max(0, (self.n_items - len(state[i.ITEM]))):
|
||||
empty_tiles = state[c.FLOOR].empty_tiles[:item_to_spawns]
|
||||
state[i.ITEM].spawn_items(empty_tiles)
|
||||
self._next_item_spawn = self.spawn_frequency
|
||||
state.print(f'{item_to_spawns} new items have been spawned; next spawn in {self._next_item_spawn}')
|
||||
return len(empty_tiles)
|
||||
else:
|
||||
state.print('No Items are spawning, limit is reached.')
|
||||
return 0
|
||||
|
||||
@staticmethod
|
||||
def trigger_inventory_spawn(state):
|
||||
state[i.INVENTORY].spawn_inventories(state[c.AGENT])
|
||||
|
||||
def tick_post_step(self, state) -> List[TickResult]:
|
||||
for item in list(state[i.ITEM].values()):
|
||||
if item.auto_despawn >= 1:
|
||||
item.set_auto_despawn(item.auto_despawn-1)
|
||||
elif not item.auto_despawn:
|
||||
state[i.ITEM].delete_env_object(item)
|
||||
else:
|
||||
pass
|
||||
|
||||
if not self._next_item_spawn:
|
||||
if spawned_items := self.trigger_item_spawn(state):
|
||||
return [TickResult(self.name, validity=c.VALID, value=spawned_items, entity=None)]
|
||||
else:
|
||||
return [TickResult(self.name, validity=c.NOT_VALID, value=0, entity=None)]
|
||||
else:
|
||||
self._next_item_spawn = max(0, self._next_item_spawn-1)
|
||||
return []
|
||||
|
||||
def trigger_drop_off_location_spawn(self, state):
|
||||
empty_tiles = state[c.FLOOR].empty_tiles[:self.n_locations]
|
||||
do_entites = state[i.DROP_OFF]
|
||||
drop_offs = [DropOffLocation(tile) for tile in empty_tiles]
|
||||
do_entites.add_items(drop_offs)
|