Compare commits
24 Commits
testing_ol
...
main
Author | SHA1 | Date | |
---|---|---|---|
![]() |
bcbd4a8078 | ||
![]() |
bdbd1c4e25 | ||
![]() |
6c2df735d4 | ||
![]() |
4f3924d3ab | ||
![]() |
6a24e7b518 | ||
![]() |
e7461d7dcf | ||
![]() |
33f144fc93 | ||
![]() |
0218f8f4e9 | ||
![]() |
a9a4274370 | ||
![]() |
b09c461754 | ||
![]() |
ffc47752a7 | ||
![]() |
3e19970a60 | ||
![]() |
51fb73ebb8 | ||
![]() |
a16d7e709e | ||
![]() |
3ce6302e8a | ||
![]() |
823aa075b9 | ||
![]() |
d29ccbbb71 | ||
![]() |
2a2aafa988 | ||
![]() |
0e8a4af740 | ||
![]() |
b6c8cbd2e3 | ||
![]() |
3150757347 | ||
![]() |
435056f373 | ||
![]() |
78bf19f7f4 | ||
![]() |
b43f595207 |
1
.gitignore
vendored
1
.gitignore
vendored
@ -702,3 +702,4 @@ $RECYCLE.BIN/
|
|||||||
|
|
||||||
# End of https://www.toptal.com/developers/gitignore/api/linux,unity,macos,python,windows,pycharm,notepadpp,visualstudiocode,latex
|
# End of https://www.toptal.com/developers/gitignore/api/linux,unity,macos,python,windows,pycharm,notepadpp,visualstudiocode,latex
|
||||||
/studies/e_1/
|
/studies/e_1/
|
||||||
|
/studies/curious_study/
|
||||||
|
@ -5,11 +5,25 @@ from networkx.algorithms.approximation import traveling_salesman as tsp
|
|||||||
from environments.factory.base.objects import Agent
|
from environments.factory.base.objects import Agent
|
||||||
from environments.helpers import points_to_graph
|
from environments.helpers import points_to_graph
|
||||||
from environments import helpers as h
|
from environments import helpers as h
|
||||||
from environments.helpers import Constants as c
|
|
||||||
|
|
||||||
|
from environments.helpers import Constants as BaseConstants
|
||||||
|
from environments.helpers import EnvActions as BaseActions
|
||||||
|
|
||||||
|
|
||||||
|
class Constants(BaseConstants):
|
||||||
|
DIRT = 'DirtPile'
|
||||||
|
|
||||||
|
|
||||||
|
class Actions(BaseActions):
|
||||||
|
CLEAN_UP = 'do_cleanup_action'
|
||||||
|
|
||||||
|
|
||||||
|
a = Actions
|
||||||
|
c = Constants
|
||||||
|
|
||||||
future_planning = 7
|
future_planning = 7
|
||||||
|
|
||||||
|
|
||||||
class TSPDirtAgent(Agent):
|
class TSPDirtAgent(Agent):
|
||||||
|
|
||||||
def __init__(self, env, *args,
|
def __init__(self, env, *args,
|
||||||
@ -26,7 +40,7 @@ class TSPDirtAgent(Agent):
|
|||||||
def predict(self, *_, **__):
|
def predict(self, *_, **__):
|
||||||
if self._env[c.DIRT].by_pos(self.pos) is not None:
|
if self._env[c.DIRT].by_pos(self.pos) is not None:
|
||||||
# Translate the action_object to an integer to have the same output as any other model
|
# Translate the action_object to an integer to have the same output as any other model
|
||||||
action = h.EnvActions.CLEAN_UP
|
action = a.CLEAN_UP
|
||||||
elif any('door' in x.name.lower() for x in self.tile.guests):
|
elif any('door' in x.name.lower() for x in self.tile.guests):
|
||||||
door = next(x for x in self.tile.guests if 'door' in x.name.lower())
|
door = next(x for x in self.tile.guests if 'door' in x.name.lower())
|
||||||
if door.is_closed:
|
if door.is_closed:
|
||||||
@ -37,7 +51,7 @@ class TSPDirtAgent(Agent):
|
|||||||
else:
|
else:
|
||||||
action = self._predict_move()
|
action = self._predict_move()
|
||||||
# Translate the action_object to an integer to have the same output as any other model
|
# Translate the action_object to an integer to have the same output as any other model
|
||||||
action_obj = next(action_i for action_i, action_obj in enumerate(self._env._actions) if action_obj == action)
|
action_obj = next(action_i for action_name, action_i in self._env.named_action_space.items() if action_name == action)
|
||||||
return action_obj
|
return action_obj
|
||||||
|
|
||||||
def _predict_move(self):
|
def _predict_move(self):
|
||||||
|
@ -1,221 +0,0 @@
|
|||||||
from typing import NamedTuple, Union
|
|
||||||
from collections import deque, OrderedDict, defaultdict
|
|
||||||
import numpy as np
|
|
||||||
import random
|
|
||||||
|
|
||||||
import pandas as pd
|
|
||||||
import torch
|
|
||||||
import torch.nn as nn
|
|
||||||
|
|
||||||
from tqdm import trange
|
|
||||||
|
|
||||||
class Experience(NamedTuple):
|
|
||||||
# can be use for a single (s_t, a, r s_{t+1}) tuple
|
|
||||||
# or for a batch of tuples
|
|
||||||
observation: np.ndarray
|
|
||||||
next_observation: np.ndarray
|
|
||||||
action: np.ndarray
|
|
||||||
reward: Union[float, np.ndarray]
|
|
||||||
done : Union[bool, np.ndarray]
|
|
||||||
episode: int = -1
|
|
||||||
|
|
||||||
|
|
||||||
class BaseLearner:
|
|
||||||
def __init__(self, env, n_agents=1, train_every=('step', 4), n_grad_steps=1, stack_n_frames=1):
|
|
||||||
assert train_every[0] in ['step', 'episode'], 'train_every[0] must be one of ["step", "episode"]'
|
|
||||||
self.env = env
|
|
||||||
self.n_agents = n_agents
|
|
||||||
self.n_grad_steps = n_grad_steps
|
|
||||||
self.train_every = train_every
|
|
||||||
self.stack_n_frames = deque(maxlen=stack_n_frames)
|
|
||||||
self.device = 'cpu'
|
|
||||||
self.n_updates = 0
|
|
||||||
self.step = 0
|
|
||||||
self.episode_step = 0
|
|
||||||
self.episode = 0
|
|
||||||
self.running_reward = deque(maxlen=5)
|
|
||||||
|
|
||||||
def to(self, device):
|
|
||||||
self.device = device
|
|
||||||
for attr, value in self.__dict__.items():
|
|
||||||
if isinstance(value, nn.Module):
|
|
||||||
value = value.to(self.device)
|
|
||||||
return self
|
|
||||||
|
|
||||||
def get_action(self, obs) -> Union[int, np.ndarray]:
|
|
||||||
pass
|
|
||||||
|
|
||||||
def on_new_experience(self, experience):
|
|
||||||
pass
|
|
||||||
|
|
||||||
def on_step_end(self, n_steps):
|
|
||||||
pass
|
|
||||||
|
|
||||||
def on_episode_end(self, n_steps):
|
|
||||||
pass
|
|
||||||
|
|
||||||
def on_all_done(self):
|
|
||||||
pass
|
|
||||||
|
|
||||||
def train(self):
|
|
||||||
pass
|
|
||||||
|
|
||||||
def reward(self, r):
|
|
||||||
return r
|
|
||||||
|
|
||||||
def learn(self, n_steps):
|
|
||||||
train_type, train_freq = self.train_every
|
|
||||||
while self.step < n_steps:
|
|
||||||
obs, done = self.env.reset(), False
|
|
||||||
total_reward = 0
|
|
||||||
self.episode_step = 0
|
|
||||||
while not done:
|
|
||||||
|
|
||||||
action = self.get_action(obs)
|
|
||||||
|
|
||||||
next_obs, reward, done, info = self.env.step(action if not len(action) == 1 else action[0])
|
|
||||||
|
|
||||||
experience = Experience(observation=obs, next_observation=next_obs,
|
|
||||||
action=action, reward=self.reward(reward),
|
|
||||||
done=done, episode=self.episode) # do we really need to copy?
|
|
||||||
self.on_new_experience(experience)
|
|
||||||
# end of step routine
|
|
||||||
obs = next_obs
|
|
||||||
total_reward += reward
|
|
||||||
self.step += 1
|
|
||||||
self.episode_step += 1
|
|
||||||
self.on_step_end(n_steps)
|
|
||||||
if train_type == 'step' and (self.step % train_freq == 0):
|
|
||||||
self.train()
|
|
||||||
self.n_updates += 1
|
|
||||||
self.on_episode_end(n_steps)
|
|
||||||
if train_type == 'episode' and (self.episode % train_freq == 0):
|
|
||||||
self.train()
|
|
||||||
self.n_updates += 1
|
|
||||||
|
|
||||||
self.running_reward.append(total_reward)
|
|
||||||
self.episode += 1
|
|
||||||
try:
|
|
||||||
if self.step % 100 == 0:
|
|
||||||
print(
|
|
||||||
f'Step: {self.step} ({(self.step / n_steps) * 100:.2f}%)\tRunning reward: {sum(list(self.running_reward)) / len(self.running_reward):.2f}\t'
|
|
||||||
f' eps: {self.eps:.4f}\tRunning loss: {sum(list(self.running_loss)) / len(self.running_loss):.4f}\tUpdates:{self.n_updates}')
|
|
||||||
except Exception as e:
|
|
||||||
pass
|
|
||||||
self.on_all_done()
|
|
||||||
|
|
||||||
def evaluate(self, n_episodes=100, render=False):
|
|
||||||
with torch.no_grad():
|
|
||||||
data = []
|
|
||||||
for eval_i in trange(n_episodes):
|
|
||||||
obs, done = self.env.reset(), False
|
|
||||||
while not done:
|
|
||||||
action = self.get_action(obs)
|
|
||||||
next_obs, reward, done, info = self.env.step(action if not len(action) == 1 else action[0])
|
|
||||||
if render: self.env.render()
|
|
||||||
obs = next_obs # srsly i'm so stupid
|
|
||||||
info.update({'reward': reward, 'eval_episode': eval_i})
|
|
||||||
data.append(info)
|
|
||||||
return pd.DataFrame(data).fillna(0)
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class BaseBuffer:
|
|
||||||
def __init__(self, size: int):
|
|
||||||
self.size = size
|
|
||||||
self.experience = deque(maxlen=size)
|
|
||||||
|
|
||||||
def __len__(self):
|
|
||||||
return len(self.experience)
|
|
||||||
|
|
||||||
def add(self, exp: Experience):
|
|
||||||
self.experience.append(exp)
|
|
||||||
|
|
||||||
def sample(self, k, cer=4):
|
|
||||||
sample = random.choices(self.experience, k=k-cer)
|
|
||||||
for i in range(cer): sample += [self.experience[-i]]
|
|
||||||
observations = torch.stack([torch.from_numpy(e.observation) for e in sample], 0).float()
|
|
||||||
next_observations = torch.stack([torch.from_numpy(e.next_observation) for e in sample], 0).float()
|
|
||||||
actions = torch.tensor([e.action for e in sample]).long()
|
|
||||||
rewards = torch.tensor([e.reward for e in sample]).float().view(-1, 1)
|
|
||||||
dones = torch.tensor([e.done for e in sample]).float().view(-1, 1)
|
|
||||||
#print(observations.shape, next_observations.shape, actions.shape, rewards.shape, dones.shape)
|
|
||||||
return Experience(observations, next_observations, actions, rewards, dones)
|
|
||||||
|
|
||||||
|
|
||||||
class TrajectoryBuffer(BaseBuffer):
|
|
||||||
def __init__(self, size):
|
|
||||||
super(TrajectoryBuffer, self).__init__(size)
|
|
||||||
self.experience = defaultdict(list)
|
|
||||||
|
|
||||||
def add(self, exp: Experience):
|
|
||||||
self.experience[exp.episode].append(exp)
|
|
||||||
if len(self.experience) > self.size:
|
|
||||||
oldest_traj_key = list(sorted(self.experience.keys()))[0]
|
|
||||||
del self.experience[oldest_traj_key]
|
|
||||||
|
|
||||||
|
|
||||||
def soft_update(local_model, target_model, tau):
|
|
||||||
# taken from https://github.com/BY571/Munchausen-RL/blob/master/M-DQN.ipynb
|
|
||||||
for target_param, local_param in zip(target_model.parameters(), local_model.parameters()):
|
|
||||||
target_param.data.copy_(tau*local_param.data + (1.-tau)*target_param.data)
|
|
||||||
|
|
||||||
|
|
||||||
def mlp_maker(dims, flatten=False, activation='elu', activation_last='identity'):
|
|
||||||
activations = {'elu': nn.ELU, 'relu': nn.ReLU, 'sigmoid': nn.Sigmoid,
|
|
||||||
'leaky_relu': nn.LeakyReLU, 'tanh': nn.Tanh,
|
|
||||||
'gelu': nn.GELU, 'identity': nn.Identity}
|
|
||||||
layers = [('Flatten', nn.Flatten())] if flatten else []
|
|
||||||
for i in range(1, len(dims)):
|
|
||||||
layers.append((f'Layer #{i - 1}: Linear', nn.Linear(dims[i - 1], dims[i])))
|
|
||||||
activation_str = activation if i != len(dims)-1 else activation_last
|
|
||||||
layers.append((f'Layer #{i - 1}: {activation_str.capitalize()}', activations[activation_str]()))
|
|
||||||
return nn.Sequential(OrderedDict(layers))
|
|
||||||
|
|
||||||
|
|
||||||
class BaseDQN(nn.Module):
|
|
||||||
def __init__(self, dims=[3*5*5, 64, 64, 9]):
|
|
||||||
super(BaseDQN, self).__init__()
|
|
||||||
self.net = mlp_maker(dims, flatten=True)
|
|
||||||
|
|
||||||
@torch.no_grad()
|
|
||||||
def act(self, x) -> np.ndarray:
|
|
||||||
action = self.forward(x).max(-1)[1].numpy()
|
|
||||||
return action
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
return self.net(x)
|
|
||||||
|
|
||||||
|
|
||||||
class BaseDDQN(BaseDQN):
|
|
||||||
def __init__(self,
|
|
||||||
backbone_dims=[3*5*5, 64, 64],
|
|
||||||
value_dims=[64, 1],
|
|
||||||
advantage_dims=[64, 9],
|
|
||||||
activation='elu'):
|
|
||||||
super(BaseDDQN, self).__init__(backbone_dims)
|
|
||||||
self.net = mlp_maker(backbone_dims, activation=activation, flatten=True)
|
|
||||||
self.value_head = mlp_maker(value_dims)
|
|
||||||
self.advantage_head = mlp_maker(advantage_dims)
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
features = self.net(x)
|
|
||||||
advantages = self.advantage_head(features)
|
|
||||||
values = self.value_head(features)
|
|
||||||
return values + (advantages - advantages.mean())
|
|
||||||
|
|
||||||
|
|
||||||
class BaseICM(nn.Module):
|
|
||||||
def __init__(self, backbone_dims=[2*3*5*5, 64, 64], head_dims=[2*64, 64, 9]):
|
|
||||||
super(BaseICM, self).__init__()
|
|
||||||
self.backbone = mlp_maker(backbone_dims, flatten=True, activation_last='relu', activation='relu')
|
|
||||||
self.icm = mlp_maker(head_dims)
|
|
||||||
self.ce = nn.CrossEntropyLoss()
|
|
||||||
|
|
||||||
def forward(self, s0, s1, a):
|
|
||||||
phi_s0 = self.backbone(s0)
|
|
||||||
phi_s1 = self.backbone(s1)
|
|
||||||
cat = torch.cat((phi_s0, phi_s1), dim=1)
|
|
||||||
a_prime = torch.softmax(self.icm(cat), dim=-1)
|
|
||||||
ce = self.ce(a_prime, a)
|
|
||||||
return dict(prediction=a_prime, loss=ce)
|
|
@ -1,77 +0,0 @@
|
|||||||
import numpy as np
|
|
||||||
import torch
|
|
||||||
import torch.nn.functional as F
|
|
||||||
from algorithms.q_learner import QLearner
|
|
||||||
|
|
||||||
|
|
||||||
class MQLearner(QLearner):
|
|
||||||
# Munchhausen Q-Learning
|
|
||||||
def __init__(self, *args, temperature=0.03, alpha=0.9, clip_l0=-1.0, **kwargs):
|
|
||||||
super(MQLearner, self).__init__(*args, **kwargs)
|
|
||||||
assert self.n_agents == 1, 'M-DQN currently only supports single agent training'
|
|
||||||
self.temperature = temperature
|
|
||||||
self.alpha = alpha
|
|
||||||
self.clip0 = clip_l0
|
|
||||||
|
|
||||||
def tau_ln_pi(self, qs):
|
|
||||||
# computes log(softmax(qs/temperature))
|
|
||||||
# Custom log-sum-exp trick from page 18 to compute the log-policy terms
|
|
||||||
v_k = qs.max(-1)[0].unsqueeze(-1)
|
|
||||||
advantage = qs - v_k
|
|
||||||
logsum = torch.logsumexp(advantage / self.temperature, -1).unsqueeze(-1)
|
|
||||||
tau_ln_pi = advantage - self.temperature * logsum
|
|
||||||
return tau_ln_pi
|
|
||||||
|
|
||||||
def train(self):
|
|
||||||
if len(self.buffer) < self.batch_size: return
|
|
||||||
for _ in range(self.n_grad_steps):
|
|
||||||
|
|
||||||
experience = self.buffer.sample(self.batch_size, cer=self.train_every[-1])
|
|
||||||
|
|
||||||
with torch.no_grad():
|
|
||||||
q_target_next = self.target_q_net(experience.next_observation)
|
|
||||||
tau_log_pi_next = self.tau_ln_pi(q_target_next)
|
|
||||||
|
|
||||||
q_k_targets = self.target_q_net(experience.observation)
|
|
||||||
log_pi = self.tau_ln_pi(q_k_targets)
|
|
||||||
|
|
||||||
pi_target = F.softmax(q_target_next / self.temperature, dim=-1)
|
|
||||||
q_target = (self.gamma * (pi_target * (q_target_next - tau_log_pi_next) * (1 - experience.done)).sum(-1)).unsqueeze(-1)
|
|
||||||
|
|
||||||
munchausen_addon = log_pi.gather(-1, experience.action)
|
|
||||||
|
|
||||||
munchausen_reward = (experience.reward + self.alpha * torch.clamp(munchausen_addon, min=self.clip0, max=0))
|
|
||||||
|
|
||||||
# Compute Q targets for current states
|
|
||||||
m_q_target = munchausen_reward + q_target
|
|
||||||
|
|
||||||
# Get expected Q values from local model
|
|
||||||
q_k = self.q_net(experience.observation)
|
|
||||||
pred_q = q_k.gather(-1, experience.action)
|
|
||||||
|
|
||||||
# Compute loss
|
|
||||||
loss = torch.mean(self.reg_weight * pred_q + torch.pow(pred_q - m_q_target, 2))
|
|
||||||
self._backprop_loss(loss)
|
|
||||||
|
|
||||||
from tqdm import trange
|
|
||||||
from collections import deque
|
|
||||||
class MQICMLearner(MQLearner):
|
|
||||||
def __init__(self, *args, icm, **kwargs):
|
|
||||||
super(MQICMLearner, self).__init__(*args, **kwargs)
|
|
||||||
self.icm = icm
|
|
||||||
self.icm_optimizer = torch.optim.AdamW(self.icm.parameters())
|
|
||||||
self.normalize_reward = deque(maxlen=1000)
|
|
||||||
|
|
||||||
def on_all_done(self):
|
|
||||||
from collections import deque
|
|
||||||
losses = deque(maxlen=100)
|
|
||||||
for b in trange(10000):
|
|
||||||
batch = self.buffer.sample(128, 0)
|
|
||||||
s0, s1, a = batch.observation, batch.next_observation, batch.action
|
|
||||||
loss = self.icm(s0, s1, a.squeeze())['loss']
|
|
||||||
self.icm_optimizer.zero_grad()
|
|
||||||
loss.backward()
|
|
||||||
self.icm_optimizer.step()
|
|
||||||
losses.append(loss.item())
|
|
||||||
if b%100 == 0:
|
|
||||||
print(np.mean(losses))
|
|
6
algorithms/marl/__init__.py
Normal file
6
algorithms/marl/__init__.py
Normal file
@ -0,0 +1,6 @@
|
|||||||
|
from algorithms.marl.base_ac import BaseActorCritic
|
||||||
|
from algorithms.marl.iac import LoopIAC
|
||||||
|
from algorithms.marl.snac import LoopSNAC
|
||||||
|
from algorithms.marl.seac import LoopSEAC
|
||||||
|
from algorithms.marl.mappo import LoopMAPPO
|
||||||
|
from algorithms.marl.memory import MARLActorCriticMemory
|
221
algorithms/marl/base_ac.py
Normal file
221
algorithms/marl/base_ac.py
Normal file
@ -0,0 +1,221 @@
|
|||||||
|
import torch
|
||||||
|
from typing import Union, List
|
||||||
|
import numpy as np
|
||||||
|
from torch.distributions import Categorical
|
||||||
|
from algorithms.marl.memory import MARLActorCriticMemory
|
||||||
|
from algorithms.utils import add_env_props, instantiate_class
|
||||||
|
from pathlib import Path
|
||||||
|
import pandas as pd
|
||||||
|
from collections import deque
|
||||||
|
|
||||||
|
|
||||||
|
class Names:
|
||||||
|
REWARD = 'reward'
|
||||||
|
DONE = 'done'
|
||||||
|
ACTION = 'action'
|
||||||
|
OBSERVATION = 'observation'
|
||||||
|
LOGITS = 'logits'
|
||||||
|
HIDDEN_ACTOR = 'hidden_actor'
|
||||||
|
HIDDEN_CRITIC = 'hidden_critic'
|
||||||
|
AGENT = 'agent'
|
||||||
|
ENV = 'env'
|
||||||
|
N_AGENTS = 'n_agents'
|
||||||
|
ALGORITHM = 'algorithm'
|
||||||
|
MAX_STEPS = 'max_steps'
|
||||||
|
N_STEPS = 'n_steps'
|
||||||
|
BUFFER_SIZE = 'buffer_size'
|
||||||
|
CRITIC = 'critic'
|
||||||
|
BATCH_SIZE = 'bnatch_size'
|
||||||
|
N_ACTIONS = 'n_actions'
|
||||||
|
|
||||||
|
nms = Names
|
||||||
|
ListOrTensor = Union[List, torch.Tensor]
|
||||||
|
|
||||||
|
|
||||||
|
class BaseActorCritic:
|
||||||
|
def __init__(self, cfg):
|
||||||
|
add_env_props(cfg)
|
||||||
|
self.__training = True
|
||||||
|
self.cfg = cfg
|
||||||
|
self.n_agents = cfg[nms.ENV][nms.N_AGENTS]
|
||||||
|
self.reset_memory_after_epoch = True
|
||||||
|
self.setup()
|
||||||
|
|
||||||
|
def setup(self):
|
||||||
|
self.net = instantiate_class(self.cfg[nms.AGENT])
|
||||||
|
self.optimizer = torch.optim.RMSprop(self.net.parameters(), lr=3e-4, eps=1e-5)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _as_torch(cls, x):
|
||||||
|
if isinstance(x, np.ndarray):
|
||||||
|
return torch.from_numpy(x)
|
||||||
|
elif isinstance(x, List):
|
||||||
|
return torch.tensor(x)
|
||||||
|
elif isinstance(x, (int, float)):
|
||||||
|
return torch.tensor([x])
|
||||||
|
return x
|
||||||
|
|
||||||
|
def train(self):
|
||||||
|
self.__training = False
|
||||||
|
networks = [self.net] if not isinstance(self.net, List) else self.net
|
||||||
|
for net in networks:
|
||||||
|
net.train()
|
||||||
|
|
||||||
|
def eval(self):
|
||||||
|
self.__training = False
|
||||||
|
networks = [self.net] if not isinstance(self.net, List) else self.net
|
||||||
|
for net in networks:
|
||||||
|
net.eval()
|
||||||
|
|
||||||
|
def load_state_dict(self, path: Path):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def get_actions(self, out) -> ListOrTensor:
|
||||||
|
actions = [Categorical(logits=logits).sample().item() for logits in out[nms.LOGITS]]
|
||||||
|
return actions
|
||||||
|
|
||||||
|
def init_hidden(self) -> dict[ListOrTensor]:
|
||||||
|
pass
|
||||||
|
|
||||||
|
def forward(self,
|
||||||
|
observations: ListOrTensor,
|
||||||
|
actions: ListOrTensor,
|
||||||
|
hidden_actor: ListOrTensor,
|
||||||
|
hidden_critic: ListOrTensor
|
||||||
|
) -> dict[ListOrTensor]:
|
||||||
|
pass
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def train_loop(self, checkpointer=None):
|
||||||
|
env = instantiate_class(self.cfg[nms.ENV])
|
||||||
|
n_steps, max_steps = [self.cfg[nms.ALGORITHM][k] for k in [nms.N_STEPS, nms.MAX_STEPS]]
|
||||||
|
tm = MARLActorCriticMemory(self.n_agents, self.cfg[nms.ALGORITHM].get(nms.BUFFER_SIZE, n_steps))
|
||||||
|
global_steps, episode, df_results = 0, 0, []
|
||||||
|
reward_queue = deque(maxlen=2000)
|
||||||
|
|
||||||
|
while global_steps < max_steps:
|
||||||
|
obs = env.reset()
|
||||||
|
last_hiddens = self.init_hidden()
|
||||||
|
last_action, reward = [-1] * self.n_agents, [0.] * self.n_agents
|
||||||
|
done, rew_log = [False] * self.n_agents, 0
|
||||||
|
|
||||||
|
if self.reset_memory_after_epoch:
|
||||||
|
tm.reset()
|
||||||
|
|
||||||
|
tm.add(observation=obs, action=last_action,
|
||||||
|
logits=torch.zeros(self.n_agents, 1, self.cfg[nms.AGENT][nms.N_ACTIONS]),
|
||||||
|
values=torch.zeros(self.n_agents, 1), reward=reward, done=done, **last_hiddens)
|
||||||
|
|
||||||
|
while not all(done):
|
||||||
|
out = self.forward(obs, last_action, **last_hiddens)
|
||||||
|
action = self.get_actions(out)
|
||||||
|
next_obs, reward, done, info = env.step(action)
|
||||||
|
done = [done] * self.n_agents if isinstance(done, bool) else done
|
||||||
|
|
||||||
|
last_hiddens = dict(hidden_actor =out[nms.HIDDEN_ACTOR],
|
||||||
|
hidden_critic=out[nms.HIDDEN_CRITIC])
|
||||||
|
|
||||||
|
|
||||||
|
tm.add(observation=obs, action=action, reward=reward, done=done,
|
||||||
|
logits=out.get(nms.LOGITS, None), values=out.get(nms.CRITIC, None),
|
||||||
|
**last_hiddens)
|
||||||
|
|
||||||
|
obs = next_obs
|
||||||
|
last_action = action
|
||||||
|
|
||||||
|
if (global_steps+1) % n_steps == 0 or all(done):
|
||||||
|
with torch.inference_mode(False):
|
||||||
|
self.learn(tm)
|
||||||
|
|
||||||
|
global_steps += 1
|
||||||
|
rew_log += sum(reward)
|
||||||
|
reward_queue.extend(reward)
|
||||||
|
|
||||||
|
if checkpointer is not None:
|
||||||
|
checkpointer.step([
|
||||||
|
(f'agent#{i}', agent)
|
||||||
|
for i, agent in enumerate([self.net] if not isinstance(self.net, List) else self.net)
|
||||||
|
])
|
||||||
|
|
||||||
|
if global_steps >= max_steps:
|
||||||
|
break
|
||||||
|
print(f'reward at episode: {episode} = {rew_log}')
|
||||||
|
episode += 1
|
||||||
|
df_results.append([episode, rew_log, *reward])
|
||||||
|
df_results = pd.DataFrame(df_results, columns=['steps', 'reward', *[f'agent#{i}' for i in range(self.n_agents)]])
|
||||||
|
if checkpointer is not None:
|
||||||
|
df_results.to_csv(checkpointer.path / 'results.csv', index=False)
|
||||||
|
return df_results
|
||||||
|
|
||||||
|
@torch.inference_mode(True)
|
||||||
|
def eval_loop(self, n_episodes, render=False):
|
||||||
|
env = instantiate_class(self.cfg[nms.ENV])
|
||||||
|
episode, results = 0, []
|
||||||
|
while episode < n_episodes:
|
||||||
|
obs = env.reset()
|
||||||
|
last_hiddens = self.init_hidden()
|
||||||
|
last_action, reward = [-1] * self.n_agents, [0.] * self.n_agents
|
||||||
|
done, rew_log, eps_rew = [False] * self.n_agents, 0, torch.zeros(self.n_agents)
|
||||||
|
while not all(done):
|
||||||
|
if render: env.render()
|
||||||
|
|
||||||
|
out = self.forward(obs, last_action, **last_hiddens)
|
||||||
|
action = self.get_actions(out)
|
||||||
|
next_obs, reward, done, info = env.step(action)
|
||||||
|
|
||||||
|
if isinstance(done, bool): done = [done] * obs.shape[0]
|
||||||
|
obs = next_obs
|
||||||
|
last_action = action
|
||||||
|
last_hiddens = dict(hidden_actor=out.get(nms.HIDDEN_ACTOR, None),
|
||||||
|
hidden_critic=out.get(nms.HIDDEN_CRITIC, None)
|
||||||
|
)
|
||||||
|
eps_rew += torch.tensor(reward)
|
||||||
|
results.append(eps_rew.tolist() + [sum(eps_rew).item()] + [episode])
|
||||||
|
episode += 1
|
||||||
|
agent_columns = [f'agent#{i}' for i in range(self.cfg['env']['n_agents'])]
|
||||||
|
results = pd.DataFrame(results, columns=agent_columns + ['sum', 'episode'])
|
||||||
|
results = pd.melt(results, id_vars=['episode'], value_vars=agent_columns + ['sum'], value_name='reward', var_name='agent')
|
||||||
|
return results
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def compute_advantages(critic, reward, done, gamma, gae_coef=0.0):
|
||||||
|
tds = (reward + gamma * (1.0 - done) * critic[:, 1:].detach()) - critic[:, :-1]
|
||||||
|
|
||||||
|
if gae_coef <= 0:
|
||||||
|
return tds
|
||||||
|
|
||||||
|
gae = torch.zeros_like(tds[:, -1])
|
||||||
|
gaes = []
|
||||||
|
for t in range(tds.shape[1]-1, -1, -1):
|
||||||
|
gae = tds[:, t] + gamma * gae_coef * (1.0 - done[:, t]) * gae
|
||||||
|
gaes.insert(0, gae)
|
||||||
|
gaes = torch.stack(gaes, dim=1)
|
||||||
|
return gaes
|
||||||
|
|
||||||
|
def actor_critic(self, tm, network, gamma, entropy_coef, vf_coef, gae_coef=0.0, **kwargs):
|
||||||
|
obs, actions, done, reward = tm.observation, tm.action, tm.done[:, 1:], tm.reward[:, 1:]
|
||||||
|
|
||||||
|
out = network(obs, actions, tm.hidden_actor[:, 0], tm.hidden_critic[:, 0])
|
||||||
|
logits = out[nms.LOGITS][:, :-1] # last one only needed for v_{t+1}
|
||||||
|
critic = out[nms.CRITIC]
|
||||||
|
|
||||||
|
entropy_loss = Categorical(logits=logits).entropy().mean(-1)
|
||||||
|
advantages = self.compute_advantages(critic, reward, done, gamma, gae_coef)
|
||||||
|
value_loss = advantages.pow(2).mean(-1) # n_agent
|
||||||
|
|
||||||
|
# policy loss
|
||||||
|
log_ap = torch.log_softmax(logits, -1)
|
||||||
|
log_ap = torch.gather(log_ap, dim=-1, index=actions[:, 1:].unsqueeze(-1)).squeeze()
|
||||||
|
a2c_loss = -(advantages.detach() * log_ap).mean(-1)
|
||||||
|
# weighted loss
|
||||||
|
loss = a2c_loss + vf_coef*value_loss - entropy_coef * entropy_loss
|
||||||
|
return loss.mean()
|
||||||
|
|
||||||
|
def learn(self, tm: MARLActorCriticMemory, **kwargs):
|
||||||
|
loss = self.actor_critic(tm, self.net, **self.cfg[nms.ALGORITHM], **kwargs)
|
||||||
|
# remove next_obs, will be added in next iter
|
||||||
|
self.optimizer.zero_grad()
|
||||||
|
loss.backward()
|
||||||
|
torch.nn.utils.clip_grad_norm_(self.net.parameters(), 0.5)
|
||||||
|
self.optimizer.step()
|
||||||
|
|
24
algorithms/marl/example_config.yaml
Normal file
24
algorithms/marl/example_config.yaml
Normal file
@ -0,0 +1,24 @@
|
|||||||
|
agent:
|
||||||
|
classname: algorithms.marl.networks.RecurrentAC
|
||||||
|
n_agents: 2
|
||||||
|
obs_emb_size: 96
|
||||||
|
action_emb_size: 16
|
||||||
|
hidden_size_actor: 64
|
||||||
|
hidden_size_critic: 64
|
||||||
|
use_agent_embedding: False
|
||||||
|
env:
|
||||||
|
classname: environments.factory.make
|
||||||
|
env_name: "DirtyFactory-v0"
|
||||||
|
n_agents: 2
|
||||||
|
max_steps: 250
|
||||||
|
pomdp_r: 2
|
||||||
|
stack_n_frames: 0
|
||||||
|
individual_rewards: True
|
||||||
|
method: algorithms.marl.LoopSEAC
|
||||||
|
algorithm:
|
||||||
|
gamma: 0.99
|
||||||
|
entropy_coef: 0.01
|
||||||
|
vf_coef: 0.5
|
||||||
|
n_steps: 5
|
||||||
|
max_steps: 1000000
|
||||||
|
|
57
algorithms/marl/iac.py
Normal file
57
algorithms/marl/iac.py
Normal file
@ -0,0 +1,57 @@
|
|||||||
|
import torch
|
||||||
|
from algorithms.marl.base_ac import BaseActorCritic, nms
|
||||||
|
from algorithms.utils import instantiate_class
|
||||||
|
from pathlib import Path
|
||||||
|
from natsort import natsorted
|
||||||
|
from algorithms.marl.memory import MARLActorCriticMemory
|
||||||
|
|
||||||
|
|
||||||
|
class LoopIAC(BaseActorCritic):
|
||||||
|
|
||||||
|
def __init__(self, cfg):
|
||||||
|
super(LoopIAC, self).__init__(cfg)
|
||||||
|
|
||||||
|
def setup(self):
|
||||||
|
self.net = [
|
||||||
|
instantiate_class(self.cfg[nms.AGENT]) for _ in range(self.n_agents)
|
||||||
|
]
|
||||||
|
self.optimizer = [
|
||||||
|
torch.optim.RMSprop(self.net[ag_i].parameters(), lr=3e-4, eps=1e-5) for ag_i in range(self.n_agents)
|
||||||
|
]
|
||||||
|
|
||||||
|
def load_state_dict(self, path: Path):
|
||||||
|
paths = natsorted(list(path.glob('*.pt')))
|
||||||
|
for path, net in zip(paths, self.net):
|
||||||
|
net.load_state_dict(torch.load(path))
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def merge_dicts(ds): # todo could be recursive for more than 1 hierarchy
|
||||||
|
d = {}
|
||||||
|
for k in ds[0].keys():
|
||||||
|
d[k] = [d[k] for d in ds]
|
||||||
|
return d
|
||||||
|
|
||||||
|
def init_hidden(self):
|
||||||
|
ha = [net.init_hidden_actor() for net in self.net]
|
||||||
|
hc = [net.init_hidden_critic() for net in self.net]
|
||||||
|
return dict(hidden_actor=ha, hidden_critic=hc)
|
||||||
|
|
||||||
|
def forward(self, observations, actions, hidden_actor, hidden_critic):
|
||||||
|
outputs = [
|
||||||
|
net(
|
||||||
|
self._as_torch(observations[ag_i]).unsqueeze(0).unsqueeze(0), # agents x time
|
||||||
|
self._as_torch(actions[ag_i]).unsqueeze(0),
|
||||||
|
hidden_actor[ag_i],
|
||||||
|
hidden_critic[ag_i]
|
||||||
|
) for ag_i, net in enumerate(self.net)
|
||||||
|
]
|
||||||
|
return self.merge_dicts(outputs)
|
||||||
|
|
||||||
|
def learn(self, tms: MARLActorCriticMemory, **kwargs):
|
||||||
|
for ag_i in range(self.n_agents):
|
||||||
|
tm, net = tms(ag_i), self.net[ag_i]
|
||||||
|
loss = self.actor_critic(tm, net, **self.cfg[nms.ALGORITHM], **kwargs)
|
||||||
|
self.optimizer[ag_i].zero_grad()
|
||||||
|
loss.backward()
|
||||||
|
torch.nn.utils.clip_grad_norm_(net.parameters(), 0.5)
|
||||||
|
self.optimizer[ag_i].step()
|
67
algorithms/marl/mappo.py
Normal file
67
algorithms/marl/mappo.py
Normal file
@ -0,0 +1,67 @@
|
|||||||
|
from algorithms.marl.base_ac import Names as nms
|
||||||
|
from algorithms.marl import LoopSNAC
|
||||||
|
from algorithms.marl.memory import MARLActorCriticMemory
|
||||||
|
import random
|
||||||
|
import torch
|
||||||
|
from torch.distributions import Categorical
|
||||||
|
from algorithms.utils import instantiate_class
|
||||||
|
|
||||||
|
|
||||||
|
class LoopMAPPO(LoopSNAC):
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
super(LoopMAPPO, self).__init__(*args, **kwargs)
|
||||||
|
self.reset_memory_after_epoch = False
|
||||||
|
|
||||||
|
def setup(self):
|
||||||
|
self.net = instantiate_class(self.cfg[nms.AGENT])
|
||||||
|
self.optimizer = torch.optim.Adam(self.net.parameters(), lr=3e-4, eps=1e-5)
|
||||||
|
|
||||||
|
def learn(self, tm: MARLActorCriticMemory, **kwargs):
|
||||||
|
if len(tm) >= self.cfg['algorithm']['buffer_size']:
|
||||||
|
# only learn when buffer is full
|
||||||
|
for batch_i in range(self.cfg['algorithm']['n_updates']):
|
||||||
|
batch = tm.chunk_dataloader(chunk_len=self.cfg['algorithm']['n_steps'],
|
||||||
|
k=self.cfg['algorithm']['batch_size'])
|
||||||
|
loss = self.mappo(batch, self.net, **self.cfg[nms.ALGORITHM], **kwargs)
|
||||||
|
self.optimizer.zero_grad()
|
||||||
|
loss.backward()
|
||||||
|
torch.nn.utils.clip_grad_norm_(self.net.parameters(), 0.5)
|
||||||
|
self.optimizer.step()
|
||||||
|
|
||||||
|
def monte_carlo_returns(self, rewards, done, gamma):
|
||||||
|
rewards_ = []
|
||||||
|
discounted_reward = torch.zeros_like(rewards[:, -1])
|
||||||
|
for t in range(rewards.shape[1]-1, -1, -1):
|
||||||
|
discounted_reward = rewards[:, t] + (gamma * (1.0 - done[:, t]) * discounted_reward)
|
||||||
|
rewards_.insert(0, discounted_reward)
|
||||||
|
rewards_ = torch.stack(rewards_, dim=1)
|
||||||
|
return rewards_
|
||||||
|
|
||||||
|
def mappo(self, batch, network, gamma, entropy_coef, vf_coef, clip_range, **kwargs):
|
||||||
|
out = network(batch[nms.OBSERVATION], batch[nms.ACTION], batch[nms.HIDDEN_ACTOR], batch[nms.HIDDEN_CRITIC])
|
||||||
|
logits = out[nms.LOGITS][:, :-1] # last one only needed for v_{t+1}
|
||||||
|
|
||||||
|
old_log_probs = torch.log_softmax(batch[nms.LOGITS], -1)
|
||||||
|
old_log_probs = torch.gather(old_log_probs, index=batch[nms.ACTION][:, 1:].unsqueeze(-1), dim=-1).squeeze()
|
||||||
|
|
||||||
|
# monte carlo returns
|
||||||
|
mc_returns = self.monte_carlo_returns(batch[nms.REWARD], batch[nms.DONE], gamma)
|
||||||
|
mc_returns = (mc_returns - mc_returns.mean()) / (mc_returns.std() + 1e-8) #todo: norm across agents ok?
|
||||||
|
advantages = mc_returns - out[nms.CRITIC][:, :-1]
|
||||||
|
|
||||||
|
# policy loss
|
||||||
|
log_ap = torch.log_softmax(logits, -1)
|
||||||
|
log_ap = torch.gather(log_ap, dim=-1, index=batch[nms.ACTION][:, 1:].unsqueeze(-1)).squeeze()
|
||||||
|
ratio = (log_ap - old_log_probs).exp()
|
||||||
|
surr1 = ratio * advantages.detach()
|
||||||
|
surr2 = torch.clamp(ratio, 1 - clip_range, 1 + clip_range) * advantages.detach()
|
||||||
|
policy_loss = -torch.min(surr1, surr2).mean(-1)
|
||||||
|
|
||||||
|
# entropy & value loss
|
||||||
|
entropy_loss = Categorical(logits=logits).entropy().mean(-1)
|
||||||
|
value_loss = advantages.pow(2).mean(-1) # n_agent
|
||||||
|
|
||||||
|
# weighted loss
|
||||||
|
loss = policy_loss + vf_coef*value_loss - entropy_coef * entropy_loss
|
||||||
|
|
||||||
|
return loss.mean()
|
221
algorithms/marl/memory.py
Normal file
221
algorithms/marl/memory.py
Normal file
@ -0,0 +1,221 @@
|
|||||||
|
import numpy as np
|
||||||
|
from collections import deque
|
||||||
|
import torch
|
||||||
|
from typing import Union
|
||||||
|
from torch import Tensor
|
||||||
|
from torch.utils.data import Dataset, ConcatDataset
|
||||||
|
import random
|
||||||
|
|
||||||
|
|
||||||
|
class ActorCriticMemory(object):
|
||||||
|
def __init__(self, capacity=10):
|
||||||
|
self.capacity = capacity
|
||||||
|
self.reset()
|
||||||
|
|
||||||
|
def reset(self):
|
||||||
|
self.__actions = LazyTensorFiFoQueue(maxlen=self.capacity+1)
|
||||||
|
self.__hidden_actor = LazyTensorFiFoQueue(maxlen=self.capacity+1)
|
||||||
|
self.__hidden_critic = LazyTensorFiFoQueue(maxlen=self.capacity+1)
|
||||||
|
self.__states = LazyTensorFiFoQueue(maxlen=self.capacity+1)
|
||||||
|
self.__rewards = LazyTensorFiFoQueue(maxlen=self.capacity+1)
|
||||||
|
self.__dones = LazyTensorFiFoQueue(maxlen=self.capacity+1)
|
||||||
|
self.__logits = LazyTensorFiFoQueue(maxlen=self.capacity+1)
|
||||||
|
self.__values = LazyTensorFiFoQueue(maxlen=self.capacity+1)
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return len(self.__rewards) - 1
|
||||||
|
|
||||||
|
@property
|
||||||
|
def observation(self, sls=slice(0, None)): # add time dimension through stacking
|
||||||
|
return self.__states[sls].unsqueeze(0) # 1 x time x hidden dim
|
||||||
|
|
||||||
|
@property
|
||||||
|
def hidden_actor(self, sls=slice(0, None)): # 1 x n_layers x dim
|
||||||
|
return self.__hidden_actor[sls].unsqueeze(0) # 1 x time x n_layers x dim
|
||||||
|
|
||||||
|
@property
|
||||||
|
def hidden_critic(self, sls=slice(0, None)): # 1 x n_layers x dim
|
||||||
|
return self.__hidden_critic[sls].unsqueeze(0) # 1 x time x n_layers x dim
|
||||||
|
|
||||||
|
@property
|
||||||
|
def reward(self, sls=slice(0, None)):
|
||||||
|
return self.__rewards[sls].squeeze().unsqueeze(0) # 1 x time
|
||||||
|
|
||||||
|
@property
|
||||||
|
def action(self, sls=slice(0, None)):
|
||||||
|
return self.__actions[sls].long().squeeze().unsqueeze(0) # 1 x time
|
||||||
|
|
||||||
|
@property
|
||||||
|
def done(self, sls=slice(0, None)):
|
||||||
|
return self.__dones[sls].float().squeeze().unsqueeze(0) # 1 x time
|
||||||
|
|
||||||
|
@property
|
||||||
|
def logits(self, sls=slice(0, None)): # assumes a trailing 1 for time dimension - common when using output from NN
|
||||||
|
return self.__logits[sls].squeeze().unsqueeze(0) # 1 x time x actions
|
||||||
|
|
||||||
|
@property
|
||||||
|
def values(self, sls=slice(0, None)):
|
||||||
|
return self.__values[sls].squeeze().unsqueeze(0) # 1 x time x actions
|
||||||
|
|
||||||
|
def add_observation(self, state: Union[Tensor, np.ndarray]):
|
||||||
|
self.__states.append(state if isinstance(state, Tensor) else torch.from_numpy(state))
|
||||||
|
|
||||||
|
def add_hidden_actor(self, hidden: Tensor):
|
||||||
|
# layers x hidden dim
|
||||||
|
self.__hidden_actor.append(hidden)
|
||||||
|
|
||||||
|
def add_hidden_critic(self, hidden: Tensor):
|
||||||
|
# layers x hidden dim
|
||||||
|
self.__hidden_critic.append(hidden)
|
||||||
|
|
||||||
|
def add_action(self, action: Union[int, Tensor]):
|
||||||
|
if not isinstance(action, Tensor):
|
||||||
|
action = torch.tensor(action)
|
||||||
|
self.__actions.append(action)
|
||||||
|
|
||||||
|
def add_reward(self, reward: Union[float, Tensor]):
|
||||||
|
if not isinstance(reward, Tensor):
|
||||||
|
reward = torch.tensor(reward)
|
||||||
|
self.__rewards.append(reward)
|
||||||
|
|
||||||
|
def add_done(self, done: bool):
|
||||||
|
if not isinstance(done, Tensor):
|
||||||
|
done = torch.tensor(done)
|
||||||
|
self.__dones.append(done)
|
||||||
|
|
||||||
|
def add_logits(self, logits: Tensor):
|
||||||
|
self.__logits.append(logits)
|
||||||
|
|
||||||
|
def add_values(self, values: Tensor):
|
||||||
|
self.__values.append(values)
|
||||||
|
|
||||||
|
def add(self, **kwargs):
|
||||||
|
for k, v in kwargs.items():
|
||||||
|
func = getattr(ActorCriticMemory, f'add_{k}')
|
||||||
|
func(self, v)
|
||||||
|
|
||||||
|
|
||||||
|
class MARLActorCriticMemory(object):
|
||||||
|
def __init__(self, n_agents, capacity):
|
||||||
|
self.n_agents = n_agents
|
||||||
|
self.memories = [
|
||||||
|
ActorCriticMemory(capacity) for _ in range(n_agents)
|
||||||
|
]
|
||||||
|
|
||||||
|
def __call__(self, agent_i):
|
||||||
|
return self.memories[agent_i]
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return len(self.memories[0]) # todo add assertion check!
|
||||||
|
|
||||||
|
def reset(self):
|
||||||
|
for mem in self.memories:
|
||||||
|
mem.reset()
|
||||||
|
|
||||||
|
def add(self, **kwargs):
|
||||||
|
for agent_i in range(self.n_agents):
|
||||||
|
for k, v in kwargs.items():
|
||||||
|
func = getattr(ActorCriticMemory, f'add_{k}')
|
||||||
|
func(self.memories[agent_i], v[agent_i])
|
||||||
|
|
||||||
|
def __getattr__(self, attr):
|
||||||
|
all_attrs = [getattr(mem, attr) for mem in self.memories]
|
||||||
|
return torch.cat(all_attrs, 0) # agents x time ...
|
||||||
|
|
||||||
|
def chunk_dataloader(self, chunk_len, k):
|
||||||
|
datasets = [ExperienceChunks(mem, chunk_len, k) for mem in self.memories]
|
||||||
|
dataset = ConcatDataset(datasets)
|
||||||
|
data = [dataset[i] for i in range(len(dataset))]
|
||||||
|
data = custom_collate_fn(data)
|
||||||
|
return data
|
||||||
|
|
||||||
|
|
||||||
|
def custom_collate_fn(batch):
|
||||||
|
elem = batch[0]
|
||||||
|
return {key: torch.cat([d[key] for d in batch], dim=0) for key in elem}
|
||||||
|
|
||||||
|
|
||||||
|
class ExperienceChunks(Dataset):
|
||||||
|
def __init__(self, memory, chunk_len, k):
|
||||||
|
assert chunk_len <= len(memory), 'chunk_len cannot be longer than the size of the memory'
|
||||||
|
self.memory = memory
|
||||||
|
self.chunk_len = chunk_len
|
||||||
|
self.k = k
|
||||||
|
|
||||||
|
@property
|
||||||
|
def whitelist(self):
|
||||||
|
whitelist = torch.ones(len(self.memory) - self.chunk_len)
|
||||||
|
for d in self.memory.done.squeeze().nonzero().flatten():
|
||||||
|
whitelist[max((0, d-self.chunk_len-1)):d+2] = 0
|
||||||
|
whitelist[0] = 0
|
||||||
|
return whitelist.tolist()
|
||||||
|
|
||||||
|
def sample(self, start=1):
|
||||||
|
cl = self.chunk_len
|
||||||
|
sample = dict(observation=self.memory.observation[:, start:start+cl+1],
|
||||||
|
action=self.memory.action[:, start-1:start+cl],
|
||||||
|
hidden_actor=self.memory.hidden_actor[:, start-1],
|
||||||
|
hidden_critic=self.memory.hidden_critic[:, start-1],
|
||||||
|
reward=self.memory.reward[:, start:start + cl],
|
||||||
|
done=self.memory.done[:, start:start + cl],
|
||||||
|
logits=self.memory.logits[:, start:start + cl],
|
||||||
|
values=self.memory.values[:, start:start + cl])
|
||||||
|
return sample
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return self.k
|
||||||
|
|
||||||
|
def __getitem__(self, i):
|
||||||
|
idx = random.choices(range(0, len(self.memory) - self.chunk_len), weights=self.whitelist, k=1)
|
||||||
|
return self.sample(idx[0])
|
||||||
|
|
||||||
|
|
||||||
|
class LazyTensorFiFoQueue:
|
||||||
|
def __init__(self, maxlen):
|
||||||
|
self.maxlen = maxlen
|
||||||
|
self.reset()
|
||||||
|
|
||||||
|
def reset(self):
|
||||||
|
self.__lazy_queue = deque(maxlen=self.maxlen)
|
||||||
|
self.shape = None
|
||||||
|
self.queue = None
|
||||||
|
|
||||||
|
def shape_init(self, tensor: Tensor):
|
||||||
|
self.shape = torch.Size([self.maxlen, *tensor.shape])
|
||||||
|
|
||||||
|
def build_tensor_queue(self):
|
||||||
|
if len(self.__lazy_queue) > 0:
|
||||||
|
block = torch.stack(list(self.__lazy_queue), dim=0)
|
||||||
|
l = block.shape[0]
|
||||||
|
if self.queue is None:
|
||||||
|
self.queue = block
|
||||||
|
elif self.true_len() <= self.maxlen:
|
||||||
|
self.queue = torch.cat((self.queue, block), dim=0)
|
||||||
|
else:
|
||||||
|
self.queue = torch.cat((self.queue[l:], block), dim=0)
|
||||||
|
self.__lazy_queue.clear()
|
||||||
|
|
||||||
|
def append(self, data):
|
||||||
|
if self.shape is None:
|
||||||
|
self.shape_init(data)
|
||||||
|
self.__lazy_queue.append(data)
|
||||||
|
if len(self.__lazy_queue) >= self.maxlen:
|
||||||
|
self.build_tensor_queue()
|
||||||
|
|
||||||
|
def true_len(self):
|
||||||
|
return len(self.__lazy_queue) + (0 if self.queue is None else self.queue.shape[0])
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return min((self.true_len(), self.maxlen))
|
||||||
|
|
||||||
|
def __str__(self):
|
||||||
|
return f'LazyTensorFiFoQueue\tmaxlen: {self.maxlen}, shape: {self.shape}, ' \
|
||||||
|
f'len: {len(self)}, true_len: {self.true_len()}, elements in lazy queue: {len(self.__lazy_queue)}'
|
||||||
|
|
||||||
|
def __getitem__(self, item_or_slice):
|
||||||
|
self.build_tensor_queue()
|
||||||
|
return self.queue[item_or_slice]
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
104
algorithms/marl/networks.py
Normal file
104
algorithms/marl/networks.py
Normal file
@ -0,0 +1,104 @@
|
|||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import numpy as np
|
||||||
|
import torch.nn.functional as F
|
||||||
|
from torch.nn.utils import spectral_norm
|
||||||
|
|
||||||
|
|
||||||
|
class RecurrentAC(nn.Module):
|
||||||
|
def __init__(self, observation_size, n_actions, obs_emb_size,
|
||||||
|
action_emb_size, hidden_size_actor, hidden_size_critic,
|
||||||
|
n_agents, use_agent_embedding=True):
|
||||||
|
super(RecurrentAC, self).__init__()
|
||||||
|
observation_size = np.prod(observation_size)
|
||||||
|
self.n_layers = 1
|
||||||
|
self.n_actions = n_actions
|
||||||
|
self.use_agent_embedding = use_agent_embedding
|
||||||
|
self.hidden_size_actor = hidden_size_actor
|
||||||
|
self.hidden_size_critic = hidden_size_critic
|
||||||
|
self.action_emb_size = action_emb_size
|
||||||
|
self.obs_proj = nn.Linear(observation_size, obs_emb_size)
|
||||||
|
self.action_emb = nn.Embedding(n_actions+1, action_emb_size, padding_idx=0)
|
||||||
|
self.agent_emb = nn.Embedding(n_agents, action_emb_size)
|
||||||
|
mix_in_size = obs_emb_size+action_emb_size if not use_agent_embedding else obs_emb_size+n_agents*action_emb_size
|
||||||
|
self.mix = nn.Sequential(nn.Tanh(),
|
||||||
|
nn.Linear(mix_in_size, obs_emb_size),
|
||||||
|
nn.Tanh(),
|
||||||
|
nn.Linear(obs_emb_size, obs_emb_size)
|
||||||
|
)
|
||||||
|
self.gru_actor = nn.GRU(obs_emb_size, hidden_size_actor, batch_first=True, num_layers=self.n_layers)
|
||||||
|
self.gru_critic = nn.GRU(obs_emb_size, hidden_size_critic, batch_first=True, num_layers=self.n_layers)
|
||||||
|
self.action_head = nn.Sequential(
|
||||||
|
nn.Linear(hidden_size_actor, hidden_size_actor),
|
||||||
|
nn.Tanh(),
|
||||||
|
nn.Linear(hidden_size_actor, n_actions)
|
||||||
|
)
|
||||||
|
# spectral_norm(nn.Linear(hidden_size_actor, hidden_size_actor)),
|
||||||
|
self.critic_head = nn.Sequential(
|
||||||
|
nn.Linear(hidden_size_critic, hidden_size_critic),
|
||||||
|
nn.Tanh(),
|
||||||
|
nn.Linear(hidden_size_critic, 1)
|
||||||
|
)
|
||||||
|
#self.action_head[-1].weight.data.uniform_(-3e-3, 3e-3)
|
||||||
|
#self.action_head[-1].bias.data.uniform_(-3e-3, 3e-3)
|
||||||
|
|
||||||
|
def init_hidden_actor(self):
|
||||||
|
return torch.zeros(1, self.n_layers, self.hidden_size_actor)
|
||||||
|
|
||||||
|
def init_hidden_critic(self):
|
||||||
|
return torch.zeros(1, self.n_layers, self.hidden_size_critic)
|
||||||
|
|
||||||
|
def forward(self, observations, actions, hidden_actor=None, hidden_critic=None):
|
||||||
|
n_agents, t, *_ = observations.shape
|
||||||
|
obs_emb = self.obs_proj(observations.view(n_agents, t, -1).float())
|
||||||
|
action_emb = self.action_emb(actions+1) # shift by one due to padding idx
|
||||||
|
|
||||||
|
if not self.use_agent_embedding:
|
||||||
|
x_t = torch.cat((obs_emb, action_emb), -1)
|
||||||
|
else:
|
||||||
|
agent_emb = self.agent_emb(
|
||||||
|
torch.cat([torch.arange(0, n_agents, 1).view(-1, 1)] * t, 1)
|
||||||
|
)
|
||||||
|
x_t = torch.cat((obs_emb, agent_emb, action_emb), -1)
|
||||||
|
|
||||||
|
mixed_x_t = self.mix(x_t)
|
||||||
|
output_p, _ = self.gru_actor(input=mixed_x_t, hx=hidden_actor.swapaxes(1, 0))
|
||||||
|
output_c, _ = self.gru_critic(input=mixed_x_t, hx=hidden_critic.swapaxes(1, 0))
|
||||||
|
|
||||||
|
logits = self.action_head(output_p)
|
||||||
|
critic = self.critic_head(output_c).squeeze(-1)
|
||||||
|
return dict(logits=logits, critic=critic, hidden_actor=output_p, hidden_critic=output_c)
|
||||||
|
|
||||||
|
|
||||||
|
class RecurrentACL2(RecurrentAC):
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
self.action_head = nn.Sequential(
|
||||||
|
nn.Linear(self.hidden_size_actor, self.hidden_size_actor),
|
||||||
|
nn.Tanh(),
|
||||||
|
NormalizedLinear(self.hidden_size_actor, self.n_actions, trainable_magnitude=True)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class NormalizedLinear(nn.Linear):
|
||||||
|
def __init__(self, in_features: int, out_features: int,
|
||||||
|
device=None, dtype=None, trainable_magnitude=False):
|
||||||
|
super(NormalizedLinear, self).__init__(in_features, out_features, False, device, dtype)
|
||||||
|
self.d_sqrt = in_features**0.5
|
||||||
|
self.trainable_magnitude = trainable_magnitude
|
||||||
|
self.scale = nn.Parameter(torch.tensor([1.]), requires_grad=trainable_magnitude)
|
||||||
|
|
||||||
|
def forward(self, input):
|
||||||
|
normalized_input = F.normalize(input, dim=-1, p=2, eps=1e-5)
|
||||||
|
normalized_weight = F.normalize(self.weight, dim=-1, p=2, eps=1e-5)
|
||||||
|
return F.linear(normalized_input, normalized_weight) * self.d_sqrt * self.scale
|
||||||
|
|
||||||
|
|
||||||
|
class L2Norm(nn.Module):
|
||||||
|
def __init__(self, in_features, trainable_magnitude=False):
|
||||||
|
super(L2Norm, self).__init__()
|
||||||
|
self.d_sqrt = in_features**0.5
|
||||||
|
self.scale = nn.Parameter(torch.tensor([1.]), requires_grad=trainable_magnitude)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
return F.normalize(x, dim=-1, p=2, eps=1e-5) * self.d_sqrt * self.scale
|
56
algorithms/marl/seac.py
Normal file
56
algorithms/marl/seac.py
Normal file
@ -0,0 +1,56 @@
|
|||||||
|
import torch
|
||||||
|
from torch.distributions import Categorical
|
||||||
|
from algorithms.marl.iac import LoopIAC
|
||||||
|
from algorithms.marl.base_ac import nms
|
||||||
|
from algorithms.marl.memory import MARLActorCriticMemory
|
||||||
|
|
||||||
|
|
||||||
|
class LoopSEAC(LoopIAC):
|
||||||
|
def __init__(self, cfg):
|
||||||
|
super(LoopSEAC, self).__init__(cfg)
|
||||||
|
|
||||||
|
def actor_critic(self, tm, networks, gamma, entropy_coef, vf_coef, gae_coef=0.0, **kwargs):
|
||||||
|
obs, actions, done, reward = tm.observation, tm.action, tm.done[:, 1:], tm.reward[:, 1:]
|
||||||
|
outputs = [net(obs, actions, tm.hidden_actor[:, 0], tm.hidden_critic[:, 0]) for net in networks]
|
||||||
|
|
||||||
|
with torch.inference_mode(True):
|
||||||
|
true_action_logp = torch.stack([
|
||||||
|
torch.log_softmax(out[nms.LOGITS][ag_i, :-1], -1)
|
||||||
|
.gather(index=actions[ag_i, 1:, None], dim=-1)
|
||||||
|
for ag_i, out in enumerate(outputs)
|
||||||
|
], 0).squeeze()
|
||||||
|
|
||||||
|
losses = []
|
||||||
|
|
||||||
|
for ag_i, out in enumerate(outputs):
|
||||||
|
logits = out[nms.LOGITS][:, :-1] # last one only needed for v_{t+1}
|
||||||
|
critic = out[nms.CRITIC]
|
||||||
|
|
||||||
|
entropy_loss = Categorical(logits=logits[ag_i]).entropy().mean()
|
||||||
|
advantages = self.compute_advantages(critic, reward, done, gamma, gae_coef)
|
||||||
|
|
||||||
|
# policy loss
|
||||||
|
log_ap = torch.log_softmax(logits, -1)
|
||||||
|
log_ap = torch.gather(log_ap, dim=-1, index=actions[:, 1:].unsqueeze(-1)).squeeze()
|
||||||
|
|
||||||
|
# importance weights
|
||||||
|
iw = (log_ap - true_action_logp).exp().detach() # importance_weights
|
||||||
|
|
||||||
|
a2c_loss = (-iw*log_ap * advantages.detach()).mean(-1)
|
||||||
|
|
||||||
|
|
||||||
|
value_loss = (iw*advantages.pow(2)).mean(-1) # n_agent
|
||||||
|
|
||||||
|
# weighted loss
|
||||||
|
loss = (a2c_loss + vf_coef*value_loss - entropy_coef * entropy_loss).mean()
|
||||||
|
losses.append(loss)
|
||||||
|
|
||||||
|
return losses
|
||||||
|
|
||||||
|
def learn(self, tms: MARLActorCriticMemory, **kwargs):
|
||||||
|
losses = self.actor_critic(tms, self.net, **self.cfg[nms.ALGORITHM], **kwargs)
|
||||||
|
for ag_i, loss in enumerate(losses):
|
||||||
|
self.optimizer[ag_i].zero_grad()
|
||||||
|
loss.backward()
|
||||||
|
torch.nn.utils.clip_grad_norm_(self.net[ag_i].parameters(), 0.5)
|
||||||
|
self.optimizer[ag_i].step()
|
33
algorithms/marl/snac.py
Normal file
33
algorithms/marl/snac.py
Normal file
@ -0,0 +1,33 @@
|
|||||||
|
from algorithms.marl.base_ac import BaseActorCritic
|
||||||
|
from algorithms.marl.base_ac import nms
|
||||||
|
import torch
|
||||||
|
from torch.distributions import Categorical
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
|
||||||
|
class LoopSNAC(BaseActorCritic):
|
||||||
|
def __init__(self, cfg):
|
||||||
|
super().__init__(cfg)
|
||||||
|
|
||||||
|
def load_state_dict(self, path: Path):
|
||||||
|
path2weights = list(path.glob('*.pt'))
|
||||||
|
assert len(path2weights) == 1, f'Expected a single set of weights but got {len(path2weights)}'
|
||||||
|
self.net.load_state_dict(torch.load(path2weights[0]))
|
||||||
|
|
||||||
|
def init_hidden(self):
|
||||||
|
hidden_actor = self.net.init_hidden_actor()
|
||||||
|
hidden_critic = self.net.init_hidden_critic()
|
||||||
|
return dict(hidden_actor=torch.cat([hidden_actor] * self.n_agents, 0),
|
||||||
|
hidden_critic=torch.cat([hidden_critic] * self.n_agents, 0)
|
||||||
|
)
|
||||||
|
|
||||||
|
def get_actions(self, out):
|
||||||
|
actions = Categorical(logits=out[nms.LOGITS]).sample().squeeze()
|
||||||
|
return actions
|
||||||
|
|
||||||
|
def forward(self, observations, actions, hidden_actor, hidden_critic):
|
||||||
|
out = self.net(self._as_torch(observations).unsqueeze(1),
|
||||||
|
self._as_torch(actions).unsqueeze(1),
|
||||||
|
hidden_actor, hidden_critic
|
||||||
|
)
|
||||||
|
return out
|
@ -1,127 +0,0 @@
|
|||||||
from typing import Union
|
|
||||||
import gym
|
|
||||||
import torch
|
|
||||||
import torch.nn as nn
|
|
||||||
import numpy as np
|
|
||||||
from collections import deque
|
|
||||||
from pathlib import Path
|
|
||||||
import yaml
|
|
||||||
from algorithms.common import BaseLearner, BaseBuffer, soft_update, Experience
|
|
||||||
|
|
||||||
|
|
||||||
class QLearner(BaseLearner):
|
|
||||||
def __init__(self, q_net, target_q_net, env, buffer_size=1e5, target_update=3000, eps_end=0.05, n_agents=1,
|
|
||||||
gamma=0.99, train_every=('step', 4), n_grad_steps=1, tau=1.0, max_grad_norm=10, weight_decay=1e-2,
|
|
||||||
exploration_fraction=0.2, batch_size=64, lr=1e-4, reg_weight=0.0, eps_start=1):
|
|
||||||
super(QLearner, self).__init__(env, n_agents, train_every, n_grad_steps)
|
|
||||||
self.q_net = q_net
|
|
||||||
self.target_q_net = target_q_net
|
|
||||||
self.target_q_net.eval()
|
|
||||||
#soft_update(self.q_net, self.target_q_net, tau=1.0)
|
|
||||||
self.buffer = BaseBuffer(buffer_size)
|
|
||||||
self.target_update = target_update
|
|
||||||
self.eps = eps_start
|
|
||||||
self.eps_start = eps_start
|
|
||||||
self.eps_end = eps_end
|
|
||||||
self.exploration_fraction = exploration_fraction
|
|
||||||
self.batch_size = batch_size
|
|
||||||
self.gamma = gamma
|
|
||||||
self.tau = tau
|
|
||||||
self.reg_weight = reg_weight
|
|
||||||
self.weight_decay = weight_decay
|
|
||||||
self.lr = lr
|
|
||||||
self.optimizer = torch.optim.AdamW(self.q_net.parameters(), lr=self.lr, weight_decay=self.weight_decay)
|
|
||||||
self.max_grad_norm = max_grad_norm
|
|
||||||
self.running_reward = deque(maxlen=5)
|
|
||||||
self.running_loss = deque(maxlen=5)
|
|
||||||
self.n_updates = 0
|
|
||||||
|
|
||||||
def save(self, path):
|
|
||||||
path = Path(path) # no-op if already instance of Path
|
|
||||||
path.mkdir(parents=True, exist_ok=True)
|
|
||||||
hparams = {k: v for k, v in self.__dict__.items() if not(isinstance(v, BaseBuffer) or
|
|
||||||
isinstance(v, torch.optim.Optimizer) or
|
|
||||||
isinstance(v, gym.Env) or
|
|
||||||
isinstance(v, nn.Module))
|
|
||||||
}
|
|
||||||
hparams.update({'class': self.__class__.__name__})
|
|
||||||
with (path / 'hparams.yaml').open('w') as outfile:
|
|
||||||
yaml.dump(hparams, outfile)
|
|
||||||
torch.save(self.q_net, path / 'q_net.pt')
|
|
||||||
|
|
||||||
def anneal_eps(self, step, n_steps):
|
|
||||||
fraction = min(float(step) / int(self.exploration_fraction*n_steps), 1.0)
|
|
||||||
self.eps = 1 + fraction * (self.eps_end - 1)
|
|
||||||
|
|
||||||
def get_action(self, obs) -> Union[int, np.ndarray]:
|
|
||||||
o = torch.from_numpy(obs).unsqueeze(0) if self.n_agents <= 1 else torch.from_numpy(obs)
|
|
||||||
if np.random.rand() > self.eps:
|
|
||||||
action = self.q_net.act(o.float())
|
|
||||||
else:
|
|
||||||
action = np.array([self.env.action_space.sample() for _ in range(self.n_agents)])
|
|
||||||
return action
|
|
||||||
|
|
||||||
def on_new_experience(self, experience):
|
|
||||||
self.buffer.add(experience)
|
|
||||||
|
|
||||||
def on_step_end(self, n_steps):
|
|
||||||
self.anneal_eps(self.step, n_steps)
|
|
||||||
if self.step % self.target_update == 0:
|
|
||||||
print('UPDATE')
|
|
||||||
soft_update(self.q_net, self.target_q_net, tau=self.tau)
|
|
||||||
|
|
||||||
def _training_routine(self, obs, next_obs, action):
|
|
||||||
current_q_values = self.q_net(obs)
|
|
||||||
current_q_values = torch.gather(current_q_values, dim=-1, index=action)
|
|
||||||
next_q_values_raw = self.target_q_net(next_obs).max(dim=-1)[0].reshape(-1, 1).detach()
|
|
||||||
return current_q_values, next_q_values_raw
|
|
||||||
|
|
||||||
def _backprop_loss(self, loss):
|
|
||||||
# log loss
|
|
||||||
self.running_loss.append(loss.item())
|
|
||||||
# Optimize the model
|
|
||||||
self.optimizer.zero_grad()
|
|
||||||
loss.backward()
|
|
||||||
torch.nn.utils.clip_grad_norm_(self.q_net.parameters(), self.max_grad_norm)
|
|
||||||
self.optimizer.step()
|
|
||||||
|
|
||||||
def train(self):
|
|
||||||
if len(self.buffer) < self.batch_size: return
|
|
||||||
for _ in range(self.n_grad_steps):
|
|
||||||
experience = self.buffer.sample(self.batch_size, cer=self.train_every[-1])
|
|
||||||
pred_q, target_q_raw = self._training_routine(experience.observation,
|
|
||||||
experience.next_observation,
|
|
||||||
experience.action)
|
|
||||||
target_q = experience.reward + (1 - experience.done) * self.gamma * target_q_raw
|
|
||||||
loss = torch.mean(self.reg_weight * pred_q + torch.pow(pred_q - target_q, 2))
|
|
||||||
self._backprop_loss(loss)
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
|
||||||
from environments.factory.factory_dirt import DirtFactory, DirtProperties, MovementProperties
|
|
||||||
from algorithms.common import BaseDDQN, BaseICM
|
|
||||||
from algorithms.m_q_learner import MQLearner, MQICMLearner
|
|
||||||
from algorithms.vdn_learner import VDNLearner
|
|
||||||
|
|
||||||
N_AGENTS = 1
|
|
||||||
|
|
||||||
with (Path(f'../environments/factory/env_default_param.yaml')).open('r') as f:
|
|
||||||
env_kwargs = yaml.load(f, Loader=yaml.FullLoader)
|
|
||||||
|
|
||||||
env = DirtFactory(**env_kwargs)
|
|
||||||
obs_shape = np.prod(env.observation_space.shape)
|
|
||||||
n_actions = env.action_space.n
|
|
||||||
|
|
||||||
dqn, target_dqn = BaseDDQN(backbone_dims=[obs_shape, 128, 128], advantage_dims=[128, n_actions], value_dims=[128, 1], activation='leaky_relu'),\
|
|
||||||
BaseDDQN(backbone_dims=[obs_shape, 128, 128], advantage_dims=[128, n_actions], value_dims=[128, 1], activation='leaky_relu')
|
|
||||||
|
|
||||||
icm = BaseICM(backbone_dims=[obs_shape, 64, 32], head_dims=[2*32, 64, n_actions])
|
|
||||||
|
|
||||||
learner = MQICMLearner(dqn, target_dqn, env, 50000, icm=icm,
|
|
||||||
target_update=5000, lr=0.0007, gamma=0.99, n_agents=N_AGENTS, tau=0.95, max_grad_norm=10,
|
|
||||||
train_every=('step', 4), eps_end=0.025, n_grad_steps=1, reg_weight=0.1, exploration_fraction=0.25,
|
|
||||||
batch_size=64, weight_decay=1e-3
|
|
||||||
)
|
|
||||||
#learner.save(Path(__file__).parent / 'test' / 'testexperiment1337')
|
|
||||||
learner.learn(100000)
|
|
@ -1,52 +0,0 @@
|
|||||||
import numpy as np
|
|
||||||
import torch
|
|
||||||
import stable_baselines3 as sb3
|
|
||||||
from stable_baselines3.common import logger
|
|
||||||
|
|
||||||
|
|
||||||
class RegDQN(sb3.dqn.DQN):
|
|
||||||
def __init__(self, *args, reg_weight=0.1, **kwargs):
|
|
||||||
super().__init__(*args, **kwargs)
|
|
||||||
self.reg_weight = reg_weight
|
|
||||||
|
|
||||||
def train(self, gradient_steps: int, batch_size: int = 100) -> None:
|
|
||||||
# Update learning rate according to schedule
|
|
||||||
self._update_learning_rate(self.policy.optimizer)
|
|
||||||
|
|
||||||
losses = []
|
|
||||||
for _ in range(gradient_steps):
|
|
||||||
# Sample replay buffer
|
|
||||||
replay_data = self.replay_buffer.sample(batch_size, env=self._vec_normalize_env)
|
|
||||||
|
|
||||||
with torch.no_grad():
|
|
||||||
# Compute the next Q-values using the target network
|
|
||||||
next_q_values = self.q_net_target(replay_data.next_observations)
|
|
||||||
# Follow greedy policy: use the one with the highest value
|
|
||||||
next_q_values, _ = next_q_values.max(dim=1)
|
|
||||||
# Avoid potential broadcast issue
|
|
||||||
next_q_values = next_q_values.reshape(-1, 1)
|
|
||||||
# 1-step TD target
|
|
||||||
target_q_values = replay_data.rewards + (1 - replay_data.dones) * self.gamma * next_q_values
|
|
||||||
|
|
||||||
# Get current Q-values estimates
|
|
||||||
current_q_values = self.q_net(replay_data.observations)
|
|
||||||
|
|
||||||
# Retrieve the q-values for the actions from the replay buffer
|
|
||||||
current_q_values = torch.gather(current_q_values, dim=1, index=replay_data.actions.long())
|
|
||||||
|
|
||||||
delta = current_q_values - target_q_values
|
|
||||||
loss = torch.mean(self.reg_weight * current_q_values + torch.pow(delta, 2))
|
|
||||||
losses.append(loss.item())
|
|
||||||
|
|
||||||
# Optimize the policy
|
|
||||||
self.policy.optimizer.zero_grad()
|
|
||||||
loss.backward()
|
|
||||||
# Clip gradient norm
|
|
||||||
torch.nn.utils.clip_grad_norm_(self.policy.parameters(), self.max_grad_norm)
|
|
||||||
self.policy.optimizer.step()
|
|
||||||
|
|
||||||
# Increase update counter
|
|
||||||
self._n_updates += gradient_steps
|
|
||||||
|
|
||||||
logger.record("train/n_updates", self._n_updates, exclude="tensorboard")
|
|
||||||
logger.record("train/loss", np.mean(losses))
|
|
@ -3,14 +3,51 @@ import torch
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
import yaml
|
import yaml
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from salina import instantiate_class
|
|
||||||
from salina import TAgent
|
|
||||||
from salina.agents.gyma import (
|
def load_class(classname):
|
||||||
AutoResetGymAgent,
|
from importlib import import_module
|
||||||
_torch_type,
|
module_path, class_name = classname.rsplit(".", 1)
|
||||||
_format_frame,
|
module = import_module(module_path)
|
||||||
_torch_cat_dict
|
c = getattr(module, class_name)
|
||||||
)
|
return c
|
||||||
|
|
||||||
|
|
||||||
|
def instantiate_class(arguments):
|
||||||
|
from importlib import import_module
|
||||||
|
|
||||||
|
d = dict(arguments)
|
||||||
|
classname = d["classname"]
|
||||||
|
del d["classname"]
|
||||||
|
module_path, class_name = classname.rsplit(".", 1)
|
||||||
|
module = import_module(module_path)
|
||||||
|
c = getattr(module, class_name)
|
||||||
|
return c(**d)
|
||||||
|
|
||||||
|
|
||||||
|
def get_class(arguments):
|
||||||
|
from importlib import import_module
|
||||||
|
|
||||||
|
if isinstance(arguments, dict):
|
||||||
|
classname = arguments["classname"]
|
||||||
|
module_path, class_name = classname.rsplit(".", 1)
|
||||||
|
module = import_module(module_path)
|
||||||
|
c = getattr(module, class_name)
|
||||||
|
return c
|
||||||
|
else:
|
||||||
|
classname = arguments.classname
|
||||||
|
module_path, class_name = classname.rsplit(".", 1)
|
||||||
|
module = import_module(module_path)
|
||||||
|
c = getattr(module, class_name)
|
||||||
|
return c
|
||||||
|
|
||||||
|
|
||||||
|
def get_arguments(arguments):
|
||||||
|
from importlib import import_module
|
||||||
|
d = dict(arguments)
|
||||||
|
if "classname" in d:
|
||||||
|
del d["classname"]
|
||||||
|
return d
|
||||||
|
|
||||||
|
|
||||||
def load_yaml_file(path: Path):
|
def load_yaml_file(path: Path):
|
||||||
@ -21,90 +58,29 @@ def load_yaml_file(path: Path):
|
|||||||
|
|
||||||
def add_env_props(cfg):
|
def add_env_props(cfg):
|
||||||
env = instantiate_class(cfg['env'].copy())
|
env = instantiate_class(cfg['env'].copy())
|
||||||
cfg['agent'].update(dict(observation_size=env.observation_space.shape,
|
cfg['agent'].update(dict(observation_size=list(env.observation_space.shape),
|
||||||
n_actions=env.action_space.n))
|
n_actions=env.action_space.n))
|
||||||
|
|
||||||
|
|
||||||
|
class Checkpointer(object):
|
||||||
|
def __init__(self, experiment_name, root, config, total_steps, n_checkpoints):
|
||||||
|
self.path = root / experiment_name
|
||||||
|
self.checkpoint_indices = list(np.linspace(1, total_steps, n_checkpoints, dtype=int) - 1)
|
||||||
|
self.__current_checkpoint = 0
|
||||||
|
self.__current_step = 0
|
||||||
|
self.path.mkdir(exist_ok=True, parents=True)
|
||||||
|
with (self.path / 'config.yaml').open('w') as outfile:
|
||||||
|
yaml.dump(config, outfile, default_flow_style=False)
|
||||||
|
|
||||||
|
def save_experiment(self, name: str, model):
|
||||||
|
cpt_path = self.path / f'checkpoint_{self.__current_checkpoint}'
|
||||||
|
cpt_path.mkdir(exist_ok=True, parents=True)
|
||||||
|
torch.save(model.state_dict(), cpt_path / f'{name}.pt')
|
||||||
|
|
||||||
AGENT_PREFIX = 'agent#'
|
def step(self, to_save):
|
||||||
REWARD = 'reward'
|
if self.__current_step in self.checkpoint_indices:
|
||||||
CUMU_REWARD = 'cumulated_reward'
|
print(f'Checkpointing #{self.__current_checkpoint}')
|
||||||
OBS = 'env_obs'
|
for name, model in to_save:
|
||||||
SEP = '_'
|
self.save_experiment(name, model)
|
||||||
ACTION = 'action'
|
self.__current_checkpoint += 1
|
||||||
|
self.__current_step += 1
|
||||||
|
|
||||||
def access_str(agent_i, name, prefix=''):
|
|
||||||
return f'{prefix}{AGENT_PREFIX}{agent_i}{SEP}{name}'
|
|
||||||
|
|
||||||
|
|
||||||
class AutoResetGymMultiAgent(AutoResetGymAgent):
|
|
||||||
def __init__(self, *args, **kwargs):
|
|
||||||
super(AutoResetGymMultiAgent, self).__init__(*args, **kwargs)
|
|
||||||
|
|
||||||
def per_agent_values(self, name, values):
|
|
||||||
return {access_str(agent_i, name): value
|
|
||||||
for agent_i, value in zip(range(self.n_agents), values)}
|
|
||||||
|
|
||||||
def _initialize_envs(self, n):
|
|
||||||
super()._initialize_envs(n)
|
|
||||||
n_agents_list = [self.envs[i].unwrapped.n_agents for i in range(n)]
|
|
||||||
assert all(n_agents == n_agents_list[0] for n_agents in n_agents_list), \
|
|
||||||
'All envs must have the same number of agents.'
|
|
||||||
self.n_agents = n_agents_list[0]
|
|
||||||
|
|
||||||
def _reset(self, k, save_render):
|
|
||||||
ret = super()._reset(k, save_render)
|
|
||||||
obs = ret['env_obs'].squeeze()
|
|
||||||
self.cumulated_reward[k] = [0.0]*self.n_agents
|
|
||||||
obs = self.per_agent_values(OBS, [_format_frame(obs[i]) for i in range(self.n_agents)])
|
|
||||||
cumu_rew = self.per_agent_values(CUMU_REWARD, torch.zeros(self.n_agents, 1).float().unbind())
|
|
||||||
rewards = self.per_agent_values(REWARD, torch.zeros(self.n_agents, 1).float().unbind())
|
|
||||||
ret.update(cumu_rew)
|
|
||||||
ret.update(rewards)
|
|
||||||
ret.update(obs)
|
|
||||||
for remove in ['env_obs', 'cumulated_reward', 'reward']:
|
|
||||||
del ret[remove]
|
|
||||||
return ret
|
|
||||||
|
|
||||||
def _step(self, k, action, save_render):
|
|
||||||
self.timestep[k] += 1
|
|
||||||
env = self.envs[k]
|
|
||||||
if len(action.size()) == 0:
|
|
||||||
action = action.item()
|
|
||||||
assert isinstance(action, int)
|
|
||||||
else:
|
|
||||||
action = np.array(action.tolist())
|
|
||||||
o, r, d, _ = env.step(action)
|
|
||||||
self.cumulated_reward[k] = [x+y for x, y in zip(r, self.cumulated_reward[k])]
|
|
||||||
observation = self.per_agent_values(OBS, [_format_frame(o[i]) for i in range(self.n_agents)])
|
|
||||||
if d:
|
|
||||||
self.is_running[k] = False
|
|
||||||
if save_render:
|
|
||||||
image = env.render(mode="image").unsqueeze(0)
|
|
||||||
observation["rendering"] = image
|
|
||||||
rewards = self.per_agent_values(REWARD, torch.tensor(r).float().view(-1, 1).unbind())
|
|
||||||
cumulated_rewards = self.per_agent_values(CUMU_REWARD, torch.tensor(self.cumulated_reward[k]).float().view(-1, 1).unbind())
|
|
||||||
ret = {
|
|
||||||
**observation,
|
|
||||||
**rewards,
|
|
||||||
**cumulated_rewards,
|
|
||||||
"done": torch.tensor([d]),
|
|
||||||
"initial_state": torch.tensor([False]),
|
|
||||||
"timestep": torch.tensor([self.timestep[k]])
|
|
||||||
}
|
|
||||||
return _torch_type(ret)
|
|
||||||
|
|
||||||
|
|
||||||
class CombineActionsAgent(TAgent):
|
|
||||||
def __init__(self):
|
|
||||||
super().__init__()
|
|
||||||
self.pattern = fr'^{AGENT_PREFIX}\d{SEP}{ACTION}$'
|
|
||||||
|
|
||||||
def forward(self, t, **kwargs):
|
|
||||||
keys = list(self.workspace.keys())
|
|
||||||
action_keys = sorted([k for k in keys if bool(re.match(self.pattern, k))])
|
|
||||||
actions = torch.cat([self.get((k, t)) for k in action_keys], 0)
|
|
||||||
actions = actions if len(action_keys) <= 1 else actions.unsqueeze(0)
|
|
||||||
self.set((f'action', t), actions)
|
|
@ -1,55 +0,0 @@
|
|||||||
from typing import Union
|
|
||||||
import torch
|
|
||||||
import numpy as np
|
|
||||||
import pandas as pd
|
|
||||||
from algorithms.q_learner import QLearner
|
|
||||||
|
|
||||||
|
|
||||||
class VDNLearner(QLearner):
|
|
||||||
def __init__(self, *args, **kwargs):
|
|
||||||
super(VDNLearner, self).__init__(*args, **kwargs)
|
|
||||||
assert self.n_agents >= 2, 'VDN requires more than one agent, use QLearner instead'
|
|
||||||
|
|
||||||
def get_action(self, obs) -> Union[int, np.ndarray]:
|
|
||||||
o = torch.from_numpy(obs).unsqueeze(0) if self.n_agents <= 1 else torch.from_numpy(obs)
|
|
||||||
eps = np.random.rand(self.n_agents)
|
|
||||||
greedy = eps > self.eps
|
|
||||||
agent_actions = None
|
|
||||||
actions = []
|
|
||||||
for i in range(self.n_agents):
|
|
||||||
if greedy[i]:
|
|
||||||
if agent_actions is None: agent_actions = self.q_net.act(o.float())
|
|
||||||
action = agent_actions[i]
|
|
||||||
else:
|
|
||||||
action = self.env.action_space.sample()
|
|
||||||
actions.append(action)
|
|
||||||
return np.array(actions)
|
|
||||||
|
|
||||||
def train(self):
|
|
||||||
if len(self.buffer) < self.batch_size: return
|
|
||||||
for _ in range(self.n_grad_steps):
|
|
||||||
experience = self.buffer.sample(self.batch_size, cer=self.train_every_n_steps)
|
|
||||||
pred_q, target_q_raw = torch.zeros((self.batch_size, 1)), torch.zeros((self.batch_size, 1))
|
|
||||||
for agent_i in range(self.n_agents):
|
|
||||||
q_values, next_q_values_raw = self._training_routine(experience.observation[:, agent_i],
|
|
||||||
experience.next_observation[:, agent_i],
|
|
||||||
experience.action[:, agent_i].unsqueeze(-1))
|
|
||||||
pred_q += q_values
|
|
||||||
target_q_raw += next_q_values_raw
|
|
||||||
target_q = experience.reward + (1 - experience.done) * self.gamma * target_q_raw
|
|
||||||
loss = torch.mean(self.reg_weight * pred_q + torch.pow(pred_q - target_q, 2))
|
|
||||||
self._backprop_loss(loss)
|
|
||||||
|
|
||||||
def evaluate(self, n_episodes=100, render=False):
|
|
||||||
with torch.no_grad():
|
|
||||||
data = []
|
|
||||||
for eval_i in range(n_episodes):
|
|
||||||
obs, done = self.env.reset(), False
|
|
||||||
while not done:
|
|
||||||
action = self.get_action(obs)
|
|
||||||
next_obs, reward, done, info = self.env.step(action)
|
|
||||||
if render: self.env.render()
|
|
||||||
obs = next_obs # srsly i'm so stupid
|
|
||||||
info.update({'reward': reward, 'eval_episode': eval_i})
|
|
||||||
data.append(info)
|
|
||||||
return pd.DataFrame(data).fillna(0)
|
|
@ -1,22 +1,28 @@
|
|||||||
def make(env_name, pomdp_r=2, max_steps=400, stack_n_frames=3, n_agents=1, individual_rewards=False):
|
def make(env_name, pomdp_r=2, max_steps=400, stack_n_frames=3, n_agents=1, individual_rewards=False):
|
||||||
import yaml
|
import yaml
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from environments.factory.combined_factories import DirtItemFactory
|
from environments.factory.combined_factories import DirtItemFactory
|
||||||
from environments.factory.factory_item import ItemFactory, ItemProperties
|
from environments.factory.factory_item import ItemFactory
|
||||||
from environments.factory.factory_dirt import DirtProperties, DirtFactory
|
from environments.factory.additional.item.item_util import ItemProperties
|
||||||
from environments.utility_classes import MovementProperties, ObservationProperties, AgentRenderOptions
|
from environments.factory.factory_dirt import DirtFactory
|
||||||
|
from environments.factory.dirt_util import DirtProperties
|
||||||
|
from environments.factory.dirt_util import RewardsDirt
|
||||||
|
from environments.utility_classes import AgentRenderOptions
|
||||||
|
|
||||||
with (Path(__file__).parent / 'levels' / 'parameters' / f'{env_name}.yaml').open('r') as stream:
|
with (Path(__file__).parent / 'levels' / 'parameters' / f'{env_name}.yaml').open('r') as stream:
|
||||||
dictionary = yaml.load(stream, Loader=yaml.FullLoader)
|
dictionary = yaml.load(stream, Loader=yaml.FullLoader)
|
||||||
|
|
||||||
obs_props = ObservationProperties(render_agents=AgentRenderOptions.COMBINED,
|
obs_props = dict(render_agents=AgentRenderOptions.COMBINED,
|
||||||
frames_to_stack=stack_n_frames, pomdp_r=pomdp_r)
|
pomdp_r=pomdp_r,
|
||||||
|
indicate_door_area=True,
|
||||||
|
show_global_position_info=False,
|
||||||
|
frames_to_stack=stack_n_frames)
|
||||||
|
|
||||||
factory_kwargs = dict(n_agents=n_agents, individual_rewards=individual_rewards,
|
factory_kwargs = dict(**dictionary,
|
||||||
max_steps=max_steps, obs_prop=obs_props,
|
n_agents=n_agents,
|
||||||
mv_prop=MovementProperties(**dictionary['movement_props']),
|
individual_rewards=individual_rewards,
|
||||||
dirt_prop=DirtProperties(**dictionary['dirt_props']),
|
max_steps=max_steps,
|
||||||
record_episodes=False, verbose=False, **dictionary['factory_props']
|
obs_prop=obs_props,
|
||||||
|
verbose=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
return DirtFactory(**factory_kwargs).__enter__()
|
return DirtFactory(**factory_kwargs).__enter__()
|
||||||
|
0
environments/factory/additional/__init__.py
Normal file
0
environments/factory/additional/__init__.py
Normal file
0
environments/factory/additional/btry/__init__.py
Normal file
0
environments/factory/additional/btry/__init__.py
Normal file
41
environments/factory/additional/btry/btry_collections.py
Normal file
41
environments/factory/additional/btry/btry_collections.py
Normal file
@ -0,0 +1,41 @@
|
|||||||
|
from environments.factory.additional.btry.btry_objects import Battery, ChargePod
|
||||||
|
from environments.factory.base.registers import EnvObjectCollection, EntityCollection
|
||||||
|
|
||||||
|
|
||||||
|
class Batteries(EnvObjectCollection):
|
||||||
|
|
||||||
|
_accepted_objects = Battery
|
||||||
|
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
super(Batteries, self).__init__(*args, individual_slices=True,
|
||||||
|
is_blocking_light=False, can_be_shadowed=False, **kwargs)
|
||||||
|
self.is_observable = True
|
||||||
|
|
||||||
|
def spawn_batteries(self, agents, initial_charge_level):
|
||||||
|
batteries = [self._accepted_objects(initial_charge_level, agent, self) for _, agent in enumerate(agents)]
|
||||||
|
self.add_additional_items(batteries)
|
||||||
|
|
||||||
|
# Todo Move this to Mixin!
|
||||||
|
def by_entity(self, entity):
|
||||||
|
try:
|
||||||
|
return next((x for x in self if x.belongs_to_entity(entity)))
|
||||||
|
except StopIteration:
|
||||||
|
return None
|
||||||
|
|
||||||
|
def idx_by_entity(self, entity):
|
||||||
|
try:
|
||||||
|
return next((idx for idx, x in enumerate(self) if x.belongs_to_entity(entity)))
|
||||||
|
except StopIteration:
|
||||||
|
return None
|
||||||
|
|
||||||
|
def as_array_by_entity(self, entity):
|
||||||
|
return self._array[self.idx_by_entity(entity)]
|
||||||
|
|
||||||
|
|
||||||
|
class ChargePods(EntityCollection):
|
||||||
|
|
||||||
|
_accepted_objects = ChargePod
|
||||||
|
_stateless_entities = True
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
super(ChargePods, self).__repr__()
|
60
environments/factory/additional/btry/btry_objects.py
Normal file
60
environments/factory/additional/btry/btry_objects.py
Normal file
@ -0,0 +1,60 @@
|
|||||||
|
from environments import helpers as h
|
||||||
|
from environments.factory.base.objects import BoundingMixin, EnvObject, Entity
|
||||||
|
from environments.factory.additional.btry.btry_util import Constants as c
|
||||||
|
|
||||||
|
|
||||||
|
class Battery(BoundingMixin, EnvObject):
|
||||||
|
|
||||||
|
@property
|
||||||
|
def is_discharged(self):
|
||||||
|
return self.charge_level == 0
|
||||||
|
|
||||||
|
def __init__(self, initial_charge_level: float, *args, **kwargs):
|
||||||
|
super(Battery, self).__init__(*args, **kwargs)
|
||||||
|
self.charge_level = initial_charge_level
|
||||||
|
|
||||||
|
def encoding(self):
|
||||||
|
return self.charge_level
|
||||||
|
|
||||||
|
def do_charge_action(self, amount):
|
||||||
|
if self.charge_level < 1:
|
||||||
|
# noinspection PyTypeChecker
|
||||||
|
self.charge_level = min(1, amount + self.charge_level)
|
||||||
|
return c.VALID
|
||||||
|
else:
|
||||||
|
return c.NOT_VALID
|
||||||
|
|
||||||
|
def decharge(self, amount) -> c:
|
||||||
|
if self.charge_level != 0:
|
||||||
|
# noinspection PyTypeChecker
|
||||||
|
self.charge_level = max(0, amount + self.charge_level)
|
||||||
|
self._collection.notify_change_to_value(self)
|
||||||
|
return c.VALID
|
||||||
|
else:
|
||||||
|
return c.NOT_VALID
|
||||||
|
|
||||||
|
def summarize_state(self, **_):
|
||||||
|
attr_dict = {key: str(val) for key, val in self.__dict__.items() if not key.startswith('_') and key != 'data'}
|
||||||
|
attr_dict.update(dict(name=self.name, belongs_to=self._bound_entity.name))
|
||||||
|
return attr_dict
|
||||||
|
|
||||||
|
|
||||||
|
class ChargePod(Entity):
|
||||||
|
|
||||||
|
@property
|
||||||
|
def encoding(self):
|
||||||
|
return c.CHARGE_POD
|
||||||
|
|
||||||
|
def __init__(self, *args, charge_rate: float = 0.4,
|
||||||
|
multi_charge: bool = False, **kwargs):
|
||||||
|
super(ChargePod, self).__init__(*args, **kwargs)
|
||||||
|
self.charge_rate = charge_rate
|
||||||
|
self.multi_charge = multi_charge
|
||||||
|
|
||||||
|
def charge_battery(self, battery: Battery):
|
||||||
|
if battery.charge_level == 1.0:
|
||||||
|
return c.NOT_VALID
|
||||||
|
if sum(guest for guest in self.tile.guests if 'agent' in guest.name.lower()) > 1:
|
||||||
|
return c.NOT_VALID
|
||||||
|
valid = battery.do_charge_action(self.charge_rate)
|
||||||
|
return valid
|
30
environments/factory/additional/btry/btry_util.py
Normal file
30
environments/factory/additional/btry/btry_util.py
Normal file
@ -0,0 +1,30 @@
|
|||||||
|
from typing import NamedTuple, Union
|
||||||
|
|
||||||
|
from environments.helpers import Constants as BaseConstants, EnvActions as BaseActions
|
||||||
|
|
||||||
|
|
||||||
|
class Constants(BaseConstants):
|
||||||
|
# Battery Env
|
||||||
|
CHARGE_PODS = 'Charge_Pod'
|
||||||
|
BATTERIES = 'BATTERIES'
|
||||||
|
BATTERY_DISCHARGED = 'DISCHARGED'
|
||||||
|
CHARGE_POD = 1
|
||||||
|
|
||||||
|
|
||||||
|
class Actions(BaseActions):
|
||||||
|
CHARGE = 'do_charge_action'
|
||||||
|
|
||||||
|
|
||||||
|
class RewardsBtry(NamedTuple):
|
||||||
|
CHARGE_VALID: float = 0.1
|
||||||
|
CHARGE_FAIL: float = -0.1
|
||||||
|
BATTERY_DISCHARGED: float = -1.0
|
||||||
|
|
||||||
|
|
||||||
|
class BatteryProperties(NamedTuple):
|
||||||
|
initial_charge: float = 0.8 #
|
||||||
|
charge_rate: float = 0.4 #
|
||||||
|
charge_locations: int = 20 #
|
||||||
|
per_action_costs: Union[dict, float] = 0.02
|
||||||
|
done_when_discharged: bool = False
|
||||||
|
multi_charge: bool = False
|
139
environments/factory/additional/btry/factory_battery.py
Normal file
139
environments/factory/additional/btry/factory_battery.py
Normal file
@ -0,0 +1,139 @@
|
|||||||
|
from typing import Dict, List
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from environments.factory.additional.btry.btry_collections import Batteries, ChargePods
|
||||||
|
from environments.factory.additional.btry.btry_util import Constants, Actions, RewardsBtry, BatteryProperties
|
||||||
|
from environments.factory.base.base_factory import BaseFactory
|
||||||
|
from environments.factory.base.objects import Agent, Action
|
||||||
|
from environments.factory.base.renderer import RenderEntity
|
||||||
|
|
||||||
|
c = Constants
|
||||||
|
a = Actions
|
||||||
|
|
||||||
|
|
||||||
|
class BatteryFactory(BaseFactory):
|
||||||
|
|
||||||
|
def __init__(self, *args, btry_prop=BatteryProperties(), rewards_btry: RewardsBtry = RewardsBtry(),
|
||||||
|
**kwargs):
|
||||||
|
if isinstance(btry_prop, dict):
|
||||||
|
btry_prop = BatteryProperties(**btry_prop)
|
||||||
|
if isinstance(rewards_btry, dict):
|
||||||
|
rewards_btry = RewardsBtry(**rewards_btry)
|
||||||
|
self.btry_prop = btry_prop
|
||||||
|
self.rewards_dest = rewards_btry
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
|
||||||
|
def per_agent_raw_observations_hook(self, agent) -> Dict[str, np.typing.ArrayLike]:
|
||||||
|
additional_raw_observations = super().per_agent_raw_observations_hook(agent)
|
||||||
|
additional_raw_observations.update({c.BATTERIES: self[c.BATTERIES].as_array_by_entity(agent)})
|
||||||
|
return additional_raw_observations
|
||||||
|
|
||||||
|
def observations_hook(self) -> Dict[str, np.typing.ArrayLike]:
|
||||||
|
additional_observations = super().observations_hook()
|
||||||
|
additional_observations.update({c.CHARGE_PODS: self[c.CHARGE_PODS].as_array()})
|
||||||
|
return additional_observations
|
||||||
|
|
||||||
|
@property
|
||||||
|
def entities_hook(self):
|
||||||
|
super_entities = super().entities_hook
|
||||||
|
|
||||||
|
empty_tiles = self[c.FLOOR].empty_tiles[:self.btry_prop.charge_locations]
|
||||||
|
charge_pods = ChargePods.from_tiles(
|
||||||
|
empty_tiles, self._level_shape,
|
||||||
|
entity_kwargs=dict(charge_rate=self.btry_prop.charge_rate,
|
||||||
|
multi_charge=self.btry_prop.multi_charge)
|
||||||
|
)
|
||||||
|
|
||||||
|
batteries = Batteries(self._level_shape if not self._pomdp_r else ((self.pomdp_diameter,) * 2),
|
||||||
|
)
|
||||||
|
batteries.spawn_batteries(self[c.AGENT], self.btry_prop.initial_charge)
|
||||||
|
super_entities.update({c.BATTERIES: batteries, c.CHARGE_PODS: charge_pods})
|
||||||
|
return super_entities
|
||||||
|
|
||||||
|
def step_hook(self) -> (List[dict], dict):
|
||||||
|
super_reward_info = super(BatteryFactory, self).step_hook()
|
||||||
|
|
||||||
|
# Decharge
|
||||||
|
batteries = self[c.BATTERIES]
|
||||||
|
|
||||||
|
for agent in self[c.AGENT]:
|
||||||
|
if isinstance(self.btry_prop.per_action_costs, dict):
|
||||||
|
energy_consumption = self.btry_prop.per_action_costs[agent.temp_action]
|
||||||
|
else:
|
||||||
|
energy_consumption = self.btry_prop.per_action_costs
|
||||||
|
|
||||||
|
batteries.by_entity(agent).decharge(energy_consumption)
|
||||||
|
|
||||||
|
return super_reward_info
|
||||||
|
|
||||||
|
def do_charge_action(self, agent) -> (dict, dict):
|
||||||
|
if charge_pod := self[c.CHARGE_PODS].by_pos(agent.pos):
|
||||||
|
valid = charge_pod.charge_battery(self[c.BATTERIES].by_entity(agent))
|
||||||
|
if valid:
|
||||||
|
info_dict = {f'{agent.name}_{a.CHARGE}_VALID': 1}
|
||||||
|
self.print(f'{agent.name} just charged batteries at {charge_pod.name}.')
|
||||||
|
else:
|
||||||
|
info_dict = {f'{agent.name}_{a.CHARGE}_FAIL': 1}
|
||||||
|
self.print(f'{agent.name} failed to charged batteries at {charge_pod.name}.')
|
||||||
|
else:
|
||||||
|
valid = c.NOT_VALID
|
||||||
|
info_dict = {f'{agent.name}_{a.CHARGE}_FAIL': 1}
|
||||||
|
# info_dict = {f'{agent.name}_no_charger': 1}
|
||||||
|
self.print(f'{agent.name} failed to charged batteries at {agent.pos}.')
|
||||||
|
reward = dict(value=self.rewards_dest.CHARGE_VALID if valid else self.rewards_dest.CHARGE_FAIL,
|
||||||
|
reason=a.CHARGE, info=info_dict)
|
||||||
|
return valid, reward
|
||||||
|
|
||||||
|
def do_additional_actions(self, agent: Agent, action: Action) -> (bool, dict):
|
||||||
|
action_result = super().do_additional_actions(agent, action)
|
||||||
|
if action_result is None:
|
||||||
|
if action == a.CHARGE:
|
||||||
|
action_result = self.do_charge_action(agent)
|
||||||
|
return action_result
|
||||||
|
else:
|
||||||
|
return None
|
||||||
|
else:
|
||||||
|
return action_result
|
||||||
|
pass
|
||||||
|
|
||||||
|
def reset_hook(self) -> (List[dict], dict):
|
||||||
|
super_reward_info = super(BatteryFactory, self).reset_hook()
|
||||||
|
# There is Nothing to reset.
|
||||||
|
return super_reward_info
|
||||||
|
|
||||||
|
def check_additional_done(self) -> (bool, dict):
|
||||||
|
super_done, super_dict = super(BatteryFactory, self).check_additional_done()
|
||||||
|
if super_done:
|
||||||
|
return super_done, super_dict
|
||||||
|
else:
|
||||||
|
if self.btry_prop.done_when_discharged:
|
||||||
|
if btry_done := any(battery.is_discharged for battery in self[c.BATTERIES]):
|
||||||
|
super_dict.update(DISCHARGE_DONE=1)
|
||||||
|
return btry_done, super_dict
|
||||||
|
else:
|
||||||
|
pass
|
||||||
|
else:
|
||||||
|
pass
|
||||||
|
return super_done, super_dict
|
||||||
|
|
||||||
|
def per_agent_reward_hook(self, agent: Agent) -> List[dict]:
|
||||||
|
reward_event_list = super(BatteryFactory, self).per_agent_reward_hook(agent)
|
||||||
|
if self[c.BATTERIES].by_entity(agent).is_discharged:
|
||||||
|
self.print(f'{agent.name} Battery is discharged!')
|
||||||
|
info_dict = {f'{agent.name}_{c.BATTERY_DISCHARGED}': 1}
|
||||||
|
reward_event_list.append({'value': self.rewards_dest.BATTERY_DISCHARGED,
|
||||||
|
'reason': c.BATTERY_DISCHARGED,
|
||||||
|
'info': info_dict}
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# All Fine
|
||||||
|
pass
|
||||||
|
return reward_event_list
|
||||||
|
|
||||||
|
def render_assets_hook(self):
|
||||||
|
# noinspection PyUnresolvedReferences
|
||||||
|
additional_assets = super().render_assets_hook()
|
||||||
|
charge_pods = [RenderEntity(c.CHARGE_PODS, charge_pod.tile.pos) for charge_pod in self[c.CHARGE_PODS]]
|
||||||
|
additional_assets.extend(charge_pods)
|
||||||
|
return additional_assets
|
@ -1,11 +1,15 @@
|
|||||||
import random
|
import random
|
||||||
|
|
||||||
from environments.factory.factory_battery import BatteryFactory, BatteryProperties
|
|
||||||
from environments.factory.factory_dirt import DirtFactory, DirtProperties
|
|
||||||
from environments.factory.factory_item import ItemFactory
|
|
||||||
|
|
||||||
|
|
||||||
# noinspection PyAbstractClass
|
# noinspection PyAbstractClass
|
||||||
|
from environments.factory.additional.btry.btry_util import BatteryProperties
|
||||||
|
from environments.factory.additional.btry.factory_battery import BatteryFactory
|
||||||
|
from environments.factory.additional.dest.factory_dest import DestFactory
|
||||||
|
from environments.factory.additional.dirt.dirt_util import DirtProperties
|
||||||
|
from environments.factory.additional.dirt.factory_dirt import DirtFactory
|
||||||
|
from environments.factory.additional.item.factory_item import ItemFactory
|
||||||
|
|
||||||
|
|
||||||
class DirtItemFactory(ItemFactory, DirtFactory):
|
class DirtItemFactory(ItemFactory, DirtFactory):
|
||||||
def __init__(self, *args, **kwargs):
|
def __init__(self, *args, **kwargs):
|
||||||
super().__init__(*args, **kwargs)
|
super().__init__(*args, **kwargs)
|
||||||
@ -17,6 +21,18 @@ class DirtBatteryFactory(DirtFactory, BatteryFactory):
|
|||||||
super().__init__(*args, **kwargs)
|
super().__init__(*args, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
# noinspection PyAbstractClass
|
||||||
|
class DirtDestItemFactory(ItemFactory, DirtFactory, DestFactory):
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
# noinspection PyAbstractClass
|
||||||
|
class DestBatteryFactory(BatteryFactory, DestFactory):
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
from environments.utility_classes import AgentRenderOptions as ARO, ObservationProperties
|
from environments.utility_classes import AgentRenderOptions as ARO, ObservationProperties
|
||||||
|
|
0
environments/factory/additional/dest/__init__.py
Normal file
0
environments/factory/additional/dest/__init__.py
Normal file
38
environments/factory/additional/dest/dest_collections.py
Normal file
38
environments/factory/additional/dest/dest_collections.py
Normal file
@ -0,0 +1,38 @@
|
|||||||
|
from environments.factory.base.registers import EntityCollection
|
||||||
|
from environments.factory.additional.dest.dest_util import Constants as c
|
||||||
|
from environments.factory.additional.dest.dest_enitites import Destination
|
||||||
|
|
||||||
|
|
||||||
|
class Destinations(EntityCollection):
|
||||||
|
|
||||||
|
_accepted_objects = Destination
|
||||||
|
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
self.is_blocking_light = False
|
||||||
|
self.can_be_shadowed = False
|
||||||
|
|
||||||
|
def as_array(self):
|
||||||
|
self._array[:] = c.FREE_CELL
|
||||||
|
# ToDo: Switch to new Style Array Put
|
||||||
|
# indices = list(zip(range(len(cls)), *zip(*[x.pos for x in cls])))
|
||||||
|
# np.put(cls._array, [np.ravel_multi_index(x, cls._array.shape) for x in indices], cls.encodings)
|
||||||
|
for item in self:
|
||||||
|
if item.pos != c.NO_POS:
|
||||||
|
self._array[0, item.x, item.y] = item.encoding
|
||||||
|
return self._array
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
return super(Destinations, self).__repr__()
|
||||||
|
|
||||||
|
|
||||||
|
class ReachedDestinations(Destinations):
|
||||||
|
_accepted_objects = Destination
|
||||||
|
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
super(ReachedDestinations, self).__init__(*args, **kwargs)
|
||||||
|
self.can_be_shadowed = False
|
||||||
|
self.is_blocking_light = False
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
return super(ReachedDestinations, self).__repr__()
|
45
environments/factory/additional/dest/dest_enitites.py
Normal file
45
environments/factory/additional/dest/dest_enitites.py
Normal file
@ -0,0 +1,45 @@
|
|||||||
|
from collections import defaultdict
|
||||||
|
|
||||||
|
from environments.factory.base.objects import Entity, Agent
|
||||||
|
from environments.factory.additional.dest.dest_util import Constants as c
|
||||||
|
|
||||||
|
|
||||||
|
class Destination(Entity):
|
||||||
|
|
||||||
|
@property
|
||||||
|
def any_agent_has_dwelled(self):
|
||||||
|
return bool(len(self._per_agent_times))
|
||||||
|
|
||||||
|
@property
|
||||||
|
def currently_dwelling_names(self):
|
||||||
|
return self._per_agent_times.keys()
|
||||||
|
|
||||||
|
@property
|
||||||
|
def encoding(self):
|
||||||
|
return c.DESTINATION
|
||||||
|
|
||||||
|
def __init__(self, *args, dwell_time: int = 0, **kwargs):
|
||||||
|
super(Destination, self).__init__(*args, **kwargs)
|
||||||
|
self.dwell_time = dwell_time
|
||||||
|
self._per_agent_times = defaultdict(lambda: dwell_time)
|
||||||
|
|
||||||
|
def do_wait_action(self, agent: Agent):
|
||||||
|
self._per_agent_times[agent.name] -= 1
|
||||||
|
return c.VALID
|
||||||
|
|
||||||
|
def leave(self, agent: Agent):
|
||||||
|
del self._per_agent_times[agent.name]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def is_considered_reached(self):
|
||||||
|
agent_at_position = any(c.AGENT.lower() in x.name.lower() for x in self.tile.guests_that_can_collide)
|
||||||
|
return (agent_at_position and not self.dwell_time) or any(x == 0 for x in self._per_agent_times.values())
|
||||||
|
|
||||||
|
def agent_is_dwelling(self, agent: Agent):
|
||||||
|
return self._per_agent_times[agent.name] < self.dwell_time
|
||||||
|
|
||||||
|
def summarize_state(self) -> dict:
|
||||||
|
state_summary = super().summarize_state()
|
||||||
|
state_summary.update(per_agent_times=[
|
||||||
|
dict(belongs_to=key, time=val) for key, val in self._per_agent_times.keys()], dwell_time=self.dwell_time)
|
||||||
|
return state_summary
|
41
environments/factory/additional/dest/dest_util.py
Normal file
41
environments/factory/additional/dest/dest_util.py
Normal file
@ -0,0 +1,41 @@
|
|||||||
|
from typing import NamedTuple
|
||||||
|
|
||||||
|
from environments.helpers import Constants as BaseConstants, EnvActions as BaseActions
|
||||||
|
|
||||||
|
|
||||||
|
class Constants(BaseConstants):
|
||||||
|
# Destination Env
|
||||||
|
DEST = 'Destination'
|
||||||
|
DESTINATION = 1
|
||||||
|
DESTINATION_DONE = 0.5
|
||||||
|
DEST_REACHED = 'ReachedDestination'
|
||||||
|
|
||||||
|
|
||||||
|
class Actions(BaseActions):
|
||||||
|
WAIT_ON_DEST = 'WAIT'
|
||||||
|
|
||||||
|
|
||||||
|
class RewardsDest(NamedTuple):
|
||||||
|
|
||||||
|
WAIT_VALID: float = 0.1
|
||||||
|
WAIT_FAIL: float = -0.1
|
||||||
|
DEST_REACHED: float = 5.0
|
||||||
|
|
||||||
|
|
||||||
|
class DestModeOptions(object):
|
||||||
|
DONE = 'DONE'
|
||||||
|
GROUPED = 'GROUPED'
|
||||||
|
PER_DEST = 'PER_DEST'
|
||||||
|
|
||||||
|
|
||||||
|
class DestProperties(NamedTuple):
|
||||||
|
n_dests: int = 1 # How many destinations are there
|
||||||
|
dwell_time: int = 0 # How long does the agent need to "wait" on a destination
|
||||||
|
spawn_frequency: int = 0
|
||||||
|
spawn_in_other_zone: bool = True #
|
||||||
|
spawn_mode: str = DestModeOptions.DONE
|
||||||
|
|
||||||
|
assert dwell_time >= 0, 'dwell_time cannot be < 0!'
|
||||||
|
assert spawn_frequency >= 0, 'spawn_frequency cannot be < 0!'
|
||||||
|
assert n_dests >= 0, 'n_destinations cannot be < 0!'
|
||||||
|
assert (spawn_mode == DestModeOptions.DONE) != bool(spawn_frequency)
|
203
environments/factory/additional/dest/factory_dest.py
Normal file
203
environments/factory/additional/dest/factory_dest.py
Normal file
@ -0,0 +1,203 @@
|
|||||||
|
import time
|
||||||
|
from enum import Enum
|
||||||
|
from typing import List, Union, Dict
|
||||||
|
import numpy as np
|
||||||
|
import random
|
||||||
|
|
||||||
|
from environments.factory.additional.dest.dest_collections import Destinations, ReachedDestinations
|
||||||
|
from environments.factory.additional.dest.dest_enitites import Destination
|
||||||
|
from environments.factory.additional.dest.dest_util import Constants, Actions, RewardsDest, DestModeOptions, \
|
||||||
|
DestProperties
|
||||||
|
from environments.factory.base.base_factory import BaseFactory
|
||||||
|
from environments.factory.base.objects import Agent, Action
|
||||||
|
from environments.factory.base.registers import Entities
|
||||||
|
|
||||||
|
from environments.factory.base.renderer import RenderEntity
|
||||||
|
|
||||||
|
c = Constants
|
||||||
|
a = Actions
|
||||||
|
|
||||||
|
|
||||||
|
# noinspection PyAttributeOutsideInit, PyAbstractClass
|
||||||
|
class DestFactory(BaseFactory):
|
||||||
|
# noinspection PyMissingConstructor
|
||||||
|
|
||||||
|
def __init__(self, *args, dest_prop: DestProperties = DestProperties(), rewards_dest: RewardsDest = RewardsDest(),
|
||||||
|
env_seed=time.time_ns(), **kwargs):
|
||||||
|
if isinstance(dest_prop, dict):
|
||||||
|
dest_prop = DestProperties(**dest_prop)
|
||||||
|
if isinstance(rewards_dest, dict):
|
||||||
|
rewards_dest = RewardsDest(**rewards_dest)
|
||||||
|
self.dest_prop = dest_prop
|
||||||
|
self.rewards_dest = rewards_dest
|
||||||
|
kwargs.update(env_seed=env_seed)
|
||||||
|
self._dest_rng = np.random.default_rng(env_seed)
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def actions_hook(self) -> Union[Action, List[Action]]:
|
||||||
|
# noinspection PyUnresolvedReferences
|
||||||
|
super_actions = super().actions_hook
|
||||||
|
# If targets are considers reached after some time, agents need an action for that.
|
||||||
|
if self.dest_prop.dwell_time:
|
||||||
|
super_actions.append(Action(enum_ident=a.WAIT_ON_DEST))
|
||||||
|
return super_actions
|
||||||
|
|
||||||
|
@property
|
||||||
|
def entities_hook(self) -> Dict[(Enum, Entities)]:
|
||||||
|
# noinspection PyUnresolvedReferences
|
||||||
|
super_entities = super().entities_hook
|
||||||
|
|
||||||
|
empty_tiles = self[c.FLOOR].empty_tiles[:self.dest_prop.n_dests]
|
||||||
|
destinations = Destinations.from_tiles(
|
||||||
|
empty_tiles, self._level_shape,
|
||||||
|
entity_kwargs=dict(
|
||||||
|
dwell_time=self.dest_prop.dwell_time)
|
||||||
|
)
|
||||||
|
reached_destinations = ReachedDestinations(level_shape=self._level_shape)
|
||||||
|
|
||||||
|
super_entities.update({c.DEST: destinations, c.DEST_REACHED: reached_destinations})
|
||||||
|
return super_entities
|
||||||
|
|
||||||
|
def do_wait_action(self, agent: Agent) -> (dict, dict):
|
||||||
|
if destination := self[c.DEST].by_pos(agent.pos):
|
||||||
|
valid = destination.do_wait_action(agent)
|
||||||
|
self.print(f'{agent.name} just waited at {agent.pos}')
|
||||||
|
info_dict = {f'{agent.name}_{a.WAIT_ON_DEST}_VALID': 1}
|
||||||
|
else:
|
||||||
|
valid = c.NOT_VALID
|
||||||
|
self.print(f'{agent.name} just tried to do_wait_action do_wait_action at {agent.pos} but failed')
|
||||||
|
info_dict = {f'{agent.name}_{a.WAIT_ON_DEST}_FAIL': 1}
|
||||||
|
reward = dict(value=self.rewards_dest.WAIT_VALID if valid else self.rewards_dest.WAIT_FAIL,
|
||||||
|
reason=a.WAIT_ON_DEST, info=info_dict)
|
||||||
|
return valid, reward
|
||||||
|
|
||||||
|
def do_additional_actions(self, agent: Agent, action: Action) -> (dict, dict):
|
||||||
|
# noinspection PyUnresolvedReferences
|
||||||
|
super_action_result = super().do_additional_actions(agent, action)
|
||||||
|
if super_action_result is None:
|
||||||
|
if action == a.WAIT_ON_DEST:
|
||||||
|
action_result = self.do_wait_action(agent)
|
||||||
|
return action_result
|
||||||
|
else:
|
||||||
|
return None
|
||||||
|
else:
|
||||||
|
return super_action_result
|
||||||
|
|
||||||
|
def reset_hook(self) -> None:
|
||||||
|
# noinspection PyUnresolvedReferences
|
||||||
|
super().reset_hook()
|
||||||
|
self._dest_spawn_timer = dict()
|
||||||
|
|
||||||
|
def trigger_destination_spawn(self):
|
||||||
|
destinations_to_spawn = [key for key, val in self._dest_spawn_timer.items()
|
||||||
|
if val == self.dest_prop.spawn_frequency]
|
||||||
|
if destinations_to_spawn:
|
||||||
|
n_dest_to_spawn = len(destinations_to_spawn)
|
||||||
|
if self.dest_prop.spawn_mode != DestModeOptions.GROUPED:
|
||||||
|
destinations = [Destination(tile, self[c.DEST]) for tile in self[c.FLOOR].empty_tiles[:n_dest_to_spawn]]
|
||||||
|
self[c.DEST].add_additional_items(destinations)
|
||||||
|
for dest in destinations_to_spawn:
|
||||||
|
del self._dest_spawn_timer[dest]
|
||||||
|
self.print(f'{n_dest_to_spawn} new destinations have been spawned')
|
||||||
|
elif self.dest_prop.spawn_mode == DestModeOptions.GROUPED and n_dest_to_spawn == self.dest_prop.n_dests:
|
||||||
|
destinations = [Destination(tile, self[c.DEST]) for tile in self[c.FLOOR].empty_tiles[:n_dest_to_spawn]]
|
||||||
|
self[c.DEST].add_additional_items(destinations)
|
||||||
|
for dest in destinations_to_spawn:
|
||||||
|
del self._dest_spawn_timer[dest]
|
||||||
|
self.print(f'{n_dest_to_spawn} new destinations have been spawned')
|
||||||
|
else:
|
||||||
|
self.print(f'{n_dest_to_spawn} new destinations could be spawned, but waiting for all.')
|
||||||
|
pass
|
||||||
|
else:
|
||||||
|
self.print('No Items are spawning, limit is reached.')
|
||||||
|
|
||||||
|
def step_hook(self) -> (List[dict], dict):
|
||||||
|
# noinspection PyUnresolvedReferences
|
||||||
|
super_reward_info = super().step_hook()
|
||||||
|
for key, val in self._dest_spawn_timer.items():
|
||||||
|
self._dest_spawn_timer[key] = min(self.dest_prop.spawn_frequency, self._dest_spawn_timer[key] + 1)
|
||||||
|
|
||||||
|
for dest in list(self[c.DEST].values()):
|
||||||
|
if dest.is_considered_reached:
|
||||||
|
dest.change_parent_collection(self[c.DEST_REACHED])
|
||||||
|
self._dest_spawn_timer[dest.name] = 0
|
||||||
|
self.print(f'{dest.name} is reached now, removing...')
|
||||||
|
else:
|
||||||
|
for agent_name in dest.currently_dwelling_names:
|
||||||
|
agent = self[c.AGENT].by_name(agent_name)
|
||||||
|
if agent.pos == dest.pos:
|
||||||
|
self.print(f'{agent.name} is still waiting.')
|
||||||
|
pass
|
||||||
|
else:
|
||||||
|
dest.leave(agent)
|
||||||
|
self.print(f'{agent.name} left the destination early.')
|
||||||
|
self.trigger_destination_spawn()
|
||||||
|
return super_reward_info
|
||||||
|
|
||||||
|
def observations_hook(self) -> Dict[str, np.typing.ArrayLike]:
|
||||||
|
additional_observations = super().observations_hook()
|
||||||
|
additional_observations.update({c.DEST: self[c.DEST].as_array()})
|
||||||
|
return additional_observations
|
||||||
|
|
||||||
|
def per_agent_reward_hook(self, agent: Agent) -> List[dict]:
|
||||||
|
# noinspection PyUnresolvedReferences
|
||||||
|
reward_event_list = super().per_agent_reward_hook(agent)
|
||||||
|
if len(self[c.DEST_REACHED]):
|
||||||
|
for reached_dest in list(self[c.DEST_REACHED]):
|
||||||
|
if agent.pos == reached_dest.pos:
|
||||||
|
self.print(f'{agent.name} just reached destination at {agent.pos}')
|
||||||
|
self[c.DEST_REACHED].delete_env_object(reached_dest)
|
||||||
|
info_dict = {f'{agent.name}_{c.DEST_REACHED}': 1}
|
||||||
|
reward_event_list.append({'value': self.rewards_dest.DEST_REACHED,
|
||||||
|
'reason': c.DEST_REACHED,
|
||||||
|
'info': info_dict})
|
||||||
|
return reward_event_list
|
||||||
|
|
||||||
|
def render_assets_hook(self, mode='human'):
|
||||||
|
# noinspection PyUnresolvedReferences
|
||||||
|
additional_assets = super().render_assets_hook()
|
||||||
|
destinations = [RenderEntity(c.DEST, dest.pos) for dest in self[c.DEST]]
|
||||||
|
additional_assets.extend(destinations)
|
||||||
|
return additional_assets
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
from environments.utility_classes import AgentRenderOptions as aro, ObservationProperties
|
||||||
|
|
||||||
|
render = True
|
||||||
|
|
||||||
|
dest_probs = DestProperties(n_dests=2, spawn_frequency=5, spawn_mode=DestModeOptions.GROUPED)
|
||||||
|
|
||||||
|
obs_props = ObservationProperties(render_agents=aro.LEVEL, omit_agent_self=True, pomdp_r=2)
|
||||||
|
|
||||||
|
move_props = {'allow_square_movement': True,
|
||||||
|
'allow_diagonal_movement': False,
|
||||||
|
'allow_no_op': False}
|
||||||
|
|
||||||
|
factory = DestFactory(n_agents=10, done_at_collision=False,
|
||||||
|
level_name='rooms', max_steps=400,
|
||||||
|
obs_prop=obs_props, parse_doors=True,
|
||||||
|
verbose=True,
|
||||||
|
mv_prop=move_props, dest_prop=dest_probs
|
||||||
|
)
|
||||||
|
|
||||||
|
# noinspection DuplicatedCode
|
||||||
|
n_actions = factory.action_space.n - 1
|
||||||
|
_ = factory.observation_space
|
||||||
|
|
||||||
|
for epoch in range(4):
|
||||||
|
random_actions = [[random.randint(0, n_actions) for _
|
||||||
|
in range(factory.n_agents)] for _
|
||||||
|
in range(factory.max_steps + 1)]
|
||||||
|
env_state = factory.reset()
|
||||||
|
r = 0
|
||||||
|
for agent_i_action in random_actions:
|
||||||
|
env_state, step_r, done_bool, info_obj = factory.step(agent_i_action)
|
||||||
|
r += step_r
|
||||||
|
if render:
|
||||||
|
factory.render()
|
||||||
|
if done_bool:
|
||||||
|
break
|
||||||
|
print(f'Factory run {epoch} done, reward is:\n {r}')
|
||||||
|
pass
|
0
environments/factory/additional/dirt/__init__.py
Normal file
0
environments/factory/additional/dirt/__init__.py
Normal file
42
environments/factory/additional/dirt/dirt_collections.py
Normal file
42
environments/factory/additional/dirt/dirt_collections.py
Normal file
@ -0,0 +1,42 @@
|
|||||||
|
from environments.factory.additional.dirt.dirt_entity import DirtPile
|
||||||
|
from environments.factory.additional.dirt.dirt_util import DirtProperties
|
||||||
|
from environments.factory.base.objects import Floor
|
||||||
|
from environments.factory.base.registers import EntityCollection
|
||||||
|
from environments.factory.additional.dirt.dirt_util import Constants as c
|
||||||
|
|
||||||
|
|
||||||
|
class DirtPiles(EntityCollection):
|
||||||
|
|
||||||
|
_accepted_objects = DirtPile
|
||||||
|
|
||||||
|
@property
|
||||||
|
def amount(self):
|
||||||
|
return sum([dirt.amount for dirt in self])
|
||||||
|
|
||||||
|
@property
|
||||||
|
def dirt_properties(self):
|
||||||
|
return self._dirt_properties
|
||||||
|
|
||||||
|
def __init__(self, dirt_properties, *args):
|
||||||
|
super(DirtPiles, self).__init__(*args)
|
||||||
|
self._dirt_properties: DirtProperties = dirt_properties
|
||||||
|
|
||||||
|
def spawn_dirt(self, then_dirty_tiles) -> bool:
|
||||||
|
if isinstance(then_dirty_tiles, Floor):
|
||||||
|
then_dirty_tiles = [then_dirty_tiles]
|
||||||
|
for tile in then_dirty_tiles:
|
||||||
|
if not self.amount > self.dirt_properties.max_global_amount:
|
||||||
|
dirt = self.by_pos(tile.pos)
|
||||||
|
if dirt is None:
|
||||||
|
dirt = DirtPile(tile, self, amount=self.dirt_properties.max_spawn_amount)
|
||||||
|
self.add_item(dirt)
|
||||||
|
else:
|
||||||
|
new_value = dirt.amount + self.dirt_properties.max_spawn_amount
|
||||||
|
dirt.set_new_amount(min(new_value, self.dirt_properties.max_local_amount))
|
||||||
|
else:
|
||||||
|
return c.NOT_VALID
|
||||||
|
return c.VALID
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
s = super(DirtPiles, self).__repr__()
|
||||||
|
return f'{s[:-1]}, {self.amount})'
|
26
environments/factory/additional/dirt/dirt_entity.py
Normal file
26
environments/factory/additional/dirt/dirt_entity.py
Normal file
@ -0,0 +1,26 @@
|
|||||||
|
from environments.factory.base.objects import Entity
|
||||||
|
|
||||||
|
|
||||||
|
class DirtPile(Entity):
|
||||||
|
|
||||||
|
@property
|
||||||
|
def amount(self):
|
||||||
|
return self._amount
|
||||||
|
|
||||||
|
@property
|
||||||
|
def encoding(self):
|
||||||
|
# Edit this if you want items to be drawn in the ops differntly
|
||||||
|
return self._amount
|
||||||
|
|
||||||
|
def __init__(self, *args, amount=None, **kwargs):
|
||||||
|
super(DirtPile, self).__init__(*args, **kwargs)
|
||||||
|
self._amount = amount
|
||||||
|
|
||||||
|
def set_new_amount(self, amount):
|
||||||
|
self._amount = amount
|
||||||
|
self._collection.notify_change_to_value(self)
|
||||||
|
|
||||||
|
def summarize_state(self):
|
||||||
|
state_dict = super().summarize_state()
|
||||||
|
state_dict.update(amount=float(self.amount))
|
||||||
|
return state_dict
|
30
environments/factory/additional/dirt/dirt_util.py
Normal file
30
environments/factory/additional/dirt/dirt_util.py
Normal file
@ -0,0 +1,30 @@
|
|||||||
|
from typing import NamedTuple
|
||||||
|
|
||||||
|
from environments.helpers import Constants as BaseConstants, EnvActions as BaseActions
|
||||||
|
|
||||||
|
|
||||||
|
class Constants(BaseConstants):
|
||||||
|
DIRT = 'DirtPile'
|
||||||
|
|
||||||
|
|
||||||
|
class Actions(BaseActions):
|
||||||
|
CLEAN_UP = 'do_cleanup_action'
|
||||||
|
|
||||||
|
|
||||||
|
class RewardsDirt(NamedTuple):
|
||||||
|
CLEAN_UP_VALID: float = 0.5
|
||||||
|
CLEAN_UP_FAIL: float = -0.1
|
||||||
|
CLEAN_UP_LAST_PIECE: float = 4.5
|
||||||
|
|
||||||
|
|
||||||
|
class DirtProperties(NamedTuple):
|
||||||
|
initial_dirt_ratio: float = 0.3 # On INIT, on max how many tiles does the dirt spawn in percent.
|
||||||
|
initial_dirt_spawn_r_var: float = 0.05 # How much does the dirt spawn amount vary?
|
||||||
|
clean_amount: float = 1 # How much does the robot clean with one actions.
|
||||||
|
max_spawn_ratio: float = 0.20 # On max how many tiles does the dirt spawn in percent.
|
||||||
|
max_spawn_amount: float = 0.3 # How much dirt does spawn per tile at max.
|
||||||
|
spawn_frequency: int = 0 # Spawn Frequency in Steps.
|
||||||
|
max_local_amount: int = 2 # Max dirt amount per tile.
|
||||||
|
max_global_amount: int = 20 # Max dirt amount in the whole environment.
|
||||||
|
dirt_smear_amount: float = 0.2 # Agents smear dirt, when not cleaning up in place.
|
||||||
|
done_when_clean: bool = True
|
252
environments/factory/additional/dirt/factory_dirt.py
Normal file
252
environments/factory/additional/dirt/factory_dirt.py
Normal file
@ -0,0 +1,252 @@
|
|||||||
|
import time
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import List, Union, Dict
|
||||||
|
import random
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from environments.factory.additional.dirt.dirt_collections import DirtPiles
|
||||||
|
from environments.factory.additional.dirt.dirt_entity import DirtPile
|
||||||
|
from environments.factory.additional.dirt.dirt_util import Constants, Actions, RewardsDirt, DirtProperties
|
||||||
|
|
||||||
|
from environments.factory.base.base_factory import BaseFactory
|
||||||
|
from environments.factory.base.objects import Agent, Action
|
||||||
|
from environments.factory.base.registers import Entities
|
||||||
|
|
||||||
|
from environments.factory.base.renderer import RenderEntity
|
||||||
|
from environments.utility_classes import ObservationProperties
|
||||||
|
|
||||||
|
|
||||||
|
def softmax(x):
|
||||||
|
"""Compute softmax values for each sets of scores in x."""
|
||||||
|
e_x = np.exp(x - np.max(x))
|
||||||
|
return e_x / e_x.sum()
|
||||||
|
|
||||||
|
|
||||||
|
def entropy(x):
|
||||||
|
return -(x * np.log(x + 1e-8)).sum()
|
||||||
|
|
||||||
|
|
||||||
|
c = Constants
|
||||||
|
a = Actions
|
||||||
|
|
||||||
|
|
||||||
|
# noinspection PyAttributeOutsideInit, PyAbstractClass
|
||||||
|
class DirtFactory(BaseFactory):
|
||||||
|
|
||||||
|
@property
|
||||||
|
def actions_hook(self) -> Union[Action, List[Action]]:
|
||||||
|
super_actions = super().actions_hook
|
||||||
|
super_actions.append(Action(str_ident=a.CLEAN_UP))
|
||||||
|
return super_actions
|
||||||
|
|
||||||
|
@property
|
||||||
|
def entities_hook(self) -> Dict[(str, Entities)]:
|
||||||
|
super_entities = super().entities_hook
|
||||||
|
dirt_register = DirtPiles(self.dirt_prop, self._level_shape)
|
||||||
|
super_entities.update(({c.DIRT: dirt_register}))
|
||||||
|
return super_entities
|
||||||
|
|
||||||
|
def __init__(self, *args,
|
||||||
|
dirt_prop: DirtProperties = DirtProperties(), rewards_dirt: RewardsDirt = RewardsDirt(),
|
||||||
|
env_seed=time.time_ns(), **kwargs):
|
||||||
|
if isinstance(dirt_prop, dict):
|
||||||
|
dirt_prop = DirtProperties(**dirt_prop)
|
||||||
|
if isinstance(rewards_dirt, dict):
|
||||||
|
rewards_dirt = RewardsDirt(**rewards_dirt)
|
||||||
|
self.dirt_prop = dirt_prop
|
||||||
|
self.rewards_dirt = rewards_dirt
|
||||||
|
self._dirt_rng = np.random.default_rng(env_seed)
|
||||||
|
self._dirt: DirtPiles
|
||||||
|
kwargs.update(env_seed=env_seed)
|
||||||
|
# TODO: Reset ---> document this
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
|
||||||
|
def render_assets_hook(self, mode='human'):
|
||||||
|
additional_assets = super().render_assets_hook()
|
||||||
|
dirt = [RenderEntity('dirt', dirt.tile.pos, min(0.15 + dirt.amount, 1.5), 'scale')
|
||||||
|
for dirt in self[c.DIRT]]
|
||||||
|
additional_assets.extend(dirt)
|
||||||
|
return additional_assets
|
||||||
|
|
||||||
|
def do_cleanup_action(self, agent: Agent) -> (dict, dict):
|
||||||
|
if dirt := self[c.DIRT].by_pos(agent.pos):
|
||||||
|
new_dirt_amount = dirt.amount - self.dirt_prop.clean_amount
|
||||||
|
|
||||||
|
if new_dirt_amount <= 0:
|
||||||
|
self[c.DIRT].delete_env_object(dirt)
|
||||||
|
else:
|
||||||
|
dirt.set_new_amount(max(new_dirt_amount, c.FREE_CELL.value))
|
||||||
|
valid = c.VALID
|
||||||
|
self.print(f'{agent.name} did just clean up some dirt at {agent.pos}.')
|
||||||
|
info_dict = {f'{agent.name}_{a.CLEAN_UP}_VALID': 1, 'cleanup_valid': 1}
|
||||||
|
reward = self.rewards_dirt.CLEAN_UP_VALID
|
||||||
|
else:
|
||||||
|
valid = c.NOT_VALID
|
||||||
|
self.print(f'{agent.name} just tried to clean up some dirt at {agent.pos}, but failed.')
|
||||||
|
info_dict = {f'{agent.name}_{a.CLEAN_UP}_FAIL': 1, 'cleanup_fail': 1}
|
||||||
|
reward = self.rewards_dirt.CLEAN_UP_FAIL
|
||||||
|
|
||||||
|
if valid and self.dirt_prop.done_when_clean and (len(self[c.DIRT]) == 0):
|
||||||
|
reward += self.rewards_dirt.CLEAN_UP_LAST_PIECE
|
||||||
|
self.print(f'{agent.name} picked up the last piece of dirt!')
|
||||||
|
info_dict = {f'{agent.name}_{a.CLEAN_UP}_LAST_PIECE': 1}
|
||||||
|
return valid, dict(value=reward, reason=a.CLEAN_UP, info=info_dict)
|
||||||
|
|
||||||
|
def trigger_dirt_spawn(self, initial_spawn=False):
|
||||||
|
dirt_rng = self._dirt_rng
|
||||||
|
free_for_dirt = [x for x in self[c.FLOOR]
|
||||||
|
if len(x.guests) == 0 or (len(x.guests) == 1 and isinstance(next(y for y in x.guests), DirtPile))
|
||||||
|
]
|
||||||
|
self._dirt_rng.shuffle(free_for_dirt)
|
||||||
|
if initial_spawn:
|
||||||
|
var = self.dirt_prop.initial_dirt_spawn_r_var
|
||||||
|
new_spawn = self.dirt_prop.initial_dirt_ratio + dirt_rng.uniform(-var, var)
|
||||||
|
else:
|
||||||
|
new_spawn = dirt_rng.uniform(0, self.dirt_prop.max_spawn_ratio)
|
||||||
|
n_dirt_tiles = max(0, int(new_spawn * len(free_for_dirt)))
|
||||||
|
self[c.DIRT].spawn_dirt(free_for_dirt[:n_dirt_tiles])
|
||||||
|
|
||||||
|
def step_hook(self) -> (List[dict], dict):
|
||||||
|
super_reward_info = super().step_hook()
|
||||||
|
if smear_amount := self.dirt_prop.dirt_smear_amount:
|
||||||
|
for agent in self[c.AGENT]:
|
||||||
|
if agent.step_result['action_valid'] and agent.last_pos != c.NO_POS:
|
||||||
|
if self._actions.is_moving_action(agent.step_result['action_name']):
|
||||||
|
if old_pos_dirt := self[c.DIRT].by_pos(agent.last_pos):
|
||||||
|
if smeared_dirt := round(old_pos_dirt.amount * smear_amount, 2):
|
||||||
|
old_pos_dirt.set_new_amount(max(0, old_pos_dirt.amount-smeared_dirt))
|
||||||
|
if new_pos_dirt := self[c.DIRT].by_pos(agent.pos):
|
||||||
|
new_pos_dirt.set_new_amount(max(0, new_pos_dirt.amount + smeared_dirt))
|
||||||
|
else:
|
||||||
|
if self[c.DIRT].spawn_dirt(agent.tile):
|
||||||
|
new_pos_dirt = self[c.DIRT].by_pos(agent.pos)
|
||||||
|
new_pos_dirt.set_new_amount(max(0, new_pos_dirt.amount + smeared_dirt))
|
||||||
|
if self._next_dirt_spawn < 0:
|
||||||
|
pass # No DirtPile Spawn
|
||||||
|
elif not self._next_dirt_spawn:
|
||||||
|
self.trigger_dirt_spawn()
|
||||||
|
self._next_dirt_spawn = self.dirt_prop.spawn_frequency
|
||||||
|
else:
|
||||||
|
self._next_dirt_spawn -= 1
|
||||||
|
return super_reward_info
|
||||||
|
|
||||||
|
def do_additional_actions(self, agent: Agent, action: Action) -> (dict, dict):
|
||||||
|
action_result = super().do_additional_actions(agent, action)
|
||||||
|
if action_result is None:
|
||||||
|
if action == a.CLEAN_UP:
|
||||||
|
return self.do_cleanup_action(agent)
|
||||||
|
else:
|
||||||
|
return None
|
||||||
|
else:
|
||||||
|
return action_result
|
||||||
|
|
||||||
|
def reset_hook(self) -> None:
|
||||||
|
super().reset_hook()
|
||||||
|
self.trigger_dirt_spawn(initial_spawn=True)
|
||||||
|
self._next_dirt_spawn = self.dirt_prop.spawn_frequency if self.dirt_prop.spawn_frequency else -1
|
||||||
|
|
||||||
|
def check_additional_done(self) -> (bool, dict):
|
||||||
|
super_done, super_dict = super().check_additional_done()
|
||||||
|
if self.dirt_prop.done_when_clean:
|
||||||
|
if all_cleaned := len(self[c.DIRT]) == 0:
|
||||||
|
super_dict.update(ALL_CLEAN_DONE=all_cleaned)
|
||||||
|
return all_cleaned, super_dict
|
||||||
|
return super_done, super_dict
|
||||||
|
|
||||||
|
def observations_hook(self) -> Dict[str, np.typing.ArrayLike]:
|
||||||
|
additional_observations = super().observations_hook()
|
||||||
|
additional_observations.update({c.DIRT: self[c.DIRT].as_array()})
|
||||||
|
return additional_observations
|
||||||
|
|
||||||
|
def post_step_hook(self) -> List[Dict[str, int]]:
|
||||||
|
super_post_step = super(DirtFactory, self).post_step_hook()
|
||||||
|
info_dict = dict()
|
||||||
|
|
||||||
|
dirt = [dirt.amount for dirt in self[c.DIRT]]
|
||||||
|
current_dirt_amount = sum(dirt)
|
||||||
|
dirty_tile_count = len(dirt)
|
||||||
|
|
||||||
|
# if dirty_tile_count:
|
||||||
|
# dirt_distribution_score = entropy(softmax(np.asarray(dirt)) / dirty_tile_count)
|
||||||
|
# else:
|
||||||
|
# dirt_distribution_score = 0
|
||||||
|
|
||||||
|
info_dict.update(dirt_amount=current_dirt_amount)
|
||||||
|
info_dict.update(dirty_tile_count=dirty_tile_count)
|
||||||
|
|
||||||
|
super_post_step.append(info_dict)
|
||||||
|
return super_post_step
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
from environments.utility_classes import AgentRenderOptions as aro
|
||||||
|
render = True
|
||||||
|
|
||||||
|
dirt_props = DirtProperties(
|
||||||
|
initial_dirt_ratio=0.35,
|
||||||
|
initial_dirt_spawn_r_var=0.1,
|
||||||
|
clean_amount=0.34,
|
||||||
|
max_spawn_amount=0.1,
|
||||||
|
max_global_amount=20,
|
||||||
|
max_local_amount=1,
|
||||||
|
spawn_frequency=0,
|
||||||
|
max_spawn_ratio=0.05,
|
||||||
|
dirt_smear_amount=0.0
|
||||||
|
)
|
||||||
|
|
||||||
|
obs_props = ObservationProperties(render_agents=aro.COMBINED, omit_agent_self=True,
|
||||||
|
pomdp_r=2, additional_agent_placeholder=None, cast_shadows=True,
|
||||||
|
indicate_door_area=False)
|
||||||
|
|
||||||
|
move_props = {'allow_square_movement': True,
|
||||||
|
'allow_diagonal_movement': False,
|
||||||
|
'allow_no_op': False}
|
||||||
|
import time
|
||||||
|
global_timings = []
|
||||||
|
for i in range(10):
|
||||||
|
|
||||||
|
factory = DirtFactory(n_agents=10, done_at_collision=False,
|
||||||
|
level_name='rooms', max_steps=1000,
|
||||||
|
doors_have_area=False,
|
||||||
|
obs_prop=obs_props, parse_doors=True,
|
||||||
|
verbose=True,
|
||||||
|
mv_prop=move_props, dirt_prop=dirt_props,
|
||||||
|
# inject_agents=[TSPDirtAgent],
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# noinspection DuplicatedCode
|
||||||
|
n_actions = factory.action_space.n - 1
|
||||||
|
_ = factory.observation_space
|
||||||
|
obs_space = factory.observation_space
|
||||||
|
obs_space_named = factory.named_observation_space
|
||||||
|
action_space_named = factory.named_action_space
|
||||||
|
times = []
|
||||||
|
for epoch in range(10):
|
||||||
|
start_time = time.time()
|
||||||
|
random_actions = [[random.randint(0, n_actions) for _
|
||||||
|
in range(factory.n_agents)] for _
|
||||||
|
in range(factory.max_steps+1)]
|
||||||
|
env_state = factory.reset()
|
||||||
|
if render:
|
||||||
|
factory.render()
|
||||||
|
# tsp_agent = factory.get_injected_agents()[0]
|
||||||
|
|
||||||
|
rwrd = 0
|
||||||
|
for agent_i_action in random_actions:
|
||||||
|
# agent_i_action = tsp_agent.predict()
|
||||||
|
env_state, step_rwrd, done_bool, info_obj = factory.step(agent_i_action)
|
||||||
|
rwrd += step_rwrd
|
||||||
|
if render:
|
||||||
|
factory.render()
|
||||||
|
if done_bool:
|
||||||
|
break
|
||||||
|
times.append(time.time() - start_time)
|
||||||
|
# print(f'Factory run {epoch} done, reward is:\n {r}')
|
||||||
|
print('Mean Time Taken: ', sum(times) / 10)
|
||||||
|
global_timings.extend(times)
|
||||||
|
print('Mean Time Taken: ', sum(global_timings) / len(global_timings))
|
||||||
|
print('Median Time Taken: ', global_timings[len(global_timings)//2])
|
||||||
|
|
||||||
|
pass
|
0
environments/factory/additional/item/__init__.py
Normal file
0
environments/factory/additional/item/__init__.py
Normal file
193
environments/factory/additional/item/factory_item.py
Normal file
193
environments/factory/additional/item/factory_item.py
Normal file
@ -0,0 +1,193 @@
|
|||||||
|
import time
|
||||||
|
from typing import List, Union, Dict
|
||||||
|
import numpy as np
|
||||||
|
import random
|
||||||
|
|
||||||
|
from environments.factory.additional.item.item_collections import Items, Inventories, DropOffLocations
|
||||||
|
from environments.factory.additional.item.item_util import Constants, Actions, RewardsItem, ItemProperties
|
||||||
|
from environments.factory.base.base_factory import BaseFactory
|
||||||
|
from environments.factory.base.objects import Agent, Action
|
||||||
|
from environments.factory.base.registers import Entities
|
||||||
|
|
||||||
|
from environments.factory.base.renderer import RenderEntity
|
||||||
|
|
||||||
|
c = Constants
|
||||||
|
a = Actions
|
||||||
|
|
||||||
|
|
||||||
|
# noinspection PyAttributeOutsideInit, PyAbstractClass
|
||||||
|
class ItemFactory(BaseFactory):
|
||||||
|
# noinspection PyMissingConstructor
|
||||||
|
def __init__(self, *args, item_prop: ItemProperties = ItemProperties(), env_seed=time.time_ns(),
|
||||||
|
rewards_item: RewardsItem = RewardsItem(), **kwargs):
|
||||||
|
if isinstance(item_prop, dict):
|
||||||
|
item_prop = ItemProperties(**item_prop)
|
||||||
|
if isinstance(rewards_item, dict):
|
||||||
|
rewards_item = RewardsItem(**rewards_item)
|
||||||
|
self.item_prop = item_prop
|
||||||
|
self.rewards_item = rewards_item
|
||||||
|
kwargs.update(env_seed=env_seed)
|
||||||
|
self._item_rng = np.random.default_rng(env_seed)
|
||||||
|
assert (item_prop.n_items <= ((1 + kwargs.get('_pomdp_r', 0) * 2) ** 2)) or not kwargs.get('_pomdp_r', 0)
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def actions_hook(self) -> Union[Action, List[Action]]:
|
||||||
|
# noinspection PyUnresolvedReferences
|
||||||
|
super_actions = super().actions_hook
|
||||||
|
super_actions.append(Action(str_ident=a.ITEM_ACTION))
|
||||||
|
return super_actions
|
||||||
|
|
||||||
|
@property
|
||||||
|
def entities_hook(self) -> Dict[(str, Entities)]:
|
||||||
|
# noinspection PyUnresolvedReferences
|
||||||
|
super_entities = super().entities_hook
|
||||||
|
|
||||||
|
empty_tiles = self[c.FLOOR].empty_tiles[:self.item_prop.n_drop_off_locations]
|
||||||
|
drop_offs = DropOffLocations.from_tiles(
|
||||||
|
empty_tiles, self._level_shape,
|
||||||
|
entity_kwargs=dict(
|
||||||
|
storage_size_until_full=self.item_prop.max_dropoff_storage_size)
|
||||||
|
)
|
||||||
|
item_register = Items(self._level_shape)
|
||||||
|
empty_tiles = self[c.FLOOR].empty_tiles[:self.item_prop.n_items]
|
||||||
|
item_register.spawn_items(empty_tiles)
|
||||||
|
|
||||||
|
inventories = Inventories(self._obs_shape, self._level_shape)
|
||||||
|
inventories.spawn_inventories(self[c.AGENT], self.item_prop.max_agent_inventory_capacity)
|
||||||
|
|
||||||
|
super_entities.update({c.DROP_OFF: drop_offs, c.ITEM: item_register, c.INVENTORY: inventories})
|
||||||
|
return super_entities
|
||||||
|
|
||||||
|
def per_agent_raw_observations_hook(self, agent) -> Dict[str, np.typing.ArrayLike]:
|
||||||
|
additional_raw_observations = super().per_agent_raw_observations_hook(agent)
|
||||||
|
additional_raw_observations.update({c.INVENTORY: self[c.INVENTORY].by_entity(agent).as_array()})
|
||||||
|
return additional_raw_observations
|
||||||
|
|
||||||
|
def observations_hook(self) -> Dict[str, np.typing.ArrayLike]:
|
||||||
|
additional_observations = super().observations_hook()
|
||||||
|
additional_observations.update({c.ITEM: self[c.ITEM].as_array()})
|
||||||
|
additional_observations.update({c.DROP_OFF: self[c.DROP_OFF].as_array()})
|
||||||
|
return additional_observations
|
||||||
|
|
||||||
|
def do_item_action(self, agent: Agent) -> (dict, dict):
|
||||||
|
inventory = self[c.INVENTORY].by_entity(agent)
|
||||||
|
if drop_off := self[c.DROP_OFF].by_pos(agent.pos):
|
||||||
|
if inventory:
|
||||||
|
valid = drop_off.place_item(inventory.pop())
|
||||||
|
else:
|
||||||
|
valid = c.NOT_VALID
|
||||||
|
if valid:
|
||||||
|
self.print(f'{agent.name} just dropped of an item at {drop_off.pos}.')
|
||||||
|
info_dict = {f'{agent.name}_DROPOFF_VALID': 1, 'DROPOFF_VALID': 1}
|
||||||
|
else:
|
||||||
|
self.print(f'{agent.name} just tried to drop off at {agent.pos}, but failed.')
|
||||||
|
info_dict = {f'{agent.name}_DROPOFF_FAIL': 1, 'DROPOFF_FAIL': 1}
|
||||||
|
reward = dict(value=self.rewards_item.DROP_OFF_VALID if valid else self.rewards_item.DROP_OFF_FAIL,
|
||||||
|
reason=a.ITEM_ACTION, info=info_dict)
|
||||||
|
return valid, reward
|
||||||
|
elif item := self[c.ITEM].by_pos(agent.pos):
|
||||||
|
item.change_parent_collection(inventory)
|
||||||
|
item.set_tile_to(self._NO_POS_TILE)
|
||||||
|
self.print(f'{agent.name} just picked up an item at {agent.pos}')
|
||||||
|
info_dict = {f'{agent.name}_{a.ITEM_ACTION}_VALID': 1, f'{a.ITEM_ACTION}_VALID': 1}
|
||||||
|
return c.VALID, dict(value=self.rewards_item.PICK_UP_VALID, reason=a.ITEM_ACTION, info=info_dict)
|
||||||
|
else:
|
||||||
|
self.print(f'{agent.name} just tried to pick up an item at {agent.pos}, but failed.')
|
||||||
|
info_dict = {f'{agent.name}_{a.ITEM_ACTION}_FAIL': 1, f'{a.ITEM_ACTION}_FAIL': 1}
|
||||||
|
return c.NOT_VALID, dict(value=self.rewards_item.PICK_UP_FAIL, reason=a.ITEM_ACTION, info=info_dict)
|
||||||
|
|
||||||
|
def do_additional_actions(self, agent: Agent, action: Action) -> (dict, dict):
|
||||||
|
# noinspection PyUnresolvedReferences
|
||||||
|
action_result = super().do_additional_actions(agent, action)
|
||||||
|
if action_result is None:
|
||||||
|
if action == a.ITEM_ACTION:
|
||||||
|
action_result = self.do_item_action(agent)
|
||||||
|
return action_result
|
||||||
|
else:
|
||||||
|
return None
|
||||||
|
else:
|
||||||
|
return action_result
|
||||||
|
|
||||||
|
def reset_hook(self) -> None:
|
||||||
|
# noinspection PyUnresolvedReferences
|
||||||
|
super().reset_hook()
|
||||||
|
self._next_item_spawn = self.item_prop.spawn_frequency
|
||||||
|
self.trigger_item_spawn()
|
||||||
|
|
||||||
|
def trigger_item_spawn(self):
|
||||||
|
if item_to_spawns := max(0, (self.item_prop.n_items - len(self[c.ITEM]))):
|
||||||
|
empty_tiles = self[c.FLOOR].empty_tiles[:item_to_spawns]
|
||||||
|
self[c.ITEM].spawn_items(empty_tiles)
|
||||||
|
self._next_item_spawn = self.item_prop.spawn_frequency
|
||||||
|
self.print(f'{item_to_spawns} new items have been spawned; next spawn in {self._next_item_spawn}')
|
||||||
|
else:
|
||||||
|
self.print('No Items are spawning, limit is reached.')
|
||||||
|
|
||||||
|
def step_hook(self) -> (List[dict], dict):
|
||||||
|
# noinspection PyUnresolvedReferences
|
||||||
|
super_reward_info = super().step_hook()
|
||||||
|
for item in list(self[c.ITEM].values()):
|
||||||
|
if item.auto_despawn >= 1:
|
||||||
|
item.set_auto_despawn(item.auto_despawn-1)
|
||||||
|
elif not item.auto_despawn:
|
||||||
|
self[c.ITEM].delete_env_object(item)
|
||||||
|
else:
|
||||||
|
pass
|
||||||
|
|
||||||
|
if not self._next_item_spawn:
|
||||||
|
self.trigger_item_spawn()
|
||||||
|
else:
|
||||||
|
self._next_item_spawn = max(0, self._next_item_spawn-1)
|
||||||
|
return super_reward_info
|
||||||
|
|
||||||
|
def render_assets_hook(self, mode='human'):
|
||||||
|
# noinspection PyUnresolvedReferences
|
||||||
|
additional_assets = super().render_assets_hook()
|
||||||
|
items = [RenderEntity(c.ITEM, item.tile.pos) for item in self[c.ITEM] if item.tile != self._NO_POS_TILE]
|
||||||
|
additional_assets.extend(items)
|
||||||
|
drop_offs = [RenderEntity(c.DROP_OFF, drop_off.tile.pos) for drop_off in self[c.DROP_OFF]]
|
||||||
|
additional_assets.extend(drop_offs)
|
||||||
|
return additional_assets
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
from environments.utility_classes import AgentRenderOptions as aro, ObservationProperties
|
||||||
|
|
||||||
|
render = True
|
||||||
|
|
||||||
|
item_probs = ItemProperties(n_items=30, n_drop_off_locations=6)
|
||||||
|
|
||||||
|
obs_props = ObservationProperties(render_agents=aro.SEPERATE, omit_agent_self=True, pomdp_r=2)
|
||||||
|
|
||||||
|
move_props = {'allow_square_movement': True,
|
||||||
|
'allow_diagonal_movement': True,
|
||||||
|
'allow_no_op': False}
|
||||||
|
|
||||||
|
factory = ItemFactory(n_agents=6, done_at_collision=False,
|
||||||
|
level_name='rooms', max_steps=400,
|
||||||
|
obs_prop=obs_props, parse_doors=True,
|
||||||
|
record_episodes=True, verbose=True,
|
||||||
|
mv_prop=move_props, item_prop=item_probs
|
||||||
|
)
|
||||||
|
|
||||||
|
# noinspection DuplicatedCode
|
||||||
|
n_actions = factory.action_space.n - 1
|
||||||
|
obs_space = factory.observation_space
|
||||||
|
obs_space_named = factory.named_observation_space
|
||||||
|
|
||||||
|
for epoch in range(400):
|
||||||
|
random_actions = [[random.randint(0, n_actions) for _
|
||||||
|
in range(factory.n_agents)] for _
|
||||||
|
in range(factory.max_steps + 1)]
|
||||||
|
env_state = factory.reset()
|
||||||
|
rwrd = 0
|
||||||
|
for agent_i_action in random_actions:
|
||||||
|
env_state, step_r, done_bool, info_obj = factory.step(agent_i_action)
|
||||||
|
rwrd += step_r
|
||||||
|
if render:
|
||||||
|
factory.render()
|
||||||
|
if done_bool:
|
||||||
|
break
|
||||||
|
print(f'Factory run {epoch} done, reward is:\n {rwrd}')
|
||||||
|
pass
|
89
environments/factory/additional/item/item_collections.py
Normal file
89
environments/factory/additional/item/item_collections.py
Normal file
@ -0,0 +1,89 @@
|
|||||||
|
from typing import List
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from environments.factory.base.objects import Floor, Agent
|
||||||
|
from environments.factory.base.registers import EntityCollection, BoundEnvObjCollection, ObjectCollection
|
||||||
|
from environments.factory.additional.item.item_entities import Item, DropOffLocation
|
||||||
|
|
||||||
|
|
||||||
|
class Items(EntityCollection):
|
||||||
|
|
||||||
|
_accepted_objects = Item
|
||||||
|
|
||||||
|
def spawn_items(self, tiles: List[Floor]):
|
||||||
|
items = [Item(tile, self) for tile in tiles]
|
||||||
|
self.add_additional_items(items)
|
||||||
|
|
||||||
|
def despawn_items(self, items: List[Item]):
|
||||||
|
items = [items] if isinstance(items, Item) else items
|
||||||
|
for item in items:
|
||||||
|
del self[item]
|
||||||
|
|
||||||
|
|
||||||
|
class Inventory(BoundEnvObjCollection):
|
||||||
|
|
||||||
|
@property
|
||||||
|
def name(self):
|
||||||
|
return f'{self.__class__.__name__}({self._bound_entity.name})'
|
||||||
|
|
||||||
|
def __init__(self, agent: Agent, capacity: int, *args, **kwargs):
|
||||||
|
super(Inventory, self).__init__(agent, *args, is_blocking_light=False, can_be_shadowed=False, **kwargs)
|
||||||
|
self.capacity = capacity
|
||||||
|
|
||||||
|
def as_array(self):
|
||||||
|
if self._array is None:
|
||||||
|
self._array = np.zeros((1, *self._shape))
|
||||||
|
return super(Inventory, self).as_array()
|
||||||
|
|
||||||
|
def summarize_states(self, **kwargs):
|
||||||
|
attr_dict = {key: val for key, val in self.__dict__.items() if not key.startswith('_') and key != 'data'}
|
||||||
|
attr_dict.update(dict(items=[val.summarize_state(**kwargs) for key, val in self.items()]))
|
||||||
|
attr_dict.update(dict(name=self.name, belongs_to=self._bound_entity.name))
|
||||||
|
return attr_dict
|
||||||
|
|
||||||
|
def pop(self):
|
||||||
|
item_to_pop = self[0]
|
||||||
|
self.delete_env_object(item_to_pop)
|
||||||
|
return item_to_pop
|
||||||
|
|
||||||
|
|
||||||
|
class Inventories(ObjectCollection):
|
||||||
|
|
||||||
|
_accepted_objects = Inventory
|
||||||
|
is_blocking_light = False
|
||||||
|
can_be_shadowed = False
|
||||||
|
|
||||||
|
def __init__(self, obs_shape, *args, **kwargs):
|
||||||
|
super(Inventories, self).__init__(*args, is_per_agent=True, individual_slices=True, **kwargs)
|
||||||
|
self._obs_shape = obs_shape
|
||||||
|
|
||||||
|
def as_array(self):
|
||||||
|
return np.stack([inventory.as_array() for inv_idx, inventory in enumerate(self)])
|
||||||
|
|
||||||
|
def spawn_inventories(self, agents, capacity):
|
||||||
|
inventories = [self._accepted_objects(agent, capacity, self._obs_shape)
|
||||||
|
for _, agent in enumerate(agents)]
|
||||||
|
self.add_additional_items(inventories)
|
||||||
|
|
||||||
|
def idx_by_entity(self, entity):
|
||||||
|
try:
|
||||||
|
return next((idx for idx, inv in enumerate(self) if inv.belongs_to_entity(entity)))
|
||||||
|
except StopIteration:
|
||||||
|
return None
|
||||||
|
|
||||||
|
def by_entity(self, entity):
|
||||||
|
try:
|
||||||
|
return next((inv for inv in self if inv.belongs_to_entity(entity)))
|
||||||
|
except StopIteration:
|
||||||
|
return None
|
||||||
|
|
||||||
|
def summarize_states(self, **kwargs):
|
||||||
|
return [val.summarize_states(**kwargs) for key, val in self.items()]
|
||||||
|
|
||||||
|
|
||||||
|
class DropOffLocations(EntityCollection):
|
||||||
|
|
||||||
|
_accepted_objects = DropOffLocation
|
||||||
|
_stateless_entities = True
|
||||||
|
|
57
environments/factory/additional/item/item_entities.py
Normal file
57
environments/factory/additional/item/item_entities.py
Normal file
@ -0,0 +1,57 @@
|
|||||||
|
from collections import deque
|
||||||
|
|
||||||
|
from environments import helpers as h
|
||||||
|
from environments.factory.additional.item.item_util import Constants
|
||||||
|
from environments.factory.base.objects import Entity
|
||||||
|
|
||||||
|
|
||||||
|
class Item(Entity):
|
||||||
|
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
self._auto_despawn = -1
|
||||||
|
|
||||||
|
@property
|
||||||
|
def auto_despawn(self):
|
||||||
|
return self._auto_despawn
|
||||||
|
|
||||||
|
@property
|
||||||
|
def encoding(self):
|
||||||
|
# Edit this if you want items to be drawn in the ops differently
|
||||||
|
return 1
|
||||||
|
|
||||||
|
def set_auto_despawn(self, auto_despawn):
|
||||||
|
self._auto_despawn = auto_despawn
|
||||||
|
|
||||||
|
def set_tile_to(self, no_pos_tile):
|
||||||
|
self._tile = no_pos_tile
|
||||||
|
|
||||||
|
def summarize_state(self) -> dict:
|
||||||
|
super_summarization = super(Item, self).summarize_state()
|
||||||
|
super_summarization.update(dict(auto_despawn=self.auto_despawn))
|
||||||
|
return super_summarization
|
||||||
|
|
||||||
|
|
||||||
|
class DropOffLocation(Entity):
|
||||||
|
|
||||||
|
@property
|
||||||
|
def encoding(self):
|
||||||
|
return Constants.ITEM_DROP_OFF
|
||||||
|
|
||||||
|
def __init__(self, *args, storage_size_until_full: int = 5, auto_item_despawn_interval: int = 5, **kwargs):
|
||||||
|
super(DropOffLocation, self).__init__(*args, **kwargs)
|
||||||
|
self.auto_item_despawn_interval = auto_item_despawn_interval
|
||||||
|
self.storage = deque(maxlen=storage_size_until_full or None)
|
||||||
|
|
||||||
|
def place_item(self, item: Item):
|
||||||
|
if self.is_full:
|
||||||
|
raise RuntimeWarning("There is currently no way to clear the storage or make it unfull.")
|
||||||
|
return c.NOT_VALID
|
||||||
|
else:
|
||||||
|
self.storage.append(item)
|
||||||
|
item.set_auto_despawn(self.auto_item_despawn_interval)
|
||||||
|
return Constants.VALID
|
||||||
|
|
||||||
|
@property
|
||||||
|
def is_full(self):
|
||||||
|
return False if not self.storage.maxlen else self.storage.maxlen == len(self.storage)
|
31
environments/factory/additional/item/item_util.py
Normal file
31
environments/factory/additional/item/item_util.py
Normal file
@ -0,0 +1,31 @@
|
|||||||
|
from typing import NamedTuple
|
||||||
|
|
||||||
|
from environments.helpers import Constants as BaseConstants, EnvActions as BaseActions
|
||||||
|
|
||||||
|
|
||||||
|
class Constants(BaseConstants):
|
||||||
|
NO_ITEM = 0
|
||||||
|
ITEM_DROP_OFF = 1
|
||||||
|
# Item Env
|
||||||
|
ITEM = 'Item'
|
||||||
|
INVENTORY = 'Inventory'
|
||||||
|
DROP_OFF = 'Drop_Off'
|
||||||
|
|
||||||
|
|
||||||
|
class Actions(BaseActions):
|
||||||
|
ITEM_ACTION = 'ITEMACTION'
|
||||||
|
|
||||||
|
|
||||||
|
class RewardsItem(NamedTuple):
|
||||||
|
DROP_OFF_VALID: float = 0.1
|
||||||
|
DROP_OFF_FAIL: float = -0.1
|
||||||
|
PICK_UP_FAIL: float = -0.1
|
||||||
|
PICK_UP_VALID: float = 0.1
|
||||||
|
|
||||||
|
|
||||||
|
class ItemProperties(NamedTuple):
|
||||||
|
n_items: int = 5 # How many items are there at the same time
|
||||||
|
spawn_frequency: int = 10 # Spawn Frequency in Steps
|
||||||
|
n_drop_off_locations: int = 5 # How many DropOff locations are there at the same time
|
||||||
|
max_dropoff_storage_size: int = 0 # How many items are needed until the dropoff is full
|
||||||
|
max_agent_inventory_capacity: int = 5 # How many items are needed until the agent inventory is full
|
@ -1,7 +1,7 @@
|
|||||||
import abc
|
import abc
|
||||||
import time
|
import time
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from enum import Enum
|
from itertools import chain
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import List, Union, Iterable, Dict
|
from typing import List, Union, Iterable, Dict
|
||||||
import numpy as np
|
import numpy as np
|
||||||
@ -11,10 +11,13 @@ from gym import spaces
|
|||||||
from gym.wrappers import FrameStack
|
from gym.wrappers import FrameStack
|
||||||
|
|
||||||
from environments.factory.base.shadow_casting import Map
|
from environments.factory.base.shadow_casting import Map
|
||||||
from environments.helpers import Constants as c, Constants
|
|
||||||
from environments import helpers as h
|
from environments import helpers as h
|
||||||
from environments.factory.base.objects import Agent, Tile, Action
|
from environments.helpers import Constants as c
|
||||||
from environments.factory.base.registers import Actions, Entities, Agents, Doors, FloorTiles, WallTiles, PlaceHolders
|
from environments.helpers import EnvActions as a
|
||||||
|
from environments.helpers import RewardsBase
|
||||||
|
from environments.factory.base.objects import Agent, Floor, Action
|
||||||
|
from environments.factory.base.registers import Actions, Entities, Agents, Doors, Floors, Walls, PlaceHolders, \
|
||||||
|
GlobalPositions
|
||||||
from environments.utility_classes import MovementProperties, ObservationProperties, MarlFrameStack
|
from environments.utility_classes import MovementProperties, ObservationProperties, MarlFrameStack
|
||||||
from environments.utility_classes import AgentRenderOptions as a_obs
|
from environments.utility_classes import AgentRenderOptions as a_obs
|
||||||
|
|
||||||
@ -30,17 +33,30 @@ class BaseFactory(gym.Env):
|
|||||||
def action_space(self):
|
def action_space(self):
|
||||||
return spaces.Discrete(len(self._actions))
|
return spaces.Discrete(len(self._actions))
|
||||||
|
|
||||||
|
@property
|
||||||
|
def named_action_space(self):
|
||||||
|
return {x.identifier: idx for idx, x in enumerate(self._actions.values())}
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def observation_space(self):
|
def observation_space(self):
|
||||||
if r := self._pomdp_r:
|
obs, _ = self._build_observations()
|
||||||
z = self._obs_cube.shape[0]
|
if self.n_agents > 1:
|
||||||
xy = r*2 + 1
|
shape = obs[0].shape
|
||||||
level_shape = (z, xy, xy)
|
|
||||||
else:
|
else:
|
||||||
level_shape = self._obs_cube.shape
|
shape = obs.shape
|
||||||
space = spaces.Box(low=0, high=1, shape=level_shape, dtype=np.float32)
|
space = spaces.Box(low=0, high=1, shape=shape, dtype=np.float32)
|
||||||
return space
|
return space
|
||||||
|
|
||||||
|
@property
|
||||||
|
def named_observation_space(self):
|
||||||
|
# Build it
|
||||||
|
_, named_obs = self._build_observations()
|
||||||
|
if self.n_agents > 1:
|
||||||
|
# Only return the first named obs space, as their structure at the moment is same.
|
||||||
|
return named_obs[list(named_obs.keys())[0]]
|
||||||
|
else:
|
||||||
|
return named_obs
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def pomdp_diameter(self):
|
def pomdp_diameter(self):
|
||||||
return self._pomdp_r * 2 + 1
|
return self._pomdp_r * 2 + 1
|
||||||
@ -52,8 +68,15 @@ class BaseFactory(gym.Env):
|
|||||||
@property
|
@property
|
||||||
def params(self) -> dict:
|
def params(self) -> dict:
|
||||||
d = {key: val for key, val in self.__dict__.items() if not key.startswith('_') and not key.startswith('__')}
|
d = {key: val for key, val in self.__dict__.items() if not key.startswith('_') and not key.startswith('__')}
|
||||||
|
d['class_name'] = self.__class__.__name__
|
||||||
return d
|
return d
|
||||||
|
|
||||||
|
@property
|
||||||
|
def summarize_header(self):
|
||||||
|
summary_dict = self._summarize_state(stateless_entities=True)
|
||||||
|
summary_dict.update(actions=self._actions.summarize())
|
||||||
|
return summary_dict
|
||||||
|
|
||||||
def __enter__(self):
|
def __enter__(self):
|
||||||
return self if self.obs_prop.frames_to_stack == 0 else \
|
return self if self.obs_prop.frames_to_stack == 0 else \
|
||||||
MarlFrameStack(FrameStack(self, self.obs_prop.frames_to_stack))
|
MarlFrameStack(FrameStack(self, self.obs_prop.frames_to_stack))
|
||||||
@ -64,17 +87,26 @@ class BaseFactory(gym.Env):
|
|||||||
def __init__(self, level_name='simple', n_agents=1, max_steps=int(5e2),
|
def __init__(self, level_name='simple', n_agents=1, max_steps=int(5e2),
|
||||||
mv_prop: MovementProperties = MovementProperties(),
|
mv_prop: MovementProperties = MovementProperties(),
|
||||||
obs_prop: ObservationProperties = ObservationProperties(),
|
obs_prop: ObservationProperties = ObservationProperties(),
|
||||||
|
rewards_base: RewardsBase = RewardsBase(),
|
||||||
parse_doors=False, done_at_collision=False, inject_agents: Union[None, List] = None,
|
parse_doors=False, done_at_collision=False, inject_agents: Union[None, List] = None,
|
||||||
verbose=False, doors_have_area=True, env_seed=time.time_ns(), individual_rewards=False,
|
verbose=False, doors_have_area=True, env_seed=time.time_ns(), individual_rewards=False,
|
||||||
**kwargs):
|
class_name='', **kwargs):
|
||||||
|
|
||||||
|
if class_name:
|
||||||
|
print(f'You loaded parameters for {class_name}', f'this is: {self.__class__.__name__}')
|
||||||
|
|
||||||
if isinstance(mv_prop, dict):
|
if isinstance(mv_prop, dict):
|
||||||
mv_prop = MovementProperties(**mv_prop)
|
mv_prop = MovementProperties(**mv_prop)
|
||||||
if isinstance(obs_prop, dict):
|
if isinstance(obs_prop, dict):
|
||||||
obs_prop = ObservationProperties(**obs_prop)
|
obs_prop = ObservationProperties(**obs_prop)
|
||||||
|
if isinstance(rewards_base, dict):
|
||||||
|
rewards_base = RewardsBase(**rewards_base)
|
||||||
|
|
||||||
assert obs_prop.frames_to_stack != 1 and \
|
assert obs_prop.frames_to_stack != 1 and \
|
||||||
obs_prop.frames_to_stack >= 0, "'frames_to_stack' cannot be negative or 1."
|
obs_prop.frames_to_stack >= 0, \
|
||||||
|
"'frames_to_stack' cannot be negative or 1."
|
||||||
|
assert doors_have_area or not obs_prop.indicate_door_area, \
|
||||||
|
'"indicate_door_area" can only active, when "doors_have_area"'
|
||||||
if kwargs:
|
if kwargs:
|
||||||
print(f'Following kwargs were passed, but ignored: {kwargs}')
|
print(f'Following kwargs were passed, but ignored: {kwargs}')
|
||||||
|
|
||||||
@ -84,13 +116,17 @@ class BaseFactory(gym.Env):
|
|||||||
self._base_rng = np.random.default_rng(self.env_seed)
|
self._base_rng = np.random.default_rng(self.env_seed)
|
||||||
self.mv_prop = mv_prop
|
self.mv_prop = mv_prop
|
||||||
self.obs_prop = obs_prop
|
self.obs_prop = obs_prop
|
||||||
|
self.rewards_base = rewards_base
|
||||||
self.level_name = level_name
|
self.level_name = level_name
|
||||||
self._level_shape = None
|
self._level_shape = None
|
||||||
|
self._obs_shape = None
|
||||||
self.verbose = verbose
|
self.verbose = verbose
|
||||||
self._renderer = None # expensive - don't use it when not required !
|
self._renderer = None # expensive - don't use it when not required !
|
||||||
self._entities = Entities()
|
self._entities = Entities()
|
||||||
|
|
||||||
self.n_agents = n_agents
|
self.n_agents = n_agents
|
||||||
|
level_filepath = Path(__file__).parent.parent / h.LEVELS_DIR / f'{self.level_name}.txt'
|
||||||
|
self._parsed_level = h.parse_level(level_filepath)
|
||||||
|
|
||||||
self.max_steps = max_steps
|
self.max_steps = max_steps
|
||||||
self._pomdp_r = self.obs_prop.pomdp_r
|
self._pomdp_r = self.obs_prop.pomdp_r
|
||||||
@ -102,7 +138,7 @@ class BaseFactory(gym.Env):
|
|||||||
self.doors_have_area = doors_have_area
|
self.doors_have_area = doors_have_area
|
||||||
self.individual_rewards = individual_rewards
|
self.individual_rewards = individual_rewards
|
||||||
|
|
||||||
# Reset
|
# TODO: Reset ---> document this
|
||||||
self.reset()
|
self.reset()
|
||||||
|
|
||||||
def __getitem__(self, item):
|
def __getitem__(self, item):
|
||||||
@ -114,94 +150,96 @@ class BaseFactory(gym.Env):
|
|||||||
# Objects
|
# Objects
|
||||||
self._entities = Entities()
|
self._entities = Entities()
|
||||||
# Level
|
# Level
|
||||||
level_filepath = Path(__file__).parent.parent / h.LEVELS_DIR / f'{self.level_name}.txt'
|
level_array = h.one_hot_level(self._parsed_level)
|
||||||
parsed_level = h.parse_level(level_filepath)
|
self._level_init_shape = level_array.shape
|
||||||
level_array = h.one_hot_level(parsed_level)
|
level_array = np.pad(level_array, self.obs_prop.pomdp_r, 'constant', constant_values=c.OCCUPIED_CELL)
|
||||||
|
|
||||||
self._level_shape = level_array.shape
|
self._level_shape = level_array.shape
|
||||||
|
self._obs_shape = self._level_shape if not self.obs_prop.pomdp_r else (self.pomdp_diameter, ) * 2
|
||||||
|
|
||||||
# Walls
|
# Walls
|
||||||
walls = WallTiles.from_argwhere_coordinates(
|
walls = Walls.from_argwhere_coordinates(
|
||||||
np.argwhere(level_array == c.OCCUPIED_CELL.value),
|
np.argwhere(level_array == c.OCCUPIED_CELL),
|
||||||
self._level_shape
|
self._level_shape
|
||||||
)
|
)
|
||||||
self._entities.register_additional_items({c.WALLS: walls})
|
self._entities.add_additional_items({c.WALLS: walls})
|
||||||
|
|
||||||
# Floor
|
# Floor
|
||||||
floor = FloorTiles.from_argwhere_coordinates(
|
floor = Floors.from_argwhere_coordinates(
|
||||||
np.argwhere(level_array == c.FREE_CELL.value),
|
np.argwhere(level_array == c.FREE_CELL),
|
||||||
self._level_shape
|
self._level_shape
|
||||||
)
|
)
|
||||||
self._entities.register_additional_items({c.FLOOR: floor})
|
self._entities.add_additional_items({c.FLOOR: floor})
|
||||||
|
|
||||||
# NOPOS
|
# NOPOS
|
||||||
self._NO_POS_TILE = Tile(c.NO_POS.value)
|
self._NO_POS_TILE = Floor(c.NO_POS, None)
|
||||||
|
|
||||||
# Doors
|
# Doors
|
||||||
if self.parse_doors:
|
if self.parse_doors:
|
||||||
parsed_doors = h.one_hot_level(parsed_level, c.DOOR)
|
parsed_doors = h.one_hot_level(self._parsed_level, c.DOOR)
|
||||||
|
parsed_doors = np.pad(parsed_doors, self.obs_prop.pomdp_r, 'constant', constant_values=0)
|
||||||
if np.any(parsed_doors):
|
if np.any(parsed_doors):
|
||||||
door_tiles = [floor.by_pos(pos) for pos in np.argwhere(parsed_doors == c.OCCUPIED_CELL.value)]
|
door_tiles = [floor.by_pos(tuple(pos)) for pos in np.argwhere(parsed_doors == c.OCCUPIED_CELL)]
|
||||||
doors = Doors.from_tiles(door_tiles, self._level_shape,
|
doors = Doors.from_tiles(door_tiles, self._level_shape, have_area=self.obs_prop.indicate_door_area,
|
||||||
entity_kwargs=dict(context=floor)
|
entity_kwargs=dict(context=floor)
|
||||||
)
|
)
|
||||||
self._entities.register_additional_items({c.DOORS: doors})
|
self._entities.add_additional_items({c.DOORS: doors})
|
||||||
|
|
||||||
# Actions
|
# Actions
|
||||||
|
# TODO: Move this to Agent init, so that agents can have individual action sets.
|
||||||
self._actions = Actions(self.mv_prop, can_use_doors=self.parse_doors)
|
self._actions = Actions(self.mv_prop, can_use_doors=self.parse_doors)
|
||||||
if additional_actions := self.additional_actions:
|
if additional_actions := self.actions_hook:
|
||||||
self._actions.register_additional_items(additional_actions)
|
self._actions.add_additional_items(additional_actions)
|
||||||
|
|
||||||
# Agents
|
# Agents
|
||||||
agents_to_spawn = self.n_agents-len(self._injected_agents)
|
agents_to_spawn = self.n_agents-len(self._injected_agents)
|
||||||
agents_kwargs = dict(level_shape=self._level_shape,
|
agents_kwargs = dict(individual_slices=self.obs_prop.render_agents == a_obs.SEPERATE,
|
||||||
individual_slices=self.obs_prop.render_agents == a_obs.SEPERATE,
|
hide_from_obs_builder=self.obs_prop.render_agents in [a_obs.NOT, a_obs.LEVEL],
|
||||||
hide_from_obs_builder=self.obs_prop.render_agents == a_obs.LEVEL,
|
)
|
||||||
is_observable=self.obs_prop.render_agents != a_obs.NOT)
|
|
||||||
if agents_to_spawn:
|
if agents_to_spawn:
|
||||||
agents = Agents.from_tiles(floor.empty_tiles[:agents_to_spawn], **agents_kwargs)
|
agents = Agents.from_tiles(floor.empty_tiles[:agents_to_spawn], self._level_shape, **agents_kwargs)
|
||||||
else:
|
else:
|
||||||
agents = Agents(**agents_kwargs)
|
agents = Agents(self._level_shape, **agents_kwargs)
|
||||||
if self._injected_agents:
|
if self._injected_agents:
|
||||||
initialized_injections = list()
|
initialized_injections = list()
|
||||||
for i, injection in enumerate(self._injected_agents):
|
for i, injection in enumerate(self._injected_agents):
|
||||||
agents.register_item(injection(self, floor.empty_tiles[agents_to_spawn+i+1], static_problem=False))
|
agents.add_item(injection(self, floor.empty_tiles[0], agents, static_problem=False))
|
||||||
initialized_injections.append(agents[-1])
|
initialized_injections.append(agents[-1])
|
||||||
self._initialized_injections = initialized_injections
|
self._initialized_injections = initialized_injections
|
||||||
self._entities.register_additional_items({c.AGENT: agents})
|
self._entities.add_additional_items({c.AGENT: agents})
|
||||||
|
|
||||||
if self.obs_prop.additional_agent_placeholder is not None:
|
if self.obs_prop.additional_agent_placeholder is not None:
|
||||||
# TODO: Make this accept Lists for multiple placeholders
|
# TODO: Make this accept Lists for multiple placeholders
|
||||||
|
|
||||||
# Empty Observations with either [0, 1, N(0, 1)]
|
# Empty Observations with either [0, 1, N(0, 1)]
|
||||||
placeholder = PlaceHolders.from_tiles([self._NO_POS_TILE], self._level_shape,
|
placeholder = PlaceHolders.from_values(self.obs_prop.additional_agent_placeholder, self._level_shape,
|
||||||
entity_kwargs=dict(
|
entity_kwargs=dict(
|
||||||
fill_value=self.obs_prop.additional_agent_placeholder)
|
fill_value=self.obs_prop.additional_agent_placeholder)
|
||||||
)
|
)
|
||||||
|
|
||||||
self._entities.register_additional_items({c.AGENT_PLACEHOLDER: placeholder})
|
self._entities.add_additional_items({c.AGENT_PLACEHOLDER: placeholder})
|
||||||
|
|
||||||
# Additional Entitites from SubEnvs
|
# Additional Entitites from SubEnvs
|
||||||
if additional_entities := self.additional_entities:
|
if additional_entities := self.entities_hook:
|
||||||
self._entities.register_additional_items(additional_entities)
|
self._entities.add_additional_items(additional_entities)
|
||||||
|
|
||||||
|
if self.obs_prop.show_global_position_info:
|
||||||
|
global_positions = GlobalPositions(self._level_shape)
|
||||||
|
# This moved into the GlobalPosition object
|
||||||
|
# obs_shape_2d = self._level_shape if not self._pomdp_r else ((self.pomdp_diameter,) * 2)
|
||||||
|
global_positions.spawn_global_position_objects(self[c.AGENT])
|
||||||
|
self._entities.add_additional_items({c.GLOBAL_POSITION: global_positions})
|
||||||
|
|
||||||
# Return
|
# Return
|
||||||
return self._entities
|
return self._entities
|
||||||
|
|
||||||
def _init_obs_cube(self):
|
def reset(self) -> (np.typing.ArrayLike, int, bool, dict):
|
||||||
arrays = self._entities.obs_arrays
|
|
||||||
|
|
||||||
obs_cube_z = sum([a.shape[0] if not self[key].is_per_agent else 1 for key, a in arrays.items()])
|
|
||||||
obs_cube_z += 1 if self.obs_prop.show_global_position_info else 0
|
|
||||||
self._obs_cube = np.zeros((obs_cube_z, *self._level_shape), dtype=np.float32)
|
|
||||||
|
|
||||||
def reset(self) -> (np.ndarray, int, bool, dict):
|
|
||||||
_ = self._base_init_env()
|
_ = self._base_init_env()
|
||||||
self._init_obs_cube()
|
self.reset_hook()
|
||||||
self.do_additional_reset()
|
|
||||||
|
|
||||||
self._steps = 0
|
self._steps = 0
|
||||||
|
|
||||||
obs = self._get_observations()
|
obs, _ = self._build_observations()
|
||||||
return obs
|
return obs
|
||||||
|
|
||||||
def step(self, actions):
|
def step(self, actions):
|
||||||
@ -213,39 +251,53 @@ class BaseFactory(gym.Env):
|
|||||||
self._steps += 1
|
self._steps += 1
|
||||||
|
|
||||||
# Pre step Hook for later use
|
# Pre step Hook for later use
|
||||||
self.hook_pre_step()
|
self.pre_step_hook()
|
||||||
|
|
||||||
# Move this in a seperate function?
|
|
||||||
for action, agent in zip(actions, self[c.AGENT]):
|
for action, agent in zip(actions, self[c.AGENT]):
|
||||||
agent.clear_temp_state()
|
agent.clear_temp_state()
|
||||||
action_obj = self._actions[int(action)]
|
action_obj = self._actions[int(action)]
|
||||||
# self.print(f'Action #{action} has been resolved to: {action_obj}')
|
step_result = dict(collisions=[], rewards=[], info={}, action_name='', action_valid=False)
|
||||||
if h.MovingAction.is_member(action_obj):
|
# cls.print(f'Action #{action} has been resolved to: {action_obj}')
|
||||||
valid = self._move_or_colide(agent, action_obj)
|
if a.is_move(action_obj):
|
||||||
elif h.EnvActions.NOOP == agent.temp_action:
|
action_valid, reward = self._do_move_action(agent, action_obj)
|
||||||
valid = c.VALID
|
elif a.NOOP == action_obj:
|
||||||
elif h.EnvActions.USE_DOOR == action_obj:
|
action_valid = c.VALID
|
||||||
valid = self._handle_door_interaction(agent)
|
reward = dict(value=self.rewards_base.NOOP, reason=a.NOOP, info={f'{agent.name}_NOOP': 1, 'NOOP': 1})
|
||||||
|
elif a.USE_DOOR == action_obj:
|
||||||
|
action_valid, reward = self._handle_door_interaction(agent)
|
||||||
else:
|
else:
|
||||||
valid = self.do_additional_actions(agent, action_obj)
|
# noinspection PyTupleAssignmentBalance
|
||||||
assert valid is not None, 'This should not happen, every Action musst be detected correctly!'
|
action_valid, reward = self.do_additional_actions(agent, action_obj)
|
||||||
agent.temp_action = action_obj
|
# Not needed any more sice the tuple assignment above will fail in case of a failing action resolvement.
|
||||||
agent.temp_valid = valid
|
# assert step_result is not None, 'This should not happen, every Action musst be detected correctly!'
|
||||||
|
step_result['action_name'] = action_obj.identifier
|
||||||
# In-between step Hook for later use
|
step_result['action_valid'] = action_valid
|
||||||
info = self.do_additional_step()
|
step_result['rewards'].append(reward)
|
||||||
|
agent.step_result = step_result
|
||||||
|
|
||||||
|
# Additional step and Reward, Info Init
|
||||||
|
rewards, info = self.step_hook()
|
||||||
|
# Todo: Make this faster, so that only tiles of entities that can collide are searched.
|
||||||
tiles_with_collisions = self.get_all_tiles_with_collisions()
|
tiles_with_collisions = self.get_all_tiles_with_collisions()
|
||||||
for tile in tiles_with_collisions:
|
for tile in tiles_with_collisions:
|
||||||
guests = tile.guests_that_can_collide
|
guests = tile.guests_that_can_collide
|
||||||
for i, guest in enumerate(guests):
|
for i, guest in enumerate(guests):
|
||||||
|
# This does make a copy, but is faster than.copy()
|
||||||
this_collisions = guests[:]
|
this_collisions = guests[:]
|
||||||
del this_collisions[i]
|
del this_collisions[i]
|
||||||
guest.temp_collisions = this_collisions
|
assert hasattr(guest, 'step_result')
|
||||||
|
for collision in this_collisions:
|
||||||
|
guest.step_result['collisions'].append(collision)
|
||||||
|
|
||||||
done = self.done_at_collision and tiles_with_collisions
|
done = False
|
||||||
|
if self.done_at_collision:
|
||||||
|
if done_at_col := bool(tiles_with_collisions):
|
||||||
|
done = done_at_col
|
||||||
|
info.update(COLLISION_DONE=done_at_col)
|
||||||
|
|
||||||
done = done or self.check_additional_done()
|
additional_done, additional_done_info = self.check_additional_done()
|
||||||
|
done = done or additional_done
|
||||||
|
info.update(additional_done_info)
|
||||||
|
|
||||||
# Step the door close intervall
|
# Step the door close intervall
|
||||||
if self.parse_doors:
|
if self.parse_doors:
|
||||||
@ -253,7 +305,8 @@ class BaseFactory(gym.Env):
|
|||||||
doors.tick_doors()
|
doors.tick_doors()
|
||||||
|
|
||||||
# Finalize
|
# Finalize
|
||||||
reward, reward_info = self.calculate_reward()
|
reward, reward_info = self.build_reward_result(rewards)
|
||||||
|
|
||||||
info.update(reward_info)
|
info.update(reward_info)
|
||||||
if self._steps >= self.max_steps:
|
if self._steps >= self.max_steps:
|
||||||
done = True
|
done = True
|
||||||
@ -262,13 +315,14 @@ class BaseFactory(gym.Env):
|
|||||||
info.update(self._summarize_state())
|
info.update(self._summarize_state())
|
||||||
|
|
||||||
# Post step Hook for later use
|
# Post step Hook for later use
|
||||||
info.update(self.hook_post_step())
|
for post_step_info in self.post_step_hook():
|
||||||
|
info.update(post_step_info)
|
||||||
|
|
||||||
obs = self._get_observations()
|
obs, _ = self._build_observations()
|
||||||
|
|
||||||
return obs, reward, done, info
|
return obs, reward, done, info
|
||||||
|
|
||||||
def _handle_door_interaction(self, agent) -> c:
|
def _handle_door_interaction(self, agent) -> (bool, dict):
|
||||||
if doors := self[c.DOORS]:
|
if doors := self[c.DOORS]:
|
||||||
# Check if agent really is standing on a door:
|
# Check if agent really is standing on a door:
|
||||||
if self.doors_have_area:
|
if self.doors_have_area:
|
||||||
@ -277,164 +331,177 @@ class BaseFactory(gym.Env):
|
|||||||
door = doors.by_pos(agent.pos)
|
door = doors.by_pos(agent.pos)
|
||||||
if door is not None:
|
if door is not None:
|
||||||
door.use()
|
door.use()
|
||||||
return c.VALID
|
valid = c.VALID
|
||||||
|
self.print(f'{agent.name} just used a {door.name} at {door.pos}')
|
||||||
|
info_dict = {f'{agent.name}_door_use': 1, f'door_use': 1}
|
||||||
# When he doesn't...
|
# When he doesn't...
|
||||||
else:
|
else:
|
||||||
return c.NOT_VALID
|
valid = c.NOT_VALID
|
||||||
|
info_dict = {f'{agent.name}_failed_door_use': 1, 'failed_door_use': 1}
|
||||||
|
self.print(f'{agent.name} just tried to use a door at {agent.pos}, but there is none.')
|
||||||
|
|
||||||
else:
|
else:
|
||||||
return c.NOT_VALID
|
raise RuntimeError('This should not happen, since the door action should not be available.')
|
||||||
|
reward = dict(value=self.rewards_base.USE_DOOR_VALID if valid else self.rewards_base.USE_DOOR_FAIL,
|
||||||
|
reason=a.USE_DOOR, info=info_dict)
|
||||||
|
|
||||||
def _get_observations(self) -> np.ndarray:
|
return valid, reward
|
||||||
state_array_dict = self._entities.obs_arrays
|
|
||||||
if self.n_agents == 1:
|
def _build_observations(self) -> np.typing.ArrayLike:
|
||||||
obs = self._build_per_agent_obs(self[c.AGENT][0], state_array_dict)
|
# Observation dict:
|
||||||
elif self.n_agents >= 2:
|
per_agent_expl_idx = dict()
|
||||||
obs = np.stack([self._build_per_agent_obs(agent, state_array_dict) for agent in self[c.AGENT]])
|
per_agent_obsn = dict()
|
||||||
|
# Generel Observations
|
||||||
|
lvl_obs = self[c.WALLS].as_array()
|
||||||
|
door_obs = self[c.DOORS].as_array() if self.parse_doors else None
|
||||||
|
if self.obs_prop.render_agents == a_obs.NOT:
|
||||||
|
global_agent_obs = None
|
||||||
|
elif self.obs_prop.omit_agent_self and self.n_agents == 1:
|
||||||
|
global_agent_obs = None
|
||||||
else:
|
else:
|
||||||
raise ValueError('n_agents cannot be smaller than 1!!')
|
global_agent_obs = self[c.AGENT].as_array().copy()
|
||||||
return obs
|
placeholder_obs = self[c.AGENT_PLACEHOLDER].as_array() if self[c.AGENT_PLACEHOLDER] else None
|
||||||
|
add_obs_dict = self.observations_hook()
|
||||||
|
|
||||||
def _build_per_agent_obs(self, agent: Agent, state_array_dict) -> np.ndarray:
|
for agent_idx, agent in enumerate(self[c.AGENT]):
|
||||||
agent_pos_is_omitted = False
|
obs_dict = dict()
|
||||||
agent_omit_idx = None
|
# Build Agent Observations
|
||||||
|
if self.obs_prop.render_agents != a_obs.NOT:
|
||||||
if self.obs_prop.omit_agent_self and self.n_agents == 1:
|
if self.obs_prop.omit_agent_self and self.n_agents >= 2:
|
||||||
pass
|
if self.obs_prop.render_agents == a_obs.SEPERATE:
|
||||||
elif self.obs_prop.omit_agent_self and self.obs_prop.render_agents in [a_obs.COMBINED, ] and self.n_agents > 1:
|
other_agent_obs_idx = [x for x in range(self.n_agents) if x != agent_idx]
|
||||||
state_array_dict[c.AGENT][0, agent.x, agent.y] -= agent.encoding
|
agent_obs = np.take(global_agent_obs, other_agent_obs_idx, axis=0)
|
||||||
agent_pos_is_omitted = True
|
|
||||||
elif self.obs_prop.omit_agent_self and self.obs_prop.render_agents == a_obs.SEPERATE and self.n_agents > 1:
|
|
||||||
agent_omit_idx = next((i for i, a in enumerate(self[c.AGENT]) if a == agent))
|
|
||||||
|
|
||||||
running_idx, shadowing_idxs, can_be_shadowed_idxs = 0, [], []
|
|
||||||
self._obs_cube[:] = 0
|
|
||||||
|
|
||||||
# FIXME: Refactor this! Make a globally build observation, then add individual per-agent-obs
|
|
||||||
for key, array in state_array_dict.items():
|
|
||||||
# Flush state array object representation to obs cube
|
|
||||||
if not self[key].hide_from_obs_builder:
|
|
||||||
if self[key].is_per_agent:
|
|
||||||
per_agent_idx = self[key].idx_by_entity(agent)
|
|
||||||
z = 1
|
|
||||||
self._obs_cube[running_idx: running_idx+z] = array[per_agent_idx]
|
|
||||||
else:
|
|
||||||
if key == c.AGENT and agent_omit_idx is not None:
|
|
||||||
z = array.shape[0] - 1
|
|
||||||
for array_idx in range(array.shape[0]):
|
|
||||||
self._obs_cube[running_idx: running_idx+z] = array[[x for x in range(array.shape[0])
|
|
||||||
if x != agent_omit_idx]]
|
|
||||||
# Agent OBS are combined
|
|
||||||
elif key == c.AGENT and self.obs_prop.omit_agent_self \
|
|
||||||
and self.obs_prop.render_agents == a_obs.COMBINED:
|
|
||||||
z = 1
|
|
||||||
self._obs_cube[running_idx: running_idx + z] = array
|
|
||||||
# Each Agent is rendered on a seperate array slice
|
|
||||||
else:
|
else:
|
||||||
z = array.shape[0]
|
agent_obs = global_agent_obs.copy()
|
||||||
self._obs_cube[running_idx: running_idx + z] = array
|
agent_obs[(0, *agent.pos)] -= agent.encoding
|
||||||
# Define which OBS SLices cast a Shadow
|
else:
|
||||||
if self[key].is_blocking_light:
|
agent_obs = global_agent_obs.copy()
|
||||||
for i in range(z):
|
|
||||||
shadowing_idxs.append(running_idx + i)
|
|
||||||
# Define which OBS SLices are effected by shadows
|
|
||||||
if self[key].can_be_shadowed:
|
|
||||||
for i in range(z):
|
|
||||||
can_be_shadowed_idxs.append(running_idx + i)
|
|
||||||
running_idx += z
|
|
||||||
|
|
||||||
if agent_pos_is_omitted:
|
|
||||||
state_array_dict[c.AGENT][0, agent.x, agent.y] += agent.encoding
|
|
||||||
|
|
||||||
if self._pomdp_r:
|
|
||||||
obs = self._do_pomdp_obs_cutout(agent, self._obs_cube)
|
|
||||||
else:
|
|
||||||
obs = self._obs_cube
|
|
||||||
|
|
||||||
obs = obs.copy()
|
|
||||||
|
|
||||||
if self.obs_prop.cast_shadows:
|
|
||||||
obs_block_light = [obs[idx] != c.OCCUPIED_CELL.value for idx in shadowing_idxs]
|
|
||||||
door_shadowing = False
|
|
||||||
if self.parse_doors:
|
|
||||||
if doors := self[c.DOORS]:
|
|
||||||
if door := doors.by_pos(agent.pos):
|
|
||||||
if door.is_closed:
|
|
||||||
for group in door.connectivity_subgroups:
|
|
||||||
if agent.last_pos not in group:
|
|
||||||
door_shadowing = True
|
|
||||||
if self._pomdp_r:
|
|
||||||
blocking = [tuple(np.subtract(x, agent.pos) + (self._pomdp_r, self._pomdp_r))
|
|
||||||
for x in group]
|
|
||||||
xs, ys = zip(*blocking)
|
|
||||||
else:
|
|
||||||
xs, ys = zip(*group)
|
|
||||||
|
|
||||||
# noinspection PyUnresolvedReferences
|
|
||||||
obs_block_light[0][xs, ys] = False
|
|
||||||
|
|
||||||
light_block_map = Map((np.prod(obs_block_light, axis=0) != True).astype(int))
|
|
||||||
if self._pomdp_r:
|
|
||||||
light_block_map = light_block_map.do_fov(self._pomdp_r, self._pomdp_r, max(self._level_shape))
|
|
||||||
else:
|
else:
|
||||||
light_block_map = light_block_map.do_fov(*agent.pos, max(self._level_shape))
|
# agent_obs == None!!!!!
|
||||||
if door_shadowing:
|
agent_obs = global_agent_obs
|
||||||
# noinspection PyUnboundLocalVariable
|
|
||||||
light_block_map[xs, ys] = 0
|
|
||||||
agent.temp_light_map = light_block_map
|
|
||||||
for obs_idx in can_be_shadowed_idxs:
|
|
||||||
obs[obs_idx] = ((obs[obs_idx] * light_block_map) + 0.) - (1 - light_block_map) # * obs[0])
|
|
||||||
else:
|
|
||||||
pass
|
|
||||||
|
|
||||||
# Agents observe other agents as wall
|
# Build Level Observations
|
||||||
if self.obs_prop.render_agents == a_obs.LEVEL and self.n_agents > 1:
|
if self.obs_prop.render_agents == a_obs.LEVEL:
|
||||||
other_agent_obs = self[c.AGENT].as_array()
|
assert agent_obs is not None
|
||||||
if self.obs_prop.omit_agent_self:
|
lvl_obs = lvl_obs.copy()
|
||||||
other_agent_obs[:, agent.x, agent.y] -= agent.encoding
|
lvl_obs += agent_obs
|
||||||
|
|
||||||
|
obs_dict[c.WALLS] = lvl_obs
|
||||||
|
if self.obs_prop.render_agents in [a_obs.SEPERATE, a_obs.COMBINED] and agent_obs is not None:
|
||||||
|
obs_dict[c.AGENT] = agent_obs[:]
|
||||||
|
if self[c.AGENT_PLACEHOLDER] and placeholder_obs is not None:
|
||||||
|
obs_dict[c.AGENT_PLACEHOLDER] = placeholder_obs
|
||||||
|
if self.parse_doors and door_obs is not None:
|
||||||
|
obs_dict[c.DOORS] = door_obs[:]
|
||||||
|
obs_dict.update(add_obs_dict)
|
||||||
|
obsn = np.vstack(list(obs_dict.values()))
|
||||||
if self.obs_prop.pomdp_r:
|
if self.obs_prop.pomdp_r:
|
||||||
oobs = self._do_pomdp_obs_cutout(agent, other_agent_obs)[0]
|
obsn = self._do_pomdp_cutout(agent, obsn)
|
||||||
# noinspection PyUnresolvedReferences
|
|
||||||
mask = (oobs != c.SHADOWED_CELL.value).astype(int)
|
|
||||||
obs[0] += oobs * mask
|
|
||||||
|
|
||||||
|
raw_obs = self.per_agent_raw_observations_hook(agent)
|
||||||
|
raw_obs = {key: np.expand_dims(val, 0) if val.ndim != 3 else val for key, val in raw_obs.items()}
|
||||||
|
obsn = np.vstack((obsn, *raw_obs.values()))
|
||||||
|
|
||||||
|
keys = list(chain(obs_dict.keys(), raw_obs.keys()))
|
||||||
|
idxs = np.cumsum([x.shape[0] for x in chain(obs_dict.values(), raw_obs.values())]) - 1
|
||||||
|
per_agent_expl_idx[agent.name] = {key: list(range(d, b)) for key, d, b in
|
||||||
|
zip(keys, idxs, list(idxs[1:]) + [idxs[-1]+1, ])}
|
||||||
|
|
||||||
|
# Shadow Casting
|
||||||
|
if agent.step_result is not None:
|
||||||
|
pass
|
||||||
else:
|
else:
|
||||||
obs[0] += other_agent_obs
|
assert self._steps == 0
|
||||||
|
agent.step_result = {'action_name': a.NOOP, 'action_valid': True,
|
||||||
|
'collisions': [], 'lightmap': None}
|
||||||
|
if self.obs_prop.cast_shadows:
|
||||||
|
try:
|
||||||
|
light_block_obs = [obs_idx for key, obs_idx in per_agent_expl_idx[agent.name].items()
|
||||||
|
if self[key].is_blocking_light]
|
||||||
|
# Flatten
|
||||||
|
light_block_obs = [x for y in light_block_obs for x in y]
|
||||||
|
shadowed_obs = [obs_idx for key, obs_idx in per_agent_expl_idx[agent.name].items()
|
||||||
|
if self[key].can_be_shadowed]
|
||||||
|
# Flatten
|
||||||
|
shadowed_obs = [x for y in shadowed_obs for x in y]
|
||||||
|
except AttributeError as e:
|
||||||
|
print('Check your Keys! Only use Constants as Keys!')
|
||||||
|
print(e)
|
||||||
|
raise e
|
||||||
|
|
||||||
# Additional Observation:
|
obs_block_light = obsn[light_block_obs] != c.OCCUPIED_CELL
|
||||||
for additional_obs in self.additional_obs_build():
|
door_shadowing = False
|
||||||
obs[running_idx:running_idx+additional_obs.shape[0]] = additional_obs
|
if self.parse_doors:
|
||||||
running_idx += additional_obs.shape[0]
|
if doors := self[c.DOORS]:
|
||||||
for additional_per_agent_obs in self.additional_per_agent_obs_build(agent):
|
if door := doors.by_pos(agent.pos):
|
||||||
obs[running_idx:running_idx + additional_per_agent_obs.shape[0]] = additional_per_agent_obs
|
if door.is_closed:
|
||||||
running_idx += additional_per_agent_obs.shape[0]
|
for group in door.connectivity_subgroups:
|
||||||
|
if agent.last_pos not in group:
|
||||||
|
door_shadowing = True
|
||||||
|
if self._pomdp_r:
|
||||||
|
blocking = [
|
||||||
|
tuple(np.subtract(x, agent.pos) + (self._pomdp_r, self._pomdp_r))
|
||||||
|
for x in group]
|
||||||
|
xs, ys = zip(*blocking)
|
||||||
|
else:
|
||||||
|
xs, ys = zip(*group)
|
||||||
|
|
||||||
return obs
|
# noinspection PyUnresolvedReferences
|
||||||
|
obs_block_light[:, xs, ys] = False
|
||||||
|
|
||||||
def _do_pomdp_obs_cutout(self, agent, obs_to_be_padded):
|
light_block_map = Map((np.prod(obs_block_light, axis=0) != True).astype(int).squeeze())
|
||||||
|
if self._pomdp_r:
|
||||||
|
light_block_map = light_block_map.do_fov(self._pomdp_r, self._pomdp_r, max(self._level_shape))
|
||||||
|
else:
|
||||||
|
light_block_map = light_block_map.do_fov(*agent.pos, max(self._level_shape))
|
||||||
|
if door_shadowing:
|
||||||
|
# noinspection PyUnboundLocalVariable
|
||||||
|
light_block_map[xs, ys] = 0
|
||||||
|
|
||||||
|
agent.step_result['lightmap'] = light_block_map
|
||||||
|
|
||||||
|
obsn[shadowed_obs] = ((obsn[shadowed_obs] * light_block_map) + 0.) - (1 - light_block_map)
|
||||||
|
else:
|
||||||
|
if self._pomdp_r:
|
||||||
|
agent.step_result['lightmap'] = np.ones(self._obs_shape)
|
||||||
|
else:
|
||||||
|
agent.step_result['lightmap'] = None
|
||||||
|
|
||||||
|
per_agent_obsn[agent.name] = obsn
|
||||||
|
|
||||||
|
if self.n_agents == 1:
|
||||||
|
agent_name = self[c.AGENT][0].name
|
||||||
|
obs, explained_idx = per_agent_obsn[agent_name], per_agent_expl_idx[agent_name]
|
||||||
|
elif self.n_agents >= 2:
|
||||||
|
obs, explained_idx = np.stack(list(per_agent_obsn.values())), per_agent_expl_idx
|
||||||
|
else:
|
||||||
|
raise ValueError
|
||||||
|
|
||||||
|
return obs, explained_idx
|
||||||
|
|
||||||
|
def _do_pomdp_cutout(self, agent, obs_to_be_padded):
|
||||||
assert obs_to_be_padded.ndim == 3
|
assert obs_to_be_padded.ndim == 3
|
||||||
r, d = self._pomdp_r, self.pomdp_diameter
|
ra, d = self._pomdp_r, self.pomdp_diameter
|
||||||
x0, x1 = max(0, agent.x - r), min(agent.x + r + 1, self._level_shape[0])
|
x0, x1 = max(0, agent.x - ra), min(agent.x + ra + 1, self._level_shape[0])
|
||||||
y0, y1 = max(0, agent.y - r), min(agent.y + r + 1, self._level_shape[1])
|
y0, y1 = max(0, agent.y - ra), min(agent.y + ra + 1, self._level_shape[1])
|
||||||
# Other Agent Obs = oobs
|
|
||||||
oobs = obs_to_be_padded[:, x0:x1, y0:y1]
|
oobs = obs_to_be_padded[:, x0:x1, y0:y1]
|
||||||
if oobs.shape[0:] != (d, d):
|
if oobs.shape[1:] != (d, d):
|
||||||
if xd := oobs.shape[1] % d:
|
if xd := oobs.shape[1] % d:
|
||||||
if agent.x > r:
|
if agent.x > ra:
|
||||||
x0_pad = 0
|
x0_pad = 0
|
||||||
x1_pad = (d - xd)
|
x1_pad = (d - xd)
|
||||||
else:
|
else:
|
||||||
x0_pad = r - agent.x
|
x0_pad = ra - agent.x
|
||||||
x1_pad = 0
|
x1_pad = 0
|
||||||
else:
|
else:
|
||||||
x0_pad, x1_pad = 0, 0
|
x0_pad, x1_pad = 0, 0
|
||||||
|
|
||||||
if yd := oobs.shape[2] % d:
|
if yd := oobs.shape[2] % d:
|
||||||
if agent.y > r:
|
if agent.y > ra:
|
||||||
y0_pad = 0
|
y0_pad = 0
|
||||||
y1_pad = (d - yd)
|
y1_pad = (d - yd)
|
||||||
else:
|
else:
|
||||||
y0_pad = r - agent.y
|
y0_pad = ra - agent.y
|
||||||
y1_pad = 0
|
y1_pad = 0
|
||||||
else:
|
else:
|
||||||
y0_pad, y1_pad = 0, 0
|
y0_pad, y1_pad = 0, 0
|
||||||
@ -442,27 +509,43 @@ class BaseFactory(gym.Env):
|
|||||||
oobs = np.pad(oobs, ((0, 0), (x0_pad, x1_pad), (y0_pad, y1_pad)), 'constant')
|
oobs = np.pad(oobs, ((0, 0), (x0_pad, x1_pad), (y0_pad, y1_pad)), 'constant')
|
||||||
return oobs
|
return oobs
|
||||||
|
|
||||||
def get_all_tiles_with_collisions(self) -> List[Tile]:
|
def get_all_tiles_with_collisions(self) -> List[Floor]:
|
||||||
tiles_with_collisions = list()
|
tiles = [x for x in self[c.FLOOR] if len(x.guests_that_can_collide) > 1]
|
||||||
for tile in self[c.FLOOR]:
|
if False:
|
||||||
if tile.is_occupied():
|
tiles_with_collisions = list()
|
||||||
guests = tile.guests_that_can_collide
|
for tile in self[c.FLOOR]:
|
||||||
if len(guests) >= 2:
|
if tile.is_occupied():
|
||||||
tiles_with_collisions.append(tile)
|
guests = tile.guests_that_can_collide
|
||||||
return tiles_with_collisions
|
if len(guests) >= 2:
|
||||||
|
tiles_with_collisions.append(tile)
|
||||||
|
return tiles
|
||||||
|
|
||||||
def _move_or_colide(self, agent: Agent, action: Action) -> Constants:
|
def _do_move_action(self, agent: Agent, action: Action) -> (dict, dict):
|
||||||
|
info_dict = dict()
|
||||||
new_tile, valid = self._check_agent_move(agent, action)
|
new_tile, valid = self._check_agent_move(agent, action)
|
||||||
if valid:
|
if valid:
|
||||||
# Does not collide width level boundaries
|
# Does not collide width level boundaries
|
||||||
return agent.move(new_tile)
|
valid = agent.move(new_tile)
|
||||||
|
if valid:
|
||||||
|
# This will spam your logs, beware!
|
||||||
|
self.print(f'{agent.name} just moved {action.identifier} from {agent.last_pos} to {agent.pos}.')
|
||||||
|
info_dict.update({f'{agent.name}_move': 1, 'move': 1})
|
||||||
|
pass
|
||||||
|
else:
|
||||||
|
valid = c.NOT_VALID
|
||||||
|
self.print(f'{agent.name} just hit the wall at {agent.pos}. ({action.identifier})')
|
||||||
|
info_dict.update({f'{agent.name}_wall_collide': 1, 'wall_collide': 1})
|
||||||
else:
|
else:
|
||||||
# Agent seems to be trying to collide in this step
|
# Agent seems to be trying to Leave the level
|
||||||
return c.NOT_VALID
|
self.print(f'{agent.name} tried to leave the level {agent.pos}. ({action.identifier})')
|
||||||
|
info_dict.update({f'{agent.name}_wall_collide': 1, 'wall_collide': 1})
|
||||||
|
reward_value = self.rewards_base.MOVEMENTS_VALID if valid else self.rewards_base.MOVEMENTS_FAIL
|
||||||
|
reward = {'value': reward_value, 'reason': action.identifier, 'info': info_dict}
|
||||||
|
return valid, reward
|
||||||
|
|
||||||
def _check_agent_move(self, agent, action: Action) -> (Tile, bool):
|
def _check_agent_move(self, agent, action: Action) -> (Floor, bool):
|
||||||
# Actions
|
# Actions
|
||||||
x_diff, y_diff = h.ACTIONMAP[action.identifier]
|
x_diff, y_diff = a.resolve_movement_action_to_coords(action.identifier)
|
||||||
x_new = agent.x + x_diff
|
x_new = agent.x + x_diff
|
||||||
y_new = agent.y + y_diff
|
y_new = agent.y + y_diff
|
||||||
|
|
||||||
@ -478,7 +561,7 @@ class BaseFactory(gym.Env):
|
|||||||
if doors := self[c.DOORS]:
|
if doors := self[c.DOORS]:
|
||||||
if self.doors_have_area:
|
if self.doors_have_area:
|
||||||
if door := doors.by_pos(new_tile.pos):
|
if door := doors.by_pos(new_tile.pos):
|
||||||
if door.can_collide:
|
if door.is_closed:
|
||||||
return agent.tile, c.NOT_VALID
|
return agent.tile, c.NOT_VALID
|
||||||
else: # door.is_closed:
|
else: # door.is_closed:
|
||||||
pass
|
pass
|
||||||
@ -498,78 +581,65 @@ class BaseFactory(gym.Env):
|
|||||||
|
|
||||||
return new_tile, valid
|
return new_tile, valid
|
||||||
|
|
||||||
def calculate_reward(self) -> (int, dict):
|
def build_reward_result(self, global_env_rewards: list) -> (int, dict):
|
||||||
# Returns: Reward, Info
|
# Returns: Reward, Info
|
||||||
per_agent_info_dict = defaultdict(dict)
|
info = defaultdict(lambda: 0.0)
|
||||||
reward = {}
|
|
||||||
|
|
||||||
|
# Gather additional sub-env rewards and calculate collisions
|
||||||
for agent in self[c.AGENT]:
|
for agent in self[c.AGENT]:
|
||||||
per_agent_reward = 0
|
|
||||||
if self._actions.is_moving_action(agent.temp_action):
|
|
||||||
if agent.temp_valid:
|
|
||||||
# info_dict.update(movement=1)
|
|
||||||
per_agent_reward -= 0.001
|
|
||||||
pass
|
|
||||||
else:
|
|
||||||
per_agent_reward -= 0.05
|
|
||||||
self.print(f'{agent.name} just hit the wall at {agent.pos}.')
|
|
||||||
per_agent_info_dict[agent.name].update({f'{agent.name}_vs_LEVEL': 1})
|
|
||||||
|
|
||||||
elif h.EnvActions.USE_DOOR == agent.temp_action:
|
rewards = self.per_agent_reward_hook(agent)
|
||||||
if agent.temp_valid:
|
for reward in rewards:
|
||||||
# per_agent_reward += 0.00
|
agent.step_result['rewards'].append(reward)
|
||||||
self.print(f'{agent.name} did just use the door at {agent.pos}.')
|
if collisions := agent.step_result['collisions']:
|
||||||
per_agent_info_dict[agent.name].update(door_used=1)
|
self.print(f't = {self._steps}\t{agent.name} has collisions with {collisions}')
|
||||||
else:
|
info[c.COLLISION] += 1
|
||||||
# per_agent_reward -= 0.00
|
reward = {'value': self.rewards_base.COLLISION,
|
||||||
self.print(f'{agent.name} just tried to use a door at {agent.pos}, but failed.')
|
'reason': c.COLLISION,
|
||||||
per_agent_info_dict[agent.name].update({f'{agent.name}_failed_door_open': 1})
|
'info': {f'{agent.name}_{c.COLLISION}': 1}}
|
||||||
elif h.EnvActions.NOOP == agent.temp_action:
|
agent.step_result['rewards'].append(reward)
|
||||||
per_agent_info_dict[agent.name].update(no_op=1)
|
|
||||||
# per_agent_reward -= 0.00
|
|
||||||
|
|
||||||
# EnvMonitor Notes
|
|
||||||
if agent.temp_valid:
|
|
||||||
per_agent_info_dict[agent.name].update(valid_action=1)
|
|
||||||
per_agent_info_dict[agent.name].update({f'{agent.name}_valid_action': 1})
|
|
||||||
else:
|
else:
|
||||||
per_agent_info_dict[agent.name].update(failed_action=1)
|
# No Collisions, nothing to do
|
||||||
per_agent_info_dict[agent.name].update({f'{agent.name}_failed_action': 1})
|
pass
|
||||||
|
|
||||||
additional_reward, additional_info_dict = self.calculate_additional_reward(agent)
|
comb_rewards = {agent.name: sum(x['value'] for x in agent.step_result['rewards']) for agent in self[c.AGENT]}
|
||||||
per_agent_reward += additional_reward
|
|
||||||
per_agent_info_dict[agent.name].update(additional_info_dict)
|
|
||||||
|
|
||||||
if agent.temp_collisions:
|
|
||||||
self.print(f't = {self._steps}\t{agent.name} has collisions with {agent.temp_collisions}')
|
|
||||||
per_agent_info_dict[agent.name].update(collisions=1)
|
|
||||||
|
|
||||||
for other_agent in agent.temp_collisions:
|
|
||||||
per_agent_info_dict[agent.name].update({f'{agent.name}_vs_{other_agent.name}': 1})
|
|
||||||
reward[agent.name] = per_agent_reward
|
|
||||||
|
|
||||||
# Combine the per_agent_info_dict:
|
# Combine the per_agent_info_dict:
|
||||||
combined_info_dict = defaultdict(lambda: 0)
|
combined_info_dict = defaultdict(lambda: 0)
|
||||||
for info_dict in per_agent_info_dict.values():
|
for agent in self[c.AGENT]:
|
||||||
for key, value in info_dict.items():
|
for reward in agent.step_result['rewards']:
|
||||||
combined_info_dict[key] += value
|
combined_info_dict.update(reward['info'])
|
||||||
|
|
||||||
|
# Combine Info dicts into a global one
|
||||||
combined_info_dict = dict(combined_info_dict)
|
combined_info_dict = dict(combined_info_dict)
|
||||||
|
|
||||||
|
combined_info_dict.update(info)
|
||||||
|
|
||||||
|
global_reward_sum = sum(global_env_rewards)
|
||||||
if self.individual_rewards:
|
if self.individual_rewards:
|
||||||
self.print(f"rewards are {reward}")
|
self.print(f"rewards are {comb_rewards}")
|
||||||
reward = list(reward.values())
|
reward = list(comb_rewards.values())
|
||||||
|
reward = [x + global_reward_sum for x in reward]
|
||||||
return reward, combined_info_dict
|
return reward, combined_info_dict
|
||||||
else:
|
else:
|
||||||
reward = sum(reward.values())
|
reward = sum(comb_rewards.values()) + global_reward_sum
|
||||||
self.print(f"reward is {reward}")
|
self.print(f"reward is {reward}")
|
||||||
return reward, combined_info_dict
|
return reward, combined_info_dict
|
||||||
|
|
||||||
|
def start_recording(self):
|
||||||
|
self._record_episodes = True
|
||||||
|
return self._record_episodes
|
||||||
|
|
||||||
|
def stop_recording(self):
|
||||||
|
self._record_episodes = False
|
||||||
|
return not self._record_episodes
|
||||||
|
|
||||||
# noinspection PyGlobalUndefined
|
# noinspection PyGlobalUndefined
|
||||||
def render(self, mode='human'):
|
def render(self, mode='human'):
|
||||||
if not self._renderer: # lazy init
|
if not self._renderer: # lazy init
|
||||||
from environments.factory.base.renderer import Renderer, RenderEntity
|
from environments.factory.base.renderer import Renderer, RenderEntity
|
||||||
global Renderer, RenderEntity
|
global Renderer, RenderEntity
|
||||||
height, width = self._obs_cube.shape[1:]
|
height, width = self._level_shape
|
||||||
self._renderer = Renderer(width, height, view_radius=self._pomdp_r, fps=5)
|
self._renderer = Renderer(width, height, view_radius=self._pomdp_r, fps=5)
|
||||||
|
|
||||||
# noinspection PyUnboundLocalVariable
|
# noinspection PyUnboundLocalVariable
|
||||||
@ -578,13 +648,13 @@ class BaseFactory(gym.Env):
|
|||||||
agents = []
|
agents = []
|
||||||
for i, agent in enumerate(self[c.AGENT]):
|
for i, agent in enumerate(self[c.AGENT]):
|
||||||
name, state = h.asset_str(agent)
|
name, state = h.asset_str(agent)
|
||||||
agents.append(RenderEntity(name, agent.pos, 1, 'none', state, i + 1, agent.temp_light_map))
|
agents.append(RenderEntity(name, agent.pos, 1, 'none', state, i + 1, agent.step_result['lightmap']))
|
||||||
doors = []
|
doors = []
|
||||||
if self.parse_doors:
|
if self.parse_doors:
|
||||||
for i, door in enumerate(self[c.DOORS]):
|
for i, door in enumerate(self[c.DOORS]):
|
||||||
name, state = 'door_open' if door.is_open else 'door_closed', 'blank'
|
name, state = 'door_open' if door.is_open else 'door_closed', 'blank'
|
||||||
doors.append(RenderEntity(name, door.pos, 1, 'none', state, i + 1))
|
doors.append(RenderEntity(name, door.pos, 1, 'none', state, i + 1))
|
||||||
additional_assets = self.render_additional_assets()
|
additional_assets = self.render_assets_hook()
|
||||||
|
|
||||||
return self._renderer.render(walls + doors + additional_assets + agents)
|
return self._renderer.render(walls + doors + additional_assets + agents)
|
||||||
|
|
||||||
@ -601,12 +671,12 @@ class BaseFactory(gym.Env):
|
|||||||
else:
|
else:
|
||||||
return []
|
return []
|
||||||
|
|
||||||
def _summarize_state(self):
|
def _summarize_state(self, stateless_entities=False):
|
||||||
summary = {f'{REC_TAC}step': self._steps}
|
summary = {f'{REC_TAC}step': self._steps}
|
||||||
|
|
||||||
for entity_group in self._entities:
|
for entity_group in self._entities:
|
||||||
summary.update({f'{REC_TAC}{entity_group.name}': entity_group.summarize_states(n_steps=self._steps)})
|
if entity_group.is_stateless == stateless_entities:
|
||||||
|
summary.update({f'{REC_TAC}{entity_group.name}': entity_group.summarize_states()})
|
||||||
return summary
|
return summary
|
||||||
|
|
||||||
def print(self, string):
|
def print(self, string):
|
||||||
@ -615,7 +685,8 @@ class BaseFactory(gym.Env):
|
|||||||
|
|
||||||
# Properties which are called by the base class to extend beyond attributes of the base class
|
# Properties which are called by the base class to extend beyond attributes of the base class
|
||||||
@property
|
@property
|
||||||
def additional_actions(self) -> Union[Action, List[Action]]:
|
@abc.abstractmethod
|
||||||
|
def actions_hook(self) -> Union[Action, List[Action]]:
|
||||||
"""
|
"""
|
||||||
When heriting from this Base Class, you musst implement this methode!!!
|
When heriting from this Base Class, you musst implement this methode!!!
|
||||||
|
|
||||||
@ -625,7 +696,8 @@ class BaseFactory(gym.Env):
|
|||||||
return []
|
return []
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def additional_entities(self) -> Dict[(Enum, Entities)]:
|
@abc.abstractmethod
|
||||||
|
def entities_hook(self) -> Dict[(str, Entities)]:
|
||||||
"""
|
"""
|
||||||
When heriting from this Base Class, you musst implement this methode!!!
|
When heriting from this Base Class, you musst implement this methode!!!
|
||||||
|
|
||||||
@ -637,49 +709,46 @@ class BaseFactory(gym.Env):
|
|||||||
# Functions which provide additions to functions of the base class
|
# Functions which provide additions to functions of the base class
|
||||||
# Always call super!!!!!!
|
# Always call super!!!!!!
|
||||||
@abc.abstractmethod
|
@abc.abstractmethod
|
||||||
def additional_obs_build(self) -> List[np.ndarray]:
|
def reset_hook(self) -> None:
|
||||||
return []
|
|
||||||
|
|
||||||
def additional_per_agent_obs_build(self, agent) -> List[np.ndarray]:
|
|
||||||
additional_per_agent_obs = []
|
|
||||||
if self.obs_prop.show_global_position_info:
|
|
||||||
pos_array = np.zeros(self.observation_space.shape[1:])
|
|
||||||
for xy in range(1):
|
|
||||||
pos_array[0, xy] = agent.pos[xy] / self._level_shape[xy]
|
|
||||||
additional_per_agent_obs.append(pos_array)
|
|
||||||
|
|
||||||
return additional_per_agent_obs
|
|
||||||
|
|
||||||
@abc.abstractmethod
|
|
||||||
def do_additional_reset(self) -> None:
|
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@abc.abstractmethod
|
@abc.abstractmethod
|
||||||
def do_additional_step(self) -> dict:
|
def pre_step_hook(self) -> None:
|
||||||
return {}
|
pass
|
||||||
|
|
||||||
@abc.abstractmethod
|
@abc.abstractmethod
|
||||||
def do_additional_actions(self, agent: Agent, action: Action) -> Union[None, c]:
|
def do_additional_actions(self, agent: Agent, action: Action) -> (bool, dict):
|
||||||
return None
|
return None
|
||||||
|
|
||||||
@abc.abstractmethod
|
@abc.abstractmethod
|
||||||
def check_additional_done(self) -> bool:
|
def step_hook(self) -> (List[dict], dict):
|
||||||
return False
|
return [], {}
|
||||||
|
|
||||||
@abc.abstractmethod
|
@abc.abstractmethod
|
||||||
def calculate_additional_reward(self, agent: Agent) -> (int, dict):
|
def check_additional_done(self) -> (bool, dict):
|
||||||
return 0, {}
|
return False, {}
|
||||||
|
|
||||||
@abc.abstractmethod
|
@abc.abstractmethod
|
||||||
def render_additional_assets(self):
|
def observations_hook(self) -> Dict[str, np.typing.ArrayLike]:
|
||||||
|
return {}
|
||||||
|
|
||||||
|
@abc.abstractmethod
|
||||||
|
def per_agent_reward_hook(self, agent: Agent) -> List[dict]:
|
||||||
return []
|
return []
|
||||||
|
|
||||||
# Hooks for in between operations.
|
|
||||||
# Always call super!!!!!!
|
|
||||||
@abc.abstractmethod
|
@abc.abstractmethod
|
||||||
def hook_pre_step(self) -> None:
|
def post_step_hook(self) -> List[dict]:
|
||||||
pass
|
return []
|
||||||
|
|
||||||
@abc.abstractmethod
|
@abc.abstractmethod
|
||||||
def hook_post_step(self) -> dict:
|
def per_agent_raw_observations_hook(self, agent) -> Dict[str, np.typing.ArrayLike]:
|
||||||
return {}
|
additional_raw_observations = {}
|
||||||
|
if self.obs_prop.show_global_position_info:
|
||||||
|
global_pos_obs = np.zeros(self._obs_shape)
|
||||||
|
global_pos_obs[:2, 0] = self[c.GLOBAL_POSITION].by_entity(agent).encoding
|
||||||
|
additional_raw_observations.update({c.GLOBAL_POSITION: global_pos_obs})
|
||||||
|
return additional_raw_observations
|
||||||
|
|
||||||
|
@abc.abstractmethod
|
||||||
|
def render_assets_hook(self):
|
||||||
|
return []
|
||||||
|
@ -1,54 +1,51 @@
|
|||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from enum import Enum
|
|
||||||
from typing import Union
|
from typing import Union
|
||||||
|
|
||||||
import networkx as nx
|
import networkx as nx
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
from environments import helpers as h
|
from environments import helpers as h
|
||||||
from environments.helpers import Constants as c
|
from environments.helpers import Constants as c
|
||||||
import itertools
|
import itertools
|
||||||
|
|
||||||
|
##########################################################################
|
||||||
|
# ##################### Base Object Building Blocks ######################### #
|
||||||
|
##########################################################################
|
||||||
|
|
||||||
|
|
||||||
|
# TODO: Missing Documentation
|
||||||
class Object:
|
class Object:
|
||||||
|
|
||||||
|
"""Generell Objects for Organisation and Maintanance such as Actions etc..."""
|
||||||
|
|
||||||
_u_idx = defaultdict(lambda: 0)
|
_u_idx = defaultdict(lambda: 0)
|
||||||
|
|
||||||
def __bool__(self):
|
def __bool__(self):
|
||||||
return True
|
return True
|
||||||
|
|
||||||
@property
|
|
||||||
def is_blocking_light(self):
|
|
||||||
return self._is_blocking_light
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def name(self):
|
def name(self):
|
||||||
return self._name
|
return self._name
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def identifier(self):
|
def identifier(self):
|
||||||
if self._enum_ident is not None:
|
if self._str_ident is not None:
|
||||||
return self._enum_ident
|
|
||||||
elif self._str_ident is not None:
|
|
||||||
return self._str_ident
|
return self._str_ident
|
||||||
else:
|
else:
|
||||||
return self._name
|
return self._name
|
||||||
|
|
||||||
def __init__(self, str_ident: Union[str, None] = None, enum_ident: Union[Enum, None] = None,
|
def __init__(self, str_ident: Union[str, None] = None, **kwargs):
|
||||||
is_blocking_light=False, **kwargs):
|
|
||||||
|
|
||||||
self._str_ident = str_ident
|
self._str_ident = str_ident
|
||||||
self._enum_ident = enum_ident
|
|
||||||
|
|
||||||
if self._enum_ident is not None and self._str_ident is None:
|
if self._str_ident is not None:
|
||||||
self._name = f'{self.__class__.__name__}[{self._enum_ident.name}]'
|
|
||||||
elif self._str_ident is not None and self._enum_ident is None:
|
|
||||||
self._name = f'{self.__class__.__name__}[{self._str_ident}]'
|
self._name = f'{self.__class__.__name__}[{self._str_ident}]'
|
||||||
elif self._str_ident is None and self._enum_ident is None:
|
elif self._str_ident is None:
|
||||||
self._name = f'{self.__class__.__name__}#{self._u_idx[self.__class__.__name__]}'
|
self._name = f'{self.__class__.__name__}#{Object._u_idx[self.__class__.__name__]}'
|
||||||
Object._u_idx[self.__class__.__name__] += 1
|
Object._u_idx[self.__class__.__name__] += 1
|
||||||
else:
|
else:
|
||||||
raise ValueError('Please use either of the idents.')
|
raise ValueError('Please use either of the idents.')
|
||||||
|
|
||||||
self._is_blocking_light = is_blocking_light
|
|
||||||
if kwargs:
|
if kwargs:
|
||||||
print(f'Following kwargs were passed, but ignored: {kwargs}')
|
print(f'Following kwargs were passed, but ignored: {kwargs}')
|
||||||
|
|
||||||
@ -56,27 +53,44 @@ class Object:
|
|||||||
return f'{self.name}'
|
return f'{self.name}'
|
||||||
|
|
||||||
def __eq__(self, other) -> bool:
|
def __eq__(self, other) -> bool:
|
||||||
if self._enum_ident is not None:
|
return other == self.identifier
|
||||||
if isinstance(other, Enum):
|
# Base
|
||||||
return other == self._enum_ident
|
|
||||||
elif isinstance(other, Object):
|
|
||||||
return other._enum_ident == self._enum_ident
|
|
||||||
else:
|
|
||||||
raise ValueError('Must be evaluated against an Enunm Identifier or Object with such.')
|
|
||||||
else:
|
|
||||||
assert isinstance(other, Object), ' This Object can only be compared to other Objects.'
|
|
||||||
return other.name == self.name
|
|
||||||
|
|
||||||
|
|
||||||
class Entity(Object):
|
# TODO: Missing Documentation
|
||||||
|
class EnvObject(Object):
|
||||||
|
|
||||||
|
"""Objects that hold Information that are observable, but have no position on the env grid. Inventories etc..."""
|
||||||
|
|
||||||
|
_u_idx = defaultdict(lambda: 0)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def can_collide(self):
|
def can_collide(self):
|
||||||
return True
|
return False
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def encoding(self):
|
def encoding(self):
|
||||||
return c.OCCUPIED_CELL.value
|
return c.OCCUPIED_CELL
|
||||||
|
|
||||||
|
def __init__(self, collection, **kwargs):
|
||||||
|
super(EnvObject, self).__init__(**kwargs)
|
||||||
|
self._collection = collection
|
||||||
|
|
||||||
|
def change_parent_collection(self, other_collection):
|
||||||
|
other_collection.add_item(self)
|
||||||
|
self._collection.delete_env_object(self)
|
||||||
|
self._collection = other_collection
|
||||||
|
return self._collection == other_collection
|
||||||
|
# With Rendering
|
||||||
|
|
||||||
|
|
||||||
|
# TODO: Missing Documentation
|
||||||
|
class Entity(EnvObject):
|
||||||
|
"""Full Env Entity that lives on the env Grid. Doors, Items, DirtPile etc..."""
|
||||||
|
|
||||||
|
@property
|
||||||
|
def can_collide(self):
|
||||||
|
return False
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def x(self):
|
def x(self):
|
||||||
@ -94,19 +108,20 @@ class Entity(Object):
|
|||||||
def tile(self):
|
def tile(self):
|
||||||
return self._tile
|
return self._tile
|
||||||
|
|
||||||
def __init__(self, tile, **kwargs):
|
def __init__(self, tile, *args, **kwargs):
|
||||||
super().__init__(**kwargs)
|
super().__init__(*args, **kwargs)
|
||||||
self._tile = tile
|
self._tile = tile
|
||||||
tile.enter(self)
|
tile.enter(self)
|
||||||
|
|
||||||
def summarize_state(self, **_) -> dict:
|
def summarize_state(self) -> dict:
|
||||||
return dict(name=str(self.name), x=int(self.x), y=int(self.y),
|
return dict(name=str(self.name), x=int(self.x), y=int(self.y),
|
||||||
tile=str(self.tile.name), can_collide=bool(self.can_collide))
|
tile=str(self.tile.name), can_collide=bool(self.can_collide))
|
||||||
|
|
||||||
def __repr__(self):
|
def __repr__(self):
|
||||||
return f'{self.name}(@{self.pos})'
|
return super(Entity, self).__repr__() + f'(@{self.pos})'
|
||||||
|
|
||||||
|
|
||||||
|
# TODO: Missing Documentation
|
||||||
class MoveableEntity(Entity):
|
class MoveableEntity(Entity):
|
||||||
|
|
||||||
@property
|
@property
|
||||||
@ -137,9 +152,36 @@ class MoveableEntity(Entity):
|
|||||||
curr_tile.leave(self)
|
curr_tile.leave(self)
|
||||||
self._tile = next_tile
|
self._tile = next_tile
|
||||||
self._last_tile = curr_tile
|
self._last_tile = curr_tile
|
||||||
return True
|
self._collection.notify_change_to_value(self)
|
||||||
|
return c.VALID
|
||||||
else:
|
else:
|
||||||
return False
|
return c.NOT_VALID
|
||||||
|
# Can Move
|
||||||
|
|
||||||
|
|
||||||
|
# TODO: Missing Documentation
|
||||||
|
class BoundingMixin(Object):
|
||||||
|
|
||||||
|
@property
|
||||||
|
def bound_entity(self):
|
||||||
|
return self._bound_entity
|
||||||
|
|
||||||
|
def __init__(self,entity_to_be_bound, *args, **kwargs):
|
||||||
|
super(BoundingMixin, self).__init__(*args, **kwargs)
|
||||||
|
assert entity_to_be_bound is not None
|
||||||
|
self._bound_entity = entity_to_be_bound
|
||||||
|
|
||||||
|
@property
|
||||||
|
def name(self):
|
||||||
|
return f'{super(BoundingMixin, self).name}({self._bound_entity.name})'
|
||||||
|
|
||||||
|
def belongs_to_entity(self, entity):
|
||||||
|
return entity == self.bound_entity
|
||||||
|
|
||||||
|
|
||||||
|
##########################################################################
|
||||||
|
# ####################### Objects and Entitys ########################## #
|
||||||
|
##########################################################################
|
||||||
|
|
||||||
|
|
||||||
class Action(Object):
|
class Action(Object):
|
||||||
@ -148,34 +190,45 @@ class Action(Object):
|
|||||||
super().__init__(*args, **kwargs)
|
super().__init__(*args, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
class PlaceHolder(MoveableEntity):
|
class PlaceHolder(Object):
|
||||||
|
|
||||||
def __init__(self, *args, fill_value=0, **kwargs):
|
def __init__(self, *args, fill_value=0, **kwargs):
|
||||||
super().__init__(*args, **kwargs)
|
super().__init__(*args, **kwargs)
|
||||||
self._fill_value = fill_value
|
self._fill_value = fill_value
|
||||||
|
|
||||||
@property
|
|
||||||
def last_tile(self):
|
|
||||||
return self.tile
|
|
||||||
|
|
||||||
@property
|
|
||||||
def direction_of_view(self):
|
|
||||||
return self.pos
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def can_collide(self):
|
def can_collide(self):
|
||||||
return False
|
return False
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def encoding(self):
|
def encoding(self):
|
||||||
return c.NO_POS.value[0]
|
return self._fill_value
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def name(self):
|
def name(self):
|
||||||
return "PlaceHolder"
|
return "PlaceHolder"
|
||||||
|
|
||||||
|
|
||||||
class Tile(Object):
|
class GlobalPosition(BoundingMixin, EnvObject):
|
||||||
|
|
||||||
|
@property
|
||||||
|
def encoding(self):
|
||||||
|
if self._normalized:
|
||||||
|
return tuple(np.divide(self._bound_entity.pos, self._level_shape))
|
||||||
|
else:
|
||||||
|
return self.bound_entity.pos
|
||||||
|
|
||||||
|
def __init__(self, level_shape: (int, int), *args, normalized: bool = True, **kwargs):
|
||||||
|
super(GlobalPosition, self).__init__(*args, **kwargs)
|
||||||
|
self._level_shape = level_shape
|
||||||
|
self._normalized = normalized
|
||||||
|
|
||||||
|
|
||||||
|
class Floor(EnvObject):
|
||||||
|
|
||||||
|
@property
|
||||||
|
def encoding(self):
|
||||||
|
return c.FREE_CELL
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def guests_that_can_collide(self):
|
def guests_that_can_collide(self):
|
||||||
@ -197,8 +250,8 @@ class Tile(Object):
|
|||||||
def pos(self):
|
def pos(self):
|
||||||
return self._pos
|
return self._pos
|
||||||
|
|
||||||
def __init__(self, pos, **kwargs):
|
def __init__(self, pos, *args, **kwargs):
|
||||||
super(Tile, self).__init__(**kwargs)
|
super(Floor, self).__init__(*args, **kwargs)
|
||||||
self._guests = dict()
|
self._guests = dict()
|
||||||
self._pos = tuple(pos)
|
self._pos = tuple(pos)
|
||||||
|
|
||||||
@ -232,7 +285,16 @@ class Tile(Object):
|
|||||||
return dict(name=self.name, x=int(self.x), y=int(self.y))
|
return dict(name=self.name, x=int(self.x), y=int(self.y))
|
||||||
|
|
||||||
|
|
||||||
class Wall(Tile):
|
class Wall(Floor):
|
||||||
|
|
||||||
|
@property
|
||||||
|
def can_collide(self):
|
||||||
|
return True
|
||||||
|
|
||||||
|
@property
|
||||||
|
def encoding(self):
|
||||||
|
return c.OCCUPIED_CELL
|
||||||
|
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
@ -247,7 +309,8 @@ class Door(Entity):
|
|||||||
|
|
||||||
@property
|
@property
|
||||||
def encoding(self):
|
def encoding(self):
|
||||||
return 1 if self.is_closed else 2
|
# This is important as it shadow is checked by occupation value
|
||||||
|
return c.CLOSED_DOOR_CELL if self.is_closed else c.OPEN_DOOR_CELL
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def str_state(self):
|
def str_state(self):
|
||||||
@ -275,8 +338,8 @@ class Door(Entity):
|
|||||||
if not closed_on_init:
|
if not closed_on_init:
|
||||||
self._open()
|
self._open()
|
||||||
|
|
||||||
def summarize_state(self, **kwargs):
|
def summarize_state(self):
|
||||||
state_dict = super().summarize_state(**kwargs)
|
state_dict = super().summarize_state()
|
||||||
state_dict.update(state=str(self.str_state), time_to_close=int(self.time_to_close))
|
state_dict.update(state=str(self.str_state), time_to_close=int(self.time_to_close))
|
||||||
return state_dict
|
return state_dict
|
||||||
|
|
||||||
@ -307,11 +370,13 @@ class Door(Entity):
|
|||||||
def _open(self):
|
def _open(self):
|
||||||
self.connectivity.add_edges_from([(self.pos, x) for x in range(len(self.connectivity_subgroups))])
|
self.connectivity.add_edges_from([(self.pos, x) for x in range(len(self.connectivity_subgroups))])
|
||||||
self._state = c.OPEN_DOOR
|
self._state = c.OPEN_DOOR
|
||||||
|
self._collection.notify_change_to_value(self)
|
||||||
self.time_to_close = self.auto_close_interval
|
self.time_to_close = self.auto_close_interval
|
||||||
|
|
||||||
def _close(self):
|
def _close(self):
|
||||||
self.connectivity.remove_node(self.pos)
|
self.connectivity.remove_node(self.pos)
|
||||||
self._state = c.CLOSED_DOOR
|
self._state = c.CLOSED_DOOR
|
||||||
|
self._collection.notify_change_to_value(self)
|
||||||
|
|
||||||
def is_linked(self, old_pos, new_pos):
|
def is_linked(self, old_pos, new_pos):
|
||||||
try:
|
try:
|
||||||
@ -323,20 +388,21 @@ class Door(Entity):
|
|||||||
|
|
||||||
class Agent(MoveableEntity):
|
class Agent(MoveableEntity):
|
||||||
|
|
||||||
|
@property
|
||||||
|
def can_collide(self):
|
||||||
|
return True
|
||||||
|
|
||||||
def __init__(self, *args, **kwargs):
|
def __init__(self, *args, **kwargs):
|
||||||
super(Agent, self).__init__(*args, **kwargs)
|
super(Agent, self).__init__(*args, **kwargs)
|
||||||
self.clear_temp_state()
|
self.clear_temp_state()
|
||||||
|
|
||||||
# noinspection PyAttributeOutsideInit
|
# noinspection PyAttributeOutsideInit
|
||||||
def clear_temp_state(self):
|
def clear_temp_state(self):
|
||||||
# for attr in self.__dict__:
|
# for attr in cls.__dict__:
|
||||||
# if attr.startswith('temp'):
|
# if attr.startswith('temp'):
|
||||||
self.temp_collisions = []
|
self.step_result = None
|
||||||
self.temp_valid = None
|
|
||||||
self.temp_action = None
|
|
||||||
self.temp_light_map = None
|
|
||||||
|
|
||||||
def summarize_state(self, **kwargs):
|
def summarize_state(self):
|
||||||
state_dict = super().summarize_state(**kwargs)
|
state_dict = super().summarize_state()
|
||||||
state_dict.update(valid=bool(self.temp_valid), action=str(self.temp_action))
|
state_dict.update(valid=bool(self.step_result['action_valid']), action=str(self.step_result['action_name']))
|
||||||
return state_dict
|
return state_dict
|
||||||
|
@ -1,106 +1,183 @@
|
|||||||
import numbers
|
import numbers
|
||||||
import random
|
import random
|
||||||
from abc import ABC
|
from abc import ABC
|
||||||
from typing import List, Union, Dict
|
from typing import List, Union, Dict, Tuple
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
import six
|
||||||
|
|
||||||
from environments.factory.base.objects import Entity, Tile, Agent, Door, Action, Wall, Object, PlaceHolder
|
from environments.factory.base.objects import Entity, Floor, Agent, Door, Action, Wall, PlaceHolder, GlobalPosition, \
|
||||||
|
Object, EnvObject
|
||||||
from environments.utility_classes import MovementProperties
|
from environments.utility_classes import MovementProperties
|
||||||
from environments import helpers as h
|
from environments import helpers as h
|
||||||
from environments.helpers import Constants as c
|
from environments.helpers import Constants as c
|
||||||
|
|
||||||
|
##########################################################################
|
||||||
|
# ################## Base Collections Definition ####################### #
|
||||||
|
##########################################################################
|
||||||
|
|
||||||
class Register:
|
|
||||||
_accepted_objects = Entity
|
class ObjectCollection:
|
||||||
|
_accepted_objects = Object
|
||||||
|
_stateless_entities = False
|
||||||
|
|
||||||
|
@property
|
||||||
|
def is_stateless(self):
|
||||||
|
return self._stateless_entities
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def name(self):
|
def name(self):
|
||||||
return f'{self.__class__.__name__}'
|
return f'{self.__class__.__name__}'
|
||||||
|
|
||||||
def __init__(self, *args, **kwargs):
|
def __init__(self, *args, **kwargs):
|
||||||
self._register = dict()
|
self._collection = dict()
|
||||||
|
|
||||||
def __len__(self):
|
def __len__(self):
|
||||||
return len(self._register)
|
return len(self._collection)
|
||||||
|
|
||||||
def __iter__(self):
|
def __iter__(self):
|
||||||
return iter(self.values())
|
return iter(self.values())
|
||||||
|
|
||||||
def register_item(self, other: _accepted_objects):
|
def add_item(self, other: _accepted_objects):
|
||||||
assert isinstance(other, self._accepted_objects), f'All item names have to be of type ' \
|
assert isinstance(other, self._accepted_objects), f'All item names have to be of type ' \
|
||||||
f'{self._accepted_objects}, ' \
|
f'{self._accepted_objects}, ' \
|
||||||
f'but were {other.__class__}.,'
|
f'but were {other.__class__}.,'
|
||||||
self._register.update({other.name: other})
|
self._collection.update({other.name: other})
|
||||||
return self
|
return self
|
||||||
|
|
||||||
def register_additional_items(self, others: List[_accepted_objects]):
|
def add_additional_items(self, others: List[_accepted_objects]):
|
||||||
for other in others:
|
for other in others:
|
||||||
self.register_item(other)
|
self.add_item(other)
|
||||||
return self
|
return self
|
||||||
|
|
||||||
def keys(self):
|
def keys(self):
|
||||||
return self._register.keys()
|
return self._collection.keys()
|
||||||
|
|
||||||
def values(self):
|
def values(self):
|
||||||
return self._register.values()
|
return self._collection.values()
|
||||||
|
|
||||||
def items(self):
|
def items(self):
|
||||||
return self._register.items()
|
return self._collection.items()
|
||||||
|
|
||||||
|
def _get_index(self, item):
|
||||||
|
try:
|
||||||
|
return next(i for i, v in enumerate(self._collection.values()) if v == item)
|
||||||
|
except StopIteration:
|
||||||
|
return None
|
||||||
|
|
||||||
def __getitem__(self, item):
|
def __getitem__(self, item):
|
||||||
if isinstance(item, (int, np.int64, np.int32)):
|
if isinstance(item, (int, np.int64, np.int32)):
|
||||||
if item < 0:
|
if item < 0:
|
||||||
item = len(self._register) - abs(item)
|
item = len(self._collection) - abs(item)
|
||||||
try:
|
try:
|
||||||
return next(v for i, v in enumerate(self._register.values()) if i == item)
|
return next(v for i, v in enumerate(self._collection.values()) if i == item)
|
||||||
except StopIteration:
|
except StopIteration:
|
||||||
return None
|
return None
|
||||||
try:
|
try:
|
||||||
return self._register[item]
|
return self._collection[item]
|
||||||
except KeyError:
|
except KeyError:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def __repr__(self):
|
def __repr__(self):
|
||||||
return f'{self.__class__.__name__}({self._register})'
|
return f'{self.__class__.__name__}[{self._collection}]'
|
||||||
|
|
||||||
|
|
||||||
class ObjectRegister(Register):
|
class EnvObjectCollection(ObjectCollection):
|
||||||
|
|
||||||
hide_from_obs_builder = False
|
_accepted_objects = EnvObject
|
||||||
|
|
||||||
def __init__(self, level_shape: (int, int), *args, individual_slices=False, is_per_agent=False, **kwargs):
|
@property
|
||||||
super(ObjectRegister, self).__init__(*args, **kwargs)
|
def encodings(self):
|
||||||
self.is_per_agent = is_per_agent
|
return [x.encoding for x in self]
|
||||||
self.individual_slices = individual_slices
|
|
||||||
self._level_shape = level_shape
|
def __init__(self, obs_shape: (int, int), *args,
|
||||||
|
individual_slices: bool = False,
|
||||||
|
is_blocking_light: bool = False,
|
||||||
|
can_collide: bool = False,
|
||||||
|
can_be_shadowed: bool = True, **kwargs):
|
||||||
|
super(EnvObjectCollection, self).__init__(*args, **kwargs)
|
||||||
|
self._shape = obs_shape
|
||||||
self._array = None
|
self._array = None
|
||||||
|
self._individual_slices = individual_slices
|
||||||
|
self._lazy_eval_transforms = []
|
||||||
|
self.is_blocking_light = is_blocking_light
|
||||||
|
self.can_be_shadowed = can_be_shadowed
|
||||||
|
self.can_collide = can_collide
|
||||||
|
|
||||||
def register_item(self, other):
|
def add_item(self, other: EnvObject):
|
||||||
super(ObjectRegister, self).register_item(other)
|
super(EnvObjectCollection, self).add_item(other)
|
||||||
if self._array is None:
|
if self._array is None:
|
||||||
self._array = np.zeros((1, *self._level_shape))
|
self._array = np.zeros((1, *self._shape))
|
||||||
else:
|
else:
|
||||||
if self.individual_slices:
|
if self._individual_slices:
|
||||||
self._array = np.concatenate((self._array, np.zeros((1, *self._array.shape[1:]))))
|
self._array = np.vstack((self._array, np.zeros((1, *self._shape))))
|
||||||
|
self.notify_change_to_value(other)
|
||||||
def summarize_states(self, n_steps=None):
|
|
||||||
return [val.summarize_state(n_steps=n_steps) for val in self.values()]
|
|
||||||
|
|
||||||
|
|
||||||
class EntityObjectRegister(ObjectRegister, ABC):
|
|
||||||
|
|
||||||
def as_array(self):
|
def as_array(self):
|
||||||
raise NotImplementedError
|
if self._lazy_eval_transforms:
|
||||||
|
idxs, values = zip(*self._lazy_eval_transforms)
|
||||||
|
# nuumpy put repects the ordering so that
|
||||||
|
np.put(self._array, idxs, values)
|
||||||
|
self._lazy_eval_transforms = []
|
||||||
|
return self._array
|
||||||
|
|
||||||
|
def summarize_states(self):
|
||||||
|
return [entity.summarize_state() for entity in self.values()]
|
||||||
|
|
||||||
|
def notify_change_to_free(self, env_object: EnvObject):
|
||||||
|
self._array_change_notifyer(env_object, value=c.FREE_CELL)
|
||||||
|
|
||||||
|
def notify_change_to_value(self, env_object: EnvObject):
|
||||||
|
self._array_change_notifyer(env_object)
|
||||||
|
|
||||||
|
def _array_change_notifyer(self, env_object: EnvObject, value=None):
|
||||||
|
pos = self._get_index(env_object)
|
||||||
|
value = value if value is not None else env_object.encoding
|
||||||
|
self._lazy_eval_transforms.append((pos, value))
|
||||||
|
if self._individual_slices:
|
||||||
|
idx = (self._get_index(env_object) * np.prod(self._shape[1:]), value)
|
||||||
|
self._lazy_eval_transforms.append((idx, value))
|
||||||
|
else:
|
||||||
|
self._lazy_eval_transforms.append((pos, value))
|
||||||
|
|
||||||
|
def _refresh_arrays(self):
|
||||||
|
poss, values = zip(*[(idx, x.encoding) for idx,x in enumerate(self.values())])
|
||||||
|
for pos, value in zip(poss, values):
|
||||||
|
self._lazy_eval_transforms.append((pos, value))
|
||||||
|
|
||||||
|
def __delitem__(self, name):
|
||||||
|
idx, obj = next((i, obj) for i, obj in enumerate(self) if obj.name == name)
|
||||||
|
if self._individual_slices:
|
||||||
|
self._array = np.delete(self._array, idx, axis=0)
|
||||||
|
else:
|
||||||
|
self.notify_change_to_free(self._collection[name])
|
||||||
|
# Dirty Hack to check if not beeing subclassed. In that case we need to refresh the array since positions
|
||||||
|
# in the observation array are result of enumeration. They can overide each other.
|
||||||
|
# Todo: Find a better solution
|
||||||
|
if not issubclass(self.__class__, EntityCollection) and issubclass(self.__class__, EnvObjectCollection):
|
||||||
|
self._refresh_arrays()
|
||||||
|
del self._collection[name]
|
||||||
|
|
||||||
|
def delete_env_object(self, env_object: EnvObject):
|
||||||
|
del self[env_object.name]
|
||||||
|
|
||||||
|
def delete_env_object_by_name(self, name):
|
||||||
|
del self[name]
|
||||||
|
|
||||||
|
|
||||||
|
class EntityCollection(EnvObjectCollection, ABC):
|
||||||
|
|
||||||
|
_accepted_objects = Entity
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_tiles(cls, tiles, *args, entity_kwargs=None, **kwargs):
|
def from_tiles(cls, tiles, *args, entity_kwargs=None, **kwargs):
|
||||||
# objects_name = cls._accepted_objects.__name__
|
# objects_name = cls._accepted_objects.__name__
|
||||||
register_obj = cls(*args, **kwargs)
|
collection = cls(*args, **kwargs)
|
||||||
entities = [cls._accepted_objects(tile, str_ident=i, **entity_kwargs if entity_kwargs is not None else {})
|
entities = [cls._accepted_objects(tile, collection, str_ident=i,
|
||||||
|
**entity_kwargs if entity_kwargs is not None else {})
|
||||||
for i, tile in enumerate(tiles)]
|
for i, tile in enumerate(tiles)]
|
||||||
register_obj.register_additional_items(entities)
|
collection.add_additional_items(entities)
|
||||||
return register_obj
|
return collection
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_argwhere_coordinates(cls, positions: [(int, int)], tiles, *args, entity_kwargs=None, **kwargs, ):
|
def from_argwhere_coordinates(cls, positions: [(int, int)], tiles, *args, entity_kwargs=None, **kwargs, ):
|
||||||
@ -115,90 +192,169 @@ class EntityObjectRegister(ObjectRegister, ABC):
|
|||||||
def tiles(self):
|
def tiles(self):
|
||||||
return [entity.tile for entity in self]
|
return [entity.tile for entity in self]
|
||||||
|
|
||||||
def __init__(self, *args, is_blocking_light=False, is_observable=True, can_be_shadowed=True, **kwargs):
|
def __init__(self, level_shape, *args, **kwargs):
|
||||||
super(EntityObjectRegister, self).__init__(*args, **kwargs)
|
super(EntityCollection, self).__init__(level_shape, *args, **kwargs)
|
||||||
self.can_be_shadowed = can_be_shadowed
|
self._lazy_eval_transforms = []
|
||||||
self.is_blocking_light = is_blocking_light
|
|
||||||
self.is_observable = is_observable
|
|
||||||
|
|
||||||
def by_pos(self, pos):
|
|
||||||
if isinstance(pos, np.ndarray):
|
|
||||||
pos = tuple(pos)
|
|
||||||
try:
|
|
||||||
return next(item for item in self.values() if item.pos == pos)
|
|
||||||
except StopIteration:
|
|
||||||
return None
|
|
||||||
|
|
||||||
|
|
||||||
class MovingEntityObjectRegister(EntityObjectRegister, ABC):
|
|
||||||
|
|
||||||
def __init__(self, *args, **kwargs):
|
|
||||||
super(MovingEntityObjectRegister, self).__init__(*args, **kwargs)
|
|
||||||
|
|
||||||
def by_pos(self, pos):
|
|
||||||
if isinstance(pos, np.ndarray):
|
|
||||||
pos = tuple(pos)
|
|
||||||
try:
|
|
||||||
return next(x for x in self if x.pos == pos)
|
|
||||||
except StopIteration:
|
|
||||||
return None
|
|
||||||
|
|
||||||
def __delitem__(self, name):
|
def __delitem__(self, name):
|
||||||
idx = next(i for i, entity in enumerate(self) if entity.name == name)
|
idx, obj = next((i, obj) for i, obj in enumerate(self) if obj.name == name)
|
||||||
del self._register[name]
|
obj.tile.leave(obj)
|
||||||
if self.individual_slices:
|
super(EntityCollection, self).__delitem__(name)
|
||||||
self._array = np.delete(self._array, idx, axis=0)
|
|
||||||
|
|
||||||
def delete_entity(self, item):
|
def as_array(self):
|
||||||
self.delete_entity_by_name(item.name)
|
if self._lazy_eval_transforms:
|
||||||
|
idxs, values = zip(*self._lazy_eval_transforms)
|
||||||
|
# numpy put repects the ordering so that
|
||||||
|
# Todo: Export the index building in a seperate function
|
||||||
|
np.put(self._array, [np.ravel_multi_index(idx, self._array.shape) for idx in idxs], values)
|
||||||
|
self._lazy_eval_transforms = []
|
||||||
|
return self._array
|
||||||
|
|
||||||
def delete_entity_by_name(self, name):
|
def _array_change_notifyer(self, entity, pos=None, value=None):
|
||||||
del self[name]
|
# Todo: Export the contruction in a seperate function
|
||||||
|
pos = pos if pos is not None else entity.pos
|
||||||
|
value = value if value is not None else entity.encoding
|
||||||
|
x, y = pos
|
||||||
|
if self._individual_slices:
|
||||||
|
idx = (self._get_index(entity), x, y)
|
||||||
|
else:
|
||||||
|
idx = (0, x, y)
|
||||||
|
self._lazy_eval_transforms.append((idx, value))
|
||||||
|
|
||||||
|
def by_pos(self, pos: Tuple[int, int]):
|
||||||
|
try:
|
||||||
|
return next(item for item in self if item.pos == tuple(pos))
|
||||||
|
except StopIteration:
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
class PlaceHolders(MovingEntityObjectRegister):
|
class BoundEnvObjCollection(EnvObjectCollection, ABC):
|
||||||
|
|
||||||
|
def __init__(self, entity_to_be_bound, *args, **kwargs):
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
self._bound_entity = entity_to_be_bound
|
||||||
|
|
||||||
|
def belongs_to_entity(self, entity):
|
||||||
|
return self._bound_entity == entity
|
||||||
|
|
||||||
|
def by_entity(self, entity):
|
||||||
|
try:
|
||||||
|
return next((x for x in self if x.belongs_to_entity(entity)))
|
||||||
|
except StopIteration:
|
||||||
|
return None
|
||||||
|
|
||||||
|
def idx_by_entity(self, entity):
|
||||||
|
try:
|
||||||
|
return next((idx for idx, x in enumerate(self) if x.belongs_to_entity(entity)))
|
||||||
|
except StopIteration:
|
||||||
|
return None
|
||||||
|
|
||||||
|
def as_array_by_entity(self, entity):
|
||||||
|
return self._array[self.idx_by_entity(entity)]
|
||||||
|
|
||||||
|
|
||||||
|
class MovingEntityObjectCollection(EntityCollection, ABC):
|
||||||
|
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
super(MovingEntityObjectCollection, self).__init__(*args, **kwargs)
|
||||||
|
|
||||||
|
def notify_change_to_value(self, entity):
|
||||||
|
super(MovingEntityObjectCollection, self).notify_change_to_value(entity)
|
||||||
|
if entity.last_pos != c.NO_POS:
|
||||||
|
try:
|
||||||
|
self._array_change_notifyer(entity, entity.last_pos, value=c.FREE_CELL)
|
||||||
|
except AttributeError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
##########################################################################
|
||||||
|
# ################# Objects and Entity Collection ###################### #
|
||||||
|
##########################################################################
|
||||||
|
|
||||||
|
|
||||||
|
class GlobalPositions(EnvObjectCollection):
|
||||||
|
|
||||||
|
_accepted_objects = GlobalPosition
|
||||||
|
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
super(GlobalPositions, self).__init__(*args, is_per_agent=True, individual_slices=True, is_blocking_light = False,
|
||||||
|
can_be_shadowed = False, can_collide = False, **kwargs)
|
||||||
|
|
||||||
|
def as_array(self):
|
||||||
|
# FIXME DEBUG!!! make this lazy?
|
||||||
|
return np.stack([gp.as_array() for inv_idx, gp in enumerate(self)])
|
||||||
|
|
||||||
|
def as_array_by_entity(self, entity):
|
||||||
|
# FIXME DEBUG!!! make this lazy?
|
||||||
|
return np.stack([gp.as_array() for inv_idx, gp in enumerate(self)])
|
||||||
|
|
||||||
|
def spawn_global_position_objects(self, agents):
|
||||||
|
# Todo, change to 'from xy'-form
|
||||||
|
global_positions = [self._accepted_objects(self._shape, agent, self)
|
||||||
|
for _, agent in enumerate(agents)]
|
||||||
|
# noinspection PyTypeChecker
|
||||||
|
self.add_additional_items(global_positions)
|
||||||
|
|
||||||
|
def idx_by_entity(self, entity):
|
||||||
|
try:
|
||||||
|
return next((idx for idx, inv in enumerate(self) if inv.belongs_to_entity(entity)))
|
||||||
|
except StopIteration:
|
||||||
|
return None
|
||||||
|
|
||||||
|
def by_entity(self, entity):
|
||||||
|
try:
|
||||||
|
return next((inv for inv in self if inv.belongs_to_entity(entity)))
|
||||||
|
except StopIteration:
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
class PlaceHolders(EnvObjectCollection):
|
||||||
_accepted_objects = PlaceHolder
|
_accepted_objects = PlaceHolder
|
||||||
|
|
||||||
def __init__(self, *args, fill_value: Union[str, int] = 0, **kwargs):
|
def __init__(self, *args, **kwargs):
|
||||||
|
assert 'individual_slices' not in kwargs, 'Keyword - "individual_slices": "True" and must not be altered'
|
||||||
|
kwargs.update(individual_slices=False)
|
||||||
super().__init__(*args, **kwargs)
|
super().__init__(*args, **kwargs)
|
||||||
self.fill_value = fill_value
|
|
||||||
|
@classmethod
|
||||||
|
def from_values(cls, values: Union[str, numbers.Number, List[Union[str, numbers.Number]]],
|
||||||
|
*args, object_kwargs=None, **kwargs):
|
||||||
|
# objects_name = cls._accepted_objects.__name__
|
||||||
|
if isinstance(values, (str, numbers.Number)):
|
||||||
|
values = [values]
|
||||||
|
collection = cls(*args, **kwargs)
|
||||||
|
objects = [cls._accepted_objects(collection, str_ident=i, fill_value=value,
|
||||||
|
**object_kwargs if object_kwargs is not None else {})
|
||||||
|
for i, value in enumerate(values)]
|
||||||
|
collection.add_additional_items(objects)
|
||||||
|
return collection
|
||||||
|
|
||||||
# noinspection DuplicatedCode
|
# noinspection DuplicatedCode
|
||||||
def as_array(self):
|
def as_array(self):
|
||||||
if isinstance(self.fill_value, numbers.Number):
|
for idx, placeholder in enumerate(self):
|
||||||
self._array[:] = self.fill_value
|
if isinstance(placeholder.encoding, numbers.Number):
|
||||||
elif isinstance(self.fill_value, str):
|
self._array[idx][:] = placeholder.fill_value
|
||||||
if self.fill_value.lower() in ['normal', 'n']:
|
elif isinstance(placeholder.fill_value, str):
|
||||||
self._array = np.random.normal(size=self._array.shape)
|
if placeholder.fill_value.lower() in ['normal', 'n']:
|
||||||
|
self._array[:] = np.random.normal(size=self._array.shape)
|
||||||
|
else:
|
||||||
|
raise ValueError('Choose one of: ["normal", "N"]')
|
||||||
else:
|
else:
|
||||||
raise ValueError('Choose one of: ["normal", "N"]')
|
raise TypeError('Objects of type "str" or "number" is required here.')
|
||||||
else:
|
|
||||||
raise TypeError('Objects of type "str" or "number" is required here.')
|
|
||||||
|
|
||||||
if self.individual_slices:
|
return self._array
|
||||||
return self._array
|
|
||||||
else:
|
|
||||||
return self._array[None, 0]
|
|
||||||
|
|
||||||
|
|
||||||
class Entities(Register):
|
class Entities(ObjectCollection):
|
||||||
|
_accepted_objects = EntityCollection
|
||||||
_accepted_objects = EntityObjectRegister
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def observable_arrays(self):
|
def arrays(self):
|
||||||
# FIXME: Find a better name
|
return {key: val.as_array() for key, val in self.items()}
|
||||||
return {key: val.as_array() for key, val in self.items() if val.is_observable}
|
|
||||||
|
|
||||||
@property
|
|
||||||
def obs_arrays(self):
|
|
||||||
# FIXME: Find a better name
|
|
||||||
return {key: val.as_array() for key, val in self.items() if val.is_observable and not val.hide_from_obs_builder}
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def names(self):
|
def names(self):
|
||||||
return list(self._register.keys())
|
return list(self._collection.keys())
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super(Entities, self).__init__()
|
super(Entities, self).__init__()
|
||||||
@ -206,48 +362,45 @@ class Entities(Register):
|
|||||||
def iter_individual_entitites(self):
|
def iter_individual_entitites(self):
|
||||||
return iter((x for sublist in self.values() for x in sublist))
|
return iter((x for sublist in self.values() for x in sublist))
|
||||||
|
|
||||||
def register_item(self, other: dict):
|
def add_item(self, other: dict):
|
||||||
assert not any([key for key in other.keys() if key in self.keys()]), \
|
assert not any([key for key in other.keys() if key in self.keys()]), \
|
||||||
"This group of entities has already been registered!"
|
"This group of entities has already been added!"
|
||||||
self._register.update(other)
|
self._collection.update(other)
|
||||||
return self
|
return self
|
||||||
|
|
||||||
def register_additional_items(self, others: Dict):
|
def add_additional_items(self, others: Dict):
|
||||||
return self.register_item(others)
|
return self.add_item(others)
|
||||||
|
|
||||||
def by_pos(self, pos: (int, int)):
|
def by_pos(self, pos: (int, int)):
|
||||||
found_entities = [y for y in (x.by_pos(pos) for x in self.values() if hasattr(x, 'by_pos')) if y is not None]
|
found_entities = [y for y in (x.by_pos(pos) for x in self.values() if hasattr(x, 'by_pos')) if y is not None]
|
||||||
return found_entities
|
return found_entities
|
||||||
|
|
||||||
|
|
||||||
class WallTiles(EntityObjectRegister):
|
class Walls(EntityCollection):
|
||||||
_accepted_objects = Wall
|
_accepted_objects = Wall
|
||||||
_light_blocking = True
|
_stateless_entities = True
|
||||||
|
|
||||||
def as_array(self):
|
def as_array(self):
|
||||||
if not np.any(self._array):
|
if not np.any(self._array):
|
||||||
|
# Which is Faster?
|
||||||
|
# indices = [x.pos for x in cls]
|
||||||
|
# np.put(cls._array, [np.ravel_multi_index((0, *x), cls._array.shape) for x in indices], cls.encodings)
|
||||||
x, y = zip(*[x.pos for x in self])
|
x, y = zip(*[x.pos for x in self])
|
||||||
self._array[0, x, y] = self.encoding
|
self._array[0, x, y] = self._value
|
||||||
return self._array
|
return self._array
|
||||||
|
|
||||||
def __init__(self, *args, **kwargs):
|
def __init__(self, *args, is_blocking_light=True, **kwargs):
|
||||||
super(WallTiles, self).__init__(*args, individual_slices=False,
|
super(Walls, self).__init__(*args, individual_slices=False,
|
||||||
is_blocking_light=self._light_blocking, **kwargs)
|
can_collide=True,
|
||||||
|
is_blocking_light=is_blocking_light, **kwargs)
|
||||||
@property
|
self._value = c.OCCUPIED_CELL
|
||||||
def encoding(self):
|
|
||||||
return c.OCCUPIED_CELL.value
|
|
||||||
|
|
||||||
@property
|
|
||||||
def array(self):
|
|
||||||
return self._array
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_argwhere_coordinates(cls, argwhere_coordinates, *args, **kwargs):
|
def from_argwhere_coordinates(cls, argwhere_coordinates, *args, **kwargs):
|
||||||
tiles = cls(*args, **kwargs)
|
tiles = cls(*args, **kwargs)
|
||||||
# noinspection PyTypeChecker
|
# noinspection PyTypeChecker
|
||||||
tiles.register_additional_items(
|
tiles.add_additional_items(
|
||||||
[cls._accepted_objects(pos, is_blocking_light=cls._light_blocking)
|
[cls._accepted_objects(pos, tiles)
|
||||||
for pos in argwhere_coordinates]
|
for pos in argwhere_coordinates]
|
||||||
)
|
)
|
||||||
return tiles
|
return tiles
|
||||||
@ -256,24 +409,14 @@ class WallTiles(EntityObjectRegister):
|
|||||||
def from_tiles(cls, tiles, *args, **kwargs):
|
def from_tiles(cls, tiles, *args, **kwargs):
|
||||||
raise RuntimeError()
|
raise RuntimeError()
|
||||||
|
|
||||||
def summarize_states(self, n_steps=None):
|
|
||||||
if n_steps == h.STEPS_START:
|
|
||||||
return super(WallTiles, self).summarize_states(n_steps=n_steps)
|
|
||||||
else:
|
|
||||||
return {}
|
|
||||||
|
|
||||||
|
class Floors(Walls):
|
||||||
|
_accepted_objects = Floor
|
||||||
|
_stateless_entities = True
|
||||||
|
|
||||||
class FloorTiles(WallTiles):
|
def __init__(self, *args, is_blocking_light=False, **kwargs):
|
||||||
|
super(Floors, self).__init__(*args, is_blocking_light=is_blocking_light, **kwargs)
|
||||||
_accepted_objects = Tile
|
self._value = c.FREE_CELL
|
||||||
_light_blocking = False
|
|
||||||
|
|
||||||
def __init__(self, *args, **kwargs):
|
|
||||||
super(FloorTiles, self).__init__(*args, is_observable=False, **kwargs)
|
|
||||||
|
|
||||||
@property
|
|
||||||
def encoding(self):
|
|
||||||
return c.FREE_CELL.value
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def occupied_tiles(self):
|
def occupied_tiles(self):
|
||||||
@ -282,7 +425,7 @@ class FloorTiles(WallTiles):
|
|||||||
return tiles
|
return tiles
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def empty_tiles(self) -> List[Tile]:
|
def empty_tiles(self) -> List[Floor]:
|
||||||
tiles = [tile for tile in self if tile.is_empty()]
|
tiles = [tile for tile in self if tile.is_empty()]
|
||||||
random.shuffle(tiles)
|
random.shuffle(tiles)
|
||||||
return tiles
|
return tiles
|
||||||
@ -291,32 +434,12 @@ class FloorTiles(WallTiles):
|
|||||||
def from_tiles(cls, tiles, *args, **kwargs):
|
def from_tiles(cls, tiles, *args, **kwargs):
|
||||||
raise RuntimeError()
|
raise RuntimeError()
|
||||||
|
|
||||||
def summarize_states(self, n_steps=None):
|
|
||||||
# Do not summarize
|
|
||||||
return {}
|
|
||||||
|
|
||||||
|
|
||||||
class Agents(MovingEntityObjectRegister):
|
|
||||||
|
|
||||||
|
class Agents(MovingEntityObjectCollection):
|
||||||
_accepted_objects = Agent
|
_accepted_objects = Agent
|
||||||
|
|
||||||
def __init__(self, *args, hide_from_obs_builder=False, **kwargs):
|
def __init__(self, *args, **kwargs):
|
||||||
super().__init__(*args, **kwargs)
|
super().__init__(*args, can_collide=True, **kwargs)
|
||||||
self.hide_from_obs_builder = hide_from_obs_builder
|
|
||||||
|
|
||||||
# noinspection DuplicatedCode
|
|
||||||
def as_array(self):
|
|
||||||
self._array[:] = c.FREE_CELL.value
|
|
||||||
# noinspection PyTupleAssignmentBalance
|
|
||||||
for z, x, y, v in zip(range(len(self)), *zip(*[x.pos for x in self]), [x.encoding for x in self]):
|
|
||||||
if self.individual_slices:
|
|
||||||
self._array[z, x, y] += v
|
|
||||||
else:
|
|
||||||
self._array[0, x, y] += v
|
|
||||||
if self.individual_slices:
|
|
||||||
return self._array
|
|
||||||
else:
|
|
||||||
return self._array.sum(axis=0, keepdims=True)
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def positions(self):
|
def positions(self):
|
||||||
@ -326,19 +449,15 @@ class Agents(MovingEntityObjectRegister):
|
|||||||
old_agent = self[key]
|
old_agent = self[key]
|
||||||
self[key].tile.leave(self[key])
|
self[key].tile.leave(self[key])
|
||||||
agent._name = old_agent.name
|
agent._name = old_agent.name
|
||||||
self._register[agent.name] = agent
|
self._collection[agent.name] = agent
|
||||||
|
|
||||||
|
|
||||||
class Doors(EntityObjectRegister):
|
class Doors(EntityCollection):
|
||||||
|
|
||||||
def __init__(self, *args, **kwargs):
|
def __init__(self, *args, have_area: bool = False, **kwargs):
|
||||||
super(Doors, self).__init__(*args, is_blocking_light=True, **kwargs)
|
self.have_area = have_area
|
||||||
|
self._area_marked = False
|
||||||
def as_array(self):
|
super(Doors, self).__init__(*args, is_blocking_light=True, can_collide=True, **kwargs)
|
||||||
self._array[:] = 0
|
|
||||||
for door in self:
|
|
||||||
self._array[0, door.x, door.y] = door.encoding
|
|
||||||
return self._array
|
|
||||||
|
|
||||||
_accepted_objects = Door
|
_accepted_objects = Door
|
||||||
|
|
||||||
@ -352,9 +471,20 @@ class Doors(EntityObjectRegister):
|
|||||||
for door in self:
|
for door in self:
|
||||||
door.tick()
|
door.tick()
|
||||||
|
|
||||||
|
def as_array(self):
|
||||||
|
if self.have_area and not self._area_marked:
|
||||||
|
for door in self:
|
||||||
|
for pos in door.access_area:
|
||||||
|
if self._individual_slices:
|
||||||
|
pass
|
||||||
|
else:
|
||||||
|
pos = (0, *pos)
|
||||||
|
self._lazy_eval_transforms.append((pos, c.ACCESS_DOOR_CELL))
|
||||||
|
self._area_marked = True
|
||||||
|
return super(Doors, self).as_array()
|
||||||
|
|
||||||
class Actions(Register):
|
|
||||||
|
|
||||||
|
class Actions(ObjectCollection):
|
||||||
_accepted_objects = Action
|
_accepted_objects = Action
|
||||||
|
|
||||||
@property
|
@property
|
||||||
@ -369,27 +499,31 @@ class Actions(Register):
|
|||||||
self.can_use_doors = can_use_doors
|
self.can_use_doors = can_use_doors
|
||||||
super(Actions, self).__init__()
|
super(Actions, self).__init__()
|
||||||
|
|
||||||
|
# Move this to Baseclass, Env init?
|
||||||
if self.allow_square_movement:
|
if self.allow_square_movement:
|
||||||
self.register_additional_items([self._accepted_objects(enum_ident=direction)
|
self.add_additional_items([self._accepted_objects(str_ident=direction)
|
||||||
for direction in h.MovingAction.square()])
|
for direction in h.EnvActions.square_move()])
|
||||||
if self.allow_diagonal_movement:
|
if self.allow_diagonal_movement:
|
||||||
self.register_additional_items([self._accepted_objects(enum_ident=direction)
|
self.add_additional_items([self._accepted_objects(str_ident=direction)
|
||||||
for direction in h.MovingAction.diagonal()])
|
for direction in h.EnvActions.diagonal_move()])
|
||||||
self._movement_actions = self._register.copy()
|
self._movement_actions = self._collection.copy()
|
||||||
if self.can_use_doors:
|
if self.can_use_doors:
|
||||||
self.register_additional_items([self._accepted_objects(enum_ident=h.EnvActions.USE_DOOR)])
|
self.add_additional_items([self._accepted_objects(str_ident=h.EnvActions.USE_DOOR)])
|
||||||
if self.allow_no_op:
|
if self.allow_no_op:
|
||||||
self.register_additional_items([self._accepted_objects(enum_ident=h.EnvActions.NOOP)])
|
self.add_additional_items([self._accepted_objects(str_ident=h.EnvActions.NOOP)])
|
||||||
|
|
||||||
def is_moving_action(self, action: Union[int]):
|
def is_moving_action(self, action: Union[int]):
|
||||||
return action in self.movement_actions.values()
|
return action in self.movement_actions.values()
|
||||||
|
|
||||||
|
def summarize(self):
|
||||||
|
return [dict(name=action.identifier) for action in self]
|
||||||
|
|
||||||
class Zones(Register):
|
|
||||||
|
class Zones(ObjectCollection):
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def accounting_zones(self):
|
def accounting_zones(self):
|
||||||
return [self[idx] for idx, name in self.items() if name != c.DANGER_ZONE.value]
|
return [self[idx] for idx, name in self.items() if name != c.DANGER_ZONE]
|
||||||
|
|
||||||
def __init__(self, parsed_level):
|
def __init__(self, parsed_level):
|
||||||
raise NotImplementedError('This needs a Rework')
|
raise NotImplementedError('This needs a Rework')
|
||||||
@ -398,9 +532,9 @@ class Zones(Register):
|
|||||||
self._accounting_zones = list()
|
self._accounting_zones = list()
|
||||||
self._danger_zones = list()
|
self._danger_zones = list()
|
||||||
for symbol in np.unique(parsed_level):
|
for symbol in np.unique(parsed_level):
|
||||||
if symbol == c.WALL.value:
|
if symbol == c.WALL:
|
||||||
continue
|
continue
|
||||||
elif symbol == c.DANGER_ZONE.value:
|
elif symbol == c.DANGER_ZONE:
|
||||||
self + symbol
|
self + symbol
|
||||||
slices.append(h.one_hot_level(parsed_level, symbol))
|
slices.append(h.one_hot_level(parsed_level, symbol))
|
||||||
self._danger_zones.append(symbol)
|
self._danger_zones.append(symbol)
|
||||||
@ -414,5 +548,5 @@ class Zones(Register):
|
|||||||
def __getitem__(self, item):
|
def __getitem__(self, item):
|
||||||
return self._zone_slices[item]
|
return self._zone_slices[item]
|
||||||
|
|
||||||
def register_additional_items(self, other: Union[str, List[str]]):
|
def add_additional_items(self, other: Union[str, List[str]]):
|
||||||
raise AttributeError('You are not allowed to add additional Zones in runtime.')
|
raise AttributeError('You are not allowed to add additional Zones in runtime.')
|
||||||
|
@ -20,21 +20,33 @@ class RenderEntity(NamedTuple):
|
|||||||
aux: Any = None
|
aux: Any = None
|
||||||
|
|
||||||
|
|
||||||
|
class RenderNames:
|
||||||
|
AGENT: str = 'agent'
|
||||||
|
BLANK: str = 'blank'
|
||||||
|
DOOR: str = 'door'
|
||||||
|
OPACITY: str = 'opacity'
|
||||||
|
SCALE: str = 'scale'
|
||||||
|
rn = RenderNames
|
||||||
|
|
||||||
|
|
||||||
class Renderer:
|
class Renderer:
|
||||||
BG_COLOR = (178, 190, 195) # (99, 110, 114)
|
BG_COLOR = (178, 190, 195) # (99, 110, 114)
|
||||||
WHITE = (223, 230, 233) # (200, 200, 200)
|
WHITE = (223, 230, 233) # (200, 200, 200)
|
||||||
AGENT_VIEW_COLOR = (9, 132, 227)
|
AGENT_VIEW_COLOR = (9, 132, 227)
|
||||||
ASSETS = Path(__file__).parent.parent / 'assets'
|
ASSETS = Path(__file__).parent.parent / 'assets'
|
||||||
|
|
||||||
def __init__(self, grid_w=16, grid_h=16, cell_size=40, fps=7, grid_lines=True, view_radius=2):
|
def __init__(self, lvl_shape=(16, 16),
|
||||||
self.grid_h = grid_h
|
lvl_padded_shape=None,
|
||||||
self.grid_w = grid_w
|
cell_size=40, fps=7,
|
||||||
|
grid_lines=True, view_radius=2):
|
||||||
|
self.grid_h, self.grid_w = lvl_shape
|
||||||
|
self.lvl_padded_shape = lvl_padded_shape if lvl_padded_shape is not None else lvl_shape
|
||||||
self.cell_size = cell_size
|
self.cell_size = cell_size
|
||||||
self.fps = fps
|
self.fps = fps
|
||||||
self.grid_lines = grid_lines
|
self.grid_lines = grid_lines
|
||||||
self.view_radius = view_radius
|
self.view_radius = view_radius
|
||||||
pygame.init()
|
pygame.init()
|
||||||
self.screen_size = (grid_w*cell_size, grid_h*cell_size)
|
self.screen_size = (self.grid_w*cell_size, self.grid_h*cell_size)
|
||||||
self.screen = pygame.display.set_mode(self.screen_size)
|
self.screen = pygame.display.set_mode(self.screen_size)
|
||||||
self.clock = pygame.time.Clock()
|
self.clock = pygame.time.Clock()
|
||||||
assets = list(self.ASSETS.rglob('*.png'))
|
assets = list(self.ASSETS.rglob('*.png'))
|
||||||
@ -43,7 +55,7 @@ class Renderer:
|
|||||||
|
|
||||||
now = time.time()
|
now = time.time()
|
||||||
self.font = pygame.font.Font(None, 20)
|
self.font = pygame.font.Font(None, 20)
|
||||||
self.font.set_bold(1)
|
self.font.set_bold(True)
|
||||||
print('Loading System font with pygame.font.Font took', time.time() - now)
|
print('Loading System font with pygame.font.Font took', time.time() - now)
|
||||||
|
|
||||||
def fill_bg(self):
|
def fill_bg(self):
|
||||||
@ -56,11 +68,16 @@ class Renderer:
|
|||||||
pygame.draw.rect(self.screen, Renderer.WHITE, rect, 1)
|
pygame.draw.rect(self.screen, Renderer.WHITE, rect, 1)
|
||||||
|
|
||||||
def blit_params(self, entity):
|
def blit_params(self, entity):
|
||||||
|
offset_r, offset_c = (self.lvl_padded_shape[0] - self.grid_h) // 2, \
|
||||||
|
(self.lvl_padded_shape[1] - self.grid_w) // 2
|
||||||
|
|
||||||
r, c = entity.pos
|
r, c = entity.pos
|
||||||
|
r, c = r - offset_r, c-offset_c
|
||||||
|
|
||||||
img = self.assets[entity.name.lower()]
|
img = self.assets[entity.name.lower()]
|
||||||
if entity.value_operation == 'opacity':
|
if entity.value_operation == rn.OPACITY:
|
||||||
img.set_alpha(255*entity.value)
|
img.set_alpha(255*entity.value)
|
||||||
elif entity.value_operation == 'scale':
|
elif entity.value_operation == rn.SCALE:
|
||||||
re = img.get_rect()
|
re = img.get_rect()
|
||||||
img = pygame.transform.smoothscale(
|
img = pygame.transform.smoothscale(
|
||||||
img, (int(entity.value*re.width), int(entity.value*re.height))
|
img, (int(entity.value*re.width), int(entity.value*re.height))
|
||||||
@ -99,19 +116,19 @@ class Renderer:
|
|||||||
sys.exit()
|
sys.exit()
|
||||||
self.fill_bg()
|
self.fill_bg()
|
||||||
blits = deque()
|
blits = deque()
|
||||||
for entity in [x for x in entities if 'door' in x.name]:
|
for entity in [x for x in entities if rn.DOOR in x.name]:
|
||||||
bp = self.blit_params(entity)
|
bp = self.blit_params(entity)
|
||||||
blits.append(bp)
|
blits.append(bp)
|
||||||
for entity in [x for x in entities if 'door' not in x.name]:
|
for entity in [x for x in entities if rn.DOOR not in x.name]:
|
||||||
bp = self.blit_params(entity)
|
bp = self.blit_params(entity)
|
||||||
blits.append(bp)
|
blits.append(bp)
|
||||||
if entity.name.lower() == 'agent':
|
if entity.name.lower() == rn.AGENT:
|
||||||
if self.view_radius > 0:
|
if self.view_radius > 0:
|
||||||
vis_rects = self.visibility_rects(bp, entity.aux)
|
vis_rects = self.visibility_rects(bp, entity.aux)
|
||||||
blits.extendleft(vis_rects)
|
blits.extendleft(vis_rects)
|
||||||
if entity.state != 'blank':
|
if entity.state != rn.BLANK:
|
||||||
agent_state_blits = self.blit_params(
|
agent_state_blits = self.blit_params(
|
||||||
RenderEntity(entity.state, (entity.pos[0] + 0.12, entity.pos[1]), 0.48, 'scale')
|
RenderEntity(entity.state, (entity.pos[0] + 0.12, entity.pos[1]), 0.48, rn.SCALE)
|
||||||
)
|
)
|
||||||
textsurface = self.font.render(str(entity.id), False, (0, 0, 0))
|
textsurface = self.font.render(str(entity.id), False, (0, 0, 0))
|
||||||
text_blit = dict(source=textsurface, dest=(bp['dest'].center[0]-.07*self.cell_size,
|
text_blit = dict(source=textsurface, dest=(bp['dest'].center[0]-.07*self.cell_size,
|
||||||
|
@ -2,6 +2,7 @@ import numpy as np
|
|||||||
|
|
||||||
from environments.helpers import Constants as c
|
from environments.helpers import Constants as c
|
||||||
|
|
||||||
|
# Multipliers for transforming coordinates to other octants:
|
||||||
mult_array = np.asarray([
|
mult_array = np.asarray([
|
||||||
[1, 0, 0, -1, -1, 0, 0, 1],
|
[1, 0, 0, -1, -1, 0, 0, 1],
|
||||||
[0, 1, -1, 0, 0, -1, 1, 0],
|
[0, 1, -1, 0, 0, -1, 1, 0],
|
||||||
@ -11,19 +12,17 @@ mult_array = np.asarray([
|
|||||||
|
|
||||||
|
|
||||||
class Map(object):
|
class Map(object):
|
||||||
# Multipliers for transforming coordinates to other octants:
|
def __init__(self, map_array: np.typing.ArrayLike, diamond_slope: float = 0.9):
|
||||||
|
|
||||||
def __init__(self, map_array: np.ndarray, diamond_slope: float = 0.9):
|
|
||||||
self.data = map_array
|
self.data = map_array
|
||||||
self.width, self.height = map_array.shape
|
self.width, self.height = map_array.shape
|
||||||
self.light = np.full_like(self.data, c.FREE_CELL.value)
|
self.light = np.full_like(self.data, c.FREE_CELL)
|
||||||
self.flag = c.FREE_CELL.value
|
self.flag = c.FREE_CELL
|
||||||
self.d_slope = diamond_slope
|
self.d_slope = diamond_slope
|
||||||
|
|
||||||
def blocked(self, x, y):
|
def blocked(self, x, y):
|
||||||
return (x < 0 or y < 0
|
return (x < 0 or y < 0
|
||||||
or x >= self.width or y >= self.height
|
or x >= self.width or y >= self.height
|
||||||
or self.data[x, y] == c.OCCUPIED_CELL.value)
|
or self.data[x, y] == c.OCCUPIED_CELL)
|
||||||
|
|
||||||
def lit(self, x, y):
|
def lit(self, x, y):
|
||||||
return self.light[x, y] == self.flag
|
return self.light[x, y] == self.flag
|
||||||
@ -33,7 +32,7 @@ class Map(object):
|
|||||||
self.light[x, y] = self.flag
|
self.light[x, y] = self.flag
|
||||||
|
|
||||||
def _cast_light(self, cx, cy, row, start, end, radius, xx, xy, yx, yy, id):
|
def _cast_light(self, cx, cy, row, start, end, radius, xx, xy, yx, yy, id):
|
||||||
"Recursive lightcasting function"
|
"""Recursive lightcasting function"""
|
||||||
if start < end:
|
if start < end:
|
||||||
return
|
return
|
||||||
radius_squared = radius*radius
|
radius_squared = radius*radius
|
||||||
@ -46,14 +45,14 @@ class Map(object):
|
|||||||
# Translate the dx, dy coordinates into map coordinates:
|
# Translate the dx, dy coordinates into map coordinates:
|
||||||
X, Y = cx + dx * xx + dy * xy, cy + dx * yx + dy * yy
|
X, Y = cx + dx * xx + dy * xy, cy + dx * yx + dy * yy
|
||||||
# l_slope and r_slope store the slopes of the left and right
|
# l_slope and r_slope store the slopes of the left and right
|
||||||
# extremities of the square we're considering:
|
# extremities of the square_move we're considering:
|
||||||
l_slope, r_slope = (dx-self.d_slope)/(dy+self.d_slope), (dx+self.d_slope)/(dy-self.d_slope)
|
l_slope, r_slope = (dx-self.d_slope)/(dy+self.d_slope), (dx+self.d_slope)/(dy-self.d_slope)
|
||||||
if start < r_slope:
|
if start < r_slope:
|
||||||
continue
|
continue
|
||||||
elif end > l_slope:
|
elif end > l_slope:
|
||||||
break
|
break
|
||||||
else:
|
else:
|
||||||
# Our light beam is touching this square; light it:
|
# Our light beam is touching this square_move; light it:
|
||||||
if dx*dx + dy*dy < radius_squared:
|
if dx*dx + dy*dy < radius_squared:
|
||||||
self.set_lit(X, Y)
|
self.set_lit(X, Y)
|
||||||
if blocked:
|
if blocked:
|
||||||
@ -66,12 +65,12 @@ class Map(object):
|
|||||||
start = new_start
|
start = new_start
|
||||||
else:
|
else:
|
||||||
if self.blocked(X, Y) and j < radius:
|
if self.blocked(X, Y) and j < radius:
|
||||||
# This is a blocking square, start a child scan:
|
# This is a blocking square_move, start a child scan:
|
||||||
blocked = True
|
blocked = True
|
||||||
self._cast_light(cx, cy, j+1, start, l_slope,
|
self._cast_light(cx, cy, j+1, start, l_slope,
|
||||||
radius, xx, xy, yx, yy, id+1)
|
radius, xx, xy, yx, yy, id+1)
|
||||||
new_start = r_slope
|
new_start = r_slope
|
||||||
# Row is scanned; do next row unless last square was blocked:
|
# Row is scanned; do next row unless last square_move was blocked:
|
||||||
if blocked:
|
if blocked:
|
||||||
break
|
break
|
||||||
|
|
||||||
|
@ -1,275 +0,0 @@
|
|||||||
from typing import Union, NamedTuple
|
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
|
|
||||||
from environments.factory.base.base_factory import BaseFactory
|
|
||||||
from environments.factory.base.objects import Agent, Action, Entity
|
|
||||||
from environments.factory.base.registers import EntityObjectRegister, ObjectRegister
|
|
||||||
from environments.factory.base.renderer import RenderEntity
|
|
||||||
from environments.helpers import Constants as c
|
|
||||||
|
|
||||||
from environments import helpers as h
|
|
||||||
|
|
||||||
|
|
||||||
CHARGE_ACTION = h.EnvActions.CHARGE
|
|
||||||
ITEM_DROP_OFF = 1
|
|
||||||
|
|
||||||
|
|
||||||
class BatteryProperties(NamedTuple):
|
|
||||||
initial_charge: float = 0.8 #
|
|
||||||
charge_rate: float = 0.4 #
|
|
||||||
charge_locations: int = 20 #
|
|
||||||
per_action_costs: Union[dict, float] = 0.02
|
|
||||||
done_when_discharged = False
|
|
||||||
multi_charge: bool = False
|
|
||||||
|
|
||||||
|
|
||||||
class Battery(object):
|
|
||||||
|
|
||||||
@property
|
|
||||||
def is_discharged(self):
|
|
||||||
return self.charge_level == 0
|
|
||||||
|
|
||||||
@property
|
|
||||||
def is_blocking_light(self):
|
|
||||||
return False
|
|
||||||
|
|
||||||
@property
|
|
||||||
def can_collide(self):
|
|
||||||
return False
|
|
||||||
|
|
||||||
@property
|
|
||||||
def name(self):
|
|
||||||
return f'{self.__class__.__name__}({self.agent.name})'
|
|
||||||
|
|
||||||
def __init__(self, pomdp_r: int, level_shape: (int, int), agent: Agent, initial_charge_level: float):
|
|
||||||
super().__init__()
|
|
||||||
self.agent = agent
|
|
||||||
self._pomdp_r = pomdp_r
|
|
||||||
self._level_shape = level_shape
|
|
||||||
if self._pomdp_r:
|
|
||||||
self._array = np.zeros((1, pomdp_r * 2 + 1, pomdp_r * 2 + 1))
|
|
||||||
else:
|
|
||||||
self._array = np.zeros((1, *self._level_shape))
|
|
||||||
self.charge_level = initial_charge_level
|
|
||||||
|
|
||||||
def as_array(self):
|
|
||||||
self._array[:] = c.FREE_CELL.value
|
|
||||||
self._array[0, 0] = self.charge_level
|
|
||||||
return self._array
|
|
||||||
|
|
||||||
def __repr__(self):
|
|
||||||
return f'{self.__class__.__name__}[{self.agent.name}]({self.charge_level})'
|
|
||||||
|
|
||||||
def charge(self, amount) -> c:
|
|
||||||
if self.charge_level < 1:
|
|
||||||
# noinspection PyTypeChecker
|
|
||||||
self.charge_level = min(1, amount + self.charge_level)
|
|
||||||
return c.VALID
|
|
||||||
else:
|
|
||||||
return c.NOT_VALID
|
|
||||||
|
|
||||||
def decharge(self, amount) -> c:
|
|
||||||
if self.charge_level != 0:
|
|
||||||
# noinspection PyTypeChecker
|
|
||||||
self.charge_level = max(0, amount + self.charge_level)
|
|
||||||
return c.VALID
|
|
||||||
else:
|
|
||||||
return c.NOT_VALID
|
|
||||||
|
|
||||||
def belongs_to_entity(self, entity):
|
|
||||||
return self.agent == entity
|
|
||||||
|
|
||||||
def summarize_state(self, **kwargs):
|
|
||||||
attr_dict = {key: str(val) for key, val in self.__dict__.items() if not key.startswith('_') and key != 'data'}
|
|
||||||
attr_dict.update(dict(name=self.name))
|
|
||||||
return attr_dict
|
|
||||||
|
|
||||||
|
|
||||||
class BatteriesRegister(ObjectRegister):
|
|
||||||
|
|
||||||
_accepted_objects = Battery
|
|
||||||
is_blocking_light = False
|
|
||||||
can_be_shadowed = False
|
|
||||||
hide_from_obs_builder = True
|
|
||||||
|
|
||||||
def __init__(self, *args, **kwargs):
|
|
||||||
super(BatteriesRegister, self).__init__(*args, is_per_agent=True, individual_slices=True, **kwargs)
|
|
||||||
self.is_observable = True
|
|
||||||
|
|
||||||
def as_array(self):
|
|
||||||
# self._array[:] = c.FREE_CELL.value
|
|
||||||
for inv_idx, battery in enumerate(self):
|
|
||||||
self._array[inv_idx] = battery.as_array()
|
|
||||||
return self._array
|
|
||||||
|
|
||||||
def spawn_batteries(self, agents, pomdp_r, initial_charge_level):
|
|
||||||
inventories = [self._accepted_objects(pomdp_r, self._level_shape, agent,
|
|
||||||
initial_charge_level)
|
|
||||||
for _, agent in enumerate(agents)]
|
|
||||||
self.register_additional_items(inventories)
|
|
||||||
|
|
||||||
def idx_by_entity(self, entity):
|
|
||||||
try:
|
|
||||||
return next((idx for idx, bat in enumerate(self) if bat.belongs_to_entity(entity)))
|
|
||||||
except StopIteration:
|
|
||||||
return None
|
|
||||||
|
|
||||||
def by_entity(self, entity):
|
|
||||||
try:
|
|
||||||
return next((bat for bat in self if bat.belongs_to_entity(entity)))
|
|
||||||
except StopIteration:
|
|
||||||
return None
|
|
||||||
|
|
||||||
def summarize_states(self, n_steps=None):
|
|
||||||
# as dict with additional nesting
|
|
||||||
# return dict(items=super(Inventories, self).summarize_states())
|
|
||||||
return super(BatteriesRegister, self).summarize_states(n_steps=n_steps)
|
|
||||||
|
|
||||||
|
|
||||||
class ChargePod(Entity):
|
|
||||||
|
|
||||||
@property
|
|
||||||
def can_collide(self):
|
|
||||||
return False
|
|
||||||
|
|
||||||
@property
|
|
||||||
def encoding(self):
|
|
||||||
return ITEM_DROP_OFF
|
|
||||||
|
|
||||||
def __init__(self, *args, charge_rate: float = 0.4,
|
|
||||||
multi_charge: bool = False, **kwargs):
|
|
||||||
super(ChargePod, self).__init__(*args, **kwargs)
|
|
||||||
self.charge_rate = charge_rate
|
|
||||||
self.multi_charge = multi_charge
|
|
||||||
|
|
||||||
def charge_battery(self, battery: Battery):
|
|
||||||
if battery.charge_level == 1.0:
|
|
||||||
return c.NOT_VALID
|
|
||||||
if sum(guest for guest in self.tile.guests if c.AGENT.name in guest.name) > 1:
|
|
||||||
return c.NOT_VALID
|
|
||||||
battery.charge(self.charge_rate)
|
|
||||||
return c.VALID
|
|
||||||
|
|
||||||
def summarize_state(self, n_steps=None) -> dict:
|
|
||||||
if n_steps == h.STEPS_START:
|
|
||||||
summary = super().summarize_state(n_steps=n_steps)
|
|
||||||
return summary
|
|
||||||
|
|
||||||
|
|
||||||
class ChargePods(EntityObjectRegister):
|
|
||||||
|
|
||||||
_accepted_objects = ChargePod
|
|
||||||
|
|
||||||
def as_array(self):
|
|
||||||
self._array[:] = c.FREE_CELL.value
|
|
||||||
for item in self:
|
|
||||||
if item.pos != c.NO_POS.value:
|
|
||||||
self._array[0, item.x, item.y] = item.encoding
|
|
||||||
return self._array
|
|
||||||
|
|
||||||
def __repr__(self):
|
|
||||||
super(ChargePods, self).__repr__()
|
|
||||||
|
|
||||||
|
|
||||||
class BatteryFactory(BaseFactory):
|
|
||||||
|
|
||||||
def __init__(self, *args, btry_prop=BatteryProperties(), **kwargs):
|
|
||||||
if isinstance(btry_prop, dict):
|
|
||||||
btry_prop = BatteryProperties(**btry_prop)
|
|
||||||
self.btry_prop = btry_prop
|
|
||||||
super().__init__(*args, **kwargs)
|
|
||||||
|
|
||||||
@property
|
|
||||||
def additional_entities(self):
|
|
||||||
super_entities = super().additional_entities
|
|
||||||
|
|
||||||
empty_tiles = self[c.FLOOR].empty_tiles[:self.btry_prop.charge_locations]
|
|
||||||
charge_pods = ChargePods.from_tiles(
|
|
||||||
empty_tiles, self._level_shape,
|
|
||||||
entity_kwargs=dict(charge_rate=self.btry_prop.charge_rate,
|
|
||||||
multi_charge=self.btry_prop.multi_charge)
|
|
||||||
)
|
|
||||||
|
|
||||||
batteries = BatteriesRegister(self._level_shape if not self._pomdp_r else ((self.pomdp_diameter,) * 2),
|
|
||||||
)
|
|
||||||
batteries.spawn_batteries(self[c.AGENT], self._pomdp_r, self.btry_prop.initial_charge)
|
|
||||||
super_entities.update({c.BATTERIES: batteries, c.CHARGE_POD: charge_pods})
|
|
||||||
return super_entities
|
|
||||||
|
|
||||||
def do_additional_step(self) -> dict:
|
|
||||||
info_dict = super(BatteryFactory, self).do_additional_step()
|
|
||||||
|
|
||||||
# Decharge
|
|
||||||
batteries = self[c.BATTERIES]
|
|
||||||
|
|
||||||
for agent in self[c.AGENT]:
|
|
||||||
if isinstance(self.btry_prop.per_action_costs, dict):
|
|
||||||
energy_consumption = self.btry_prop.per_action_costs[agent.temp_action]
|
|
||||||
else:
|
|
||||||
energy_consumption = self.btry_prop.per_action_costs
|
|
||||||
|
|
||||||
batteries.by_entity(agent).decharge(energy_consumption)
|
|
||||||
|
|
||||||
return info_dict
|
|
||||||
|
|
||||||
def do_charge(self, agent) -> c:
|
|
||||||
if charge_pod := self[c.CHARGE_POD].by_pos(agent.pos):
|
|
||||||
return charge_pod.charge_battery(self[c.BATTERIES].by_entity(agent))
|
|
||||||
else:
|
|
||||||
return c.NOT_VALID
|
|
||||||
|
|
||||||
def do_additional_actions(self, agent: Agent, action: Action) -> Union[None, c]:
|
|
||||||
valid = super().do_additional_actions(agent, action)
|
|
||||||
if valid is None:
|
|
||||||
if action == CHARGE_ACTION:
|
|
||||||
valid = self.do_charge(agent)
|
|
||||||
return valid
|
|
||||||
else:
|
|
||||||
return None
|
|
||||||
else:
|
|
||||||
return valid
|
|
||||||
pass
|
|
||||||
|
|
||||||
def do_additional_reset(self) -> None:
|
|
||||||
# There is Nothing to reset.
|
|
||||||
pass
|
|
||||||
|
|
||||||
def check_additional_done(self) -> bool:
|
|
||||||
super_done = super(BatteryFactory, self).check_additional_done()
|
|
||||||
if super_done:
|
|
||||||
return super_done
|
|
||||||
else:
|
|
||||||
return self.btry_prop.done_when_discharged and any(battery.is_discharged for battery in self[c.BATTERIES])
|
|
||||||
pass
|
|
||||||
|
|
||||||
def calculate_additional_reward(self, agent: Agent) -> (int, dict):
|
|
||||||
reward, info_dict = super(BatteryFactory, self).calculate_additional_reward(agent)
|
|
||||||
if h.EnvActions.CHARGE == agent.temp_action:
|
|
||||||
if agent.temp_valid:
|
|
||||||
charge_pod = self[c.CHARGE_POD].by_pos(agent.pos)
|
|
||||||
info_dict.update({f'{agent.name}_charge': 1})
|
|
||||||
info_dict.update(agent_charged=1)
|
|
||||||
self.print(f'{agent.name} just charged batteries at {charge_pod.pos}.')
|
|
||||||
reward += 0.1
|
|
||||||
else:
|
|
||||||
self[c.DROP_OFF].by_pos(agent.pos)
|
|
||||||
info_dict.update({f'{agent.name}_failed_charge': 1})
|
|
||||||
info_dict.update(failed_charge=1)
|
|
||||||
self.print(f'{agent.name} just tried to charge at {agent.pos}, but failed.')
|
|
||||||
reward -= 0.1
|
|
||||||
|
|
||||||
if self[c.BATTERIES].by_entity(agent).is_discharged:
|
|
||||||
info_dict.update({f'{agent.name}_discharged': 1})
|
|
||||||
reward -= 1
|
|
||||||
else:
|
|
||||||
info_dict.update({f'{agent.name}_battery_level': self[c.BATTERIES].by_entity(agent).charge_level})
|
|
||||||
return reward, info_dict
|
|
||||||
|
|
||||||
def render_additional_assets(self):
|
|
||||||
# noinspection PyUnresolvedReferences
|
|
||||||
additional_assets = super().render_additional_assets()
|
|
||||||
charge_pods = [RenderEntity(c.CHARGE_POD.value, charge_pod.tile.pos) for charge_pod in self[c.CHARGE_POD]]
|
|
||||||
additional_assets.extend(charge_pods)
|
|
||||||
return additional_assets
|
|
||||||
|
|
@ -1,292 +0,0 @@
|
|||||||
import time
|
|
||||||
from collections import defaultdict
|
|
||||||
from enum import Enum
|
|
||||||
from typing import List, Union, NamedTuple, Dict
|
|
||||||
import numpy as np
|
|
||||||
import random
|
|
||||||
|
|
||||||
from environments.factory.base.base_factory import BaseFactory
|
|
||||||
from environments.helpers import Constants as c
|
|
||||||
from environments import helpers as h
|
|
||||||
from environments.factory.base.objects import Agent, Entity, Action, Tile
|
|
||||||
from environments.factory.base.registers import Entities, MovingEntityObjectRegister
|
|
||||||
|
|
||||||
from environments.factory.base.renderer import RenderEntity
|
|
||||||
|
|
||||||
|
|
||||||
DESTINATION = 1
|
|
||||||
DESTINATION_DONE = 0.5
|
|
||||||
|
|
||||||
|
|
||||||
class Destination(Entity):
|
|
||||||
|
|
||||||
@property
|
|
||||||
def any_agent_has_dwelled(self):
|
|
||||||
return bool(len(self._per_agent_times))
|
|
||||||
|
|
||||||
@property
|
|
||||||
def currently_dwelling_names(self):
|
|
||||||
return self._per_agent_times.keys()
|
|
||||||
|
|
||||||
@property
|
|
||||||
def can_collide(self):
|
|
||||||
return False
|
|
||||||
|
|
||||||
@property
|
|
||||||
def encoding(self):
|
|
||||||
return DESTINATION
|
|
||||||
|
|
||||||
def __init__(self, *args, dwell_time: int = 0, **kwargs):
|
|
||||||
super(Destination, self).__init__(*args, **kwargs)
|
|
||||||
self.dwell_time = dwell_time
|
|
||||||
self._per_agent_times = defaultdict(lambda: dwell_time)
|
|
||||||
|
|
||||||
def wait(self, agent: Agent):
|
|
||||||
self._per_agent_times[agent.name] -= 1
|
|
||||||
return c.VALID
|
|
||||||
|
|
||||||
def leave(self, agent: Agent):
|
|
||||||
del self._per_agent_times[agent.name]
|
|
||||||
|
|
||||||
@property
|
|
||||||
def is_considered_reached(self):
|
|
||||||
agent_at_position = any(c.AGENT.name.lower() in x.name.lower() for x in self.tile.guests_that_can_collide)
|
|
||||||
return (agent_at_position and not self.dwell_time) or any(x == 0 for x in self._per_agent_times.values())
|
|
||||||
|
|
||||||
def agent_is_dwelling(self, agent: Agent):
|
|
||||||
return self._per_agent_times[agent.name] < self.dwell_time
|
|
||||||
|
|
||||||
def summarize_state(self, n_steps=None) -> dict:
|
|
||||||
state_summary = super().summarize_state(n_steps=n_steps)
|
|
||||||
state_summary.update(per_agent_times=self._per_agent_times)
|
|
||||||
return state_summary
|
|
||||||
|
|
||||||
|
|
||||||
class Destinations(MovingEntityObjectRegister):
|
|
||||||
|
|
||||||
_accepted_objects = Destination
|
|
||||||
_light_blocking = False
|
|
||||||
|
|
||||||
def as_array(self):
|
|
||||||
self._array[:] = c.FREE_CELL.value
|
|
||||||
for item in self:
|
|
||||||
if item.pos != c.NO_POS.value:
|
|
||||||
self._array[0, item.x, item.y] = item.encoding
|
|
||||||
return self._array
|
|
||||||
|
|
||||||
def __repr__(self):
|
|
||||||
super(Destinations, self).__repr__()
|
|
||||||
|
|
||||||
|
|
||||||
class ReachedDestinations(Destinations):
|
|
||||||
_accepted_objects = Destination
|
|
||||||
_light_blocking = False
|
|
||||||
|
|
||||||
def __init__(self, *args, **kwargs):
|
|
||||||
super(ReachedDestinations, self).__init__(*args, is_observable=False, **kwargs)
|
|
||||||
|
|
||||||
def summarize_states(self, n_steps=None):
|
|
||||||
return {}
|
|
||||||
|
|
||||||
|
|
||||||
class DestSpawnMode(object):
|
|
||||||
DONE = 'DONE'
|
|
||||||
GROUPED = 'GROUPED'
|
|
||||||
PER_DEST = 'PER_DEST'
|
|
||||||
|
|
||||||
|
|
||||||
class DestinationProperties(NamedTuple):
|
|
||||||
n_dests: int = 1 # How many destinations are there
|
|
||||||
dwell_time: int = 0 # How long does the agent need to "wait" on a destination
|
|
||||||
spawn_frequency: int = 0
|
|
||||||
spawn_in_other_zone: bool = True #
|
|
||||||
spawn_mode: str = DestSpawnMode.DONE
|
|
||||||
|
|
||||||
assert dwell_time >= 0, 'dwell_time cannot be < 0!'
|
|
||||||
assert spawn_frequency >= 0, 'spawn_frequency cannot be < 0!'
|
|
||||||
assert n_dests >= 0, 'n_destinations cannot be < 0!'
|
|
||||||
assert (spawn_mode == DestSpawnMode.DONE) != bool(spawn_frequency)
|
|
||||||
|
|
||||||
|
|
||||||
# noinspection PyAttributeOutsideInit, PyAbstractClass
|
|
||||||
class DestinationFactory(BaseFactory):
|
|
||||||
# noinspection PyMissingConstructor
|
|
||||||
|
|
||||||
def __init__(self, *args, dest_prop: DestinationProperties = DestinationProperties(),
|
|
||||||
env_seed=time.time_ns(), **kwargs):
|
|
||||||
if isinstance(dest_prop, dict):
|
|
||||||
dest_prop = DestinationProperties(**dest_prop)
|
|
||||||
self.dest_prop = dest_prop
|
|
||||||
kwargs.update(env_seed=env_seed)
|
|
||||||
self._dest_rng = np.random.default_rng(env_seed)
|
|
||||||
super().__init__(*args, **kwargs)
|
|
||||||
|
|
||||||
@property
|
|
||||||
def additional_actions(self) -> Union[Action, List[Action]]:
|
|
||||||
# noinspection PyUnresolvedReferences
|
|
||||||
super_actions = super().additional_actions
|
|
||||||
if self.dest_prop.dwell_time:
|
|
||||||
super_actions.append(Action(enum_ident=h.EnvActions.WAIT_ON_DEST))
|
|
||||||
return super_actions
|
|
||||||
|
|
||||||
@property
|
|
||||||
def additional_entities(self) -> Dict[(Enum, Entities)]:
|
|
||||||
# noinspection PyUnresolvedReferences
|
|
||||||
super_entities = super().additional_entities
|
|
||||||
|
|
||||||
empty_tiles = self[c.FLOOR].empty_tiles[:self.dest_prop.n_dests]
|
|
||||||
destinations = Destinations.from_tiles(
|
|
||||||
empty_tiles, self._level_shape,
|
|
||||||
entity_kwargs=dict(
|
|
||||||
dwell_time=self.dest_prop.dwell_time)
|
|
||||||
)
|
|
||||||
reached_destinations = ReachedDestinations(level_shape=self._level_shape)
|
|
||||||
|
|
||||||
super_entities.update({c.DESTINATION: destinations, c.REACHEDDESTINATION: reached_destinations})
|
|
||||||
return super_entities
|
|
||||||
|
|
||||||
def additional_per_agent_obs_build(self, agent) -> List[np.ndarray]:
|
|
||||||
additional_per_agent_obs_build = super().additional_per_agent_obs_build(agent)
|
|
||||||
return additional_per_agent_obs_build
|
|
||||||
|
|
||||||
def wait(self, agent: Agent):
|
|
||||||
if destiantion := self[c.DESTINATION].by_pos(agent.pos):
|
|
||||||
valid = destiantion.wait(agent)
|
|
||||||
return valid
|
|
||||||
else:
|
|
||||||
return c.NOT_VALID
|
|
||||||
|
|
||||||
def do_additional_actions(self, agent: Agent, action: Action) -> Union[None, c]:
|
|
||||||
# noinspection PyUnresolvedReferences
|
|
||||||
valid = super().do_additional_actions(agent, action)
|
|
||||||
if valid is None:
|
|
||||||
if action == h.EnvActions.WAIT_ON_DEST:
|
|
||||||
valid = self.wait(agent)
|
|
||||||
return valid
|
|
||||||
else:
|
|
||||||
return None
|
|
||||||
else:
|
|
||||||
return valid
|
|
||||||
|
|
||||||
def do_additional_reset(self) -> None:
|
|
||||||
# noinspection PyUnresolvedReferences
|
|
||||||
super().do_additional_reset()
|
|
||||||
self._dest_spawn_timer = dict()
|
|
||||||
|
|
||||||
def trigger_destination_spawn(self):
|
|
||||||
destinations_to_spawn = [key for key, val in self._dest_spawn_timer.items()
|
|
||||||
if val == self.dest_prop.spawn_frequency]
|
|
||||||
if destinations_to_spawn:
|
|
||||||
n_dest_to_spawn = len(destinations_to_spawn)
|
|
||||||
if self.dest_prop.spawn_mode != DestSpawnMode.GROUPED:
|
|
||||||
destinations = [Destination(tile) for tile in self[c.FLOOR].empty_tiles[:n_dest_to_spawn]]
|
|
||||||
self[c.DESTINATION].register_additional_items(destinations)
|
|
||||||
for dest in destinations_to_spawn:
|
|
||||||
del self._dest_spawn_timer[dest]
|
|
||||||
self.print(f'{n_dest_to_spawn} new destinations have been spawned')
|
|
||||||
elif self.dest_prop.spawn_mode == DestSpawnMode.GROUPED and n_dest_to_spawn == self.dest_prop.n_dests:
|
|
||||||
destinations = [Destination(tile) for tile in self[c.FLOOR].empty_tiles[:n_dest_to_spawn]]
|
|
||||||
self[c.DESTINATION].register_additional_items(destinations)
|
|
||||||
for dest in destinations_to_spawn:
|
|
||||||
del self._dest_spawn_timer[dest]
|
|
||||||
self.print(f'{n_dest_to_spawn} new destinations have been spawned')
|
|
||||||
else:
|
|
||||||
self.print(f'{n_dest_to_spawn} new destinations could be spawned, but waiting for all.')
|
|
||||||
pass
|
|
||||||
else:
|
|
||||||
self.print('No Items are spawning, limit is reached.')
|
|
||||||
|
|
||||||
def do_additional_step(self) -> dict:
|
|
||||||
# noinspection PyUnresolvedReferences
|
|
||||||
info_dict = super().do_additional_step()
|
|
||||||
for key, val in self._dest_spawn_timer.items():
|
|
||||||
self._dest_spawn_timer[key] = min(self.dest_prop.spawn_frequency, self._dest_spawn_timer[key] + 1)
|
|
||||||
for dest in list(self[c.DESTINATION].values()):
|
|
||||||
if dest.is_considered_reached:
|
|
||||||
self[c.REACHEDDESTINATION].register_item(dest)
|
|
||||||
self[c.DESTINATION].delete_entity(dest)
|
|
||||||
self._dest_spawn_timer[dest.name] = 0
|
|
||||||
self.print(f'{dest.name} is reached now, removing...')
|
|
||||||
else:
|
|
||||||
for agent_name in dest.currently_dwelling_names:
|
|
||||||
agent = self[c.AGENT].by_name(agent_name)
|
|
||||||
if agent.pos == dest.pos:
|
|
||||||
self.print(f'{agent.name} is still waiting.')
|
|
||||||
pass
|
|
||||||
else:
|
|
||||||
dest.leave(agent)
|
|
||||||
self.print(f'{agent.name} left the destination early.')
|
|
||||||
self.trigger_destination_spawn()
|
|
||||||
return info_dict
|
|
||||||
|
|
||||||
def calculate_additional_reward(self, agent: Agent) -> (int, dict):
|
|
||||||
# noinspection PyUnresolvedReferences
|
|
||||||
reward, info_dict = super().calculate_additional_reward(agent)
|
|
||||||
if h.EnvActions.WAIT_ON_DEST == agent.temp_action:
|
|
||||||
if agent.temp_valid:
|
|
||||||
info_dict.update({f'{agent.name}_waiting_at_dest': 1})
|
|
||||||
info_dict.update(agent_waiting_at_dest=1)
|
|
||||||
self.print(f'{agent.name} just waited at {agent.pos}')
|
|
||||||
reward += 0.1
|
|
||||||
else:
|
|
||||||
info_dict.update({f'{agent.name}_tried_failed': 1})
|
|
||||||
info_dict.update(agent_waiting_failed=1)
|
|
||||||
self.print(f'{agent.name} just tried to wait wait at {agent.pos} but failed')
|
|
||||||
reward -= 0.1
|
|
||||||
if len(self[c.REACHEDDESTINATION]):
|
|
||||||
for reached_dest in list(self[c.REACHEDDESTINATION]):
|
|
||||||
if agent.pos == reached_dest.pos:
|
|
||||||
info_dict.update({f'{agent.name}_reached_destination': 1})
|
|
||||||
info_dict.update(agent_reached_destination=1)
|
|
||||||
self.print(f'{agent.name} just reached destination at {agent.pos}')
|
|
||||||
reward += 0.5
|
|
||||||
self[c.REACHEDDESTINATION].delete_entity(reached_dest)
|
|
||||||
return reward, info_dict
|
|
||||||
|
|
||||||
def render_additional_assets(self, mode='human'):
|
|
||||||
# noinspection PyUnresolvedReferences
|
|
||||||
additional_assets = super().render_additional_assets()
|
|
||||||
destinations = [RenderEntity(c.DESTINATION.value, dest.pos) for dest in self[c.DESTINATION]]
|
|
||||||
additional_assets.extend(destinations)
|
|
||||||
return additional_assets
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
|
||||||
from environments.utility_classes import AgentRenderOptions as ARO, ObservationProperties
|
|
||||||
|
|
||||||
render = True
|
|
||||||
|
|
||||||
dest_probs = DestinationProperties(n_dests=2, spawn_frequency=5, spawn_mode=DestSpawnMode.GROUPED)
|
|
||||||
|
|
||||||
obs_props = ObservationProperties(render_agents=ARO.LEVEL, omit_agent_self=True, pomdp_r=2)
|
|
||||||
|
|
||||||
move_props = {'allow_square_movement': True,
|
|
||||||
'allow_diagonal_movement': False,
|
|
||||||
'allow_no_op': False}
|
|
||||||
|
|
||||||
factory = DestinationFactory(n_agents=10, done_at_collision=False,
|
|
||||||
level_name='rooms', max_steps=400,
|
|
||||||
obs_prop=obs_props, parse_doors=True,
|
|
||||||
verbose=True,
|
|
||||||
mv_prop=move_props, dest_prop=dest_probs
|
|
||||||
)
|
|
||||||
|
|
||||||
# noinspection DuplicatedCode
|
|
||||||
n_actions = factory.action_space.n - 1
|
|
||||||
_ = factory.observation_space
|
|
||||||
|
|
||||||
for epoch in range(4):
|
|
||||||
random_actions = [[random.randint(0, n_actions) for _
|
|
||||||
in range(factory.n_agents)] for _
|
|
||||||
in range(factory.max_steps + 1)]
|
|
||||||
env_state = factory.reset()
|
|
||||||
r = 0
|
|
||||||
for agent_i_action in random_actions:
|
|
||||||
env_state, step_r, done_bool, info_obj = factory.step(agent_i_action)
|
|
||||||
r += step_r
|
|
||||||
if render:
|
|
||||||
factory.render()
|
|
||||||
if done_bool:
|
|
||||||
break
|
|
||||||
print(f'Factory run {epoch} done, reward is:\n {r}')
|
|
||||||
pass
|
|
@ -1,318 +0,0 @@
|
|||||||
import time
|
|
||||||
from enum import Enum
|
|
||||||
from typing import List, Union, NamedTuple, Dict
|
|
||||||
import random
|
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
|
|
||||||
from algorithms.TSP_dirt_agent import TSPDirtAgent
|
|
||||||
from environments.helpers import Constants as c
|
|
||||||
from environments import helpers as h
|
|
||||||
from environments.factory.base.base_factory import BaseFactory
|
|
||||||
from environments.factory.base.objects import Agent, Action, Entity, Tile
|
|
||||||
from environments.factory.base.registers import Entities, MovingEntityObjectRegister
|
|
||||||
|
|
||||||
from environments.factory.base.renderer import RenderEntity
|
|
||||||
from environments.utility_classes import ObservationProperties
|
|
||||||
|
|
||||||
CLEAN_UP_ACTION = h.EnvActions.CLEAN_UP
|
|
||||||
|
|
||||||
|
|
||||||
class DirtProperties(NamedTuple):
|
|
||||||
initial_dirt_ratio: float = 0.3 # On INIT, on max how much tiles does the dirt spawn in percent.
|
|
||||||
initial_dirt_spawn_r_var: float = 0.05 # How much does the dirt spawn amount vary?
|
|
||||||
clean_amount: float = 1 # How much does the robot clean with one actions.
|
|
||||||
max_spawn_ratio: float = 0.20 # On max how much tiles does the dirt spawn in percent.
|
|
||||||
max_spawn_amount: float = 0.3 # How much dirt does spawn per tile at max.
|
|
||||||
spawn_frequency: int = 0 # Spawn Frequency in Steps.
|
|
||||||
max_local_amount: int = 2 # Max dirt amount per tile.
|
|
||||||
max_global_amount: int = 20 # Max dirt amount in the whole environment.
|
|
||||||
dirt_smear_amount: float = 0.2 # Agents smear dirt, when not cleaning up in place.
|
|
||||||
agent_can_interact: bool = True # Whether the agents can interact with the dirt in this environment.
|
|
||||||
done_when_clean: bool = True
|
|
||||||
|
|
||||||
|
|
||||||
class Dirt(Entity):
|
|
||||||
|
|
||||||
@property
|
|
||||||
def can_collide(self):
|
|
||||||
return False
|
|
||||||
|
|
||||||
@property
|
|
||||||
def amount(self):
|
|
||||||
return self._amount
|
|
||||||
|
|
||||||
def encoding(self):
|
|
||||||
# Edit this if you want items to be drawn in the ops differntly
|
|
||||||
return self._amount
|
|
||||||
|
|
||||||
def __init__(self, *args, amount=None, **kwargs):
|
|
||||||
super(Dirt, self).__init__(*args, **kwargs)
|
|
||||||
self._amount = amount
|
|
||||||
|
|
||||||
def set_new_amount(self, amount):
|
|
||||||
self._amount = amount
|
|
||||||
|
|
||||||
def summarize_state(self, **kwargs):
|
|
||||||
state_dict = super().summarize_state(**kwargs)
|
|
||||||
state_dict.update(amount=float(self.amount))
|
|
||||||
return state_dict
|
|
||||||
|
|
||||||
|
|
||||||
class DirtRegister(MovingEntityObjectRegister):
|
|
||||||
|
|
||||||
def as_array(self):
|
|
||||||
if self._array is not None:
|
|
||||||
self._array[:] = c.FREE_CELL.value
|
|
||||||
for dirt in list(self.values()):
|
|
||||||
if dirt.amount == 0:
|
|
||||||
self.delete_entity(dirt)
|
|
||||||
self._array[0, dirt.x, dirt.y] = dirt.amount
|
|
||||||
else:
|
|
||||||
self._array = np.zeros((1, *self._level_shape))
|
|
||||||
return self._array
|
|
||||||
|
|
||||||
_accepted_objects = Dirt
|
|
||||||
|
|
||||||
@property
|
|
||||||
def amount(self):
|
|
||||||
return sum([dirt.amount for dirt in self])
|
|
||||||
|
|
||||||
@property
|
|
||||||
def dirt_properties(self):
|
|
||||||
return self._dirt_properties
|
|
||||||
|
|
||||||
def __init__(self, dirt_properties, *args):
|
|
||||||
super(DirtRegister, self).__init__(*args)
|
|
||||||
self._dirt_properties: DirtProperties = dirt_properties
|
|
||||||
|
|
||||||
def spawn_dirt(self, then_dirty_tiles) -> c:
|
|
||||||
if isinstance(then_dirty_tiles, Tile):
|
|
||||||
then_dirty_tiles = [then_dirty_tiles]
|
|
||||||
for tile in then_dirty_tiles:
|
|
||||||
if not self.amount > self.dirt_properties.max_global_amount:
|
|
||||||
dirt = self.by_pos(tile.pos)
|
|
||||||
if dirt is None:
|
|
||||||
dirt = Dirt(tile, amount=self.dirt_properties.max_spawn_amount)
|
|
||||||
self.register_item(dirt)
|
|
||||||
else:
|
|
||||||
new_value = dirt.amount + self.dirt_properties.max_spawn_amount
|
|
||||||
dirt.set_new_amount(min(new_value, self.dirt_properties.max_local_amount))
|
|
||||||
else:
|
|
||||||
return c.NOT_VALID
|
|
||||||
return c.VALID
|
|
||||||
|
|
||||||
def __repr__(self):
|
|
||||||
s = super(DirtRegister, self).__repr__()
|
|
||||||
return f'{s[:-1]}, {self.amount})'
|
|
||||||
|
|
||||||
|
|
||||||
def softmax(x):
|
|
||||||
"""Compute softmax values for each sets of scores in x."""
|
|
||||||
e_x = np.exp(x - np.max(x))
|
|
||||||
return e_x / e_x.sum()
|
|
||||||
|
|
||||||
|
|
||||||
def entropy(x):
|
|
||||||
return -(x * np.log(x + 1e-8)).sum()
|
|
||||||
|
|
||||||
|
|
||||||
# noinspection PyAttributeOutsideInit, PyAbstractClass
|
|
||||||
class DirtFactory(BaseFactory):
|
|
||||||
|
|
||||||
@property
|
|
||||||
def additional_actions(self) -> Union[Action, List[Action]]:
|
|
||||||
super_actions = super().additional_actions
|
|
||||||
if self.dirt_prop.agent_can_interact:
|
|
||||||
super_actions.append(Action(enum_ident=CLEAN_UP_ACTION))
|
|
||||||
return super_actions
|
|
||||||
|
|
||||||
@property
|
|
||||||
def additional_entities(self) -> Dict[(Enum, Entities)]:
|
|
||||||
super_entities = super().additional_entities
|
|
||||||
dirt_register = DirtRegister(self.dirt_prop, self._level_shape)
|
|
||||||
super_entities.update(({c.DIRT: dirt_register}))
|
|
||||||
return super_entities
|
|
||||||
|
|
||||||
def __init__(self, *args, dirt_prop: DirtProperties = DirtProperties(), env_seed=time.time_ns(), **kwargs):
|
|
||||||
if isinstance(dirt_prop, dict):
|
|
||||||
dirt_prop = DirtProperties(**dirt_prop)
|
|
||||||
self.dirt_prop = dirt_prop
|
|
||||||
self._dirt_rng = np.random.default_rng(env_seed)
|
|
||||||
self._dirt: DirtRegister
|
|
||||||
kwargs.update(env_seed=env_seed)
|
|
||||||
super().__init__(*args, **kwargs)
|
|
||||||
|
|
||||||
def render_additional_assets(self, mode='human'):
|
|
||||||
additional_assets = super().render_additional_assets()
|
|
||||||
dirt = [RenderEntity('dirt', dirt.tile.pos, min(0.15 + dirt.amount, 1.5), 'scale')
|
|
||||||
for dirt in self[c.DIRT]]
|
|
||||||
additional_assets.extend(dirt)
|
|
||||||
return additional_assets
|
|
||||||
|
|
||||||
def clean_up(self, agent: Agent) -> c:
|
|
||||||
if dirt := self[c.DIRT].by_pos(agent.pos):
|
|
||||||
new_dirt_amount = dirt.amount - self.dirt_prop.clean_amount
|
|
||||||
|
|
||||||
if new_dirt_amount <= 0:
|
|
||||||
self[c.DIRT].delete_entity(dirt)
|
|
||||||
else:
|
|
||||||
dirt.set_new_amount(max(new_dirt_amount, c.FREE_CELL.value))
|
|
||||||
return c.VALID
|
|
||||||
else:
|
|
||||||
return c.NOT_VALID
|
|
||||||
|
|
||||||
def trigger_dirt_spawn(self, initial_spawn=False):
|
|
||||||
dirt_rng = self._dirt_rng
|
|
||||||
free_for_dirt = [x for x in self[c.FLOOR]
|
|
||||||
if len(x.guests) == 0 or (len(x.guests) == 1 and isinstance(next(y for y in x.guests), Dirt))
|
|
||||||
]
|
|
||||||
self._dirt_rng.shuffle(free_for_dirt)
|
|
||||||
if initial_spawn:
|
|
||||||
var = self.dirt_prop.initial_dirt_spawn_r_var
|
|
||||||
new_spawn = self.dirt_prop.initial_dirt_ratio + dirt_rng.uniform(-var, var)
|
|
||||||
else:
|
|
||||||
new_spawn = dirt_rng.uniform(0, self.dirt_prop.max_spawn_ratio)
|
|
||||||
n_dirt_tiles = max(0, int(new_spawn * len(free_for_dirt)))
|
|
||||||
self[c.DIRT].spawn_dirt(free_for_dirt[:n_dirt_tiles])
|
|
||||||
|
|
||||||
def do_additional_step(self) -> dict:
|
|
||||||
info_dict = super().do_additional_step()
|
|
||||||
if smear_amount := self.dirt_prop.dirt_smear_amount:
|
|
||||||
for agent in self[c.AGENT]:
|
|
||||||
if agent.temp_valid and agent.last_pos != c.NO_POS:
|
|
||||||
if self._actions.is_moving_action(agent.temp_action):
|
|
||||||
if old_pos_dirt := self[c.DIRT].by_pos(agent.last_pos):
|
|
||||||
if smeared_dirt := round(old_pos_dirt.amount * smear_amount, 2):
|
|
||||||
old_pos_dirt.set_new_amount(max(0, old_pos_dirt.amount-smeared_dirt))
|
|
||||||
if new_pos_dirt := self[c.DIRT].by_pos(agent.pos):
|
|
||||||
new_pos_dirt.set_new_amount(max(0, new_pos_dirt.amount + smeared_dirt))
|
|
||||||
else:
|
|
||||||
if self[c.DIRT].spawn_dirt(agent.tile):
|
|
||||||
new_pos_dirt = self[c.DIRT].by_pos(agent.pos)
|
|
||||||
new_pos_dirt.set_new_amount(max(0, new_pos_dirt.amount + smeared_dirt))
|
|
||||||
if self._next_dirt_spawn < 0:
|
|
||||||
pass # No Dirt Spawn
|
|
||||||
elif not self._next_dirt_spawn:
|
|
||||||
self.trigger_dirt_spawn()
|
|
||||||
self._next_dirt_spawn = self.dirt_prop.spawn_frequency
|
|
||||||
else:
|
|
||||||
self._next_dirt_spawn -= 1
|
|
||||||
return info_dict
|
|
||||||
|
|
||||||
def do_additional_actions(self, agent: Agent, action: Action) -> Union[None, c]:
|
|
||||||
valid = super().do_additional_actions(agent, action)
|
|
||||||
if valid is None:
|
|
||||||
if action == CLEAN_UP_ACTION:
|
|
||||||
if self.dirt_prop.agent_can_interact:
|
|
||||||
valid = self.clean_up(agent)
|
|
||||||
return valid
|
|
||||||
else:
|
|
||||||
return c.NOT_VALID
|
|
||||||
else:
|
|
||||||
return None
|
|
||||||
else:
|
|
||||||
return valid
|
|
||||||
|
|
||||||
def do_additional_reset(self) -> None:
|
|
||||||
super().do_additional_reset()
|
|
||||||
self.trigger_dirt_spawn(initial_spawn=True)
|
|
||||||
self._next_dirt_spawn = self.dirt_prop.spawn_frequency if self.dirt_prop.spawn_frequency else -1
|
|
||||||
|
|
||||||
def check_additional_done(self):
|
|
||||||
super_done = super().check_additional_done()
|
|
||||||
done = self.dirt_prop.done_when_clean and (len(self[c.DIRT]) == 0)
|
|
||||||
return super_done or done
|
|
||||||
|
|
||||||
def calculate_additional_reward(self, agent: Agent) -> (int, dict):
|
|
||||||
reward, info_dict = super().calculate_additional_reward(agent)
|
|
||||||
dirt = [dirt.amount for dirt in self[c.DIRT]]
|
|
||||||
current_dirt_amount = sum(dirt)
|
|
||||||
dirty_tile_count = len(dirt)
|
|
||||||
# if dirty_tile_count:
|
|
||||||
# dirt_distribution_score = entropy(softmax(np.asarray(dirt)) / dirty_tile_count)
|
|
||||||
#else:
|
|
||||||
# dirt_distribution_score = 0
|
|
||||||
|
|
||||||
info_dict.update(dirt_amount=current_dirt_amount)
|
|
||||||
info_dict.update(dirty_tile_count=dirty_tile_count)
|
|
||||||
# info_dict.update(dirt_distribution_score=dirt_distribution_score)
|
|
||||||
|
|
||||||
if agent.temp_action == CLEAN_UP_ACTION:
|
|
||||||
if agent.temp_valid:
|
|
||||||
# Reward if pickup succeds,
|
|
||||||
# 0.5 on every pickup
|
|
||||||
reward += 0.5
|
|
||||||
info_dict.update(dirt_cleaned=1)
|
|
||||||
if self.dirt_prop.done_when_clean and (len(self[c.DIRT]) == 0):
|
|
||||||
# 0.5 additional reward for the very last pickup
|
|
||||||
reward += 4.5
|
|
||||||
info_dict.update(done_clean=1)
|
|
||||||
self.print(f'{agent.name} did just clean up some dirt at {agent.pos}.')
|
|
||||||
else:
|
|
||||||
reward -= 0.01
|
|
||||||
self.print(f'{agent.name} just tried to clean up some dirt at {agent.pos}, but failed.')
|
|
||||||
info_dict.update({f'{agent.name}_failed_dirt_cleanup': 1})
|
|
||||||
info_dict.update(failed_dirt_clean=1)
|
|
||||||
|
|
||||||
# Potential based rewards ->
|
|
||||||
# track the last reward , minus the current reward = potential
|
|
||||||
return reward, info_dict
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
|
||||||
from environments.utility_classes import AgentRenderOptions as ARO
|
|
||||||
render = True
|
|
||||||
|
|
||||||
dirt_props = DirtProperties(
|
|
||||||
initial_dirt_ratio=0.35,
|
|
||||||
initial_dirt_spawn_r_var=0.1,
|
|
||||||
clean_amount=0.34,
|
|
||||||
max_spawn_amount=0.1,
|
|
||||||
max_global_amount=20,
|
|
||||||
max_local_amount=1,
|
|
||||||
spawn_frequency=0,
|
|
||||||
max_spawn_ratio=0.05,
|
|
||||||
dirt_smear_amount=0.0,
|
|
||||||
agent_can_interact=True
|
|
||||||
)
|
|
||||||
|
|
||||||
obs_props = ObservationProperties(render_agents=ARO.COMBINED, omit_agent_self=True,
|
|
||||||
pomdp_r=2, additional_agent_placeholder=None)
|
|
||||||
|
|
||||||
move_props = {'allow_square_movement': True,
|
|
||||||
'allow_diagonal_movement': False,
|
|
||||||
'allow_no_op': False}
|
|
||||||
|
|
||||||
factory = DirtFactory(n_agents=1, done_at_collision=False,
|
|
||||||
level_name='rooms', max_steps=400,
|
|
||||||
doors_have_area=False,
|
|
||||||
obs_prop=obs_props, parse_doors=True,
|
|
||||||
record_episodes=True, verbose=True,
|
|
||||||
mv_prop=move_props, dirt_prop=dirt_props,
|
|
||||||
inject_agents=[TSPDirtAgent]
|
|
||||||
)
|
|
||||||
|
|
||||||
# noinspection DuplicatedCode
|
|
||||||
n_actions = factory.action_space.n - 1
|
|
||||||
_ = factory.observation_space
|
|
||||||
|
|
||||||
for epoch in range(10):
|
|
||||||
random_actions = [[random.randint(0, n_actions) for _
|
|
||||||
in range(factory.n_agents)] for _
|
|
||||||
in range(factory.max_steps+1)]
|
|
||||||
env_state = factory.reset()
|
|
||||||
if render:
|
|
||||||
factory.render()
|
|
||||||
tsp_agent = factory.get_injected_agents()[0]
|
|
||||||
|
|
||||||
r = 0
|
|
||||||
for agent_i_action in random_actions:
|
|
||||||
env_state, step_r, done_bool, info_obj = factory.step(tsp_agent.predict())
|
|
||||||
r += step_r
|
|
||||||
if render:
|
|
||||||
factory.render()
|
|
||||||
if done_bool:
|
|
||||||
break
|
|
||||||
print(f'Factory run {epoch} done, reward is:\n {r}')
|
|
||||||
pass
|
|
59
environments/factory/factory_dirt_stationary_machines.py
Normal file
59
environments/factory/factory_dirt_stationary_machines.py
Normal file
@ -0,0 +1,59 @@
|
|||||||
|
from typing import Dict, List, Union
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from environments.factory.base.objects import Agent, Entity, Action
|
||||||
|
from environments.factory.factory_dirt import DirtFactory
|
||||||
|
from environments.factory.additional.dirt.dirt_collections import DirtPiles
|
||||||
|
from environments.factory.additional.dirt.dirt_entity import DirtPile
|
||||||
|
from environments.factory.base.objects import Floor
|
||||||
|
from environments.factory.base.registers import Floors, Entities, EntityCollection
|
||||||
|
|
||||||
|
|
||||||
|
class Machines(EntityCollection):
|
||||||
|
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
class Machine(Entity):
|
||||||
|
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
class StationaryMachinesDirtFactory(DirtFactory):
|
||||||
|
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
self._machine_coords = [(6, 6), (12, 13)]
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
|
||||||
|
def entities_hook(self) -> Dict[(str, Entities)]:
|
||||||
|
super_entities = super().entities_hook()
|
||||||
|
return super_entities
|
||||||
|
|
||||||
|
def reset_hook(self) -> None:
|
||||||
|
pass
|
||||||
|
|
||||||
|
def observations_hook(self) -> Dict[str, np.typing.ArrayLike]:
|
||||||
|
pass
|
||||||
|
|
||||||
|
def actions_hook(self) -> Union[Action, List[Action]]:
|
||||||
|
pass
|
||||||
|
|
||||||
|
def step_hook(self) -> (List[dict], dict):
|
||||||
|
|
||||||
|
pass
|
||||||
|
|
||||||
|
def per_agent_raw_observations_hook(self, agent) -> Dict[str, np.typing.ArrayLike]:
|
||||||
|
super_per_agent_raw_observations = super().per_agent_raw_observations_hook(agent)
|
||||||
|
return super_per_agent_raw_observations
|
||||||
|
|
||||||
|
def per_agent_reward_hook(self, agent: Agent) -> List[dict]:
|
||||||
|
return super(StationaryMachinesDirtFactory, self).per_agent_reward_hook(agent)
|
||||||
|
|
||||||
|
def pre_step_hook(self) -> None:
|
||||||
|
pass
|
||||||
|
|
||||||
|
def post_step_hook(self) -> dict:
|
||||||
|
pass
|
@ -1,397 +0,0 @@
|
|||||||
import time
|
|
||||||
from collections import deque, UserList
|
|
||||||
from enum import Enum
|
|
||||||
from typing import List, Union, NamedTuple, Dict
|
|
||||||
import numpy as np
|
|
||||||
import random
|
|
||||||
|
|
||||||
from environments.factory.base.base_factory import BaseFactory
|
|
||||||
from environments.helpers import Constants as c
|
|
||||||
from environments import helpers as h
|
|
||||||
from environments.factory.base.objects import Agent, Entity, Action, Tile, MoveableEntity
|
|
||||||
from environments.factory.base.registers import Entities, EntityObjectRegister, ObjectRegister, \
|
|
||||||
MovingEntityObjectRegister
|
|
||||||
|
|
||||||
from environments.factory.base.renderer import RenderEntity
|
|
||||||
|
|
||||||
|
|
||||||
NO_ITEM = 0
|
|
||||||
ITEM_DROP_OFF = 1
|
|
||||||
|
|
||||||
|
|
||||||
class Item(MoveableEntity):
|
|
||||||
|
|
||||||
def __init__(self, *args, **kwargs):
|
|
||||||
super().__init__(*args, **kwargs)
|
|
||||||
self._auto_despawn = -1
|
|
||||||
|
|
||||||
@property
|
|
||||||
def auto_despawn(self):
|
|
||||||
return self._auto_despawn
|
|
||||||
|
|
||||||
@property
|
|
||||||
def can_collide(self):
|
|
||||||
return False
|
|
||||||
|
|
||||||
@property
|
|
||||||
def encoding(self):
|
|
||||||
# Edit this if you want items to be drawn in the ops differently
|
|
||||||
return 1
|
|
||||||
|
|
||||||
def set_auto_despawn(self, auto_despawn):
|
|
||||||
self._auto_despawn = auto_despawn
|
|
||||||
|
|
||||||
|
|
||||||
class ItemRegister(MovingEntityObjectRegister):
|
|
||||||
|
|
||||||
def as_array(self):
|
|
||||||
self._array[:] = c.FREE_CELL.value
|
|
||||||
for item in self:
|
|
||||||
if item.pos != c.NO_POS.value:
|
|
||||||
self._array[0, item.x, item.y] = item.encoding
|
|
||||||
return self._array
|
|
||||||
|
|
||||||
_accepted_objects = Item
|
|
||||||
|
|
||||||
def spawn_items(self, tiles: List[Tile]):
|
|
||||||
items = [Item(tile) for tile in tiles]
|
|
||||||
self.register_additional_items(items)
|
|
||||||
|
|
||||||
def despawn_items(self, items: List[Item]):
|
|
||||||
items = [items] if isinstance(items, Item) else items
|
|
||||||
for item in items:
|
|
||||||
del self[item]
|
|
||||||
|
|
||||||
|
|
||||||
class Inventory(UserList):
|
|
||||||
|
|
||||||
@property
|
|
||||||
def is_blocking_light(self):
|
|
||||||
return False
|
|
||||||
|
|
||||||
@property
|
|
||||||
def name(self):
|
|
||||||
return f'{self.__class__.__name__}({self.agent.name})'
|
|
||||||
|
|
||||||
def __init__(self, pomdp_r: int, level_shape: (int, int), agent: Agent, capacity: int):
|
|
||||||
super(Inventory, self).__init__()
|
|
||||||
self.agent = agent
|
|
||||||
self.pomdp_r = pomdp_r
|
|
||||||
self._level_shape = level_shape
|
|
||||||
if self.pomdp_r:
|
|
||||||
self._array = np.zeros((1, pomdp_r * 2 + 1, pomdp_r * 2 + 1))
|
|
||||||
else:
|
|
||||||
self._array = np.zeros((1, *self._level_shape))
|
|
||||||
self.capacity = min(capacity, self._array.size)
|
|
||||||
|
|
||||||
def as_array(self):
|
|
||||||
self._array[:] = c.FREE_CELL.value
|
|
||||||
for item_idx, item in enumerate(self):
|
|
||||||
x_diff, y_diff = divmod(item_idx, self._array.shape[1])
|
|
||||||
self._array[0, int(x_diff), int(y_diff)] = item.encoding
|
|
||||||
return self._array
|
|
||||||
|
|
||||||
def __repr__(self):
|
|
||||||
return f'{self.__class__.__name__}[{self.agent.name}]({self.data})'
|
|
||||||
|
|
||||||
def append(self, item) -> None:
|
|
||||||
if len(self) < self.capacity:
|
|
||||||
super(Inventory, self).append(item)
|
|
||||||
else:
|
|
||||||
raise RuntimeError('Inventory is full')
|
|
||||||
|
|
||||||
def belongs_to_entity(self, entity):
|
|
||||||
return self.agent == entity
|
|
||||||
|
|
||||||
def summarize_state(self, **kwargs):
|
|
||||||
attr_dict = {key: str(val) for key, val in self.__dict__.items() if not key.startswith('_') and key != 'data'}
|
|
||||||
attr_dict.update(dict(items={val.name: val.summarize_state(**kwargs) for val in self}))
|
|
||||||
attr_dict.update(dict(name=self.name))
|
|
||||||
return attr_dict
|
|
||||||
|
|
||||||
|
|
||||||
class Inventories(ObjectRegister):
|
|
||||||
|
|
||||||
_accepted_objects = Inventory
|
|
||||||
is_blocking_light = False
|
|
||||||
can_be_shadowed = False
|
|
||||||
hide_from_obs_builder = True
|
|
||||||
|
|
||||||
def __init__(self, *args, **kwargs):
|
|
||||||
super(Inventories, self).__init__(*args, is_per_agent=True, individual_slices=True, **kwargs)
|
|
||||||
self.is_observable = True
|
|
||||||
|
|
||||||
def as_array(self):
|
|
||||||
# self._array[:] = c.FREE_CELL.value
|
|
||||||
for inv_idx, inventory in enumerate(self):
|
|
||||||
self._array[inv_idx] = inventory.as_array()
|
|
||||||
return self._array
|
|
||||||
|
|
||||||
def spawn_inventories(self, agents, pomdp_r, capacity):
|
|
||||||
inventories = [self._accepted_objects(pomdp_r, self._level_shape, agent, capacity)
|
|
||||||
for _, agent in enumerate(agents)]
|
|
||||||
self.register_additional_items(inventories)
|
|
||||||
|
|
||||||
def idx_by_entity(self, entity):
|
|
||||||
try:
|
|
||||||
return next((idx for idx, inv in enumerate(self) if inv.belongs_to_entity(entity)))
|
|
||||||
except StopIteration:
|
|
||||||
return None
|
|
||||||
|
|
||||||
def by_entity(self, entity):
|
|
||||||
try:
|
|
||||||
return next((inv for inv in self if inv.belongs_to_entity(entity)))
|
|
||||||
except StopIteration:
|
|
||||||
return None
|
|
||||||
|
|
||||||
def summarize_states(self, n_steps=None):
|
|
||||||
# as dict with additional nesting
|
|
||||||
# return dict(items=super(Inventories, self).summarize_states())
|
|
||||||
return super(Inventories, self).summarize_states(n_steps=n_steps)
|
|
||||||
|
|
||||||
|
|
||||||
class DropOffLocation(Entity):
|
|
||||||
|
|
||||||
@property
|
|
||||||
def can_collide(self):
|
|
||||||
return False
|
|
||||||
|
|
||||||
@property
|
|
||||||
def encoding(self):
|
|
||||||
return ITEM_DROP_OFF
|
|
||||||
|
|
||||||
def __init__(self, *args, storage_size_until_full: int = 5, auto_item_despawn_interval: int = 5, **kwargs):
|
|
||||||
super(DropOffLocation, self).__init__(*args, **kwargs)
|
|
||||||
self.auto_item_despawn_interval = auto_item_despawn_interval
|
|
||||||
self.storage = deque(maxlen=storage_size_until_full or None)
|
|
||||||
|
|
||||||
def place_item(self, item: Item):
|
|
||||||
if self.is_full:
|
|
||||||
raise RuntimeWarning("There is currently no way to clear the storage or make it unfull.")
|
|
||||||
return c.NOT_VALID
|
|
||||||
else:
|
|
||||||
self.storage.append(item)
|
|
||||||
item.set_auto_despawn(self.auto_item_despawn_interval)
|
|
||||||
return c.VALID
|
|
||||||
|
|
||||||
@property
|
|
||||||
def is_full(self):
|
|
||||||
return False if not self.storage.maxlen else self.storage.maxlen == len(self.storage)
|
|
||||||
|
|
||||||
def summarize_state(self, n_steps=None) -> dict:
|
|
||||||
if n_steps == h.STEPS_START:
|
|
||||||
return super().summarize_state(n_steps=n_steps)
|
|
||||||
|
|
||||||
|
|
||||||
class DropOffLocations(EntityObjectRegister):
|
|
||||||
|
|
||||||
_accepted_objects = DropOffLocation
|
|
||||||
|
|
||||||
def as_array(self):
|
|
||||||
self._array[:] = c.FREE_CELL.value
|
|
||||||
for item in self:
|
|
||||||
if item.pos != c.NO_POS.value:
|
|
||||||
self._array[0, item.x, item.y] = item.encoding
|
|
||||||
return self._array
|
|
||||||
|
|
||||||
def __repr__(self):
|
|
||||||
super(DropOffLocations, self).__repr__()
|
|
||||||
|
|
||||||
|
|
||||||
class ItemProperties(NamedTuple):
|
|
||||||
n_items: int = 5 # How many items are there at the same time
|
|
||||||
spawn_frequency: int = 10 # Spawn Frequency in Steps
|
|
||||||
n_drop_off_locations: int = 5 # How many DropOff locations are there at the same time
|
|
||||||
max_dropoff_storage_size: int = 0 # How many items are needed until the drop off is full
|
|
||||||
max_agent_inventory_capacity: int = 5 # How many items are needed until the agent inventory is full
|
|
||||||
agent_can_interact: bool = True # Whether agents have the possibility to interact with the domain items
|
|
||||||
|
|
||||||
|
|
||||||
# noinspection PyAttributeOutsideInit, PyAbstractClass
|
|
||||||
class ItemFactory(BaseFactory):
|
|
||||||
# noinspection PyMissingConstructor
|
|
||||||
def __init__(self, *args, item_prop: ItemProperties = ItemProperties(), env_seed=time.time_ns(), **kwargs):
|
|
||||||
if isinstance(item_prop, dict):
|
|
||||||
item_prop = ItemProperties(**item_prop)
|
|
||||||
self.item_prop = item_prop
|
|
||||||
kwargs.update(env_seed=env_seed)
|
|
||||||
self._item_rng = np.random.default_rng(env_seed)
|
|
||||||
assert (item_prop.n_items <= ((1 + kwargs.get('_pomdp_r', 0) * 2) ** 2)) or not kwargs.get('_pomdp_r', 0)
|
|
||||||
super().__init__(*args, **kwargs)
|
|
||||||
|
|
||||||
@property
|
|
||||||
def additional_actions(self) -> Union[Action, List[Action]]:
|
|
||||||
# noinspection PyUnresolvedReferences
|
|
||||||
super_actions = super().additional_actions
|
|
||||||
super_actions.append(Action(enum_ident=h.EnvActions.ITEM_ACTION))
|
|
||||||
return super_actions
|
|
||||||
|
|
||||||
@property
|
|
||||||
def additional_entities(self) -> Dict[(Enum, Entities)]:
|
|
||||||
# noinspection PyUnresolvedReferences
|
|
||||||
super_entities = super().additional_entities
|
|
||||||
|
|
||||||
empty_tiles = self[c.FLOOR].empty_tiles[:self.item_prop.n_drop_off_locations]
|
|
||||||
drop_offs = DropOffLocations.from_tiles(
|
|
||||||
empty_tiles, self._level_shape,
|
|
||||||
entity_kwargs=dict(
|
|
||||||
storage_size_until_full=self.item_prop.max_dropoff_storage_size)
|
|
||||||
)
|
|
||||||
item_register = ItemRegister(self._level_shape)
|
|
||||||
empty_tiles = self[c.FLOOR].empty_tiles[:self.item_prop.n_items]
|
|
||||||
item_register.spawn_items(empty_tiles)
|
|
||||||
|
|
||||||
inventories = Inventories(self._level_shape if not self._pomdp_r else ((self.pomdp_diameter,) * 2))
|
|
||||||
inventories.spawn_inventories(self[c.AGENT], self._pomdp_r,
|
|
||||||
self.item_prop.max_agent_inventory_capacity)
|
|
||||||
|
|
||||||
super_entities.update({c.DROP_OFF: drop_offs, c.ITEM: item_register, c.INVENTORY: inventories})
|
|
||||||
return super_entities
|
|
||||||
|
|
||||||
def additional_per_agent_obs_build(self, agent) -> List[np.ndarray]:
|
|
||||||
additional_per_agent_obs_build = super().additional_per_agent_obs_build(agent)
|
|
||||||
additional_per_agent_obs_build.append(self[c.INVENTORY].by_entity(agent).as_array())
|
|
||||||
return additional_per_agent_obs_build
|
|
||||||
|
|
||||||
def do_item_action(self, agent: Agent):
|
|
||||||
inventory = self[c.INVENTORY].by_entity(agent)
|
|
||||||
if drop_off := self[c.DROP_OFF].by_pos(agent.pos):
|
|
||||||
if inventory:
|
|
||||||
valid = drop_off.place_item(inventory.pop(0))
|
|
||||||
return valid
|
|
||||||
else:
|
|
||||||
return c.NOT_VALID
|
|
||||||
elif item := self[c.ITEM].by_pos(agent.pos):
|
|
||||||
try:
|
|
||||||
inventory.append(item)
|
|
||||||
item.move(self._NO_POS_TILE)
|
|
||||||
return c.VALID
|
|
||||||
except RuntimeError:
|
|
||||||
return c.NOT_VALID
|
|
||||||
else:
|
|
||||||
return c.NOT_VALID
|
|
||||||
|
|
||||||
def do_additional_actions(self, agent: Agent, action: Action) -> Union[None, c]:
|
|
||||||
# noinspection PyUnresolvedReferences
|
|
||||||
valid = super().do_additional_actions(agent, action)
|
|
||||||
if valid is None:
|
|
||||||
if action == h.EnvActions.ITEM_ACTION:
|
|
||||||
if self.item_prop.agent_can_interact:
|
|
||||||
valid = self.do_item_action(agent)
|
|
||||||
return valid
|
|
||||||
else:
|
|
||||||
return c.NOT_VALID
|
|
||||||
else:
|
|
||||||
return None
|
|
||||||
else:
|
|
||||||
return valid
|
|
||||||
|
|
||||||
def do_additional_reset(self) -> None:
|
|
||||||
# noinspection PyUnresolvedReferences
|
|
||||||
super().do_additional_reset()
|
|
||||||
self._next_item_spawn = self.item_prop.spawn_frequency
|
|
||||||
self.trigger_item_spawn()
|
|
||||||
|
|
||||||
def trigger_item_spawn(self):
|
|
||||||
if item_to_spawns := max(0, (self.item_prop.n_items - len(self[c.ITEM]))):
|
|
||||||
empty_tiles = self[c.FLOOR].empty_tiles[:item_to_spawns]
|
|
||||||
self[c.ITEM].spawn_items(empty_tiles)
|
|
||||||
self._next_item_spawn = self.item_prop.spawn_frequency
|
|
||||||
self.print(f'{item_to_spawns} new items have been spawned; next spawn in {self._next_item_spawn}')
|
|
||||||
else:
|
|
||||||
self.print('No Items are spawning, limit is reached.')
|
|
||||||
|
|
||||||
def do_additional_step(self) -> dict:
|
|
||||||
# noinspection PyUnresolvedReferences
|
|
||||||
info_dict = super().do_additional_step()
|
|
||||||
for item in list(self[c.ITEM].values()):
|
|
||||||
if item.auto_despawn >= 1:
|
|
||||||
item.set_auto_despawn(item.auto_despawn-1)
|
|
||||||
elif not item.auto_despawn:
|
|
||||||
self[c.ITEM].delete_entity(item)
|
|
||||||
else:
|
|
||||||
pass
|
|
||||||
|
|
||||||
if not self._next_item_spawn:
|
|
||||||
self.trigger_item_spawn()
|
|
||||||
else:
|
|
||||||
self._next_item_spawn = max(0, self._next_item_spawn-1)
|
|
||||||
return info_dict
|
|
||||||
|
|
||||||
def calculate_additional_reward(self, agent: Agent) -> (int, dict):
|
|
||||||
# noinspection PyUnresolvedReferences
|
|
||||||
reward, info_dict = super().calculate_additional_reward(agent)
|
|
||||||
if h.EnvActions.ITEM_ACTION == agent.temp_action:
|
|
||||||
if agent.temp_valid:
|
|
||||||
if drop_off := self[c.DROP_OFF].by_pos(agent.pos):
|
|
||||||
info_dict.update({f'{agent.name}_item_drop_off': 1})
|
|
||||||
info_dict.update(item_drop_off=1)
|
|
||||||
self.print(f'{agent.name} just dropped of an item at {drop_off.pos}.')
|
|
||||||
reward += 0.5
|
|
||||||
else:
|
|
||||||
info_dict.update({f'{agent.name}_item_pickup': 1})
|
|
||||||
info_dict.update(item_pickup=1)
|
|
||||||
self.print(f'{agent.name} just picked up an item at {agent.pos}')
|
|
||||||
reward += 0.1
|
|
||||||
else:
|
|
||||||
if self[c.DROP_OFF].by_pos(agent.pos):
|
|
||||||
info_dict.update({f'{agent.name}_failed_drop_off': 1})
|
|
||||||
info_dict.update(failed_drop_off=1)
|
|
||||||
self.print(f'{agent.name} just tried to drop off at {agent.pos}, but failed.')
|
|
||||||
reward -= 0.1
|
|
||||||
else:
|
|
||||||
info_dict.update({f'{agent.name}_failed_item_action': 1})
|
|
||||||
info_dict.update(failed_pick_up=1)
|
|
||||||
self.print(f'{agent.name} just tried to pick up an item at {agent.pos}, but failed.')
|
|
||||||
reward -= 0.1
|
|
||||||
return reward, info_dict
|
|
||||||
|
|
||||||
def render_additional_assets(self, mode='human'):
|
|
||||||
# noinspection PyUnresolvedReferences
|
|
||||||
additional_assets = super().render_additional_assets()
|
|
||||||
items = [RenderEntity(c.ITEM.value, item.tile.pos) for item in self[c.ITEM]]
|
|
||||||
additional_assets.extend(items)
|
|
||||||
drop_offs = [RenderEntity(c.DROP_OFF.value, drop_off.tile.pos) for drop_off in self[c.DROP_OFF]]
|
|
||||||
additional_assets.extend(drop_offs)
|
|
||||||
return additional_assets
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
|
||||||
from environments.utility_classes import AgentRenderOptions as ARO, ObservationProperties
|
|
||||||
|
|
||||||
render = True
|
|
||||||
|
|
||||||
item_probs = ItemProperties()
|
|
||||||
|
|
||||||
obs_props = ObservationProperties(render_agents=ARO.LEVEL, omit_agent_self=True, pomdp_r=2)
|
|
||||||
|
|
||||||
move_props = {'allow_square_movement': True,
|
|
||||||
'allow_diagonal_movement': False,
|
|
||||||
'allow_no_op': False}
|
|
||||||
|
|
||||||
factory = ItemFactory(n_agents=3, done_at_collision=False,
|
|
||||||
level_name='rooms', max_steps=400,
|
|
||||||
obs_prop=obs_props, parse_doors=True,
|
|
||||||
record_episodes=True, verbose=True,
|
|
||||||
mv_prop=move_props, item_prop=item_probs
|
|
||||||
)
|
|
||||||
|
|
||||||
# noinspection DuplicatedCode
|
|
||||||
n_actions = factory.action_space.n - 1
|
|
||||||
_ = factory.observation_space
|
|
||||||
|
|
||||||
for epoch in range(4):
|
|
||||||
random_actions = [[random.randint(0, n_actions) for _
|
|
||||||
in range(factory.n_agents)] for _
|
|
||||||
in range(factory.max_steps + 1)]
|
|
||||||
env_state = factory.reset()
|
|
||||||
r = 0
|
|
||||||
for agent_i_action in random_actions:
|
|
||||||
env_state, step_r, done_bool, info_obj = factory.step(agent_i_action)
|
|
||||||
r += step_r
|
|
||||||
if render:
|
|
||||||
factory.render()
|
|
||||||
if done_bool:
|
|
||||||
break
|
|
||||||
print(f'Factory run {epoch} done, reward is:\n {r}')
|
|
||||||
pass
|
|
@ -1,8 +1,12 @@
|
|||||||
movement_props:
|
parse_doors: True
|
||||||
|
doors_have_area: True
|
||||||
|
done_at_collision: False
|
||||||
|
level_name: "rooms"
|
||||||
|
mv_prop:
|
||||||
allow_diagonal_movement: True
|
allow_diagonal_movement: True
|
||||||
allow_square_movement: True
|
allow_square_movement: True
|
||||||
allow_no_op: False
|
allow_no_op: False
|
||||||
dirt_props:
|
dirt_prop:
|
||||||
initial_dirt_ratio: 0.35
|
initial_dirt_ratio: 0.35
|
||||||
initial_dirt_spawn_r_var : 0.1
|
initial_dirt_spawn_r_var : 0.1
|
||||||
clean_amount: 0.34
|
clean_amount: 0.34
|
||||||
@ -12,8 +16,15 @@ dirt_props:
|
|||||||
spawn_frequency: 0
|
spawn_frequency: 0
|
||||||
max_spawn_ratio: 0.05
|
max_spawn_ratio: 0.05
|
||||||
dirt_smear_amount: 0.0
|
dirt_smear_amount: 0.0
|
||||||
agent_can_interact: True
|
done_when_clean: True
|
||||||
factory_props:
|
rewards_base:
|
||||||
parse_doors: True
|
MOVEMENTS_VALID: 0
|
||||||
level_name: "rooms"
|
MOVEMENTS_FAIL: 0
|
||||||
doors_have_area: False
|
NOOP: 0
|
||||||
|
USE_DOOR_VALID: 0
|
||||||
|
USE_DOOR_FAIL: 0
|
||||||
|
COLLISION: 0
|
||||||
|
rewards_dirt:
|
||||||
|
CLEAN_UP_VALID: 1
|
||||||
|
CLEAN_UP_FAIL: 0
|
||||||
|
CLEAN_UP_LAST_PIECE: 5
|
@ -1,116 +1,286 @@
|
|||||||
import itertools
|
import itertools
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from enum import Enum, auto
|
from typing import Tuple, Union, Dict, List, NamedTuple
|
||||||
from typing import Tuple, Union
|
|
||||||
|
|
||||||
import networkx as nx
|
import networkx as nx
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from pathlib import Path
|
from numpy.typing import ArrayLike
|
||||||
|
|
||||||
from stable_baselines3 import PPO, DQN, A2C
|
from stable_baselines3 import PPO, DQN, A2C
|
||||||
|
|
||||||
MODEL_MAP = dict(PPO=PPO, DQN=DQN, A2C=A2C)
|
|
||||||
|
|
||||||
LEVELS_DIR = 'levels'
|
"""
|
||||||
STEPS_START = 1
|
This file is used for:
|
||||||
|
1. string based definition
|
||||||
|
Use a class like `Constants`, to define attributes, which then reveal strings.
|
||||||
|
These can be used for naming convention along the environments as well as keys for mappings such as dicts etc.
|
||||||
|
When defining new envs, use class inheritance.
|
||||||
|
|
||||||
TO_BE_AVERAGED = ['dirt_amount', 'dirty_tiles']
|
2. utility function definition
|
||||||
IGNORED_DF_COLUMNS = ['Episode', 'Run', 'train_step', 'step', 'index', 'dirt_amount',
|
There are static utility functions which are not bound to a specific environment.
|
||||||
'dirty_tile_count', 'terminal_observation', 'episode']
|
In this file they are defined to be used across the entire package.
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
# Constants
|
MODEL_MAP = dict(PPO=PPO, DQN=DQN, A2C=A2C) # For use in studies and experiments
|
||||||
class Constants(Enum):
|
|
||||||
WALL = '#'
|
|
||||||
WALLS = 'Walls'
|
|
||||||
FLOOR = 'Floor'
|
|
||||||
DOOR = 'D'
|
|
||||||
DANGER_ZONE = 'x'
|
|
||||||
LEVEL = 'Level'
|
|
||||||
AGENT = 'Agent'
|
|
||||||
AGENT_PLACEHOLDER = 'AGENT_PLACEHOLDER'
|
|
||||||
FREE_CELL = 0
|
|
||||||
OCCUPIED_CELL = 1
|
|
||||||
SHADOWED_CELL = -1
|
|
||||||
NO_POS = (-9999, -9999)
|
|
||||||
|
|
||||||
DOORS = 'Doors'
|
|
||||||
CLOSED_DOOR = 'closed'
|
|
||||||
OPEN_DOOR = 'open'
|
|
||||||
|
|
||||||
ACTION = 'action'
|
|
||||||
COLLISIONS = 'collision'
|
|
||||||
VALID = 'valid'
|
|
||||||
NOT_VALID = 'not_valid'
|
|
||||||
|
|
||||||
# Dirt Env
|
|
||||||
DIRT = 'Dirt'
|
|
||||||
|
|
||||||
# Item Env
|
|
||||||
ITEM = 'Item'
|
|
||||||
INVENTORY = 'Inventory'
|
|
||||||
DROP_OFF = 'Drop_Off'
|
|
||||||
|
|
||||||
# Battery Env
|
|
||||||
CHARGE_POD = 'Charge_Pod'
|
|
||||||
BATTERIES = 'BATTERIES'
|
|
||||||
|
|
||||||
# Destination Env
|
|
||||||
DESTINATION = 'Destination'
|
|
||||||
REACHEDDESTINATION = 'ReachedDestination'
|
|
||||||
|
|
||||||
def __bool__(self):
|
|
||||||
if 'not_' in self.value:
|
|
||||||
return False
|
|
||||||
else:
|
|
||||||
return bool(self.value)
|
|
||||||
|
|
||||||
|
|
||||||
class MovingAction(Enum):
|
LEVELS_DIR = 'levels' # for use in studies and experiments
|
||||||
NORTH = 'north'
|
STEPS_START = 1 # Define where to the stepcount; which is the first step
|
||||||
EAST = 'east'
|
|
||||||
SOUTH = 'south'
|
# Not used anymore? Clean!
|
||||||
WEST = 'west'
|
# TO_BE_AVERAGED = ['dirt_amount', 'dirty_tiles']
|
||||||
NORTHEAST = 'north_east'
|
IGNORED_DF_COLUMNS = ['Episode', 'Run', # For plotting, which values are ignored when loading monitor files
|
||||||
SOUTHEAST = 'south_east'
|
'train_step', 'step', 'index', 'dirt_amount', 'dirty_tile_count', 'terminal_observation',
|
||||||
SOUTHWEST = 'south_west'
|
'episode']
|
||||||
NORTHWEST = 'north_west'
|
|
||||||
|
|
||||||
|
class Constants:
|
||||||
|
|
||||||
|
"""
|
||||||
|
String based mapping. Use these to handle keys or define values, which can be then be used globaly.
|
||||||
|
Please use class inheritance when defining new environments.
|
||||||
|
"""
|
||||||
|
|
||||||
|
WALL = '#' # Wall tile identifier for resolving the string based map files.
|
||||||
|
DOOR = 'D' # Door identifier for resolving the string based map files.
|
||||||
|
DANGER_ZONE = 'x' # Dange Zone tile identifier for resolving the string based map files.
|
||||||
|
|
||||||
|
WALLS = 'Walls' # Identifier of Wall-objects and sets (collections).
|
||||||
|
FLOOR = 'Floor' # Identifier of Floor-objects and sets (collections).
|
||||||
|
DOORS = 'Doors' # Identifier of Door-objects and sets (collections).
|
||||||
|
LEVEL = 'Level' # Identifier of Level-objects and sets (collections).
|
||||||
|
AGENT = 'Agent' # Identifier of Agent-objects and sets (collections).
|
||||||
|
AGENT_PLACEHOLDER = 'AGENT_PLACEHOLDER' # Identifier of Placeholder-objects and sets (collections).
|
||||||
|
GLOBAL_POSITION = 'GLOBAL_POSITION' # Identifier of the global position slice
|
||||||
|
|
||||||
|
FREE_CELL = 0 # Free-Cell value used in observation
|
||||||
|
OCCUPIED_CELL = 1 # Occupied-Cell value used in observation
|
||||||
|
SHADOWED_CELL = -1 # Shadowed-Cell value used in observation
|
||||||
|
ACCESS_DOOR_CELL = 1/3 # Access-door-Cell value used in observation
|
||||||
|
OPEN_DOOR_CELL = 2/3 # Open-door-Cell value used in observation
|
||||||
|
CLOSED_DOOR_CELL = 3/3 # Closed-door-Cell value used in observation
|
||||||
|
|
||||||
|
NO_POS = (-9999, -9999) # Invalid Position value used in the environment (something is off-grid)
|
||||||
|
|
||||||
|
CLOSED_DOOR = 'closed' # Identifier to compare door-is-closed state
|
||||||
|
OPEN_DOOR = 'open' # Identifier to compare door-is-open state
|
||||||
|
# ACCESS_DOOR = 'access' # Identifier to compare access positions
|
||||||
|
|
||||||
|
ACTION = 'action' # Identifier of Action-objects and sets (collections).
|
||||||
|
COLLISION = 'collision' # Identifier to use in the context of collitions.
|
||||||
|
VALID = True # Identifier to rename boolean values in the context of actions.
|
||||||
|
NOT_VALID = False # Identifier to rename boolean values in the context of actions.
|
||||||
|
|
||||||
|
|
||||||
|
class EnvActions:
|
||||||
|
"""
|
||||||
|
String based mapping. Use these to identifiy actions, can be used globaly.
|
||||||
|
Please use class inheritance when defining new environments with new actions.
|
||||||
|
"""
|
||||||
|
# Movements
|
||||||
|
NORTH = 'north'
|
||||||
|
EAST = 'east'
|
||||||
|
SOUTH = 'south'
|
||||||
|
WEST = 'west'
|
||||||
|
NORTHEAST = 'north_east'
|
||||||
|
SOUTHEAST = 'south_east'
|
||||||
|
SOUTHWEST = 'south_west'
|
||||||
|
NORTHWEST = 'north_west'
|
||||||
|
|
||||||
|
# Other
|
||||||
|
# MOVE = 'move'
|
||||||
|
NOOP = 'no_op'
|
||||||
|
USE_DOOR = 'use_door'
|
||||||
|
|
||||||
|
_ACTIONMAP = defaultdict(lambda: (0, 0),
|
||||||
|
{NORTH: (-1, 0), NORTHEAST: (-1, 1),
|
||||||
|
EAST: (0, 1), SOUTHEAST: (1, 1),
|
||||||
|
SOUTH: (1, 0), SOUTHWEST: (1, -1),
|
||||||
|
WEST: (0, -1), NORTHWEST: (-1, -1)
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def is_member(cls, other):
|
def is_move(cls, action):
|
||||||
return any([other == direction for direction in cls])
|
"""
|
||||||
|
Classmethod; checks if given action is a movement action or not. Depending on the env. configuration,
|
||||||
|
Movement actions are either `manhattan` (square) style movements (up,down, left, right) and/or diagonal.
|
||||||
|
|
||||||
|
:param action: Action to be checked
|
||||||
|
:type action: str
|
||||||
|
:return: Whether the given action is a movement action.
|
||||||
|
:rtype: bool
|
||||||
|
"""
|
||||||
|
return any([action == direction for direction in cls.movement_actions()])
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def square(cls):
|
def square_move(cls):
|
||||||
|
"""
|
||||||
|
Classmethod; return a list of movement actions that are considered square or `manhattan` style movements.
|
||||||
|
|
||||||
|
:return: A list of movement actions.
|
||||||
|
:rtype: list(str)
|
||||||
|
"""
|
||||||
return [cls.NORTH, cls.EAST, cls.SOUTH, cls.WEST]
|
return [cls.NORTH, cls.EAST, cls.SOUTH, cls.WEST]
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def diagonal(cls):
|
def diagonal_move(cls):
|
||||||
|
"""
|
||||||
|
Classmethod; return a list of movement actions that are considered diagonal movements.
|
||||||
|
|
||||||
|
:return: A list of movement actions.
|
||||||
|
:rtype: list(str)
|
||||||
|
"""
|
||||||
return [cls.NORTHEAST, cls.SOUTHEAST, cls.SOUTHWEST, cls.NORTHWEST]
|
return [cls.NORTHEAST, cls.SOUTHEAST, cls.SOUTHWEST, cls.NORTHWEST]
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def movement_actions(cls):
|
||||||
|
"""
|
||||||
|
Classmethod; return a list of all available movement actions.
|
||||||
|
Please note, that this is indipendent from the env. properties
|
||||||
|
|
||||||
class EnvActions(Enum):
|
:return: A list of movement actions.
|
||||||
NOOP = 'no_op'
|
:rtype: list(str)
|
||||||
USE_DOOR = 'use_door'
|
"""
|
||||||
CLEAN_UP = 'clean_up'
|
return list(itertools.chain(cls.square_move(), cls.diagonal_move()))
|
||||||
ITEM_ACTION = 'item_action'
|
|
||||||
CHARGE = 'charge'
|
@classmethod
|
||||||
WAIT_ON_DEST = 'wait'
|
def resolve_movement_action_to_coords(cls, action):
|
||||||
|
"""
|
||||||
|
Classmethod; resolve movement actions. Given a movement action, return the delta in coordinates it stands for.
|
||||||
|
How does the current entity coordinate change if it performs the given action?
|
||||||
|
Please note, this is indipendent from the env. properties
|
||||||
|
|
||||||
|
:return: Delta coorinates.
|
||||||
|
:rtype: tuple(int, int)
|
||||||
|
"""
|
||||||
|
return cls._ACTIONMAP[action]
|
||||||
|
|
||||||
|
|
||||||
m = MovingAction
|
class RewardsBase(NamedTuple):
|
||||||
c = Constants
|
"""
|
||||||
|
Value based mapping. Use these to define reward values for specific conditions (i.e. the action
|
||||||
|
in a given context), can be used globaly.
|
||||||
|
Please use class inheritance when defining new environments with new rewards.
|
||||||
|
"""
|
||||||
|
MOVEMENTS_VALID: float = -0.001
|
||||||
|
MOVEMENTS_FAIL: float = -0.05
|
||||||
|
NOOP: float = -0.01
|
||||||
|
USE_DOOR_VALID: float = -0.00
|
||||||
|
USE_DOOR_FAIL: float = -0.01
|
||||||
|
COLLISION: float = -0.5
|
||||||
|
|
||||||
ACTIONMAP = defaultdict(lambda: (0, 0), {m.NORTH: (-1, 0), m.NORTHEAST: (-1, +1),
|
|
||||||
m.EAST: (0, 1), m.SOUTHEAST: (1, 1),
|
class ObservationTranslator:
|
||||||
m.SOUTH: (1, 0), m.SOUTHWEST: (+1, -1),
|
|
||||||
m.WEST: (0, -1), m.NORTHWEST: (-1, -1)
|
def __init__(self, this_named_observation_space: Dict[str, dict],
|
||||||
}
|
*per_agent_named_obs_spaces: Dict[str, dict],
|
||||||
)
|
placeholder_fill_value: Union[int, str, None] = None):
|
||||||
|
"""
|
||||||
|
This is a helper class, which converts agents observations from joined environments.
|
||||||
|
For example, agents trained in different environments may expect different observations.
|
||||||
|
This class translates from larger observations spaces to smaller.
|
||||||
|
A string identifier based approach is used.
|
||||||
|
Currently, it is not possible to mix different obs shapes.
|
||||||
|
|
||||||
|
|
||||||
|
:param this_named_observation_space: `Named observation space` of the joined environment.
|
||||||
|
:type this_named_observation_space: Dict[str, dict]
|
||||||
|
|
||||||
|
:param per_agent_named_obs_spaces: `Named observation space` one for each agent. Overloaded.
|
||||||
|
type per_agent_named_obs_spaces: Dict[str, dict]
|
||||||
|
|
||||||
|
:param placeholder_fill_value: Currently not fully implemented!!!
|
||||||
|
:type placeholder_fill_value: Union[int, str] = 'N')
|
||||||
|
"""
|
||||||
|
|
||||||
|
if isinstance(placeholder_fill_value, str):
|
||||||
|
if placeholder_fill_value.lower() in ['normal', 'n']:
|
||||||
|
self.random_fill = np.random.normal
|
||||||
|
elif placeholder_fill_value.lower() in ['uniform', 'u']:
|
||||||
|
self.random_fill = np.random.uniform
|
||||||
|
else:
|
||||||
|
raise ValueError('Please chooe between "uniform" or "normal" ("u", "n").')
|
||||||
|
elif isinstance(placeholder_fill_value, int):
|
||||||
|
raise NotImplementedError('"Future Work."')
|
||||||
|
else:
|
||||||
|
self.random_fill = None
|
||||||
|
|
||||||
|
self._this_named_obs_space = this_named_observation_space
|
||||||
|
self._per_agent_named_obs_space = list(per_agent_named_obs_spaces)
|
||||||
|
|
||||||
|
def translate_observation(self, agent_idx: int, obs: np.ndarray):
|
||||||
|
target_obs_space = self._per_agent_named_obs_space[agent_idx]
|
||||||
|
translation = dict()
|
||||||
|
for name, idxs in target_obs_space.items():
|
||||||
|
if name in self._this_named_obs_space:
|
||||||
|
for target_idx, this_idx in zip(idxs, self._this_named_obs_space[name]):
|
||||||
|
taken_slice = np.take(obs, [this_idx], axis=1 if obs.ndim == 4 else 0)
|
||||||
|
translation[target_idx] = taken_slice
|
||||||
|
elif random_fill := self.random_fill:
|
||||||
|
for target_idx in idxs:
|
||||||
|
translation[target_idx] = random_fill(size=obs.shape[:-3] + (1,) + obs.shape[-2:])
|
||||||
|
else:
|
||||||
|
for target_idx in idxs:
|
||||||
|
translation[target_idx] = np.zeros(shape=(obs.shape[:-3] + (1,) + obs.shape[-2:]))
|
||||||
|
|
||||||
|
translation = dict(sorted(translation.items()))
|
||||||
|
return np.concatenate(list(translation.values()), axis=-3)
|
||||||
|
|
||||||
|
def translate_observations(self, observations: List[ArrayLike]):
|
||||||
|
return [self.translate_observation(idx, observation) for idx, observation in enumerate(observations)]
|
||||||
|
|
||||||
|
def __call__(self, observations):
|
||||||
|
return self.translate_observations(observations)
|
||||||
|
|
||||||
|
|
||||||
|
class ActionTranslator:
|
||||||
|
|
||||||
|
def __init__(self, target_named_action_space: Dict[str, int], *per_agent_named_action_space: Dict[str, int]):
|
||||||
|
"""
|
||||||
|
This is a helper class, which converts agents action spaces to a joined environments action space.
|
||||||
|
For example, agents trained in different environments may have different action spaces.
|
||||||
|
This class translates from smaller individual agent action spaces to larger joined spaces.
|
||||||
|
A string identifier based approach is used.
|
||||||
|
|
||||||
|
:param target_named_action_space: Joined `Named action space` for the current environment.
|
||||||
|
:type target_named_action_space: Dict[str, dict]
|
||||||
|
|
||||||
|
:param per_agent_named_action_space: `Named action space` one for each agent. Overloaded.
|
||||||
|
:type per_agent_named_action_space: Dict[str, dict]
|
||||||
|
"""
|
||||||
|
|
||||||
|
self._target_named_action_space = target_named_action_space
|
||||||
|
if isinstance(per_agent_named_action_space, (list, tuple)):
|
||||||
|
self._per_agent_named_action_space = per_agent_named_action_space
|
||||||
|
else:
|
||||||
|
self._per_agent_named_action_space = list(per_agent_named_action_space)
|
||||||
|
self._per_agent_idx_actions = [{idx: a for a, idx in x.items()} for x in self._per_agent_named_action_space]
|
||||||
|
|
||||||
|
def translate_action(self, agent_idx: int, action: int):
|
||||||
|
named_action = self._per_agent_idx_actions[agent_idx][action]
|
||||||
|
translated_action = self._target_named_action_space[named_action]
|
||||||
|
return translated_action
|
||||||
|
|
||||||
|
def translate_actions(self, actions: List[int]):
|
||||||
|
return [self.translate_action(idx, action) for idx, action in enumerate(actions)]
|
||||||
|
|
||||||
|
def __call__(self, actions):
|
||||||
|
return self.translate_actions(actions)
|
||||||
|
|
||||||
|
|
||||||
# Utility functions
|
# Utility functions
|
||||||
def parse_level(path):
|
def parse_level(path):
|
||||||
|
"""
|
||||||
|
Given the path to a strin based `level` or `map` representation, this function reads the content.
|
||||||
|
Cleans `space`, checks for equal length of each row and returns a list of lists.
|
||||||
|
|
||||||
|
:param path: Path to the `level` or `map` file on harddrive.
|
||||||
|
:type path: os.Pathlike
|
||||||
|
|
||||||
|
:return: The read string representation of the `level` or `map`
|
||||||
|
:rtype: List[List[str]]
|
||||||
|
"""
|
||||||
with path.open('r') as lvl:
|
with path.open('r') as lvl:
|
||||||
level = list(map(lambda x: list(x.strip()), lvl.readlines()))
|
level = list(map(lambda x: list(x.strip()), lvl.readlines()))
|
||||||
if len(set([len(line) for line in level])) > 1:
|
if len(set([len(line) for line in level])) > 1:
|
||||||
@ -118,66 +288,107 @@ def parse_level(path):
|
|||||||
return level
|
return level
|
||||||
|
|
||||||
|
|
||||||
def one_hot_level(level, wall_char: Union[c, str] = c.WALL):
|
def one_hot_level(level, wall_char: str = Constants.WALL):
|
||||||
|
"""
|
||||||
|
Given a string based level representation (list of lists, see function `parse_level`), this function creates a
|
||||||
|
binary numpy array or `grid`. Grid values that equal `wall_char` become of `Constants.OCCUPIED_CELL` value.
|
||||||
|
Can be changed to filter for any symbol.
|
||||||
|
|
||||||
|
:param level: String based level representation (list of lists, see function `parse_level`).
|
||||||
|
:param wall_char: List[List[str]]
|
||||||
|
|
||||||
|
:return: Binary numpy array
|
||||||
|
:rtype: np.typing._array_like.ArrayLike
|
||||||
|
"""
|
||||||
|
|
||||||
grid = np.array(level)
|
grid = np.array(level)
|
||||||
binary_grid = np.zeros(grid.shape, dtype=np.int8)
|
binary_grid = np.zeros(grid.shape, dtype=np.int8)
|
||||||
if wall_char in c:
|
binary_grid[grid == wall_char] = Constants.OCCUPIED_CELL
|
||||||
binary_grid[grid == wall_char.value] = c.OCCUPIED_CELL.value
|
|
||||||
else:
|
|
||||||
binary_grid[grid == wall_char] = c.OCCUPIED_CELL.value
|
|
||||||
return binary_grid
|
return binary_grid
|
||||||
|
|
||||||
|
|
||||||
def check_position(slice_to_check_against: np.ndarray, position_to_check: Tuple[int, int]):
|
def check_position(slice_to_check_against: ArrayLike, position_to_check: Tuple[int, int]):
|
||||||
|
"""
|
||||||
|
Given a slice (2-D Arraylike object)
|
||||||
|
|
||||||
|
:param slice_to_check_against: The slice to check for accessability
|
||||||
|
:type slice_to_check_against: np.typing._array_like.ArrayLike
|
||||||
|
|
||||||
|
:param position_to_check: Position in slice that should be checked. Can be outside of slice boundarys.
|
||||||
|
:type position_to_check: tuple(int, int)
|
||||||
|
|
||||||
|
:return: Whether a position can be moved to.
|
||||||
|
:rtype: bool
|
||||||
|
"""
|
||||||
x_pos, y_pos = position_to_check
|
x_pos, y_pos = position_to_check
|
||||||
|
|
||||||
# Check if agent colides with grid boundrys
|
# Check if agent colides with grid boundrys
|
||||||
valid = not (
|
valid = not (
|
||||||
x_pos < 0 or y_pos < 0
|
x_pos < 0 or y_pos < 0
|
||||||
or x_pos >= slice_to_check_against.shape[0]
|
or x_pos >= slice_to_check_against.shape[0]
|
||||||
or y_pos >= slice_to_check_against.shape[0]
|
or y_pos >= slice_to_check_against.shape[1]
|
||||||
)
|
)
|
||||||
|
|
||||||
# Check for collision with level walls
|
# Check for collision with level walls
|
||||||
valid = valid and not slice_to_check_against[x_pos, y_pos]
|
valid = valid and not slice_to_check_against[x_pos, y_pos]
|
||||||
return c.VALID if valid else c.NOT_VALID
|
return Constants.VALID if valid else Constants.NOT_VALID
|
||||||
|
|
||||||
|
|
||||||
def asset_str(agent):
|
def asset_str(agent):
|
||||||
|
"""
|
||||||
|
FIXME @ romue
|
||||||
|
"""
|
||||||
# What does this abonimation do?
|
# What does this abonimation do?
|
||||||
# if any([x is None for x in [self._slices[j] for j in agent.collisions]]):
|
# if any([x is None for x in [cls._slices[j] for j in agent.collisions]]):
|
||||||
# print('error')
|
# print('error')
|
||||||
col_names = [x.name for x in agent.temp_collisions]
|
if step_result := agent.step_result:
|
||||||
if any(c.AGENT.value in name for name in col_names):
|
action = step_result['action_name']
|
||||||
return 'agent_collision', 'blank'
|
valid = step_result['action_valid']
|
||||||
elif not agent.temp_valid or c.LEVEL.name in col_names or c.AGENT.name in col_names:
|
col_names = [x.name for x in step_result['collisions']]
|
||||||
return c.AGENT.value, 'invalid'
|
if any(Constants.AGENT in name for name in col_names):
|
||||||
elif agent.temp_valid and not MovingAction.is_member(agent.temp_action):
|
return 'agent_collision', 'blank'
|
||||||
return c.AGENT.value, 'valid'
|
elif not valid or Constants.LEVEL in col_names or Constants.AGENT in col_names:
|
||||||
elif agent.temp_valid and MovingAction.is_member(agent.temp_action):
|
return Constants.AGENT, 'invalid'
|
||||||
return c.AGENT.value, 'move'
|
elif valid and not EnvActions.is_move(action):
|
||||||
|
return Constants.AGENT, 'valid'
|
||||||
|
elif valid and EnvActions.is_move(action):
|
||||||
|
return Constants.AGENT, 'move'
|
||||||
|
else:
|
||||||
|
return Constants.AGENT, 'idle'
|
||||||
else:
|
else:
|
||||||
return c.AGENT.value, 'idle'
|
return Constants.AGENT, 'idle'
|
||||||
|
|
||||||
|
|
||||||
def points_to_graph(coordiniates_or_tiles, allow_euclidean_connections=True, allow_manhattan_connections=True):
|
def points_to_graph(coordiniates_or_tiles, allow_euclidean_connections=True, allow_manhattan_connections=True):
|
||||||
|
"""
|
||||||
|
Given a set of coordinates, this function contructs a non-directed graph, by conncting adjected points.
|
||||||
|
There are three combinations of settings:
|
||||||
|
Allow all neigbors: Distance(a, b) <= sqrt(2)
|
||||||
|
Allow only manhattan: Distance(a, b) == 1
|
||||||
|
Allow only euclidean: Distance(a, b) == sqrt(2)
|
||||||
|
|
||||||
|
|
||||||
|
:param coordiniates_or_tiles: A set of coordinates.
|
||||||
|
:type coordiniates_or_tiles: Tiles
|
||||||
|
:param allow_euclidean_connections: Whether to regard diagonal adjected cells as neighbors
|
||||||
|
:type: bool
|
||||||
|
:param allow_manhattan_connections: Whether to regard directly adjected cells as neighbors
|
||||||
|
:type: bool
|
||||||
|
|
||||||
|
:return: A graph with nodes that are conneceted as specified by the parameters.
|
||||||
|
:rtype: nx.Graph
|
||||||
|
"""
|
||||||
assert allow_euclidean_connections or allow_manhattan_connections
|
assert allow_euclidean_connections or allow_manhattan_connections
|
||||||
if hasattr(coordiniates_or_tiles, 'positions'):
|
if hasattr(coordiniates_or_tiles, 'positions'):
|
||||||
coordiniates_or_tiles = coordiniates_or_tiles.positions
|
coordiniates_or_tiles = coordiniates_or_tiles.positions
|
||||||
possible_connections = itertools.combinations(coordiniates_or_tiles, 2)
|
possible_connections = itertools.combinations(coordiniates_or_tiles, 2)
|
||||||
graph = nx.Graph()
|
graph = nx.Graph()
|
||||||
for a, b in possible_connections:
|
for a, b in possible_connections:
|
||||||
diff = abs(np.subtract(a, b))
|
diff = np.linalg.norm(np.asarray(a)-np.asarray(b))
|
||||||
if not max(diff) > 1:
|
if allow_manhattan_connections and allow_euclidean_connections and diff <= np.sqrt(2):
|
||||||
if allow_manhattan_connections and allow_euclidean_connections:
|
graph.add_edge(a, b)
|
||||||
graph.add_edge(a, b)
|
elif not allow_manhattan_connections and allow_euclidean_connections and diff == np.sqrt(2):
|
||||||
elif not allow_manhattan_connections and allow_euclidean_connections and all(diff):
|
graph.add_edge(a, b)
|
||||||
graph.add_edge(a, b)
|
elif allow_manhattan_connections and not allow_euclidean_connections and diff == 1:
|
||||||
elif allow_manhattan_connections and not allow_euclidean_connections and not all(diff) and any(diff):
|
graph.add_edge(a, b)
|
||||||
graph.add_edge(a, b)
|
|
||||||
return graph
|
return graph
|
||||||
|
|
||||||
if __name__ == '__main__':
|
|
||||||
parsed_level = parse_level(Path(__file__).parent / 'factory' / 'levels' / 'simple.txt')
|
|
||||||
y = one_hot_level(parsed_level)
|
|
||||||
print(np.argwhere(y == 0))
|
|
||||||
|
@ -1,5 +1,6 @@
|
|||||||
import pickle
|
import pickle
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
|
from os import PathLike
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import List, Dict, Union
|
from typing import List, Dict, Union
|
||||||
|
|
||||||
@ -9,14 +10,17 @@ from environments.helpers import IGNORED_DF_COLUMNS
|
|||||||
|
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
|
|
||||||
|
from plotting.compare_runs import plot_single_run
|
||||||
|
|
||||||
|
|
||||||
class EnvMonitor(BaseCallback):
|
class EnvMonitor(BaseCallback):
|
||||||
|
|
||||||
ext = 'png'
|
ext = 'png'
|
||||||
|
|
||||||
def __init__(self, env):
|
def __init__(self, env, filepath: Union[str, PathLike] = None):
|
||||||
super(EnvMonitor, self).__init__()
|
super(EnvMonitor, self).__init__()
|
||||||
self.unwrapped = env
|
self.unwrapped = env
|
||||||
|
self._filepath = filepath
|
||||||
self._monitor_df = pd.DataFrame()
|
self._monitor_df = pd.DataFrame()
|
||||||
self._monitor_dicts = defaultdict(dict)
|
self._monitor_dicts = defaultdict(dict)
|
||||||
|
|
||||||
@ -67,8 +71,10 @@ class EnvMonitor(BaseCallback):
|
|||||||
pass
|
pass
|
||||||
return
|
return
|
||||||
|
|
||||||
def save_run(self, filepath: Union[Path, str]):
|
def save_run(self, filepath: Union[Path, str, None] = None, auto_plotting_keys=None):
|
||||||
filepath = Path(filepath)
|
filepath = Path(filepath or self._filepath)
|
||||||
filepath.parent.mkdir(exist_ok=True, parents=True)
|
filepath.parent.mkdir(exist_ok=True, parents=True)
|
||||||
with filepath.open('wb') as f:
|
with filepath.open('wb') as f:
|
||||||
pickle.dump(self._monitor_df.reset_index(), f, protocol=pickle.HIGHEST_PROTOCOL)
|
pickle.dump(self._monitor_df.reset_index(), f, protocol=pickle.HIGHEST_PROTOCOL)
|
||||||
|
if auto_plotting_keys:
|
||||||
|
plot_single_run(filepath, column_keys=auto_plotting_keys)
|
||||||
|
@ -1,10 +1,13 @@
|
|||||||
|
import warnings
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
|
from os import PathLike
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Union
|
from typing import Union
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
import simplejson
|
import simplejson
|
||||||
|
from deepdiff.operator import BaseOperator
|
||||||
from stable_baselines3.common.callbacks import BaseCallback
|
from stable_baselines3.common.callbacks import BaseCallback
|
||||||
|
|
||||||
from environments.factory.base.base_factory import REC_TAC
|
from environments.factory.base.base_factory import REC_TAC
|
||||||
@ -12,11 +15,15 @@ from environments.factory.base.base_factory import REC_TAC
|
|||||||
|
|
||||||
class EnvRecorder(BaseCallback):
|
class EnvRecorder(BaseCallback):
|
||||||
|
|
||||||
def __init__(self, env, entities='all'):
|
def __init__(self, env, entities: str = 'all', filepath: Union[str, PathLike] = None, freq: int = 0):
|
||||||
super(EnvRecorder, self).__init__()
|
super(EnvRecorder, self).__init__()
|
||||||
|
self.filepath = filepath
|
||||||
self.unwrapped = env
|
self.unwrapped = env
|
||||||
|
self.freq = freq
|
||||||
self._recorder_dict = defaultdict(list)
|
self._recorder_dict = defaultdict(list)
|
||||||
self._recorder_out_list = list()
|
self._recorder_out_list = list()
|
||||||
|
self._episode_counter = 1
|
||||||
|
self._do_record_dict = defaultdict(lambda: False)
|
||||||
if isinstance(entities, str):
|
if isinstance(entities, str):
|
||||||
if entities.lower() == 'all':
|
if entities.lower() == 'all':
|
||||||
self._entities = None
|
self._entities = None
|
||||||
@ -24,45 +31,70 @@ class EnvRecorder(BaseCallback):
|
|||||||
self._entities = [entities]
|
self._entities = [entities]
|
||||||
else:
|
else:
|
||||||
self._entities = entities
|
self._entities = entities
|
||||||
self.started = False
|
|
||||||
self.closed = False
|
|
||||||
|
|
||||||
def __getattr__(self, item):
|
def __getattr__(self, item):
|
||||||
return getattr(self.unwrapped, item)
|
return getattr(self.unwrapped, item)
|
||||||
|
|
||||||
def reset(self):
|
def reset(self):
|
||||||
self.unwrapped._record_episodes = True
|
self._on_training_start()
|
||||||
return self.unwrapped.reset()
|
return self.unwrapped.reset()
|
||||||
|
|
||||||
def _on_training_start(self) -> None:
|
def _on_training_start(self) -> None:
|
||||||
self.unwrapped._record_episodes = True
|
assert self.start_recording()
|
||||||
pass
|
|
||||||
|
|
||||||
def _read_info(self, env_idx, info: dict):
|
def _read_info(self, env_idx, info: dict):
|
||||||
if info_dict := {key.replace(REC_TAC, ''): val for key, val in info.items() if key.startswith(f'{REC_TAC}')}:
|
if info_dict := {key.replace(REC_TAC, ''): val for key, val in info.items() if key.startswith(f'{REC_TAC}')}:
|
||||||
if self._entities:
|
if self._entities:
|
||||||
info_dict = {k: v for k, v in info_dict.items() if k in self._entities}
|
info_dict = {k: v for k, v in info_dict.items() if k in self._entities}
|
||||||
|
|
||||||
info_dict.update(episode=(self.num_timesteps + env_idx))
|
|
||||||
self._recorder_dict[env_idx].append(info_dict)
|
self._recorder_dict[env_idx].append(info_dict)
|
||||||
else:
|
else:
|
||||||
pass
|
pass
|
||||||
return
|
return True
|
||||||
|
|
||||||
def _read_done(self, env_idx, done):
|
def _read_done(self, env_idx, done):
|
||||||
if done:
|
if done:
|
||||||
self._recorder_out_list.append({'steps': self._recorder_dict[env_idx],
|
self._recorder_out_list.append({'steps': self._recorder_dict[env_idx],
|
||||||
'episode': len(self._recorder_out_list)})
|
'episode': self._episode_counter})
|
||||||
self._recorder_dict[env_idx] = list()
|
self._recorder_dict[env_idx] = list()
|
||||||
else:
|
else:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def save_records(self, filepath: Union[Path, str], save_occupation_map=False, save_trajectory_map=False):
|
def step(self, actions):
|
||||||
filepath = Path(filepath)
|
step_result = self.unwrapped.step(actions)
|
||||||
|
if self.do_record_episode(0):
|
||||||
|
info = step_result[-1]
|
||||||
|
self._read_info(0, info)
|
||||||
|
if self._do_record_dict[0]:
|
||||||
|
self._read_done(0, step_result[-2])
|
||||||
|
return step_result
|
||||||
|
|
||||||
|
def finalize(self):
|
||||||
|
self._on_training_end()
|
||||||
|
return True
|
||||||
|
|
||||||
|
def save_records(self, filepath: Union[Path, str, None] = None,
|
||||||
|
only_deltas=True,
|
||||||
|
save_occupation_map=False,
|
||||||
|
save_trajectory_map=False,
|
||||||
|
):
|
||||||
|
filepath = Path(filepath or self.filepath)
|
||||||
filepath.parent.mkdir(exist_ok=True, parents=True)
|
filepath.parent.mkdir(exist_ok=True, parents=True)
|
||||||
# self.out_file.unlink(missing_ok=True)
|
# cls.out_file.unlink(missing_ok=True)
|
||||||
with filepath.open('w') as f:
|
with filepath.open('w') as f:
|
||||||
out_dict = {'episodes': self._recorder_out_list, 'header': self.unwrapped.params}
|
if only_deltas:
|
||||||
|
from deepdiff import DeepDiff, Delta
|
||||||
|
diff_dict = [DeepDiff(t1,t2, ignore_order=True)
|
||||||
|
for t1, t2 in zip(self._recorder_out_list, self._recorder_out_list[1:])
|
||||||
|
]
|
||||||
|
out_dict = {'episodes': diff_dict}
|
||||||
|
|
||||||
|
else:
|
||||||
|
out_dict = {'episodes': self._recorder_out_list}
|
||||||
|
out_dict.update(
|
||||||
|
{'n_episodes': self._episode_counter,
|
||||||
|
'env_params': self.unwrapped.params,
|
||||||
|
'header': self.unwrapped.summarize_header
|
||||||
|
})
|
||||||
try:
|
try:
|
||||||
simplejson.dump(out_dict, f, indent=4)
|
simplejson.dump(out_dict, f, indent=4)
|
||||||
except TypeError:
|
except TypeError:
|
||||||
@ -70,6 +102,7 @@ class EnvRecorder(BaseCallback):
|
|||||||
|
|
||||||
if save_occupation_map:
|
if save_occupation_map:
|
||||||
a = np.zeros((15, 15))
|
a = np.zeros((15, 15))
|
||||||
|
# noinspection PyTypeChecker
|
||||||
for episode in out_dict['episodes']:
|
for episode in out_dict['episodes']:
|
||||||
df = pd.DataFrame([y for x in episode['steps'] for y in x['Agents']])
|
df = pd.DataFrame([y for x in episode['steps'] for y in x['Agents']])
|
||||||
|
|
||||||
@ -87,16 +120,34 @@ class EnvRecorder(BaseCallback):
|
|||||||
if save_trajectory_map:
|
if save_trajectory_map:
|
||||||
raise NotImplementedError('This has not yet been implemented.')
|
raise NotImplementedError('This has not yet been implemented.')
|
||||||
|
|
||||||
|
def do_record_episode(self, env_idx):
|
||||||
|
if not self._recorder_dict[env_idx]:
|
||||||
|
if self.freq:
|
||||||
|
self._do_record_dict[env_idx] = (self.freq == -1) or (self._episode_counter % self.freq) == 0
|
||||||
|
else:
|
||||||
|
self._do_record_dict[env_idx] = False
|
||||||
|
warnings.warn('You did wrap your Environment with a recorder, but set the freq to zero\n'
|
||||||
|
'Nothing will be recorded')
|
||||||
|
self._episode_counter += 1
|
||||||
|
else:
|
||||||
|
pass
|
||||||
|
return self._do_record_dict[env_idx]
|
||||||
|
|
||||||
def _on_step(self) -> bool:
|
def _on_step(self) -> bool:
|
||||||
for env_idx, info in enumerate(self.locals.get('infos', [])):
|
for env_idx, info in enumerate(self.locals.get('infos', [])):
|
||||||
self._read_info(env_idx, info)
|
if self._do_record_dict[env_idx]:
|
||||||
|
self._read_info(env_idx, info)
|
||||||
dones = list(enumerate(self.locals.get('dones', [])))
|
dones = list(enumerate(self.locals.get('dones', [])))
|
||||||
dones.extend(list(enumerate(self.locals.get('done', []))))
|
dones.extend(list(enumerate(self.locals.get('done', []))))
|
||||||
for env_idx, done in dones:
|
for env_idx, done in dones:
|
||||||
self._read_done(env_idx, done)
|
if self._do_record_dict[env_idx]:
|
||||||
|
self._read_done(env_idx, done)
|
||||||
|
|
||||||
return True
|
return True
|
||||||
|
|
||||||
def _on_training_end(self) -> None:
|
def _on_training_end(self) -> None:
|
||||||
|
for env_idx in range(len(self._recorder_dict)):
|
||||||
|
if self._recorder_dict[env_idx]:
|
||||||
|
self._recorder_out_list.append({'steps': self._recorder_dict[env_idx],
|
||||||
|
'episode': self._episode_counter})
|
||||||
pass
|
pass
|
||||||
|
@ -3,7 +3,38 @@ import gym
|
|||||||
from gym.wrappers.frame_stack import FrameStack
|
from gym.wrappers.frame_stack import FrameStack
|
||||||
|
|
||||||
|
|
||||||
|
class EnvCombiner(object):
|
||||||
|
|
||||||
|
def __init__(self, *envs_cls):
|
||||||
|
self._env_dict = {env_cls.__name__: env_cls for env_cls in envs_cls}
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def combine_cls(name, *envs_cls):
|
||||||
|
return type(name,envs_cls,{})
|
||||||
|
|
||||||
|
def build(self):
|
||||||
|
name = f'{"".join([x.lower().replace("factory").capitalize() for x in self._env_dict.keys()])}Factory'
|
||||||
|
|
||||||
|
return self.combine_cls(name, tuple(self._env_dict.values()))
|
||||||
|
|
||||||
|
|
||||||
class AgentRenderOptions(object):
|
class AgentRenderOptions(object):
|
||||||
|
"""
|
||||||
|
Class that specifies the available options for the way agents are represented in the env observation.
|
||||||
|
|
||||||
|
SEPERATE:
|
||||||
|
Each agent is represented in a seperate slice as Constant.OCCUPIED_CELL value (one hot)
|
||||||
|
|
||||||
|
COMBINED:
|
||||||
|
For all agent, value of Constant.OCCUPIED_CELL is added to a zero-value slice at the agents position (sum(SEPERATE))
|
||||||
|
|
||||||
|
LEVEL:
|
||||||
|
The combined slice is added to the LEVEL-slice. (Agents appear as obstacle / wall)
|
||||||
|
|
||||||
|
NOT:
|
||||||
|
The position of individual agents can not be read from the observation.
|
||||||
|
"""
|
||||||
|
|
||||||
SEPERATE = 'seperate'
|
SEPERATE = 'seperate'
|
||||||
COMBINED = 'combined'
|
COMBINED = 'combined'
|
||||||
LEVEL = 'lvl'
|
LEVEL = 'lvl'
|
||||||
@ -11,22 +42,61 @@ class AgentRenderOptions(object):
|
|||||||
|
|
||||||
|
|
||||||
class MovementProperties(NamedTuple):
|
class MovementProperties(NamedTuple):
|
||||||
|
"""
|
||||||
|
Property holder; for setting multiple related parameters through a single parameter. Comes with default values.
|
||||||
|
"""
|
||||||
|
|
||||||
|
"""Allow the manhattan style movement on a grid (move to cells that are connected by square edges)."""
|
||||||
allow_square_movement: bool = True
|
allow_square_movement: bool = True
|
||||||
|
|
||||||
|
"""Allow diagonal movement on the grid (move to cells that are connected by square corners)."""
|
||||||
allow_diagonal_movement: bool = False
|
allow_diagonal_movement: bool = False
|
||||||
|
|
||||||
|
"""Allow the agent to just do nothing; not move (NO-OP)."""
|
||||||
allow_no_op: bool = False
|
allow_no_op: bool = False
|
||||||
|
|
||||||
|
|
||||||
class ObservationProperties(NamedTuple):
|
class ObservationProperties(NamedTuple):
|
||||||
|
"""
|
||||||
|
Property holder; for setting multiple related parameters through a single parameter. Comes with default values.
|
||||||
|
"""
|
||||||
|
|
||||||
|
"""How to represent agents in the observation space. This may also alter the obs-shape."""
|
||||||
render_agents: AgentRenderOptions = AgentRenderOptions.SEPERATE
|
render_agents: AgentRenderOptions = AgentRenderOptions.SEPERATE
|
||||||
|
|
||||||
|
"""Obserations are build per agent; whether the current agent should be represented in its own observation."""
|
||||||
omit_agent_self: bool = True
|
omit_agent_self: bool = True
|
||||||
|
|
||||||
|
"""Their might be the case you want to modify the agents obs-space, so that it can be used with additional obs.
|
||||||
|
The additional slice can be filled with any number"""
|
||||||
additional_agent_placeholder: Union[None, str, int] = None
|
additional_agent_placeholder: Union[None, str, int] = None
|
||||||
|
|
||||||
|
"""Whether to cast shadows (make floortiles and items hidden).; """
|
||||||
cast_shadows: bool = True
|
cast_shadows: bool = True
|
||||||
|
|
||||||
|
"""Frame Stacking is a methode do give some temporal information to the agents.
|
||||||
|
This paramters controls how many "old-frames" """
|
||||||
frames_to_stack: int = 0
|
frames_to_stack: int = 0
|
||||||
pomdp_r: int = 0
|
|
||||||
show_global_position_info: bool = True
|
"""Specifies the radius (_r) of the agents field of view. Please note, that the agents grid cellis not taken
|
||||||
|
accountance for. This means, that the resulting field of view diameter = `pomdp_r * 2 + 1`.
|
||||||
|
A 'pomdp_r' of 0 always returns the full env == no partial observability."""
|
||||||
|
pomdp_r: int = 2
|
||||||
|
|
||||||
|
"""Whether to place a visual encoding on walkable tiles around the doors. This is helpfull when the doors can be
|
||||||
|
operated from their surrounding area. So the agent can more easily get a notion of where to choose the door option.
|
||||||
|
However, this is not necesarry at all.
|
||||||
|
"""
|
||||||
|
indicate_door_area: bool = False
|
||||||
|
|
||||||
|
"""Whether to add the agents normalized global position as float values (2,1) to a seperate information slice.
|
||||||
|
More optional informations are to come.
|
||||||
|
"""
|
||||||
|
show_global_position_info: bool = False
|
||||||
|
|
||||||
|
|
||||||
class MarlFrameStack(gym.ObservationWrapper):
|
class MarlFrameStack(gym.ObservationWrapper):
|
||||||
|
"""todo @romue404"""
|
||||||
def __init__(self, env):
|
def __init__(self, env):
|
||||||
super().__init__(env)
|
super().__init__(env)
|
||||||
|
|
||||||
@ -34,4 +104,3 @@ class MarlFrameStack(gym.ObservationWrapper):
|
|||||||
if isinstance(self.env, FrameStack) and self.env.unwrapped.n_agents > 1:
|
if isinstance(self.env, FrameStack) and self.env.unwrapped.n_agents > 1:
|
||||||
return observation[0:].swapaxes(0, 1)
|
return observation[0:].swapaxes(0, 1)
|
||||||
return observation
|
return observation
|
||||||
|
|
||||||
|
@ -10,6 +10,45 @@ from environments.helpers import IGNORED_DF_COLUMNS, MODEL_MAP
|
|||||||
from plotting.plotting import prepare_plot
|
from plotting.plotting import prepare_plot
|
||||||
|
|
||||||
|
|
||||||
|
def plot_single_run(run_path: Union[str, PathLike], use_tex: bool = False, column_keys=None):
|
||||||
|
run_path = Path(run_path)
|
||||||
|
df_list = list()
|
||||||
|
if run_path.is_dir():
|
||||||
|
monitor_file = next(run_path.glob('*monitor*.pick'))
|
||||||
|
elif run_path.exists() and run_path.is_file():
|
||||||
|
monitor_file = run_path
|
||||||
|
else:
|
||||||
|
raise ValueError
|
||||||
|
|
||||||
|
with monitor_file.open('rb') as f:
|
||||||
|
monitor_df = pickle.load(f)
|
||||||
|
|
||||||
|
monitor_df = monitor_df.fillna(0)
|
||||||
|
df_list.append(monitor_df)
|
||||||
|
|
||||||
|
df = pd.concat(df_list, ignore_index=True)
|
||||||
|
df = df.fillna(0).rename(columns={'episode': 'Episode'}).sort_values(['Episode'])
|
||||||
|
if column_keys is not None:
|
||||||
|
columns = [col for col in column_keys if col in df.columns]
|
||||||
|
else:
|
||||||
|
columns = [col for col in df.columns if col not in IGNORED_DF_COLUMNS]
|
||||||
|
|
||||||
|
roll_n = 50
|
||||||
|
|
||||||
|
non_overlapp_window = df.groupby(['Episode']).rolling(roll_n, min_periods=1).mean()
|
||||||
|
|
||||||
|
df_melted = df[columns + ['Episode']].reset_index().melt(id_vars=['Episode'],
|
||||||
|
value_vars=columns, var_name="Measurement",
|
||||||
|
value_name="Score")
|
||||||
|
|
||||||
|
if df_melted['Episode'].max() > 800:
|
||||||
|
skip_n = round(df_melted['Episode'].max() * 0.02)
|
||||||
|
df_melted = df_melted[df_melted['Episode'] % skip_n == 0]
|
||||||
|
|
||||||
|
prepare_plot(run_path.parent / f'{run_path.parent.name}_monitor_lineplot.png', df_melted, use_tex=use_tex)
|
||||||
|
print('Plotting done.')
|
||||||
|
|
||||||
|
|
||||||
def compare_seed_runs(run_path: Union[str, PathLike], use_tex: bool = False):
|
def compare_seed_runs(run_path: Union[str, PathLike], use_tex: bool = False):
|
||||||
run_path = Path(run_path)
|
run_path = Path(run_path)
|
||||||
df_list = list()
|
df_list = list()
|
||||||
@ -37,7 +76,11 @@ def compare_seed_runs(run_path: Union[str, PathLike], use_tex: bool = False):
|
|||||||
skip_n = round(df_melted['Episode'].max() * 0.02)
|
skip_n = round(df_melted['Episode'].max() * 0.02)
|
||||||
df_melted = df_melted[df_melted['Episode'] % skip_n == 0]
|
df_melted = df_melted[df_melted['Episode'] % skip_n == 0]
|
||||||
|
|
||||||
prepare_plot(run_path / f'{run_path.name}_monitor_lineplot.png', df_melted, use_tex=use_tex)
|
run_path.mkdir(parents=True, exist_ok=True)
|
||||||
|
if run_path.exists() and run_path.is_file():
|
||||||
|
prepare_plot(run_path.parent / f'{run_path.name}_monitor_lineplot.png', df_melted, use_tex=use_tex)
|
||||||
|
else:
|
||||||
|
prepare_plot(run_path / f'{run_path.name}_monitor_lineplot.png', df_melted, use_tex=use_tex)
|
||||||
print('Plotting done.')
|
print('Plotting done.')
|
||||||
|
|
||||||
|
|
||||||
|
@ -1,4 +1,5 @@
|
|||||||
import seaborn as sns
|
import seaborn as sns
|
||||||
|
import matplotlib as mpl
|
||||||
from matplotlib import pyplot as plt
|
from matplotlib import pyplot as plt
|
||||||
|
|
||||||
PALETTE = 10 * (
|
PALETTE = 10 * (
|
||||||
@ -21,7 +22,14 @@ PALETTE = 10 * (
|
|||||||
def plot(filepath, ext='png'):
|
def plot(filepath, ext='png'):
|
||||||
plt.tight_layout()
|
plt.tight_layout()
|
||||||
figure = plt.gcf()
|
figure = plt.gcf()
|
||||||
figure.savefig(str(filepath), format=ext)
|
ax = plt.gca()
|
||||||
|
legends = [c for c in ax.get_children() if isinstance(c, mpl.legend.Legend)]
|
||||||
|
|
||||||
|
if legends:
|
||||||
|
figure.savefig(str(filepath), format=ext, bbox_extra_artists=(*legends,), bbox_inches='tight')
|
||||||
|
else:
|
||||||
|
figure.savefig(str(filepath), format=ext)
|
||||||
|
|
||||||
plt.show()
|
plt.show()
|
||||||
plt.clf()
|
plt.clf()
|
||||||
|
|
||||||
@ -30,7 +38,7 @@ def prepare_tex(df, hue, style, hue_order):
|
|||||||
sns.set(rc={'text.usetex': True}, style='whitegrid')
|
sns.set(rc={'text.usetex': True}, style='whitegrid')
|
||||||
lineplot = sns.lineplot(data=df, x='Episode', y='Score', ci=95, palette=PALETTE,
|
lineplot = sns.lineplot(data=df, x='Episode', y='Score', ci=95, palette=PALETTE,
|
||||||
hue_order=hue_order, hue=hue, style=style)
|
hue_order=hue_order, hue=hue, style=style)
|
||||||
# lineplot.set_title(f'{sorted(list(df["Measurement"].unique()))}')
|
lineplot.set_title(f'{sorted(list(df["Measurement"].unique()))}')
|
||||||
plt.legend(bbox_to_anchor=(1.02, 1), loc='upper left', borderaxespad=0)
|
plt.legend(bbox_to_anchor=(1.02, 1), loc='upper left', borderaxespad=0)
|
||||||
plt.tight_layout()
|
plt.tight_layout()
|
||||||
return lineplot
|
return lineplot
|
||||||
@ -48,6 +56,19 @@ def prepare_plt(df, hue, style, hue_order):
|
|||||||
return lineplot
|
return lineplot
|
||||||
|
|
||||||
|
|
||||||
|
def prepare_center_double_column_legend(df, hue, style, hue_order):
|
||||||
|
print('Struggling to plot Figure using LaTeX - going back to normal.')
|
||||||
|
plt.close('all')
|
||||||
|
sns.set(rc={'text.usetex': False}, style='whitegrid')
|
||||||
|
fig = plt.figure(figsize=(10, 11))
|
||||||
|
lineplot = sns.lineplot(data=df, x='Episode', y='Score', hue=hue, style=style,
|
||||||
|
ci=95, palette=PALETTE, hue_order=hue_order, legend=False)
|
||||||
|
# plt.legend(bbox_to_anchor=(1.02, 1), loc='upper left', borderaxespad=0)
|
||||||
|
lineplot.legend(hue_order, ncol=3, loc='lower center', title='Parameter Combinations', bbox_to_anchor=(0.5, -0.43))
|
||||||
|
plt.tight_layout()
|
||||||
|
return lineplot
|
||||||
|
|
||||||
|
|
||||||
def prepare_plot(filepath, results_df, ext='png', hue='Measurement', style=None, use_tex=False):
|
def prepare_plot(filepath, results_df, ext='png', hue='Measurement', style=None, use_tex=False):
|
||||||
df = results_df.copy()
|
df = results_df.copy()
|
||||||
df[hue] = df[hue].str.replace('_', '-')
|
df[hue] = df[hue].str.replace('_', '-')
|
||||||
|
189
quickstart/combine_and_monitor_rerun.py
Normal file
189
quickstart/combine_and_monitor_rerun.py
Normal file
@ -0,0 +1,189 @@
|
|||||||
|
import sys
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
##############################################
|
||||||
|
# keep this for stand alone script execution #
|
||||||
|
##############################################
|
||||||
|
from environments.factory.base.base_factory import BaseFactory
|
||||||
|
from environments.logging.recorder import EnvRecorder
|
||||||
|
|
||||||
|
try:
|
||||||
|
# noinspection PyUnboundLocalVariable
|
||||||
|
if __package__ is None:
|
||||||
|
DIR = Path(__file__).resolve().parent
|
||||||
|
sys.path.insert(0, str(DIR.parent))
|
||||||
|
__package__ = DIR.name
|
||||||
|
else:
|
||||||
|
DIR = None
|
||||||
|
except NameError:
|
||||||
|
DIR = None
|
||||||
|
pass
|
||||||
|
##############################################
|
||||||
|
##############################################
|
||||||
|
##############################################
|
||||||
|
|
||||||
|
|
||||||
|
import simplejson
|
||||||
|
|
||||||
|
from environments import helpers as h
|
||||||
|
from environments.factory.additional.combined_factories import DestBatteryFactory
|
||||||
|
from environments.factory.additional.dest.factory_dest import DestFactory
|
||||||
|
from environments.factory.additional.dirt.factory_dirt import DirtFactory
|
||||||
|
from environments.factory.additional.item.factory_item import ItemFactory
|
||||||
|
from environments.helpers import ObservationTranslator, ActionTranslator
|
||||||
|
from environments.logging.envmonitor import EnvMonitor
|
||||||
|
from environments.utility_classes import ObservationProperties, AgentRenderOptions, MovementProperties
|
||||||
|
|
||||||
|
|
||||||
|
def policy_model_kwargs():
|
||||||
|
return dict(ent_coef=0.01)
|
||||||
|
|
||||||
|
|
||||||
|
def dqn_model_kwargs():
|
||||||
|
return dict(buffer_size=50000,
|
||||||
|
learning_starts=64,
|
||||||
|
batch_size=64,
|
||||||
|
target_update_interval=5000,
|
||||||
|
exploration_fraction=0.25,
|
||||||
|
exploration_final_eps=0.025
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def encapsule_env_factory(env_fctry, env_kwrgs):
|
||||||
|
|
||||||
|
def _init():
|
||||||
|
with env_fctry(**env_kwrgs) as init_env:
|
||||||
|
return init_env
|
||||||
|
|
||||||
|
return _init
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
|
||||||
|
render = False
|
||||||
|
# Define Global Env Parameters
|
||||||
|
# Define properties object parameters
|
||||||
|
factory_kwargs = dict(
|
||||||
|
max_steps=400, parse_doors=True,
|
||||||
|
level_name='rooms',
|
||||||
|
doors_have_area=True, verbose=False,
|
||||||
|
mv_prop=MovementProperties(allow_diagonal_movement=True,
|
||||||
|
allow_square_movement=True,
|
||||||
|
allow_no_op=False),
|
||||||
|
obs_prop=ObservationProperties(
|
||||||
|
frames_to_stack=3,
|
||||||
|
cast_shadows=True,
|
||||||
|
omit_agent_self=True,
|
||||||
|
render_agents=AgentRenderOptions.LEVEL,
|
||||||
|
additional_agent_placeholder=None,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Bundle both environments with global kwargs and parameters
|
||||||
|
# Todo: find a better solution, like outo module loading
|
||||||
|
env_map = {'DirtFactory': DirtFactory,
|
||||||
|
'ItemFactory': ItemFactory,
|
||||||
|
'DestFactory': DestFactory,
|
||||||
|
'DestBatteryFactory': DestBatteryFactory
|
||||||
|
}
|
||||||
|
env_names = list(env_map.keys())
|
||||||
|
|
||||||
|
# Put all your multi-seed agends in a single folder, we do not need specific names etc.
|
||||||
|
available_models = dict()
|
||||||
|
available_envs = dict()
|
||||||
|
available_runs_kwargs = dict()
|
||||||
|
available_runs_agents = dict()
|
||||||
|
max_seed = 0
|
||||||
|
# Define this folder
|
||||||
|
combinations_path = Path('combinations')
|
||||||
|
# Those are all differently trained combinations of mdoels, env and parameters
|
||||||
|
for combination in (x for x in combinations_path.iterdir() if x.is_dir()):
|
||||||
|
# These are all the models for this specific combination
|
||||||
|
for model_run in (x for x in combination.iterdir() if x.is_dir()):
|
||||||
|
model_name, env_name = model_run.name.split('_')[:2]
|
||||||
|
if model_name not in available_models:
|
||||||
|
available_models[model_name] = h.MODEL_MAP[model_name]
|
||||||
|
if env_name not in available_envs:
|
||||||
|
available_envs[env_name] = env_map[env_name]
|
||||||
|
# Those are all available seeds
|
||||||
|
for seed_run in (x for x in model_run.iterdir() if x.is_dir()):
|
||||||
|
max_seed = max(int(seed_run.name.split('_')[0]), max_seed)
|
||||||
|
# Read the env configuration from ROM
|
||||||
|
with next(seed_run.glob('env_params.json')).open('r') as f:
|
||||||
|
env_kwargs = simplejson.load(f)
|
||||||
|
available_runs_kwargs[seed_run.name] = env_kwargs
|
||||||
|
# Read the trained model_path from ROM
|
||||||
|
model_path = next(seed_run.glob('model.zip'))
|
||||||
|
available_runs_agents[seed_run.name] = model_path
|
||||||
|
|
||||||
|
# We start by combining all SAME MODEL CLASSES per available Seed, across ALL available ENVIRONMENTS.
|
||||||
|
for model_name, model_cls in available_models.items():
|
||||||
|
for seed in range(max_seed):
|
||||||
|
combined_env_kwargs = dict()
|
||||||
|
model_paths = list()
|
||||||
|
comparable_runs = {key: val for key, val in available_runs_kwargs.items() if (
|
||||||
|
key.startswith(str(seed)) and model_name in key and key != 'key')
|
||||||
|
}
|
||||||
|
for name, run_kwargs in comparable_runs.items():
|
||||||
|
# Select trained agent as a candidate:
|
||||||
|
model_paths.append(available_runs_agents[name])
|
||||||
|
# Sort Env Kwars:
|
||||||
|
for key, val in run_kwargs.items():
|
||||||
|
if key not in combined_env_kwargs:
|
||||||
|
combined_env_kwargs.update(dict(key=val))
|
||||||
|
else:
|
||||||
|
assert combined_env_kwargs[key] == val, "Check the combinations you try to make!"
|
||||||
|
|
||||||
|
# Update and combine all kwargs to account for multiple agents etc.
|
||||||
|
# We cannot capture all configuration cases!
|
||||||
|
for key, val in factory_kwargs.items():
|
||||||
|
if key not in combined_env_kwargs:
|
||||||
|
combined_env_kwargs[key] = val
|
||||||
|
else:
|
||||||
|
assert combined_env_kwargs[key] == val
|
||||||
|
del combined_env_kwargs['key']
|
||||||
|
combined_env_kwargs.update(n_agents=len(comparable_runs))
|
||||||
|
with type("CombinedEnv", tuple(available_envs.values()), {})(**combined_env_kwargs) as combEnv:
|
||||||
|
# EnvMonitor Init
|
||||||
|
comb = f'comb_{model_name}_{seed}'
|
||||||
|
comb_monitor_path = combinations_path / comb / f'{comb}_monitor.pick'
|
||||||
|
comb_recorder_path = combinations_path / comb / f'{comb}_recorder.json'
|
||||||
|
comb_monitor_path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
monitoredCombEnv = EnvMonitor(combEnv, filepath=comb_monitor_path)
|
||||||
|
monitoredCombEnv = EnvRecorder(monitoredCombEnv, filepath=comb_recorder_path, freq=1)
|
||||||
|
|
||||||
|
# Evaluation starts here #####################################################
|
||||||
|
# Load all models
|
||||||
|
loaded_models = [available_models[model_name].load(model_path) for model_path in model_paths]
|
||||||
|
obs_translators = ObservationTranslator(
|
||||||
|
monitoredCombEnv.named_observation_space,
|
||||||
|
*[agent.named_observation_space for agent in loaded_models],
|
||||||
|
placeholder_fill_value='n')
|
||||||
|
act_translators = ActionTranslator(
|
||||||
|
monitoredCombEnv.named_action_space,
|
||||||
|
*(agent.named_action_space for agent in loaded_models)
|
||||||
|
)
|
||||||
|
|
||||||
|
for episode in range(1):
|
||||||
|
obs = monitoredCombEnv.reset()
|
||||||
|
if render: monitoredCombEnv.render()
|
||||||
|
rew, done_bool = 0, False
|
||||||
|
while not done_bool:
|
||||||
|
actions = []
|
||||||
|
for i, model in enumerate(loaded_models):
|
||||||
|
pred = model.predict(obs_translators.translate_observation(i, obs[i]))[0]
|
||||||
|
actions.append(act_translators.translate_action(i, pred))
|
||||||
|
|
||||||
|
obs, step_r, done_bool, info_obj = monitoredCombEnv.step(actions)
|
||||||
|
|
||||||
|
rew += step_r
|
||||||
|
if render: monitoredCombEnv.render()
|
||||||
|
if done_bool:
|
||||||
|
break
|
||||||
|
print(f'Factory run {episode} done, reward is:\n {rew}')
|
||||||
|
# Eval monitor outputs are automatically stored by the monitor object
|
||||||
|
# TODO: Plotting
|
||||||
|
monitoredCombEnv.save_records()
|
||||||
|
monitoredCombEnv.save_run()
|
||||||
|
pass
|
203
quickstart/single_agent_train_battery_target_env.py
Normal file
203
quickstart/single_agent_train_battery_target_env.py
Normal file
@ -0,0 +1,203 @@
|
|||||||
|
import sys
|
||||||
|
import time
|
||||||
|
|
||||||
|
from pathlib import Path
|
||||||
|
import simplejson
|
||||||
|
|
||||||
|
import stable_baselines3 as sb3
|
||||||
|
|
||||||
|
# This is needed, when you put this file in a subfolder.
|
||||||
|
try:
|
||||||
|
# noinspection PyUnboundLocalVariable
|
||||||
|
if __package__ is None:
|
||||||
|
DIR = Path(__file__).resolve().parent
|
||||||
|
sys.path.insert(0, str(DIR.parent))
|
||||||
|
__package__ = DIR.name
|
||||||
|
else:
|
||||||
|
DIR = None
|
||||||
|
except NameError:
|
||||||
|
DIR = None
|
||||||
|
pass
|
||||||
|
|
||||||
|
from environments import helpers as h
|
||||||
|
from environments.factory.additional.dest.dest_util import DestModeOptions, DestProperties
|
||||||
|
from environments.factory.additional.btry.btry_util import BatteryProperties
|
||||||
|
from environments.logging.envmonitor import EnvMonitor
|
||||||
|
from environments.logging.recorder import EnvRecorder
|
||||||
|
from environments.factory.additional.combined_factories import DestBatteryFactory
|
||||||
|
from environments.utility_classes import MovementProperties, ObservationProperties, AgentRenderOptions
|
||||||
|
|
||||||
|
from plotting.compare_runs import compare_seed_runs
|
||||||
|
|
||||||
|
"""
|
||||||
|
Welcome to this quick start file. Here we will see how to:
|
||||||
|
0. Setup I/O Paths
|
||||||
|
1. Setup parameters for the environments (dirt-factory).
|
||||||
|
2. Setup parameters for the agent training (SB3: PPO) and save metrics.
|
||||||
|
Run the training.
|
||||||
|
3. Save env and agent for later analysis.
|
||||||
|
4. Load the agent from drive
|
||||||
|
5. Rendering the env with a run of the trained agent.
|
||||||
|
6. Plot metrics
|
||||||
|
"""
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
#########################################################
|
||||||
|
# 0. Setup I/O Paths
|
||||||
|
# Define some general parameters
|
||||||
|
train_steps = 1e6
|
||||||
|
n_seeds = 3
|
||||||
|
model_class = sb3.PPO
|
||||||
|
env_class = DestBatteryFactory
|
||||||
|
|
||||||
|
env_params_json = 'env_params.json'
|
||||||
|
|
||||||
|
# Define a global studi save path
|
||||||
|
start_time = int(time.time())
|
||||||
|
study_root_path = Path(__file__).parent.parent / 'study_out' / f'{Path(__file__).stem}_{start_time}'
|
||||||
|
# Create an identifier, which is unique for every combination and easy to read in filesystem
|
||||||
|
identifier = f'{model_class.__name__}_{env_class.__name__}_{start_time}'
|
||||||
|
exp_path = study_root_path / identifier
|
||||||
|
|
||||||
|
#########################################################
|
||||||
|
# 1. Setup parameters for the environments (dirt-factory).
|
||||||
|
|
||||||
|
|
||||||
|
# Define property object parameters.
|
||||||
|
# 'ObservationProperties' are for specifying how the agent sees the env.
|
||||||
|
obs_props = ObservationProperties(render_agents=AgentRenderOptions.NOT, # Agents won`t be shown in the obs at all
|
||||||
|
omit_agent_self=True, # This is default
|
||||||
|
additional_agent_placeholder=None, # We will not take care of future agents
|
||||||
|
frames_to_stack=3, # To give the agent a notion of time
|
||||||
|
pomdp_r=2 # the agents view-radius
|
||||||
|
)
|
||||||
|
# 'MovementProperties' are for specifying how the agent is allowed to move in the env.
|
||||||
|
move_props = MovementProperties(allow_diagonal_movement=True, # Euclidean style (vertices)
|
||||||
|
allow_square_movement=True, # Manhattan (edges)
|
||||||
|
allow_no_op=False) # Pause movement (do nothing)
|
||||||
|
|
||||||
|
# 'DirtProperties' control if and how dirt is spawned
|
||||||
|
# TODO: Comments
|
||||||
|
dest_props = DestProperties(
|
||||||
|
n_dests = 2, # How many destinations are there
|
||||||
|
dwell_time = 0, # How long does the agent need to "wait" on a destination
|
||||||
|
spawn_frequency = 0,
|
||||||
|
spawn_in_other_zone = True, #
|
||||||
|
spawn_mode = DestModeOptions.DONE,
|
||||||
|
)
|
||||||
|
btry_props = BatteryProperties(
|
||||||
|
initial_charge = 0.9, #
|
||||||
|
charge_rate = 0.4, #
|
||||||
|
charge_locations = 3, #
|
||||||
|
per_action_costs = 0.01,
|
||||||
|
done_when_discharged = True,
|
||||||
|
multi_charge = False,
|
||||||
|
)
|
||||||
|
|
||||||
|
# These are the EnvKwargs for initializing the env class, holding all former parameter-classes
|
||||||
|
# TODO: Comments
|
||||||
|
factory_kwargs = dict(n_agents=1,
|
||||||
|
max_steps=400,
|
||||||
|
parse_doors=True,
|
||||||
|
level_name='rooms',
|
||||||
|
doors_have_area=True, #
|
||||||
|
verbose=False,
|
||||||
|
mv_prop=move_props, # See Above
|
||||||
|
obs_prop=obs_props, # See Above
|
||||||
|
done_at_collision=True,
|
||||||
|
dest_prop=dest_props,
|
||||||
|
btry_prop=btry_props
|
||||||
|
)
|
||||||
|
|
||||||
|
#########################################################
|
||||||
|
# 2. Setup parameters for the agent training (SB3: PPO) and save metrics.
|
||||||
|
agent_kwargs = dict()
|
||||||
|
|
||||||
|
|
||||||
|
#########################################################
|
||||||
|
# Run the Training
|
||||||
|
for seed in range(n_seeds):
|
||||||
|
# Make a copy if you want to alter things in the training loop; like the seed.
|
||||||
|
env_kwargs = factory_kwargs.copy()
|
||||||
|
env_kwargs.update(env_seed=seed)
|
||||||
|
|
||||||
|
# Output folder
|
||||||
|
seed_path = exp_path / f'{str(seed)}_{identifier}'
|
||||||
|
seed_path.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
# Parameter Storage
|
||||||
|
param_path = seed_path / env_params_json
|
||||||
|
# Observation (measures) Storage
|
||||||
|
monitor_path = seed_path / 'monitor.pick'
|
||||||
|
recorder_path = seed_path / 'recorder.json'
|
||||||
|
# Model save Path for the trained model
|
||||||
|
model_save_path = seed_path / f'model.zip'
|
||||||
|
|
||||||
|
# Env Init & Model kwargs definition
|
||||||
|
with env_class(**env_kwargs) as env_factory:
|
||||||
|
|
||||||
|
# EnvMonitor Init
|
||||||
|
env_monitor_callback = EnvMonitor(env_factory)
|
||||||
|
|
||||||
|
# EnvRecorder Init
|
||||||
|
env_recorder_callback = EnvRecorder(env_factory, freq=int(train_steps / 400 / 10))
|
||||||
|
|
||||||
|
# Model Init
|
||||||
|
model = model_class("MlpPolicy", env_factory, verbose=1, seed=seed, device='cpu')
|
||||||
|
|
||||||
|
# Model train
|
||||||
|
model.learn(total_timesteps=int(train_steps), callback=[env_monitor_callback, env_recorder_callback])
|
||||||
|
|
||||||
|
#########################################################
|
||||||
|
# 3. Save env and agent for later analysis.
|
||||||
|
# Save the trained Model, the monitor (env measures) and the env parameters
|
||||||
|
model.named_observation_space = env_factory.named_observation_space
|
||||||
|
model.named_action_space = env_factory.named_action_space
|
||||||
|
model.save(model_save_path)
|
||||||
|
env_factory.save_params(param_path)
|
||||||
|
env_monitor_callback.save_run(monitor_path)
|
||||||
|
env_recorder_callback.save_records(recorder_path, save_occupation_map=False)
|
||||||
|
|
||||||
|
# Compare performance runs, for each seed within a model
|
||||||
|
try:
|
||||||
|
compare_seed_runs(exp_path, use_tex=False)
|
||||||
|
except ValueError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
# Train ends here ############################################################
|
||||||
|
|
||||||
|
# Evaluation starts here #####################################################
|
||||||
|
# First Iterate over every model and monitor "as trained"
|
||||||
|
print('Start Measurement Tracking')
|
||||||
|
# For trained policy in study_root_path / identifier
|
||||||
|
for policy_path in [x for x in exp_path.iterdir() if x.is_dir()]:
|
||||||
|
|
||||||
|
# retrieve model class
|
||||||
|
model_cls = next(val for key, val in h.MODEL_MAP.items() if key in policy_path.parent.name)
|
||||||
|
# Load the agent agent
|
||||||
|
model = model_cls.load(policy_path / 'model.zip', device='cpu')
|
||||||
|
# Load old env kwargs
|
||||||
|
with next(policy_path.glob(env_params_json)).open('r') as f:
|
||||||
|
env_kwargs = simplejson.load(f)
|
||||||
|
# Make the env stop ar collisions
|
||||||
|
# (you only want to have a single collision per episode hence the statistics)
|
||||||
|
env_kwargs.update(done_at_collision=True)
|
||||||
|
|
||||||
|
# Init Env
|
||||||
|
with env_class(**env_kwargs) as env_factory:
|
||||||
|
monitored_env_factory = EnvMonitor(env_factory)
|
||||||
|
|
||||||
|
# Evaluation Loop for i in range(n Episodes)
|
||||||
|
for episode in range(100):
|
||||||
|
# noinspection PyRedeclaration
|
||||||
|
env_state = monitored_env_factory.reset()
|
||||||
|
rew, done_bool = 0, False
|
||||||
|
while not done_bool:
|
||||||
|
action = model.predict(env_state, deterministic=True)[0]
|
||||||
|
env_state, step_r, done_bool, info_obj = monitored_env_factory.step(action)
|
||||||
|
rew += step_r
|
||||||
|
if done_bool:
|
||||||
|
break
|
||||||
|
print(f'Factory run {episode} done, reward is:\n {rew}')
|
||||||
|
monitored_env_factory.save_run(filepath=policy_path / 'eval_run_monitor.pick')
|
||||||
|
print('Measurements Done')
|
193
quickstart/single_agent_train_dest_env.py
Normal file
193
quickstart/single_agent_train_dest_env.py
Normal file
@ -0,0 +1,193 @@
|
|||||||
|
import sys
|
||||||
|
import time
|
||||||
|
|
||||||
|
from pathlib import Path
|
||||||
|
import simplejson
|
||||||
|
|
||||||
|
import stable_baselines3 as sb3
|
||||||
|
|
||||||
|
# This is needed, when you put this file in a subfolder.
|
||||||
|
try:
|
||||||
|
# noinspection PyUnboundLocalVariable
|
||||||
|
if __package__ is None:
|
||||||
|
DIR = Path(__file__).resolve().parent
|
||||||
|
sys.path.insert(0, str(DIR.parent))
|
||||||
|
__package__ = DIR.name
|
||||||
|
else:
|
||||||
|
DIR = None
|
||||||
|
except NameError:
|
||||||
|
DIR = None
|
||||||
|
pass
|
||||||
|
|
||||||
|
from environments import helpers as h
|
||||||
|
from environments.factory.additional.dest.dest_util import DestModeOptions, DestProperties
|
||||||
|
from environments.logging.envmonitor import EnvMonitor
|
||||||
|
from environments.logging.recorder import EnvRecorder
|
||||||
|
from environments.factory.additional.dest.factory_dest import DestFactory
|
||||||
|
from environments.utility_classes import MovementProperties, ObservationProperties, AgentRenderOptions
|
||||||
|
|
||||||
|
from plotting.compare_runs import compare_seed_runs
|
||||||
|
|
||||||
|
"""
|
||||||
|
Welcome to this quick start file. Here we will see how to:
|
||||||
|
0. Setup I/O Paths
|
||||||
|
1. Setup parameters for the environments (dest-factory).
|
||||||
|
2. Setup parameters for the agent training (SB3: PPO) and save metrics.
|
||||||
|
Run the training.
|
||||||
|
3. Save env and agent for later analysis.
|
||||||
|
4. Load the agent from drive
|
||||||
|
5. Rendering the env with a run of the trained agent.
|
||||||
|
6. Plot metrics
|
||||||
|
"""
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
#########################################################
|
||||||
|
# 0. Setup I/O Paths
|
||||||
|
# Define some general parameters
|
||||||
|
train_steps = 1e6
|
||||||
|
n_seeds = 3
|
||||||
|
model_class = sb3.PPO
|
||||||
|
env_class = DestFactory
|
||||||
|
|
||||||
|
env_params_json = 'env_params.json'
|
||||||
|
|
||||||
|
# Define a global studi save path
|
||||||
|
start_time = int(time.time())
|
||||||
|
study_root_path = Path(__file__).parent.parent / 'study_out' / f'{Path(__file__).stem}_{start_time}'
|
||||||
|
# Create an identifier, which is unique for every combination and easy to read in filesystem
|
||||||
|
identifier = f'{model_class.__name__}_{env_class.__name__}_{start_time}'
|
||||||
|
exp_path = study_root_path / identifier
|
||||||
|
|
||||||
|
#########################################################
|
||||||
|
# 1. Setup parameters for the environments (dest-factory).
|
||||||
|
|
||||||
|
|
||||||
|
# Define property object parameters.
|
||||||
|
# 'ObservationProperties' are for specifying how the agent sees the env.
|
||||||
|
obs_props = ObservationProperties(render_agents=AgentRenderOptions.NOT, # Agents won`t be shown in the obs at all
|
||||||
|
omit_agent_self=True, # This is default
|
||||||
|
additional_agent_placeholder=None, # We will not take care of future agents
|
||||||
|
frames_to_stack=3, # To give the agent a notion of time
|
||||||
|
pomdp_r=2 # the agents view-radius
|
||||||
|
)
|
||||||
|
# 'MovementProperties' are for specifying how the agent is allowed to move in the env.
|
||||||
|
move_props = MovementProperties(allow_diagonal_movement=True, # Euclidean style (vertices)
|
||||||
|
allow_square_movement=True, # Manhattan (edges)
|
||||||
|
allow_no_op=False) # Pause movement (do nothing)
|
||||||
|
|
||||||
|
# 'DestProperties' control if and how dest is spawned
|
||||||
|
# TODO: Comments
|
||||||
|
dest_props = DestProperties(
|
||||||
|
n_dests = 2, # How many destinations are there
|
||||||
|
dwell_time = 0, # How long does the agent need to "wait" on a destination
|
||||||
|
spawn_frequency = 0,
|
||||||
|
spawn_in_other_zone = True, #
|
||||||
|
spawn_mode = DestModeOptions.DONE,
|
||||||
|
)
|
||||||
|
|
||||||
|
# These are the EnvKwargs for initializing the env class, holding all former parameter-classes
|
||||||
|
# TODO: Comments
|
||||||
|
factory_kwargs = dict(n_agents=1,
|
||||||
|
max_steps=400,
|
||||||
|
parse_doors=True,
|
||||||
|
level_name='rooms',
|
||||||
|
doors_have_area=True, #
|
||||||
|
verbose=False,
|
||||||
|
mv_prop=move_props, # See Above
|
||||||
|
obs_prop=obs_props, # See Above
|
||||||
|
done_at_collision=True,
|
||||||
|
dest_prop=dest_props
|
||||||
|
)
|
||||||
|
|
||||||
|
#########################################################
|
||||||
|
# 2. Setup parameters for the agent training (SB3: PPO) and save metrics.
|
||||||
|
agent_kwargs = dict()
|
||||||
|
|
||||||
|
|
||||||
|
#########################################################
|
||||||
|
# Run the Training
|
||||||
|
for seed in range(n_seeds):
|
||||||
|
# Make a copy if you want to alter things in the training loop; like the seed.
|
||||||
|
env_kwargs = factory_kwargs.copy()
|
||||||
|
env_kwargs.update(env_seed=seed)
|
||||||
|
|
||||||
|
# Output folder
|
||||||
|
seed_path = exp_path / f'{str(seed)}_{identifier}'
|
||||||
|
seed_path.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
# Parameter Storage
|
||||||
|
param_path = seed_path / env_params_json
|
||||||
|
# Observation (measures) Storage
|
||||||
|
monitor_path = seed_path / 'monitor.pick'
|
||||||
|
recorder_path = seed_path / 'recorder.json'
|
||||||
|
# Model save Path for the trained model
|
||||||
|
model_save_path = seed_path / f'model.zip'
|
||||||
|
|
||||||
|
# Env Init & Model kwargs definition
|
||||||
|
with env_class(**env_kwargs) as env_factory:
|
||||||
|
|
||||||
|
# EnvMonitor Init
|
||||||
|
env_monitor_callback = EnvMonitor(env_factory)
|
||||||
|
|
||||||
|
# EnvRecorder Init
|
||||||
|
env_recorder_callback = EnvRecorder(env_factory, freq=int(train_steps / 400 / 10))
|
||||||
|
|
||||||
|
# Model Init
|
||||||
|
model = model_class("MlpPolicy", env_factory,verbose=1, seed=seed, device='cpu')
|
||||||
|
|
||||||
|
# Model train
|
||||||
|
model.learn(total_timesteps=int(train_steps), callback=[env_monitor_callback, env_recorder_callback])
|
||||||
|
|
||||||
|
#########################################################
|
||||||
|
# 3. Save env and agent for later analysis.
|
||||||
|
# Save the trained Model, the monitor (env measures) and the env parameters
|
||||||
|
model.named_observation_space = env_factory.named_observation_space
|
||||||
|
model.named_action_space = env_factory.named_action_space
|
||||||
|
model.save(model_save_path)
|
||||||
|
env_factory.save_params(param_path)
|
||||||
|
env_monitor_callback.save_run(monitor_path)
|
||||||
|
env_recorder_callback.save_records(recorder_path, save_occupation_map=False)
|
||||||
|
|
||||||
|
# Compare performance runs, for each seed within a model
|
||||||
|
try:
|
||||||
|
compare_seed_runs(exp_path, use_tex=False)
|
||||||
|
except ValueError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
# Train ends here ############################################################
|
||||||
|
|
||||||
|
# Evaluation starts here #####################################################
|
||||||
|
# First Iterate over every model and monitor "as trained"
|
||||||
|
print('Start Measurement Tracking')
|
||||||
|
# For trained policy in study_root_path / identifier
|
||||||
|
for policy_path in [x for x in exp_path.iterdir() if x.is_dir()]:
|
||||||
|
|
||||||
|
# retrieve model class
|
||||||
|
model_cls = next(val for key, val in h.MODEL_MAP.items() if key in policy_path.parent.name)
|
||||||
|
# Load the agent agent
|
||||||
|
model = model_cls.load(policy_path / 'model.zip', device='cpu')
|
||||||
|
# Load old env kwargs
|
||||||
|
with next(policy_path.glob(env_params_json)).open('r') as f:
|
||||||
|
env_kwargs = simplejson.load(f)
|
||||||
|
# Make the env stop ar collisions
|
||||||
|
# (you only want to have a single collision per episode hence the statistics)
|
||||||
|
env_kwargs.update(done_at_collision=True)
|
||||||
|
|
||||||
|
# Init Env
|
||||||
|
with env_class(**env_kwargs) as env_factory:
|
||||||
|
monitored_env_factory = EnvMonitor(env_factory)
|
||||||
|
|
||||||
|
# Evaluation Loop for i in range(n Episodes)
|
||||||
|
for episode in range(100):
|
||||||
|
# noinspection PyRedeclaration
|
||||||
|
env_state = monitored_env_factory.reset()
|
||||||
|
rew, done_bool = 0, False
|
||||||
|
while not done_bool:
|
||||||
|
action = model.predict(env_state, deterministic=True)[0]
|
||||||
|
env_state, step_r, done_bool, info_obj = monitored_env_factory.step(action)
|
||||||
|
rew += step_r
|
||||||
|
if done_bool:
|
||||||
|
break
|
||||||
|
print(f'Factory run {episode} done, reward is:\n {rew}')
|
||||||
|
monitored_env_factory.save_run(filepath=policy_path / 'eval_run_monitor.pick')
|
||||||
|
print('Measurements Done')
|
195
quickstart/single_agent_train_dirt_env.py
Normal file
195
quickstart/single_agent_train_dirt_env.py
Normal file
@ -0,0 +1,195 @@
|
|||||||
|
import sys
|
||||||
|
import time
|
||||||
|
|
||||||
|
from pathlib import Path
|
||||||
|
import simplejson
|
||||||
|
|
||||||
|
import stable_baselines3 as sb3
|
||||||
|
|
||||||
|
# This is needed, when you put this file in a subfolder.
|
||||||
|
try:
|
||||||
|
# noinspection PyUnboundLocalVariable
|
||||||
|
if __package__ is None:
|
||||||
|
DIR = Path(__file__).resolve().parent
|
||||||
|
sys.path.insert(0, str(DIR.parent))
|
||||||
|
__package__ = DIR.name
|
||||||
|
else:
|
||||||
|
DIR = None
|
||||||
|
except NameError:
|
||||||
|
DIR = None
|
||||||
|
pass
|
||||||
|
|
||||||
|
from environments import helpers as h
|
||||||
|
from environments.logging.envmonitor import EnvMonitor
|
||||||
|
from environments.logging.recorder import EnvRecorder
|
||||||
|
from environments.factory.additional.dirt.dirt_util import DirtProperties
|
||||||
|
from environments.factory.additional.dirt.factory_dirt import DirtFactory
|
||||||
|
from environments.utility_classes import MovementProperties, ObservationProperties, AgentRenderOptions
|
||||||
|
|
||||||
|
from plotting.compare_runs import compare_seed_runs
|
||||||
|
|
||||||
|
"""
|
||||||
|
Welcome to this quick start file. Here we will see how to:
|
||||||
|
0. Setup I/O Paths
|
||||||
|
1. Setup parameters for the environments (dirt-factory).
|
||||||
|
2. Setup parameters for the agent training (SB3: PPO) and save metrics.
|
||||||
|
Run the training.
|
||||||
|
3. Save env and agent for later analysis.
|
||||||
|
4. Load the agent from drive
|
||||||
|
5. Rendering the env with a run of the trained agent.
|
||||||
|
6. Plot metrics
|
||||||
|
"""
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
#########################################################
|
||||||
|
# 0. Setup I/O Paths
|
||||||
|
# Define some general parameters
|
||||||
|
train_steps = 1e6
|
||||||
|
n_seeds = 3
|
||||||
|
model_class = sb3.PPO
|
||||||
|
env_class = DirtFactory
|
||||||
|
|
||||||
|
env_params_json = 'env_params.json'
|
||||||
|
|
||||||
|
# Define a global studi save path
|
||||||
|
start_time = int(time.time())
|
||||||
|
study_root_path = Path(__file__).parent.parent / 'study_out' / f'{Path(__file__).stem}_{start_time}'
|
||||||
|
# Create an identifier, which is unique for every combination and easy to read in filesystem
|
||||||
|
identifier = f'{model_class.__name__}_{env_class.__name__}_{start_time}'
|
||||||
|
exp_path = study_root_path / identifier
|
||||||
|
|
||||||
|
#########################################################
|
||||||
|
# 1. Setup parameters for the environments (dirt-factory).
|
||||||
|
|
||||||
|
|
||||||
|
# Define property object parameters.
|
||||||
|
# 'ObservationProperties' are for specifying how the agent sees the env.
|
||||||
|
obs_props = ObservationProperties(render_agents=AgentRenderOptions.NOT, # Agents won`t be shown in the obs at all
|
||||||
|
omit_agent_self=True, # This is default
|
||||||
|
additional_agent_placeholder=None, # We will not take care of future agents
|
||||||
|
frames_to_stack=3, # To give the agent a notion of time
|
||||||
|
pomdp_r=2 # the agents view-radius
|
||||||
|
)
|
||||||
|
# 'MovementProperties' are for specifying how the agent is allowed to move in the env.
|
||||||
|
move_props = MovementProperties(allow_diagonal_movement=True, # Euclidean style (vertices)
|
||||||
|
allow_square_movement=True, # Manhattan (edges)
|
||||||
|
allow_no_op=False) # Pause movement (do nothing)
|
||||||
|
|
||||||
|
# 'DirtProperties' control if and how dirt is spawned
|
||||||
|
# TODO: Comments
|
||||||
|
dirt_props = DirtProperties(initial_dirt_ratio=0.35,
|
||||||
|
initial_dirt_spawn_r_var=0.1,
|
||||||
|
clean_amount=0.34,
|
||||||
|
max_spawn_amount=0.1,
|
||||||
|
max_global_amount=20,
|
||||||
|
max_local_amount=1,
|
||||||
|
spawn_frequency=0,
|
||||||
|
max_spawn_ratio=0.05,
|
||||||
|
dirt_smear_amount=0.0)
|
||||||
|
|
||||||
|
# These are the EnvKwargs for initializing the env class, holding all former parameter-classes
|
||||||
|
# TODO: Comments
|
||||||
|
factory_kwargs = dict(n_agents=1,
|
||||||
|
max_steps=400,
|
||||||
|
parse_doors=True,
|
||||||
|
level_name='rooms',
|
||||||
|
doors_have_area=True, #
|
||||||
|
verbose=False,
|
||||||
|
mv_prop=move_props, # See Above
|
||||||
|
obs_prop=obs_props, # See Above
|
||||||
|
done_at_collision=True,
|
||||||
|
dirt_prop=dirt_props
|
||||||
|
)
|
||||||
|
|
||||||
|
#########################################################
|
||||||
|
# 2. Setup parameters for the agent training (SB3: PPO) and save metrics.
|
||||||
|
agent_kwargs = dict()
|
||||||
|
|
||||||
|
|
||||||
|
#########################################################
|
||||||
|
# Run the Training
|
||||||
|
for seed in range(n_seeds):
|
||||||
|
# Make a copy if you want to alter things in the training loop; like the seed.
|
||||||
|
env_kwargs = factory_kwargs.copy()
|
||||||
|
env_kwargs.update(env_seed=seed)
|
||||||
|
|
||||||
|
# Output folder
|
||||||
|
seed_path = exp_path / f'{str(seed)}_{identifier}'
|
||||||
|
seed_path.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
# Parameter Storage
|
||||||
|
param_path = seed_path / env_params_json
|
||||||
|
# Observation (measures) Storage
|
||||||
|
monitor_path = seed_path / 'monitor.pick'
|
||||||
|
recorder_path = seed_path / 'recorder.json'
|
||||||
|
# Model save Path for the trained model
|
||||||
|
model_save_path = seed_path / f'model.zip'
|
||||||
|
|
||||||
|
# Env Init & Model kwargs definition
|
||||||
|
with env_class(**env_kwargs) as env_factory:
|
||||||
|
|
||||||
|
# EnvMonitor Init
|
||||||
|
env_monitor_callback = EnvMonitor(env_factory)
|
||||||
|
|
||||||
|
# EnvRecorder Init
|
||||||
|
env_recorder_callback = EnvRecorder(env_factory, freq=int(train_steps / 400 / 10))
|
||||||
|
|
||||||
|
# Model Init
|
||||||
|
model = model_class("MlpPolicy", env_factory,verbose=1, seed=seed, device='cpu')
|
||||||
|
|
||||||
|
# Model train
|
||||||
|
model.learn(total_timesteps=int(train_steps), callback=[env_monitor_callback, env_recorder_callback])
|
||||||
|
|
||||||
|
#########################################################
|
||||||
|
# 3. Save env and agent for later analysis.
|
||||||
|
# Save the trained Model, the monitor (env measures) and the env parameters
|
||||||
|
model.named_observation_space = env_factory.named_observation_space
|
||||||
|
model.named_action_space = env_factory.named_action_space
|
||||||
|
model.save(model_save_path)
|
||||||
|
env_factory.save_params(param_path)
|
||||||
|
env_monitor_callback.save_run(monitor_path)
|
||||||
|
env_recorder_callback.save_records(recorder_path, save_occupation_map=False)
|
||||||
|
|
||||||
|
# Compare performance runs, for each seed within a model
|
||||||
|
try:
|
||||||
|
compare_seed_runs(exp_path, use_tex=False)
|
||||||
|
except ValueError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
# Train ends here ############################################################
|
||||||
|
|
||||||
|
# Evaluation starts here #####################################################
|
||||||
|
# First Iterate over every model and monitor "as trained"
|
||||||
|
print('Start Measurement Tracking')
|
||||||
|
# For trained policy in study_root_path / identifier
|
||||||
|
for policy_path in [x for x in exp_path.iterdir() if x.is_dir()]:
|
||||||
|
|
||||||
|
# retrieve model class
|
||||||
|
model_cls = next(val for key, val in h.MODEL_MAP.items() if key in policy_path.parent.name)
|
||||||
|
# Load the agent agent
|
||||||
|
model = model_cls.load(policy_path / 'model.zip', device='cpu')
|
||||||
|
# Load old env kwargs
|
||||||
|
with next(policy_path.glob(env_params_json)).open('r') as f:
|
||||||
|
env_kwargs = simplejson.load(f)
|
||||||
|
# Make the env stop ar collisions
|
||||||
|
# (you only want to have a single collision per episode hence the statistics)
|
||||||
|
env_kwargs.update(done_at_collision=True)
|
||||||
|
|
||||||
|
# Init Env
|
||||||
|
with env_class(**env_kwargs) as env_factory:
|
||||||
|
monitored_env_factory = EnvMonitor(env_factory)
|
||||||
|
|
||||||
|
# Evaluation Loop for i in range(n Episodes)
|
||||||
|
for episode in range(100):
|
||||||
|
# noinspection PyRedeclaration
|
||||||
|
env_state = monitored_env_factory.reset()
|
||||||
|
rew, done_bool = 0, False
|
||||||
|
while not done_bool:
|
||||||
|
action = model.predict(env_state, deterministic=True)[0]
|
||||||
|
env_state, step_r, done_bool, info_obj = monitored_env_factory.step(action)
|
||||||
|
rew += step_r
|
||||||
|
if done_bool:
|
||||||
|
break
|
||||||
|
print(f'Factory run {episode} done, reward is:\n {rew}')
|
||||||
|
monitored_env_factory.save_run(filepath=policy_path / 'eval_run_monitor.pick')
|
||||||
|
print('Measurements Done')
|
191
quickstart/single_agent_train_item_env.py
Normal file
191
quickstart/single_agent_train_item_env.py
Normal file
@ -0,0 +1,191 @@
|
|||||||
|
import sys
|
||||||
|
import time
|
||||||
|
|
||||||
|
from pathlib import Path
|
||||||
|
import simplejson
|
||||||
|
|
||||||
|
import stable_baselines3 as sb3
|
||||||
|
|
||||||
|
# This is needed, when you put this file in a subfolder.
|
||||||
|
try:
|
||||||
|
# noinspection PyUnboundLocalVariable
|
||||||
|
if __package__ is None:
|
||||||
|
DIR = Path(__file__).resolve().parent
|
||||||
|
sys.path.insert(0, str(DIR.parent))
|
||||||
|
__package__ = DIR.name
|
||||||
|
else:
|
||||||
|
DIR = None
|
||||||
|
except NameError:
|
||||||
|
DIR = None
|
||||||
|
pass
|
||||||
|
|
||||||
|
from environments import helpers as h
|
||||||
|
from environments.factory.additional.item.factory_item import ItemFactory
|
||||||
|
from environments.factory.additional.item.item_util import ItemProperties
|
||||||
|
from environments.logging.envmonitor import EnvMonitor
|
||||||
|
from environments.logging.recorder import EnvRecorder
|
||||||
|
from environments.utility_classes import MovementProperties, ObservationProperties, AgentRenderOptions
|
||||||
|
|
||||||
|
from plotting.compare_runs import compare_seed_runs
|
||||||
|
|
||||||
|
"""
|
||||||
|
Welcome to this quick start file. Here we will see how to:
|
||||||
|
0. Setup I/O Paths
|
||||||
|
1. Setup parameters for the environments (item-factory).
|
||||||
|
2. Setup parameters for the agent training (SB3: PPO) and save metrics.
|
||||||
|
Run the training.
|
||||||
|
3. Save env and agent for later analysis.
|
||||||
|
4. Load the agent from drive
|
||||||
|
5. Rendering the env with a run of the trained agent.
|
||||||
|
6. Plot metrics
|
||||||
|
"""
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
#########################################################
|
||||||
|
# 0. Setup I/O Paths
|
||||||
|
# Define some general parameters
|
||||||
|
train_steps = 1e6
|
||||||
|
n_seeds = 3
|
||||||
|
model_class = sb3.PPO
|
||||||
|
env_class = ItemFactory
|
||||||
|
|
||||||
|
env_params_json = 'env_params.json'
|
||||||
|
|
||||||
|
# Define a global studi save path
|
||||||
|
start_time = int(time.time())
|
||||||
|
study_root_path = Path(__file__).parent.parent / 'study_out' / f'{Path(__file__).stem}_{start_time}'
|
||||||
|
# Create an identifier, which is unique for every combination and easy to read in filesystem
|
||||||
|
identifier = f'{model_class.__name__}_{env_class.__name__}_{start_time}'
|
||||||
|
exp_path = study_root_path / identifier
|
||||||
|
|
||||||
|
#########################################################
|
||||||
|
# 1. Setup parameters for the environments (item-factory).
|
||||||
|
#
|
||||||
|
# Define property object parameters.
|
||||||
|
# 'ObservationProperties' are for specifying how the agent sees the env.
|
||||||
|
obs_props = ObservationProperties(render_agents=AgentRenderOptions.NOT, # Agents won`t be shown in the obs at all
|
||||||
|
omit_agent_self=True, # This is default
|
||||||
|
additional_agent_placeholder=None, # We will not take care of future agents
|
||||||
|
frames_to_stack=3, # To give the agent a notion of time
|
||||||
|
pomdp_r=2 # the agents view-radius
|
||||||
|
)
|
||||||
|
# 'MovementProperties' are for specifying how the agent is allowed to move in the env.
|
||||||
|
move_props = MovementProperties(allow_diagonal_movement=True, # Euclidean style (vertices)
|
||||||
|
allow_square_movement=True, # Manhattan (edges)
|
||||||
|
allow_no_op=False) # Pause movement (do nothing)
|
||||||
|
|
||||||
|
# 'ItemProperties' control if and how item is spawned
|
||||||
|
# TODO: Comments
|
||||||
|
item_props = ItemProperties(
|
||||||
|
n_items = 7, # How many items are there at the same time
|
||||||
|
spawn_frequency = 50, # Spawn Frequency in Steps
|
||||||
|
n_drop_off_locations = 10, # How many DropOff locations are there at the same time
|
||||||
|
max_dropoff_storage_size = 0, # How many items are needed until the dropoff is full
|
||||||
|
max_agent_inventory_capacity = 5, # How many items are needed until the agent inventory is full)
|
||||||
|
)
|
||||||
|
|
||||||
|
# These are the EnvKwargs for initializing the env class, holding all former parameter-classes
|
||||||
|
# TODO: Comments
|
||||||
|
factory_kwargs = dict(n_agents=1,
|
||||||
|
max_steps=400,
|
||||||
|
parse_doors=True,
|
||||||
|
level_name='rooms',
|
||||||
|
doors_have_area=True, #
|
||||||
|
verbose=False,
|
||||||
|
mv_prop=move_props, # See Above
|
||||||
|
obs_prop=obs_props, # See Above
|
||||||
|
done_at_collision=True,
|
||||||
|
item_prop=item_props
|
||||||
|
)
|
||||||
|
|
||||||
|
#########################################################
|
||||||
|
# 2. Setup parameters for the agent training (SB3: PPO) and save metrics.
|
||||||
|
agent_kwargs = dict()
|
||||||
|
|
||||||
|
#########################################################
|
||||||
|
# Run the Training
|
||||||
|
for seed in range(n_seeds):
|
||||||
|
# Make a copy if you want to alter things in the training loop; like the seed.
|
||||||
|
env_kwargs = factory_kwargs.copy()
|
||||||
|
env_kwargs.update(env_seed=seed)
|
||||||
|
|
||||||
|
# Output folder
|
||||||
|
seed_path = exp_path / f'{str(seed)}_{identifier}'
|
||||||
|
seed_path.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
# Parameter Storage
|
||||||
|
param_path = seed_path / env_params_json
|
||||||
|
# Observation (measures) Storage
|
||||||
|
monitor_path = seed_path / 'monitor.pick'
|
||||||
|
recorder_path = seed_path / 'recorder.json'
|
||||||
|
# Model save Path for the trained model
|
||||||
|
model_save_path = seed_path / f'model.zip'
|
||||||
|
|
||||||
|
# Env Init & Model kwargs definition
|
||||||
|
with ItemFactory(**env_kwargs) as env_factory:
|
||||||
|
|
||||||
|
# EnvMonitor Init
|
||||||
|
env_monitor_callback = EnvMonitor(env_factory)
|
||||||
|
|
||||||
|
# EnvRecorder Init
|
||||||
|
env_recorder_callback = EnvRecorder(env_factory, freq=int(train_steps / 400 / 10))
|
||||||
|
|
||||||
|
# Model Init
|
||||||
|
model = model_class("MlpPolicy", env_factory,verbose=1, seed=seed, device='cpu')
|
||||||
|
|
||||||
|
# Model train
|
||||||
|
model.learn(total_timesteps=int(train_steps), callback=[env_monitor_callback, env_recorder_callback])
|
||||||
|
|
||||||
|
#########################################################
|
||||||
|
# 3. Save env and agent for later analysis.
|
||||||
|
# Save the trained Model, the monitor (env measures) and the env parameters
|
||||||
|
model.named_observation_space = env_factory.named_observation_space
|
||||||
|
model.named_action_space = env_factory.named_action_space
|
||||||
|
model.save(model_save_path)
|
||||||
|
env_factory.save_params(param_path)
|
||||||
|
env_monitor_callback.save_run(monitor_path)
|
||||||
|
env_recorder_callback.save_records(recorder_path, save_occupation_map=False)
|
||||||
|
|
||||||
|
# Compare performance runs, for each seed within a model
|
||||||
|
try:
|
||||||
|
compare_seed_runs(exp_path, use_tex=False)
|
||||||
|
except ValueError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
# Train ends here ############################################################
|
||||||
|
|
||||||
|
# Evaluation starts here #####################################################
|
||||||
|
# First Iterate over every model and monitor "as trained"
|
||||||
|
print('Start Measurement Tracking')
|
||||||
|
# For trained policy in study_root_path / identifier
|
||||||
|
for policy_path in [x for x in exp_path.iterdir() if x.is_dir()]:
|
||||||
|
|
||||||
|
# retrieve model class
|
||||||
|
model_cls = next(val for key, val in h.MODEL_MAP.items() if key in policy_path.parent.name)
|
||||||
|
# Load the agent agent
|
||||||
|
model = model_cls.load(policy_path / 'model.zip', device='cpu')
|
||||||
|
# Load old env kwargs
|
||||||
|
with next(policy_path.glob(env_params_json)).open('r') as f:
|
||||||
|
env_kwargs = simplejson.load(f)
|
||||||
|
# Make the env stop ar collisions
|
||||||
|
# (you only want to have a single collision per episode hence the statistics)
|
||||||
|
env_kwargs.update(done_at_collision=True)
|
||||||
|
|
||||||
|
# Init Env
|
||||||
|
with ItemFactory(**env_kwargs) as env_factory:
|
||||||
|
monitored_env_factory = EnvMonitor(env_factory)
|
||||||
|
|
||||||
|
# Evaluation Loop for i in range(n Episodes)
|
||||||
|
for episode in range(100):
|
||||||
|
# noinspection PyRedeclaration
|
||||||
|
env_state = monitored_env_factory.reset()
|
||||||
|
rew, done_bool = 0, False
|
||||||
|
while not done_bool:
|
||||||
|
action = model.predict(env_state, deterministic=True)[0]
|
||||||
|
env_state, step_r, done_bool, info_obj = monitored_env_factory.step(action)
|
||||||
|
rew += step_r
|
||||||
|
if done_bool:
|
||||||
|
break
|
||||||
|
print(f'Factory run {episode} done, reward is:\n {rew}')
|
||||||
|
monitored_env_factory.save_run(filepath=policy_path / 'eval_run_monitor.pick')
|
||||||
|
print('Measurements Done')
|
@ -1,13 +1,13 @@
|
|||||||
import warnings
|
import warnings
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
import yaml
|
import yaml
|
||||||
|
from stable_baselines3 import A2C, PPO, DQN
|
||||||
|
|
||||||
from environments import helpers as h
|
from environments.factory.additional.dirt.dirt_util import Constants
|
||||||
from environments.helpers import Constants as c
|
|
||||||
from environments.factory.factory_dirt import DirtFactory
|
from environments.factory.additional.dirt.factory_dirt import DirtFactory
|
||||||
from environments.factory.combined_factories import DirtItemFactory
|
from environments.logging.envmonitor import EnvMonitor
|
||||||
from environments.logging.recorder import EnvRecorder
|
from environments.logging.recorder import EnvRecorder
|
||||||
|
|
||||||
warnings.filterwarnings('ignore', category=FutureWarning)
|
warnings.filterwarnings('ignore', category=FutureWarning)
|
||||||
@ -18,36 +18,41 @@ if __name__ == '__main__':
|
|||||||
|
|
||||||
determin = False
|
determin = False
|
||||||
render = True
|
render = True
|
||||||
record = True
|
record = False
|
||||||
seed = 67
|
verbose = True
|
||||||
|
seed = 13
|
||||||
n_agents = 1
|
n_agents = 1
|
||||||
out_path = Path('study_out/test/dirt')
|
# out_path = Path('study_out/e_1_new_reward/no_obs/dirt/A2C_new_reward/0_A2C_new_reward')
|
||||||
|
out_path = Path('quickstart/combinations/single_agent_train_dirt_env_1659374984/PPO_DirtFactory_1659374984/0_PPO_DirtFactory_1659374984/')
|
||||||
|
model_path = out_path
|
||||||
|
|
||||||
with (out_path / f'env_params.json').open('r') as f:
|
with (out_path / f'env_params.json').open('r') as f:
|
||||||
env_kwargs = yaml.load(f, Loader=yaml.FullLoader)
|
env_kwargs = yaml.load(f, Loader=yaml.FullLoader)
|
||||||
env_kwargs.update(additional_agent_placeholder=None, n_agents=n_agents, max_steps=150)
|
env_kwargs.update(n_agents=n_agents, done_at_collision=False, verbose=verbose)
|
||||||
if gain_amount := env_kwargs.get('dirt_prop', {}).get('gain_amount', None):
|
|
||||||
env_kwargs['dirt_prop']['max_spawn_amount'] = gain_amount
|
|
||||||
del env_kwargs['dirt_prop']['gain_amount']
|
|
||||||
|
|
||||||
env_kwargs.update(record_episodes=record, done_at_collision=True)
|
|
||||||
|
|
||||||
this_model = out_path / 'model.zip'
|
this_model = out_path / 'model.zip'
|
||||||
|
|
||||||
model_cls =h.MODEL_MAP['A2C']
|
model_cls = PPO # next(val for key, val in h.MODEL_MAP.items() if key in out_path.parent.name)
|
||||||
models = [model_cls.load(this_model)]
|
models = [model_cls.load(this_model)]
|
||||||
|
try:
|
||||||
|
# Legacy Cleanups
|
||||||
|
del env_kwargs['dirt_prop']['agent_can_interact']
|
||||||
|
env_kwargs['verbose'] = True
|
||||||
|
except KeyError:
|
||||||
|
pass
|
||||||
|
|
||||||
# Init Env
|
# Init Env
|
||||||
with DirtFactory(**env_kwargs) as env:
|
with DirtFactory(**env_kwargs) as env:
|
||||||
env = EnvRecorder(env)
|
env = EnvMonitor(env)
|
||||||
|
env = EnvRecorder(env) if record else env
|
||||||
obs_shape = env.observation_space.shape
|
obs_shape = env.observation_space.shape
|
||||||
# Evaluation Loop for i in range(n Episodes)
|
# Evaluation Loop for i in range(n Episodes)
|
||||||
for episode in range(50):
|
for episode in range(500):
|
||||||
env_state = env.reset()
|
env_state = env.reset()
|
||||||
rew, done_bool = 0, False
|
rew, done_bool = 0, False
|
||||||
while not done_bool:
|
while not done_bool:
|
||||||
if n_agents > 1:
|
if n_agents > 1:
|
||||||
actions = [model.predict(env_state[model_idx], deterministic=True)[0]
|
actions = [model.predict(env_state[model_idx], deterministic=determin)[0]
|
||||||
for model_idx, model in enumerate(models)]
|
for model_idx, model in enumerate(models)]
|
||||||
else:
|
else:
|
||||||
actions = models[0].predict(env_state, deterministic=determin)[0]
|
actions = models[0].predict(env_state, deterministic=determin)[0]
|
||||||
@ -56,7 +61,17 @@ if __name__ == '__main__':
|
|||||||
rew += step_r
|
rew += step_r
|
||||||
if render:
|
if render:
|
||||||
env.render()
|
env.render()
|
||||||
|
try:
|
||||||
|
door = next(x for x in env.unwrapped.unwrapped[Constants.DOORS] if x.is_open)
|
||||||
|
print('openDoor found')
|
||||||
|
except StopIteration:
|
||||||
|
pass
|
||||||
|
|
||||||
if done_bool:
|
if done_bool:
|
||||||
break
|
break
|
||||||
print(f'Factory run {episode} done, reward is:\n {rew}')
|
print(f'Factory run {episode} done, steps taken {env.unwrapped.unwrapped._steps}, reward is:\n {rew}')
|
||||||
|
env.save_run(out_path / 'reload_monitor.pick',
|
||||||
|
auto_plotting_keys=['step_reward', 'cleanup_valid', 'cleanup_fail'])
|
||||||
|
if record:
|
||||||
|
env.save_records(out_path / 'reload_recorder.pick', save_occupation_map=True)
|
||||||
print('all done')
|
print('all done')
|
||||||
|
@ -1,12 +1,14 @@
|
|||||||
numpy
|
numpy
|
||||||
scipy
|
scipy
|
||||||
tqdm
|
tqdm
|
||||||
|
pandas
|
||||||
seaborn>=0.11.1
|
seaborn>=0.11.1
|
||||||
matplotlib>=3.4.1
|
matplotlib>=3.3.4
|
||||||
stable-baselines3>=1.0
|
stable-baselines3>=1.0
|
||||||
pygame>=2.1.0
|
pygame>=2.1.0
|
||||||
gym>=0.18.0
|
gym>=0.18.0
|
||||||
networkx>=2.6.1
|
networkx>=2.6.3
|
||||||
simplejson>=3.17.5
|
simplejson>=3.17.5
|
||||||
PyYAML>=6.0
|
PyYAML>=6.0
|
||||||
git+https://github.com/facebookresearch/salina.git@main#egg=salina
|
einops
|
||||||
|
natsort
|
@ -1,7 +1,6 @@
|
|||||||
import sys
|
import sys
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from matplotlib import pyplot as plt
|
from matplotlib import pyplot as plt
|
||||||
import numpy as np
|
|
||||||
import itertools as it
|
import itertools as it
|
||||||
|
|
||||||
try:
|
try:
|
||||||
@ -16,26 +15,24 @@ except NameError:
|
|||||||
DIR = None
|
DIR = None
|
||||||
pass
|
pass
|
||||||
|
|
||||||
import time
|
|
||||||
|
|
||||||
import simplejson
|
import simplejson
|
||||||
from stable_baselines3.common.vec_env import SubprocVecEnv
|
from stable_baselines3.common.vec_env import SubprocVecEnv
|
||||||
|
|
||||||
from environments import helpers as h
|
from environments import helpers as h
|
||||||
from environments.factory.factory_dirt import DirtProperties, DirtFactory
|
from environments.factory.factory_dirt import DirtFactory
|
||||||
|
from environments.factory.dirt_util import DirtProperties
|
||||||
from environments.factory.combined_factories import DirtItemFactory
|
from environments.factory.combined_factories import DirtItemFactory
|
||||||
from environments.factory.factory_item import ItemProperties, ItemFactory
|
from environments.factory.factory_item import ItemFactory
|
||||||
|
from environments.factory.additional.item.item_util import ItemProperties
|
||||||
from environments.logging.envmonitor import EnvMonitor
|
from environments.logging.envmonitor import EnvMonitor
|
||||||
from environments.utility_classes import MovementProperties, ObservationProperties, AgentRenderOptions
|
from environments.utility_classes import MovementProperties, ObservationProperties, AgentRenderOptions
|
||||||
import pickle
|
import pickle
|
||||||
from plotting.compare_runs import compare_seed_runs, compare_model_runs, compare_all_parameter_runs
|
from plotting.compare_runs import compare_seed_runs, compare_model_runs
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
import seaborn as sns
|
import seaborn as sns
|
||||||
|
|
||||||
import multiprocessing as mp
|
import multiprocessing as mp
|
||||||
|
|
||||||
# mp.set_start_method("spawn")
|
|
||||||
|
|
||||||
"""
|
"""
|
||||||
In this studie, we want to explore the macro behaviour of multi agents which are trained on the same task,
|
In this studie, we want to explore the macro behaviour of multi agents which are trained on the same task,
|
||||||
but never saw each other in training.
|
but never saw each other in training.
|
||||||
@ -72,10 +69,9 @@ n_agents = 4
|
|||||||
ood_monitor_file = f'e_1_{n_agents}_agents'
|
ood_monitor_file = f'e_1_{n_agents}_agents'
|
||||||
baseline_monitor_file = 'e_1_baseline'
|
baseline_monitor_file = 'e_1_baseline'
|
||||||
|
|
||||||
from stable_baselines3 import A2C
|
|
||||||
|
|
||||||
def policy_model_kwargs():
|
def policy_model_kwargs():
|
||||||
return dict() # gae_lambda=0.25, n_steps=16, max_grad_norm=0.25, use_rms_prop=True)
|
return dict() # gae_lambda=0.25, n_steps=16, max_grad_norm=0.25, use_rms_prop=True)
|
||||||
|
|
||||||
|
|
||||||
def dqn_model_kwargs():
|
def dqn_model_kwargs():
|
||||||
@ -198,7 +194,7 @@ if __name__ == '__main__':
|
|||||||
ood_run = True
|
ood_run = True
|
||||||
plotting = True
|
plotting = True
|
||||||
|
|
||||||
train_steps = 1e7
|
train_steps = 1e6
|
||||||
n_seeds = 3
|
n_seeds = 3
|
||||||
frames_to_stack = 3
|
frames_to_stack = 3
|
||||||
|
|
||||||
@ -221,8 +217,8 @@ if __name__ == '__main__':
|
|||||||
clean_amount=0.34,
|
clean_amount=0.34,
|
||||||
max_spawn_amount=0.1, max_global_amount=20,
|
max_spawn_amount=0.1, max_global_amount=20,
|
||||||
max_local_amount=1, spawn_frequency=0, max_spawn_ratio=0.05,
|
max_local_amount=1, spawn_frequency=0, max_spawn_ratio=0.05,
|
||||||
dirt_smear_amount=0.0, agent_can_interact=True)
|
dirt_smear_amount=0.0)
|
||||||
item_props = ItemProperties(n_items=10, agent_can_interact=True,
|
item_props = ItemProperties(n_items=10,
|
||||||
spawn_frequency=30, n_drop_off_locations=2,
|
spawn_frequency=30, n_drop_off_locations=2,
|
||||||
max_agent_inventory_capacity=15)
|
max_agent_inventory_capacity=15)
|
||||||
factory_kwargs = dict(n_agents=1, max_steps=400, parse_doors=True,
|
factory_kwargs = dict(n_agents=1, max_steps=400, parse_doors=True,
|
||||||
@ -355,6 +351,7 @@ if __name__ == '__main__':
|
|||||||
# Env Init & Model kwargs definition
|
# Env Init & Model kwargs definition
|
||||||
if model_cls.__name__ in ["PPO", "A2C"]:
|
if model_cls.__name__ in ["PPO", "A2C"]:
|
||||||
# env_factory = env_class(**env_kwargs)
|
# env_factory = env_class(**env_kwargs)
|
||||||
|
|
||||||
env_factory = SubprocVecEnv([encapsule_env_factory(env_class, env_kwargs)
|
env_factory = SubprocVecEnv([encapsule_env_factory(env_class, env_kwargs)
|
||||||
for _ in range(6)], start_method="spawn")
|
for _ in range(6)], start_method="spawn")
|
||||||
model_kwargs = policy_model_kwargs()
|
model_kwargs = policy_model_kwargs()
|
||||||
@ -434,8 +431,8 @@ if __name__ == '__main__':
|
|||||||
# Iteration
|
# Iteration
|
||||||
start_mp_baseline_run(env_map, policy_path)
|
start_mp_baseline_run(env_map, policy_path)
|
||||||
|
|
||||||
# for seed_path in (y for y in policy_path.iterdir() if y.is_dir()):
|
# for policy_path in (y for y in policy_path.iterdir() if y.is_dir()):
|
||||||
# load_model_run_baseline(seed_path)
|
# load_model_run_baseline(policy_path)
|
||||||
print('Baseline Tracking done')
|
print('Baseline Tracking done')
|
||||||
|
|
||||||
# Then iterate over every model and monitor "ood behavior" - "is it ood?"
|
# Then iterate over every model and monitor "ood behavior" - "is it ood?"
|
||||||
@ -448,11 +445,11 @@ if __name__ == '__main__':
|
|||||||
for policy_path in [x for x in env_path.iterdir() if x. is_dir()]:
|
for policy_path in [x for x in env_path.iterdir() if x. is_dir()]:
|
||||||
# FIXME: Pick random seed or iterate over available seeds
|
# FIXME: Pick random seed or iterate over available seeds
|
||||||
# First seed path version
|
# First seed path version
|
||||||
# seed_path = next((y for y in policy_path.iterdir() if y.is_dir()))
|
# policy_path = next((y for y in policy_path.iterdir() if y.is_dir()))
|
||||||
# Iteration
|
# Iteration
|
||||||
start_mp_study_run(env_map, policy_path)
|
start_mp_study_run(env_map, policy_path)
|
||||||
#for seed_path in (y for y in policy_path.iterdir() if y.is_dir()):
|
#for policy_path in (y for y in policy_path.iterdir() if y.is_dir()):
|
||||||
# load_model_run_study(seed_path, env_map[env_path.name][0], observation_modes[obs_mode])
|
# load_model_run_study(policy_path, env_map[env_path.name][0], observation_modes[obs_mode])
|
||||||
print('OOD Tracking Done')
|
print('OOD Tracking Done')
|
||||||
|
|
||||||
# Plotting
|
# Plotting
|
||||||
|
23
studies/normalization_study.py
Normal file
23
studies/normalization_study.py
Normal file
@ -0,0 +1,23 @@
|
|||||||
|
from algorithms.utils import Checkpointer
|
||||||
|
from pathlib import Path
|
||||||
|
from algorithms.utils import load_yaml_file, add_env_props, instantiate_class, load_class
|
||||||
|
#from algorithms.marl import LoopSNAC, LoopIAC, LoopSEAC
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
for i in range(0, 5):
|
||||||
|
for name in ['snac', 'mappo', 'iac', 'seac']:
|
||||||
|
study_root = Path(__file__).parent / name
|
||||||
|
cfg = load_yaml_file(study_root / f'{name}.yaml')
|
||||||
|
add_env_props(cfg)
|
||||||
|
|
||||||
|
env = instantiate_class(cfg['env'])
|
||||||
|
net = instantiate_class(cfg['agent'])
|
||||||
|
max_steps = cfg['algorithm']['max_steps']
|
||||||
|
n_steps = cfg['algorithm']['n_steps']
|
||||||
|
|
||||||
|
checkpointer = Checkpointer(f'{name}#{i}', study_root, cfg, max_steps, 50)
|
||||||
|
|
||||||
|
loop = load_class(cfg['method'])(cfg)
|
||||||
|
df = loop.train_loop(checkpointer)
|
||||||
|
|
22
studies/playground_file.py
Normal file
22
studies/playground_file.py
Normal file
@ -0,0 +1,22 @@
|
|||||||
|
import pandas as pd
|
||||||
|
from pathlib import Path
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
import seaborn as sns
|
||||||
|
|
||||||
|
dfs = []
|
||||||
|
for name in ['mappo']:
|
||||||
|
for c in range(5):
|
||||||
|
try:
|
||||||
|
study_root = Path(__file__).parent / name / f'{name}#{c}'
|
||||||
|
print(study_root)
|
||||||
|
df = pd.read_csv(study_root / 'results.csv', index_col=False)
|
||||||
|
df.reward = df.reward.rolling(100).mean()
|
||||||
|
df['method'] = name.upper()
|
||||||
|
dfs.append(df)
|
||||||
|
except Exception as e:
|
||||||
|
pass
|
||||||
|
|
||||||
|
df = pd.concat(dfs).reset_index()
|
||||||
|
sns.lineplot(data=df, x='steps', y='reward', hue='method', palette='husl', ci='sd', linewidth=1.5, err_style='bars')
|
||||||
|
plt.savefig('study.png')
|
||||||
|
print('saved image')
|
@ -1,139 +0,0 @@
|
|||||||
from salina.agents.gyma import AutoResetGymAgent
|
|
||||||
from salina.agents import Agents, TemporalAgent
|
|
||||||
from salina.rl.functional import _index, gae
|
|
||||||
import torch
|
|
||||||
import torch.nn as nn
|
|
||||||
from torch.distributions import Categorical
|
|
||||||
from salina import TAgent, Workspace, get_arguments, get_class, instantiate_class
|
|
||||||
from pathlib import Path
|
|
||||||
import numpy as np
|
|
||||||
from tqdm import tqdm
|
|
||||||
import time
|
|
||||||
from algorithms.utils import (
|
|
||||||
add_env_props,
|
|
||||||
load_yaml_file,
|
|
||||||
CombineActionsAgent,
|
|
||||||
AutoResetGymMultiAgent,
|
|
||||||
access_str,
|
|
||||||
AGENT_PREFIX, REWARD, CUMU_REWARD, OBS, SEP
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class A2CAgent(TAgent):
|
|
||||||
def __init__(self, observation_size, hidden_size, n_actions, agent_id):
|
|
||||||
super().__init__()
|
|
||||||
observation_size = np.prod(observation_size)
|
|
||||||
print(observation_size)
|
|
||||||
self.agent_id = agent_id
|
|
||||||
self.model = nn.Sequential(
|
|
||||||
nn.Flatten(),
|
|
||||||
nn.Linear(observation_size, hidden_size),
|
|
||||||
nn.ELU(),
|
|
||||||
nn.Linear(hidden_size, hidden_size),
|
|
||||||
nn.ELU(),
|
|
||||||
nn.Linear(hidden_size, hidden_size),
|
|
||||||
nn.ELU()
|
|
||||||
)
|
|
||||||
self.action_head = nn.Linear(hidden_size, n_actions)
|
|
||||||
self.critic_head = nn.Linear(hidden_size, 1)
|
|
||||||
|
|
||||||
def get_obs(self, t):
|
|
||||||
observation = self.get((f'env/{access_str(self.agent_id, OBS)}', t))
|
|
||||||
return observation
|
|
||||||
|
|
||||||
def forward(self, t, stochastic, **kwargs):
|
|
||||||
observation = self.get_obs(t)
|
|
||||||
features = self.model(observation)
|
|
||||||
scores = self.action_head(features)
|
|
||||||
probs = torch.softmax(scores, dim=-1)
|
|
||||||
critic = self.critic_head(features).squeeze(-1)
|
|
||||||
if stochastic:
|
|
||||||
action = torch.distributions.Categorical(probs).sample()
|
|
||||||
else:
|
|
||||||
action = probs.argmax(1)
|
|
||||||
self.set((f'{access_str(self.agent_id, "action")}', t), action)
|
|
||||||
self.set((f'{access_str(self.agent_id, "action_probs")}', t), probs)
|
|
||||||
self.set((f'{access_str(self.agent_id, "critic")}', t), critic)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
|
||||||
# Setup workspace
|
|
||||||
uid = time.time()
|
|
||||||
workspace = Workspace()
|
|
||||||
n_agents = 2
|
|
||||||
|
|
||||||
# load config
|
|
||||||
cfg = load_yaml_file(Path(__file__).parent / 'sat_mad.yaml')
|
|
||||||
add_env_props(cfg)
|
|
||||||
cfg['env'].update({'n_agents': n_agents})
|
|
||||||
|
|
||||||
# instantiate agent and env
|
|
||||||
env_agent = AutoResetGymMultiAgent(
|
|
||||||
get_class(cfg['env']),
|
|
||||||
get_arguments(cfg['env']),
|
|
||||||
n_envs=1
|
|
||||||
)
|
|
||||||
|
|
||||||
a2c_agents = [instantiate_class({**cfg['agent'],
|
|
||||||
'agent_id': agent_id})
|
|
||||||
for agent_id in range(n_agents)]
|
|
||||||
|
|
||||||
# combine agents
|
|
||||||
acquisition_agent = TemporalAgent(Agents(env_agent, *a2c_agents, CombineActionsAgent()))
|
|
||||||
acquisition_agent.seed(69)
|
|
||||||
|
|
||||||
# optimizers & other parameters
|
|
||||||
cfg_optim = cfg['algorithm']['optimizer']
|
|
||||||
optimizers = [get_class(cfg_optim)(a2c_agent.parameters(), **get_arguments(cfg_optim))
|
|
||||||
for a2c_agent in a2c_agents]
|
|
||||||
n_timesteps = cfg['algorithm']['n_timesteps']
|
|
||||||
|
|
||||||
# Decision making loop
|
|
||||||
best = -float('inf')
|
|
||||||
with tqdm(range(int(cfg['algorithm']['max_epochs'] / n_timesteps))) as pbar:
|
|
||||||
for epoch in pbar:
|
|
||||||
workspace.zero_grad()
|
|
||||||
if epoch > 0:
|
|
||||||
workspace.copy_n_last_steps(1)
|
|
||||||
acquisition_agent(workspace, t=1, n_steps=n_timesteps-1, stochastic=True)
|
|
||||||
else:
|
|
||||||
acquisition_agent(workspace, t=0, n_steps=n_timesteps, stochastic=True)
|
|
||||||
|
|
||||||
for agent_id in range(n_agents):
|
|
||||||
critic, done, action_probs, reward, action = workspace[
|
|
||||||
access_str(agent_id, 'critic'),
|
|
||||||
"env/done",
|
|
||||||
access_str(agent_id, 'action_probs'),
|
|
||||||
access_str(agent_id, 'reward', 'env/'),
|
|
||||||
access_str(agent_id, 'action')
|
|
||||||
]
|
|
||||||
td = gae(critic, reward, done, 0.98, 0.25)
|
|
||||||
td_error = td ** 2
|
|
||||||
critic_loss = td_error.mean()
|
|
||||||
entropy_loss = Categorical(action_probs).entropy().mean()
|
|
||||||
action_logp = _index(action_probs, action).log()
|
|
||||||
a2c_loss = action_logp[:-1] * td.detach()
|
|
||||||
a2c_loss = a2c_loss.mean()
|
|
||||||
loss = (
|
|
||||||
-0.001 * entropy_loss
|
|
||||||
+ 1.0 * critic_loss
|
|
||||||
- 0.1 * a2c_loss
|
|
||||||
)
|
|
||||||
optimizer = optimizers[agent_id]
|
|
||||||
optimizer.zero_grad()
|
|
||||||
loss.backward()
|
|
||||||
#torch.nn.utils.clip_grad_norm_(a2c_agents[agent_id].parameters(), .5)
|
|
||||||
optimizer.step()
|
|
||||||
|
|
||||||
# Compute the cumulated reward on final_state
|
|
||||||
rews = ''
|
|
||||||
for agent_i in range(n_agents):
|
|
||||||
creward = workspace['env/'+access_str(agent_i, CUMU_REWARD)]
|
|
||||||
creward = creward[done]
|
|
||||||
if creward.size()[0] > 0:
|
|
||||||
rews += f'{AGENT_PREFIX}{agent_i}: {creward.mean().item():.2f} | '
|
|
||||||
"""if cum_r > best:
|
|
||||||
torch.save(a2c_agent.state_dict(), Path(__file__).parent / f'agent_{uid}.pt')
|
|
||||||
best = cum_r"""
|
|
||||||
pbar.set_description(rews, refresh=True)
|
|
||||||
|
|
@ -1,27 +0,0 @@
|
|||||||
agent:
|
|
||||||
classname: studies.sat_mad.A2CAgent
|
|
||||||
observation_size: 4*5*5
|
|
||||||
hidden_size: 128
|
|
||||||
n_actions: 10
|
|
||||||
|
|
||||||
env:
|
|
||||||
classname: environments.factory.make
|
|
||||||
env_name: "DirtyFactory-v0"
|
|
||||||
n_agents: 1
|
|
||||||
pomdp_r: 2
|
|
||||||
max_steps: 400
|
|
||||||
stack_n_frames: 3
|
|
||||||
individual_rewards: True
|
|
||||||
|
|
||||||
algorithm:
|
|
||||||
max_epochs: 1000000
|
|
||||||
n_envs: 1
|
|
||||||
n_timesteps: 10
|
|
||||||
discount_factor: 0.99
|
|
||||||
entropy_coef: 0.01
|
|
||||||
critic_coef: 1.0
|
|
||||||
gae: 0.25
|
|
||||||
optimizer:
|
|
||||||
classname: torch.optim.Adam
|
|
||||||
lr: 0.0003
|
|
||||||
weight_decay: 0.0
|
|
270
studies/single_run_with_export.py
Normal file
270
studies/single_run_with_export.py
Normal file
@ -0,0 +1,270 @@
|
|||||||
|
import itertools
|
||||||
|
import sys
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
from stable_baselines3.common.vec_env import SubprocVecEnv
|
||||||
|
|
||||||
|
try:
|
||||||
|
# noinspection PyUnboundLocalVariable
|
||||||
|
if __package__ is None:
|
||||||
|
DIR = Path(__file__).resolve().parent
|
||||||
|
sys.path.insert(0, str(DIR.parent))
|
||||||
|
__package__ = DIR.name
|
||||||
|
else:
|
||||||
|
DIR = None
|
||||||
|
except NameError:
|
||||||
|
DIR = None
|
||||||
|
pass
|
||||||
|
|
||||||
|
import simplejson
|
||||||
|
from environments.helpers import ActionTranslator, ObservationTranslator
|
||||||
|
from environments.logging.recorder import EnvRecorder
|
||||||
|
from environments import helpers as h
|
||||||
|
from environments.factory.factory_dirt import DirtFactory
|
||||||
|
from environments.factory.dirt_util import DirtProperties
|
||||||
|
from environments.factory.factory_item import ItemFactory
|
||||||
|
from environments.factory.additional.item.item_util import ItemProperties
|
||||||
|
from environments.factory.factory_dest import DestFactory
|
||||||
|
from environments.factory.additional.dest.dest_util import DestModeOptions, DestProperties
|
||||||
|
from environments.factory.combined_factories import DirtDestItemFactory
|
||||||
|
from environments.logging.envmonitor import EnvMonitor
|
||||||
|
from environments.utility_classes import MovementProperties, ObservationProperties, AgentRenderOptions
|
||||||
|
|
||||||
|
"""
|
||||||
|
In this studie, we want to export trained Agents for debugging purposes.
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
def encapsule_env_factory(env_fctry, env_kwrgs):
|
||||||
|
|
||||||
|
def _init():
|
||||||
|
with env_fctry(**env_kwrgs) as init_env:
|
||||||
|
return init_env
|
||||||
|
|
||||||
|
return _init
|
||||||
|
|
||||||
|
|
||||||
|
def load_model_run_baseline(policy_path, env_to_run):
|
||||||
|
# retrieve model class
|
||||||
|
model_cls = h.MODEL_MAP['A2C']
|
||||||
|
# Load both agents
|
||||||
|
model = model_cls.load(policy_path / 'model.zip', device='cpu')
|
||||||
|
# Load old env kwargs
|
||||||
|
with next(policy_path.glob('*params.json')).open('r') as f:
|
||||||
|
env_kwargs = simplejson.load(f)
|
||||||
|
env_kwargs.update(done_at_collision=True)
|
||||||
|
# Init Env
|
||||||
|
with env_to_run(**env_kwargs) as env_factory:
|
||||||
|
monitored_env_factory = EnvMonitor(env_factory)
|
||||||
|
recorded_env_factory = EnvRecorder(monitored_env_factory)
|
||||||
|
|
||||||
|
# Evaluation Loop for i in range(n Episodes)
|
||||||
|
for episode in range(5):
|
||||||
|
env_state = recorded_env_factory.reset()
|
||||||
|
rew, done_bool = 0, False
|
||||||
|
while not done_bool:
|
||||||
|
action = model.predict(env_state, deterministic=True)[0]
|
||||||
|
env_state, step_r, done_bool, info_obj = recorded_env_factory.step(action)
|
||||||
|
rew += step_r
|
||||||
|
if done_bool:
|
||||||
|
break
|
||||||
|
print(f'Factory run {episode} done, reward is:\n {rew}')
|
||||||
|
recorded_env_factory.save_run(filepath=policy_path / f'baseline_monitor.pick')
|
||||||
|
recorded_env_factory.save_records(filepath=policy_path / f'baseline_recorder.json')
|
||||||
|
|
||||||
|
|
||||||
|
def load_model_run_combined(root_path, env_to_run, env_kwargs):
|
||||||
|
# retrieve model class
|
||||||
|
model_cls = h.MODEL_MAP['A2C']
|
||||||
|
# Load both agents
|
||||||
|
models = [model_cls.load(model_zip, device='cpu') for model_zip in root_path.rglob('model.zip')]
|
||||||
|
# Load old env kwargs
|
||||||
|
env_kwargs = env_kwargs.copy()
|
||||||
|
env_kwargs.update(
|
||||||
|
n_agents=len(models),
|
||||||
|
done_at_collision=False)
|
||||||
|
|
||||||
|
# Init Env
|
||||||
|
with env_to_run(**env_kwargs) as env_factory:
|
||||||
|
|
||||||
|
action_translator = ActionTranslator(env_factory.named_action_space,
|
||||||
|
*[x.named_action_space for x in models])
|
||||||
|
observation_translator = ObservationTranslator(env_factory.observation_space.shape[-2:],
|
||||||
|
env_factory.named_observation_space,
|
||||||
|
*[x.named_observation_space for x in models])
|
||||||
|
|
||||||
|
env = EnvMonitor(env_factory)
|
||||||
|
# Evaluation Loop for i in range(n Episodes)
|
||||||
|
for episode in range(5):
|
||||||
|
env_state = env.reset()
|
||||||
|
rew, done_bool = 0, False
|
||||||
|
while not done_bool:
|
||||||
|
translated_observations = observation_translator(env_state)
|
||||||
|
actions = [model.predict(translated_observations[model_idx], deterministic=True)[0]
|
||||||
|
for model_idx, model in enumerate(models)]
|
||||||
|
translated_actions = action_translator(actions)
|
||||||
|
env_state, step_r, done_bool, info_obj = env.step(translated_actions)
|
||||||
|
rew += step_r
|
||||||
|
if done_bool:
|
||||||
|
break
|
||||||
|
print(f'Factory run {episode} done, reward is:\n {rew}')
|
||||||
|
env.save_run(filepath=root_path / f'monitor_combined.pick')
|
||||||
|
# env.save_records(filepath=root_path / f'recorder_combined.json')
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
# What to do:
|
||||||
|
train = True
|
||||||
|
individual_run = False
|
||||||
|
combined_run = False
|
||||||
|
multi_env = False
|
||||||
|
|
||||||
|
train_steps = 1e6
|
||||||
|
frames_to_stack = 3
|
||||||
|
|
||||||
|
# Define a global studi save path
|
||||||
|
paremters_of_interest = dict(
|
||||||
|
show_global_position_info=[True, False],
|
||||||
|
pomdp_r=[3],
|
||||||
|
cast_shadows=[True, False],
|
||||||
|
allow_diagonal_movement=[True],
|
||||||
|
parse_doors=[True, False],
|
||||||
|
doors_have_area=[True, False],
|
||||||
|
done_at_collision=[True, False]
|
||||||
|
)
|
||||||
|
keys, vals = zip(*paremters_of_interest.items())
|
||||||
|
|
||||||
|
# Then we find all permutations for those values
|
||||||
|
p = list(itertools.product(*vals))
|
||||||
|
|
||||||
|
# Finally we can create out list of dicts
|
||||||
|
result = [{keys[index]: entry[index] for index in range(len(entry))} for entry in p]
|
||||||
|
|
||||||
|
for u in result:
|
||||||
|
file_name = '_'.join('_'.join([str(y)[0] for y in x]) for x in u.items())
|
||||||
|
study_root_path = Path(__file__).parent.parent / 'study_out' / file_name
|
||||||
|
|
||||||
|
# Model Kwargs
|
||||||
|
policy_model_kwargs = dict(ent_coef=0.01)
|
||||||
|
|
||||||
|
# Define Global Env Parameters
|
||||||
|
# Define properties object parameters
|
||||||
|
obs_props = ObservationProperties(render_agents=AgentRenderOptions.NOT,
|
||||||
|
additional_agent_placeholder=None,
|
||||||
|
omit_agent_self=True,
|
||||||
|
frames_to_stack=frames_to_stack,
|
||||||
|
pomdp_r=u['pomdp_r'], cast_shadows=u['cast_shadows'],
|
||||||
|
show_global_position_info=u['show_global_position_info'])
|
||||||
|
move_props = MovementProperties(allow_diagonal_movement=u['allow_diagonal_movement'],
|
||||||
|
allow_square_movement=True,
|
||||||
|
allow_no_op=False)
|
||||||
|
dirt_props = DirtProperties(initial_dirt_ratio=0.35, initial_dirt_spawn_r_var=0.1,
|
||||||
|
clean_amount=0.34,
|
||||||
|
max_spawn_amount=0.1, max_global_amount=20,
|
||||||
|
max_local_amount=1, spawn_frequency=0, max_spawn_ratio=0.05,
|
||||||
|
dirt_smear_amount=0.0)
|
||||||
|
item_props = ItemProperties(n_items=10, spawn_frequency=30, n_drop_off_locations=2,
|
||||||
|
max_agent_inventory_capacity=15)
|
||||||
|
dest_props = DestProperties(n_dests=4, spawn_mode=DestModeOptions.GROUPED, spawn_frequency=1)
|
||||||
|
factory_kwargs = dict(n_agents=1, max_steps=500, parse_doors=u['parse_doors'],
|
||||||
|
level_name='rooms', doors_have_area=u['doors_have_area'],
|
||||||
|
verbose=False,
|
||||||
|
mv_prop=move_props,
|
||||||
|
obs_prop=obs_props,
|
||||||
|
done_at_collision=u['done_at_collision']
|
||||||
|
)
|
||||||
|
|
||||||
|
# Bundle both environments with global kwargs and parameters
|
||||||
|
env_map = {}
|
||||||
|
env_map.update({'dirt': (DirtFactory, dict(dirt_prop=dirt_props,
|
||||||
|
**factory_kwargs.copy()),
|
||||||
|
['cleanup_valid', 'cleanup_fail'])})
|
||||||
|
# env_map.update({'item': (ItemFactory, dict(item_prop=item_props,
|
||||||
|
# **factory_kwargs.copy()),
|
||||||
|
# ['DROPOFF_FAIL', 'ITEMACTION_FAIL', 'DROPOFF_VALID', 'ITEMACTION_VALID'])})
|
||||||
|
# env_map.update({'dest': (DestFactory, dict(dest_prop=dest_props,
|
||||||
|
# **factory_kwargs.copy()))})
|
||||||
|
env_map.update({'combined': (DirtDestItemFactory, dict(dest_prop=dest_props,
|
||||||
|
item_prop=item_props,
|
||||||
|
dirt_prop=dirt_props,
|
||||||
|
**factory_kwargs.copy()))})
|
||||||
|
env_names = list(env_map.keys())
|
||||||
|
|
||||||
|
# Train starts here ############################################################
|
||||||
|
# Build Major Loop parameters, parameter versions, Env Classes and models
|
||||||
|
if train:
|
||||||
|
for env_key in (env_key for env_key in env_map if 'combined' != env_key):
|
||||||
|
model_cls = h.MODEL_MAP['PPO']
|
||||||
|
combination_path = study_root_path / env_key
|
||||||
|
env_class, env_kwargs, env_plot_keys = env_map[env_key]
|
||||||
|
|
||||||
|
# Output folder
|
||||||
|
if (combination_path / 'monitor.pick').exists():
|
||||||
|
continue
|
||||||
|
combination_path.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
if not multi_env:
|
||||||
|
env_factory = encapsule_env_factory(env_class, env_kwargs)()
|
||||||
|
else:
|
||||||
|
env_factory = SubprocVecEnv([encapsule_env_factory(env_class, env_kwargs)
|
||||||
|
for _ in range(6)], start_method="spawn")
|
||||||
|
|
||||||
|
param_path = combination_path / f'env_params.json'
|
||||||
|
try:
|
||||||
|
env_factory.env_method('save_params', param_path)
|
||||||
|
except AttributeError:
|
||||||
|
env_factory.save_params(param_path)
|
||||||
|
|
||||||
|
# EnvMonitor Init
|
||||||
|
env_monitor = EnvMonitor(env_factory)
|
||||||
|
callbacks = [env_monitor]
|
||||||
|
|
||||||
|
# Model Init
|
||||||
|
model = model_cls("MlpPolicy", env_factory, **policy_model_kwargs,
|
||||||
|
verbose=1, seed=69, device='cpu')
|
||||||
|
|
||||||
|
# Model train
|
||||||
|
model.learn(total_timesteps=int(train_steps), callback=callbacks)
|
||||||
|
|
||||||
|
# Model save
|
||||||
|
try:
|
||||||
|
model.named_action_space = env_factory.unwrapped.named_action_space
|
||||||
|
model.named_observation_space = env_factory.unwrapped.named_observation_space
|
||||||
|
except AttributeError:
|
||||||
|
model.named_action_space = env_factory.get_attr("named_action_space")[0]
|
||||||
|
model.named_observation_space = env_factory.get_attr("named_observation_space")[0]
|
||||||
|
save_path = combination_path / f'model.zip'
|
||||||
|
model.save(save_path)
|
||||||
|
|
||||||
|
# Monitor Save
|
||||||
|
env_monitor.save_run(combination_path / 'monitor.pick',
|
||||||
|
auto_plotting_keys=['step_reward', 'collision'] + env_plot_keys)
|
||||||
|
|
||||||
|
# Better be save then sorry: Clean up!
|
||||||
|
del env_factory, model
|
||||||
|
import gc
|
||||||
|
gc.collect()
|
||||||
|
|
||||||
|
# Train ends here ############################################################
|
||||||
|
|
||||||
|
# Evaluation starts here #####################################################
|
||||||
|
# First Iterate over every model and monitor "as trained"
|
||||||
|
if individual_run:
|
||||||
|
print('Start Individual Recording')
|
||||||
|
for env_key in (env_key for env_key in env_map if 'combined' != env_key):
|
||||||
|
# For trained policy in study_root_path / identifier
|
||||||
|
policy_path = study_root_path / env_key
|
||||||
|
load_model_run_baseline(policy_path, env_map[policy_path.name][0])
|
||||||
|
|
||||||
|
# for policy_path in (y for y in policy_path.iterdir() if y.is_dir()):
|
||||||
|
# load_model_run_baseline(policy_path)
|
||||||
|
print('Done Individual Recording')
|
||||||
|
|
||||||
|
# Then iterate over every model and monitor "ood behavior" - "is it ood?"
|
||||||
|
if combined_run:
|
||||||
|
print('Start combined run')
|
||||||
|
for env_key in (env_key for env_key in env_map if 'combined' == env_key):
|
||||||
|
# For trained policy in study_root_path / identifier
|
||||||
|
factory, kwargs = env_map[env_key]
|
||||||
|
load_model_run_combined(study_root_path, factory, kwargs)
|
||||||
|
print('OOD Tracking Done')
|
36
studies/viz_policy.py
Normal file
36
studies/viz_policy.py
Normal file
@ -0,0 +1,36 @@
|
|||||||
|
import pandas as pd
|
||||||
|
from algorithms.marl import LoopSNAC, LoopIAC, LoopSEAC
|
||||||
|
from pathlib import Path
|
||||||
|
from algorithms.utils import load_yaml_file
|
||||||
|
from tqdm import trange
|
||||||
|
study = 'example_config#0'
|
||||||
|
#study_root = Path(__file__).parent / study
|
||||||
|
study_root = Path('/Users/romue/PycharmProjects/EDYS/algorithms/marl/')
|
||||||
|
|
||||||
|
#['L2NoAh_gru', 'L2NoCh_gru', 'nomix_gru']:
|
||||||
|
render = True
|
||||||
|
eval_eps = 3
|
||||||
|
for run in range(0, 5):
|
||||||
|
for name in ['example_config']:#['L2OnlyAh_gru', 'L2OnlyChAh_gru', 'L2OnlyMix_gru']: #['layernorm_gru', 'basic_gru', 'nonorm_gru', 'spectralnorm_gru']:
|
||||||
|
cfg = load_yaml_file(study_root / study / 'config.yaml')
|
||||||
|
#p_root = Path(study_root / study / f'{name}#{run}')
|
||||||
|
dfs = []
|
||||||
|
for i in trange(500):
|
||||||
|
path = study_root / study / f'checkpoint_{161}'
|
||||||
|
print(path)
|
||||||
|
|
||||||
|
snac = LoopSEAC(cfg)
|
||||||
|
snac.load_state_dict(path)
|
||||||
|
snac.eval()
|
||||||
|
|
||||||
|
df = snac.eval_loop(render=render, n_episodes=eval_eps)
|
||||||
|
df['checkpoint'] = i
|
||||||
|
dfs.append(df)
|
||||||
|
|
||||||
|
results = pd.concat(dfs)
|
||||||
|
results['run'] = run
|
||||||
|
results.to_csv(p_root / 'results.csv', index=False)
|
||||||
|
|
||||||
|
#sns.lineplot(data=results, x='checkpoint', y='reward', hue='agent', palette='husl')
|
||||||
|
|
||||||
|
#plt.savefig(f'{experiment_name}.png')
|
@ -1,39 +0,0 @@
|
|||||||
from salina.agents import Agents, TemporalAgent
|
|
||||||
import torch
|
|
||||||
from salina import Workspace, get_arguments, get_class, instantiate_class
|
|
||||||
from pathlib import Path
|
|
||||||
from salina.agents.gyma import GymAgent
|
|
||||||
import time
|
|
||||||
from algorithms.utils import load_yaml_file, add_env_props
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
|
||||||
# Setup workspace
|
|
||||||
uid = time.time()
|
|
||||||
workspace = Workspace()
|
|
||||||
weights = Path('/Users/romue/PycharmProjects/EDYS/studies/agent_1636994369.145843.pt')
|
|
||||||
|
|
||||||
cfg = load_yaml_file(Path(__file__).parent / 'sat_mad.yaml')
|
|
||||||
add_env_props(cfg)
|
|
||||||
cfg['env'].update({'n_agents': 2})
|
|
||||||
|
|
||||||
# instantiate agent and env
|
|
||||||
env_agent = GymAgent(
|
|
||||||
get_class(cfg['env']),
|
|
||||||
get_arguments(cfg['env']),
|
|
||||||
n_envs=1
|
|
||||||
)
|
|
||||||
|
|
||||||
agents = []
|
|
||||||
for _ in range(2):
|
|
||||||
a2c_agent = instantiate_class(cfg['agent'])
|
|
||||||
if weights:
|
|
||||||
a2c_agent.load_state_dict(torch.load(weights))
|
|
||||||
agents.append(a2c_agent)
|
|
||||||
|
|
||||||
# combine agents
|
|
||||||
acquisition_agent = TemporalAgent(Agents(env_agent, *agents))
|
|
||||||
acquisition_agent.seed(42)
|
|
||||||
|
|
||||||
acquisition_agent(workspace, t=0, n_steps=400, stochastic=False, save_render=True)
|
|
||||||
|
|
||||||
|
|
Loading…
x
Reference in New Issue
Block a user