New Dataset Generator, How to differentiate the loss function?
This commit is contained in:
@ -2,15 +2,10 @@ import multiprocessing as mp
|
||||
import pickle
|
||||
import shelve
|
||||
from collections import defaultdict
|
||||
from functools import partial
|
||||
|
||||
from pathlib import Path
|
||||
from typing import Union
|
||||
|
||||
from tqdm import trange
|
||||
|
||||
from lib.objects.map import Map
|
||||
from lib.utils.parallel import run_n_in_parallel
|
||||
|
||||
|
||||
class Generator:
|
||||
@ -109,7 +104,7 @@ class Generator:
|
||||
trajectory=trajectory,
|
||||
labels=labels)
|
||||
if 'map' not in f:
|
||||
f['map'] = dict(map=self.map, name=f'map_{self.map.name}')
|
||||
f['map'] = dict(map=self.map, name=self.map.name)
|
||||
|
||||
@staticmethod
|
||||
def _remove_unequal(hom_dict):
|
||||
|
Reference in New Issue
Block a user