Modifying U-Net implementation for smaller image size

I'm implementing the U-Net model per the published paper here. This is my model so far:

def create_unet_model(image_size = IMAGE_SIZE):

    # Input layer is a 572,572 colour image
    input_layer = Input(shape=(image_size) + (3,))

     Begin Downsampling 

    # Block 1
    conv_1 = Conv2D(64, 3, activation = 'relu')(input_layer)
    conv_2 = Conv2D(64, 3, activation = 'relu')(conv_1)

    max_pool_1 = MaxPool2D(strides=2)(conv_2)

    # Block 2
    conv_3 = Conv2D(128, 3, activation = 'relu')(max_pool_1)
    conv_4 = Conv2D(128, 3, activation = 'relu')(conv_3)

    max_pool_2 = MaxPool2D(strides=2)(conv_4)

    # Block 3
    conv_5 = Conv2D(256, 3, activation = 'relu')(max_pool_2)
    conv_6 = Conv2D(256, 3, activation = 'relu')(conv_5)

    max_pool_3 = MaxPool2D(strides=2)(conv_6)

    # Block 4
    conv_7 = Conv2D(512, 3, activation = 'relu')(max_pool_3)
    conv_8 = Conv2D(512, 3, activation = 'relu')(conv_7)

    max_pool_4 = MaxPool2D(strides=2)(conv_8)

     Begin Upsampling 

    # Block 5
    conv_9 = Conv2D(1024, 3, activation = 'relu')(max_pool_4)
    conv_10 = Conv2D(1024, 3, activation = 'relu')(conv_9)

    upsample_1 = UpSampling2D()(conv_10)

    # Copy and Crop
    conv_8_cropped = Cropping2D(cropping=4)(conv_8)
    merge_1 = Concatenate()([conv_8_cropped, upsample_1])

    # Block 6
    conv_11 = Conv2D(512, 3, activation = 'relu')(merge_1)
    conv_12 = Conv2D(512, 3, activation = 'relu')(conv_11)

    upsample_2 = UpSampling2D()(conv_12)

    # Copy and Crop
    conv_6_cropped = Cropping2D(cropping=16)(conv_6)
    merge_2 = Concatenate()([conv_6_cropped, upsample_2])

    # Block 7
    conv_13 = Conv2D(256, 3, activation = 'relu')(merge_2)
    conv_14 = Conv2D(256, 3, activation = 'relu')(conv_13)
    upsample_3 = UpSampling2D()(conv_14)

    # Copy and Crop
    conv_4_cropped = Cropping2D(cropping=40)(conv_4)
    merge_3 = Concatenate()([conv_4_cropped, upsample_3])

    # Block 8
    conv_15 = Conv2D(128, 3, activation = 'relu')(merge_3)
    conv_16 = Conv2D(128, 3, activation = 'relu')(conv_15)
    upsample_4 = UpSampling2D()(conv_16)

    # Connect layers
    conv_2_cropped = Cropping2D(cropping=88)(conv_2)
    merge_4 = Concatenate()([conv_2_cropped, upsample_4])

    # Block 9
    conv_17 = Conv2D(64, 3, activation = 'relu')(merge_4)
    conv_18 = Conv2D(64, 3, activation = 'relu')(conv_17)

    # Output layer
    output_layer = Conv2D(1, 1, activation='sigmoid')(conv_18)

     Define the model 
    unet = Model(input_layer, output_layer)
    
    return unet

The cropping implemented as specified in this answer and is specific to 572x572 images.

Unfortunately this implementation causes a ResourceExhaustedError:

Exception has occurred: ResourceExhaustedError
 OOM when allocating tensor with shape[32,64,392,392] and type float on /job:localhost/replica:0/task:0/device:GPU:0 by allocator GPU_0_bfc
     [[node model/cropping2d_3/strided_slice (defined at c:\main.py:74) ]]
Hint: If you want to see a list of allocated tensors when OOM happens, add report_tensor_allocations_upon_oom to RunOptions for current allocation info.
 [Op:__inference_train_function_3026]

