[mlpack-svn] r16727 - in mlpack/trunk: . src/mlpack/core/tree src/mlpack/core/tree/cosine_tree src/mlpack/tests

fastlab-svn at coffeetalk-1.cc.gatech.edu fastlab-svn at coffeetalk-1.cc.gatech.edu
Fri Jun 27 04:28:07 EDT 2014


Author: siddharth.950
Date: Fri Jun 27 04:28:07 2014
New Revision: 16727

Log:
Adding new cosine_tree code.

Added:
   mlpack/trunk/src/mlpack/core/tree/cosine_tree/
   mlpack/trunk/src/mlpack/core/tree/cosine_tree/cosine_node.hpp
   mlpack/trunk/src/mlpack/core/tree/cosine_tree/cosine_node_impl.hpp
   mlpack/trunk/src/mlpack/core/tree/cosine_tree/cosine_tree.hpp
   mlpack/trunk/src/mlpack/core/tree/cosine_tree/cosine_tree_impl.hpp
   mlpack/trunk/src/mlpack/tests/cosine_tree_test.cpp
Modified:
   mlpack/trunk/CMakeLists.txt
   mlpack/trunk/src/mlpack/core/tree/CMakeLists.txt
   mlpack/trunk/src/mlpack/tests/CMakeLists.txt
   mlpack/trunk/src/mlpack/tests/CMakeLists.txt~
   mlpack/trunk/src/mlpack/tests/tree_test.cpp

Modified: mlpack/trunk/CMakeLists.txt
==============================================================================
--- mlpack/trunk/CMakeLists.txt	(original)
+++ mlpack/trunk/CMakeLists.txt	Fri Jun 27 04:28:07 2014
@@ -170,36 +170,16 @@
 # Unfortunately this configuration variable is necessary and will need to be
 # updated as time goes on and new versions are released.
 set(Boost_ADDITIONAL_VERSIONS
-  "1.41" "1.41.0" "1.42" "1.42.0" "1.43" "1.43.0" "1.44" "1.44.0" "1.45.0"
-  "1.46.0" "1.46.1" "1.47.0" "1.48.0" "1.49.0")
+  "1.49.0" "1.50.0" "1.51.0" "1.52.0" "1.53.0" "1.54.0" "1.55.0")
 find_package(Boost
     COMPONENTS
       program_options
       unit_test_framework
+      random
     REQUIRED
 )
 include_directories(${Boost_INCLUDE_DIRS})
 
-# Save the actual link paths (because they will get overwritten if we discover
-# we need to find Boost.Random too).
-set(Boost_BACKUP_LIBRARIES ${Boost_LIBRARIES})
-
-# We need to include Boost.Random, but only if newer than 1.45 (as of 1.46 it
-# became a separate package with its own linkable library object).
-if(Boost_MAJOR_VERSION EQUAL 1 AND Boost_MINOR_VERSION GREATER 45)
-  find_package(Boost
-      COMPONENTS
-          random
-      REQUIRED
-  )
-
-  # Restore actual link locations of the other Boost libraries.
-  set(Boost_LIBRARIES ${Boost_LIBRARIES} ${Boost_BACKUP_LIBRARIES})
-
-  # This may be redundant.
-  include_directories(${Boost_INCLUDE_DIRS})
-
-endif(Boost_MAJOR_VERSION EQUAL 1 AND Boost_MINOR_VERSION GREATER 45)
 link_directories(${Boost_LIBRARY_DIRS})
 
 # On Windows, automatic linking is performed, so we don't need to worry about

Modified: mlpack/trunk/src/mlpack/core/tree/CMakeLists.txt
==============================================================================
--- mlpack/trunk/src/mlpack/core/tree/CMakeLists.txt	(original)
+++ mlpack/trunk/src/mlpack/core/tree/CMakeLists.txt	Fri Jun 27 04:28:07 2014
@@ -13,10 +13,10 @@
   binary_space_tree/single_tree_traverser_impl.hpp
   binary_space_tree/traits.hpp
   bounds.hpp
+  cosine_tree/cosine_node.hpp
+  cosine_tree/cosine_node_impl.hpp
   cosine_tree/cosine_tree_impl.hpp
   cosine_tree/cosine_tree.hpp
-  cosine_tree/cosine_tree_builder.hpp
-  cosine_tree/cosine_tree_builder_impl.hpp
   cover_tree/cover_tree.hpp
   cover_tree/cover_tree_impl.hpp
   cover_tree/first_point_is_root.hpp

