[mlpack-git] master: Add rvalue reference constructors. (10d7140)
gitdub at big.cc.gt.atl.ga.us
gitdub at big.cc.gt.atl.ga.us
Mon Oct 19 16:04:27 EDT 2015
Repository : https://github.com/mlpack/mlpack
On branch : master
Link : https://github.com/mlpack/mlpack/compare/09cd0d67f2fdae252a8ab85324e71dbb4dfe0010...fecf1194c123ced12d56e7daad761c7b9aaac262
>---------------------------------------------------------------
commit 10d7140ef0b02b07c446e717caf445598a0795a8
Author: Ryan Curtin <ryan at ratml.org>
Date: Mon Oct 19 10:07:58 2015 -0400
Add rvalue reference constructors.
>---------------------------------------------------------------
10d7140ef0b02b07c446e717caf445598a0795a8
src/mlpack/core/tree/cover_tree/cover_tree.hpp | 22 +++
.../core/tree/cover_tree/cover_tree_impl.hpp | 169 +++++++++++++++++++++
src/mlpack/tests/tree_test.cpp | 21 +++
3 files changed, 212 insertions(+)
diff --git a/src/mlpack/core/tree/cover_tree/cover_tree.hpp b/src/mlpack/core/tree/cover_tree/cover_tree.hpp
index ca7bbbd..bbec68c 100644
--- a/src/mlpack/core/tree/cover_tree/cover_tree.hpp
+++ b/src/mlpack/core/tree/cover_tree/cover_tree.hpp
@@ -121,6 +121,28 @@ class CoverTree
const double base = 2.0);
/**
+ * Create the cover tree with the given dataset, taking ownership of the
+ * dataset. Optionally, set the base.
+ *
+ * @param dataset Reference to the dataset to build a tree on.
+ * @param base Base to use during tree building (default 2.0).
+ */
+ CoverTree(MatType&& dataset,
+ const double base = 2.0);
+
+ /**
+ * Create the cover tree with the given dataset and the given instantiated
+ * metric, taking ownership of the dataset. Optionally, set the base.
+ *
+ * @param dataset Reference to the dataset to build a tree on.
+ * @param metric Instantiated metric to use during tree building.
+ * @param base Base to use during tree building (default 2.0).
+ */
+ CoverTree(MatType&& dataset,
+ MetricType& metric,
+ const double base = 2.0);
+
+ /**
* Construct a child cover tree node. This constructor is not meant to be
* used externally, but it could be used to insert another node into a tree.
* This procedure uses only one vector for the near set, the far set, and the
diff --git a/src/mlpack/core/tree/cover_tree/cover_tree_impl.hpp b/src/mlpack/core/tree/cover_tree/cover_tree_impl.hpp
index 0a8bafb..93d146e 100644
--- a/src/mlpack/core/tree/cover_tree/cover_tree_impl.hpp
+++ b/src/mlpack/core/tree/cover_tree/cover_tree_impl.hpp
@@ -197,6 +197,175 @@ template<
typename RootPointPolicy
>
CoverTree<MetricType, StatisticType, MatType, RootPointPolicy>::CoverTree(
+ MatType&& data,
+ const double base) :
+ dataset(new MatType(std::move(data))),
+ point(RootPointPolicy::ChooseRoot(dataset)),
+ scale(INT_MAX),
+ base(base),
+ numDescendants(0),
+ parent(NULL),
+ parentDistance(0),
+ furthestDescendantDistance(0),
+ localMetric(true),
+ localDataset(true),
+ distanceComps(0)
+{
+ // We need to create a metric. We'll just do it on the heap.
+ this->metric = new MetricType();
+
+ // If there is only one point in the dataset... uh, we're done.
+ if (dataset->n_cols == 1)
+ return;
+
+ // Kick off the building. Create the indices array and the distances array.
+ arma::Col<size_t> indices = arma::linspace<arma::Col<size_t> >(1,
+ dataset->n_cols - 1, dataset->n_cols - 1);
+ // This is now [1 2 3 4 ... n]. We must be sure that our point does not
+ // occur.
+ if (point != 0)
+ indices[point - 1] = 0; // Put 0 back into the set; remove what was there.
+
+ arma::vec distances(dataset->n_cols - 1);
+
+ // Build the initial distances.
+ ComputeDistances(point, indices, distances, dataset->n_cols - 1);
+
+ // Create the children.
+ size_t farSetSize = 0;
+ size_t usedSetSize = 0;
+ CreateChildren(indices, distances, dataset->n_cols - 1, farSetSize,
+ usedSetSize);
+
+ // If we ended up creating only one child, remove the implicit node.
+ while (children.size() == 1)
+ {
+ // Prepare to delete the implicit child node.
+ CoverTree* old = children[0];
+
+ // Now take its children and set their parent correctly.
+ children.erase(children.begin());
+ for (size_t i = 0; i < old->NumChildren(); ++i)
+ {
+ children.push_back(&(old->Child(i)));
+
+ // Set its parent correctly, and rebuild the statistic.
+ old->Child(i).Parent() = this;
+ old->Child(i).Stat() = StatisticType(old->Child(i));
+ }
+
+ // Remove all the children so they don't get erased.
+ old->Children().clear();
+
+ // Reduce our own scale.
+ scale = old->Scale();
+
+ // Now delete it.
+ delete old;
+ }
+
+ // Use the furthest descendant distance to determine the scale of the root
+ // node.
+ scale = (int) ceil(log(furthestDescendantDistance) / log(base));
+
+ // Initialize statistic.
+ stat = StatisticType(*this);
+
+ Log::Info << distanceComps << " distance computations during tree "
+ << "construction." << std::endl;
+}
+
+template<
+ typename MetricType,
+ typename StatisticType,
+ typename MatType,
+ typename RootPointPolicy
+>
+CoverTree<MetricType, StatisticType, MatType, RootPointPolicy>::CoverTree(
+ MatType&& data,
+ MetricType& metric,
+ const double base) :
+ dataset(new MatType(std::move(data))),
+ point(RootPointPolicy::ChooseRoot(dataset)),
+ scale(INT_MAX),
+ base(base),
+ numDescendants(0),
+ parent(NULL),
+ parentDistance(0),
+ furthestDescendantDistance(0),
+ localMetric(false),
+ localDataset(true),
+ metric(&metric),
+ distanceComps(0)
+{
+ // If there is only one point in the dataset... uh, we're done.
+ if (dataset->n_cols == 1)
+ return;
+
+ // Kick off the building. Create the indices array and the distances array.
+ arma::Col<size_t> indices = arma::linspace<arma::Col<size_t> >(1,
+ dataset->n_cols - 1, dataset->n_cols - 1);
+ // This is now [1 2 3 4 ... n]. We must be sure that our point does not
+ // occur.
+ if (point != 0)
+ indices[point - 1] = 0; // Put 0 back into the set; remove what was there.
+
+ arma::vec distances(dataset->n_cols - 1);
+
+ // Build the initial distances.
+ ComputeDistances(point, indices, distances, dataset->n_cols - 1);
+
+ // Create the children.
+ size_t farSetSize = 0;
+ size_t usedSetSize = 0;
+ CreateChildren(indices, distances, dataset->n_cols - 1, farSetSize,
+ usedSetSize);
+
+ // If we ended up creating only one child, remove the implicit node.
+ while (children.size() == 1)
+ {
+ // Prepare to delete the implicit child node.
+ CoverTree* old = children[0];
+
+ // Now take its children and set their parent correctly.
+ children.erase(children.begin());
+ for (size_t i = 0; i < old->NumChildren(); ++i)
+ {
+ children.push_back(&(old->Child(i)));
+
+ // Set its parent correctly, and rebuild the statistic.
+ old->Child(i).Parent() = this;
+ old->Child(i).Stat() = StatisticType(old->Child(i));
+ }
+
+ // Remove all the children so they don't get erased.
+ old->Children().clear();
+
+ // Reduce our own scale.
+ scale = old->Scale();
+
+ // Now delete it.
+ delete old;
+ }
+
+ // Use the furthest descendant distance to determine the scale of the root
+ // node.
+ scale = (int) ceil(log(furthestDescendantDistance) / log(base));
+
+ // Initialize statistic.
+ stat = StatisticType(*this);
+
+ Log::Info << distanceComps << " distance computations during tree "
+ << "construction." << std::endl;
+}
+
+template<
+ typename MetricType,
+ typename StatisticType,
+ typename MatType,
+ typename RootPointPolicy
+>
+CoverTree<MetricType, StatisticType, MatType, RootPointPolicy>::CoverTree(
const MatType& dataset,
const double base,
const size_t pointIndex,
diff --git a/src/mlpack/tests/tree_test.cpp b/src/mlpack/tests/tree_test.cpp
index b6ebe57..14b2f51 100644
--- a/src/mlpack/tests/tree_test.cpp
+++ b/src/mlpack/tests/tree_test.cpp
@@ -1914,6 +1914,27 @@ BOOST_AUTO_TEST_CASE(CoverTreeCopyConstructor)
BOOST_REQUIRE_EQUAL(c.Child(1).NumChildren(), d.Child(1).NumChildren());
}
+BOOST_AUTO_TEST_CASE(CoverTreeMoveDatasetTest)
+{
+ arma::mat dataset = arma::randu<arma::mat>(3, 1000);
+ typedef StandardCoverTree<EuclideanDistance, EmptyStatistic, arma::mat>
+ TreeType;
+
+ TreeType t(std::move(dataset));
+
+ BOOST_REQUIRE_EQUAL(dataset.n_elem, 0);
+ BOOST_REQUIRE_EQUAL(t.Dataset().n_rows, 3);
+ BOOST_REQUIRE_EQUAL(t.Dataset().n_cols, 1000);
+
+ EuclideanDistance ed; // Test the other constructor.
+ dataset = arma::randu<arma::mat>(3, 1000);
+ TreeType t2(std::move(dataset), ed);
+
+ BOOST_REQUIRE_EQUAL(dataset.n_elem, 0);
+ BOOST_REQUIRE_EQUAL(t2.Dataset().n_rows, 3);
+ BOOST_REQUIRE_EQUAL(t2.Dataset().n_cols, 1000);
+}
+
/**
* Make sure copy constructor works right for the binary space tree.
*/
More information about the mlpack-git
mailing list