I have a 4 column dataframe which I extracted from the iris dataset. I use kmeans to plot 3 clusters from all possible combinations of 2 columns.
However, there seems to be something wrong with the output, especially since the cluster centers are not placed at the center of the clusters. I have provided examples of the output. Only cluster_1 seems OK but the other 3 look completely wrong .
How best can I fix my clustering? This is the sample code I am using
import numpy as np import pandas as pd import matplotlib.pyplot as plt from sklearn.cluster import KMeans import itertools df = pd.read_csv('iris.csv') df_columns = ['column_a', 'column_b', 'column_c', 'column_d'] n_clusters=3 kmeans = KMeans(n_clusters=n_clusters, init = 'k-means++', max_iter=200) kmeans = kmeans.fit(df) centroids = kmeans.cluster_centers_ cluster_labels = kmeans.labels_ for i in itertools.combinations(df_columns, 2): fig, ax = plt.subplots(figsize=(12, 8)) fig=plt.figure() ax.scatter(df[i[0]].values, df[i[1]].values, c=cluster_labels , cmap='viridis', edgecolor='k', s=20, alpha = 0.5) ax.scatter(centroids[:, 0], centroids[:, 1],s = 20, c = 'black', marker='*') plt.show()
Dataset used:
**column_a**,**column_b**,**column_c**,**column_d** 5.1,3.5,1.4,0.2 4.9,3.0,1.4,0.2 4.7,3.2,1.3,0.2 4.6,3.1,1.5,0.2 5.0,3.6,1.4,0.2 5.4,3.9,1.7,0.4 4.6,3.4,1.4,0.3 5.0,3.4,1.5,0.2 4.4,2.9,1.4,0.2 4.9,3.1,1.5,0.1 5.4,3.7,1.5,0.2 4.8,3.4,1.6,0.2 4.8,3.0,1.4,0.1 4.3,3.0,1.1,0.1 5.8,4.0,1.2,0.2 5.7,4.4,1.5,0.4 5.4,3.9,1.3,0.4 5.1,3.5,1.4,0.3 5.7,3.8,1.7,0.3 5.1,3.8,1.5,0.3 5.4,3.4,1.7,0.2 5.1,3.7,1.5,0.4 4.6,3.6,1.0,0.2 5.1,3.3,1.7,0.5 4.8,3.4,1.9,0.2 5.0,3.0,1.6,0.2 5.0,3.4,1.6,0.4 5.2,3.5,1.5,0.2 5.2,3.4,1.4,0.2 4.7,3.2,1.6,0.2 4.8,3.1,1.6,0.2 5.4,3.4,1.5,0.4 5.2,4.1,1.5,0.1 5.5,4.2,1.4,0.2 4.9,3.1,1.5,0.1 5.0,3.2,1.2,0.2 5.5,3.5,1.3,0.2 4.9,3.1,1.5,0.1 4.4,3.0,1.3,0.2 5.1,3.4,1.5,0.2 5.0,3.5,1.3,0.3 4.5,2.3,1.3,0.3 4.4,3.2,1.3,0.2 5.0,3.5,1.6,0.6 5.1,3.8,1.9,0.4 4.8,3.0,1.4,0.3 5.1,3.8,1.6,0.2 4.6,3.2,1.4,0.2 5.3,3.7,1.5,0.2 5.0,3.3,1.4,0.2 7.0,3.2,4.7,1.4 6.4,3.2,4.5,1.5 6.9,3.1,4.9,1.5 5.5,2.3,4.0,1.3 6.5,2.8,4.6,1.5 5.7,2.8,4.5,1.3 6.3,3.3,4.7,1.6 4.9,2.4,3.3,1.0 6.6,2.9,4.6,1.3 5.2,2.7,3.9,1.4 5.0,2.0,3.5,1.0 5.9,3.0,4.2,1.5 6.0,2.2,4.0,1.0 6.1,2.9,4.7,1.4 5.6,2.9,3.6,1.3 6.7,3.1,4.4,1.4 5.6,3.0,4.5,1.5 5.8,2.7,4.1,1.0 6.2,2.2,4.5,1.5 5.6,2.5,3.9,1.1 5.9,3.2,4.8,1.8 6.1,2.8,4.0,1.3 6.3,2.5,4.9,1.5 6.1,2.8,4.7,1.2 6.4,2.9,4.3,1.3 6.6,3.0,4.4,1.4 6.8,2.8,4.8,1.4 6.7,3.0,5.0,1.7 6.0,2.9,4.5,1.5 5.7,2.6,3.5,1.0 5.5,2.4,3.8,1.1 5.5,2.4,3.7,1.0 5.8,2.7,3.9,1.2 6.0,2.7,5.1,1.6 5.4,3.0,4.5,1.5 6.0,3.4,4.5,1.6 6.7,3.1,4.7,1.5 6.3,2.3,4.4,1.3 5.6,3.0,4.1,1.3 5.5,2.5,4.0,1.3 5.5,2.6,4.4,1.2 6.1,3.0,4.6,1.4 5.8,2.6,4.0,1.2 5.0,2.3,3.3,1.0 5.6,2.7,4.2,1.3 5.7,3.0,4.2,1.2 5.7,2.9,4.2,1.3 6.2,2.9,4.3,1.3 5.1,2.5,3.0,1.1 5.7,2.8,4.1,1.3 6.3,3.3,6.0,2.5 5.8,2.7,5.1,1.9 7.1,3.0,5.9,2.1 6.3,2.9,5.6,1.8 6.5,3.0,5.8,2.2 7.6,3.0,6.6,2.1 4.9,2.5,4.5,1.7 7.3,2.9,6.3,1.8 6.7,2.5,5.8,1.8 7.2,3.6,6.1,2.5 6.5,3.2,5.1,2.0 6.4,2.7,5.3,1.9 6.8,3.0,5.5,2.1 5.7,2.5,5.0,2.0 5.8,2.8,5.1,2.4 6.4,3.2,5.3,2.3 6.5,3.0,5.5,1.8 7.7,3.8,6.7,2.2 7.7,2.6,6.9,2.3 6.0,2.2,5.0,1.5 6.9,3.2,5.7,2.3 5.6,2.8,4.9,2.0 7.7,2.8,6.7,2.0 6.3,2.7,4.9,1.8 6.7,3.3,5.7,2.1 7.2,3.2,6.0,1.8 6.2,2.8,4.8,1.8 6.1,3.0,4.9,1.8 6.4,2.8,5.6,2.1 7.2,3.0,5.8,1.6 7.4,2.8,6.1,1.9 7.9,3.8,6.4,2.0 6.4,2.8,5.6,2.2 6.3,2.8,5.1,1.5 6.1,2.6,5.6,1.4 7.7,3.0,6.1,2.3 6.3,3.4,5.6,2.4 6.4,3.1,5.5,1.8 6.0,3.0,4.8,1.8 6.9,3.1,5.4,2.1 6.7,3.1,5.6,2.4 6.9,3.1,5.1,2.3 5.8,2.7,5.1,1.9 6.8,3.2,5.9,2.3 6.7,3.3,5.7,2.5 6.7,3.0,5.2,2.3 6.3,2.5,5.0,1.9 6.5,3.0,5.2,2.0 6.2,3.4,5.4,2.3 5.9,3.0,5.1,1.8
Advertisement
Answer
You compute the clusters in four dimensions. Note this implies the centroids are four-dimensional points too. Then you plot two-dimensional projections of the clusters. So when you plot the centroids, you have to pick out the same two dimensions that you just used for the scatterplot of the individual points.
for i, j in itertools.combinations([0, 1, 2, 3], 2): fig, ax = plt.subplots(figsize=(12, 8)) ax.scatter(df.iloc[:, i], df.iloc[:, j], c=cluster_labels, cmap='viridis', edgecolor='k', s=20, alpha=0.5) ax.scatter(centroids[:, i], centroids[:, j], s=20, c='black', marker='*') plt.show()