[mlpack-git] master: Refactor so holding a dataset internally is possible. (997d03b)

gitdub at big.cc.gt.atl.ga.us gitdub at big.cc.gt.atl.ga.us
Sun Oct 11 16:26:14 EDT 2015


Repository : https://github.com/mlpack/mlpack

On branch  : master
Link       : https://github.com/mlpack/mlpack/compare/7e6fe3f1c25445dfd41c747839aa1fbf1c31a776...6e9a7465d7739e05e6b4aa650e1f87c45e9cd656

>---------------------------------------------------------------

commit 997d03baf19eaf73cb3f51b38d5c5b8b877d353a
Author: Ryan Curtin <ryan at ratml.org>
Date:   Sun Oct 11 05:18:32 2015 -0400

    Refactor so holding a dataset internally is possible.


>---------------------------------------------------------------

997d03baf19eaf73cb3f51b38d5c5b8b877d353a
 src/mlpack/core/tree/cover_tree/cover_tree.hpp     | 27 ++++++------
 .../core/tree/cover_tree/cover_tree_impl.hpp       | 48 +++++++++++++---------
 src/mlpack/tests/tree_test.cpp                     | 15 ++++---
 3 files changed, 50 insertions(+), 40 deletions(-)

diff --git a/src/mlpack/core/tree/cover_tree/cover_tree.hpp b/src/mlpack/core/tree/cover_tree/cover_tree.hpp
index e5e78fc..1f27a60 100644
--- a/src/mlpack/core/tree/cover_tree/cover_tree.hpp
+++ b/src/mlpack/core/tree/cover_tree/cover_tree.hpp
@@ -191,7 +191,7 @@ class CoverTree
 
   /**
    * Create a cover tree from another tree.  Be careful!  This may use a lot of
-   * memory and take a lot of time.
+   * memory and take a lot of time.  This will also make a copy of the dataset.
    *
    * @param other Cover tree to copy from.
    */
@@ -215,7 +215,7 @@ class CoverTree
   using BreadthFirstDualTreeTraverser = DualTreeTraverser<RuleType>;
 
   //! Get a reference to the dataset.
-  const MatType& Dataset() const { return dataset; }
+  const MatType& Dataset() const { return *dataset; }
 
   //! Get the index of the point which this node represents.
   size_t Point() const { return point; }
@@ -335,7 +335,7 @@ class CoverTree
   //! Get the center of the node and store it in the given vector.
   void Center(arma::vec& center) const
   {
-    center = arma::vec(dataset.col(point));
+    center = arma::vec(dataset->col(point));
   }
 
   //! Get the instantiated metric.
@@ -343,38 +343,29 @@ class CoverTree
 
  private:
   //! Reference to the matrix which this tree is built on.
-  const MatType& dataset;
-
+  const MatType* dataset;
   //! Index of the point in the matrix which this node represents.
   size_t point;
-
   //! The list of children; the first is the self-child.
   std::vector<CoverTree*> children;
-
   //! Scale level of the node.
   int scale;
-
   //! The base used to construct the tree.
   double base;
-
   //! The instantiated statistic.
   StatisticType stat;
-
   //! The number of descendant points.
   size_t numDescendants;
-
   //! The parent node (NULL if this is the root of the tree).
   CoverTree* parent;
-
   //! Distance to the parent.
   double parentDistance;
-
   //! Distance to the furthest descendant.
   double furthestDescendantDistance;
-
   //! Whether or not we need to destroy the metric in the destructor.
   bool localMetric;
-
+  //! If true, we own the dataset and need to destroy it in the destructor.
+  bool localDataset;
   //! The metric used for this tree.
   MetricType* metric;
 
@@ -472,6 +463,12 @@ class CoverTree
    */
   std::string ToString() const;
 
+  /**
+   * Serialize the tree.
+   */
+  template<typename Archive>
+  void Serialize(Archive& ar, const unsigned int /* version */);
+
   size_t DistanceComps() const { return distanceComps; }
   size_t& DistanceComps() { return distanceComps; }
 
diff --git a/src/mlpack/core/tree/cover_tree/cover_tree_impl.hpp b/src/mlpack/core/tree/cover_tree/cover_tree_impl.hpp
index cd634d7..e9212fc 100644
--- a/src/mlpack/core/tree/cover_tree/cover_tree_impl.hpp
+++ b/src/mlpack/core/tree/cover_tree/cover_tree_impl.hpp
@@ -27,7 +27,7 @@ CoverTree<MetricType, StatisticType, MatType, RootPointPolicy>::CoverTree(
     const MatType& dataset,
     const double base,
     MetricType* metric) :
