[mlpack-svn] r10844 - in mlpack/trunk/src/mlpack/methods: . mvu

fastlab-svn at coffeetalk-1.cc.gatech.edu fastlab-svn at coffeetalk-1.cc.gatech.edu
Fri Dec 16 01:38:47 EST 2011


Author: rcurtin
Date: 2011-12-16 01:38:47 -0500 (Fri, 16 Dec 2011)
New Revision: 10844

Added:
   mlpack/trunk/src/mlpack/methods/mvu/mvu.cpp
   mlpack/trunk/src/mlpack/methods/mvu/mvu_main.cpp
Removed:
   mlpack/trunk/src/mlpack/methods/mvu/mvu_impl.hpp
Modified:
   mlpack/trunk/src/mlpack/methods/CMakeLists.txt
   mlpack/trunk/src/mlpack/methods/mvu/CMakeLists.txt
   mlpack/trunk/src/mlpack/methods/mvu/mvu.hpp
   mlpack/trunk/src/mlpack/methods/mvu/mvu_objective_function.cpp
   mlpack/trunk/src/mlpack/methods/mvu/mvu_objective_function.hpp
Log:
Clean up MVU code and add a main executable to it.


Modified: mlpack/trunk/src/mlpack/methods/CMakeLists.txt
===================================================================
--- mlpack/trunk/src/mlpack/methods/CMakeLists.txt	2011-12-16 06:37:32 UTC (rev 10843)
+++ mlpack/trunk/src/mlpack/methods/CMakeLists.txt	2011-12-16 06:38:47 UTC (rev 10844)
@@ -9,7 +9,7 @@
   kmeans
   lars
   linear_regression
-  #mvu  # (currently known to not work)
+  mvu
   naive_bayes
   nca
   neighbor_search

Modified: mlpack/trunk/src/mlpack/methods/mvu/CMakeLists.txt
===================================================================
--- mlpack/trunk/src/mlpack/methods/mvu/CMakeLists.txt	2011-12-16 06:37:32 UTC (rev 10843)
+++ mlpack/trunk/src/mlpack/methods/mvu/CMakeLists.txt	2011-12-16 06:38:47 UTC (rev 10844)
@@ -4,9 +4,9 @@
 # Anything not in this list will not be compiled into MLPACK.
 set(SOURCES
   mvu.hpp
-  mvu_impl.hpp
-#  mvu_objective_function.hpp
-#  mvu_objective_function.cpp
+  mvu.cpp
+  mvu_objective_function.hpp
+  mvu_objective_function.cpp
 )
 
 # Add directory name to sources.
@@ -18,9 +18,9 @@
 # the parent scope).
 set(MLPACK_SRCS ${MLPACK_SRCS} ${DIR_SRCS} PARENT_SCOPE)
 
-add_executable(ncmvu
+add_executable(mvu
   mvu_main.cpp
 )
-target_link_libraries(ncmvu
+target_link_libraries(mvu
   mlpack
 )

Copied: mlpack/trunk/src/mlpack/methods/mvu/mvu.cpp (from rev 10803, mlpack/trunk/src/mlpack/methods/mvu/mvu_impl.hpp)
===================================================================
--- mlpack/trunk/src/mlpack/methods/mvu/mvu.cpp	                        (rev 0)
+++ mlpack/trunk/src/mlpack/methods/mvu/mvu.cpp	2011-12-16 06:38:47 UTC (rev 10844)
@@ -0,0 +1,33 @@
+/**
+ * @file mvu.cpp
+ * @author Ryan Curtin
+ *
+ * Implementation of the MVU class and its auxiliary objective function class.
+ */
+#include "mvu.hpp"
+#include "mvu_objective_function.hpp"
+
+#include <mlpack/core/optimizers/aug_lagrangian/aug_lagrangian.hpp>
+
+using namespace mlpack;
+using namespace mlpack::mvu;
+using namespace mlpack::optimization;
+
+MVU::MVU(const arma::mat& data) : data(data)
+{
+  // Nothing to do.
+}
+
+void MVU::Unfold(const size_t newDim,
+                 const size_t numNeighbors,
+                 arma::mat& outputData)
+{
+  MVUObjectiveFunction obj(data, newDim, numNeighbors);
+
+  // Set up Augmented Lagrangian method.
+  // Memory choice is arbitrary; this needs to be configurable.
+  AugLagrangian<MVUObjectiveFunction> aug(obj, 20);
+
+  outputData = obj.GetInitialPoint();
+  aug.Optimize(outputData, 0);
+}

Modified: mlpack/trunk/src/mlpack/methods/mvu/mvu.hpp
===================================================================
--- mlpack/trunk/src/mlpack/methods/mvu/mvu.hpp	2011-12-16 06:37:32 UTC (rev 10843)
+++ mlpack/trunk/src/mlpack/methods/mvu/mvu.hpp	2011-12-16 06:38:47 UTC (rev 10844)
@@ -22,22 +22,20 @@
  * - dataset
  * - new dimensionality
  */
-template<typename LagrangianFunction>
 class MVU
 {
  public:
-  MVU(arma::mat& data_in); // probably needs arguments
+  MVU(const arma::mat& data_in);
 
-  bool Unfold(arma::mat& output_coordinates); // probably needs arguments
+  void Unfold(const size_t newDim,
+              const size_t numNeighbors,
+              arma::mat& output_coordinates);
 
  private:
-  arma::mat& data_;
-  LagrangianFunction f_;
+  const arma::mat& data;
 };
 
 }; // namespace mvu
 }; // namespace mlpack
 
