Writing K-Means Clustering from Scratch the Hard Way
I have always felt that to be successful in any field requires the ability to not only use tools which others have built, but also to have at least some basic knowledge of how those tools work under the hood. Data Science is no exception. Those of us who do work in Python have been richly blessed by the folks who put scikit-learn together, but how many of us have actually taken the time to attempt to replicate one of their functions from scratch, let alone the hard way? By the hard way, I mean researching the algorithm yourself, either on YouTube, or, for the truly hardcore, by reading through a statistics textbook, then attempting to implement the algorithm, step by step without reference to any outside code. Writing a functioning first-pass K-Means Clustering algorithm from scratch in Python, the hard way was what I set out to do, researching the algorithm itself, then translating the statistics into Python.
After putting a wire frame for a KMeans class together, I began researching the algorithm itself, finding Josh Starmer’s StatQuest video on the topic to be the clearest resource. For the benefit of those unfamiliar with K-Means Clustering, I will explain its function and outline the process the algorithm goes through below.
As the name suggests, K-Means Clustering, is a form of unsupervised machine learning used to identify different groupings in a dataset. Unsupervised here simply means that it is used for identification rather than prediction. An example of this would be examining tumor measurements from mammograms to see if any clusters happen to stick out, informing decisions on which patients need to be re-examined based on potential risk of malignancy.
How K-Means works is actually fairly straight forward:
- A number of data points (k) equal to the desired number of clusters are randomly chosen as centroids
- The distance between each centroid and every other point is calculated
- Data points are assigned to the cluster belonging to the nearest centroid
- The geometric mean (point with the lowest mean distance from all other points in the cluster) is calculated as the new centroid
- The distance between the new centroids and every other point is calculated
- Data points are reassigned to the cluster belonging to their new nearest centroid
This process is repeated until either the the new centroids cease to differ within a certain level of tolerance or until a maximum number of iterations is reached. The technical form for the former is convergence. Scikit-learn sets the number of max iterations at 300, but, in most cases, the data allow for much faster convergence, so this is rarely used.
With this in mind, and not wanting to make this taskharder than it already is, I settled with using an n_iter hyper-parameter with the default set to 10 instead.
Following the algorithm’s steps as outlined above, I created the fit method as shown in the code below. Owing to Python’s readability, I will not explicitly explain every line of this code; rather, I will point out some of the elements which may not be readily obvious to the uninitiated.
As my reader can see, treating the algorithm’s variables as class attributes allowed for each of these variable to be used, changed, and retrieved at the appropriate times for someone using the function to be able to derive utility from it. In this way, if a programmer wanted to know the the average distance of all points from each centroid for instance, he or she could call the self.avg attribute at a given number of iterations and see how these average distances change as the algorithm converges.
For those viewing this code who have familiarity with Python, but not the NumPy or SciPy libraries, I will explicitly explain what each library function is doing. The fit method takes and in returns a NumPy array.
- The NumPy array is data structure similar to a Python list, but which has attributes that make linear algebra operations possible where they would be painful with Python’s native list class.
- In line 24, the np.linalg.norm function used in the list comprehension is equivalent of the distance formula you likely learned in middle school algebra only it works in multiple dimensions, something really useful the more features are found in a given dataset.
- In line 42 np.where() acts in much the same way that list.index() does in vanilla Python, only the former identifies the indices of all points with that value contained within the array.
- In line 46, the gmean function from the SciPy library is used to calculate the geometric mean explained above.
Enough with the structure of the class itself, let’s see how well it can perform, first with randomized data, then with a real dataset.
As my reader can see, the clusters are fairly similar, but with the scikit-learn clustering clearly coming out on top.
Using a dataset with tumor measurements from mammograms, and passing it through both the homemade K-Means algorithm and its scikit-learn equivalent netted the above results. Considering the fact that the cluster numbers and associated colors are arbitrary, these clusterings for all practical purposes are identical.