[mlpack-git] master, mlpack-1.0.x: Embedding nystroem method into kernel pca method. (44b26ee)

gitdub at big.cc.gt.atl.ga.us gitdub at big.cc.gt.atl.ga.us
Thu Mar 5 21:53:44 EST 2015


Repository : https://github.com/mlpack/mlpack

On branches: master,mlpack-1.0.x
Link       : https://github.com/mlpack/mlpack/compare/904762495c039e345beba14c1142fd719b3bd50e...f94823c800ad6f7266995c700b1b630d5ffdcf40

>---------------------------------------------------------------

commit 44b26ee7962b4cd7dffd96acecb23af8e465f9a8
Author: Marcus Edel <marcus.edel at fu-berlin.de>
Date:   Sat Jul 19 19:29:11 2014 +0000

    Embedding nystroem method into kernel pca method.


>---------------------------------------------------------------

44b26ee7962b4cd7dffd96acecb23af8e465f9a8
 src/mlpack/methods/CMakeLists.txt                  |   1 +
 src/mlpack/methods/kernel_pca/CMakeLists.txt       |   2 +
 src/mlpack/methods/kernel_pca/kernel_pca.hpp       |  33 ++++---
 src/mlpack/methods/kernel_pca/kernel_pca_impl.hpp  | 100 ++++++---------------
 .../kernel_rules}/CMakeLists.txt                   |   4 +-
 .../kernel_pca/kernel_rules/naive_method.hpp       |  89 ++++++++++++++++++
 .../kernel_pca/kernel_rules/nystroem_method.hpp    |  71 +++++++++++++++
 .../init_rules => nystroem_method}/CMakeLists.txt  |   7 +-
 src/mlpack/tests/CMakeLists.txt                    |   1 +
 src/mlpack/tests/kernel_pca_test.cpp               |  70 +++++++++++++--
 10 files changed, 286 insertions(+), 92 deletions(-)

diff --git a/src/mlpack/methods/CMakeLists.txt b/src/mlpack/methods/CMakeLists.txt
index 1c0e61d..041e475 100644
--- a/src/mlpack/methods/CMakeLists.txt
+++ b/src/mlpack/methods/CMakeLists.txt
@@ -31,6 +31,7 @@ set(DIRS
   regularized_svd
   sparse_autoencoder
   sparse_coding
+  nystroem_method
 )
 
 foreach(dir ${DIRS})
diff --git a/src/mlpack/methods/kernel_pca/CMakeLists.txt b/src/mlpack/methods/kernel_pca/CMakeLists.txt
index c575af9..4b5fcb5 100644
--- a/src/mlpack/methods/kernel_pca/CMakeLists.txt
+++ b/src/mlpack/methods/kernel_pca/CMakeLists.txt
@@ -14,6 +14,8 @@ endforeach()
 # the parent scope).
 set(MLPACK_SRCS ${MLPACK_SRCS} ${DIR_SRCS} PARENT_SCOPE)
 
+add_subdirectory(kernel_rules)
+
 add_executable(kernel_pca
   kernel_pca_main.cpp
 )
diff --git a/src/mlpack/methods/kernel_pca/kernel_pca.hpp b/src/mlpack/methods/kernel_pca/kernel_pca.hpp
index 98c23ca..fdf2ef4 100644
--- a/src/mlpack/methods/kernel_pca/kernel_pca.hpp
+++ b/src/mlpack/methods/kernel_pca/kernel_pca.hpp
@@ -1,6 +1,7 @@
 /**
  * @file kernel_pca.hpp
  * @author Ajinkya Kale
+ * @author Marcus Edel
  *
  * Defines the KernelPCA class to perform Kernel Principal Components Analysis
  * on the specified data set.
@@ -9,7 +10,7 @@
 #define __MLPACK_METHODS_KERNEL_PCA_KERNEL_PCA_HPP
 
 #include <mlpack/core.hpp>
-#include <mlpack/core/kernels/linear_kernel.hpp>
+#include <mlpack/methods/kernel_pca/kernel_rules/naive_method.hpp>
 
 namespace mlpack {
 namespace kpca {
@@ -27,7 +28,10 @@ namespace kpca {
  * files in mlpack/core/kernels/) and it is easy to write your own; see other
  * implementations for examples.
  */
