Merge remote-tracking branch 'origin/master'
# Conflicts: # utils/tools.py
This commit is contained in:
		
							
								
								
									
										7
									
								
								.gitignore
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										7
									
								
								.gitignore
									
									
									
									
										vendored
									
									
								
							| @@ -1,6 +1 @@ | ||||
| /.idea/ | ||||
| # my own stuff | ||||
|  | ||||
| /data | ||||
| /.idea | ||||
| /ml_lib | ||||
| .idea | ||||
|   | ||||
| @@ -4,11 +4,11 @@ from pathlib import Path | ||||
| from typing import Union | ||||
|  | ||||
| import torch | ||||
| from performer_pytorch import FastAttention | ||||
|  | ||||
| from torch import nn | ||||
| from torch.nn import functional as F | ||||
|  | ||||
| from einops import rearrange | ||||
| from einops import rearrange, repeat | ||||
|  | ||||
| import sys | ||||
| sys.path.append(str(Path(__file__).parent)) | ||||
| @@ -262,7 +262,10 @@ class Attention(nn.Module): | ||||
|             mask = F.pad(mask.flatten(1), (1, 0), value=True) | ||||
|             assert mask.shape[-1] == dots.shape[-1], 'mask has incorrect dimensions' | ||||
|             mask = mask[:, None, :] * mask[:, :, None] | ||||
|             dots.masked_fill_(~mask, mask_value) | ||||
|             mask = repeat(mask, 'b n d -> b h n d', h=h)             # My addition | ||||
|  | ||||
|             #dots.masked_fill_(~mask, mask_value) | ||||
|             dots.masked_fill_(mask, mask_value) | ||||
|             del mask | ||||
|  | ||||
|         attn = dots.softmax(dim=-1) | ||||
|   | ||||
| @@ -26,6 +26,7 @@ def parse_comandline_args_add_defaults(filepath, overrides=None): | ||||
|     parser = ArgumentParser() | ||||
|     parser.add_argument('--model_name', type=str) | ||||
|     parser.add_argument('--data_name', type=str) | ||||
|     parser.add_argument('--seed', type=str) | ||||
|  | ||||
|     # Load Defaults from _parameters.ini file | ||||
|     config = configparser.ConfigParser() | ||||
| @@ -46,9 +47,11 @@ def parse_comandline_args_add_defaults(filepath, overrides=None): | ||||
|     overrides = overrides or dict() | ||||
|     default_data = overrides.get('data_name', None) or new_defaults['data_name'] | ||||
|     default_model = overrides.get('model_name', None) or new_defaults['model_name'] | ||||
|     default_seed = overrides.get('seed', None) or new_defaults['seed'] | ||||
|  | ||||
|     data_name = args.__dict__.get('data_name', None) or default_data | ||||
|     model_name = args.__dict__.get('model_name', None) or default_model | ||||
|     found_seed = args.__dict__.get('seed', None) or default_seed | ||||
|  | ||||
|     new_defaults.update({key: auto_cast(val) for key, val in config[model_name].items()}) | ||||
|  | ||||
| @@ -72,7 +75,7 @@ def parse_comandline_args_add_defaults(filepath, overrides=None): | ||||
|  | ||||
|     if overrides is not None and isinstance(overrides, (Mapping, Dict)): | ||||
|         args.update(**overrides) | ||||
|     return args, found_data_class, found_model_class | ||||
|     return args, found_data_class, found_model_class, found_seed | ||||
|  | ||||
|  | ||||
| def is_jsonable(x): | ||||
|   | ||||
| @@ -29,6 +29,7 @@ def fix_all_random_seeds(seed): | ||||
|     np.random.seed(seed) | ||||
|     torch.manual_seed(seed) | ||||
|     random.seed(seed) | ||||
|     print(f'Seed is now fixed: "{seed}".') | ||||
|  | ||||
|  | ||||
| def write_to_shelve(file_path, value): | ||||
|   | ||||
		Reference in New Issue
	
	Block a user
	 Steffen Illium
					Steffen Illium