How to get the expected output shape from a unet model?

I have an image segmentation task where my input image shape is (140, 85, 95, 4) and the output label shape is (140, 85, 95). Below is my model:

from tensorflow.keras.layers import Conv2D, Conv2DTranspose, Input, Rescaling

num_classes = 4

my_model = tf.keras.Sequential([

Input(shape = (85, 95, 4), name = 'image'),
Rescaling(scale = 1./255),
Conv2D(filters = 64, kernel_size = 3, strides = 1, activation = 'relu', padding = 'same'),
Conv2D(filters = 64, kernel_size = 3, activation = 'relu', padding = 'same'),
Conv2D(filters = 128, kernel_size = 3, strides = 1, activation = 'relu', padding = 'same'),
Conv2D(filters = 128, kernel_size = 3, activation = 'relu', padding = 'same'),
Conv2D(filters = 256, kernel_size = 3, strides = 1, activation = 'relu', padding = 'same'),
Conv2D(filters = 256, kernel_size = 3, activation = 'relu', padding = 'same'),

Conv2DTranspose(filters = 256, kernel_size = 3, activation = 'relu', padding = 'same'),
Conv2DTranspose(filters = 256, kernel_size = 3, strides = 1, activation = 'relu', padding = 'same'),
Conv2DTranspose(filters = 128, kernel_size = 3, activation = 'relu', padding = 'same'),
Conv2DTranspose(filters = 128, kernel_size = 3, strides = 1, activation = 'relu', padding = 'same'),
Conv2DTranspose(filters = 64, kernel_size = 3, activation = 'relu', padding = 'same'),
Conv2DTranspose(filters = 64, kernel_size = 3, strides = 1, activation = 'relu', padding = 'same'),

Conv2D(filters = num_classes, kernel_size = 3, activation = 'softmax', padding = 'same')


After training, I tried predicting one image and the model produced a label with shape (140, 85, 95, 4) as the output but I want it to be (140, 85, 95) or (140, 85, 95, 1).

How can I fix this? Thank you.

Topic image-segmentation python

Category Data Science

Your last layer uses 4 filter because num_classes is set to 4, resulting in an array with 4 channels in the last dimension. If you simply want only one channel simply change the number of filters for the last convolutional layer to one (filters=1).


Geeks Mental is a community that publishes articles and tutorials about Web, Android, Data Science, new techniques and Linux security.