mirror of
https://github.com/illiumst/marl-factory-grid.git
synced 2025-07-06 09:31:35 +02:00
Experiments look good
This commit is contained in:
@ -24,14 +24,12 @@ class EnvRecorder(BaseCallback):
|
||||
self._entities = [entities]
|
||||
else:
|
||||
self._entities = entities
|
||||
self.started = False
|
||||
self.closed = False
|
||||
|
||||
def __getattr__(self, item):
|
||||
return getattr(self.unwrapped, item)
|
||||
|
||||
def reset(self):
|
||||
self.unwrapped._record_episodes = True
|
||||
self.unwrapped.start_recording()
|
||||
return self.unwrapped.reset()
|
||||
|
||||
def _on_training_start(self) -> None:
|
||||
@ -57,6 +55,14 @@ class EnvRecorder(BaseCallback):
|
||||
else:
|
||||
pass
|
||||
|
||||
def step(self, actions):
|
||||
step_result = self.unwrapped.step(actions)
|
||||
# 0, 1, 2 , 3 = idx
|
||||
# _, _, done_bool, info_obj = step_result
|
||||
self._read_info(0, step_result[3])
|
||||
self._read_done(0, step_result[2])
|
||||
return step_result
|
||||
|
||||
def save_records(self, filepath: Union[Path, str], save_occupation_map=False, save_trajectory_map=False):
|
||||
filepath = Path(filepath)
|
||||
filepath.parent.mkdir(exist_ok=True, parents=True)
|
||||
|
Reference in New Issue
Block a user