SciKit-Learn Decision Tree Overfitting

We have a project to utilize a few algorithms we have learned so far. I've been using SciKit-Learn to perform these algorithms, but when it comes to decision trees I keep getting a feeling I am overfitting.

I'm using a dataset about the weather, giving characteristics such as city, state, month, year, wind direction, wind speed, etc... where the target variable is the average temperature for the day. Now I know this is hard to classify, as it is pretty much a continuous variable space, but I've simplified it to the predicted being within a range of 5 to the actual.

Here is a link to the csv I'm using

The following is my code:

address2 = 'C:/.../weather.csv'
weather = pd.read_csv(address2)

cityCode= le.fit_transform(weather.iloc[:,2])
windDirection = le.fit_transform(weather.iloc[:,3])
month = le.fit_transform(weather.iloc[:,8])
precip = le.fit_transform(weather.iloc[:,9])
windSpeed = le.fit_transform(weather.iloc[:,10])
state = le.fit_transform(weather.iloc[:,11])
week = le.fit_transform(weather.iloc[:,12])
year = le.fit_transform(weather.iloc[:,13])

Xweather = list(zip(cityCode,windDirection,month,precip,windSpeed,state,week,year))
yweather= weather.iloc[:,0]

yweather_test = train_test_split(Xweather, y, test_size = 0.2, random_state=413)

cWeather = tree.DecisionTreeClassifier()
cWeather.fit(Xweather_train,yweather_train)
accu_train_weather=np.sum(abs(cWeather.predict(Xweather_train)-yweather_train)=5)/float(yweather_train.size)*100
accu_test_weather=np.sum(abs(cWeather.predict(Xweather_test)-yweather_test)=5)/float(yweather_test.size)*100
print(Classificaton accuracy on training set, accu_train_weather, %)
print(Classificaton accuracy on test set, accu_test_weather, %)

My training set constantly gets 100% training accuracy, but the test set is constantly 57% accurate, which leads me to believe the tree is overfitting to the training set.

I know I'm not doing any pruning, but even when I do, I can get the same test accuracy as unpruned at best. By pruning I mean setting the tree classifier to have a maximum number of leaves, minimum samples per leaf, and maximum depth.

Topic overfitting decision-trees scikit-learn python machine-learning

Category Data Science


It seems like the split is incorrect....

yweather_test = train_test_split(Xweather, y, test_size = 0.2, random_state=413)

Change it to:

xweather_train, xweather_test, yweather_train, yweather_test = train_test_split(
    Xweather, yweather, test_size = 0.2, random_state=413)

Then do the prediction:

y_pred = classifier.predict(xweather_test)

Make the confusion matrix and see the result:

from sklearn.metrics import confusion_matrix    
cm = confusion_matrix(yweather_test, y_pred)

Note - I have just suggested some corrections based on just your code. Your data is not available on the link. I am assuming you have correctly converted the target variable from continuous to categorical range.


Predicting average temperature is a regression task, not classification. You should be using DecisionTreeRegressor instead. Temperature is a continuous value and you are treating it as a category by using a classifier.

Tinkering with the hyperparameters (maximum number of leaves, minimum samples per leaf, and maximum depth, etc) is still important since decision trees always are prone to overfitting. If you struggle to find good parameters yourself then you can try some automated methods such as GridSearchCV or RandomizedSearchCV in sklearn.


The vanilla decision tree algorithm is prone to overfitting. That's kind of why we have those ensembled tree algorithm. The classics include Random Forests, AdaBoost, and Gradient Boosted Trees. All of those are implemented in sklearn.

There are other more advanced variation/implementation outside sklearn, for example, lightGBM and xgboost etc.

If you must use the vanilla decision tree, trying to reduce the dimensionality of your inputs might help to reduce overfitting.

About

Geeks Mental is a community that publishes articles and tutorials about Web, Android, Data Science, new techniques and Linux security.