Implementing class weighting in Faster RCNN
I have a dataset (around 45,000 screenshots) of UI elements (UI trees containing element types and bounding boxes) and associated screenshots:
The dataset is highly imbalanced with the button element being highly overrepresented:
When training on my local machine on a tiny subset of the data (900 screenshots for training, 100 for testing) and 10 epochs, my results aren't bad:
I trained the model on Azure ML with 25,000 screenshots for 13 epochs (which took about 3 days) and my results were actually worse, with most elements being misclassified:
I suspect this is due to the class imbalance and was wondering what the best way to combat it might be. I was thinking I could omit many of the overrepresented elements, or use class weighting however I'm unsure how to implement that in pytorch.
This is my training loop:
# Train for an epoch
def train_one_epoch(model, optimizer, data_loader, device, epoch, print_freq):
model.train()
metric_logger = MetricLogger(delimiter= )
metric_logger.add_meter('lr', SmoothedValue(window_size=1, fmt='{value:.6f}'))
header = 'Epoch: [{}]'.format(epoch)
lr_scheduler = None
if epoch == 0:
warmup_factor = 1. / 1000
warmup_iters = min(1000, len(data_loader) - 1)
lr_scheduler = warmup_lr_scheduler(optimizer, warmup_iters, warmup_factor)
for images, targets in metric_logger.log_every(data_loader, print_freq, header):
try:
# FOR GPU
images = list(image.to(device) for image in images)
targets = [{k: v.to(device) for k, v in t.items()} for t in targets]
# Train the model
loss_dict = model(images, targets)
# reduce losses over all GPUs for logging purposes
losses = sum(loss for loss in loss_dict.values())
loss_dict_reduced = reduce_dict(loss_dict)
losses_reduced = sum(loss for loss in loss_dict_reduced.values())
loss_value = losses_reduced.item()
if not np.math.isfinite(loss_value):
print(Loss is {}, stopping training.format(loss_value))
print(loss_dict_reduced)
sys.exit(1)
optimizer.zero_grad()
losses.backward()
optimizer.step()
if lr_scheduler is not None:
lr_scheduler.step()
# Try free memory
del images
del targets
metric_logger.update(loss=losses_reduced, **loss_dict_reduced)
metric_logger.update(lr=optimizer.param_groups[0][lr])
except Exception as e:
print(fCaught Exception: {e})
time.sleep(10000)
gc.collect()
torch.cuda.empty_cache()
return metric_logger
And my data loader:
# Create Datasets
train_dataset = ScreenshotDataset(training_dataset_files, path_to_screenshots, CONTROL_TYPES, CONTROL_TYPE_TRANSFORMATIONS, train_transformations)
test_dataset = ScreenshotDataset(testing_dataset_files, path_to_screenshots, CONTROL_TYPES, CONTROL_TYPE_TRANSFORMATIONS, test_transformations)
# define training and validation data loaders
data_loader = torch.utils.data.DataLoader(
train_dataset, batch_size=4, shuffle=True, num_workers=4, collate_fn=collate_fn)
data_loader_test = torch.utils.data.DataLoader(
test_dataset, batch_size=1, shuffle=False, num_workers=0, collate_fn=collate_fn)
...
class ScreenshotDataset(torch.utils.data.Dataset):
def __init__(self, filenames, image_path, target_control_types, control_type_transformations, transformations=None):
self.filenames = filenames
self.image_path = image_path
self.target_control_types = target_control_types
self.control_type_transformations = control_type_transformations
self.target_control_indexs = self.__create_target_control_index__(target_control_types)
self.transforms = transformations
# Map control types to labels
def __create_target_control_index__(self, target_controls):
control_index = {}
for i, control in enumerate(target_controls, 1):
control_index[control] = i
return control_index
@staticmethod
def get_annotations(boxes, ui_element, target_control_types, control_type_transformations):
element_right = ui_element['BoundingRectangle']['Left'] + ui_element['BoundingRectangle']['Width']
element_bottom = ui_element['BoundingRectangle']['Top'] + ui_element['BoundingRectangle']['Height']
# If the boundary is on the edge, snap it within the limit
if element_right == 1281 or element_right == 1280:
element_right = 1279
if element_bottom == 1024 or element_bottom == 1025:
element_bottom = 1023
# Convert control types e.g. Pane - Window
if ui_element['ControlType'] in control_type_transformations:
ui_element['ControlType'] = control_type_transformations[ui_element['ControlType']]
element = {
'xmin': ui_element['BoundingRectangle']['Left'] if ui_element['BoundingRectangle']['Left'] 0 else 1, # Left
'ymin': ui_element['BoundingRectangle']['Top'] if ui_element['BoundingRectangle']['Top'] 0 else 1, # Top
'xmax': element_right, # Right
'ymax': element_bottom, # Bottom
'name': ui_element['ControlType']
}
# Save to list
# Check element size is correct
if ui_element['BoundingRectangle']['Height'] 0 and ui_element['BoundingRectangle']['Width'] 0 and element_right 1280 and element_right 1 and element_bottom 1 and element_bottom 1024:
# Check control type is one of the targets
if ui_element['ControlType'] in target_control_types:
# Check boundaries are correct
if element['xmin'] element['xmax'] and element['ymin'] element['ymax']:
boxes.append(element)
# Get children if any
for child_element in ui_element['Children']:
ScreenshotDataset.get_annotations(boxes, child_element, target_control_types, control_type_transformations)
def __getitem__(self, idx):
# Load images and annotations
img_path = os.path.join(self.image_path, self.filenames[idx])
json_path = os.path.join(self.image_path, self.filenames[idx].replace('.png', '.json'))
# Load JSON
image_annotations = []
with open(json_path, encoding=utf-8) as json_file:
ui_element = json.load(json_file)
ScreenshotDataset.get_annotations(image_annotations, ui_element['Children'][0], self.target_control_types, self.control_type_transformations)
# Create boxes list
boxes = [
[annotation['xmin'], annotation['ymin'], annotation['xmax'], annotation['ymax']]
for annotation in image_annotations
]
if len(boxes) == 0:
print (Help!)
# Create labels list
labels = [self.target_control_indexs[annotation['name']] for annotation in image_annotations]
# convert everything into a torch.Tensor
boxes = torch.as_tensor(boxes, dtype=torch.float32)
labels = torch.as_tensor(labels, dtype=torch.int64)
# get area for evaluation with the COCO metric, to separate the
# metric scores between small, medium and large boxes.
area = (boxes[:, 3] - boxes[:, 1]) * (boxes[:, 2] - boxes[:, 0])
# unique id
im_id = torch.tensor([idx])
# suppose all instances are not crowd (torchvision specific)
iscrowd = torch.zeros((len(boxes),), dtype=torch.int64)
# Create target dictionary
target = {
boxes: boxes,
labels: labels,
image_id: im_id,
area: area,
iscrowd: iscrowd,
}
# Load the image + augment
img = Image.open(img_path).convert(RGB)
img = self.transforms(img)
return img, target
def __len__(self):
return len(self.filenames)
Any advice or direction is appreciated!
Topic object-detection pytorch weighted-data class-imbalance
Category Data Science