[mlpack-svn] r16751 - in mlpack/trunk/src/mlpack: core/tree core/tree/cosine_tree methods methods/quic_svd tests

fastlab-svn at coffeetalk-1.cc.gatech.edu fastlab-svn at coffeetalk-1.cc.gatech.edu
Thu Jul 3 09:08:12 EDT 2014


Author: siddharth.950
Date: Thu Jul  3 09:08:12 2014
New Revision: 16751

Log:
Adding QUIC-SVD.

Added:
   mlpack/trunk/src/mlpack/core/tree/cosine_tree/cosine_tree.cpp
   mlpack/trunk/src/mlpack/methods/quic_svd/
   mlpack/trunk/src/mlpack/methods/quic_svd/CMakeLists.txt
   mlpack/trunk/src/mlpack/methods/quic_svd/quic_svd.hpp
   mlpack/trunk/src/mlpack/methods/quic_svd/quic_svd_impl.hpp
   mlpack/trunk/src/mlpack/tests/quic_svd_test.cpp
Removed:
   mlpack/trunk/src/mlpack/core/tree/cosine_tree/cosine_tree_impl.hpp
Modified:
   mlpack/trunk/src/mlpack/core/tree/CMakeLists.txt
   mlpack/trunk/src/mlpack/core/tree/cosine_tree/cosine_tree.hpp
   mlpack/trunk/src/mlpack/methods/CMakeLists.txt
   mlpack/trunk/src/mlpack/methods/CMakeLists.txt~
   mlpack/trunk/src/mlpack/tests/CMakeLists.txt
   mlpack/trunk/src/mlpack/tests/CMakeLists.txt~

Modified: mlpack/trunk/src/mlpack/core/tree/CMakeLists.txt
==============================================================================
--- mlpack/trunk/src/mlpack/core/tree/CMakeLists.txt	(original)
+++ mlpack/trunk/src/mlpack/core/tree/CMakeLists.txt	Thu Jul  3 09:08:12 2014
@@ -14,7 +14,7 @@
   binary_space_tree/traits.hpp
   bounds.hpp
   cosine_tree/cosine_tree.hpp
-  cosine_tree/cosine_tree_impl.hpp
+  cosine_tree/cosine_tree.cpp
   cover_tree/cover_tree.hpp
   cover_tree/cover_tree_impl.hpp
   cover_tree/first_point_is_root.hpp