Added: mlpack/trunk/src/mlpack/core/tree/cosine_tree/cosine_node.hpp
==============================================================================
--- (empty file)
+++ mlpack/trunk/src/mlpack/core/tree/cosine_tree/cosine_node.hpp	Fri Jun 27 04:28:07 2014
@@ -0,0 +1,178 @@
+/**
+ * @file cosine_node.hpp
+ * @author Siddharth Agrawal
+ *
+ * Definition of Cosine Node.
+ */
+ 
+#ifndef __MLPACK_CORE_TREE_COSINE_TREE_COSINE_NODE_HPP
+#define __MLPACK_CORE_TREE_COSINE_TREE_COSINE_NODE_HPP
+
+#include <mlpack/core.hpp>
+
+namespace mlpack {
+namespace tree {
+
+class CosineNode
+{
+ public:
+ 
+  /**
+   * CosineNode constructor for the root node of the tree. It initializes the
+   * necessary variables required for splitting of the node, and building the
+   * tree further. It takes a pointer to the input matrix and calculates the
+   * relevant variables using it.
+   *
+   * @param dataset Matrix for which cosine tree is constructed.
+   */
+  CosineNode(const arma::mat& dataset);
+  
+  /**
+   * CosineNode constructor for nodes other than the root node of the tree. It
+   * takes in a pointer to the parent node and a list of column indices which
+   * mentions the columns to be included in the node. The function calculate the
+   * relevant variables just like the constructor above.
+   *
+   * @param parentNode Pointer to the parent CosineNode.
+   * @param subIndices Pointer to vector of column indices to be included.
+   */
+  CosineNode(CosineNode& parentNode, const std::vector<size_t>& subIndices);
+  
+  /**
+   * This function splits the CosineNode into two children based on the cosines
+   * of the columns contained in the node, with respect to the sampled splitting
+   * point. The function also calls the CosineNode constructor for the children.
+   */
+  void CosineNodeSplit();
+  
+  /**
+   * Sample 'numSamples' points from the Length-Squared distribution of the
+   * CosineNode. The function uses 'l2NormsSquared' to calculate the cumulative
+   * probability distribution of the column vectors. The sampling is based on a
+   * randomly generated values in the range [0, 1].
+   */
+  void ColumnSamplesLS(std::vector<size_t>& sampledIndices, 
+                       arma::vec& probabilities, size_t numSamples);
+  
+  /**
+   * Sample a point from the Length-Squared distribution of the CosineNode. The
+   * function uses 'l2NormsSquared' to calculate the cumulative probability
+   * distribution of the column vectors. The sampling is based on a randomly
+   * generated value in the range [0, 1].
+   */
+  size_t ColumnSampleLS();
+  
+  /**
+   * Sample a column based on the cumulative Length-Squared distribution of the
+   * CosineNode, and a randomly generated value in the range [0, 1]. Binary
+   * search is more efficient than searching linearly for the same. This leads
+   * a significant speedup when there are large number of columns to choose from
+   * and when a number of samples are to be drawn from the distribution.
+   *
+   * @param cDistribution Cumulative LS distibution of columns in the node.
+   * @param value Randomly generated value in the range [0, 1].
+   * @param start Starting index of the distribution interval to search in.
+   * @param end Ending index of the distribution interval to search in.
+   */
+  size_t BinarySearch(arma::vec& cDistribution, double value, size_t start,
+                      size_t end);
+  
+  /**
+   * Calculate cosines of the columns present in the node, with respect to the
+   * sampled splitting point. The calculated cosine values are useful for
+   * splitting the node into its children.
+   *
+   * @param cosines Vector to store the cosine values in.
+   */
+  void CalculateCosines(arma::vec& cosines);
+  
+  /**
+   * Calculate centroid of the columns present in the node. The calculated
+   * centroid is used as a basis vector for the cosine tree being constructed.
+   */
+  void CalculateCentroid();
+  
+  //! Get pointer to the dataset matrix.
+  const arma::mat& GetDataset() const { return dataset; }
+  
+  //! Get the indices of columns in the node.
+  std::vector<size_t>& VectorIndices() { return indices; }
+  
+  //! Set the Monte Carlo error.
+  void L2Error(const double error) { this->l2Error = error; }
+  
+  //! Get the Monte Carlo error.
+  double L2Error() const { return l2Error; }
+  
+  //! Get pointer to the centroid vector.
+  arma::vec& Centroid() { return centroid; }
+  
+  //! Set the basis vector of the node.
+  void BasisVector(arma::vec& bVector) { this->basisVector = bVector; }
+  
+  //! Get the basis vector of the node.
+  arma::vec& BasisVector() { return basisVector; }
+  
+  //! Get pointer to the left child of the node.
+  CosineNode* Left() { return left; }
+  
+  //! Get pointer to the right child of the node.
+  CosineNode* Right() { return right; }
+  
+  //! Get number of columns of input matrix in the node.
+  size_t NumColumns() const { return numColumns; }
+  
+  //! Get the Frobenius norm squared of columns in the node.
+  double FrobNormSquared() const { return frobNormSquared; }
+  
+  //! Get the column index of split point of the node.
+  size_t SplitPointIndex() const { return indices[splitPointIndex]; }
+ 
+ private:
+  //! Matrix for which cosine tree is constructed.
+  const arma::mat& dataset;
+  //! Parent of the node.
+  CosineNode* parent;
+  //! Right child of the node.
+  CosineNode* right;
+  //! Left child of the node.
+  CosineNode* left;
+  //! Indices of columns of input matrix in the node.
+  std::vector<size_t> indices;
+  //! L2-norm squared of columns in the node.
+  arma::vec l2NormsSquared;
+  //! Centroid of columns of input matrix in the node.
+  arma::vec centroid;
+  //! Orthonormalized basis vector of the node.
+  arma::vec basisVector;
+  //! Index of split point of cosine node.
+  size_t splitPointIndex;
+  //! Number of columns of input matrix in the node.
+  size_t numColumns;
+  //! Monte Carlo error for this node.
+  double l2Error;
+  //! Frobenius norm squared of columns in the node.
+  double frobNormSquared;
+  
+  // Friend class to facilitate construction of priority queue.
+  friend class CompareCosineNode;
+};
+
+class CompareCosineNode
+{
+ public:
+ 
+  // Comparison function for construction of priority queue.
+  bool operator() (const CosineNode* a, const CosineNode* b) const
+  {
+    return a->l2Error < b->l2Error;
+  }
+};
+
+}; // namespace tree
+}; // namespace mlpack
+
+// Include implementation.
+#include "cosine_node_impl.hpp"
+
+#endif