-template <typename KernelType>
+template <
+  typename KernelType,
+  typename KernelRule = NaiveKernelRule<KernelType>
+>
 class KernelPCA
 {
  public:
@@ -38,6 +42,7 @@ class KernelPCA
    * much).
    *
    * @param kernel Kernel to be used for computation.
+   * @param centerTransformedData Center transformed data.
    */
   KernelPCA(const KernelType kernel = KernelType(),
             const bool centerTransformedData = false);
@@ -49,6 +54,21 @@ class KernelPCA
    * @param transformedData Matrix to output results into.
    * @param eigval KPCA eigenvalues will be written to this vector.
    * @param eigvec KPCA eigenvectors will be written to this matrix.
+   * @param newDimension New dimension for the dataset.
+   */
+  void Apply(const arma::mat& data,
+             arma::mat& transformedData,
+             arma::vec& eigval,
+             arma::mat& eigvec,
+             const size_t newDimension);
+
+  /**
+   * Apply Kernel Principal Components Analysis to the provided data set.
+   *
+   * @param data Data matrix.
+   * @param transformedData Matrix to output results into.
+   * @param eigval KPCA eigenvalues will be written to this vector.
+   * @param eigvec KPCA eigenvectors will be written to this matrix.
    */
   void Apply(const arma::mat& data,
              arma::mat& transformedData,
@@ -91,7 +111,6 @@ class KernelPCA
   //! Return whether or not the transformed data is centered.
   bool& CenterTransformedData() { return centerTransformedData; }
    
-   
   // Returns a string representation of this object. 
   std::string ToString() const;
 
@@ -102,14 +121,6 @@ class KernelPCA
   //! run.
   bool centerTransformedData;
 
-  /**
-   * Construct the kernel matrix.
-   *
-   * @param data Input data points.
-   * @param kernelMatrix Matrix to store the constructed kernel matrix in.
-   */
-  void GetKernelMatrix(const arma::mat& data, arma::mat& kernelMatrix);
-
 }; // class KernelPCA
 
 }; // namespace kpca
diff --git a/src/mlpack/methods/kernel_pca/kernel_pca_impl.hpp b/src/mlpack/methods/kernel_pca/kernel_pca_impl.hpp
index 9e65845..e209b1a 100644
--- a/src/mlpack/methods/kernel_pca/kernel_pca_impl.hpp
+++ b/src/mlpack/methods/kernel_pca/kernel_pca_impl.hpp
@@ -3,7 +3,7 @@
  * @author Ajinkya Kale
  * @author Marcus Edel
  *
- * Implementation of KernelPCA class to perform Kernel Principal Components
+ * Implementation of Kernel PCA class to perform Kernel Principal Components
  * Analysis on the specified data set.
  */
 #ifndef __MLPACK_METHODS_KERNEL_PCA_KERNEL_PCA_IMPL_HPP
@@ -12,54 +12,26 @@
 // In case it hasn't already been included.
 #include "kernel_pca.hpp"
 