Added: mlpack/trunk/src/mlpack/core/tree/cosine_tree/cosine_tree.cpp
==============================================================================
--- (empty file)
+++ mlpack/trunk/src/mlpack/core/tree/cosine_tree/cosine_tree.cpp	Thu Jul  3 09:08:12 2014
@@ -0,0 +1,425 @@
+/**
+ * @file cosine_tree_impl.hpp
+ * @author Siddharth Agrawal
+ *
+ * Implementation of cosine tree.
+ */
+#include "cosine_tree.hpp"
+
+#include <boost/math/distributions/normal.hpp>
+
+namespace mlpack {
+namespace tree {
+
+CosineTree::CosineTree(const arma::mat& dataset) :
+    dataset(dataset),
+    parent(NULL),
+    right(NULL),
+    left(NULL),
+    numColumns(dataset.n_cols)
+{  
+  // Initialize sizes of column indices and l2 norms.
+  indices.resize(numColumns);
+  l2NormsSquared.zeros(numColumns);
+  
+  // Set indices and calculate squared norms of the columns.
+  for(size_t i = 0; i < numColumns; i++)
+  {
+    indices[i] = i;
+    double l2Norm = arma::norm(dataset.col(i), 2);
+    l2NormsSquared(i) = l2Norm * l2Norm;
+  }
+  
+  // Frobenius norm of columns in the node.
+  frobNormSquared = arma::accu(l2NormsSquared);
+  
+  // Calculate centroid of columns in the node.
+  CalculateCentroid();
+  
+  splitPointIndex = ColumnSampleLS();
+}
+
+CosineTree::CosineTree(CosineTree& parentNode,
+                       const std::vector<size_t>& subIndices) :
+    dataset(parentNode.GetDataset()),
+    parent(&parentNode),
+    right(NULL),
+    left(NULL),
+    numColumns(subIndices.size())
+{
+  // Initialize sizes of column indices and l2 norms.
+  indices.resize(numColumns);
+  l2NormsSquared.zeros(numColumns);
+  
+  // Set indices and squared norms of the columns.
+  for(size_t i = 0; i < numColumns; i++)
+  {
+    indices[i] = parentNode.indices[subIndices[i]];
+    l2NormsSquared(i) = parentNode.l2NormsSquared(subIndices[i]);
+  }
+  
+  // Frobenius norm of columns in the node.
+  frobNormSquared = arma::accu(l2NormsSquared);
+  
+  // Calculate centroid of columns in the node.
+  CalculateCentroid();
+  
+  splitPointIndex = ColumnSampleLS();
+}
+
+CosineTree::CosineTree(const arma::mat& dataset,
+                       const double epsilon,
+                       const double delta) :
+    dataset(dataset),
+    epsilon(epsilon),
+    delta(delta)
+{
+  // Declare the cosine tree priority queue.
+  CosineNodeQueue treeQueue;
+  
+  // Define root node of the tree and add it to the queue.
+  CosineTree root(dataset);
+  arma::vec tempVector = arma::zeros(dataset.n_rows);
+  root.L2Error(0);
+  root.BasisVector(tempVector);
+  treeQueue.push(&root);
+  
+  // Initialize Monte Carlo error estimate for comparison.
+  double monteCarloError = root.FrobNormSquared();
+  
+  while(monteCarloError > epsilon * root.FrobNormSquared())
+  {
+    // Pop node from queue with highest projection error.
+    CosineTree* currentNode;
+    currentNode = treeQueue.top();
+    treeQueue.pop();
+    
+    // Split the node into left and right children.
+    currentNode->CosineNodeSplit();
+    
+    // Obtain pointers to the left and right children of the current node.
+    CosineTree *currentLeft, *currentRight;
+    currentLeft = currentNode->Left();
+    currentRight = currentNode->Right();
+    
+    // Calculate basis vectors of left and right children.
+    arma::vec lBasisVector, rBasisVector;
+    
+    ModifiedGramSchmidt(treeQueue, currentLeft->Centroid(), lBasisVector);
+    ModifiedGramSchmidt(treeQueue, currentRight->Centroid(), rBasisVector,
+                        &lBasisVector);
+    
+    // Add basis vectors to their respective nodes.
+    currentLeft->BasisVector(lBasisVector);
+    currentRight->BasisVector(rBasisVector);
+    
+    // Calculate Monte Carlo error estimates for child nodes.
+    MonteCarloError(currentLeft, treeQueue, &lBasisVector, &rBasisVector);
+    MonteCarloError(currentRight, treeQueue, &lBasisVector, &rBasisVector);
+    
+    // Push child nodes into the priority queue.
+    treeQueue.push(currentLeft);
+    treeQueue.push(currentRight);
+    
+    // Calculate Monte Carlo error estimate for the root node.
+    monteCarloError = MonteCarloError(&root, treeQueue);
+  }
+  
+  // Construct the subspace basis from the current priority queue.
+  ConstructBasis(treeQueue);
+}
+
+void CosineTree::ModifiedGramSchmidt(CosineNodeQueue& treeQueue,
+                                     arma::vec& centroid,
+                                     arma::vec& newBasisVector,
+                                     arma::vec* addBasisVector)
+{
+  // Set new basis vector to centroid.
+  newBasisVector = centroid;
+
+  // Variables for iterating throught the priority queue.
+  CosineTree *currentNode;
+  CosineNodeQueue::const_iterator i = treeQueue.begin();
+
+  // For every vector in the current basis, remove its projection from the
+  // centroid.
+  for(; i != treeQueue.end(); i++)
+  {
+    currentNode = *i;
+    
+    double projection = arma::dot(currentNode->BasisVector(), centroid);
+    newBasisVector -= projection * currentNode->BasisVector();
+  }
+  
+  // If additional basis vector is passed, take it into account.
+  if(addBasisVector)
+  {
+    double projection = arma::dot(*addBasisVector, centroid);
+    newBasisVector -= *addBasisVector * projection;
+  }
+  
+  // Normalize the modified centroid vector.
+  if(arma::norm(newBasisVector, 2))
+    newBasisVector /= arma::norm(newBasisVector, 2);
+}
+
+double CosineTree::MonteCarloError(CosineTree* node,
+                                   CosineNodeQueue& treeQueue,
+                                   arma::vec* addBasisVector1,
+                                   arma::vec* addBasisVector2)
+{
+  std::vector<size_t> sampledIndices;
+  arma::vec probabilities;
+  
+  // Sample O(log m) points from the input node's distribution.
+  // 'm' is the number of columns present in the node.
+  size_t numSamples = log(node->NumColumns()) + 1;  
+  node->ColumnSamplesLS(sampledIndices, probabilities, numSamples);
+  
+  // Get pointer to the original dataset.
+  arma::mat dataset = node->GetDataset();
+  
+  // Initialize weighted projection magnitudes as zeros.
+  arma::vec weightedMagnitudes;
+  weightedMagnitudes.zeros(numSamples);
+  
+  // Set size of projection vector, depending on whether additional basis
+  // vectors are passed.
+  size_t projectionSize;
+  if(addBasisVector1 && addBasisVector2)
+    projectionSize = treeQueue.size() + 2;
+  else
+    projectionSize = treeQueue.size();
+  
+  // For each sample, calculate the weighted projection onto the current basis.
+  for(size_t i = 0; i < numSamples; i++)
+  {
+    // Initialize projection as a vector of zeros.
+    arma::vec projection;
+    projection.zeros(projectionSize);
+
+    CosineTree *currentNode;
+    CosineNodeQueue::const_iterator j = treeQueue.begin();
+  
+    size_t k = 0;
+    // Compute the projection of the sampled vector onto the existing subspace.
+    for(; j != treeQueue.end(); j++, k++)
+    {
+      currentNode = *j;
+    
+      projection(k) = arma::dot(dataset.col(sampledIndices[i]),
+                                currentNode->BasisVector());
+    }
+    // If two additional vectors are passed, take their projections.
+    if(addBasisVector1 && addBasisVector2)
+    {
+      projection(k++) = arma::dot(dataset.col(sampledIndices[i]),
+                                  *addBasisVector1);
+      projection(k) = arma::dot(dataset.col(sampledIndices[i]),
+                                *addBasisVector2);
+    }
+    
+    // Calculate the Frobenius norm squared of the projected vector.
+    double frobProjection = arma::norm(projection, "frob");
+    double frobProjectionSquared = frobProjection * frobProjection;
+    
+    // Calculate the weighted projection magnitude.
+    weightedMagnitudes(i) = frobProjectionSquared / probabilities(i);
+  }
+  
+  // Compute mean and standard deviation of the weighted samples.
+  double mu = arma::mean(weightedMagnitudes);
+  double sigma = arma::stddev(weightedMagnitudes);
+  
+  if(!sigma)
+  {
+    node->L2Error(node->FrobNormSquared() - mu);
+    return (node->FrobNormSquared() - mu);
+  }
+  
+  // Fit a normal distribution using the calculated statistics, and calculate a
+  // lower bound on the magnitudes for the passed 'delta' parameter.
+  boost::math::normal dist(mu, sigma);
+  double lowerBound = boost::math::quantile(dist, delta);
+  
+  // Upper bound on the subspace reconstruction error.
+  node->L2Error(node->FrobNormSquared() - lowerBound);
+  
+  return (node->FrobNormSquared() - lowerBound);
+}
+
+void CosineTree::ConstructBasis(CosineNodeQueue& treeQueue)
+{
+  // Initialize basis as matrix of zeros.
+  basis.zeros(dataset.n_rows, treeQueue.size());
+  
+  // Variables for iterating through the priority queue.
+  CosineTree *currentNode;
+  CosineNodeQueue::const_iterator i = treeQueue.begin();
+  
+  // Transfer basis vectors from the queue to the basis matrix.
+  size_t j = 0;
+  for(; i != treeQueue.end(); i++, j++)
+  {
+    currentNode = *i;
+    basis.col(j) = currentNode->BasisVector();
+  }
+}
+
+void CosineTree::CosineNodeSplit()
+{
+  //! If less than two nodes, splitting does not make sense.
+  if(numColumns < 3) return;
+  
+  //! Calculate cosines with respect to the splitting point.
+  arma::vec cosines;
+  CalculateCosines(cosines);
+  
+  //! Compute maximum and minimum cosine values.
+  double cosineMax, cosineMin;
+  cosineMax = arma::max(cosines % (cosines < 1));
+  cosineMin = arma::min(cosines);
+  
+  std::vector<size_t> leftIndices, rightIndices;
+  
+  // Split columns into left and right children. The splitting condition for the
+  // column to be in the left child is as follows:
+  // 			cos_max - cos(i) <= cos(i) - cos_min
+  for(size_t i = 0; i < numColumns; i++)
+  {
+    if(cosineMax - cosines(i) <= cosines(i) - cosineMin)
+    {
+      leftIndices.push_back(i);
+    }
+    else
+    {
+      rightIndices.push_back(i);
+    }
+  }
+  
+  // Split the node into left and right children.
+  left = new CosineTree(*this, leftIndices);
+  right = new CosineTree(*this, rightIndices);
+}
+
+void CosineTree::ColumnSamplesLS(std::vector<size_t>& sampledIndices,
+                                 arma::vec& probabilities,
+                                 size_t numSamples)
+{
+  // Initialize the cumulative distribution vector size.
+  arma::vec cDistribution;
+  cDistribution.zeros(numColumns + 1);
+  
+  // Calculate cumulative length-squared distribution for the node.
+  for(size_t i = 0; i < numColumns; i++)
+  {
+    cDistribution(i+1) = cDistribution(i) + l2NormsSquared(i) / frobNormSquared;
+  }
+  
+  // Intialize sizes of the 'sampledIndices' and 'probabilities' vectors.
+  sampledIndices.resize(numSamples);
+  probabilities.zeros(numSamples);
+  
+  for(size_t i = 0; i < numSamples; i++)
+  {
+    // Generate a random value for sampling.
+    double randValue = arma::randu();
+    size_t start = 0, end = numColumns, searchIndex;
+    
+    // Sample from the distribution and store corresponding probability.
+    searchIndex = BinarySearch(cDistribution, randValue, start, end);
+    sampledIndices[i] = indices[searchIndex];
+    probabilities(i) = l2NormsSquared(searchIndex) / frobNormSquared;
+  }
+}
+
+size_t CosineTree::ColumnSampleLS()
+{
+  // If only one element is present, there can only be one sample.
+  if(numColumns < 2)
+  {
+    return 0;
+  }
+
+  // Initialize the cumulative distribution vector size.
+  arma::vec cDistribution;
+  cDistribution.zeros(numColumns + 1);
+  
+  // Calculate cumulative length-squared distribution for the node.
+  for(size_t i = 0; i < numColumns; i++)
+  {
+    cDistribution(i+1) = cDistribution(i) + l2NormsSquared(i) / frobNormSquared;
+  }
+  
+  // Generate a random value for sampling.
+  double randValue = arma::randu();
+  size_t start = 0, end = numColumns;
+  
+  // Sample from the distribution.
+  return BinarySearch(cDistribution, randValue, start, end);
+}
+
+size_t CosineTree::BinarySearch(arma::vec& cDistribution,
+                                double value,
+                                size_t start,
+                                size_t end)
+{
+  size_t pivot = (start + end) / 2;
+  
+  // If pivot is zero, first point is the sampled point.
+  if(!pivot)
+  {
+    return pivot;
+  }
+  
+  // Binary search recursive algorithm.
+  if(value > cDistribution(pivot - 1) && value <= cDistribution(pivot))
+  {
+    return (pivot - 1);
+  }
+  else if(value < cDistribution(pivot - 1))
+  {
+    return BinarySearch(cDistribution, value, start, pivot - 1);
+  }
+  else
+  {
+    return BinarySearch(cDistribution, value, pivot + 1, end);
+  }
+}
+
+void CosineTree::CalculateCosines(arma::vec& cosines)
+{
+  // Initialize cosine vector as a vector of zeros.
+  cosines.zeros(numColumns);
+  
+  for(size_t i = 0; i < numColumns; i++)
+  {
+    // If norm is zero, store cosine value as zero. Else, calculate cosine value
+    // between two vectors.
+    if(l2NormsSquared(i) == 0)
+    {
+      cosines(i) = 0;
+    }
+    else
+    {
+      cosines(i) = arma::norm_dot(dataset.col(indices[splitPointIndex]),
+                                  dataset.col(indices[i]));
+    }
+  }
+}
+
+void CosineTree::CalculateCentroid()
+{
+  // Initialize centroid as vector of zeros.
+  centroid.zeros(dataset.n_rows);
+  
+  // Calculate centroid of columns in the node.
+  for(size_t i = 0; i < numColumns; i++)
+  {
+    centroid += dataset.col(indices[i]);
+  }
+  centroid /= numColumns;
+}
+
+}; // namespace tree
+}; // namespace mlpack

