33 lines
1.2 KiB
Python
33 lines
1.2 KiB
Python
from abc import ABC, abstractmethod
|
|
import numpy as np
|
|
|
|
from typing import Tuple
|
|
|
|
|
|
class Task(ABC):
|
|
|
|
def __init__(self, input_shape, output_shape, **kwargs):
|
|
assert any([x not in kwargs.keys() for x in ["input_shape", "output_shape"]]), 'Dublicated arguments were given'
|
|
self.input_shape = input_shape
|
|
self.output_shape = output_shape
|
|
self.batchsize = kwargs.get('batchsize', 100)
|
|
|
|
def get_samples(self) -> Tuple[np.ndarray, np.ndarray]:
|
|
raise NotImplementedError
|
|
|
|
|
|
class TaskAdditionOfN(Task):
|
|
|
|
def __init__(self, n: int, input_shape=(4,), output_shape=1, **kwargs):
|
|
assert any([x not in kwargs.keys() for x in ["input_shape", "output_shape"]]), 'Dublicated arguments were given'
|
|
assert n <= input_shape[0], f'You cannot Add more values (n={n}) than your input is long (in={input_shape}).'
|
|
kwargs.update(input_shape=input_shape, output_shape=output_shape)
|
|
super(TaskAdditionOfN, self).__init__(**kwargs)
|
|
self.n = n
|
|
|
|
def get_samples(self) -> Tuple[np.ndarray, np.ndarray]:
|
|
x = np.zeros((self.batchsize, *self.input_shape))
|
|
x[:, :self.n] = np.random.standard_normal((self.batchsize, self.n)) * 0.5
|
|
y = np.sum(x, axis=1)
|
|
return x, y
|