Added: mlpack/trunk/src/mlpack/core/tree/cosine_tree/cosine_node_impl.hpp
==============================================================================
--- (empty file)
+++ mlpack/trunk/src/mlpack/core/tree/cosine_tree/cosine_node_impl.hpp	Fri Jun 27 04:28:07 2014
@@ -0,0 +1,230 @@
+/**
+ * @file cosine_node_impl.hpp
+ * @author Siddharth Agrawal
+ *
+ * Implementation of cosine node.
+ */
+#ifndef __MLPACK_CORE_TREE_COSINE_TREE_COSINE_NODE_IMPL_HPP
+#define __MLPACK_CORE_TREE_COSINE_TREE_COSINE_NODE_IMPL_HPP
+
+// In case it wasn't included already for some reason.
+#include "cosine_node.hpp"
+
+namespace mlpack {
+namespace tree {
+
+CosineNode::CosineNode(const arma::mat& dataset) :
+    dataset(dataset),
+    parent(NULL),
+    right(NULL),
+    left(NULL),
+    numColumns(dataset.n_cols)
+{  
+  // Initialize sizes of column indices and l2 norms.
+  indices.resize(numColumns);
+  l2NormsSquared.zeros(numColumns);
+  
+  // Set indices and calculate squared norms of the columns.
+  for(size_t i = 0; i < numColumns; i++)
+  {
+    indices[i] = i;
+    double l2Norm = arma::norm(dataset.col(i), 2);
+    l2NormsSquared(i) = l2Norm * l2Norm;
+  }
+  
+  // Frobenius norm of columns in the node.
+  frobNormSquared = arma::accu(l2NormsSquared);
+  
+  // Calculate centroid of columns in the node.
+  CalculateCentroid();
+  
+  splitPointIndex = ColumnSampleLS();
+}
+
+CosineNode::CosineNode(CosineNode& parentNode,
+                       const std::vector<size_t>& subIndices) :
+    dataset(parentNode.GetDataset()),
+    parent(&parentNode),
+    right(NULL),
+    left(NULL),
+    numColumns(subIndices.size())
+{
+  // Initialize sizes of column indices and l2 norms.
+  indices.resize(numColumns);
+  l2NormsSquared.zeros(numColumns);
+  
+  // Set indices and squared norms of the columns.
+  for(size_t i = 0; i < numColumns; i++)
+  {
+    indices[i] = parentNode.indices[subIndices[i]];
+    l2NormsSquared(i) = parentNode.l2NormsSquared(subIndices[i]);
+  }
+  
+  // Frobenius norm of columns in the node.
+  frobNormSquared = arma::accu(l2NormsSquared);
+  
+  // Calculate centroid of columns in the node.
+  CalculateCentroid();
+  
+  splitPointIndex = ColumnSampleLS();
+}
+
+void CosineNode::CosineNodeSplit()
+{
+  //! If less than two nodes, splitting does not make sense.
+  if(numColumns < 3) return;
+  
+  //! Calculate cosines with respect to the splitting point.
+  arma::vec cosines;
+  CalculateCosines(cosines);
+  
+  //! Compute maximum and minimum cosine values.
+  double cosineMax, cosineMin;
+  cosineMax = arma::max(cosines % (cosines < 1));
+  cosineMin = arma::min(cosines);
+  
+  std::vector<size_t> leftIndices, rightIndices;
+  
+  // Split columns into left and right children. The splitting condition for the
+  // column to be in the left child is as follows:
+  // 			cos_max - cos(i) <= cos(i) - cos_min
+  for(size_t i = 0; i < numColumns; i++)
+  {
+    if(cosineMax - cosines(i) <= cosines(i) - cosineMin)
+    {
+      leftIndices.push_back(i);
+    }
+    else
+    {
+      rightIndices.push_back(i);
+    }
+  }
+  
+  // Split the node into left and right children.
+  left = new CosineNode(*this, leftIndices);
+  right = new CosineNode(*this, rightIndices);
+}
+
+void CosineNode::ColumnSamplesLS(std::vector<size_t>& sampledIndices,
+                                 arma::vec& probabilities,
+                                 size_t numSamples)
+{
+  // Initialize the cumulative distribution vector size.
+  arma::vec cDistribution;
+  cDistribution.zeros(numColumns + 1);
+  
+  // Calculate cumulative length-squared distribution for the node.
+  for(size_t i = 0; i < numColumns; i++)
+  {
+    cDistribution(i+1) = cDistribution(i) + l2NormsSquared(i) / frobNormSquared;
+  }
+  
+  // Intialize sizes of the 'sampledIndices' and 'probabilities' vectors.
+  sampledIndices.resize(numSamples);
+  probabilities.zeros(numSamples);
+  
+  for(size_t i = 0; i < numSamples; i++)
+  {
+    // Generate a random value for sampling.
+    double randValue = arma::randu();
+    size_t start = 0, end = numColumns, searchIndex;
+    
+    // Sample from the distribution and store corresponding probability.
+    searchIndex = BinarySearch(cDistribution, randValue, start, end);
+    sampledIndices[i] = indices[searchIndex];
+    probabilities(i) = l2NormsSquared(searchIndex) / frobNormSquared;
+  }
+}
+
+size_t CosineNode::ColumnSampleLS()
+{
+  // If only one element is present, there can only be one sample.
+  if(numColumns < 2)
+  {
+    return 0;
+  }
+
+  // Initialize the cumulative distribution vector size.
+  arma::vec cDistribution;
+  cDistribution.zeros(numColumns + 1);
+  
+  // Calculate cumulative length-squared distribution for the node.
+  for(size_t i = 0; i < numColumns; i++)
+  {
+    cDistribution(i+1) = cDistribution(i) + l2NormsSquared(i) / frobNormSquared;
+  }
+  
+  // Generate a random value for sampling.
+  double randValue = arma::randu();
+  size_t start = 0, end = numColumns;
+  
+  // Sample from the distribution.
+  return BinarySearch(cDistribution, randValue, start, end);
+}
+
+size_t CosineNode::BinarySearch(arma::vec& cDistribution,
+                                double value,
+                                size_t start,
+                                size_t end)
+{
+  size_t pivot = (start + end) / 2;
+  
+  // If pivot is zero, first point is the sampled point.
+  if(!pivot)
+  {
+    return pivot;
+  }
+  
+  // Binary search recursive algorithm.
+  if(value > cDistribution(pivot - 1) && value <= cDistribution(pivot))
+  {
+    return (pivot - 1);
+  }
+  else if(value < cDistribution(pivot - 1))
+  {
+    return BinarySearch(cDistribution, value, start, pivot - 1);
+  }
+  else
+  {
+    return BinarySearch(cDistribution, value, pivot + 1, end);
+  }
+}
+
+void CosineNode::CalculateCosines(arma::vec& cosines)
+{
+  // Initialize cosine vector as a vector of zeros.
+  cosines.zeros(numColumns);
+  
+  for(size_t i = 0; i < numColumns; i++)
+  {
+    // If norm is zero, store cosine value as zero. Else, calculate cosine value
+    // between two vectors.
+    if(l2NormsSquared(i) == 0)
+    {
+      cosines(i) = 0;
+    }
+    else
+    {
+      cosines(i) = arma::norm_dot(dataset.col(indices[splitPointIndex]),
+                                  dataset.col(indices[i]));
+    }
+  }
+}
+
+void CosineNode::CalculateCentroid()
+{
+  // Initialize centroid as vector of zeros.
+  centroid.zeros(dataset.n_rows);
+  
+  // Calculate centroid of columns in the node.
+  for(size_t i = 0; i < numColumns; i++)
+  {
+    centroid += dataset.col(indices[i]);
+  }
+  centroid /= numColumns;
+}
+
+}; // namespace tree
+}; // namespace mlpack
+
+#endif