Modified: mlpack/trunk/src/mlpack/core/tree/cosine_tree/cosine_tree.hpp
==============================================================================
--- mlpack/trunk/src/mlpack/core/tree/cosine_tree/cosine_tree.hpp	(original)
+++ mlpack/trunk/src/mlpack/core/tree/cosine_tree/cosine_tree.hpp	Thu Jul  3 09:08:12 2014
@@ -243,7 +243,4 @@
 }; // namespace tree
 }; // namespace mlpack
 
-// Include implementation.
-#include "cosine_tree_impl.hpp"
-
 #endif

Modified: mlpack/trunk/src/mlpack/methods/CMakeLists.txt
==============================================================================
--- mlpack/trunk/src/mlpack/methods/CMakeLists.txt	(original)
+++ mlpack/trunk/src/mlpack/methods/CMakeLists.txt	Thu Jul  3 09:08:12 2014
@@ -23,6 +23,7 @@
 #  lmf
   pca
   perceptron
+  quic_svd
   radical
   range_search
   rann

Modified: mlpack/trunk/src/mlpack/methods/CMakeLists.txt~
==============================================================================
--- mlpack/trunk/src/mlpack/methods/CMakeLists.txt~	(original)
+++ mlpack/trunk/src/mlpack/methods/CMakeLists.txt~	Thu Jul  3 09:08:12 2014
@@ -2,6 +2,7 @@
 set(DIRS
   amf
   cf
+  decision_stump
   det
   emst
   fastmks
@@ -21,6 +22,8 @@
   nmf
 #  lmf
   pca
+  perceptron
+#  quic_svd
   radical
   range_search
   rann

Added: mlpack/trunk/src/mlpack/methods/quic_svd/CMakeLists.txt
==============================================================================
--- (empty file)
+++ mlpack/trunk/src/mlpack/methods/quic_svd/CMakeLists.txt	Thu Jul  3 09:08:12 2014
@@ -0,0 +1,15 @@
+# Define the files we need to compile.
+# Anything not in this list will not be compiled into MLPACK.
+set(SOURCES
+  quic_svd.hpp
+  quic_svd_impl.hpp
+)
+
+# Add directory name to sources.
+set(DIR_SRCS)
+foreach(file ${SOURCES})
+  set(DIR_SRCS ${DIR_SRCS} ${CMAKE_CURRENT_SOURCE_DIR}/${file})
+endforeach()
+# Append sources (with directory name) to list of all MLPACK sources (used at
+# the parent scope).
+set(MLPACK_SRCS ${MLPACK_SRCS} ${DIR_SRCS} PARENT_SCOPE)

Added: mlpack/trunk/src/mlpack/methods/quic_svd/quic_svd.hpp
==============================================================================
--- (empty file)
+++ mlpack/trunk/src/mlpack/methods/quic_svd/quic_svd.hpp	Thu Jul  3 09:08:12 2014
@@ -0,0 +1,70 @@
+/**
+ * @file quic_svd.hpp
+ * @author Siddharth Agrawal
+ *
+ * An implementation of QUIC-SVD.
+ */
+#ifndef __MLPACK_METHODS_QUIC_SVD_QUIC_SVD_HPP
+#define __MLPACK_METHODS_QUIC_SVD_QUIC_SVD_HPP
+
+#include <mlpack/core.hpp>
+#include <mlpack/core/tree/cosine_tree/cosine_tree.hpp>
+
+namespace mlpack {
+namespace svd {
+
+class QUIC_SVD
+{
+ public:
+ 
+  /**
+   * Constructor which implements the QUIC-SVD algorithm. The function calls the
+   * CosineTree constructor to create a subspace basis, where the original
+   * matrix's projection has minimum reconstruction error. The constructor then
+   * uses the ExtractSVD() function to calculate the SVD of the original dataset
+   * in that subspace.
+   *
+   * @param dataset Matrix for which SVD is calculated.
+   * @param u First unitary matrix.
+   * @param v Second unitary matrix.
+   * @param sigma Diagonal matrix of singular values.
+   * @param epsilon Error tolerance fraction for calculated subspace.
+   * @param delta Cumulative probability for Monte Carlo error lower bound.
+   */
+  QUIC_SVD(const arma::mat& dataset,
+           arma::mat& u,
+           arma::mat& v,
+           arma::mat& sigma,
+           const double epsilon = 0.03,
+           const double delta = 0.1);
+  
+  /**
+   * This function uses the vector subspace created using a cosine tree to
+   * calculate an approximate SVD of the original matrix.
+   *
+   * @param u First unitary matrix.
+   * @param v Second unitary matrix.
+   * @param sigma Diagonal matrix of singular values.
+   */
+  void ExtractSVD(arma::mat& u,
+                  arma::mat& v,
+                  arma::mat& sigma);
+  
+ private:
+  //! Matrix for which cosine tree is constructed.
+  const arma::mat& dataset;
+  //! Error tolerance fraction for calculated subspace.
+  double epsilon;
+  //! Cumulative probability for Monte Carlo error lower bound.
+  double delta;
+  //! Subspace basis of the input dataset.
+  arma::mat basis;
+};
+
+}; // namespace svd
+}; // namespace mlpack
+
+// Include implementation.
+#include "quic_svd_impl.hpp"
+
+#endif

Added: mlpack/trunk/src/mlpack/methods/quic_svd/quic_svd_impl.hpp
==============================================================================
--- (empty file)
+++ mlpack/trunk/src/mlpack/methods/quic_svd/quic_svd_impl.hpp	Thu Jul  3 09:08:12 2014
@@ -0,0 +1,82 @@
+/**
+ * @file quic_svd_impl.hpp
+ * @author Siddharth Agrawal
+ *
+ * An implementation of QUIC-SVD.
+ */
+#ifndef __MLPACK_METHODS_QUIC_SVD_QUIC_SVD_IMPL_HPP
+#define __MLPACK_METHODS_QUIC_SVD_QUIC_SVD_IMPL_HPP
+
+// In case it hasn't been included yet.
+#include "quic_svd.hpp"
+
+using namespace mlpack::tree;
+
+namespace mlpack {
+namespace svd {
+
+QUIC_SVD::QUIC_SVD(const arma::mat& dataset,
+                   arma::mat& u,
+                   arma::mat& v,
+                   arma::mat& sigma,
+                   const double epsilon,
+                   const double delta) :
+    dataset(dataset),
+    epsilon(epsilon),
+    delta(delta)
+{
+  // Since columns are sample in the implementation, the matrix is transposed if
+  // necessary for maximum speedup.
+  CosineTree* ctree;
+  if(dataset.n_cols > dataset.n_rows)
+    ctree = new CosineTree(dataset, epsilon, delta);
+  else
+    ctree = new CosineTree(dataset.t(), epsilon, delta);
+    
+  // Get subspace basis by creating the cosine tree.
+  ctree->GetFinalBasis(basis);
+  
+  // Use the ExtractSVD algorithm mentioned in the paper to extract the SVD of
+  // the original dataset in the obtained subspace.
+  ExtractSVD(u, v, sigma);
+}
+
+void QUIC_SVD::ExtractSVD(arma::mat& u,
+                          arma::mat& v,
+                          arma::mat& sigma)
+{
+  // Calculate A * V_hat, necessary for further calculations.
+  arma::mat projectedMat;
+  if(dataset.n_cols > dataset.n_rows)
+    projectedMat = dataset.t() * basis;
+  else
+    projectedMat = dataset * basis;
+  
+  // Calculate the squared projected matrix.
+  arma::mat projectedMatSquared = projectedMat.t() * projectedMat;
+
+  // Calculate the SVD of the above matrix.
+  arma::mat uBar, vBar;
+  arma::vec sigmaBar;
+  arma::svd(uBar, sigmaBar, vBar, projectedMatSquared);
+
+  // Calculate the approximate SVD of the original matrix, using the SVD of the
+  // squared projected matrix.
+  v = basis * vBar;
+  sigma = arma::sqrt(diagmat(sigmaBar));
+  u = projectedMat * vBar * sigma.i();
+  
+  // Since columns are sampled, the unitary matrices have to be exchanged, if
+  // the transposed matrix is not passed.
+  if(dataset.n_cols > dataset.n_rows)
+  {
+    arma::mat tempMat = u;
+    u = v;
+    v = tempMat;
+  }
+}
+
+}; // namespace svd
+}; // namespace mlpack
+
+#endif

Modified: mlpack/trunk/src/mlpack/tests/CMakeLists.txt
==============================================================================
--- mlpack/trunk/src/mlpack/tests/CMakeLists.txt	(original)
+++ mlpack/trunk/src/mlpack/tests/CMakeLists.txt	Thu Jul  3 09:08:12 2014
@@ -36,6 +36,7 @@
   nmf_test.cpp
   pca_test.cpp
   perceptron_test.cpp
+  quic_svd_test.cpp
   radical_test.cpp
   range_search_test.cpp
   rectangle_tree_test.cpp

Modified: mlpack/trunk/src/mlpack/tests/CMakeLists.txt~
==============================================================================
--- mlpack/trunk/src/mlpack/tests/CMakeLists.txt~	(original)
+++ mlpack/trunk/src/mlpack/tests/CMakeLists.txt~	Thu Jul  3 09:08:12 2014
@@ -8,6 +8,7 @@
   aug_lagrangian_test.cpp
   cf_test.cpp
   cli_test.cpp
+  cosine_tree_test.cpp
   decision_stump_test.cpp
   det_test.cpp
   distribution_test.cpp
@@ -35,8 +36,11 @@
   nmf_test.cpp
   pca_test.cpp
   perceptron_test.cpp
+#  quic_svd_test.cpp
   radical_test.cpp
   range_search_test.cpp
+  rectangle_tree_test.cpp
+  sa_test.cpp
   save_restore_utility_test.cpp
   sgd_test.cpp
   sort_policy_test.cpp

Added: mlpack/trunk/src/mlpack/tests/quic_svd_test.cpp
==============================================================================
--- (empty file)
+++ mlpack/trunk/src/mlpack/tests/quic_svd_test.cpp	Thu Jul  3 09:08:12 2014
@@ -0,0 +1,42 @@
+/**
+ * @file quic_svd_test.cpp
+ * @author Siddharth Agrawal
+ *
+ * Test file for QUIC-SVD class.
+ */
+
+#include <mlpack/core.hpp>
+#include <mlpack/methods/quic_svd/quic_svd.hpp>
+
+#include <boost/test/unit_test.hpp>
+#include "old_boost_test_definitions.hpp"
+
+BOOST_AUTO_TEST_SUITE(QUICSVDTest);
+
+using namespace mlpack;
+using namespace mlpack::svd;
+
+/**
+ * The reconstruction error of the obtained SVD should be small.
+ */
+BOOST_AUTO_TEST_CASE(QUICSVDReconstructionError)
+{
+  // Load the dataset.
+  arma::mat dataset;
+  data::Load("test_data_3_1000.csv", dataset);
+	
+	// Obtain the SVD using default parameters.
+  arma::mat u, v, sigma;
+  QUIC_SVD quicsvd(dataset, u, v, sigma);
+  
+  // Reconstruct the matrix using the SVD.
+  arma::mat reconstruct;
+  reconstruct = u * sigma * v.t();
+  
+  // The relative reconstruction error should be small.
+  double relativeError = arma::norm(dataset - reconstruct, "frob") /
+                         arma::norm(dataset, "frob");                         
+  BOOST_REQUIRE_SMALL(relativeError, 1e-5);
+}
+
+BOOST_AUTO_TEST_SUITE_END();



More information about the mlpack-svn mailing list