Compare commits
17 Commits
testing_ol
...
jannis_exp
Author | SHA1 | Date | |
---|---|---|---|
4fe43a23b8 | |||
a9a4274370 | |||
b09c461754 | |||
ffc47752a7 | |||
3e19970a60 | |||
51fb73ebb8 | |||
a16d7e709e | |||
3ce6302e8a | |||
823aa075b9 | |||
d29ccbbb71 | |||
2a2aafa988 | |||
0e8a4af740 | |||
b6c8cbd2e3 | |||
3150757347 | |||
435056f373 | |||
78bf19f7f4 | |||
b43f595207 |
1
.gitignore
vendored
1
.gitignore
vendored
@ -702,3 +702,4 @@ $RECYCLE.BIN/
|
|||||||
|
|
||||||
# End of https://www.toptal.com/developers/gitignore/api/linux,unity,macos,python,windows,pycharm,notepadpp,visualstudiocode,latex
|
# End of https://www.toptal.com/developers/gitignore/api/linux,unity,macos,python,windows,pycharm,notepadpp,visualstudiocode,latex
|
||||||
/studies/e_1/
|
/studies/e_1/
|
||||||
|
/studies/curious_study/
|
||||||
|
@ -5,11 +5,25 @@ from networkx.algorithms.approximation import traveling_salesman as tsp
|
|||||||
from environments.factory.base.objects import Agent
|
from environments.factory.base.objects import Agent
|
||||||
from environments.helpers import points_to_graph
|
from environments.helpers import points_to_graph
|
||||||
from environments import helpers as h
|
from environments import helpers as h
|
||||||
from environments.helpers import Constants as c
|
|
||||||
|
|
||||||
|
from environments.helpers import Constants as BaseConstants
|
||||||
|
from environments.helpers import EnvActions as BaseActions
|
||||||
|
|
||||||
|
|
||||||
|
class Constants(BaseConstants):
|
||||||
|
DIRT = 'Dirt'
|
||||||
|
|
||||||
|
|
||||||
|
class Actions(BaseActions):
|
||||||
|
CLEAN_UP = 'do_cleanup_action'
|
||||||
|
|
||||||
|
|
||||||
|
a = Actions
|
||||||
|
c = Constants
|
||||||
|
|
||||||
future_planning = 7
|
future_planning = 7
|
||||||
|
|
||||||
|
|
||||||
class TSPDirtAgent(Agent):
|
class TSPDirtAgent(Agent):
|
||||||
|
|
||||||
def __init__(self, env, *args,
|
def __init__(self, env, *args,
|
||||||
@ -26,7 +40,7 @@ class TSPDirtAgent(Agent):
|
|||||||
def predict(self, *_, **__):
|
def predict(self, *_, **__):
|
||||||
if self._env[c.DIRT].by_pos(self.pos) is not None:
|
if self._env[c.DIRT].by_pos(self.pos) is not None:
|
||||||
# Translate the action_object to an integer to have the same output as any other model
|
# Translate the action_object to an integer to have the same output as any other model
|
||||||
action = h.EnvActions.CLEAN_UP
|
action = a.CLEAN_UP
|
||||||
elif any('door' in x.name.lower() for x in self.tile.guests):
|
elif any('door' in x.name.lower() for x in self.tile.guests):
|
||||||
door = next(x for x in self.tile.guests if 'door' in x.name.lower())
|
door = next(x for x in self.tile.guests if 'door' in x.name.lower())
|
||||||
if door.is_closed:
|
if door.is_closed:
|
||||||
@ -37,7 +51,7 @@ class TSPDirtAgent(Agent):
|
|||||||
else:
|
else:
|
||||||
action = self._predict_move()
|
action = self._predict_move()
|
||||||
# Translate the action_object to an integer to have the same output as any other model
|
# Translate the action_object to an integer to have the same output as any other model
|
||||||
action_obj = next(action_i for action_i, action_obj in enumerate(self._env._actions) if action_obj == action)
|
action_obj = next(action_i for action_name, action_i in self._env.named_action_space.items() if action_name == action)
|
||||||
return action_obj
|
return action_obj
|
||||||
|
|
||||||
def _predict_move(self):
|
def _predict_move(self):
|
||||||
|
@ -1,221 +0,0 @@
|
|||||||
from typing import NamedTuple, Union
|
|
||||||
from collections import deque, OrderedDict, defaultdict
|
|
||||||
import numpy as np
|
|
||||||
import random
|
|
||||||
|
|
||||||
import pandas as pd
|
|
||||||
import torch
|
|
||||||
import torch.nn as nn
|
|
||||||
|
|
||||||
from tqdm import trange
|
|
||||||
|
|
||||||
class Experience(NamedTuple):
|
|
||||||
# can be use for a single (s_t, a, r s_{t+1}) tuple
|
|
||||||
# or for a batch of tuples
|
|
||||||
observation: np.ndarray
|
|
||||||
next_observation: np.ndarray
|
|
||||||
action: np.ndarray
|
|
||||||
reward: Union[float, np.ndarray]
|
|
||||||
done : Union[bool, np.ndarray]
|
|
||||||
episode: int = -1
|
|
||||||
|
|
||||||
|
|
||||||
class BaseLearner:
|
|
||||||
def __init__(self, env, n_agents=1, train_every=('step', 4), n_grad_steps=1, stack_n_frames=1):
|
|
||||||
assert train_every[0] in ['step', 'episode'], 'train_every[0] must be one of ["step", "episode"]'
|
|
||||||
self.env = env
|
|
||||||
self.n_agents = n_agents
|
|
||||||
self.n_grad_steps = n_grad_steps
|
|
||||||
self.train_every = train_every
|
|
||||||
self.stack_n_frames = deque(maxlen=stack_n_frames)
|
|
||||||
self.device = 'cpu'
|
|
||||||
self.n_updates = 0
|
|
||||||
self.step = 0
|
|
||||||
self.episode_step = 0
|
|
||||||
self.episode = 0
|
|
||||||
self.running_reward = deque(maxlen=5)
|
|
||||||
|
|
||||||
def to(self, device):
|
|
||||||
self.device = device
|
|
||||||
for attr, value in self.__dict__.items():
|
|
||||||
if isinstance(value, nn.Module):
|
|
||||||
value = value.to(self.device)
|
|
||||||
return self
|
|
||||||
|
|
||||||
def get_action(self, obs) -> Union[int, np.ndarray]:
|
|
||||||
pass
|
|
||||||
|
|
||||||
def on_new_experience(self, experience):
|
|
||||||
pass
|
|
||||||
|
|
||||||
def on_step_end(self, n_steps):
|
|
||||||
pass
|
|
||||||
|
|
||||||
def on_episode_end(self, n_steps):
|
|
||||||
pass
|
|
||||||
|
|
||||||
def on_all_done(self):
|
|
||||||
pass
|
|
||||||
|
|
||||||
def train(self):
|
|
||||||
pass
|
|
||||||
|
|
||||||
def reward(self, r):
|
|
||||||
return r
|
|
||||||
|
|
||||||
def learn(self, n_steps):
|
|
||||||
train_type, train_freq = self.train_every
|
|
||||||
while self.step < n_steps:
|
|
||||||
obs, done = self.env.reset(), False
|
|
||||||
total_reward = 0
|
|
||||||
self.episode_step = 0
|
|
||||||
while not done:
|
|
||||||
|
|
||||||
action = self.get_action(obs)
|
|
||||||
|
|
||||||
next_obs, reward, done, info = self.env.step(action if not len(action) == 1 else action[0])
|
|
||||||
|
|
||||||
experience = Experience(observation=obs, next_observation=next_obs,
|
|
||||||
action=action, reward=self.reward(reward),
|
|
||||||
done=done, episode=self.episode) # do we really need to copy?
|
|
||||||
self.on_new_experience(experience)
|
|
||||||
# end of step routine
|
|
||||||
obs = next_obs
|
|
||||||
total_reward += reward
|
|
||||||
self.step += 1
|
|
||||||
self.episode_step += 1
|
|
||||||
self.on_step_end(n_steps)
|
|
||||||
if train_type == 'step' and (self.step % train_freq == 0):
|
|
||||||
self.train()
|
|
||||||
self.n_updates += 1
|
|
||||||
self.on_episode_end(n_steps)
|
|
||||||
if train_type == 'episode' and (self.episode % train_freq == 0):
|
|
||||||
self.train()
|
|
||||||
self.n_updates += 1
|
|
||||||
|
|
||||||
self.running_reward.append(total_reward)
|
|
||||||
self.episode += 1
|
|
||||||
try:
|
|
||||||
if self.step % 100 == 0:
|
|
||||||
print(
|
|
||||||
f'Step: {self.step} ({(self.step / n_steps) * 100:.2f}%)\tRunning reward: {sum(list(self.running_reward)) / len(self.running_reward):.2f}\t'
|
|
||||||
f' eps: {self.eps:.4f}\tRunning loss: {sum(list(self.running_loss)) / len(self.running_loss):.4f}\tUpdates:{self.n_updates}')
|
|
||||||
except Exception as e:
|
|
||||||
pass
|
|
||||||
self.on_all_done()
|
|
||||||
|
|
||||||
def evaluate(self, n_episodes=100, render=False):
|
|
||||||
with torch.no_grad():
|
|
||||||
data = []
|
|
||||||
for eval_i in trange(n_episodes):
|
|
||||||
obs, done = self.env.reset(), False
|
|
||||||
while not done:
|
|
||||||
action = self.get_action(obs)
|
|
||||||
next_obs, reward, done, info = self.env.step(action if not len(action) == 1 else action[0])
|
|
||||||
if render: self.env.render()
|
|
||||||
obs = next_obs # srsly i'm so stupid
|
|
||||||
info.update({'reward': reward, 'eval_episode': eval_i})
|
|
||||||
data.append(info)
|
|
||||||
return pd.DataFrame(data).fillna(0)
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class BaseBuffer:
|
|
||||||
def __init__(self, size: int):
|
|
||||||
self.size = size
|
|
||||||
self.experience = deque(maxlen=size)
|
|
||||||
|
|
||||||
def __len__(self):
|
|
||||||
return len(self.experience)
|
|
||||||
|
|
||||||
def add(self, exp: Experience):
|
|
||||||
self.experience.append(exp)
|
|
||||||
|
|
||||||
def sample(self, k, cer=4):
|
|
||||||
sample = random.choices(self.experience, k=k-cer)
|
|
||||||
for i in range(cer): sample += [self.experience[-i]]
|
|
||||||
observations = torch.stack([torch.from_numpy(e.observation) for e in sample], 0).float()
|
|
||||||
next_observations = torch.stack([torch.from_numpy(e.next_observation) for e in sample], 0).float()
|
|
||||||
actions = torch.tensor([e.action for e in sample]).long()
|
|
||||||
rewards = torch.tensor([e.reward for e in sample]).float().view(-1, 1)
|
|
||||||
dones = torch.tensor([e.done for e in sample]).float().view(-1, 1)
|
|
||||||
#print(observations.shape, next_observations.shape, actions.shape, rewards.shape, dones.shape)
|
|
||||||
return Experience(observations, next_observations, actions, rewards, dones)
|
|
||||||
|
|
||||||
|
|
||||||
class TrajectoryBuffer(BaseBuffer):
|
|
||||||
def __init__(self, size):
|
|
||||||
super(TrajectoryBuffer, self).__init__(size)
|
|
||||||
self.experience = defaultdict(list)
|
|
||||||
|
|
||||||
def add(self, exp: Experience):
|
|
||||||
self.experience[exp.episode].append(exp)
|
|
||||||
if len(self.experience) > self.size:
|
|
||||||
oldest_traj_key = list(sorted(self.experience.keys()))[0]
|
|
||||||
del self.experience[oldest_traj_key]
|
|
||||||
|
|
||||||
|
|
||||||
def soft_update(local_model, target_model, tau):
|
|
||||||
# taken from https://github.com/BY571/Munchausen-RL/blob/master/M-DQN.ipynb
|
|
||||||
for target_param, local_param in zip(target_model.parameters(), local_model.parameters()):
|
|
||||||
target_param.data.copy_(tau*local_param.data + (1.-tau)*target_param.data)
|
|
||||||
|
|
||||||
|
|
||||||
def mlp_maker(dims, flatten=False, activation='elu', activation_last='identity'):
|
|
||||||
activations = {'elu': nn.ELU, 'relu': nn.ReLU, 'sigmoid': nn.Sigmoid,
|
|
||||||
'leaky_relu': nn.LeakyReLU, 'tanh': nn.Tanh,
|
|
||||||
'gelu': nn.GELU, 'identity': nn.Identity}
|
|
||||||
layers = [('Flatten', nn.Flatten())] if flatten else []
|
|
||||||
for i in range(1, len(dims)):
|
|
||||||
layers.append((f'Layer #{i - 1}: Linear', nn.Linear(dims[i - 1], dims[i])))
|
|
||||||
activation_str = activation if i != len(dims)-1 else activation_last
|
|
||||||
layers.append((f'Layer #{i - 1}: {activation_str.capitalize()}', activations[activation_str]()))
|
|
||||||
return nn.Sequential(OrderedDict(layers))
|
|
||||||
|
|
||||||
|
|
||||||
class BaseDQN(nn.Module):
|
|
||||||
def __init__(self, dims=[3*5*5, 64, 64, 9]):
|
|
||||||
super(BaseDQN, self).__init__()
|
|
||||||
self.net = mlp_maker(dims, flatten=True)
|
|
||||||
|
|
||||||
@torch.no_grad()
|
|
||||||
def act(self, x) -> np.ndarray:
|
|
||||||
action = self.forward(x).max(-1)[1].numpy()
|
|
||||||
return action
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
return self.net(x)
|
|
||||||
|
|
||||||
|
|
||||||
class BaseDDQN(BaseDQN):
|
|
||||||
def __init__(self,
|
|
||||||
backbone_dims=[3*5*5, 64, 64],
|
|
||||||
value_dims=[64, 1],
|
|
||||||
advantage_dims=[64, 9],
|
|
||||||
activation='elu'):
|
|
||||||
super(BaseDDQN, self).__init__(backbone_dims)
|
|
||||||
self.net = mlp_maker(backbone_dims, activation=activation, flatten=True)
|
|
||||||
self.value_head = mlp_maker(value_dims)
|
|
||||||
self.advantage_head = mlp_maker(advantage_dims)
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
features = self.net(x)
|
|
||||||
advantages = self.advantage_head(features)
|
|
||||||
values = self.value_head(features)
|
|
||||||
return values + (advantages - advantages.mean())
|
|
||||||
|
|
||||||
|
|
||||||
class BaseICM(nn.Module):
|
|
||||||
def __init__(self, backbone_dims=[2*3*5*5, 64, 64], head_dims=[2*64, 64, 9]):
|
|
||||||
super(BaseICM, self).__init__()
|
|
||||||
self.backbone = mlp_maker(backbone_dims, flatten=True, activation_last='relu', activation='relu')
|
|
||||||
self.icm = mlp_maker(head_dims)
|
|
||||||
self.ce = nn.CrossEntropyLoss()
|
|
||||||
|
|
||||||
def forward(self, s0, s1, a):
|
|
||||||
phi_s0 = self.backbone(s0)
|
|
||||||
phi_s1 = self.backbone(s1)
|
|
||||||
cat = torch.cat((phi_s0, phi_s1), dim=1)
|
|
||||||
a_prime = torch.softmax(self.icm(cat), dim=-1)
|
|
||||||
ce = self.ce(a_prime, a)
|
|
||||||
return dict(prediction=a_prime, loss=ce)
|
|
@ -1,77 +0,0 @@
|
|||||||
import numpy as np
|
|
||||||
import torch
|
|
||||||
import torch.nn.functional as F
|
|
||||||
from algorithms.q_learner import QLearner
|
|
||||||
|
|
||||||
|
|
||||||
class MQLearner(QLearner):
|
|
||||||
# Munchhausen Q-Learning
|
|
||||||
def __init__(self, *args, temperature=0.03, alpha=0.9, clip_l0=-1.0, **kwargs):
|
|
||||||
super(MQLearner, self).__init__(*args, **kwargs)
|
|
||||||
assert self.n_agents == 1, 'M-DQN currently only supports single agent training'
|
|
||||||
self.temperature = temperature
|
|
||||||
self.alpha = alpha
|
|
||||||
self.clip0 = clip_l0
|
|
||||||
|
|
||||||
def tau_ln_pi(self, qs):
|
|
||||||
# computes log(softmax(qs/temperature))
|
|
||||||
# Custom log-sum-exp trick from page 18 to compute the log-policy terms
|
|
||||||
v_k = qs.max(-1)[0].unsqueeze(-1)
|
|
||||||
advantage = qs - v_k
|
|
||||||
logsum = torch.logsumexp(advantage / self.temperature, -1).unsqueeze(-1)
|
|
||||||
tau_ln_pi = advantage - self.temperature * logsum
|
|
||||||
return tau_ln_pi
|
|
||||||
|
|
||||||
def train(self):
|
|
||||||
if len(self.buffer) < self.batch_size: return
|
|
||||||
for _ in range(self.n_grad_steps):
|
|
||||||
|
|
||||||
experience = self.buffer.sample(self.batch_size, cer=self.train_every[-1])
|
|
||||||
|
|
||||||
with torch.no_grad():
|
|
||||||
q_target_next = self.target_q_net(experience.next_observation)
|
|
||||||
tau_log_pi_next = self.tau_ln_pi(q_target_next)
|
|
||||||
|
|
||||||
q_k_targets = self.target_q_net(experience.observation)
|
|
||||||
log_pi = self.tau_ln_pi(q_k_targets)
|
|
||||||
|
|
||||||
pi_target = F.softmax(q_target_next / self.temperature, dim=-1)
|
|
||||||
q_target = (self.gamma * (pi_target * (q_target_next - tau_log_pi_next) * (1 - experience.done)).sum(-1)).unsqueeze(-1)
|
|
||||||
|
|
||||||
munchausen_addon = log_pi.gather(-1, experience.action)
|
|
||||||
|
|
||||||
munchausen_reward = (experience.reward + self.alpha * torch.clamp(munchausen_addon, min=self.clip0, max=0))
|
|
||||||
|
|
||||||
# Compute Q targets for current states
|
|
||||||
m_q_target = munchausen_reward + q_target
|
|
||||||
|
|
||||||
# Get expected Q values from local model
|
|
||||||
q_k = self.q_net(experience.observation)
|
|
||||||
pred_q = q_k.gather(-1, experience.action)
|
|
||||||
|
|
||||||
# Compute loss
|
|
||||||
loss = torch.mean(self.reg_weight * pred_q + torch.pow(pred_q - m_q_target, 2))
|
|
||||||
self._backprop_loss(loss)
|
|
||||||
|
|
||||||
from tqdm import trange
|
|
||||||
from collections import deque
|
|
||||||
class MQICMLearner(MQLearner):
|
|
||||||
def __init__(self, *args, icm, **kwargs):
|
|
||||||
super(MQICMLearner, self).__init__(*args, **kwargs)
|
|
||||||
self.icm = icm
|
|
||||||
self.icm_optimizer = torch.optim.AdamW(self.icm.parameters())
|
|
||||||
self.normalize_reward = deque(maxlen=1000)
|
|
||||||
|
|
||||||
def on_all_done(self):
|
|
||||||
from collections import deque
|
|
||||||
losses = deque(maxlen=100)
|
|
||||||
for b in trange(10000):
|
|
||||||
batch = self.buffer.sample(128, 0)
|
|
||||||
s0, s1, a = batch.observation, batch.next_observation, batch.action
|
|
||||||
loss = self.icm(s0, s1, a.squeeze())['loss']
|
|
||||||
self.icm_optimizer.zero_grad()
|
|
||||||
loss.backward()
|
|
||||||
self.icm_optimizer.step()
|
|
||||||
losses.append(loss.item())
|
|
||||||
if b%100 == 0:
|
|
||||||
print(np.mean(losses))
|
|
6
algorithms/marl/__init__.py
Normal file
6
algorithms/marl/__init__.py
Normal file
@ -0,0 +1,6 @@
|
|||||||
|
from algorithms.marl.base_ac import BaseActorCritic
|
||||||
|
from algorithms.marl.iac import LoopIAC
|
||||||
|
from algorithms.marl.snac import LoopSNAC
|
||||||
|
from algorithms.marl.seac import LoopSEAC
|
||||||
|
from algorithms.marl.mappo import LoopMAPPO
|
||||||
|
from algorithms.marl.memory import MARLActorCriticMemory
|
221
algorithms/marl/base_ac.py
Normal file
221
algorithms/marl/base_ac.py
Normal file
@ -0,0 +1,221 @@
|
|||||||
|
import torch
|
||||||
|
from typing import Union, List
|
||||||
|
import numpy as np
|
||||||
|
from torch.distributions import Categorical
|
||||||
|
from algorithms.marl.memory import MARLActorCriticMemory
|
||||||
|
from algorithms.utils import add_env_props, instantiate_class
|
||||||
|
from pathlib import Path
|
||||||
|
import pandas as pd
|
||||||
|
from collections import deque
|
||||||
|
|
||||||
|
|
||||||
|
class Names:
|
||||||
|
REWARD = 'reward'
|
||||||
|
DONE = 'done'
|
||||||
|
ACTION = 'action'
|
||||||
|
OBSERVATION = 'observation'
|
||||||
|
LOGITS = 'logits'
|
||||||
|
HIDDEN_ACTOR = 'hidden_actor'
|
||||||
|
HIDDEN_CRITIC = 'hidden_critic'
|
||||||
|
AGENT = 'agent'
|
||||||
|
ENV = 'env'
|
||||||
|
N_AGENTS = 'n_agents'
|
||||||
|
ALGORITHM = 'algorithm'
|
||||||
|
MAX_STEPS = 'max_steps'
|
||||||
|
N_STEPS = 'n_steps'
|
||||||
|
BUFFER_SIZE = 'buffer_size'
|
||||||
|
CRITIC = 'critic'
|
||||||
|
BATCH_SIZE = 'bnatch_size'
|
||||||
|
N_ACTIONS = 'n_actions'
|
||||||
|
|
||||||
|
nms = Names
|
||||||
|
ListOrTensor = Union[List, torch.Tensor]
|
||||||
|
|
||||||
|
|
||||||
|
class BaseActorCritic:
|
||||||
|
def __init__(self, cfg):
|
||||||
|
add_env_props(cfg)
|
||||||
|
self.__training = True
|
||||||
|
self.cfg = cfg
|
||||||
|
self.n_agents = cfg[nms.ENV][nms.N_AGENTS]
|
||||||
|
self.reset_memory_after_epoch = True
|
||||||
|
self.setup()
|
||||||
|
|
||||||
|
def setup(self):
|
||||||
|
self.net = instantiate_class(self.cfg[nms.AGENT])
|
||||||
|
self.optimizer = torch.optim.RMSprop(self.net.parameters(), lr=3e-4, eps=1e-5)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _as_torch(cls, x):
|
||||||
|
if isinstance(x, np.ndarray):
|
||||||
|
return torch.from_numpy(x)
|
||||||
|
elif isinstance(x, List):
|
||||||
|
return torch.tensor(x)
|
||||||
|
elif isinstance(x, (int, float)):
|
||||||
|
return torch.tensor([x])
|
||||||
|
return x
|
||||||
|
|
||||||
|
def train(self):
|
||||||
|
self.__training = False
|
||||||
|
networks = [self.net] if not isinstance(self.net, List) else self.net
|
||||||
|
for net in networks:
|
||||||
|
net.train()
|
||||||
|
|
||||||
|
def eval(self):
|
||||||
|
self.__training = False
|
||||||
|
networks = [self.net] if not isinstance(self.net, List) else self.net
|
||||||
|
for net in networks:
|
||||||
|
net.eval()
|
||||||
|
|
||||||
|
def load_state_dict(self, path: Path):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def get_actions(self, out) -> ListOrTensor:
|
||||||
|
actions = [Categorical(logits=logits).sample().item() for logits in out[nms.LOGITS]]
|
||||||
|
return actions
|
||||||
|
|
||||||
|
def init_hidden(self) -> dict[ListOrTensor]:
|
||||||
|
pass
|
||||||
|
|
||||||
|
def forward(self,
|
||||||
|
observations: ListOrTensor,
|
||||||
|
actions: ListOrTensor,
|
||||||
|
hidden_actor: ListOrTensor,
|
||||||
|
hidden_critic: ListOrTensor
|
||||||
|
) -> dict[ListOrTensor]:
|
||||||
|
pass
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def train_loop(self, checkpointer=None):
|
||||||
|
env = instantiate_class(self.cfg[nms.ENV])
|
||||||
|
n_steps, max_steps = [self.cfg[nms.ALGORITHM][k] for k in [nms.N_STEPS, nms.MAX_STEPS]]
|
||||||
|
tm = MARLActorCriticMemory(self.n_agents, self.cfg[nms.ALGORITHM].get(nms.BUFFER_SIZE, n_steps))
|
||||||
|
global_steps, episode, df_results = 0, 0, []
|
||||||
|
reward_queue = deque(maxlen=2000)
|
||||||
|
|
||||||
|
while global_steps < max_steps:
|
||||||
|
obs = env.reset()
|
||||||
|
last_hiddens = self.init_hidden()
|
||||||
|
last_action, reward = [-1] * self.n_agents, [0.] * self.n_agents
|
||||||
|
done, rew_log = [False] * self.n_agents, 0
|
||||||
|
|
||||||
|
if self.reset_memory_after_epoch:
|
||||||
|
tm.reset()
|
||||||
|
|
||||||
|
tm.add(observation=obs, action=last_action,
|
||||||
|
logits=torch.zeros(self.n_agents, 1, self.cfg[nms.AGENT][nms.N_ACTIONS]),
|
||||||
|
values=torch.zeros(self.n_agents, 1), reward=reward, done=done, **last_hiddens)
|
||||||
|
|
||||||
|
while not all(done):
|
||||||
|
out = self.forward(obs, last_action, **last_hiddens)
|
||||||
|
action = self.get_actions(out)
|
||||||
|
next_obs, reward, done, info = env.step(action)
|
||||||
|
done = [done] * self.n_agents if isinstance(done, bool) else done
|
||||||
|
|
||||||
|
last_hiddens = dict(hidden_actor =out[nms.HIDDEN_ACTOR],
|
||||||
|
hidden_critic=out[nms.HIDDEN_CRITIC])
|
||||||
|
|
||||||
|
|
||||||
|
tm.add(observation=obs, action=action, reward=reward, done=done,
|
||||||
|
logits=out.get(nms.LOGITS, None), values=out.get(nms.CRITIC, None),
|
||||||
|
**last_hiddens)
|
||||||
|
|
||||||
|
obs = next_obs
|
||||||
|
last_action = action
|
||||||
|
|
||||||
|
if (global_steps+1) % n_steps == 0 or all(done):
|
||||||
|
with torch.inference_mode(False):
|
||||||
|
self.learn(tm)
|
||||||
|
|
||||||
|
global_steps += 1
|
||||||
|
rew_log += sum(reward)
|
||||||
|
reward_queue.extend(reward)
|
||||||
|
|
||||||
|
if checkpointer is not None:
|
||||||
|
checkpointer.step([
|
||||||
|
(f'agent#{i}', agent)
|
||||||
|
for i, agent in enumerate([self.net] if not isinstance(self.net, List) else self.net)
|
||||||
|
])
|
||||||
|
|
||||||
|
if global_steps >= max_steps:
|
||||||
|
break
|
||||||
|
print(f'reward at episode: {episode} = {rew_log}')
|
||||||
|
episode += 1
|
||||||
|
df_results.append([episode, rew_log, *reward])
|
||||||
|
df_results = pd.DataFrame(df_results, columns=['steps', 'reward', *[f'agent#{i}' for i in range(self.n_agents)]])
|
||||||
|
if checkpointer is not None:
|
||||||
|
df_results.to_csv(checkpointer.path / 'results.csv', index=False)
|
||||||
|
return df_results
|
||||||
|
|
||||||
|
@torch.inference_mode(True)
|
||||||
|
def eval_loop(self, n_episodes, render=False):
|
||||||
|
env = instantiate_class(self.cfg[nms.ENV])
|
||||||
|
episode, results = 0, []
|
||||||
|
while episode < n_episodes:
|
||||||
|
obs = env.reset()
|
||||||
|
last_hiddens = self.init_hidden()
|
||||||
|
last_action, reward = [-1] * self.n_agents, [0.] * self.n_agents
|
||||||
|
done, rew_log, eps_rew = [False] * self.n_agents, 0, torch.zeros(self.n_agents)
|
||||||
|
while not all(done):
|
||||||
|
if render: env.render()
|
||||||
|
|
||||||
|
out = self.forward(obs, last_action, **last_hiddens)
|
||||||
|
action = self.get_actions(out)
|
||||||
|
next_obs, reward, done, info = env.step(action)
|
||||||
|
|
||||||
|
if isinstance(done, bool): done = [done] * obs.shape[0]
|
||||||
|
obs = next_obs
|
||||||
|
last_action = action
|
||||||
|
last_hiddens = dict(hidden_actor=out.get(nms.HIDDEN_ACTOR, None),
|
||||||
|
hidden_critic=out.get(nms.HIDDEN_CRITIC, None)
|
||||||
|
)
|
||||||
|
eps_rew += torch.tensor(reward)
|
||||||
|
results.append(eps_rew.tolist() + [sum(eps_rew).item()] + [episode])
|
||||||
|
episode += 1
|
||||||
|
agent_columns = [f'agent#{i}' for i in range(self.cfg['env']['n_agents'])]
|
||||||
|
results = pd.DataFrame(results, columns=agent_columns + ['sum', 'episode'])
|
||||||
|
results = pd.melt(results, id_vars=['episode'], value_vars=agent_columns + ['sum'], value_name='reward', var_name='agent')
|
||||||
|
return results
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def compute_advantages(critic, reward, done, gamma, gae_coef=0.0):
|
||||||
|
tds = (reward + gamma * (1.0 - done) * critic[:, 1:].detach()) - critic[:, :-1]
|
||||||
|
|
||||||
|
if gae_coef <= 0:
|
||||||
|
return tds
|
||||||
|
|
||||||
|
gae = torch.zeros_like(tds[:, -1])
|
||||||
|
gaes = []
|
||||||
|
for t in range(tds.shape[1]-1, -1, -1):
|
||||||
|
gae = tds[:, t] + gamma * gae_coef * (1.0 - done[:, t]) * gae
|
||||||
|
gaes.insert(0, gae)
|
||||||
|
gaes = torch.stack(gaes, dim=1)
|
||||||
|
return gaes
|
||||||
|
|
||||||
|
def actor_critic(self, tm, network, gamma, entropy_coef, vf_coef, gae_coef=0.0, **kwargs):
|
||||||
|
obs, actions, done, reward = tm.observation, tm.action, tm.done[:, 1:], tm.reward[:, 1:]
|
||||||
|
|
||||||
|
out = network(obs, actions, tm.hidden_actor[:, 0], tm.hidden_critic[:, 0])
|
||||||
|
logits = out[nms.LOGITS][:, :-1] # last one only needed for v_{t+1}
|
||||||
|
critic = out[nms.CRITIC]
|
||||||
|
|
||||||
|
entropy_loss = Categorical(logits=logits).entropy().mean(-1)
|
||||||
|
advantages = self.compute_advantages(critic, reward, done, gamma, gae_coef)
|
||||||
|
value_loss = advantages.pow(2).mean(-1) # n_agent
|
||||||
|
|
||||||
|
# policy loss
|
||||||
|
log_ap = torch.log_softmax(logits, -1)
|
||||||
|
log_ap = torch.gather(log_ap, dim=-1, index=actions[:, 1:].unsqueeze(-1)).squeeze()
|
||||||
|
a2c_loss = -(advantages.detach() * log_ap).mean(-1)
|
||||||
|
# weighted loss
|
||||||
|
loss = a2c_loss + vf_coef*value_loss - entropy_coef * entropy_loss
|
||||||
|
return loss.mean()
|
||||||
|
|
||||||
|
def learn(self, tm: MARLActorCriticMemory, **kwargs):
|
||||||
|
loss = self.actor_critic(tm, self.net, **self.cfg[nms.ALGORITHM], **kwargs)
|
||||||
|
# remove next_obs, will be added in next iter
|
||||||
|
self.optimizer.zero_grad()
|
||||||
|
loss.backward()
|
||||||
|
torch.nn.utils.clip_grad_norm_(self.net.parameters(), 0.5)
|
||||||
|
self.optimizer.step()
|
||||||
|
|
24
algorithms/marl/example_config.yaml
Normal file
24
algorithms/marl/example_config.yaml
Normal file
@ -0,0 +1,24 @@
|
|||||||
|
agent:
|
||||||
|
classname: algorithms.marl.networks.RecurrentAC
|
||||||
|
n_agents: 2
|
||||||
|
obs_emb_size: 96
|
||||||
|
action_emb_size: 16
|
||||||
|
hidden_size_actor: 64
|
||||||
|
hidden_size_critic: 64
|
||||||
|
use_agent_embedding: False
|
||||||
|
env:
|
||||||
|
classname: environments.factory.make
|
||||||
|
env_name: "DirtyFactory-v0"
|
||||||
|
n_agents: 2
|
||||||
|
max_steps: 250
|
||||||
|
pomdp_r: 2
|
||||||
|
stack_n_frames: 0
|
||||||
|
individual_rewards: True
|
||||||
|
method: algorithms.marl.LoopSEAC
|
||||||
|
algorithm:
|
||||||
|
gamma: 0.99
|
||||||
|
entropy_coef: 0.01
|
||||||
|
vf_coef: 0.5
|
||||||
|
n_steps: 5
|
||||||
|
max_steps: 1000000
|
||||||
|
|
57
algorithms/marl/iac.py
Normal file
57
algorithms/marl/iac.py
Normal file
@ -0,0 +1,57 @@
|
|||||||
|
import torch
|
||||||
|
from algorithms.marl.base_ac import BaseActorCritic, nms
|
||||||
|
from algorithms.utils import instantiate_class
|
||||||
|
from pathlib import Path
|
||||||
|
from natsort import natsorted
|
||||||
|
from algorithms.marl.memory import MARLActorCriticMemory
|
||||||
|
|
||||||
|
|
||||||
|
class LoopIAC(BaseActorCritic):
|
||||||
|
|
||||||
|
def __init__(self, cfg):
|
||||||
|
super(LoopIAC, self).__init__(cfg)
|
||||||
|
|
||||||
|
def setup(self):
|
||||||
|
self.net = [
|
||||||
|
instantiate_class(self.cfg[nms.AGENT]) for _ in range(self.n_agents)
|
||||||
|
]
|
||||||
|
self.optimizer = [
|
||||||
|
torch.optim.RMSprop(self.net[ag_i].parameters(), lr=3e-4, eps=1e-5) for ag_i in range(self.n_agents)
|
||||||
|
]
|
||||||
|
|
||||||
|
def load_state_dict(self, path: Path):
|
||||||
|
paths = natsorted(list(path.glob('*.pt')))
|
||||||
|
for path, net in zip(paths, self.net):
|
||||||
|
net.load_state_dict(torch.load(path))
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def merge_dicts(ds): # todo could be recursive for more than 1 hierarchy
|
||||||
|
d = {}
|
||||||
|
for k in ds[0].keys():
|
||||||
|
d[k] = [d[k] for d in ds]
|
||||||
|
return d
|
||||||
|
|
||||||
|
def init_hidden(self):
|
||||||
|
ha = [net.init_hidden_actor() for net in self.net]
|
||||||
|
hc = [net.init_hidden_critic() for net in self.net]
|
||||||
|
return dict(hidden_actor=ha, hidden_critic=hc)
|
||||||
|
|
||||||
|
def forward(self, observations, actions, hidden_actor, hidden_critic):
|
||||||
|
outputs = [
|
||||||
|
net(
|
||||||
|
self._as_torch(observations[ag_i]).unsqueeze(0).unsqueeze(0), # agents x time
|
||||||
|
self._as_torch(actions[ag_i]).unsqueeze(0),
|
||||||
|
hidden_actor[ag_i],
|
||||||
|
hidden_critic[ag_i]
|
||||||
|
) for ag_i, net in enumerate(self.net)
|
||||||
|
]
|
||||||
|
return self.merge_dicts(outputs)
|
||||||
|
|
||||||
|
def learn(self, tms: MARLActorCriticMemory, **kwargs):
|
||||||
|
for ag_i in range(self.n_agents):
|
||||||
|
tm, net = tms(ag_i), self.net[ag_i]
|
||||||
|
loss = self.actor_critic(tm, net, **self.cfg[nms.ALGORITHM], **kwargs)
|
||||||
|
self.optimizer[ag_i].zero_grad()
|
||||||
|
loss.backward()
|
||||||
|
torch.nn.utils.clip_grad_norm_(net.parameters(), 0.5)
|
||||||
|
self.optimizer[ag_i].step()
|
67
algorithms/marl/mappo.py
Normal file
67
algorithms/marl/mappo.py
Normal file
@ -0,0 +1,67 @@
|
|||||||
|
from algorithms.marl.base_ac import Names as nms
|
||||||
|
from algorithms.marl import LoopSNAC
|
||||||
|
from algorithms.marl.memory import MARLActorCriticMemory
|
||||||
|
import random
|
||||||
|
import torch
|
||||||
|
from torch.distributions import Categorical
|
||||||
|
from algorithms.utils import instantiate_class
|
||||||
|
|
||||||
|
|
||||||
|
class LoopMAPPO(LoopSNAC):
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
super(LoopMAPPO, self).__init__(*args, **kwargs)
|
||||||
|
self.reset_memory_after_epoch = False
|
||||||
|
|
||||||
|
def setup(self):
|
||||||
|
self.net = instantiate_class(self.cfg[nms.AGENT])
|
||||||
|
self.optimizer = torch.optim.Adam(self.net.parameters(), lr=3e-4, eps=1e-5)
|
||||||
|
|
||||||
|
def learn(self, tm: MARLActorCriticMemory, **kwargs):
|
||||||
|
if len(tm) >= self.cfg['algorithm']['buffer_size']:
|
||||||
|
# only learn when buffer is full
|
||||||
|
for batch_i in range(self.cfg['algorithm']['n_updates']):
|
||||||
|
batch = tm.chunk_dataloader(chunk_len=self.cfg['algorithm']['n_steps'],
|
||||||
|
k=self.cfg['algorithm']['batch_size'])
|
||||||
|
loss = self.mappo(batch, self.net, **self.cfg[nms.ALGORITHM], **kwargs)
|
||||||
|
self.optimizer.zero_grad()
|
||||||
|
loss.backward()
|
||||||
|
torch.nn.utils.clip_grad_norm_(self.net.parameters(), 0.5)
|
||||||
|
self.optimizer.step()
|
||||||
|
|
||||||
|
def monte_carlo_returns(self, rewards, done, gamma):
|
||||||
|
rewards_ = []
|
||||||
|
discounted_reward = torch.zeros_like(rewards[:, -1])
|
||||||
|
for t in range(rewards.shape[1]-1, -1, -1):
|
||||||
|
discounted_reward = rewards[:, t] + (gamma * (1.0 - done[:, t]) * discounted_reward)
|
||||||
|
rewards_.insert(0, discounted_reward)
|
||||||
|
rewards_ = torch.stack(rewards_, dim=1)
|
||||||
|
return rewards_
|
||||||
|
|
||||||
|
def mappo(self, batch, network, gamma, entropy_coef, vf_coef, clip_range, **kwargs):
|
||||||
|
out = network(batch[nms.OBSERVATION], batch[nms.ACTION], batch[nms.HIDDEN_ACTOR], batch[nms.HIDDEN_CRITIC])
|
||||||
|
logits = out[nms.LOGITS][:, :-1] # last one only needed for v_{t+1}
|
||||||
|
|
||||||
|
old_log_probs = torch.log_softmax(batch[nms.LOGITS], -1)
|
||||||
|
old_log_probs = torch.gather(old_log_probs, index=batch[nms.ACTION][:, 1:].unsqueeze(-1), dim=-1).squeeze()
|
||||||
|
|
||||||
|
# monte carlo returns
|
||||||
|
mc_returns = self.monte_carlo_returns(batch[nms.REWARD], batch[nms.DONE], gamma)
|
||||||
|
mc_returns = (mc_returns - mc_returns.mean()) / (mc_returns.std() + 1e-8) #todo: norm across agents ok?
|
||||||
|
advantages = mc_returns - out[nms.CRITIC][:, :-1]
|
||||||
|
|
||||||
|
# policy loss
|
||||||
|
log_ap = torch.log_softmax(logits, -1)
|
||||||
|
log_ap = torch.gather(log_ap, dim=-1, index=batch[nms.ACTION][:, 1:].unsqueeze(-1)).squeeze()
|
||||||
|
ratio = (log_ap - old_log_probs).exp()
|
||||||
|
surr1 = ratio * advantages.detach()
|
||||||
|
surr2 = torch.clamp(ratio, 1 - clip_range, 1 + clip_range) * advantages.detach()
|
||||||
|
policy_loss = -torch.min(surr1, surr2).mean(-1)
|
||||||
|
|
||||||
|
# entropy & value loss
|
||||||
|
entropy_loss = Categorical(logits=logits).entropy().mean(-1)
|
||||||
|
value_loss = advantages.pow(2).mean(-1) # n_agent
|
||||||
|
|
||||||
|
# weighted loss
|
||||||
|
loss = policy_loss + vf_coef*value_loss - entropy_coef * entropy_loss
|
||||||
|
|
||||||
|
return loss.mean()
|
221
algorithms/marl/memory.py
Normal file
221
algorithms/marl/memory.py
Normal file
@ -0,0 +1,221 @@
|
|||||||
|
import numpy as np
|
||||||
|
from collections import deque
|
||||||
|
import torch
|
||||||
|
from typing import Union
|
||||||
|
from torch import Tensor
|
||||||
|
from torch.utils.data import Dataset, ConcatDataset
|
||||||
|
import random
|
||||||
|
|
||||||
|
|
||||||
|
class ActorCriticMemory(object):
|
||||||
|
def __init__(self, capacity=10):
|
||||||
|
self.capacity = capacity
|
||||||
|
self.reset()
|
||||||
|
|
||||||
|
def reset(self):
|
||||||
|
self.__actions = LazyTensorFiFoQueue(maxlen=self.capacity+1)
|
||||||
|
self.__hidden_actor = LazyTensorFiFoQueue(maxlen=self.capacity+1)
|
||||||
|
self.__hidden_critic = LazyTensorFiFoQueue(maxlen=self.capacity+1)
|
||||||
|
self.__states = LazyTensorFiFoQueue(maxlen=self.capacity+1)
|
||||||
|
self.__rewards = LazyTensorFiFoQueue(maxlen=self.capacity+1)
|
||||||
|
self.__dones = LazyTensorFiFoQueue(maxlen=self.capacity+1)
|
||||||
|
self.__logits = LazyTensorFiFoQueue(maxlen=self.capacity+1)
|
||||||
|
self.__values = LazyTensorFiFoQueue(maxlen=self.capacity+1)
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return len(self.__rewards) - 1
|
||||||
|
|
||||||
|
@property
|
||||||
|
def observation(self, sls=slice(0, None)): # add time dimension through stacking
|
||||||
|
return self.__states[sls].unsqueeze(0) # 1 x time x hidden dim
|
||||||
|
|
||||||
|
@property
|
||||||
|
def hidden_actor(self, sls=slice(0, None)): # 1 x n_layers x dim
|
||||||
|
return self.__hidden_actor[sls].unsqueeze(0) # 1 x time x n_layers x dim
|
||||||
|
|
||||||
|
@property
|
||||||
|
def hidden_critic(self, sls=slice(0, None)): # 1 x n_layers x dim
|
||||||
|
return self.__hidden_critic[sls].unsqueeze(0) # 1 x time x n_layers x dim
|
||||||
|
|
||||||
|
@property
|
||||||
|
def reward(self, sls=slice(0, None)):
|
||||||
|
return self.__rewards[sls].squeeze().unsqueeze(0) # 1 x time
|
||||||
|
|
||||||
|
@property
|
||||||
|
def action(self, sls=slice(0, None)):
|
||||||
|
return self.__actions[sls].long().squeeze().unsqueeze(0) # 1 x time
|
||||||
|
|
||||||
|
@property
|
||||||
|
def done(self, sls=slice(0, None)):
|
||||||
|
return self.__dones[sls].float().squeeze().unsqueeze(0) # 1 x time
|
||||||
|
|
||||||
|
@property
|
||||||
|
def logits(self, sls=slice(0, None)): # assumes a trailing 1 for time dimension - common when using output from NN
|
||||||
|
return self.__logits[sls].squeeze().unsqueeze(0) # 1 x time x actions
|
||||||
|
|
||||||
|
@property
|
||||||
|
def values(self, sls=slice(0, None)):
|
||||||
|
return self.__values[sls].squeeze().unsqueeze(0) # 1 x time x actions
|
||||||
|
|
||||||
|
def add_observation(self, state: Union[Tensor, np.ndarray]):
|
||||||
|
self.__states.append(state if isinstance(state, Tensor) else torch.from_numpy(state))
|
||||||
|
|
||||||
|
def add_hidden_actor(self, hidden: Tensor):
|
||||||
|
# layers x hidden dim
|
||||||
|
self.__hidden_actor.append(hidden)
|
||||||
|
|
||||||
|
def add_hidden_critic(self, hidden: Tensor):
|
||||||
|
# layers x hidden dim
|
||||||
|
self.__hidden_critic.append(hidden)
|
||||||
|
|
||||||
|
def add_action(self, action: Union[int, Tensor]):
|
||||||
|
if not isinstance(action, Tensor):
|
||||||
|
action = torch.tensor(action)
|
||||||
|
self.__actions.append(action)
|
||||||
|
|
||||||
|
def add_reward(self, reward: Union[float, Tensor]):
|
||||||
|
if not isinstance(reward, Tensor):
|
||||||
|
reward = torch.tensor(reward)
|
||||||
|
self.__rewards.append(reward)
|
||||||
|
|
||||||
|
def add_done(self, done: bool):
|
||||||
|
if not isinstance(done, Tensor):
|
||||||
|
done = torch.tensor(done)
|
||||||
|
self.__dones.append(done)
|
||||||
|
|
||||||
|
def add_logits(self, logits: Tensor):
|
||||||
|
self.__logits.append(logits)
|
||||||
|
|
||||||
|
def add_values(self, values: Tensor):
|
||||||
|
self.__values.append(values)
|
||||||
|
|
||||||
|
def add(self, **kwargs):
|
||||||
|
for k, v in kwargs.items():
|
||||||
|
func = getattr(ActorCriticMemory, f'add_{k}')
|
||||||
|
func(self, v)
|
||||||
|
|
||||||
|
|
||||||
|
class MARLActorCriticMemory(object):
|
||||||
|
def __init__(self, n_agents, capacity):
|
||||||
|
self.n_agents = n_agents
|
||||||
|
self.memories = [
|
||||||
|
ActorCriticMemory(capacity) for _ in range(n_agents)
|
||||||
|
]
|
||||||
|
|
||||||
|
def __call__(self, agent_i):
|
||||||
|
return self.memories[agent_i]
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return len(self.memories[0]) # todo add assertion check!
|
||||||
|
|
||||||
|
def reset(self):
|
||||||
|
for mem in self.memories:
|
||||||
|
mem.reset()
|
||||||
|
|
||||||
|
def add(self, **kwargs):
|
||||||
|
for agent_i in range(self.n_agents):
|
||||||
|
for k, v in kwargs.items():
|
||||||
|
func = getattr(ActorCriticMemory, f'add_{k}')
|
||||||
|
func(self.memories[agent_i], v[agent_i])
|
||||||
|
|
||||||
|
def __getattr__(self, attr):
|
||||||
|
all_attrs = [getattr(mem, attr) for mem in self.memories]
|
||||||
|
return torch.cat(all_attrs, 0) # agents x time ...
|
||||||
|
|
||||||
|
def chunk_dataloader(self, chunk_len, k):
|
||||||
|
datasets = [ExperienceChunks(mem, chunk_len, k) for mem in self.memories]
|
||||||
|
dataset = ConcatDataset(datasets)
|
||||||
|
data = [dataset[i] for i in range(len(dataset))]
|
||||||
|
data = custom_collate_fn(data)
|
||||||
|
return data
|
||||||
|
|
||||||
|
|
||||||
|
def custom_collate_fn(batch):
|
||||||
|
elem = batch[0]
|
||||||
|
return {key: torch.cat([d[key] for d in batch], dim=0) for key in elem}
|
||||||
|
|
||||||
|
|
||||||
|
class ExperienceChunks(Dataset):
|
||||||
|
def __init__(self, memory, chunk_len, k):
|
||||||
|
assert chunk_len <= len(memory), 'chunk_len cannot be longer than the size of the memory'
|
||||||
|
self.memory = memory
|
||||||
|
self.chunk_len = chunk_len
|
||||||
|
self.k = k
|
||||||
|
|
||||||
|
@property
|
||||||
|
def whitelist(self):
|
||||||
|
whitelist = torch.ones(len(self.memory) - self.chunk_len)
|
||||||
|
for d in self.memory.done.squeeze().nonzero().flatten():
|
||||||
|
whitelist[max((0, d-self.chunk_len-1)):d+2] = 0
|
||||||
|
whitelist[0] = 0
|
||||||
|
return whitelist.tolist()
|
||||||
|
|
||||||
|
def sample(self, start=1):
|
||||||
|
cl = self.chunk_len
|
||||||
|
sample = dict(observation=self.memory.observation[:, start:start+cl+1],
|
||||||
|
action=self.memory.action[:, start-1:start+cl],
|
||||||
|
hidden_actor=self.memory.hidden_actor[:, start-1],
|
||||||
|
hidden_critic=self.memory.hidden_critic[:, start-1],
|
||||||
|
reward=self.memory.reward[:, start:start + cl],
|
||||||
|
done=self.memory.done[:, start:start + cl],
|
||||||
|
logits=self.memory.logits[:, start:start + cl],
|
||||||
|
values=self.memory.values[:, start:start + cl])
|
||||||
|
return sample
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return self.k
|
||||||
|
|
||||||
|
def __getitem__(self, i):
|
||||||
|
idx = random.choices(range(0, len(self.memory) - self.chunk_len), weights=self.whitelist, k=1)
|
||||||
|
return self.sample(idx[0])
|
||||||
|
|
||||||
|
|
||||||
|
class LazyTensorFiFoQueue:
|
||||||
|
def __init__(self, maxlen):
|
||||||
|
self.maxlen = maxlen
|
||||||
|
self.reset()
|
||||||
|
|
||||||
|
def reset(self):
|
||||||
|
self.__lazy_queue = deque(maxlen=self.maxlen)
|
||||||
|
self.shape = None
|
||||||
|
self.queue = None
|
||||||
|
|
||||||
|
def shape_init(self, tensor: Tensor):
|
||||||
|
self.shape = torch.Size([self.maxlen, *tensor.shape])
|
||||||
|
|
||||||
|
def build_tensor_queue(self):
|
||||||
|
if len(self.__lazy_queue) > 0:
|
||||||
|
block = torch.stack(list(self.__lazy_queue), dim=0)
|
||||||
|
l = block.shape[0]
|
||||||
|
if self.queue is None:
|
||||||
|
self.queue = block
|
||||||
|
elif self.true_len() <= self.maxlen:
|
||||||
|
self.queue = torch.cat((self.queue, block), dim=0)
|
||||||
|
else:
|
||||||
|
self.queue = torch.cat((self.queue[l:], block), dim=0)
|
||||||
|
self.__lazy_queue.clear()
|
||||||
|
|
||||||
|
def append(self, data):
|
||||||
|
if self.shape is None:
|
||||||
|
self.shape_init(data)
|
||||||
|
self.__lazy_queue.append(data)
|
||||||
|
if len(self.__lazy_queue) >= self.maxlen:
|
||||||
|
self.build_tensor_queue()
|
||||||
|
|
||||||
|
def true_len(self):
|
||||||
|
return len(self.__lazy_queue) + (0 if self.queue is None else self.queue.shape[0])
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return min((self.true_len(), self.maxlen))
|
||||||
|
|
||||||
|
def __str__(self):
|
||||||
|
return f'LazyTensorFiFoQueue\tmaxlen: {self.maxlen}, shape: {self.shape}, ' \
|
||||||
|
f'len: {len(self)}, true_len: {self.true_len()}, elements in lazy queue: {len(self.__lazy_queue)}'
|
||||||
|
|
||||||
|
def __getitem__(self, item_or_slice):
|
||||||
|
self.build_tensor_queue()
|
||||||
|
return self.queue[item_or_slice]
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
104
algorithms/marl/networks.py
Normal file
104
algorithms/marl/networks.py
Normal file
@ -0,0 +1,104 @@
|
|||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import numpy as np
|
||||||
|
import torch.nn.functional as F
|
||||||
|
from torch.nn.utils import spectral_norm
|
||||||
|
|
||||||
|
|
||||||
|
class RecurrentAC(nn.Module):
|
||||||
|
def __init__(self, observation_size, n_actions, obs_emb_size,
|
||||||
|
action_emb_size, hidden_size_actor, hidden_size_critic,
|
||||||
|
n_agents, use_agent_embedding=True):
|
||||||
|
super(RecurrentAC, self).__init__()
|
||||||
|
observation_size = np.prod(observation_size)
|
||||||
|
self.n_layers = 1
|
||||||
|
self.n_actions = n_actions
|
||||||
|
self.use_agent_embedding = use_agent_embedding
|
||||||
|
self.hidden_size_actor = hidden_size_actor
|
||||||
|
self.hidden_size_critic = hidden_size_critic
|
||||||
|
self.action_emb_size = action_emb_size
|
||||||
|
self.obs_proj = nn.Linear(observation_size, obs_emb_size)
|
||||||
|
self.action_emb = nn.Embedding(n_actions+1, action_emb_size, padding_idx=0)
|
||||||
|
self.agent_emb = nn.Embedding(n_agents, action_emb_size)
|
||||||
|
mix_in_size = obs_emb_size+action_emb_size if not use_agent_embedding else obs_emb_size+n_agents*action_emb_size
|
||||||
|
self.mix = nn.Sequential(nn.Tanh(),
|
||||||
|
nn.Linear(mix_in_size, obs_emb_size),
|
||||||
|
nn.Tanh(),
|
||||||
|
nn.Linear(obs_emb_size, obs_emb_size)
|
||||||
|
)
|
||||||
|
self.gru_actor = nn.GRU(obs_emb_size, hidden_size_actor, batch_first=True, num_layers=self.n_layers)
|
||||||
|
self.gru_critic = nn.GRU(obs_emb_size, hidden_size_critic, batch_first=True, num_layers=self.n_layers)
|
||||||
|
self.action_head = nn.Sequential(
|
||||||
|
nn.Linear(hidden_size_actor, hidden_size_actor),
|
||||||
|
nn.Tanh(),
|
||||||
|
nn.Linear(hidden_size_actor, n_actions)
|
||||||
|
)
|
||||||
|
# spectral_norm(nn.Linear(hidden_size_actor, hidden_size_actor)),
|
||||||
|
self.critic_head = nn.Sequential(
|
||||||
|
nn.Linear(hidden_size_critic, hidden_size_critic),
|
||||||
|
nn.Tanh(),
|
||||||
|
nn.Linear(hidden_size_critic, 1)
|
||||||
|
)
|
||||||
|
#self.action_head[-1].weight.data.uniform_(-3e-3, 3e-3)
|
||||||
|
#self.action_head[-1].bias.data.uniform_(-3e-3, 3e-3)
|
||||||
|
|
||||||
|
def init_hidden_actor(self):
|
||||||
|
return torch.zeros(1, self.n_layers, self.hidden_size_actor)
|
||||||
|
|
||||||
|
def init_hidden_critic(self):
|
||||||
|
return torch.zeros(1, self.n_layers, self.hidden_size_critic)
|
||||||
|
|
||||||
|
def forward(self, observations, actions, hidden_actor=None, hidden_critic=None):
|
||||||
|
n_agents, t, *_ = observations.shape
|
||||||
|
obs_emb = self.obs_proj(observations.view(n_agents, t, -1).float())
|
||||||
|
action_emb = self.action_emb(actions+1) # shift by one due to padding idx
|
||||||
|
|
||||||
|
if not self.use_agent_embedding:
|
||||||
|
x_t = torch.cat((obs_emb, action_emb), -1)
|
||||||
|
else:
|
||||||
|
agent_emb = self.agent_emb(
|
||||||
|
torch.cat([torch.arange(0, n_agents, 1).view(-1, 1)] * t, 1)
|
||||||
|
)
|
||||||
|
x_t = torch.cat((obs_emb, agent_emb, action_emb), -1)
|
||||||
|
|
||||||
|
mixed_x_t = self.mix(x_t)
|
||||||
|
output_p, _ = self.gru_actor(input=mixed_x_t, hx=hidden_actor.swapaxes(1, 0))
|
||||||
|
output_c, _ = self.gru_critic(input=mixed_x_t, hx=hidden_critic.swapaxes(1, 0))
|
||||||
|
|
||||||
|
logits = self.action_head(output_p)
|
||||||
|
critic = self.critic_head(output_c).squeeze(-1)
|
||||||
|
return dict(logits=logits, critic=critic, hidden_actor=output_p, hidden_critic=output_c)
|
||||||
|
|
||||||
|
|
||||||
|
class RecurrentACL2(RecurrentAC):
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
self.action_head = nn.Sequential(
|
||||||
|
nn.Linear(self.hidden_size_actor, self.hidden_size_actor),
|
||||||
|
nn.Tanh(),
|
||||||
|
NormalizedLinear(self.hidden_size_actor, self.n_actions, trainable_magnitude=True)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class NormalizedLinear(nn.Linear):
|
||||||
|
def __init__(self, in_features: int, out_features: int,
|
||||||
|
device=None, dtype=None, trainable_magnitude=False):
|
||||||
|
super(NormalizedLinear, self).__init__(in_features, out_features, False, device, dtype)
|
||||||
|
self.d_sqrt = in_features**0.5
|
||||||
|
self.trainable_magnitude = trainable_magnitude
|
||||||
|
self.scale = nn.Parameter(torch.tensor([1.]), requires_grad=trainable_magnitude)
|
||||||
|
|
||||||
|
def forward(self, input):
|
||||||
|
normalized_input = F.normalize(input, dim=-1, p=2, eps=1e-5)
|
||||||
|
normalized_weight = F.normalize(self.weight, dim=-1, p=2, eps=1e-5)
|
||||||
|
return F.linear(normalized_input, normalized_weight) * self.d_sqrt * self.scale
|
||||||
|
|
||||||
|
|
||||||
|
class L2Norm(nn.Module):
|
||||||
|
def __init__(self, in_features, trainable_magnitude=False):
|
||||||
|
super(L2Norm, self).__init__()
|
||||||
|
self.d_sqrt = in_features**0.5
|
||||||
|
self.scale = nn.Parameter(torch.tensor([1.]), requires_grad=trainable_magnitude)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
return F.normalize(x, dim=-1, p=2, eps=1e-5) * self.d_sqrt * self.scale
|
56
algorithms/marl/seac.py
Normal file
56
algorithms/marl/seac.py
Normal file
@ -0,0 +1,56 @@
|
|||||||
|
import torch
|
||||||
|
from torch.distributions import Categorical
|
||||||
|
from algorithms.marl.iac import LoopIAC
|
||||||
|
from algorithms.marl.base_ac import nms
|
||||||
|
from algorithms.marl.memory import MARLActorCriticMemory
|
||||||
|
|
||||||
|
|
||||||
|
class LoopSEAC(LoopIAC):
|
||||||
|
def __init__(self, cfg):
|
||||||
|
super(LoopSEAC, self).__init__(cfg)
|
||||||
|
|
||||||
|
def actor_critic(self, tm, networks, gamma, entropy_coef, vf_coef, gae_coef=0.0, **kwargs):
|
||||||
|
obs, actions, done, reward = tm.observation, tm.action, tm.done[:, 1:], tm.reward[:, 1:]
|
||||||
|
outputs = [net(obs, actions, tm.hidden_actor[:, 0], tm.hidden_critic[:, 0]) for net in networks]
|
||||||
|
|
||||||
|
with torch.inference_mode(True):
|
||||||
|
true_action_logp = torch.stack([
|
||||||
|
torch.log_softmax(out[nms.LOGITS][ag_i, :-1], -1)
|
||||||
|
.gather(index=actions[ag_i, 1:, None], dim=-1)
|
||||||
|
for ag_i, out in enumerate(outputs)
|
||||||
|
], 0).squeeze()
|
||||||
|
|
||||||
|
losses = []
|
||||||
|
|
||||||
|
for ag_i, out in enumerate(outputs):
|
||||||
|
logits = out[nms.LOGITS][:, :-1] # last one only needed for v_{t+1}
|
||||||
|
critic = out[nms.CRITIC]
|
||||||
|
|
||||||
|
entropy_loss = Categorical(logits=logits[ag_i]).entropy().mean()
|
||||||
|
advantages = self.compute_advantages(critic, reward, done, gamma, gae_coef)
|
||||||
|
|
||||||
|
# policy loss
|
||||||
|
log_ap = torch.log_softmax(logits, -1)
|
||||||
|
log_ap = torch.gather(log_ap, dim=-1, index=actions[:, 1:].unsqueeze(-1)).squeeze()
|
||||||
|
|
||||||
|
# importance weights
|
||||||
|
iw = (log_ap - true_action_logp).exp().detach() # importance_weights
|
||||||
|
|
||||||
|
a2c_loss = (-iw*log_ap * advantages.detach()).mean(-1)
|
||||||
|
|
||||||
|
|
||||||
|
value_loss = (iw*advantages.pow(2)).mean(-1) # n_agent
|
||||||
|
|
||||||
|
# weighted loss
|
||||||
|
loss = (a2c_loss + vf_coef*value_loss - entropy_coef * entropy_loss).mean()
|
||||||
|
losses.append(loss)
|
||||||
|
|
||||||
|
return losses
|
||||||
|
|
||||||
|
def learn(self, tms: MARLActorCriticMemory, **kwargs):
|
||||||
|
losses = self.actor_critic(tms, self.net, **self.cfg[nms.ALGORITHM], **kwargs)
|
||||||
|
for ag_i, loss in enumerate(losses):
|
||||||
|
self.optimizer[ag_i].zero_grad()
|
||||||
|
loss.backward()
|
||||||
|
torch.nn.utils.clip_grad_norm_(self.net[ag_i].parameters(), 0.5)
|
||||||
|
self.optimizer[ag_i].step()
|
33
algorithms/marl/snac.py
Normal file
33
algorithms/marl/snac.py
Normal file
@ -0,0 +1,33 @@
|
|||||||
|
from algorithms.marl.base_ac import BaseActorCritic
|
||||||
|
from algorithms.marl.base_ac import nms
|
||||||
|
import torch
|
||||||
|
from torch.distributions import Categorical
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
|
||||||
|
class LoopSNAC(BaseActorCritic):
|
||||||
|
def __init__(self, cfg):
|
||||||
|
super().__init__(cfg)
|
||||||
|
|
||||||
|
def load_state_dict(self, path: Path):
|
||||||
|
path2weights = list(path.glob('*.pt'))
|
||||||
|
assert len(path2weights) == 1, f'Expected a single set of weights but got {len(path2weights)}'
|
||||||
|
self.net.load_state_dict(torch.load(path2weights[0]))
|
||||||
|
|
||||||
|
def init_hidden(self):
|
||||||
|
hidden_actor = self.net.init_hidden_actor()
|
||||||
|
hidden_critic = self.net.init_hidden_critic()
|
||||||
|
return dict(hidden_actor=torch.cat([hidden_actor] * self.n_agents, 0),
|
||||||
|
hidden_critic=torch.cat([hidden_critic] * self.n_agents, 0)
|
||||||
|
)
|
||||||
|
|
||||||
|
def get_actions(self, out):
|
||||||
|
actions = Categorical(logits=out[nms.LOGITS]).sample().squeeze()
|
||||||
|
return actions
|
||||||
|
|
||||||
|
def forward(self, observations, actions, hidden_actor, hidden_critic):
|
||||||
|
out = self.net(self._as_torch(observations).unsqueeze(1),
|
||||||
|
self._as_torch(actions).unsqueeze(1),
|
||||||
|
hidden_actor, hidden_critic
|
||||||
|
)
|
||||||
|
return out
|
@ -1,127 +0,0 @@
|
|||||||
from typing import Union
|
|
||||||
import gym
|
|
||||||
import torch
|
|
||||||
import torch.nn as nn
|
|
||||||
import numpy as np
|
|
||||||
from collections import deque
|
|
||||||
from pathlib import Path
|
|
||||||
import yaml
|
|
||||||
from algorithms.common import BaseLearner, BaseBuffer, soft_update, Experience
|
|
||||||
|
|
||||||
|
|
||||||
class QLearner(BaseLearner):
|
|
||||||
def __init__(self, q_net, target_q_net, env, buffer_size=1e5, target_update=3000, eps_end=0.05, n_agents=1,
|
|
||||||
gamma=0.99, train_every=('step', 4), n_grad_steps=1, tau=1.0, max_grad_norm=10, weight_decay=1e-2,
|
|
||||||
exploration_fraction=0.2, batch_size=64, lr=1e-4, reg_weight=0.0, eps_start=1):
|
|
||||||
super(QLearner, self).__init__(env, n_agents, train_every, n_grad_steps)
|
|
||||||
self.q_net = q_net
|
|
||||||
self.target_q_net = target_q_net
|
|
||||||
self.target_q_net.eval()
|
|
||||||
#soft_update(self.q_net, self.target_q_net, tau=1.0)
|
|
||||||
self.buffer = BaseBuffer(buffer_size)
|
|
||||||
self.target_update = target_update
|
|
||||||
self.eps = eps_start
|
|
||||||
self.eps_start = eps_start
|
|
||||||
self.eps_end = eps_end
|
|
||||||
self.exploration_fraction = exploration_fraction
|
|
||||||
self.batch_size = batch_size
|
|
||||||
self.gamma = gamma
|
|
||||||
self.tau = tau
|
|
||||||
self.reg_weight = reg_weight
|
|
||||||
self.weight_decay = weight_decay
|
|
||||||
self.lr = lr
|
|
||||||
self.optimizer = torch.optim.AdamW(self.q_net.parameters(), lr=self.lr, weight_decay=self.weight_decay)
|
|
||||||
self.max_grad_norm = max_grad_norm
|
|
||||||
self.running_reward = deque(maxlen=5)
|
|
||||||
self.running_loss = deque(maxlen=5)
|
|
||||||
self.n_updates = 0
|
|
||||||
|
|
||||||
def save(self, path):
|
|
||||||
path = Path(path) # no-op if already instance of Path
|
|
||||||
path.mkdir(parents=True, exist_ok=True)
|
|
||||||
hparams = {k: v for k, v in self.__dict__.items() if not(isinstance(v, BaseBuffer) or
|
|
||||||
isinstance(v, torch.optim.Optimizer) or
|
|
||||||
isinstance(v, gym.Env) or
|
|
||||||
isinstance(v, nn.Module))
|
|
||||||
}
|
|
||||||
hparams.update({'class': self.__class__.__name__})
|
|
||||||
with (path / 'hparams.yaml').open('w') as outfile:
|
|
||||||
yaml.dump(hparams, outfile)
|
|
||||||
torch.save(self.q_net, path / 'q_net.pt')
|
|
||||||
|
|
||||||
def anneal_eps(self, step, n_steps):
|
|
||||||
fraction = min(float(step) / int(self.exploration_fraction*n_steps), 1.0)
|
|
||||||
self.eps = 1 + fraction * (self.eps_end - 1)
|
|
||||||
|
|
||||||
def get_action(self, obs) -> Union[int, np.ndarray]:
|
|
||||||
o = torch.from_numpy(obs).unsqueeze(0) if self.n_agents <= 1 else torch.from_numpy(obs)
|
|
||||||
if np.random.rand() > self.eps:
|
|
||||||
action = self.q_net.act(o.float())
|
|
||||||
else:
|
|
||||||
action = np.array([self.env.action_space.sample() for _ in range(self.n_agents)])
|
|
||||||
return action
|
|
||||||
|
|
||||||
def on_new_experience(self, experience):
|
|
||||||
self.buffer.add(experience)
|
|
||||||
|
|
||||||
def on_step_end(self, n_steps):
|
|
||||||
self.anneal_eps(self.step, n_steps)
|
|
||||||
if self.step % self.target_update == 0:
|
|
||||||
print('UPDATE')
|
|
||||||
soft_update(self.q_net, self.target_q_net, tau=self.tau)
|
|
||||||
|
|
||||||
def _training_routine(self, obs, next_obs, action):
|
|
||||||
current_q_values = self.q_net(obs)
|
|
||||||
current_q_values = torch.gather(current_q_values, dim=-1, index=action)
|
|
||||||
next_q_values_raw = self.target_q_net(next_obs).max(dim=-1)[0].reshape(-1, 1).detach()
|
|
||||||
return current_q_values, next_q_values_raw
|
|
||||||
|
|
||||||
def _backprop_loss(self, loss):
|
|
||||||
# log loss
|
|
||||||
self.running_loss.append(loss.item())
|
|
||||||
# Optimize the model
|
|
||||||
self.optimizer.zero_grad()
|
|
||||||
loss.backward()
|
|
||||||
torch.nn.utils.clip_grad_norm_(self.q_net.parameters(), self.max_grad_norm)
|
|
||||||
self.optimizer.step()
|
|
||||||
|
|
||||||
def train(self):
|
|
||||||
if len(self.buffer) < self.batch_size: return
|
|
||||||
for _ in range(self.n_grad_steps):
|
|
||||||
experience = self.buffer.sample(self.batch_size, cer=self.train_every[-1])
|
|
||||||
pred_q, target_q_raw = self._training_routine(experience.observation,
|
|
||||||
experience.next_observation,
|
|
||||||
experience.action)
|
|
||||||
target_q = experience.reward + (1 - experience.done) * self.gamma * target_q_raw
|
|
||||||
loss = torch.mean(self.reg_weight * pred_q + torch.pow(pred_q - target_q, 2))
|
|
||||||
self._backprop_loss(loss)
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
|
||||||
from environments.factory.factory_dirt import DirtFactory, DirtProperties, MovementProperties
|
|
||||||
from algorithms.common import BaseDDQN, BaseICM
|
|
||||||
from algorithms.m_q_learner import MQLearner, MQICMLearner
|
|
||||||
from algorithms.vdn_learner import VDNLearner
|
|
||||||
|
|
||||||
N_AGENTS = 1
|
|
||||||
|
|
||||||
with (Path(f'../environments/factory/env_default_param.yaml')).open('r') as f:
|
|
||||||
env_kwargs = yaml.load(f, Loader=yaml.FullLoader)
|
|
||||||
|
|
||||||
env = DirtFactory(**env_kwargs)
|
|
||||||
obs_shape = np.prod(env.observation_space.shape)
|
|
||||||
n_actions = env.action_space.n
|
|
||||||
|
|
||||||
dqn, target_dqn = BaseDDQN(backbone_dims=[obs_shape, 128, 128], advantage_dims=[128, n_actions], value_dims=[128, 1], activation='leaky_relu'),\
|
|
||||||
BaseDDQN(backbone_dims=[obs_shape, 128, 128], advantage_dims=[128, n_actions], value_dims=[128, 1], activation='leaky_relu')
|
|
||||||
|
|
||||||
icm = BaseICM(backbone_dims=[obs_shape, 64, 32], head_dims=[2*32, 64, n_actions])
|
|
||||||
|
|
||||||
learner = MQICMLearner(dqn, target_dqn, env, 50000, icm=icm,
|
|
||||||
target_update=5000, lr=0.0007, gamma=0.99, n_agents=N_AGENTS, tau=0.95, max_grad_norm=10,
|
|
||||||
train_every=('step', 4), eps_end=0.025, n_grad_steps=1, reg_weight=0.1, exploration_fraction=0.25,
|
|
||||||
batch_size=64, weight_decay=1e-3
|
|
||||||
)
|
|
||||||
#learner.save(Path(__file__).parent / 'test' / 'testexperiment1337')
|
|
||||||
learner.learn(100000)
|
|
@ -1,52 +0,0 @@
|
|||||||
import numpy as np
|
|
||||||
import torch
|
|
||||||
import stable_baselines3 as sb3
|
|
||||||
from stable_baselines3.common import logger
|
|
||||||
|
|
||||||
|
|
||||||
class RegDQN(sb3.dqn.DQN):
|
|
||||||
def __init__(self, *args, reg_weight=0.1, **kwargs):
|
|
||||||
super().__init__(*args, **kwargs)
|
|
||||||
self.reg_weight = reg_weight
|
|
||||||
|
|
||||||
def train(self, gradient_steps: int, batch_size: int = 100) -> None:
|
|
||||||
# Update learning rate according to schedule
|
|
||||||
self._update_learning_rate(self.policy.optimizer)
|
|
||||||
|
|
||||||
losses = []
|
|
||||||
for _ in range(gradient_steps):
|
|
||||||
# Sample replay buffer
|
|
||||||
replay_data = self.replay_buffer.sample(batch_size, env=self._vec_normalize_env)
|
|
||||||
|
|
||||||
with torch.no_grad():
|
|
||||||
# Compute the next Q-values using the target network
|
|
||||||
next_q_values = self.q_net_target(replay_data.next_observations)
|
|
||||||
# Follow greedy policy: use the one with the highest value
|
|
||||||
next_q_values, _ = next_q_values.max(dim=1)
|
|
||||||
# Avoid potential broadcast issue
|
|
||||||
next_q_values = next_q_values.reshape(-1, 1)
|
|
||||||
# 1-step TD target
|
|
||||||
target_q_values = replay_data.rewards + (1 - replay_data.dones) * self.gamma * next_q_values
|
|
||||||
|
|
||||||
# Get current Q-values estimates
|
|
||||||
current_q_values = self.q_net(replay_data.observations)
|
|
||||||
|
|
||||||
# Retrieve the q-values for the actions from the replay buffer
|
|
||||||
current_q_values = torch.gather(current_q_values, dim=1, index=replay_data.actions.long())
|
|
||||||
|
|
||||||
delta = current_q_values - target_q_values
|
|
||||||
loss = torch.mean(self.reg_weight * current_q_values + torch.pow(delta, 2))
|
|
||||||
losses.append(loss.item())
|
|
||||||
|
|
||||||
# Optimize the policy
|
|
||||||
self.policy.optimizer.zero_grad()
|
|
||||||
loss.backward()
|
|
||||||
# Clip gradient norm
|
|
||||||
torch.nn.utils.clip_grad_norm_(self.policy.parameters(), self.max_grad_norm)
|
|
||||||
self.policy.optimizer.step()
|
|
||||||
|
|
||||||
# Increase update counter
|
|
||||||
self._n_updates += gradient_steps
|
|
||||||
|
|
||||||
logger.record("train/n_updates", self._n_updates, exclude="tensorboard")
|
|
||||||
logger.record("train/loss", np.mean(losses))
|
|
@ -3,14 +3,51 @@ import torch
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
import yaml
|
import yaml
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from salina import instantiate_class
|
|
||||||
from salina import TAgent
|
|
||||||
from salina.agents.gyma import (
|
def load_class(classname):
|
||||||
AutoResetGymAgent,
|
from importlib import import_module
|
||||||
_torch_type,
|
module_path, class_name = classname.rsplit(".", 1)
|
||||||
_format_frame,
|
module = import_module(module_path)
|
||||||
_torch_cat_dict
|
c = getattr(module, class_name)
|
||||||
)
|
return c
|
||||||
|
|
||||||
|
|
||||||
|
def instantiate_class(arguments):
|
||||||
|
from importlib import import_module
|
||||||
|
|
||||||
|
d = dict(arguments)
|
||||||
|
classname = d["classname"]
|
||||||
|
del d["classname"]
|
||||||
|
module_path, class_name = classname.rsplit(".", 1)
|
||||||
|
module = import_module(module_path)
|
||||||
|
c = getattr(module, class_name)
|
||||||
|
return c(**d)
|
||||||
|
|
||||||
|
|
||||||
|
def get_class(arguments):
|
||||||
|
from importlib import import_module
|
||||||
|
|
||||||
|
if isinstance(arguments, dict):
|
||||||
|
classname = arguments["classname"]
|
||||||
|
module_path, class_name = classname.rsplit(".", 1)
|
||||||
|
module = import_module(module_path)
|
||||||
|
c = getattr(module, class_name)
|
||||||
|
return c
|
||||||
|
else:
|
||||||
|
classname = arguments.classname
|
||||||
|
module_path, class_name = classname.rsplit(".", 1)
|
||||||
|
module = import_module(module_path)
|
||||||
|
c = getattr(module, class_name)
|
||||||
|
return c
|
||||||
|
|
||||||
|
|
||||||
|
def get_arguments(arguments):
|
||||||
|
from importlib import import_module
|
||||||
|
d = dict(arguments)
|
||||||
|
if "classname" in d:
|
||||||
|
del d["classname"]
|
||||||
|
return d
|
||||||
|
|
||||||
|
|
||||||
def load_yaml_file(path: Path):
|
def load_yaml_file(path: Path):
|
||||||
@ -21,90 +58,29 @@ def load_yaml_file(path: Path):
|
|||||||
|
|
||||||
def add_env_props(cfg):
|
def add_env_props(cfg):
|
||||||
env = instantiate_class(cfg['env'].copy())
|
env = instantiate_class(cfg['env'].copy())
|
||||||
cfg['agent'].update(dict(observation_size=env.observation_space.shape,
|
cfg['agent'].update(dict(observation_size=list(env.observation_space.shape),
|
||||||
n_actions=env.action_space.n))
|
n_actions=env.action_space.n))
|
||||||
|
|
||||||
|
|
||||||
|
class Checkpointer(object):
|
||||||
|
def __init__(self, experiment_name, root, config, total_steps, n_checkpoints):
|
||||||
|
self.path = root / experiment_name
|
||||||
|
self.checkpoint_indices = list(np.linspace(1, total_steps, n_checkpoints, dtype=int) - 1)
|
||||||
|
self.__current_checkpoint = 0
|
||||||
|
self.__current_step = 0
|
||||||
|
self.path.mkdir(exist_ok=True, parents=True)
|
||||||
|
with (self.path / 'config.yaml').open('w') as outfile:
|
||||||
|
yaml.dump(config, outfile, default_flow_style=False)
|
||||||
|
|
||||||
|
def save_experiment(self, name: str, model):
|
||||||
|
cpt_path = self.path / f'checkpoint_{self.__current_checkpoint}'
|
||||||
|
cpt_path.mkdir(exist_ok=True, parents=True)
|
||||||
|
torch.save(model.state_dict(), cpt_path / f'{name}.pt')
|
||||||
|
|
||||||
AGENT_PREFIX = 'agent#'
|
def step(self, to_save):
|
||||||
REWARD = 'reward'
|
if self.__current_step in self.checkpoint_indices:
|
||||||
CUMU_REWARD = 'cumulated_reward'
|
print(f'Checkpointing #{self.__current_checkpoint}')
|
||||||
OBS = 'env_obs'
|
for name, model in to_save:
|
||||||
SEP = '_'
|
self.save_experiment(name, model)
|
||||||
ACTION = 'action'
|
self.__current_checkpoint += 1
|
||||||
|
self.__current_step += 1
|
||||||
|
|
||||||
def access_str(agent_i, name, prefix=''):
|
|
||||||
return f'{prefix}{AGENT_PREFIX}{agent_i}{SEP}{name}'
|
|
||||||
|
|
||||||
|
|
||||||
class AutoResetGymMultiAgent(AutoResetGymAgent):
|
|
||||||
def __init__(self, *args, **kwargs):
|
|
||||||
super(AutoResetGymMultiAgent, self).__init__(*args, **kwargs)
|
|
||||||
|
|
||||||
def per_agent_values(self, name, values):
|
|
||||||
return {access_str(agent_i, name): value
|
|
||||||
for agent_i, value in zip(range(self.n_agents), values)}
|
|
||||||
|
|
||||||
def _initialize_envs(self, n):
|
|
||||||
super()._initialize_envs(n)
|
|
||||||
n_agents_list = [self.envs[i].unwrapped.n_agents for i in range(n)]
|
|
||||||
assert all(n_agents == n_agents_list[0] for n_agents in n_agents_list), \
|
|
||||||
'All envs must have the same number of agents.'
|
|
||||||
self.n_agents = n_agents_list[0]
|
|
||||||
|
|
||||||
def _reset(self, k, save_render):
|
|
||||||
ret = super()._reset(k, save_render)
|
|
||||||
obs = ret['env_obs'].squeeze()
|
|
||||||
self.cumulated_reward[k] = [0.0]*self.n_agents
|
|
||||||
obs = self.per_agent_values(OBS, [_format_frame(obs[i]) for i in range(self.n_agents)])
|
|
||||||
cumu_rew = self.per_agent_values(CUMU_REWARD, torch.zeros(self.n_agents, 1).float().unbind())
|
|
||||||
rewards = self.per_agent_values(REWARD, torch.zeros(self.n_agents, 1).float().unbind())
|
|
||||||
ret.update(cumu_rew)
|
|
||||||
ret.update(rewards)
|
|
||||||
ret.update(obs)
|
|
||||||
for remove in ['env_obs', 'cumulated_reward', 'reward']:
|
|
||||||
del ret[remove]
|
|
||||||
return ret
|
|
||||||
|
|
||||||
def _step(self, k, action, save_render):
|
|
||||||
self.timestep[k] += 1
|
|
||||||
env = self.envs[k]
|
|
||||||
if len(action.size()) == 0:
|
|
||||||
action = action.item()
|
|
||||||
assert isinstance(action, int)
|
|
||||||
else:
|
|
||||||
action = np.array(action.tolist())
|
|
||||||
o, r, d, _ = env.step(action)
|
|
||||||
self.cumulated_reward[k] = [x+y for x, y in zip(r, self.cumulated_reward[k])]
|
|
||||||
observation = self.per_agent_values(OBS, [_format_frame(o[i]) for i in range(self.n_agents)])
|
|
||||||
if d:
|
|
||||||
self.is_running[k] = False
|
|
||||||
if save_render:
|
|
||||||
image = env.render(mode="image").unsqueeze(0)
|
|
||||||
observation["rendering"] = image
|
|
||||||
rewards = self.per_agent_values(REWARD, torch.tensor(r).float().view(-1, 1).unbind())
|
|
||||||
cumulated_rewards = self.per_agent_values(CUMU_REWARD, torch.tensor(self.cumulated_reward[k]).float().view(-1, 1).unbind())
|
|
||||||
ret = {
|
|
||||||
**observation,
|
|
||||||
**rewards,
|
|
||||||
**cumulated_rewards,
|
|
||||||
"done": torch.tensor([d]),
|
|
||||||
"initial_state": torch.tensor([False]),
|
|
||||||
"timestep": torch.tensor([self.timestep[k]])
|
|
||||||
}
|
|
||||||
return _torch_type(ret)
|
|
||||||
|
|
||||||
|
|
||||||
class CombineActionsAgent(TAgent):
|
|
||||||
def __init__(self):
|
|
||||||
super().__init__()
|
|
||||||
self.pattern = fr'^{AGENT_PREFIX}\d{SEP}{ACTION}$'
|
|
||||||
|
|
||||||
def forward(self, t, **kwargs):
|
|
||||||
keys = list(self.workspace.keys())
|
|
||||||
action_keys = sorted([k for k in keys if bool(re.match(self.pattern, k))])
|
|
||||||
actions = torch.cat([self.get((k, t)) for k in action_keys], 0)
|
|
||||||
actions = actions if len(action_keys) <= 1 else actions.unsqueeze(0)
|
|
||||||
self.set((f'action', t), actions)
|
|
@ -1,55 +0,0 @@
|
|||||||
from typing import Union
|
|
||||||
import torch
|
|
||||||
import numpy as np
|
|
||||||
import pandas as pd
|
|
||||||
from algorithms.q_learner import QLearner
|
|
||||||
|
|
||||||
|
|
||||||
class VDNLearner(QLearner):
|
|
||||||
def __init__(self, *args, **kwargs):
|
|
||||||
super(VDNLearner, self).__init__(*args, **kwargs)
|
|
||||||
assert self.n_agents >= 2, 'VDN requires more than one agent, use QLearner instead'
|
|
||||||
|
|
||||||
def get_action(self, obs) -> Union[int, np.ndarray]:
|
|
||||||
o = torch.from_numpy(obs).unsqueeze(0) if self.n_agents <= 1 else torch.from_numpy(obs)
|
|
||||||
eps = np.random.rand(self.n_agents)
|
|
||||||
greedy = eps > self.eps
|
|
||||||
agent_actions = None
|
|
||||||
actions = []
|
|
||||||
for i in range(self.n_agents):
|
|
||||||
if greedy[i]:
|
|
||||||
if agent_actions is None: agent_actions = self.q_net.act(o.float())
|
|
||||||
action = agent_actions[i]
|
|
||||||
else:
|
|
||||||
action = self.env.action_space.sample()
|
|
||||||
actions.append(action)
|
|
||||||
return np.array(actions)
|
|
||||||
|
|
||||||
def train(self):
|
|
||||||
if len(self.buffer) < self.batch_size: return
|
|
||||||
for _ in range(self.n_grad_steps):
|
|
||||||
experience = self.buffer.sample(self.batch_size, cer=self.train_every_n_steps)
|
|
||||||
pred_q, target_q_raw = torch.zeros((self.batch_size, 1)), torch.zeros((self.batch_size, 1))
|
|
||||||
for agent_i in range(self.n_agents):
|
|
||||||
q_values, next_q_values_raw = self._training_routine(experience.observation[:, agent_i],
|
|
||||||
experience.next_observation[:, agent_i],
|
|
||||||
experience.action[:, agent_i].unsqueeze(-1))
|
|
||||||
pred_q += q_values
|
|
||||||
target_q_raw += next_q_values_raw
|
|
||||||
target_q = experience.reward + (1 - experience.done) * self.gamma * target_q_raw
|
|
||||||
loss = torch.mean(self.reg_weight * pred_q + torch.pow(pred_q - target_q, 2))
|
|
||||||
self._backprop_loss(loss)
|
|
||||||
|
|
||||||
def evaluate(self, n_episodes=100, render=False):
|
|
||||||
with torch.no_grad():
|
|
||||||
data = []
|
|
||||||
for eval_i in range(n_episodes):
|
|
||||||
obs, done = self.env.reset(), False
|
|
||||||
while not done:
|
|
||||||
action = self.get_action(obs)
|
|
||||||
next_obs, reward, done, info = self.env.step(action)
|
|
||||||
if render: self.env.render()
|
|
||||||
obs = next_obs # srsly i'm so stupid
|
|
||||||
info.update({'reward': reward, 'eval_episode': eval_i})
|
|
||||||
data.append(info)
|
|
||||||
return pd.DataFrame(data).fillna(0)
|
|
@ -1,22 +1,25 @@
|
|||||||
def make(env_name, pomdp_r=2, max_steps=400, stack_n_frames=3, n_agents=1, individual_rewards=False):
|
def make(env_name, pomdp_r=2, max_steps=400, stack_n_frames=3, n_agents=1, individual_rewards=False):
|
||||||
import yaml
|
import yaml
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from environments.factory.combined_factories import DirtItemFactory
|
from environments.factory.combined_factories import DirtItemFactory
|
||||||
from environments.factory.factory_item import ItemFactory, ItemProperties
|
from environments.factory.factory_item import ItemFactory, ItemProperties
|
||||||
from environments.factory.factory_dirt import DirtProperties, DirtFactory
|
from environments.factory.factory_dirt import DirtProperties, DirtFactory, RewardsDirt
|
||||||
from environments.utility_classes import MovementProperties, ObservationProperties, AgentRenderOptions
|
from environments.utility_classes import AgentRenderOptions
|
||||||
|
|
||||||
with (Path(__file__).parent / 'levels' / 'parameters' / f'{env_name}.yaml').open('r') as stream:
|
with (Path(__file__).parent / 'levels' / 'parameters' / f'{env_name}.yaml').open('r') as stream:
|
||||||
dictionary = yaml.load(stream, Loader=yaml.FullLoader)
|
dictionary = yaml.load(stream, Loader=yaml.FullLoader)
|
||||||
|
|
||||||
obs_props = ObservationProperties(render_agents=AgentRenderOptions.COMBINED,
|
obs_props = dict(render_agents=AgentRenderOptions.COMBINED,
|
||||||
frames_to_stack=stack_n_frames, pomdp_r=pomdp_r)
|
pomdp_r=pomdp_r,
|
||||||
|
indicate_door_area=True,
|
||||||
|
show_global_position_info=False,
|
||||||
|
frames_to_stack=stack_n_frames)
|
||||||
|
|
||||||
factory_kwargs = dict(n_agents=n_agents, individual_rewards=individual_rewards,
|
factory_kwargs = dict(**dictionary,
|
||||||
max_steps=max_steps, obs_prop=obs_props,
|
n_agents=n_agents,
|
||||||
mv_prop=MovementProperties(**dictionary['movement_props']),
|
individual_rewards=individual_rewards,
|
||||||
dirt_prop=DirtProperties(**dictionary['dirt_props']),
|
max_steps=max_steps,
|
||||||
record_episodes=False, verbose=False, **dictionary['factory_props']
|
obs_prop=obs_props,
|
||||||
|
verbose=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
return DirtFactory(**factory_kwargs).__enter__()
|
return DirtFactory(**factory_kwargs).__enter__()
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
import abc
|
import abc
|
||||||
import time
|
import time
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from enum import Enum
|
from itertools import chain
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import List, Union, Iterable, Dict
|
from typing import List, Union, Iterable, Dict
|
||||||
import numpy as np
|
import numpy as np
|
||||||
@ -11,10 +11,13 @@ from gym import spaces
|
|||||||
from gym.wrappers import FrameStack
|
from gym.wrappers import FrameStack
|
||||||
|
|
||||||
from environments.factory.base.shadow_casting import Map
|
from environments.factory.base.shadow_casting import Map
|
||||||
from environments.helpers import Constants as c, Constants
|
|
||||||
from environments import helpers as h
|
from environments import helpers as h
|
||||||
from environments.factory.base.objects import Agent, Tile, Action
|
from environments.helpers import Constants as c
|
||||||
from environments.factory.base.registers import Actions, Entities, Agents, Doors, FloorTiles, WallTiles, PlaceHolders
|
from environments.helpers import EnvActions as a
|
||||||
|
from environments.helpers import RewardsBase
|
||||||
|
from environments.factory.base.objects import Agent, Floor, Action
|
||||||
|
from environments.factory.base.registers import Actions, Entities, Agents, Doors, Floors, Walls, PlaceHolders, \
|
||||||
|
GlobalPositions
|
||||||
from environments.utility_classes import MovementProperties, ObservationProperties, MarlFrameStack
|
from environments.utility_classes import MovementProperties, ObservationProperties, MarlFrameStack
|
||||||
from environments.utility_classes import AgentRenderOptions as a_obs
|
from environments.utility_classes import AgentRenderOptions as a_obs
|
||||||
|
|
||||||
@ -30,17 +33,30 @@ class BaseFactory(gym.Env):
|
|||||||
def action_space(self):
|
def action_space(self):
|
||||||
return spaces.Discrete(len(self._actions))
|
return spaces.Discrete(len(self._actions))
|
||||||
|
|
||||||
|
@property
|
||||||
|
def named_action_space(self):
|
||||||
|
return {x.identifier: idx for idx, x in enumerate(self._actions.values())}
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def observation_space(self):
|
def observation_space(self):
|
||||||
if r := self._pomdp_r:
|
obs, _ = self._build_observations()
|
||||||
z = self._obs_cube.shape[0]
|
if self.n_agents > 1:
|
||||||
xy = r*2 + 1
|
shape = obs[0].shape
|
||||||
level_shape = (z, xy, xy)
|
|
||||||
else:
|
else:
|
||||||
level_shape = self._obs_cube.shape
|
shape = obs.shape
|
||||||
space = spaces.Box(low=0, high=1, shape=level_shape, dtype=np.float32)
|
space = spaces.Box(low=0, high=1, shape=shape, dtype=np.float32)
|
||||||
return space
|
return space
|
||||||
|
|
||||||
|
@property
|
||||||
|
def named_observation_space(self):
|
||||||
|
# Build it
|
||||||
|
_, named_obs = self._build_observations()
|
||||||
|
if self.n_agents > 1:
|
||||||
|
# Only return the first named obs space, as their structure at the moment is same.
|
||||||
|
return named_obs[list(named_obs.keys())[0]]
|
||||||
|
else:
|
||||||
|
return named_obs
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def pomdp_diameter(self):
|
def pomdp_diameter(self):
|
||||||
return self._pomdp_r * 2 + 1
|
return self._pomdp_r * 2 + 1
|
||||||
@ -52,6 +68,7 @@ class BaseFactory(gym.Env):
|
|||||||
@property
|
@property
|
||||||
def params(self) -> dict:
|
def params(self) -> dict:
|
||||||
d = {key: val for key, val in self.__dict__.items() if not key.startswith('_') and not key.startswith('__')}
|
d = {key: val for key, val in self.__dict__.items() if not key.startswith('_') and not key.startswith('__')}
|
||||||
|
d['class_name'] = self.__class__.__name__
|
||||||
return d
|
return d
|
||||||
|
|
||||||
def __enter__(self):
|
def __enter__(self):
|
||||||
@ -64,17 +81,26 @@ class BaseFactory(gym.Env):
|
|||||||
def __init__(self, level_name='simple', n_agents=1, max_steps=int(5e2),
|
def __init__(self, level_name='simple', n_agents=1, max_steps=int(5e2),
|
||||||
mv_prop: MovementProperties = MovementProperties(),
|
mv_prop: MovementProperties = MovementProperties(),
|
||||||
obs_prop: ObservationProperties = ObservationProperties(),
|
obs_prop: ObservationProperties = ObservationProperties(),
|
||||||
|
rewards_base: RewardsBase = RewardsBase(),
|
||||||
parse_doors=False, done_at_collision=False, inject_agents: Union[None, List] = None,
|
parse_doors=False, done_at_collision=False, inject_agents: Union[None, List] = None,
|
||||||
verbose=False, doors_have_area=True, env_seed=time.time_ns(), individual_rewards=False,
|
verbose=False, doors_have_area=True, env_seed=time.time_ns(), individual_rewards=False,
|
||||||
**kwargs):
|
class_name='', **kwargs):
|
||||||
|
|
||||||
|
if class_name:
|
||||||
|
print(f'You loaded parameters for {class_name}', f'this is: {self.__class__.__name__}')
|
||||||
|
|
||||||
if isinstance(mv_prop, dict):
|
if isinstance(mv_prop, dict):
|
||||||
mv_prop = MovementProperties(**mv_prop)
|
mv_prop = MovementProperties(**mv_prop)
|
||||||
if isinstance(obs_prop, dict):
|
if isinstance(obs_prop, dict):
|
||||||
obs_prop = ObservationProperties(**obs_prop)
|
obs_prop = ObservationProperties(**obs_prop)
|
||||||
|
if isinstance(rewards_base, dict):
|
||||||
|
rewards_base = RewardsBase(**rewards_base)
|
||||||
|
|
||||||
assert obs_prop.frames_to_stack != 1 and \
|
assert obs_prop.frames_to_stack != 1 and \
|
||||||
obs_prop.frames_to_stack >= 0, "'frames_to_stack' cannot be negative or 1."
|
obs_prop.frames_to_stack >= 0, \
|
||||||
|
"'frames_to_stack' cannot be negative or 1."
|
||||||
|
assert doors_have_area or not obs_prop.indicate_door_area, \
|
||||||
|
'"indicate_door_area" can only active, when "doors_have_area"'
|
||||||
if kwargs:
|
if kwargs:
|
||||||
print(f'Following kwargs were passed, but ignored: {kwargs}')
|
print(f'Following kwargs were passed, but ignored: {kwargs}')
|
||||||
|
|
||||||
@ -84,13 +110,17 @@ class BaseFactory(gym.Env):
|
|||||||
self._base_rng = np.random.default_rng(self.env_seed)
|
self._base_rng = np.random.default_rng(self.env_seed)
|
||||||
self.mv_prop = mv_prop
|
self.mv_prop = mv_prop
|
||||||
self.obs_prop = obs_prop
|
self.obs_prop = obs_prop
|
||||||
|
self.rewards_base = rewards_base
|
||||||
self.level_name = level_name
|
self.level_name = level_name
|
||||||
self._level_shape = None
|
self._level_shape = None
|
||||||
|
self._obs_shape = None
|
||||||
self.verbose = verbose
|
self.verbose = verbose
|
||||||
self._renderer = None # expensive - don't use it when not required !
|
self._renderer = None # expensive - don't use it when not required !
|
||||||
self._entities = Entities()
|
self._entities = Entities()
|
||||||
|
|
||||||
self.n_agents = n_agents
|
self.n_agents = n_agents
|
||||||
|
level_filepath = Path(__file__).parent.parent / h.LEVELS_DIR / f'{self.level_name}.txt'
|
||||||
|
self._parsed_level = h.parse_level(level_filepath)
|
||||||
|
|
||||||
self.max_steps = max_steps
|
self.max_steps = max_steps
|
||||||
self._pomdp_r = self.obs_prop.pomdp_r
|
self._pomdp_r = self.obs_prop.pomdp_r
|
||||||
@ -102,7 +132,7 @@ class BaseFactory(gym.Env):
|
|||||||
self.doors_have_area = doors_have_area
|
self.doors_have_area = doors_have_area
|
||||||
self.individual_rewards = individual_rewards
|
self.individual_rewards = individual_rewards
|
||||||
|
|
||||||
# Reset
|
# TODO: Reset ---> document this
|
||||||
self.reset()
|
self.reset()
|
||||||
|
|
||||||
def __getitem__(self, item):
|
def __getitem__(self, item):
|
||||||
@ -114,57 +144,59 @@ class BaseFactory(gym.Env):
|
|||||||
# Objects
|
# Objects
|
||||||
self._entities = Entities()
|
self._entities = Entities()
|
||||||
# Level
|
# Level
|
||||||
level_filepath = Path(__file__).parent.parent / h.LEVELS_DIR / f'{self.level_name}.txt'
|
|
||||||
parsed_level = h.parse_level(level_filepath)
|
level_array = h.one_hot_level(self._parsed_level)
|
||||||
level_array = h.one_hot_level(parsed_level)
|
level_array = np.pad(level_array, self.obs_prop.pomdp_r, 'constant', constant_values=1)
|
||||||
|
|
||||||
self._level_shape = level_array.shape
|
self._level_shape = level_array.shape
|
||||||
|
self._obs_shape = self._level_shape if not self.obs_prop.pomdp_r else (self.pomdp_diameter, ) * 2
|
||||||
|
|
||||||
# Walls
|
# Walls
|
||||||
walls = WallTiles.from_argwhere_coordinates(
|
walls = Walls.from_argwhere_coordinates(
|
||||||
np.argwhere(level_array == c.OCCUPIED_CELL.value),
|
np.argwhere(level_array == c.OCCUPIED_CELL),
|
||||||
self._level_shape
|
self._level_shape
|
||||||
)
|
)
|
||||||
self._entities.register_additional_items({c.WALLS: walls})
|
self._entities.register_additional_items({c.WALLS: walls})
|
||||||
|
|
||||||
# Floor
|
# Floor
|
||||||
floor = FloorTiles.from_argwhere_coordinates(
|
floor = Floors.from_argwhere_coordinates(
|
||||||
np.argwhere(level_array == c.FREE_CELL.value),
|
np.argwhere(level_array == c.FREE_CELL),
|
||||||
self._level_shape
|
self._level_shape
|
||||||
)
|
)
|
||||||
self._entities.register_additional_items({c.FLOOR: floor})
|
self._entities.register_additional_items({c.FLOOR: floor})
|
||||||
|
|
||||||
# NOPOS
|
# NOPOS
|
||||||
self._NO_POS_TILE = Tile(c.NO_POS.value)
|
self._NO_POS_TILE = Floor(c.NO_POS, None)
|
||||||
|
|
||||||
# Doors
|
# Doors
|
||||||
if self.parse_doors:
|
if self.parse_doors:
|
||||||
parsed_doors = h.one_hot_level(parsed_level, c.DOOR)
|
parsed_doors = h.one_hot_level(self._parsed_level, c.DOOR)
|
||||||
|
parsed_doors = np.pad(parsed_doors, self.obs_prop.pomdp_r, 'constant', constant_values=0)
|
||||||
if np.any(parsed_doors):
|
if np.any(parsed_doors):
|
||||||
door_tiles = [floor.by_pos(pos) for pos in np.argwhere(parsed_doors == c.OCCUPIED_CELL.value)]
|
door_tiles = [floor.by_pos(tuple(pos)) for pos in np.argwhere(parsed_doors == c.OCCUPIED_CELL)]
|
||||||
doors = Doors.from_tiles(door_tiles, self._level_shape,
|
doors = Doors.from_tiles(door_tiles, self._level_shape, have_area=self.obs_prop.indicate_door_area,
|
||||||
entity_kwargs=dict(context=floor)
|
entity_kwargs=dict(context=floor)
|
||||||
)
|
)
|
||||||
self._entities.register_additional_items({c.DOORS: doors})
|
self._entities.register_additional_items({c.DOORS: doors})
|
||||||
|
|
||||||
# Actions
|
# Actions
|
||||||
self._actions = Actions(self.mv_prop, can_use_doors=self.parse_doors)
|
self._actions = Actions(self.mv_prop, can_use_doors=self.parse_doors)
|
||||||
if additional_actions := self.additional_actions:
|
if additional_actions := self.actions_hook:
|
||||||
self._actions.register_additional_items(additional_actions)
|
self._actions.register_additional_items(additional_actions)
|
||||||
|
|
||||||
# Agents
|
# Agents
|
||||||
agents_to_spawn = self.n_agents-len(self._injected_agents)
|
agents_to_spawn = self.n_agents-len(self._injected_agents)
|
||||||
agents_kwargs = dict(level_shape=self._level_shape,
|
agents_kwargs = dict(individual_slices=self.obs_prop.render_agents == a_obs.SEPERATE,
|
||||||
individual_slices=self.obs_prop.render_agents == a_obs.SEPERATE,
|
hide_from_obs_builder=self.obs_prop.render_agents in [a_obs.NOT, a_obs.LEVEL],
|
||||||
hide_from_obs_builder=self.obs_prop.render_agents == a_obs.LEVEL,
|
)
|
||||||
is_observable=self.obs_prop.render_agents != a_obs.NOT)
|
|
||||||
if agents_to_spawn:
|
if agents_to_spawn:
|
||||||
agents = Agents.from_tiles(floor.empty_tiles[:agents_to_spawn], **agents_kwargs)
|
agents = Agents.from_tiles(floor.empty_tiles[:agents_to_spawn], self._level_shape, **agents_kwargs)
|
||||||
else:
|
else:
|
||||||
agents = Agents(**agents_kwargs)
|
agents = Agents(self._level_shape, **agents_kwargs)
|
||||||
if self._injected_agents:
|
if self._injected_agents:
|
||||||
initialized_injections = list()
|
initialized_injections = list()
|
||||||
for i, injection in enumerate(self._injected_agents):
|
for i, injection in enumerate(self._injected_agents):
|
||||||
agents.register_item(injection(self, floor.empty_tiles[agents_to_spawn+i+1], static_problem=False))
|
agents.register_item(injection(self, floor.empty_tiles[0], agents, static_problem=False))
|
||||||
initialized_injections.append(agents[-1])
|
initialized_injections.append(agents[-1])
|
||||||
self._initialized_injections = initialized_injections
|
self._initialized_injections = initialized_injections
|
||||||
self._entities.register_additional_items({c.AGENT: agents})
|
self._entities.register_additional_items({c.AGENT: agents})
|
||||||
@ -173,35 +205,34 @@ class BaseFactory(gym.Env):
|
|||||||
# TODO: Make this accept Lists for multiple placeholders
|
# TODO: Make this accept Lists for multiple placeholders
|
||||||
|
|
||||||
# Empty Observations with either [0, 1, N(0, 1)]
|
# Empty Observations with either [0, 1, N(0, 1)]
|
||||||
placeholder = PlaceHolders.from_tiles([self._NO_POS_TILE], self._level_shape,
|
placeholder = PlaceHolders.from_values(self.obs_prop.additional_agent_placeholder, self._level_shape,
|
||||||
entity_kwargs=dict(
|
entity_kwargs=dict(
|
||||||
fill_value=self.obs_prop.additional_agent_placeholder)
|
fill_value=self.obs_prop.additional_agent_placeholder)
|
||||||
)
|
)
|
||||||
|
|
||||||
self._entities.register_additional_items({c.AGENT_PLACEHOLDER: placeholder})
|
self._entities.register_additional_items({c.AGENT_PLACEHOLDER: placeholder})
|
||||||
|
|
||||||
# Additional Entitites from SubEnvs
|
# Additional Entitites from SubEnvs
|
||||||
if additional_entities := self.additional_entities:
|
if additional_entities := self.entities_hook:
|
||||||
self._entities.register_additional_items(additional_entities)
|
self._entities.register_additional_items(additional_entities)
|
||||||
|
|
||||||
|
if self.obs_prop.show_global_position_info:
|
||||||
|
global_positions = GlobalPositions(self._level_shape)
|
||||||
|
# This moved into the GlobalPosition object
|
||||||
|
# obs_shape_2d = self._level_shape if not self._pomdp_r else ((self.pomdp_diameter,) * 2)
|
||||||
|
global_positions.spawn_global_position_objects(self[c.AGENT])
|
||||||
|
self._entities.register_additional_items({c.GLOBAL_POSITION: global_positions})
|
||||||
|
|
||||||
# Return
|
# Return
|
||||||
return self._entities
|
return self._entities
|
||||||
|
|
||||||
def _init_obs_cube(self):
|
def reset(self) -> (np.typing.ArrayLike, int, bool, dict):
|
||||||
arrays = self._entities.obs_arrays
|
|
||||||
|
|
||||||
obs_cube_z = sum([a.shape[0] if not self[key].is_per_agent else 1 for key, a in arrays.items()])
|
|
||||||
obs_cube_z += 1 if self.obs_prop.show_global_position_info else 0
|
|
||||||
self._obs_cube = np.zeros((obs_cube_z, *self._level_shape), dtype=np.float32)
|
|
||||||
|
|
||||||
def reset(self) -> (np.ndarray, int, bool, dict):
|
|
||||||
_ = self._base_init_env()
|
_ = self._base_init_env()
|
||||||
self._init_obs_cube()
|
self.reset_hook()
|
||||||
self.do_additional_reset()
|
|
||||||
|
|
||||||
self._steps = 0
|
self._steps = 0
|
||||||
|
|
||||||
obs = self._get_observations()
|
obs, _ = self._build_observations()
|
||||||
return obs
|
return obs
|
||||||
|
|
||||||
def step(self, actions):
|
def step(self, actions):
|
||||||
@ -213,39 +244,53 @@ class BaseFactory(gym.Env):
|
|||||||
self._steps += 1
|
self._steps += 1
|
||||||
|
|
||||||
# Pre step Hook for later use
|
# Pre step Hook for later use
|
||||||
self.hook_pre_step()
|
self.pre_step_hook()
|
||||||
|
|
||||||
# Move this in a seperate function?
|
|
||||||
for action, agent in zip(actions, self[c.AGENT]):
|
for action, agent in zip(actions, self[c.AGENT]):
|
||||||
agent.clear_temp_state()
|
agent.clear_temp_state()
|
||||||
action_obj = self._actions[int(action)]
|
action_obj = self._actions[int(action)]
|
||||||
# self.print(f'Action #{action} has been resolved to: {action_obj}')
|
step_result = dict(collisions=[], rewards=[], info={}, action_name='', action_valid=False)
|
||||||
if h.MovingAction.is_member(action_obj):
|
# cls.print(f'Action #{action} has been resolved to: {action_obj}')
|
||||||
valid = self._move_or_colide(agent, action_obj)
|
if a.is_move(action_obj):
|
||||||
elif h.EnvActions.NOOP == agent.temp_action:
|
action_valid, reward = self._do_move_action(agent, action_obj)
|
||||||
valid = c.VALID
|
elif a.NOOP == action_obj:
|
||||||
elif h.EnvActions.USE_DOOR == action_obj:
|
action_valid = c.VALID
|
||||||
valid = self._handle_door_interaction(agent)
|
reward = dict(value=self.rewards_base.NOOP, reason=a.NOOP, info={f'{agent.name}_NOOP': 1, 'NOOP': 1})
|
||||||
|
elif a.USE_DOOR == action_obj:
|
||||||
|
action_valid, reward = self._handle_door_interaction(agent)
|
||||||
else:
|
else:
|
||||||
valid = self.do_additional_actions(agent, action_obj)
|
# noinspection PyTupleAssignmentBalance
|
||||||
assert valid is not None, 'This should not happen, every Action musst be detected correctly!'
|
action_valid, reward = self.do_additional_actions(agent, action_obj)
|
||||||
agent.temp_action = action_obj
|
# Not needed any more sice the tuple assignment above will fail in case of a failing action resolvement.
|
||||||
agent.temp_valid = valid
|
# assert step_result is not None, 'This should not happen, every Action musst be detected correctly!'
|
||||||
|
step_result['action_name'] = action_obj.identifier
|
||||||
# In-between step Hook for later use
|
step_result['action_valid'] = action_valid
|
||||||
info = self.do_additional_step()
|
step_result['rewards'].append(reward)
|
||||||
|
agent.step_result = step_result
|
||||||
|
|
||||||
|
# Additional step and Reward, Info Init
|
||||||
|
rewards, info = self.step_hook()
|
||||||
|
# Todo: Make this faster, so that only tiles of entities that can collide are searched.
|
||||||
tiles_with_collisions = self.get_all_tiles_with_collisions()
|
tiles_with_collisions = self.get_all_tiles_with_collisions()
|
||||||
for tile in tiles_with_collisions:
|
for tile in tiles_with_collisions:
|
||||||
guests = tile.guests_that_can_collide
|
guests = tile.guests_that_can_collide
|
||||||
for i, guest in enumerate(guests):
|
for i, guest in enumerate(guests):
|
||||||
|
# This does make a copy, but is faster than.copy()
|
||||||
this_collisions = guests[:]
|
this_collisions = guests[:]
|
||||||
del this_collisions[i]
|
del this_collisions[i]
|
||||||
guest.temp_collisions = this_collisions
|
assert hasattr(guest, 'step_result')
|
||||||
|
for collision in this_collisions:
|
||||||
|
guest.step_result['collisions'].append(collision)
|
||||||
|
|
||||||
done = self.done_at_collision and tiles_with_collisions
|
done = False
|
||||||
|
if self.done_at_collision:
|
||||||
|
if done_at_col := bool(tiles_with_collisions):
|
||||||
|
done = done_at_col
|
||||||
|
info.update(COLLISION_DONE=done_at_col)
|
||||||
|
|
||||||
done = done or self.check_additional_done()
|
additional_done, additional_done_info = self.check_additional_done()
|
||||||
|
done = done or additional_done
|
||||||
|
info.update(additional_done_info)
|
||||||
|
|
||||||
# Step the door close intervall
|
# Step the door close intervall
|
||||||
if self.parse_doors:
|
if self.parse_doors:
|
||||||
@ -253,7 +298,8 @@ class BaseFactory(gym.Env):
|
|||||||
doors.tick_doors()
|
doors.tick_doors()
|
||||||
|
|
||||||
# Finalize
|
# Finalize
|
||||||
reward, reward_info = self.calculate_reward()
|
reward, reward_info = self.build_reward_result(rewards)
|
||||||
|
|
||||||
info.update(reward_info)
|
info.update(reward_info)
|
||||||
if self._steps >= self.max_steps:
|
if self._steps >= self.max_steps:
|
||||||
done = True
|
done = True
|
||||||
@ -262,13 +308,13 @@ class BaseFactory(gym.Env):
|
|||||||
info.update(self._summarize_state())
|
info.update(self._summarize_state())
|
||||||
|
|
||||||
# Post step Hook for later use
|
# Post step Hook for later use
|
||||||
info.update(self.hook_post_step())
|
info.update(self.post_step_hook())
|
||||||
|
|
||||||
obs = self._get_observations()
|
obs, _ = self._build_observations()
|
||||||
|
|
||||||
return obs, reward, done, info
|
return obs, reward, done, info
|
||||||
|
|
||||||
def _handle_door_interaction(self, agent) -> c:
|
def _handle_door_interaction(self, agent) -> (bool, dict):
|
||||||
if doors := self[c.DOORS]:
|
if doors := self[c.DOORS]:
|
||||||
# Check if agent really is standing on a door:
|
# Check if agent really is standing on a door:
|
||||||
if self.doors_have_area:
|
if self.doors_have_area:
|
||||||
@ -277,164 +323,175 @@ class BaseFactory(gym.Env):
|
|||||||
door = doors.by_pos(agent.pos)
|
door = doors.by_pos(agent.pos)
|
||||||
if door is not None:
|
if door is not None:
|
||||||
door.use()
|
door.use()
|
||||||
return c.VALID
|
valid = c.VALID
|
||||||
|
self.print(f'{agent.name} just used a {door.name} at {door.pos}')
|
||||||
|
info_dict = {f'{agent.name}_door_use': 1, f'door_use': 1}
|
||||||
# When he doesn't...
|
# When he doesn't...
|
||||||
else:
|
else:
|
||||||
return c.NOT_VALID
|
valid = c.NOT_VALID
|
||||||
|
info_dict = {f'{agent.name}_failed_door_use': 1, 'failed_door_use': 1}
|
||||||
|
self.print(f'{agent.name} just tried to use a door at {agent.pos}, but there is none.')
|
||||||
|
|
||||||
else:
|
else:
|
||||||
return c.NOT_VALID
|
raise RuntimeError('This should not happen, since the door action should not be available.')
|
||||||
|
reward = dict(value=self.rewards_base.USE_DOOR_VALID if valid else self.rewards_base.USE_DOOR_FAIL,
|
||||||
|
reason=a.USE_DOOR, info=info_dict)
|
||||||
|
|
||||||
def _get_observations(self) -> np.ndarray:
|
return valid, reward
|
||||||
state_array_dict = self._entities.obs_arrays
|
|
||||||
if self.n_agents == 1:
|
def _build_observations(self) -> np.typing.ArrayLike:
|
||||||
obs = self._build_per_agent_obs(self[c.AGENT][0], state_array_dict)
|
# Observation dict:
|
||||||
elif self.n_agents >= 2:
|
per_agent_expl_idx = dict()
|
||||||
obs = np.stack([self._build_per_agent_obs(agent, state_array_dict) for agent in self[c.AGENT]])
|
per_agent_obsn = dict()
|
||||||
|
# Generel Observations
|
||||||
|
lvl_obs = self[c.WALLS].as_array()
|
||||||
|
door_obs = self[c.DOORS].as_array() if self.parse_doors else None
|
||||||
|
if self.obs_prop.render_agents == a_obs.NOT:
|
||||||
|
global_agent_obs = None
|
||||||
|
elif self.obs_prop.omit_agent_self and self.n_agents == 1:
|
||||||
|
global_agent_obs = None
|
||||||
else:
|
else:
|
||||||
raise ValueError('n_agents cannot be smaller than 1!!')
|
global_agent_obs = self[c.AGENT].as_array().copy()
|
||||||
return obs
|
placeholder_obs = self[c.AGENT_PLACEHOLDER].as_array() if self[c.AGENT_PLACEHOLDER] else None
|
||||||
|
add_obs_dict = self.observations_hook()
|
||||||
|
|
||||||
def _build_per_agent_obs(self, agent: Agent, state_array_dict) -> np.ndarray:
|
for agent_idx, agent in enumerate(self[c.AGENT]):
|
||||||
agent_pos_is_omitted = False
|
obs_dict = dict()
|
||||||
agent_omit_idx = None
|
# Build Agent Observations
|
||||||
|
if self.obs_prop.render_agents != a_obs.NOT:
|
||||||
if self.obs_prop.omit_agent_self and self.n_agents == 1:
|
if self.obs_prop.omit_agent_self and self.n_agents >= 2:
|
||||||
pass
|
if self.obs_prop.render_agents == a_obs.SEPERATE:
|
||||||
elif self.obs_prop.omit_agent_self and self.obs_prop.render_agents in [a_obs.COMBINED, ] and self.n_agents > 1:
|
other_agent_obs_idx = [x for x in range(self.n_agents) if x != agent_idx]
|
||||||
state_array_dict[c.AGENT][0, agent.x, agent.y] -= agent.encoding
|
agent_obs = np.take(global_agent_obs, other_agent_obs_idx, axis=0)
|
||||||
agent_pos_is_omitted = True
|
|
||||||
elif self.obs_prop.omit_agent_self and self.obs_prop.render_agents == a_obs.SEPERATE and self.n_agents > 1:
|
|
||||||
agent_omit_idx = next((i for i, a in enumerate(self[c.AGENT]) if a == agent))
|
|
||||||
|
|
||||||
running_idx, shadowing_idxs, can_be_shadowed_idxs = 0, [], []
|
|
||||||
self._obs_cube[:] = 0
|
|
||||||
|
|
||||||
# FIXME: Refactor this! Make a globally build observation, then add individual per-agent-obs
|
|
||||||
for key, array in state_array_dict.items():
|
|
||||||
# Flush state array object representation to obs cube
|
|
||||||
if not self[key].hide_from_obs_builder:
|
|
||||||
if self[key].is_per_agent:
|
|
||||||
per_agent_idx = self[key].idx_by_entity(agent)
|
|
||||||
z = 1
|
|
||||||
self._obs_cube[running_idx: running_idx+z] = array[per_agent_idx]
|
|
||||||
else:
|
|
||||||
if key == c.AGENT and agent_omit_idx is not None:
|
|
||||||
z = array.shape[0] - 1
|
|
||||||
for array_idx in range(array.shape[0]):
|
|
||||||
self._obs_cube[running_idx: running_idx+z] = array[[x for x in range(array.shape[0])
|
|
||||||
if x != agent_omit_idx]]
|
|
||||||
# Agent OBS are combined
|
|
||||||
elif key == c.AGENT and self.obs_prop.omit_agent_self \
|
|
||||||
and self.obs_prop.render_agents == a_obs.COMBINED:
|
|
||||||
z = 1
|
|
||||||
self._obs_cube[running_idx: running_idx + z] = array
|
|
||||||
# Each Agent is rendered on a seperate array slice
|
|
||||||
else:
|
else:
|
||||||
z = array.shape[0]
|
agent_obs = global_agent_obs.copy()
|
||||||
self._obs_cube[running_idx: running_idx + z] = array
|
agent_obs[(0, *agent.pos)] -= agent.encoding
|
||||||
# Define which OBS SLices cast a Shadow
|
else:
|
||||||
if self[key].is_blocking_light:
|
agent_obs = global_agent_obs
|
||||||
for i in range(z):
|
|
||||||
shadowing_idxs.append(running_idx + i)
|
|
||||||
# Define which OBS SLices are effected by shadows
|
|
||||||
if self[key].can_be_shadowed:
|
|
||||||
for i in range(z):
|
|
||||||
can_be_shadowed_idxs.append(running_idx + i)
|
|
||||||
running_idx += z
|
|
||||||
|
|
||||||
if agent_pos_is_omitted:
|
|
||||||
state_array_dict[c.AGENT][0, agent.x, agent.y] += agent.encoding
|
|
||||||
|
|
||||||
if self._pomdp_r:
|
|
||||||
obs = self._do_pomdp_obs_cutout(agent, self._obs_cube)
|
|
||||||
else:
|
|
||||||
obs = self._obs_cube
|
|
||||||
|
|
||||||
obs = obs.copy()
|
|
||||||
|
|
||||||
if self.obs_prop.cast_shadows:
|
|
||||||
obs_block_light = [obs[idx] != c.OCCUPIED_CELL.value for idx in shadowing_idxs]
|
|
||||||
door_shadowing = False
|
|
||||||
if self.parse_doors:
|
|
||||||
if doors := self[c.DOORS]:
|
|
||||||
if door := doors.by_pos(agent.pos):
|
|
||||||
if door.is_closed:
|
|
||||||
for group in door.connectivity_subgroups:
|
|
||||||
if agent.last_pos not in group:
|
|
||||||
door_shadowing = True
|
|
||||||
if self._pomdp_r:
|
|
||||||
blocking = [tuple(np.subtract(x, agent.pos) + (self._pomdp_r, self._pomdp_r))
|
|
||||||
for x in group]
|
|
||||||
xs, ys = zip(*blocking)
|
|
||||||
else:
|
|
||||||
xs, ys = zip(*group)
|
|
||||||
|
|
||||||
# noinspection PyUnresolvedReferences
|
|
||||||
obs_block_light[0][xs, ys] = False
|
|
||||||
|
|
||||||
light_block_map = Map((np.prod(obs_block_light, axis=0) != True).astype(int))
|
|
||||||
if self._pomdp_r:
|
|
||||||
light_block_map = light_block_map.do_fov(self._pomdp_r, self._pomdp_r, max(self._level_shape))
|
|
||||||
else:
|
else:
|
||||||
light_block_map = light_block_map.do_fov(*agent.pos, max(self._level_shape))
|
agent_obs = global_agent_obs
|
||||||
if door_shadowing:
|
|
||||||
# noinspection PyUnboundLocalVariable
|
|
||||||
light_block_map[xs, ys] = 0
|
|
||||||
agent.temp_light_map = light_block_map
|
|
||||||
for obs_idx in can_be_shadowed_idxs:
|
|
||||||
obs[obs_idx] = ((obs[obs_idx] * light_block_map) + 0.) - (1 - light_block_map) # * obs[0])
|
|
||||||
else:
|
|
||||||
pass
|
|
||||||
|
|
||||||
# Agents observe other agents as wall
|
# Build Level Observations
|
||||||
if self.obs_prop.render_agents == a_obs.LEVEL and self.n_agents > 1:
|
if self.obs_prop.render_agents == a_obs.LEVEL:
|
||||||
other_agent_obs = self[c.AGENT].as_array()
|
lvl_obs = lvl_obs.copy()
|
||||||
if self.obs_prop.omit_agent_self:
|
lvl_obs += global_agent_obs
|
||||||
other_agent_obs[:, agent.x, agent.y] -= agent.encoding
|
|
||||||
|
|
||||||
|
obs_dict[c.WALLS] = lvl_obs
|
||||||
|
if self.obs_prop.render_agents in [a_obs.SEPERATE, a_obs.COMBINED] and agent_obs is not None:
|
||||||
|
obs_dict[c.AGENT] = agent_obs[:]
|
||||||
|
if self[c.AGENT_PLACEHOLDER] and placeholder_obs is not None:
|
||||||
|
obs_dict[c.AGENT_PLACEHOLDER] = placeholder_obs
|
||||||
|
if self.parse_doors and door_obs is not None:
|
||||||
|
obs_dict[c.DOORS] = door_obs[:]
|
||||||
|
obs_dict.update(add_obs_dict)
|
||||||
|
obsn = np.vstack(list(obs_dict.values()))
|
||||||
if self.obs_prop.pomdp_r:
|
if self.obs_prop.pomdp_r:
|
||||||
oobs = self._do_pomdp_obs_cutout(agent, other_agent_obs)[0]
|
obsn = self._do_pomdp_cutout(agent, obsn)
|
||||||
# noinspection PyUnresolvedReferences
|
|
||||||
mask = (oobs != c.SHADOWED_CELL.value).astype(int)
|
|
||||||
obs[0] += oobs * mask
|
|
||||||
|
|
||||||
|
raw_obs = self.per_agent_raw_observations_hook(agent)
|
||||||
|
raw_obs = {key: np.expand_dims(val, 0) if val.ndim != 3 else val for key, val in raw_obs.items()}
|
||||||
|
obsn = np.vstack((obsn, *raw_obs.values()))
|
||||||
|
|
||||||
|
keys = list(chain(obs_dict.keys(), raw_obs.keys()))
|
||||||
|
idxs = np.cumsum([x.shape[0] for x in chain(obs_dict.values(), raw_obs.values())]) - 1
|
||||||
|
per_agent_expl_idx[agent.name] = {key: list(range(d, b)) for key, d, b in
|
||||||
|
zip(keys, idxs, list(idxs[1:]) + [idxs[-1]+1, ])}
|
||||||
|
|
||||||
|
# Shadow Casting
|
||||||
|
if agent.step_result is not None:
|
||||||
|
pass
|
||||||
else:
|
else:
|
||||||
obs[0] += other_agent_obs
|
assert self._steps == 0
|
||||||
|
agent.step_result = {'action_name': a.NOOP, 'action_valid': True,
|
||||||
|
'collisions': [], 'lightmap': None}
|
||||||
|
if self.obs_prop.cast_shadows:
|
||||||
|
try:
|
||||||
|
light_block_obs = [obs_idx for key, obs_idx in per_agent_expl_idx[agent.name].items()
|
||||||
|
if self[key].is_blocking_light]
|
||||||
|
# Flatten
|
||||||
|
light_block_obs = [x for y in light_block_obs for x in y]
|
||||||
|
shadowed_obs = [obs_idx for key, obs_idx in per_agent_expl_idx[agent.name].items()
|
||||||
|
if self[key].can_be_shadowed]
|
||||||
|
# Flatten
|
||||||
|
shadowed_obs = [x for y in shadowed_obs for x in y]
|
||||||
|
except AttributeError as e:
|
||||||
|
print('Check your Keys! Only use Constants as Keys!')
|
||||||
|
print(e)
|
||||||
|
raise e
|
||||||
|
|
||||||
# Additional Observation:
|
obs_block_light = obsn[light_block_obs] != c.OCCUPIED_CELL
|
||||||
for additional_obs in self.additional_obs_build():
|
door_shadowing = False
|
||||||
obs[running_idx:running_idx+additional_obs.shape[0]] = additional_obs
|
if self.parse_doors:
|
||||||
running_idx += additional_obs.shape[0]
|
if doors := self[c.DOORS]:
|
||||||
for additional_per_agent_obs in self.additional_per_agent_obs_build(agent):
|
if door := doors.by_pos(agent.pos):
|
||||||
obs[running_idx:running_idx + additional_per_agent_obs.shape[0]] = additional_per_agent_obs
|
if door.is_closed:
|
||||||
running_idx += additional_per_agent_obs.shape[0]
|
for group in door.connectivity_subgroups:
|
||||||
|
if agent.last_pos not in group:
|
||||||
|
door_shadowing = True
|
||||||
|
if self._pomdp_r:
|
||||||
|
blocking = [
|
||||||
|
tuple(np.subtract(x, agent.pos) + (self._pomdp_r, self._pomdp_r))
|
||||||
|
for x in group]
|
||||||
|
xs, ys = zip(*blocking)
|
||||||
|
else:
|
||||||
|
xs, ys = zip(*group)
|
||||||
|
|
||||||
return obs
|
# noinspection PyUnresolvedReferences
|
||||||
|
obs_block_light[:, xs, ys] = False
|
||||||
|
|
||||||
def _do_pomdp_obs_cutout(self, agent, obs_to_be_padded):
|
light_block_map = Map((np.prod(obs_block_light, axis=0) != True).astype(int).squeeze())
|
||||||
|
if self._pomdp_r:
|
||||||
|
light_block_map = light_block_map.do_fov(self._pomdp_r, self._pomdp_r, max(self._level_shape))
|
||||||
|
else:
|
||||||
|
light_block_map = light_block_map.do_fov(*agent.pos, max(self._level_shape))
|
||||||
|
if door_shadowing:
|
||||||
|
# noinspection PyUnboundLocalVariable
|
||||||
|
light_block_map[xs, ys] = 0
|
||||||
|
|
||||||
|
agent.step_result['lightmap'] = light_block_map
|
||||||
|
|
||||||
|
obsn[shadowed_obs] = ((obsn[shadowed_obs] * light_block_map) + 0.) - (1 - light_block_map)
|
||||||
|
else:
|
||||||
|
if self._pomdp_r:
|
||||||
|
agent.step_result['lightmap'] = np.ones(self._obs_shape)
|
||||||
|
else:
|
||||||
|
agent.step_result['lightmap'] = None
|
||||||
|
|
||||||
|
per_agent_obsn[agent.name] = obsn
|
||||||
|
|
||||||
|
if self.n_agents == 1:
|
||||||
|
agent_name = self[c.AGENT][0].name
|
||||||
|
obs, explained_idx = per_agent_obsn[agent_name], per_agent_expl_idx[agent_name]
|
||||||
|
elif self.n_agents >= 2:
|
||||||
|
obs, explained_idx = np.stack(list(per_agent_obsn.values())), per_agent_expl_idx
|
||||||
|
else:
|
||||||
|
raise ValueError
|
||||||
|
|
||||||
|
return obs, explained_idx
|
||||||
|
|
||||||
|
def _do_pomdp_cutout(self, agent, obs_to_be_padded):
|
||||||
assert obs_to_be_padded.ndim == 3
|
assert obs_to_be_padded.ndim == 3
|
||||||
r, d = self._pomdp_r, self.pomdp_diameter
|
ra, d = self._pomdp_r, self.pomdp_diameter
|
||||||
x0, x1 = max(0, agent.x - r), min(agent.x + r + 1, self._level_shape[0])
|
x0, x1 = max(0, agent.x - ra), min(agent.x + ra + 1, self._level_shape[0])
|
||||||
y0, y1 = max(0, agent.y - r), min(agent.y + r + 1, self._level_shape[1])
|
y0, y1 = max(0, agent.y - ra), min(agent.y + ra + 1, self._level_shape[1])
|
||||||
# Other Agent Obs = oobs
|
|
||||||
oobs = obs_to_be_padded[:, x0:x1, y0:y1]
|
oobs = obs_to_be_padded[:, x0:x1, y0:y1]
|
||||||
if oobs.shape[0:] != (d, d):
|
if oobs.shape[1:] != (d, d):
|
||||||
if xd := oobs.shape[1] % d:
|
if xd := oobs.shape[1] % d:
|
||||||
if agent.x > r:
|
if agent.x > ra:
|
||||||
x0_pad = 0
|
x0_pad = 0
|
||||||
x1_pad = (d - xd)
|
x1_pad = (d - xd)
|
||||||
else:
|
else:
|
||||||
x0_pad = r - agent.x
|
x0_pad = ra - agent.x
|
||||||
x1_pad = 0
|
x1_pad = 0
|
||||||
else:
|
else:
|
||||||
x0_pad, x1_pad = 0, 0
|
x0_pad, x1_pad = 0, 0
|
||||||
|
|
||||||
if yd := oobs.shape[2] % d:
|
if yd := oobs.shape[2] % d:
|
||||||
if agent.y > r:
|
if agent.y > ra:
|
||||||
y0_pad = 0
|
y0_pad = 0
|
||||||
y1_pad = (d - yd)
|
y1_pad = (d - yd)
|
||||||
else:
|
else:
|
||||||
y0_pad = r - agent.y
|
y0_pad = ra - agent.y
|
||||||
y1_pad = 0
|
y1_pad = 0
|
||||||
else:
|
else:
|
||||||
y0_pad, y1_pad = 0, 0
|
y0_pad, y1_pad = 0, 0
|
||||||
@ -442,25 +499,41 @@ class BaseFactory(gym.Env):
|
|||||||
oobs = np.pad(oobs, ((0, 0), (x0_pad, x1_pad), (y0_pad, y1_pad)), 'constant')
|
oobs = np.pad(oobs, ((0, 0), (x0_pad, x1_pad), (y0_pad, y1_pad)), 'constant')
|
||||||
return oobs
|
return oobs
|
||||||
|
|
||||||
def get_all_tiles_with_collisions(self) -> List[Tile]:
|
def get_all_tiles_with_collisions(self) -> List[Floor]:
|
||||||
tiles_with_collisions = list()
|
tiles = [x for x in self[c.FLOOR] if len(x.guests_that_can_collide) > 1]
|
||||||
for tile in self[c.FLOOR]:
|
if False:
|
||||||
if tile.is_occupied():
|
tiles_with_collisions = list()
|
||||||
guests = tile.guests_that_can_collide
|
for tile in self[c.FLOOR]:
|
||||||
if len(guests) >= 2:
|
if tile.is_occupied():
|
||||||
tiles_with_collisions.append(tile)
|
guests = tile.guests_that_can_collide
|
||||||
return tiles_with_collisions
|
if len(guests) >= 2:
|
||||||
|
tiles_with_collisions.append(tile)
|
||||||
|
return tiles
|
||||||
|
|
||||||
def _move_or_colide(self, agent: Agent, action: Action) -> Constants:
|
def _do_move_action(self, agent: Agent, action: Action) -> (dict, dict):
|
||||||
|
info_dict = dict()
|
||||||
new_tile, valid = self._check_agent_move(agent, action)
|
new_tile, valid = self._check_agent_move(agent, action)
|
||||||
if valid:
|
if valid:
|
||||||
# Does not collide width level boundaries
|
# Does not collide width level boundaries
|
||||||
return agent.move(new_tile)
|
valid = agent.move(new_tile)
|
||||||
|
if valid:
|
||||||
|
# This will spam your logs, beware!
|
||||||
|
self.print(f'{agent.name} just moved {action.identifier} from {agent.last_pos} to {agent.pos}.')
|
||||||
|
info_dict.update({f'{agent.name}_move': 1, 'move': 1})
|
||||||
|
pass
|
||||||
|
else:
|
||||||
|
valid = c.NOT_VALID
|
||||||
|
self.print(f'{agent.name} just hit the wall at {agent.pos}. ({action.identifier})')
|
||||||
|
info_dict.update({f'{agent.name}_wall_collide': 1, 'wall_collide': 1})
|
||||||
else:
|
else:
|
||||||
# Agent seems to be trying to collide in this step
|
# Agent seems to be trying to Leave the level
|
||||||
return c.NOT_VALID
|
self.print(f'{agent.name} tried to leave the level {agent.pos}. ({action.identifier})')
|
||||||
|
info_dict.update({f'{agent.name}_wall_collide': 1, 'wall_collide': 1})
|
||||||
|
reward_value = self.rewards_base.MOVEMENTS_VALID if valid else self.rewards_base.MOVEMENTS_FAIL
|
||||||
|
reward = {'value': reward_value, 'reason': action.identifier, 'info': info_dict}
|
||||||
|
return valid, reward
|
||||||
|
|
||||||
def _check_agent_move(self, agent, action: Action) -> (Tile, bool):
|
def _check_agent_move(self, agent, action: Action) -> (Floor, bool):
|
||||||
# Actions
|
# Actions
|
||||||
x_diff, y_diff = h.ACTIONMAP[action.identifier]
|
x_diff, y_diff = h.ACTIONMAP[action.identifier]
|
||||||
x_new = agent.x + x_diff
|
x_new = agent.x + x_diff
|
||||||
@ -478,7 +551,7 @@ class BaseFactory(gym.Env):
|
|||||||
if doors := self[c.DOORS]:
|
if doors := self[c.DOORS]:
|
||||||
if self.doors_have_area:
|
if self.doors_have_area:
|
||||||
if door := doors.by_pos(new_tile.pos):
|
if door := doors.by_pos(new_tile.pos):
|
||||||
if door.can_collide:
|
if door.is_closed:
|
||||||
return agent.tile, c.NOT_VALID
|
return agent.tile, c.NOT_VALID
|
||||||
else: # door.is_closed:
|
else: # door.is_closed:
|
||||||
pass
|
pass
|
||||||
@ -498,78 +571,61 @@ class BaseFactory(gym.Env):
|
|||||||
|
|
||||||
return new_tile, valid
|
return new_tile, valid
|
||||||
|
|
||||||
def calculate_reward(self) -> (int, dict):
|
def build_reward_result(self, global_env_rewards: list) -> (int, dict):
|
||||||
# Returns: Reward, Info
|
# Returns: Reward, Info
|
||||||
per_agent_info_dict = defaultdict(dict)
|
info = defaultdict(lambda: 0.0)
|
||||||
reward = {}
|
|
||||||
|
|
||||||
|
# Gather additional sub-env rewards and calculate collisions
|
||||||
for agent in self[c.AGENT]:
|
for agent in self[c.AGENT]:
|
||||||
per_agent_reward = 0
|
|
||||||
if self._actions.is_moving_action(agent.temp_action):
|
|
||||||
if agent.temp_valid:
|
|
||||||
# info_dict.update(movement=1)
|
|
||||||
per_agent_reward -= 0.001
|
|
||||||
pass
|
|
||||||
else:
|
|
||||||
per_agent_reward -= 0.05
|
|
||||||
self.print(f'{agent.name} just hit the wall at {agent.pos}.')
|
|
||||||
per_agent_info_dict[agent.name].update({f'{agent.name}_vs_LEVEL': 1})
|
|
||||||
|
|
||||||
elif h.EnvActions.USE_DOOR == agent.temp_action:
|
rewards = self.per_agent_reward_hook(agent)
|
||||||
if agent.temp_valid:
|
for reward in rewards:
|
||||||
# per_agent_reward += 0.00
|
agent.step_result['rewards'].append(reward)
|
||||||
self.print(f'{agent.name} did just use the door at {agent.pos}.')
|
if collisions := agent.step_result['collisions']:
|
||||||
per_agent_info_dict[agent.name].update(door_used=1)
|
self.print(f't = {self._steps}\t{agent.name} has collisions with {collisions}')
|
||||||
else:
|
info[c.COLLISION] += 1
|
||||||
# per_agent_reward -= 0.00
|
reward = {'value': self.rewards_base.COLLISION,
|
||||||
self.print(f'{agent.name} just tried to use a door at {agent.pos}, but failed.')
|
'reason': c.COLLISION,
|
||||||
per_agent_info_dict[agent.name].update({f'{agent.name}_failed_door_open': 1})
|
'info': {f'{agent.name}_{c.COLLISION}': 1}}
|
||||||
elif h.EnvActions.NOOP == agent.temp_action:
|
agent.step_result['rewards'].append(reward)
|
||||||
per_agent_info_dict[agent.name].update(no_op=1)
|
|
||||||
# per_agent_reward -= 0.00
|
|
||||||
|
|
||||||
# EnvMonitor Notes
|
|
||||||
if agent.temp_valid:
|
|
||||||
per_agent_info_dict[agent.name].update(valid_action=1)
|
|
||||||
per_agent_info_dict[agent.name].update({f'{agent.name}_valid_action': 1})
|
|
||||||
else:
|
else:
|
||||||
per_agent_info_dict[agent.name].update(failed_action=1)
|
# No Collisions, nothing to do
|
||||||
per_agent_info_dict[agent.name].update({f'{agent.name}_failed_action': 1})
|
pass
|
||||||
|
|
||||||
additional_reward, additional_info_dict = self.calculate_additional_reward(agent)
|
comb_rewards = {agent.name: sum(x['value'] for x in agent.step_result['rewards']) for agent in self[c.AGENT]}
|
||||||
per_agent_reward += additional_reward
|
|
||||||
per_agent_info_dict[agent.name].update(additional_info_dict)
|
|
||||||
|
|
||||||
if agent.temp_collisions:
|
|
||||||
self.print(f't = {self._steps}\t{agent.name} has collisions with {agent.temp_collisions}')
|
|
||||||
per_agent_info_dict[agent.name].update(collisions=1)
|
|
||||||
|
|
||||||
for other_agent in agent.temp_collisions:
|
|
||||||
per_agent_info_dict[agent.name].update({f'{agent.name}_vs_{other_agent.name}': 1})
|
|
||||||
reward[agent.name] = per_agent_reward
|
|
||||||
|
|
||||||
# Combine the per_agent_info_dict:
|
# Combine the per_agent_info_dict:
|
||||||
combined_info_dict = defaultdict(lambda: 0)
|
combined_info_dict = defaultdict(lambda: 0)
|
||||||
for info_dict in per_agent_info_dict.values():
|
for agent in self[c.AGENT]:
|
||||||
for key, value in info_dict.items():
|
for reward in agent.step_result['rewards']:
|
||||||
combined_info_dict[key] += value
|
combined_info_dict.update(reward['info'])
|
||||||
combined_info_dict = dict(combined_info_dict)
|
|
||||||
|
|
||||||
|
combined_info_dict = dict(combined_info_dict)
|
||||||
|
combined_info_dict.update(info)
|
||||||
|
|
||||||
|
global_reward_sum = sum(global_env_rewards)
|
||||||
if self.individual_rewards:
|
if self.individual_rewards:
|
||||||
self.print(f"rewards are {reward}")
|
self.print(f"rewards are {comb_rewards}")
|
||||||
reward = list(reward.values())
|
reward = list(comb_rewards.values())
|
||||||
|
reward = [x + global_reward_sum for x in reward]
|
||||||
return reward, combined_info_dict
|
return reward, combined_info_dict
|
||||||
else:
|
else:
|
||||||
reward = sum(reward.values())
|
reward = sum(comb_rewards.values()) + global_reward_sum
|
||||||
self.print(f"reward is {reward}")
|
self.print(f"reward is {reward}")
|
||||||
return reward, combined_info_dict
|
return reward, combined_info_dict
|
||||||
|
|
||||||
|
def start_recording(self):
|
||||||
|
self._record_episodes = True
|
||||||
|
|
||||||
|
def stop_recording(self):
|
||||||
|
self._record_episodes = False
|
||||||
|
|
||||||
# noinspection PyGlobalUndefined
|
# noinspection PyGlobalUndefined
|
||||||
def render(self, mode='human'):
|
def render(self, mode='human'):
|
||||||
if not self._renderer: # lazy init
|
if not self._renderer: # lazy init
|
||||||
from environments.factory.base.renderer import Renderer, RenderEntity
|
from environments.factory.base.renderer import Renderer, RenderEntity
|
||||||
global Renderer, RenderEntity
|
global Renderer, RenderEntity
|
||||||
height, width = self._obs_cube.shape[1:]
|
height, width = self._level_shape
|
||||||
self._renderer = Renderer(width, height, view_radius=self._pomdp_r, fps=5)
|
self._renderer = Renderer(width, height, view_radius=self._pomdp_r, fps=5)
|
||||||
|
|
||||||
# noinspection PyUnboundLocalVariable
|
# noinspection PyUnboundLocalVariable
|
||||||
@ -578,13 +634,13 @@ class BaseFactory(gym.Env):
|
|||||||
agents = []
|
agents = []
|
||||||
for i, agent in enumerate(self[c.AGENT]):
|
for i, agent in enumerate(self[c.AGENT]):
|
||||||
name, state = h.asset_str(agent)
|
name, state = h.asset_str(agent)
|
||||||
agents.append(RenderEntity(name, agent.pos, 1, 'none', state, i + 1, agent.temp_light_map))
|
agents.append(RenderEntity(name, agent.pos, 1, 'none', state, i + 1, agent.step_result['lightmap']))
|
||||||
doors = []
|
doors = []
|
||||||
if self.parse_doors:
|
if self.parse_doors:
|
||||||
for i, door in enumerate(self[c.DOORS]):
|
for i, door in enumerate(self[c.DOORS]):
|
||||||
name, state = 'door_open' if door.is_open else 'door_closed', 'blank'
|
name, state = 'door_open' if door.is_open else 'door_closed', 'blank'
|
||||||
doors.append(RenderEntity(name, door.pos, 1, 'none', state, i + 1))
|
doors.append(RenderEntity(name, door.pos, 1, 'none', state, i + 1))
|
||||||
additional_assets = self.render_additional_assets()
|
additional_assets = self.render_assets_hook()
|
||||||
|
|
||||||
return self._renderer.render(walls + doors + additional_assets + agents)
|
return self._renderer.render(walls + doors + additional_assets + agents)
|
||||||
|
|
||||||
@ -615,7 +671,8 @@ class BaseFactory(gym.Env):
|
|||||||
|
|
||||||
# Properties which are called by the base class to extend beyond attributes of the base class
|
# Properties which are called by the base class to extend beyond attributes of the base class
|
||||||
@property
|
@property
|
||||||
def additional_actions(self) -> Union[Action, List[Action]]:
|
@abc.abstractmethod
|
||||||
|
def actions_hook(self) -> Union[Action, List[Action]]:
|
||||||
"""
|
"""
|
||||||
When heriting from this Base Class, you musst implement this methode!!!
|
When heriting from this Base Class, you musst implement this methode!!!
|
||||||
|
|
||||||
@ -625,7 +682,8 @@ class BaseFactory(gym.Env):
|
|||||||
return []
|
return []
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def additional_entities(self) -> Dict[(Enum, Entities)]:
|
@abc.abstractmethod
|
||||||
|
def entities_hook(self) -> Dict[(str, Entities)]:
|
||||||
"""
|
"""
|
||||||
When heriting from this Base Class, you musst implement this methode!!!
|
When heriting from this Base Class, you musst implement this methode!!!
|
||||||
|
|
||||||
@ -637,49 +695,46 @@ class BaseFactory(gym.Env):
|
|||||||
# Functions which provide additions to functions of the base class
|
# Functions which provide additions to functions of the base class
|
||||||
# Always call super!!!!!!
|
# Always call super!!!!!!
|
||||||
@abc.abstractmethod
|
@abc.abstractmethod
|
||||||
def additional_obs_build(self) -> List[np.ndarray]:
|
def reset_hook(self) -> None:
|
||||||
return []
|
|
||||||
|
|
||||||
def additional_per_agent_obs_build(self, agent) -> List[np.ndarray]:
|
|
||||||
additional_per_agent_obs = []
|
|
||||||
if self.obs_prop.show_global_position_info:
|
|
||||||
pos_array = np.zeros(self.observation_space.shape[1:])
|
|
||||||
for xy in range(1):
|
|
||||||
pos_array[0, xy] = agent.pos[xy] / self._level_shape[xy]
|
|
||||||
additional_per_agent_obs.append(pos_array)
|
|
||||||
|
|
||||||
return additional_per_agent_obs
|
|
||||||
|
|
||||||
@abc.abstractmethod
|
|
||||||
def do_additional_reset(self) -> None:
|
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@abc.abstractmethod
|
@abc.abstractmethod
|
||||||
def do_additional_step(self) -> dict:
|
def pre_step_hook(self) -> None:
|
||||||
return {}
|
pass
|
||||||
|
|
||||||
@abc.abstractmethod
|
@abc.abstractmethod
|
||||||
def do_additional_actions(self, agent: Agent, action: Action) -> Union[None, c]:
|
def do_additional_actions(self, agent: Agent, action: Action) -> (bool, dict):
|
||||||
return None
|
return None
|
||||||
|
|
||||||
@abc.abstractmethod
|
@abc.abstractmethod
|
||||||
def check_additional_done(self) -> bool:
|
def step_hook(self) -> (List[dict], dict):
|
||||||
return False
|
return [], {}
|
||||||
|
|
||||||
@abc.abstractmethod
|
@abc.abstractmethod
|
||||||
def calculate_additional_reward(self, agent: Agent) -> (int, dict):
|
def check_additional_done(self) -> (bool, dict):
|
||||||
return 0, {}
|
return False, {}
|
||||||
|
|
||||||
@abc.abstractmethod
|
@abc.abstractmethod
|
||||||
def render_additional_assets(self):
|
def observations_hook(self) -> Dict[str, np.typing.ArrayLike]:
|
||||||
return []
|
|
||||||
|
|
||||||
# Hooks for in between operations.
|
|
||||||
# Always call super!!!!!!
|
|
||||||
@abc.abstractmethod
|
|
||||||
def hook_pre_step(self) -> None:
|
|
||||||
pass
|
|
||||||
|
|
||||||
@abc.abstractmethod
|
|
||||||
def hook_post_step(self) -> dict:
|
|
||||||
return {}
|
return {}
|
||||||
|
|
||||||
|
@abc.abstractmethod
|
||||||
|
def per_agent_reward_hook(self, agent: Agent) -> Dict[str, dict]:
|
||||||
|
return {}
|
||||||
|
|
||||||
|
@abc.abstractmethod
|
||||||
|
def post_step_hook(self) -> dict:
|
||||||
|
return {}
|
||||||
|
|
||||||
|
@abc.abstractmethod
|
||||||
|
def per_agent_raw_observations_hook(self, agent) -> Dict[str, np.typing.ArrayLike]:
|
||||||
|
additional_raw_observations = {}
|
||||||
|
if self.obs_prop.show_global_position_info:
|
||||||
|
global_pos_obs = np.zeros(self._obs_shape)
|
||||||
|
global_pos_obs[:2, 0] = self[c.GLOBAL_POSITION].by_entity(agent).encoding
|
||||||
|
additional_raw_observations.update({c.GLOBAL_POSITION: global_pos_obs})
|
||||||
|
return additional_raw_observations
|
||||||
|
|
||||||
|
@abc.abstractmethod
|
||||||
|
def render_assets_hook(self):
|
||||||
|
return []
|
||||||
|
@ -1,54 +1,51 @@
|
|||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from enum import Enum
|
|
||||||
from typing import Union
|
from typing import Union
|
||||||
|
|
||||||
import networkx as nx
|
import networkx as nx
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
from environments import helpers as h
|
from environments import helpers as h
|
||||||
from environments.helpers import Constants as c
|
from environments.helpers import Constants as c
|
||||||
import itertools
|
import itertools
|
||||||
|
|
||||||
|
##########################################################################
|
||||||
|
# ##################### Base Object Building Blocks ######################### #
|
||||||
|
##########################################################################
|
||||||
|
|
||||||
|
|
||||||
|
# TODO: Missing Documentation
|
||||||
class Object:
|
class Object:
|
||||||
|
|
||||||
|
"""Generell Objects for Organisation and Maintanance such as Actions etc..."""
|
||||||
|
|
||||||
_u_idx = defaultdict(lambda: 0)
|
_u_idx = defaultdict(lambda: 0)
|
||||||
|
|
||||||
def __bool__(self):
|
def __bool__(self):
|
||||||
return True
|
return True
|
||||||
|
|
||||||
@property
|
|
||||||
def is_blocking_light(self):
|
|
||||||
return self._is_blocking_light
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def name(self):
|
def name(self):
|
||||||
return self._name
|
return self._name
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def identifier(self):
|
def identifier(self):
|
||||||
if self._enum_ident is not None:
|
if self._str_ident is not None:
|
||||||
return self._enum_ident
|
|
||||||
elif self._str_ident is not None:
|
|
||||||
return self._str_ident
|
return self._str_ident
|
||||||
else:
|
else:
|
||||||
return self._name
|
return self._name
|
||||||
|
|
||||||
def __init__(self, str_ident: Union[str, None] = None, enum_ident: Union[Enum, None] = None,
|
def __init__(self, str_ident: Union[str, None] = None, **kwargs):
|
||||||
is_blocking_light=False, **kwargs):
|
|
||||||
|
|
||||||
self._str_ident = str_ident
|
self._str_ident = str_ident
|
||||||
self._enum_ident = enum_ident
|
|
||||||
|
|
||||||
if self._enum_ident is not None and self._str_ident is None:
|
if self._str_ident is not None:
|
||||||
self._name = f'{self.__class__.__name__}[{self._enum_ident.name}]'
|
|
||||||
elif self._str_ident is not None and self._enum_ident is None:
|
|
||||||
self._name = f'{self.__class__.__name__}[{self._str_ident}]'
|
self._name = f'{self.__class__.__name__}[{self._str_ident}]'
|
||||||
elif self._str_ident is None and self._enum_ident is None:
|
elif self._str_ident is None:
|
||||||
self._name = f'{self.__class__.__name__}#{self._u_idx[self.__class__.__name__]}'
|
self._name = f'{self.__class__.__name__}#{Object._u_idx[self.__class__.__name__]}'
|
||||||
Object._u_idx[self.__class__.__name__] += 1
|
Object._u_idx[self.__class__.__name__] += 1
|
||||||
else:
|
else:
|
||||||
raise ValueError('Please use either of the idents.')
|
raise ValueError('Please use either of the idents.')
|
||||||
|
|
||||||
self._is_blocking_light = is_blocking_light
|
|
||||||
if kwargs:
|
if kwargs:
|
||||||
print(f'Following kwargs were passed, but ignored: {kwargs}')
|
print(f'Following kwargs were passed, but ignored: {kwargs}')
|
||||||
|
|
||||||
@ -56,27 +53,44 @@ class Object:
|
|||||||
return f'{self.name}'
|
return f'{self.name}'
|
||||||
|
|
||||||
def __eq__(self, other) -> bool:
|
def __eq__(self, other) -> bool:
|
||||||
if self._enum_ident is not None:
|
return other == self.identifier
|
||||||
if isinstance(other, Enum):
|
# Base
|
||||||
return other == self._enum_ident
|
|
||||||
elif isinstance(other, Object):
|
|
||||||
return other._enum_ident == self._enum_ident
|
|
||||||
else:
|
|
||||||
raise ValueError('Must be evaluated against an Enunm Identifier or Object with such.')
|
|
||||||
else:
|
|
||||||
assert isinstance(other, Object), ' This Object can only be compared to other Objects.'
|
|
||||||
return other.name == self.name
|
|
||||||
|
|
||||||
|
|
||||||
class Entity(Object):
|
# TODO: Missing Documentation
|
||||||
|
class EnvObject(Object):
|
||||||
|
|
||||||
|
"""Objects that hold Information that are observable, but have no position on the env grid. Inventories etc..."""
|
||||||
|
|
||||||
|
_u_idx = defaultdict(lambda: 0)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def can_collide(self):
|
def can_collide(self):
|
||||||
return True
|
return False
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def encoding(self):
|
def encoding(self):
|
||||||
return c.OCCUPIED_CELL.value
|
return c.OCCUPIED_CELL
|
||||||
|
|
||||||
|
def __init__(self, register, **kwargs):
|
||||||
|
super(EnvObject, self).__init__(**kwargs)
|
||||||
|
self._register = register
|
||||||
|
|
||||||
|
def change_register(self, register):
|
||||||
|
register.register_item(self)
|
||||||
|
self._register.delete_env_object(self)
|
||||||
|
self._register = register
|
||||||
|
return self._register == register
|
||||||
|
# With Rendering
|
||||||
|
|
||||||
|
|
||||||
|
# TODO: Missing Documentation
|
||||||
|
class Entity(EnvObject):
|
||||||
|
"""Full Env Entity that lives on the env Grid. Doors, Items, Dirt etc..."""
|
||||||
|
|
||||||
|
@property
|
||||||
|
def can_collide(self):
|
||||||
|
return False
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def x(self):
|
def x(self):
|
||||||
@ -94,8 +108,8 @@ class Entity(Object):
|
|||||||
def tile(self):
|
def tile(self):
|
||||||
return self._tile
|
return self._tile
|
||||||
|
|
||||||
def __init__(self, tile, **kwargs):
|
def __init__(self, tile, *args, **kwargs):
|
||||||
super().__init__(**kwargs)
|
super().__init__(*args, **kwargs)
|
||||||
self._tile = tile
|
self._tile = tile
|
||||||
tile.enter(self)
|
tile.enter(self)
|
||||||
|
|
||||||
@ -104,9 +118,11 @@ class Entity(Object):
|
|||||||
tile=str(self.tile.name), can_collide=bool(self.can_collide))
|
tile=str(self.tile.name), can_collide=bool(self.can_collide))
|
||||||
|
|
||||||
def __repr__(self):
|
def __repr__(self):
|
||||||
return f'{self.name}(@{self.pos})'
|
return super(Entity, self).__repr__() + f'(@{self.pos})'
|
||||||
|
# With Position in Env
|
||||||
|
|
||||||
|
|
||||||
|
# TODO: Missing Documentation
|
||||||
class MoveableEntity(Entity):
|
class MoveableEntity(Entity):
|
||||||
|
|
||||||
@property
|
@property
|
||||||
@ -137,9 +153,36 @@ class MoveableEntity(Entity):
|
|||||||
curr_tile.leave(self)
|
curr_tile.leave(self)
|
||||||
self._tile = next_tile
|
self._tile = next_tile
|
||||||
self._last_tile = curr_tile
|
self._last_tile = curr_tile
|
||||||
return True
|
self._register.notify_change_to_value(self)
|
||||||
|
return c.VALID
|
||||||
else:
|
else:
|
||||||
return False
|
return c.NOT_VALID
|
||||||
|
# Can Move
|
||||||
|
|
||||||
|
|
||||||
|
# TODO: Missing Documentation
|
||||||
|
class BoundingMixin(Object):
|
||||||
|
|
||||||
|
@property
|
||||||
|
def bound_entity(self):
|
||||||
|
return self._bound_entity
|
||||||
|
|
||||||
|
def __init__(self,entity_to_be_bound, *args, **kwargs):
|
||||||
|
super(BoundingMixin, self).__init__(*args, **kwargs)
|
||||||
|
assert entity_to_be_bound is not None
|
||||||
|
self._bound_entity = entity_to_be_bound
|
||||||
|
|
||||||
|
@property
|
||||||
|
def name(self):
|
||||||
|
return f'{super(BoundingMixin, self).name}({self._bound_entity.name})'
|
||||||
|
|
||||||
|
def belongs_to_entity(self, entity):
|
||||||
|
return entity == self.bound_entity
|
||||||
|
|
||||||
|
|
||||||
|
##########################################################################
|
||||||
|
# ####################### Objects and Entitys ########################## #
|
||||||
|
##########################################################################
|
||||||
|
|
||||||
|
|
||||||
class Action(Object):
|
class Action(Object):
|
||||||
@ -148,34 +191,45 @@ class Action(Object):
|
|||||||
super().__init__(*args, **kwargs)
|
super().__init__(*args, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
class PlaceHolder(MoveableEntity):
|
class PlaceHolder(Object):
|
||||||
|
|
||||||
def __init__(self, *args, fill_value=0, **kwargs):
|
def __init__(self, *args, fill_value=0, **kwargs):
|
||||||
super().__init__(*args, **kwargs)
|
super().__init__(*args, **kwargs)
|
||||||
self._fill_value = fill_value
|
self._fill_value = fill_value
|
||||||
|
|
||||||
@property
|
|
||||||
def last_tile(self):
|
|
||||||
return self.tile
|
|
||||||
|
|
||||||
@property
|
|
||||||
def direction_of_view(self):
|
|
||||||
return self.pos
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def can_collide(self):
|
def can_collide(self):
|
||||||
return False
|
return False
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def encoding(self):
|
def encoding(self):
|
||||||
return c.NO_POS.value[0]
|
return self._fill_value
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def name(self):
|
def name(self):
|
||||||
return "PlaceHolder"
|
return "PlaceHolder"
|
||||||
|
|
||||||
|
|
||||||
class Tile(Object):
|
class GlobalPosition(BoundingMixin, EnvObject):
|
||||||
|
|
||||||
|
@property
|
||||||
|
def encoding(self):
|
||||||
|
if self._normalized:
|
||||||
|
return tuple(np.divide(self._bound_entity.pos, self._level_shape))
|
||||||
|
else:
|
||||||
|
return self.bound_entity.pos
|
||||||
|
|
||||||
|
def __init__(self, level_shape: (int, int), *args, normalized: bool = True, **kwargs):
|
||||||
|
super(GlobalPosition, self).__init__(*args, **kwargs)
|
||||||
|
self._level_shape = level_shape
|
||||||
|
self._normalized = normalized
|
||||||
|
|
||||||
|
|
||||||
|
class Floor(EnvObject):
|
||||||
|
|
||||||
|
@property
|
||||||
|
def encoding(self):
|
||||||
|
return c.FREE_CELL
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def guests_that_can_collide(self):
|
def guests_that_can_collide(self):
|
||||||
@ -197,8 +251,8 @@ class Tile(Object):
|
|||||||
def pos(self):
|
def pos(self):
|
||||||
return self._pos
|
return self._pos
|
||||||
|
|
||||||
def __init__(self, pos, **kwargs):
|
def __init__(self, pos, *args, **kwargs):
|
||||||
super(Tile, self).__init__(**kwargs)
|
super(Floor, self).__init__(*args, **kwargs)
|
||||||
self._guests = dict()
|
self._guests = dict()
|
||||||
self._pos = tuple(pos)
|
self._pos = tuple(pos)
|
||||||
|
|
||||||
@ -232,7 +286,16 @@ class Tile(Object):
|
|||||||
return dict(name=self.name, x=int(self.x), y=int(self.y))
|
return dict(name=self.name, x=int(self.x), y=int(self.y))
|
||||||
|
|
||||||
|
|
||||||
class Wall(Tile):
|
class Wall(Floor):
|
||||||
|
|
||||||
|
@property
|
||||||
|
def can_collide(self):
|
||||||
|
return True
|
||||||
|
|
||||||
|
@property
|
||||||
|
def encoding(self):
|
||||||
|
return c.OCCUPIED_CELL
|
||||||
|
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
@ -247,7 +310,8 @@ class Door(Entity):
|
|||||||
|
|
||||||
@property
|
@property
|
||||||
def encoding(self):
|
def encoding(self):
|
||||||
return 1 if self.is_closed else 2
|
# This is important as it shadow is checked by occupation value
|
||||||
|
return c.CLOSED_DOOR_CELL if self.is_closed else c.OPEN_DOOR_CELL
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def str_state(self):
|
def str_state(self):
|
||||||
@ -307,11 +371,13 @@ class Door(Entity):
|
|||||||
def _open(self):
|
def _open(self):
|
||||||
self.connectivity.add_edges_from([(self.pos, x) for x in range(len(self.connectivity_subgroups))])
|
self.connectivity.add_edges_from([(self.pos, x) for x in range(len(self.connectivity_subgroups))])
|
||||||
self._state = c.OPEN_DOOR
|
self._state = c.OPEN_DOOR
|
||||||
|
self._register.notify_change_to_value(self)
|
||||||
self.time_to_close = self.auto_close_interval
|
self.time_to_close = self.auto_close_interval
|
||||||
|
|
||||||
def _close(self):
|
def _close(self):
|
||||||
self.connectivity.remove_node(self.pos)
|
self.connectivity.remove_node(self.pos)
|
||||||
self._state = c.CLOSED_DOOR
|
self._state = c.CLOSED_DOOR
|
||||||
|
self._register.notify_change_to_value(self)
|
||||||
|
|
||||||
def is_linked(self, old_pos, new_pos):
|
def is_linked(self, old_pos, new_pos):
|
||||||
try:
|
try:
|
||||||
@ -323,20 +389,21 @@ class Door(Entity):
|
|||||||
|
|
||||||
class Agent(MoveableEntity):
|
class Agent(MoveableEntity):
|
||||||
|
|
||||||
|
@property
|
||||||
|
def can_collide(self):
|
||||||
|
return True
|
||||||
|
|
||||||
def __init__(self, *args, **kwargs):
|
def __init__(self, *args, **kwargs):
|
||||||
super(Agent, self).__init__(*args, **kwargs)
|
super(Agent, self).__init__(*args, **kwargs)
|
||||||
self.clear_temp_state()
|
self.clear_temp_state()
|
||||||
|
|
||||||
# noinspection PyAttributeOutsideInit
|
# noinspection PyAttributeOutsideInit
|
||||||
def clear_temp_state(self):
|
def clear_temp_state(self):
|
||||||
# for attr in self.__dict__:
|
# for attr in cls.__dict__:
|
||||||
# if attr.startswith('temp'):
|
# if attr.startswith('temp'):
|
||||||
self.temp_collisions = []
|
self.step_result = None
|
||||||
self.temp_valid = None
|
|
||||||
self.temp_action = None
|
|
||||||
self.temp_light_map = None
|
|
||||||
|
|
||||||
def summarize_state(self, **kwargs):
|
def summarize_state(self, **kwargs):
|
||||||
state_dict = super().summarize_state(**kwargs)
|
state_dict = super().summarize_state(**kwargs)
|
||||||
state_dict.update(valid=bool(self.temp_valid), action=str(self.temp_action))
|
state_dict.update(valid=bool(self.step_result['action_valid']), action=str(self.step_result['action_name']))
|
||||||
return state_dict
|
return state_dict
|
||||||
|
@ -1,18 +1,24 @@
|
|||||||
import numbers
|
import numbers
|
||||||
import random
|
import random
|
||||||
from abc import ABC
|
from abc import ABC
|
||||||
from typing import List, Union, Dict
|
from typing import List, Union, Dict, Tuple
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
import six
|
||||||
|
|
||||||
from environments.factory.base.objects import Entity, Tile, Agent, Door, Action, Wall, Object, PlaceHolder
|
from environments.factory.base.objects import Entity, Floor, Agent, Door, Action, Wall, PlaceHolder, GlobalPosition, \
|
||||||
|
Object, EnvObject
|
||||||
from environments.utility_classes import MovementProperties
|
from environments.utility_classes import MovementProperties
|
||||||
from environments import helpers as h
|
from environments import helpers as h
|
||||||
from environments.helpers import Constants as c
|
from environments.helpers import Constants as c
|
||||||
|
|
||||||
|
##########################################################################
|
||||||
|
# ##################### Base Register Definition ####################### #
|
||||||
|
##########################################################################
|
||||||
|
|
||||||
class Register:
|
|
||||||
_accepted_objects = Entity
|
class ObjectRegister:
|
||||||
|
_accepted_objects = Object
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def name(self):
|
def name(self):
|
||||||
@ -48,6 +54,12 @@ class Register:
|
|||||||
def items(self):
|
def items(self):
|
||||||
return self._register.items()
|
return self._register.items()
|
||||||
|
|
||||||
|
def _get_index(self, item):
|
||||||
|
try:
|
||||||
|
return next(i for i, v in enumerate(self._register.values()) if v == item)
|
||||||
|
except StopIteration:
|
||||||
|
return None
|
||||||
|
|
||||||
def __getitem__(self, item):
|
def __getitem__(self, item):
|
||||||
if isinstance(item, (int, np.int64, np.int32)):
|
if isinstance(item, (int, np.int64, np.int32)):
|
||||||
if item < 0:
|
if item < 0:
|
||||||
@ -62,42 +74,102 @@ class Register:
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
def __repr__(self):
|
def __repr__(self):
|
||||||
return f'{self.__class__.__name__}({self._register})'
|
return f'{self.__class__.__name__}[{self._register}]'
|
||||||
|
|
||||||
|
|
||||||
class ObjectRegister(Register):
|
class EnvObjectRegister(ObjectRegister):
|
||||||
|
|
||||||
hide_from_obs_builder = False
|
_accepted_objects = EnvObject
|
||||||
|
|
||||||
def __init__(self, level_shape: (int, int), *args, individual_slices=False, is_per_agent=False, **kwargs):
|
@property
|
||||||
super(ObjectRegister, self).__init__(*args, **kwargs)
|
def encodings(self):
|
||||||
self.is_per_agent = is_per_agent
|
return [x.encoding for x in self]
|
||||||
self.individual_slices = individual_slices
|
|
||||||
self._level_shape = level_shape
|
def __init__(self, obs_shape: (int, int), *args,
|
||||||
|
individual_slices: bool = False,
|
||||||
|
is_blocking_light: bool = False,
|
||||||
|
can_collide: bool = False,
|
||||||
|
can_be_shadowed: bool = True, **kwargs):
|
||||||
|
super(EnvObjectRegister, self).__init__(*args, **kwargs)
|
||||||
|
self._shape = obs_shape
|
||||||
self._array = None
|
self._array = None
|
||||||
|
self._individual_slices = individual_slices
|
||||||
|
self._lazy_eval_transforms = []
|
||||||
|
self.is_blocking_light = is_blocking_light
|
||||||
|
self.can_be_shadowed = can_be_shadowed
|
||||||
|
self.can_collide = can_collide
|
||||||
|
|
||||||
def register_item(self, other):
|
def register_item(self, other: EnvObject):
|
||||||
super(ObjectRegister, self).register_item(other)
|
super(EnvObjectRegister, self).register_item(other)
|
||||||
if self._array is None:
|
if self._array is None:
|
||||||
self._array = np.zeros((1, *self._level_shape))
|
self._array = np.zeros((1, *self._shape))
|
||||||
else:
|
else:
|
||||||
if self.individual_slices:
|
if self._individual_slices:
|
||||||
self._array = np.concatenate((self._array, np.zeros((1, *self._array.shape[1:]))))
|
self._array = np.vstack((self._array, np.zeros((1, *self._shape))))
|
||||||
|
self.notify_change_to_value(other)
|
||||||
|
|
||||||
|
def as_array(self):
|
||||||
|
if self._lazy_eval_transforms:
|
||||||
|
idxs, values = zip(*self._lazy_eval_transforms)
|
||||||
|
# nuumpy put repects the ordering so that
|
||||||
|
np.put(self._array, idxs, values)
|
||||||
|
self._lazy_eval_transforms = []
|
||||||
|
return self._array
|
||||||
|
|
||||||
def summarize_states(self, n_steps=None):
|
def summarize_states(self, n_steps=None):
|
||||||
return [val.summarize_state(n_steps=n_steps) for val in self.values()]
|
return [val.summarize_state(n_steps=n_steps) for val in self.values()]
|
||||||
|
|
||||||
|
def notify_change_to_free(self, env_object: EnvObject):
|
||||||
|
self._array_change_notifyer(env_object, value=c.FREE_CELL)
|
||||||
|
|
||||||
class EntityObjectRegister(ObjectRegister, ABC):
|
def notify_change_to_value(self, env_object: EnvObject):
|
||||||
|
self._array_change_notifyer(env_object)
|
||||||
|
|
||||||
def as_array(self):
|
def _array_change_notifyer(self, env_object: EnvObject, value=None):
|
||||||
raise NotImplementedError
|
pos = self._get_index(env_object)
|
||||||
|
value = value if value is not None else env_object.encoding
|
||||||
|
self._lazy_eval_transforms.append((pos, value))
|
||||||
|
if self._individual_slices:
|
||||||
|
idx = (self._get_index(env_object) * np.prod(self._shape[1:]), value)
|
||||||
|
self._lazy_eval_transforms.append((idx, value))
|
||||||
|
else:
|
||||||
|
self._lazy_eval_transforms.append((pos, value))
|
||||||
|
|
||||||
|
def _refresh_arrays(self):
|
||||||
|
poss, values = zip(*[(idx, x.encoding) for idx,x in enumerate(self.values())])
|
||||||
|
for pos, value in zip(poss, values):
|
||||||
|
self._lazy_eval_transforms.append((pos, value))
|
||||||
|
|
||||||
|
def __delitem__(self, name):
|
||||||
|
idx, obj = next((i, obj) for i, obj in enumerate(self) if obj.name == name)
|
||||||
|
if self._individual_slices:
|
||||||
|
self._array = np.delete(self._array, idx, axis=0)
|
||||||
|
else:
|
||||||
|
self.notify_change_to_free(self._register[name])
|
||||||
|
# Dirty Hack to check if not beeing subclassed. In that case we need to refresh the array since positions
|
||||||
|
# in the observation array are result of enumeration. They can overide each other.
|
||||||
|
# Todo: Find a better solution
|
||||||
|
if not issubclass(self.__class__, EntityRegister) and issubclass(self.__class__, EnvObjectRegister):
|
||||||
|
self._refresh_arrays()
|
||||||
|
del self._register[name]
|
||||||
|
|
||||||
|
def delete_env_object(self, env_object: EnvObject):
|
||||||
|
del self[env_object.name]
|
||||||
|
|
||||||
|
def delete_env_object_by_name(self, name):
|
||||||
|
del self[name]
|
||||||
|
|
||||||
|
|
||||||
|
class EntityRegister(EnvObjectRegister, ABC):
|
||||||
|
|
||||||
|
_accepted_objects = Entity
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_tiles(cls, tiles, *args, entity_kwargs=None, **kwargs):
|
def from_tiles(cls, tiles, *args, entity_kwargs=None, **kwargs):
|
||||||
# objects_name = cls._accepted_objects.__name__
|
# objects_name = cls._accepted_objects.__name__
|
||||||
register_obj = cls(*args, **kwargs)
|
register_obj = cls(*args, **kwargs)
|
||||||
entities = [cls._accepted_objects(tile, str_ident=i, **entity_kwargs if entity_kwargs is not None else {})
|
entities = [cls._accepted_objects(tile, register_obj, str_ident=i,
|
||||||
|
**entity_kwargs if entity_kwargs is not None else {})
|
||||||
for i, tile in enumerate(tiles)]
|
for i, tile in enumerate(tiles)]
|
||||||
register_obj.register_additional_items(entities)
|
register_obj.register_additional_items(entities)
|
||||||
return register_obj
|
return register_obj
|
||||||
@ -115,86 +187,168 @@ class EntityObjectRegister(ObjectRegister, ABC):
|
|||||||
def tiles(self):
|
def tiles(self):
|
||||||
return [entity.tile for entity in self]
|
return [entity.tile for entity in self]
|
||||||
|
|
||||||
def __init__(self, *args, is_blocking_light=False, is_observable=True, can_be_shadowed=True, **kwargs):
|
def __init__(self, level_shape, *args, **kwargs):
|
||||||
super(EntityObjectRegister, self).__init__(*args, **kwargs)
|
super(EntityRegister, self).__init__(level_shape, *args, **kwargs)
|
||||||
self.can_be_shadowed = can_be_shadowed
|
self._lazy_eval_transforms = []
|
||||||
self.is_blocking_light = is_blocking_light
|
|
||||||
self.is_observable = is_observable
|
|
||||||
|
|
||||||
def by_pos(self, pos):
|
def __delitem__(self, name):
|
||||||
if isinstance(pos, np.ndarray):
|
idx, obj = next((i, obj) for i, obj in enumerate(self) if obj.name == name)
|
||||||
pos = tuple(pos)
|
obj.tile.leave(obj)
|
||||||
|
super(EntityRegister, self).__delitem__(name)
|
||||||
|
|
||||||
|
def as_array(self):
|
||||||
|
if self._lazy_eval_transforms:
|
||||||
|
idxs, values = zip(*self._lazy_eval_transforms)
|
||||||
|
# numpy put repects the ordering so that
|
||||||
|
# Todo: Export the index building in a seperate function
|
||||||
|
np.put(self._array, [np.ravel_multi_index(idx, self._array.shape) for idx in idxs], values)
|
||||||
|
self._lazy_eval_transforms = []
|
||||||
|
return self._array
|
||||||
|
|
||||||
|
def _array_change_notifyer(self, entity, pos=None, value=None):
|
||||||
|
# Todo: Export the contruction in a seperate function
|
||||||
|
pos = pos if pos is not None else entity.pos
|
||||||
|
value = value if value is not None else entity.encoding
|
||||||
|
x, y = pos
|
||||||
|
if self._individual_slices:
|
||||||
|
idx = (self._get_index(entity), x, y)
|
||||||
|
else:
|
||||||
|
idx = (0, x, y)
|
||||||
|
self._lazy_eval_transforms.append((idx, value))
|
||||||
|
|
||||||
|
def by_pos(self, pos: Tuple[int, int]):
|
||||||
try:
|
try:
|
||||||
return next(item for item in self.values() if item.pos == pos)
|
return next(item for item in self if item.pos == tuple(pos))
|
||||||
except StopIteration:
|
except StopIteration:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
class MovingEntityObjectRegister(EntityObjectRegister, ABC):
|
class BoundEnvObjRegister(EnvObjectRegister, ABC):
|
||||||
|
|
||||||
|
def __init__(self, entity_to_be_bound, *args, **kwargs):
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
self._bound_entity = entity_to_be_bound
|
||||||
|
|
||||||
|
def belongs_to_entity(self, entity):
|
||||||
|
return self._bound_entity == entity
|
||||||
|
|
||||||
|
def by_entity(self, entity):
|
||||||
|
try:
|
||||||
|
return next((x for x in self if x.belongs_to_entity(entity)))
|
||||||
|
except StopIteration:
|
||||||
|
return None
|
||||||
|
|
||||||
|
def idx_by_entity(self, entity):
|
||||||
|
try:
|
||||||
|
return next((idx for idx, x in enumerate(self) if x.belongs_to_entity(entity)))
|
||||||
|
except StopIteration:
|
||||||
|
return None
|
||||||
|
|
||||||
|
def as_array_by_entity(self, entity):
|
||||||
|
return self._array[self.idx_by_entity(entity)]
|
||||||
|
|
||||||
|
|
||||||
|
class MovingEntityObjectRegister(EntityRegister, ABC):
|
||||||
|
|
||||||
def __init__(self, *args, **kwargs):
|
def __init__(self, *args, **kwargs):
|
||||||
super(MovingEntityObjectRegister, self).__init__(*args, **kwargs)
|
super(MovingEntityObjectRegister, self).__init__(*args, **kwargs)
|
||||||
|
|
||||||
def by_pos(self, pos):
|
def notify_change_to_value(self, entity):
|
||||||
if isinstance(pos, np.ndarray):
|
super(MovingEntityObjectRegister, self).notify_change_to_value(entity)
|
||||||
pos = tuple(pos)
|
if entity.last_pos != c.NO_POS:
|
||||||
|
try:
|
||||||
|
self._array_change_notifyer(entity, entity.last_pos, value=c.FREE_CELL)
|
||||||
|
except AttributeError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
##########################################################################
|
||||||
|
# ################# Objects and Entity Registers ####################### #
|
||||||
|
##########################################################################
|
||||||
|
|
||||||
|
|
||||||
|
class GlobalPositions(EnvObjectRegister):
|
||||||
|
|
||||||
|
_accepted_objects = GlobalPosition
|
||||||
|
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
super(GlobalPositions, self).__init__(*args, is_per_agent=True, individual_slices=True, is_blocking_light = False,
|
||||||
|
can_be_shadowed = False, can_collide = False, **kwargs)
|
||||||
|
|
||||||
|
def as_array(self):
|
||||||
|
# FIXME DEBUG!!! make this lazy?
|
||||||
|
return np.stack([gp.as_array() for inv_idx, gp in enumerate(self)])
|
||||||
|
|
||||||
|
def as_array_by_entity(self, entity):
|
||||||
|
# FIXME DEBUG!!! make this lazy?
|
||||||
|
return np.stack([gp.as_array() for inv_idx, gp in enumerate(self)])
|
||||||
|
|
||||||
|
def spawn_global_position_objects(self, agents):
|
||||||
|
# Todo, change to 'from xy'-form
|
||||||
|
global_positions = [self._accepted_objects(self._shape, agent, self)
|
||||||
|
for _, agent in enumerate(agents)]
|
||||||
|
# noinspection PyTypeChecker
|
||||||
|
self.register_additional_items(global_positions)
|
||||||
|
|
||||||
|
def summarize_states(self, n_steps=None):
|
||||||
|
return {}
|
||||||
|
|
||||||
|
def idx_by_entity(self, entity):
|
||||||
try:
|
try:
|
||||||
return next(x for x in self if x.pos == pos)
|
return next((idx for idx, inv in enumerate(self) if inv.belongs_to_entity(entity)))
|
||||||
except StopIteration:
|
except StopIteration:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def __delitem__(self, name):
|
def by_entity(self, entity):
|
||||||
idx = next(i for i, entity in enumerate(self) if entity.name == name)
|
try:
|
||||||
del self._register[name]
|
return next((inv for inv in self if inv.belongs_to_entity(entity)))
|
||||||
if self.individual_slices:
|
except StopIteration:
|
||||||
self._array = np.delete(self._array, idx, axis=0)
|
return None
|
||||||
|
|
||||||
def delete_entity(self, item):
|
|
||||||
self.delete_entity_by_name(item.name)
|
|
||||||
|
|
||||||
def delete_entity_by_name(self, name):
|
|
||||||
del self[name]
|
|
||||||
|
|
||||||
|
|
||||||
class PlaceHolders(MovingEntityObjectRegister):
|
class PlaceHolders(EnvObjectRegister):
|
||||||
|
|
||||||
_accepted_objects = PlaceHolder
|
_accepted_objects = PlaceHolder
|
||||||
|
|
||||||
def __init__(self, *args, fill_value: Union[str, int] = 0, **kwargs):
|
def __init__(self, *args, **kwargs):
|
||||||
|
assert 'individual_slices' not in kwargs, 'Keyword - "individual_slices": "True" and must not be altered'
|
||||||
|
kwargs.update(individual_slices=False)
|
||||||
super().__init__(*args, **kwargs)
|
super().__init__(*args, **kwargs)
|
||||||
self.fill_value = fill_value
|
|
||||||
|
@classmethod
|
||||||
|
def from_values(cls, values: Union[str, numbers.Number, List[Union[str, numbers.Number]]],
|
||||||
|
*args, object_kwargs=None, **kwargs):
|
||||||
|
# objects_name = cls._accepted_objects.__name__
|
||||||
|
if isinstance(values, (str, numbers.Number)):
|
||||||
|
values = [values]
|
||||||
|
register_obj = cls(*args, **kwargs)
|
||||||
|
objects = [cls._accepted_objects(register_obj, str_ident=i, fill_value=value,
|
||||||
|
**object_kwargs if object_kwargs is not None else {})
|
||||||
|
for i, value in enumerate(values)]
|
||||||
|
register_obj.register_additional_items(objects)
|
||||||
|
return register_obj
|
||||||
|
|
||||||
# noinspection DuplicatedCode
|
# noinspection DuplicatedCode
|
||||||
def as_array(self):
|
def as_array(self):
|
||||||
if isinstance(self.fill_value, numbers.Number):
|
for idx, placeholder in enumerate(self):
|
||||||
self._array[:] = self.fill_value
|
if isinstance(placeholder.encoding, numbers.Number):
|
||||||
elif isinstance(self.fill_value, str):
|
self._array[idx][:] = placeholder.fill_value
|
||||||
if self.fill_value.lower() in ['normal', 'n']:
|
elif isinstance(placeholder.fill_value, str):
|
||||||
self._array = np.random.normal(size=self._array.shape)
|
if placeholder.fill_value.lower() in ['normal', 'n']:
|
||||||
|
self._array[:] = np.random.normal(size=self._array.shape)
|
||||||
|
else:
|
||||||
|
raise ValueError('Choose one of: ["normal", "N"]')
|
||||||
else:
|
else:
|
||||||
raise ValueError('Choose one of: ["normal", "N"]')
|
raise TypeError('Objects of type "str" or "number" is required here.')
|
||||||
else:
|
|
||||||
raise TypeError('Objects of type "str" or "number" is required here.')
|
|
||||||
|
|
||||||
if self.individual_slices:
|
return self._array
|
||||||
return self._array
|
|
||||||
else:
|
|
||||||
return self._array[None, 0]
|
|
||||||
|
|
||||||
|
|
||||||
class Entities(Register):
|
class Entities(ObjectRegister):
|
||||||
|
_accepted_objects = EntityRegister
|
||||||
_accepted_objects = EntityObjectRegister
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def observable_arrays(self):
|
def arrays(self):
|
||||||
# FIXME: Find a better name
|
return {key: val.as_array() for key, val in self.items()}
|
||||||
return {key: val.as_array() for key, val in self.items() if val.is_observable}
|
|
||||||
|
|
||||||
@property
|
|
||||||
def obs_arrays(self):
|
|
||||||
# FIXME: Find a better name
|
|
||||||
return {key: val.as_array() for key, val in self.items() if val.is_observable and not val.hide_from_obs_builder}
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def names(self):
|
def names(self):
|
||||||
@ -220,34 +374,30 @@ class Entities(Register):
|
|||||||
return found_entities
|
return found_entities
|
||||||
|
|
||||||
|
|
||||||
class WallTiles(EntityObjectRegister):
|
class Walls(EntityRegister):
|
||||||
_accepted_objects = Wall
|
_accepted_objects = Wall
|
||||||
_light_blocking = True
|
|
||||||
|
|
||||||
def as_array(self):
|
def as_array(self):
|
||||||
if not np.any(self._array):
|
if not np.any(self._array):
|
||||||
|
# Which is Faster?
|
||||||
|
# indices = [x.pos for x in cls]
|
||||||
|
# np.put(cls._array, [np.ravel_multi_index((0, *x), cls._array.shape) for x in indices], cls.encodings)
|
||||||
x, y = zip(*[x.pos for x in self])
|
x, y = zip(*[x.pos for x in self])
|
||||||
self._array[0, x, y] = self.encoding
|
self._array[0, x, y] = self._value
|
||||||
return self._array
|
return self._array
|
||||||
|
|
||||||
def __init__(self, *args, **kwargs):
|
def __init__(self, *args, is_blocking_light=True, **kwargs):
|
||||||
super(WallTiles, self).__init__(*args, individual_slices=False,
|
super(Walls, self).__init__(*args, individual_slices=False,
|
||||||
is_blocking_light=self._light_blocking, **kwargs)
|
can_collide=True,
|
||||||
|
is_blocking_light=is_blocking_light, **kwargs)
|
||||||
@property
|
self._value = c.OCCUPIED_CELL
|
||||||
def encoding(self):
|
|
||||||
return c.OCCUPIED_CELL.value
|
|
||||||
|
|
||||||
@property
|
|
||||||
def array(self):
|
|
||||||
return self._array
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_argwhere_coordinates(cls, argwhere_coordinates, *args, **kwargs):
|
def from_argwhere_coordinates(cls, argwhere_coordinates, *args, **kwargs):
|
||||||
tiles = cls(*args, **kwargs)
|
tiles = cls(*args, **kwargs)
|
||||||
# noinspection PyTypeChecker
|
# noinspection PyTypeChecker
|
||||||
tiles.register_additional_items(
|
tiles.register_additional_items(
|
||||||
[cls._accepted_objects(pos, is_blocking_light=cls._light_blocking)
|
[cls._accepted_objects(pos, tiles)
|
||||||
for pos in argwhere_coordinates]
|
for pos in argwhere_coordinates]
|
||||||
)
|
)
|
||||||
return tiles
|
return tiles
|
||||||
@ -258,22 +408,17 @@ class WallTiles(EntityObjectRegister):
|
|||||||
|
|
||||||
def summarize_states(self, n_steps=None):
|
def summarize_states(self, n_steps=None):
|
||||||
if n_steps == h.STEPS_START:
|
if n_steps == h.STEPS_START:
|
||||||
return super(WallTiles, self).summarize_states(n_steps=n_steps)
|
return super(Walls, self).summarize_states(n_steps=n_steps)
|
||||||
else:
|
else:
|
||||||
return {}
|
return {}
|
||||||
|
|
||||||
|
|
||||||
class FloorTiles(WallTiles):
|
class Floors(Walls):
|
||||||
|
_accepted_objects = Floor
|
||||||
|
|
||||||
_accepted_objects = Tile
|
def __init__(self, *args, is_blocking_light=False, **kwargs):
|
||||||
_light_blocking = False
|
super(Floors, self).__init__(*args, is_blocking_light=is_blocking_light, **kwargs)
|
||||||
|
self._value = c.FREE_CELL
|
||||||
def __init__(self, *args, **kwargs):
|
|
||||||
super(FloorTiles, self).__init__(*args, is_observable=False, **kwargs)
|
|
||||||
|
|
||||||
@property
|
|
||||||
def encoding(self):
|
|
||||||
return c.FREE_CELL.value
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def occupied_tiles(self):
|
def occupied_tiles(self):
|
||||||
@ -282,7 +427,7 @@ class FloorTiles(WallTiles):
|
|||||||
return tiles
|
return tiles
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def empty_tiles(self) -> List[Tile]:
|
def empty_tiles(self) -> List[Floor]:
|
||||||
tiles = [tile for tile in self if tile.is_empty()]
|
tiles = [tile for tile in self if tile.is_empty()]
|
||||||
random.shuffle(tiles)
|
random.shuffle(tiles)
|
||||||
return tiles
|
return tiles
|
||||||
@ -297,26 +442,10 @@ class FloorTiles(WallTiles):
|
|||||||
|
|
||||||
|
|
||||||
class Agents(MovingEntityObjectRegister):
|
class Agents(MovingEntityObjectRegister):
|
||||||
|
|
||||||
_accepted_objects = Agent
|
_accepted_objects = Agent
|
||||||
|
|
||||||
def __init__(self, *args, hide_from_obs_builder=False, **kwargs):
|
def __init__(self, *args, **kwargs):
|
||||||
super().__init__(*args, **kwargs)
|
super().__init__(*args, can_collide=True, **kwargs)
|
||||||
self.hide_from_obs_builder = hide_from_obs_builder
|
|
||||||
|
|
||||||
# noinspection DuplicatedCode
|
|
||||||
def as_array(self):
|
|
||||||
self._array[:] = c.FREE_CELL.value
|
|
||||||
# noinspection PyTupleAssignmentBalance
|
|
||||||
for z, x, y, v in zip(range(len(self)), *zip(*[x.pos for x in self]), [x.encoding for x in self]):
|
|
||||||
if self.individual_slices:
|
|
||||||
self._array[z, x, y] += v
|
|
||||||
else:
|
|
||||||
self._array[0, x, y] += v
|
|
||||||
if self.individual_slices:
|
|
||||||
return self._array
|
|
||||||
else:
|
|
||||||
return self._array.sum(axis=0, keepdims=True)
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def positions(self):
|
def positions(self):
|
||||||
@ -329,16 +458,12 @@ class Agents(MovingEntityObjectRegister):
|
|||||||
self._register[agent.name] = agent
|
self._register[agent.name] = agent
|
||||||
|
|
||||||
|
|
||||||
class Doors(EntityObjectRegister):
|
class Doors(EntityRegister):
|
||||||
|
|
||||||
def __init__(self, *args, **kwargs):
|
def __init__(self, *args, have_area: bool = False, **kwargs):
|
||||||
super(Doors, self).__init__(*args, is_blocking_light=True, **kwargs)
|
self.have_area = have_area
|
||||||
|
self._area_marked = False
|
||||||
def as_array(self):
|
super(Doors, self).__init__(*args, is_blocking_light=True, can_collide=True, **kwargs)
|
||||||
self._array[:] = 0
|
|
||||||
for door in self:
|
|
||||||
self._array[0, door.x, door.y] = door.encoding
|
|
||||||
return self._array
|
|
||||||
|
|
||||||
_accepted_objects = Door
|
_accepted_objects = Door
|
||||||
|
|
||||||
@ -352,9 +477,20 @@ class Doors(EntityObjectRegister):
|
|||||||
for door in self:
|
for door in self:
|
||||||
door.tick()
|
door.tick()
|
||||||
|
|
||||||
|
def as_array(self):
|
||||||
|
if self.have_area and not self._area_marked:
|
||||||
|
for door in self:
|
||||||
|
for pos in door.access_area:
|
||||||
|
if self._individual_slices:
|
||||||
|
pass
|
||||||
|
else:
|
||||||
|
pos = (0, *pos)
|
||||||
|
self._lazy_eval_transforms.append((pos, c.ACCESS_DOOR_CELL))
|
||||||
|
self._area_marked = True
|
||||||
|
return super(Doors, self).as_array()
|
||||||
|
|
||||||
class Actions(Register):
|
|
||||||
|
|
||||||
|
class Actions(ObjectRegister):
|
||||||
_accepted_objects = Action
|
_accepted_objects = Action
|
||||||
|
|
||||||
@property
|
@property
|
||||||
@ -369,27 +505,28 @@ class Actions(Register):
|
|||||||
self.can_use_doors = can_use_doors
|
self.can_use_doors = can_use_doors
|
||||||
super(Actions, self).__init__()
|
super(Actions, self).__init__()
|
||||||
|
|
||||||
|
# Move this to Baseclass, Env init?
|
||||||
if self.allow_square_movement:
|
if self.allow_square_movement:
|
||||||
self.register_additional_items([self._accepted_objects(enum_ident=direction)
|
self.register_additional_items([self._accepted_objects(str_ident=direction)
|
||||||
for direction in h.MovingAction.square()])
|
for direction in h.EnvActions.square_move()])
|
||||||
if self.allow_diagonal_movement:
|
if self.allow_diagonal_movement:
|
||||||
self.register_additional_items([self._accepted_objects(enum_ident=direction)
|
self.register_additional_items([self._accepted_objects(str_ident=direction)
|
||||||
for direction in h.MovingAction.diagonal()])
|
for direction in h.EnvActions.diagonal_move()])
|
||||||
self._movement_actions = self._register.copy()
|
self._movement_actions = self._register.copy()
|
||||||
if self.can_use_doors:
|
if self.can_use_doors:
|
||||||
self.register_additional_items([self._accepted_objects(enum_ident=h.EnvActions.USE_DOOR)])
|
self.register_additional_items([self._accepted_objects(str_ident=h.EnvActions.USE_DOOR)])
|
||||||
if self.allow_no_op:
|
if self.allow_no_op:
|
||||||
self.register_additional_items([self._accepted_objects(enum_ident=h.EnvActions.NOOP)])
|
self.register_additional_items([self._accepted_objects(str_ident=h.EnvActions.NOOP)])
|
||||||
|
|
||||||
def is_moving_action(self, action: Union[int]):
|
def is_moving_action(self, action: Union[int]):
|
||||||
return action in self.movement_actions.values()
|
return action in self.movement_actions.values()
|
||||||
|
|
||||||
|
|
||||||
class Zones(Register):
|
class Zones(ObjectRegister):
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def accounting_zones(self):
|
def accounting_zones(self):
|
||||||
return [self[idx] for idx, name in self.items() if name != c.DANGER_ZONE.value]
|
return [self[idx] for idx, name in self.items() if name != c.DANGER_ZONE]
|
||||||
|
|
||||||
def __init__(self, parsed_level):
|
def __init__(self, parsed_level):
|
||||||
raise NotImplementedError('This needs a Rework')
|
raise NotImplementedError('This needs a Rework')
|
||||||
@ -398,9 +535,9 @@ class Zones(Register):
|
|||||||
self._accounting_zones = list()
|
self._accounting_zones = list()
|
||||||
self._danger_zones = list()
|
self._danger_zones = list()
|
||||||
for symbol in np.unique(parsed_level):
|
for symbol in np.unique(parsed_level):
|
||||||
if symbol == c.WALL.value:
|
if symbol == c.WALL:
|
||||||
continue
|
continue
|
||||||
elif symbol == c.DANGER_ZONE.value:
|
elif symbol == c.DANGER_ZONE:
|
||||||
self + symbol
|
self + symbol
|
||||||
slices.append(h.one_hot_level(parsed_level, symbol))
|
slices.append(h.one_hot_level(parsed_level, symbol))
|
||||||
self._danger_zones.append(symbol)
|
self._danger_zones.append(symbol)
|
||||||
|
@ -2,6 +2,7 @@ import numpy as np
|
|||||||
|
|
||||||
from environments.helpers import Constants as c
|
from environments.helpers import Constants as c
|
||||||
|
|
||||||
|
# Multipliers for transforming coordinates to other octants:
|
||||||
mult_array = np.asarray([
|
mult_array = np.asarray([
|
||||||
[1, 0, 0, -1, -1, 0, 0, 1],
|
[1, 0, 0, -1, -1, 0, 0, 1],
|
||||||
[0, 1, -1, 0, 0, -1, 1, 0],
|
[0, 1, -1, 0, 0, -1, 1, 0],
|
||||||
@ -11,19 +12,17 @@ mult_array = np.asarray([
|
|||||||
|
|
||||||
|
|
||||||
class Map(object):
|
class Map(object):
|
||||||
# Multipliers for transforming coordinates to other octants:
|
def __init__(self, map_array: np.typing.ArrayLike, diamond_slope: float = 0.9):
|
||||||
|
|
||||||
def __init__(self, map_array: np.ndarray, diamond_slope: float = 0.9):
|
|
||||||
self.data = map_array
|
self.data = map_array
|
||||||
self.width, self.height = map_array.shape
|
self.width, self.height = map_array.shape
|
||||||
self.light = np.full_like(self.data, c.FREE_CELL.value)
|
self.light = np.full_like(self.data, c.FREE_CELL)
|
||||||
self.flag = c.FREE_CELL.value
|
self.flag = c.FREE_CELL
|
||||||
self.d_slope = diamond_slope
|
self.d_slope = diamond_slope
|
||||||
|
|
||||||
def blocked(self, x, y):
|
def blocked(self, x, y):
|
||||||
return (x < 0 or y < 0
|
return (x < 0 or y < 0
|
||||||
or x >= self.width or y >= self.height
|
or x >= self.width or y >= self.height
|
||||||
or self.data[x, y] == c.OCCUPIED_CELL.value)
|
or self.data[x, y] == c.OCCUPIED_CELL)
|
||||||
|
|
||||||
def lit(self, x, y):
|
def lit(self, x, y):
|
||||||
return self.light[x, y] == self.flag
|
return self.light[x, y] == self.flag
|
||||||
@ -33,7 +32,7 @@ class Map(object):
|
|||||||
self.light[x, y] = self.flag
|
self.light[x, y] = self.flag
|
||||||
|
|
||||||
def _cast_light(self, cx, cy, row, start, end, radius, xx, xy, yx, yy, id):
|
def _cast_light(self, cx, cy, row, start, end, radius, xx, xy, yx, yy, id):
|
||||||
"Recursive lightcasting function"
|
"""Recursive lightcasting function"""
|
||||||
if start < end:
|
if start < end:
|
||||||
return
|
return
|
||||||
radius_squared = radius*radius
|
radius_squared = radius*radius
|
||||||
@ -46,14 +45,14 @@ class Map(object):
|
|||||||
# Translate the dx, dy coordinates into map coordinates:
|
# Translate the dx, dy coordinates into map coordinates:
|
||||||
X, Y = cx + dx * xx + dy * xy, cy + dx * yx + dy * yy
|
X, Y = cx + dx * xx + dy * xy, cy + dx * yx + dy * yy
|
||||||
# l_slope and r_slope store the slopes of the left and right
|
# l_slope and r_slope store the slopes of the left and right
|
||||||
# extremities of the square we're considering:
|
# extremities of the square_move we're considering:
|
||||||
l_slope, r_slope = (dx-self.d_slope)/(dy+self.d_slope), (dx+self.d_slope)/(dy-self.d_slope)
|
l_slope, r_slope = (dx-self.d_slope)/(dy+self.d_slope), (dx+self.d_slope)/(dy-self.d_slope)
|
||||||
if start < r_slope:
|
if start < r_slope:
|
||||||
continue
|
continue
|
||||||
elif end > l_slope:
|
elif end > l_slope:
|
||||||
break
|
break
|
||||||
else:
|
else:
|
||||||
# Our light beam is touching this square; light it:
|
# Our light beam is touching this square_move; light it:
|
||||||
if dx*dx + dy*dy < radius_squared:
|
if dx*dx + dy*dy < radius_squared:
|
||||||
self.set_lit(X, Y)
|
self.set_lit(X, Y)
|
||||||
if blocked:
|
if blocked:
|
||||||
@ -66,12 +65,12 @@ class Map(object):
|
|||||||
start = new_start
|
start = new_start
|
||||||
else:
|
else:
|
||||||
if self.blocked(X, Y) and j < radius:
|
if self.blocked(X, Y) and j < radius:
|
||||||
# This is a blocking square, start a child scan:
|
# This is a blocking square_move, start a child scan:
|
||||||
blocked = True
|
blocked = True
|
||||||
self._cast_light(cx, cy, j+1, start, l_slope,
|
self._cast_light(cx, cy, j+1, start, l_slope,
|
||||||
radius, xx, xy, yx, yy, id+1)
|
radius, xx, xy, yx, yy, id+1)
|
||||||
new_start = r_slope
|
new_start = r_slope
|
||||||
# Row is scanned; do next row unless last square was blocked:
|
# Row is scanned; do next row unless last square_move was blocked:
|
||||||
if blocked:
|
if blocked:
|
||||||
break
|
break
|
||||||
|
|
||||||
|
@ -1,6 +1,7 @@
|
|||||||
import random
|
import random
|
||||||
|
|
||||||
from environments.factory.factory_battery import BatteryFactory, BatteryProperties
|
from environments.factory.factory_battery import BatteryFactory, BatteryProperties
|
||||||
|
from environments.factory.factory_dest import DestFactory
|
||||||
from environments.factory.factory_dirt import DirtFactory, DirtProperties
|
from environments.factory.factory_dirt import DirtFactory, DirtProperties
|
||||||
from environments.factory.factory_item import ItemFactory
|
from environments.factory.factory_item import ItemFactory
|
||||||
|
|
||||||
@ -17,6 +18,12 @@ class DirtBatteryFactory(DirtFactory, BatteryFactory):
|
|||||||
super().__init__(*args, **kwargs)
|
super().__init__(*args, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
# noinspection PyAbstractClass
|
||||||
|
class DirtDestItemFactory(ItemFactory, DirtFactory, DestFactory):
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
from environments.utility_classes import AgentRenderOptions as ARO, ObservationProperties
|
from environments.utility_classes import AgentRenderOptions as ARO, ObservationProperties
|
||||||
|
|
||||||
|
@ -1,18 +1,33 @@
|
|||||||
from typing import Union, NamedTuple
|
from typing import Union, NamedTuple, Dict, List
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from environments.factory.base.base_factory import BaseFactory
|
from environments.factory.base.base_factory import BaseFactory
|
||||||
from environments.factory.base.objects import Agent, Action, Entity
|
from environments.factory.base.objects import Agent, Action, Entity, EnvObject, BoundingMixin
|
||||||
from environments.factory.base.registers import EntityObjectRegister, ObjectRegister
|
from environments.factory.base.registers import EntityRegister, EnvObjectRegister
|
||||||
from environments.factory.base.renderer import RenderEntity
|
from environments.factory.base.renderer import RenderEntity
|
||||||
from environments.helpers import Constants as c
|
from environments.helpers import Constants as BaseConstants
|
||||||
|
from environments.helpers import EnvActions as BaseActions
|
||||||
|
|
||||||
from environments import helpers as h
|
from environments import helpers as h
|
||||||
|
|
||||||
|
|
||||||
CHARGE_ACTION = h.EnvActions.CHARGE
|
class Constants(BaseConstants):
|
||||||
ITEM_DROP_OFF = 1
|
# Battery Env
|
||||||
|
CHARGE_PODS = 'Charge_Pod'
|
||||||
|
BATTERIES = 'BATTERIES'
|
||||||
|
BATTERY_DISCHARGED = 'DISCHARGED'
|
||||||
|
CHARGE_POD = 1
|
||||||
|
|
||||||
|
|
||||||
|
class Actions(BaseActions):
|
||||||
|
CHARGE = 'do_charge_action'
|
||||||
|
|
||||||
|
|
||||||
|
class RewardsBtry(NamedTuple):
|
||||||
|
CHARGE_VALID: float = 0.1
|
||||||
|
CHARGE_FAIL: float = -0.1
|
||||||
|
BATTERY_DISCHARGED: float = -1.0
|
||||||
|
|
||||||
|
|
||||||
class BatteryProperties(NamedTuple):
|
class BatteryProperties(NamedTuple):
|
||||||
@ -24,44 +39,24 @@ class BatteryProperties(NamedTuple):
|
|||||||
multi_charge: bool = False
|
multi_charge: bool = False
|
||||||
|
|
||||||
|
|
||||||
class Battery(object):
|
c = Constants
|
||||||
|
a = Actions
|
||||||
|
|
||||||
|
|
||||||
|
class Battery(BoundingMixin, EnvObject):
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def is_discharged(self):
|
def is_discharged(self):
|
||||||
return self.charge_level == 0
|
return self.charge_level == 0
|
||||||
|
|
||||||
@property
|
def __init__(self, initial_charge_level: float, *args, **kwargs):
|
||||||
def is_blocking_light(self):
|
super(Battery, self).__init__(*args, **kwargs)
|
||||||
return False
|
|
||||||
|
|
||||||
@property
|
|
||||||
def can_collide(self):
|
|
||||||
return False
|
|
||||||
|
|
||||||
@property
|
|
||||||
def name(self):
|
|
||||||
return f'{self.__class__.__name__}({self.agent.name})'
|
|
||||||
|
|
||||||
def __init__(self, pomdp_r: int, level_shape: (int, int), agent: Agent, initial_charge_level: float):
|
|
||||||
super().__init__()
|
|
||||||
self.agent = agent
|
|
||||||
self._pomdp_r = pomdp_r
|
|
||||||
self._level_shape = level_shape
|
|
||||||
if self._pomdp_r:
|
|
||||||
self._array = np.zeros((1, pomdp_r * 2 + 1, pomdp_r * 2 + 1))
|
|
||||||
else:
|
|
||||||
self._array = np.zeros((1, *self._level_shape))
|
|
||||||
self.charge_level = initial_charge_level
|
self.charge_level = initial_charge_level
|
||||||
|
|
||||||
def as_array(self):
|
def encoding(self):
|
||||||
self._array[:] = c.FREE_CELL.value
|
return self.charge_level
|
||||||
self._array[0, 0] = self.charge_level
|
|
||||||
return self._array
|
|
||||||
|
|
||||||
def __repr__(self):
|
def do_charge_action(self, amount):
|
||||||
return f'{self.__class__.__name__}[{self.agent.name}]({self.charge_level})'
|
|
||||||
|
|
||||||
def charge(self, amount) -> c:
|
|
||||||
if self.charge_level < 1:
|
if self.charge_level < 1:
|
||||||
# noinspection PyTypeChecker
|
# noinspection PyTypeChecker
|
||||||
self.charge_level = min(1, amount + self.charge_level)
|
self.charge_level = min(1, amount + self.charge_level)
|
||||||
@ -73,69 +68,57 @@ class Battery(object):
|
|||||||
if self.charge_level != 0:
|
if self.charge_level != 0:
|
||||||
# noinspection PyTypeChecker
|
# noinspection PyTypeChecker
|
||||||
self.charge_level = max(0, amount + self.charge_level)
|
self.charge_level = max(0, amount + self.charge_level)
|
||||||
|
self._register.notify_change_to_value(self)
|
||||||
return c.VALID
|
return c.VALID
|
||||||
else:
|
else:
|
||||||
return c.NOT_VALID
|
return c.NOT_VALID
|
||||||
|
|
||||||
def belongs_to_entity(self, entity):
|
def summarize_state(self, **_):
|
||||||
return self.agent == entity
|
|
||||||
|
|
||||||
def summarize_state(self, **kwargs):
|
|
||||||
attr_dict = {key: str(val) for key, val in self.__dict__.items() if not key.startswith('_') and key != 'data'}
|
attr_dict = {key: str(val) for key, val in self.__dict__.items() if not key.startswith('_') and key != 'data'}
|
||||||
attr_dict.update(dict(name=self.name))
|
attr_dict.update(dict(name=self.name))
|
||||||
return attr_dict
|
return attr_dict
|
||||||
|
|
||||||
|
|
||||||
class BatteriesRegister(ObjectRegister):
|
class BatteriesRegister(EnvObjectRegister):
|
||||||
|
|
||||||
_accepted_objects = Battery
|
_accepted_objects = Battery
|
||||||
is_blocking_light = False
|
|
||||||
can_be_shadowed = False
|
|
||||||
hide_from_obs_builder = True
|
|
||||||
|
|
||||||
def __init__(self, *args, **kwargs):
|
def __init__(self, *args, **kwargs):
|
||||||
super(BatteriesRegister, self).__init__(*args, is_per_agent=True, individual_slices=True, **kwargs)
|
super(BatteriesRegister, self).__init__(*args, individual_slices=True,
|
||||||
|
is_blocking_light=False, can_be_shadowed=False, **kwargs)
|
||||||
self.is_observable = True
|
self.is_observable = True
|
||||||
|
|
||||||
def as_array(self):
|
def spawn_batteries(self, agents, initial_charge_level):
|
||||||
# self._array[:] = c.FREE_CELL.value
|
batteries = [self._accepted_objects(initial_charge_level, agent, self) for _, agent in enumerate(agents)]
|
||||||
for inv_idx, battery in enumerate(self):
|
self.register_additional_items(batteries)
|
||||||
self._array[inv_idx] = battery.as_array()
|
|
||||||
return self._array
|
|
||||||
|
|
||||||
def spawn_batteries(self, agents, pomdp_r, initial_charge_level):
|
|
||||||
inventories = [self._accepted_objects(pomdp_r, self._level_shape, agent,
|
|
||||||
initial_charge_level)
|
|
||||||
for _, agent in enumerate(agents)]
|
|
||||||
self.register_additional_items(inventories)
|
|
||||||
|
|
||||||
def idx_by_entity(self, entity):
|
|
||||||
try:
|
|
||||||
return next((idx for idx, bat in enumerate(self) if bat.belongs_to_entity(entity)))
|
|
||||||
except StopIteration:
|
|
||||||
return None
|
|
||||||
|
|
||||||
def by_entity(self, entity):
|
|
||||||
try:
|
|
||||||
return next((bat for bat in self if bat.belongs_to_entity(entity)))
|
|
||||||
except StopIteration:
|
|
||||||
return None
|
|
||||||
|
|
||||||
def summarize_states(self, n_steps=None):
|
def summarize_states(self, n_steps=None):
|
||||||
# as dict with additional nesting
|
# as dict with additional nesting
|
||||||
# return dict(items=super(Inventories, self).summarize_states())
|
# return dict(items=super(Inventories, cls).summarize_states())
|
||||||
return super(BatteriesRegister, self).summarize_states(n_steps=n_steps)
|
return super(BatteriesRegister, self).summarize_states(n_steps=n_steps)
|
||||||
|
|
||||||
|
# Todo Move this to Mixin!
|
||||||
|
def by_entity(self, entity):
|
||||||
|
try:
|
||||||
|
return next((x for x in self if x.belongs_to_entity(entity)))
|
||||||
|
except StopIteration:
|
||||||
|
return None
|
||||||
|
|
||||||
|
def idx_by_entity(self, entity):
|
||||||
|
try:
|
||||||
|
return next((idx for idx, x in enumerate(self) if x.belongs_to_entity(entity)))
|
||||||
|
except StopIteration:
|
||||||
|
return None
|
||||||
|
|
||||||
|
def as_array_by_entity(self, entity):
|
||||||
|
return self._array[self.idx_by_entity(entity)]
|
||||||
|
|
||||||
|
|
||||||
class ChargePod(Entity):
|
class ChargePod(Entity):
|
||||||
|
|
||||||
@property
|
|
||||||
def can_collide(self):
|
|
||||||
return False
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def encoding(self):
|
def encoding(self):
|
||||||
return ITEM_DROP_OFF
|
return c.CHARGE_POD
|
||||||
|
|
||||||
def __init__(self, *args, charge_rate: float = 0.4,
|
def __init__(self, *args, charge_rate: float = 0.4,
|
||||||
multi_charge: bool = False, **kwargs):
|
multi_charge: bool = False, **kwargs):
|
||||||
@ -146,10 +129,10 @@ class ChargePod(Entity):
|
|||||||
def charge_battery(self, battery: Battery):
|
def charge_battery(self, battery: Battery):
|
||||||
if battery.charge_level == 1.0:
|
if battery.charge_level == 1.0:
|
||||||
return c.NOT_VALID
|
return c.NOT_VALID
|
||||||
if sum(guest for guest in self.tile.guests if c.AGENT.name in guest.name) > 1:
|
if sum(guest for guest in self.tile.guests if 'agent' in guest.name.lower()) > 1:
|
||||||
return c.NOT_VALID
|
return c.NOT_VALID
|
||||||
battery.charge(self.charge_rate)
|
valid = battery.do_charge_action(self.charge_rate)
|
||||||
return c.VALID
|
return valid
|
||||||
|
|
||||||
def summarize_state(self, n_steps=None) -> dict:
|
def summarize_state(self, n_steps=None) -> dict:
|
||||||
if n_steps == h.STEPS_START:
|
if n_steps == h.STEPS_START:
|
||||||
@ -157,32 +140,39 @@ class ChargePod(Entity):
|
|||||||
return summary
|
return summary
|
||||||
|
|
||||||
|
|
||||||
class ChargePods(EntityObjectRegister):
|
class ChargePods(EntityRegister):
|
||||||
|
|
||||||
_accepted_objects = ChargePod
|
_accepted_objects = ChargePod
|
||||||
|
|
||||||
def as_array(self):
|
|
||||||
self._array[:] = c.FREE_CELL.value
|
|
||||||
for item in self:
|
|
||||||
if item.pos != c.NO_POS.value:
|
|
||||||
self._array[0, item.x, item.y] = item.encoding
|
|
||||||
return self._array
|
|
||||||
|
|
||||||
def __repr__(self):
|
def __repr__(self):
|
||||||
super(ChargePods, self).__repr__()
|
super(ChargePods, self).__repr__()
|
||||||
|
|
||||||
|
|
||||||
class BatteryFactory(BaseFactory):
|
class BatteryFactory(BaseFactory):
|
||||||
|
|
||||||
def __init__(self, *args, btry_prop=BatteryProperties(), **kwargs):
|
def __init__(self, *args, btry_prop=BatteryProperties(), rewards_dest: RewardsBtry = RewardsBtry(),
|
||||||
|
**kwargs):
|
||||||
if isinstance(btry_prop, dict):
|
if isinstance(btry_prop, dict):
|
||||||
btry_prop = BatteryProperties(**btry_prop)
|
btry_prop = BatteryProperties(**btry_prop)
|
||||||
|
if isinstance(rewards_dest, dict):
|
||||||
|
rewards_dest = BatteryProperties(**rewards_dest)
|
||||||
self.btry_prop = btry_prop
|
self.btry_prop = btry_prop
|
||||||
|
self.rewards_dest = rewards_dest
|
||||||
super().__init__(*args, **kwargs)
|
super().__init__(*args, **kwargs)
|
||||||
|
|
||||||
|
def per_agent_raw_observations_hook(self, agent) -> Dict[str, np.typing.ArrayLike]:
|
||||||
|
additional_raw_observations = super().per_agent_raw_observations_hook(agent)
|
||||||
|
additional_raw_observations.update({c.BATTERIES: self[c.BATTERIES].as_array_by_entity(agent)})
|
||||||
|
return additional_raw_observations
|
||||||
|
|
||||||
|
def observations_hook(self) -> Dict[str, np.typing.ArrayLike]:
|
||||||
|
additional_observations = super().observations_hook()
|
||||||
|
additional_observations.update({c.CHARGE_PODS: self[c.CHARGE_PODS].as_array()})
|
||||||
|
return additional_observations
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def additional_entities(self):
|
def entities_hook(self):
|
||||||
super_entities = super().additional_entities
|
super_entities = super().entities_hook
|
||||||
|
|
||||||
empty_tiles = self[c.FLOOR].empty_tiles[:self.btry_prop.charge_locations]
|
empty_tiles = self[c.FLOOR].empty_tiles[:self.btry_prop.charge_locations]
|
||||||
charge_pods = ChargePods.from_tiles(
|
charge_pods = ChargePods.from_tiles(
|
||||||
@ -193,12 +183,12 @@ class BatteryFactory(BaseFactory):
|
|||||||
|
|
||||||
batteries = BatteriesRegister(self._level_shape if not self._pomdp_r else ((self.pomdp_diameter,) * 2),
|
batteries = BatteriesRegister(self._level_shape if not self._pomdp_r else ((self.pomdp_diameter,) * 2),
|
||||||
)
|
)
|
||||||
batteries.spawn_batteries(self[c.AGENT], self._pomdp_r, self.btry_prop.initial_charge)
|
batteries.spawn_batteries(self[c.AGENT], self.btry_prop.initial_charge)
|
||||||
super_entities.update({c.BATTERIES: batteries, c.CHARGE_POD: charge_pods})
|
super_entities.update({c.BATTERIES: batteries, c.CHARGE_PODS: charge_pods})
|
||||||
return super_entities
|
return super_entities
|
||||||
|
|
||||||
def do_additional_step(self) -> dict:
|
def step_hook(self) -> (List[dict], dict):
|
||||||
info_dict = super(BatteryFactory, self).do_additional_step()
|
super_reward_info = super(BatteryFactory, self).step_hook()
|
||||||
|
|
||||||
# Decharge
|
# Decharge
|
||||||
batteries = self[c.BATTERIES]
|
batteries = self[c.BATTERIES]
|
||||||
@ -211,65 +201,73 @@ class BatteryFactory(BaseFactory):
|
|||||||
|
|
||||||
batteries.by_entity(agent).decharge(energy_consumption)
|
batteries.by_entity(agent).decharge(energy_consumption)
|
||||||
|
|
||||||
return info_dict
|
return super_reward_info
|
||||||
|
|
||||||
def do_charge(self, agent) -> c:
|
def do_charge_action(self, agent) -> (dict, dict):
|
||||||
if charge_pod := self[c.CHARGE_POD].by_pos(agent.pos):
|
if charge_pod := self[c.CHARGE_PODS].by_pos(agent.pos):
|
||||||
return charge_pod.charge_battery(self[c.BATTERIES].by_entity(agent))
|
valid = charge_pod.charge_battery(self[c.BATTERIES].by_entity(agent))
|
||||||
|
if valid:
|
||||||
|
info_dict = {f'{agent.name}_{a.CHARGE}_VALID': 1}
|
||||||
|
self.print(f'{agent.name} just charged batteries at {charge_pod.name}.')
|
||||||
|
else:
|
||||||
|
info_dict = {f'{agent.name}_{a.CHARGE}_FAIL': 1}
|
||||||
|
self.print(f'{agent.name} failed to charged batteries at {charge_pod.name}.')
|
||||||
else:
|
else:
|
||||||
return c.NOT_VALID
|
valid = c.NOT_VALID
|
||||||
|
info_dict = {f'{agent.name}_{a.CHARGE}_FAIL': 1}
|
||||||
|
# info_dict = {f'{agent.name}_no_charger': 1}
|
||||||
|
self.print(f'{agent.name} failed to charged batteries at {agent.pos}.')
|
||||||
|
reward = dict(value=self.rewards_dest.CHARGE_VALID if valid else self.rewards_dest.CHARGE_FAIL,
|
||||||
|
reason=a.CHARGE, info=info_dict)
|
||||||
|
return valid, reward
|
||||||
|
|
||||||
def do_additional_actions(self, agent: Agent, action: Action) -> Union[None, c]:
|
def do_additional_actions(self, agent: Agent, action: Action) -> (bool, dict):
|
||||||
valid = super().do_additional_actions(agent, action)
|
action_result = super().do_additional_actions(agent, action)
|
||||||
if valid is None:
|
if action_result is None:
|
||||||
if action == CHARGE_ACTION:
|
if action == a.CHARGE:
|
||||||
valid = self.do_charge(agent)
|
action_result = self.do_charge_action(agent)
|
||||||
return valid
|
return action_result
|
||||||
else:
|
else:
|
||||||
return None
|
return None
|
||||||
else:
|
else:
|
||||||
return valid
|
return action_result
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def do_additional_reset(self) -> None:
|
def reset_hook(self) -> None:
|
||||||
# There is Nothing to reset.
|
# There is Nothing to reset.
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def check_additional_done(self) -> bool:
|
def check_additional_done(self) -> (bool, dict):
|
||||||
super_done = super(BatteryFactory, self).check_additional_done()
|
super_done, super_dict = super(BatteryFactory, self).check_additional_done()
|
||||||
if super_done:
|
if super_done:
|
||||||
return super_done
|
return super_done, super_dict
|
||||||
else:
|
else:
|
||||||
return self.btry_prop.done_when_discharged and any(battery.is_discharged for battery in self[c.BATTERIES])
|
if self.btry_prop.done_when_discharged:
|
||||||
|
if btry_done := any(battery.is_discharged for battery in self[c.BATTERIES]):
|
||||||
|
super_dict.update(DISCHARGE_DONE=1)
|
||||||
|
return btry_done, super_dict
|
||||||
|
else:
|
||||||
|
pass
|
||||||
|
else:
|
||||||
|
pass
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def calculate_additional_reward(self, agent: Agent) -> (int, dict):
|
def per_agent_reward_hook(self, agent: Agent) -> Dict[str, dict]:
|
||||||
reward, info_dict = super(BatteryFactory, self).calculate_additional_reward(agent)
|
reward_event_dict = super(BatteryFactory, self).per_agent_reward_hook(agent)
|
||||||
if h.EnvActions.CHARGE == agent.temp_action:
|
|
||||||
if agent.temp_valid:
|
|
||||||
charge_pod = self[c.CHARGE_POD].by_pos(agent.pos)
|
|
||||||
info_dict.update({f'{agent.name}_charge': 1})
|
|
||||||
info_dict.update(agent_charged=1)
|
|
||||||
self.print(f'{agent.name} just charged batteries at {charge_pod.pos}.')
|
|
||||||
reward += 0.1
|
|
||||||
else:
|
|
||||||
self[c.DROP_OFF].by_pos(agent.pos)
|
|
||||||
info_dict.update({f'{agent.name}_failed_charge': 1})
|
|
||||||
info_dict.update(failed_charge=1)
|
|
||||||
self.print(f'{agent.name} just tried to charge at {agent.pos}, but failed.')
|
|
||||||
reward -= 0.1
|
|
||||||
|
|
||||||
if self[c.BATTERIES].by_entity(agent).is_discharged:
|
if self[c.BATTERIES].by_entity(agent).is_discharged:
|
||||||
info_dict.update({f'{agent.name}_discharged': 1})
|
self.print(f'{agent.name} Battery is discharged!')
|
||||||
reward -= 1
|
info_dict = {f'{agent.name}_{c.BATTERY_DISCHARGED}': 1}
|
||||||
|
reward_event_dict.update({c.BATTERY_DISCHARGED: {'reward': self.rewards_dest.BATTERY_DISCHARGED,
|
||||||
|
'info': info_dict}}
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
info_dict.update({f'{agent.name}_battery_level': self[c.BATTERIES].by_entity(agent).charge_level})
|
# All Fine
|
||||||
return reward, info_dict
|
pass
|
||||||
|
return reward_event_dict
|
||||||
|
|
||||||
def render_additional_assets(self):
|
def render_assets_hook(self):
|
||||||
# noinspection PyUnresolvedReferences
|
# noinspection PyUnresolvedReferences
|
||||||
additional_assets = super().render_additional_assets()
|
additional_assets = super().render_assets_hook()
|
||||||
charge_pods = [RenderEntity(c.CHARGE_POD.value, charge_pod.tile.pos) for charge_pod in self[c.CHARGE_POD]]
|
charge_pods = [RenderEntity(c.CHARGE_PODS, charge_pod.tile.pos) for charge_pod in self[c.CHARGE_PODS]]
|
||||||
additional_assets.extend(charge_pods)
|
additional_assets.extend(charge_pods)
|
||||||
return additional_assets
|
return additional_assets
|
||||||
|
|
||||||
|
@ -6,16 +6,31 @@ import numpy as np
|
|||||||
import random
|
import random
|
||||||
|
|
||||||
from environments.factory.base.base_factory import BaseFactory
|
from environments.factory.base.base_factory import BaseFactory
|
||||||
from environments.helpers import Constants as c
|
from environments.helpers import Constants as BaseConstants
|
||||||
from environments import helpers as h
|
from environments.helpers import EnvActions as BaseActions
|
||||||
from environments.factory.base.objects import Agent, Entity, Action, Tile
|
from environments.factory.base.objects import Agent, Entity, Action
|
||||||
from environments.factory.base.registers import Entities, MovingEntityObjectRegister
|
from environments.factory.base.registers import Entities, EntityRegister
|
||||||
|
|
||||||
from environments.factory.base.renderer import RenderEntity
|
from environments.factory.base.renderer import RenderEntity
|
||||||
|
|
||||||
|
|
||||||
DESTINATION = 1
|
class Constants(BaseConstants):
|
||||||
DESTINATION_DONE = 0.5
|
# Destination Env
|
||||||
|
DEST = 'Destination'
|
||||||
|
DESTINATION = 1
|
||||||
|
DESTINATION_DONE = 0.5
|
||||||
|
DEST_REACHED = 'ReachedDestination'
|
||||||
|
|
||||||
|
|
||||||
|
class Actions(BaseActions):
|
||||||
|
WAIT_ON_DEST = 'WAIT'
|
||||||
|
|
||||||
|
|
||||||
|
class RewardsDest(NamedTuple):
|
||||||
|
|
||||||
|
WAIT_VALID: float = 0.1
|
||||||
|
WAIT_FAIL: float = -0.1
|
||||||
|
DEST_REACHED: float = 5.0
|
||||||
|
|
||||||
|
|
||||||
class Destination(Entity):
|
class Destination(Entity):
|
||||||
@ -28,20 +43,16 @@ class Destination(Entity):
|
|||||||
def currently_dwelling_names(self):
|
def currently_dwelling_names(self):
|
||||||
return self._per_agent_times.keys()
|
return self._per_agent_times.keys()
|
||||||
|
|
||||||
@property
|
|
||||||
def can_collide(self):
|
|
||||||
return False
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def encoding(self):
|
def encoding(self):
|
||||||
return DESTINATION
|
return c.DESTINATION
|
||||||
|
|
||||||
def __init__(self, *args, dwell_time: int = 0, **kwargs):
|
def __init__(self, *args, dwell_time: int = 0, **kwargs):
|
||||||
super(Destination, self).__init__(*args, **kwargs)
|
super(Destination, self).__init__(*args, **kwargs)
|
||||||
self.dwell_time = dwell_time
|
self.dwell_time = dwell_time
|
||||||
self._per_agent_times = defaultdict(lambda: dwell_time)
|
self._per_agent_times = defaultdict(lambda: dwell_time)
|
||||||
|
|
||||||
def wait(self, agent: Agent):
|
def do_wait_action(self, agent: Agent):
|
||||||
self._per_agent_times[agent.name] -= 1
|
self._per_agent_times[agent.name] -= 1
|
||||||
return c.VALID
|
return c.VALID
|
||||||
|
|
||||||
@ -50,7 +61,7 @@ class Destination(Entity):
|
|||||||
|
|
||||||
@property
|
@property
|
||||||
def is_considered_reached(self):
|
def is_considered_reached(self):
|
||||||
agent_at_position = any(c.AGENT.name.lower() in x.name.lower() for x in self.tile.guests_that_can_collide)
|
agent_at_position = any(c.AGENT.lower() in x.name.lower() for x in self.tile.guests_that_can_collide)
|
||||||
return (agent_at_position and not self.dwell_time) or any(x == 0 for x in self._per_agent_times.values())
|
return (agent_at_position and not self.dwell_time) or any(x == 0 for x in self._per_agent_times.values())
|
||||||
|
|
||||||
def agent_is_dwelling(self, agent: Agent):
|
def agent_is_dwelling(self, agent: Agent):
|
||||||
@ -62,15 +73,22 @@ class Destination(Entity):
|
|||||||
return state_summary
|
return state_summary
|
||||||
|
|
||||||
|
|
||||||
class Destinations(MovingEntityObjectRegister):
|
class Destinations(EntityRegister):
|
||||||
|
|
||||||
_accepted_objects = Destination
|
_accepted_objects = Destination
|
||||||
_light_blocking = False
|
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
self.is_blocking_light = False
|
||||||
|
self.can_be_shadowed = False
|
||||||
|
|
||||||
def as_array(self):
|
def as_array(self):
|
||||||
self._array[:] = c.FREE_CELL.value
|
self._array[:] = c.FREE_CELL
|
||||||
|
# ToDo: Switch to new Style Array Put
|
||||||
|
# indices = list(zip(range(len(cls)), *zip(*[x.pos for x in cls])))
|
||||||
|
# np.put(cls._array, [np.ravel_multi_index(x, cls._array.shape) for x in indices], cls.encodings)
|
||||||
for item in self:
|
for item in self:
|
||||||
if item.pos != c.NO_POS.value:
|
if item.pos != c.NO_POS:
|
||||||
self._array[0, item.x, item.y] = item.encoding
|
self._array[0, item.x, item.y] = item.encoding
|
||||||
return self._array
|
return self._array
|
||||||
|
|
||||||
@ -80,59 +98,67 @@ class Destinations(MovingEntityObjectRegister):
|
|||||||
|
|
||||||
class ReachedDestinations(Destinations):
|
class ReachedDestinations(Destinations):
|
||||||
_accepted_objects = Destination
|
_accepted_objects = Destination
|
||||||
_light_blocking = False
|
|
||||||
|
|
||||||
def __init__(self, *args, **kwargs):
|
def __init__(self, *args, **kwargs):
|
||||||
super(ReachedDestinations, self).__init__(*args, is_observable=False, **kwargs)
|
super(ReachedDestinations, self).__init__(*args, **kwargs)
|
||||||
|
self.can_be_shadowed = False
|
||||||
|
self.is_blocking_light = False
|
||||||
|
|
||||||
def summarize_states(self, n_steps=None):
|
def summarize_states(self, n_steps=None):
|
||||||
return {}
|
return {}
|
||||||
|
|
||||||
|
|
||||||
class DestSpawnMode(object):
|
class DestModeOptions(object):
|
||||||
DONE = 'DONE'
|
DONE = 'DONE'
|
||||||
GROUPED = 'GROUPED'
|
GROUPED = 'GROUPED'
|
||||||
PER_DEST = 'PER_DEST'
|
PER_DEST = 'PER_DEST'
|
||||||
|
|
||||||
|
|
||||||
class DestinationProperties(NamedTuple):
|
class DestProperties(NamedTuple):
|
||||||
n_dests: int = 1 # How many destinations are there
|
n_dests: int = 1 # How many destinations are there
|
||||||
dwell_time: int = 0 # How long does the agent need to "wait" on a destination
|
dwell_time: int = 0 # How long does the agent need to "wait" on a destination
|
||||||
spawn_frequency: int = 0
|
spawn_frequency: int = 0
|
||||||
spawn_in_other_zone: bool = True #
|
spawn_in_other_zone: bool = True #
|
||||||
spawn_mode: str = DestSpawnMode.DONE
|
spawn_mode: str = DestModeOptions.DONE
|
||||||
|
|
||||||
assert dwell_time >= 0, 'dwell_time cannot be < 0!'
|
assert dwell_time >= 0, 'dwell_time cannot be < 0!'
|
||||||
assert spawn_frequency >= 0, 'spawn_frequency cannot be < 0!'
|
assert spawn_frequency >= 0, 'spawn_frequency cannot be < 0!'
|
||||||
assert n_dests >= 0, 'n_destinations cannot be < 0!'
|
assert n_dests >= 0, 'n_destinations cannot be < 0!'
|
||||||
assert (spawn_mode == DestSpawnMode.DONE) != bool(spawn_frequency)
|
assert (spawn_mode == DestModeOptions.DONE) != bool(spawn_frequency)
|
||||||
|
|
||||||
|
|
||||||
|
c = Constants
|
||||||
|
a = Actions
|
||||||
|
|
||||||
|
|
||||||
# noinspection PyAttributeOutsideInit, PyAbstractClass
|
# noinspection PyAttributeOutsideInit, PyAbstractClass
|
||||||
class DestinationFactory(BaseFactory):
|
class DestFactory(BaseFactory):
|
||||||
# noinspection PyMissingConstructor
|
# noinspection PyMissingConstructor
|
||||||
|
|
||||||
def __init__(self, *args, dest_prop: DestinationProperties = DestinationProperties(),
|
def __init__(self, *args, dest_prop: DestProperties = DestProperties(), rewards_dest: RewardsDest = RewardsDest(),
|
||||||
env_seed=time.time_ns(), **kwargs):
|
env_seed=time.time_ns(), **kwargs):
|
||||||
if isinstance(dest_prop, dict):
|
if isinstance(dest_prop, dict):
|
||||||
dest_prop = DestinationProperties(**dest_prop)
|
dest_prop = DestProperties(**dest_prop)
|
||||||
|
if isinstance(rewards_dest, dict):
|
||||||
|
rewards_dest = RewardsDest(**rewards_dest)
|
||||||
self.dest_prop = dest_prop
|
self.dest_prop = dest_prop
|
||||||
|
self.rewards_dest = rewards_dest
|
||||||
kwargs.update(env_seed=env_seed)
|
kwargs.update(env_seed=env_seed)
|
||||||
self._dest_rng = np.random.default_rng(env_seed)
|
self._dest_rng = np.random.default_rng(env_seed)
|
||||||
super().__init__(*args, **kwargs)
|
super().__init__(*args, **kwargs)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def additional_actions(self) -> Union[Action, List[Action]]:
|
def actions_hook(self) -> Union[Action, List[Action]]:
|
||||||
# noinspection PyUnresolvedReferences
|
# noinspection PyUnresolvedReferences
|
||||||
super_actions = super().additional_actions
|
super_actions = super().actions_hook
|
||||||
if self.dest_prop.dwell_time:
|
if self.dest_prop.dwell_time:
|
||||||
super_actions.append(Action(enum_ident=h.EnvActions.WAIT_ON_DEST))
|
super_actions.append(Action(enum_ident=a.WAIT_ON_DEST))
|
||||||
return super_actions
|
return super_actions
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def additional_entities(self) -> Dict[(Enum, Entities)]:
|
def entities_hook(self) -> Dict[(Enum, Entities)]:
|
||||||
# noinspection PyUnresolvedReferences
|
# noinspection PyUnresolvedReferences
|
||||||
super_entities = super().additional_entities
|
super_entities = super().entities_hook
|
||||||
|
|
||||||
empty_tiles = self[c.FLOOR].empty_tiles[:self.dest_prop.n_dests]
|
empty_tiles = self[c.FLOOR].empty_tiles[:self.dest_prop.n_dests]
|
||||||
destinations = Destinations.from_tiles(
|
destinations = Destinations.from_tiles(
|
||||||
@ -142,35 +168,37 @@ class DestinationFactory(BaseFactory):
|
|||||||
)
|
)
|
||||||
reached_destinations = ReachedDestinations(level_shape=self._level_shape)
|
reached_destinations = ReachedDestinations(level_shape=self._level_shape)
|
||||||
|
|
||||||
super_entities.update({c.DESTINATION: destinations, c.REACHEDDESTINATION: reached_destinations})
|
super_entities.update({c.DEST: destinations, c.DEST_REACHED: reached_destinations})
|
||||||
return super_entities
|
return super_entities
|
||||||
|
|
||||||
def additional_per_agent_obs_build(self, agent) -> List[np.ndarray]:
|
def do_wait_action(self, agent: Agent) -> (dict, dict):
|
||||||
additional_per_agent_obs_build = super().additional_per_agent_obs_build(agent)
|
if destination := self[c.DEST].by_pos(agent.pos):
|
||||||
return additional_per_agent_obs_build
|
valid = destination.do_wait_action(agent)
|
||||||
|
self.print(f'{agent.name} just waited at {agent.pos}')
|
||||||
def wait(self, agent: Agent):
|
info_dict = {f'{agent.name}_{a.WAIT_ON_DEST}_VALID': 1}
|
||||||
if destiantion := self[c.DESTINATION].by_pos(agent.pos):
|
|
||||||
valid = destiantion.wait(agent)
|
|
||||||
return valid
|
|
||||||
else:
|
else:
|
||||||
return c.NOT_VALID
|
valid = c.NOT_VALID
|
||||||
|
self.print(f'{agent.name} just tried to do_wait_action do_wait_action at {agent.pos} but failed')
|
||||||
|
info_dict = {f'{agent.name}_{a.WAIT_ON_DEST}_FAIL': 1}
|
||||||
|
reward = dict(value=self.rewards_dest.WAIT_VALID if valid else self.rewards_dest.WAIT_FAIL,
|
||||||
|
reason=a.WAIT_ON_DEST, info=info_dict)
|
||||||
|
return valid, reward
|
||||||
|
|
||||||
def do_additional_actions(self, agent: Agent, action: Action) -> Union[None, c]:
|
def do_additional_actions(self, agent: Agent, action: Action) -> (dict, dict):
|
||||||
# noinspection PyUnresolvedReferences
|
# noinspection PyUnresolvedReferences
|
||||||
valid = super().do_additional_actions(agent, action)
|
super_action_result = super().do_additional_actions(agent, action)
|
||||||
if valid is None:
|
if super_action_result is None:
|
||||||
if action == h.EnvActions.WAIT_ON_DEST:
|
if action == a.WAIT_ON_DEST:
|
||||||
valid = self.wait(agent)
|
action_result = self.do_wait_action(agent)
|
||||||
return valid
|
return action_result
|
||||||
else:
|
else:
|
||||||
return None
|
return None
|
||||||
else:
|
else:
|
||||||
return valid
|
return super_action_result
|
||||||
|
|
||||||
def do_additional_reset(self) -> None:
|
def reset_hook(self) -> None:
|
||||||
# noinspection PyUnresolvedReferences
|
# noinspection PyUnresolvedReferences
|
||||||
super().do_additional_reset()
|
super().reset_hook()
|
||||||
self._dest_spawn_timer = dict()
|
self._dest_spawn_timer = dict()
|
||||||
|
|
||||||
def trigger_destination_spawn(self):
|
def trigger_destination_spawn(self):
|
||||||
@ -178,15 +206,15 @@ class DestinationFactory(BaseFactory):
|
|||||||
if val == self.dest_prop.spawn_frequency]
|
if val == self.dest_prop.spawn_frequency]
|
||||||
if destinations_to_spawn:
|
if destinations_to_spawn:
|
||||||
n_dest_to_spawn = len(destinations_to_spawn)
|
n_dest_to_spawn = len(destinations_to_spawn)
|
||||||
if self.dest_prop.spawn_mode != DestSpawnMode.GROUPED:
|
if self.dest_prop.spawn_mode != DestModeOptions.GROUPED:
|
||||||
destinations = [Destination(tile) for tile in self[c.FLOOR].empty_tiles[:n_dest_to_spawn]]
|
destinations = [Destination(tile, c.DEST) for tile in self[c.FLOOR].empty_tiles[:n_dest_to_spawn]]
|
||||||
self[c.DESTINATION].register_additional_items(destinations)
|
self[c.DEST].register_additional_items(destinations)
|
||||||
for dest in destinations_to_spawn:
|
for dest in destinations_to_spawn:
|
||||||
del self._dest_spawn_timer[dest]
|
del self._dest_spawn_timer[dest]
|
||||||
self.print(f'{n_dest_to_spawn} new destinations have been spawned')
|
self.print(f'{n_dest_to_spawn} new destinations have been spawned')
|
||||||
elif self.dest_prop.spawn_mode == DestSpawnMode.GROUPED and n_dest_to_spawn == self.dest_prop.n_dests:
|
elif self.dest_prop.spawn_mode == DestModeOptions.GROUPED and n_dest_to_spawn == self.dest_prop.n_dests:
|
||||||
destinations = [Destination(tile) for tile in self[c.FLOOR].empty_tiles[:n_dest_to_spawn]]
|
destinations = [Destination(tile, self[c.DEST]) for tile in self[c.FLOOR].empty_tiles[:n_dest_to_spawn]]
|
||||||
self[c.DESTINATION].register_additional_items(destinations)
|
self[c.DEST].register_additional_items(destinations)
|
||||||
for dest in destinations_to_spawn:
|
for dest in destinations_to_spawn:
|
||||||
del self._dest_spawn_timer[dest]
|
del self._dest_spawn_timer[dest]
|
||||||
self.print(f'{n_dest_to_spawn} new destinations have been spawned')
|
self.print(f'{n_dest_to_spawn} new destinations have been spawned')
|
||||||
@ -196,15 +224,14 @@ class DestinationFactory(BaseFactory):
|
|||||||
else:
|
else:
|
||||||
self.print('No Items are spawning, limit is reached.')
|
self.print('No Items are spawning, limit is reached.')
|
||||||
|
|
||||||
def do_additional_step(self) -> dict:
|
def step_hook(self) -> (List[dict], dict):
|
||||||
# noinspection PyUnresolvedReferences
|
# noinspection PyUnresolvedReferences
|
||||||
info_dict = super().do_additional_step()
|
super_reward_info = super().step_hook()
|
||||||
for key, val in self._dest_spawn_timer.items():
|
for key, val in self._dest_spawn_timer.items():
|
||||||
self._dest_spawn_timer[key] = min(self.dest_prop.spawn_frequency, self._dest_spawn_timer[key] + 1)
|
self._dest_spawn_timer[key] = min(self.dest_prop.spawn_frequency, self._dest_spawn_timer[key] + 1)
|
||||||
for dest in list(self[c.DESTINATION].values()):
|
for dest in list(self[c.DEST].values()):
|
||||||
if dest.is_considered_reached:
|
if dest.is_considered_reached:
|
||||||
self[c.REACHEDDESTINATION].register_item(dest)
|
dest.change_register(self[c.DEST])
|
||||||
self[c.DESTINATION].delete_entity(dest)
|
|
||||||
self._dest_spawn_timer[dest.name] = 0
|
self._dest_spawn_timer[dest.name] = 0
|
||||||
self.print(f'{dest.name} is reached now, removing...')
|
self.print(f'{dest.name} is reached now, removing...')
|
||||||
else:
|
else:
|
||||||
@ -217,59 +244,53 @@ class DestinationFactory(BaseFactory):
|
|||||||
dest.leave(agent)
|
dest.leave(agent)
|
||||||
self.print(f'{agent.name} left the destination early.')
|
self.print(f'{agent.name} left the destination early.')
|
||||||
self.trigger_destination_spawn()
|
self.trigger_destination_spawn()
|
||||||
return info_dict
|
return super_reward_info
|
||||||
|
|
||||||
def calculate_additional_reward(self, agent: Agent) -> (int, dict):
|
def observations_hook(self) -> Dict[str, np.typing.ArrayLike]:
|
||||||
|
additional_observations = super().observations_hook()
|
||||||
|
additional_observations.update({c.DEST: self[c.DEST].as_array()})
|
||||||
|
return additional_observations
|
||||||
|
|
||||||
|
def per_agent_reward_hook(self, agent: Agent) -> Dict[str, dict]:
|
||||||
# noinspection PyUnresolvedReferences
|
# noinspection PyUnresolvedReferences
|
||||||
reward, info_dict = super().calculate_additional_reward(agent)
|
reward_event_dict = super().per_agent_reward_hook(agent)
|
||||||
if h.EnvActions.WAIT_ON_DEST == agent.temp_action:
|
if len(self[c.DEST_REACHED]):
|
||||||
if agent.temp_valid:
|
for reached_dest in list(self[c.DEST_REACHED]):
|
||||||
info_dict.update({f'{agent.name}_waiting_at_dest': 1})
|
|
||||||
info_dict.update(agent_waiting_at_dest=1)
|
|
||||||
self.print(f'{agent.name} just waited at {agent.pos}')
|
|
||||||
reward += 0.1
|
|
||||||
else:
|
|
||||||
info_dict.update({f'{agent.name}_tried_failed': 1})
|
|
||||||
info_dict.update(agent_waiting_failed=1)
|
|
||||||
self.print(f'{agent.name} just tried to wait wait at {agent.pos} but failed')
|
|
||||||
reward -= 0.1
|
|
||||||
if len(self[c.REACHEDDESTINATION]):
|
|
||||||
for reached_dest in list(self[c.REACHEDDESTINATION]):
|
|
||||||
if agent.pos == reached_dest.pos:
|
if agent.pos == reached_dest.pos:
|
||||||
info_dict.update({f'{agent.name}_reached_destination': 1})
|
|
||||||
info_dict.update(agent_reached_destination=1)
|
|
||||||
self.print(f'{agent.name} just reached destination at {agent.pos}')
|
self.print(f'{agent.name} just reached destination at {agent.pos}')
|
||||||
reward += 0.5
|
self[c.DEST_REACHED].delete_env_object(reached_dest)
|
||||||
self[c.REACHEDDESTINATION].delete_entity(reached_dest)
|
info_dict = {f'{agent.name}_{c.DEST_REACHED}': 1}
|
||||||
return reward, info_dict
|
reward_event_dict.update({c.DEST_REACHED: {'reward': self.rewards_dest.DEST_REACHED,
|
||||||
|
'info': info_dict}})
|
||||||
|
return reward_event_dict
|
||||||
|
|
||||||
def render_additional_assets(self, mode='human'):
|
def render_assets_hook(self, mode='human'):
|
||||||
# noinspection PyUnresolvedReferences
|
# noinspection PyUnresolvedReferences
|
||||||
additional_assets = super().render_additional_assets()
|
additional_assets = super().render_assets_hook()
|
||||||
destinations = [RenderEntity(c.DESTINATION.value, dest.pos) for dest in self[c.DESTINATION]]
|
destinations = [RenderEntity(c.DEST, dest.pos) for dest in self[c.DEST]]
|
||||||
additional_assets.extend(destinations)
|
additional_assets.extend(destinations)
|
||||||
return additional_assets
|
return additional_assets
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
from environments.utility_classes import AgentRenderOptions as ARO, ObservationProperties
|
from environments.utility_classes import AgentRenderOptions as aro, ObservationProperties
|
||||||
|
|
||||||
render = True
|
render = True
|
||||||
|
|
||||||
dest_probs = DestinationProperties(n_dests=2, spawn_frequency=5, spawn_mode=DestSpawnMode.GROUPED)
|
dest_probs = DestProperties(n_dests=2, spawn_frequency=5, spawn_mode=DestModeOptions.GROUPED)
|
||||||
|
|
||||||
obs_props = ObservationProperties(render_agents=ARO.LEVEL, omit_agent_self=True, pomdp_r=2)
|
obs_props = ObservationProperties(render_agents=aro.LEVEL, omit_agent_self=True, pomdp_r=2)
|
||||||
|
|
||||||
move_props = {'allow_square_movement': True,
|
move_props = {'allow_square_movement': True,
|
||||||
'allow_diagonal_movement': False,
|
'allow_diagonal_movement': False,
|
||||||
'allow_no_op': False}
|
'allow_no_op': False}
|
||||||
|
|
||||||
factory = DestinationFactory(n_agents=10, done_at_collision=False,
|
factory = DestFactory(n_agents=10, done_at_collision=False,
|
||||||
level_name='rooms', max_steps=400,
|
level_name='rooms', max_steps=400,
|
||||||
obs_prop=obs_props, parse_doors=True,
|
obs_prop=obs_props, parse_doors=True,
|
||||||
verbose=True,
|
verbose=True,
|
||||||
mv_prop=move_props, dest_prop=dest_probs
|
mv_prop=move_props, dest_prop=dest_probs
|
||||||
)
|
)
|
||||||
|
|
||||||
# noinspection DuplicatedCode
|
# noinspection DuplicatedCode
|
||||||
n_actions = factory.action_space.n - 1
|
n_actions = factory.action_space.n - 1
|
@ -1,47 +1,56 @@
|
|||||||
import time
|
import time
|
||||||
from enum import Enum
|
from pathlib import Path
|
||||||
from typing import List, Union, NamedTuple, Dict
|
from typing import List, Union, NamedTuple, Dict
|
||||||
import random
|
import random
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from algorithms.TSP_dirt_agent import TSPDirtAgent
|
from algorithms.TSP_dirt_agent import TSPDirtAgent
|
||||||
from environments.helpers import Constants as c
|
from environments.helpers import Constants as BaseConstants
|
||||||
from environments import helpers as h
|
from environments.helpers import EnvActions as BaseActions
|
||||||
|
|
||||||
from environments.factory.base.base_factory import BaseFactory
|
from environments.factory.base.base_factory import BaseFactory
|
||||||
from environments.factory.base.objects import Agent, Action, Entity, Tile
|
from environments.factory.base.objects import Agent, Action, Entity, Floor
|
||||||
from environments.factory.base.registers import Entities, MovingEntityObjectRegister
|
from environments.factory.base.registers import Entities, EntityRegister
|
||||||
|
|
||||||
from environments.factory.base.renderer import RenderEntity
|
from environments.factory.base.renderer import RenderEntity
|
||||||
from environments.utility_classes import ObservationProperties
|
from environments.utility_classes import ObservationProperties
|
||||||
|
|
||||||
CLEAN_UP_ACTION = h.EnvActions.CLEAN_UP
|
|
||||||
|
class Constants(BaseConstants):
|
||||||
|
DIRT = 'Dirt'
|
||||||
|
|
||||||
|
|
||||||
|
class Actions(BaseActions):
|
||||||
|
CLEAN_UP = 'do_cleanup_action'
|
||||||
|
|
||||||
|
|
||||||
|
class RewardsDirt(NamedTuple):
|
||||||
|
CLEAN_UP_VALID: float = 0.5
|
||||||
|
CLEAN_UP_FAIL: float = -0.1
|
||||||
|
CLEAN_UP_LAST_PIECE: float = 4.5
|
||||||
|
|
||||||
|
|
||||||
class DirtProperties(NamedTuple):
|
class DirtProperties(NamedTuple):
|
||||||
initial_dirt_ratio: float = 0.3 # On INIT, on max how much tiles does the dirt spawn in percent.
|
initial_dirt_ratio: float = 0.3 # On INIT, on max how many tiles does the dirt spawn in percent.
|
||||||
initial_dirt_spawn_r_var: float = 0.05 # How much does the dirt spawn amount vary?
|
initial_dirt_spawn_r_var: float = 0.05 # How much does the dirt spawn amount vary?
|
||||||
clean_amount: float = 1 # How much does the robot clean with one actions.
|
clean_amount: float = 1 # How much does the robot clean with one actions.
|
||||||
max_spawn_ratio: float = 0.20 # On max how much tiles does the dirt spawn in percent.
|
max_spawn_ratio: float = 0.20 # On max how many tiles does the dirt spawn in percent.
|
||||||
max_spawn_amount: float = 0.3 # How much dirt does spawn per tile at max.
|
max_spawn_amount: float = 0.3 # How much dirt does spawn per tile at max.
|
||||||
spawn_frequency: int = 0 # Spawn Frequency in Steps.
|
spawn_frequency: int = 0 # Spawn Frequency in Steps.
|
||||||
max_local_amount: int = 2 # Max dirt amount per tile.
|
max_local_amount: int = 2 # Max dirt amount per tile.
|
||||||
max_global_amount: int = 20 # Max dirt amount in the whole environment.
|
max_global_amount: int = 20 # Max dirt amount in the whole environment.
|
||||||
dirt_smear_amount: float = 0.2 # Agents smear dirt, when not cleaning up in place.
|
dirt_smear_amount: float = 0.2 # Agents smear dirt, when not cleaning up in place.
|
||||||
agent_can_interact: bool = True # Whether the agents can interact with the dirt in this environment.
|
|
||||||
done_when_clean: bool = True
|
done_when_clean: bool = True
|
||||||
|
|
||||||
|
|
||||||
class Dirt(Entity):
|
class Dirt(Entity):
|
||||||
|
|
||||||
@property
|
|
||||||
def can_collide(self):
|
|
||||||
return False
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def amount(self):
|
def amount(self):
|
||||||
return self._amount
|
return self._amount
|
||||||
|
|
||||||
|
@property
|
||||||
def encoding(self):
|
def encoding(self):
|
||||||
# Edit this if you want items to be drawn in the ops differntly
|
# Edit this if you want items to be drawn in the ops differntly
|
||||||
return self._amount
|
return self._amount
|
||||||
@ -52,6 +61,7 @@ class Dirt(Entity):
|
|||||||
|
|
||||||
def set_new_amount(self, amount):
|
def set_new_amount(self, amount):
|
||||||
self._amount = amount
|
self._amount = amount
|
||||||
|
self._register.notify_change_to_value(self)
|
||||||
|
|
||||||
def summarize_state(self, **kwargs):
|
def summarize_state(self, **kwargs):
|
||||||
state_dict = super().summarize_state(**kwargs)
|
state_dict = super().summarize_state(**kwargs)
|
||||||
@ -59,18 +69,7 @@ class Dirt(Entity):
|
|||||||
return state_dict
|
return state_dict
|
||||||
|
|
||||||
|
|
||||||
class DirtRegister(MovingEntityObjectRegister):
|
class DirtRegister(EntityRegister):
|
||||||
|
|
||||||
def as_array(self):
|
|
||||||
if self._array is not None:
|
|
||||||
self._array[:] = c.FREE_CELL.value
|
|
||||||
for dirt in list(self.values()):
|
|
||||||
if dirt.amount == 0:
|
|
||||||
self.delete_entity(dirt)
|
|
||||||
self._array[0, dirt.x, dirt.y] = dirt.amount
|
|
||||||
else:
|
|
||||||
self._array = np.zeros((1, *self._level_shape))
|
|
||||||
return self._array
|
|
||||||
|
|
||||||
_accepted_objects = Dirt
|
_accepted_objects = Dirt
|
||||||
|
|
||||||
@ -86,14 +85,14 @@ class DirtRegister(MovingEntityObjectRegister):
|
|||||||
super(DirtRegister, self).__init__(*args)
|
super(DirtRegister, self).__init__(*args)
|
||||||
self._dirt_properties: DirtProperties = dirt_properties
|
self._dirt_properties: DirtProperties = dirt_properties
|
||||||
|
|
||||||
def spawn_dirt(self, then_dirty_tiles) -> c:
|
def spawn_dirt(self, then_dirty_tiles) -> bool:
|
||||||
if isinstance(then_dirty_tiles, Tile):
|
if isinstance(then_dirty_tiles, Floor):
|
||||||
then_dirty_tiles = [then_dirty_tiles]
|
then_dirty_tiles = [then_dirty_tiles]
|
||||||
for tile in then_dirty_tiles:
|
for tile in then_dirty_tiles:
|
||||||
if not self.amount > self.dirt_properties.max_global_amount:
|
if not self.amount > self.dirt_properties.max_global_amount:
|
||||||
dirt = self.by_pos(tile.pos)
|
dirt = self.by_pos(tile.pos)
|
||||||
if dirt is None:
|
if dirt is None:
|
||||||
dirt = Dirt(tile, amount=self.dirt_properties.max_spawn_amount)
|
dirt = Dirt(tile, self, amount=self.dirt_properties.max_spawn_amount)
|
||||||
self.register_item(dirt)
|
self.register_item(dirt)
|
||||||
else:
|
else:
|
||||||
new_value = dirt.amount + self.dirt_properties.max_spawn_amount
|
new_value = dirt.amount + self.dirt_properties.max_spawn_amount
|
||||||
@ -117,50 +116,71 @@ def entropy(x):
|
|||||||
return -(x * np.log(x + 1e-8)).sum()
|
return -(x * np.log(x + 1e-8)).sum()
|
||||||
|
|
||||||
|
|
||||||
|
c = Constants
|
||||||
|
a = Actions
|
||||||
|
|
||||||
|
|
||||||
# noinspection PyAttributeOutsideInit, PyAbstractClass
|
# noinspection PyAttributeOutsideInit, PyAbstractClass
|
||||||
class DirtFactory(BaseFactory):
|
class DirtFactory(BaseFactory):
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def additional_actions(self) -> Union[Action, List[Action]]:
|
def actions_hook(self) -> Union[Action, List[Action]]:
|
||||||
super_actions = super().additional_actions
|
super_actions = super().actions_hook
|
||||||
if self.dirt_prop.agent_can_interact:
|
super_actions.append(Action(str_ident=a.CLEAN_UP))
|
||||||
super_actions.append(Action(enum_ident=CLEAN_UP_ACTION))
|
|
||||||
return super_actions
|
return super_actions
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def additional_entities(self) -> Dict[(Enum, Entities)]:
|
def entities_hook(self) -> Dict[(str, Entities)]:
|
||||||
super_entities = super().additional_entities
|
super_entities = super().entities_hook
|
||||||
dirt_register = DirtRegister(self.dirt_prop, self._level_shape)
|
dirt_register = DirtRegister(self.dirt_prop, self._level_shape)
|
||||||
super_entities.update(({c.DIRT: dirt_register}))
|
super_entities.update(({c.DIRT: dirt_register}))
|
||||||
return super_entities
|
return super_entities
|
||||||
|
|
||||||
def __init__(self, *args, dirt_prop: DirtProperties = DirtProperties(), env_seed=time.time_ns(), **kwargs):
|
def __init__(self, *args,
|
||||||
|
dirt_prop: DirtProperties = DirtProperties(), rewards_dirt: RewardsDirt = RewardsDirt(),
|
||||||
|
env_seed=time.time_ns(), **kwargs):
|
||||||
if isinstance(dirt_prop, dict):
|
if isinstance(dirt_prop, dict):
|
||||||
dirt_prop = DirtProperties(**dirt_prop)
|
dirt_prop = DirtProperties(**dirt_prop)
|
||||||
|
if isinstance(rewards_dirt, dict):
|
||||||
|
rewards_dirt = RewardsDirt(**rewards_dirt)
|
||||||
self.dirt_prop = dirt_prop
|
self.dirt_prop = dirt_prop
|
||||||
|
self.rewards_dirt = rewards_dirt
|
||||||
self._dirt_rng = np.random.default_rng(env_seed)
|
self._dirt_rng = np.random.default_rng(env_seed)
|
||||||
self._dirt: DirtRegister
|
self._dirt: DirtRegister
|
||||||
kwargs.update(env_seed=env_seed)
|
kwargs.update(env_seed=env_seed)
|
||||||
|
# TODO: Reset ---> document this
|
||||||
super().__init__(*args, **kwargs)
|
super().__init__(*args, **kwargs)
|
||||||
|
|
||||||
def render_additional_assets(self, mode='human'):
|
def render_assets_hook(self, mode='human'):
|
||||||
additional_assets = super().render_additional_assets()
|
additional_assets = super().render_assets_hook()
|
||||||
dirt = [RenderEntity('dirt', dirt.tile.pos, min(0.15 + dirt.amount, 1.5), 'scale')
|
dirt = [RenderEntity('dirt', dirt.tile.pos, min(0.15 + dirt.amount, 1.5), 'scale')
|
||||||
for dirt in self[c.DIRT]]
|
for dirt in self[c.DIRT]]
|
||||||
additional_assets.extend(dirt)
|
additional_assets.extend(dirt)
|
||||||
return additional_assets
|
return additional_assets
|
||||||
|
|
||||||
def clean_up(self, agent: Agent) -> c:
|
def do_cleanup_action(self, agent: Agent) -> (dict, dict):
|
||||||
if dirt := self[c.DIRT].by_pos(agent.pos):
|
if dirt := self[c.DIRT].by_pos(agent.pos):
|
||||||
new_dirt_amount = dirt.amount - self.dirt_prop.clean_amount
|
new_dirt_amount = dirt.amount - self.dirt_prop.clean_amount
|
||||||
|
|
||||||
if new_dirt_amount <= 0:
|
if new_dirt_amount <= 0:
|
||||||
self[c.DIRT].delete_entity(dirt)
|
self[c.DIRT].delete_env_object(dirt)
|
||||||
else:
|
else:
|
||||||
dirt.set_new_amount(max(new_dirt_amount, c.FREE_CELL.value))
|
dirt.set_new_amount(max(new_dirt_amount, c.FREE_CELL.value))
|
||||||
return c.VALID
|
valid = c.VALID
|
||||||
|
self.print(f'{agent.name} did just clean up some dirt at {agent.pos}.')
|
||||||
|
info_dict = {f'{agent.name}_{a.CLEAN_UP}_VALID': 1, 'cleanup_valid': 1}
|
||||||
|
reward = self.rewards_dirt.CLEAN_UP_VALID
|
||||||
else:
|
else:
|
||||||
return c.NOT_VALID
|
valid = c.NOT_VALID
|
||||||
|
self.print(f'{agent.name} just tried to clean up some dirt at {agent.pos}, but failed.')
|
||||||
|
info_dict = {f'{agent.name}_{a.CLEAN_UP}_FAIL': 1, 'cleanup_fail': 1}
|
||||||
|
reward = self.rewards_dirt.CLEAN_UP_FAIL
|
||||||
|
|
||||||
|
if valid and self.dirt_prop.done_when_clean and (len(self[c.DIRT]) == 0):
|
||||||
|
reward += self.rewards_dirt.CLEAN_UP_LAST_PIECE
|
||||||
|
self.print(f'{agent.name} picked up the last piece of dirt!')
|
||||||
|
info_dict = {f'{agent.name}_{a.CLEAN_UP}_LAST_PIECE': 1}
|
||||||
|
return valid, dict(value=reward, reason=a.CLEAN_UP, info=info_dict)
|
||||||
|
|
||||||
def trigger_dirt_spawn(self, initial_spawn=False):
|
def trigger_dirt_spawn(self, initial_spawn=False):
|
||||||
dirt_rng = self._dirt_rng
|
dirt_rng = self._dirt_rng
|
||||||
@ -176,21 +196,21 @@ class DirtFactory(BaseFactory):
|
|||||||
n_dirt_tiles = max(0, int(new_spawn * len(free_for_dirt)))
|
n_dirt_tiles = max(0, int(new_spawn * len(free_for_dirt)))
|
||||||
self[c.DIRT].spawn_dirt(free_for_dirt[:n_dirt_tiles])
|
self[c.DIRT].spawn_dirt(free_for_dirt[:n_dirt_tiles])
|
||||||
|
|
||||||
def do_additional_step(self) -> dict:
|
def step_hook(self) -> (List[dict], dict):
|
||||||
info_dict = super().do_additional_step()
|
super_reward_info = super().step_hook()
|
||||||
if smear_amount := self.dirt_prop.dirt_smear_amount:
|
# if smear_amount := self.dirt_prop.dirt_smear_amount:
|
||||||
for agent in self[c.AGENT]:
|
# for agent in self[c.AGENT]:
|
||||||
if agent.temp_valid and agent.last_pos != c.NO_POS:
|
# if agent.temp_valid and agent.last_pos != c.NO_POS:
|
||||||
if self._actions.is_moving_action(agent.temp_action):
|
# if self._actions.is_moving_action(agent.temp_action):
|
||||||
if old_pos_dirt := self[c.DIRT].by_pos(agent.last_pos):
|
# if old_pos_dirt := self[c.DIRT].by_pos(agent.last_pos):
|
||||||
if smeared_dirt := round(old_pos_dirt.amount * smear_amount, 2):
|
# if smeared_dirt := round(old_pos_dirt.amount * smear_amount, 2):
|
||||||
old_pos_dirt.set_new_amount(max(0, old_pos_dirt.amount-smeared_dirt))
|
# old_pos_dirt.set_new_amount(max(0, old_pos_dirt.amount-smeared_dirt))
|
||||||
if new_pos_dirt := self[c.DIRT].by_pos(agent.pos):
|
# if new_pos_dirt := self[c.DIRT].by_pos(agent.pos):
|
||||||
new_pos_dirt.set_new_amount(max(0, new_pos_dirt.amount + smeared_dirt))
|
# new_pos_dirt.set_new_amount(max(0, new_pos_dirt.amount + smeared_dirt))
|
||||||
else:
|
# else:
|
||||||
if self[c.DIRT].spawn_dirt(agent.tile):
|
# if self[c.DIRT].spawn_dirt(agent.tile):
|
||||||
new_pos_dirt = self[c.DIRT].by_pos(agent.pos)
|
# new_pos_dirt = self[c.DIRT].by_pos(agent.pos)
|
||||||
new_pos_dirt.set_new_amount(max(0, new_pos_dirt.amount + smeared_dirt))
|
# new_pos_dirt.set_new_amount(max(0, new_pos_dirt.amount + smeared_dirt))
|
||||||
if self._next_dirt_spawn < 0:
|
if self._next_dirt_spawn < 0:
|
||||||
pass # No Dirt Spawn
|
pass # No Dirt Spawn
|
||||||
elif not self._next_dirt_spawn:
|
elif not self._next_dirt_spawn:
|
||||||
@ -198,70 +218,58 @@ class DirtFactory(BaseFactory):
|
|||||||
self._next_dirt_spawn = self.dirt_prop.spawn_frequency
|
self._next_dirt_spawn = self.dirt_prop.spawn_frequency
|
||||||
else:
|
else:
|
||||||
self._next_dirt_spawn -= 1
|
self._next_dirt_spawn -= 1
|
||||||
return info_dict
|
return super_reward_info
|
||||||
|
|
||||||
def do_additional_actions(self, agent: Agent, action: Action) -> Union[None, c]:
|
def do_additional_actions(self, agent: Agent, action: Action) -> (dict, dict):
|
||||||
valid = super().do_additional_actions(agent, action)
|
action_result = super().do_additional_actions(agent, action)
|
||||||
if valid is None:
|
if action_result is None:
|
||||||
if action == CLEAN_UP_ACTION:
|
if action == a.CLEAN_UP:
|
||||||
if self.dirt_prop.agent_can_interact:
|
return self.do_cleanup_action(agent)
|
||||||
valid = self.clean_up(agent)
|
|
||||||
return valid
|
|
||||||
else:
|
|
||||||
return c.NOT_VALID
|
|
||||||
else:
|
else:
|
||||||
return None
|
return None
|
||||||
else:
|
else:
|
||||||
return valid
|
return action_result
|
||||||
|
|
||||||
def do_additional_reset(self) -> None:
|
def reset_hook(self) -> None:
|
||||||
super().do_additional_reset()
|
super().reset_hook()
|
||||||
self.trigger_dirt_spawn(initial_spawn=True)
|
self.trigger_dirt_spawn(initial_spawn=True)
|
||||||
self._next_dirt_spawn = self.dirt_prop.spawn_frequency if self.dirt_prop.spawn_frequency else -1
|
self._next_dirt_spawn = self.dirt_prop.spawn_frequency if self.dirt_prop.spawn_frequency else -1
|
||||||
|
|
||||||
def check_additional_done(self):
|
def check_additional_done(self) -> (bool, dict):
|
||||||
super_done = super().check_additional_done()
|
super_done, super_dict = super().check_additional_done()
|
||||||
done = self.dirt_prop.done_when_clean and (len(self[c.DIRT]) == 0)
|
if self.dirt_prop.done_when_clean:
|
||||||
return super_done or done
|
if all_cleaned := len(self[c.DIRT]) == 0:
|
||||||
|
super_dict.update(ALL_CLEAN_DONE=all_cleaned)
|
||||||
|
return all_cleaned, super_dict
|
||||||
|
return super_done, super_dict
|
||||||
|
|
||||||
|
def observations_hook(self) -> Dict[str, np.typing.ArrayLike]:
|
||||||
|
additional_observations = super().observations_hook()
|
||||||
|
additional_observations.update({c.DIRT: self[c.DIRT].as_array()})
|
||||||
|
return additional_observations
|
||||||
|
|
||||||
|
def gather_additional_info(self, agent: Agent) -> dict:
|
||||||
|
event_reward_dict = super().per_agent_reward_hook(agent)
|
||||||
|
info_dict = dict()
|
||||||
|
|
||||||
def calculate_additional_reward(self, agent: Agent) -> (int, dict):
|
|
||||||
reward, info_dict = super().calculate_additional_reward(agent)
|
|
||||||
dirt = [dirt.amount for dirt in self[c.DIRT]]
|
dirt = [dirt.amount for dirt in self[c.DIRT]]
|
||||||
current_dirt_amount = sum(dirt)
|
current_dirt_amount = sum(dirt)
|
||||||
dirty_tile_count = len(dirt)
|
dirty_tile_count = len(dirt)
|
||||||
|
|
||||||
# if dirty_tile_count:
|
# if dirty_tile_count:
|
||||||
# dirt_distribution_score = entropy(softmax(np.asarray(dirt)) / dirty_tile_count)
|
# dirt_distribution_score = entropy(softmax(np.asarray(dirt)) / dirty_tile_count)
|
||||||
#else:
|
# else:
|
||||||
# dirt_distribution_score = 0
|
# dirt_distribution_score = 0
|
||||||
|
|
||||||
info_dict.update(dirt_amount=current_dirt_amount)
|
info_dict.update(dirt_amount=current_dirt_amount)
|
||||||
info_dict.update(dirty_tile_count=dirty_tile_count)
|
info_dict.update(dirty_tile_count=dirty_tile_count)
|
||||||
# info_dict.update(dirt_distribution_score=dirt_distribution_score)
|
|
||||||
|
|
||||||
if agent.temp_action == CLEAN_UP_ACTION:
|
event_reward_dict.update({'info': info_dict})
|
||||||
if agent.temp_valid:
|
return event_reward_dict
|
||||||
# Reward if pickup succeds,
|
|
||||||
# 0.5 on every pickup
|
|
||||||
reward += 0.5
|
|
||||||
info_dict.update(dirt_cleaned=1)
|
|
||||||
if self.dirt_prop.done_when_clean and (len(self[c.DIRT]) == 0):
|
|
||||||
# 0.5 additional reward for the very last pickup
|
|
||||||
reward += 4.5
|
|
||||||
info_dict.update(done_clean=1)
|
|
||||||
self.print(f'{agent.name} did just clean up some dirt at {agent.pos}.')
|
|
||||||
else:
|
|
||||||
reward -= 0.01
|
|
||||||
self.print(f'{agent.name} just tried to clean up some dirt at {agent.pos}, but failed.')
|
|
||||||
info_dict.update({f'{agent.name}_failed_dirt_cleanup': 1})
|
|
||||||
info_dict.update(failed_dirt_clean=1)
|
|
||||||
|
|
||||||
# Potential based rewards ->
|
|
||||||
# track the last reward , minus the current reward = potential
|
|
||||||
return reward, info_dict
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
from environments.utility_classes import AgentRenderOptions as ARO
|
from environments.utility_classes import AgentRenderOptions as aro
|
||||||
render = True
|
render = True
|
||||||
|
|
||||||
dirt_props = DirtProperties(
|
dirt_props = DirtProperties(
|
||||||
@ -273,46 +281,62 @@ if __name__ == '__main__':
|
|||||||
max_local_amount=1,
|
max_local_amount=1,
|
||||||
spawn_frequency=0,
|
spawn_frequency=0,
|
||||||
max_spawn_ratio=0.05,
|
max_spawn_ratio=0.05,
|
||||||
dirt_smear_amount=0.0,
|
dirt_smear_amount=0.0
|
||||||
agent_can_interact=True
|
|
||||||
)
|
)
|
||||||
|
|
||||||
obs_props = ObservationProperties(render_agents=ARO.COMBINED, omit_agent_self=True,
|
obs_props = ObservationProperties(render_agents=aro.COMBINED, omit_agent_self=True,
|
||||||
pomdp_r=2, additional_agent_placeholder=None)
|
pomdp_r=2, additional_agent_placeholder=None, cast_shadows=True,
|
||||||
|
indicate_door_area=False)
|
||||||
|
|
||||||
move_props = {'allow_square_movement': True,
|
move_props = {'allow_square_movement': True,
|
||||||
'allow_diagonal_movement': False,
|
'allow_diagonal_movement': False,
|
||||||
'allow_no_op': False}
|
'allow_no_op': False}
|
||||||
|
import time
|
||||||
|
global_timings = []
|
||||||
|
for i in range(10):
|
||||||
|
|
||||||
factory = DirtFactory(n_agents=1, done_at_collision=False,
|
factory = DirtFactory(n_agents=10, done_at_collision=False,
|
||||||
level_name='rooms', max_steps=400,
|
level_name='rooms', max_steps=1000,
|
||||||
doors_have_area=False,
|
doors_have_area=False,
|
||||||
obs_prop=obs_props, parse_doors=True,
|
obs_prop=obs_props, parse_doors=True,
|
||||||
record_episodes=True, verbose=True,
|
verbose=True,
|
||||||
mv_prop=move_props, dirt_prop=dirt_props,
|
mv_prop=move_props, dirt_prop=dirt_props,
|
||||||
inject_agents=[TSPDirtAgent]
|
# inject_agents=[TSPDirtAgent],
|
||||||
)
|
)
|
||||||
|
|
||||||
# noinspection DuplicatedCode
|
factory.save_params(Path('rewards_param'))
|
||||||
n_actions = factory.action_space.n - 1
|
|
||||||
_ = factory.observation_space
|
|
||||||
|
|
||||||
for epoch in range(10):
|
# noinspection DuplicatedCode
|
||||||
random_actions = [[random.randint(0, n_actions) for _
|
n_actions = factory.action_space.n - 1
|
||||||
in range(factory.n_agents)] for _
|
_ = factory.observation_space
|
||||||
in range(factory.max_steps+1)]
|
obs_space = factory.observation_space
|
||||||
env_state = factory.reset()
|
obs_space_named = factory.named_observation_space
|
||||||
if render:
|
action_space_named = factory.named_action_space
|
||||||
factory.render()
|
times = []
|
||||||
tsp_agent = factory.get_injected_agents()[0]
|
for epoch in range(10):
|
||||||
|
start_time = time.time()
|
||||||
r = 0
|
random_actions = [[random.randint(0, n_actions) for _
|
||||||
for agent_i_action in random_actions:
|
in range(factory.n_agents)] for _
|
||||||
env_state, step_r, done_bool, info_obj = factory.step(tsp_agent.predict())
|
in range(factory.max_steps+1)]
|
||||||
r += step_r
|
env_state = factory.reset()
|
||||||
if render:
|
if render:
|
||||||
factory.render()
|
factory.render()
|
||||||
if done_bool:
|
# tsp_agent = factory.get_injected_agents()[0]
|
||||||
break
|
|
||||||
print(f'Factory run {epoch} done, reward is:\n {r}')
|
rwrd = 0
|
||||||
|
for agent_i_action in random_actions:
|
||||||
|
# agent_i_action = tsp_agent.predict()
|
||||||
|
env_state, step_rwrd, done_bool, info_obj = factory.step(agent_i_action)
|
||||||
|
rwrd += step_rwrd
|
||||||
|
if render:
|
||||||
|
factory.render()
|
||||||
|
if done_bool:
|
||||||
|
break
|
||||||
|
times.append(time.time() - start_time)
|
||||||
|
# print(f'Factory run {epoch} done, reward is:\n {r}')
|
||||||
|
print('Mean Time Taken: ', sum(times) / 10)
|
||||||
|
global_timings.extend(times)
|
||||||
|
print('Mean Time Taken: ', sum(global_timings) / len(global_timings))
|
||||||
|
print('Median Time Taken: ', global_timings[len(global_timings)//2])
|
||||||
|
|
||||||
pass
|
pass
|
||||||
|
58
environments/factory/factory_dirt_stationary_machines.py
Normal file
58
environments/factory/factory_dirt_stationary_machines.py
Normal file
@ -0,0 +1,58 @@
|
|||||||
|
from typing import Dict, List, Union
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from environments.factory.base.objects import Agent, Entity, Action
|
||||||
|
from environments.factory.factory_dirt import Dirt, DirtRegister, DirtFactory
|
||||||
|
from environments.factory.base.objects import Floor
|
||||||
|
from environments.factory.base.registers import Floors, Entities, EntityRegister
|
||||||
|
|
||||||
|
|
||||||
|
class Machines(EntityRegister):
|
||||||
|
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
class Machine(Entity):
|
||||||
|
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
class StationaryMachinesDirtFactory(DirtFactory):
|
||||||
|
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
self._machine_coords = [(6, 6), (12, 13)]
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
|
||||||
|
def entities_hook(self) -> Dict[(str, Entities)]:
|
||||||
|
super_entities = super().entities_hook()
|
||||||
|
|
||||||
|
return super_entities
|
||||||
|
|
||||||
|
def reset_hook(self) -> None:
|
||||||
|
pass
|
||||||
|
|
||||||
|
def observations_hook(self) -> Dict[str, np.typing.ArrayLike]:
|
||||||
|
pass
|
||||||
|
|
||||||
|
def actions_hook(self) -> Union[Action, List[Action]]:
|
||||||
|
pass
|
||||||
|
|
||||||
|
def step_hook(self) -> (List[dict], dict):
|
||||||
|
|
||||||
|
pass
|
||||||
|
|
||||||
|
def per_agent_raw_observations_hook(self, agent) -> Dict[str, np.typing.ArrayLike]:
|
||||||
|
super_per_agent_raw_observations = super().per_agent_raw_observations_hook(agent)
|
||||||
|
return super_per_agent_raw_observations
|
||||||
|
|
||||||
|
def per_agent_reward_hook(self, agent: Agent) -> Dict[str, dict]:
|
||||||
|
pass
|
||||||
|
|
||||||
|
def pre_step_hook(self) -> None:
|
||||||
|
pass
|
||||||
|
|
||||||
|
def post_step_hook(self) -> dict:
|
||||||
|
pass
|
@ -1,25 +1,40 @@
|
|||||||
import time
|
import time
|
||||||
from collections import deque, UserList
|
from collections import deque
|
||||||
from enum import Enum
|
|
||||||
from typing import List, Union, NamedTuple, Dict
|
from typing import List, Union, NamedTuple, Dict
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import random
|
import random
|
||||||
|
|
||||||
from environments.factory.base.base_factory import BaseFactory
|
from environments.factory.base.base_factory import BaseFactory
|
||||||
from environments.helpers import Constants as c
|
from environments.helpers import Constants as BaseConstants
|
||||||
|
from environments.helpers import EnvActions as BaseActions
|
||||||
from environments import helpers as h
|
from environments import helpers as h
|
||||||
from environments.factory.base.objects import Agent, Entity, Action, Tile, MoveableEntity
|
from environments.factory.base.objects import Agent, Entity, Action, Floor
|
||||||
from environments.factory.base.registers import Entities, EntityObjectRegister, ObjectRegister, \
|
from environments.factory.base.registers import Entities, EntityRegister, BoundEnvObjRegister, ObjectRegister
|
||||||
MovingEntityObjectRegister
|
|
||||||
|
|
||||||
from environments.factory.base.renderer import RenderEntity
|
from environments.factory.base.renderer import RenderEntity
|
||||||
|
|
||||||
|
|
||||||
NO_ITEM = 0
|
class Constants(BaseConstants):
|
||||||
ITEM_DROP_OFF = 1
|
NO_ITEM = 0
|
||||||
|
ITEM_DROP_OFF = 1
|
||||||
|
# Item Env
|
||||||
|
ITEM = 'Item'
|
||||||
|
INVENTORY = 'Inventory'
|
||||||
|
DROP_OFF = 'Drop_Off'
|
||||||
|
|
||||||
|
|
||||||
class Item(MoveableEntity):
|
class Actions(BaseActions):
|
||||||
|
ITEM_ACTION = 'ITEMACTION'
|
||||||
|
|
||||||
|
|
||||||
|
class RewardsItem(NamedTuple):
|
||||||
|
DROP_OFF_VALID: float = 0.1
|
||||||
|
DROP_OFF_FAIL: float = -0.1
|
||||||
|
PICK_UP_FAIL: float = -0.1
|
||||||
|
PICK_UP_VALID: float = 0.1
|
||||||
|
|
||||||
|
|
||||||
|
class Item(Entity):
|
||||||
|
|
||||||
def __init__(self, *args, **kwargs):
|
def __init__(self, *args, **kwargs):
|
||||||
super().__init__(*args, **kwargs)
|
super().__init__(*args, **kwargs)
|
||||||
@ -29,10 +44,6 @@ class Item(MoveableEntity):
|
|||||||
def auto_despawn(self):
|
def auto_despawn(self):
|
||||||
return self._auto_despawn
|
return self._auto_despawn
|
||||||
|
|
||||||
@property
|
|
||||||
def can_collide(self):
|
|
||||||
return False
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def encoding(self):
|
def encoding(self):
|
||||||
# Edit this if you want items to be drawn in the ops differently
|
# Edit this if you want items to be drawn in the ops differently
|
||||||
@ -41,20 +52,17 @@ class Item(MoveableEntity):
|
|||||||
def set_auto_despawn(self, auto_despawn):
|
def set_auto_despawn(self, auto_despawn):
|
||||||
self._auto_despawn = auto_despawn
|
self._auto_despawn = auto_despawn
|
||||||
|
|
||||||
|
def set_tile_to(self, no_pos_tile):
|
||||||
|
assert self._register.__class__.__name__ != ItemRegister.__class__
|
||||||
|
self._tile = no_pos_tile
|
||||||
|
|
||||||
class ItemRegister(MovingEntityObjectRegister):
|
|
||||||
|
|
||||||
def as_array(self):
|
class ItemRegister(EntityRegister):
|
||||||
self._array[:] = c.FREE_CELL.value
|
|
||||||
for item in self:
|
|
||||||
if item.pos != c.NO_POS.value:
|
|
||||||
self._array[0, item.x, item.y] = item.encoding
|
|
||||||
return self._array
|
|
||||||
|
|
||||||
_accepted_objects = Item
|
_accepted_objects = Item
|
||||||
|
|
||||||
def spawn_items(self, tiles: List[Tile]):
|
def spawn_items(self, tiles: List[Floor]):
|
||||||
items = [Item(tile) for tile in tiles]
|
items = [Item(tile, self) for tile in tiles]
|
||||||
self.register_additional_items(items)
|
self.register_additional_items(items)
|
||||||
|
|
||||||
def despawn_items(self, items: List[Item]):
|
def despawn_items(self, items: List[Item]):
|
||||||
@ -63,72 +71,48 @@ class ItemRegister(MovingEntityObjectRegister):
|
|||||||
del self[item]
|
del self[item]
|
||||||
|
|
||||||
|
|
||||||
class Inventory(UserList):
|
class Inventory(BoundEnvObjRegister):
|
||||||
|
|
||||||
@property
|
|
||||||
def is_blocking_light(self):
|
|
||||||
return False
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def name(self):
|
def name(self):
|
||||||
return f'{self.__class__.__name__}({self.agent.name})'
|
return f'{self.__class__.__name__}({self._bound_entity.name})'
|
||||||
|
|
||||||
def __init__(self, pomdp_r: int, level_shape: (int, int), agent: Agent, capacity: int):
|
def __init__(self, agent: Agent, capacity: int, *args, **kwargs):
|
||||||
super(Inventory, self).__init__()
|
super(Inventory, self).__init__(agent, *args, is_blocking_light=False, can_be_shadowed=False, **kwargs)
|
||||||
self.agent = agent
|
self.capacity = capacity
|
||||||
self.pomdp_r = pomdp_r
|
|
||||||
self._level_shape = level_shape
|
|
||||||
if self.pomdp_r:
|
|
||||||
self._array = np.zeros((1, pomdp_r * 2 + 1, pomdp_r * 2 + 1))
|
|
||||||
else:
|
|
||||||
self._array = np.zeros((1, *self._level_shape))
|
|
||||||
self.capacity = min(capacity, self._array.size)
|
|
||||||
|
|
||||||
def as_array(self):
|
def as_array(self):
|
||||||
self._array[:] = c.FREE_CELL.value
|
if self._array is None:
|
||||||
for item_idx, item in enumerate(self):
|
self._array = np.zeros((1, *self._shape))
|
||||||
x_diff, y_diff = divmod(item_idx, self._array.shape[1])
|
return super(Inventory, self).as_array()
|
||||||
self._array[0, int(x_diff), int(y_diff)] = item.encoding
|
|
||||||
return self._array
|
|
||||||
|
|
||||||
def __repr__(self):
|
def summarize_states(self, **kwargs):
|
||||||
return f'{self.__class__.__name__}[{self.agent.name}]({self.data})'
|
|
||||||
|
|
||||||
def append(self, item) -> None:
|
|
||||||
if len(self) < self.capacity:
|
|
||||||
super(Inventory, self).append(item)
|
|
||||||
else:
|
|
||||||
raise RuntimeError('Inventory is full')
|
|
||||||
|
|
||||||
def belongs_to_entity(self, entity):
|
|
||||||
return self.agent == entity
|
|
||||||
|
|
||||||
def summarize_state(self, **kwargs):
|
|
||||||
attr_dict = {key: str(val) for key, val in self.__dict__.items() if not key.startswith('_') and key != 'data'}
|
attr_dict = {key: str(val) for key, val in self.__dict__.items() if not key.startswith('_') and key != 'data'}
|
||||||
attr_dict.update(dict(items={val.name: val.summarize_state(**kwargs) for val in self}))
|
attr_dict.update(dict(items={key: val.summarize_state(**kwargs) for key, val in self.items()}))
|
||||||
attr_dict.update(dict(name=self.name))
|
attr_dict.update(dict(name=self.name))
|
||||||
return attr_dict
|
return attr_dict
|
||||||
|
|
||||||
|
def pop(self):
|
||||||
|
item_to_pop = self[0]
|
||||||
|
self.delete_env_object(item_to_pop)
|
||||||
|
return item_to_pop
|
||||||
|
|
||||||
|
|
||||||
class Inventories(ObjectRegister):
|
class Inventories(ObjectRegister):
|
||||||
|
|
||||||
_accepted_objects = Inventory
|
_accepted_objects = Inventory
|
||||||
is_blocking_light = False
|
is_blocking_light = False
|
||||||
can_be_shadowed = False
|
can_be_shadowed = False
|
||||||
hide_from_obs_builder = True
|
|
||||||
|
|
||||||
def __init__(self, *args, **kwargs):
|
def __init__(self, obs_shape, *args, **kwargs):
|
||||||
super(Inventories, self).__init__(*args, is_per_agent=True, individual_slices=True, **kwargs)
|
super(Inventories, self).__init__(*args, is_per_agent=True, individual_slices=True, **kwargs)
|
||||||
self.is_observable = True
|
self._obs_shape = obs_shape
|
||||||
|
|
||||||
def as_array(self):
|
def as_array(self):
|
||||||
# self._array[:] = c.FREE_CELL.value
|
return np.stack([inventory.as_array() for inv_idx, inventory in enumerate(self)])
|
||||||
for inv_idx, inventory in enumerate(self):
|
|
||||||
self._array[inv_idx] = inventory.as_array()
|
|
||||||
return self._array
|
|
||||||
|
|
||||||
def spawn_inventories(self, agents, pomdp_r, capacity):
|
def spawn_inventories(self, agents, capacity):
|
||||||
inventories = [self._accepted_objects(pomdp_r, self._level_shape, agent, capacity)
|
inventories = [self._accepted_objects(agent, capacity, self._obs_shape)
|
||||||
for _, agent in enumerate(agents)]
|
for _, agent in enumerate(agents)]
|
||||||
self.register_additional_items(inventories)
|
self.register_additional_items(inventories)
|
||||||
|
|
||||||
@ -144,21 +128,15 @@ class Inventories(ObjectRegister):
|
|||||||
except StopIteration:
|
except StopIteration:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def summarize_states(self, n_steps=None):
|
def summarize_states(self, **kwargs):
|
||||||
# as dict with additional nesting
|
return {key: val.summarize_states(**kwargs) for key, val in self.items()}
|
||||||
# return dict(items=super(Inventories, self).summarize_states())
|
|
||||||
return super(Inventories, self).summarize_states(n_steps=n_steps)
|
|
||||||
|
|
||||||
|
|
||||||
class DropOffLocation(Entity):
|
class DropOffLocation(Entity):
|
||||||
|
|
||||||
@property
|
|
||||||
def can_collide(self):
|
|
||||||
return False
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def encoding(self):
|
def encoding(self):
|
||||||
return ITEM_DROP_OFF
|
return Constants.ITEM_DROP_OFF
|
||||||
|
|
||||||
def __init__(self, *args, storage_size_until_full: int = 5, auto_item_despawn_interval: int = 5, **kwargs):
|
def __init__(self, *args, storage_size_until_full: int = 5, auto_item_despawn_interval: int = 5, **kwargs):
|
||||||
super(DropOffLocation, self).__init__(*args, **kwargs)
|
super(DropOffLocation, self).__init__(*args, **kwargs)
|
||||||
@ -183,53 +161,50 @@ class DropOffLocation(Entity):
|
|||||||
return super().summarize_state(n_steps=n_steps)
|
return super().summarize_state(n_steps=n_steps)
|
||||||
|
|
||||||
|
|
||||||
class DropOffLocations(EntityObjectRegister):
|
class DropOffLocations(EntityRegister):
|
||||||
|
|
||||||
_accepted_objects = DropOffLocation
|
_accepted_objects = DropOffLocation
|
||||||
|
|
||||||
def as_array(self):
|
|
||||||
self._array[:] = c.FREE_CELL.value
|
|
||||||
for item in self:
|
|
||||||
if item.pos != c.NO_POS.value:
|
|
||||||
self._array[0, item.x, item.y] = item.encoding
|
|
||||||
return self._array
|
|
||||||
|
|
||||||
def __repr__(self):
|
|
||||||
super(DropOffLocations, self).__repr__()
|
|
||||||
|
|
||||||
|
|
||||||
class ItemProperties(NamedTuple):
|
class ItemProperties(NamedTuple):
|
||||||
n_items: int = 5 # How many items are there at the same time
|
n_items: int = 5 # How many items are there at the same time
|
||||||
spawn_frequency: int = 10 # Spawn Frequency in Steps
|
spawn_frequency: int = 10 # Spawn Frequency in Steps
|
||||||
n_drop_off_locations: int = 5 # How many DropOff locations are there at the same time
|
n_drop_off_locations: int = 5 # How many DropOff locations are there at the same time
|
||||||
max_dropoff_storage_size: int = 0 # How many items are needed until the drop off is full
|
max_dropoff_storage_size: int = 0 # How many items are needed until the dropoff is full
|
||||||
max_agent_inventory_capacity: int = 5 # How many items are needed until the agent inventory is full
|
max_agent_inventory_capacity: int = 5 # How many items are needed until the agent inventory is full
|
||||||
agent_can_interact: bool = True # Whether agents have the possibility to interact with the domain items
|
|
||||||
|
|
||||||
|
c = Constants
|
||||||
|
a = Actions
|
||||||
|
|
||||||
|
|
||||||
# noinspection PyAttributeOutsideInit, PyAbstractClass
|
# noinspection PyAttributeOutsideInit, PyAbstractClass
|
||||||
class ItemFactory(BaseFactory):
|
class ItemFactory(BaseFactory):
|
||||||
# noinspection PyMissingConstructor
|
# noinspection PyMissingConstructor
|
||||||
def __init__(self, *args, item_prop: ItemProperties = ItemProperties(), env_seed=time.time_ns(), **kwargs):
|
def __init__(self, *args, item_prop: ItemProperties = ItemProperties(), env_seed=time.time_ns(),
|
||||||
|
rewards_item: RewardsItem = RewardsItem(), **kwargs):
|
||||||
if isinstance(item_prop, dict):
|
if isinstance(item_prop, dict):
|
||||||
item_prop = ItemProperties(**item_prop)
|
item_prop = ItemProperties(**item_prop)
|
||||||
|
if isinstance(rewards_item, dict):
|
||||||
|
rewards_item = RewardsItem(**rewards_item)
|
||||||
self.item_prop = item_prop
|
self.item_prop = item_prop
|
||||||
|
self.rewards_item = rewards_item
|
||||||
kwargs.update(env_seed=env_seed)
|
kwargs.update(env_seed=env_seed)
|
||||||
self._item_rng = np.random.default_rng(env_seed)
|
self._item_rng = np.random.default_rng(env_seed)
|
||||||
assert (item_prop.n_items <= ((1 + kwargs.get('_pomdp_r', 0) * 2) ** 2)) or not kwargs.get('_pomdp_r', 0)
|
assert (item_prop.n_items <= ((1 + kwargs.get('_pomdp_r', 0) * 2) ** 2)) or not kwargs.get('_pomdp_r', 0)
|
||||||
super().__init__(*args, **kwargs)
|
super().__init__(*args, **kwargs)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def additional_actions(self) -> Union[Action, List[Action]]:
|
def actions_hook(self) -> Union[Action, List[Action]]:
|
||||||
# noinspection PyUnresolvedReferences
|
# noinspection PyUnresolvedReferences
|
||||||
super_actions = super().additional_actions
|
super_actions = super().actions_hook
|
||||||
super_actions.append(Action(enum_ident=h.EnvActions.ITEM_ACTION))
|
super_actions.append(Action(str_ident=a.ITEM_ACTION))
|
||||||
return super_actions
|
return super_actions
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def additional_entities(self) -> Dict[(Enum, Entities)]:
|
def entities_hook(self) -> Dict[(str, Entities)]:
|
||||||
# noinspection PyUnresolvedReferences
|
# noinspection PyUnresolvedReferences
|
||||||
super_entities = super().additional_entities
|
super_entities = super().entities_hook
|
||||||
|
|
||||||
empty_tiles = self[c.FLOOR].empty_tiles[:self.item_prop.n_drop_off_locations]
|
empty_tiles = self[c.FLOOR].empty_tiles[:self.item_prop.n_drop_off_locations]
|
||||||
drop_offs = DropOffLocations.from_tiles(
|
drop_offs = DropOffLocations.from_tiles(
|
||||||
@ -241,54 +216,65 @@ class ItemFactory(BaseFactory):
|
|||||||
empty_tiles = self[c.FLOOR].empty_tiles[:self.item_prop.n_items]
|
empty_tiles = self[c.FLOOR].empty_tiles[:self.item_prop.n_items]
|
||||||
item_register.spawn_items(empty_tiles)
|
item_register.spawn_items(empty_tiles)
|
||||||
|
|
||||||
inventories = Inventories(self._level_shape if not self._pomdp_r else ((self.pomdp_diameter,) * 2))
|
inventories = Inventories(self._obs_shape, self._level_shape)
|
||||||
inventories.spawn_inventories(self[c.AGENT], self._pomdp_r,
|
inventories.spawn_inventories(self[c.AGENT], self.item_prop.max_agent_inventory_capacity)
|
||||||
self.item_prop.max_agent_inventory_capacity)
|
|
||||||
|
|
||||||
super_entities.update({c.DROP_OFF: drop_offs, c.ITEM: item_register, c.INVENTORY: inventories})
|
super_entities.update({c.DROP_OFF: drop_offs, c.ITEM: item_register, c.INVENTORY: inventories})
|
||||||
return super_entities
|
return super_entities
|
||||||
|
|
||||||
def additional_per_agent_obs_build(self, agent) -> List[np.ndarray]:
|
def per_agent_raw_observations_hook(self, agent) -> Dict[str, np.typing.ArrayLike]:
|
||||||
additional_per_agent_obs_build = super().additional_per_agent_obs_build(agent)
|
additional_raw_observations = super().per_agent_raw_observations_hook(agent)
|
||||||
additional_per_agent_obs_build.append(self[c.INVENTORY].by_entity(agent).as_array())
|
additional_raw_observations.update({c.INVENTORY: self[c.INVENTORY].by_entity(agent).as_array()})
|
||||||
return additional_per_agent_obs_build
|
return additional_raw_observations
|
||||||
|
|
||||||
def do_item_action(self, agent: Agent):
|
def observations_hook(self) -> Dict[str, np.typing.ArrayLike]:
|
||||||
|
additional_observations = super().observations_hook()
|
||||||
|
additional_observations.update({c.ITEM: self[c.ITEM].as_array()})
|
||||||
|
additional_observations.update({c.DROP_OFF: self[c.DROP_OFF].as_array()})
|
||||||
|
return additional_observations
|
||||||
|
|
||||||
|
def do_item_action(self, agent: Agent) -> (dict, dict):
|
||||||
inventory = self[c.INVENTORY].by_entity(agent)
|
inventory = self[c.INVENTORY].by_entity(agent)
|
||||||
if drop_off := self[c.DROP_OFF].by_pos(agent.pos):
|
if drop_off := self[c.DROP_OFF].by_pos(agent.pos):
|
||||||
if inventory:
|
if inventory:
|
||||||
valid = drop_off.place_item(inventory.pop(0))
|
valid = drop_off.place_item(inventory.pop())
|
||||||
return valid
|
|
||||||
else:
|
else:
|
||||||
return c.NOT_VALID
|
valid = c.NOT_VALID
|
||||||
|
if valid:
|
||||||
|
self.print(f'{agent.name} just dropped of an item at {drop_off.pos}.')
|
||||||
|
info_dict = {f'{agent.name}_DROPOFF_VALID': 1, 'DROPOFF_VALID': 1}
|
||||||
|
else:
|
||||||
|
self.print(f'{agent.name} just tried to drop off at {agent.pos}, but failed.')
|
||||||
|
info_dict = {f'{agent.name}_DROPOFF_FAIL': 1, 'DROPOFF_FAIL': 1}
|
||||||
|
reward = dict(value=self.rewards_item.DROP_OFF_VALID if valid else self.rewards_item.DROP_OFF_FAIL,
|
||||||
|
reason=a.ITEM_ACTION, info=info_dict)
|
||||||
|
return valid, reward
|
||||||
elif item := self[c.ITEM].by_pos(agent.pos):
|
elif item := self[c.ITEM].by_pos(agent.pos):
|
||||||
try:
|
item.change_register(inventory)
|
||||||
inventory.append(item)
|
item.set_tile_to(self._NO_POS_TILE)
|
||||||
item.move(self._NO_POS_TILE)
|
self.print(f'{agent.name} just picked up an item at {agent.pos}')
|
||||||
return c.VALID
|
info_dict = {f'{agent.name}_{a.ITEM_ACTION}_VALID': 1, f'{a.ITEM_ACTION}_VALID': 1}
|
||||||
except RuntimeError:
|
return c.VALID, dict(value=self.rewards_item.PICK_UP_VALID, reason=a.ITEM_ACTION, info=info_dict)
|
||||||
return c.NOT_VALID
|
|
||||||
else:
|
else:
|
||||||
return c.NOT_VALID
|
self.print(f'{agent.name} just tried to pick up an item at {agent.pos}, but failed.')
|
||||||
|
info_dict = {f'{agent.name}_{a.ITEM_ACTION}_FAIL': 1, f'{a.ITEM_ACTION}_FAIL': 1}
|
||||||
|
return c.NOT_VALID, dict(value=self.rewards_item.PICK_UP_FAIL, reason=a.ITEM_ACTION, info=info_dict)
|
||||||
|
|
||||||
def do_additional_actions(self, agent: Agent, action: Action) -> Union[None, c]:
|
def do_additional_actions(self, agent: Agent, action: Action) -> (dict, dict):
|
||||||
# noinspection PyUnresolvedReferences
|
# noinspection PyUnresolvedReferences
|
||||||
valid = super().do_additional_actions(agent, action)
|
action_result = super().do_additional_actions(agent, action)
|
||||||
if valid is None:
|
if action_result is None:
|
||||||
if action == h.EnvActions.ITEM_ACTION:
|
if action == a.ITEM_ACTION:
|
||||||
if self.item_prop.agent_can_interact:
|
action_result = self.do_item_action(agent)
|
||||||
valid = self.do_item_action(agent)
|
return action_result
|
||||||
return valid
|
|
||||||
else:
|
|
||||||
return c.NOT_VALID
|
|
||||||
else:
|
else:
|
||||||
return None
|
return None
|
||||||
else:
|
else:
|
||||||
return valid
|
return action_result
|
||||||
|
|
||||||
def do_additional_reset(self) -> None:
|
def reset_hook(self) -> None:
|
||||||
# noinspection PyUnresolvedReferences
|
# noinspection PyUnresolvedReferences
|
||||||
super().do_additional_reset()
|
super().reset_hook()
|
||||||
self._next_item_spawn = self.item_prop.spawn_frequency
|
self._next_item_spawn = self.item_prop.spawn_frequency
|
||||||
self.trigger_item_spawn()
|
self.trigger_item_spawn()
|
||||||
|
|
||||||
@ -301,14 +287,14 @@ class ItemFactory(BaseFactory):
|
|||||||
else:
|
else:
|
||||||
self.print('No Items are spawning, limit is reached.')
|
self.print('No Items are spawning, limit is reached.')
|
||||||
|
|
||||||
def do_additional_step(self) -> dict:
|
def step_hook(self) -> (List[dict], dict):
|
||||||
# noinspection PyUnresolvedReferences
|
# noinspection PyUnresolvedReferences
|
||||||
info_dict = super().do_additional_step()
|
super_reward_info = super().step_hook()
|
||||||
for item in list(self[c.ITEM].values()):
|
for item in list(self[c.ITEM].values()):
|
||||||
if item.auto_despawn >= 1:
|
if item.auto_despawn >= 1:
|
||||||
item.set_auto_despawn(item.auto_despawn-1)
|
item.set_auto_despawn(item.auto_despawn-1)
|
||||||
elif not item.auto_despawn:
|
elif not item.auto_despawn:
|
||||||
self[c.ITEM].delete_entity(item)
|
self[c.ITEM].delete_env_object(item)
|
||||||
else:
|
else:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@ -316,60 +302,32 @@ class ItemFactory(BaseFactory):
|
|||||||
self.trigger_item_spawn()
|
self.trigger_item_spawn()
|
||||||
else:
|
else:
|
||||||
self._next_item_spawn = max(0, self._next_item_spawn-1)
|
self._next_item_spawn = max(0, self._next_item_spawn-1)
|
||||||
return info_dict
|
return super_reward_info
|
||||||
|
|
||||||
def calculate_additional_reward(self, agent: Agent) -> (int, dict):
|
def render_assets_hook(self, mode='human'):
|
||||||
# noinspection PyUnresolvedReferences
|
# noinspection PyUnresolvedReferences
|
||||||
reward, info_dict = super().calculate_additional_reward(agent)
|
additional_assets = super().render_assets_hook()
|
||||||
if h.EnvActions.ITEM_ACTION == agent.temp_action:
|
items = [RenderEntity(c.ITEM, item.tile.pos) for item in self[c.ITEM] if item.tile != self._NO_POS_TILE]
|
||||||
if agent.temp_valid:
|
|
||||||
if drop_off := self[c.DROP_OFF].by_pos(agent.pos):
|
|
||||||
info_dict.update({f'{agent.name}_item_drop_off': 1})
|
|
||||||
info_dict.update(item_drop_off=1)
|
|
||||||
self.print(f'{agent.name} just dropped of an item at {drop_off.pos}.')
|
|
||||||
reward += 0.5
|
|
||||||
else:
|
|
||||||
info_dict.update({f'{agent.name}_item_pickup': 1})
|
|
||||||
info_dict.update(item_pickup=1)
|
|
||||||
self.print(f'{agent.name} just picked up an item at {agent.pos}')
|
|
||||||
reward += 0.1
|
|
||||||
else:
|
|
||||||
if self[c.DROP_OFF].by_pos(agent.pos):
|
|
||||||
info_dict.update({f'{agent.name}_failed_drop_off': 1})
|
|
||||||
info_dict.update(failed_drop_off=1)
|
|
||||||
self.print(f'{agent.name} just tried to drop off at {agent.pos}, but failed.')
|
|
||||||
reward -= 0.1
|
|
||||||
else:
|
|
||||||
info_dict.update({f'{agent.name}_failed_item_action': 1})
|
|
||||||
info_dict.update(failed_pick_up=1)
|
|
||||||
self.print(f'{agent.name} just tried to pick up an item at {agent.pos}, but failed.')
|
|
||||||
reward -= 0.1
|
|
||||||
return reward, info_dict
|
|
||||||
|
|
||||||
def render_additional_assets(self, mode='human'):
|
|
||||||
# noinspection PyUnresolvedReferences
|
|
||||||
additional_assets = super().render_additional_assets()
|
|
||||||
items = [RenderEntity(c.ITEM.value, item.tile.pos) for item in self[c.ITEM]]
|
|
||||||
additional_assets.extend(items)
|
additional_assets.extend(items)
|
||||||
drop_offs = [RenderEntity(c.DROP_OFF.value, drop_off.tile.pos) for drop_off in self[c.DROP_OFF]]
|
drop_offs = [RenderEntity(c.DROP_OFF, drop_off.tile.pos) for drop_off in self[c.DROP_OFF]]
|
||||||
additional_assets.extend(drop_offs)
|
additional_assets.extend(drop_offs)
|
||||||
return additional_assets
|
return additional_assets
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
from environments.utility_classes import AgentRenderOptions as ARO, ObservationProperties
|
from environments.utility_classes import AgentRenderOptions as aro, ObservationProperties
|
||||||
|
|
||||||
render = True
|
render = True
|
||||||
|
|
||||||
item_probs = ItemProperties()
|
item_probs = ItemProperties(n_items=30, n_drop_off_locations=6)
|
||||||
|
|
||||||
obs_props = ObservationProperties(render_agents=ARO.LEVEL, omit_agent_self=True, pomdp_r=2)
|
obs_props = ObservationProperties(render_agents=aro.SEPERATE, omit_agent_self=True, pomdp_r=2)
|
||||||
|
|
||||||
move_props = {'allow_square_movement': True,
|
move_props = {'allow_square_movement': True,
|
||||||
'allow_diagonal_movement': False,
|
'allow_diagonal_movement': True,
|
||||||
'allow_no_op': False}
|
'allow_no_op': False}
|
||||||
|
|
||||||
factory = ItemFactory(n_agents=3, done_at_collision=False,
|
factory = ItemFactory(n_agents=6, done_at_collision=False,
|
||||||
level_name='rooms', max_steps=400,
|
level_name='rooms', max_steps=400,
|
||||||
obs_prop=obs_props, parse_doors=True,
|
obs_prop=obs_props, parse_doors=True,
|
||||||
record_episodes=True, verbose=True,
|
record_episodes=True, verbose=True,
|
||||||
@ -378,20 +336,21 @@ if __name__ == '__main__':
|
|||||||
|
|
||||||
# noinspection DuplicatedCode
|
# noinspection DuplicatedCode
|
||||||
n_actions = factory.action_space.n - 1
|
n_actions = factory.action_space.n - 1
|
||||||
_ = factory.observation_space
|
obs_space = factory.observation_space
|
||||||
|
obs_space_named = factory.named_observation_space
|
||||||
|
|
||||||
for epoch in range(4):
|
for epoch in range(400):
|
||||||
random_actions = [[random.randint(0, n_actions) for _
|
random_actions = [[random.randint(0, n_actions) for _
|
||||||
in range(factory.n_agents)] for _
|
in range(factory.n_agents)] for _
|
||||||
in range(factory.max_steps + 1)]
|
in range(factory.max_steps + 1)]
|
||||||
env_state = factory.reset()
|
env_state = factory.reset()
|
||||||
r = 0
|
rwrd = 0
|
||||||
for agent_i_action in random_actions:
|
for agent_i_action in random_actions:
|
||||||
env_state, step_r, done_bool, info_obj = factory.step(agent_i_action)
|
env_state, step_r, done_bool, info_obj = factory.step(agent_i_action)
|
||||||
r += step_r
|
rwrd += step_r
|
||||||
if render:
|
if render:
|
||||||
factory.render()
|
factory.render()
|
||||||
if done_bool:
|
if done_bool:
|
||||||
break
|
break
|
||||||
print(f'Factory run {epoch} done, reward is:\n {r}')
|
print(f'Factory run {epoch} done, reward is:\n {rwrd}')
|
||||||
pass
|
pass
|
||||||
|
@ -1,8 +1,12 @@
|
|||||||
movement_props:
|
parse_doors: True
|
||||||
|
doors_have_area: True
|
||||||
|
done_at_collision: False
|
||||||
|
level_name: "rooms"
|
||||||
|
mv_prop:
|
||||||
allow_diagonal_movement: True
|
allow_diagonal_movement: True
|
||||||
allow_square_movement: True
|
allow_square_movement: True
|
||||||
allow_no_op: False
|
allow_no_op: False
|
||||||
dirt_props:
|
dirt_prop:
|
||||||
initial_dirt_ratio: 0.35
|
initial_dirt_ratio: 0.35
|
||||||
initial_dirt_spawn_r_var : 0.1
|
initial_dirt_spawn_r_var : 0.1
|
||||||
clean_amount: 0.34
|
clean_amount: 0.34
|
||||||
@ -12,8 +16,15 @@ dirt_props:
|
|||||||
spawn_frequency: 0
|
spawn_frequency: 0
|
||||||
max_spawn_ratio: 0.05
|
max_spawn_ratio: 0.05
|
||||||
dirt_smear_amount: 0.0
|
dirt_smear_amount: 0.0
|
||||||
agent_can_interact: True
|
done_when_clean: True
|
||||||
factory_props:
|
rewards_base:
|
||||||
parse_doors: True
|
MOVEMENTS_VALID: 0
|
||||||
level_name: "rooms"
|
MOVEMENTS_FAIL: 0
|
||||||
doors_have_area: False
|
NOOP: 0
|
||||||
|
USE_DOOR_VALID: 0
|
||||||
|
USE_DOOR_FAIL: 0
|
||||||
|
COLLISION: 0
|
||||||
|
rewards_dirt:
|
||||||
|
CLEAN_UP_VALID: 1
|
||||||
|
CLEAN_UP_FAIL: 0
|
||||||
|
CLEAN_UP_LAST_PIECE: 5
|
@ -1,12 +1,10 @@
|
|||||||
import itertools
|
import itertools
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from enum import Enum, auto
|
from typing import Tuple, Union, Dict, List, NamedTuple
|
||||||
from typing import Tuple, Union
|
|
||||||
|
|
||||||
import networkx as nx
|
import networkx as nx
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from pathlib import Path
|
from numpy.typing import ArrayLike
|
||||||
|
|
||||||
from stable_baselines3 import PPO, DQN, A2C
|
from stable_baselines3 import PPO, DQN, A2C
|
||||||
|
|
||||||
MODEL_MAP = dict(PPO=PPO, DQN=DQN, A2C=A2C)
|
MODEL_MAP = dict(PPO=PPO, DQN=DQN, A2C=A2C)
|
||||||
@ -20,7 +18,7 @@ IGNORED_DF_COLUMNS = ['Episode', 'Run', 'train_step', 'step', 'index', 'dirt_amo
|
|||||||
|
|
||||||
|
|
||||||
# Constants
|
# Constants
|
||||||
class Constants(Enum):
|
class Constants:
|
||||||
WALL = '#'
|
WALL = '#'
|
||||||
WALLS = 'Walls'
|
WALLS = 'Walls'
|
||||||
FLOOR = 'Floor'
|
FLOOR = 'Floor'
|
||||||
@ -29,86 +27,132 @@ class Constants(Enum):
|
|||||||
LEVEL = 'Level'
|
LEVEL = 'Level'
|
||||||
AGENT = 'Agent'
|
AGENT = 'Agent'
|
||||||
AGENT_PLACEHOLDER = 'AGENT_PLACEHOLDER'
|
AGENT_PLACEHOLDER = 'AGENT_PLACEHOLDER'
|
||||||
|
GLOBAL_POSITION = 'GLOBAL_POSITION'
|
||||||
FREE_CELL = 0
|
FREE_CELL = 0
|
||||||
OCCUPIED_CELL = 1
|
OCCUPIED_CELL = 1
|
||||||
SHADOWED_CELL = -1
|
SHADOWED_CELL = -1
|
||||||
|
ACCESS_DOOR_CELL = 1/3
|
||||||
|
OPEN_DOOR_CELL = 2/3
|
||||||
|
CLOSED_DOOR_CELL = 3/3
|
||||||
NO_POS = (-9999, -9999)
|
NO_POS = (-9999, -9999)
|
||||||
|
|
||||||
DOORS = 'Doors'
|
DOORS = 'Doors'
|
||||||
CLOSED_DOOR = 'closed'
|
CLOSED_DOOR = 'closed'
|
||||||
OPEN_DOOR = 'open'
|
OPEN_DOOR = 'open'
|
||||||
|
ACCESS_DOOR = 'access'
|
||||||
|
|
||||||
ACTION = 'action'
|
ACTION = 'action'
|
||||||
COLLISIONS = 'collision'
|
COLLISION = 'collision'
|
||||||
VALID = 'valid'
|
VALID = True
|
||||||
NOT_VALID = 'not_valid'
|
NOT_VALID = False
|
||||||
|
|
||||||
# Dirt Env
|
|
||||||
DIRT = 'Dirt'
|
|
||||||
|
|
||||||
# Item Env
|
|
||||||
ITEM = 'Item'
|
|
||||||
INVENTORY = 'Inventory'
|
|
||||||
DROP_OFF = 'Drop_Off'
|
|
||||||
|
|
||||||
# Battery Env
|
|
||||||
CHARGE_POD = 'Charge_Pod'
|
|
||||||
BATTERIES = 'BATTERIES'
|
|
||||||
|
|
||||||
# Destination Env
|
|
||||||
DESTINATION = 'Destination'
|
|
||||||
REACHEDDESTINATION = 'ReachedDestination'
|
|
||||||
|
|
||||||
def __bool__(self):
|
|
||||||
if 'not_' in self.value:
|
|
||||||
return False
|
|
||||||
else:
|
|
||||||
return bool(self.value)
|
|
||||||
|
|
||||||
|
|
||||||
class MovingAction(Enum):
|
class EnvActions:
|
||||||
NORTH = 'north'
|
# Movements
|
||||||
EAST = 'east'
|
NORTH = 'north'
|
||||||
SOUTH = 'south'
|
EAST = 'east'
|
||||||
WEST = 'west'
|
SOUTH = 'south'
|
||||||
NORTHEAST = 'north_east'
|
WEST = 'west'
|
||||||
SOUTHEAST = 'south_east'
|
NORTHEAST = 'north_east'
|
||||||
SOUTHWEST = 'south_west'
|
SOUTHEAST = 'south_east'
|
||||||
NORTHWEST = 'north_west'
|
SOUTHWEST = 'south_west'
|
||||||
|
NORTHWEST = 'north_west'
|
||||||
|
|
||||||
|
# Other
|
||||||
|
# MOVE = 'move'
|
||||||
|
NOOP = 'no_op'
|
||||||
|
USE_DOOR = 'use_door'
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def is_member(cls, other):
|
def is_move(cls, other):
|
||||||
return any([other == direction for direction in cls])
|
return any([other == direction for direction in cls.movement_actions()])
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def square(cls):
|
def square_move(cls):
|
||||||
return [cls.NORTH, cls.EAST, cls.SOUTH, cls.WEST]
|
return [cls.NORTH, cls.EAST, cls.SOUTH, cls.WEST]
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def diagonal(cls):
|
def diagonal_move(cls):
|
||||||
return [cls.NORTHEAST, cls.SOUTHEAST, cls.SOUTHWEST, cls.NORTHWEST]
|
return [cls.NORTHEAST, cls.SOUTHEAST, cls.SOUTHWEST, cls.NORTHWEST]
|
||||||
|
|
||||||
|
@classmethod
|
||||||
class EnvActions(Enum):
|
def movement_actions(cls):
|
||||||
NOOP = 'no_op'
|
return list(itertools.chain(cls.square_move(), cls.diagonal_move()))
|
||||||
USE_DOOR = 'use_door'
|
|
||||||
CLEAN_UP = 'clean_up'
|
|
||||||
ITEM_ACTION = 'item_action'
|
|
||||||
CHARGE = 'charge'
|
|
||||||
WAIT_ON_DEST = 'wait'
|
|
||||||
|
|
||||||
|
|
||||||
m = MovingAction
|
class RewardsBase(NamedTuple):
|
||||||
|
MOVEMENTS_VALID: float = -0.001
|
||||||
|
MOVEMENTS_FAIL: float = -0.05
|
||||||
|
NOOP: float = -0.01
|
||||||
|
USE_DOOR_VALID: float = -0.00
|
||||||
|
USE_DOOR_FAIL: float = -0.01
|
||||||
|
COLLISION: float = -0.5
|
||||||
|
|
||||||
|
|
||||||
|
m = EnvActions
|
||||||
c = Constants
|
c = Constants
|
||||||
|
|
||||||
ACTIONMAP = defaultdict(lambda: (0, 0), {m.NORTH: (-1, 0), m.NORTHEAST: (-1, +1),
|
ACTIONMAP = defaultdict(lambda: (0, 0),
|
||||||
m.EAST: (0, 1), m.SOUTHEAST: (1, 1),
|
{m.NORTH: (-1, 0), m.NORTHEAST: (-1, 1),
|
||||||
m.SOUTH: (1, 0), m.SOUTHWEST: (+1, -1),
|
m.EAST: (0, 1), m.SOUTHEAST: (1, 1),
|
||||||
m.WEST: (0, -1), m.NORTHWEST: (-1, -1)
|
m.SOUTH: (1, 0), m.SOUTHWEST: (1, -1),
|
||||||
}
|
m.WEST: (0, -1), m.NORTHWEST: (-1, -1)
|
||||||
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class ObservationTranslator:
|
||||||
|
|
||||||
|
def __init__(self, obs_shape_2d: (int, int), this_named_observation_space: Dict[str, dict],
|
||||||
|
*per_agent_named_obs_space: Dict[str, dict],
|
||||||
|
placeholder_fill_value: Union[int, str] = 'N'):
|
||||||
|
assert len(obs_shape_2d) == 2
|
||||||
|
self.obs_shape = obs_shape_2d
|
||||||
|
if isinstance(placeholder_fill_value, str):
|
||||||
|
if placeholder_fill_value.lower() in ['normal', 'n']:
|
||||||
|
self.random_fill = lambda: np.random.normal(size=self.obs_shape)
|
||||||
|
elif placeholder_fill_value.lower() in ['uniform', 'u']:
|
||||||
|
self.random_fill = lambda: np.random.uniform(size=self.obs_shape)
|
||||||
|
else:
|
||||||
|
raise ValueError('Please chooe between "uniform" or "normal"')
|
||||||
|
else:
|
||||||
|
self.random_fill = None
|
||||||
|
|
||||||
|
self._this_named_obs_space = this_named_observation_space
|
||||||
|
self._per_agent_named_obs_space = list(per_agent_named_obs_space)
|
||||||
|
|
||||||
|
def translate_observation(self, agent_idx: int, obs: np.ndarray):
|
||||||
|
target_obs_space = self._per_agent_named_obs_space[agent_idx]
|
||||||
|
translation = [idx_space_dict for name, idx_space_dict in target_obs_space.items()]
|
||||||
|
flat_translation = [x for y in translation for x in y]
|
||||||
|
return np.take(obs, flat_translation, axis=1 if obs.ndim == 4 else 0)
|
||||||
|
|
||||||
|
def translate_observations(self, observations: List[ArrayLike]):
|
||||||
|
return [self.translate_observation(idx, observation) for idx, observation in enumerate(observations)]
|
||||||
|
|
||||||
|
def __call__(self, observations):
|
||||||
|
return self.translate_observations(observations)
|
||||||
|
|
||||||
|
|
||||||
|
class ActionTranslator:
|
||||||
|
|
||||||
|
def __init__(self, target_named_action_space: Dict[str, int], *per_agent_named_action_space: Dict[str, int]):
|
||||||
|
self._target_named_action_space = target_named_action_space
|
||||||
|
self._per_agent_named_action_space = list(per_agent_named_action_space)
|
||||||
|
self._per_agent_idx_actions = [{idx: a for a, idx in x.items()} for x in self._per_agent_named_action_space]
|
||||||
|
|
||||||
|
def translate_action(self, agent_idx: int, action: int):
|
||||||
|
named_action = self._per_agent_idx_actions[agent_idx][action]
|
||||||
|
translated_action = self._target_named_action_space[named_action]
|
||||||
|
return translated_action
|
||||||
|
|
||||||
|
def translate_actions(self, actions: List[int]):
|
||||||
|
return [self.translate_action(idx, action) for idx, action in enumerate(actions)]
|
||||||
|
|
||||||
|
def __call__(self, actions):
|
||||||
|
return self.translate_actions(actions)
|
||||||
|
|
||||||
|
|
||||||
# Utility functions
|
# Utility functions
|
||||||
def parse_level(path):
|
def parse_level(path):
|
||||||
with path.open('r') as lvl:
|
with path.open('r') as lvl:
|
||||||
@ -118,17 +162,14 @@ def parse_level(path):
|
|||||||
return level
|
return level
|
||||||
|
|
||||||
|
|
||||||
def one_hot_level(level, wall_char: Union[c, str] = c.WALL):
|
def one_hot_level(level, wall_char: str = c.WALL):
|
||||||
grid = np.array(level)
|
grid = np.array(level)
|
||||||
binary_grid = np.zeros(grid.shape, dtype=np.int8)
|
binary_grid = np.zeros(grid.shape, dtype=np.int8)
|
||||||
if wall_char in c:
|
binary_grid[grid == wall_char] = c.OCCUPIED_CELL
|
||||||
binary_grid[grid == wall_char.value] = c.OCCUPIED_CELL.value
|
|
||||||
else:
|
|
||||||
binary_grid[grid == wall_char] = c.OCCUPIED_CELL.value
|
|
||||||
return binary_grid
|
return binary_grid
|
||||||
|
|
||||||
|
|
||||||
def check_position(slice_to_check_against: np.ndarray, position_to_check: Tuple[int, int]):
|
def check_position(slice_to_check_against: ArrayLike, position_to_check: Tuple[int, int]):
|
||||||
x_pos, y_pos = position_to_check
|
x_pos, y_pos = position_to_check
|
||||||
|
|
||||||
# Check if agent colides with grid boundrys
|
# Check if agent colides with grid boundrys
|
||||||
@ -145,19 +186,24 @@ def check_position(slice_to_check_against: np.ndarray, position_to_check: Tuple[
|
|||||||
|
|
||||||
def asset_str(agent):
|
def asset_str(agent):
|
||||||
# What does this abonimation do?
|
# What does this abonimation do?
|
||||||
# if any([x is None for x in [self._slices[j] for j in agent.collisions]]):
|
# if any([x is None for x in [cls._slices[j] for j in agent.collisions]]):
|
||||||
# print('error')
|
# print('error')
|
||||||
col_names = [x.name for x in agent.temp_collisions]
|
if step_result := agent.step_result:
|
||||||
if any(c.AGENT.value in name for name in col_names):
|
action = step_result['action_name']
|
||||||
return 'agent_collision', 'blank'
|
valid = step_result['action_valid']
|
||||||
elif not agent.temp_valid or c.LEVEL.name in col_names or c.AGENT.name in col_names:
|
col_names = [x.name for x in step_result['collisions']]
|
||||||
return c.AGENT.value, 'invalid'
|
if any(c.AGENT in name for name in col_names):
|
||||||
elif agent.temp_valid and not MovingAction.is_member(agent.temp_action):
|
return 'agent_collision', 'blank'
|
||||||
return c.AGENT.value, 'valid'
|
elif not valid or c.LEVEL in col_names or c.AGENT in col_names:
|
||||||
elif agent.temp_valid and MovingAction.is_member(agent.temp_action):
|
return c.AGENT, 'invalid'
|
||||||
return c.AGENT.value, 'move'
|
elif valid and not EnvActions.is_move(action):
|
||||||
|
return c.AGENT, 'valid'
|
||||||
|
elif valid and EnvActions.is_move(action):
|
||||||
|
return c.AGENT, 'move'
|
||||||
|
else:
|
||||||
|
return c.AGENT, 'idle'
|
||||||
else:
|
else:
|
||||||
return c.AGENT.value, 'idle'
|
return c.AGENT, 'idle'
|
||||||
|
|
||||||
|
|
||||||
def points_to_graph(coordiniates_or_tiles, allow_euclidean_connections=True, allow_manhattan_connections=True):
|
def points_to_graph(coordiniates_or_tiles, allow_euclidean_connections=True, allow_manhattan_connections=True):
|
||||||
@ -176,8 +222,3 @@ def points_to_graph(coordiniates_or_tiles, allow_euclidean_connections=True, all
|
|||||||
elif allow_manhattan_connections and not allow_euclidean_connections and not all(diff) and any(diff):
|
elif allow_manhattan_connections and not allow_euclidean_connections and not all(diff) and any(diff):
|
||||||
graph.add_edge(a, b)
|
graph.add_edge(a, b)
|
||||||
return graph
|
return graph
|
||||||
|
|
||||||
if __name__ == '__main__':
|
|
||||||
parsed_level = parse_level(Path(__file__).parent / 'factory' / 'levels' / 'simple.txt')
|
|
||||||
y = one_hot_level(parsed_level)
|
|
||||||
print(np.argwhere(y == 0))
|
|
||||||
|
@ -1,5 +1,6 @@
|
|||||||
import pickle
|
import pickle
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
|
from os import PathLike
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import List, Dict, Union
|
from typing import List, Dict, Union
|
||||||
|
|
||||||
@ -9,14 +10,17 @@ from environments.helpers import IGNORED_DF_COLUMNS
|
|||||||
|
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
|
|
||||||
|
from plotting.compare_runs import plot_single_run
|
||||||
|
|
||||||
|
|
||||||
class EnvMonitor(BaseCallback):
|
class EnvMonitor(BaseCallback):
|
||||||
|
|
||||||
ext = 'png'
|
ext = 'png'
|
||||||
|
|
||||||
def __init__(self, env):
|
def __init__(self, env, filepath: Union[str, PathLike] = None):
|
||||||
super(EnvMonitor, self).__init__()
|
super(EnvMonitor, self).__init__()
|
||||||
self.unwrapped = env
|
self.unwrapped = env
|
||||||
|
self._filepath = filepath
|
||||||
self._monitor_df = pd.DataFrame()
|
self._monitor_df = pd.DataFrame()
|
||||||
self._monitor_dicts = defaultdict(dict)
|
self._monitor_dicts = defaultdict(dict)
|
||||||
|
|
||||||
@ -43,7 +47,7 @@ class EnvMonitor(BaseCallback):
|
|||||||
self._read_info(env_idx, info)
|
self._read_info(env_idx, info)
|
||||||
|
|
||||||
for env_idx, done in list(
|
for env_idx, done in list(
|
||||||
enumerate(self.locals.get('dones', []))) + list(enumerate(self.locals.get('done', []))):
|
enumerate(self.locals.get('dones', []))): # + list(enumerate(self.locals.get('done', []))):
|
||||||
self._read_done(env_idx, done)
|
self._read_done(env_idx, done)
|
||||||
return True
|
return True
|
||||||
|
|
||||||
@ -67,8 +71,10 @@ class EnvMonitor(BaseCallback):
|
|||||||
pass
|
pass
|
||||||
return
|
return
|
||||||
|
|
||||||
def save_run(self, filepath: Union[Path, str]):
|
def save_run(self, filepath: Union[Path, str], auto_plotting_keys=None):
|
||||||
filepath = Path(filepath)
|
filepath = Path(filepath)
|
||||||
filepath.parent.mkdir(exist_ok=True, parents=True)
|
filepath.parent.mkdir(exist_ok=True, parents=True)
|
||||||
with filepath.open('wb') as f:
|
with filepath.open('wb') as f:
|
||||||
pickle.dump(self._monitor_df.reset_index(), f, protocol=pickle.HIGHEST_PROTOCOL)
|
pickle.dump(self._monitor_df.reset_index(), f, protocol=pickle.HIGHEST_PROTOCOL)
|
||||||
|
if auto_plotting_keys:
|
||||||
|
plot_single_run(filepath, column_keys=auto_plotting_keys)
|
||||||
|
@ -24,14 +24,12 @@ class EnvRecorder(BaseCallback):
|
|||||||
self._entities = [entities]
|
self._entities = [entities]
|
||||||
else:
|
else:
|
||||||
self._entities = entities
|
self._entities = entities
|
||||||
self.started = False
|
|
||||||
self.closed = False
|
|
||||||
|
|
||||||
def __getattr__(self, item):
|
def __getattr__(self, item):
|
||||||
return getattr(self.unwrapped, item)
|
return getattr(self.unwrapped, item)
|
||||||
|
|
||||||
def reset(self):
|
def reset(self):
|
||||||
self.unwrapped._record_episodes = True
|
self.unwrapped.start_recording()
|
||||||
return self.unwrapped.reset()
|
return self.unwrapped.reset()
|
||||||
|
|
||||||
def _on_training_start(self) -> None:
|
def _on_training_start(self) -> None:
|
||||||
@ -57,10 +55,18 @@ class EnvRecorder(BaseCallback):
|
|||||||
else:
|
else:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
def step(self, actions):
|
||||||
|
step_result = self.unwrapped.step(actions)
|
||||||
|
# 0, 1, 2 , 3 = idx
|
||||||
|
# _, _, done_bool, info_obj = step_result
|
||||||
|
self._read_info(0, step_result[3])
|
||||||
|
self._read_done(0, step_result[2])
|
||||||
|
return step_result
|
||||||
|
|
||||||
def save_records(self, filepath: Union[Path, str], save_occupation_map=False, save_trajectory_map=False):
|
def save_records(self, filepath: Union[Path, str], save_occupation_map=False, save_trajectory_map=False):
|
||||||
filepath = Path(filepath)
|
filepath = Path(filepath)
|
||||||
filepath.parent.mkdir(exist_ok=True, parents=True)
|
filepath.parent.mkdir(exist_ok=True, parents=True)
|
||||||
# self.out_file.unlink(missing_ok=True)
|
# cls.out_file.unlink(missing_ok=True)
|
||||||
with filepath.open('w') as f:
|
with filepath.open('w') as f:
|
||||||
out_dict = {'episodes': self._recorder_out_list, 'header': self.unwrapped.params}
|
out_dict = {'episodes': self._recorder_out_list, 'header': self.unwrapped.params}
|
||||||
try:
|
try:
|
||||||
|
@ -17,13 +17,15 @@ class MovementProperties(NamedTuple):
|
|||||||
|
|
||||||
|
|
||||||
class ObservationProperties(NamedTuple):
|
class ObservationProperties(NamedTuple):
|
||||||
|
# Todo: Add Description
|
||||||
render_agents: AgentRenderOptions = AgentRenderOptions.SEPERATE
|
render_agents: AgentRenderOptions = AgentRenderOptions.SEPERATE
|
||||||
omit_agent_self: bool = True
|
omit_agent_self: bool = True
|
||||||
additional_agent_placeholder: Union[None, str, int] = None
|
additional_agent_placeholder: Union[None, str, int] = None
|
||||||
cast_shadows = True
|
cast_shadows: bool = True
|
||||||
frames_to_stack: int = 0
|
frames_to_stack: int = 0
|
||||||
pomdp_r: int = 0
|
pomdp_r: int = 0
|
||||||
show_global_position_info: bool = True
|
indicate_door_area: bool = False
|
||||||
|
show_global_position_info: bool = False
|
||||||
|
|
||||||
|
|
||||||
class MarlFrameStack(gym.ObservationWrapper):
|
class MarlFrameStack(gym.ObservationWrapper):
|
||||||
@ -34,4 +36,3 @@ class MarlFrameStack(gym.ObservationWrapper):
|
|||||||
if isinstance(self.env, FrameStack) and self.env.unwrapped.n_agents > 1:
|
if isinstance(self.env, FrameStack) and self.env.unwrapped.n_agents > 1:
|
||||||
return observation[0:].swapaxes(0, 1)
|
return observation[0:].swapaxes(0, 1)
|
||||||
return observation
|
return observation
|
||||||
|
|
||||||
|
119
experiments/simple_example.py
Normal file
119
experiments/simple_example.py
Normal file
@ -0,0 +1,119 @@
|
|||||||
|
import warnings
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import yaml
|
||||||
|
|
||||||
|
from stable_baselines3 import PPO
|
||||||
|
|
||||||
|
from environments.factory.factory_dirt import DirtProperties, DirtFactory, RewardsDirt
|
||||||
|
from environments.logging.envmonitor import EnvMonitor
|
||||||
|
from environments.logging.recorder import EnvRecorder
|
||||||
|
from environments.utility_classes import MovementProperties, ObservationProperties, AgentRenderOptions
|
||||||
|
from environments.factory.factory_dirt import Constants as c
|
||||||
|
|
||||||
|
warnings.filterwarnings('ignore', category=FutureWarning)
|
||||||
|
warnings.filterwarnings('ignore', category=UserWarning)
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
TRAIN_AGENT = True
|
||||||
|
LOAD_AND_REPLAY = True
|
||||||
|
record = True
|
||||||
|
render = False
|
||||||
|
|
||||||
|
study_root_path = Path(__file__).parent.parent / 'experiment_out'
|
||||||
|
|
||||||
|
parameter_path = Path(__file__).parent.parent / 'environments' / 'factory' / 'levels' / 'parameters' / 'DirtyFactory-v0.yaml'
|
||||||
|
|
||||||
|
save_path = study_root_path / f'model.zip'
|
||||||
|
|
||||||
|
# Output folder
|
||||||
|
|
||||||
|
study_root_path.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
train_steps = 2*1e5
|
||||||
|
frames_to_stack = 0
|
||||||
|
|
||||||
|
u = dict(
|
||||||
|
show_global_position_info=True,
|
||||||
|
pomdp_r=3,
|
||||||
|
cast_shadows=True,
|
||||||
|
allow_diagonal_movement=False,
|
||||||
|
parse_doors=True,
|
||||||
|
doors_have_area=False,
|
||||||
|
done_at_collision=True
|
||||||
|
)
|
||||||
|
obs_props = ObservationProperties(render_agents=AgentRenderOptions.SEPERATE,
|
||||||
|
additional_agent_placeholder=None,
|
||||||
|
omit_agent_self=True,
|
||||||
|
frames_to_stack=frames_to_stack,
|
||||||
|
pomdp_r=u['pomdp_r'], cast_shadows=u['cast_shadows'],
|
||||||
|
show_global_position_info=u['show_global_position_info'])
|
||||||
|
move_props = MovementProperties(allow_diagonal_movement=u['allow_diagonal_movement'],
|
||||||
|
allow_square_movement=True,
|
||||||
|
allow_no_op=False)
|
||||||
|
dirt_props = DirtProperties(initial_dirt_ratio=0.35, initial_dirt_spawn_r_var=0.1,
|
||||||
|
clean_amount=0.34,
|
||||||
|
max_spawn_amount=0.1, max_global_amount=20,
|
||||||
|
max_local_amount=1, spawn_frequency=0, max_spawn_ratio=0.05,
|
||||||
|
dirt_smear_amount=0.0)
|
||||||
|
rewards_dirt = RewardsDirt(CLEAN_UP_FAIL=-0.5, CLEAN_UP_VALID=1, CLEAN_UP_LAST_PIECE=5)
|
||||||
|
factory_kwargs = dict(n_agents=1, max_steps=500, parse_doors=u['parse_doors'],
|
||||||
|
level_name='rooms', doors_have_area=u['doors_have_area'],
|
||||||
|
verbose=True,
|
||||||
|
mv_prop=move_props,
|
||||||
|
obs_prop=obs_props,
|
||||||
|
rewards_dirt=rewards_dirt,
|
||||||
|
done_at_collision=u['done_at_collision']
|
||||||
|
)
|
||||||
|
|
||||||
|
# with (parameter_path).open('r') as f:
|
||||||
|
# factory_kwargs = yaml.load(f, Loader=yaml.FullLoader)
|
||||||
|
# factory_kwargs.update(n_agents=1, done_at_collision=False, verbose=True)
|
||||||
|
|
||||||
|
if TRAIN_AGENT:
|
||||||
|
env = DirtFactory(**factory_kwargs)
|
||||||
|
callbacks = EnvMonitor(env)
|
||||||
|
obs_shape = env.observation_space.shape
|
||||||
|
|
||||||
|
model = PPO("MlpPolicy", env, verbose=1, device='cpu')
|
||||||
|
|
||||||
|
model.learn(total_timesteps=train_steps, callback=callbacks)
|
||||||
|
|
||||||
|
callbacks.save_run(study_root_path / 'monitor.pick', auto_plotting_keys=['step_reward', 'collision'] + ['cleanup_valid', 'cleanup_fail']) # + env_plot_keys)
|
||||||
|
|
||||||
|
|
||||||
|
model.save(save_path)
|
||||||
|
|
||||||
|
if LOAD_AND_REPLAY:
|
||||||
|
with DirtFactory(**factory_kwargs) as env:
|
||||||
|
env = EnvMonitor(env)
|
||||||
|
env = EnvRecorder(env) if record else env
|
||||||
|
obs_shape = env.observation_space.shape
|
||||||
|
model = PPO.load(save_path)
|
||||||
|
# Evaluation Loop for i in range(n Episodes)
|
||||||
|
for episode in range(10):
|
||||||
|
env_state = env.reset()
|
||||||
|
rew, done_bool = 0, False
|
||||||
|
while not done_bool:
|
||||||
|
actions = model.predict(env_state, deterministic=True)[0]
|
||||||
|
env_state, step_r, done_bool, info_obj = env.step(actions)
|
||||||
|
|
||||||
|
rew += step_r
|
||||||
|
|
||||||
|
if render:
|
||||||
|
env.render()
|
||||||
|
|
||||||
|
try:
|
||||||
|
door = next(x for x in env.unwrapped.unwrapped.unwrapped[c.DOORS] if x.is_open)
|
||||||
|
print('openDoor found')
|
||||||
|
except StopIteration:
|
||||||
|
pass
|
||||||
|
|
||||||
|
if done_bool:
|
||||||
|
break
|
||||||
|
print(
|
||||||
|
f'Factory run {episode} done, steps taken {env.unwrapped.unwrapped.unwrapped._steps}, reward is:\n {rew}')
|
||||||
|
|
||||||
|
env.save_records(study_root_path / 'reload_recorder.pick', save_occupation_map=False)
|
||||||
|
#env.save_run(study_root_path / 'reload_monitor.pick',
|
||||||
|
# auto_plotting_keys=['step_reward', 'cleanup_valid', 'cleanup_fail'])
|
@ -10,6 +10,45 @@ from environments.helpers import IGNORED_DF_COLUMNS, MODEL_MAP
|
|||||||
from plotting.plotting import prepare_plot
|
from plotting.plotting import prepare_plot
|
||||||
|
|
||||||
|
|
||||||
|
def plot_single_run(run_path: Union[str, PathLike], use_tex: bool = False, column_keys=None):
|
||||||
|
run_path = Path(run_path)
|
||||||
|
df_list = list()
|
||||||
|
if run_path.is_dir():
|
||||||
|
monitor_file = next(run_path.glob('*monitor*.pick'))
|
||||||
|
elif run_path.exists() and run_path.is_file():
|
||||||
|
monitor_file = run_path
|
||||||
|
else:
|
||||||
|
raise ValueError
|
||||||
|
|
||||||
|
with monitor_file.open('rb') as f:
|
||||||
|
monitor_df = pickle.load(f)
|
||||||
|
|
||||||
|
monitor_df = monitor_df.fillna(0)
|
||||||
|
df_list.append(monitor_df)
|
||||||
|
|
||||||
|
df = pd.concat(df_list, ignore_index=True)
|
||||||
|
df = df.fillna(0).rename(columns={'episode': 'Episode'}).sort_values(['Episode'])
|
||||||
|
if column_keys is not None:
|
||||||
|
columns = [col for col in column_keys if col in df.columns]
|
||||||
|
else:
|
||||||
|
columns = [col for col in df.columns if col not in IGNORED_DF_COLUMNS]
|
||||||
|
|
||||||
|
roll_n = 50
|
||||||
|
|
||||||
|
non_overlapp_window = df.groupby(['Episode']).rolling(roll_n, min_periods=1).mean()
|
||||||
|
|
||||||
|
df_melted = df[columns + ['Episode']].reset_index().melt(id_vars=['Episode'],
|
||||||
|
value_vars=columns, var_name="Measurement",
|
||||||
|
value_name="Score")
|
||||||
|
|
||||||
|
if df_melted['Episode'].max() > 800:
|
||||||
|
skip_n = round(df_melted['Episode'].max() * 0.02)
|
||||||
|
df_melted = df_melted[df_melted['Episode'] % skip_n == 0]
|
||||||
|
|
||||||
|
prepare_plot(run_path.parent / f'{run_path.parent.name}_monitor_lineplot.png', df_melted, use_tex=use_tex)
|
||||||
|
print('Plotting done.')
|
||||||
|
|
||||||
|
|
||||||
def compare_seed_runs(run_path: Union[str, PathLike], use_tex: bool = False):
|
def compare_seed_runs(run_path: Union[str, PathLike], use_tex: bool = False):
|
||||||
run_path = Path(run_path)
|
run_path = Path(run_path)
|
||||||
df_list = list()
|
df_list = list()
|
||||||
@ -37,7 +76,10 @@ def compare_seed_runs(run_path: Union[str, PathLike], use_tex: bool = False):
|
|||||||
skip_n = round(df_melted['Episode'].max() * 0.02)
|
skip_n = round(df_melted['Episode'].max() * 0.02)
|
||||||
df_melted = df_melted[df_melted['Episode'] % skip_n == 0]
|
df_melted = df_melted[df_melted['Episode'] % skip_n == 0]
|
||||||
|
|
||||||
prepare_plot(run_path / f'{run_path.name}_monitor_lineplot.png', df_melted, use_tex=use_tex)
|
if run_path.is_dir():
|
||||||
|
prepare_plot(run_path / f'{run_path}_monitor_lineplot.png', df_melted, use_tex=use_tex)
|
||||||
|
elif run_path.exists() and run_path.is_file():
|
||||||
|
prepare_plot(run_path.parent / f'{run_path.parent}_monitor_lineplot.png', df_melted, use_tex=use_tex)
|
||||||
print('Plotting done.')
|
print('Plotting done.')
|
||||||
|
|
||||||
|
|
||||||
|
@ -1,4 +1,5 @@
|
|||||||
import seaborn as sns
|
import seaborn as sns
|
||||||
|
import matplotlib as mpl
|
||||||
from matplotlib import pyplot as plt
|
from matplotlib import pyplot as plt
|
||||||
|
|
||||||
PALETTE = 10 * (
|
PALETTE = 10 * (
|
||||||
@ -21,7 +22,14 @@ PALETTE = 10 * (
|
|||||||
def plot(filepath, ext='png'):
|
def plot(filepath, ext='png'):
|
||||||
plt.tight_layout()
|
plt.tight_layout()
|
||||||
figure = plt.gcf()
|
figure = plt.gcf()
|
||||||
figure.savefig(str(filepath), format=ext)
|
ax = plt.gca()
|
||||||
|
legends = [c for c in ax.get_children() if isinstance(c, mpl.legend.Legend)]
|
||||||
|
|
||||||
|
if legends:
|
||||||
|
figure.savefig(str(filepath), format=ext, bbox_extra_artists=(*legends,), bbox_inches='tight')
|
||||||
|
else:
|
||||||
|
figure.savefig(str(filepath), format=ext)
|
||||||
|
|
||||||
plt.show()
|
plt.show()
|
||||||
plt.clf()
|
plt.clf()
|
||||||
|
|
||||||
@ -30,7 +38,7 @@ def prepare_tex(df, hue, style, hue_order):
|
|||||||
sns.set(rc={'text.usetex': True}, style='whitegrid')
|
sns.set(rc={'text.usetex': True}, style='whitegrid')
|
||||||
lineplot = sns.lineplot(data=df, x='Episode', y='Score', ci=95, palette=PALETTE,
|
lineplot = sns.lineplot(data=df, x='Episode', y='Score', ci=95, palette=PALETTE,
|
||||||
hue_order=hue_order, hue=hue, style=style)
|
hue_order=hue_order, hue=hue, style=style)
|
||||||
# lineplot.set_title(f'{sorted(list(df["Measurement"].unique()))}')
|
lineplot.set_title(f'{sorted(list(df["Measurement"].unique()))}')
|
||||||
plt.legend(bbox_to_anchor=(1.02, 1), loc='upper left', borderaxespad=0)
|
plt.legend(bbox_to_anchor=(1.02, 1), loc='upper left', borderaxespad=0)
|
||||||
plt.tight_layout()
|
plt.tight_layout()
|
||||||
return lineplot
|
return lineplot
|
||||||
@ -48,6 +56,19 @@ def prepare_plt(df, hue, style, hue_order):
|
|||||||
return lineplot
|
return lineplot
|
||||||
|
|
||||||
|
|
||||||
|
def prepare_center_double_column_legend(df, hue, style, hue_order):
|
||||||
|
print('Struggling to plot Figure using LaTeX - going back to normal.')
|
||||||
|
plt.close('all')
|
||||||
|
sns.set(rc={'text.usetex': False}, style='whitegrid')
|
||||||
|
fig = plt.figure(figsize=(10, 11))
|
||||||
|
lineplot = sns.lineplot(data=df, x='Episode', y='Score', hue=hue, style=style,
|
||||||
|
ci=95, palette=PALETTE, hue_order=hue_order, legend=False)
|
||||||
|
# plt.legend(bbox_to_anchor=(1.02, 1), loc='upper left', borderaxespad=0)
|
||||||
|
lineplot.legend(hue_order, ncol=3, loc='lower center', title='Parameter Combinations', bbox_to_anchor=(0.5, -0.43))
|
||||||
|
plt.tight_layout()
|
||||||
|
return lineplot
|
||||||
|
|
||||||
|
|
||||||
def prepare_plot(filepath, results_df, ext='png', hue='Measurement', style=None, use_tex=False):
|
def prepare_plot(filepath, results_df, ext='png', hue='Measurement', style=None, use_tex=False):
|
||||||
df = results_df.copy()
|
df = results_df.copy()
|
||||||
df[hue] = df[hue].str.replace('_', '-')
|
df[hue] = df[hue].str.replace('_', '-')
|
||||||
|
@ -1,13 +1,13 @@
|
|||||||
import warnings
|
import warnings
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
import yaml
|
import yaml
|
||||||
|
from stable_baselines3 import A2C, PPO, DQN
|
||||||
|
|
||||||
|
from environments.factory.factory_dirt import Constants as c
|
||||||
|
|
||||||
from environments import helpers as h
|
|
||||||
from environments.helpers import Constants as c
|
|
||||||
from environments.factory.factory_dirt import DirtFactory
|
from environments.factory.factory_dirt import DirtFactory
|
||||||
from environments.factory.combined_factories import DirtItemFactory
|
from environments.logging.envmonitor import EnvMonitor
|
||||||
from environments.logging.recorder import EnvRecorder
|
from environments.logging.recorder import EnvRecorder
|
||||||
|
|
||||||
warnings.filterwarnings('ignore', category=FutureWarning)
|
warnings.filterwarnings('ignore', category=FutureWarning)
|
||||||
@ -18,39 +18,41 @@ if __name__ == '__main__':
|
|||||||
|
|
||||||
determin = False
|
determin = False
|
||||||
render = True
|
render = True
|
||||||
record = True
|
record = False
|
||||||
seed = 67
|
verbose = True
|
||||||
|
seed = 13
|
||||||
n_agents = 1
|
n_agents = 1
|
||||||
out_path = Path('study_out/e_1_new_reward/no_obs/dirt/A2C_new_reward/0_A2C_new_reward')
|
# out_path = Path('study_out/e_1_new_reward/no_obs/dirt/A2C_new_reward/0_A2C_new_reward')
|
||||||
out_path_2 = Path('study_out/e_1_obs_stack_3_gae_0.25_n_steps_16/seperate_N/dirt/A2C_obs_stack_3_gae_0.25_n_steps_16/1_A2C_obs_stack_3_gae_0.25_n_steps_16')
|
out_path = Path('study_out/reload')
|
||||||
model_path = out_path
|
model_path = out_path
|
||||||
|
|
||||||
with (out_path / f'env_params.json').open('r') as f:
|
with (out_path / f'env_params.json').open('r') as f:
|
||||||
env_kwargs = yaml.load(f, Loader=yaml.FullLoader)
|
env_kwargs = yaml.load(f, Loader=yaml.FullLoader)
|
||||||
env_kwargs.update(additional_agent_placeholder=None, n_agents=n_agents, max_steps=150)
|
env_kwargs.update(n_agents=n_agents, done_at_collision=False, verbose=verbose)
|
||||||
if gain_amount := env_kwargs.get('dirt_prop', {}).get('gain_amount', None):
|
|
||||||
env_kwargs['dirt_prop']['max_spawn_amount'] = gain_amount
|
|
||||||
del env_kwargs['dirt_prop']['gain_amount']
|
|
||||||
|
|
||||||
env_kwargs.update(record_episodes=record, done_at_collision=True)
|
|
||||||
|
|
||||||
this_model = out_path / 'model.zip'
|
this_model = out_path / 'model.zip'
|
||||||
other_model = out_path / 'model.zip'
|
|
||||||
|
|
||||||
model_cls = next(val for key, val in h.MODEL_MAP.items() if key in out_path.parent.name)
|
model_cls = PPO # next(val for key, val in h.MODEL_MAP.items() if key in out_path.parent.name)
|
||||||
models = [model_cls.load(this_model)] # , model_cls.load(other_model)]
|
models = [model_cls.load(this_model)]
|
||||||
|
try:
|
||||||
|
# Legacy Cleanups
|
||||||
|
del env_kwargs['dirt_prop']['agent_can_interact']
|
||||||
|
env_kwargs['verbose'] = True
|
||||||
|
except KeyError:
|
||||||
|
pass
|
||||||
|
|
||||||
# Init Env
|
# Init Env
|
||||||
with DirtFactory(**env_kwargs) as env:
|
with DirtFactory(**env_kwargs) as env:
|
||||||
env = EnvRecorder(env)
|
env = EnvMonitor(env)
|
||||||
|
env = EnvRecorder(env) if record else env
|
||||||
obs_shape = env.observation_space.shape
|
obs_shape = env.observation_space.shape
|
||||||
# Evaluation Loop for i in range(n Episodes)
|
# Evaluation Loop for i in range(n Episodes)
|
||||||
for episode in range(50):
|
for episode in range(500):
|
||||||
env_state = env.reset()
|
env_state = env.reset()
|
||||||
rew, done_bool = 0, False
|
rew, done_bool = 0, False
|
||||||
while not done_bool:
|
while not done_bool:
|
||||||
if n_agents > 1:
|
if n_agents > 1:
|
||||||
actions = [model.predict(env_state[model_idx], deterministic=True)[0]
|
actions = [model.predict(env_state[model_idx], deterministic=determin)[0]
|
||||||
for model_idx, model in enumerate(models)]
|
for model_idx, model in enumerate(models)]
|
||||||
else:
|
else:
|
||||||
actions = models[0].predict(env_state, deterministic=determin)[0]
|
actions = models[0].predict(env_state, deterministic=determin)[0]
|
||||||
@ -59,7 +61,17 @@ if __name__ == '__main__':
|
|||||||
rew += step_r
|
rew += step_r
|
||||||
if render:
|
if render:
|
||||||
env.render()
|
env.render()
|
||||||
|
try:
|
||||||
|
door = next(x for x in env.unwrapped.unwrapped[c.DOORS] if x.is_open)
|
||||||
|
print('openDoor found')
|
||||||
|
except StopIteration:
|
||||||
|
pass
|
||||||
|
|
||||||
if done_bool:
|
if done_bool:
|
||||||
break
|
break
|
||||||
print(f'Factory run {episode} done, reward is:\n {rew}')
|
print(f'Factory run {episode} done, steps taken {env.unwrapped.unwrapped._steps}, reward is:\n {rew}')
|
||||||
|
env.save_run(out_path / 'reload_monitor.pick',
|
||||||
|
auto_plotting_keys=['step_reward', 'cleanup_valid', 'cleanup_fail'])
|
||||||
|
if record:
|
||||||
|
env.save_records(out_path / 'reload_recorder.pick', save_occupation_map=True)
|
||||||
print('all done')
|
print('all done')
|
||||||
|
@ -1,12 +1,14 @@
|
|||||||
numpy
|
numpy
|
||||||
scipy
|
scipy
|
||||||
tqdm
|
tqdm
|
||||||
|
pandas
|
||||||
seaborn>=0.11.1
|
seaborn>=0.11.1
|
||||||
matplotlib>=3.4.1
|
matplotlib>=3.3.4
|
||||||
stable-baselines3>=1.0
|
stable-baselines3>=1.0
|
||||||
pygame>=2.1.0
|
pygame>=2.1.0
|
||||||
gym>=0.18.0
|
gym>=0.18.0
|
||||||
networkx>=2.6.1
|
networkx>=2.6.3
|
||||||
simplejson>=3.17.5
|
simplejson>=3.17.5
|
||||||
PyYAML>=6.0
|
PyYAML>=6.0
|
||||||
git+https://github.com/facebookresearch/salina.git@main#egg=salina
|
einops
|
||||||
|
natsort
|
@ -1,7 +1,6 @@
|
|||||||
import sys
|
import sys
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from matplotlib import pyplot as plt
|
from matplotlib import pyplot as plt
|
||||||
import numpy as np
|
|
||||||
import itertools as it
|
import itertools as it
|
||||||
|
|
||||||
try:
|
try:
|
||||||
@ -16,8 +15,6 @@ except NameError:
|
|||||||
DIR = None
|
DIR = None
|
||||||
pass
|
pass
|
||||||
|
|
||||||
import time
|
|
||||||
|
|
||||||
import simplejson
|
import simplejson
|
||||||
from stable_baselines3.common.vec_env import SubprocVecEnv
|
from stable_baselines3.common.vec_env import SubprocVecEnv
|
||||||
|
|
||||||
@ -28,14 +25,12 @@ from environments.factory.factory_item import ItemProperties, ItemFactory
|
|||||||
from environments.logging.envmonitor import EnvMonitor
|
from environments.logging.envmonitor import EnvMonitor
|
||||||
from environments.utility_classes import MovementProperties, ObservationProperties, AgentRenderOptions
|
from environments.utility_classes import MovementProperties, ObservationProperties, AgentRenderOptions
|
||||||
import pickle
|
import pickle
|
||||||
from plotting.compare_runs import compare_seed_runs, compare_model_runs, compare_all_parameter_runs
|
from plotting.compare_runs import compare_seed_runs, compare_model_runs
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
import seaborn as sns
|
import seaborn as sns
|
||||||
|
|
||||||
import multiprocessing as mp
|
import multiprocessing as mp
|
||||||
|
|
||||||
# mp.set_start_method("spawn")
|
|
||||||
|
|
||||||
"""
|
"""
|
||||||
In this studie, we want to explore the macro behaviour of multi agents which are trained on the same task,
|
In this studie, we want to explore the macro behaviour of multi agents which are trained on the same task,
|
||||||
but never saw each other in training.
|
but never saw each other in training.
|
||||||
@ -72,10 +67,9 @@ n_agents = 4
|
|||||||
ood_monitor_file = f'e_1_{n_agents}_agents'
|
ood_monitor_file = f'e_1_{n_agents}_agents'
|
||||||
baseline_monitor_file = 'e_1_baseline'
|
baseline_monitor_file = 'e_1_baseline'
|
||||||
|
|
||||||
from stable_baselines3 import A2C
|
|
||||||
|
|
||||||
def policy_model_kwargs():
|
def policy_model_kwargs():
|
||||||
return dict() # gae_lambda=0.25, n_steps=16, max_grad_norm=0.25, use_rms_prop=True)
|
return dict() # gae_lambda=0.25, n_steps=16, max_grad_norm=0.25, use_rms_prop=True)
|
||||||
|
|
||||||
|
|
||||||
def dqn_model_kwargs():
|
def dqn_model_kwargs():
|
||||||
@ -198,7 +192,7 @@ if __name__ == '__main__':
|
|||||||
ood_run = True
|
ood_run = True
|
||||||
plotting = True
|
plotting = True
|
||||||
|
|
||||||
train_steps = 1e7
|
train_steps = 1e6
|
||||||
n_seeds = 3
|
n_seeds = 3
|
||||||
frames_to_stack = 3
|
frames_to_stack = 3
|
||||||
|
|
||||||
@ -222,7 +216,7 @@ if __name__ == '__main__':
|
|||||||
max_spawn_amount=0.1, max_global_amount=20,
|
max_spawn_amount=0.1, max_global_amount=20,
|
||||||
max_local_amount=1, spawn_frequency=0, max_spawn_ratio=0.05,
|
max_local_amount=1, spawn_frequency=0, max_spawn_ratio=0.05,
|
||||||
dirt_smear_amount=0.0, agent_can_interact=True)
|
dirt_smear_amount=0.0, agent_can_interact=True)
|
||||||
item_props = ItemProperties(n_items=10, agent_can_interact=True,
|
item_props = ItemProperties(n_items=10,
|
||||||
spawn_frequency=30, n_drop_off_locations=2,
|
spawn_frequency=30, n_drop_off_locations=2,
|
||||||
max_agent_inventory_capacity=15)
|
max_agent_inventory_capacity=15)
|
||||||
factory_kwargs = dict(n_agents=1, max_steps=400, parse_doors=True,
|
factory_kwargs = dict(n_agents=1, max_steps=400, parse_doors=True,
|
||||||
@ -434,8 +428,8 @@ if __name__ == '__main__':
|
|||||||
# Iteration
|
# Iteration
|
||||||
start_mp_baseline_run(env_map, policy_path)
|
start_mp_baseline_run(env_map, policy_path)
|
||||||
|
|
||||||
# for seed_path in (y for y in policy_path.iterdir() if y.is_dir()):
|
# for policy_path in (y for y in policy_path.iterdir() if y.is_dir()):
|
||||||
# load_model_run_baseline(seed_path)
|
# load_model_run_baseline(policy_path)
|
||||||
print('Baseline Tracking done')
|
print('Baseline Tracking done')
|
||||||
|
|
||||||
# Then iterate over every model and monitor "ood behavior" - "is it ood?"
|
# Then iterate over every model and monitor "ood behavior" - "is it ood?"
|
||||||
@ -448,11 +442,11 @@ if __name__ == '__main__':
|
|||||||
for policy_path in [x for x in env_path.iterdir() if x. is_dir()]:
|
for policy_path in [x for x in env_path.iterdir() if x. is_dir()]:
|
||||||
# FIXME: Pick random seed or iterate over available seeds
|
# FIXME: Pick random seed or iterate over available seeds
|
||||||
# First seed path version
|
# First seed path version
|
||||||
# seed_path = next((y for y in policy_path.iterdir() if y.is_dir()))
|
# policy_path = next((y for y in policy_path.iterdir() if y.is_dir()))
|
||||||
# Iteration
|
# Iteration
|
||||||
start_mp_study_run(env_map, policy_path)
|
start_mp_study_run(env_map, policy_path)
|
||||||
#for seed_path in (y for y in policy_path.iterdir() if y.is_dir()):
|
#for policy_path in (y for y in policy_path.iterdir() if y.is_dir()):
|
||||||
# load_model_run_study(seed_path, env_map[env_path.name][0], observation_modes[obs_mode])
|
# load_model_run_study(policy_path, env_map[env_path.name][0], observation_modes[obs_mode])
|
||||||
print('OOD Tracking Done')
|
print('OOD Tracking Done')
|
||||||
|
|
||||||
# Plotting
|
# Plotting
|
||||||
|
23
studies/normalization_study.py
Normal file
23
studies/normalization_study.py
Normal file
@ -0,0 +1,23 @@
|
|||||||
|
from algorithms.utils import Checkpointer
|
||||||
|
from pathlib import Path
|
||||||
|
from algorithms.utils import load_yaml_file, add_env_props, instantiate_class, load_class
|
||||||
|
#from algorithms.marl import LoopSNAC, LoopIAC, LoopSEAC
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
for i in range(0, 5):
|
||||||
|
for name in ['snac', 'mappo', 'iac', 'seac']:
|
||||||
|
study_root = Path(__file__).parent / name
|
||||||
|
cfg = load_yaml_file(study_root / f'{name}.yaml')
|
||||||
|
add_env_props(cfg)
|
||||||
|
|
||||||
|
env = instantiate_class(cfg['env'])
|
||||||
|
net = instantiate_class(cfg['agent'])
|
||||||
|
max_steps = cfg['algorithm']['max_steps']
|
||||||
|
n_steps = cfg['algorithm']['n_steps']
|
||||||
|
|
||||||
|
checkpointer = Checkpointer(f'{name}#{i}', study_root, cfg, max_steps, 50)
|
||||||
|
|
||||||
|
loop = load_class(cfg['method'])(cfg)
|
||||||
|
df = loop.train_loop(checkpointer)
|
||||||
|
|
22
studies/playground_file.py
Normal file
22
studies/playground_file.py
Normal file
@ -0,0 +1,22 @@
|
|||||||
|
import pandas as pd
|
||||||
|
from pathlib import Path
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
import seaborn as sns
|
||||||
|
|
||||||
|
dfs = []
|
||||||
|
for name in ['mappo']:
|
||||||
|
for c in range(5):
|
||||||
|
try:
|
||||||
|
study_root = Path(__file__).parent / name / f'{name}#{c}'
|
||||||
|
print(study_root)
|
||||||
|
df = pd.read_csv(study_root / 'results.csv', index_col=False)
|
||||||
|
df.reward = df.reward.rolling(100).mean()
|
||||||
|
df['method'] = name.upper()
|
||||||
|
dfs.append(df)
|
||||||
|
except Exception as e:
|
||||||
|
pass
|
||||||
|
|
||||||
|
df = pd.concat(dfs).reset_index()
|
||||||
|
sns.lineplot(data=df, x='steps', y='reward', hue='method', palette='husl', ci='sd', linewidth=1.5, err_style='bars')
|
||||||
|
plt.savefig('study.png')
|
||||||
|
print('saved image')
|
@ -1,139 +0,0 @@
|
|||||||
from salina.agents.gyma import AutoResetGymAgent
|
|
||||||
from salina.agents import Agents, TemporalAgent
|
|
||||||
from salina.rl.functional import _index, gae
|
|
||||||
import torch
|
|
||||||
import torch.nn as nn
|
|
||||||
from torch.distributions import Categorical
|
|
||||||
from salina import TAgent, Workspace, get_arguments, get_class, instantiate_class
|
|
||||||
from pathlib import Path
|
|
||||||
import numpy as np
|
|
||||||
from tqdm import tqdm
|
|
||||||
import time
|
|
||||||
from algorithms.utils import (
|
|
||||||
add_env_props,
|
|
||||||
load_yaml_file,
|
|
||||||
CombineActionsAgent,
|
|
||||||
AutoResetGymMultiAgent,
|
|
||||||
access_str,
|
|
||||||
AGENT_PREFIX, REWARD, CUMU_REWARD, OBS, SEP
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class A2CAgent(TAgent):
|
|
||||||
def __init__(self, observation_size, hidden_size, n_actions, agent_id):
|
|
||||||
super().__init__()
|
|
||||||
observation_size = np.prod(observation_size)
|
|
||||||
print(observation_size)
|
|
||||||
self.agent_id = agent_id
|
|
||||||
self.model = nn.Sequential(
|
|
||||||
nn.Flatten(),
|
|
||||||
nn.Linear(observation_size, hidden_size),
|
|
||||||
nn.ELU(),
|
|
||||||
nn.Linear(hidden_size, hidden_size),
|
|
||||||
nn.ELU(),
|
|
||||||
nn.Linear(hidden_size, hidden_size),
|
|
||||||
nn.ELU()
|
|
||||||
)
|
|
||||||
self.action_head = nn.Linear(hidden_size, n_actions)
|
|
||||||
self.critic_head = nn.Linear(hidden_size, 1)
|
|
||||||
|
|
||||||
def get_obs(self, t):
|
|
||||||
observation = self.get((f'env/{access_str(self.agent_id, OBS)}', t))
|
|
||||||
return observation
|
|
||||||
|
|
||||||
def forward(self, t, stochastic, **kwargs):
|
|
||||||
observation = self.get_obs(t)
|
|
||||||
features = self.model(observation)
|
|
||||||
scores = self.action_head(features)
|
|
||||||
probs = torch.softmax(scores, dim=-1)
|
|
||||||
critic = self.critic_head(features).squeeze(-1)
|
|
||||||
if stochastic:
|
|
||||||
action = torch.distributions.Categorical(probs).sample()
|
|
||||||
else:
|
|
||||||
action = probs.argmax(1)
|
|
||||||
self.set((f'{access_str(self.agent_id, "action")}', t), action)
|
|
||||||
self.set((f'{access_str(self.agent_id, "action_probs")}', t), probs)
|
|
||||||
self.set((f'{access_str(self.agent_id, "critic")}', t), critic)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
|
||||||
# Setup workspace
|
|
||||||
uid = time.time()
|
|
||||||
workspace = Workspace()
|
|
||||||
n_agents = 2
|
|
||||||
|
|
||||||
# load config
|
|
||||||
cfg = load_yaml_file(Path(__file__).parent / 'sat_mad.yaml')
|
|
||||||
add_env_props(cfg)
|
|
||||||
cfg['env'].update({'n_agents': n_agents})
|
|
||||||
|
|
||||||
# instantiate agent and env
|
|
||||||
env_agent = AutoResetGymMultiAgent(
|
|
||||||
get_class(cfg['env']),
|
|
||||||
get_arguments(cfg['env']),
|
|
||||||
n_envs=1
|
|
||||||
)
|
|
||||||
|
|
||||||
a2c_agents = [instantiate_class({**cfg['agent'],
|
|
||||||
'agent_id': agent_id})
|
|
||||||
for agent_id in range(n_agents)]
|
|
||||||
|
|
||||||
# combine agents
|
|
||||||
acquisition_agent = TemporalAgent(Agents(env_agent, *a2c_agents, CombineActionsAgent()))
|
|
||||||
acquisition_agent.seed(69)
|
|
||||||
|
|
||||||
# optimizers & other parameters
|
|
||||||
cfg_optim = cfg['algorithm']['optimizer']
|
|
||||||
optimizers = [get_class(cfg_optim)(a2c_agent.parameters(), **get_arguments(cfg_optim))
|
|
||||||
for a2c_agent in a2c_agents]
|
|
||||||
n_timesteps = cfg['algorithm']['n_timesteps']
|
|
||||||
|
|
||||||
# Decision making loop
|
|
||||||
best = -float('inf')
|
|
||||||
with tqdm(range(int(cfg['algorithm']['max_epochs'] / n_timesteps))) as pbar:
|
|
||||||
for epoch in pbar:
|
|
||||||
workspace.zero_grad()
|
|
||||||
if epoch > 0:
|
|
||||||
workspace.copy_n_last_steps(1)
|
|
||||||
acquisition_agent(workspace, t=1, n_steps=n_timesteps-1, stochastic=True)
|
|
||||||
else:
|
|
||||||
acquisition_agent(workspace, t=0, n_steps=n_timesteps, stochastic=True)
|
|
||||||
|
|
||||||
for agent_id in range(n_agents):
|
|
||||||
critic, done, action_probs, reward, action = workspace[
|
|
||||||
access_str(agent_id, 'critic'),
|
|
||||||
"env/done",
|
|
||||||
access_str(agent_id, 'action_probs'),
|
|
||||||
access_str(agent_id, 'reward', 'env/'),
|
|
||||||
access_str(agent_id, 'action')
|
|
||||||
]
|
|
||||||
td = gae(critic, reward, done, 0.98, 0.25)
|
|
||||||
td_error = td ** 2
|
|
||||||
critic_loss = td_error.mean()
|
|
||||||
entropy_loss = Categorical(action_probs).entropy().mean()
|
|
||||||
action_logp = _index(action_probs, action).log()
|
|
||||||
a2c_loss = action_logp[:-1] * td.detach()
|
|
||||||
a2c_loss = a2c_loss.mean()
|
|
||||||
loss = (
|
|
||||||
-0.001 * entropy_loss
|
|
||||||
+ 1.0 * critic_loss
|
|
||||||
- 0.1 * a2c_loss
|
|
||||||
)
|
|
||||||
optimizer = optimizers[agent_id]
|
|
||||||
optimizer.zero_grad()
|
|
||||||
loss.backward()
|
|
||||||
#torch.nn.utils.clip_grad_norm_(a2c_agents[agent_id].parameters(), .5)
|
|
||||||
optimizer.step()
|
|
||||||
|
|
||||||
# Compute the cumulated reward on final_state
|
|
||||||
rews = ''
|
|
||||||
for agent_i in range(n_agents):
|
|
||||||
creward = workspace['env/'+access_str(agent_i, CUMU_REWARD)]
|
|
||||||
creward = creward[done]
|
|
||||||
if creward.size()[0] > 0:
|
|
||||||
rews += f'{AGENT_PREFIX}{agent_i}: {creward.mean().item():.2f} | '
|
|
||||||
"""if cum_r > best:
|
|
||||||
torch.save(a2c_agent.state_dict(), Path(__file__).parent / f'agent_{uid}.pt')
|
|
||||||
best = cum_r"""
|
|
||||||
pbar.set_description(rews, refresh=True)
|
|
||||||
|
|
@ -1,27 +0,0 @@
|
|||||||
agent:
|
|
||||||
classname: studies.sat_mad.A2CAgent
|
|
||||||
observation_size: 4*5*5
|
|
||||||
hidden_size: 128
|
|
||||||
n_actions: 10
|
|
||||||
|
|
||||||
env:
|
|
||||||
classname: environments.factory.make
|
|
||||||
env_name: "DirtyFactory-v0"
|
|
||||||
n_agents: 1
|
|
||||||
pomdp_r: 2
|
|
||||||
max_steps: 400
|
|
||||||
stack_n_frames: 3
|
|
||||||
individual_rewards: True
|
|
||||||
|
|
||||||
algorithm:
|
|
||||||
max_epochs: 1000000
|
|
||||||
n_envs: 1
|
|
||||||
n_timesteps: 10
|
|
||||||
discount_factor: 0.99
|
|
||||||
entropy_coef: 0.01
|
|
||||||
critic_coef: 1.0
|
|
||||||
gae: 0.25
|
|
||||||
optimizer:
|
|
||||||
classname: torch.optim.Adam
|
|
||||||
lr: 0.0003
|
|
||||||
weight_decay: 0.0
|
|
266
studies/single_run_with_export.py
Normal file
266
studies/single_run_with_export.py
Normal file
@ -0,0 +1,266 @@
|
|||||||
|
import itertools
|
||||||
|
import sys
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
from stable_baselines3.common.vec_env import SubprocVecEnv
|
||||||
|
|
||||||
|
try:
|
||||||
|
# noinspection PyUnboundLocalVariable
|
||||||
|
if __package__ is None:
|
||||||
|
DIR = Path(__file__).resolve().parent
|
||||||
|
sys.path.insert(0, str(DIR.parent))
|
||||||
|
__package__ = DIR.name
|
||||||
|
else:
|
||||||
|
DIR = None
|
||||||
|
except NameError:
|
||||||
|
DIR = None
|
||||||
|
pass
|
||||||
|
|
||||||
|
import simplejson
|
||||||
|
from environments.helpers import ActionTranslator, ObservationTranslator
|
||||||
|
from environments.logging.recorder import EnvRecorder
|
||||||
|
from environments import helpers as h
|
||||||
|
from environments.factory.factory_dirt import DirtProperties, DirtFactory
|
||||||
|
from environments.factory.factory_item import ItemProperties, ItemFactory
|
||||||
|
from environments.factory.factory_dest import DestProperties, DestFactory, DestModeOptions
|
||||||
|
from environments.factory.combined_factories import DirtDestItemFactory
|
||||||
|
from environments.logging.envmonitor import EnvMonitor
|
||||||
|
from environments.utility_classes import MovementProperties, ObservationProperties, AgentRenderOptions
|
||||||
|
|
||||||
|
"""
|
||||||
|
In this studie, we want to export trained Agents for debugging purposes.
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
def encapsule_env_factory(env_fctry, env_kwrgs):
|
||||||
|
|
||||||
|
def _init():
|
||||||
|
with env_fctry(**env_kwrgs) as init_env:
|
||||||
|
return init_env
|
||||||
|
|
||||||
|
return _init
|
||||||
|
|
||||||
|
|
||||||
|
def load_model_run_baseline(policy_path, env_to_run):
|
||||||
|
# retrieve model class
|
||||||
|
model_cls = h.MODEL_MAP['A2C']
|
||||||
|
# Load both agents
|
||||||
|
model = model_cls.load(policy_path / 'model.zip', device='cpu')
|
||||||
|
# Load old env kwargs
|
||||||
|
with next(policy_path.glob('*params.json')).open('r') as f:
|
||||||
|
env_kwargs = simplejson.load(f)
|
||||||
|
env_kwargs.update(done_at_collision=True)
|
||||||
|
# Init Env
|
||||||
|
with env_to_run(**env_kwargs) as env_factory:
|
||||||
|
monitored_env_factory = EnvMonitor(env_factory)
|
||||||
|
recorded_env_factory = EnvRecorder(monitored_env_factory)
|
||||||
|
|
||||||
|
# Evaluation Loop for i in range(n Episodes)
|
||||||
|
for episode in range(5):
|
||||||
|
env_state = recorded_env_factory.reset()
|
||||||
|
rew, done_bool = 0, False
|
||||||
|
while not done_bool:
|
||||||
|
action = model.predict(env_state, deterministic=True)[0]
|
||||||
|
env_state, step_r, done_bool, info_obj = recorded_env_factory.step(action)
|
||||||
|
rew += step_r
|
||||||
|
if done_bool:
|
||||||
|
break
|
||||||
|
print(f'Factory run {episode} done, reward is:\n {rew}')
|
||||||
|
recorded_env_factory.save_run(filepath=policy_path / f'baseline_monitor.pick')
|
||||||
|
recorded_env_factory.save_records(filepath=policy_path / f'baseline_recorder.json')
|
||||||
|
|
||||||
|
|
||||||
|
def load_model_run_combined(root_path, env_to_run, env_kwargs):
|
||||||
|
# retrieve model class
|
||||||
|
model_cls = h.MODEL_MAP['A2C']
|
||||||
|
# Load both agents
|
||||||
|
models = [model_cls.load(model_zip, device='cpu') for model_zip in root_path.rglob('model.zip')]
|
||||||
|
# Load old env kwargs
|
||||||
|
env_kwargs = env_kwargs.copy()
|
||||||
|
env_kwargs.update(
|
||||||
|
n_agents=len(models),
|
||||||
|
done_at_collision=False)
|
||||||
|
|
||||||
|
# Init Env
|
||||||
|
with env_to_run(**env_kwargs) as env_factory:
|
||||||
|
|
||||||
|
action_translator = ActionTranslator(env_factory.named_action_space,
|
||||||
|
*[x.named_action_space for x in models])
|
||||||
|
observation_translator = ObservationTranslator(env_factory.observation_space.shape[-2:],
|
||||||
|
env_factory.named_observation_space,
|
||||||
|
*[x.named_observation_space for x in models])
|
||||||
|
|
||||||
|
env = EnvMonitor(env_factory)
|
||||||
|
# Evaluation Loop for i in range(n Episodes)
|
||||||
|
for episode in range(5):
|
||||||
|
env_state = env.reset()
|
||||||
|
rew, done_bool = 0, False
|
||||||
|
while not done_bool:
|
||||||
|
translated_observations = observation_translator(env_state)
|
||||||
|
actions = [model.predict(translated_observations[model_idx], deterministic=True)[0]
|
||||||
|
for model_idx, model in enumerate(models)]
|
||||||
|
translated_actions = action_translator(actions)
|
||||||
|
env_state, step_r, done_bool, info_obj = env.step(translated_actions)
|
||||||
|
rew += step_r
|
||||||
|
if done_bool:
|
||||||
|
break
|
||||||
|
print(f'Factory run {episode} done, reward is:\n {rew}')
|
||||||
|
env.save_run(filepath=root_path / f'monitor_combined.pick')
|
||||||
|
# env.save_records(filepath=root_path / f'recorder_combined.json')
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
# What to do:
|
||||||
|
train = True
|
||||||
|
individual_run = False
|
||||||
|
combined_run = False
|
||||||
|
multi_env = False
|
||||||
|
|
||||||
|
train_steps = 1e6
|
||||||
|
frames_to_stack = 3
|
||||||
|
|
||||||
|
# Define a global studi save path
|
||||||
|
paremters_of_interest = dict(
|
||||||
|
show_global_position_info=[True, False],
|
||||||
|
pomdp_r=[3],
|
||||||
|
cast_shadows=[True, False],
|
||||||
|
allow_diagonal_movement=[True],
|
||||||
|
parse_doors=[True, False],
|
||||||
|
doors_have_area=[True, False],
|
||||||
|
done_at_collision=[True, False]
|
||||||
|
)
|
||||||
|
keys, vals = zip(*paremters_of_interest.items())
|
||||||
|
|
||||||
|
# Then we find all permutations for those values
|
||||||
|
p = list(itertools.product(*vals))
|
||||||
|
|
||||||
|
# Finally we can create out list of dicts
|
||||||
|
result = [{keys[index]: entry[index] for index in range(len(entry))} for entry in p]
|
||||||
|
|
||||||
|
for u in result:
|
||||||
|
file_name = '_'.join('_'.join([str(y)[0] for y in x]) for x in u.items())
|
||||||
|
study_root_path = Path(__file__).parent.parent / 'study_out' / file_name
|
||||||
|
|
||||||
|
# Model Kwargs
|
||||||
|
policy_model_kwargs = dict(ent_coef=0.01)
|
||||||
|
|
||||||
|
# Define Global Env Parameters
|
||||||
|
# Define properties object parameters
|
||||||
|
obs_props = ObservationProperties(render_agents=AgentRenderOptions.NOT,
|
||||||
|
additional_agent_placeholder=None,
|
||||||
|
omit_agent_self=True,
|
||||||
|
frames_to_stack=frames_to_stack,
|
||||||
|
pomdp_r=u['pomdp_r'], cast_shadows=u['cast_shadows'],
|
||||||
|
show_global_position_info=u['show_global_position_info'])
|
||||||
|
move_props = MovementProperties(allow_diagonal_movement=u['allow_diagonal_movement'],
|
||||||
|
allow_square_movement=True,
|
||||||
|
allow_no_op=False)
|
||||||
|
dirt_props = DirtProperties(initial_dirt_ratio=0.35, initial_dirt_spawn_r_var=0.1,
|
||||||
|
clean_amount=0.34,
|
||||||
|
max_spawn_amount=0.1, max_global_amount=20,
|
||||||
|
max_local_amount=1, spawn_frequency=0, max_spawn_ratio=0.05,
|
||||||
|
dirt_smear_amount=0.0)
|
||||||
|
item_props = ItemProperties(n_items=10, spawn_frequency=30, n_drop_off_locations=2,
|
||||||
|
max_agent_inventory_capacity=15)
|
||||||
|
dest_props = DestProperties(n_dests=4, spawn_mode=DestModeOptions.GROUPED, spawn_frequency=1)
|
||||||
|
factory_kwargs = dict(n_agents=1, max_steps=500, parse_doors=u['parse_doors'],
|
||||||
|
level_name='rooms', doors_have_area=u['doors_have_area'],
|
||||||
|
verbose=False,
|
||||||
|
mv_prop=move_props,
|
||||||
|
obs_prop=obs_props,
|
||||||
|
done_at_collision=u['done_at_collision']
|
||||||
|
)
|
||||||
|
|
||||||
|
# Bundle both environments with global kwargs and parameters
|
||||||
|
env_map = {}
|
||||||
|
env_map.update({'dirt': (DirtFactory, dict(dirt_prop=dirt_props,
|
||||||
|
**factory_kwargs.copy()),
|
||||||
|
['cleanup_valid', 'cleanup_fail'])})
|
||||||
|
# env_map.update({'item': (ItemFactory, dict(item_prop=item_props,
|
||||||
|
# **factory_kwargs.copy()),
|
||||||
|
# ['DROPOFF_FAIL', 'ITEMACTION_FAIL', 'DROPOFF_VALID', 'ITEMACTION_VALID'])})
|
||||||
|
# env_map.update({'dest': (DestFactory, dict(dest_prop=dest_props,
|
||||||
|
# **factory_kwargs.copy()))})
|
||||||
|
env_map.update({'combined': (DirtDestItemFactory, dict(dest_prop=dest_props,
|
||||||
|
item_prop=item_props,
|
||||||
|
dirt_prop=dirt_props,
|
||||||
|
**factory_kwargs.copy()))})
|
||||||
|
env_names = list(env_map.keys())
|
||||||
|
|
||||||
|
# Train starts here ############################################################
|
||||||
|
# Build Major Loop parameters, parameter versions, Env Classes and models
|
||||||
|
if train:
|
||||||
|
for env_key in (env_key for env_key in env_map if 'combined' != env_key):
|
||||||
|
model_cls = h.MODEL_MAP['PPO']
|
||||||
|
combination_path = study_root_path / env_key
|
||||||
|
env_class, env_kwargs, env_plot_keys = env_map[env_key]
|
||||||
|
|
||||||
|
# Output folder
|
||||||
|
if (combination_path / 'monitor.pick').exists():
|
||||||
|
continue
|
||||||
|
combination_path.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
if not multi_env:
|
||||||
|
env_factory = encapsule_env_factory(env_class, env_kwargs)()
|
||||||
|
else:
|
||||||
|
env_factory = SubprocVecEnv([encapsule_env_factory(env_class, env_kwargs)
|
||||||
|
for _ in range(6)], start_method="spawn")
|
||||||
|
|
||||||
|
param_path = combination_path / f'env_params.json'
|
||||||
|
try:
|
||||||
|
env_factory.env_method('save_params', param_path)
|
||||||
|
except AttributeError:
|
||||||
|
env_factory.save_params(param_path)
|
||||||
|
|
||||||
|
# EnvMonitor Init
|
||||||
|
callbacks = [EnvMonitor(env_factory)]
|
||||||
|
|
||||||
|
# Model Init
|
||||||
|
model = model_cls("MlpPolicy", env_factory, **policy_model_kwargs,
|
||||||
|
verbose=1, seed=69, device='cpu')
|
||||||
|
|
||||||
|
# Model train
|
||||||
|
model.learn(total_timesteps=int(train_steps), callback=callbacks)
|
||||||
|
|
||||||
|
# Model save
|
||||||
|
try:
|
||||||
|
model.named_action_space = env_factory.unwrapped.named_action_space
|
||||||
|
model.named_observation_space = env_factory.unwrapped.named_observation_space
|
||||||
|
except AttributeError:
|
||||||
|
model.named_action_space = env_factory.get_attr("named_action_space")[0]
|
||||||
|
model.named_observation_space = env_factory.get_attr("named_observation_space")[0]
|
||||||
|
save_path = combination_path / f'model.zip'
|
||||||
|
model.save(save_path)
|
||||||
|
|
||||||
|
# Monitor Save
|
||||||
|
callbacks[0].save_run(combination_path / 'monitor.pick',
|
||||||
|
auto_plotting_keys=['step_reward', 'collision'] + env_plot_keys)
|
||||||
|
|
||||||
|
# Better be save then sorry: Clean up!
|
||||||
|
del env_factory, model
|
||||||
|
import gc
|
||||||
|
gc.collect()
|
||||||
|
|
||||||
|
# Train ends here ############################################################
|
||||||
|
|
||||||
|
# Evaluation starts here #####################################################
|
||||||
|
# First Iterate over every model and monitor "as trained"
|
||||||
|
if individual_run:
|
||||||
|
print('Start Individual Recording')
|
||||||
|
for env_key in (env_key for env_key in env_map if 'combined' != env_key):
|
||||||
|
# For trained policy in study_root_path / identifier
|
||||||
|
policy_path = study_root_path / env_key
|
||||||
|
load_model_run_baseline(policy_path, env_map[policy_path.name][0])
|
||||||
|
|
||||||
|
# for policy_path in (y for y in policy_path.iterdir() if y.is_dir()):
|
||||||
|
# load_model_run_baseline(policy_path)
|
||||||
|
print('Done Individual Recording')
|
||||||
|
|
||||||
|
# Then iterate over every model and monitor "ood behavior" - "is it ood?"
|
||||||
|
if combined_run:
|
||||||
|
print('Start combined run')
|
||||||
|
for env_key in (env_key for env_key in env_map if 'combined' == env_key):
|
||||||
|
# For trained policy in study_root_path / identifier
|
||||||
|
factory, kwargs = env_map[env_key]
|
||||||
|
load_model_run_combined(study_root_path, factory, kwargs)
|
||||||
|
print('OOD Tracking Done')
|
36
studies/viz_policy.py
Normal file
36
studies/viz_policy.py
Normal file
@ -0,0 +1,36 @@
|
|||||||
|
import pandas as pd
|
||||||
|
from algorithms.marl import LoopSNAC, LoopIAC, LoopSEAC
|
||||||
|
from pathlib import Path
|
||||||
|
from algorithms.utils import load_yaml_file
|
||||||
|
from tqdm import trange
|
||||||
|
study = 'example_config#0'
|
||||||
|
#study_root = Path(__file__).parent / study
|
||||||
|
study_root = Path('/Users/romue/PycharmProjects/EDYS/algorithms/marl/')
|
||||||
|
|
||||||
|
#['L2NoAh_gru', 'L2NoCh_gru', 'nomix_gru']:
|
||||||
|
render = True
|
||||||
|
eval_eps = 3
|
||||||
|
for run in range(0, 5):
|
||||||
|
for name in ['example_config']:#['L2OnlyAh_gru', 'L2OnlyChAh_gru', 'L2OnlyMix_gru']: #['layernorm_gru', 'basic_gru', 'nonorm_gru', 'spectralnorm_gru']:
|
||||||
|
cfg = load_yaml_file(study_root / study / 'config.yaml')
|
||||||
|
#p_root = Path(study_root / study / f'{name}#{run}')
|
||||||
|
dfs = []
|
||||||
|
for i in trange(500):
|
||||||
|
path = study_root / study / f'checkpoint_{161}'
|
||||||
|
print(path)
|
||||||
|
|
||||||
|
snac = LoopSEAC(cfg)
|
||||||
|
snac.load_state_dict(path)
|
||||||
|
snac.eval()
|
||||||
|
|
||||||
|
df = snac.eval_loop(render=render, n_episodes=eval_eps)
|
||||||
|
df['checkpoint'] = i
|
||||||
|
dfs.append(df)
|
||||||
|
|
||||||
|
results = pd.concat(dfs)
|
||||||
|
results['run'] = run
|
||||||
|
results.to_csv(p_root / 'results.csv', index=False)
|
||||||
|
|
||||||
|
#sns.lineplot(data=results, x='checkpoint', y='reward', hue='agent', palette='husl')
|
||||||
|
|
||||||
|
#plt.savefig(f'{experiment_name}.png')
|
@ -1,39 +0,0 @@
|
|||||||
from salina.agents import Agents, TemporalAgent
|
|
||||||
import torch
|
|
||||||
from salina import Workspace, get_arguments, get_class, instantiate_class
|
|
||||||
from pathlib import Path
|
|
||||||
from salina.agents.gyma import GymAgent
|
|
||||||
import time
|
|
||||||
from algorithms.utils import load_yaml_file, add_env_props
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
|
||||||
# Setup workspace
|
|
||||||
uid = time.time()
|
|
||||||
workspace = Workspace()
|
|
||||||
weights = Path('/Users/romue/PycharmProjects/EDYS/studies/agent_1636994369.145843.pt')
|
|
||||||
|
|
||||||
cfg = load_yaml_file(Path(__file__).parent / 'sat_mad.yaml')
|
|
||||||
add_env_props(cfg)
|
|
||||||
cfg['env'].update({'n_agents': 2})
|
|
||||||
|
|
||||||
# instantiate agent and env
|
|
||||||
env_agent = GymAgent(
|
|
||||||
get_class(cfg['env']),
|
|
||||||
get_arguments(cfg['env']),
|
|
||||||
n_envs=1
|
|
||||||
)
|
|
||||||
|
|
||||||
agents = []
|
|
||||||
for _ in range(2):
|
|
||||||
a2c_agent = instantiate_class(cfg['agent'])
|
|
||||||
if weights:
|
|
||||||
a2c_agent.load_state_dict(torch.load(weights))
|
|
||||||
agents.append(a2c_agent)
|
|
||||||
|
|
||||||
# combine agents
|
|
||||||
acquisition_agent = TemporalAgent(Agents(env_agent, *agents))
|
|
||||||
acquisition_agent.seed(42)
|
|
||||||
|
|
||||||
acquisition_agent(workspace, t=0, n_steps=400, stochastic=False, save_render=True)
|
|
||||||
|
|
||||||
|
|
Reference in New Issue
Block a user