[mlpack-svn] r10313 - in mlpack/trunk/src/mlpack: core/kernels core/tree core/utilities methods/neighbor_search methods/neighbor_search/sort_policies tests
fastlab-svn at coffeetalk-1.cc.gatech.edu
fastlab-svn at coffeetalk-1.cc.gatech.edu
Thu Nov 17 12:44:31 EST 2011
Author: mamidon
Date: 2011-11-17 12:44:31 -0500 (Thu, 17 Nov 2011)
New Revision: 10313
Added:
mlpack/trunk/src/mlpack/core/tree/binary_space_tree_crtp.hpp
mlpack/trunk/src/mlpack/core/tree/binary_space_tree_impl_crtp.hpp
Modified:
mlpack/trunk/src/mlpack/core/kernels/lmetric.cpp
mlpack/trunk/src/mlpack/core/kernels/lmetric.hpp
mlpack/trunk/src/mlpack/core/tree/CMakeLists.txt
mlpack/trunk/src/mlpack/core/tree/hrectbound.hpp
mlpack/trunk/src/mlpack/core/tree/hrectbound_impl.hpp
mlpack/trunk/src/mlpack/core/utilities/timers.hpp
mlpack/trunk/src/mlpack/methods/neighbor_search/allkfn_main.cc
mlpack/trunk/src/mlpack/methods/neighbor_search/neighbor_search.h
mlpack/trunk/src/mlpack/methods/neighbor_search/neighbor_search_impl.h
mlpack/trunk/src/mlpack/methods/neighbor_search/sort_policies/furthest_neighbor_sort.hpp
mlpack/trunk/src/mlpack/methods/neighbor_search/sort_policies/furthest_neighbor_sort_impl.hpp
mlpack/trunk/src/mlpack/methods/neighbor_search/sort_policies/nearest_neighbor_sort.hpp
mlpack/trunk/src/mlpack/methods/neighbor_search/sort_policies/nearest_neighbor_sort_impl.hpp
mlpack/trunk/src/mlpack/methods/neighbor_search/typedef.h
mlpack/trunk/src/mlpack/tests/allkfn_test.cpp
Log:
Checkin mid-progress. Arbitrary data types added, need to clean up
NearestNeighbors since it makes a few (too many) assumptions about
the matrix it'll be using.
Modified: mlpack/trunk/src/mlpack/core/kernels/lmetric.cpp
===================================================================
--- mlpack/trunk/src/mlpack/core/kernels/lmetric.cpp 2011-11-17 17:41:26 UTC (rev 10312)
+++ mlpack/trunk/src/mlpack/core/kernels/lmetric.cpp 2011-11-17 17:44:31 UTC (rev 10313)
@@ -11,7 +11,9 @@
// L1-metric specializations; the root doesn't matter.
template<>
-double LMetric<1, true>::Evaluate(const arma::vec& a, const arma::vec& b) {
+template<typename elem_type>
+double LMetric<1, true>::Evaluate(const arma::Col<elem_type>& a,
+ const arma::Col<elem_type>& b) {
double sum = 0;
for (size_t i = 0; i < a.n_elem; i++)
sum += fabs(a[i] - b[i]);
@@ -20,7 +22,9 @@
}
template<>
-double LMetric<1, false>::Evaluate(const arma::vec& a, const arma::vec& b) {
+template<typename elem_type>
+double LMetric<1, false>::Evaluate(const arma::Col<elem_type>& a,
+ const arma::Col<elem_type>& b) {
double sum = 0;
for (size_t i = 0; i < a.n_elem; i++)
sum += fabs(a[i] - b[i]);
@@ -30,7 +34,9 @@
// L2-metric specializations.
template<>
-double LMetric<2, true>::Evaluate(const arma::vec& a, const arma::vec& b) {
+template<typename elem_type>
+double LMetric<2, true>::Evaluate(const arma::Col<elem_type>& a,
+ const arma::Col<elem_type>& b) {
double sum = 0;
for (size_t i = 0; i < a.n_elem; i++)
sum += pow(a[i] - b[i], 2.0); // fabs() not necessary when squaring.
@@ -39,7 +45,9 @@
}
template<>
-double LMetric<2, false>::Evaluate(const arma::vec& a, const arma::vec& b) {
+template<typename elem_type>
+double LMetric<2, false>::Evaluate(const arma::Col<elem_type>& a,
+ const arma::Col<elem_type>& b) {
double sum = 0;
for (size_t i = 0; i < a.n_elem; i++)
sum += pow(a[i] - b[i], 2.0);
@@ -49,7 +57,9 @@
// L3-metric specialization (not very likely to be used, but just in case).
template<>
-double LMetric<3, true>::Evaluate(const arma::vec& a, const arma::vec& b) {
+template<typename elem_type>
+double LMetric<3, true>::Evaluate(const arma::Col<elem_type>& a,
+ const arma::Col<elem_type>& b) {
double sum = 0;
for (size_t i = 0; i < a.n_elem; i++)
sum += pow(fabs(a[i] - b[i]), 3.0);
@@ -58,7 +68,9 @@
}
template<>
-double LMetric<3, false>::Evaluate(const arma::vec& a, const arma::vec& b) {
+template<typename elem_type>
+double LMetric<3, false>::Evaluate(const arma::Col<elem_type>& a,
+ const arma::Col<elem_type>& b) {
double sum = 0;
for (size_t i = 0; i < a.n_elem; i++)
sum += pow(fabs(a[i] - b[i]), 3.0);
Modified: mlpack/trunk/src/mlpack/core/kernels/lmetric.hpp
===================================================================
--- mlpack/trunk/src/mlpack/core/kernels/lmetric.hpp 2011-11-17 17:41:26 UTC (rev 10312)
+++ mlpack/trunk/src/mlpack/core/kernels/lmetric.hpp 2011-11-17 17:44:31 UTC (rev 10313)
@@ -62,7 +62,9 @@
/**
* Computes the distance between two points.
*/
- static double Evaluate(const arma::vec& a, const arma::vec& b);
+ template<typename elem_type>
+ static double Evaluate(const arma::Col<elem_type>& a,
+ const arma::Col<elem_type>& b);
};
// Doxygen will not include this specialization.
@@ -72,8 +74,9 @@
// the unspecialized implementation of the one function is given below.
// Unspecialized implementation. This should almost never be used...
template<int t_pow, bool t_take_root>
-double LMetric<t_pow, t_take_root>::Evaluate(const arma::vec& a,
- const arma::vec& b) {
+template<typename elem_type>
+double LMetric<t_pow, t_take_root>::Evaluate(const arma::Col<elem_type>& a,
+ const arma::Col<elem_type>& b) {
double sum = 0;
for (size_t i = 0; i < a.n_elem; i++)
sum += pow(fabs(a[i] - b[i]), t_pow);
Modified: mlpack/trunk/src/mlpack/core/tree/CMakeLists.txt
===================================================================
--- mlpack/trunk/src/mlpack/core/tree/CMakeLists.txt 2011-11-17 17:41:26 UTC (rev 10312)
+++ mlpack/trunk/src/mlpack/core/tree/CMakeLists.txt 2011-11-17 17:44:31 UTC (rev 10313)
@@ -5,6 +5,8 @@
set(SOURCES
binary_space_tree.hpp
binary_space_tree_impl.hpp
+ binary_space_tree_crtp.hpp
+ binary_space_tree_impl_crtp.hpp
bounds.hpp
dballbound.hpp
dballbound_impl.hpp
Added: mlpack/trunk/src/mlpack/core/tree/binary_space_tree_crtp.hpp
===================================================================
--- mlpack/trunk/src/mlpack/core/tree/binary_space_tree_crtp.hpp (rev 0)
+++ mlpack/trunk/src/mlpack/core/tree/binary_space_tree_crtp.hpp 2011-11-17 17:44:31 UTC (rev 10313)
@@ -0,0 +1,294 @@
+/**
+ * @file spacetree.h
+ *
+ * Definition of generalized binary space partitioning tree (BinarySpaceTree).
+ */
+#ifndef __MLPACK_CORE_TREE_BINARY_SPACE_TREE_HPP
+#define __MLPACK_CORE_TREE_BINARY_SPACE_TREE_HPP
+
+#include <mlpack/core.h>
+
+#include "statistic.hpp"
+
+namespace mlpack {
+namespace tree /** Trees and tree-building procedures. */ {
+
+PARAM_MODULE("tree", "Parameters for the binary space partitioning tree.");
+PARAM_INT("leaf_size", "Leaf size used during tree construction.", "tree", 20);
+
+/**
+ * A binary space partitioning tree, such as a KD-tree or a ball tree. Once the
+ * bound and type of dataset is defined, the tree will construct itself. Call
+ * the constructor with the dataset to build the tree on, and the entire tree
+ * will be built.
+ *
+ * This particular tree does not allow growth, so you cannot add or delete nodes
+ * from it. If you need to add or delete a node, the better procedure is to
+ * rebuild the tree entirely.
+ *
+ * This tree does take one parameter, which is the leaf size to be used. You
+ * can set this at runtime with --tree/leaf_size [leaf_size]. You can also set
+ * it in your program using CLI:
+ *
+ * @code
+ * CLI::GetParam<int>("tree/leaf_size") = target_leaf_size;
+ * @endcode
+ *
+ * @param leaf_size Maximum number of points allowed in each leaf.
+ *
+ * @tparam TBound The bound used for each node. The valid types of bounds and
+ * the necessary skeleton interface for this class can be found in bounds/.
+ * @tparam TDataset The type of dataset (forced to be arma::mat for now).
+ * @tparam TStatistic Extra data contained in the node. See statistic.h for
+ * the necessary skeleton interface.
+ */
+template<typename T1,
+ typename Bound,
+ typename Statistic = EmptyStatistic>
+class BinarySpaceTree {
+ private:
+ //! The left child node.
+ BinarySpaceTree *left_;
+ //! The right child node.
+ BinarySpaceTree *right_;
+ //! The index of the first point in the dataset contained in this node (and
+ //! its children).
+ size_t begin_;
+ //! The number of points of the dataset contained in this node (and its
+ //! children).
+ size_t count_;
+ //! The bound object for this node.
+ Bound bound_;
+ //! Any extra data contained in the node.
+ Statistic stat_;
+
+ public:
+ /**
+ * Construct this as the root node of a binary space tree using the given
+ * dataset. This will modify the ordering of the points in the dataset!
+ *
+ * @param data Dataset to create tree from. This will be modified!
+ */
+ BinarySpaceTree(arma::Base<typename T1::elem_type, T1>& data);
+
+ /**
+ * Construct this as the root node of a binary space tree using the given
+ * dataset. This will modify the ordering of points in the dataset! A
+ * mapping of the old point indices to the new point indices is filled.
+ *
+ * @param data Dataset to create tree from. This will be modified!
+ * @param old_from_new Vector which will be filled with the old positions for
+ * each new point.
+ * @param new_from_old Vector which will be filled with the new positions for
+ * each old point.
+ */
+ BinarySpaceTree(arma::Base<typename T1::elem_type, T1>& data, std::vector<size_t>& old_from_new);
+
+ /**
+ * Construct this as the root node of a binary space tree using the given
+ * dataset. This will modify the ordering of points in the dataset! A
+ * mapping of the old point indices to the new point indices is filled, as
+ * well as a mapping of the new point indices to the old point indices.
+ *
+ * @param data Dataset to create tree from. This will be modified!
+ * @param old_from_new Vector which will be filled with the old positions for
+ * each new point.
+ * @param new_from_old Vector which will be filled with the new positions for
+ * each old point.
+ */
+ BinarySpaceTree(arma::Base<typename T1::elem_type, T1>& data,
+ std::vector<size_t>& old_from_new,
+ std::vector<size_t>& new_from_old);
+
+ /**
+ * Construct this node on a subset of the given matrix, starting at column
+ * begin_in and using count_in points. The ordering of that subset of points
+ * will be modified! This is used for recursive tree-building by the other
+ * constructors which don't specify point indices.
+ *
+ * @param data Dataset to create tree from. This will be modified!
+ * @param begin_in Index of point to start tree construction with.
+ * @param count_in Number of points to use to construct tree.
+ */
+ BinarySpaceTree(arma::Base<typename T1::elem_type, T1>& data,
+ size_t begin_in,
+ size_t count_in);
+
+ /**
+ * Construct this node on a subset of the given matrix, starting at column
+ * begin_in and using count_in points. The ordering of that subset of points
+ * will be modified! This is used for recursive tree-building by the other
+ * constructors which don't specify point indices.
+ *
+ * A mapping of the old point indices to the new point indices is filled, but
+ * it is expected that the vector is already allocated with size greater than
+ * or equal to (begin_in + count_in), and if that is not true, invalid memory
+ * reads (and writes) will occur.
+ *
+ * @param data Dataset to create tree from. This will be modified!
+ * @param begin_in Index of point to start tree construction with.
+ * @param count_in Number of points to use to construct tree.
+ * @param old_from_new Vector which will be filled with the old positions for
+ * each new point.
+ */
+ BinarySpaceTree(arma::Base<typename T1::elem_type, T1>& data,
+ size_t begin_in,
+ size_t count_in,
+ std::vector<size_t>& old_from_new);
+
+ /**
+ * Construct this node on a subset of the given matrix, starting at column
+ * begin_in and using count_in points. The ordering of that subset of points
+ * will be modified! This is used for recursive tree-building by the other
+ * constructors which don't specify point indices.
+ *
+ * A mapping of the old point indices to the new point indices is filled, as
+ * well as a mapping of the new point indices to the old point indices. It is
+ * expected that the vector is already allocated with size greater than or
+ * equal to (begin_in + count_in), and if that is not true, invalid memory
+ * reads (and writes) will occur.
+ *
+ * @param data Dataset to create tree from. This will be modified!
+ * @param begin_in Index of point to start tree construction with.
+ * @param count_in Number of points to use to construct tree.
+ * @param old_from_new Vector which will be filled with the old positions for
+ * each new point.
+ * @param new_from_old Vector which will be filled with the new positions for
+ * each old point.
+ */
+ BinarySpaceTree(arma::Base<typename T1::elem_type, T1>& data,
+ size_t begin_in,
+ size_t count_in,
+ std::vector<size_t>& old_from_new,
+ std::vector<size_t>& new_from_old);
+
+ /**
+ * Create an empty tree node.
+ */
+ BinarySpaceTree();
+
+ /**
+ * Deletes this node, deallocating the memory for the children and calling
+ * their destructors in turn. This will invalidate any pointers or references
+ * to any nodes which are children of this one.
+ */
+ ~BinarySpaceTree();
+
+ /**
+ * Find a node in this tree by its begin and count (const).
+ *
+ * Every node is uniquely identified by these two numbers.
+ * This is useful for communicating position over the network,
+ * when pointers would be invalid.
+ *
+ * @param begin_q The begin() of the node to find.
+ * @param count_q The count() of the node to find.
+ * @return The found node, or NULL if not found.
+ */
+ const BinarySpaceTree* FindByBeginCount(size_t begin_q,
+ size_t count_q) const;
+
+ /**
+ * Find a node in this tree by its begin and count.
+ *
+ * Every node is uniquely identified by these two numbers.
+ * This is useful for communicating position over the network,
+ * when pointers would be invalid.
+ *
+ * @param begin_q The begin() of the node to find.
+ * @param count_q The count() of the node to find.
+ * @return The found node, or NULL if not found.
+ */
+ BinarySpaceTree* FindByBeginCount(size_t begin_q, size_t count_q);
+
+ //! Return the bound object for this node.
+ const Bound& bound() const;
+ //! Return the bound object for this node.
+ Bound& bound();
+
+ //! Return the statistic object for this node.
+ const Statistic& stat() const;
+ //! Return the statistic object for this node.
+ Statistic& stat();
+
+ //! Return whether or not this node is a leaf (true if it has no children).
+ bool is_leaf() const;
+
+ /**
+ * Gets the left child of this node.
+ */
+ BinarySpaceTree *left() const;
+
+ /**
+ * Gets the right child of this node.
+ */
+ BinarySpaceTree *right() const;
+
+ /**
+ * Gets the index of the beginning point of this subset.
+ */
+ size_t begin() const;
+
+ /**
+ * Gets the index one beyond the last index in the subset.
+ */
+ size_t end() const;
+
+ /**
+ * Gets the number of points in this subset.
+ */
+ size_t count() const;
+
+ void Print() const;
+
+ private:
+ /**
+ * Splits the current node, assigning its left and right children recursively.
+ *
+ * @param data Dataset which we are using.
+ */
+ void SplitNode(arma::Base<typename T1::elem_type, T1>& data);
+
+ /**
+ * Splits the current node, assigning its left and right children recursively.
+ * Also returns a list of the changed indices.
+ *
+ * @param data Dataset which we are using.
+ * @param old_from_new Vector holding permuted indices.
+ */
+ void SplitNode(arma::Base<typename T1::elem_type, T1>& data,
+ std::vector<size_t>& old_from_new);
+
+ /**
+ * Find the index to split on for this node, given that we are splitting in
+ * the given split dimension on the specified split value.
+ *
+ * @param data Dataset which we are using.
+ * @param split_dim Dimension of dataset to split on.
+ * @param split_val Value to split on, in the given split dimension.
+ */
+ size_t GetSplitIndex(arma::Base<typename T1::elem_type, T1>& data,
+ int split_dim, double split_val);
+
+ /**
+ * Find the index to split on for this node, given that we are splitting in
+ * the given split dimension on the specified split value. Also returns a
+ * list of the changed indices.
+ *
+ * @param data Dataset which we are using.
+ * @param split_dim Dimension of dataset to split on.
+ * @param split_val Value to split on, in the given split dimension.
+ * @param old_from_new Vector holding permuted indices.
+ */
+ size_t GetSplitIndex(arma::Base<typename T1::elem_type, T1>& data,
+ int split_dim, double split_val, std::vector<size_t>& old_from_new);
+
+};
+
+}; // namespace tree
+}; // namespace mlpack
+
+// Include implementation.
+#include "binary_space_tree_impl_crtp.hpp"
+
+#endif
Added: mlpack/trunk/src/mlpack/core/tree/binary_space_tree_impl_crtp.hpp
===================================================================
--- mlpack/trunk/src/mlpack/core/tree/binary_space_tree_impl_crtp.hpp (rev 0)
+++ mlpack/trunk/src/mlpack/core/tree/binary_space_tree_impl_crtp.hpp 2011-11-17 17:44:31 UTC (rev 10313)
@@ -0,0 +1,481 @@
+/**
+ * @file binary_space_tree_impl.hpp
+ *
+ * Implementation of generalized space partitioning tree.
+ */
+#ifndef __MLPACK_CORE_TREE_BINARY_SPACE_TREE_IMPL_HPP
+#define __MLPACK_CORE_TREE_BINARY_SPACE_TREE_IMPL_HPP
+
+// In case it wasn't included already for some reason.
+#include "binary_space_tree_crtp.hpp"
+
+#include <mlpack/core/io/cli.hpp>
+#include <mlpack/core/io/log.hpp>
+
+namespace mlpack {
+namespace tree {
+
+// Each of these overloads is kept as a separate function to keep the overhead
+// from the two std::vectors out, if possible.
+template<typename T1, typename Bound, typename Statistic>
+BinarySpaceTree<T1, Bound, Statistic>::BinarySpaceTree(
+ arma::Base<typename T1::elem_type, T1>& data) : left_(NULL), right_(NULL), begin_(0), /* This root node starts at index 0, */
+ count_(data.get_ref().n_cols), /* and spans all of the dataset. */
+ bound_(data.get_ref().n_rows),
+ stat_() {
+ // Do the actual splitting of this node.
+ SplitNode(data);
+}
+
+template<typename T1, typename Bound, typename Statistic>
+BinarySpaceTree<T1, Bound, Statistic>::BinarySpaceTree(
+ arma::Base<typename T1::elem_type, T1>& data,
+ std::vector<size_t>& old_from_new) :
+ left_(NULL),
+ right_(NULL),
+ begin_(0),
+ count_(data.get_ref().n_cols),
+ bound_(data.get_ref().n_rows),
+ stat_() {
+ // Initialize old_from_new correctly.
+ old_from_new.resize(data.get_ref().n_cols);
+ for (size_t i = 0; i < data.get_ref().n_cols; i++)
+ old_from_new[i] = i; // Fill with unharmed indices.
+
+ // Now do the actual splitting.
+ SplitNode(data, old_from_new);
+}
+
+template<typename T1, typename Bound, typename Statistic>
+BinarySpaceTree<T1, Bound, Statistic>::BinarySpaceTree(
+ arma::Base<typename T1::elem_type, T1>& data,
+ std::vector<size_t>& old_from_new,
+ std::vector<size_t>& new_from_old) :
+ left_(NULL),
+ right_(NULL),
+ begin_(0),
+ count_(data.get_ref().n_cols),
+ bound_(data.get_ref().n_rows),
+ stat_() {
+ // Initialize the old_from_new vector correctly.
+ old_from_new.resize(data.get_ref().n_cols);
+ for (size_t i = 0; i < data.get_ref().n_cols; i++)
+ old_from_new[i] = i; // Fill with unharmed indices.
+
+ // Now do the actual splitting.
+ SplitNode(data, old_from_new);
+
+ // Map the new_from_old indices correctly.
+ new_from_old.resize(data.get_ref().n_cols);
+ for (size_t i = 0; i < data.get_ref().n_cols; i++)
+ new_from_old[old_from_new[i]] = i;
+}
+
+template<typename T1, typename Bound, typename Statistic>
+BinarySpaceTree<T1, Bound, Statistic>::BinarySpaceTree(
+ arma::Base<typename T1::elem_type, T1>& data,
+ size_t begin_in,
+ size_t count_in) :
+ left_(NULL),
+ right_(NULL),
+ begin_(begin_in),
+ count_(count_in),
+ bound_(data.get_ref().n_rows),
+ stat_() {
+ // Perform the actual splitting.
+ SplitNode(data);
+}
+
+template<typename T1, typename Bound, typename Statistic>
+BinarySpaceTree<T1, Bound, Statistic>::BinarySpaceTree(
+ arma::Base<typename T1::elem_type, T1>& data,
+ size_t begin_in,
+ size_t count_in,
+ std::vector<size_t>& old_from_new) :
+ left_(NULL),
+ right_(NULL),
+ begin_(begin_in),
+ count_(count_in),
+ bound_(data.get_ref().n_rows),
+ stat_() {
+ // Hopefully the vector is initialized correctly! We can't check that
+ // entirely but we can do a minor sanity check.
+ assert(old_from_new.size() == data.get_ref().n_cols);
+
+ // Perform the actual splitting.
+ SplitNode(data, old_from_new);
+}
+
+template<typename T1, typename Bound, typename Statistic>
+BinarySpaceTree<T1, Bound, Statistic>::BinarySpaceTree(
+ arma::Base<typename T1::elem_type, T1>& data,
+ size_t begin_in,
+ size_t count_in,
+ std::vector<size_t>& old_from_new,
+ std::vector<size_t>& new_from_old) :
+ left_(NULL),
+ right_(NULL),
+ begin_(begin_in),
+ count_(count_in),
+ bound_(data.get_ref().n_rows),
+ stat_() {
+ // Hopefully the vector is initialized correctly! We can't check that
+ // entirely but we can do a minor sanity check.
+ assert(old_from_new.size() == data.get_ref().n_cols);
+
+ // Perform the actual splitting.
+ SplitNode(data, old_from_new);
+
+ // Map the new_from_old indices correctly.
+ new_from_old.resize(data.get_reg().n_cols);
+ for (size_t i = 0; i < data.get_ref().n_cols; i++)
+ new_from_old[old_from_new[i]] = i;
+}
+
+template<typename T1, typename Bound, typename Statistic>
+BinarySpaceTree<T1, Bound, Statistic>::BinarySpaceTree() :
+ left_(NULL),
+ right_(NULL),
+ begin_(0),
+ count_(0),
+ bound_(),
+ stat_() {
+ // Nothing to do.
+}
+
+/**
+ * Deletes this node, deallocating the memory for the children and calling their
+ * destructors in turn. This will invalidate any pointers or references to any
+ * nodes which are children of this one.
+ */
+template<typename T1, typename Bound, typename Statistic>
+BinarySpaceTree<T1, Bound, Statistic>::~BinarySpaceTree() {
+ if (left_)
+ delete left_;
+ if (right_)
+ delete right_;
+}
+
+/**
+ * Find a node in this tree by its begin and count.
+ *
+ * Every node is uniquely identified by these two numbers.
+ * This is useful for communicating position over the network,
+ * when pointers would be invalid.
+ *
+ * @param begin_q the begin() of the node to find
+ * @param count_q the count() of the node to find
+ * @return the found node, or NULL
+ */
+template<typename T1, typename Bound, typename Statistic>
+const BinarySpaceTree<T1, Bound, Statistic>*
+BinarySpaceTree<T1, Bound, Statistic>::FindByBeginCount(size_t begin_q,
+ size_t count_q) const {
+
+ mlpack::Log::Assert(begin_q >= begin_);
+ mlpack::Log::Assert(count_q <= count_);
+ if (begin_ == begin_q && count_ == count_q)
+ return this;
+ else if (is_leaf())
+ return NULL;
+ else if (begin_q < right_->begin_)
+ return left_->FindByBeginCount(begin_q, count_q);
+ else
+ return right_->FindByBeginCount(begin_q, count_q);
+}
+
+/**
+ * Find a node in this tree by its begin and count (const).
+ *
+ * Every node is uniquely identified by these two numbers.
+ * This is useful for communicating position over the network,
+ * when pointers would be invalid.
+ *
+ * @param begin_q the begin() of the node to find
+ * @param count_q the count() of the node to find
+ * @return the found node, or NULL
+ */
+template<typename T1, typename Bound, typename Statistic>
+BinarySpaceTree<T1, Bound, Statistic>*
+BinarySpaceTree<T1, Bound, Statistic>::FindByBeginCount(size_t begin_q,
+ size_t count_q) {
+
+ mlpack::Log::Assert(begin_q >= begin_);
+ mlpack::Log::Assert(count_q <= count_);
+ if (begin_ == begin_q && count_ == count_q)
+ return this;
+ else if (is_leaf())
+ return NULL;
+ else if (begin_q < right_->begin_)
+ return left_->FindByBeginCount(begin_q, count_q);
+ else
+ return right_->FindByBeginCount(begin_q, count_q);
+}
+
+template<typename T1, typename Bound, typename Statistic>
+const Bound& BinarySpaceTree<T1, Bound, Statistic>::bound() const {
+ return bound_;
+}
+
+template<typename T1, typename Bound, typename Statistic>
+Bound& BinarySpaceTree<T1, Bound, Statistic>::bound() {
+ return bound_;
+}
+
+template<typename T1, typename Bound, typename Statistic>
+const Statistic& BinarySpaceTree<T1, Bound, Statistic>::stat() const {
+ return stat_;
+}
+
+template<typename T1, typename Bound, typename Statistic>
+Statistic& BinarySpaceTree<T1, Bound, Statistic>::stat() {
+ return stat_;
+}
+
+template<typename T1, typename Bound, typename Statistic>
+bool BinarySpaceTree<T1, Bound, Statistic>::is_leaf() const {
+ return !left_;
+}
+
+/**
+ * Gets the left branch of the tree.
+ */
+template<typename T1, typename Bound, typename Statistic>
+BinarySpaceTree<T1, Bound, Statistic>*
+BinarySpaceTree<T1, Bound, Statistic>::left() const {
+ // TODO: Const correctness
+ return left_;
+}
+
+/**
+ * Gets the right branch.
+ */
+template<typename T1, typename Bound, typename Statistic>
+BinarySpaceTree<T1, Bound, Statistic>*
+BinarySpaceTree<T1, Bound, Statistic>::right() const {
+ // TODO: Const correctness
+ return right_;
+}
+
+/**
+ * Gets the index of the begin point of this subset.
+ */
+template<typename T1, typename Bound, typename Statistic>
+size_t BinarySpaceTree<T1, Bound, Statistic>::begin() const {
+ return begin_;
+}
+
+/**
+ * Gets the index one beyond the last index in the series.
+ */
+template<typename T1, typename Bound, typename Statistic>
+size_t BinarySpaceTree<T1, Bound, Statistic>::end() const {
+ return begin_ + count_;
+}
+
+/**
+ * Gets the number of points in this subset.
+ */
+template<typename T1, typename Bound, typename Statistic>
+size_t BinarySpaceTree<T1, Bound, Statistic>::count() const {
+ return count_;
+}
+
+template<typename T1, typename Bound, typename Statistic>
+void BinarySpaceTree<T1, Bound, Statistic>::Print() const {
+ printf("node: %d to %d: %d points total\n",
+ begin_, begin_ + count_ - 1, count_);
+ if (!is_leaf()) {
+ left_->Print();
+ right_->Print();
+ }
+}
+
+template<typename T1, typename Bound, typename Statistic>
+void BinarySpaceTree<T1, Bound, Statistic>::SplitNode(
+ arma::Base<typename T1::elem_type, T1>& data) {
+ // This should be a single function for Bound.
+ // We need to expand the bounds of this node properly.
+ for (size_t i = begin_; i < (begin_ + count_); i++)
+ bound_ |= data.get_ref().unsafe_col(i);
+
+ // Now, check if we need to split at all.
+ if (count_ <= (size_t) CLI::GetParam<int>("tree/leaf_size"))
+ return; // We can't split this.
+
+ // Figure out which dimension to split on.
+ size_t split_dim = data.get_ref().n_rows; // Indicate invalid by max_dim + 1.
+ double max_width = -1;
+
+ // Find the split dimension.
+ for (size_t d = 0; d < data.get_ref().n_rows; d++) {
+ double width = bound_[d].width();
+
+ if (width > max_width) {
+ max_width = width;
+ split_dim = d;
+ }
+ }
+
+ // Split in the middle of that dimension.
+ double split_val = bound_[split_dim].mid();
+
+ if (max_width == 0) // All these points are the same. We can't split.
+ return;
+
+ // Perform the actual splitting. This will order the dataset such that points
+ // with value in dimension split_dim less than or equal to split_val are on
+ // the left of split_col, and points with value in dimension split_dim greater
+ // than split_val are on the right side of split_col.
+ size_t split_col = GetSplitIndex(data, split_dim, split_val);
+
+ // Now that we know the split column, we will recursively split the children
+ // by calling their constructors (which perform this splitting process).
+ left_ = new BinarySpaceTree<T1, Bound, Statistic>(data, begin_,
+ split_col - begin_);
+ right_ = new BinarySpaceTree<T1, Bound, Statistic>(data, split_col,
+ begin_ + count_ - split_col);
+}
+
+template<typename T1, typename Bound, typename Statistic>
+void BinarySpaceTree<T1, Bound, Statistic>::SplitNode(
+ arma::Base<typename T1::elem_type, T1>& data,
+ std::vector<size_t>& old_from_new) {
+ // This should be a single function for Bound.
+ // We need to expand the bounds of this node properly.
+ for (size_t i = begin_; i < (begin_ + count_); i++)
+ bound_ |= data.get_ref().unsafe_col(i);
+
+ // First, check if we need to split at all.
+ if (count_ <= (size_t) CLI::GetParam<int>("tree/leaf_size"))
+ return; // We can't split this.
+
+ // Figure out which dimension to split on.
+ size_t split_dim = data.get_ref().n_rows; // Indicate invalid by max_dim + 1.
+ double max_width = -1;
+
+ // Find the split dimension.
+ for (size_t d = 0; d < data.get_ref().n_rows; d++) {
+ double width = bound_[d].width();
+
+ if (width > max_width) {
+ max_width = width;
+ split_dim = d;
+ }
+ }
+
+ // Split in the middle of that dimension.
+ double split_val = bound_[split_dim].mid();
+
+ if (max_width == 0) // All these points are the same. We can't split.
+ return;
+
+ // Perform the actual splitting. This will order the dataset such that points
+ // with value in dimension split_dim less than or equal to split_val are on
+ // the left of split_col, and points with value in dimension split_dim greater
+ // than split_val are on the right side of split_col.
+ size_t split_col = GetSplitIndex(data, split_dim, split_val, old_from_new);
+
+ // Now that we know the split column, we will recursively split the children
+ // by calling their constructors (which perform this splitting process).
+ left_ = new BinarySpaceTree<T1, Bound, Statistic>(data, begin_,
+ split_col - begin_, old_from_new);
+ right_ = new BinarySpaceTree<T1, Bound, Statistic>(data, split_col,
+ begin_ + count_ - split_col, old_from_new);
+}
+
+template<typename T1, typename Bound, typename Statistic>
+size_t BinarySpaceTree<T1, Bound, Statistic>::GetSplitIndex(
+ arma::Base<typename T1::elem_type, T1>& data,
+ int split_dim,
+ double split_val) {
+ // This method modifies the input dataset. We loop both from the left and
+ // right sides of the points contained in this node. The points less than
+ // split_val should be on the left side of the matrix, and the points greater
+ // than split_val should be on the right side of the matrix.
+ size_t left = begin_;
+ size_t right = begin_ + count_ - 1;
+
+ // First half-iteration of the loop is out here because the termination
+ // condition is in the middle.
+ while ((data.get_ref()(split_dim, left) < split_val) && (left <= right))
+ left++;
+ while ((data.get_ref()(split_dim, right) >= split_val) && (left <= right))
+ right--;
+
+ while(left <= right) {
+ // Swap columns.
+ data.get_ref().swap_cols(left, right);
+
+ // See how many points on the left are correct. When they are correct,
+ // increase the left counter accordingly. When we encounter one that isn't
+ // correct, stop. We will switch it later.
+ while ((data.get_ref()(split_dim, left) < split_val) && (left <= right))
+ left++;
+
+ // Now see how many points on the right are correct. When they are correct,
+ // decrease the right counter accordingly. When we encounter one that isn't
+ // correct, stop. We will switch it with the wrong point we found in the
+ // previous loop.
+ while ((data.get_ref()(split_dim, right) >= split_val) && (left <= right))
+ right--;
+ }
+
+ assert(left == right + 1);
+
+ return left;
+}
+
+template<typename T1, typename Bound, typename Statistic>
+size_t BinarySpaceTree<T1, Bound, Statistic>::GetSplitIndex(
+ arma::Base<typename T1::elem_type, T1>& data,
+ int split_dim,
+ double split_val,
+ std::vector<size_t>& old_from_new) {
+ // This method modifies the input dataset. We loop both from the left and
+ // right sides of the points contained in this node. The points less than
+ // split_val should be on the left side of the matrix, and the points greater
+ // than split_val should be on the right side of the matrix.
+ size_t left = begin_;
+ size_t right = begin_ + count_ -1;
+
+ // First half-iteration of the loop is out here because the termination
+ // condition is in the middle.
+ while ((data.get_ref()(split_dim, left) < split_val) && (left <= right))
+ left++;
+ while ((data.get_ref()(split_dim, right) >= split_val) && (left <= right))
+ right--;
+
+ while(left <= right) {
+ // Swap columns.
+ T1 ref = data.get_ref();
+ ref.swap_cols(left, right);
+
+ // Update the indices for what we changed.
+ size_t t = old_from_new[left];
+ old_from_new[left] = old_from_new[right];
+ old_from_new[right] = t;
+
+ // See how many points on the left are correct. When they are correct,
+ // increase the left counter accordingly. When we encounter one that isn't
+ // correct, stop. We will switch it later.
+ while ((data.get_ref()(split_dim, left) < split_val) && (left <= right))
+ left++;
+
+ // Now see how many points on the right are correct. When they are correct,
+ // decrease the right counter accordingly. When we encounter one that isn't
+ // correct, stop. We will switch it with the wrong point we found in the
+ // previous loop.
+ while ((data.get_ref()(split_dim, right) >= split_val) && (left <= right))
+ right--;
+ }
+
+ assert(left == right + 1);
+
+ return left;
+}
+
+}; // namespace tree
+}; // namespace mlpack
+
+#endif
Modified: mlpack/trunk/src/mlpack/core/tree/hrectbound.hpp
===================================================================
--- mlpack/trunk/src/mlpack/core/tree/hrectbound.hpp 2011-11-17 17:41:26 UTC (rev 10312)
+++ mlpack/trunk/src/mlpack/core/tree/hrectbound.hpp 2011-11-17 17:44:31 UTC (rev 10313)
@@ -73,7 +73,8 @@
/**
* Calculates minimum bound-to-point squared distance.
*/
- double MinDistance(const arma::vec& point) const;
+ template<typename elem_type>
+ double MinDistance(const arma::Col<elem_type>& point) const;
/**
* Calculates minimum bound-to-bound squared distance.
@@ -85,7 +86,8 @@
/**
* Calculates maximum bound-to-point squared distance.
*/
- double MaxDistance(const arma::vec& point) const;
+ template<typename elem_type>
+ double MaxDistance(const arma::Col<elem_type>& point) const;
/**
* Computes maximum distance.
@@ -105,7 +107,8 @@
/**
* Expands this region to include a new point.
*/
- HRectBound& operator|=(const arma::vec& vector);
+ template<typename elem_type>
+ HRectBound& operator|=(const arma::Col<elem_type>& vector);
/**
* Expands this region to encompass another bound.
Modified: mlpack/trunk/src/mlpack/core/tree/hrectbound_impl.hpp
===================================================================
--- mlpack/trunk/src/mlpack/core/tree/hrectbound_impl.hpp 2011-11-17 17:41:26 UTC (rev 10312)
+++ mlpack/trunk/src/mlpack/core/tree/hrectbound_impl.hpp 2011-11-17 17:44:31 UTC (rev 10313)
@@ -121,7 +121,8 @@
* Calculates minimum bound-to-point squared distance.
*/
template<int t_pow>
-double HRectBound<t_pow>::MinDistance(const arma::vec& point) const {
+template<typename elem_type>
+double HRectBound<t_pow>::MinDistance(const arma::Col<elem_type>& point) const {
assert(point.n_elem == dim_);
double sum = 0;
@@ -181,7 +182,8 @@
* Calculates maximum bound-to-point squared distance.
*/
template<int t_pow>
-double HRectBound<t_pow>::MaxDistance(const arma::vec& point) const {
+template<typename elem_type>
+double HRectBound<t_pow>::MaxDistance(const arma::Col<elem_type>& point) const {
double sum = 0;
assert(point.n_elem == dim_);
@@ -287,7 +289,8 @@
* Expands this region to include a new point.
*/
template<int t_pow>
-HRectBound<t_pow>& HRectBound<t_pow>::operator|=(const arma::vec& vector) {
+template<typename elem_type>
+HRectBound<t_pow>& HRectBound<t_pow>::operator|=(const arma::Col<elem_type>& vector) {
Log::Assert(vector.n_elem == dim_);
for (size_t i = 0; i < dim_; i++) {
Modified: mlpack/trunk/src/mlpack/core/utilities/timers.hpp
===================================================================
--- mlpack/trunk/src/mlpack/core/utilities/timers.hpp 2011-11-17 17:41:26 UTC (rev 10312)
+++ mlpack/trunk/src/mlpack/core/utilities/timers.hpp 2011-11-17 17:44:31 UTC (rev 10313)
@@ -31,7 +31,7 @@
*
* @param timerName The name of the timer in question.
*/
- static timeval GetTimer(const char* timerName);
+ static timeval Get(const char* timerName);
/*
* Prints the specified timer. If it took longer than a minute to complete
@@ -39,7 +39,7 @@
*
* @param timerName The name of the timer in question.
*/
- static void PrintTimer(const char* timerName);
+ static void Print(const char* timerName);
/*
* Initializes a timer, available like a normal value specified on
@@ -47,7 +47,7 @@
*
* @param timerName The name of the timer in question.
*/
- static void StartTimer(const char* timerName);
+ static void Start(const char* timerName);
/*
* Halts the timer, and replaces it's value with
@@ -55,7 +55,7 @@
*
* @param timerName The name of the timer in question.
*/
- static void StopTimer(const char* timerName);
+ static void Stop(const char* timerName);
private:
static std::map<std::string, timeval> timers;
Modified: mlpack/trunk/src/mlpack/methods/neighbor_search/allkfn_main.cc
===================================================================
--- mlpack/trunk/src/mlpack/methods/neighbor_search/allkfn_main.cc 2011-11-17 17:41:26 UTC (rev 10312)
+++ mlpack/trunk/src/mlpack/methods/neighbor_search/allkfn_main.cc 2011-11-17 17:44:31 UTC (rev 10313)
@@ -45,7 +45,7 @@
string reference_file = CLI::GetParam<string>("reference_file");
string output_file = CLI::GetParam<string>("output_file");
- arma::mat reference_data;
+ arma::vec reference_data;
arma::Mat<size_t> neighbors;
arma::mat distances;
@@ -74,7 +74,7 @@
if (CLI::GetParam<string>("query_file") != "") {
string query_file = CLI::GetParam<string>("query_file");
- arma::mat query_data;
+ arma::vec query_data;
if (!data::Load(query_file.c_str(), query_data))
Log::Fatal << "Query file " << query_file << " not found" << endl;
Modified: mlpack/trunk/src/mlpack/methods/neighbor_search/neighbor_search.h
===================================================================
--- mlpack/trunk/src/mlpack/methods/neighbor_search/neighbor_search.h 2011-11-17 17:41:26 UTC (rev 10312)
+++ mlpack/trunk/src/mlpack/methods/neighbor_search/neighbor_search.h 2011-11-17 17:44:31 UTC (rev 10313)
@@ -9,7 +9,7 @@
#include <mlpack/core.h>
#include <mlpack/core/tree/bounds.hpp>
-#include <mlpack/core/tree/binary_space_tree.hpp>
+#include <mlpack/core/tree/binary_space_tree_crtp.hpp>
#include <vector>
#include <string>
@@ -54,7 +54,8 @@
* @tparam Kernel The kernel function; see kernel::ExampleKernel.
* @tparam SortPolicy The sort policy for distances; see NearestNeighborSort.
*/
-template<typename Kernel = mlpack::kernel::SquaredEuclideanDistance,
+template<typename T1 = arma::mat,
+ typename Kernel = mlpack::kernel::SquaredEuclideanDistance,
typename SortPolicy = NearestNeighborSort>
class NeighborSearch {
@@ -79,13 +80,13 @@
* Simple typedef for the trees, which use a bound and a QueryStat (to store
* distances for each node). The bound should be configurable...
*/
- typedef tree::BinarySpaceTree<bound::HRectBound<2>, QueryStat> TreeType;
+ typedef tree::BinarySpaceTree<T1, bound::HRectBound<2>, QueryStat> TreeType;
private:
//! Reference dataset.
- arma::mat references_;
+ arma::Base<typename T1::elem_type, T1> references_;
//! Query dataset (may not be given).
- arma::mat queries_;
+ arma::Base<typename T1::elem_type, T1> queries_;
//! Instantiation of kernel.
Kernel kernel_;
@@ -134,7 +135,8 @@
* process! Defaults to false.
* @param kernel An optional instance of the Kernel class.
*/
- NeighborSearch(arma::mat& queries_in, arma::mat& references_in,
+ NeighborSearch(arma::Base<typename T1::elem_type, T1>& queries_in,
+ arma::Base<typename T1::elem_type, T1>&references_in,
bool alias_matrix = false, Kernel kernel = Kernel());
/**
@@ -151,8 +153,8 @@
* process! Defaults to false.
* @param kernel An optional instance of the Kernel class.
*/
- NeighborSearch(arma::mat& references_in, bool alias_matrix = false,
- Kernel kernel = Kernel());
+ NeighborSearch(arma::Base<typename T1::elem_type, T1>& references_in,
+ bool alias_matrix = false, Kernel kernel = Kernel());
/**
* Delete the NeighborSearch object. The tree is the only member we are
@@ -214,7 +216,8 @@
* @param reference_node Reference node.
* @param best_dist_so_far Best distance to a node so far -- used for pruning.
*/
- void ComputeSingleNeighborsRecursion_(size_t point_id, arma::vec& point,
+ void ComputeSingleNeighborsRecursion_(size_t point_id,
+ arma::Col<typename T1::elem_type>& point,
TreeType* reference_node,
double& best_dist_so_far);
Modified: mlpack/trunk/src/mlpack/methods/neighbor_search/neighbor_search_impl.h
===================================================================
--- mlpack/trunk/src/mlpack/methods/neighbor_search/neighbor_search_impl.h 2011-11-17 17:41:26 UTC (rev 10312)
+++ mlpack/trunk/src/mlpack/methods/neighbor_search/neighbor_search_impl.h 2011-11-17 17:44:31 UTC (rev 10313)
@@ -13,15 +13,13 @@
// We call an advanced constructor of arma::mat which allows us to alias a
// matrix (if the user has asked for that).
-template<typename Kernel, typename SortPolicy>
-NeighborSearch<Kernel, SortPolicy>::NeighborSearch(arma::mat& queries_in,
- arma::mat& references_in,
+template<typename T1, typename Kernel, typename SortPolicy>
+NeighborSearch<T1, Kernel, SortPolicy>::NeighborSearch(arma::Base<typename T1::elem_type, T1>& queries_in,
+ arma::Base<typename T1::elem_type, T1>& references_in,
bool alias_matrix,
Kernel kernel) :
- references_(references_in.memptr(), references_in.n_rows,
- references_in.n_cols, !alias_matrix),
- queries_(queries_in.memptr(), queries_in.n_rows, queries_in.n_cols,
- !alias_matrix),
+ references_(references_in), //Need to figure out how to push alias
+ queries_(queries_in),
kernel_(kernel),
naive_(CLI::GetParam<bool>("neighbor_search/naive_mode")),
dual_mode_(!(naive_ || CLI::GetParam<bool>("neighbor_search/single_mode"))),
@@ -33,20 +31,20 @@
// Get the leaf size; naive ensures that the entire tree is one node
if (naive_)
CLI::GetParam<int>("tree/leaf_size") =
- std::max(queries_.n_cols, references_.n_cols);
+ std::max(queries_.get_ref().n_cols, references_.get_ref().n_cols);
// K-nearest neighbors initialization
knns_ = CLI::GetParam<int>("neighbor_search/k");
// Initialize the list of nearest neighbor candidates
- neighbor_indices_.set_size(knns_, queries_.n_cols);
+ neighbor_indices_.set_size(knns_, queries_.get_ref().n_cols);
// Initialize the vector of upper bounds for each point.
- neighbor_distances_.set_size(knns_, queries_.n_cols);
+ neighbor_distances_.set_size(knns_, queries_.get_ref().n_cols);
neighbor_distances_.fill(SortPolicy::WorstDistance());
// We'll time tree building
- Timers::StartTimer("neighbor_search/tree_building");
+ Timers::Start("neighbor_search/tree_building");
// This call makes each tree from a matrix, leaf size, and two arrays
// that record the permutation of the data points
@@ -54,20 +52,16 @@
reference_tree_ = new TreeType(references_, old_from_new_references_);
// Stop the timer we started above
- Timers::StopTimer("neighbor_search/tree_building");
+ Timers::Stop("neighbor_search/tree_building");
}
// We call an advanced constructor of arma::mat which allows us to alias a
// matrix (if the user has asked for that).
-template<typename Kernel, typename SortPolicy>
-NeighborSearch<Kernel, SortPolicy>::NeighborSearch(arma::mat& references_in,
+template<typename T1, typename Kernel, typename SortPolicy>
+NeighborSearch<T1, Kernel, SortPolicy>::NeighborSearch(arma::Base<typename T1::elem_type, T1>& references_in,
bool alias_matrix,
Kernel kernel) :
- references_(references_in.memptr(), references_in.n_rows,
- references_in.n_cols, !alias_matrix),
- queries_(references_.memptr(), references_.n_rows, references_.n_cols,
- false),
- kernel_(kernel),
+ references_(references_in), queries_(references_), kernel_(kernel),
naive_(CLI::GetParam<bool>("neighbor_search/naive_mode")),
dual_mode_(!(naive_ || CLI::GetParam<bool>("neighbor_search/single_mode"))),
number_of_prunes_(0) {
@@ -75,20 +69,20 @@
// Get the leaf size from the module
if (naive_)
CLI::GetParam<int>("tree/leaf_size") =
- std::max(queries_.n_cols, references_.n_cols);
+ std::max(queries_.get_ref().n_cols, references_.get_ref().n_cols);
// K-nearest neighbors initialization
knns_ = CLI::GetParam<int>("neighbor_search/k");
// Initialize the list of nearest neighbor candidates
- neighbor_indices_.set_size(knns_, references_.n_cols);
+ neighbor_indices_.set_size(knns_, references_.get_ref().n_cols);
// Initialize the vector of upper bounds for each point.
- neighbor_distances_.set_size(knns_, references_.n_cols);
+ neighbor_distances_.set_size(knns_, references_.get_ref().n_cols);
neighbor_distances_.fill(SortPolicy::WorstDistance());
// We'll time tree building
- Timers::StartTimer("neighbor_search/tree_building");
+ Timers::Start("neighbor_search/tree_building");
// This call makes each tree from a matrix, leaf size, and two arrays
// that record the permutation of the data points
@@ -97,15 +91,15 @@
reference_tree_ = new TreeType(references_, old_from_new_references_);
// Stop the timer we started above
- Timers::StopTimer("neighbor_search/tree_building");
+ Timers::Stop("neighbor_search/tree_building");
}
/**
* The tree is the only member we are responsible for deleting. The others will
* take care of themselves.
*/
-template<typename Kernel, typename SortPolicy>
-NeighborSearch<Kernel, SortPolicy>::~NeighborSearch() {
+template<typename T1, typename Kernel, typename SortPolicy>
+NeighborSearch<T1, Kernel, SortPolicy>::~NeighborSearch() {
if (reference_tree_ != query_tree_)
delete reference_tree_;
if (query_tree_ != NULL)
@@ -115,8 +109,8 @@
/**
* Performs exhaustive computation between two leaves.
*/
-template<typename Kernel, typename SortPolicy>
-void NeighborSearch<Kernel, SortPolicy>::ComputeBaseCase_(
+template<typename T1, typename Kernel, typename SortPolicy>
+void NeighborSearch<T1, Kernel, SortPolicy>::ComputeBaseCase_(
TreeType* query_node,
TreeType* reference_node) {
// Used to find the query node's new upper bound
@@ -128,10 +122,11 @@
query_index < query_node->end(); query_index++) {
// Get the query point from the matrix
- arma::vec query_point = queries_.unsafe_col(query_index);
+ arma::Col<typename T1::elem_type> query_point =
+ queries_.get_ref().unsafe_col(query_index);
- double query_to_node_distance =
- SortPolicy::BestPointToNodeDistance(query_point, reference_node);
+ double query_to_node_distance = SortPolicy::BestPointToNodeDistance
+ (query_point, reference_node);
if (SortPolicy::IsBetter(query_to_node_distance,
neighbor_distances_(knns_ - 1, query_index))) {
@@ -142,7 +137,8 @@
// Confirm that points do not identify themselves as neighbors
// in the monochromatic case
if (reference_node != query_node || reference_index != query_index) {
- arma::vec reference_point = references_.unsafe_col(reference_index);
+ arma::Col<typename T1::elem_type> reference_point =
+ references_.get_ref().unsafe_col(reference_index);
double distance = kernel_.Evaluate(query_point, reference_point);
@@ -174,8 +170,8 @@
/**
* The recursive function for dual tree
*/
-template<typename Kernel, typename SortPolicy>
-void NeighborSearch<Kernel, SortPolicy>::ComputeDualNeighborsRecursion_(
+template<typename T1, typename Kernel, typename SortPolicy>
+void NeighborSearch<T1, Kernel, SortPolicy>::ComputeDualNeighborsRecursion_(
TreeType* query_node,
TreeType* reference_node,
double lower_bound) {
@@ -288,10 +284,10 @@
} // ComputeDualNeighborsRecursion_
-template<typename Kernel, typename SortPolicy>
-void NeighborSearch<Kernel, SortPolicy>::ComputeSingleNeighborsRecursion_(
+template<typename T1, typename Kernel, typename SortPolicy>
+void NeighborSearch<T1, Kernel, SortPolicy>::ComputeSingleNeighborsRecursion_(
size_t point_id,
- arma::vec& point,
+ arma::Col<typename T1::elem_type>& point,
TreeType* reference_node,
double& best_dist_so_far) {
@@ -302,9 +298,11 @@
reference_index < reference_node->end(); reference_index++) {
// Confirm that points do not identify themselves as neighbors
// in the monochromatic case
- if (!(references_.memptr() == queries_.memptr() &&
+ // SpMat does NOT currently implement memptr
+ if (!(references_.get_ref().memptr() == queries_.get_ref().memptr() &&
reference_index == point_id)) {
- arma::vec reference_point = references_.unsafe_col(reference_index);
+ arma::Col<typename T1::elem_type> reference_point =
+ references_.get_ref().unsafe_col(reference_index);
double distance = kernel_.Evaluate(point, reference_point);
@@ -322,10 +320,10 @@
best_dist_so_far = neighbor_distances_(knns_ - 1, point_id);
} else {
// We'll order the computation by distance.
- double left_distance = SortPolicy::BestPointToNodeDistance(point,
- reference_node->left());
- double right_distance = SortPolicy::BestPointToNodeDistance(point,
- reference_node->right());
+ double left_distance = SortPolicy::BestPointToNodeDistance
+ (point, reference_node->left());
+ double right_distance = SortPolicy::BestPointToNodeDistance
+ (point, reference_node->right());
// Recurse in the best direction first.
if (SortPolicy::IsBetter(left_distance, right_distance)) {
@@ -361,12 +359,12 @@
* Computes the best neighbors and stores them in resulting_neighbors and
* distances.
*/
-template<typename Kernel, typename SortPolicy>
-void NeighborSearch<Kernel, SortPolicy>::ComputeNeighbors(
+template<typename T1, typename Kernel, typename SortPolicy>
+void NeighborSearch<T1, Kernel, SortPolicy>::ComputeNeighbors(
arma::Mat<size_t>& resulting_neighbors,
arma::mat& distances) {
- Timers::StartTimer("neighbor_search/computing_neighbors");
+ Timers::Start("neighbor_search/computing_neighbors");
if (naive_) {
// Run the base case computation on all nodes
if (query_tree_)
@@ -385,19 +383,23 @@
reference_tree_));
}
} else {
- size_t chunk = queries_.n_cols / 10;
+ size_t chunk = queries_.get_ref().n_cols / 10;
for(size_t i = 0; i < 10; i++) {
for(size_t j = 0; j < chunk; j++) {
- arma::vec point = queries_.unsafe_col(i * chunk + j);
+ arma::Col<typename T1::elem_type> point =
+ queries_.get_ref().unsafe_col(i * chunk + j);
+
double best_dist_so_far = SortPolicy::WorstDistance();
ComputeSingleNeighborsRecursion_(i * chunk + j, point,
reference_tree_, best_dist_so_far);
}
}
- for(size_t i = 0; i < queries_.n_cols % 10; i++) {
- size_t ind = (queries_.n_cols / 10) * 10 + i;
- arma::vec point = queries_.unsafe_col(ind);
+ for(size_t i = 0; i < queries_.get_ref().n_cols % 10; i++) {
+ size_t ind = (queries_.get_ref().n_cols / 10) * 10 + i;
+ arma::Col<typename T1::elem_type> point =
+ queries_.get_ref().unsafe_col(ind);
+
double best_dist_so_far = SortPolicy::WorstDistance();
ComputeSingleNeighborsRecursion_(ind, point, reference_tree_,
best_dist_so_far);
@@ -405,7 +407,7 @@
}
}
- Timers::StopTimer("neighbor_search/computing_neighbors");
+ Timers::Stop("neighbor_search/computing_neighbors");
// We need to initialize the results list before filling it
resulting_neighbors.set_size(neighbor_indices_.n_rows,
@@ -440,8 +442,8 @@
* @param neighbor Index of reference point which is being inserted.
* @param distance Distance from query point to reference point.
*/
-template<typename Kernel, typename SortPolicy>
-void NeighborSearch<Kernel, SortPolicy>::InsertNeighbor(size_t query_index,
+template<typename T1, typename Kernel, typename SortPolicy>
+void NeighborSearch<T1, Kernel, SortPolicy>::InsertNeighbor(size_t query_index,
size_t pos,
size_t neighbor,
double distance) {
Modified: mlpack/trunk/src/mlpack/methods/neighbor_search/sort_policies/furthest_neighbor_sort.hpp
===================================================================
--- mlpack/trunk/src/mlpack/methods/neighbor_search/sort_policies/furthest_neighbor_sort.hpp 2011-11-17 17:41:26 UTC (rev 10312)
+++ mlpack/trunk/src/mlpack/methods/neighbor_search/sort_policies/furthest_neighbor_sort.hpp 2011-11-17 17:44:31 UTC (rev 10313)
@@ -64,8 +64,8 @@
* this is the maximum distance between the tree node and the point using the
* given distance function.
*/
- template<typename TreeType>
- static double BestPointToNodeDistance(const arma::vec& query_point,
+ template<typename elem_type, typename TreeType>
+ static double BestPointToNodeDistance(const arma::Col<elem_type>& query_point,
const TreeType* reference_node);
/**
Modified: mlpack/trunk/src/mlpack/methods/neighbor_search/sort_policies/furthest_neighbor_sort_impl.hpp
===================================================================
--- mlpack/trunk/src/mlpack/methods/neighbor_search/sort_policies/furthest_neighbor_sort_impl.hpp 2011-11-17 17:41:26 UTC (rev 10312)
+++ mlpack/trunk/src/mlpack/methods/neighbor_search/sort_policies/furthest_neighbor_sort_impl.hpp 2011-11-17 17:44:31 UTC (rev 10313)
@@ -22,9 +22,9 @@
return query_node->bound().MaxDistance(reference_node->bound());
}
-template<typename TreeType>
+template<typename elem_type, typename TreeType>
double FurthestNeighborSort::BestPointToNodeDistance(
- const arma::vec& point,
+ const arma::Col<elem_type>& point,
const TreeType* reference_node) {
// This is not implemented yet for the general case because the trees do not
// accept arbitrary distance metrics.
Modified: mlpack/trunk/src/mlpack/methods/neighbor_search/sort_policies/nearest_neighbor_sort.hpp
===================================================================
--- mlpack/trunk/src/mlpack/methods/neighbor_search/sort_policies/nearest_neighbor_sort.hpp 2011-11-17 17:41:26 UTC (rev 10312)
+++ mlpack/trunk/src/mlpack/methods/neighbor_search/sort_policies/nearest_neighbor_sort.hpp 2011-11-17 17:44:31 UTC (rev 10313)
@@ -68,8 +68,8 @@
* this is the minimum distance between the tree node and the point using the
* given distance function.
*/
- template<typename TreeType>
- static double BestPointToNodeDistance(const arma::vec& query_point,
+ template<typename elem_type, typename TreeType>
+ static double BestPointToNodeDistance(const arma::Col<elem_type>& query_point,
const TreeType* reference_node);
/**
Modified: mlpack/trunk/src/mlpack/methods/neighbor_search/sort_policies/nearest_neighbor_sort_impl.hpp
===================================================================
--- mlpack/trunk/src/mlpack/methods/neighbor_search/sort_policies/nearest_neighbor_sort_impl.hpp 2011-11-17 17:41:26 UTC (rev 10312)
+++ mlpack/trunk/src/mlpack/methods/neighbor_search/sort_policies/nearest_neighbor_sort_impl.hpp 2011-11-17 17:44:31 UTC (rev 10313)
@@ -22,9 +22,9 @@
return query_node->bound().MinDistance(reference_node->bound());
}
-template<typename TreeType>
+template<typename elem_type, typename TreeType>
double NearestNeighborSort::BestPointToNodeDistance(
- const arma::vec& point,
+ const arma::Col<elem_type>& point,
const TreeType* reference_node) {
// This is not implemented yet for the general case because the trees do not
// accept arbitrary distance metrics.
Modified: mlpack/trunk/src/mlpack/methods/neighbor_search/typedef.h
===================================================================
--- mlpack/trunk/src/mlpack/methods/neighbor_search/typedef.h 2011-11-17 17:41:26 UTC (rev 10312)
+++ mlpack/trunk/src/mlpack/methods/neighbor_search/typedef.h 2011-11-17 17:44:31 UTC (rev 10313)
@@ -26,7 +26,7 @@
* neighbors. Squared distances are used because they are slightly faster than
* non-squared distances (they have one fewer call to sqrt()).
*/
-typedef NeighborSearch<kernel::SquaredEuclideanDistance, NearestNeighborSort>
+typedef NeighborSearch<arma::mat, kernel::SquaredEuclideanDistance, NearestNeighborSort>
AllkNN;
/**
@@ -35,7 +35,7 @@
* neighbors. Squared distances are used because they are slightly faster than
* non-squared distances (they have one fewer call to sqrt()).
*/
-typedef NeighborSearch<kernel::SquaredEuclideanDistance, FurthestNeighborSort>
+typedef NeighborSearch<arma::mat, kernel::SquaredEuclideanDistance, FurthestNeighborSort>
AllkFN;
}; // namespace neighbor
Modified: mlpack/trunk/src/mlpack/tests/allkfn_test.cpp
===================================================================
--- mlpack/trunk/src/mlpack/tests/allkfn_test.cpp 2011-11-17 17:41:26 UTC (rev 10312)
+++ mlpack/trunk/src/mlpack/tests/allkfn_test.cpp 2011-11-17 17:44:31 UTC (rev 10313)
@@ -10,6 +10,11 @@
using namespace mlpack;
using namespace mlpack::neighbor;
+#define ELEM double
+#define CONTAINER arma::Mat<ELEM>
+
+typedef NeighborSearch<CONTAINER, kernel::SquaredEuclideanDistance, FurthestNeighborSort> bAllkFN;
+
BOOST_AUTO_TEST_SUITE(AllkFNTest);
/**
@@ -23,7 +28,7 @@
BOOST_AUTO_TEST_CASE(exhaustive_synthetic_test)
{
// Set up our data.
- arma::mat data(1, 11);
+ CONTAINER data(1, 11);
data[0] = 0.05; // Row addressing is unnecessary (they are all 0).
data[1] = 0.35;
data[2] = 0.15;
@@ -41,21 +46,23 @@
CLI::GetParam<int>("neighbor_search/k") = 10;
for (int i = 0; i < 3; i++)
{
- AllkFN* allkfn;
- arma::mat data_mutable = data;
+ //AllkFN* allkfn;
+ bAllkFN* allkfn;
+
+ arma::Col<ELEM> data_mutable = data;
switch(i)
{
case 0: // Use the dual-tree method.
- allkfn = new AllkFN(data_mutable);
+ allkfn = new bAllkFN(data_mutable);
break;
case 1: // Use the single-tree method.
CLI::GetParam<bool>("neighbor_search/single_mode") = true;
- allkfn = new AllkFN(data_mutable);
+ allkfn = new bAllkFN(data_mutable);
break;
case 2: // Use the naive method.
CLI::GetParam<bool>("neighbor_search/single_mode") = false;
CLI::GetParam<bool>("neighbor_search/naive_mode") = true;
- allkfn = new AllkFN(data_mutable);
+ allkfn = new bAllkFN(data_mutable);
break;
}
More information about the mlpack-svn
mailing list