-#include "mvu_impl.h"
-
 #endif

Deleted: mlpack/trunk/src/mlpack/methods/mvu/mvu_impl.hpp
===================================================================
--- mlpack/trunk/src/mlpack/methods/mvu/mvu_impl.hpp	2011-12-16 06:37:32 UTC (rev 10843)
+++ mlpack/trunk/src/mlpack/methods/mvu/mvu_impl.hpp	2011-12-16 06:38:47 UTC (rev 10844)
@@ -1,42 +0,0 @@
-/**
- * @file mvu_impl.hpp
- * @author Ryan Curtin
- *
- * Implementation of the MVU class and its auxiliary objective function class.
- */
-#ifndef __MLPACK_METHODS_MVU_IMPL_HPP
-#define __MLPACK_METHODS_MVU_IMPL_HPP
-
-// In case it has not been included.
-#include "mvu.hpp"
-
-#include <mlpack/core/optimizers/aug_lagrangian/aug_lagrangian.hpp>
-
-namespace mlpack {
-namespace mvu {
-
-template<typename LagrangianFunction>
-MVU<LagrangianFunction>::MVU(arma::mat& data_in) :
-    data_(data_in),
-    f_(data_)
-{
-  // Nothing to do.
-}
-
-template<typename LagrangianFunction>
-bool MVU<LagrangianFunction>::Unfold(arma::mat& output_coordinates)
-{
-  // Set up Augmented Lagrangian method.
-  // Memory choice is arbitrary; this needs to be configurable.
-  mlpack::optimization::AugLagrangian<LagrangianFunction> aug(f_, 20);
-
-  output_coordinates = f_.GetInitialPoint();
-  aug.Optimize(0, output_coordinates);
-
-  return true;
-}
-
-}; // namespace mvu
-}; // namespace mlpack
-
-#endif