Added: mlpack/trunk/src/mlpack/core/tree/cosine_tree/cosine_tree.hpp
==============================================================================
--- (empty file)
+++ mlpack/trunk/src/mlpack/core/tree/cosine_tree/cosine_tree.hpp	Fri Jun 27 04:28:07 2014
@@ -0,0 +1,103 @@
+/**
+ * @file cosine_tree.hpp
+ * @author Siddharth Agrawal
+ *
+ * Definition of Cosine Tree.
+ */
+ 
+#ifndef __MLPACK_CORE_TREE_COSINE_TREE_COSINE_TREE_HPP
+#define __MLPACK_CORE_TREE_COSINE_TREE_COSINE_TREE_HPP
+
+#include <mlpack/core.hpp>
+#include <boost/heap/priority_queue.hpp>
+
+#include "cosine_node.hpp"
+
+namespace mlpack {
+namespace tree {
+
+class CosineTree
+{
+ public:
+ 
+  // Type definition for CosineNode priority queue.
+  typedef boost::heap::priority_queue<CosineNode*,
+      boost::heap::compare<CompareCosineNode> > CosineNodeQueue;
+ 
+  /**
+   * Construct the CosineTree and the basis for the given matrix, and passed
+   * 'epsilon' and 'delta' parameters. The CosineTree is constructed by
+   * splitting nodes in the direction of maximum error, stored using a priority
+   * queue. Basis vectors are added from the left and right children of the
+   * split node. The basis vector from a node is the orthonormalized centroid of
+   * its columns. The splitting continues till the Monte Carlo estimate of the
+   * input matrix's projection on the obtained subspace is less than a fraction
+   * of the norm of the input matrix.
+   *
+   * @param dataset Matrix for which the CosineTree is constructed.
+   * @param epsilon Error tolerance fraction for calculated subspace.
+   * @param delta Cumulative probability for Monte Carlo error lower bound.
+   */
+  CosineTree(const arma::mat& dataset,
+             const double epsilon,
+             const double delta);
+  
+  /**
+   * Calculates the orthonormalization of the passed centroid, with respect to
+   * the current vector subspace.
+   *
+   * @param treeQueue Priority queue of cosine nodes.
+   * @param centroid Centroid of the node being added to the basis.
+   * @param newBasisVector Orthonormalized centroid of the node.
+   * @param addBasisVector Address to additional basis vector.
+   */                           
+  void ModifiedGramSchmidt(CosineNodeQueue& treeQueue,
+                           arma::vec& centroid,
+                           arma::vec& newBasisVector,
+                           arma::vec* addBasisVector = NULL);
+  
+  /**
+   * Estimates the squared error of the projection of the input node's matrix
+   * onto the current vector subspace. A normal distribution is fit using
+   * weighted norms of projections of samples drawn from the input node's matrix
+   * columns. The error is calculated as the difference between the Frobenius
+   * norm of the input node's matrix and lower bound of the normal distribution.
+   *
+   * @param node Node for which Monte Carlo estimate is calculated.
+   * @param treeQueue Priority queue of cosine nodes.
+   * @param addBasisVector1 Address to first additional basis vector.
+   * @param addBasisVector2 Address to second additional basis vector.
+   */                         
+  double MonteCarloError(CosineNode* node,
+                         CosineNodeQueue& treeQueue,
+                         arma::vec* addBasisVector1 = NULL,
+                         arma::vec* addBasisVector2 = NULL);
+  
+  /**
+   * Constructs the final basis matrix, after the cosine tree construction.
+   *
+   * @param treeQueue Priority queue of cosine nodes.
+   */                       
+  void ConstructBasis(CosineNodeQueue& treeQueue);
+  
+  //! Returns the basis of the constructed subspace.
+  void GetFinalBasis(arma::mat& finalBasis) { finalBasis = basis; }
+  
+ private:
+  //! Matrix for which cosine tree is constructed.
+  const arma::mat& dataset;
+  //! Error tolerance fraction for calculated subspace.
+  double epsilon;
+  //! Cumulative probability for Monte Carlo error lower bound.
+  double delta;
+  //! Subspace basis of the input dataset.
+  arma::mat basis;
+};
+
+}; // namespace tree
+}; // namespace mlpack
+
+// Include implementation.
+#include "cosine_tree_impl.hpp"
+
+#endif

