Early stopping with class weights / sample weights

I'm performing a classification of imbalanced multiclass data using a Neural Network in the TensorFlow framework. Therefore, I'm applying class weights.

I would like to apply early stopping to reduce overfitting. My concern is that the cost of the validation set used for early stopping will be calculated differently from the cost of the training set due to the class weights, so the early stopping will not work correctly. That's because the cost of the validation set could be biased to the classes that are over-represented in the data.

My questions are:

  1. Is the concern expressed above correct?
  2. If the answer is yes, is it possible to apply class weights or sample weights on the validation set in TensorFlow in order that the cost of the training and the validation set will be calculated in a similar way?
  3. If it's not possible in TensorFlow, is it possible in PyTorch or other frameworks?
  4. Perhaps there are other solutions to the expressed concern?

My current relevant piece of code is:

model_checkpoint_callback = None
checkpoint_filepath = r'C:\Users\User\PycharmProjects\models\SUAI\nn_checkpoint'
if early_stopping:
    model_checkpoint_callback = ModelCheckpoint(
            filepath=checkpoint_filepath,
            save_weights_only=True,
            monitor='val_loss',
            mode='min',
            save_best_only=True
        )

history = nn.fit(
        X_train, y_train,
        validation_data=(X_val, y_val),
        epochs=n_epochs,
        callbacks=[model_checkpoint_callback] if early_stopping else None,
        batch_size=batch_size,
        verbose=verbose,
        class_weight=class_weights,
        sample_weight=sample_weight
    )

if early_stopping:
    nn.load_weights(checkpoint_filepath)

Topic early-stopping keras tensorflow neural-network python

Category Data Science


You seem to believe that early stopping involves some type of comparison between the training & validation loss, which (belief) in turn leads you to incorrect and invalid paths. But this is not the case. Consider the documentation of early stopping in Tensorflow Keras:

Stop training when a monitored metric has stopped improving.

tf.keras.callbacks.EarlyStopping(
    monitor='val_loss', min_delta=0, patience=0, verbose=0,
    mode='auto', baseline=None, restore_best_weights=False)

Here (as in your case), the monitored metric is val_loss, and this is the only metric that is being monitored. In other words, when monitor='val_loss' (which is the most usual usage), early stopping does not know (or care) about training loss at all; it only cares about the validation loss, and it stops the training as soon as this (the validation loss) stops improving, so as to avoid overfitting.

Having clarified this, it should hopefully be apparent than any discussion of weights (class or sample) and any concern that, since they may not be taken into account in the validation set, this may create some "asymmetry" that can lead to incorrect results, is actually irrelevant.

the cost of the validation set could be biased to the classes that are over-represented in the data.

This is the only way of computing the loss on anything else than the training set during training, and this would also be the situation when the model is deployed in the real world (without classes or class weights provided in advance); the discussion in Why you shouldn't upsample before cross validation, although superficially about a different aspect of the class imbalance case, may actually be helpful.

To wrap-up: validation sets are supposed to be as close to the unseen data to be encountered by the model in a (possible) future after deployment; that means no weights, and no artificial upsampling whatsoever. That is how the performance of your model in the real world will be assessed.

About

Geeks Mental is a community that publishes articles and tutorials about Web, Android, Data Science, new techniques and Linux security.