-    dataset(dataset),
+    dataset(&dataset),
     point(RootPointPolicy::ChooseRoot(dataset)),
     scale(INT_MAX),
     base(base),
@@ -36,6 +36,7 @@ CoverTree<MetricType, StatisticType, MatType, RootPointPolicy>::CoverTree(
     parentDistance(0),
     furthestDescendantDistance(0),
     localMetric(metric == NULL),
+    localDataset(false),
     metric(metric),
     distanceComps(0)
 {
@@ -114,7 +115,7 @@ CoverTree<MetricType, StatisticType, MatType, RootPointPolicy>::CoverTree(
     const MatType& dataset,
     MetricType& metric,
     const double base) :
-    dataset(dataset),
+    dataset(&dataset),
     point(RootPointPolicy::ChooseRoot(dataset)),
     scale(INT_MAX),
     base(base),
@@ -123,6 +124,7 @@ CoverTree<MetricType, StatisticType, MatType, RootPointPolicy>::CoverTree(
     parentDistance(0),
     furthestDescendantDistance(0),
     localMetric(false),
+    localDataset(false),
     metric(&metric),
     distanceComps(0)
 {
@@ -207,7 +209,7 @@ CoverTree<MetricType, StatisticType, MatType, RootPointPolicy>::CoverTree(
     size_t& farSetSize,
     size_t& usedSetSize,
     MetricType& metric) :
-    dataset(dataset),
+    dataset(&dataset),
     point(pointIndex),
     scale(scale),
     base(base),
@@ -216,6 +218,7 @@ CoverTree<MetricType, StatisticType, MatType, RootPointPolicy>::CoverTree(
     parentDistance(parentDistance),
     furthestDescendantDistance(0),
     localMetric(false),
+    localDataset(false),
     metric(&metric),
     distanceComps(0)
 {
@@ -251,7 +254,7 @@ CoverTree<MetricType, StatisticType, MatType, RootPointPolicy>::CoverTree(
     const double parentDistance,
     const double furthestDescendantDistance,
     MetricType* metric) :
-    dataset(dataset),
+    dataset(&dataset),
     point(pointIndex),
     scale(scale),
     base(base),
@@ -260,6 +263,7 @@ CoverTree<MetricType, StatisticType, MatType, RootPointPolicy>::CoverTree(
     parentDistance(parentDistance),
     furthestDescendantDistance(furthestDescendantDistance),
     localMetric(metric == NULL),
+    localDataset(false),
     metric(metric),
     distanceComps(0)
 {
@@ -279,7 +283,7 @@ template<
 >
 CoverTree<MetricType, StatisticType, MatType, RootPointPolicy>::CoverTree(
     const CoverTree& other) :
-    dataset(other.dataset),
+    dataset((other.parent == NULL) ? new MatType(*other.dataset) : NULL),
     point(other.point),
     scale(other.scale),
     base(other.base),
@@ -289,6 +293,7 @@ CoverTree<MetricType, StatisticType, MatType, RootPointPolicy>::CoverTree(
     parentDistance(other.parentDistance),
     furthestDescendantDistance(other.furthestDescendantDistance),
     localMetric(false),
+    localDataset(other.parent == NULL),
     metric(other.metric),
     distanceComps(0)
 {
@@ -297,6 +302,7 @@ CoverTree<MetricType, StatisticType, MatType, RootPointPolicy>::CoverTree(
   {
     children.push_back(new CoverTree(other.Child(i)));
     children[i]->Parent() = this;
+    children[i]->dataset = this->dataset;
   }
 }
 
@@ -315,6 +321,10 @@ CoverTree<MetricType, StatisticType, MatType, RootPointPolicy>::~CoverTree()
   // Delete the local metric, if necessary.
   if (localMetric)
     delete metric;
+
+  // Delete the local dataset, if necessary.
+  if (localDataset)
+    delete dataset;
 }
 
 //! Return the number of descendant points.
@@ -373,7 +383,7 @@ double CoverTree<MetricType, StatisticType, MatType, RootPointPolicy>::
     MinDistance(const CoverTree* other) const
 {
   // Every cover tree node will contain points up to base^(scale + 1) away.
-  return std::max(metric->Evaluate(dataset.col(point),
+  return std::max(metric->Evaluate(dataset->col(point),
       other->Dataset().col(other->Point())) -
       furthestDescendantDistance - other->FurthestDescendantDistance(), 0.0);
 }
@@ -401,7 +411,7 @@ template<
 double CoverTree<MetricType, StatisticType, MatType, RootPointPolicy>::
     MinDistance(const arma::vec& other) const
 {
-  return std::max(metric->Evaluate(dataset.col(point), other) -
+  return std::max(metric->Evaluate(dataset->col(point), other) -
       furthestDescendantDistance, 0.0);
 }
 
@@ -426,7 +436,7 @@ template<
 double CoverTree<MetricType, StatisticType, MatType, RootPointPolicy>::
     MaxDistance(const CoverTree* other) const
 {
-  return metric->Evaluate(dataset.col(point),
+  return metric->Evaluate(dataset->col(point),
       other->Dataset().col(other->Point())) +
       furthestDescendantDistance + other->FurthestDescendantDistance();
 }
@@ -454,7 +464,7 @@ template<
 double CoverTree<MetricType, StatisticType, MatType, RootPointPolicy>::
     MaxDistance(const arma::vec& other) const
 {
-  return metric->Evaluate(dataset.col(point), other) +
+  return metric->Evaluate(dataset->col(point), other) +
       furthestDescendantDistance;
 }
 
@@ -480,7 +490,7 @@ template<
 math::Range CoverTree<MetricType, StatisticType, MatType, RootPointPolicy>::
     RangeDistance(const CoverTree* other) const
 {
-  const double distance = metric->Evaluate(dataset.col(point),
+  const double distance = metric->Evaluate(dataset->col(point),
       other->Dataset().col(other->Point()));
 
   math::Range result;
@@ -523,7 +533,7 @@ template<
 math::Range CoverTree<MetricType, StatisticType, MatType, RootPointPolicy>::
     RangeDistance(const arma::vec& other) const
 {
-  const double distance = metric->Evaluate(dataset.col(point), other);
+  const double distance = metric->Evaluate(dataset->col(point), other);
 
   return math::Range(distance - furthestDescendantDistance,
                      distance + furthestDescendantDistance);
@@ -575,7 +585,7 @@ CoverTree<MetricType, StatisticType, MatType, RootPointPolicy>::CreateChildren(
     // Make the self child at the lowest possible level.
     // This should not modify farSetSize or usedSetSize.
     size_t tempSize = 0;
-    children.push_back(new CoverTree(dataset, base, point, INT_MIN, this, 0,
+    children.push_back(new CoverTree(*dataset, base, point, INT_MIN, this, 0,
         indices, distances, 0, tempSize, usedSetSize, *metric));
     distanceComps += children.back()->DistanceComps();
 
@@ -583,7 +593,7 @@ CoverTree<MetricType, StatisticType, MatType, RootPointPolicy>::CreateChildren(
     for (size_t i = 0; i < nearSetSize; ++i)
     {
       // farSetSize and usedSetSize will not be modified.
-      children.push_back(new CoverTree(dataset, base, indices[i],
+      children.push_back(new CoverTree(*dataset, base, indices[i],
           INT_MIN, this, distances[i], indices, distances, 0, tempSize,
           usedSetSize, *metric));
       distanceComps += children.back()->DistanceComps();
@@ -615,7 +625,7 @@ CoverTree<MetricType, StatisticType, MatType, RootPointPolicy>::CreateChildren(
   // Build the self child (recursively).
   size_t childFarSetSize = nearSetSize - childNearSetSize;
   size_t childUsedSetSize = 0;
-  children.push_back(new CoverTree(dataset, base, point, nextScale, this, 0,
+  children.push_back(new CoverTree(*dataset, base, point, nextScale, this, 0,
       indices, distances, childNearSetSize, childFarSetSize, childUsedSetSize,
       *metric));
   // Don't double-count the self-child (so, subtract one).
@@ -673,7 +683,7 @@ CoverTree<MetricType, StatisticType, MatType, RootPointPolicy>::CreateChildren(
     if ((nearSetSize == 1) && (farSetSize == 0))
     {
       size_t childNearSetSize = 0;
-      children.push_back(new CoverTree(dataset, base, indices[0], nextScale,
+      children.push_back(new CoverTree(*dataset, base, indices[0], nextScale,
           this, distances[0], indices, distances, childNearSetSize, farSetSize,
           usedSetSize, *metric));
       distanceComps += children.back()->DistanceComps();
@@ -715,7 +725,7 @@ CoverTree<MetricType, StatisticType, MatType, RootPointPolicy>::CreateChildren(
 
     // Build this child (recursively).
     childUsedSetSize = 1; // Mark self point as used.
-    children.push_back(new CoverTree(dataset, base, indices[0], nextScale,
+    children.push_back(new CoverTree(*dataset, base, indices[0], nextScale,
         this, distances[0], childIndices, childDistances, childNearSetSize,
         childFarSetSize, childUsedSetSize, *metric));
     numDescendants += children.back()->NumDescendants();
@@ -816,8 +826,8 @@ void CoverTree<MetricType, StatisticType, MatType, RootPointPolicy>::
   distanceComps += pointSetSize;
   for (size_t i = 0; i < pointSetSize; ++i)
   {
-    distances[i] = metric->Evaluate(dataset.col(pointIndex),
-        dataset.col(indices[i]));
+    distances[i] = metric->Evaluate(dataset->col(pointIndex),
+        dataset->col(indices[i]));
   }
 }
 
@@ -1125,7 +1135,7 @@ std::string CoverTree<MetricType, StatisticType, MatType, RootPointPolicy>::
 {
   std::ostringstream convert;
   convert << "CoverTree [" << this << "]" << std::endl;
-  convert << "  dataset: " << &dataset << std::endl;
+  convert << "  dataset: " << dataset << std::endl;
   convert << "  point: " << point << std::endl;
   convert << "  scale: " << scale << std::endl;
   convert << "  base: " << base << std::endl;
diff --git a/src/mlpack/tests/tree_test.cpp b/src/mlpack/tests/tree_test.cpp
index 782e10d..5f6b1c4 100644
--- a/src/mlpack/tests/tree_test.cpp
+++ b/src/mlpack/tests/tree_test.cpp
@@ -1831,8 +1831,9 @@ BOOST_AUTO_TEST_CASE(CoverTreeCopyConstructor)
 
   TreeType d = c;
 
-  // Check that everything is the same.
-  BOOST_REQUIRE_EQUAL(c.Dataset().memptr(), d.Dataset().memptr());
+  // Check that everything is the same, except the dataset, which should have
+  // been copied.
+  BOOST_REQUIRE_NE(c.Dataset().memptr(), d.Dataset().memptr());
   BOOST_REQUIRE_CLOSE(c.Base(), d.Base(), 1e-50);
   BOOST_REQUIRE_EQUAL(c.Point(), d.Point());
   BOOST_REQUIRE_EQUAL(c.Scale(), d.Scale());
@@ -1850,8 +1851,9 @@ BOOST_AUTO_TEST_CASE(CoverTreeCopyConstructor)
   BOOST_REQUIRE_EQUAL(d.Child(1).Parent(), &d);
 
   // Check that the children are okay.
-  BOOST_REQUIRE_EQUAL(c.Child(0).Dataset().memptr(),
-                      d.Child(0).Dataset().memptr());
+  BOOST_REQUIRE_NE(c.Child(0).Dataset().memptr(),
+                   d.Child(0).Dataset().memptr());
+  BOOST_REQUIRE_EQUAL(c.Child(0).Dataset().memptr(), c.Dataset().memptr());
   BOOST_REQUIRE_CLOSE(c.Child(0).Base(), d.Child(0).Base(), 1e-50);
   BOOST_REQUIRE_EQUAL(c.Child(0).Point(), d.Child(0).Point());
   BOOST_REQUIRE_EQUAL(c.Child(0).Scale(), d.Child(0).Scale());
@@ -1860,8 +1862,9 @@ BOOST_AUTO_TEST_CASE(CoverTreeCopyConstructor)
                       d.Child(0).FurthestDescendantDistance());
   BOOST_REQUIRE_EQUAL(c.Child(0).NumChildren(), d.Child(0).NumChildren());
 
-  BOOST_REQUIRE_EQUAL(c.Child(1).Dataset().memptr(),
-                      d.Child(1).Dataset().memptr());
+  BOOST_REQUIRE_NE(c.Child(1).Dataset().memptr(),
+                   d.Child(1).Dataset().memptr());
+  BOOST_REQUIRE_EQUAL(c.Child(1).Dataset().memptr(), c.Dataset().memptr());
   BOOST_REQUIRE_CLOSE(c.Child(1).Base(), d.Child(1).Base(), 1e-50);
   BOOST_REQUIRE_EQUAL(c.Child(1).Point(), d.Child(1).Point());
   BOOST_REQUIRE_EQUAL(c.Child(1).Scale(), d.Child(1).Scale());



More information about the mlpack-git mailing list