mirror of
https://github.com/illiumst/marl-factory-grid.git
synced 2025-11-02 13:37:27 +01:00
new rules, new spawn logic, small fixes, default and narrow corridor debugged
This commit is contained in:
@@ -0,0 +1,3 @@
|
||||
from . import helpers as h
|
||||
from . import helpers
|
||||
from .results import Result, DoneResult, ActionResult, TickResult
|
||||
|
||||
@@ -1,28 +1,24 @@
|
||||
import ast
|
||||
from collections import defaultdict
|
||||
|
||||
from os import PathLike
|
||||
from pathlib import Path
|
||||
from typing import Union
|
||||
from typing import Union, List
|
||||
|
||||
import yaml
|
||||
|
||||
from marl_factory_grid.environment.groups.agents import Agents
|
||||
from marl_factory_grid.environment.entity.agent import Agent
|
||||
from marl_factory_grid.environment.rules import Rule
|
||||
from marl_factory_grid.utils.helpers import locate_and_import_class
|
||||
from marl_factory_grid.environment.constants import DEFAULT_PATH, MODULE_PATH
|
||||
from marl_factory_grid.environment import constants as c
|
||||
|
||||
DEFAULT_PATH = 'environment'
|
||||
MODULE_PATH = 'modules'
|
||||
|
||||
|
||||
class FactoryConfigParser(object):
|
||||
default_entites = []
|
||||
default_rules = ['MaxStepsReached', 'Collision']
|
||||
default_rules = ['DoneAtMaxStepsReached', 'WatchCollision']
|
||||
default_actions = [c.MOVE8, c.NOOP]
|
||||
default_observations = [c.WALLS, c.AGENT]
|
||||
|
||||
def __init__(self, config_path, custom_modules_path: Union[None, PathLike] = None):
|
||||
def __init__(self, config_path, custom_modules_path: Union[PathLike] = None):
|
||||
self.config_path = Path(config_path)
|
||||
self.custom_modules_path = Path(custom_modules_path) if custom_modules_path is not None else custom_modules_path
|
||||
self.config = yaml.safe_load(self.config_path.open())
|
||||
@@ -46,6 +42,10 @@ class FactoryConfigParser(object):
|
||||
def rules(self):
|
||||
return self.config['Rules']
|
||||
|
||||
@property
|
||||
def tests(self):
|
||||
return self.config.get('Tests', [])
|
||||
|
||||
@property
|
||||
def agents(self):
|
||||
return self.config['Agents']
|
||||
@@ -61,7 +61,6 @@ class FactoryConfigParser(object):
|
||||
return self.config[item]
|
||||
|
||||
def load_entities(self):
|
||||
# entites = Entities()
|
||||
entity_classes = dict()
|
||||
entities = []
|
||||
if c.DEFAULTS in self.entities:
|
||||
@@ -69,28 +68,40 @@ class FactoryConfigParser(object):
|
||||
entities.extend(x for x in self.entities if x != c.DEFAULTS)
|
||||
|
||||
for entity in entities:
|
||||
e1 = e2 = e3 = None
|
||||
try:
|
||||
folder_path = Path(__file__).parent.parent / DEFAULT_PATH
|
||||
entity_class = locate_and_import_class(entity, folder_path)
|
||||
except AttributeError as e1:
|
||||
except AttributeError as e:
|
||||
e1 = e
|
||||
try:
|
||||
folder_path = Path(__file__).parent.parent / MODULE_PATH
|
||||
entity_class = locate_and_import_class(entity, folder_path)
|
||||
except AttributeError as e2:
|
||||
try:
|
||||
folder_path = self.custom_modules_path
|
||||
entity_class = locate_and_import_class(entity, folder_path)
|
||||
except AttributeError as e3:
|
||||
ents = [y for x in [e1.argss[1], e2.argss[1], e3.argss[1]] for y in x]
|
||||
print('### Error ### Error ### Error ### Error ### Error ###')
|
||||
print()
|
||||
print(f'Class "{entity}" was not found in "{folder_path.name}"')
|
||||
print('Possible Entitys are:', str(ents))
|
||||
print()
|
||||
print('Goodbye')
|
||||
print()
|
||||
exit()
|
||||
# raise AttributeError(e1.argss[0], e2.argss[0], e3.argss[0], 'Possible Entitys are:', str(ents))
|
||||
module_path = Path(__file__).parent.parent / MODULE_PATH
|
||||
entity_class = locate_and_import_class(entity, module_path)
|
||||
except AttributeError as e:
|
||||
e2 = e
|
||||
if self.custom_modules_path:
|
||||
try:
|
||||
entity_class = locate_and_import_class(entity, self.custom_modules_path)
|
||||
except AttributeError as e:
|
||||
e3 = e
|
||||
pass
|
||||
if (e1 and e2) or e3:
|
||||
ents = [y for x in [e1, e2, e3] if x is not None for y in x.args[1]]
|
||||
print('##############################################################')
|
||||
print('### Error ### Error ### Error ### Error ### Error ###')
|
||||
print('##############################################################')
|
||||
print(f'Class "{entity}" was not found in "{module_path.name}"')
|
||||
print(f'Class "{entity}" was not found in "{folder_path.name}"')
|
||||
print('##############################################################')
|
||||
if self.custom_modules_path:
|
||||
print(f'Class "{entity}" was not found in "{self.custom_modules_path}"')
|
||||
print('Possible Entitys are:', str(ents))
|
||||
print('##############################################################')
|
||||
print('Goodbye')
|
||||
print('##############################################################')
|
||||
print('### Error ### Error ### Error ### Error ### Error ###')
|
||||
print('##############################################################')
|
||||
exit(-99999)
|
||||
|
||||
entity_kwargs = self.entities.get(entity, {})
|
||||
entity_symbol = entity_class.symbol if hasattr(entity_class, 'symbol') else None
|
||||
@@ -128,31 +139,86 @@ class FactoryConfigParser(object):
|
||||
observations.extend(self.default_observations)
|
||||
observations.extend(x for x in self.agents[name]['Observations'] if x != c.DEFAULTS)
|
||||
positions = [ast.literal_eval(x) for x in self.agents[name].get('Positions', [])]
|
||||
parsed_agents_conf[name] = dict(actions=parsed_actions, observations=observations, positions=positions)
|
||||
other_kwargs = {k: v for k, v in self.agents[name].items() if k not in
|
||||
['Actions', 'Observations', 'Positions']}
|
||||
parsed_agents_conf[name] = dict(
|
||||
actions=parsed_actions, observations=observations, positions=positions, other=other_kwargs
|
||||
)
|
||||
|
||||
return parsed_agents_conf
|
||||
|
||||
def load_rules(self):
|
||||
# entites = Entities()
|
||||
rules_classes = dict()
|
||||
rules = []
|
||||
def load_env_rules(self) -> List[Rule]:
|
||||
rules = self.rules.copy()
|
||||
if c.DEFAULTS in self.rules:
|
||||
for rule in self.default_rules:
|
||||
if rule not in rules:
|
||||
rules.append(rule)
|
||||
rules.extend(x for x in self.rules if x != c.DEFAULTS)
|
||||
rules.append({rule: {}})
|
||||
|
||||
for rule in rules:
|
||||
return self._load_smth(rules, Rule)
|
||||
|
||||
def load_env_tests(self) -> List[Rule]:
|
||||
return self._load_smth(self.tests, None) # Test
|
||||
|
||||
def _load_smth(self, config, class_obj):
|
||||
rules = list()
|
||||
rules_names = list()
|
||||
for rule in config:
|
||||
e1 = e2 = e3 = None
|
||||
try:
|
||||
folder_path = (Path(__file__).parent.parent / DEFAULT_PATH)
|
||||
rule_class = locate_and_import_class(rule, folder_path)
|
||||
except AttributeError:
|
||||
except AttributeError as e:
|
||||
e1 = e
|
||||
try:
|
||||
folder_path = (Path(__file__).parent.parent / MODULE_PATH)
|
||||
rule_class = locate_and_import_class(rule, folder_path)
|
||||
module_path = (Path(__file__).parent.parent / MODULE_PATH)
|
||||
rule_class = locate_and_import_class(rule, module_path)
|
||||
except AttributeError as e:
|
||||
e2 = e
|
||||
if self.custom_modules_path:
|
||||
try:
|
||||
rule_class = locate_and_import_class(rule, self.custom_modules_path)
|
||||
except AttributeError as e:
|
||||
e3 = e
|
||||
pass
|
||||
if (e1 and e2) or e3:
|
||||
ents = [y for x in [e1, e2, e3] if x is not None for y in x.args[1]]
|
||||
print('### Error ### Error ### Error ### Error ### Error ###')
|
||||
print('')
|
||||
print(f'Class "{rule}" was not found in "{module_path.name}"')
|
||||
print(f'Class "{rule}" was not found in "{folder_path.name}"')
|
||||
if self.custom_modules_path:
|
||||
print(f'Class "{rule}" was not found in "{self.custom_modules_path}"')
|
||||
print('Possible Entitys are:', str(ents))
|
||||
print('')
|
||||
print('Goodbye')
|
||||
print('')
|
||||
exit(-99999)
|
||||
|
||||
if issubclass(rule_class, class_obj):
|
||||
rule_kwargs = config.get(rule, {})
|
||||
rules.append(rule_class(**(rule_kwargs or {})))
|
||||
return rules
|
||||
|
||||
def load_entity_spawn_rules(self, entities) -> List[Rule]:
|
||||
rules = list()
|
||||
rules_dicts = list()
|
||||
for e in entities:
|
||||
try:
|
||||
if spawn_rule := e.spawn_rule:
|
||||
rules_dicts.append(spawn_rule)
|
||||
except AttributeError:
|
||||
pass
|
||||
|
||||
for rule_dict in rules_dicts:
|
||||
for rule_name, rule_kwargs in rule_dict.items():
|
||||
try:
|
||||
folder_path = (Path(__file__).parent.parent / DEFAULT_PATH)
|
||||
rule_class = locate_and_import_class(rule_name, folder_path)
|
||||
except AttributeError:
|
||||
rule_class = locate_and_import_class(rule, self.custom_modules_path)
|
||||
# Fixme This check does not work!
|
||||
# assert isinstance(rule_class, Rule), f'{rule_class.__name__} is no valid "Rule".'
|
||||
rule_kwargs = self.rules.get(rule, {})
|
||||
rules_classes.update({rule: {'class': rule_class, 'kwargs': rule_kwargs}})
|
||||
return rules_classes
|
||||
try:
|
||||
folder_path = (Path(__file__).parent.parent / MODULE_PATH)
|
||||
rule_class = locate_and_import_class(rule_name, folder_path)
|
||||
except AttributeError:
|
||||
rule_class = locate_and_import_class(rule_name, self.custom_modules_path)
|
||||
rules.append(rule_class(**rule_kwargs))
|
||||
return rules
|
||||
|
||||
@@ -2,7 +2,7 @@ import importlib
|
||||
|
||||
from collections import defaultdict
|
||||
from pathlib import PurePath, Path
|
||||
from typing import Union, Dict, List
|
||||
from typing import Union, Dict, List, Iterable, Callable
|
||||
|
||||
import numpy as np
|
||||
from numpy.typing import ArrayLike
|
||||
@@ -222,7 +222,7 @@ def locate_and_import_class(class_name, folder_path: Union[str, PurePath] = ''):
|
||||
mod = importlib.import_module('.'.join(module_parts))
|
||||
all_found_modules.extend([x for x in dir(mod) if (not(x.startswith('__') or len(x) <= 2) and x.istitle())
|
||||
and x not in ['Entity', 'NamedTuple', 'List', 'Rule', 'Union',
|
||||
'TickResult', 'ActionResult', 'Action', 'Agent', 'BoundEntityMixin',
|
||||
'TickResult', 'ActionResult', 'Action', 'Agent',
|
||||
'RenderEntity', 'TemplateRule', 'Objects', 'PositionMixin',
|
||||
'IsBoundMixin', 'EnvObject', 'EnvObjects', 'Dict', 'Any'
|
||||
]])
|
||||
@@ -240,7 +240,13 @@ def add_bound_name(name_str, bound_e):
|
||||
|
||||
def add_pos_name(name_str, bound_e):
|
||||
if bound_e.var_has_position:
|
||||
return f'{name_str}({bound_e.pos})'
|
||||
return f'{name_str}@{bound_e.pos}'
|
||||
return name_str
|
||||
|
||||
|
||||
def get_first(iterable: Iterable, filter_by: Callable[[any], bool] = lambda _: True):
|
||||
return next((x for x in iterable if filter_by(x)), None)
|
||||
|
||||
|
||||
def get_first_index(iterable: Iterable, filter_by: Callable[[any], bool] = lambda _: True):
|
||||
return next((idx for idx, x in enumerate(iterable) if filter_by(x)), None)
|
||||
|
||||
@@ -47,6 +47,7 @@ class LevelParser(object):
|
||||
# All other
|
||||
for es_name in self.e_p_dict:
|
||||
e_class, e_kwargs = self.e_p_dict[es_name]['class'], self.e_p_dict[es_name]['kwargs']
|
||||
e_kwargs = e_kwargs if e_kwargs else {}
|
||||
|
||||
if hasattr(e_class, 'symbol') and e_class.symbol is not None:
|
||||
symbols = e_class.symbol
|
||||
|
||||
@@ -1,17 +1,17 @@
|
||||
import math
|
||||
import re
|
||||
from collections import defaultdict
|
||||
from itertools import product
|
||||
from typing import Dict, List
|
||||
|
||||
import numpy as np
|
||||
from numba import njit
|
||||
|
||||
from marl_factory_grid.environment import constants as c
|
||||
from marl_factory_grid.environment.entity.object import _Object
|
||||
from marl_factory_grid.environment.groups.utils import Combined
|
||||
import marl_factory_grid.utils.helpers as h
|
||||
from marl_factory_grid.utils.states import Gamestate
|
||||
from marl_factory_grid.utils.utility_classes import Floor
|
||||
from marl_factory_grid.utils.ray_caster import RayCaster
|
||||
from marl_factory_grid.utils.states import Gamestate
|
||||
from marl_factory_grid.utils import helpers as h
|
||||
|
||||
|
||||
|
||||
class OBSBuilder(object):
|
||||
@@ -77,11 +77,13 @@ class OBSBuilder(object):
|
||||
|
||||
def place_entity_in_observation(self, obs_array, agent, e):
|
||||
x, y = (e.x - agent.x) + self.pomdp_r, (e.y - agent.y) + self.pomdp_r
|
||||
try:
|
||||
obs_array[x, y] += e.encoding
|
||||
except IndexError:
|
||||
# Seemded to be visible but is out of range
|
||||
pass
|
||||
if not min([y, x]) < 0:
|
||||
try:
|
||||
obs_array[x, y] += e.encoding
|
||||
except IndexError:
|
||||
# Seemded to be visible but is out of range
|
||||
pass
|
||||
pass
|
||||
|
||||
def build_for_agent(self, agent, state) -> (List[str], np.ndarray):
|
||||
assert self._curr_env_step == state.curr_step, (
|
||||
@@ -121,18 +123,24 @@ class OBSBuilder(object):
|
||||
e = self.all_obs[l_name]
|
||||
except KeyError:
|
||||
try:
|
||||
# Look for bound entity names!
|
||||
pattern = re.compile(f'{re.escape(l_name)}(.*){re.escape(agent.name)}')
|
||||
name = next((x for x in self.all_obs if pattern.search(x)), None)
|
||||
# Look for bound entity REPRs!
|
||||
pattern = re.compile(f'{re.escape(l_name)}'
|
||||
f'{re.escape("[")}(.*){re.escape("]")}'
|
||||
f'{re.escape("(")}{re.escape(agent.name)}{re.escape(")")}')
|
||||
name = next((key for key, val in self.all_obs.items()
|
||||
if pattern.search(str(val)) and isinstance(val, _Object)), None)
|
||||
e = self.all_obs[name]
|
||||
except KeyError:
|
||||
try:
|
||||
e = next(v for k, v in self.all_obs.items() if l_name in k and agent.name in k)
|
||||
except StopIteration:
|
||||
raise KeyError(
|
||||
f'Check for spelling errors! \n '
|
||||
f'No combination of "{l_name} and {agent.name}" could not be found in:\n '
|
||||
f'{list(dict(self.all_obs).keys())}')
|
||||
print(f'# Check for spelling errors!')
|
||||
print(f'# No combination of "{l_name}" and "{agent.name}" could not be found in:')
|
||||
print(f'# {list(dict(self.all_obs).keys())}')
|
||||
print('#')
|
||||
print('# exiting...')
|
||||
print('#')
|
||||
exit(-99999)
|
||||
|
||||
try:
|
||||
positional = e.var_has_position
|
||||
@@ -161,15 +169,14 @@ class OBSBuilder(object):
|
||||
try:
|
||||
light_map = np.zeros(self.obs_shape)
|
||||
visible_floor = self.ray_caster[agent.name].visible_entities(self._floortiles, reset_cache=False)
|
||||
if self.pomdp_r:
|
||||
for f in set(visible_floor):
|
||||
self.place_entity_in_observation(light_map, agent, f)
|
||||
else:
|
||||
for f in set(visible_floor):
|
||||
light_map[f.x, f.y] += f.encoding
|
||||
|
||||
for f in set(visible_floor):
|
||||
self.place_entity_in_observation(light_map, agent, f)
|
||||
# else:
|
||||
# for f in set(visible_floor):
|
||||
# light_map[f.x, f.y] += f.encoding
|
||||
self.curr_lightmaps[agent.name] = light_map
|
||||
except (KeyError, ValueError):
|
||||
print()
|
||||
pass
|
||||
return obs, self.obs_layers[agent.name]
|
||||
|
||||
@@ -185,7 +192,7 @@ class OBSBuilder(object):
|
||||
|
||||
for obs_str in agent.observations:
|
||||
if isinstance(obs_str, dict):
|
||||
obs_str, vals = next(obs_str.items().__iter__())
|
||||
obs_str, vals = h.get_first(obs_str.items())
|
||||
else:
|
||||
vals = None
|
||||
if obs_str == c.SELF:
|
||||
@@ -214,129 +221,3 @@ class OBSBuilder(object):
|
||||
obs_layers.append(obs_str)
|
||||
self.obs_layers[agent.name] = obs_layers
|
||||
self.curr_lightmaps[agent.name] = np.zeros(self.obs_shape)
|
||||
|
||||
|
||||
class RayCaster:
|
||||
def __init__(self, agent, pomdp_r, degs=360):
|
||||
self.agent = agent
|
||||
self.pomdp_r = pomdp_r
|
||||
self.n_rays = (self.pomdp_r + 1) * 8
|
||||
self.degs = degs
|
||||
self.ray_targets = self.build_ray_targets()
|
||||
self.obs_shape_cube = np.array([self.pomdp_r, self.pomdp_r])
|
||||
self._cache_dict = {}
|
||||
|
||||
def __repr__(self):
|
||||
return f'{self.__class__.__name__}({self.agent.name})'
|
||||
|
||||
def build_ray_targets(self):
|
||||
north = np.array([0, -1]) * self.pomdp_r
|
||||
thetas = [np.deg2rad(deg) for deg in np.linspace(-self.degs // 2, self.degs // 2, self.n_rays)[::-1]]
|
||||
rot_M = [
|
||||
[[math.cos(theta), -math.sin(theta)],
|
||||
[math.sin(theta), math.cos(theta)]] for theta in thetas
|
||||
]
|
||||
rot_M = np.stack(rot_M, 0)
|
||||
rot_M = np.unique(np.round(rot_M @ north), axis=0)
|
||||
return rot_M.astype(int)
|
||||
|
||||
def ray_block_cache(self, key, callback):
|
||||
if key not in self._cache_dict:
|
||||
self._cache_dict[key] = callback()
|
||||
return self._cache_dict[key]
|
||||
|
||||
def visible_entities(self, pos_dict, reset_cache=True):
|
||||
visible = list()
|
||||
if reset_cache:
|
||||
self._cache_dict = {}
|
||||
|
||||
for ray in self.get_rays():
|
||||
rx, ry = ray[0]
|
||||
for x, y in ray:
|
||||
cx, cy = x - rx, y - ry
|
||||
|
||||
entities_hit = pos_dict[(x, y)]
|
||||
hits = self.ray_block_cache((x, y),
|
||||
lambda: any(True for e in entities_hit if e.var_is_blocking_light)
|
||||
)
|
||||
|
||||
diag_hits = all([
|
||||
self.ray_block_cache(
|
||||
key,
|
||||
lambda: all(False for e in pos_dict[key] if not e.var_is_blocking_light) and bool(
|
||||
pos_dict[key]))
|
||||
for key in ((x, y - cy), (x - cx, y))
|
||||
]) if (cx != 0 and cy != 0) else False
|
||||
|
||||
visible += entities_hit if not diag_hits else []
|
||||
if hits or diag_hits:
|
||||
break
|
||||
rx, ry = x, y
|
||||
return visible
|
||||
|
||||
def get_rays(self):
|
||||
a_pos = self.agent.pos
|
||||
outline = self.ray_targets + a_pos
|
||||
return self.bresenham_loop(a_pos, outline)
|
||||
|
||||
# todo do this once and cache the points!
|
||||
def get_fov_outline(self) -> np.ndarray:
|
||||
return self.ray_targets + self.agent.pos
|
||||
|
||||
def get_square_outline(self):
|
||||
agent = self.agent
|
||||
x_coords = range(agent.x - self.pomdp_r, agent.x + self.pomdp_r + 1)
|
||||
y_coords = range(agent.y - self.pomdp_r, agent.y + self.pomdp_r + 1)
|
||||
outline = list(product(x_coords, [agent.y - self.pomdp_r, agent.y + self.pomdp_r])) \
|
||||
+ list(product([agent.x - self.pomdp_r, agent.x + self.pomdp_r], y_coords))
|
||||
return outline
|
||||
|
||||
@staticmethod
|
||||
@njit
|
||||
def bresenham_loop(a_pos, points):
|
||||
results = []
|
||||
for end in points:
|
||||
x1, y1 = a_pos
|
||||
x2, y2 = end
|
||||
dx = x2 - x1
|
||||
dy = y2 - y1
|
||||
|
||||
# Determine how steep the line is
|
||||
is_steep = abs(dy) > abs(dx)
|
||||
|
||||
# Rotate line
|
||||
if is_steep:
|
||||
x1, y1 = y1, x1
|
||||
x2, y2 = y2, x2
|
||||
|
||||
# Swap start and end points if necessary and store swap state
|
||||
swapped = False
|
||||
if x1 > x2:
|
||||
x1, x2 = x2, x1
|
||||
y1, y2 = y2, y1
|
||||
swapped = True
|
||||
|
||||
# Recalculate differentials
|
||||
dx = x2 - x1
|
||||
dy = y2 - y1
|
||||
|
||||
# Calculate error
|
||||
error = int(dx / 2.0)
|
||||
ystep = 1 if y1 < y2 else -1
|
||||
|
||||
# Iterate over bounding box generating points between start and end
|
||||
y = y1
|
||||
points = []
|
||||
for x in range(int(x1), int(x2) + 1):
|
||||
coord = [y, x] if is_steep else [x, y]
|
||||
points.append(coord)
|
||||
error -= abs(dy)
|
||||
if error < 0:
|
||||
y += ystep
|
||||
error += dx
|
||||
|
||||
# Reverse the list if the coordinates were swapped
|
||||
if swapped:
|
||||
points.reverse()
|
||||
results.append(points)
|
||||
return results
|
||||
|
||||
@@ -39,8 +39,9 @@ class RayCaster:
|
||||
if reset_cache:
|
||||
self._cache_dict = dict()
|
||||
|
||||
for ray in self.get_rays():
|
||||
for ray in self.get_rays(): # Do not check, just trust.
|
||||
rx, ry = ray[0]
|
||||
# self.ray_block_cache(ray[0], lambda: False) We do not do that, because of doors etc...
|
||||
for x, y in ray:
|
||||
cx, cy = x - rx, y - ry
|
||||
|
||||
@@ -52,7 +53,8 @@ class RayCaster:
|
||||
diag_hits = all([
|
||||
self.ray_block_cache(
|
||||
key,
|
||||
lambda: all(False for e in pos_dict[key] if not e.var_is_blocking_light))
|
||||
lambda: any(True for e in pos_dict[key] if e.var_is_blocking_light))
|
||||
# lambda: all(False for e in pos_dict[key] if not e.var_is_blocking_light))
|
||||
for key in ((x, y-cy), (x-cx, y))
|
||||
]) if (cx != 0 and cy != 0) else False
|
||||
|
||||
|
||||
@@ -31,7 +31,7 @@ class Renderer:
|
||||
|
||||
def __init__(self, lvl_shape: Tuple[int, int] = (16, 16),
|
||||
lvl_padded_shape: Union[Tuple[int, int], None] = None,
|
||||
cell_size: int = 40, fps: int = 7,
|
||||
cell_size: int = 40, fps: int = 7, factor: float = 0.9,
|
||||
grid_lines: bool = True, view_radius: int = 2):
|
||||
# TODO: Customn_assets paths
|
||||
self.grid_h, self.grid_w = lvl_shape
|
||||
@@ -45,7 +45,7 @@ class Renderer:
|
||||
self.screen = pygame.display.set_mode(self.screen_size)
|
||||
self.clock = pygame.time.Clock()
|
||||
assets = list(self.ASSETS.rglob('*.png'))
|
||||
self.assets = {path.stem: self.load_asset(str(path), 1) for path in assets}
|
||||
self.assets = {path.stem: self.load_asset(str(path), factor) for path in assets}
|
||||
self.fill_bg()
|
||||
|
||||
now = time.time()
|
||||
@@ -110,22 +110,22 @@ class Renderer:
|
||||
pygame.quit()
|
||||
sys.exit()
|
||||
self.fill_bg()
|
||||
blits = deque()
|
||||
for entity in [x for x in entities]:
|
||||
bp = self.blit_params(entity)
|
||||
blits.append(bp)
|
||||
if entity.name.lower() == AGENT:
|
||||
if self.view_radius > 0:
|
||||
vis_rects = self.visibility_rects(bp, entity.aux)
|
||||
blits.extendleft(vis_rects)
|
||||
if entity.state != BLANK:
|
||||
agent_state_blits = self.blit_params(
|
||||
RenderEntity(entity.state, (entity.pos[0] + 0.12, entity.pos[1]), 0.48, SCALE)
|
||||
)
|
||||
textsurface = self.font.render(str(entity.id), False, (0, 0, 0))
|
||||
text_blit = dict(source=textsurface, dest=(bp['dest'].center[0]-.07*self.cell_size,
|
||||
bp['dest'].center[1]))
|
||||
blits += [agent_state_blits, text_blit]
|
||||
# First all others
|
||||
blits = deque(self.blit_params(x) for x in entities if not x.name.lower() == AGENT)
|
||||
# Then Agents, so that agents are rendered on top.
|
||||
for agent in (x for x in entities if x.name.lower() == AGENT):
|
||||
agent_blit = self.blit_params(agent)
|
||||
if self.view_radius > 0:
|
||||
vis_rects = self.visibility_rects(agent_blit, agent.aux)
|
||||
blits.extendleft(vis_rects)
|
||||
if agent.state != BLANK:
|
||||
state_blit = self.blit_params(
|
||||
RenderEntity(agent.state, (agent.pos[0] + 0.12, agent.pos[1]), 0.48, SCALE)
|
||||
)
|
||||
textsurface = self.font.render(str(agent.id), False, (0, 0, 0))
|
||||
text_blit = dict(source=textsurface, dest=(agent_blit['dest'].center[0]-.07*self.cell_size,
|
||||
agent_blit['dest'].center[1]))
|
||||
blits += [agent_blit, state_blit, text_blit]
|
||||
|
||||
for blit in blits:
|
||||
self.screen.blit(**blit)
|
||||
|
||||
@@ -28,7 +28,10 @@ class Result:
|
||||
|
||||
def __repr__(self):
|
||||
valid = "not " if not self.validity else ""
|
||||
return f'{self.__class__.__name__}({self.identifier.capitalize()} {valid}valid: {self.reward})'
|
||||
reward = f" | Reward: {self.reward}" if self.reward is not None else ""
|
||||
value = f" | Value: {self.value}" if self.value is not None else ""
|
||||
entity = f" | by: {self.entity.name}" if self.entity is not None else ""
|
||||
return f'{self.__class__.__name__}({self.identifier.capitalize()} {valid}valid{reward}{value})'
|
||||
|
||||
|
||||
@dataclass
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
from itertools import islice
|
||||
from typing import List, Dict, Tuple
|
||||
|
||||
import numpy as np
|
||||
@@ -59,14 +60,15 @@ class Gamestate(object):
|
||||
def moving_entites(self):
|
||||
return [y for x in self.entities for y in x if x.var_can_move]
|
||||
|
||||
def __init__(self, entities, agents_conf, rules: Dict[str, dict], env_seed=69, verbose=False):
|
||||
def __init__(self, entities, agents_conf, rules: List[Rule], lvl_shape, env_seed=69, verbose=False):
|
||||
self.lvl_shape = lvl_shape
|
||||
self.entities = entities
|
||||
self.curr_step = 0
|
||||
self.curr_actions = None
|
||||
self.agents_conf = agents_conf
|
||||
self.verbose = verbose
|
||||
self.rng = np.random.default_rng(env_seed)
|
||||
self.rules = StepRules(*(v['class'](**v['kwargs']) for v in rules.values()))
|
||||
self.rules = StepRules(*rules)
|
||||
|
||||
def __getitem__(self, item):
|
||||
return self.entities[item]
|
||||
@@ -80,6 +82,13 @@ class Gamestate(object):
|
||||
def __repr__(self):
|
||||
return f'{self.__class__.__name__}({len(self.entities)} Entitites @ Step {self.curr_step})'
|
||||
|
||||
@property
|
||||
def random_free_position(self):
|
||||
return self.get_n_random_free_positions(1)[0]
|
||||
|
||||
def get_n_random_free_positions(self, n):
|
||||
return list(islice(self.entities.free_positions_generator, n))
|
||||
|
||||
def tick(self, actions) -> List[Result]:
|
||||
results = list()
|
||||
self.curr_step += 1
|
||||
@@ -115,8 +124,7 @@ class Gamestate(object):
|
||||
return results
|
||||
|
||||
def get_all_pos_with_collisions(self) -> List[Tuple[(int, int)]]:
|
||||
positions = [pos for pos, entity_list_for_position in self.entities.pos_dict.items()
|
||||
if any([e.var_can_collide for e in entity_list_for_position])]
|
||||
positions = [pos for pos, entities in self.entities.pos_dict.items() if len(entities) >= 2 and (len([e for e in entities if e.var_can_collide]) >= 2)]
|
||||
return positions
|
||||
|
||||
def check_move_validity(self, moving_entity, position):
|
||||
|
||||
@@ -135,4 +135,3 @@ if __name__ == '__main__':
|
||||
ce.get_observations()
|
||||
ce.get_assets()
|
||||
all_conf = ce.get_all()
|
||||
print()
|
||||
|
||||
@@ -52,3 +52,6 @@ class Floor:
|
||||
|
||||
def __hash__(self):
|
||||
return hash(self.name)
|
||||
|
||||
def __repr__(self):
|
||||
return f"Floor{self.pos}"
|
||||
|
||||
Reference in New Issue
Block a user