Implementing class weighting in Faster RCNN

I have a dataset (around 45,000 screenshots) of UI elements (UI trees containing element types and bounding boxes) and associated screenshots:

The dataset is highly imbalanced with the button element being highly overrepresented:

When training on my local machine on a tiny subset of the data (900 screenshots for training, 100 for testing) and 10 epochs, my results aren't bad:

I trained the model on Azure ML with 25,000 screenshots for 13 epochs (which took about 3 days) and my results were actually worse, with most elements being misclassified:

I suspect this is due to the class imbalance and was wondering what the best way to combat it might be. I was thinking I could omit many of the overrepresented elements, or use class weighting however I'm unsure how to implement that in pytorch.

This is my training loop:

# Train for an epoch
def train_one_epoch(model, optimizer, data_loader, device, epoch, print_freq):
    model.train()
    metric_logger = MetricLogger(delimiter=  )
    metric_logger.add_meter('lr', SmoothedValue(window_size=1, fmt='{value:.6f}'))
    header = 'Epoch: [{}]'.format(epoch)

    lr_scheduler = None
    if epoch == 0:
        warmup_factor = 1. / 1000
        warmup_iters = min(1000, len(data_loader) - 1)

        lr_scheduler = warmup_lr_scheduler(optimizer, warmup_iters, warmup_factor)

    for images, targets in metric_logger.log_every(data_loader, print_freq, header):
        try:
            # FOR GPU
            images = list(image.to(device) for image in images)
            targets = [{k: v.to(device) for k, v in t.items()} for t in targets]

            # Train the model
            loss_dict = model(images, targets)

            # reduce losses over all GPUs for logging purposes
            losses = sum(loss for loss in loss_dict.values())
            loss_dict_reduced = reduce_dict(loss_dict)
            losses_reduced = sum(loss for loss in loss_dict_reduced.values())
            loss_value = losses_reduced.item()

            if not np.math.isfinite(loss_value):
                print(Loss is {}, stopping training.format(loss_value))
                print(loss_dict_reduced)
                sys.exit(1)

            optimizer.zero_grad()
            losses.backward()
            optimizer.step()

            if lr_scheduler is not None:
                lr_scheduler.step()

            # Try free memory
            del images
            del targets

            metric_logger.update(loss=losses_reduced, **loss_dict_reduced)
            metric_logger.update(lr=optimizer.param_groups[0][lr])
        except Exception as e:
            print(fCaught Exception: {e})
            time.sleep(10000)
            gc.collect()
            torch.cuda.empty_cache()


    return metric_logger

And my data loader:

# Create Datasets
train_dataset = ScreenshotDataset(training_dataset_files, path_to_screenshots, CONTROL_TYPES, CONTROL_TYPE_TRANSFORMATIONS, train_transformations)
test_dataset = ScreenshotDataset(testing_dataset_files, path_to_screenshots, CONTROL_TYPES, CONTROL_TYPE_TRANSFORMATIONS, test_transformations)

# define training and validation data loaders
data_loader = torch.utils.data.DataLoader(
    train_dataset, batch_size=4, shuffle=True, num_workers=4, collate_fn=collate_fn)

data_loader_test = torch.utils.data.DataLoader(
    test_dataset, batch_size=1, shuffle=False, num_workers=0, collate_fn=collate_fn)

...

class ScreenshotDataset(torch.utils.data.Dataset):
    def __init__(self, filenames, image_path, target_control_types, control_type_transformations, transformations=None):
        self.filenames = filenames
        self.image_path = image_path
        self.target_control_types = target_control_types
        self.control_type_transformations = control_type_transformations

        self.target_control_indexs = self.__create_target_control_index__(target_control_types)

        self.transforms = transformations

    # Map control types to labels
    def __create_target_control_index__(self, target_controls):
        control_index = {}
        for i, control in enumerate(target_controls, 1):
            control_index[control] = i

        return control_index

    @staticmethod
    def get_annotations(boxes, ui_element, target_control_types, control_type_transformations):
        element_right = ui_element['BoundingRectangle']['Left'] + ui_element['BoundingRectangle']['Width']
        element_bottom = ui_element['BoundingRectangle']['Top'] + ui_element['BoundingRectangle']['Height']

        # If the boundary is on the edge, snap it within the limit
        if element_right == 1281 or element_right == 1280:
            element_right = 1279

        if element_bottom == 1024 or element_bottom == 1025:
            element_bottom = 1023

        # Convert control types e.g. Pane - Window
        if ui_element['ControlType'] in control_type_transformations:
            ui_element['ControlType'] = control_type_transformations[ui_element['ControlType']]

        element = { 
            'xmin': ui_element['BoundingRectangle']['Left'] if ui_element['BoundingRectangle']['Left']  0 else 1, # Left
            'ymin': ui_element['BoundingRectangle']['Top'] if ui_element['BoundingRectangle']['Top']  0 else 1, # Top
            'xmax': element_right, # Right
            'ymax': element_bottom, # Bottom
            'name': ui_element['ControlType']
        }

        # Save to list
        # Check element size is correct
        if ui_element['BoundingRectangle']['Height']  0 and ui_element['BoundingRectangle']['Width']  0 and element_right  1280 and element_right  1 and element_bottom  1 and element_bottom  1024:
            # Check control type is one of the targets
            if ui_element['ControlType'] in target_control_types:
                # Check boundaries are correct
                if element['xmin']  element['xmax'] and element['ymin']  element['ymax']:
                    boxes.append(element)

        # Get children if any
        for child_element in ui_element['Children']:
            ScreenshotDataset.get_annotations(boxes, child_element, target_control_types, control_type_transformations)

    def __getitem__(self, idx):
        # Load images and annotations
        img_path = os.path.join(self.image_path, self.filenames[idx])
        json_path = os.path.join(self.image_path, self.filenames[idx].replace('.png', '.json'))
        
        # Load JSON
        image_annotations = []
        with open(json_path, encoding=utf-8) as json_file:
            ui_element = json.load(json_file)

        ScreenshotDataset.get_annotations(image_annotations, ui_element['Children'][0], self.target_control_types, self.control_type_transformations)

        # Create boxes list
        boxes = [
            [annotation['xmin'], annotation['ymin'], annotation['xmax'], annotation['ymax']]
            for annotation in image_annotations
        ]

        if len(boxes) == 0:
            print (Help!)

        # Create labels list
        labels = [self.target_control_indexs[annotation['name']] for annotation in image_annotations]

        # convert everything into a torch.Tensor
        boxes = torch.as_tensor(boxes, dtype=torch.float32)
        labels = torch.as_tensor(labels, dtype=torch.int64)

        # get area for evaluation with the COCO metric, to separate the
        # metric scores between small, medium and large boxes.
        area = (boxes[:, 3] - boxes[:, 1]) * (boxes[:, 2] - boxes[:, 0])

        # unique id
        im_id = torch.tensor([idx])

        # suppose all instances are not crowd (torchvision specific)
        iscrowd = torch.zeros((len(boxes),), dtype=torch.int64)

        # Create target dictionary
        target = {
            boxes: boxes,
            labels: labels,
            image_id: im_id,
            area: area,
            iscrowd: iscrowd,
        }

        # Load the image + augment
        img = Image.open(img_path).convert(RGB)
        img = self.transforms(img)

        return img, target

    def __len__(self):
        return len(self.filenames)

Any advice or direction is appreciated!

Topic object-detection pytorch weighted-data class-imbalance

Category Data Science

About

Geeks Mental is a community that publishes articles and tutorials about Web, Android, Data Science, new techniques and Linux security.