Added: mlpack/trunk/src/mlpack/core/tree/cosine_tree/cosine_tree_impl.hpp
==============================================================================
--- (empty file)
+++ mlpack/trunk/src/mlpack/core/tree/cosine_tree/cosine_tree_impl.hpp	Fri Jun 27 04:28:07 2014
@@ -0,0 +1,222 @@
+/**
+ * @file cosine_tree_impl.hpp
+ * @author Siddharth Agrawal
+ *
+ * Implementation of cosine tree.
+ */
+#ifndef __MLPACK_CORE_TREE_COSINE_TREE_COSINE_TREE_IMPL_HPP
+#define __MLPACK_CORE_TREE_COSINE_TREE_COSINE_TREE_IMPL_HPP
+
+// In case it wasn't included already for some reason.
+#include "cosine_tree.hpp"
+
+#include <boost/math/distributions/normal.hpp>
+
+namespace mlpack {
+namespace tree {
+
+CosineTree::CosineTree(const arma::mat& dataset,
+                       const double epsilon,
+                       const double delta) :
+    dataset(dataset),
+    epsilon(epsilon),
+    delta(delta)
+{
+  // Declare the cosine tree priority queue.
+  CosineNodeQueue treeQueue;
+  
+  // Define root node of the tree and add it to the queue.
+  CosineNode root(dataset);
+  arma::vec tempVector = arma::zeros(dataset.n_rows);
+  root.L2Error(0);
+  root.BasisVector(tempVector);
+  treeQueue.push(&root);
+  
+  // Initialize Monte Carlo error estimate for comparison.
+  double monteCarloError = root.FrobNormSquared();
+  
+  while(monteCarloError > epsilon * root.FrobNormSquared())
+  {
+    // Pop node from queue with highest projection error.
+    CosineNode* currentNode;
+    currentNode = treeQueue.top();
+    treeQueue.pop();
+    
+    // Split the node into left and right children.
+    currentNode->CosineNodeSplit();
+    
+    // Obtain pointers to the left and right children of the current node.
+    CosineNode *currentLeft, *currentRight;
+    currentLeft = currentNode->Left();
+    currentRight = currentNode->Right();
+    
+    // Calculate basis vectors of left and right children.
+    arma::vec lBasisVector, rBasisVector;
+    
+    ModifiedGramSchmidt(treeQueue, currentLeft->Centroid(), lBasisVector);
+    ModifiedGramSchmidt(treeQueue, currentRight->Centroid(), rBasisVector,
+                        &lBasisVector);
+    
+    // Add basis vectors to their respective nodes.
+    currentLeft->BasisVector(lBasisVector);
+    currentRight->BasisVector(rBasisVector);
+    
+    // Calculate Monte Carlo error estimates for child nodes.
+    MonteCarloError(currentLeft, treeQueue, &lBasisVector, &rBasisVector);
+    MonteCarloError(currentRight, treeQueue, &lBasisVector, &rBasisVector);
+    
+    // Push child nodes into the priority queue.
+    treeQueue.push(currentLeft);
+    treeQueue.push(currentRight);
+    
+    // Calculate Monte Carlo error estimate for the root node.
+    monteCarloError = MonteCarloError(&root, treeQueue);
+    
+    std::cout << monteCarloError / root.FrobNormSquared() << "\n";
+  }
+  
+  // Construct the subspace basis from the current priority queue.
+  ConstructBasis(treeQueue);
+}
+
+void CosineTree::ModifiedGramSchmidt(CosineNodeQueue& treeQueue,
+                                     arma::vec& centroid,
+                                     arma::vec& newBasisVector,
+                                     arma::vec* addBasisVector)
+{
+  // Set new basis vector to centroid.
+  newBasisVector = centroid;
+
+  // Variables for iterating throught the priority queue.
+  CosineNode *currentNode;
+  CosineNodeQueue::const_iterator i = treeQueue.begin();
+
+  // For every vector in the current basis, remove its projection from the
+  // centroid.
+  for(; i != treeQueue.end(); i++)
+  {
+    currentNode = *i;
+    
+    double projection = arma::dot(currentNode->BasisVector(), centroid);
+    newBasisVector -= projection * currentNode->BasisVector();
+  }
+  
+  // If additional basis vector is passed, take it into account.
+  if(addBasisVector)
+  {
+    double projection = arma::dot(*addBasisVector, centroid);
+    newBasisVector -= *addBasisVector * projection;
+  }
+  
+  // Normalize the modified centroid vector.
+  if(arma::norm(newBasisVector, 2))
+    newBasisVector /= arma::norm(newBasisVector, 2);
+}
+
+double CosineTree::MonteCarloError(CosineNode* node,
+                                   CosineNodeQueue& treeQueue,
+                                   arma::vec* addBasisVector1,
+                                   arma::vec* addBasisVector2)
+{
+  std::vector<size_t> sampledIndices;
+  arma::vec probabilities;
+  
+  // Sample O(log m) points from the input node's distribution.
+  // 'm' is the number of columns present in the node.
+  size_t numSamples = log(node->NumColumns()) + 1;  
+  node->ColumnSamplesLS(sampledIndices, probabilities, numSamples);
+  
+  // Get pointer to the original dataset.
+  arma::mat dataset = node->GetDataset();
+  
+  // Initialize weighted projection magnitudes as zeros.
+  arma::vec weightedMagnitudes;
+  weightedMagnitudes.zeros(numSamples);
+  
+  // Set size of projection vector, depending on whether additional basis
+  // vectors are passed.
+  size_t projectionSize;
+  if(addBasisVector1 && addBasisVector2)
+    projectionSize = treeQueue.size() + 2;
+  else
+    projectionSize = treeQueue.size();
+  
+  // For each sample, calculate the weighted projection onto the current basis.
+  for(size_t i = 0; i < numSamples; i++)
+  {
+    // Initialize projection as a vector of zeros.
+    arma::vec projection;
+    projection.zeros(projectionSize);
+
+    CosineNode *currentNode;
+    CosineNodeQueue::const_iterator j = treeQueue.begin();
+  
+    size_t k = 0;
+    // Compute the projection of the sampled vector onto the existing subspace.
+    for(; j != treeQueue.end(); j++, k++)
+    {
+      currentNode = *j;
+    
+      projection(k) = arma::dot(dataset.col(sampledIndices[i]),
+                                currentNode->BasisVector());
+    }
+    // If two additional vectors are passed, take their projections.
+    if(addBasisVector1 && addBasisVector2)
+    {
+      projection(k++) = arma::dot(dataset.col(sampledIndices[i]),
+                                  *addBasisVector1);
+      projection(k) = arma::dot(dataset.col(sampledIndices[i]),
+                                *addBasisVector2);
+    }
+    
+    // Calculate the Frobenius norm squared of the projected vector.
+    double frobProjection = arma::norm(projection, "frob");
+    double frobProjectionSquared = frobProjection * frobProjection;
+    
+    // Calculate the weighted projection magnitude.
+    weightedMagnitudes(i) = frobProjectionSquared / probabilities(i);
+  }
+  
+  // Compute mean and standard deviation of the weighted samples.
+  double mu = arma::mean(weightedMagnitudes);
+  double sigma = arma::stddev(weightedMagnitudes);
+  
+  if(!sigma)
+  {
+    node->L2Error(node->FrobNormSquared() - mu);
+    return (node->FrobNormSquared() - mu);
+  }
+  
+  // Fit a normal distribution using the calculated statistics, and calculate a
+  // lower bound on the magnitudes for the passed 'delta' parameter.
+  boost::math::normal dist(mu, sigma);
+  double lowerBound = boost::math::quantile(dist, delta);
+  
+  // Upper bound on the subspace reconstruction error.
+  node->L2Error(node->FrobNormSquared() - lowerBound);
+  
+  return (node->FrobNormSquared() - lowerBound);
+}
+
+void CosineTree::ConstructBasis(CosineNodeQueue& treeQueue)
+{
+  // Initialize basis as matrix of zeros.
+  basis.zeros(dataset.n_rows, treeQueue.size());
+  
+  // Variables for iterating through the priority queue.
+  CosineNode *currentNode;
+  CosineNodeQueue::const_iterator i = treeQueue.begin();
+  
+  // Transfer basis vectors from the queue to the basis matrix.
+  size_t j = 0;
+  for(; i != treeQueue.end(); i++, j++)
+  {
+    currentNode = *i;
+    basis.col(j) = currentNode->BasisVector();
+  }
+}
+
+}; // namespace tree
+}; // namespace mlpack
+
+#endif

