mirror of
https://github.com/illiumst/marl-factory-grid.git
synced 2025-07-06 09:31:35 +02:00
Items and combination of item and dirt
This commit is contained in:
@ -1,9 +1,7 @@
|
||||
import itertools
|
||||
import random
|
||||
from enum import Enum
|
||||
from typing import List, Union
|
||||
|
||||
import networkx as nx
|
||||
import numpy as np
|
||||
|
||||
from environments.factory.base.objects import Entity, Tile, Agent, Door, Slice, Action
|
||||
@ -16,11 +14,8 @@ class Register:
|
||||
_accepted_objects = Entity
|
||||
|
||||
@classmethod
|
||||
def from_argwhere_coordinates(cls, positions: (int, int), tiles):
|
||||
entities = [cls._accepted_objects(i, tiles.by_pos(position)) for i, position in enumerate(positions)]
|
||||
registered_obj = cls()
|
||||
registered_obj.register_additional_items(entities)
|
||||
return registered_obj
|
||||
def from_argwhere_coordinates(cls, positions: [(int, int)], tiles):
|
||||
return cls.from_tiles([tiles.by_pos(position) for position in positions])
|
||||
|
||||
@property
|
||||
def name(self):
|
||||
@ -72,8 +67,8 @@ class Register:
|
||||
def by_name(self, item):
|
||||
return self[self._names[item]]
|
||||
|
||||
def by_enum(self, enum: Enum):
|
||||
return self[self._names[enum.name]]
|
||||
def by_enum(self, enum_obj: Enum):
|
||||
return self[self._names[enum_obj.name]]
|
||||
|
||||
def __repr__(self):
|
||||
return f'{self.__class__.__name__}({self._register})'
|
||||
@ -84,13 +79,13 @@ class Register:
|
||||
def get_idx_by_name(self, item):
|
||||
return self._names[item]
|
||||
|
||||
def get_idx(self, enum: Enum):
|
||||
return self._names[enum.name]
|
||||
def get_idx(self, enum_obj: Enum):
|
||||
return self._names[enum_obj.name]
|
||||
|
||||
@classmethod
|
||||
def from_tiles(cls, tiles, **kwargs):
|
||||
entities = [cls._accepted_objects(f'{cls._accepted_objects.__name__.upper()}#{i}', tile, **kwargs)
|
||||
for i, tile in enumerate(tiles)]
|
||||
# objects_name = cls._accepted_objects.__name__
|
||||
entities = [cls._accepted_objects(i, tile, name_is_identifier=True, **kwargs) for i, tile in enumerate(tiles)]
|
||||
registered_obj = cls()
|
||||
registered_obj.register_additional_items(entities)
|
||||
return registered_obj
|
||||
@ -98,14 +93,6 @@ class Register:
|
||||
|
||||
class EntityRegister(Register):
|
||||
|
||||
@classmethod
|
||||
def from_argwhere_coordinates(cls, argwhere_coordinates, **kwargs):
|
||||
tiles = cls()
|
||||
tiles.register_additional_items(
|
||||
[cls._accepted_objects(i, pos, **kwargs) for i, pos in enumerate(argwhere_coordinates)]
|
||||
)
|
||||
return tiles
|
||||
|
||||
def __init__(self):
|
||||
super(EntityRegister, self).__init__()
|
||||
self._tiles = dict()
|
||||
@ -141,6 +128,15 @@ class Entities(Register):
|
||||
class FloorTiles(EntityRegister):
|
||||
_accepted_objects = Tile
|
||||
|
||||
@classmethod
|
||||
def from_argwhere_coordinates(cls, argwhere_coordinates):
|
||||
tiles = cls()
|
||||
# noinspection PyTypeChecker
|
||||
tiles.register_additional_items(
|
||||
[cls._accepted_objects(i, pos, name_is_identifier=True) for i, pos in enumerate(argwhere_coordinates)]
|
||||
)
|
||||
return tiles
|
||||
|
||||
@property
|
||||
def occupied_tiles(self):
|
||||
tiles = [tile for tile in self if tile.is_occupied()]
|
||||
@ -148,7 +144,7 @@ class FloorTiles(EntityRegister):
|
||||
return tiles
|
||||
|
||||
@property
|
||||
def empty_tiles(self):
|
||||
def empty_tiles(self) -> List[Tile]:
|
||||
tiles = [tile for tile in self if tile.is_empty()]
|
||||
random.shuffle(tiles)
|
||||
return tiles
|
||||
@ -185,6 +181,7 @@ class Actions(Register):
|
||||
def movement_actions(self):
|
||||
return self._movement_actions
|
||||
|
||||
# noinspection PyTypeChecker
|
||||
def __init__(self, movement_properties: MovementProperties, can_use_doors=False):
|
||||
self.allow_no_op = movement_properties.allow_no_op
|
||||
self.allow_diagonal_movement = movement_properties.allow_diagonal_movement
|
||||
@ -193,43 +190,47 @@ class Actions(Register):
|
||||
super(Actions, self).__init__()
|
||||
|
||||
if self.allow_square_movement:
|
||||
self.register_additional_items([self._accepted_objects(direction) for direction in h.MANHATTAN_MOVES])
|
||||
self.register_additional_items([self._accepted_objects(direction) for direction in h.ManhattanMoves])
|
||||
if self.allow_diagonal_movement:
|
||||
self.register_additional_items([self._accepted_objects(direction) for direction in h.DIAGONAL_MOVES])
|
||||
self.register_additional_items([self._accepted_objects(direction) for direction in h.DiagonalMoves])
|
||||
self._movement_actions = self._register.copy()
|
||||
if self.can_use_doors:
|
||||
self.register_additional_items([self._accepted_objects('use_door')])
|
||||
self.register_additional_items([self._accepted_objects(h.EnvActions.USE_DOOR)])
|
||||
if self.allow_no_op:
|
||||
self.register_additional_items([self._accepted_objects('no-op')])
|
||||
self.register_additional_items([self._accepted_objects(h.EnvActions.NOOP)])
|
||||
|
||||
def is_moving_action(self, action: Union[int]):
|
||||
#if isinstance(action, Action):
|
||||
# return (action.name in h.MANHATTAN_MOVES and self.allow_square_movement) or \
|
||||
# (action.name in h.DIAGONAL_MOVES and self.allow_diagonal_movement)
|
||||
#else:
|
||||
return action in self.movement_actions.keys()
|
||||
return action in self.movement_actions.values()
|
||||
|
||||
def is_no_op(self, action: Union[str, int]):
|
||||
if isinstance(action, str):
|
||||
action = self.by_name(action)
|
||||
return self[action].name == 'no-op'
|
||||
def is_no_op(self, action: Union[str, Action, int]):
|
||||
if isinstance(action, int):
|
||||
action = self[action]
|
||||
if isinstance(action, Action):
|
||||
action = action.name
|
||||
return action == h.EnvActions.NOOP.name
|
||||
|
||||
def is_door_usage(self, action: Union[str, int]):
|
||||
if isinstance(action, str):
|
||||
action = self.by_name(action)
|
||||
return self[action].name == 'use_door'
|
||||
if isinstance(action, int):
|
||||
action = self[action]
|
||||
if isinstance(action, Action):
|
||||
action = action.name
|
||||
return action == h.EnvActions.USE_DOOR.name
|
||||
|
||||
|
||||
class StateSlices(Register):
|
||||
|
||||
_accepted_objects = Slice
|
||||
@property
|
||||
def n_observable_slices(self):
|
||||
return len([x for x in self if x.is_observable])
|
||||
|
||||
|
||||
@property
|
||||
def AGENTSTARTIDX(self):
|
||||
if self._agent_start_idx:
|
||||
return self._agent_start_idx
|
||||
else:
|
||||
self._agent_start_idx = min([idx for idx, x in self.items() if c.AGENT.name in x.name])
|
||||
self._agent_start_idx = min([idx for idx, x in self.items() if c.AGENT.value in x.name])
|
||||
return self._agent_start_idx
|
||||
|
||||
def __init__(self):
|
||||
|
Reference in New Issue
Block a user