-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathmain.py
More file actions
77 lines (51 loc) · 3.69 KB
/
main.py
File metadata and controls
77 lines (51 loc) · 3.69 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
from MultiResGenerator import MultiResGenerator
img_folder = "path_to_img_folder/"
gt_folder = "path_to_label_folder/"
output_folder = "path_to_output_folder/"
# Patch-related hyperparameters
number_of_scales = 5
real_patch_size = [[1024, 1024], [512, 512], [256,256] , [128,128], [128,128]] # Size of the field of view to consider for each scale
model_patch_size = [[64, 64], [64, 64], [64, 64], [64,64], [128,128]] # Size of the patch (model input) for each scale
stride = [0, 256, 128, 64, 64] # Stride of the field of view used to draw patches
augmentation = [["flip_vertical", "flip_horizontal", "rotation90", "rotation180"], ["none"], ["none"], ["none"], ["none"]] # Augmentation to apply for each scale
model_type = ["ModeSeekingcGAN", "RefinementcGAN", "RefinementcGAN", "RefinementcGAN", "RefinementcGAN"] # Type of model to use for each scale
# Other hyperparameters
lbd = [1, 10, 10, 10, 10] # Loss balance parameter lbd (depends on the type of model)
itrs = [6000, 6000, 6000, 6000, 6000] # Number of iteration to run for each scale
lr = [0.0001, 0.0001, 0.0001, 0.0001, 0.0001] # Learning rate to use for each scale
batch_size = [32, 32, 32, 32, 32] # Batch size for each scale
nz = 12 # Size of the random vector used for ModeSeekingcGAN
### TRAINING ALL ###
# Create the cascade model
model = MultiResGenerator(output_folder, len(real_patch_size), real_patch_size, model_patch_size, model_type, itrs, augmentation, stride, lr, lbd, nz, batch_size)
# Load the images and labels. In our experience, to generate non histogram equalized images, it is crucial to provide both the equalized and non-equalized version of the original images.
# The images and labels must have the same file names. Optional : file_list argument to load only part of the images contained in folder (given by a list of file names)
model.load_images([img_folder, img_folder, gt_folder], data_type = ["img","img", "gt"], dataset_name = "train", equalize = [True, False, False], data_range = [-1,1])
model.run_training_all("train") # Train all scales one-by-one
model.write_images(number_of_scales-1, "train", "gen_" + str(number_of_scales-1)) # Write images generated at the last scales
"""
### TRAINING ONLY SOME STEPS ###
# Create the cascade model
model = MultiResGenerator(output_folder, len(real_patch_size), real_patch_size, model_patch_size, model_type, itrs, augmentation, stride, lr, lbd, nz, batch_size)
# Load the images and labels
model.load_images([img_folder, img_folder, gt_folder], data_type = ["img","img","gt"], dataset_name = "train", equalize = [True, False, False], data_range = [-1,1])
# Load all the previously trained models
model.load_state_dicts()
start_scale = 2 # Train from scale 2
for i in range(0,start_scale):
model.load_generated_image(i, "train") # Load the images generated at previous scales
for i in range(start_scale, number_of_scales):
model.run_training(i, "train", write = True)
#model.write_images(i, "train", "gen_" + str(i)) # Write images generated at the intermediate scales
model.write_images(number_of_scales-1, "train", "gen_" + str(number_of_scales-1)) # Write images generated at the last scales
### EVALUATION ###
test_gt_folder = "path_to_test_label_folder"
# Load the full cascade weights
augmentation = [["none"], ["none"], ["none"], ["none"], ["none"]]
model = MultiResGenerator(output_folder, len(real_patch_size), real_patch_size, model_patch_size, model_type, itrs, augmentation, stride, lr, lbd, nz, batch_size, n_img_channels = 2)
model.load_images([test_gt_folder], ["gt"], "test", [False], data_range = [-1,1])
model.load_state_dicts()
for i in range(len(itrs)):
model.run_evaluation(i, "test", write = False)
model.write_images(number_of_scales-1, "test", "gen_" + str(number_of_scales-1))
"""