After playing around with tf.data.map
operations I found the answer was easier than expected, I simply had to preprocess the data and put all the labels for each output of the model as a different key of a dictionary.
First I create a dataset from the tfrecords file
dataset = tf.data.TFRecordDataset(tfrecords_file)
Next, I parse data from the file
feature = {'image/encoded': tf.io.FixedLenFeature((), tf.string),
'image/shape': tf.io.FixedLenFeature((3), tf.int64),
'age': tf.io.FixedLenFeature((), tf.int64),
'gender': tf.io.FixedLenFeature((), tf.int64),
'ethnicity': tf.io.FixedLenFeature((), tf.int64),
}
return tf_util.parse_pb_message(protobuff_message, feature)
dataset = dataset.map(parser).map(process_example)
At this point, we have a standard dataset we can operate with doing batching, shuffling, augmentation or whatever we wan. Finally, before feeding the data into the model, we have to transform it to fit the requirements of the model. The code below shows an example of both input and label preprocessing. Previoulsy, I concatenated all the labels, now I create a dictionary witht the names of the outputs in the model as keys.
def preprocess_input_fn():
def _preprocess_input(image,image_shape, age, gender, ethnicity):
image = self.preprocess_image(image)
labels = self.preprocess_labels(age, gender, ethnicity)
return image, labels
return _preprocess_input
def preprocess_image(image):
image = tf.cast(image)
image = tf.image.resize(image)
image = (image / 127.5) - 1.0
return image
def preprocess_labels(age,gender,ethnicity):
gender = tf.one_hot(gender, 2)
ethnicity = tf.one_hot(ethnicity, self.ethnic_groups)
age = tf.one_hot(age, self.age_groups)
return {'Gender': gender, 'Ethnicity': ethnicity, 'Age': age}
In my model, Gender, Ethnicity and Age are the names of the last layers of the model, so my model is defined as having three outputs:
model = Model(inputs=inputs,
outputs=[gender, ethnic_group, age_group])
Now I can use a dataset to fit the model by applying the preprocessing function first:
data = dataset.map(preprocess_input_fn())
model.fit(data, epochs=...)