mirror of
https://github.com/illiumst/marl-factory-grid.git
synced 2025-12-14 03:00:37 +01:00
Reset tsp route caching + renamed and moved configs + removed unnecessary files
This commit is contained in:
@@ -1,4 +1,3 @@
|
||||
from .quickstart import init
|
||||
from marl_factory_grid.environment.factory import Factory
|
||||
"""
|
||||
Main module of the 'marl-factory-grid'-environment.
|
||||
|
||||
@@ -1 +1 @@
|
||||
from marl_factory_grid.algorithms.marl.memory import MARLActorCriticMemory
|
||||
|
||||
|
||||
@@ -11,7 +11,6 @@ import numpy as np
|
||||
from torch.distributions import Categorical
|
||||
|
||||
from marl_factory_grid.algorithms.marl.base_a2c import PolicyGradient, cumulate_discount
|
||||
from marl_factory_grid.algorithms.marl.memory import MARLActorCriticMemory
|
||||
from marl_factory_grid.algorithms.utils import add_env_props, instantiate_class
|
||||
from pathlib import Path
|
||||
from collections import deque
|
||||
|
||||
@@ -2,8 +2,6 @@ import numpy as np; import torch as th; import scipy as sp;
|
||||
from collections import deque
|
||||
from torch import nn
|
||||
|
||||
# RLLab Magic for calculating the discounted return G(t) = R(t) + gamma * R(t-1)
|
||||
# cf. https://github.com/rll/rllab/blob/ba78e4c16dc492982e648f117875b22af3965579/rllab/misc/special.py#L107
|
||||
cumulate_discount = lambda x, gamma: sp.signal.lfilter([1], [1, - gamma], x[::-1], axis=0)[::-1]
|
||||
|
||||
class Net(th.nn.Module):
|
||||
|
||||
@@ -1,242 +0,0 @@
|
||||
import torch
|
||||
from typing import Union, List, Dict
|
||||
import numpy as np
|
||||
from torch.distributions import Categorical
|
||||
from marl_factory_grid.algorithms.marl.memory import MARLActorCriticMemory
|
||||
from marl_factory_grid.algorithms.utils import add_env_props, instantiate_class
|
||||
from pathlib import Path
|
||||
import pandas as pd
|
||||
from collections import deque
|
||||
|
||||
|
||||
class Names:
|
||||
REWARD = 'reward'
|
||||
DONE = 'done'
|
||||
ACTION = 'action'
|
||||
OBSERVATION = 'observation'
|
||||
LOGITS = 'logits'
|
||||
HIDDEN_ACTOR = 'hidden_actor'
|
||||
HIDDEN_CRITIC = 'hidden_critic'
|
||||
AGENT = 'agent'
|
||||
ENV = 'env'
|
||||
ENV_NAME = 'env_name'
|
||||
N_AGENTS = 'n_agents'
|
||||
ALGORITHM = 'algorithm'
|
||||
MAX_STEPS = 'max_steps'
|
||||
N_STEPS = 'n_steps'
|
||||
BUFFER_SIZE = 'buffer_size'
|
||||
CRITIC = 'critic'
|
||||
BATCH_SIZE = 'bnatch_size'
|
||||
N_ACTIONS = 'n_actions'
|
||||
TRAIN_RENDER = 'train_render'
|
||||
EVAL_RENDER = 'eval_render'
|
||||
|
||||
|
||||
nms = Names
|
||||
ListOrTensor = Union[List, torch.Tensor]
|
||||
|
||||
|
||||
class BaseActorCritic:
|
||||
def __init__(self, cfg):
|
||||
self.factory = add_env_props(cfg)
|
||||
self.__training = True
|
||||
self.cfg = cfg
|
||||
self.n_agents = cfg[nms.AGENT][nms.N_AGENTS]
|
||||
self.reset_memory_after_epoch = True
|
||||
self.setup()
|
||||
|
||||
def setup(self):
|
||||
self.net = instantiate_class(self.cfg[nms.AGENT])
|
||||
self.optimizer = torch.optim.RMSprop(self.net.parameters(), lr=3e-4, eps=1e-5)
|
||||
|
||||
@classmethod
|
||||
def _as_torch(cls, x):
|
||||
if isinstance(x, np.ndarray):
|
||||
return torch.from_numpy(x)
|
||||
elif isinstance(x, List):
|
||||
return torch.tensor(x)
|
||||
elif isinstance(x, (int, float)):
|
||||
return torch.tensor([x])
|
||||
return x
|
||||
|
||||
def train(self):
|
||||
self.__training = False
|
||||
networks = [self.net] if not isinstance(self.net, List) else self.net
|
||||
for net in networks:
|
||||
net.train()
|
||||
|
||||
def eval(self):
|
||||
self.__training = False
|
||||
networks = [self.net] if not isinstance(self.net, List) else self.net
|
||||
for net in networks:
|
||||
net.eval()
|
||||
|
||||
def load_state_dict(self, path: Path):
|
||||
pass
|
||||
|
||||
def get_actions(self, out) -> ListOrTensor:
|
||||
actions = [Categorical(logits=logits).sample().item() for logits in out[nms.LOGITS]]
|
||||
return actions
|
||||
|
||||
def init_hidden(self) -> Dict[str, ListOrTensor]:
|
||||
pass
|
||||
|
||||
def forward(self,
|
||||
observations: ListOrTensor,
|
||||
actions: ListOrTensor,
|
||||
hidden_actor: ListOrTensor,
|
||||
hidden_critic: ListOrTensor
|
||||
) -> Dict[str, ListOrTensor]:
|
||||
pass
|
||||
|
||||
@torch.no_grad()
|
||||
def train_loop(self, checkpointer=None):
|
||||
env = self.factory
|
||||
if self.cfg[nms.ENV][nms.TRAIN_RENDER]:
|
||||
env.render()
|
||||
n_steps, max_steps = [self.cfg[nms.ALGORITHM][k] for k in [nms.N_STEPS, nms.MAX_STEPS]]
|
||||
tm = MARLActorCriticMemory(self.n_agents, self.cfg[nms.ALGORITHM].get(nms.BUFFER_SIZE, n_steps))
|
||||
global_steps, episode, df_results = 0, 0, []
|
||||
reward_queue = deque(maxlen=2000)
|
||||
|
||||
while global_steps < max_steps:
|
||||
obs = env.reset()
|
||||
obs = list(obs.values())
|
||||
last_hiddens = self.init_hidden()
|
||||
last_action, reward = [-1] * self.n_agents, [0.] * self.n_agents
|
||||
done, rew_log = [False] * self.n_agents, 0
|
||||
|
||||
if self.reset_memory_after_epoch:
|
||||
tm.reset()
|
||||
|
||||
tm.add(observation=obs, action=last_action,
|
||||
logits=torch.zeros(self.n_agents, 1, self.cfg[nms.AGENT][nms.N_ACTIONS]),
|
||||
values=torch.zeros(self.n_agents, 1), reward=reward, done=done, **last_hiddens)
|
||||
|
||||
while not all(done):
|
||||
out = self.forward(obs, last_action, **last_hiddens)
|
||||
action = self.get_actions(out)
|
||||
_, next_obs, reward, done, info = env.step(action)
|
||||
done = [done] * self.n_agents if isinstance(done, bool) else done
|
||||
|
||||
if self.cfg[nms.ENV][nms.TRAIN_RENDER]:
|
||||
env.render()
|
||||
|
||||
last_hiddens = dict(hidden_actor=out[nms.HIDDEN_ACTOR],
|
||||
hidden_critic=out[nms.HIDDEN_CRITIC])
|
||||
|
||||
logits = torch.stack([tensor.squeeze(0) for tensor in out.get(nms.LOGITS, None)], dim=0)
|
||||
values = torch.stack([tensor.squeeze(0) for tensor in out.get(nms.CRITIC, None)], dim=0)
|
||||
|
||||
tm.add(observation=obs, action=action, reward=reward, done=done,
|
||||
logits=logits, values=values,
|
||||
**last_hiddens)
|
||||
|
||||
obs = next_obs
|
||||
last_action = action
|
||||
|
||||
if (global_steps+1) % n_steps == 0 or all(done):
|
||||
with torch.inference_mode(False):
|
||||
self.learn(tm)
|
||||
|
||||
global_steps += 1
|
||||
rew_log += sum(reward)
|
||||
reward_queue.extend(reward)
|
||||
|
||||
if checkpointer is not None:
|
||||
checkpointer.step([
|
||||
(f'agent#{i}', agent)
|
||||
for i, agent in enumerate([self.net] if not isinstance(self.net, List) else self.net)
|
||||
])
|
||||
|
||||
if global_steps >= max_steps:
|
||||
break
|
||||
if global_steps%100 == 0:
|
||||
print(f'reward at episode: {episode} = {rew_log}')
|
||||
episode += 1
|
||||
df_results.append([episode, rew_log, *reward])
|
||||
df_results = pd.DataFrame(df_results,
|
||||
columns=['steps', 'reward', *[f'agent#{i}' for i in range(self.n_agents)]]
|
||||
)
|
||||
if checkpointer is not None:
|
||||
df_results.to_csv(checkpointer.path / 'results.csv', index=False)
|
||||
return df_results
|
||||
|
||||
@torch.inference_mode(True)
|
||||
def eval_loop(self, n_episodes, render=False):
|
||||
env = self.factory
|
||||
if self.cfg[nms.ENV][nms.EVAL_RENDER]:
|
||||
env.render()
|
||||
episode, results = 0, []
|
||||
while episode < n_episodes:
|
||||
obs = env.reset()
|
||||
obs = list(obs.values())
|
||||
last_hiddens = self.init_hidden()
|
||||
last_action, reward = [-1] * self.n_agents, [0.] * self.n_agents
|
||||
done, rew_log, eps_rew = [False] * self.n_agents, 0, torch.zeros(self.n_agents)
|
||||
while not all(done):
|
||||
out = self.forward(obs, last_action, **last_hiddens)
|
||||
action = self.get_actions(out)
|
||||
_, next_obs, reward, done, info = env.step(action)
|
||||
|
||||
if self.cfg[nms.ENV][nms.EVAL_RENDER]:
|
||||
env.render()
|
||||
|
||||
if isinstance(done, bool):
|
||||
done = [done] * obs[0].shape[0]
|
||||
obs = next_obs
|
||||
last_action = action
|
||||
last_hiddens = dict(hidden_actor=out.get(nms.HIDDEN_ACTOR, None),
|
||||
hidden_critic=out.get(nms.HIDDEN_CRITIC, None)
|
||||
)
|
||||
eps_rew += torch.tensor(reward)
|
||||
results.append(eps_rew.tolist() + [sum(eps_rew).item()] + [episode])
|
||||
episode += 1
|
||||
agent_columns = [f'agent#{i}' for i in range(self.cfg[nms.ENV][nms.N_AGENTS])]
|
||||
results = pd.DataFrame(results, columns=agent_columns + ['sum', 'episode'])
|
||||
results = pd.melt(results, id_vars=['episode'], value_vars=agent_columns + ['sum'],
|
||||
value_name='reward', var_name='agent')
|
||||
return results
|
||||
|
||||
@staticmethod
|
||||
def compute_advantages(critic, reward, done, gamma, gae_coef=0.0):
|
||||
tds = (reward + gamma * (1.0 - done) * critic[:, 1:].detach()) - critic[:, :-1]
|
||||
|
||||
if gae_coef <= 0:
|
||||
return tds
|
||||
|
||||
gae = torch.zeros_like(tds[:, -1])
|
||||
gaes = []
|
||||
for t in range(tds.shape[1]-1, -1, -1):
|
||||
gae = tds[:, t] + gamma * gae_coef * (1.0 - done[:, t]) * gae
|
||||
gaes.insert(0, gae)
|
||||
gaes = torch.stack(gaes, dim=1)
|
||||
return gaes
|
||||
|
||||
def actor_critic(self, tm, network, gamma, entropy_coef, vf_coef, gae_coef=0.0, **kwargs):
|
||||
obs, actions, done, reward = tm.observation, tm.action, tm.done[:, 1:], tm.reward[:, 1:]
|
||||
|
||||
out = network(obs, actions, tm.hidden_actor[:, 0].squeeze(0), tm.hidden_critic[:, 0].squeeze(0))
|
||||
logits = out[nms.LOGITS][:, :-1] # last one only needed for v_{t+1}
|
||||
critic = out[nms.CRITIC]
|
||||
|
||||
entropy_loss = Categorical(logits=logits).entropy().mean(-1)
|
||||
advantages = self.compute_advantages(critic, reward, done, gamma, gae_coef)
|
||||
value_loss = advantages.pow(2).mean(-1) # n_agent
|
||||
|
||||
# policy loss
|
||||
log_ap = torch.log_softmax(logits, -1)
|
||||
log_ap = torch.gather(log_ap, dim=-1, index=actions[:, 1:].unsqueeze(-1)).squeeze()
|
||||
a2c_loss = -(advantages.detach() * log_ap).mean(-1)
|
||||
# weighted loss
|
||||
loss = a2c_loss + vf_coef*value_loss - entropy_coef * entropy_loss
|
||||
return loss.mean()
|
||||
|
||||
def learn(self, tm: MARLActorCriticMemory, **kwargs):
|
||||
loss = self.actor_critic(tm, self.net, **self.cfg[nms.ALGORITHM], **kwargs)
|
||||
# remove next_obs, will be added in next iter
|
||||
self.optimizer.zero_grad()
|
||||
loss.backward()
|
||||
torch.nn.utils.clip_grad_norm_(self.net.parameters(), 0.5)
|
||||
self.optimizer.step()
|
||||
|
||||
@@ -1,8 +0,0 @@
|
||||
marl_factory_grid>environment>rules.py#SpawnEntity.on_reset()
|
||||
marl_factory_grid>environment>rewards.py
|
||||
marl_factory_grid>modules>clean_up>groups.py#DirtPiles.trigger_spawn()
|
||||
marl_factory_grid>environment>rules.py#AgentSpawnRule
|
||||
marl_factory_grid>utils>states.py#GameState.__init__()
|
||||
marl_factory_grid>environment>factory.py>Factory#render
|
||||
marl_factory_grid>environment>factory.py>Factory#set_recorder
|
||||
marl_factory_grid>utils>renderer.py>Renderer#render
|
||||
@@ -1,57 +0,0 @@
|
||||
import torch
|
||||
from marl_factory_grid.algorithms.marl.base_ac import BaseActorCritic, nms
|
||||
from marl_factory_grid.algorithms.utils import instantiate_class
|
||||
from pathlib import Path
|
||||
from natsort import natsorted
|
||||
from marl_factory_grid.algorithms.marl.memory import MARLActorCriticMemory
|
||||
|
||||
|
||||
class LoopIAC(BaseActorCritic):
|
||||
|
||||
def __init__(self, cfg):
|
||||
super(LoopIAC, self).__init__(cfg)
|
||||
|
||||
def setup(self):
|
||||
self.net = [
|
||||
instantiate_class(self.cfg[nms.AGENT]) for _ in range(self.n_agents)
|
||||
]
|
||||
self.optimizer = [
|
||||
torch.optim.RMSprop(self.net[ag_i].parameters(), lr=3e-4, eps=1e-5) for ag_i in range(self.n_agents)
|
||||
]
|
||||
|
||||
def load_state_dict(self, path: Path):
|
||||
paths = natsorted(list(path.glob('*.pt')))
|
||||
for path, net in zip(paths, self.net):
|
||||
net.load_state_dict(torch.load(path))
|
||||
|
||||
@staticmethod
|
||||
def merge_dicts(ds): # todo could be recursive for more than 1 hierarchy
|
||||
d = {}
|
||||
for k in ds[0].keys():
|
||||
d[k] = [d[k] for d in ds]
|
||||
return d
|
||||
|
||||
def init_hidden(self):
|
||||
ha = [net.init_hidden_actor() for net in self.net]
|
||||
hc = [net.init_hidden_critic() for net in self.net]
|
||||
return dict(hidden_actor=ha, hidden_critic=hc)
|
||||
|
||||
def forward(self, observations, actions, hidden_actor, hidden_critic):
|
||||
outputs = [
|
||||
net(
|
||||
self._as_torch(observations[ag_i]).unsqueeze(0).unsqueeze(0), # agent x time
|
||||
self._as_torch(actions[ag_i]).unsqueeze(0),
|
||||
hidden_actor[ag_i],
|
||||
hidden_critic[ag_i]
|
||||
) for ag_i, net in enumerate(self.net)
|
||||
]
|
||||
return self.merge_dicts(outputs)
|
||||
|
||||
def learn(self, tms: MARLActorCriticMemory, **kwargs):
|
||||
for ag_i in range(self.n_agents):
|
||||
tm, net = tms(ag_i), self.net[ag_i]
|
||||
loss = self.actor_critic(tm, net, **self.cfg[nms.ALGORITHM], **kwargs)
|
||||
self.optimizer[ag_i].zero_grad()
|
||||
loss.backward()
|
||||
torch.nn.utils.clip_grad_norm_(net.parameters(), 0.5)
|
||||
self.optimizer[ag_i].step()
|
||||
@@ -1,66 +0,0 @@
|
||||
from marl_factory_grid.algorithms.marl.base_ac import Names as nms
|
||||
from marl_factory_grid.algorithms.marl.snac import LoopSNAC
|
||||
from marl_factory_grid.algorithms.marl.memory import MARLActorCriticMemory
|
||||
import torch
|
||||
from torch.distributions import Categorical
|
||||
from marl_factory_grid.algorithms.utils import instantiate_class
|
||||
|
||||
|
||||
class LoopMAPPO(LoopSNAC):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super(LoopMAPPO, self).__init__(*args, **kwargs)
|
||||
self.reset_memory_after_epoch = False
|
||||
|
||||
def setup(self):
|
||||
self.net = instantiate_class(self.cfg[nms.AGENT])
|
||||
self.optimizer = torch.optim.Adam(self.net.parameters(), lr=3e-4, eps=1e-5)
|
||||
|
||||
def learn(self, tm: MARLActorCriticMemory, **kwargs):
|
||||
if len(tm) >= self.cfg['algorithm']['buffer_size']:
|
||||
# only learn when buffer is full
|
||||
for batch_i in range(self.cfg['algorithm']['n_updates']):
|
||||
batch = tm.chunk_dataloader(chunk_len=self.cfg['algorithm']['n_steps'],
|
||||
k=self.cfg['algorithm']['batch_size'])
|
||||
loss = self.mappo(batch, self.net, **self.cfg[nms.ALGORITHM], **kwargs)
|
||||
self.optimizer.zero_grad()
|
||||
loss.backward()
|
||||
torch.nn.utils.clip_grad_norm_(self.net.parameters(), 0.5)
|
||||
self.optimizer.step()
|
||||
|
||||
def monte_carlo_returns(self, rewards, done, gamma):
|
||||
rewards_ = []
|
||||
discounted_reward = torch.zeros_like(rewards[:, -1])
|
||||
for t in range(rewards.shape[1]-1, -1, -1):
|
||||
discounted_reward = rewards[:, t] + (gamma * (1.0 - done[:, t]) * discounted_reward)
|
||||
rewards_.insert(0, discounted_reward)
|
||||
rewards_ = torch.stack(rewards_, dim=1)
|
||||
return rewards_
|
||||
|
||||
def mappo(self, batch, network, gamma, entropy_coef, vf_coef, clip_range, **__):
|
||||
out = network(batch[nms.OBSERVATION], batch[nms.ACTION], batch[nms.HIDDEN_ACTOR], batch[nms.HIDDEN_CRITIC])
|
||||
logits = out[nms.LOGITS][:, :-1] # last one only needed for v_{t+1}
|
||||
|
||||
old_log_probs = torch.log_softmax(batch[nms.LOGITS], -1)
|
||||
old_log_probs = torch.gather(old_log_probs, index=batch[nms.ACTION][:, 1:].unsqueeze(-1), dim=-1).squeeze()
|
||||
|
||||
# monte carlo returns
|
||||
mc_returns = self.monte_carlo_returns(batch[nms.REWARD], batch[nms.DONE], gamma)
|
||||
mc_returns = (mc_returns - mc_returns.mean()) / (mc_returns.std() + 1e-8) # todo: norm across agent ok?
|
||||
advantages = mc_returns - out[nms.CRITIC][:, :-1]
|
||||
|
||||
# policy loss
|
||||
log_ap = torch.log_softmax(logits, -1)
|
||||
log_ap = torch.gather(log_ap, dim=-1, index=batch[nms.ACTION][:, 1:].unsqueeze(-1)).squeeze()
|
||||
ratio = (log_ap - old_log_probs).exp()
|
||||
surr1 = ratio * advantages.detach()
|
||||
surr2 = torch.clamp(ratio, 1 - clip_range, 1 + clip_range) * advantages.detach()
|
||||
policy_loss = -torch.min(surr1, surr2).mean(-1)
|
||||
|
||||
# entropy & value loss
|
||||
entropy_loss = Categorical(logits=logits).entropy().mean(-1)
|
||||
value_loss = advantages.pow(2).mean(-1) # n_agent
|
||||
|
||||
# weighted loss
|
||||
loss = policy_loss + vf_coef*value_loss - entropy_coef * entropy_loss
|
||||
|
||||
return loss.mean()
|
||||
@@ -1,221 +0,0 @@
|
||||
import numpy as np
|
||||
from collections import deque
|
||||
import torch
|
||||
from typing import Union
|
||||
from torch import Tensor
|
||||
from torch.utils.data import Dataset, ConcatDataset
|
||||
import random
|
||||
|
||||
|
||||
class ActorCriticMemory(object):
|
||||
def __init__(self, capacity=10):
|
||||
self.capacity = capacity
|
||||
self.reset()
|
||||
|
||||
def reset(self):
|
||||
self.__actions = LazyTensorFiFoQueue(maxlen=self.capacity+1)
|
||||
self.__hidden_actor = LazyTensorFiFoQueue(maxlen=self.capacity+1)
|
||||
self.__hidden_critic = LazyTensorFiFoQueue(maxlen=self.capacity+1)
|
||||
self.__states = LazyTensorFiFoQueue(maxlen=self.capacity+1)
|
||||
self.__rewards = LazyTensorFiFoQueue(maxlen=self.capacity+1)
|
||||
self.__dones = LazyTensorFiFoQueue(maxlen=self.capacity+1)
|
||||
self.__logits = LazyTensorFiFoQueue(maxlen=self.capacity+1)
|
||||
self.__values = LazyTensorFiFoQueue(maxlen=self.capacity+1)
|
||||
|
||||
def __len__(self):
|
||||
return len(self.__rewards) - 1
|
||||
|
||||
@property
|
||||
def observation(self, sls=slice(0, None)): # add time dimension through stacking
|
||||
return self.__states[sls].unsqueeze(0) # 1 x time x hidden dim
|
||||
|
||||
@property
|
||||
def hidden_actor(self, sls=slice(0, None)): # 1 x n_layers x dim
|
||||
return self.__hidden_actor[sls].unsqueeze(0) # 1 x time x n_layers x dim
|
||||
|
||||
@property
|
||||
def hidden_critic(self, sls=slice(0, None)): # 1 x n_layers x dim
|
||||
return self.__hidden_critic[sls].unsqueeze(0) # 1 x time x n_layers x dim
|
||||
|
||||
@property
|
||||
def reward(self, sls=slice(0, None)):
|
||||
return self.__rewards[sls].squeeze().unsqueeze(0) # 1 x time
|
||||
|
||||
@property
|
||||
def action(self, sls=slice(0, None)):
|
||||
return self.__actions[sls].long().squeeze().unsqueeze(0) # 1 x time
|
||||
|
||||
@property
|
||||
def done(self, sls=slice(0, None)):
|
||||
return self.__dones[sls].float().squeeze().unsqueeze(0) # 1 x time
|
||||
|
||||
@property
|
||||
def logits(self, sls=slice(0, None)): # assumes a trailing 1 for time dimension - common when using output from NN
|
||||
return self.__logits[sls].squeeze().unsqueeze(0) # 1 x time x actions
|
||||
|
||||
@property
|
||||
def values(self, sls=slice(0, None)):
|
||||
return self.__values[sls].squeeze().unsqueeze(0) # 1 x time x actions
|
||||
|
||||
def add_observation(self, state: Union[Tensor, np.ndarray]):
|
||||
self.__states.append(state if isinstance(state, Tensor) else torch.from_numpy(state))
|
||||
|
||||
def add_hidden_actor(self, hidden: Tensor):
|
||||
# layers x hidden dim
|
||||
self.__hidden_actor.append(hidden)
|
||||
|
||||
def add_hidden_critic(self, hidden: Tensor):
|
||||
# layers x hidden dim
|
||||
self.__hidden_critic.append(hidden)
|
||||
|
||||
def add_action(self, action: Union[int, Tensor]):
|
||||
if not isinstance(action, Tensor):
|
||||
action = torch.tensor(action)
|
||||
self.__actions.append(action)
|
||||
|
||||
def add_reward(self, reward: Union[float, Tensor]):
|
||||
if not isinstance(reward, Tensor):
|
||||
reward = torch.tensor(reward)
|
||||
self.__rewards.append(reward)
|
||||
|
||||
def add_done(self, done: bool):
|
||||
if not isinstance(done, Tensor):
|
||||
done = torch.tensor(done)
|
||||
self.__dones.append(done)
|
||||
|
||||
def add_logits(self, logits: Tensor):
|
||||
self.__logits.append(logits)
|
||||
|
||||
def add_values(self, values: Tensor):
|
||||
self.__values.append(values)
|
||||
|
||||
def add(self, **kwargs):
|
||||
for k, v in kwargs.items():
|
||||
func = getattr(ActorCriticMemory, f'add_{k}')
|
||||
func(self, v)
|
||||
|
||||
|
||||
class MARLActorCriticMemory(object):
|
||||
def __init__(self, n_agents, capacity):
|
||||
self.n_agents = n_agents
|
||||
self.memories = [
|
||||
ActorCriticMemory(capacity) for _ in range(n_agents)
|
||||
]
|
||||
|
||||
def __call__(self, agent_i):
|
||||
return self.memories[agent_i]
|
||||
|
||||
def __len__(self):
|
||||
return len(self.memories[0]) # todo add assertion check!
|
||||
|
||||
def reset(self):
|
||||
for mem in self.memories:
|
||||
mem.reset()
|
||||
|
||||
def add(self, **kwargs):
|
||||
for agent_i in range(self.n_agents):
|
||||
for k, v in kwargs.items():
|
||||
func = getattr(ActorCriticMemory, f'add_{k}')
|
||||
func(self.memories[agent_i], v[agent_i])
|
||||
|
||||
def __getattr__(self, attr):
|
||||
all_attrs = [getattr(mem, attr) for mem in self.memories]
|
||||
return torch.cat(all_attrs, 0) # agent x time ...
|
||||
|
||||
def chunk_dataloader(self, chunk_len, k):
|
||||
datasets = [ExperienceChunks(mem, chunk_len, k) for mem in self.memories]
|
||||
dataset = ConcatDataset(datasets)
|
||||
data = [dataset[i] for i in range(len(dataset))]
|
||||
data = custom_collate_fn(data)
|
||||
return data
|
||||
|
||||
|
||||
def custom_collate_fn(batch):
|
||||
elem = batch[0]
|
||||
return {key: torch.cat([d[key] for d in batch], dim=0) for key in elem}
|
||||
|
||||
|
||||
class ExperienceChunks(Dataset):
|
||||
def __init__(self, memory, chunk_len, k):
|
||||
assert chunk_len <= len(memory), 'chunk_len cannot be longer than the size of the memory'
|
||||
self.memory = memory
|
||||
self.chunk_len = chunk_len
|
||||
self.k = k
|
||||
|
||||
@property
|
||||
def whitelist(self):
|
||||
whitelist = torch.ones(len(self.memory) - self.chunk_len)
|
||||
for d in self.memory.done.squeeze().nonzero().flatten():
|
||||
whitelist[max((0, d-self.chunk_len-1)):d+2] = 0
|
||||
whitelist[0] = 0
|
||||
return whitelist.tolist()
|
||||
|
||||
def sample(self, start=1):
|
||||
cl = self.chunk_len
|
||||
sample = dict(observation=self.memory.observation[:, start:start+cl+1],
|
||||
action=self.memory.action[:, start-1:start+cl],
|
||||
hidden_actor=self.memory.hidden_actor[:, start-1],
|
||||
hidden_critic=self.memory.hidden_critic[:, start-1],
|
||||
reward=self.memory.reward[:, start:start + cl],
|
||||
done=self.memory.done[:, start:start + cl],
|
||||
logits=self.memory.logits[:, start:start + cl],
|
||||
values=self.memory.values[:, start:start + cl])
|
||||
return sample
|
||||
|
||||
def __len__(self):
|
||||
return self.k
|
||||
|
||||
def __getitem__(self, i):
|
||||
idx = random.choices(range(0, len(self.memory) - self.chunk_len), weights=self.whitelist, k=1)
|
||||
return self.sample(idx[0])
|
||||
|
||||
|
||||
class LazyTensorFiFoQueue:
|
||||
def __init__(self, maxlen):
|
||||
self.maxlen = maxlen
|
||||
self.reset()
|
||||
|
||||
def reset(self):
|
||||
self.__lazy_queue = deque(maxlen=self.maxlen)
|
||||
self.shape = None
|
||||
self.queue = None
|
||||
|
||||
def shape_init(self, tensor: Tensor):
|
||||
self.shape = torch.Size([self.maxlen, *tensor.shape])
|
||||
|
||||
def build_tensor_queue(self):
|
||||
if len(self.__lazy_queue) > 0:
|
||||
block = torch.stack(list(self.__lazy_queue), dim=0)
|
||||
l = block.shape[0]
|
||||
if self.queue is None:
|
||||
self.queue = block
|
||||
elif self.true_len() <= self.maxlen:
|
||||
self.queue = torch.cat((self.queue, block), dim=0)
|
||||
else:
|
||||
self.queue = torch.cat((self.queue[l:], block), dim=0)
|
||||
self.__lazy_queue.clear()
|
||||
|
||||
def append(self, data):
|
||||
if self.shape is None:
|
||||
self.shape_init(data)
|
||||
self.__lazy_queue.append(data)
|
||||
if len(self.__lazy_queue) >= self.maxlen:
|
||||
self.build_tensor_queue()
|
||||
|
||||
def true_len(self):
|
||||
return len(self.__lazy_queue) + (0 if self.queue is None else self.queue.shape[0])
|
||||
|
||||
def __len__(self):
|
||||
return min((self.true_len(), self.maxlen))
|
||||
|
||||
def __str__(self):
|
||||
return f'LazyTensorFiFoQueue\tmaxlen: {self.maxlen}, shape: {self.shape}, ' \
|
||||
f'len: {len(self)}, true_len: {self.true_len()}, elements in lazy queue: {len(self.__lazy_queue)}'
|
||||
|
||||
def __getitem__(self, item_or_slice):
|
||||
self.build_tensor_queue()
|
||||
return self.queue[item_or_slice]
|
||||
|
||||
|
||||
|
||||
|
||||
@@ -7,8 +7,8 @@ agent:
|
||||
hidden_size_critic: 64
|
||||
use_agent_embedding: False
|
||||
env:
|
||||
classname: marl_factory_grid.configs.custom
|
||||
env_name: "custom/MultiAgentConfigs/dirt_quadrant_train_config"
|
||||
classname: marl_factory_grid.environment.configs.marl_eval
|
||||
env_name: "marl_eval/dirt_quadrant_eval_config"
|
||||
n_agents: 2
|
||||
max_steps: 250
|
||||
pomdp_r: 2
|
||||
@@ -7,8 +7,8 @@ agent:
|
||||
hidden_size_critic: 64
|
||||
use_agent_embedding: False
|
||||
env:
|
||||
classname: marl_factory_grid.configs.custom
|
||||
env_name: "custom/two_rooms_one_door_modified_train_config"
|
||||
classname: marl_factory_grid.environment.configs.marl_eval
|
||||
env_name: "marl_eval/two_rooms_eval_config"
|
||||
n_agents: 2
|
||||
max_steps: 250
|
||||
pomdp_r: 2
|
||||
@@ -1,103 +0,0 @@
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
class RecurrentAC(nn.Module):
|
||||
def __init__(self, observation_size, n_actions, obs_emb_size,
|
||||
action_emb_size, hidden_size_actor, hidden_size_critic,
|
||||
n_agents, use_agent_embedding=True):
|
||||
super(RecurrentAC, self).__init__()
|
||||
observation_size = np.prod(observation_size)
|
||||
self.n_layers = 1
|
||||
self.n_actions = n_actions
|
||||
self.use_agent_embedding = use_agent_embedding
|
||||
self.hidden_size_actor = hidden_size_actor
|
||||
self.hidden_size_critic = hidden_size_critic
|
||||
self.action_emb_size = action_emb_size
|
||||
self.obs_proj = nn.Linear(observation_size, obs_emb_size)
|
||||
self.action_emb = nn.Embedding(n_actions+1, action_emb_size, padding_idx=0)
|
||||
self.agent_emb = nn.Embedding(n_agents, action_emb_size)
|
||||
mix_in_size = obs_emb_size+action_emb_size if not use_agent_embedding else obs_emb_size+n_agents*action_emb_size
|
||||
self.mix = nn.Sequential(nn.Tanh(),
|
||||
nn.Linear(mix_in_size, obs_emb_size),
|
||||
nn.Tanh(),
|
||||
nn.Linear(obs_emb_size, obs_emb_size)
|
||||
)
|
||||
self.gru_actor = nn.GRU(obs_emb_size, hidden_size_actor, batch_first=True, num_layers=self.n_layers)
|
||||
self.gru_critic = nn.GRU(obs_emb_size, hidden_size_critic, batch_first=True, num_layers=self.n_layers)
|
||||
self.action_head = nn.Sequential(
|
||||
nn.Linear(hidden_size_actor, hidden_size_actor),
|
||||
nn.Tanh(),
|
||||
nn.Linear(hidden_size_actor, n_actions)
|
||||
)
|
||||
# spectral_norm(nn.Linear(hidden_size_actor, hidden_size_actor)),
|
||||
self.critic_head = nn.Sequential(
|
||||
nn.Linear(hidden_size_critic, hidden_size_critic),
|
||||
nn.Tanh(),
|
||||
nn.Linear(hidden_size_critic, 1)
|
||||
)
|
||||
#self.action_head[-1].weight.data.uniform_(-3e-3, 3e-3)
|
||||
#self.action_head[-1].bias.data.uniform_(-3e-3, 3e-3)
|
||||
|
||||
def init_hidden_actor(self):
|
||||
return torch.zeros(1, self.n_layers, self.hidden_size_actor)
|
||||
|
||||
def init_hidden_critic(self):
|
||||
return torch.zeros(1, self.n_layers, self.hidden_size_critic)
|
||||
|
||||
def forward(self, observations, actions, hidden_actor=None, hidden_critic=None):
|
||||
n_agents, t, *_ = observations.shape
|
||||
obs_emb = self.obs_proj(observations.view(n_agents, t, -1).float())
|
||||
action_emb = self.action_emb(actions+1) # shift by one due to padding idx
|
||||
|
||||
if not self.use_agent_embedding:
|
||||
x_t = torch.cat((obs_emb, action_emb), -1)
|
||||
else:
|
||||
agent_emb = self.agent_emb(
|
||||
torch.cat([torch.arange(0, n_agents, 1).view(-1, 1)] * t, 1)
|
||||
)
|
||||
x_t = torch.cat((obs_emb, agent_emb, action_emb), -1)
|
||||
|
||||
mixed_x_t = self.mix(x_t)
|
||||
output_p, _ = self.gru_actor(input=mixed_x_t, hx=hidden_actor.swapaxes(1, 0))
|
||||
output_c, _ = self.gru_critic(input=mixed_x_t, hx=hidden_critic.swapaxes(1, 0))
|
||||
|
||||
logits = self.action_head(output_p)
|
||||
critic = self.critic_head(output_c).squeeze(-1)
|
||||
return dict(logits=logits, critic=critic, hidden_actor=output_p, hidden_critic=output_c)
|
||||
|
||||
|
||||
class RecurrentACL2(RecurrentAC):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.action_head = nn.Sequential(
|
||||
nn.Linear(self.hidden_size_actor, self.hidden_size_actor),
|
||||
nn.Tanh(),
|
||||
NormalizedLinear(self.hidden_size_actor, self.n_actions, trainable_magnitude=True)
|
||||
)
|
||||
|
||||
|
||||
class NormalizedLinear(nn.Linear):
|
||||
def __init__(self, in_features: int, out_features: int,
|
||||
device=None, dtype=None, trainable_magnitude=False):
|
||||
super(NormalizedLinear, self).__init__(in_features, out_features, False, device, dtype)
|
||||
self.d_sqrt = in_features**0.5
|
||||
self.trainable_magnitude = trainable_magnitude
|
||||
self.scale = nn.Parameter(torch.tensor([1.]), requires_grad=trainable_magnitude)
|
||||
|
||||
def forward(self, in_array):
|
||||
normalized_input = F.normalize(in_array, dim=-1, p=2, eps=1e-5)
|
||||
normalized_weight = F.normalize(self.weight, dim=-1, p=2, eps=1e-5)
|
||||
return F.linear(normalized_input, normalized_weight) * self.d_sqrt * self.scale
|
||||
|
||||
|
||||
class L2Norm(nn.Module):
|
||||
def __init__(self, in_features, trainable_magnitude=False):
|
||||
super(L2Norm, self).__init__()
|
||||
self.d_sqrt = in_features**0.5
|
||||
self.scale = nn.Parameter(torch.tensor([1.]), requires_grad=trainable_magnitude)
|
||||
|
||||
def forward(self, x):
|
||||
return F.normalize(x, dim=-1, p=2, eps=1e-5) * self.d_sqrt * self.scale
|
||||
@@ -1,55 +0,0 @@
|
||||
import torch
|
||||
from torch.distributions import Categorical
|
||||
from marl_factory_grid.algorithms.marl.iac import LoopIAC
|
||||
from marl_factory_grid.algorithms.marl.base_ac import nms
|
||||
from marl_factory_grid.algorithms.marl.memory import MARLActorCriticMemory
|
||||
|
||||
|
||||
class LoopSEAC(LoopIAC):
|
||||
def __init__(self, cfg):
|
||||
super(LoopSEAC, self).__init__(cfg)
|
||||
|
||||
def actor_critic(self, tm, networks, gamma, entropy_coef, vf_coef, gae_coef=0.0, **kwargs):
|
||||
obs, actions, done, reward = tm.observation, tm.action, tm.done[:, 1:], tm.reward[:, 1:]
|
||||
outputs = [net(obs, actions, tm.hidden_actor[:, 0], tm.hidden_critic[:, 0]) for net in networks]
|
||||
|
||||
with torch.inference_mode(True):
|
||||
true_action_logp = torch.stack([
|
||||
torch.log_softmax(out[nms.LOGITS][ag_i, :-1], -1)
|
||||
.gather(index=actions[ag_i, 1:, None], dim=-1)
|
||||
for ag_i, out in enumerate(outputs)
|
||||
], 0).squeeze()
|
||||
|
||||
losses = []
|
||||
|
||||
for ag_i, out in enumerate(outputs):
|
||||
logits = out[nms.LOGITS][:, :-1] # last one only needed for v_{t+1}
|
||||
critic = out[nms.CRITIC]
|
||||
|
||||
entropy_loss = Categorical(logits=logits[ag_i]).entropy().mean()
|
||||
advantages = self.compute_advantages(critic, reward, done, gamma, gae_coef)
|
||||
|
||||
# policy loss
|
||||
log_ap = torch.log_softmax(logits, -1)
|
||||
log_ap = torch.gather(log_ap, dim=-1, index=actions[:, 1:].unsqueeze(-1)).squeeze()
|
||||
|
||||
# importance weights
|
||||
iw = (log_ap - true_action_logp).exp().detach() # importance_weights
|
||||
|
||||
a2c_loss = (-iw*log_ap * advantages.detach()).mean(-1)
|
||||
|
||||
value_loss = (iw*advantages.pow(2)).mean(-1) # n_agent
|
||||
|
||||
# weighted loss
|
||||
loss = (a2c_loss + vf_coef*value_loss - entropy_coef * entropy_loss).mean()
|
||||
losses.append(loss)
|
||||
|
||||
return losses
|
||||
|
||||
def learn(self, tms: MARLActorCriticMemory, **kwargs):
|
||||
losses = self.actor_critic(tms, self.net, **self.cfg[nms.ALGORITHM], **kwargs)
|
||||
for ag_i, loss in enumerate(losses):
|
||||
self.optimizer[ag_i].zero_grad()
|
||||
loss.backward()
|
||||
torch.nn.utils.clip_grad_norm_(self.net[ag_i].parameters(), 0.5)
|
||||
self.optimizer[ag_i].step()
|
||||
@@ -7,8 +7,8 @@ agent:
|
||||
hidden_size_critic: 64
|
||||
use_agent_embedding: False
|
||||
env:
|
||||
classname: marl_factory_grid.configs.custom
|
||||
env_name: "custom/dirt_quadrant_train_config"
|
||||
classname: marl_factory_grid.environment.configs.rl
|
||||
env_name: "rl/dirt_quadrant_train_config"
|
||||
n_agents: 1
|
||||
max_steps: 250
|
||||
pomdp_r: 2
|
||||
@@ -7,8 +7,8 @@ agent:
|
||||
hidden_size_critic: 64
|
||||
use_agent_embedding: False
|
||||
env:
|
||||
classname: marl_factory_grid.configs.custom
|
||||
env_name: "custom/two_rooms_one_door_modified_train_config"
|
||||
classname: marl_factory_grid.environment.configs.rl
|
||||
env_name: "rl/two_rooms_train_config"
|
||||
n_agents: 1
|
||||
max_steps: 250
|
||||
pomdp_r: 2
|
||||
@@ -1,33 +0,0 @@
|
||||
from marl_factory_grid.algorithms.marl.base_ac import BaseActorCritic
|
||||
from marl_factory_grid.algorithms.marl.base_ac import nms
|
||||
import torch
|
||||
from torch.distributions import Categorical
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
class LoopSNAC(BaseActorCritic):
|
||||
def __init__(self, cfg):
|
||||
super().__init__(cfg)
|
||||
|
||||
def load_state_dict(self, path: Path):
|
||||
path2weights = list(path.glob('*.pt'))
|
||||
assert len(path2weights) == 1, f'Expected a single set of weights but got {len(path2weights)}'
|
||||
self.net.load_state_dict(torch.load(path2weights[0]))
|
||||
|
||||
def init_hidden(self):
|
||||
hidden_actor = self.net.init_hidden_actor()
|
||||
hidden_critic = self.net.init_hidden_critic()
|
||||
return dict(hidden_actor=torch.cat([hidden_actor] * self.n_agents, 0),
|
||||
hidden_critic=torch.cat([hidden_critic] * self.n_agents, 0)
|
||||
)
|
||||
|
||||
def get_actions(self, out):
|
||||
actions = Categorical(logits=out[nms.LOGITS]).sample().squeeze()
|
||||
return actions
|
||||
|
||||
def forward(self, observations, actions, hidden_actor, hidden_critic):
|
||||
out = self.net(self._as_torch(observations).unsqueeze(1),
|
||||
self._as_torch(actions).unsqueeze(1),
|
||||
hidden_actor, hidden_critic
|
||||
)
|
||||
return out
|
||||
@@ -37,7 +37,6 @@ class TSPBaseAgent(ABC):
|
||||
self._position_graph = self.generate_pos_graph()
|
||||
self._static_route = None
|
||||
self.cached_route = None
|
||||
self.fallback_action = None
|
||||
self.action_list = []
|
||||
|
||||
@abstractmethod
|
||||
@@ -50,46 +49,6 @@ class TSPBaseAgent(ABC):
|
||||
"""
|
||||
return 0
|
||||
|
||||
def calculate_tsp_route(self, target_identifier):
|
||||
"""
|
||||
Calculate the TSP route to reach a target.
|
||||
|
||||
:param target_identifier: Identifier of the target entity
|
||||
:type target_identifier: str
|
||||
|
||||
:return: TSP route
|
||||
:rtype: List[int]
|
||||
"""
|
||||
target_positions = [x for x in self._env.state[target_identifier].positions if x != c.VALUE_NO_POS]
|
||||
|
||||
# if there are cached routes, search for one matching the current and target position
|
||||
if self._env.state.route_cache and (
|
||||
route := self._env.state.get_cached_route(self.state.pos, target_positions)) is not None:
|
||||
# print(f"Retrieved cached route: {route}")
|
||||
return route
|
||||
# if none are found, calculate tsp route and cache it
|
||||
else:
|
||||
start_time = time.time()
|
||||
if self.local_optimization:
|
||||
nodes = \
|
||||
[self.state.pos] + \
|
||||
[x for x in target_positions if max(abs(np.subtract(x, self.state.pos))) < 3]
|
||||
try:
|
||||
while len(nodes) < 7:
|
||||
nodes += [next(x for x in target_positions if x not in nodes)]
|
||||
except StopIteration:
|
||||
nodes = [self.state.pos] + target_positions
|
||||
|
||||
else:
|
||||
nodes = [self.state.pos] + target_positions
|
||||
|
||||
route = tsp.traveling_salesman_problem(self._position_graph,
|
||||
nodes=nodes, cycle=True, method=tsp.greedy_tsp)
|
||||
duration = time.time() - start_time
|
||||
print("TSP calculation took {:.2f} seconds to execute".format(duration))
|
||||
self._env.state.cache_route(route)
|
||||
return route
|
||||
|
||||
def _use_door_or_move(self, door, target):
|
||||
"""
|
||||
Helper method to decide whether to use a door or move towards a target.
|
||||
@@ -108,6 +67,47 @@ class TSPBaseAgent(ABC):
|
||||
action = self._predict_move(target)
|
||||
return action
|
||||
|
||||
def calculate_tsp_route(self, target_identifier):
|
||||
"""
|
||||
Calculate the TSP route to reach a target.
|
||||
|
||||
:param target_identifier: Identifier of the target entity
|
||||
:type target_identifier: str
|
||||
|
||||
:return: TSP route
|
||||
:rtype: List[int]
|
||||
"""
|
||||
start_time = time.time()
|
||||
|
||||
if self.cached_route is not None:
|
||||
#print(f" Used cached route: {self.cached_route}")
|
||||
return copy.deepcopy(self.cached_route)
|
||||
|
||||
else:
|
||||
positions = [x for x in self._env.state[target_identifier].positions if x != c.VALUE_NO_POS]
|
||||
if self.local_optimization:
|
||||
nodes = \
|
||||
[self.state.pos] + \
|
||||
[x for x in positions if max(abs(np.subtract(x, self.state.pos))) < 3]
|
||||
try:
|
||||
while len(nodes) < 7:
|
||||
nodes += [next(x for x in positions if x not in nodes)]
|
||||
except StopIteration:
|
||||
nodes = [self.state.pos] + positions
|
||||
|
||||
else:
|
||||
nodes = [self.state.pos] + positions
|
||||
|
||||
route = tsp.traveling_salesman_problem(self._position_graph,
|
||||
nodes=nodes, cycle=True, method=tsp.greedy_tsp)
|
||||
self.cached_route = copy.deepcopy(route)
|
||||
#print(f"Cached route: {self.cached_route}")
|
||||
|
||||
end_time = time.time()
|
||||
duration = end_time - start_time
|
||||
#print("TSP calculation took {:.2f} seconds to execute".format(duration))
|
||||
return route
|
||||
|
||||
def _door_is_close(self, state):
|
||||
"""
|
||||
Check if a door is close to the agent's position.
|
||||
@@ -173,11 +173,8 @@ class TSPBaseAgent(ABC):
|
||||
action = next(action for action, pos_diff in MOVEMAP.items() if
|
||||
np.all(diff == pos_diff) and action in allowed_directions)
|
||||
except StopIteration:
|
||||
print(f"No valid action found for pos diff: {diff}. Using fallback action: {self.fallback_action}.")
|
||||
if self.fallback_action and any(self.fallback_action == action.name for action in self.state.actions):
|
||||
action = self.fallback_action
|
||||
else:
|
||||
action = choice(self.state.actions).name
|
||||
print(f"No valid action found for pos diff: {diff}. Using fallback action.")
|
||||
action = choice(self.state.actions).name
|
||||
else:
|
||||
action = choice(self.state.actions).name
|
||||
# noinspection PyUnboundLocalVariable
|
||||
|
||||
@@ -1,76 +0,0 @@
|
||||
import numpy as np
|
||||
|
||||
from marl_factory_grid.algorithms.static.TSP_base_agent import TSPBaseAgent
|
||||
|
||||
from marl_factory_grid.modules.items import constants as i
|
||||
from marl_factory_grid.environment import constants as c
|
||||
|
||||
future_planning = 7
|
||||
inventory_size = 3
|
||||
|
||||
MODE_GET = 'Mode_Get'
|
||||
MODE_BRING = 'Mode_Bring'
|
||||
|
||||
|
||||
class TSPItemAgent(TSPBaseAgent):
|
||||
|
||||
def __init__(self, *args, mode=MODE_GET, **kwargs):
|
||||
"""
|
||||
Initializes a TSPItemAgent that colects items in the environment, stores them in his inventory and drops them off
|
||||
at a drop-off location.
|
||||
|
||||
:param mode: Mode of the agent, either MODE_GET or MODE_BRING.
|
||||
"""
|
||||
super(TSPItemAgent, self).__init__(*args, **kwargs)
|
||||
self.mode = mode
|
||||
self.fallback_action = c.NOOP
|
||||
|
||||
def predict(self, *_, **__):
|
||||
item_at_position = self._env.state[i.ITEM].by_pos(self.state.pos)
|
||||
dropoff_at_position = self._env.state[i.DROP_OFF].by_pos(self.state.pos)
|
||||
if item_at_position:
|
||||
# Translate the action_object to an integer to have the same output as any other model
|
||||
action = i.ITEM_ACTION
|
||||
elif dropoff_at_position:
|
||||
# Translate the action_object to an integer to have the same output as any other model
|
||||
action = i.ITEM_ACTION
|
||||
elif door := self._door_is_close(self._env.state):
|
||||
action = self._use_door_or_move(door, i.DROP_OFF if self.mode == MODE_BRING else i.ITEM)
|
||||
else:
|
||||
action = self._choose()
|
||||
self.action_list.append(action)
|
||||
# Translate the action_object to an integer to have the same output as any other model
|
||||
try:
|
||||
action_obj = next(action_i for action_i, a in enumerate(self.state.actions) if a.name == action)
|
||||
except (StopIteration, UnboundLocalError):
|
||||
print('Will not happen')
|
||||
raise EnvironmentError
|
||||
# noinspection PyUnboundLocalVariable
|
||||
if self.mode == MODE_BRING and len(self._env[i.INVENTORY].by_entity(self.state)):
|
||||
pass
|
||||
elif self.mode == MODE_BRING and not len(self._env[i.INVENTORY].by_entity(self.state)):
|
||||
self.mode = MODE_GET
|
||||
elif self.mode == MODE_GET and len(self._env[i.INVENTORY].by_entity(self.state)) > inventory_size:
|
||||
self.mode = MODE_BRING
|
||||
else:
|
||||
pass
|
||||
return action_obj
|
||||
|
||||
def _choose(self):
|
||||
"""
|
||||
Internal Usage. Chooses the action based on the agent's mode and the environment state.
|
||||
|
||||
:return: Chosen action.
|
||||
:rtype: int
|
||||
"""
|
||||
target = i.DROP_OFF if self.mode == MODE_BRING else i.ITEM
|
||||
if len(self._env.state[i.ITEM]) >= 1:
|
||||
action = self._predict_move(target)
|
||||
|
||||
elif len(self._env[i.INVENTORY].by_entity(self.state)):
|
||||
self.mode = MODE_BRING
|
||||
action = self._predict_move(target)
|
||||
else:
|
||||
action = int(np.random.randint(self._env.action_space.n))
|
||||
# noinspection PyUnboundLocalVariable
|
||||
return action
|
||||
@@ -1,27 +0,0 @@
|
||||
from random import randint
|
||||
|
||||
from marl_factory_grid.algorithms.static.TSP_base_agent import TSPBaseAgent
|
||||
|
||||
future_planning = 7
|
||||
|
||||
|
||||
class TSPRandomAgent(TSPBaseAgent):
|
||||
|
||||
def __init__(self, n_actions, *args, **kwargs):
|
||||
"""
|
||||
Initializes a TSPRandomAgent that performs random actions from within his action space.
|
||||
|
||||
:param n_actions: Number of possible actions.
|
||||
:type n_actions: int
|
||||
"""
|
||||
super(TSPRandomAgent, self).__init__(*args, **kwargs)
|
||||
self.n_action = n_actions
|
||||
|
||||
def predict(self, *_, **__):
|
||||
"""
|
||||
Predicts the next action randomly.
|
||||
|
||||
:return: Predicted action.
|
||||
:rtype: int
|
||||
"""
|
||||
return randint(0, self.n_action - 1)
|
||||
@@ -58,7 +58,7 @@ def load_yaml_file(path: Path):
|
||||
|
||||
def add_env_props(cfg):
|
||||
# Path to config File
|
||||
env_path = Path(f'../marl_factory_grid/configs/{cfg["env"]["env_name"]}.yaml')
|
||||
env_path = Path(f'../marl_factory_grid/environment/configs/{cfg["env"]["env_name"]}.yaml')
|
||||
|
||||
# Env Init
|
||||
factory = Factory(env_path)
|
||||
|
||||
@@ -1,66 +0,0 @@
|
||||
General:
|
||||
env_seed: 69
|
||||
individual_rewards: true
|
||||
level_name: obs_test_map
|
||||
pomdp_r: 0
|
||||
verbose: True
|
||||
tests: false
|
||||
|
||||
Agents:
|
||||
Wolfgang:
|
||||
Actions:
|
||||
- Noop
|
||||
Observations:
|
||||
- Walls
|
||||
- Doors
|
||||
- Other
|
||||
- DirtPiles
|
||||
Positions:
|
||||
- (1, 3)
|
||||
|
||||
Soeren:
|
||||
Actions:
|
||||
- Noop
|
||||
Observations:
|
||||
- Walls
|
||||
- Doors
|
||||
- Other
|
||||
- DirtPiles
|
||||
Positions:
|
||||
- (1, 1)
|
||||
|
||||
Juergen:
|
||||
Actions:
|
||||
- Noop
|
||||
Observations:
|
||||
- Walls
|
||||
- Doors
|
||||
- Other
|
||||
- DirtPiles
|
||||
Positions:
|
||||
- (1, 2)
|
||||
|
||||
Walter:
|
||||
Actions:
|
||||
- Noop
|
||||
Observations:
|
||||
- Walls
|
||||
- Doors
|
||||
- Other
|
||||
- DirtPiles
|
||||
Positions:
|
||||
- (1, 4)
|
||||
|
||||
|
||||
Entities:
|
||||
DirtPiles:
|
||||
Doors:
|
||||
|
||||
Rules:
|
||||
# Utilities
|
||||
WatchCollisions:
|
||||
done_at_collisions: false
|
||||
|
||||
# Done Conditions
|
||||
DoneAtMaxStepsReached:
|
||||
max_steps: 500
|
||||
@@ -1,92 +0,0 @@
|
||||
General:
|
||||
# RNG-seed to sample the same "random" numbers every time, to make the different runs comparable.
|
||||
env_seed: 69
|
||||
# Individual vs global rewards
|
||||
individual_rewards: true
|
||||
# The level.txt file to load from marl_factory_grid/levels
|
||||
level_name: rooms
|
||||
# Radius of Partially observable Markov decision process
|
||||
pomdp_r: 3
|
||||
# Print all messages and events
|
||||
verbose: false
|
||||
# Run tests
|
||||
tests: false
|
||||
|
||||
# In the "clean and bring" Scenario one agent aims to pick up all items and drop them at drop-off locations while all
|
||||
# other agents aim to clean dirt piles.
|
||||
Agents:
|
||||
# The clean agents
|
||||
Wolfgang:
|
||||
Actions:
|
||||
- Move8
|
||||
- DoorUse
|
||||
- Clean
|
||||
- Noop
|
||||
Observations:
|
||||
- Walls
|
||||
- Doors
|
||||
- Other
|
||||
- DirtPiles
|
||||
Clones: 8
|
||||
|
||||
# The item agent
|
||||
Juergen:
|
||||
Actions:
|
||||
- Move8
|
||||
- DoorUse
|
||||
- ItemAction
|
||||
- Noop
|
||||
Observations:
|
||||
- Walls
|
||||
- Doors
|
||||
- Other
|
||||
- Items
|
||||
- DropOffLocations
|
||||
- Inventory
|
||||
|
||||
Entities:
|
||||
DirtPiles:
|
||||
coords_or_quantity: 10
|
||||
initial_amount: 2
|
||||
clean_amount: 1
|
||||
dirt_spawn_r_var: 0.1
|
||||
max_global_amount: 20
|
||||
max_local_amount: 5
|
||||
Doors:
|
||||
DropOffLocations:
|
||||
coords_or_quantity: 1
|
||||
max_dropoff_storage_size: 0
|
||||
Inventories: { }
|
||||
Items:
|
||||
coords_or_quantity: 5
|
||||
|
||||
# Rules section specifies the rules governing the dynamics of the environment.
|
||||
Rules:
|
||||
# Environment Dynamics
|
||||
# When stepping over a dirt pile, entities carry a ratio of the dirt to their next position
|
||||
EntitiesSmearDirtOnMove:
|
||||
smear_ratio: 0.2
|
||||
# Doors automatically close after a certain number of time steps
|
||||
DoorAutoClose:
|
||||
close_frequency: 7
|
||||
|
||||
# Respawn Stuff
|
||||
# Define how dirt should respawn after the initial spawn
|
||||
RespawnDirt:
|
||||
respawn_freq: 30
|
||||
# Define how items should respawn after the initial spawn
|
||||
RespawnItems:
|
||||
respawn_freq: 50
|
||||
|
||||
# Utilities
|
||||
# This rule defines the collision mechanic, introduces a related DoneCondition and lets you specify rewards.
|
||||
# Can be omitted/ignored if you do not want to take care of collisions at all.
|
||||
WatchCollisions:
|
||||
done_at_collisions: false
|
||||
|
||||
# Done Conditions
|
||||
# Define the conditions for the environment to stop. Either success or a fail conditions.
|
||||
# The environment stops when all dirt is cleaned
|
||||
DoneOnAllDirtCleaned:
|
||||
DoneAtMaxStepsReached:
|
||||
max_steps: 500
|
||||
@@ -1,73 +0,0 @@
|
||||
General:
|
||||
# RNG-seed to sample the same "random" numbers every time, to make the different runs comparable.
|
||||
env_seed: 69
|
||||
# Individual vs global rewards
|
||||
individual_rewards: true
|
||||
# The level.txt file to load from marl_factory_grid/levels
|
||||
level_name: quadrant
|
||||
# Radius of Partially observable Markov decision process
|
||||
pomdp_r: 0 # default 3
|
||||
# Print all messages and events
|
||||
verbose: false
|
||||
# Run tests
|
||||
tests: false
|
||||
|
||||
# In the "clean and bring" Scenario one agent aims to pick up all items and drop them at drop-off locations while all
|
||||
# other agents aim to clean dirt piles.
|
||||
Agents:
|
||||
# The clean agents
|
||||
Sigmund:
|
||||
Actions:
|
||||
- Move4
|
||||
#- Clean
|
||||
#- Noop
|
||||
Observations:
|
||||
- DirtPiles
|
||||
- Self
|
||||
Positions:
|
||||
- (9,1)
|
||||
- (4,5)
|
||||
- (1,1)
|
||||
- (4,5)
|
||||
- (9,1)
|
||||
- (9,9)
|
||||
Wolfgang:
|
||||
Actions:
|
||||
- Move4
|
||||
#- Clean
|
||||
#- Noop
|
||||
Observations:
|
||||
- DirtPiles
|
||||
- Self
|
||||
Positions:
|
||||
- (9,5)
|
||||
- (4,5)
|
||||
- (1,1)
|
||||
- (4,5)
|
||||
- (9,5)
|
||||
- (9,9)
|
||||
|
||||
Entities:
|
||||
DirtPiles:
|
||||
coords_or_quantity: (9,9), (1,1), (4,5) # (4,7), (2,4), (1, 1) #(1, 1), (2,4), (4,7), (7,9), (9,9) # (1, 1), (1,2), (1,3), (2,4), (2,5), (3,6), (4,7), (5,8), (6,8), (7,9), (8,9), (9,9)
|
||||
initial_amount: 0.5 # <1 to ensure that the robot which first attempts to clean this field, can remove the dirt in one action
|
||||
clean_amount: 1
|
||||
dirt_spawn_r_var: 0
|
||||
max_global_amount: 12
|
||||
max_local_amount: 1
|
||||
|
||||
# Rules section specifies the rules governing the dynamics of the environment.
|
||||
Rules:
|
||||
|
||||
# Utilities
|
||||
# This rule defines the collision mechanic, introduces a related DoneCondition and lets you specify rewards.
|
||||
# Can be omitted/ignored if you do not want to take care of collisions at all.
|
||||
WatchCollisions:
|
||||
done_at_collisions: false
|
||||
|
||||
# Done Conditions
|
||||
# Define the conditions for the environment to stop. Either success or a fail conditions.
|
||||
# The environment stops when all dirt is cleaned
|
||||
DoneOnAllDirtCleaned:
|
||||
#DoneAtMaxStepsReached: # An episode should last for at most max_steps steps
|
||||
#max_steps: 100
|
||||
@@ -1,146 +0,0 @@
|
||||
# Default Configuration File
|
||||
|
||||
General:
|
||||
# RNG-seed to sample the same "random" numbers every time, to make the different runs comparable.
|
||||
env_seed: 69
|
||||
# Individual vs global rewards
|
||||
individual_rewards: true
|
||||
# The level.txt file to load from marl_factory_grid/levels
|
||||
level_name: large
|
||||
# View Radius; 0 = full observatbility
|
||||
pomdp_r: 3
|
||||
# Print all messages and events
|
||||
verbose: false
|
||||
# Run tests
|
||||
tests: false
|
||||
|
||||
# Agents section defines the characteristics of different agents in the environment.
|
||||
|
||||
# An Agent requires a list of actions and observations.
|
||||
# Possible actions: Noop, Charge, Clean, DestAction, DoorUse, ItemAction, MachineAction, Move8, Move4, North, NorthEast, ...
|
||||
# Possible observations: All, Combined, GlobalPosition, Battery, ChargePods, DirtPiles, Destinations, Doors, Items, Inventory, DropOffLocations, Maintainers, ...
|
||||
# You can use 'clone' as the agent name to have multiple instances with either a list of names or an int specifying the number of clones.
|
||||
Agents:
|
||||
Wolfgang:
|
||||
Actions:
|
||||
- Noop
|
||||
- Charge
|
||||
- Clean
|
||||
- DestAction
|
||||
- DoorUse
|
||||
- ItemAction
|
||||
- Move8
|
||||
Observations:
|
||||
- Combined:
|
||||
- Other
|
||||
- Walls
|
||||
- GlobalPosition
|
||||
- Battery
|
||||
- ChargePods
|
||||
- DirtPiles
|
||||
- Destinations
|
||||
- Doors
|
||||
- Items
|
||||
- Inventory
|
||||
- DropOffLocations
|
||||
- Maintainers
|
||||
|
||||
# Entities section defines the initial parameters and behaviors of different entities in the environment.
|
||||
# Entities all spawn using coords_or_quantity, a number of entities or coordinates to place them.
|
||||
Entities:
|
||||
# Batteries: Entities representing power sources for agents.
|
||||
Batteries:
|
||||
initial_charge: 0.8
|
||||
per_action_costs: 0.02
|
||||
|
||||
# ChargePods: Entities representing charging stations for Batteries.
|
||||
ChargePods:
|
||||
coords_or_quantity: 2
|
||||
|
||||
# Destinations: Entities representing target locations for agents.
|
||||
# - spawn_mode: GROUPED or SINGLE. Determines how destinations are spawned.
|
||||
Destinations:
|
||||
coords_or_quantity: 1
|
||||
spawn_mode: GROUPED
|
||||
|
||||
# DirtPiles: Entities representing piles of dirt.
|
||||
# - initial_amount: Initial amount of dirt in each pile.
|
||||
# - clean_amount: Amount of dirt cleaned in each cleaning action.
|
||||
# - dirt_spawn_r_var: Random variation in dirt spawn amounts.
|
||||
# - max_global_amount: Maximum total amount of dirt allowed in the environment.
|
||||
# - max_local_amount: Maximum amount of dirt allowed in one position.
|
||||
DirtPiles:
|
||||
coords_or_quantity: 10
|
||||
initial_amount: 2
|
||||
clean_amount: 1
|
||||
dirt_spawn_r_var: 0.1
|
||||
max_global_amount: 20
|
||||
max_local_amount: 5
|
||||
|
||||
# Doors are spawned using the level map.
|
||||
Doors:
|
||||
|
||||
# DropOffLocations: Entities representing locations where agents can drop off items.
|
||||
# - max_dropoff_storage_size: Maximum storage capacity at each drop-off location.
|
||||
DropOffLocations:
|
||||
coords_or_quantity: 1
|
||||
max_dropoff_storage_size: 0
|
||||
|
||||
# GlobalPositions.
|
||||
GlobalPositions: { }
|
||||
|
||||
# Inventories: Entities representing inventories for agents.
|
||||
Inventories: { }
|
||||
|
||||
# Items: Entities representing items in the environment.
|
||||
Items:
|
||||
coords_or_quantity: 5
|
||||
|
||||
# Machines: Entities representing machines in the environment.
|
||||
Machines:
|
||||
coords_or_quantity: 2
|
||||
|
||||
# Maintainers: Entities representing maintainers that aim to maintain machines.
|
||||
Maintainers:
|
||||
coords_or_quantity: 1
|
||||
|
||||
|
||||
# Rules section specifies the rules governing the dynamics of the environment.
|
||||
Rules:
|
||||
# Environment Dynamics
|
||||
# When stepping over a dirt pile, entities carry a ratio of the dirt to their next position
|
||||
EntitiesSmearDirtOnMove:
|
||||
smear_ratio: 0.2
|
||||
# Doors automatically close after a certain number of time steps
|
||||
DoorAutoClose:
|
||||
close_frequency: 10
|
||||
# Maintainers move at every time step
|
||||
MoveMaintainers:
|
||||
|
||||
# Respawn Stuff
|
||||
# Define how dirt should respawn after the initial spawn
|
||||
RespawnDirt:
|
||||
respawn_freq: 15
|
||||
# Define how items should respawn after the initial spawn
|
||||
RespawnItems:
|
||||
respawn_freq: 15
|
||||
|
||||
# Utilities
|
||||
# This rule defines the collision mechanic, introduces a related DoneCondition and lets you specify rewards.
|
||||
# Can be omitted/ignored if you do not want to take care of collisions at all.
|
||||
WatchCollisions:
|
||||
done_at_collisions: false
|
||||
|
||||
# Done Conditions
|
||||
# Define the conditions for the environment to stop. Either success or a fail conditions.
|
||||
# The environment stops when an agent reaches a destination
|
||||
DoneAtDestinationReach:
|
||||
# The environment stops when all dirt is cleaned
|
||||
DoneOnAllDirtCleaned:
|
||||
# The environment stops when a battery is discharged
|
||||
DoneAtBatteryDischarge:
|
||||
# The environment stops when a maintainer reports a collision
|
||||
DoneAtMaintainerCollision:
|
||||
# The environment stops after max steps
|
||||
DoneAtMaxStepsReached:
|
||||
max_steps: 500
|
||||
@@ -1,89 +0,0 @@
|
||||
# Gneral env. settings.
|
||||
General:
|
||||
# Just the best seed.
|
||||
env_seed: 69
|
||||
# Each agent receives an inividual Reward.
|
||||
individual_rewards: true
|
||||
# level file to load from .\levels\.
|
||||
level_name: eight_puzzle
|
||||
# Partial Observability. 0 = Full Observation.
|
||||
pomdp_r: 0
|
||||
# Please do not spam me.
|
||||
verbose: false
|
||||
# Do not touch, WIP
|
||||
tests: false
|
||||
|
||||
# RL Surrogates
|
||||
Agents:
|
||||
# This defines the name of the agent. UTF-8
|
||||
Wolfgang:
|
||||
# Section which defines the availabll Actions per Agent
|
||||
Actions:
|
||||
# Move4 adds 4 actions [`North`, `East`, `South`, `West`]
|
||||
Move4:
|
||||
# Reward specification which differ from the default.
|
||||
# Agent does a valid move in the environment. He actually moves.
|
||||
valid_reward: -0.1
|
||||
# Agent wants to move, but fails.
|
||||
fail_reward: 0
|
||||
# NOOP aka agent does not do a thing.
|
||||
Noop:
|
||||
# The Agent decides to not do anything. Which is always valid.
|
||||
valid_reward: 0
|
||||
# Does not do anything, just using the same interface.
|
||||
fail_reward: 0
|
||||
# What the agent wants to see.
|
||||
Observations:
|
||||
# The agent...
|
||||
# sees other agents, but himself.
|
||||
- Other
|
||||
# wants to see walls
|
||||
- Walls
|
||||
# sees his associated Destination (singular). Use the Plural for `see all destinations`.
|
||||
- Destination
|
||||
# You want to have 7 clones, also possible to name them by giving names as list.
|
||||
Clones: 7
|
||||
# Agents are blocking their grid position from beeing entered by others.
|
||||
is_blocking_pos: true
|
||||
# Apart from agents, which additional endities do you want to load?
|
||||
Entities:
|
||||
# Observable destinations, which can be reached by stepping on the same position. Has additional parameters...
|
||||
Destinations:
|
||||
# Let them spawn on closed doors and agent positions
|
||||
ignore_blocking: true
|
||||
# For 8-Puzzle, we need a special spawn rule...
|
||||
spawnrule:
|
||||
# ...which spawn a single position just underneath an associated agent.
|
||||
SpawnDestinationOnAgent: {} # There are no parameters, so we state empty kwargs.
|
||||
|
||||
# This section defines which operations are performed beside agent action.
|
||||
# Without this section nothing happens, not even Done-condition checks.
|
||||
# Also, situation based rewards are specidief this way.
|
||||
Rules:
|
||||
## Utilities
|
||||
# This rule defines the collision mechanic, introduces a related DoneCondition and lets you specify rewards.
|
||||
# Can be omited/ignored if you do not want to take care of collisions at all.
|
||||
# This does not mean, that agents can not collide, its just ignored.
|
||||
WatchCollisions:
|
||||
reward: 0
|
||||
done_at_collisions: false
|
||||
|
||||
# In 8 Puzzle, do not randomize the start positions, rather move a random agent onto the single free position n-times.
|
||||
DoRandomInitialSteps:
|
||||
# How many times?
|
||||
random_steps: 2
|
||||
|
||||
## Done Conditions
|
||||
# Maximum steps per episode. There is no reward for failing.
|
||||
DoneAtMaxStepsReached:
|
||||
# After how many steps should the episode end?
|
||||
max_steps: 200
|
||||
|
||||
# For 8 Puzzle we need a done condition that checks whether destinations have been reached, so...
|
||||
DoneAtDestinationReach:
|
||||
# On every step, should there be a reward for agets that reach their associated destination? No!
|
||||
dest_reach_reward: 0 # Do not touch. This is usefull in other settings!
|
||||
# Reward should only be given when all destiantions are reached in parallel!
|
||||
condition: "simultaneous"
|
||||
# Reward if this is the case. Granted to each agent when all agents are at their target position simultaniously.
|
||||
reward_at_done: 1
|
||||
@@ -1,92 +0,0 @@
|
||||
General:
|
||||
# Your Seed
|
||||
env_seed: 69
|
||||
# Individual vs global rewards
|
||||
individual_rewards: true
|
||||
# The level.txt file to load from marl_factory_grid/levels
|
||||
level_name: narrow_corridor
|
||||
# View Radius; 0 = full observatbility
|
||||
pomdp_r: 0
|
||||
# print all messages and events
|
||||
verbose: true
|
||||
# Run tests
|
||||
tests: false
|
||||
|
||||
Agents:
|
||||
# Agents are identified by their name
|
||||
Wolfgang:
|
||||
# The available actions for this particular agent
|
||||
Actions:
|
||||
# Able to do nothing
|
||||
- Noop
|
||||
# Able to move in all 8 directions
|
||||
- Move8
|
||||
# Stuff the agent can observe (per 2d slice)
|
||||
# use "Combined" if you want to merge multiple slices into one
|
||||
Observations:
|
||||
# He sees walls
|
||||
- Walls
|
||||
# he sees other agent, "karl-Heinz" in this setting would be fine, too
|
||||
- Other
|
||||
# He can see Destinations, that are assigned to him (hence the singular)
|
||||
- Destination
|
||||
# Avaiable Spawn Positions as list
|
||||
Positions:
|
||||
- (2, 1)
|
||||
- (2, 5)
|
||||
# It is okay to collide with other agents, so that
|
||||
# they end up on the same position
|
||||
is_blocking_pos: true
|
||||
# See Above....
|
||||
Karl-Heinz:
|
||||
Actions:
|
||||
- Noop
|
||||
- Move8
|
||||
Observations:
|
||||
- Walls
|
||||
- Other
|
||||
- Destination
|
||||
Positions:
|
||||
- (2, 1)
|
||||
- (2, 5)
|
||||
is_blocking_pos: true
|
||||
|
||||
# Other noteworthy Entitites
|
||||
Entities:
|
||||
# The destiantions or positional targets to reach
|
||||
Destinations:
|
||||
# Let them spawn on closed doors and agent positions
|
||||
ignore_blocking: true
|
||||
# We need a special spawn rule...
|
||||
spawnrule:
|
||||
# ...which assigns the destinations per agent
|
||||
SpawnDestinationsPerAgent:
|
||||
# we use this parameter
|
||||
coords_or_quantity:
|
||||
# to enable and assign special positions per agent
|
||||
Wolfgang:
|
||||
- (2, 1)
|
||||
- (2, 5)
|
||||
Karl-Heinz:
|
||||
- (2, 1)
|
||||
- (2, 5)
|
||||
# Whether you want to provide a numeric Position observation.
|
||||
# GlobalPositions:
|
||||
# normalized: false
|
||||
|
||||
# Define the env. dynamics
|
||||
Rules:
|
||||
# Utilities
|
||||
# This rule Checks for Collision, also it assigns the (negative) reward
|
||||
WatchCollisions:
|
||||
reward: -0.1
|
||||
reward_at_done: -1
|
||||
done_at_collisions: false
|
||||
# Done Conditions
|
||||
# Load any of the rules, to check for done conditions.
|
||||
DoneAtDestinationReach:
|
||||
reward_at_done: 1
|
||||
# We want to give rewards only, when all targets have been reached.
|
||||
condition: "all"
|
||||
DoneAtMaxStepsReached:
|
||||
max_steps: 200
|
||||
@@ -1,70 +0,0 @@
|
||||
General:
|
||||
# Your Seed
|
||||
env_seed: 69
|
||||
# Individual vs global rewards
|
||||
individual_rewards: true
|
||||
level_name: simple_crossing
|
||||
# View Radius; 0 = full observatbility
|
||||
pomdp_r: 0
|
||||
verbose: false
|
||||
tests: false
|
||||
|
||||
Agents:
|
||||
Agent_horizontal:
|
||||
Actions:
|
||||
- Noop
|
||||
- Move4
|
||||
Observations:
|
||||
- Walls
|
||||
- Other
|
||||
- Destination
|
||||
# Avaiable Spawn Positions as list
|
||||
Positions:
|
||||
- (2,1)
|
||||
# It is okay to collide with other agents, so that
|
||||
# they end up on the same position
|
||||
is_blocking_pos: false
|
||||
Agent_vertical:
|
||||
Actions:
|
||||
- Noop
|
||||
- Move4
|
||||
Observations:
|
||||
- Walls
|
||||
- Other
|
||||
- Destination
|
||||
Positions:
|
||||
- (1,2)
|
||||
is_blocking_pos: false
|
||||
|
||||
# Other noteworthy Entitites
|
||||
Entities:
|
||||
Destinations:
|
||||
# Let them spawn on closed doors and agent positions
|
||||
ignore_blocking: true
|
||||
spawnrule:
|
||||
SpawnDestinationsPerAgent:
|
||||
coords_or_quantity:
|
||||
Agent_horizontal:
|
||||
- (2,3)
|
||||
Agent_vertical:
|
||||
- (3,2)
|
||||
# Whether you want to provide a numeric Position observation.
|
||||
# GlobalPositions:
|
||||
# normalized: false
|
||||
|
||||
# Define the env. dynamics
|
||||
Rules:
|
||||
# Utilities
|
||||
# This rule Checks for Collision, also it assigns the (negative) reward
|
||||
WatchCollisions:
|
||||
reward: -0.1
|
||||
reward_at_done: -1
|
||||
done_at_collisions: false
|
||||
# Done Conditions
|
||||
# Load any of the rules, to check for done conditions.
|
||||
DoneAtDestinationReach:
|
||||
reward_at_done: 1
|
||||
# We want to give rewards only, when all targets have been reached.
|
||||
condition: "all"
|
||||
DoneAtMaxStepsReached:
|
||||
max_steps: 200
|
||||
@@ -1,124 +0,0 @@
|
||||
Agents:
|
||||
# Clean test agent:
|
||||
# Actions:
|
||||
# - Noop
|
||||
# - Charge
|
||||
# - Clean
|
||||
# - DoorUse
|
||||
# - Move8
|
||||
# Observations:
|
||||
# - Combined:
|
||||
# - Other
|
||||
# - Walls
|
||||
# - GlobalPosition
|
||||
# - Battery
|
||||
# - ChargePods
|
||||
# - DirtPiles
|
||||
# - Destinations
|
||||
# - Doors
|
||||
# - Maintainers
|
||||
# Clones: 0
|
||||
# Item test agent:
|
||||
# Actions:
|
||||
# - Noop
|
||||
# - Charge
|
||||
# - DestAction
|
||||
# - DoorUse
|
||||
# - ItemAction
|
||||
# - Move8
|
||||
# Observations:
|
||||
# - Combined:
|
||||
# - Other
|
||||
# - Walls
|
||||
# - GlobalPosition
|
||||
# - Battery
|
||||
# - ChargePods
|
||||
# - Destinations
|
||||
# - Doors
|
||||
# - Items
|
||||
# - Inventory
|
||||
# - DropOffLocations
|
||||
# - Maintainers
|
||||
# Clones: 0
|
||||
Target test agent:
|
||||
Actions:
|
||||
- Noop
|
||||
- Charge
|
||||
- DoorUse
|
||||
- Move8
|
||||
Observations:
|
||||
- Combined:
|
||||
- Other
|
||||
- Walls
|
||||
- GlobalPosition
|
||||
- Battery
|
||||
- Destinations
|
||||
- Doors
|
||||
- Maintainers
|
||||
Clones: 1
|
||||
|
||||
Entities:
|
||||
|
||||
Batteries:
|
||||
initial_charge: 0.8
|
||||
per_action_costs: 0.02
|
||||
ChargePods:
|
||||
coords_or_quantity: 2
|
||||
Destinations:
|
||||
coords_or_quantity: 1
|
||||
spawn_mode: GROUPED
|
||||
DirtPiles:
|
||||
coords_or_quantity: 10
|
||||
initial_amount: 2
|
||||
clean_amount: 1
|
||||
dirt_spawn_r_var: 0.1
|
||||
max_global_amount: 20
|
||||
max_local_amount: 5
|
||||
Doors:
|
||||
DropOffLocations:
|
||||
coords_or_quantity: 1
|
||||
max_dropoff_storage_size: 0
|
||||
GlobalPositions: {}
|
||||
Inventories: {}
|
||||
Items:
|
||||
coords_or_quantity: 5
|
||||
Machines:
|
||||
coords_or_quantity: 2
|
||||
Maintainers:
|
||||
coords_or_quantity: 1
|
||||
|
||||
General:
|
||||
env_seed: 69
|
||||
individual_rewards: true
|
||||
level_name: quadrant
|
||||
pomdp_r: 3
|
||||
verbose: false
|
||||
tests: false
|
||||
|
||||
Rules:
|
||||
# Environment Dynamics
|
||||
EntitiesSmearDirtOnMove:
|
||||
smear_ratio: 0.2
|
||||
DoorAutoClose:
|
||||
close_frequency: 10
|
||||
MoveMaintainers:
|
||||
|
||||
# Respawn Stuff
|
||||
RespawnDirt:
|
||||
respawn_freq: 15
|
||||
RespawnItems:
|
||||
respawn_freq: 15
|
||||
|
||||
# Utilities
|
||||
WatchCollisions:
|
||||
done_at_collisions: false
|
||||
|
||||
# Done Conditions
|
||||
DoneAtMaxStepsReached:
|
||||
max_steps: 20
|
||||
|
||||
Tests:
|
||||
# MaintainerTest: {}
|
||||
# DirtAgentTest: {}
|
||||
# ItemAgentTest: {}
|
||||
# TargetAgentTest: {}
|
||||
@@ -1,69 +0,0 @@
|
||||
General:
|
||||
env_seed: 69
|
||||
# Individual vs global rewards
|
||||
individual_rewards: true
|
||||
# The level.txt file to load from marl_factory_grid/levels
|
||||
level_name: two_rooms
|
||||
# View Radius; 0 = full observatbility
|
||||
pomdp_r: 3
|
||||
# Print all messages and events
|
||||
verbose: false
|
||||
# Run tests
|
||||
tests: false
|
||||
|
||||
# In "two rooms one door" scenario 2 agents spawn in 2 different rooms that are connected by a single door. Their aim
|
||||
# is to reach the destination in the room they didn't spawn in leading to a conflict at the door.
|
||||
Agents:
|
||||
Wolfgang:
|
||||
Actions:
|
||||
- Move8
|
||||
- Noop
|
||||
- DestAction
|
||||
- DoorUse
|
||||
Observations:
|
||||
- Walls
|
||||
- Other
|
||||
- Doors
|
||||
- Destination
|
||||
Sigmund:
|
||||
Actions:
|
||||
- Move8
|
||||
- Noop
|
||||
- DestAction
|
||||
- DoorUse
|
||||
Observations:
|
||||
- Combined:
|
||||
- Other
|
||||
- Walls
|
||||
- Destination
|
||||
- Doors
|
||||
|
||||
Entities:
|
||||
Destinations:
|
||||
spawnrule:
|
||||
SpawnDestinationsPerAgent:
|
||||
coords_or_quantity:
|
||||
Wolfgang:
|
||||
- (6,12)
|
||||
Sigmund:
|
||||
- (6, 2)
|
||||
|
||||
Doors: { }
|
||||
GlobalPositions: { }
|
||||
|
||||
Rules:
|
||||
# Environment Dynamics
|
||||
DoorAutoClose:
|
||||
close_frequency: 10
|
||||
|
||||
# Utilities
|
||||
# This rule defines the collision mechanic, introduces a related DoneCondition and lets you specify rewards.
|
||||
WatchCollisions:
|
||||
done_at_collisions: false
|
||||
|
||||
# Init
|
||||
AssignGlobalPositions: { }
|
||||
|
||||
# Done Conditions
|
||||
DoneAtMaxStepsReached:
|
||||
max_steps: 10
|
||||
@@ -3,7 +3,7 @@ General:
|
||||
# Individual vs global rewards
|
||||
individual_rewards: true
|
||||
# The level.txt file to load from marl_factory_grid/levels
|
||||
level_name: two_rooms_modified
|
||||
level_name: two_rooms
|
||||
# View Radius; 0 = full observatbility
|
||||
pomdp_r: 0
|
||||
# Print all messages and events
|
||||
@@ -3,7 +3,7 @@ General:
|
||||
# Individual vs global rewards
|
||||
individual_rewards: true
|
||||
# The level.txt file to load from marl_factory_grid/levels
|
||||
level_name: two_rooms_modified
|
||||
level_name: two_rooms
|
||||
# View Radius; 0 = full observatbility
|
||||
pomdp_r: 0
|
||||
# Print all messages and events
|
||||
@@ -3,7 +3,7 @@ General:
|
||||
# Individual vs global rewards
|
||||
individual_rewards: true
|
||||
# The level.txt file to load from marl_factory_grid/levels
|
||||
level_name: two_rooms_modified
|
||||
level_name: two_rooms
|
||||
# View Radius; 0 = full observatbility
|
||||
pomdp_r: 0
|
||||
# Print all messages and events
|
||||
@@ -3,7 +3,7 @@ General:
|
||||
# Individual vs global rewards
|
||||
individual_rewards: true
|
||||
# The level.txt file to load from marl_factory_grid/levels
|
||||
level_name: two_rooms_modified
|
||||
level_name: two_rooms
|
||||
# View Radius; 0 = full observatbility
|
||||
pomdp_r: 0
|
||||
# Print all messages and events
|
||||
@@ -3,7 +3,7 @@ General:
|
||||
# Individual vs global rewards
|
||||
individual_rewards: true
|
||||
# The level.txt file to load from marl_factory_grid/levels
|
||||
level_name: two_rooms_modified
|
||||
level_name: two_rooms
|
||||
# View Radius; 0 = full observatbility
|
||||
pomdp_r: 0
|
||||
# Print all messages and events
|
||||
@@ -109,7 +109,6 @@ class Factory(gym.Env):
|
||||
|
||||
# expensive - don't use; unless required !
|
||||
self._renderer = None
|
||||
self._recorder = None
|
||||
|
||||
# Init entities
|
||||
entities = self.map.do_init()
|
||||
@@ -278,7 +277,7 @@ class Factory(gym.Env):
|
||||
for render_entity in render_entities:
|
||||
if render_entity.name == c.AGENT:
|
||||
render_entity.aux = self.obs_builder.curr_lightmaps[render_entity.real_name]
|
||||
return self._renderer.render(render_entities, self._recorder)
|
||||
return self._renderer.render(render_entities)
|
||||
|
||||
def set_recorder(self, recorder):
|
||||
self._recorder = recorder
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
MOVEMENTS_VALID: float = -1 # default: -0.001
|
||||
MOVEMENTS_FAIL: float = -1 # default: -0.05
|
||||
MOVEMENTS_VALID: float = -1
|
||||
MOVEMENTS_FAIL: float = -1
|
||||
NOOP: float = -1
|
||||
COLLISION: float = -1
|
||||
COLLISION_DONE: float = -1
|
||||
|
||||
@@ -1,11 +1,6 @@
|
||||
import unittest
|
||||
from typing import List
|
||||
|
||||
import marl_factory_grid.modules.maintenance.constants as M
|
||||
from marl_factory_grid.environment.entity.agent import Agent
|
||||
from marl_factory_grid.modules import Door, Machine, DirtPile, Item, DropOffLocation, ItemAction
|
||||
from marl_factory_grid.utils.results import TickResult, DoneResult, ActionResult
|
||||
import marl_factory_grid.environment.constants as c
|
||||
from marl_factory_grid.utils.results import TickResult, DoneResult
|
||||
|
||||
|
||||
class Test(unittest.TestCase):
|
||||
@@ -41,235 +36,3 @@ class Test(unittest.TestCase):
|
||||
|
||||
def on_check_done(self, state) -> List[DoneResult]:
|
||||
return []
|
||||
|
||||
|
||||
class MaintainerTest(Test):
|
||||
|
||||
def __init__(self):
|
||||
"""
|
||||
Tests whether the maintainer performs the correct actions and whether his actions register correctly in the env.
|
||||
"""
|
||||
super().__init__()
|
||||
self.temp_state_dict = {}
|
||||
pass
|
||||
|
||||
def tick_step(self, state) -> List[TickResult]:
|
||||
for maintainer in state.entities[M.MAINTAINERS]:
|
||||
self.assertIsInstance(maintainer.state, (ActionResult, TickResult))
|
||||
# print(f"state validity maintainer: {maintainer.state.validity}")
|
||||
|
||||
# will open doors when standing in front
|
||||
if maintainer._closed_door_in_path(state):
|
||||
self.assertEqual(maintainer.get_move_action(state).name, 'use_door')
|
||||
|
||||
# if maintainer._next and not maintainer._path:
|
||||
# finds valid targets when at target location
|
||||
# route = maintainer.calculate_route(maintainer._last[-1], state.floortile_graph)
|
||||
# if entities_at_target_location := [entity for entity in state.entities.by_pos(route[-1])]:
|
||||
# self.assertTrue(any(isinstance(e, Machine) for e in entities_at_target_location))
|
||||
return []
|
||||
|
||||
def tick_post_step(self, state) -> List[TickResult]:
|
||||
# do maintainers' actions have correct effects on environment i.e. doors open, machines heal
|
||||
for maintainer in state.entities[M.MAINTAINERS]:
|
||||
if maintainer._path and self.temp_state_dict != {}:
|
||||
if maintainer.identifier in self.temp_state_dict:
|
||||
last_action = self.temp_state_dict[maintainer.identifier]
|
||||
if last_action.identifier == 'DoorUse':
|
||||
if door := next((entity for entity in state.entities.get_entities_near_pos(maintainer.pos) if
|
||||
isinstance(entity, Door)), None):
|
||||
agents_near_door = [agent for agent in state.entities.get_entities_near_pos(door.pos) if
|
||||
isinstance(agent, Agent)]
|
||||
if len(agents_near_door) < 2:
|
||||
self.assertTrue(door.is_open)
|
||||
if last_action.identifier == 'MachineAction':
|
||||
if machine := next((entity for entity in state.entities.get_entities_near_pos(maintainer.pos) if
|
||||
isinstance(entity, Machine)), None):
|
||||
self.assertEqual(machine.health, 100)
|
||||
return []
|
||||
|
||||
def on_check_done(self, state) -> List[DoneResult]:
|
||||
# clear dict as the maintainer identifier increments each run the dict would fill over episodes
|
||||
self.temp_state_dict = {}
|
||||
for maintainer in state.entities[M.MAINTAINERS]:
|
||||
temp_state = maintainer._status
|
||||
if isinstance(temp_state, (ActionResult, TickResult)):
|
||||
# print(f"maintainer {temp_state}")
|
||||
self.temp_state_dict[maintainer.identifier] = temp_state
|
||||
else:
|
||||
self.temp_state_dict[maintainer.identifier] = None
|
||||
return []
|
||||
|
||||
|
||||
class DirtAgentTest(Test):
|
||||
|
||||
def __init__(self):
|
||||
"""
|
||||
Tests whether the dirt agent will perform the correct actions and whether the actions register correctly in the
|
||||
environment.
|
||||
"""
|
||||
super().__init__()
|
||||
self.temp_state_dict = {}
|
||||
pass
|
||||
|
||||
def on_init(self, state, lvl_map):
|
||||
return []
|
||||
|
||||
def on_reset(self):
|
||||
return []
|
||||
|
||||
def tick_step(self, state) -> List[TickResult]:
|
||||
for dirtagent in [a for a in state.entities[c.AGENT] if "Clean" in a.identifier]: # isinstance TSPDirtAgent
|
||||
# state usually is an actionresult but after a crash, tickresults are reported
|
||||
self.assertIsInstance(dirtagent.state, (ActionResult, TickResult))
|
||||
# print(f"state validity dirtagent: {dirtagent.state.validity}")
|
||||
return []
|
||||
|
||||
def tick_post_step(self, state) -> List[TickResult]:
|
||||
# do agents' actions have correct effects on environment i.e. doors open, dirt is cleaned
|
||||
for dirtagent in [a for a in state.entities[c.AGENT] if "Clean" in a.identifier]: # isinstance TSPDirtAgent
|
||||
if self.temp_state_dict != {}:
|
||||
last_action = self.temp_state_dict[dirtagent.identifier]
|
||||
if last_action.identifier == 'DoorUse':
|
||||
if door := next((entity for entity in state.entities.get_entities_near_pos(dirtagent.pos) if
|
||||
isinstance(entity, Door)), None):
|
||||
agents_near_door = [agent for agent in state.entities.get_entities_near_pos(door.pos) if
|
||||
isinstance(agent, Agent)]
|
||||
if len(agents_near_door) < 2:
|
||||
# self.assertTrue(door.is_open)
|
||||
if door.is_closed:
|
||||
print("door should be open but seems closed.")
|
||||
if last_action.identifier == 'Clean':
|
||||
if dirt := next((entity for entity in state.entities.get_entities_near_pos(dirtagent.pos) if
|
||||
isinstance(entity, DirtPile)), None):
|
||||
# print(f"dirt left on pos: {dirt.amount}")
|
||||
self.assertTrue(dirt.amount < 5) # get dirt amount one step before - clean amount
|
||||
return []
|
||||
|
||||
def on_check_done(self, state) -> List[DoneResult]:
|
||||
for dirtagent in [a for a in state.entities[c.AGENT] if "Clean" in a.identifier]: # isinstance TSPDirtAgent
|
||||
temp_state = dirtagent._status
|
||||
if isinstance(temp_state, (ActionResult, TickResult)):
|
||||
# print(f"dirtagent {temp_state}")
|
||||
self.temp_state_dict[dirtagent.identifier] = temp_state
|
||||
else:
|
||||
self.temp_state_dict[dirtagent.identifier] = None
|
||||
return []
|
||||
|
||||
|
||||
class ItemAgentTest(Test):
|
||||
|
||||
def __init__(self):
|
||||
"""
|
||||
Tests whether the dirt agent will perform the correct actions and whether the actions register correctly in the
|
||||
environment.
|
||||
"""
|
||||
super().__init__()
|
||||
self.temp_state_dict = {}
|
||||
pass
|
||||
|
||||
def on_init(self, state, lvl_map):
|
||||
return []
|
||||
|
||||
def on_reset(self):
|
||||
return []
|
||||
|
||||
def tick_step(self, state) -> List[TickResult]:
|
||||
for itemagent in [a for a in state.entities[c.AGENT] if "Item" in a.identifier]: # isinstance TSPItemAgent
|
||||
# state usually is an actionresult but after a crash, tickresults are reported
|
||||
self.assertIsInstance(itemagent.state, (ActionResult, TickResult))
|
||||
# self.assertEqual(agent.state.validity, True)
|
||||
# print(f"state validity itemagent: {itemagent.state.validity}")
|
||||
|
||||
return []
|
||||
|
||||
def tick_post_step(self, state) -> List[TickResult]:
|
||||
# do agents' actions have correct effects on environment i.e. doors open, items are picked up and dropped off
|
||||
for itemagent in [a for a in state.entities[c.AGENT] if "Item" in a.identifier]: # isinstance TSPItemAgent
|
||||
|
||||
if self.temp_state_dict != {}: # and
|
||||
last_action = self.temp_state_dict[itemagent.identifier]
|
||||
if last_action.identifier == 'DoorUse':
|
||||
if door := next((entity for entity in state.entities.get_entities_near_pos(itemagent.pos) if
|
||||
isinstance(entity, Door)), None):
|
||||
agents_near_door = [agent for agent in state.entities.get_entities_near_pos(door.pos) if
|
||||
isinstance(agent, Agent)]
|
||||
if len(agents_near_door) < 2:
|
||||
# self.assertTrue(door.is_open)
|
||||
if door.is_closed:
|
||||
print("door should be open but seems closed.")
|
||||
|
||||
# if last_action.identifier == 'ItemAction':
|
||||
# If it was a pick-up action the item should be in the agents inventory and not in his neighboring
|
||||
# positions anymore
|
||||
# nearby_items = [e for e in state.entities.get_entities_near_pos(itemagent.pos) if
|
||||
# isinstance(e, Item)]
|
||||
# self.assertNotIn(Item, nearby_items)
|
||||
# self.assertTrue(itemagent.bound_entity) # where is the inventory
|
||||
#
|
||||
# If it was a drop-off action the item should not be in the agents inventory anymore but instead in
|
||||
# the drop-off locations inventory
|
||||
#
|
||||
# if nearby_drop_offs := [e for e in state.entities.get_entities_near_pos(itemagent.pos) if
|
||||
# isinstance(e, DropOffLocation)]:
|
||||
# dol = nearby_drop_offs[0]
|
||||
# self.assertTrue(dol.bound_entity) # item in drop-off location?
|
||||
# self.assertNotIn(Item, state.entities.get_entities_near_pos(itemagent.pos))
|
||||
|
||||
return []
|
||||
|
||||
def on_check_done(self, state) -> List[DoneResult]:
|
||||
for itemagent in [a for a in state.entities[c.AGENT] if "Item" in a.identifier]: # isinstance TSPItemAgent
|
||||
temp_state = itemagent._status
|
||||
# print(f"itemagent {temp_state}")
|
||||
self.temp_state_dict[itemagent.identifier] = temp_state
|
||||
return []
|
||||
|
||||
|
||||
class TargetAgentTest(Test):
|
||||
|
||||
def __init__(self):
|
||||
"""
|
||||
Tests whether the target agent will perform the correct actions and whether the actions register correctly in the
|
||||
environment.
|
||||
"""
|
||||
super().__init__()
|
||||
self.temp_state_dict = {}
|
||||
pass
|
||||
|
||||
def on_init(self, state, lvl_map):
|
||||
return []
|
||||
|
||||
def on_reset(self):
|
||||
return []
|
||||
|
||||
def tick_step(self, state) -> List[TickResult]:
|
||||
for targetagent in [a for a in state.entities[c.AGENT] if "Target" in a.identifier]:
|
||||
# state usually is an actionresult but after a crash, tickresults are reported
|
||||
self.assertIsInstance(targetagent.state, (ActionResult, TickResult))
|
||||
# print(f"state validity targetagent: {targetagent.state.validity}")
|
||||
return []
|
||||
|
||||
def tick_post_step(self, state) -> List[TickResult]:
|
||||
# do agents' actions have correct effects on environment i.e. doors open, targets are destinations
|
||||
for targetagent in [a for a in state.entities[c.AGENT] if "Target" in a.identifier]:
|
||||
if self.temp_state_dict != {}:
|
||||
last_action = self.temp_state_dict[targetagent.identifier]
|
||||
if last_action.identifier == 'DoorUse':
|
||||
if door := next((entity for entity in state.entities.get_entities_near_pos(targetagent.pos) if
|
||||
isinstance(entity, Door)), None):
|
||||
agents_near_door = [agent for agent in state.entities.get_entities_near_pos(door.pos) if
|
||||
isinstance(agent, Agent)]
|
||||
if len(agents_near_door) < 2:
|
||||
# self.assertTrue(door.is_open)
|
||||
if door.is_closed:
|
||||
print("door should be open but seems closed.")
|
||||
|
||||
return []
|
||||
|
||||
def on_check_done(self, state) -> List[DoneResult]:
|
||||
for targetagent in [a for a in state.entities[c.AGENT] if "Target" in a.identifier]:
|
||||
temp_state = targetagent._status
|
||||
# print(f"targetagent {temp_state}")
|
||||
self.temp_state_dict[targetagent.identifier] = temp_state
|
||||
return []
|
||||
|
||||
@@ -1,5 +0,0 @@
|
||||
#####
|
||||
#---#
|
||||
#---#
|
||||
#---#
|
||||
#####
|
||||
@@ -1,24 +0,0 @@
|
||||
##############################################################
|
||||
#-----------#---#--------------------------------------------#
|
||||
#-----------#---#--------------------------------------------#
|
||||
#-----------#---#------##------##------##------##------##----#
|
||||
#-----------#---D------##------##------##------##------##----#
|
||||
#-----------D---#--------------------------------------------#
|
||||
#-----------#---#--------------------------------------------#
|
||||
#############---####################D####################D####
|
||||
#------------------------------------------------------------#
|
||||
#------------------------------------------------------------#
|
||||
#------------------------------------------------------------#
|
||||
####################-####################################D####
|
||||
#-----------------#---#------------------------------#-------#
|
||||
#-----------------#---D------------------------------#-------#
|
||||
#-----------------D---#------------------------------#-------#
|
||||
#-----------------#---#######D#############D##########-------#
|
||||
#-----------------#---D------------------------------D-------#
|
||||
###################---#------------------------------#-------#
|
||||
#-----------------#---#######D#############D##########-------#
|
||||
#-----------------D---#------------------------------#-------#
|
||||
#-----------------#---#------------------------------#-------#
|
||||
#-----------------#---#------------------------------D-------#
|
||||
#-----------------#---#------------------------------#-------#
|
||||
##############################################################
|
||||
@@ -1,47 +0,0 @@
|
||||
###########################################################################################################################
|
||||
#-----------#---#--------------------------------------------#-----------#---#--------------------------------------------#
|
||||
#-----------#---#--------------------------------------------#-----------#---#--------------------------------------------#
|
||||
#-----------#---#------##------##------##------##------##----#-----------#---#------##------##------##------##------##----#
|
||||
#-----------#---D------##------##------##------##------##----#-----------#---D------##------##------##------##------##----#
|
||||
#-----------D---#--------------------------------------------#-----------D---#--------------------------------------------#
|
||||
#-----------#---#--------------------------------------------#-----------#---#--------------------------------------------#
|
||||
#############---####################D####################D################---####################D####################D####
|
||||
#------------------------------------------------------------#------------------------------------------------------------#
|
||||
#------------------------------------------------------------D------------------------------------------------------------#
|
||||
#------------------------------------------------------------#------------------------------------------------------------#
|
||||
####################-####################################D#######################-####################################D####
|
||||
#-----------------#---#------------------------------#-------#-----------------#---#------------------------------#-------#
|
||||
#-----------------#---D------------------------------#-------#-----------------#---D------------------------------#-------#
|
||||
#-----------------D---#------------------------------#-------#-----------------D---#------------------------------#-------#
|
||||
#-----------------#---#######D#############D##########-------#-----------------#---#######D#############D##########-------#
|
||||
#-----------------#---D------------------------------D-------#-----------------#---D------------------------------D-------#
|
||||
###################---#------------------------------#-------###################---#------------------------------#-------#
|
||||
#-----------------#---#######D#############D##########-------#-----------------#---#######D#############D##########-------#
|
||||
#-----------------D---#------------------------------#-------D-----------------D---#------------------------------#-------#
|
||||
#-----------------#---#------------------------------#-------#-----------------#---#------------------------------#-------#
|
||||
#-----------------#---#------------------------------D-------#-----------------#---#------------------------------D-------#
|
||||
#-----------------#---#------------------------------#-------#-----------------#---#------------------------------#-------#
|
||||
##############D############################################################D###############################################
|
||||
#-----------#---#--------------------------------------------#-----------#---#--------------------------------------------#
|
||||
#-----------#---#--------------------------------------------#-----------#---#--------------------------------------------#
|
||||
#-----------#---#------##------##------##------##------##----#-----------#---#------##------##------##------##------##----#
|
||||
#-----------#---D------##------##------##------##------##----#-----------#---D------##------##------##------##------##----#
|
||||
#-----------D---#--------------------------------------------#-----------D---#--------------------------------------------#
|
||||
#-----------#---#--------------------------------------------#-----------#---#--------------------------------------------#
|
||||
#############---####################D####################D################---####################D####################D####
|
||||
#------------------------------------------------------------#------------------------------------------------------------#
|
||||
#------------------------------------------------------------D------------------------------------------------------------#
|
||||
#------------------------------------------------------------#------------------------------------------------------------#
|
||||
###################---###################################D######################---###################################D####
|
||||
#-----------------#---#------------------------------#-------#-----------------#---#------------------------------#-------#
|
||||
#-----------------#---D------------------------------#-------#-----------------#---D------------------------------#-------#
|
||||
#-----------------D---#------------------------------#-------#-----------------D---#------------------------------#-------#
|
||||
#-----------------#---#######D#############D##########-------#-----------------#---#######D#############D##########-------#
|
||||
#-----------------#---D------------------------------D-------#-----------------#---D------------------------------D-------#
|
||||
###################---#------------------------------#-------###################---#------------------------------#-------#
|
||||
#-----------------#---#######D#############D##########-------#-----------------#---#######D#############D##########-------#
|
||||
#-----------------D---#------------------------------#-------#-----------------D---#------------------------------#-------#
|
||||
#-----------------#---#------------------------------#-------#-----------------#---#------------------------------#-------#
|
||||
#-----------------#---#------------------------------D-------#-----------------#---#------------------------------D-------#
|
||||
#-----------------#---#------------------------------#-------#-----------------#---#------------------------------#-------#
|
||||
###########################################################################################################################
|
||||
@@ -1,5 +0,0 @@
|
||||
#######
|
||||
###-###
|
||||
#-----#
|
||||
###-###
|
||||
#######
|
||||
@@ -1,12 +0,0 @@
|
||||
############
|
||||
#----------#
|
||||
#-#######--#
|
||||
#-#-----D--#
|
||||
#-#######--#
|
||||
#-D-----D--#
|
||||
#-#-#-#-#-##
|
||||
#----------#
|
||||
#----------#
|
||||
#----------#
|
||||
#----------#
|
||||
############
|
||||
@@ -1,13 +0,0 @@
|
||||
###############
|
||||
#333x33#444444#
|
||||
#333#33#444444#
|
||||
#333333xx#4444#
|
||||
#333333#444444#
|
||||
#333333#444444#
|
||||
###x#######D###
|
||||
#1111##2222222#
|
||||
#11111#2222#22#
|
||||
#11111D2222222#
|
||||
#11111#2222222#
|
||||
#11111#2222222#
|
||||
###############
|
||||
@@ -1,13 +0,0 @@
|
||||
############
|
||||
#----------#
|
||||
#--######--#
|
||||
#----------#
|
||||
#--######--#
|
||||
#----------#
|
||||
#--######--#
|
||||
#----------#
|
||||
#--######--#
|
||||
#----------#
|
||||
#--######--#
|
||||
#----------#
|
||||
############
|
||||
@@ -1,12 +0,0 @@
|
||||
############
|
||||
#----------#
|
||||
#---#------#
|
||||
#--------#-#
|
||||
#----------#
|
||||
#--#-------#
|
||||
#----------#
|
||||
#----#-----#
|
||||
#----------#
|
||||
#-------#--#
|
||||
#----------#
|
||||
############
|
||||
@@ -1,5 +0,0 @@
|
||||
#####
|
||||
##-##
|
||||
#---#
|
||||
##-##
|
||||
#####
|
||||
@@ -1,13 +1,7 @@
|
||||
###############
|
||||
#111111#222222#
|
||||
#111111#222222#
|
||||
#111111#222222#
|
||||
#111111#222222#
|
||||
#111111#222222#
|
||||
#111111D222222#
|
||||
#111111#222222#
|
||||
#111111#222222#
|
||||
#111111#222222#
|
||||
#111111#222222#
|
||||
#111111#222222#
|
||||
#------#------#
|
||||
#------#------#
|
||||
#------D------#
|
||||
#------#------#
|
||||
#------#------#
|
||||
###############
|
||||
@@ -1,7 +0,0 @@
|
||||
###############
|
||||
#111111#222222#
|
||||
#111111#222222#
|
||||
#111111D222222#
|
||||
#111111#222222#
|
||||
#111111#222222#
|
||||
###############
|
||||
@@ -1,10 +1,6 @@
|
||||
from .batteries import *
|
||||
from .clean_up import *
|
||||
from .destinations import *
|
||||
from .doors import *
|
||||
from .items import *
|
||||
from .machines import *
|
||||
from .maintenance import *
|
||||
|
||||
"""
|
||||
modules
|
||||
|
||||
@@ -1,4 +0,0 @@
|
||||
from .actions import Charge
|
||||
from .entitites import ChargePod, Battery
|
||||
from .groups import ChargePods, Batteries
|
||||
from .rules import DoneAtBatteryDischarge, BatteryDecharge
|
||||
@@ -1,31 +0,0 @@
|
||||
from typing import Union
|
||||
|
||||
from marl_factory_grid.environment.actions import Action
|
||||
from marl_factory_grid.utils.results import ActionResult
|
||||
|
||||
from marl_factory_grid.modules.batteries import constants as b
|
||||
from marl_factory_grid.environment import constants as c
|
||||
from marl_factory_grid.utils import helpers as h
|
||||
|
||||
|
||||
class Charge(Action):
|
||||
|
||||
def __init__(self):
|
||||
"""
|
||||
Checks if a charge pod is present at the agent's position.
|
||||
If found, it attempts to charge the battery using the charge pod.
|
||||
"""
|
||||
super().__init__(b.ACTION_CHARGE, b.REWARD_CHARGE_VALID, b.Reward_CHARGE_FAIL)
|
||||
|
||||
def do(self, entity, state) -> Union[None, ActionResult]:
|
||||
if charge_pod := h.get_first(state[b.CHARGE_PODS].by_pos(entity.pos)):
|
||||
valid = charge_pod.charge_battery(entity, state)
|
||||
if valid:
|
||||
state.print(f'{entity.name} just charged batteries at {charge_pod.name}.')
|
||||
else:
|
||||
state.print(f'{entity.name} failed to charged batteries at {charge_pod.name}.')
|
||||
else:
|
||||
valid = c.NOT_VALID
|
||||
state.print(f'{entity.name} failed to charged batteries at {entity.pos}.')
|
||||
|
||||
return self.get_result(valid, entity)
|
||||
Binary file not shown.
|
Before Width: | Height: | Size: 7.9 KiB |
@@ -1,17 +0,0 @@
|
||||
# Battery Env
|
||||
CHARGE_PODS = 'ChargePods'
|
||||
BATTERIES = 'Batteries'
|
||||
BATTERY_DISCHARGED = 'DISCHARGED'
|
||||
CHARGE_POD_SYMBOL = 1
|
||||
|
||||
ACTION_CHARGE = 'do_charge_action'
|
||||
|
||||
REWARD_CHARGE_VALID: float = 0.1
|
||||
Reward_CHARGE_FAIL: float = -0.1
|
||||
REWARD_BATTERY_DISCHARGED: float = -1.0
|
||||
REWARD_DISCHARGE_DONE: float = -1.0
|
||||
|
||||
|
||||
GROUPED = "single"
|
||||
SINGLE = "grouped"
|
||||
MODES = [GROUPED, SINGLE]
|
||||
@@ -1,119 +0,0 @@
|
||||
from marl_factory_grid.environment import constants as c
|
||||
from marl_factory_grid.environment.entity.agent import Agent
|
||||
from marl_factory_grid.environment.entity.entity import Entity
|
||||
from marl_factory_grid.environment.entity.object import Object
|
||||
from marl_factory_grid.modules.batteries import constants as b
|
||||
from marl_factory_grid.utils.utility_classes import RenderEntity
|
||||
|
||||
|
||||
class Battery(Object):
|
||||
|
||||
@property
|
||||
def var_can_be_bound(self):
|
||||
return True
|
||||
|
||||
@property
|
||||
def is_discharged(self) -> bool:
|
||||
"""
|
||||
Indicates whether the Batteries charge level is at 0 or not.
|
||||
|
||||
:return: Whether this battery is empty.
|
||||
"""
|
||||
return self.charge_level == 0
|
||||
|
||||
@property
|
||||
def obs_tag(self):
|
||||
return self.name
|
||||
|
||||
@property
|
||||
def encoding(self):
|
||||
return self.charge_level
|
||||
|
||||
def __init__(self, initial_charge_level, owner, *args, **kwargs):
|
||||
"""
|
||||
Represents a battery entity in the environment that can be bound to an agent and charged at charge pods.
|
||||
|
||||
:param initial_charge_level: The current charge level of the battery, ranging from 0 to 1.
|
||||
:type initial_charge_level: float
|
||||
|
||||
:param owner: The entity to which the battery is bound.
|
||||
:type owner: Entity
|
||||
"""
|
||||
super(Battery, self).__init__(*args, **kwargs)
|
||||
self.charge_level = initial_charge_level
|
||||
self.bind_to(owner)
|
||||
|
||||
def do_charge_action(self, amount) -> bool:
|
||||
"""
|
||||
Updates the Battery's charge level according to the passed value.
|
||||
|
||||
:param amount: Amount added to the Battery's charge level.
|
||||
:returns: whether the battery could be charged. if not, it was already fully charged.
|
||||
"""
|
||||
if self.charge_level < 1:
|
||||
# noinspection PyTypeChecker
|
||||
self.charge_level = min(1, amount + self.charge_level)
|
||||
return c.VALID
|
||||
else:
|
||||
return c.NOT_VALID
|
||||
|
||||
def decharge(self, amount) -> bool:
|
||||
"""
|
||||
Decreases the charge value of a battery. Currently only triggered by the battery-decharge rule.
|
||||
"""
|
||||
if self.charge_level != 0:
|
||||
# noinspection PyTypeChecker
|
||||
self.charge_level = max(0, amount + self.charge_level)
|
||||
return c.VALID
|
||||
else:
|
||||
return c.NOT_VALID
|
||||
|
||||
def summarize_state(self):
|
||||
summary = super().summarize_state()
|
||||
summary.update(dict(belongs_to=self._bound_entity.name, chargeLevel=self.charge_level))
|
||||
return summary
|
||||
|
||||
|
||||
class ChargePod(Entity):
|
||||
|
||||
@property
|
||||
def encoding(self):
|
||||
return b.CHARGE_POD_SYMBOL
|
||||
|
||||
def __init__(self, *args, charge_rate: float = 0.4, multi_charge: bool = False, **kwargs):
|
||||
"""
|
||||
Represents a charging pod for batteries in the environment.
|
||||
|
||||
:param charge_rate: The rate at which the charging pod charges batteries. Defaults to 0.4.
|
||||
:type charge_rate: float
|
||||
|
||||
:param multi_charge: Indicates whether the charging pod supports charging multiple batteries simultaneously.
|
||||
Defaults to False.
|
||||
:type multi_charge: bool
|
||||
"""
|
||||
super(ChargePod, self).__init__(*args, **kwargs)
|
||||
self.charge_rate = charge_rate
|
||||
self.multi_charge = multi_charge
|
||||
|
||||
def charge_battery(self, entity, state) -> bool:
|
||||
"""
|
||||
Triggers the battery charge action if possible. Impossible if battery at full charge level or more than one
|
||||
agent at charge pods' position.
|
||||
|
||||
:returns: whether the action was successful (valid) or not.
|
||||
"""
|
||||
battery = state[b.BATTERIES].by_entity(entity)
|
||||
if battery.charge_level >= 1.0:
|
||||
return c.NOT_VALID
|
||||
if len([x for x in state[c.AGENT].by_pos(entity.pos)]) > 1:
|
||||
return c.NOT_VALID
|
||||
valid = battery.do_charge_action(self.charge_rate)
|
||||
return valid
|
||||
|
||||
def render(self):
|
||||
return RenderEntity(b.CHARGE_PODS, self.pos)
|
||||
|
||||
def summarize_state(self) -> dict:
|
||||
summary = super().summarize_state()
|
||||
summary.update(charge_rate=self.charge_rate)
|
||||
return summary
|
||||
@@ -1,52 +0,0 @@
|
||||
from typing import Union, List, Tuple
|
||||
|
||||
from marl_factory_grid.environment import constants as c
|
||||
from marl_factory_grid.environment.groups.collection import Collection
|
||||
from marl_factory_grid.modules.batteries.entitites import ChargePod, Battery
|
||||
from marl_factory_grid.utils.results import Result
|
||||
|
||||
|
||||
class Batteries(Collection):
|
||||
_entity = Battery
|
||||
|
||||
@property
|
||||
def var_has_position(self):
|
||||
return False
|
||||
|
||||
@property
|
||||
def var_can_be_bound(self):
|
||||
return True
|
||||
|
||||
def __init__(self, size, initial_charge_level=1.0, *args, **kwargs):
|
||||
"""
|
||||
A collection of batteries that is in charge of spawning batteries. (spawned batteries are bound to agents)
|
||||
|
||||
:param size: The maximum allowed size of the collection. Ensures that the collection does not exceed this size.
|
||||
:type size: int
|
||||
|
||||
:param initial_charge_level: The initial charge level of the battery.
|
||||
:type initial_charge_level: float
|
||||
"""
|
||||
super(Batteries, self).__init__(size, *args, **kwargs)
|
||||
self.initial_charge_level = initial_charge_level
|
||||
|
||||
def spawn(self, coords_or_quantity: Union[int, List[Tuple[(int, int)]]], *entity_args, **entity_kwargs):
|
||||
batteries = [self._entity(self.initial_charge_level, agent) for _, agent in enumerate(entity_args[0])]
|
||||
self.add_items(batteries)
|
||||
|
||||
def trigger_spawn(self, state, *entity_args, coords_or_quantity=None, **entity_kwargs):
|
||||
self.spawn(0, state[c.AGENT])
|
||||
return Result(identifier=f'{self.name}_spawn', validity=c.VALID, value=len(self))
|
||||
|
||||
|
||||
class ChargePods(Collection):
|
||||
_entity = ChargePod
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
"""
|
||||
A collection of charge pods in the environment.
|
||||
"""
|
||||
super(ChargePods, self).__init__(*args, **kwargs)
|
||||
|
||||
def __repr__(self):
|
||||
return super(ChargePods, self).__repr__()
|
||||
@@ -1,128 +0,0 @@
|
||||
from typing import List, Union
|
||||
|
||||
from marl_factory_grid.environment import constants as c
|
||||
from marl_factory_grid.environment.rules import Rule
|
||||
from marl_factory_grid.modules.batteries import constants as b
|
||||
from marl_factory_grid.utils.results import TickResult, DoneResult
|
||||
|
||||
|
||||
class BatteryDecharge(Rule):
|
||||
|
||||
def __init__(self, initial_charge: float = 0.8, per_action_costs: Union[dict, float] = 0.02,
|
||||
battery_charge_reward: float = b.REWARD_CHARGE_VALID,
|
||||
battery_failed_reward: float = b.Reward_CHARGE_FAIL,
|
||||
battery_discharge_reward: float = b.REWARD_BATTERY_DISCHARGED,
|
||||
paralyze_agents_on_discharge: bool = False):
|
||||
f"""
|
||||
Enables the Battery Charge/Discharge functionality.
|
||||
|
||||
:type paralyze_agents_on_discharge: bool
|
||||
:param paralyze_agents_on_discharge: Wether agents are still able to perform actions when discharged.
|
||||
:type per_action_costs: Union[dict, float] = 0.02
|
||||
:param per_action_costs: 1. dict: with an action name as key, provide a value for each
|
||||
(maybe walking is less tedious as opening a door? Just saying...).
|
||||
2. float: each action "costs" the same.
|
||||
----
|
||||
!!! Does not introduce any Env.-Done condition.
|
||||
!!! Batteries can only be charged if agent posses the "Charge" Action.
|
||||
!!! Batteries can only be charged if there are "Charge Pods" and they are spawned!
|
||||
----
|
||||
:type initial_charge: float
|
||||
:param initial_charge: How much juice they have.
|
||||
:type battery_discharge_reward: float
|
||||
:param battery_discharge_reward: Negative reward, when agents let their batters discharge.
|
||||
Default: {b.REWARD_BATTERY_DISCHARGED}
|
||||
:type battery_failed_reward: float
|
||||
:param battery_failed_reward: Negative reward, when agent cannot charge, but do (overcharge, not on station).
|
||||
Default: {b.Reward_CHARGE_FAIL}
|
||||
:type battery_charge_reward: float
|
||||
:param battery_charge_reward: Positive reward, when agent actually charge their battery.
|
||||
Default: {b.REWARD_CHARGE_VALID}
|
||||
"""
|
||||
super().__init__()
|
||||
self.paralyze_agents_on_discharge = paralyze_agents_on_discharge
|
||||
self.battery_discharge_reward = battery_discharge_reward
|
||||
self.battery_failed_reward = battery_failed_reward
|
||||
self.battery_charge_reward = battery_charge_reward
|
||||
self.per_action_costs = per_action_costs
|
||||
self.initial_charge = initial_charge
|
||||
|
||||
def tick_step(self, state) -> List[TickResult]:
|
||||
batteries = state[b.BATTERIES]
|
||||
results = []
|
||||
|
||||
for agent in state[c.AGENT]:
|
||||
if isinstance(self.per_action_costs, dict):
|
||||
energy_consumption = self.per_action_costs[agent.state.identifier]
|
||||
else:
|
||||
energy_consumption = self.per_action_costs
|
||||
|
||||
batteries.by_entity(agent).decharge(energy_consumption)
|
||||
|
||||
results.append(TickResult(self.name, entity=agent, validity=c.VALID, value=energy_consumption))
|
||||
|
||||
return results
|
||||
|
||||
def tick_post_step(self, state) -> List[TickResult]:
|
||||
results = []
|
||||
for btry in state[b.BATTERIES]:
|
||||
if btry.is_discharged:
|
||||
state.print(f'Battery of {btry.bound_entity.name} is discharged!')
|
||||
results.append(
|
||||
TickResult(self.name, entity=btry.bound_entity, reward=self.battery_discharge_reward,
|
||||
validity=c.VALID)
|
||||
)
|
||||
if self.paralyze_agents_on_discharge:
|
||||
btry.bound_entity.paralyze(self.name)
|
||||
results.append(
|
||||
TickResult("Paralyzed", entity=btry.bound_entity, validity=c.VALID)
|
||||
)
|
||||
state.print(f'{btry.bound_entity.name} has just been paralyzed!')
|
||||
if btry.bound_entity.var_is_paralyzed and not btry.is_discharged:
|
||||
btry.bound_entity.de_paralyze(self.name)
|
||||
results.append(
|
||||
TickResult("De-Paralyzed", entity=btry.bound_entity, validity=c.VALID)
|
||||
)
|
||||
state.print(f'{btry.bound_entity.name} has just been de-paralyzed!')
|
||||
return results
|
||||
|
||||
|
||||
class DoneAtBatteryDischarge(BatteryDecharge):
|
||||
|
||||
def __init__(self, reward_discharge_done=b.REWARD_DISCHARGE_DONE, mode: str = b.SINGLE, **kwargs):
|
||||
f"""
|
||||
Enables the Battery Charge/Discharge functionality. Additionally
|
||||
|
||||
:type mode: str
|
||||
:param mode: Does this Done rule trigger, when any battery is or all batteries are discharged?
|
||||
:type per_action_costs: Union[dict, float] = 0.02
|
||||
:param per_action_costs: 1. dict: with an action name as key, provide a value for each
|
||||
(maybe walking is less tedious as opening a door? Just saying...).
|
||||
2. float: each action "costs" the same.
|
||||
|
||||
:type initial_charge: float
|
||||
:param initial_charge: How much juice they have.
|
||||
:type reward_discharge_done: float
|
||||
:param reward_discharge_done: Global negative reward, when agents let their batters discharge.
|
||||
Default: {b.REWARD_BATTERY_DISCHARGED}
|
||||
:type battery_discharge_reward: float
|
||||
:param battery_discharge_reward: Negative reward, when agents let their batters discharge.
|
||||
Default: {b.REWARD_BATTERY_DISCHARGED}
|
||||
:type battery_failed_reward: float
|
||||
:param battery_failed_reward: Negative reward, when agent cannot charge, but do (overcharge, not on station).
|
||||
Default: {b.Reward_CHARGE_FAIL}
|
||||
:type battery_charge_reward: float
|
||||
:param battery_charge_reward: Positive reward, when agent actually charge their battery.
|
||||
Default: {b.REWARD_CHARGE_VALID}
|
||||
"""
|
||||
super().__init__(**kwargs)
|
||||
self.mode = mode
|
||||
self.reward_discharge_done = reward_discharge_done
|
||||
|
||||
def on_check_done(self, state) -> List[DoneResult]:
|
||||
any_discharged = (self.mode == b.SINGLE and any(battery.is_discharged for battery in state[b.BATTERIES]))
|
||||
all_discharged = (self.mode == b.SINGLE and all(battery.is_discharged for battery in state[b.BATTERIES]))
|
||||
if any_discharged or all_discharged:
|
||||
return [DoneResult(self.name, validity=c.VALID, reward=self.reward_discharge_done)]
|
||||
else:
|
||||
return [DoneResult(self.name, validity=c.NOT_VALID)]
|
||||
@@ -1,11 +0,0 @@
|
||||
from .actions import ItemAction
|
||||
from .entitites import Item, DropOffLocation
|
||||
from .groups import DropOffLocations, Items, Inventory, Inventories
|
||||
|
||||
"""
|
||||
items
|
||||
=====
|
||||
|
||||
Todo
|
||||
|
||||
"""
|
||||
@@ -1,63 +0,0 @@
|
||||
from typing import Union
|
||||
|
||||
from marl_factory_grid.environment.actions import Action
|
||||
from marl_factory_grid.utils.results import ActionResult
|
||||
|
||||
from marl_factory_grid.modules.items import constants as i
|
||||
from marl_factory_grid.environment import constants as c
|
||||
|
||||
|
||||
class ItemAction(Action):
|
||||
|
||||
def __init__(self, failed_dropoff_reward: float | None = None, valid_dropoff_reward: float | None = None, **kwargs):
|
||||
"""
|
||||
Allows an entity to pick up or drop off items in the environment.
|
||||
|
||||
:param failed_drop_off_reward: The reward assigned when a drop-off action fails. Default is None.
|
||||
:type failed_dropoff_reward: float | None
|
||||
:param valid_drop_off_reward: The reward assigned when a drop-off action is successful. Default is None.
|
||||
:type valid_dropoff_reward: float | None
|
||||
"""
|
||||
super().__init__(i.ITEM_ACTION, i.REWARD_PICK_UP_FAIL, i.REWARD_PICK_UP_VALID, **kwargs)
|
||||
self.failed_drop_off_reward = failed_dropoff_reward if failed_dropoff_reward is not None else i.REWARD_DROP_OFF_FAIL
|
||||
self.valid_drop_off_reward = valid_dropoff_reward if valid_dropoff_reward is not None else i.REWARD_DROP_OFF_VALID
|
||||
|
||||
def get_dropoff_result(self, validity, entity) -> ActionResult:
|
||||
"""
|
||||
Generates an ActionResult for a drop-off action based on its validity.
|
||||
|
||||
:param validity: Whether the drop-off action is valid.
|
||||
:type validity: bool
|
||||
|
||||
:param entity: The entity performing the action.
|
||||
:type entity: Entity
|
||||
|
||||
:return: ActionResult for the drop-off action.
|
||||
:rtype: ActionResult
|
||||
"""
|
||||
reward = self.valid_drop_off_reward if validity else self.failed_drop_off_reward
|
||||
return ActionResult(self.__class__.__name__, validity, reward=reward, entity=entity)
|
||||
|
||||
def do(self, entity, state) -> Union[None, ActionResult]:
|
||||
inventory = state[i.INVENTORY].by_entity(entity)
|
||||
if drop_off := state[i.DROP_OFF].by_pos(entity.pos):
|
||||
if inventory:
|
||||
valid = drop_off.place_item(inventory.pop())
|
||||
else:
|
||||
valid = c.NOT_VALID
|
||||
if valid:
|
||||
state.print(f'{entity.name} just dropped of an item at {drop_off.pos}.')
|
||||
else:
|
||||
state.print(f'{entity.name} just tried to drop off at {entity.pos}, but failed.')
|
||||
return self.get_dropoff_result(valid, entity)
|
||||
|
||||
elif items := state[i.ITEM].by_pos(entity.pos):
|
||||
item = items[0]
|
||||
item.change_parent_collection(inventory)
|
||||
item.set_pos(c.VALUE_NO_POS)
|
||||
state.print(f'{entity.name} just picked up an item at {entity.pos}')
|
||||
return self.get_result(c.VALID, entity)
|
||||
|
||||
else:
|
||||
state.print(f'{entity.name} just tried to pick up an item at {entity.pos}, but failed.')
|
||||
return self.get_result(c.NOT_VALID, entity)
|
||||
Binary file not shown.
|
Before Width: | Height: | Size: 6.5 KiB |
Binary file not shown.
|
Before Width: | Height: | Size: 2.3 KiB |
Binary file not shown.
|
Before Width: | Height: | Size: 3.0 KiB |
@@ -1,14 +0,0 @@
|
||||
SYMBOL_NO_ITEM = 0
|
||||
SYMBOL_DROP_OFF = 1
|
||||
# Item Env
|
||||
ITEM = 'Items'
|
||||
INVENTORY = 'Inventories'
|
||||
DROP_OFF = 'DropOffLocations'
|
||||
|
||||
ITEM_ACTION = 'ITEMACTION'
|
||||
|
||||
# Rewards
|
||||
REWARD_DROP_OFF_VALID: float = 0.1
|
||||
REWARD_DROP_OFF_FAIL: float = -0.1
|
||||
REWARD_PICK_UP_FAIL: float = -0.1
|
||||
REWARD_PICK_UP_VALID: float = 0.1
|
||||
@@ -1,59 +0,0 @@
|
||||
from collections import deque
|
||||
|
||||
from marl_factory_grid.environment.entity.entity import Entity
|
||||
from marl_factory_grid.environment import constants as c
|
||||
from marl_factory_grid.utils.utility_classes import RenderEntity
|
||||
from marl_factory_grid.modules.items import constants as i
|
||||
|
||||
|
||||
class Item(Entity):
|
||||
|
||||
@property
|
||||
def encoding(self):
|
||||
return 1
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
"""
|
||||
An item that can be picked up or dropped by agents. If picked up, it enters the agents inventory.
|
||||
"""
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
def render(self):
|
||||
return RenderEntity(i.ITEM, self.pos) if self.pos != c.VALUE_NO_POS else None
|
||||
|
||||
|
||||
class DropOffLocation(Entity):
|
||||
@property
|
||||
def encoding(self):
|
||||
return i.SYMBOL_DROP_OFF
|
||||
|
||||
@property
|
||||
def is_full(self) -> bool:
|
||||
"""
|
||||
Checks whether the drop-off location is full or whether another item can be dropped here.
|
||||
"""
|
||||
return False if not self.storage.maxlen else self.storage.maxlen == len(self.storage)
|
||||
|
||||
def __init__(self, *args, storage_size_until_full=5, **kwargs):
|
||||
"""
|
||||
Represents a drop-off location in the environment that agents aim to drop items at.
|
||||
|
||||
:param storage_size_until_full: The number of items that can be dropped here until it is considered full.
|
||||
:type storage_size_until_full: int
|
||||
"""
|
||||
super(DropOffLocation, self).__init__(*args, **kwargs)
|
||||
self.storage = deque(maxlen=storage_size_until_full or None)
|
||||
|
||||
def place_item(self, item: Item) -> bool:
|
||||
"""
|
||||
If the storage of the drop-off location is not full, the item is placed. Otherwise, a RuntimeWarning is raised.
|
||||
"""
|
||||
if self.is_full:
|
||||
raise RuntimeWarning("There is currently no way to clear the storage or make it unfull.")
|
||||
return c.NOT_VALID
|
||||
else:
|
||||
self.storage.append(item)
|
||||
return c.VALID
|
||||
|
||||
def render(self):
|
||||
return RenderEntity(i.DROP_OFF, self.pos)
|
||||
@@ -1,180 +0,0 @@
|
||||
from typing import Dict, Any
|
||||
|
||||
from marl_factory_grid.environment import constants as c
|
||||
from marl_factory_grid.environment.entity.agent import Agent
|
||||
from marl_factory_grid.environment.groups.collection import Collection
|
||||
from marl_factory_grid.environment.groups.mixins import IsBoundMixin
|
||||
from marl_factory_grid.environment.groups.objects import Objects
|
||||
from marl_factory_grid.modules.items import constants as i
|
||||
from marl_factory_grid.modules.items.entitites import Item, DropOffLocation
|
||||
from marl_factory_grid.utils.results import Result
|
||||
|
||||
|
||||
class Items(Collection):
|
||||
_entity = Item
|
||||
|
||||
@property
|
||||
def var_has_position(self):
|
||||
return True
|
||||
|
||||
@property
|
||||
def var_is_blocking_light(self):
|
||||
return False
|
||||
|
||||
@property
|
||||
def var_can_collide(self):
|
||||
return False
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
"""
|
||||
A collection of items that triggers their spawn.
|
||||
"""
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
def trigger_spawn(self, state, *entity_args, coords_or_quantity=None, **entity_kwargs) -> [Result]:
|
||||
coords_or_quantity = coords_or_quantity if coords_or_quantity else self._coords_or_quantity
|
||||
assert coords_or_quantity
|
||||
|
||||
if item_to_spawns := max(0, (coords_or_quantity - len(self))):
|
||||
return super().trigger_spawn(state,
|
||||
*entity_args,
|
||||
coords_or_quantity=item_to_spawns,
|
||||
**entity_kwargs)
|
||||
else:
|
||||
state.print('No Items are spawning, limit is reached.')
|
||||
return Result(identifier=f'{self.name}_spawn', validity=c.NOT_VALID, value=coords_or_quantity)
|
||||
|
||||
|
||||
class Inventory(IsBoundMixin, Collection):
|
||||
_accepted_objects = Item
|
||||
|
||||
@property
|
||||
def var_can_be_bound(self):
|
||||
return True
|
||||
|
||||
@property
|
||||
def obs_tag(self):
|
||||
return self.name
|
||||
|
||||
@property
|
||||
def name(self):
|
||||
return f'{self.__class__.__name__}[{self._bound_entity.name}]'
|
||||
|
||||
def __init__(self, agent, *args, **kwargs):
|
||||
"""
|
||||
An inventory that can hold items picked up by the agent it is bound to.
|
||||
|
||||
:param agent: The agent this inventory is bound to and belongs to.
|
||||
:type agent: Agent
|
||||
"""
|
||||
super(Inventory, self).__init__(*args, **kwargs)
|
||||
self._collection = None
|
||||
self.bind(agent)
|
||||
|
||||
def __repr__(self):
|
||||
return f'{self.__class__.__name__}#{self._bound_entity.name}({dict(self._data)})'
|
||||
|
||||
def summarize_states(self, **kwargs):
|
||||
attr_dict = {key: val for key, val in self.__dict__.items() if not key.startswith('_') and key != 'data'}
|
||||
attr_dict.update(dict(items=[val.summarize_state(**kwargs) for key, val in self.items()]))
|
||||
attr_dict.update(dict(name=self.name, belongs_to=self._bound_entity.name))
|
||||
return attr_dict
|
||||
|
||||
def pop(self) -> Item:
|
||||
"""
|
||||
Removes and returns the first item in the inventory.
|
||||
"""
|
||||
item_to_pop = self[0]
|
||||
self.delete_env_object(item_to_pop)
|
||||
return item_to_pop
|
||||
|
||||
def set_collection(self, collection):
|
||||
"""
|
||||
No usage
|
||||
"""
|
||||
self._collection = collection
|
||||
|
||||
def clear_temp_state(self):
|
||||
"""
|
||||
Entities need this, but inventories have no state.
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
class Inventories(Objects):
|
||||
_entity = Inventory
|
||||
symbol = None
|
||||
|
||||
@property
|
||||
def var_can_move(self):
|
||||
return False
|
||||
|
||||
@property
|
||||
def var_has_position(self):
|
||||
return False
|
||||
|
||||
@property
|
||||
def spawn_rule(self) -> dict[Any, dict[str, Any]]:
|
||||
"""
|
||||
:returns: a dict containing the specified spawn rule and its arguments.
|
||||
:rtype: dict(dict(collection=self, coords_or_quantity=None))
|
||||
"""
|
||||
return {c.SPAWN_ENTITY_RULE: dict(collection=self, coords_or_quantity=None)}
|
||||
|
||||
def __init__(self, size: int, *args, **kwargs):
|
||||
"""
|
||||
A collection of all inventories used to spawn an inventory per agent.
|
||||
"""
|
||||
super(Inventories, self).__init__(*args, **kwargs)
|
||||
self.size = size
|
||||
self._obs = None
|
||||
self._lazy_eval_transforms = []
|
||||
|
||||
def spawn(self, agents, *args, **kwargs) -> [Result]:
|
||||
self.add_items([self._entity(agent, self.size, *args, **kwargs) for _, agent in enumerate(agents)])
|
||||
return [Result(identifier=f'{self.name}_spawn', validity=c.VALID, value=len(self))]
|
||||
|
||||
def trigger_spawn(self, state, *args, **kwargs) -> [Result]:
|
||||
return self.spawn(state[c.AGENT], *args, **kwargs)
|
||||
|
||||
def by_entity(self, entity):
|
||||
try:
|
||||
return next((inv for inv in self if inv.belongs_to_entity(entity)))
|
||||
except StopIteration:
|
||||
return None
|
||||
|
||||
def summarize_states(self, **kwargs):
|
||||
return [val.summarize_states(**kwargs) for key, val in self.items()]
|
||||
|
||||
|
||||
class DropOffLocations(Collection):
|
||||
_entity = DropOffLocation
|
||||
|
||||
@property
|
||||
def var_is_blocking_light(self):
|
||||
return False
|
||||
|
||||
@property
|
||||
def var_can_collide(self):
|
||||
return False
|
||||
|
||||
@property
|
||||
def var_can_move(self):
|
||||
return False
|
||||
|
||||
@property
|
||||
def var_has_position(self):
|
||||
return True
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
"""
|
||||
A Collection of Drop-off locations that can trigger their spawn.
|
||||
"""
|
||||
super(DropOffLocations, self).__init__(*args, **kwargs)
|
||||
|
||||
@staticmethod
|
||||
def trigger_drop_off_location_spawn(state, n_locations):
|
||||
empty_positions = state.entities.empty_positions[:n_locations]
|
||||
do_entites = state[i.DROP_OFF]
|
||||
drop_offs = [DropOffLocation(pos) for pos in empty_positions]
|
||||
do_entites.add_items(drop_offs)
|
||||
@@ -1,44 +0,0 @@
|
||||
from typing import List
|
||||
|
||||
from marl_factory_grid.environment.rules import Rule
|
||||
from marl_factory_grid.environment import constants as c
|
||||
from marl_factory_grid.utils.results import TickResult
|
||||
from marl_factory_grid.modules.items import constants as i
|
||||
|
||||
|
||||
class RespawnItems(Rule):
|
||||
|
||||
def __init__(self, n_items: int = 5, respawn_freq: int = 15, n_locations: int = 5):
|
||||
"""
|
||||
Defines the respawning behaviour of items.
|
||||
|
||||
:param n_items: Specifies how many items should respawn.
|
||||
:type n_items: int
|
||||
:param respawn_freq: Specifies how often items should respawn.
|
||||
:type respawn_freq: int
|
||||
:param n_locations: Specifies at how many locations items should be able to respawn.
|
||||
:type: int
|
||||
"""
|
||||
super().__init__()
|
||||
self.spawn_frequency = respawn_freq
|
||||
self._next_item_spawn = respawn_freq
|
||||
self.n_items = n_items
|
||||
self.n_locations = n_locations
|
||||
|
||||
def tick_step(self, state):
|
||||
if not self._next_item_spawn:
|
||||
state[i.ITEM].trigger_spawn(state, self.n_items, self.spawn_frequency)
|
||||
else:
|
||||
self._next_item_spawn = max(0, self._next_item_spawn - 1)
|
||||
return []
|
||||
|
||||
def tick_post_step(self, state) -> List[TickResult]:
|
||||
if not self._next_item_spawn:
|
||||
if spawned_items := state[i.ITEM].trigger_spawn(state, self.n_items, self.spawn_frequency):
|
||||
return [TickResult(self.name, validity=c.VALID, value=spawned_items.value)]
|
||||
else:
|
||||
return [TickResult(self.name, validity=c.NOT_VALID, value=0)]
|
||||
else:
|
||||
self._next_item_spawn = max(0, self._next_item_spawn-1)
|
||||
return []
|
||||
|
||||
@@ -1,10 +0,0 @@
|
||||
from .entitites import Machine
|
||||
from .groups import Machines
|
||||
|
||||
"""
|
||||
machines
|
||||
========
|
||||
|
||||
Todo
|
||||
|
||||
"""
|
||||
@@ -1,25 +0,0 @@
|
||||
from typing import Union
|
||||
|
||||
from marl_factory_grid.environment import constants as c
|
||||
from marl_factory_grid.environment.actions import Action
|
||||
from marl_factory_grid.modules.machines import constants as m
|
||||
from marl_factory_grid.utils import helpers as h
|
||||
from marl_factory_grid.utils.results import ActionResult
|
||||
|
||||
|
||||
class MachineAction(Action):
|
||||
|
||||
def __init__(self):
|
||||
"""
|
||||
When performing this action, the maintainer attempts to maintain the machine at his current position, returning
|
||||
an action result if successful.
|
||||
"""
|
||||
super().__init__(m.MACHINE_ACTION, m.MAINTAIN_VALID, m.MAINTAIN_FAIL)
|
||||
|
||||
def do(self, entity, state) -> Union[None, ActionResult]:
|
||||
if machine := h.get_first(state[m.MACHINES].by_pos(entity.pos)):
|
||||
valid = machine.maintain()
|
||||
return self.get_result(valid, entity)
|
||||
|
||||
else:
|
||||
return self.get_result(c.NOT_VALID, entity)
|
||||
@@ -1,17 +0,0 @@
|
||||
|
||||
MACHINES = 'Machines'
|
||||
MACHINE = 'Machine'
|
||||
|
||||
MACHINE_ACTION = 'Maintain'
|
||||
|
||||
STATE_WORK = 'working'
|
||||
STATE_IDLE = 'idling'
|
||||
STATE_MAINTAIN = 'maintenance'
|
||||
|
||||
SYMBOL_WORK = 1
|
||||
SYMBOL_IDLE = 0.6
|
||||
SYMBOL_MAINTAIN = 0.3
|
||||
MAINTAIN_VALID: float = 0.5
|
||||
MAINTAIN_FAIL: float = -0.1
|
||||
FAIL_MISSING_MAINTENANCE: float = -0.5
|
||||
NONE: float = 0
|
||||
@@ -1,79 +0,0 @@
|
||||
from marl_factory_grid.environment.entity.entity import Entity
|
||||
from ...utils.utility_classes import RenderEntity
|
||||
from marl_factory_grid.environment import constants as c
|
||||
from marl_factory_grid.utils.results import TickResult
|
||||
|
||||
from . import constants as m
|
||||
|
||||
|
||||
class Machine(Entity):
|
||||
|
||||
@property
|
||||
def encoding(self):
|
||||
return self._encodings[self.status]
|
||||
|
||||
def __init__(self, *args, work_interval: int = 10, pause_interval: int = 15, **kwargs):
|
||||
"""
|
||||
Represents a machine entity that the maintainer will try to maintain by performing the maintenance action.
|
||||
Machines' health depletes over time.
|
||||
|
||||
:param work_interval: How long should the machine work before pausing.
|
||||
:type work_interval: int
|
||||
:param pause_interval: How long should the machine pause before continuing to work.
|
||||
:type pause_interval: int
|
||||
"""
|
||||
super(Machine, self).__init__(*args, **kwargs)
|
||||
self._intervals = dict({m.STATE_IDLE: pause_interval, m.STATE_WORK: work_interval})
|
||||
self._encodings = dict({m.STATE_IDLE: pause_interval, m.STATE_WORK: work_interval})
|
||||
|
||||
self.status = m.STATE_IDLE
|
||||
self.health = 100
|
||||
self._counter = 0
|
||||
|
||||
def maintain(self) -> bool:
|
||||
"""
|
||||
Attempts to maintain the machine by increasing its health, which is only possible if the machine is at a maximum
|
||||
of 98/100 HP.
|
||||
"""
|
||||
if self.status == m.STATE_WORK:
|
||||
return c.NOT_VALID
|
||||
if self.health <= 98:
|
||||
self.health = 100
|
||||
return c.VALID
|
||||
else:
|
||||
return c.NOT_VALID
|
||||
|
||||
def tick(self, state):
|
||||
"""
|
||||
Updates the machine's mode (work, pause) depending on its current counter and whether an agent is currently on
|
||||
its position. If no agent is standing on the machine's position, it decrements its own health.
|
||||
|
||||
:param state: The current game state.
|
||||
:type state: GameState
|
||||
:return: The result of the tick operation on the machine.
|
||||
:rtype: TickResult | None
|
||||
"""
|
||||
others = state.entities.pos_dict[self.pos]
|
||||
if self.status == m.STATE_MAINTAIN and any([c.AGENT in x.name for x in others]):
|
||||
return TickResult(identifier=self.name, validity=c.VALID, entity=self)
|
||||
elif self.status == m.STATE_MAINTAIN and not any([c.AGENT in x.name for x in others]):
|
||||
self.status = m.STATE_WORK
|
||||
self.reset_counter()
|
||||
return None
|
||||
elif self._counter:
|
||||
self._counter -= 1
|
||||
self.health -= 1
|
||||
return None
|
||||
else:
|
||||
self.status = m.STATE_WORK if self.status == m.STATE_IDLE else m.STATE_IDLE
|
||||
self.reset_counter()
|
||||
return None
|
||||
|
||||
def reset_counter(self):
|
||||
"""
|
||||
Internal Usage
|
||||
"""
|
||||
self._counter = self._intervals[self.status]
|
||||
|
||||
def render(self):
|
||||
return RenderEntity(m.MACHINE, self.pos)
|
||||
@@ -1,27 +0,0 @@
|
||||
from marl_factory_grid.environment.groups.collection import Collection
|
||||
|
||||
from .entitites import Machine
|
||||
|
||||
|
||||
class Machines(Collection):
|
||||
|
||||
_entity = Machine
|
||||
|
||||
@property
|
||||
def var_can_collide(self):
|
||||
return False
|
||||
|
||||
@property
|
||||
def var_is_blocking_light(self):
|
||||
return False
|
||||
|
||||
@property
|
||||
def var_has_position(self):
|
||||
return True
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
"""
|
||||
A Collection of Machines.
|
||||
"""
|
||||
super(Machines, self).__init__(*args, **kwargs)
|
||||
|
||||
Binary file not shown.
|
Before Width: | Height: | Size: 8.5 KiB |
@@ -1,9 +0,0 @@
|
||||
from .entities import Maintainer
|
||||
from .groups import Maintainers
|
||||
"""
|
||||
maintenance
|
||||
===========
|
||||
|
||||
Todo
|
||||
|
||||
"""
|
||||
@@ -1,4 +0,0 @@
|
||||
MAINTAINER = 'Maintainer' # TEMPLATE _identifier. Define your own!
|
||||
MAINTAINERS = 'Maintainers' # TEMPLATE _identifier. Define your own!
|
||||
|
||||
MAINTAINER_COLLISION_REWARD = -5
|
||||
@@ -1,139 +0,0 @@
|
||||
from random import shuffle
|
||||
|
||||
import networkx as nx
|
||||
import numpy as np
|
||||
|
||||
from ...environment import constants as c
|
||||
from ...environment.actions import Action, ALL_BASEACTIONS
|
||||
from ...environment.entity.entity import Entity
|
||||
from ..doors import constants as do
|
||||
from ..maintenance import constants as mi
|
||||
from ...utils import helpers as h
|
||||
from ...utils.utility_classes import RenderEntity, Floor
|
||||
from ..doors import DoorUse
|
||||
|
||||
|
||||
class Maintainer(Entity):
|
||||
|
||||
def __init__(self, objective, action, *args, **kwargs):
|
||||
self.action_ = """
|
||||
Represents the maintainer entity that aims to maintain machines. The maintainer calculates its route using nx
|
||||
shortest path and restores the health of machines it visits to 100.
|
||||
|
||||
:param objective: The maintainer's objective, e.g., "Machines".
|
||||
:type objective: str
|
||||
:param action: The default action to be performed by the maintainer.
|
||||
:type action: Action
|
||||
"""
|
||||
super().__init__(*args, **kwargs)
|
||||
self.action = action
|
||||
self.actions = [x() for x in ALL_BASEACTIONS] + [DoorUse()]
|
||||
self.objective = objective
|
||||
self._path = None
|
||||
self._next = []
|
||||
self._last = []
|
||||
self._last_serviced = 'None'
|
||||
|
||||
def tick(self, state):
|
||||
"""
|
||||
If there is an objective at the current position, the maintainer performs its action on the objective.
|
||||
If the objective has changed since the last servicing, the maintainer performs the action and updates
|
||||
the last serviced objective. Otherwise, it calculates a move action and performs it.
|
||||
|
||||
:param state: The current game state.
|
||||
:type state: GameState
|
||||
:return: The result of the action performed by the maintainer.
|
||||
:rtype: ActionResult
|
||||
"""
|
||||
if found_objective := h.get_first(state[self.objective].by_pos(self.pos)):
|
||||
if found_objective.name != self._last_serviced:
|
||||
result = self.action.do(self, state)
|
||||
self._last_serviced = found_objective.name
|
||||
else:
|
||||
action = self.get_move_action(state)
|
||||
result = action.do(self, state)
|
||||
else:
|
||||
action = self.get_move_action(state)
|
||||
result = action.do(self, state)
|
||||
self.set_state(result)
|
||||
return result
|
||||
|
||||
def set_state(self, action_result):
|
||||
"""
|
||||
Updates the maintainers own status with an action result.
|
||||
"""
|
||||
self._status = action_result
|
||||
|
||||
def get_move_action(self, state) -> Action:
|
||||
"""
|
||||
Retrieves the next move action for the agent.
|
||||
|
||||
If a path is not already determined, the agent calculates the shortest path to its objective, considering doors
|
||||
and obstacles. If a closed door is found in the calculated path, the agent attempts to open it.
|
||||
|
||||
:param state: The current state of the environment.
|
||||
:type state: GameState
|
||||
|
||||
:return: The chosen move action for the agent.
|
||||
:rtype: Action
|
||||
"""
|
||||
if self._path is None or not len(self._path):
|
||||
if not self._next:
|
||||
self._next = list(state[self.objective].values()) + [Floor(*state.random_free_position)]
|
||||
shuffle(self._next)
|
||||
self._last = []
|
||||
self._last.append(self._next.pop())
|
||||
state.print("Calculating shortest path....")
|
||||
self._path = self.calculate_route(self._last[-1], state.floortile_graph)
|
||||
if not self._path:
|
||||
self._last.append(self._next.pop())
|
||||
state.print("Calculating shortest path.... Again....")
|
||||
self._path = self.calculate_route(self._last[-1], state.floortile_graph)
|
||||
|
||||
if door := self._closed_door_in_path(state):
|
||||
state.print(f"{self} found {door} that is closed. Attempt to open.")
|
||||
# Translate the action_object to an integer to have the same output as any other model
|
||||
action = do.ACTION_DOOR_USE
|
||||
else:
|
||||
action = self._predict_move(state)
|
||||
# Translate the action_object to an integer to have the same output as any other model
|
||||
try:
|
||||
action_obj = h.get_first(self.actions, lambda x: x.name == action)
|
||||
except (StopIteration, UnboundLocalError):
|
||||
print('Will not happen')
|
||||
raise EnvironmentError
|
||||
return action_obj
|
||||
|
||||
def calculate_route(self, entity, floortile_graph) -> list:
|
||||
"""
|
||||
:returns: path, include both the source and target position
|
||||
:rtype: list
|
||||
"""
|
||||
route = nx.shortest_path(floortile_graph, self.pos, entity.pos)
|
||||
return route[1:]
|
||||
|
||||
def _closed_door_in_path(self, state):
|
||||
"""
|
||||
Internal Use
|
||||
"""
|
||||
if self._path:
|
||||
return h.get_first(state[do.DOORS].by_pos(self._path[0]), lambda x: x.is_closed)
|
||||
else:
|
||||
return None
|
||||
|
||||
def _predict_move(self, state) -> Action:
|
||||
"""
|
||||
Internal Use
|
||||
"""
|
||||
next_pos = self._path[0]
|
||||
if any(x for x in state.entities.pos_dict[next_pos] if x.var_can_collide) > 0:
|
||||
action = c.NOOP
|
||||
else:
|
||||
next_pos = self._path.pop(0)
|
||||
diff = np.subtract(next_pos, self.pos)
|
||||
# Retrieve action based on the pos dif (like in: What do I have to do to get there?)
|
||||
action = next(action for action, pos_diff in h.MOVEMAP.items() if np.all(diff == pos_diff))
|
||||
return action
|
||||
|
||||
def render(self):
|
||||
return RenderEntity(mi.MAINTAINER, self.pos)
|
||||
@@ -1,35 +0,0 @@
|
||||
from typing import Union, List, Tuple, Dict
|
||||
|
||||
from marl_factory_grid.environment.groups.collection import Collection
|
||||
from .entities import Maintainer
|
||||
from ..machines import constants as mc
|
||||
from ..machines.actions import MachineAction
|
||||
|
||||
|
||||
class Maintainers(Collection):
|
||||
_entity = Maintainer
|
||||
|
||||
@property
|
||||
def var_can_collide(self):
|
||||
return True
|
||||
|
||||
@property
|
||||
def var_can_move(self):
|
||||
return True
|
||||
|
||||
@property
|
||||
def var_is_blocking_light(self):
|
||||
return False
|
||||
|
||||
@property
|
||||
def var_has_position(self):
|
||||
return True
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
"""
|
||||
A collection of maintainers that is used to spawn them.
|
||||
"""
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
def spawn(self, coords_or_quantity: Union[int, List[Tuple[(int, int)]]], *entity_args):
|
||||
self.add_items([self._entity(mc.MACHINES, MachineAction(), pos) for pos in coords_or_quantity])
|
||||
Binary file not shown.
|
Before Width: | Height: | Size: 22 KiB |
@@ -1,41 +0,0 @@
|
||||
from typing import List
|
||||
|
||||
from marl_factory_grid.environment.rules import Rule
|
||||
from marl_factory_grid.utils.results import TickResult, DoneResult
|
||||
from marl_factory_grid.environment import constants as c
|
||||
from . import constants as M
|
||||
|
||||
|
||||
class MoveMaintainers(Rule):
|
||||
|
||||
def __init__(self):
|
||||
"""
|
||||
This rule is responsible for moving the maintainers at every step of the environment.
|
||||
"""
|
||||
super().__init__()
|
||||
|
||||
def tick_step(self, state) -> List[TickResult]:
|
||||
move_results = []
|
||||
for maintainer in state[M.MAINTAINERS]:
|
||||
result = maintainer.tick(state)
|
||||
move_results.append(result)
|
||||
return move_results
|
||||
|
||||
|
||||
class DoneAtMaintainerCollision(Rule):
|
||||
|
||||
def __init__(self):
|
||||
"""
|
||||
When active, this rule stops the environment after a maintainer reports a collision with another entity.
|
||||
"""
|
||||
super().__init__()
|
||||
|
||||
def on_check_done(self, state) -> List[DoneResult]:
|
||||
agents = list(state[c.AGENT].values())
|
||||
m_pos = state[M.MAINTAINERS].positions
|
||||
done_results = []
|
||||
for agent in agents:
|
||||
if agent.pos in m_pos:
|
||||
done_results.append(DoneResult(entity=agent, validity=c.VALID, identifier=self.name,
|
||||
reward=M.MAINTAINER_COLLISION_REWARD))
|
||||
return done_results
|
||||
@@ -1,19 +0,0 @@
|
||||
import os
|
||||
import shutil
|
||||
from pathlib import Path
|
||||
|
||||
from marl_factory_grid.utils.tools import ConfigExplainer
|
||||
|
||||
|
||||
def init():
|
||||
print('Retrieving available options...')
|
||||
ce = ConfigExplainer()
|
||||
cwd = Path(os.getcwd())
|
||||
ce.save_all(cwd / 'full_config.yaml')
|
||||
template_path = Path(__file__).parent / 'modules' / '_template'
|
||||
print(f'Available config options saved to: {(cwd / "full_config.yaml").resolve()}')
|
||||
print('-----------------------------')
|
||||
print(f'Copying Templates....')
|
||||
shutil.copytree(template_path, cwd)
|
||||
print(f'Templates copied to {cwd}"/"{template_path.name}')
|
||||
print(':wave:')
|
||||
@@ -1,7 +0,0 @@
|
||||
"""
|
||||
logging
|
||||
=======
|
||||
|
||||
Todo
|
||||
|
||||
"""
|
||||
@@ -1,74 +0,0 @@
|
||||
import pickle
|
||||
from os import PathLike
|
||||
from pathlib import Path
|
||||
from typing import Union
|
||||
|
||||
from gymnasium import Wrapper
|
||||
|
||||
from marl_factory_grid.utils.helpers import IGNORED_DF_COLUMNS
|
||||
|
||||
import pandas as pd
|
||||
|
||||
from marl_factory_grid.utils.plotting.plot_single_runs import plot_single_run
|
||||
|
||||
|
||||
class EnvMonitor(Wrapper):
|
||||
|
||||
ext = 'png'
|
||||
|
||||
def __init__(self, env, filepath: Union[str, PathLike] = None):
|
||||
"""
|
||||
EnvMonitor is a wrapper for Gymnasium environments that monitors and logs key information during interactions.
|
||||
"""
|
||||
super(EnvMonitor, self).__init__(env)
|
||||
self._filepath = filepath
|
||||
self._monitor_df = pd.DataFrame()
|
||||
self._monitor_dict = dict()
|
||||
|
||||
def step(self, action):
|
||||
obs_type, obs, reward, done, info = self.env.step(action)
|
||||
self._read_info(info)
|
||||
self._read_done(done)
|
||||
return obs_type, obs, reward, done, info
|
||||
|
||||
def reset(self):
|
||||
return self.env.reset()
|
||||
|
||||
def _read_info(self, info: dict):
|
||||
self._monitor_dict[len(self._monitor_dict)] = {
|
||||
key: val for key, val in info.items() if
|
||||
key not in ['terminal_observation', 'episode']}
|
||||
return
|
||||
|
||||
def _read_done(self, done):
|
||||
if done:
|
||||
env_monitor_df = pd.DataFrame.from_dict(self._monitor_dict, orient='index')
|
||||
self._monitor_dict = dict()
|
||||
columns = [col for col in env_monitor_df.columns if col not in IGNORED_DF_COLUMNS]
|
||||
env_monitor_df = env_monitor_df.aggregate(
|
||||
{col: 'mean' if col.endswith('ount') else 'sum' for col in columns}
|
||||
)
|
||||
env_monitor_df['episode'] = len(self._monitor_df)
|
||||
self._monitor_df = pd.concat([self._monitor_df, pd.DataFrame([env_monitor_df])], ignore_index=True)
|
||||
else:
|
||||
pass
|
||||
return
|
||||
|
||||
def save_monitor(self, filepath: Union[Path, str, None] = None, auto_plotting_keys=None):
|
||||
"""
|
||||
Saves the monitoring data to a file and optionally generates plots.
|
||||
|
||||
:param filepath: The path to save the monitoring data file.
|
||||
:type filepath: Union[Path, str, None]
|
||||
:param auto_plotting_keys: Keys to use for automatic plot generation.
|
||||
:type auto_plotting_keys: Any
|
||||
"""
|
||||
filepath = Path(filepath or self._filepath)
|
||||
filepath.parent.mkdir(exist_ok=True, parents=True)
|
||||
with filepath.open('wb') as f:
|
||||
pickle.dump(self._monitor_df.reset_index(), f, protocol=pickle.HIGHEST_PROTOCOL)
|
||||
if auto_plotting_keys:
|
||||
plot_single_run(filepath, column_keys=auto_plotting_keys)
|
||||
|
||||
def report_possible_colum_keys(self):
|
||||
print(self._monitor_df.columns)
|
||||
@@ -1,190 +0,0 @@
|
||||
from os import PathLike
|
||||
from pathlib import Path
|
||||
from typing import Union, List
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from gymnasium import Wrapper
|
||||
|
||||
|
||||
class EnvRecorder(Wrapper):
|
||||
|
||||
def __init__(self, env, filepath: Union[str, PathLike] = None,
|
||||
episodes: Union[List[int], None] = None):
|
||||
"""
|
||||
EnvRecorder is a wrapper for OpenAI Gym environments that records state summaries during interactions.
|
||||
|
||||
:param env: The environment to record.
|
||||
:type env: gym.Env
|
||||
:param filepath: The path to save the recording data file.
|
||||
:type filepath: Union[str, PathLike]
|
||||
:param episodes: A list of episode numbers to record. If None, records all episodes.
|
||||
:type episodes: Union[List[int], None]
|
||||
"""
|
||||
super(EnvRecorder, self).__init__(env)
|
||||
self.filepath = filepath
|
||||
self.episodes = episodes
|
||||
self._curr_episode = 0
|
||||
self._curr_ep_recorder = list()
|
||||
self._recorder_out_list = list()
|
||||
|
||||
def reset(self):
|
||||
"""
|
||||
Overrides the reset method to reset the environment and recording lists.
|
||||
"""
|
||||
self._curr_ep_recorder = list()
|
||||
self._recorder_out_list = list()
|
||||
self._curr_episode += 1
|
||||
return self.env.reset()
|
||||
|
||||
def step(self, actions):
|
||||
"""
|
||||
Overrides the step method to record state summaries during each step.
|
||||
|
||||
:param actions: The actions taken in the environment.
|
||||
:type actions: Any
|
||||
:return: The observation, reward, done flag, and additional information.
|
||||
:rtype: Tuple
|
||||
"""
|
||||
obs_type, obs, reward, done, info = self.env.step(actions)
|
||||
if not self.episodes or self._curr_episode in self.episodes:
|
||||
summary: dict = self.env.unwrapped.summarize_state()
|
||||
# summary.update(done=done)
|
||||
# summary.update({'episode': self._curr_episode})
|
||||
# TODO Protobuff Adjustments ######
|
||||
# summary.update(info)
|
||||
self._curr_ep_recorder.append(summary)
|
||||
if done:
|
||||
self._recorder_out_list.append({'steps': self._curr_ep_recorder,
|
||||
'episode_nr': self._curr_episode})
|
||||
self._curr_ep_recorder = list()
|
||||
return obs_type, obs, reward, done, info
|
||||
|
||||
def _finalize(self):
|
||||
if self._curr_ep_recorder:
|
||||
self._recorder_out_list.append({'steps': self._curr_ep_recorder.copy(),
|
||||
'episode_nr': len(self._recorder_out_list)})
|
||||
|
||||
def save_records(self, filepath: Union[Path, str, None] = None,
|
||||
only_deltas=False,
|
||||
save_occupation_map=False,
|
||||
save_trajectory_map=False,
|
||||
):
|
||||
"""
|
||||
Saves the recorded data to a file.
|
||||
|
||||
:param filepath: The path to save the recording data file.
|
||||
:type filepath: Union[Path, str, None]
|
||||
:param only_deltas: If True, saves only the differences between consecutive episodes.
|
||||
:type only_deltas: bool
|
||||
:param save_occupation_map: If True, saves an occupation map as a heatmap.
|
||||
:type save_occupation_map: bool
|
||||
:param save_trajectory_map: If True, saves a trajectory map.
|
||||
:type save_trajectory_map: bool
|
||||
"""
|
||||
self._finalize()
|
||||
filepath = Path(filepath or self.filepath)
|
||||
filepath.parent.mkdir(exist_ok=True, parents=True)
|
||||
# cls.out_file.unlink(missing_ok=True)
|
||||
with filepath.open('wb') as f:
|
||||
if only_deltas:
|
||||
from deepdiff import DeepDiff
|
||||
diff_dict = [DeepDiff(t1, t2, ignore_order=True)
|
||||
for t1, t2 in zip(self._recorder_out_list, self._recorder_out_list[1:])
|
||||
]
|
||||
out_dict = {'episodes': diff_dict}
|
||||
|
||||
else:
|
||||
# TODO Protobuff Adjustments Revert
|
||||
dest_prop = dict(
|
||||
n_dests=0,
|
||||
dwell_time=0,
|
||||
spawn_frequency=0,
|
||||
spawn_mode=''
|
||||
)
|
||||
rewards_dest = dict(
|
||||
WAIT_VALID=0.00,
|
||||
WAIT_FAIL=0.00,
|
||||
DEST_REACHED=0.00,
|
||||
)
|
||||
mv_prop = dict(
|
||||
allow_square_movement=False,
|
||||
allow_diagonal_movement=False,
|
||||
allow_no_op=False,
|
||||
)
|
||||
obs_prop = dict(
|
||||
render_agents='',
|
||||
omit_agent_self=False,
|
||||
additional_agent_placeholder=0,
|
||||
cast_shadows=False,
|
||||
frames_to_stack=0,
|
||||
pomdp_r=self.env.params['General']['pomdp_r'],
|
||||
indicate_door_area=False,
|
||||
show_global_position_info=False,
|
||||
|
||||
)
|
||||
rewards_base = dict(
|
||||
MOVEMENTS_VALID=0.00,
|
||||
MOVEMENTS_FAIL=0.00,
|
||||
NOOP=0.00,
|
||||
USE_DOOR_VALID=0.00,
|
||||
USE_DOOR_FAIL=0.00,
|
||||
COLLISION=0.00,
|
||||
|
||||
)
|
||||
|
||||
out_dict = {'episodes': self._recorder_out_list}
|
||||
out_dict.update(
|
||||
{'n_episodes': self._curr_episode,
|
||||
'metadata': dict(
|
||||
level_name=self.env.params['General']['level_name'],
|
||||
verbose=False,
|
||||
n_agents=len(self.env.params['Agents']),
|
||||
max_steps=100,
|
||||
done_at_collision=False,
|
||||
parse_doors=True,
|
||||
doors_have_area=False,
|
||||
individual_rewards=True,
|
||||
class_name='Where does this end up?',
|
||||
env_seed=69,
|
||||
|
||||
dest_prop=dest_prop,
|
||||
rewards_dest=rewards_dest,
|
||||
mv_prop=mv_prop,
|
||||
obs_prop=obs_prop,
|
||||
rewards_base=rewards_base,
|
||||
),
|
||||
# 'env_params': self.env.params,
|
||||
'header': self.env.summarize_header()
|
||||
})
|
||||
try:
|
||||
from marl_factory_grid.utils.proto import fiksProto_pb2
|
||||
from google.protobuf import json_format
|
||||
|
||||
bulk = fiksProto_pb2.Bulk()
|
||||
json_format.ParseDict(out_dict, bulk)
|
||||
f.write(bulk.SerializeToString())
|
||||
# yaml.dump(out_dict, f, indent=4)
|
||||
except TypeError:
|
||||
print('Shit')
|
||||
print('done')
|
||||
|
||||
if save_occupation_map:
|
||||
a = np.zeros((15, 15))
|
||||
# noinspection PyTypeChecker
|
||||
for episode in out_dict['episodes']:
|
||||
df = pd.DataFrame([y for x in episode['steps'] for y in x['Agents']])
|
||||
|
||||
b = list(df[['x', 'y']].to_records(index=False))
|
||||
|
||||
np.add.at(a, tuple(zip(*b)), 1)
|
||||
|
||||
# a = np.rot90(a)
|
||||
import seaborn as sns
|
||||
from matplotlib import pyplot as plt
|
||||
hm = sns.heatmap(data=a)
|
||||
hm.set_title('Very Nice Heatmap')
|
||||
plt.show()
|
||||
|
||||
if save_trajectory_map:
|
||||
raise NotImplementedError('This has not yet been implemented.')
|
||||
@@ -1,201 +0,0 @@
|
||||
import pickle
|
||||
import re
|
||||
from os import PathLike
|
||||
from pathlib import Path
|
||||
from typing import Union, List
|
||||
|
||||
import pandas as pd
|
||||
|
||||
from marl_factory_grid.utils.helpers import IGNORED_DF_COLUMNS
|
||||
from marl_factory_grid.utils.plotting.plotting_utils import prepare_plot
|
||||
|
||||
MODEL_MAP = None
|
||||
|
||||
|
||||
def compare_seed_runs(run_path: Union[str, PathLike], use_tex: bool = False):
|
||||
"""
|
||||
|
||||
Compare multiple runs with different seeds by generating a line plot that shows the evolution of scores (step rewards)
|
||||
across episodes.
|
||||
|
||||
:param run_path: The path to the directory containing the monitor files for each run.
|
||||
:type run_path: Union[str, PathLike]
|
||||
:param use_tex: A boolean indicating whether to use TeX formatting in the plot.
|
||||
:type use_tex: bool
|
||||
"""
|
||||
run_path = Path(run_path)
|
||||
df_list = list()
|
||||
for run, monitor_file in enumerate(run_path.rglob('monitor*.pick')):
|
||||
with monitor_file.open('rb') as f:
|
||||
monitor_df = pickle.load(f)
|
||||
|
||||
monitor_df['run'] = run
|
||||
monitor_df = monitor_df.fillna(0)
|
||||
df_list.append(monitor_df)
|
||||
|
||||
df = pd.concat(df_list, ignore_index=True)
|
||||
df = df.fillna(0).rename(columns={'episode': 'Episode', 'run': 'Run'}).sort_values(['Run', 'Episode'])
|
||||
columns = [col for col in df.columns if col not in IGNORED_DF_COLUMNS]
|
||||
|
||||
roll_n = 50
|
||||
|
||||
non_overlapp_window = df.groupby(['Run', 'Episode']).rolling(roll_n, min_periods=1).mean()
|
||||
|
||||
df_melted = non_overlapp_window[columns].reset_index().melt(id_vars=['Episode', 'Run'],
|
||||
value_vars=columns, var_name="Measurement",
|
||||
value_name="Score")
|
||||
|
||||
if df_melted['Episode'].max() > 800:
|
||||
skip_n = round(df_melted['Episode'].max() * 0.02)
|
||||
df_melted = df_melted[df_melted['Episode'] % skip_n == 0]
|
||||
|
||||
run_path.mkdir(parents=True, exist_ok=True)
|
||||
if run_path.exists() and run_path.is_file():
|
||||
prepare_plot(run_path.parent / f'{run_path.name}_monitor_lineplot.png', df_melted, use_tex=use_tex)
|
||||
else:
|
||||
prepare_plot(run_path / f'{run_path.name}_monitor_lineplot.png', df_melted, use_tex=use_tex)
|
||||
print('Plotting done.')
|
||||
|
||||
|
||||
def compare_model_runs(run_path: Path, run_identifier: Union[str, int], parameter: Union[str, List[str]],
|
||||
use_tex: bool = False):
|
||||
"""
|
||||
Compares multiple model runs based on specified parameters by generating a line plot showing the evolution of scores (step rewards)
|
||||
across episodes.
|
||||
|
||||
:param run_path: The path to the directory containing the monitor files for each model run.
|
||||
:type run_path: Path
|
||||
:param run_identifier: A string or integer identifying the runs to compare.
|
||||
:type run_identifier: Union[str, int]
|
||||
:param parameter: A single parameter or a list of parameters to compare.
|
||||
:type parameter: Union[str, List[str]]
|
||||
:param use_tex: A boolean indicating whether to use TeX formatting in the plot.
|
||||
:type use_tex: bool
|
||||
"""
|
||||
run_path = Path(run_path)
|
||||
df_list = list()
|
||||
parameter = [parameter] if isinstance(parameter, str) else parameter
|
||||
for path in run_path.iterdir():
|
||||
if path.is_dir() and str(run_identifier) in path.name:
|
||||
for run, monitor_file in enumerate(path.rglob('monitor*.pick')):
|
||||
with monitor_file.open('rb') as f:
|
||||
monitor_df = pickle.load(f)
|
||||
|
||||
monitor_df['run'] = run
|
||||
monitor_df['model'] = next((x for x in path.name.split('_') if x in MODEL_MAP.keys()))
|
||||
monitor_df = monitor_df.fillna(0)
|
||||
df_list.append(monitor_df)
|
||||
|
||||
df = pd.concat(df_list, ignore_index=True)
|
||||
df = df.fillna(0).rename(columns={'episode': 'Episode', 'run': 'Run', 'model': 'Model'})
|
||||
columns = [col for col in df.columns if col in parameter]
|
||||
|
||||
last_episode_to_report = min(df.groupby(['Model'])['Episode'].max())
|
||||
df = df[df['Episode'] < last_episode_to_report]
|
||||
|
||||
roll_n = 40
|
||||
non_overlapp_window = df.groupby(['Model', 'Run', 'Episode']).rolling(roll_n, min_periods=1).mean()
|
||||
|
||||
df_melted = non_overlapp_window[columns].reset_index().melt(id_vars=['Episode', 'Run', 'Model'],
|
||||
value_vars=columns, var_name="Measurement",
|
||||
value_name="Score")
|
||||
|
||||
if df_melted['Episode'].max() > 80:
|
||||
skip_n = round(df_melted['Episode'].max() * 0.02)
|
||||
df_melted = df_melted[df_melted['Episode'] % skip_n == 0]
|
||||
|
||||
style = 'Measurement' if len(columns) > 1 else None
|
||||
prepare_plot(run_path / f'{run_identifier}_compare_{parameter}.png', df_melted, hue='Model', style=style,
|
||||
use_tex=use_tex)
|
||||
print('Plotting done.')
|
||||
|
||||
|
||||
def compare_all_parameter_runs(run_root_path: Path, parameter: Union[str, List[str]],
|
||||
param_names: Union[List[str], None] = None, str_to_ignore='', use_tex: bool = False):
|
||||
"""
|
||||
Compares model runs across different parameter settings by generating a line plot showing the evolution of scores across episodes.
|
||||
|
||||
:param run_root_path: The root path to the directory containing the monitor files for all model runs.
|
||||
:type run_root_path: Path
|
||||
:param parameter: The parameter(s) to compare across different runs.
|
||||
:type parameter: Union[str, List[str]]
|
||||
:param param_names: A list of custom names for the parameters to be used as labels in the plot. If None, default names will be assigned.
|
||||
:type param_names: Union[List[str], None]
|
||||
:param str_to_ignore: A string pattern to ignore in parameter names.
|
||||
:type str_to_ignore: str
|
||||
:param use_tex: A boolean indicating whether to use TeX formatting in the plot.
|
||||
:type use_tex: bool
|
||||
"""
|
||||
run_root_path = Path(run_root_path)
|
||||
df_list = list()
|
||||
parameter = [parameter] if isinstance(parameter, str) else parameter
|
||||
for monitor_idx, monitor_file in enumerate(run_root_path.rglob('monitor*.pick')):
|
||||
with monitor_file.open('rb') as f:
|
||||
monitor_df = pickle.load(f)
|
||||
|
||||
params = [x.name for x in monitor_file.parents if x.parent not in run_root_path.parents]
|
||||
if str_to_ignore:
|
||||
params = [re.sub(f'_*({str_to_ignore})', '', param) for param in params]
|
||||
|
||||
if monitor_idx == 0:
|
||||
if param_names is not None:
|
||||
if len(param_names) < len(params):
|
||||
# FIXME: Missing Seed Detection, see below @111
|
||||
param_names = [next(param_names) if param not in MODEL_MAP.keys() else 'Model' for param in params]
|
||||
elif len(param_names) == len(params):
|
||||
pass
|
||||
else:
|
||||
raise ValueError
|
||||
else:
|
||||
param_names = []
|
||||
for param_idx, param in enumerate(params):
|
||||
dtype = None
|
||||
if param in MODEL_MAP.keys():
|
||||
param_name = 'Model'
|
||||
elif '_' in param:
|
||||
param_split = param.split('_')
|
||||
if len(param_split) == 2 and any(split in MODEL_MAP.keys() for split in param_split):
|
||||
# Extract the seed
|
||||
param = int(next(x for x in param_split if x not in MODEL_MAP))
|
||||
param_name = 'Seed'
|
||||
dtype = int
|
||||
else:
|
||||
param_name = f'param_{param_idx}'
|
||||
else:
|
||||
param_name = f'param_{param_idx}'
|
||||
dtype = dtype if dtype is not None else str
|
||||
monitor_df[param_name] = str(param)
|
||||
monitor_df[param_name] = monitor_df[param_name].astype(dtype)
|
||||
if monitor_idx == 0:
|
||||
param_names.append(param_name)
|
||||
|
||||
monitor_df = monitor_df.fillna(0)
|
||||
df_list.append(monitor_df)
|
||||
|
||||
df = pd.concat(df_list, ignore_index=True)
|
||||
df = df.fillna(0).rename(columns={'episode': 'Episode'}).sort_values(['Episode'])
|
||||
|
||||
for param_name in param_names:
|
||||
df[param_name] = df[param_name].astype(str)
|
||||
columns = [col for col in df.columns if col in parameter]
|
||||
|
||||
last_episode_to_report = min(df.groupby(['Model'])['Episode'].max())
|
||||
df = df[df['Episode'] < last_episode_to_report]
|
||||
|
||||
if df['Episode'].max() > 80:
|
||||
skip_n = round(df['Episode'].max() * 0.02)
|
||||
df = df[df['Episode'] % skip_n == 0]
|
||||
combinations = [x for x in param_names if x not in ['Model', 'Seed']]
|
||||
df['Parameter Combination'] = df[combinations].apply(lambda row: '_'.join(row.values.astype(str)), axis=1)
|
||||
df.drop(columns=combinations, inplace=True)
|
||||
|
||||
# non_overlapp_window = df.groupby(param_names).sum()
|
||||
|
||||
df_melted = df.reset_index().melt(id_vars=['Parameter Combination', 'Episode'],
|
||||
value_vars=columns, var_name="Measurement",
|
||||
value_name="Score")
|
||||
|
||||
style = 'Measurement' if len(columns) > 1 else None
|
||||
prepare_plot(run_root_path / f'compare_{parameter}.png', df_melted, hue='Parameter Combination',
|
||||
style=style, use_tex=use_tex)
|
||||
print('Plotting done.')
|
||||
@@ -9,7 +9,6 @@ import numpy as np
|
||||
import pandas as pd
|
||||
|
||||
from marl_factory_grid.utils.helpers import IGNORED_DF_COLUMNS
|
||||
from marl_factory_grid.utils.plotting.plotting_utils import prepare_plot
|
||||
|
||||
from marl_factory_grid.utils.renderer import Renderer
|
||||
from marl_factory_grid.utils.utility_classes import RenderEntity
|
||||
@@ -17,59 +16,6 @@ from marl_factory_grid.utils.utility_classes import RenderEntity
|
||||
from marl_factory_grid.modules.clean_up import constants as d
|
||||
|
||||
|
||||
def plot_single_run(run_path: Union[str, PathLike], use_tex: bool = False, column_keys=None,
|
||||
file_key: str = 'monitor', file_ext: str = 'pkl'):
|
||||
"""
|
||||
Plots the Epoch score (step reward) over a single run based on monitoring data stored in a file.
|
||||
|
||||
:param run_path: The path to the directory containing monitoring data or directly to the monitoring file.
|
||||
:type run_path: Union[str, PathLike]
|
||||
:param use_tex: Flag indicating whether to use TeX for plotting.
|
||||
:type use_tex: bool, optional
|
||||
:param column_keys: Specific columns to include in the plot. If None, includes all columns except ignored ones.
|
||||
:type column_keys: list or None, optional
|
||||
:param file_key: The keyword to identify the monitoring file.
|
||||
:type file_key: str, optional
|
||||
:param file_ext: The extension of the monitoring file.
|
||||
:type file_ext: str, optional
|
||||
"""
|
||||
run_path = Path(run_path)
|
||||
df_list = list()
|
||||
if run_path.is_dir():
|
||||
monitor_file = next(run_path.glob(f'*{file_key}*.{file_ext}'))
|
||||
elif run_path.exists() and run_path.is_file():
|
||||
monitor_file = run_path
|
||||
else:
|
||||
raise ValueError
|
||||
|
||||
with monitor_file.open('rb') as f:
|
||||
monitor_df = pickle.load(f)
|
||||
|
||||
monitor_df = monitor_df.fillna(0)
|
||||
df_list.append(monitor_df)
|
||||
|
||||
df = pd.concat(df_list, ignore_index=True)
|
||||
df = df.fillna(0).rename(columns={'episode': 'Episode'}).sort_values(['Episode'])
|
||||
if column_keys is not None:
|
||||
columns = [col for col in column_keys if col in df.columns]
|
||||
else:
|
||||
columns = [col for col in df.columns if col not in IGNORED_DF_COLUMNS]
|
||||
|
||||
# roll_n = 50
|
||||
# non_overlapp_window = df.groupby(['Episode']).rolling(roll_n, min_periods=1).mean()
|
||||
|
||||
df_melted = df[columns + ['Episode']].reset_index().melt(
|
||||
id_vars=['Episode'], value_vars=columns, var_name="Measurement", value_name="Score"
|
||||
)
|
||||
|
||||
if df_melted['Episode'].max() > 800:
|
||||
skip_n = round(df_melted['Episode'].max() * 0.02)
|
||||
df_melted = df_melted[df_melted['Episode'] % skip_n == 0]
|
||||
|
||||
prepare_plot(run_path.parent / f'{run_path.parent.name}_monitor_lineplot.png', df_melted, use_tex=use_tex)
|
||||
print('Plotting done.')
|
||||
|
||||
|
||||
def plot_routes(factory, agents):
|
||||
"""
|
||||
Creates a plot of the agents' actions on the level map by creating a Renderer and Render Entities that hold the
|
||||
|
||||
@@ -1,151 +0,0 @@
|
||||
import seaborn as sns
|
||||
import matplotlib as mpl
|
||||
from matplotlib import pyplot as plt
|
||||
PALETTE = 10 * (
|
||||
"#377eb8",
|
||||
"#4daf4a",
|
||||
"#984ea3",
|
||||
"#e41a1c",
|
||||
"#ff7f00",
|
||||
"#a65628",
|
||||
"#f781bf",
|
||||
"#888888",
|
||||
"#a6cee3",
|
||||
"#b2df8a",
|
||||
"#cab2d6",
|
||||
"#fb9a99",
|
||||
"#fdbf6f",
|
||||
)
|
||||
|
||||
|
||||
def plot(filepath, ext='png'):
|
||||
"""
|
||||
Saves the current plot to a file and displays it.
|
||||
|
||||
:param filepath: The path to save the plot file.
|
||||
:type filepath: str
|
||||
:param ext: The file extension of the saved plot. Default is 'png'.
|
||||
:type ext: str
|
||||
"""
|
||||
plt.tight_layout()
|
||||
figure = plt.gcf()
|
||||
ax = plt.gca()
|
||||
legends = [c for c in ax.get_children() if isinstance(c, mpl.legend.Legend)]
|
||||
|
||||
if legends:
|
||||
figure.savefig(str(filepath), format=ext, bbox_extra_artists=(*legends,), bbox_inches='tight')
|
||||
else:
|
||||
figure.savefig(str(filepath), format=ext)
|
||||
|
||||
plt.show()
|
||||
plt.clf()
|
||||
|
||||
|
||||
def prepare_tex(df, hue, style, hue_order):
|
||||
"""
|
||||
Prepares a line plot for rendering in LaTeX.
|
||||
|
||||
:param df: The DataFrame containing the data to be plotted.
|
||||
:type df: pandas.DataFrame
|
||||
:param hue: Grouping variable that will produce lines with different colors.
|
||||
:type hue: str
|
||||
:param style: Grouping variable that will produce lines with different styles.
|
||||
:type style: str
|
||||
:param hue_order: Order for the levels of the hue variable in the plot.
|
||||
:type hue_order: list
|
||||
:return: The prepared line plot.
|
||||
:rtype: matplotlib.axes._subplots.AxesSubplot
|
||||
"""
|
||||
sns.set(rc={'text.usetex': True}, style='whitegrid')
|
||||
lineplot = sns.lineplot(data=df, x='Episode', y='Score', ci=95, palette=PALETTE,
|
||||
hue_order=hue_order, hue=hue, style=style)
|
||||
lineplot.set_title(f'{sorted(list(df["Measurement"].unique()))}')
|
||||
plt.legend(bbox_to_anchor=(1.02, 1), loc='upper left', borderaxespad=0)
|
||||
plt.tight_layout()
|
||||
return lineplot
|
||||
|
||||
|
||||
def prepare_plt(df, hue, style, hue_order):
|
||||
"""
|
||||
Prepares a line plot using matplotlib.
|
||||
|
||||
:param df: The DataFrame containing the data to be plotted.
|
||||
:type df: pandas.DataFrame
|
||||
:param hue: Grouping variable that will produce lines with different colors.
|
||||
:type hue: str
|
||||
:param style: Grouping variable that will produce lines with different styles.
|
||||
:type style: str
|
||||
:param hue_order: Order for the levels of the hue variable in the plot.
|
||||
:type hue_order: list
|
||||
:return: The prepared line plot.
|
||||
:rtype: matplotlib.axes._subplots.AxesSubplot
|
||||
"""
|
||||
print('Struggling to plot Figure using LaTeX - going back to normal.')
|
||||
plt.close('all')
|
||||
sns.set(rc={'text.usetex': False}, style='whitegrid')
|
||||
lineplot = sns.lineplot(data=df, x='Episode', y='Score', hue=hue, style=style,
|
||||
errorbar=('ci', 95), palette=PALETTE, hue_order=hue_order, )
|
||||
plt.legend(bbox_to_anchor=(1.02, 1), loc='upper left', borderaxespad=0)
|
||||
plt.tight_layout()
|
||||
# lineplot.set_title(f'{sorted(list(df["Measurement"].unique()))}')
|
||||
return lineplot
|
||||
|
||||
|
||||
def prepare_center_double_column_legend(df, hue, style, hue_order):
|
||||
"""
|
||||
Prepares a line plot with a legend centered at the bottom and spread across two columns.
|
||||
|
||||
:param df: The DataFrame containing the data to be plotted.
|
||||
:type df: pandas.DataFrame
|
||||
:param hue: Grouping variable that will produce lines with different colors.
|
||||
:type hue: str
|
||||
:param style: Grouping variable that will produce lines with different styles.
|
||||
:type style: str
|
||||
:param hue_order: Order for the levels of the hue variable in the plot.
|
||||
:type hue_order: list
|
||||
:return: The prepared line plot.
|
||||
:rtype: matplotlib.axes._subplots.AxesSubplot
|
||||
"""
|
||||
print('Struggling to plot Figure using LaTeX - going back to normal.')
|
||||
plt.close('all')
|
||||
sns.set(rc={'text.usetex': False}, style='whitegrid')
|
||||
_ = plt.figure(figsize=(10, 11))
|
||||
lineplot = sns.lineplot(data=df, x='Episode', y='Score', hue=hue, style=style,
|
||||
ci=95, palette=PALETTE, hue_order=hue_order, legend=False)
|
||||
# plt.legend(bbox_to_anchor=(1.02, 1), loc='upper left', borderaxespad=0)
|
||||
lineplot.legend(hue_order, ncol=3, loc='lower center', title='Parameter Combinations', bbox_to_anchor=(0.5, -0.43))
|
||||
plt.tight_layout()
|
||||
return lineplot
|
||||
|
||||
|
||||
def prepare_plot(filepath, results_df, ext='png', hue='Measurement', style=None, use_tex=False):
|
||||
"""
|
||||
Prepares a line plot for visualization. Based on the use tex parameter calls the prepare_tex or prepare_plot
|
||||
function accordingly, followed by the plot function to save the plot.
|
||||
|
||||
:param filepath: The file path where the plot will be saved.
|
||||
:type filepath: str
|
||||
:param results_df: The DataFrame containing the data to be plotted.
|
||||
:type results_df: pandas.DataFrame
|
||||
:param ext: The file extension of the saved plot (default is 'png').
|
||||
:type ext: str
|
||||
:param hue: The variable to determine the color of the lines in the plot.
|
||||
:type hue: str
|
||||
:param style: The variable to determine the style of the lines in the plot (default is None).
|
||||
:type style: str or None
|
||||
:param use_tex: Whether to use LaTeX for text rendering (default is False).
|
||||
:type use_tex: bool
|
||||
"""
|
||||
df = results_df.copy()
|
||||
df[hue] = df[hue].str.replace('_', '-')
|
||||
hue_order = sorted(list(df[hue].unique()))
|
||||
if use_tex:
|
||||
try:
|
||||
_ = prepare_tex(df, hue, style, hue_order)
|
||||
plot(filepath, ext=ext) # plot raises errors not lineplot!
|
||||
except (FileNotFoundError, RuntimeError):
|
||||
_ = prepare_plt(df, hue, style, hue_order)
|
||||
plot(filepath, ext=ext)
|
||||
else:
|
||||
_ = prepare_plt(df, hue, style, hue_order)
|
||||
plot(filepath, ext=ext)
|
||||
@@ -196,7 +196,7 @@ class Renderer:
|
||||
rects.append(dict(source=shape_surf, dest=visibility_rect))
|
||||
return rects
|
||||
|
||||
def render(self, entities, recorder):
|
||||
def render(self, entities):
|
||||
"""
|
||||
Renders the entities on the screen.
|
||||
|
||||
@@ -230,11 +230,6 @@ class Renderer:
|
||||
for blit in blits:
|
||||
self.screen.blit(**blit)
|
||||
|
||||
if recorder:
|
||||
frame = pygame.surfarray.array3d(self.screen)
|
||||
frame = np.transpose(frame, (1, 0, 2)) # Transpose to (height, width, channels)
|
||||
recorder.append_data(frame)
|
||||
|
||||
pygame.display.flip()
|
||||
self.clock.tick(self.fps)
|
||||
rgb_obs = pygame.surfarray.array3d(self.screen)
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
import copy
|
||||
from itertools import islice
|
||||
from typing import List, Tuple
|
||||
|
||||
@@ -117,7 +116,6 @@ class Gamestate(object):
|
||||
self.rng = np.random.default_rng(env_seed)
|
||||
self.rules = StepRules(*rules)
|
||||
self._floortile_graph = None
|
||||
self.route_cache = []
|
||||
self.tests = StepTests(*tests)
|
||||
|
||||
# Pointer that defines current spawn points of agents
|
||||
@@ -322,42 +320,6 @@ class Gamestate(object):
|
||||
# json_file.seek(0)
|
||||
# json.dump(existing_content, json_file, indent=4)
|
||||
|
||||
def cache_route(self, route):
|
||||
"""
|
||||
Save routes in env-level cache so agents can access it.
|
||||
|
||||
:param route: The route to be saved
|
||||
"""
|
||||
self.route_cache.append(copy.deepcopy(route))
|
||||
# print(f"Cached route: {route}")
|
||||
|
||||
def get_cached_route(self, current_pos, target_positions, route_cutting=False):
|
||||
"""
|
||||
Use a cached route if it includes the current position and a target
|
||||
|
||||
:param current_pos: The agent's current position and thus the first position of possibly cached routes
|
||||
:param target_positions: The positions of targets the agent wants to visit
|
||||
:param route_cutting: if true, cuts found routes to end at target. False allows target agents to loop.
|
||||
|
||||
:returns: A cached route from the agent's position to the first target if it exists
|
||||
"""
|
||||
if not self.route_cache:
|
||||
return None
|
||||
|
||||
for route in self.route_cache:
|
||||
if current_pos in route:
|
||||
targets = [target for target in target_positions if target in route]
|
||||
if targets:
|
||||
first_target = targets[0]
|
||||
index_start = route.index(current_pos)
|
||||
|
||||
if route_cutting:
|
||||
index_end = route.index(first_target) + 1
|
||||
return copy.deepcopy(route[index_start:index_end])
|
||||
else:
|
||||
return copy.deepcopy(route[index_start:])
|
||||
return None
|
||||
|
||||
|
||||
class StepTests:
|
||||
def __init__(self, *args):
|
||||
|
||||
@@ -1,247 +0,0 @@
|
||||
import importlib
|
||||
import inspect
|
||||
from os import PathLike
|
||||
from pathlib import Path
|
||||
|
||||
import yaml
|
||||
|
||||
from marl_factory_grid.environment import constants as c
|
||||
from marl_factory_grid.utils.helpers import locate_and_import_class
|
||||
|
||||
ACTION = 'Action'
|
||||
GENERAL = 'General'
|
||||
ENTITIES = 'Objects'
|
||||
OBSERVATIONS = 'Observations'
|
||||
RULES = 'Rule'
|
||||
TESTS = 'Tests'
|
||||
EXCLUDED = ['identifier', 'args', 'kwargs', 'Move', 'Agent', 'GlobalPositions', 'Walls', 'Gamestate', 'Path',
|
||||
'Iterable', 'Move', 'Result', 'TemplateRule', 'Entities', 'EnvObjects', 'Collection',
|
||||
'State', 'Object', 'default_valid_reward', 'default_fail_reward', 'size']
|
||||
|
||||
|
||||
class ConfigExplainer:
|
||||
|
||||
def __init__(self, custom_path: None | PathLike = None):
|
||||
"""
|
||||
This utility serves as a helper for debugging and exploring available modules and classes.
|
||||
Does not do anything unless told.
|
||||
The functions get_xxxxx() retrieves and returns the information and save_xxxxx() dumps them to disk.
|
||||
|
||||
get_all() and save_all() helps geting a general overview.
|
||||
|
||||
When provided with a custom path, your own modules become available.
|
||||
|
||||
:param custom_path: Path to your custom module folder.
|
||||
"""
|
||||
|
||||
self.base_path = Path(__file__).parent.parent.resolve() /'environment'
|
||||
self.modules_path = Path(__file__).parent.parent.resolve() / 'modules'
|
||||
self.custom_path = Path(custom_path) if custom_path is not None else custom_path
|
||||
self.searchspace = [ACTION, GENERAL, ENTITIES, OBSERVATIONS, RULES, TESTS]
|
||||
|
||||
@staticmethod
|
||||
def _explain_module(class_to_explain):
|
||||
"""
|
||||
INTERNAL USE ONLY
|
||||
"""
|
||||
this_search = class_to_explain
|
||||
parameters = dict(inspect.signature(class_to_explain).parameters)
|
||||
while this_search.__bases__:
|
||||
base_class = this_search.__bases__[0]
|
||||
parameters.update(dict(inspect.signature(base_class).parameters))
|
||||
this_search = base_class
|
||||
|
||||
explained = {class_to_explain.__name__:
|
||||
{key: val.default if val.default != inspect._empty else '!' for key, val in parameters.items()
|
||||
if key not in EXCLUDED}
|
||||
}
|
||||
return explained
|
||||
|
||||
def _get_by_identifier(self, identifier):
|
||||
"""
|
||||
INTERNAL USE ONLY
|
||||
"""
|
||||
entities_base_cls = locate_and_import_class(identifier, self.base_path)
|
||||
module_paths = [x.resolve() for x in self.modules_path.rglob('*.py') if x.is_file() and '__init__' not in x.name]
|
||||
base_paths = [x.resolve() for x in self.base_path.rglob('*.py') if x.is_file() and '__init__' not in x.name]
|
||||
found_entities = self._load_and_compare(entities_base_cls, base_paths)
|
||||
found_entities.update(self._load_and_compare(entities_base_cls, module_paths))
|
||||
if self.custom_path is not None:
|
||||
module_paths = [x.resolve() for x in self.custom_path.rglob('*.py') if x.is_file()
|
||||
and '__init__' not in x.name]
|
||||
found_entities.update(self._load_and_compare(entities_base_cls, module_paths))
|
||||
return found_entities
|
||||
|
||||
def _load_and_compare(self, compare_class, paths):
|
||||
"""
|
||||
INTERNAL USE ONLY
|
||||
"""
|
||||
conf = {}
|
||||
package_pos = next(idx for idx, x in enumerate(Path(__file__).resolve().parts) if x == 'marl_factory_grid')
|
||||
for module_path in paths:
|
||||
module_parts = [x.replace('.py', '') for idx, x in enumerate(module_path.parts) if idx >= package_pos]
|
||||
mods = importlib.import_module('.'.join(module_parts))
|
||||
for key in mods.__dict__.keys():
|
||||
if key not in EXCLUDED and not key.startswith('_'):
|
||||
mod = mods.__getattribute__(key)
|
||||
try:
|
||||
if issubclass(mod, compare_class) and mod != compare_class:
|
||||
conf.update(self._explain_module(mod))
|
||||
except TypeError:
|
||||
pass
|
||||
return conf
|
||||
|
||||
@staticmethod
|
||||
def _save_to_file(data: dict, filepath: PathLike, tag: str = ''):
|
||||
"""
|
||||
INTERNAL USE ONLY
|
||||
"""
|
||||
filepath = Path(filepath)
|
||||
yaml.Dumper.ignore_aliases = lambda *args: True
|
||||
with filepath.open('w') as f:
|
||||
yaml.dump(data, f, encoding='utf-8')
|
||||
print(f'Example config {"for " + tag + " " if tag else " "}dumped')
|
||||
print(f'See file: {filepath}')
|
||||
|
||||
def get_actions(self) -> dict[str]:
|
||||
"""
|
||||
Retrieve all actions from module folders.
|
||||
|
||||
:returns: A list of all available actions.
|
||||
"""
|
||||
actions = self._get_by_identifier(ACTION)
|
||||
actions.update({c.MOVE8: {}, c.MOVE4: {}})
|
||||
return actions
|
||||
|
||||
def get_all(self) -> dict[str]:
|
||||
"""
|
||||
Retrieve all available configurations from module folders.
|
||||
|
||||
:returns: A dictionary of all available configurations.
|
||||
"""
|
||||
|
||||
config_dict = {
|
||||
'General': self.get_general_section(),
|
||||
'Agents': self.get_agent_section(),
|
||||
'Entities': self.get_entities(),
|
||||
'Rules': self.get_rules()
|
||||
}
|
||||
return config_dict
|
||||
|
||||
def get_entities(self):
|
||||
"""
|
||||
Retrieve all entities from module folders.
|
||||
|
||||
:returns: A list of all available entities.
|
||||
"""
|
||||
entities = self._get_by_identifier(ENTITIES)
|
||||
for key in ['Combined', 'Agents', 'Inventory']:
|
||||
del entities[key]
|
||||
return entities
|
||||
|
||||
@staticmethod
|
||||
def get_general_section():
|
||||
"""
|
||||
Build the general section.
|
||||
|
||||
:returns: A list of all available entities.
|
||||
"""
|
||||
general = {'level_name': 'rooms', 'env_seed': 69, 'verbose': False,
|
||||
'pomdp_r': 3, 'individual_rewards': True, 'tests': False}
|
||||
return general
|
||||
|
||||
def get_agent_section(self):
|
||||
"""
|
||||
Build the Agent section and retrieve all available actions and observations from module folders.
|
||||
|
||||
:returns: Agent section.
|
||||
"""
|
||||
agents = dict(
|
||||
ExampleAgentName=dict(
|
||||
Actions=self.get_actions(),
|
||||
Observations=self.get_observations())),
|
||||
return agents
|
||||
|
||||
def get_rules(self) -> dict[str]:
|
||||
"""
|
||||
Retrieve all rules from module folders.
|
||||
|
||||
:returns: All available rules.
|
||||
"""
|
||||
rules = self._get_by_identifier(RULES)
|
||||
return rules
|
||||
|
||||
def get_observations(self) -> list[str]:
|
||||
"""
|
||||
Retrieve all agent observations from module folders.
|
||||
|
||||
:returns: A list of all available observations.
|
||||
"""
|
||||
names = [c.ALL, c.COMBINED, c.SELF, c.OTHERS, "Agent['ExampleAgentName']"]
|
||||
for key, val in self.get_entities().items():
|
||||
try:
|
||||
e = locate_and_import_class(key, self.base_path)(level_shape=(0, 0), pomdp_r=0).obs_pairs
|
||||
except TypeError:
|
||||
e = [key]
|
||||
except AttributeError as err:
|
||||
try:
|
||||
e = locate_and_import_class(key, self.modules_path)(level_shape=(0, 0), pomdp_r=0).obs_pairs
|
||||
except TypeError:
|
||||
e = [key]
|
||||
except AttributeError as err2:
|
||||
if self.custom_path is not None:
|
||||
try:
|
||||
e = locate_and_import_class(key, self.base_path)(level_shape=(0, 0), pomdp_r=0).obs_pairs
|
||||
except TypeError:
|
||||
e = [key]
|
||||
else:
|
||||
print(err.args)
|
||||
print(err2.args)
|
||||
exit(-9999)
|
||||
names.extend(e)
|
||||
return names
|
||||
|
||||
def save_actions(self, output_conf_file: PathLike = Path('../../quickstart') / 'actions.yml'):
|
||||
"""
|
||||
Write all availale actions to a file.
|
||||
:param output_conf_file: File to write to. Defaults to ../../quickstart/actions.yml
|
||||
"""
|
||||
self._save_to_file(self.get_entities(), output_conf_file, ACTION)
|
||||
|
||||
def save_entities(self, output_conf_file: PathLike = Path('../../quickstart') / 'entities.yml'):
|
||||
"""
|
||||
Write all availale entities to a file.
|
||||
:param output_conf_file: File to write to. Defaults to ../../quickstart/entities.yml
|
||||
"""
|
||||
self._save_to_file(self.get_entities(), output_conf_file, ENTITIES)
|
||||
|
||||
def save_observations(self, output_conf_file: PathLike = Path('../../quickstart') / 'observations.yml'):
|
||||
"""
|
||||
Write all availale observations to a file.
|
||||
:param output_conf_file: File to write to. Defaults to ../../quickstart/observations.yml
|
||||
"""
|
||||
self._save_to_file(self.get_entities(), output_conf_file, OBSERVATIONS)
|
||||
|
||||
def save_rules(self, output_conf_file: PathLike = Path('../../quickstart') / 'rules.yml'):
|
||||
"""
|
||||
Write all availale rules to a file.
|
||||
:param output_conf_file: File to write to. Defaults to ../../quickstart/rules.yml
|
||||
"""
|
||||
self._save_to_file(self.get_entities(), output_conf_file, RULES)
|
||||
|
||||
def save_all(self, output_conf_file: PathLike = Path('../../quickstart') / 'all.yml'):
|
||||
"""
|
||||
Write all availale keywords to a file.
|
||||
:param output_conf_file: File to write to. Defaults to ../../quickstart/all.yml
|
||||
"""
|
||||
self._save_to_file(self.get_all(), output_conf_file, 'ALL')
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
ce = ConfigExplainer()
|
||||
# ce.get_actions()
|
||||
# ce.get_entities()
|
||||
# ce.get_rules()
|
||||
# ce.get_observations()
|
||||
all_conf = ce.get_all()
|
||||
ce.save_all()
|
||||
Reference in New Issue
Block a user