-#include <iostream>
-
 namespace mlpack {
 namespace kpca {
 
-template <typename KernelType>
-arma::mat GetKernelMatrix(KernelType kernel, arma::mat transData);
-
-template <typename KernelType>
-KernelPCA<KernelType>::KernelPCA(const KernelType kernel,
+template <typename KernelType, typename KernelRule>
+KernelPCA<KernelType, KernelRule>::KernelPCA(const KernelType kernel,
                                  const bool centerTransformedData) :
       kernel(kernel),
       centerTransformedData(centerTransformedData)
 { }
 
 //! Apply Kernel Principal Component Analysis to the provided data set.
-template <typename KernelType>
-void KernelPCA<KernelType>::Apply(const arma::mat& data,
+template <typename KernelType, typename KernelRule>
+void KernelPCA<KernelType, KernelRule>::Apply(const arma::mat& data,
                                   arma::mat& transformedData,
                                   arma::vec& eigval,
-                                  arma::mat& eigvec)
+                                  arma::mat& eigvec,
+                                  const size_t newDimension)
 {
-  // Construct the kernel matrix.
-  arma::mat kernelMatrix;
-  GetKernelMatrix(data, kernelMatrix);
-
-  // For PCA the data has to be centered, even if the data is centered.  But it
-  // is not guaranteed that the data, when mapped to the kernel space, is also
-  // centered. Since we actually never work in the feature space we cannot
-  // center the data. So, we perform a "psuedo-centering" using the kernel
-  // matrix.
-  arma::rowvec rowMean = arma::sum(kernelMatrix, 0) / kernelMatrix.n_cols;
-  kernelMatrix.each_col() -= arma::sum(kernelMatrix, 1) / kernelMatrix.n_cols;
-  kernelMatrix.each_row() -= rowMean;
-  kernelMatrix += arma::sum(rowMean) / kernelMatrix.n_cols;
-
-  // Eigendecompose the centered kernel matrix.
-  arma::eig_sym(eigval, eigvec, kernelMatrix);
-
-  // Swap the eigenvalues since they are ordered backwards (we need largest to
-  // smallest).
-  for (size_t i = 0; i < floor(eigval.n_elem / 2.0); ++i)
-    eigval.swap_rows(i, (eigval.n_elem - 1) - i);
-
-  // Flip the coefficients to produce the same effect.
-  eigvec = arma::fliplr(eigvec);
-
-  transformedData = eigvec.t() * kernelMatrix;
+  KernelRule::ApplyKernelMatrix(data, transformedData, eigval,
+                                eigvec, newDimension, kernel);
 
   // Center the transformed data, if the user asked for it.
   if (centerTransformedData)
@@ -71,58 +43,42 @@ void KernelPCA<KernelType>::Apply(const arma::mat& data,
 }
 
 //! Apply Kernel Principal Component Analysis to the provided data set.
-template <typename KernelType>
-void KernelPCA<KernelType>::Apply(const arma::mat& data,
+template <typename KernelType, typename KernelRule>
+void KernelPCA<KernelType, KernelRule>::Apply(const arma::mat& data,
+                                  arma::mat& transformedData,
+                                  arma::vec& eigval,
+                                  arma::mat& eigvec)
+{
+  Apply(data, transformedData, eigval, eigvec, data.n_cols);
+}
+
+//! Apply Kernel Principal Component Analysis to the provided data set.
+template <typename KernelType, typename KernelRule>
+void KernelPCA<KernelType, KernelRule>::Apply(const arma::mat& data,
                                   arma::mat& transformedData,
                                   arma::vec& eigVal)
 {
   arma::mat coeffs;
-  Apply(data, transformedData, eigVal, coeffs);
+  Apply(data, transformedData, eigVal, coeffs, data.n_cols);
 }
 
 //! Use KPCA for dimensionality reduction.
-template <typename KernelType>
-void KernelPCA<KernelType>::Apply(arma::mat& data, const size_t newDimension)
+template <typename KernelType, typename KernelRule>
+void KernelPCA<KernelType, KernelRule>::Apply(arma::mat& data,
+                                    const size_t newDimension)
 {
   arma::mat coeffs;
   arma::vec eigVal;
 
-  Apply(data, data, eigVal, coeffs);
+  Apply(data, data, eigVal, coeffs, newDimension);
 
   if (newDimension < coeffs.n_rows && newDimension > 0)
     data.shed_rows(newDimension, data.n_rows - 1);
 }
 
-//! Construct the kernel matrix.
-template <typename KernelType>
-void KernelPCA<KernelType>::GetKernelMatrix(const arma::mat& data,
-                                            arma::mat& kernelMatrix)
-{
-  // Resize the kernel matrix to the right size.
-  kernelMatrix.set_size(data.n_cols, data.n_cols);
-
-  // Note that we only need to calculate the upper triangular part of the kernel
-  // matrix, since it is symmetric.  This helps minimize the number of kernel
-  // evaluations.
-  for (size_t i = 0; i < data.n_cols; ++i)
-  {
-    for (size_t j = i; j < data.n_cols; ++j)
-    {
-      // Evaluate the kernel on these two points.
-      kernelMatrix(i, j) = kernel.Evaluate(data.unsafe_col(i),
-                                           data.unsafe_col(j));
-    }
-  }
-
-  // Copy to the lower triangular part of the matrix.
-  for (size_t i = 1; i < data.n_cols; ++i)
-    for (size_t j = 0; j < i; ++j)
-      kernelMatrix(i, j) = kernelMatrix(j, i);
-}
-
 // Returns a String of the Object
-template <typename KernelType>
-std::string KernelPCA<KernelType>::ToString() const
+template <typename KernelType, typename KernelRule>
+std::string KernelPCA<KernelType, KernelRule>::ToString() const
 {
   std::ostringstream convert;
   convert << "KernelPCA [" << this << "]" << std::endl;
diff --git a/src/mlpack/methods/perceptron/initialization_methods/CMakeLists.txt b/src/mlpack/methods/kernel_pca/kernel_rules/CMakeLists.txt
similarity index 91%
copy from src/mlpack/methods/perceptron/initialization_methods/CMakeLists.txt
copy to src/mlpack/methods/kernel_pca/kernel_rules/CMakeLists.txt
index d5d9c31..7093fbf 100644
--- a/src/mlpack/methods/perceptron/initialization_methods/CMakeLists.txt
+++ b/src/mlpack/methods/kernel_pca/kernel_rules/CMakeLists.txt
@@ -1,8 +1,8 @@
 # Define the files we need to compile
 # Anything not in this list will not be compiled into MLPACK.
 set(SOURCES
-  random_init.hpp
-  zero_init.hpp
+  nystroem_method.hpp
+  naive_method.hpp
 )
 
 # Add directory name to sources.
diff --git a/src/mlpack/methods/kernel_pca/kernel_rules/naive_method.hpp b/src/mlpack/methods/kernel_pca/kernel_rules/naive_method.hpp
new file mode 100644
index 0000000..7a97f34
--- /dev/null
+++ b/src/mlpack/methods/kernel_pca/kernel_rules/naive_method.hpp
@@ -0,0 +1,89 @@
+/**
+ * @file naive_method.hpp
+ * @author Ajinkya Kale
+ *
+ * Use the naive method to construct the kernel matrix.
+ */
+
+#ifndef __MLPACK_METHODS_KERNEL_PCA_NAIVE_METHOD_HPP
+#define __MLPACK_METHODS_KERNEL_PCA_NAIVE_METHOD_HPP
+
+#include <mlpack/core.hpp>
+
+namespace mlpack {
+namespace kpca {
+
+template<typename KernelType>
+class NaiveKernelRule
+{
+ public:
+  public:
+    /**
+     * Construct the kernel matrix approximation using the nystroem method.
+     *
+     * @param data Input data points.
+     * @param transformedData Matrix to output results into.
+     * @param eigval KPCA eigenvalues will be written to this vector.
+     * @param eigvec KPCA eigenvectors will be written to this matrix.
+     * @param rank Rank to be used for matrix approximation.
+     * @param kernel Kernel to be used for computation.
+     */
+    static void ApplyKernelMatrix(const arma::mat& data,
+                                  arma::mat& transformedData,
+                                  arma::vec& eigval,
+                                  arma::mat& eigvec,
+                                  const size_t /* unused */,
+                                  KernelType kernel = KernelType())
+  {
+    // Construct the kernel matrix.
+    arma::mat kernelMatrix;
+    // Resize the kernel matrix to the right size.
+    kernelMatrix.set_size(data.n_cols, data.n_cols);
+
+    // Note that we only need to calculate the upper triangular part of the 
+    // kernel matrix, since it is symmetric. This helps minimize the number of
+    // kernel evaluations.
+    for (size_t i = 0; i < data.n_cols; ++i)
+    {
+      for (size_t j = i; j < data.n_cols; ++j)
+      {
+        // Evaluate the kernel on these two points.
+        kernelMatrix(i, j) = kernel.Evaluate(data.unsafe_col(i),
+                                             data.unsafe_col(j));
+      }
+    }
+
+    // Copy to the lower triangular part of the matrix.
+    for (size_t i = 1; i < data.n_cols; ++i)
+      for (size_t j = 0; j < i; ++j)
+        kernelMatrix(i, j) = kernelMatrix(j, i);
+
+    // For PCA the data has to be centered, even if the data is centered. But it
+    // is not guaranteed that the data, when mapped to the kernel space, is also
+    // centered. Since we actually never work in the feature space we cannot
+    // center the data. So, we perform a "psuedo-centering" using the kernel
+    // matrix.
+    arma::rowvec rowMean = arma::sum(kernelMatrix, 0) / kernelMatrix.n_cols;
+    kernelMatrix.each_col() -= arma::sum(kernelMatrix, 1) / kernelMatrix.n_cols;
+    kernelMatrix.each_row() -= rowMean;
+    kernelMatrix += arma::sum(rowMean) / kernelMatrix.n_cols;
+
+    // Eigendecompose the centered kernel matrix.
+    arma::eig_sym(eigval, eigvec, kernelMatrix);
+
+    // Swap the eigenvalues since they are ordered backwards (we need largest to
+    // smallest).
+    for (size_t i = 0; i < floor(eigval.n_elem / 2.0); ++i)
+      eigval.swap_rows(i, (eigval.n_elem - 1) - i);
+
+    // Flip the coefficients to produce the same effect.
+    eigvec = arma::fliplr(eigvec);
+
+    transformedData = eigvec.t() * kernelMatrix;
+  }
+};
+
+}; // namespace kpca
+}; // namespace mlpack
+
+#endif
diff --git a/src/mlpack/methods/kernel_pca/kernel_rules/nystroem_method.hpp b/src/mlpack/methods/kernel_pca/kernel_rules/nystroem_method.hpp
new file mode 100644
index 0000000..34bcf62
--- /dev/null
+++ b/src/mlpack/methods/kernel_pca/kernel_rules/nystroem_method.hpp
@@ -0,0 +1,71 @@
+/**
+ * @file nystroem_method.hpp
+ * @author Marcus Edel
+ *
+ * Use the Nystroem method for approximating a kernel matrix.
+ */
+
+#ifndef __MLPACK_METHODS_KERNEL_PCA_NYSTROEM_METHOD_HPP
+#define __MLPACK_METHODS_KERNEL_PCA_NYSTROEM_METHOD_HPP
+
+#include <mlpack/core.hpp>
+#include <mlpack/methods/nystroem_method/kmeans_selection.hpp>
+#include <mlpack/methods/nystroem_method/nystroem_method.hpp>
+
+namespace mlpack {
+namespace kpca {
+
+template<
+  typename KernelType,
+  typename PointSelectionPolicy = kernel::KMeansSelection<> 
+>
+class NystroemKernelRule
+{
+  public:
+    /**
+     * Construct the kernel matrix approximation using the nystroem method.
+     *
+     * @param data Input data points.
+     * @param transformedData Matrix to output results into.
+     * @param eigval KPCA eigenvalues will be written to this vector.
+     * @param eigvec KPCA eigenvectors will be written to this matrix.
+     * @param rank Rank to be used for matrix approximation.
+     * @param kernel Kernel to be used for computation.
+     */
+    static void ApplyKernelMatrix(const arma::mat& data,
+                                  arma::mat& transformedData,
+                                  arma::vec& eigval,
+                                  arma::mat& eigvec,
+                                  const size_t rank,
+                                  KernelType kernel = KernelType())
+    {
+      arma::mat G, v;
+      kernel::NystroemMethod<KernelType, PointSelectionPolicy> nm(data, kernel,
+                                                        rank);
+      nm.Apply(G);
+      transformedData = G.t() * G;
+
+      // For PCA the data has to be centered, even if the data is centered. But
+      // it is not guaranteed that the data, when mapped to the kernel space, is
+      // also centered. Since we actually never work in the feature space we 
+      // cannot center the data. So, we perform a "psuedo-centering" using the
+      // kernel matrix.
+      arma::rowvec rowMean = arma::sum(transformedData, 0) / 
+          transformedData.n_cols;
+      transformedData.each_col() -= arma::sum(transformedData, 1) /
+          transformedData.n_cols;
+      transformedData.each_row() -= rowMean;
+      transformedData += arma::sum(rowMean) / transformedData.n_cols;
+
+      // Eigendecompose the centered kernel matrix.
+      arma::svd(eigvec, eigval, v, transformedData);
+      eigval %= eigval / (data.n_cols - 1);
+
+      transformedData = eigvec.t() * G.t();
+    }
+};
+
+}; // namespace kpca
+}; // namespace mlpack
+
+#endif
diff --git a/src/mlpack/methods/amf/init_rules/CMakeLists.txt b/src/mlpack/methods/nystroem_method/CMakeLists.txt
similarity index 77%
copy from src/mlpack/methods/amf/init_rules/CMakeLists.txt
copy to src/mlpack/methods/nystroem_method/CMakeLists.txt
index a31d281..5b3f5d7 100644
--- a/src/mlpack/methods/amf/init_rules/CMakeLists.txt
+++ b/src/mlpack/methods/nystroem_method/CMakeLists.txt
@@ -1,8 +1,11 @@
 # Define the files we need to compile
 # Anything not in this list will not be compiled into MLPACK.
 set(SOURCES
-  random_init.hpp
-  random_acol_init.hpp
+  nystroem_method.hpp
+  nystroem_method_impl.hpp
+  ordered_selection.hpp
+  random_selection.hpp
+  kmeans_selection.hpp
 )
 
 # Add directory name to sources.
diff --git a/src/mlpack/tests/CMakeLists.txt b/src/mlpack/tests/CMakeLists.txt
index 897520b..5200932 100644
--- a/src/mlpack/tests/CMakeLists.txt
+++ b/src/mlpack/tests/CMakeLists.txt
@@ -52,6 +52,7 @@ add_executable(mlpack_test
   union_find_test.cpp
   svd_batch_test.cpp
   svd_incremental_test.cpp
+  nystroem_method_test.cpp
 )
 # Link dependencies of test executable.
 target_link_libraries(mlpack_test
diff --git a/src/mlpack/tests/kernel_pca_test.cpp b/src/mlpack/tests/kernel_pca_test.cpp
index 5c8f874..e1459a0 100644
--- a/src/mlpack/tests/kernel_pca_test.cpp
+++ b/src/mlpack/tests/kernel_pca_test.cpp
@@ -5,9 +5,9 @@
  * Test file for Kernel PCA.
  */
 #include <mlpack/core.hpp>
-#include <mlpack/core/kernels/linear_kernel.hpp>
 #include <mlpack/core/kernels/gaussian_kernel.hpp>
 #include <mlpack/methods/kernel_pca/kernel_pca.hpp>
+#include <mlpack/methods/kernel_pca/kernel_rules/nystroem_method.hpp>
 
 #include <boost/test/unit_test.hpp>
 #include "old_boost_test_definitions.hpp"
@@ -25,7 +25,7 @@ using namespace arma;
  * If KernelPCA is working right, then it should turn a circle dataset into a
  * linearly separable dataset in one dimension (which is easy to check).
  */
-BOOST_AUTO_TEST_CASE(CircleTransformationTest)
+BOOST_AUTO_TEST_CASE(CircleTransformationTestNaive)
 {
   // The dataset, which will have three concentric rings in three dimensions.
   arma::mat dataset;
@@ -56,10 +56,8 @@ BOOST_AUTO_TEST_CASE(CircleTransformationTest)
     dataset(2, i) += 5.0 * (dataset(2, i) / pointNorm);
   }
 
-  data::Save("circle.csv", dataset);
-
   // Now we have a dataset; we will use the GaussianKernel to perform KernelPCA
-  // to take it down to one dimension.
+  // using the naive method to take it down to one dimension.
   KernelPCA<GaussianKernel> p;
   p.Apply(dataset, 1);
 
@@ -85,4 +83,66 @@ BOOST_AUTO_TEST_CASE(CircleTransformationTest)
   BOOST_REQUIRE_EQUAL(ranges[1].Contains(ranges[2]), false);
 }
 
+/**
+ * If KernelPCA is working right, then it should turn a circle dataset into a
+ * linearly separable dataset in one dimension (which is easy to check).
+ */
+BOOST_AUTO_TEST_CASE(CircleTransformationTestNystroem)
+{
+  // The dataset, which will have three concentric rings in three dimensions.
+  arma::mat dataset;
+
+  // Now, there are 750 points centered at the origin with unit variance.
+  dataset.randn(3, 750);
+  dataset *= 0.05;
+
+  // Take the second 250 points and spread them away from the origin.
+  for (size_t i = 250; i < 500; ++i)
+  {
+    // Push the point away from the origin by 2.
+    const double pointNorm = norm(dataset.col(i), 2);
+
+    dataset(0, i) += 2.0 * (dataset(0, i) / pointNorm);
+    dataset(1, i) += 2.0 * (dataset(1, i) / pointNorm);
+    dataset(2, i) += 2.0 * (dataset(2, i) / pointNorm);
+  }
+
+  // Take the third 500 points and spread them away from the origin.
+  for (size_t i = 500; i < 750; ++i)
+  {
+    // Push the point away from the origin by 5.
+    const double pointNorm = norm(dataset.col(i), 2);
+
+    dataset(0, i) += 5.0 * (dataset(0, i) / pointNorm);
+    dataset(1, i) += 5.0 * (dataset(1, i) / pointNorm);
+    dataset(2, i) += 5.0 * (dataset(2, i) / pointNorm);
+  }
+
+  // Now we have a dataset; we will use the GaussianKernel to perform KernelPCA
+  // using the nytroem method to take it down to one dimension.
+  KernelPCA<GaussianKernel, NystroemKernelRule<GaussianKernel> > p;
+  p.Apply(dataset, 1);
+
+  // Get the ranges of each "class".  These are all initialized as empty ranges
+  // containing no points.
+  Range ranges[3];
+  ranges[0] = Range();
+  ranges[1] = Range();
+  ranges[2] = Range();
+
+  // Expand the ranges to hold all of the points in the class.
+  for (size_t i = 0; i < 250; ++i)
+    ranges[0] |= dataset(0, i);
+  for (size_t i = 250; i < 500; ++i)
+    ranges[1] |= dataset(0, i);
+  for (size_t i = 500; i < 750; ++i)
+    ranges[2] |= dataset(0, i);
+
+  // None of these ranges should overlap -- the classes should be linearly
+  // separable.
+  BOOST_REQUIRE_EQUAL(ranges[0].Contains(ranges[1]), false);
+  BOOST_REQUIRE_EQUAL(ranges[0].Contains(ranges[2]), false);
+  BOOST_REQUIRE_EQUAL(ranges[1].Contains(ranges[2]), false);
+}
+
 BOOST_AUTO_TEST_SUITE_END();



More information about the mlpack-git mailing list