Modifying U-Net implementation for smaller image size
I'm implementing the U-Net model per the published paper here. This is my model so far:
def create_unet_model(image_size = IMAGE_SIZE):
# Input layer is a 572,572 colour image
input_layer = Input(shape=(image_size) + (3,))
Begin Downsampling
# Block 1
conv_1 = Conv2D(64, 3, activation = 'relu')(input_layer)
conv_2 = Conv2D(64, 3, activation = 'relu')(conv_1)
max_pool_1 = MaxPool2D(strides=2)(conv_2)
# Block 2
conv_3 = Conv2D(128, 3, activation = 'relu')(max_pool_1)
conv_4 = Conv2D(128, 3, activation = 'relu')(conv_3)
max_pool_2 = MaxPool2D(strides=2)(conv_4)
# Block 3
conv_5 = Conv2D(256, 3, activation = 'relu')(max_pool_2)
conv_6 = Conv2D(256, 3, activation = 'relu')(conv_5)
max_pool_3 = MaxPool2D(strides=2)(conv_6)
# Block 4
conv_7 = Conv2D(512, 3, activation = 'relu')(max_pool_3)
conv_8 = Conv2D(512, 3, activation = 'relu')(conv_7)
max_pool_4 = MaxPool2D(strides=2)(conv_8)
Begin Upsampling
# Block 5
conv_9 = Conv2D(1024, 3, activation = 'relu')(max_pool_4)
conv_10 = Conv2D(1024, 3, activation = 'relu')(conv_9)
upsample_1 = UpSampling2D()(conv_10)
# Copy and Crop
conv_8_cropped = Cropping2D(cropping=4)(conv_8)
merge_1 = Concatenate()([conv_8_cropped, upsample_1])
# Block 6
conv_11 = Conv2D(512, 3, activation = 'relu')(merge_1)
conv_12 = Conv2D(512, 3, activation = 'relu')(conv_11)
upsample_2 = UpSampling2D()(conv_12)
# Copy and Crop
conv_6_cropped = Cropping2D(cropping=16)(conv_6)
merge_2 = Concatenate()([conv_6_cropped, upsample_2])
# Block 7
conv_13 = Conv2D(256, 3, activation = 'relu')(merge_2)
conv_14 = Conv2D(256, 3, activation = 'relu')(conv_13)
upsample_3 = UpSampling2D()(conv_14)
# Copy and Crop
conv_4_cropped = Cropping2D(cropping=40)(conv_4)
merge_3 = Concatenate()([conv_4_cropped, upsample_3])
# Block 8
conv_15 = Conv2D(128, 3, activation = 'relu')(merge_3)
conv_16 = Conv2D(128, 3, activation = 'relu')(conv_15)
upsample_4 = UpSampling2D()(conv_16)
# Connect layers
conv_2_cropped = Cropping2D(cropping=88)(conv_2)
merge_4 = Concatenate()([conv_2_cropped, upsample_4])
# Block 9
conv_17 = Conv2D(64, 3, activation = 'relu')(merge_4)
conv_18 = Conv2D(64, 3, activation = 'relu')(conv_17)
# Output layer
output_layer = Conv2D(1, 1, activation='sigmoid')(conv_18)
Define the model
unet = Model(input_layer, output_layer)
return unet
The cropping implemented as specified in this answer and is specific to 572x572 images.
Unfortunately this implementation causes a ResourceExhaustedError:
Exception has occurred: ResourceExhaustedError
OOM when allocating tensor with shape[32,64,392,392] and type float on /job:localhost/replica:0/task:0/device:GPU:0 by allocator GPU_0_bfc
[[node model/cropping2d_3/strided_slice (defined at c:\main.py:74) ]]
Hint: If you want to see a list of allocated tensors when OOM happens, add report_tensor_allocations_upon_oom to RunOptions for current allocation info.
[Op:__inference_train_function_3026]
Function call stack:
train_function
File C:\main.py, line 74, in main
unet_model.fit(train_images, epochs=epochs, validation_data=validation_images, callbacks=CALLBACKS)
File C:\main.py, line 276, in module
main()
My GPU is a GeForce RTX 2070 Super 8GB.
I verified that the image size was the source of this by reproducing the error in another u-net solution which I know works.
To workaround this issue, I'm trying to lower the image sizes e.g. 256x256. I've changed the Cropping2D
layers to crop to the expected sizes for each layer:
# Copy and Crop - 24 - 16
conv_8_cropped = Cropping2D(cropping=4)(conv_8)
merge_1 = Concatenate()([conv_8_cropped, upsample_1])
# Copy and Crop - 57 - 24
conv_6_cropped = Cropping2D(cropping=((17,16),(17,16)))(conv_6)
merge_2 = Concatenate()([conv_6_cropped, upsample_2])
# Copy and Crop - 122 - 40
conv_4_cropped = Cropping2D(cropping=41)(conv_4)
merge_3 = Concatenate()([conv_4_cropped, upsample_3])
# Copy and Crop - 252 - 72
conv_2_cropped = Cropping2D(cropping=90)(conv_2)
merge_4 = Concatenate()([conv_2_cropped, upsample_4])
Updated model summary:
__________________________________________________________________________________________________
Layer (type) Output Shape Param # Connected to
==================================================================================================
input_1 (InputLayer) [(None, 256, 256, 3) 0
__________________________________________________________________________________________________
conv2d (Conv2D) (None, 254, 254, 64) 1792 input_1[0][0]
__________________________________________________________________________________________________
conv2d_1 (Conv2D) (None, 252, 252, 64) 36928 conv2d[0][0]
__________________________________________________________________________________________________
max_pooling2d (MaxPooling2D) (None, 126, 126, 64) 0 conv2d_1[0][0]
__________________________________________________________________________________________________
conv2d_2 (Conv2D) (None, 124, 124, 128 73856 max_pooling2d[0][0]
__________________________________________________________________________________________________
conv2d_3 (Conv2D) (None, 122, 122, 128 147584 conv2d_2[0][0]
__________________________________________________________________________________________________
max_pooling2d_1 (MaxPooling2D) (None, 61, 61, 128) 0 conv2d_3[0][0]
__________________________________________________________________________________________________
conv2d_4 (Conv2D) (None, 59, 59, 256) 295168 max_pooling2d_1[0][0]
__________________________________________________________________________________________________
conv2d_5 (Conv2D) (None, 57, 57, 256) 590080 conv2d_4[0][0]
__________________________________________________________________________________________________
max_pooling2d_2 (MaxPooling2D) (None, 28, 28, 256) 0 conv2d_5[0][0]
__________________________________________________________________________________________________
conv2d_6 (Conv2D) (None, 26, 26, 512) 1180160 max_pooling2d_2[0][0]
__________________________________________________________________________________________________
conv2d_7 (Conv2D) (None, 24, 24, 512) 2359808 conv2d_6[0][0]
__________________________________________________________________________________________________
max_pooling2d_3 (MaxPooling2D) (None, 12, 12, 512) 0 conv2d_7[0][0]
__________________________________________________________________________________________________
conv2d_8 (Conv2D) (None, 10, 10, 1024) 4719616 max_pooling2d_3[0][0]
__________________________________________________________________________________________________
conv2d_9 (Conv2D) (None, 8, 8, 1024) 9438208 conv2d_8[0][0]
__________________________________________________________________________________________________
cropping2d (Cropping2D) (None, 16, 16, 512) 0 conv2d_7[0][0]
__________________________________________________________________________________________________
up_sampling2d (UpSampling2D) (None, 16, 16, 1024) 0 conv2d_9[0][0]
__________________________________________________________________________________________________
concatenate (Concatenate) (None, 16, 16, 1536) 0 cropping2d[0][0]
up_sampling2d[0][0]
__________________________________________________________________________________________________
conv2d_10 (Conv2D) (None, 14, 14, 512) 7078400 concatenate[0][0]
__________________________________________________________________________________________________
conv2d_11 (Conv2D) (None, 12, 12, 512) 2359808 conv2d_10[0][0]
__________________________________________________________________________________________________
cropping2d_1 (Cropping2D) (None, 24, 24, 256) 0 conv2d_5[0][0]
__________________________________________________________________________________________________
up_sampling2d_1 (UpSampling2D) (None, 24, 24, 512) 0 conv2d_11[0][0]
__________________________________________________________________________________________________
concatenate_1 (Concatenate) (None, 24, 24, 768) 0 cropping2d_1[0][0]
up_sampling2d_1[0][0]
__________________________________________________________________________________________________
conv2d_12 (Conv2D) (None, 22, 22, 256) 1769728 concatenate_1[0][0]
__________________________________________________________________________________________________
conv2d_13 (Conv2D) (None, 20, 20, 256) 590080 conv2d_12[0][0]
__________________________________________________________________________________________________
cropping2d_2 (Cropping2D) (None, 40, 40, 128) 0 conv2d_3[0][0]
__________________________________________________________________________________________________
up_sampling2d_2 (UpSampling2D) (None, 40, 40, 256) 0 conv2d_13[0][0]
__________________________________________________________________________________________________
concatenate_2 (Concatenate) (None, 40, 40, 384) 0 cropping2d_2[0][0]
up_sampling2d_2[0][0]
__________________________________________________________________________________________________
conv2d_14 (Conv2D) (None, 38, 38, 128) 442496 concatenate_2[0][0]
__________________________________________________________________________________________________
conv2d_15 (Conv2D) (None, 36, 36, 128) 147584 conv2d_14[0][0]
__________________________________________________________________________________________________
cropping2d_3 (Cropping2D) (None, 72, 72, 64) 0 conv2d_1[0][0]
__________________________________________________________________________________________________
up_sampling2d_3 (UpSampling2D) (None, 72, 72, 128) 0 conv2d_15[0][0]
__________________________________________________________________________________________________
concatenate_3 (Concatenate) (None, 72, 72, 192) 0 cropping2d_3[0][0]
up_sampling2d_3[0][0]
__________________________________________________________________________________________________
conv2d_16 (Conv2D) (None, 70, 70, 64) 110656 concatenate_3[0][0]
__________________________________________________________________________________________________
conv2d_17 (Conv2D) (None, 68, 68, 64) 36928 conv2d_16[0][0]
__________________________________________________________________________________________________
conv2d_18 (Conv2D) (None, 68, 68, 1) 65 conv2d_17[0][0]
==================================================================================================
Total params: 31,378,945
Trainable params: 31,378,945
Non-trainable params: 0
This compiles fine but fails at train time with:
Exception has occurred: InvalidArgumentError
Incompatible shapes: [32,68,68] vs. [32,256,256]
[[node Equal (defined at c:\main.py:74) ]] [Op:__inference_train_function_3026]
Function call stack:
train_function
Does anyone know why the shapes are so incorrect at runtime and how I can fix them?
Update Image loading as part of custom Sequence implementation
source_image = load_img(source_image_paths[i], target_size=self.image_size, color_mode='grayscale')
target_image = load_img(target_image_paths[i], target_size=self.image_size, color_mode='grayscale')
#Start classes at 0
target_image = np.array(target_image) - 1
target_image_array.append(target_image)
source_image_array.append(np.array(source_image))
Topic image-segmentation reshape cnn keras tensorflow
Category Data Science