Added: mlpack/trunk/src/mlpack/methods/mvu/mvu_main.cpp
===================================================================
--- mlpack/trunk/src/mlpack/methods/mvu/mvu_main.cpp	                        (rev 0)
+++ mlpack/trunk/src/mlpack/methods/mvu/mvu_main.cpp	2011-12-16 06:38:47 UTC (rev 10844)
@@ -0,0 +1,74 @@
+/**
+ * @file mvu_main.cpp
+ * @author Ryan Curtin
+ *
+ * Executable for MVU.
+ */
+#include <mlpack/core.hpp>
+#include "mvu.hpp"
+
+PROGRAM_INFO("Maximum Variance Unfolding (MVU)", "This program implements "
+    "Maximum Variance Unfolding, a nonlinear dimensionality reduction "
+    "technique.  The method minimizes dimensionality by unfolding a manifold "
+    "such that the distances to the nearest neighbors of each point are held "
+    "constant.  For more information, see the following paper:\n"
+    "\n"
+    "@inproceedings{\n"
+    "  title = {An introduction to Nonlinear Dimensionality Reduction by \n"
+    "      Maximum Variance Unfolding},\n"
+    "  author = {Weinberger, K.Q. and Saul, L.K.},\n"
+    "  year = {2006},\n"
+    "  "
+    "}");
+
+PARAM_STRING_REQ("input_file", "Filename of input dataset.", "i");
+PARAM_INT_REQ("new_dim", "New dimensionality of dataset.", "d");
+
+PARAM_STRING("output_file", "Filename to save unfolded dataset to.", "o",
+    "output.csv");
+PARAM_INT("num_neighbors", "Number of nearest neighbors to consider while "
+    "unfolding.", "k", 5);
+
+using namespace mlpack;
+using namespace mlpack::mvu;
+using namespace arma;
+using namespace std;
+
+int main(int argc, char **argv)
+{
+  // Read from command line.
+  CLI::ParseCommandLine(argc, argv);
+
+  // Load input dataset.
+  const string inputFile = CLI::GetParam<string>("input_file");
+  mat data;
+  data::Load(inputFile, data, true);
+
+  // Verify that the requested dimensionality is valid.
+  const int newDim = CLI::GetParam<int>("new_dim");
+  if (newDim <= 0 || newDim > (int) data.n_rows)
+  {
+    Log::Fatal << "Invalid new dimensionality (" << newDim << ").  Must be "
+      << "between 1 and the input dataset dimensionality (" << data.n_rows
+      << ")." << std::endl;
+  }
+
+  // Verify that the number of neighbors is valid.
+  const int numNeighbors = CLI::GetParam<int>("num_neighbors");
+  if (numNeighbors <= 0 || numNeighbors > (int) data.n_cols)
+  {
+    Log::Fatal << "Invalid number of neighbors (" << numNeighbors << ").  Must "
+        << "be between 1 and the number of points in the input dataset ("
+        << data.n_cols << ")." << std::endl;
+  }
+
+  // Now run MVU.
+  MVU mvu(data);
+
+  mat output;
+  mvu.Unfold(newDim, numNeighbors, output);
+
+  // Save results to file.
+  const string outputFile = CLI::GetParam<string>("output_file");
+  data::Save(outputFile, output, true);
+}

Modified: mlpack/trunk/src/mlpack/methods/mvu/mvu_objective_function.cpp
===================================================================
--- mlpack/trunk/src/mlpack/methods/mvu/mvu_objective_function.cpp	2011-12-16 06:37:32 UTC (rev 10843)
+++ mlpack/trunk/src/mlpack/methods/mvu/mvu_objective_function.cpp	2011-12-16 06:38:47 UTC (rev 10844)
@@ -6,8 +6,7 @@
  */
 #include "mvu_objective_function.hpp"
 
-#include <mlpack/neighbor_search/neighbor_search.h>
-#include <mlpack/fastica/lin_alg.h>
+#include <mlpack/methods/neighbor_search/neighbor_search.hpp>
 
 using namespace mlpack;
 using namespace mlpack::mvu;
@@ -17,30 +16,23 @@
 MVUObjectiveFunction::MVUObjectiveFunction()
 {
   // Need to set initial point?  I guess this will be the initial matrix...
-  Log::Fatal << "Initialized MVUObjectiveFunction all wrong." << std::endl;
+  Log::Warn << "Don't use empty constructor for MVUObjectiveFunction()!"
+      << "MVU will fail." << std::endl;
 }
 