Function call stack:
train_function
  File C:\main.py, line 74, in main
    unet_model.fit(train_images, epochs=epochs, validation_data=validation_images, callbacks=CALLBACKS)
  File C:\main.py, line 276, in module
    main()

My GPU is a GeForce RTX 2070 Super 8GB.

I verified that the image size was the source of this by reproducing the error in another u-net solution which I know works.

To workaround this issue, I'm trying to lower the image sizes e.g. 256x256. I've changed the Cropping2D layers to crop to the expected sizes for each layer:

# Copy and Crop - 24 - 16
conv_8_cropped = Cropping2D(cropping=4)(conv_8)
merge_1 = Concatenate()([conv_8_cropped, upsample_1])

# Copy and Crop - 57 - 24
conv_6_cropped = Cropping2D(cropping=((17,16),(17,16)))(conv_6)
merge_2 = Concatenate()([conv_6_cropped, upsample_2])

# Copy and Crop - 122 - 40
conv_4_cropped = Cropping2D(cropping=41)(conv_4)
merge_3 = Concatenate()([conv_4_cropped, upsample_3])

# Copy and Crop - 252 - 72
conv_2_cropped = Cropping2D(cropping=90)(conv_2)
merge_4 = Concatenate()([conv_2_cropped, upsample_4])

Updated model summary:

