{ "cells": [ { "cell_type": "code", "execution_count": 2, "metadata": { "collapsed": true }, "outputs": [], "source": [ "%matplotlib inline" ] }, { "cell_type": "code", "execution_count": 3, "metadata": { "collapsed": true }, "outputs": [], "source": [ "import matplotlib.pyplot as plt\n", "import matplotlib as mpl\n", "import numpy as np" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Gaussian Mixture Models (GMMs)\n", "See the documentation at http://scikit-learn.org/stable/modules/mixture.html#gmm section 2.1.1." ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [], "source": [ "from sklearn import datasets\n", "from sklearn.externals.six.moves import xrange\n", "from sklearn.mixture import GaussianMixture" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "For a full example with plots, see http://scikit-learn.org/stable/auto_examples/mixture/plot_gmm_covariances.html\n", "\n", "Please note that while in this example the GMMs are used for classification, the basic usage is for clustering (or density estimation). The .fit() method does not need the labels...\n", "\n", "Here we just look at the main steps for fitting a model:" ] }, { "cell_type": "code", "execution_count": 5, "metadata": { "collapsed": true }, "outputs": [], "source": [ "#-get the data\n", "iris = datasets.load_iris()\n", "X = iris.data\n", "y = iris.target\n", "\n", "#-how many features are there?" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "GaussianMixture(covariance_type='diag', init_params='kmeans', max_iter=20,\n", " means_init=None, n_components=3, n_init=1, precisions_init=None,\n", " random_state=None, reg_covar=1e-06, tol=0.001, verbose=0,\n", " verbose_interval=10, warm_start=False, weights_init=None)" ] }, "execution_count": 6, "metadata": {}, "output_type": "execute_result" } ], "source": [ "#-fit a model (example):\n", "n_classes = 3\n", "clst = GaussianMixture(n_components=n_classes, covariance_type='diag', init_params='kmeans', max_iter=20)\n", "clst.fit(X)" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "[[ 6.81209583 3.07212929 5.72666022 2.10669076]\n", " [ 5.006 3.418 1.464 0.244 ]\n", " [ 5.92793543 2.75046463 4.40762592 1.41444826]]\n", "[ 0.25188836 0.33333333 0.41477831]\n" ] } ], "source": [ "#-inspect the fitted parameters:\n", "print(clst.means_) # is the matrix of component centers, one per row\n", "print(clst.weights_) # these are the mixing coefficients" ] }, { "cell_type": "code", "execution_count": 8, "metadata": { "collapsed": true }, "outputs": [], "source": [ "#-you can assign the data points to a cluster:\n", "y_pred = clst.predict(X)" ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "0.90666666666666662" ] }, "execution_count": 9, "metadata": {}, "output_type": "execute_result" } ], "source": [ "#-and \"compare\" with the true labels:\n", "sum(y != y_pred)*1.0 / y.size" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Why there is such a mismatch between cluster labels and class labels? Inspect the two and try for find a way to put in correspondence the class labels with the cluster labels." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# K-means clustering\n", "\n", "See the documentation at\n", "http://scikit-learn.org/stable/modules/generated/sklearn.cluster.KMeans.html#sklearn.cluster.KMeans\n", "\n", "Note that there is a parameter, n_jobs, which allows you running the\n", "code on several CPUs. Great speed-up for large data sets.\n", "\n", "The basic steps are as before:" ] }, { "cell_type": "code", "execution_count": 10, "metadata": { "collapsed": true }, "outputs": [], "source": [ "from sklearn.cluster import KMeans" ] }, { "cell_type": "code", "execution_count": 11, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "KMeans(algorithm='auto', copy_x=True, init='k-means++', max_iter=300,\n", " n_clusters=3, n_init=10, n_jobs=1, precompute_distances='auto',\n", " random_state=None, tol=0.0001, verbose=0)" ] }, "execution_count": 11, "metadata": {}, "output_type": "execute_result" } ], "source": [ "#-fit a model (example):\n", "n_classes = 3\n", "clst = KMeans(n_clusters=n_classes)\n", "clst.fit(X)" ] }, { "cell_type": "code", "execution_count": 12, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "array([[ 5.006 , 3.418 , 1.464 , 0.244 ],\n", " [ 5.9016129 , 2.7483871 , 4.39354839, 1.43387097],\n", " [ 6.85 , 3.07368421, 5.74210526, 2.07105263]])" ] }, "execution_count": 12, "metadata": {}, "output_type": "execute_result" } ], "source": [ "#-inspect the fitted parameters:\n", "clst.cluster_centers_ # is the matrix of component centers, one per row" ] }, { "cell_type": "code", "execution_count": 13, "metadata": { "collapsed": true }, "outputs": [], "source": [ "#-you can assign the data points to a cluster:\n", "y_pred = clst.predict(X)" ] }, { "cell_type": "code", "execution_count": 15, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "0.10666666666666667" ] }, "execution_count": 15, "metadata": {}, "output_type": "execute_result" } ], "source": [ "#-and compare with the true labels:\n", "sum(y != y_pred)*1.0 / y.size ## Why does it not work, as you would expect??" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "** NOTE: ** For large data sets there is a \"batch\" version of KMeans, which converges much faster. Look at:\n", " http://scikit-learn.org/stable/modules/generated/sklearn.cluster.MiniBatchKMeans.html#sklearn.cluster.MiniBatchKMeans\n", "\n", "And see the discussion:\n", "http://scikit-learn.org/stable/auto_examples/cluster/plot_mini_batch_kmeans.html" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Application: color quantization\n", "An important application of KMeans is in image processing. For example in re-quantizing the color levels. For this, execute the example at http://scikit-learn.org/stable/auto_examples/cluster/plot_color_quantization.html\n", "\n", "### TODO: (if time allows)\n", "Use a different clustering method - explore the options from\n", "\n", "- mean shift: http://scikit-learn.org/stable/modules/generated/sklearn.cluster.MeanShift.html#sklearn.cluster.MeanShift\n", "- spectral clustering: http://scikit-learn.org/stable/modules/generated/sklearn.cluster.SpectralClustering.html#sklearn.cluster.SpectralClustering\n", "- hierarchical clustering (Ward): http://scikit-learn.org/0.16/modules/generated/sklearn.cluster.Ward.html" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": true }, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 2", "language": "python", "name": "python2" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 2 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython2", "version": "2.7.13" } }, "nbformat": 4, "nbformat_minor": 1 }