Is it possible to cluster data according to a target?

I was wondering if there exists techniques to cluster data according to a target. For example, suppose we want to find groups of customers likely to churn:

  • Target is churn.
  • We want to find clusters exhibiting the same behaviour according to the fact that they are likely to churn (or not). Therefore, variables not explaining churn behaviour should not influence how clusters are built.

I have done this analysis the following way:

  1. Predict target (e.g. using a Random Forest) and retrieve "most important features" (from feature importance analysis).
  2. Cluster samples with selected features (e.g. using k-means).

However, I am afraid the clustering technique used in the 2nd step might not catch behaviours found in the 1st step which might explain churn (suppose there is a complex interaction in some trees in the RF, this interaction might not be cought in the k-means algorithm).

I was thinking of another way of doing this by using a neural network:

  1. Predict target using a neural network with several layers, and for each sample retrieve activations from a given layer.
  2. Cluster samples with their activations.

If the performance of the neural network is good and if the layer from which activations are retrieved is carefully chosen (not too close to the input or the output layer), I suppose the clusters could show customers displaying the same behaviour explaining the target.

I did not find any articles having this approach. Did anyone deal with the same issue or have other ideas?

Topic predictor-importance predictive-modeling clustering

Category Data Science


I have been thinking of using shapley values to cluster predictions. Indeed, for samples having the same prediction, the contributions of each feature leading to that prediction can be different. This information could be captured with techniques decomposing the prediction into "contributions" (or a proxy of contributions) such as shapley values.

So clustering data according to a target could be done following these three steps:

  1. train a supervised ML model (e.g. a random forest)
  2. extract the shapley values for every sample
  3. cluster samples using their shapley values

A quick search on google led me to the same idea in Christoph Molnar's famous book, so it comforts me in this approach.


One approach I would try would be a supervised dimension reduction (UMAP for example https://umap-learn.readthedocs.io/en/latest/supervised.html) then a clustering approach (such as Hdbscan: https://hdbscan.readthedocs.io/en/latest/how_hdbscan_works.html). This would allow you to perform clustering, in cluding a supervised dimension. Be cautious in that I have found that UMAP can 'overfit' - in the sense that it might provide clean 'groups' on training data and very different stuff on testing data.


You can train a decision tree with your features and target. Then just take the leaf of DT with the highest target rate, constrains on your features in that leaf will be your segment.


Predict target (e.g. using a Random Forest) and retrieve "most important features" (from feature importance analysis). Cluster samples with selected features (e.g. using k-means).

You must also scale based on variable importance.

However, I am afraid the clustering technique used in the 2nd step might not catch behaviours found in the 1st step which might explain churn (suppose there is a complex interaction in some trees in the RF, this interaction might not be cought in the k-means algorithm).

Scaling based on varimp will help with this. Actually I am not sure this is at all correct. Lets say conditional XOR based on two variables. That will divide the plane into 4 even squares where one class will be in diagonally opposite of the two squares. This does not exactly explain what is happening but it does show it. But then how to see into multidimensional space? Use hierarchical clustering diagram and color each end point by their resulting class.

Look into random forest "proximity plots". Section 15.3.3 in elements of statistical learning.


Use constrained clustering.

This allows you to set up "must link" and "cannot link" constraints.

Then you can cluster your data such that no cluster contains both 'churn' and 'non churn' entries bybsettingn"cannot link" constraints.

I'm just not aware of any good implementations.


First of all, your approaches are smart and creative but some remarks:

  1. The question is not defined well. You talk about clustering using target which is actually paradoxical however I understand your point. The problem is that, not caring about this paradox may hurt your analysis which comes in second point.
  2. According to 1, you classify your points based on target and try to find dense subgroups there. That will not work as you are including the target in your analysis. This is the confusion made from paradoxical definition I mentioned above.
  3. You have your targets, so divide your data according to them and analyze each subset of data separately, having an eye on their interaction.

Second the suggestions:

Subgroup Discovery

This is the right, solid and difficult way to find these interesting subgroups. To adopt it to your use-case you need to modify the algorithm which might make it even more difficult but is certainly worth, at least, to have a look at.

Creative Way

Partition your data according to target. You will end up with several disjoint subsets of data. Then start a statistical analysis on the association of variables within and between subsets of data. For instance the distribution of values within a variable should be significantly different among different classes, if that variable really contributes to that target (ANOVA).

This also helps to remove variables which do not contribute to the target (by doing this 1. you reduce the complexity of data and analysis and improve interpretability by removing them, 2. you already found which target IS NOT contributing which is a part of your answer)

PS: I just improvised. Please try it and let me know if it works or not so we can think of another solution :)

About

Geeks Mental is a community that publishes articles and tutorials about Web, Android, Data Science, new techniques and Linux security.