__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to
==================================================================================================
input_1 (InputLayer)            [(None, 256, 256, 3) 0
__________________________________________________________________________________________________
conv2d (Conv2D)                 (None, 254, 254, 64) 1792        input_1[0][0]
__________________________________________________________________________________________________
conv2d_1 (Conv2D)               (None, 252, 252, 64) 36928       conv2d[0][0]
__________________________________________________________________________________________________
max_pooling2d (MaxPooling2D)    (None, 126, 126, 64) 0           conv2d_1[0][0]
__________________________________________________________________________________________________
conv2d_2 (Conv2D)               (None, 124, 124, 128 73856       max_pooling2d[0][0]
__________________________________________________________________________________________________
conv2d_3 (Conv2D)               (None, 122, 122, 128 147584      conv2d_2[0][0]
__________________________________________________________________________________________________
max_pooling2d_1 (MaxPooling2D)  (None, 61, 61, 128)  0           conv2d_3[0][0]
__________________________________________________________________________________________________
conv2d_4 (Conv2D)               (None, 59, 59, 256)  295168      max_pooling2d_1[0][0]
__________________________________________________________________________________________________
conv2d_5 (Conv2D)               (None, 57, 57, 256)  590080      conv2d_4[0][0]
__________________________________________________________________________________________________
max_pooling2d_2 (MaxPooling2D)  (None, 28, 28, 256)  0           conv2d_5[0][0]
__________________________________________________________________________________________________
conv2d_6 (Conv2D)               (None, 26, 26, 512)  1180160     max_pooling2d_2[0][0]
__________________________________________________________________________________________________
conv2d_7 (Conv2D)               (None, 24, 24, 512)  2359808     conv2d_6[0][0]
__________________________________________________________________________________________________
max_pooling2d_3 (MaxPooling2D)  (None, 12, 12, 512)  0           conv2d_7[0][0]
__________________________________________________________________________________________________
conv2d_8 (Conv2D)               (None, 10, 10, 1024) 4719616     max_pooling2d_3[0][0]
__________________________________________________________________________________________________
conv2d_9 (Conv2D)               (None, 8, 8, 1024)   9438208     conv2d_8[0][0]
__________________________________________________________________________________________________
cropping2d (Cropping2D)         (None, 16, 16, 512)  0           conv2d_7[0][0]
__________________________________________________________________________________________________
up_sampling2d (UpSampling2D)    (None, 16, 16, 1024) 0           conv2d_9[0][0]
__________________________________________________________________________________________________
concatenate (Concatenate)       (None, 16, 16, 1536) 0           cropping2d[0][0]
                                                                 up_sampling2d[0][0]
__________________________________________________________________________________________________
conv2d_10 (Conv2D)              (None, 14, 14, 512)  7078400     concatenate[0][0]
__________________________________________________________________________________________________
conv2d_11 (Conv2D)              (None, 12, 12, 512)  2359808     conv2d_10[0][0]
__________________________________________________________________________________________________
cropping2d_1 (Cropping2D)       (None, 24, 24, 256)  0           conv2d_5[0][0]
__________________________________________________________________________________________________
up_sampling2d_1 (UpSampling2D)  (None, 24, 24, 512)  0           conv2d_11[0][0]
__________________________________________________________________________________________________
concatenate_1 (Concatenate)     (None, 24, 24, 768)  0           cropping2d_1[0][0]
                                                                 up_sampling2d_1[0][0]
__________________________________________________________________________________________________
conv2d_12 (Conv2D)              (None, 22, 22, 256)  1769728     concatenate_1[0][0]
__________________________________________________________________________________________________
conv2d_13 (Conv2D)              (None, 20, 20, 256)  590080      conv2d_12[0][0]
__________________________________________________________________________________________________
cropping2d_2 (Cropping2D)       (None, 40, 40, 128)  0           conv2d_3[0][0]
__________________________________________________________________________________________________
up_sampling2d_2 (UpSampling2D)  (None, 40, 40, 256)  0           conv2d_13[0][0]
__________________________________________________________________________________________________
concatenate_2 (Concatenate)     (None, 40, 40, 384)  0           cropping2d_2[0][0]
                                                                 up_sampling2d_2[0][0]
__________________________________________________________________________________________________
conv2d_14 (Conv2D)              (None, 38, 38, 128)  442496      concatenate_2[0][0]
__________________________________________________________________________________________________
conv2d_15 (Conv2D)              (None, 36, 36, 128)  147584      conv2d_14[0][0]
__________________________________________________________________________________________________
cropping2d_3 (Cropping2D)       (None, 72, 72, 64)   0           conv2d_1[0][0]
__________________________________________________________________________________________________
up_sampling2d_3 (UpSampling2D)  (None, 72, 72, 128)  0           conv2d_15[0][0]
__________________________________________________________________________________________________
concatenate_3 (Concatenate)     (None, 72, 72, 192)  0           cropping2d_3[0][0]
                                                                 up_sampling2d_3[0][0]
__________________________________________________________________________________________________
conv2d_16 (Conv2D)              (None, 70, 70, 64)   110656      concatenate_3[0][0]
__________________________________________________________________________________________________
conv2d_17 (Conv2D)              (None, 68, 68, 64)   36928       conv2d_16[0][0]
__________________________________________________________________________________________________
conv2d_18 (Conv2D)              (None, 68, 68, 1)    65          conv2d_17[0][0]
==================================================================================================
Total params: 31,378,945
Trainable params: 31,378,945
Non-trainable params: 0

This compiles fine but fails at train time with:

Exception has occurred: InvalidArgumentError
 Incompatible shapes: [32,68,68] vs. [32,256,256]
     [[node Equal (defined at c:\main.py:74) ]] [Op:__inference_train_function_3026]

Function call stack:
train_function

Does anyone know why the shapes are so incorrect at runtime and how I can fix them?

Update Image loading as part of custom Sequence implementation

source_image = load_img(source_image_paths[i], target_size=self.image_size, color_mode='grayscale')
target_image = load_img(target_image_paths[i], target_size=self.image_size, color_mode='grayscale')

#Start classes at 0
target_image = np.array(target_image) - 1

target_image_array.append(target_image)
source_image_array.append(np.array(source_image))

Topic image-segmentation reshape cnn keras tensorflow

Category Data Science


It appears that the original images are 68x68 pixels and the model expects 256x256.

You can use the Keras image processing API, in particular the smart_resize function to transform the images to expected number of pixels.

Something like this:

from tf.keras.preprocessing.image import smart_resize

target_size = (256,256)
image_resized = smart_resize(image_original, size=target_size, interpolation='bilinear')

About

Geeks Mental is a community that publishes articles and tutorials about Web, Android, Data Science, new techniques and Linux security.