Modified: mlpack/trunk/src/mlpack/tests/CMakeLists.txt
==============================================================================
--- mlpack/trunk/src/mlpack/tests/CMakeLists.txt	(original)
+++ mlpack/trunk/src/mlpack/tests/CMakeLists.txt	Fri Jun 27 04:28:07 2014
@@ -8,6 +8,7 @@
   aug_lagrangian_test.cpp
   cf_test.cpp
   cli_test.cpp
+  cosine_tree_test.cpp
   decision_stump_test.cpp
   det_test.cpp
   distribution_test.cpp

Modified: mlpack/trunk/src/mlpack/tests/CMakeLists.txt~
==============================================================================
--- mlpack/trunk/src/mlpack/tests/CMakeLists.txt~	(original)
+++ mlpack/trunk/src/mlpack/tests/CMakeLists.txt~	Fri Jun 27 04:28:07 2014
@@ -8,6 +8,7 @@
   aug_lagrangian_test.cpp
   cf_test.cpp
   cli_test.cpp
+  decision_stump_test.cpp
   det_test.cpp
   distribution_test.cpp
   emst_test.cpp
@@ -33,6 +34,7 @@
   nca_test.cpp
   nmf_test.cpp
   pca_test.cpp
+  perceptron_test.cpp
   radical_test.cpp
   range_search_test.cpp
   save_restore_utility_test.cpp

Added: mlpack/trunk/src/mlpack/tests/cosine_tree_test.cpp
==============================================================================
--- (empty file)
+++ mlpack/trunk/src/mlpack/tests/cosine_tree_test.cpp	Fri Jun 27 04:28:07 2014
@@ -0,0 +1,186 @@
+/**
+ * @file cosine_tree_test.cpp
+ * @author Siddharth Agrawal
+ *
+ * Test file for CosineTree class.
+ */
+
+#include <mlpack/core.hpp>
+#include <mlpack/core/tree/cosine_tree/cosine_tree.hpp>
+
+#include <boost/test/unit_test.hpp>
+#include "old_boost_test_definitions.hpp"
+
+BOOST_AUTO_TEST_SUITE(CosineTreeTest);
+
+using namespace mlpack;
+using namespace mlpack::tree;
+
+/**
+ * Constructs a cosine tree with epsilon = 1. Checks if the root node is split
+ * further, as it shouldn't be.
+ */
+BOOST_AUTO_TEST_CASE(CosineTreeNoSplit)
+{
+  // Initialize constants required for the test.
+  const size_t numRows = 10;
+  const size_t numCols = 15;
+  const double epsilon = 1;
+  const double delta = 0.1;
+
+  // Make a random dataset.
+  arma::mat data = arma::randu(numRows, numCols);
+  
+  // Make a cosine tree, with the generated dataset and the defined constants.
+  // Note that the value of epsilon is one.
+  CosineTree ctree(data, epsilon, delta);
+  arma::mat basis;
+  ctree.GetFinalBasis(basis);
+  
+  // Since epsilon is one, there should be no splitting and the only vector in
+  // the basis should come from the root node.
+  BOOST_REQUIRE_EQUAL(basis.n_cols, 1);
+}
+
+/**
+ * Checks CosineNode::CosineNodeSplit() by doing a depth first search on a
+ * random dataset and checking if it satisfies the split condition.
+ */
+BOOST_AUTO_TEST_CASE(CosineNodeCosineSplit)
+{
+  // Intialize constants required for the test.
+  const size_t numRows = 500;
+  const size_t numCols = 1000;
+  
+  // Make a random dataset and the root object.
+  arma::mat data = arma::randu(numRows, numCols);
+  CosineNode root(data);
+  
+  // Stack for depth first search of the tree.
+  std::vector<CosineNode*> nodeStack;
+  nodeStack.push_back(&root);
+  
+  // While stack is not empty.
+  while(nodeStack.size())
+  {
+    // Pop a node from the stack and split it.
+    CosineNode *currentNode, *currentLeft, *currentRight;
+    currentNode = nodeStack.back();
+    currentNode->CosineNodeSplit();
+    nodeStack.pop_back();
+    
+    // Obtain pointers to the children of the node.
+    currentLeft = currentNode->Left();
+    currentRight = currentNode->Right();
+    
+    // If children exist.
+    if(currentLeft && currentRight)
+    {
+      // Push the child nodes on to the stack.
+      nodeStack.push_back(currentLeft);
+      nodeStack.push_back(currentRight);
+      
+      // Obtain the split point of the popped node.
+      arma::vec splitPoint = data.col(currentNode->SplitPointIndex());
+      
+      // Column indices of the the child nodes.
+      std::vector<size_t> leftIndices, rightIndices;
+      leftIndices = currentLeft->VectorIndices();
+      rightIndices = currentRight->VectorIndices();
+      
+      // The columns in the popped should be split into left and right nodes.
+      BOOST_REQUIRE_EQUAL(currentNode->NumColumns(), leftIndices.size() +
+          rightIndices.size());
+      
+      // Calculate the cosine values for each of the columns in the node.
+      arma::vec cosines;
+      cosines.zeros(currentNode->NumColumns());
+      
+      size_t i, j, k;
+      for(i = 0; i < leftIndices.size(); i++)
+      {
+        cosines(i) = arma::norm_dot(data.col(leftIndices[i]), splitPoint);
+      }
+      for(j = 0, k = i; j < rightIndices.size(); j++, k++)
+      {
+        cosines(k) = arma::norm_dot(data.col(rightIndices[j]), splitPoint);
+      }
+      
+      // Check if the columns assigned to the children agree with the splitting
+      // condition.
+      double cosineMax = arma::max(cosines % (cosines < 1));
+      double cosineMin = arma::min(cosines);
+      
+      for(i = 0; i < leftIndices.size(); i++)
+      {
+        BOOST_CHECK_LT(cosineMax - cosines(i), cosines(i) - cosineMin);
+      }
+      for(j = 0, k = i; j < rightIndices.size(); j++, k++)
+      {
+        BOOST_CHECK_GT(cosineMax - cosines(k), cosines(k) - cosineMin);
+      }
+    }
+  }
+}
+
+/**
+ * Checks CosineTree::ModifiedGramSchmidt() by creating a random basis for the
+ * vector subspace and checking if all the vectors are orthogonal to each other.
+ */
+BOOST_AUTO_TEST_CASE(CosineTreeModifiedGramSchmidt)
+{
+  // Initialize constants required for the test.
+  const size_t numRows = 100;
+  const size_t numCols = 50;
+  const double epsilon = 1;
+  const double delta = 0.1;
+  
+  // Make a random dataset.
+  arma::mat data = arma::randu(numRows, numCols);
+  
+  // Declare a queue and a dummy CosineTree object.
+  CosineTree::CosineNodeQueue basisQueue;
+  CosineTree dummyTree(data, epsilon, delta);
+  
+  for(size_t i = 0; i < numCols; i++)
+  {
+    // Make a new CosineNode object.
+    CosineNode* basisNode;
+    basisNode = new CosineNode(data);
+    
+    // Use the columns of the dataset as random centroids.
+    arma::vec centroid = data.col(i);
+    arma::vec newBasisVector;
+    
+    // Obtain the orthonormalized version of the centroid.
+    dummyTree.ModifiedGramSchmidt(basisQueue, centroid, newBasisVector);   
+    
+    // Check if the obtained vector is orthonormal to the basis vectors.
+    CosineTree::CosineNodeQueue::const_iterator j = basisQueue.begin();
+    CosineNode* currentNode;
+    
+    for(; j != basisQueue.end(); j++)
+    {
+      currentNode = *j;
+      BOOST_REQUIRE_SMALL(arma::dot(currentNode->BasisVector(), newBasisVector),
+                          1e-5);
+    }
+    
+    // Add the obtained vector to the basis.
+    basisNode->BasisVector(newBasisVector);
+    basisNode->L2Error(arma::randu());
+    basisQueue.push(basisNode);
+  }
+  
+  // Deallocate memory given to the objects.
+  for(size_t i = 0; i < numCols; i++)
+  {
+    CosineNode* currentNode;
+    currentNode = basisQueue.top();
+    basisQueue.pop();
+    
+    delete currentNode;
+  }
+}
+
+BOOST_AUTO_TEST_SUITE_END();

