Complex balanced dataloading from multiple imbalanced datasets?
The Setting
- Let's suppose that I have an imbalanced dataset.
- For training purposes, I want to implement a dataloading scheme that samples from this dataset in a more balanced way.
- I want to leverage existing metadata for this purpose.
- Each instance in my dataset belongs to either category $A$ or category $B$. Similarly, each category can be subdivided into several subcategories, namely, $A_1$, $A_2$, $A_3$, $A_4$, ..., $A_N$ and $B_1$, $B_2$, $B_3$, $B_4$, ..., $B_M$.
How I want the dataloading to work
My goal is to have the model learn to discriminate between $A$ and $B$ (top priority) and also between the subcategories within each category.
So I had the following in mind:
- I would like to have a separate dataset for each subcategory $A_1$, $A_2$, $A_3$, $A_4$, ..., $A_N$ and $B_1$, $B_2$, $B_3$, $B_4$, ..., $B_M$.
- Then I would like to have two meta-datasets, a meta-dataset A as a wrapper for $A_1$, $A_2$, $A_3$, $A_4$, ..., $A_N$, and a meta-dataset B as a wrapper for $B_1$, $B_2$, $B_3$, $B_4$, ..., $B_M$. A meta-dataset should be able to sample from the sub-datasets it contains in a more balanced way following some heuristic. Question: Is this possible? How can I do this?
- Finally, I would like to have a single meta-meta-dataset as a wrapper for meta-dataset $A$ and meta-dataset $B$. This meta-meta-dataset should be able to sample in a balanced way from meta-dataset $A$ and meta-dataset $B$. Question: Is this possible? How can I do this?
In other words, during my training loop, I want all my batches to be relatively balanced in terms of categories $A$ and $B$, and within each category I would like the subcategories to be sampled more or less uniformly as well.
Does anyone know how something like this can be done with very imbalanced datasets in Pytorch?
Note: Keep in mind that it's very likely that I won't be able to fit all subcategories into a single batch (there may be many subcategories and the GPU might not be big enough to sample from all subcategories at the same time), therefore it's okay if only a subset of subcategories of $A$ and a subset of subcategories of $B$ are sampled in each batch, as long as all subcategories are sampled more or less uniformly on average.
Topic imbalanced-data pytorch class-imbalance deep-learning
Category Data Science