# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.

import numpy as np
import cv2
import skvideo.io


class ImageSource(object):
    """
    Source of natural images to be added to a simulated environment.
    """
    def get_image(self):
        """
        Returns:
            an RGB image of [h, w, 3] with a fixed shape.
        """
        pass

    def reset(self):
        """ Called when an episode ends. """
        pass


class FixedColorSource(ImageSource):
    def __init__(self, shape, color):
        """
        Args:
            shape: [h, w]
            color: a 3-tuple
        """
        self.arr = np.zeros((shape[0], shape[1], 3))
        self.arr[:, :] = color

    def get_image(self):
        return np.copy(self.arr)


class RandomColorSource(ImageSource):
    def __init__(self, shape):
        """
        Args:
            shape: [h, w]
        """
        self.shape = shape
        self.reset()

    def reset(self):
        self._color = np.random.randint(0, 256, size=(3,))

    def get_image(self):
        arr = np.zeros((self.shape[0], self.shape[1], 3))
        arr[:, :] = self._color
        return arr


class NoiseSource(ImageSource):
    def __init__(self, shape, strength=50):
        """
        Args:
            shape: [h, w]
            strength (int): the strength of noise, in range [0, 255]
        """
        self.shape = shape
        self.strength = strength

    def get_image(self):
        return np.maximum(np.random.randn(
            self.shape[0], self.shape[1], 3) * self.strength, 0)


class RandomImageSource(ImageSource):
    def __init__(self, shape, filelist):
        """
        Args:
            shape: [h, w]
            filelist: a list of image files
        """
        self.shape_wh = shape[::-1]
        self.filelist = filelist
        self.reset()

    def reset(self):
        fname = np.random.choice(self.filelist)
        im = cv2.imread(fname, cv2.IMREAD_COLOR)
        im = im[:, :, ::-1]
        im = cv2.resize(im, self.shape_wh)
        self._im = im

    def get_image(self):
        return self._im


class RandomVideoSource(ImageSource):
    def __init__(self, shape, filelist):
        """
        Args:
            shape: [h, w]
            filelist: a list of video files
        """
        self.shape_wh = shape[::-1]
        self.filelist = filelist
        self.reset()

    def reset(self):
        fname = np.random.choice(self.filelist)
        self.frames = skvideo.io.vread(fname)
        self.frame_idx = 0

    def get_image(self):
        if self.frame_idx >= self.frames.shape[0]:
            self.reset()
        im = self.frames[self.frame_idx][:, :, ::-1]
        self.frame_idx += 1
        im = im[:, :, ::-1]
        im = cv2.resize(im, self.shape_wh)
        return im