K-Means Clustering in R and Python

Chris Grannan
9 min readFeb 19, 2021

--

My coding experience up until recently has mostly been in Python. To expand my skillset, I have begun learning the R statistical programming language. For practice using R packages, I started using modeling algorithms I am familiar with in Python and then doing the same tasks in R. For today’s post I wanted to run through one of these practice sessions showing some examples of both python and R code.

Task Overview:

For this small project, I am using a dataset from kaggle involving mall customers. The goal of the project is to segment the customers into definable groups for marketing analysis. There are five columns in this dataset: unique customer IDs, the gender of each customer, the age of each customer, the annual income of each customer (represented in thousands of US dollars), and the overall spending score of each customer (ranked from 1 to 100). To accomplish the goal of segmentation, I used K-Means clustering using scikit-learn in python and tidyverse in R. To determine the number of clusters, I used the elbow method.

Loading Libraries and Data:

Now that the task is defined, let’s start to look over the code differences. Loading packages into either language is very simple. For instance, in R simply type library(name_of_library) as seen in the example below where we load tidyverse and readr.

library(tidyverse)
library(readr)

In python, we use import to load in libraries as seen below. Here we import pandas under the alias pd, the pyplot module from matplotlib under the alias plt, and the KMeans class from the cluster module contained in the sklearn library.

import pandas as pd
import matplotlib.pyplot as plt
from sklearn.cluster import KMeans

Both languages provide very simple ways to load csv files into dataframes as well. In R, we use the readr package to read the csv file and assign variables using the ‘<-’ operator.

df <- read_csv('Mall_Customers.csv')

In python, we use pandas to read a csv and convert it into a dataframe.

df = pd.read_csv('Mall_Customers.csv')

Great. At this point we have our packages and data loaded for each language. Next, let’s go over how we can visualize the data.

Visualizations:

For my visualizations in R, I use ggplot2 which is included in the tidyverse library. The ggplot2 module is very easy to use, simply call ggplot(), then fill in the data you want to use and any aesthetic mapping using aes(). Then we create a pipeline of layers to add to our plot using the ‘+’ command at the end of each line. Below is an example of how to create a scatter plot using the mall customers dataset. In this example we are telling ggplot that we want to use the dataframe ‘df’ as the source of data and we want to map the x axis as annual income and the y axis as spending score. We also tell ggplot to assign each datapoints color based off of the gender column. We then create a new layer where we tell ggplot which kind of plot we want to make, in this case a scatter plot represented as geom_point(). Finally, we create one more layer where we create all of our labels using the labs() command. For more details on what to pass to ggplot, take a look through the documentation found here.

ggplot(df, aes(x = annual_income, 
y = spending_score,
color = Gender)) +
geom_point() +
labs(title = 'Annual Income by Spending Score',
subtitle = 'Grouped by Gender',
x= 'Annual Income', y = 'Spending Score')

To create the same plot in python, we use the matplotlib.pyplot (represented as plt) module. I like to use seaborn (represented as sns), a library that builds on top of pyplot, to make my plots look a bit cleaner. In the example below we create an empty figure object in the size that we want using plt.figure method. Next, we create a scatterplot using seaborn. Again we set our x axis as annual income and our y axis as spending score. Here, we use the hue argument to assign each dots color according to gender. Then, we assign title, x label and y labels. Finally we call plt.show() to display the completed scatterplot.

fig = plt.figure(figsize=(10,8)) 
sns.scatterplot(x = df.Annual_Income,
y = df.Spending_Score,
hue = df.Gender)
plt.title('Annual Income vs Spending Score')
plt.xlabel('Annual Income')
plt.ylabel('Spending Score')
plt.show()

Both languages produce clean, easy to understand scatterplots, and both plots are fairly easy to customize further in regards to fonts and colors. Overall, I find matplotlib a little more straight forward to use over ggplot2, but that is mostly due to my experience with it. I find both languages to be fairly easy to use to create meaningful visualizations.

Modeling:

Now, let’s get to the interesting part of the project, creating our clustering models. The process is the same for both languages. First we will calculate the sum of squares of a K-Means model using every number between 1 and 20 as the number of clusters. Once we have our 20 sum of squares values, we plot those as a line graph and look for a sharp change in slope. That will give us a good place to start looking at the number of clusters to use. Then, we will create models around that value and compare the results. For this dataset we will be looking at the interaction between all three of our numerical variables (age, annual income, and spending score) to compare results. Let’s look at how we go through this process in R.