-MVUObjectiveFunction::MVUObjectiveFunction(arma::mat& initial_point) :
-    num_neighbors_(5)
+MVUObjectiveFunction::MVUObjectiveFunction(const arma::mat& initial_point,
+                                           const size_t newDim,
+                                           const size_t numNeighbors) :
+    numNeighbors(numNeighbors)
 {
   // We will calculate the nearest neighbors of this dataset.
   AllkNN allknn(initial_point);
 
-  allknn.ComputeNeighbors(neighbor_indices_, neighbor_distances_);
-//  NOTIFY("Neighbor indices: ");
-//  std::cout << neighbor_indices_;
-//  NOTIFY("Neighbor distances: ");
-//  std::cout << neighbor_distances_;
+  allknn.Search(numNeighbors, neighborIndices, neighborDistances);
 
   // Now shrink the point matrix to the correct target size.
-  int dimension = 2; // Get this from CLI: TODO.
-  initial_point_ = initial_point;
-  initial_point_.shed_rows(2, initial_point.n_rows - 1);
-
-  // Center the matrix, for shits and giggles.
-//  arma::mat tmp;
-//  linalg__private::Center(initial_point_, tmp);
-//  initial_point_ = tmp;
+  initialPoint = initial_point;
+  initialPoint.shed_rows(newDim, initial_point.n_rows - 1);
 }
 
 double MVUObjectiveFunction::Evaluate(const arma::mat& coordinates)
@@ -52,7 +44,7 @@
   // AugLagrangian use).
   double objective = 0;
 
-  for (int i = 0; i < coordinates.n_cols; i++)
+  for (size_t i = 0; i < coordinates.n_cols; i++)
     objective -= dot(coordinates.unsafe_col(i), coordinates.unsafe_col(i));
 
   return objective;
@@ -63,16 +55,10 @@
 {
   // Our objective, f(R) = sum_{ij} (R^T R)_ij, is differentiable into
   //   f'(R) = 2 * R.
-  Log::Info << "Coordinates are:\n" << coordinates;
-
   gradient = 2 * coordinates;
-
-  Log::Info << "Calculated gradient is\n" << gradient;
-
-  std::cout << gradient;
 }
 
-double MVUObjectiveFunction::EvaluateConstraint(int index,
+double MVUObjectiveFunction::EvaluateConstraint(const size_t index,
                                                 const arma::mat& coordinates)
 {
   if (index == 0)
@@ -83,12 +69,6 @@
     // This is a naive implementation; we may be able to improve upon it
     // significantly by avoiding the actual calculation of the Gram matrix
     // (R^T * R).
-    if (accu(trans(coordinates) * coordinates) > 0)
-    {
-      Log::Debug << "Constraint 0 is nonzero: " <<
-          accu(trans(coordinates) * coordinates) << std::endl;
-    }
-
     return accu(trans(coordinates) * coordinates);
   }
 
@@ -100,37 +80,18 @@
   //   (R^T R)_ii - 2 (R^T R)_ij + (R^T R)_jj - || x_i - x_j ||^2 = 0
   //
   // We will get the i and j values from the given index.
-  int i = floor(((double) (index - 1)) / (double) num_neighbors_);
-  int j = neighbor_indices_[index - 1]; // Retrieve index of this neighbor.
+  int i = floor(((double) (index - 1)) / (double) numNeighbors);
+  int j = neighborIndices[index - 1]; // Retrieve index of this neighbor.
 
   // (R^T R)_ij = R.col(i) * R.col(j)  (dot product)
   double rrt_ii = dot(coordinates.col(i), coordinates.col(i));
   double rrt_ij = dot(coordinates.col(i), coordinates.col(j));
   double rrt_jj = dot(coordinates.col(j), coordinates.col(j));
 
-  // We must remember the actual distance between points.
-//  NOTIFY("Index %d: i is %d, j is %d", index, i, j);
-//  NOTIFY("Neighbor %d: distance of %lf", index - 1,
-//      neighbor_distances_[index - 1]);
-//  NOTIFY("rrt_ii: %lf; rrt_ij: %lf; rrt_jj: %lf", rrt_ii, rrt_ij, rrt_jj);
-//  NOTIFY("r_i: ");
-//  std::cout << coordinates.col(i);
-//  NOTIFY("r_j: ");
-//  std::cout << coordinates.col(j);
-//  NOTIFY("LHS: %lf", (rrt_ii - 2 * rrt_ij + rrt_jj));
-
-  if (((rrt_ii - 2 * rrt_ij + rrt_jj) -
-      neighbor_distances_[index - 1]) > 1e-5)
-  {
-    Log::Debug << "Constraint " << index << " is nonzero: " <<
-        ((rrt_ii - 2 * rrt_ij + rrt_jj) - neighbor_distances_[index - 1])
-        << std::endl;
-  }
-
-  return ((rrt_ii - 2 * rrt_ij + rrt_jj) - neighbor_distances_[index - 1]);
+  return ((rrt_ii - 2 * rrt_ij + rrt_jj) - neighborDistances[index - 1]);
 }
 
