mirror of
https://github.com/illiumst/marl-factory-grid.git
synced 2025-06-21 11:21:35 +02:00
Individual Rewards
This commit is contained in:
@ -61,7 +61,8 @@ class BaseFactory(gym.Env):
|
||||
mv_prop: MovementProperties = MovementProperties(),
|
||||
obs_prop: ObservationProperties = ObservationProperties(),
|
||||
parse_doors=False, record_episodes=False, done_at_collision=False,
|
||||
verbose=False, doors_have_area=True, env_seed=time.time_ns(), **kwargs):
|
||||
verbose=False, doors_have_area=True, env_seed=time.time_ns(), individual_rewards=False,
|
||||
**kwargs):
|
||||
|
||||
if isinstance(mv_prop, dict):
|
||||
mv_prop = MovementProperties(**mv_prop)
|
||||
@ -94,6 +95,7 @@ class BaseFactory(gym.Env):
|
||||
self.record_episodes = record_episodes
|
||||
self.parse_doors = parse_doors
|
||||
self.doors_have_area = doors_have_area
|
||||
self.individual_rewards = individual_rewards
|
||||
|
||||
# Reset
|
||||
self.reset()
|
||||
@ -487,31 +489,32 @@ class BaseFactory(gym.Env):
|
||||
def calculate_reward(self) -> (int, dict):
|
||||
# Returns: Reward, Info
|
||||
per_agent_info_dict = defaultdict(dict)
|
||||
reward = 0
|
||||
reward = {}
|
||||
|
||||
for agent in self[c.AGENT]:
|
||||
per_agent_reward = 0
|
||||
if self._actions.is_moving_action(agent.temp_action):
|
||||
if agent.temp_valid:
|
||||
# info_dict.update(movement=1)
|
||||
reward -= 0.01
|
||||
per_agent_reward -= 0.01
|
||||
pass
|
||||
else:
|
||||
reward -= 0.05
|
||||
per_agent_reward -= 0.05
|
||||
self.print(f'{agent.name} just hit the wall at {agent.pos}.')
|
||||
per_agent_info_dict[agent.name].update({f'{agent.name}_vs_LEVEL': 1})
|
||||
|
||||
elif h.EnvActions.USE_DOOR == agent.temp_action:
|
||||
if agent.temp_valid:
|
||||
# reward += 0.00
|
||||
# per_agent_reward += 0.00
|
||||
self.print(f'{agent.name} did just use the door at {agent.pos}.')
|
||||
per_agent_info_dict[agent.name].update(door_used=1)
|
||||
else:
|
||||
# reward -= 0.00
|
||||
# per_agent_reward -= 0.00
|
||||
self.print(f'{agent.name} just tried to use a door at {agent.pos}, but failed.')
|
||||
per_agent_info_dict[agent.name].update({f'{agent.name}_failed_door_open': 1})
|
||||
elif h.EnvActions.NOOP == agent.temp_action:
|
||||
per_agent_info_dict[agent.name].update(no_op=1)
|
||||
# reward -= 0.00
|
||||
# per_agent_reward -= 0.00
|
||||
|
||||
# Monitor Notes
|
||||
if agent.temp_valid:
|
||||
@ -522,7 +525,7 @@ class BaseFactory(gym.Env):
|
||||
per_agent_info_dict[agent.name].update({f'{agent.name}_failed_action': 1})
|
||||
|
||||
additional_reward, additional_info_dict = self.calculate_additional_reward(agent)
|
||||
reward += additional_reward
|
||||
per_agent_reward += additional_reward
|
||||
per_agent_info_dict[agent.name].update(additional_info_dict)
|
||||
|
||||
if agent.temp_collisions:
|
||||
@ -531,6 +534,7 @@ class BaseFactory(gym.Env):
|
||||
|
||||
for other_agent in agent.temp_collisions:
|
||||
per_agent_info_dict[agent.name].update({f'{agent.name}_vs_{other_agent.name}': 1})
|
||||
reward[agent.name] = per_agent_reward
|
||||
|
||||
# Combine the per_agent_info_dict:
|
||||
combined_info_dict = defaultdict(lambda: 0)
|
||||
@ -539,7 +543,13 @@ class BaseFactory(gym.Env):
|
||||
combined_info_dict[key] += value
|
||||
combined_info_dict = dict(combined_info_dict)
|
||||
|
||||
self.print(f"reward is {reward}")
|
||||
if self.individual_rewards:
|
||||
self.print(f"rewards are {reward}")
|
||||
reward = list(reward.values())
|
||||
return reward, combined_info_dict
|
||||
else:
|
||||
reward = sum(reward.values())
|
||||
self.print(f"reward is {reward}")
|
||||
return reward, combined_info_dict
|
||||
|
||||
def render(self, mode='human'):
|
||||
|
@ -1,3 +0,0 @@
|
||||
from environments.policy_adaption.natural_rl_environment import matting
|
||||
from environments.policy_adaption.natural_rl_environment import imgsource
|
||||
from environments.policy_adaption.natural_rl_environment import natural_env
|
@ -1,120 +0,0 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import numpy as np
|
||||
import cv2
|
||||
import skvideo.io
|
||||
|
||||
|
||||
class ImageSource(object):
|
||||
"""
|
||||
Source of natural images to be added to a simulated environment.
|
||||
"""
|
||||
def get_image(self):
|
||||
"""
|
||||
Returns:
|
||||
an RGB image of [h, w, 3] with a fixed shape.
|
||||
"""
|
||||
pass
|
||||
|
||||
def reset(self):
|
||||
""" Called when an episode ends. """
|
||||
pass
|
||||
|
||||
|
||||
class FixedColorSource(ImageSource):
|
||||
def __init__(self, shape, color):
|
||||
"""
|
||||
Args:
|
||||
shape: [h, w]
|
||||
color: a 3-tuple
|
||||
"""
|
||||
self.arr = np.zeros((shape[0], shape[1], 3))
|
||||
self.arr[:, :] = color
|
||||
|
||||
def get_image(self):
|
||||
return np.copy(self.arr)
|
||||
|
||||
|
||||
class RandomColorSource(ImageSource):
|
||||
def __init__(self, shape):
|
||||
"""
|
||||
Args:
|
||||
shape: [h, w]
|
||||
"""
|
||||
self.shape = shape
|
||||
self.reset()
|
||||
|
||||
def reset(self):
|
||||
self._color = np.random.randint(0, 256, size=(3,))
|
||||
|
||||
def get_image(self):
|
||||
arr = np.zeros((self.shape[0], self.shape[1], 3))
|
||||
arr[:, :] = self._color
|
||||
return arr
|
||||
|
||||
|
||||
class NoiseSource(ImageSource):
|
||||
def __init__(self, shape, strength=50):
|
||||
"""
|
||||
Args:
|
||||
shape: [h, w]
|
||||
strength (int): the strength of noise, in range [0, 255]
|
||||
"""
|
||||
self.shape = shape
|
||||
self.strength = strength
|
||||
|
||||
def get_image(self):
|
||||
return np.maximum(np.random.randn(
|
||||
self.shape[0], self.shape[1], 3) * self.strength, 0)
|
||||
|
||||
|
||||
class RandomImageSource(ImageSource):
|
||||
def __init__(self, shape, filelist):
|
||||
"""
|
||||
Args:
|
||||
shape: [h, w]
|
||||
filelist: a list of image files
|
||||
"""
|
||||
self.shape_wh = shape[::-1]
|
||||
self.filelist = filelist
|
||||
self.reset()
|
||||
|
||||
def reset(self):
|
||||
fname = np.random.choice(self.filelist)
|
||||
im = cv2.imread(fname, cv2.IMREAD_COLOR)
|
||||
im = im[:, :, ::-1]
|
||||
im = cv2.resize(im, self.shape_wh)
|
||||
self._im = im
|
||||
|
||||
def get_image(self):
|
||||
return self._im
|
||||
|
||||
|
||||
class RandomVideoSource(ImageSource):
|
||||
def __init__(self, shape, filelist):
|
||||
"""
|
||||
Args:
|
||||
shape: [h, w]
|
||||
filelist: a list of video files
|
||||
"""
|
||||
self.shape_wh = shape[::-1]
|
||||
self.filelist = filelist
|
||||
self.reset()
|
||||
|
||||
def reset(self):
|
||||
fname = np.random.choice(self.filelist)
|
||||
self.frames = skvideo.io.vread(fname)
|
||||
self.frame_idx = 0
|
||||
|
||||
def get_image(self):
|
||||
if self.frame_idx >= self.frames.shape[0]:
|
||||
self.reset()
|
||||
im = self.frames[self.frame_idx][:, :, ::-1]
|
||||
self.frame_idx += 1
|
||||
im = im[:, :, ::-1]
|
||||
im = cv2.resize(im, self.shape_wh)
|
||||
return im
|
@ -1,32 +0,0 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
|
||||
class BackgroundMatting(object):
|
||||
"""
|
||||
Produce a mask of a given image which will be replaced by natural signals.
|
||||
"""
|
||||
def get_mask(self, img):
|
||||
"""
|
||||
Take an image of [H, W, 3]. Returns a mask of [H, W]
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
class BackgroundMattingWithColor(BackgroundMatting):
|
||||
"""
|
||||
Produce a mask by masking the given color. This is a simple strategy
|
||||
but effective for many games.
|
||||
"""
|
||||
def __init__(self, color):
|
||||
"""
|
||||
Args:
|
||||
color: a (r, g, b) tuple
|
||||
"""
|
||||
self._color = color
|
||||
|
||||
def get_mask(self, img):
|
||||
return img == self._color
|
@ -1,119 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import os
|
||||
import argparse
|
||||
import glob
|
||||
import gym
|
||||
from gym.utils import play
|
||||
|
||||
from .matting import BackgroundMattingWithColor
|
||||
from .imgsource import (
|
||||
RandomImageSource,
|
||||
RandomColorSource,
|
||||
NoiseSource,
|
||||
RandomVideoSource,
|
||||
)
|
||||
|
||||
|
||||
class ReplaceBackgroundEnv(gym.ObservationWrapper):
|
||||
|
||||
viewer = None
|
||||
|
||||
def __init__(self, env, bg_matting, natural_source):
|
||||
"""
|
||||
The source must produce a image with a shape that's compatible to
|
||||
`env.observation_space`.
|
||||
"""
|
||||
super(ReplaceBackgroundEnv, self).__init__(env)
|
||||
self._bg_matting = bg_matting
|
||||
self._natural_source = natural_source
|
||||
|
||||
def observation(self, obs):
|
||||
mask = self._bg_matting.get_mask(obs)
|
||||
img = self._natural_source.get_image()
|
||||
obs[mask] = img[mask]
|
||||
self._last_ob = obs
|
||||
return obs
|
||||
|
||||
def reset(self):
|
||||
self._natural_source.reset()
|
||||
return super(ReplaceBackgroundEnv, self).reset()
|
||||
|
||||
# modified from gym/envs/atari/atari_env.py
|
||||
# This makes the monitor work
|
||||
def render(self, mode="human"):
|
||||
img = self._last_ob
|
||||
if mode == "rgb_array":
|
||||
return img
|
||||
elif mode == "human":
|
||||
from gym.envs.classic_control import rendering
|
||||
|
||||
if self.viewer is None:
|
||||
self.viewer = rendering.SimpleImageViewer()
|
||||
self.viewer.imshow(img)
|
||||
return env.viewer.isopen
|
||||
|
||||
|
||||
def make(name='Pong-v0', imgsource='color', files=None):
|
||||
env = gym.make(name) # gravitar, breakout, MsPacman, Space Invaders
|
||||
shape2d = env.observation_space.shape[:2]
|
||||
color = (0, 0, 0) if 'Pong' not in name else (144, 72, 17)
|
||||
if imgsource == 'video':
|
||||
imgsource = RandomVideoSource(shape2d, ['/Users/romue/PycharmProjects/EDYS/environments/policy_adaption/natural_rl_environment/videos/stars.mp4'])
|
||||
elif imgsource == "color":
|
||||
imgsource = RandomColorSource(shape2d)
|
||||
elif imgsource == "noise":
|
||||
imgsource = NoiseSource(shape2d)
|
||||
elif imgsource == "images":
|
||||
imgsource = RandomImageSource(shape2d, files)
|
||||
else:
|
||||
raise NotImplementedError(f'{imgsource} is not supported, use one of {{video, color, noise}}')
|
||||
wrapped_env = ReplaceBackgroundEnv(
|
||||
env, BackgroundMattingWithColor(color), imgsource
|
||||
)
|
||||
return wrapped_env
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--env", help="The gym environment to base on")
|
||||
parser.add_argument("--imgsource", choices=["color", "noise", "images", "videos"])
|
||||
parser.add_argument(
|
||||
"--resource-files", help="A glob pattern to obtain images or videos"
|
||||
)
|
||||
parser.add_argument("--dump-video", help="If given, a directory to dump video")
|
||||
args = parser.parse_args()
|
||||
|
||||
env = gym.make(args.env)
|
||||
shape2d = env.observation_space.shape[:2]
|
||||
|
||||
if args.imgsource:
|
||||
if args.imgsource == "color":
|
||||
imgsource = RandomColorSource(shape2d)
|
||||
elif args.imgsource == "noise":
|
||||
imgsource = NoiseSource(shape2d)
|
||||
else:
|
||||
files = glob.glob(os.path.expanduser(args.resource_files))
|
||||
assert len(files), "Pattern {} does not match any files".format(
|
||||
args.resource_files
|
||||
)
|
||||
if args.imgsource == "images":
|
||||
imgsource = RandomImageSource(shape2d, files)
|
||||
else:
|
||||
imgsource = RandomVideoSource(shape2d, files)
|
||||
|
||||
wrapped_env = ReplaceBackgroundEnv(
|
||||
env, BackgroundMattingWithColor((0, 0, 0)), imgsource
|
||||
)
|
||||
else:
|
||||
wrapped_env = env
|
||||
|
||||
if args.dump_video:
|
||||
assert os.path.isdir(args.dump_video)
|
||||
wrapped_env = gym.wrappers.Monitor(wrapped_env, args.dump_video)
|
||||
play.play(wrapped_env, zoom=4)
|
Binary file not shown.
Binary file not shown.
@ -1,8 +0,0 @@
|
||||
import gym
|
||||
import glob
|
||||
from environments.policy_adaption.natural_rl_environment.imgsource import *
|
||||
from environments.policy_adaption.natural_rl_environment.natural_env import *
|
||||
|
||||
if __name__ == "__main__":
|
||||
env = make('SpaceInvaders-v0', 'video') # gravitar, breakout, MsPacman, Space Invaders
|
||||
play.play(env, zoom=4)
|
Reference in New Issue
Block a user