Luckily for us, the tidyverse library contains everything we need to accomplish this. To start we create our ‘elbow plot’ by plotting our sum of squares values. Below is a sample code block and resulting plot. In this code we are creating a dataframe object using the tibble command. This dataframe contains every number between 1 and 20. We then create a pipeline using the ‘%>%’ operator. In the next line, we add a new column (mod) to our dataframe using the mutate command. This new column is populated with a K-Means model using the cluster column as the number of clusters to create and our ‘df’ dataframe as the datasource. We are also only using the age, annual income, and spending score columns from ‘df’ (as seen as df[,3:5]). In the next line we create another column (glanced) by calling ‘glance’ on each model. Calling glance creates a summary of a model returned as a one row dataframe. Next, we unnest the glanced column, effectively turning each glance object into several columns of values. Finally we create a line plot using the clusters column and the tot.withinss/totss column which was part of each glance object. The tot.withinss/totss is the sum of squares of the distance within each cluster divided by the sum of squares of the distance of each point.

tibble(clusters = 1:20) %>% 
mutate(mod = map(clusters, ~ kmeans(df[,3:5],
centers = .x,
nstart = 50))) %>%
mutate(glanced = map(mod, glance)) %>%
unnest(glanced) %>%
ggplot(aes(x=clusters, y = tot.withinss/totss)) +
geom_line()

This is quite a bit to unpack. To make it simpler, we have measured the size of each cluster and divided that by the spread of all of our data points. The resulting line graph shows how the size of each cluster shrinks as we add more clusters. To find our ideal number of clusters, we look for the point where the slope of our line plot changes most dramatically. We are looking for the ‘elbow’ on our graph. In this case there is a big shift somewhere around 5 and 6 clusters on our plot. Now that we know where we want to start looking at our models, we build a model using 5 clusters and a model using 6 clusters. we will then plot those models and analyze which model better segments our customer base. To create a model is very simple, just call kmeans(), then fill in the data, number of clusters, and an nstart value. The nstart value indicates how many variations of the first cluster are created. A higher value reduces the effect of randomness on your model.

km_5_clusters <- kmeans(df[,3:5], centers = 5, nstart = 50)

Calling km_5_clusters will now print a summary of the model including the center of each cluster, the number of points assigned to each cluster, a list of labels, and the error values. From here we simply plot out the clusters on three scatterplots: one showing the interaction of income and spending, one showing age and income, and one showing age and spending. I won’t go over the code creating those plots, but you can check it on my github. Here are all three of those charts.

5 Clusters: Income vs Spending
5 Clusters: Age vs Income
5 Clusters: Age vs Spending

Next we create a model with 6 clusters and plot the results again. To accomplish this, we follow the same procedures as above, so I won’t go over the code, but here are the three resulting charts with 6 clusters.

6 Clusters: Income vs Spending
6 Clusters: Age vs Income
6 Clusters: Age vs Spending

Now before we get into the analysis of the models, I want to go over the differences of creating a K-Means model in python. For this, we use the KMeans object from the sklearn.cluster module. The actual creation of the model is just as easy to use as the tidyverse version. We simply call KMeans() and supply the number of clusters we want. This creates a modeling object which we fit to our data to. In this instance we use .fit_predict to fit our data and also assign cluster labels to each data point.

model_5_clusters = KMeans(n_clusters=5)
label_5 = model_5_clusters.fit_predict(df[['Age',
'Annual_Income',
'Spending_Score']])

To check your inertia (sum of squares) value, simply call model_5_clusters.inertia_ and to check the centers of your clusters call model_5_clusters.cluster_centers_. From here, we use the cluster labels to plot out our scatterplots. I’m not going to show the resulting scatterplots here, since they are essentially the same as those created using R, but for a closer look at the code to create them feel free to check out my project on github.

Analysis:

When looking at the results of our clustering plots we can see that using 5 clusters creates very distinct groups when comparing income and spending scores, but we get a very wide range of ages in each cluster. Moving to 6 clusters creates quite a bit of overlap into our middle segment on the income and spending scatterplot, but it breaks the age ranges down a substantial amount when looking at the interaction between age and income and the interaction between age and spending. Overall, the segmentation is strongest at 6 clusters as it allows us to separate the mall customers into 6 distinct groups accounting for all three variables.

Finally, my thoughts on using R versus using python. I think both languages are very easy to use in regards to K-Means clustering. Both languages also have libraries that produce polished, easy to read graphs. From my experience with this project, I would not recommend either language over the other, but I slightly prefer using python. I find graph formatting to be ever so slightly more intuitive when using matplotlib, even if it requires a few more lines of code. However, I’m sure that there are plenty of others who would disagree with me. Regardless of the language that you choose to run, you now have the knowledge and tools to create some clustering models.

Thanks for reading!

Resources:

For more information on tidyverse and all the packages contained within, check out their documentation.

For more information on KMeans with scikit-learn, the documentation is found here.

To check out my project in more detail, check out my github repo.

Check out these links for more about matplotlib and seaborn.

Sign up to discover human stories that deepen your understanding of the world.

Free

Distraction-free reading. No ads.

Organize your knowledge with lists and highlights.

Tell your story. Find your audience.

Membership

Read member-only stories

Support writers you read most

Earn money for your writing

Listen to audio narrations

Read offline with the Medium app

--

--

No responses yet

Write a response