-void MVUObjectiveFunction::GradientConstraint(int index,
+void MVUObjectiveFunction::GradientConstraint(const size_t index,
                                               const arma::mat& coordinates,
                                               arma::mat& gradient)
 {
@@ -151,8 +112,6 @@
     // We can see that we can separate this out into two distinct sums, for each
     // row and column, so we can loop first over the columns and then over the
     // rows to assemble the entire gradient matrix.
-    //for (int i = 0; i < coordinates.n_cols; i++)
-    //  gradient.col(i) += accu(coordinates.col(i)); // sum_i (R_xi)
     arma::mat ones(gradient.n_cols, gradient.n_cols);
     gradient = coordinates * ones;
 
@@ -163,8 +122,8 @@
   //  (R^T R)_ii - 2 (R^T R)_ij + (R^T R)_jj = || x_i - x_j ||^2
   //
   // We will get the i and j values from the given index.
-  int i = floor(((double) (index - 1)) / (double) num_neighbors_);
-  int j = neighbor_indices_[index - 1];
+  int i = floor(((double) (index - 1)) / (double) numNeighbors);
+  int j = neighborIndices[index - 1];
 
   // The gradient matrix for the nearest neighbor constraint (i, j) is zero
   // except for column i, which is equal to 2 (R.col(i) - R.col(j)) and also

Modified: mlpack/trunk/src/mlpack/methods/mvu/mvu_objective_function.hpp
===================================================================
--- mlpack/trunk/src/mlpack/methods/mvu/mvu_objective_function.hpp	2011-12-16 06:37:32 UTC (rev 10843)
+++ mlpack/trunk/src/mlpack/methods/mvu/mvu_objective_function.hpp	2011-12-16 06:38:47 UTC (rev 10844)
@@ -18,8 +18,8 @@
  *     the Twenty-First National Conference on Artificial Intelligence
  *     (AAAI-06), 2006.
  */
-#ifndef __MLPACK_METHODS_MVU_MVU_OBJECTIVE_FUNCTION_H
-#define __MLPACK_METHODS_MVU_MVU_OBJECTIVE_FUNCTION_H
+#ifndef __MLPACK_METHODS_MVU_MVU_OBJECTIVE_FUNCTION_HPP
+#define __MLPACK_METHODS_MVU_MVU_OBJECTIVE_FUNCTION_HPP
 
 #include <mlpack/core.hpp>
 
@@ -50,28 +50,30 @@
 {
  public:
   MVUObjectiveFunction();
-  MVUObjectiveFunction(arma::mat& initial_point);
+  MVUObjectiveFunction(const arma::mat& initial_point,
+                       const size_t newDim,
+                       const size_t numNeighbors);
 
   double Evaluate(const arma::mat& coordinates);
   void Gradient(const arma::mat& coordinates, arma::mat& gradient);
 
-  int NumConstraints() const { return num_neighbors_ * initial_point_.n_cols; }
+  size_t NumConstraints() const { return numNeighbors * initialPoint.n_cols; }
 
-  double EvaluateConstraint(int index, const arma::mat& coordinates);
-  void GradientConstraint(int index,
+  double EvaluateConstraint(const size_t index, const arma::mat& coordinates);
+  void GradientConstraint(const size_t index,
                           const arma::mat& coordinates,
                           arma::mat& gradient);
 
-  const arma::mat& GetInitialPoint() { return initial_point_; }
+  const arma::mat& GetInitialPoint() const { return initialPoint; }
 
  private:
-  arma::mat initial_point_;
-  int num_neighbors_;
+  arma::mat initialPoint;
+  size_t numNeighbors;
 
   // These hold the output of the nearest neighbors computation (done in the
   // constructor).
-  arma::Col<index_t> neighbor_indices_;
-  arma::vec neighbor_distances_;
+  arma::Mat<size_t> neighborIndices;
+  arma::mat neighborDistances;
 };
 
 }; // namespace mvu




More information about the mlpack-svn mailing list