actions are now objects :P
This commit is contained in:
@ -1,6 +1,6 @@
|
||||
from collections import OrderedDict
|
||||
from dataclasses import dataclass
|
||||
from typing import List
|
||||
from typing import List, Union
|
||||
import random
|
||||
|
||||
import numpy as np
|
||||
@ -12,7 +12,7 @@ from environments.logging.monitor import MonitorCallback
|
||||
from environments.factory.renderer import Renderer, Entity
|
||||
|
||||
DIRT_INDEX = -1
|
||||
|
||||
CLEAN_UP_ACTION = 'clean_up'
|
||||
|
||||
@dataclass
|
||||
class DirtProperties:
|
||||
@ -26,13 +26,14 @@ class DirtProperties:
|
||||
|
||||
class SimpleFactory(BaseFactory):
|
||||
|
||||
def register_additional_actions(self):
|
||||
return 1
|
||||
@property
|
||||
def additional_actions(self) -> Union[str, List[str]]:
|
||||
return CLEAN_UP_ACTION
|
||||
|
||||
def _is_clean_up_action(self, action):
|
||||
return self.action_space.n - 1 == action
|
||||
return self._actions[action] == CLEAN_UP_ACTION
|
||||
|
||||
def __init__(self, *args, dirt_properties: DirtProperties, verbose=False, force_skip_render=False, **kwargs):
|
||||
def __init__(self, *args, dirt_properties: DirtProperties, verbose=False, **kwargs):
|
||||
self._dirt_properties = dirt_properties
|
||||
self.verbose = verbose
|
||||
self.max_dirt = 20
|
||||
@ -98,7 +99,7 @@ class SimpleFactory(BaseFactory):
|
||||
self.next_dirt_spawn -= 1
|
||||
return self.state, r, done, info
|
||||
|
||||
def additional_actions(self, agent_i: int, action: int) -> ((int, int), bool):
|
||||
def do_additional_actions(self, agent_i: int, action: int) -> ((int, int), bool):
|
||||
if action != self._is_moving_action(action):
|
||||
if self._is_clean_up_action(action):
|
||||
agent_i_pos = self.agent_i_position(agent_i)
|
||||
@ -175,9 +176,10 @@ if __name__ == '__main__':
|
||||
|
||||
dirt_props = DirtProperties()
|
||||
factory = SimpleFactory(n_agents=2, dirt_properties=dirt_props)
|
||||
n_actions = factory.action_space.n - 1
|
||||
with MonitorCallback(factory):
|
||||
for epoch in range(100):
|
||||
random_actions = [(random.randint(0, 8), random.randint(0, 8)) for _ in range(200)]
|
||||
random_actions = [(random.randint(0, n_actions), random.randint(0, n_actions)) for _ in range(200)]
|
||||
env_state, this_reward, done_bool, _ = factory.reset()
|
||||
for agent_i_action in random_actions:
|
||||
env_state, reward, done_bool, info_obj = factory.step(agent_i_action)
|
||||
|
Reference in New Issue
Block a user