Modified: mlpack/trunk/src/mlpack/tests/tree_test.cpp
==============================================================================
--- mlpack/trunk/src/mlpack/tests/tree_test.cpp	(original)
+++ mlpack/trunk/src/mlpack/tests/tree_test.cpp	Fri Jun 27 04:28:07 2014
@@ -8,8 +8,6 @@
 #include <mlpack/core/tree/binary_space_tree/binary_space_tree.hpp>
 #include <mlpack/core/metrics/lmetric.hpp>
 #include <mlpack/core/tree/cover_tree/cover_tree.hpp>
-#include <mlpack/core/tree/cosine_tree/cosine_tree.hpp>
-#include <mlpack/core/tree/cosine_tree/cosine_tree_builder.hpp>
 
 #include <queue>
 #include <stack>
@@ -1915,196 +1913,4 @@
   CheckDescendants(&tree);
 }
 
-/*
- * Make sure that constructor for cosine tree is working.
- */
-BOOST_AUTO_TEST_CASE(CosineTreeConstructorTest)
-{
-  // Create test data.
-  arma::mat data = arma::randu<arma::mat>(5, 5);
-  arma::rowvec centroid = arma::randu<arma::rowvec>(1, 5);
-  arma::vec probabilities = arma::randu<arma::vec>(5, 1);
-
-  // Creating a cosine tree.
-  CosineTree ct(data, centroid, probabilities);
-
-  const arma::mat& dataRet = ct.Data();
-  const arma::rowvec& centroidRet = ct.Centroid();
-  const arma::vec& probabilitiesRet = ct.Probabilities();
-
-  // Check correctness of dimensionality of data matrix.
-  BOOST_REQUIRE_EQUAL(data.n_cols, dataRet.n_rows);
-  BOOST_REQUIRE_EQUAL(data.n_rows, dataRet.n_cols);
-
-  // Check the data matrix.
-  for (size_t i = 0; i < data.n_cols; i++)
-    for (size_t j = 0; j < data.n_rows; j++)
-      BOOST_REQUIRE_CLOSE((double) dataRet(j, i), (double) data(i, j), 1e-5);
-
-  // Check correctness of dimensionality of centroid.
-  BOOST_REQUIRE_EQUAL(centroid.n_cols, centroidRet.n_cols);
-  BOOST_REQUIRE_EQUAL(centroid.n_rows, centroidRet.n_rows);
-
-  // Check centroid.
-  for (size_t i = 0; i < centroid.n_cols; i++)
-    BOOST_REQUIRE_CLOSE((double) centroidRet(0, i), (double) centroid(0,i),
-        1e-5);
-
-  // Check correctness of dimentionality of sampling probabilities.
-  BOOST_REQUIRE_EQUAL(probabilities.n_cols, probabilitiesRet.n_cols);
-  BOOST_REQUIRE_EQUAL(probabilities.n_rows, probabilitiesRet.n_rows);
-
-  // Check sampling probabilities.
-  for (size_t i = 0; i < probabilities.n_rows; i++)
-    BOOST_REQUIRE_CLOSE((double) probabilitiesRet(i, 0), (double)
-        probabilities(i, 0), 1e-5);
-
-  // Check pointers of children nodes.
-  BOOST_REQUIRE(ct.Right() == NULL);
-  BOOST_REQUIRE(ct.Left() == NULL);
-}
-
-/**
- * Make sure that CTNode function in Cosine tree builder is working.
- */
-BOOST_AUTO_TEST_CASE(CosineTreeEmptyConstructorTest)
-{
-  // Create a tree through the empty constructor.
-  CosineTree ct;
-
-  // Check to make sure it has no children.
-  BOOST_REQUIRE(ct.Right() == NULL);
-  BOOST_REQUIRE(ct.Left() == NULL);
-}
-
-/**
- * Make sure that CTNode function in CosineTreeBuilder is working.
- * This test just validates the dimentionality and data.
- */
-BOOST_AUTO_TEST_CASE(CosineTreeBuilderCTNodeTest)
-{
-  // Create dummy test data.
-  arma::mat data = arma::randu<arma::mat>(5, 5);
-
-  // Create a cosine tree builder object.
-  CosineTreeBuilder builder;
-
-  // Create a cosine tree object.
-  CosineTree ct;
-
-  // Use the builder to create the tree.
-  builder.CTNode(data, ct);
-
-  const arma::mat& dataRet = ct.Data();
-  const arma::rowvec& centroidRet = ct.Centroid();
-  const arma::vec& probabilitiesRet = ct.Probabilities();
-
-  // Check correctness of dimentionality of data.
-  BOOST_REQUIRE_EQUAL(data.n_cols, dataRet.n_cols);
-  BOOST_REQUIRE_EQUAL(data.n_rows, dataRet.n_rows);
-
-  // Check data.
-  for (size_t i = 0; i < data.n_cols; i++)
-    for (size_t j = 0; j < data.n_rows; j++)
-      BOOST_REQUIRE_CLOSE((double) dataRet(j, i), (double) data(i, j), 1e-5);
-
-  // Check correctness of dimensionality of centroid.
-  BOOST_REQUIRE_EQUAL(data.n_rows, centroidRet.n_cols);
-  BOOST_REQUIRE_EQUAL(1, centroidRet.n_rows);
-
-  // Check correctness of dimensionality of sampling probabilities.
-  BOOST_REQUIRE_EQUAL(1, probabilitiesRet.n_cols);
-  BOOST_REQUIRE_EQUAL(data.n_rows, probabilitiesRet.n_rows);
-
-  // Check pointers of children nodes.
-  BOOST_REQUIRE(ct.Right() == NULL);
-  BOOST_REQUIRE(ct.Left() == NULL);
-
-}
-
-/**
- * Make sure that the centroid is calculated correctly when the cosine tree is
- * built.
- */
-BOOST_AUTO_TEST_CASE(CosineTreeBuilderCentroidTest)
-{
-  // Create dummy test data.
-  arma::mat data;
-  data << 1.0 << 2.0 << 3.0 << arma::endr
-       << 4.0 << 2.0 << 3.0 << arma::endr
-       << 2.5 << 3.0 << 2.0 << arma::endr;
-
-  // Expected centroid.
-  arma::vec c;
-  c << 2.0 << 3.0 << 2.5 << arma::endr;
-
-  // Build the cosine tree.
-  CosineTreeBuilder builder;
-  CosineTree ct;
-  builder.CTNode(data, ct);
-
-  // Get the centroid.
-  arma::rowvec centroid = ct.Centroid();
-
-  // Check correctness of the centroid.
-  BOOST_REQUIRE_CLOSE((double) c(0, 0), (double) centroid(0, 0), 1e-5);
-  BOOST_REQUIRE_CLOSE((double) c(1, 0), (double) centroid(0, 1), 1e-5);
-  BOOST_REQUIRE_CLOSE((double) c(2, 0), (double) centroid(0, 2), 1e-5);
-}
-
-/**
- * Make sure that the sampling probabilities are calculated correctly when the
- * cosine tree is built.
- */
-BOOST_AUTO_TEST_CASE(CosineTreeBuilderProbabilitiesTest)
-{
-  // Create dummy test data.
-  arma::mat data;
-  data << 100.0 <<   2.0 <<   3.0 << arma::endr
-       << 400.0 <<   2.0 <<   3.0 << arma::endr
-       << 200.5 <<   3.0 <<   2.0 << arma::endr;
-
-  // Expected sample probability.
-  arma::vec p;
-  p << 0.999907 << 0.00899223 << 0.0102295 << arma::endr;
-
-  // Create the cosine tree.
-  CosineTreeBuilder builder;
-  CosineTree ct;
-  builder.CTNode(data, ct);
-
-  // Get the probabilities.
-  const arma::vec& probabilities = ct.Probabilities();
-
-  // Check correctness of sampling probabilities.
-  BOOST_REQUIRE_CLOSE((double) p(0, 0), (double) probabilities(0, 0), 1e-4);
-  BOOST_REQUIRE_CLOSE((double) p(1, 0), (double) probabilities(1, 0), 1e-4);
-  BOOST_REQUIRE_CLOSE((double) p(2, 0), (double) probabilities(2, 0), 1e-4);
-}
-
-/**
- * Make sure that the cosine tree builder is splitting nodes.
- */
-BOOST_AUTO_TEST_CASE(CosineTreeBuilderCTNodeSplitTest)
-{
-  // Create dummy test data.
-  arma::mat data;
-  data << 100.0 <<   2.0 <<   3.0 << arma::endr
-       << 400.0 <<   2.0 <<   3.0 << arma::endr
-       << 200.5 <<   3.0 <<   2.0 << arma::endr;
-
-  // Build a cosine tree root node, and then split it.
-  CosineTreeBuilder builder;
-  CosineTree root, left, right;
-  builder.CTNode(data, root);
-  builder.CTNodeSplit(root, left, right);
-
-  // Ensure that there is no data loss.
-  BOOST_REQUIRE_EQUAL((left.NumPoints() + right.NumPoints()), root.NumPoints());
-
-  // Ensure that the dimensionality is correct.
-  BOOST_REQUIRE_EQUAL(left.Data().n_cols, data.n_cols);
-  BOOST_REQUIRE_EQUAL(right.Data().n_cols, data.n_cols);
-}
-
 BOOST_AUTO_TEST_SUITE_END();



More information about the mlpack-svn mailing list