[mlpack-svn] r10046 - in mlpack/trunk/src/mlpack: . core core/tree core/utilities
fastlab-svn at coffeetalk-1.cc.gatech.edu
fastlab-svn at coffeetalk-1.cc.gatech.edu
Wed Oct 26 20:08:59 EDT 2011
Author: nslagle
Date: 2011-10-26 20:08:59 -0400 (Wed, 26 Oct 2011)
New Revision: 10046
Added:
mlpack/trunk/src/mlpack/core/tree/hrectbound_impl.h
mlpack/trunk/src/mlpack/core/tree/spacetree.h
mlpack/trunk/src/mlpack/core/tree/spacetree_impl.h
mlpack/trunk/src/mlpack/core/tree/statistic.h
mlpack/trunk/src/mlpack/core/utilities/
mlpack/trunk/src/mlpack/core/utilities/save_restore_utility.cpp
mlpack/trunk/src/mlpack/core/utilities/save_restore_utility.hpp
mlpack/trunk/src/mlpack/core/utilities/save_restore_utility_impl.hpp
mlpack/trunk/src/mlpack/core/utilities/save_restore_utility_test.cpp
Removed:
mlpack/trunk/src/mlpack/core/model/
mlpack/trunk/src/mlpack/core/utilities/model.hpp
mlpack/trunk/src/mlpack/core/utilities/save_restore_model.cpp
mlpack/trunk/src/mlpack/core/utilities/save_restore_model.hpp
mlpack/trunk/src/mlpack/core/utilities/save_restore_model_impl.hpp
mlpack/trunk/src/mlpack/core/utilities/save_restore_model_test.cpp
Modified:
mlpack/trunk/src/mlpack/core.h
mlpack/trunk/src/mlpack/core/CMakeLists.txt
mlpack/trunk/src/mlpack/core/tree/hrectbound_impl.hpp
mlpack/trunk/src/mlpack/core/tree/spacetree.hpp
mlpack/trunk/src/mlpack/core/tree/spacetree_impl.hpp
mlpack/trunk/src/mlpack/core/tree/statistic.hpp
mlpack/trunk/src/mlpack/core/utilities/CMakeLists.txt
Log:
mlpack/trunk/src/mlpack: rearrange and restructure the save/restore mechanism; clean up some code
Modified: mlpack/trunk/src/mlpack/core/CMakeLists.txt
===================================================================
--- mlpack/trunk/src/mlpack/core/CMakeLists.txt 2011-10-26 22:38:11 UTC (rev 10045)
+++ mlpack/trunk/src/mlpack/core/CMakeLists.txt 2011-10-27 00:08:59 UTC (rev 10046)
@@ -6,9 +6,9 @@
io
kernels
math
- model
optimizers
tree
+ utilities
)
foreach(dir ${DIRS})
Copied: mlpack/trunk/src/mlpack/core/tree/hrectbound_impl.h (from rev 9909, mlpack/trunk/src/mlpack/core/tree/hrectbound_impl.h)
===================================================================
--- mlpack/trunk/src/mlpack/core/tree/hrectbound_impl.h (rev 0)
+++ mlpack/trunk/src/mlpack/core/tree/hrectbound_impl.h 2011-10-27 00:08:59 UTC (rev 10046)
@@ -0,0 +1,331 @@
+/**
+ * @file tree/hrectbound_impl.h
+ *
+ * Implementation of hyper-rectangle bound policy class.
+ * Template parameter t_pow is the metric to use; use 2 for Euclidean (L2).
+ *
+ * @experimental
+ */
+#ifndef __TREE_HRECTBOUND_IMPL_H
+#define __TREE_HRECTBOUND_IMPL_H
+
+#include <math.h>
+
+#include "../math/math_lib.h"
+
+// In case it has not been included yet.
+#include "hrectbound.h"
+
+namespace mlpack {
+namespace bound {
+
+/**
+ * Empty constructor
+ */
+template<int t_pow>
+HRectBound<t_pow>::HRectBound() :
+ dim_(0),
+ bounds_(NULL) { /* nothing to do */ }
+
+/**
+ * Initializes to specified dimensionality with each dimension the empty
+ * set.
+ */
+template<int t_pow>
+HRectBound<t_pow>::HRectBound(size_t dimension) :
+ dim_(dimension),
+ bounds_(new Range[dim_]) { /* nothing to do */ }
+
+/***
+ * Copy constructor necessary to prevent memory leaks.
+ */
+template<int t_pow>
+HRectBound<t_pow>::HRectBound(const HRectBound& other) :
+ dim_(other.dim()),
+ bounds_(new Range[dim_]) {
+ // Copy other bounds over.
+ for (size_t i = 0; i < dim_; i++)
+ bounds_[i] = other[i];
+}
+
+/***
+ * Same as the copy constructor.
+ */
+template<int t_pow>
+HRectBound<t_pow>& HRectBound<t_pow>::operator=(const HRectBound& other) {
+ if (bounds_)
+ delete[] bounds_;
+
+ // We can't just copy the bounds_ pointer like the default copy constructor
+ // will!
+ dim_ = other.dim();
+ bounds_ = new Range[dim_];
+ for (size_t i = 0; i < dim_; i++)
+ bounds_[i] = other[i];
+
+ return *this;
+}
+
+/**
+ * Destructor: clean up memory
+ */
+template<int t_pow>
+HRectBound<t_pow>::~HRectBound() {
+ if (bounds_)
+ delete[] bounds_;
+}
+
+/**
+ * Resets all dimensions to the empty set.
+ */
+template<int t_pow>
+void HRectBound<t_pow>::Clear() {
+ for (size_t i = 0; i < dim_; i++) {
+ bounds_[i] = Range();
+ }
+}
+
+/**
+ * Gets the range for a particular dimension.
+ */
+template<int t_pow>
+const Range HRectBound<t_pow>::operator[](size_t i) const {
+ return bounds_[i];
+}
+
+/**
+ * Sets the range for the given dimension.
+ */
+template<int t_pow>
+Range& HRectBound<t_pow>::operator[](size_t i) {
+ return bounds_[i];
+}
+
+/***
+ * Calculates the centroid of the range, placing it into the given vector.
+ *
+ * @param centroid Vector which the centroid will be written to.
+ */
+template<int t_pow>
+void HRectBound<t_pow>::Centroid(arma::vec& centroid) const {
+ // set size correctly if necessary
+ if(!(centroid.n_elem == dim_))
+ centroid.set_size(dim_);
+
+ for(size_t i = 0; i < dim_; i++) {
+ centroid(i) = bounds_[i].mid();
+ }
+}
+
+/**
+ * Calculates minimum bound-to-point squared distance.
+ */
+template<int t_pow>
+double HRectBound<t_pow>::MinDistance(const arma::vec& point) const {
+ assert(point.n_elem == dim_);
+
+ double sum = 0;
+ const Range* mbound = bounds_;
+
+ double lower, higher;
+ for(size_t d = 0; d < dim_; d++) {
+ lower = mbound->lo - point[d]; // negative if point[d] > bounds_[d]
+ higher = point[d] - mbound->hi; // negative if point[d] < bounds_[d]
+
+ // since only one of 'lower' or 'higher' is negative, if we add each's
+ // absolute value to itself and then sum those two, our result is the
+ // nonnegative half of the equation times two; then we raise to power t_pow
+ sum += pow((lower + fabs(lower)) + (higher + fabs(higher)), (double) t_pow);
+
+ // move bound pointer
+ mbound++;
+ }
+
+ // now take the t_pow'th root (but make sure our result is squared); then
+ // divide by four to cancel out the constant of 2 (which has been squared now)
+ // that was introduced earlier
+ return pow(sum, 2.0 / (double) t_pow) / 4.0;
+}
+
+/**
+ * Calculates minimum bound-to-bound squared distance.
+ *
+ * Example: bound1.MinDistanceSq(other) for minimum squared distance.
+ */
+template<int t_pow>
+double HRectBound<t_pow>::MinDistance(const HRectBound& other) const {
+ assert(dim_ == other.dim_);
+
+ double sum = 0;
+ const Range* mbound = bounds_;
+ const Range* obound = other.bounds_;
+
+ double lower, higher;
+ for (size_t d = 0; d < dim_; d++) {
+ lower = obound->lo - mbound->hi;
+ higher = mbound->lo - obound->hi;
+ // We invoke the following:
+ // x + fabs(x) = max(x * 2, 0)
+ // (x * 2)^2 / 4 = x^2
+ sum += pow((lower + fabs(lower)) + (higher + fabs(higher)), (double) t_pow);
+
+ // move bound pointers
+ mbound++;
+ obound++;
+ }
+
+ return pow(sum, 2.0 / (double) t_pow) / 4.0;
+}
+
+/**
+ * Calculates maximum bound-to-point squared distance.
+ */
+template<int t_pow>
+double HRectBound<t_pow>::MaxDistance(const arma::vec& point) const {
+ double sum = 0;
+
+ assert(point.n_elem == dim_);
+
+ for (size_t d = 0; d < dim_; d++) {
+ double v = fabs(std::max(
+ point[d] - bounds_[d].lo,
+ bounds_[d].hi - point[d]));
+ sum += pow(v, (double) t_pow);
+ }
+
+ return pow(sum, 2.0 / (double) t_pow);
+}
+
+/**
+ * Computes maximum distance.
+ */
+template<int t_pow>
+double HRectBound<t_pow>::MaxDistance(const HRectBound& other) const {
+ double sum = 0;
+
+ assert(dim_ == other.dim_);
+
+ double v;
+ for(size_t d = 0; d < dim_; d++) {
+ v = fabs(std::max(
+ other.bounds_[d].hi - bounds_[d].lo,
+ bounds_[d].hi - other.bounds_[d].lo));
+ sum += pow(v, (double) t_pow); // v is non-negative
+ }
+
+ return pow(sum, 2.0 / (double) t_pow);
+}
+
+/**
+ * Calculates minimum and maximum bound-to-bound squared distance.
+ */
+template<int t_pow>
+Range HRectBound<t_pow>::RangeDistance(const HRectBound& other) const {
+ double sum_lo = 0;
+ double sum_hi = 0;
+
+ assert(dim_ == other.dim_);
+
+ double v1, v2, v_lo, v_hi;
+ for (size_t d = 0; d < dim_; d++) {
+ v1 = other.bounds_[d].lo - bounds_[d].hi;
+ v2 = bounds_[d].lo - other.bounds_[d].hi;
+ // one of v1 or v2 is negative
+ if(v1 >= v2) {
+ v_hi = -v2; // make it nonnegative
+ v_lo = (v1 > 0) ? v1 : 0; // force to be 0 if negative
+ } else {
+ v_hi = -v1; // make it nonnegative
+ v_lo = (v2 > 0) ? v2 : 0; // force to be 0 if negative
+ }
+
+ sum_lo += pow(v_lo, (double) t_pow);
+ sum_hi += pow(v_hi, (double) t_pow);
+ }
+
+ return Range(pow(sum_lo, 2.0 / (double) t_pow),
+ pow(sum_hi, 2.0 / (double) t_pow));
+}
+
+/**
+ * Calculates minimum and maximum bound-to-point squared distance.
+ */
+template<int t_pow>
+Range HRectBound<t_pow>::RangeDistance(const arma::vec& point) const {
+ double sum_lo = 0;
+ double sum_hi = 0;
+
+ Log::Assert(point.n_elem == dim_);
+
+ double v1, v2, v_lo, v_hi;
+ for(size_t d = 0; d < dim_; d++) {
+ v1 = bounds_[d].lo - point[d]; // Negative if point[d] > lo.
+ v2 = point[d] - bounds_[d].hi; // Negative if point[d] < hi.
+ // One of v1 or v2 (or both) is negative.
+ if(v1 >= 0) { // point[d] <= bounds_[d].lo.
+ v_hi = -v2; // v2 will be larger but must be negated.
+ v_lo = v1;
+ } else { // point[d] is between lo and hi, or greater than hi.
+ if (v2 >= 0) {
+ v_hi = -v1; // v1 will be larger, but must be negated.
+ v_lo = v2;
+ } else {
+ v_hi = -std::min(v1, v2); // Both are negative, but we need the larger.
+ v_lo = 0;
+ }
+ }
+
+ sum_lo += pow(v_lo, (double) t_pow);
+ sum_hi += pow(v_hi, (double) t_pow);
+ }
+
+ return Range(pow(sum_lo, 2.0 / (double) t_pow),
+ pow(sum_hi, 2.0 / (double) t_pow));
+}
+
+/**
+ * Expands this region to include a new point.
+ */
+template<int t_pow>
+HRectBound<t_pow>& HRectBound<t_pow>::operator|=(const arma::vec& vector) {
+ Log::Assert(vector.n_elem == dim_);
+
+ for (size_t i = 0; i < dim_; i++) {
+ bounds_[i] |= vector[i];
+ }
+
+ return *this;
+}
+
+/**
+ * Expands this region to encompass another bound.
+ */
+template<int t_pow>
+HRectBound<t_pow>& HRectBound<t_pow>::operator|=(const HRectBound& other) {
+ assert(other.dim_ == dim_);
+
+ for (size_t i = 0; i < dim_; i++) {
+ bounds_[i] |= other.bounds_[i];
+ }
+
+ return *this;
+}
+
+/**
+ * Determines if a point is within this bound.
+ */
+template<int t_pow>
+bool HRectBound<t_pow>::Contains(const arma::vec& point) const {
+ for (size_t i = 0; i < point.n_elem; i++) {
+ if (!bounds_[i].Contains(point(i))) {
+ return false;
+ }
+ }
+
+ return true;
+}
+
+}; // namespace bound
+}; // namespace mlpack
+
+#endif
Modified: mlpack/trunk/src/mlpack/core/tree/hrectbound_impl.hpp
===================================================================
--- mlpack/trunk/src/mlpack/core/tree/hrectbound_impl.hpp 2011-10-26 22:38:11 UTC (rev 10045)
+++ mlpack/trunk/src/mlpack/core/tree/hrectbound_impl.hpp 2011-10-27 00:08:59 UTC (rev 10046)
@@ -288,7 +288,7 @@
*/
template<int t_pow>
HRectBound<t_pow>& HRectBound<t_pow>::operator|=(const arma::vec& vector) {
- assert(vector.n_elem == dim_);
+ Log::Assert(vector.n_elem == dim_);
for (size_t i = 0; i < dim_; i++) {
bounds_[i] |= vector[i];
Copied: mlpack/trunk/src/mlpack/core/tree/spacetree.h (from rev 9909, mlpack/trunk/src/mlpack/core/tree/spacetree.h)
===================================================================
--- mlpack/trunk/src/mlpack/core/tree/spacetree.h (rev 0)
+++ mlpack/trunk/src/mlpack/core/tree/spacetree.h 2011-10-27 00:08:59 UTC (rev 10046)
@@ -0,0 +1,202 @@
+/**
+ * @file spacetree.h
+ *
+ * Generalized space partitioning tree.
+ *
+ * @experimental
+ */
+
+#ifndef TREE_SPACETREE_H
+#define TREE_SPACETREE_H
+
+#include "statistic.h"
+
+#include <mlpack/core.h>
+#include <armadillo>
+
+namespace mlpack {
+namespace tree {
+
+PARAM_MODULE("tree", "Parameters for the binary space partitioning tree.");
+PARAM_INT("leaf_size", "Leaf size used during tree construction.", "tree", 20);
+
+/**
+ * A binary space partitioning tree, such as a KD-tree or a ball tree. Once the
+ * bound and type of dataset is defined, the tree will construct itself. Call
+ * the constructor with the dataset to build the tree on, and the entire tree
+ * will be built.
+ *
+ * This particular tree does not allow growth, so you cannot add or delete nodes
+ * from it. If you need to add or delete a node, the better procedure is to
+ * rebuild the tree entirely.
+ *
+ * This tree does take one command line parameter, which is the leaf size to be
+ * used. You can set this at runtime with --tree/leaf_size [leaf_size]. You
+ * can also set it in your program using CLI:
+ *
+ * @code
+ * CLI::GetParam<int>("tree/leaf_size") = target_leaf_size;
+ * @endcode
+ *
+ * @tparam TBound The bound used for each node. The valid types of bounds and
+ * the necessary skeleton interface for this class can be found in bounds/.
+ * @tparam TDataset The type of dataset (forced to be arma::mat for now).
+ * @tparam TStatistic Extra data contained in the node. See statistic.h for
+ * the necessary skeleton interface.
+ */
+template<typename Bound,
+ typename Statistic = EmptyStatistic>
+class BinarySpaceTree {
+ private:
+ BinarySpaceTree *left_; //< The left child node.
+ BinarySpaceTree *right_; //< The right child node.
+ size_t begin_; //< The first point in the dataset contained in this node.
+ size_t count_; //< The count of points in the dataset contained in this node.
+ Bound bound_; //< The bound object for this node.
+ Statistic stat_; //< The extra data contained in the node.
+
+ public:
+ /***
+ * Construct this as the head node of a binary space tree using the given
+ * dataset. This will modify the ordering of the points in the dataset!
+ *
+ * Optionally, pass in vectors which represent a mapping from the old
+ * dataset's point ordering to the new ordering, and vice versa.
+ *
+ * @param data Dataset to create tree from.
+ * @param leaf_size Leaf size of the tree.
+ * @param old_from_new Vector which will be filled with the old positions for
+ * each new point.
+ * @param new_from_old Vector which will be filled with the new positions for
+ * each old point.
+ */
+ BinarySpaceTree(arma::mat& data);
+ BinarySpaceTree(arma::mat& data, std::vector<size_t>& old_from_new);
+ BinarySpaceTree(arma::mat& data,
+ std::vector<size_t>& old_from_new,
+ std::vector<size_t>& new_from_old);
+
+ BinarySpaceTree(arma::mat& data,
+ size_t begin_in,
+ size_t count_in);
+ BinarySpaceTree(arma::mat& data,
+ size_t begin_in,
+ size_t count_in,
+ std::vector<size_t>& old_from_new);
+ BinarySpaceTree(arma::mat& data,
+ size_t begin_in,
+ size_t count_in,
+ std::vector<size_t>& old_from_new,
+ std::vector<size_t>& new_from_old);
+
+ BinarySpaceTree();
+
+ /***
+ * Deletes this node, deallocating the memory for the children and calling
+ * their destructors in turn. This will invalidate any pointers or references
+ * to any nodes which are children of this one.
+ */
+ ~BinarySpaceTree();
+
+ /**
+ * Find a node in this tree by its begin and count.
+ *
+ * Every node is uniquely identified by these two numbers.
+ * This is useful for communicating position over the network,
+ * when pointers would be invalid.
+ *
+ * @param begin_q the begin() of the node to find
+ * @param count_q the count() of the node to find
+ * @return the found node, or NULL
+ */
+ const BinarySpaceTree* FindByBeginCount(size_t begin_q,
+ size_t count_q) const;
+
+ /**
+ * Find a node in this tree by its begin and count (const).
+ *
+ * Every node is uniquely identified by these two numbers.
+ * This is useful for communicating position over the network,
+ * when pointers would be invalid.
+ *
+ * @param begin_q the begin() of the node to find
+ * @param count_q the count() of the node to find
+ * @return the found node, or NULL
+ */
+ BinarySpaceTree* FindByBeginCount(size_t begin_q, size_t count_q);
+
+ // TODO: Not const correct
+
+ const Bound& bound() const;
+ Bound& bound();
+
+ const Statistic& stat() const;
+ Statistic& stat();
+
+ bool is_leaf() const;
+
+ /**
+ * Gets the left branch of the tree.
+ */
+ BinarySpaceTree *left() const;
+
+ /**
+ * Gets the right branch.
+ */
+ BinarySpaceTree *right() const;
+
+ /**
+ * Gets the index of the begin point of this subset.
+ */
+ size_t begin() const;
+
+ /**
+ * Gets the index one beyond the last index in the series.
+ */
+ size_t end() const;
+
+ /**
+ * Gets the number of points in this subset.
+ */
+ size_t count() const;
+
+ void Print() const;
+
+ private:
+
+ /***
+ * Splits the current node, assigning its left and right children recursively.
+ *
+ * Optionally, return a list of the changed indices.
+ *
+ * @param data Dataset which we are using.
+ * @param leaf_size Leaf size to split with.
+ * @param old_from_new Vector holding permuted indices.
+ */
+ void SplitNode(arma::mat& data);
+ void SplitNode(arma::mat& data, std::vector<size_t>& old_from_new);
+
+ /***
+ * Find the index to split on for this node, given that we are splitting in
+ * the given split dimension on the specified split value.
+ *
+ * Optionally, return a list of the changed indices.
+ *
+ * @param data Dataset which we are using.
+ * @param split_dim Dimension of dataset to split on.
+ * @param split_val Value to split on, in the given split dimension.
+ * @param old_from_new Vector holding permuted indices.
+ */
+ size_t GetSplitIndex(arma::mat& data, int split_dim, double split_val);
+ size_t GetSplitIndex(arma::mat& data, int split_dim, double split_val,
+ std::vector<size_t>& old_from_new);
+
+};
+
+}; // namespace tree
+}; // namespace mlpack
+
+// Include implementation.
+#include "spacetree_impl.h"
+
+#endif
Modified: mlpack/trunk/src/mlpack/core/tree/spacetree.hpp
===================================================================
--- mlpack/trunk/src/mlpack/core/tree/spacetree.hpp 2011-10-26 22:38:11 UTC (rev 10045)
+++ mlpack/trunk/src/mlpack/core/tree/spacetree.hpp 2011-10-27 00:08:59 UTC (rev 10046)
@@ -1,5 +1,5 @@
/**
- * @file spacetree.h
+ * @file spacetree.hpp
*
* Generalized space partitioning tree.
*
@@ -41,7 +41,7 @@
* the necessary skeleton interface for this class can be found in bounds/.
* @tparam TDataset The type of dataset (forced to be arma::mat for now).
* @tparam TStatistic Extra data contained in the node. See statistic.h for
- * the necessary skeleton interface.
+ * the necessary skeleton interface.
*/
template<typename Bound,
typename Statistic = EmptyStatistic>
@@ -87,7 +87,7 @@
size_t count_in,
std::vector<size_t>& old_from_new,
std::vector<size_t>& new_from_old);
-
+
BinarySpaceTree();
/***
@@ -110,7 +110,7 @@
*/
const BinarySpaceTree* FindByBeginCount(size_t begin_q,
size_t count_q) const;
-
+
/**
* Find a node in this tree by its begin and count (const).
*
@@ -123,9 +123,9 @@
* @return the found node, or NULL
*/
BinarySpaceTree* FindByBeginCount(size_t begin_q, size_t count_q);
-
+
// TODO: Not const correct
-
+
const Bound& bound() const;
Bound& bound();
@@ -153,12 +153,12 @@
* Gets the index one beyond the last index in the series.
*/
size_t end() const;
-
+
/**
* Gets the number of points in this subset.
*/
size_t count() const;
-
+
void Print() const;
private:
Copied: mlpack/trunk/src/mlpack/core/tree/spacetree_impl.h (from rev 9909, mlpack/trunk/src/mlpack/core/tree/spacetree_impl.h)
===================================================================
--- mlpack/trunk/src/mlpack/core/tree/spacetree_impl.h (rev 0)
+++ mlpack/trunk/src/mlpack/core/tree/spacetree_impl.h 2011-10-27 00:08:59 UTC (rev 10046)
@@ -0,0 +1,486 @@
+/**
+ * @file spacetree_impl.h
+ *
+ * Implementation of generalized space partitioning tree.
+ *
+ * @experimental
+ */
+
+#ifndef TREE_SPACETREE_IMPL_H
+#define TREE_SPACETREE_IMPL_H
+
+// Try to prevent direct inclusion
+#ifndef TREE_SPACETREE_H
+#error "Do not include this header directly."
+#endif
+
+#include <mlpack/core/io/io.h>
+#include "../io/log.h"
+
+namespace mlpack {
+namespace tree {
+
+// Each of these overloads is kept as a separate function to keep the overhead
+// from the two std::vectors out, if possible.
+template<typename Bound, typename Statistic>
+BinarySpaceTree<Bound, Statistic>::BinarySpaceTree(arma::mat& data) :
+ left_(NULL),
+ right_(NULL),
+ begin_(0), /* This root node starts at index 0, */
+ count_(data.n_cols), /* and spans all of the dataset. */
+ bound_(data.n_rows),
+ stat_() {
+ // Do the actual splitting of this node.
+ SplitNode(data);
+}
+
+template<typename Bound, typename Statistic>
+BinarySpaceTree<Bound, Statistic>::BinarySpaceTree(
+ arma::mat& data,
+ std::vector<size_t>& old_from_new) :
+ left_(NULL),
+ right_(NULL),
+ begin_(0),
+ count_(data.n_cols),
+ bound_(data.n_rows),
+ stat_() {
+ // Initialize old_from_new correctly.
+ old_from_new.resize(data.n_cols);
+ for (size_t i = 0; i < data.n_cols; i++)
+ old_from_new[i] = i; // Fill with unharmed indices.
+
+ // Now do the actual splitting.
+ SplitNode(data, old_from_new);
+}
+
+template<typename Bound, typename Statistic>
+BinarySpaceTree<Bound, Statistic>::BinarySpaceTree(
+ arma::mat& data,
+ std::vector<size_t>& old_from_new,
+ std::vector<size_t>& new_from_old) :
+ left_(NULL),
+ right_(NULL),
+ begin_(0),
+ count_(data.n_cols),
+ bound_(data.n_rows),
+ stat_() {
+ // Initialize the old_from_new vector correctly.
+ old_from_new.resize(data.n_cols);
+ for (size_t i = 0; i < data.n_cols; i++)
+ old_from_new[i] = i; // Fill with unharmed indices.
+
+ // Now do the actual splitting.
+ SplitNode(data, old_from_new);
+
+ // Map the new_from_old indices correctly.
+ new_from_old.resize(data.n_cols);
+ for (size_t i = 0; i < data.n_cols; i++)
+ new_from_old[old_from_new[i]] = i;
+}
+
+template<typename Bound, typename Statistic>
+BinarySpaceTree<Bound, Statistic>::BinarySpaceTree(
+ arma::mat& data,
+ size_t begin_in,
+ size_t count_in) :
+ left_(NULL),
+ right_(NULL),
+ begin_(begin_in),
+ count_(count_in),
+ bound_(data.n_rows),
+ stat_() {
+ // Perform the actual splitting.
+ SplitNode(data);
+}
+
+template<typename Bound, typename Statistic>
+BinarySpaceTree<Bound, Statistic>::BinarySpaceTree(
+ arma::mat& data,
+ size_t begin_in,
+ size_t count_in,
+ std::vector<size_t>& old_from_new) :
+ left_(NULL),
+ right_(NULL),
+ begin_(begin_in),
+ count_(count_in),
+ bound_(data.n_rows),
+ stat_() {
+ // Hopefully the vector is initialized correctly! We can't check that
+ // entirely but we can do a minor sanity check.
+ assert(old_from_new.size() == data.n_cols);
+
+ // Perform the actual splitting.
+ SplitNode(data, old_from_new);
+}
+
+template<typename Bound, typename Statistic>
+BinarySpaceTree<Bound, Statistic>::BinarySpaceTree(
+ arma::mat& data,
+ size_t begin_in,
+ size_t count_in,
+ std::vector<size_t>& old_from_new,
+ std::vector<size_t>& new_from_old) :
+ left_(NULL),
+ right_(NULL),
+ begin_(begin_in),
+ count_(count_in),
+ bound_(data.n_rows),
+ stat_() {
+ // Hopefully the vector is initialized correctly! We can't check that
+ // entirely but we can do a minor sanity check.
+ assert(old_from_new.size() == data.n_cols);
+
+ // Perform the actual splitting.
+ SplitNode(data, old_from_new);
+
+ // Map the new_from_old indices correctly.
+ new_from_old.resize(data.n_cols);
+ for (size_t i = 0; i < data.n_cols; i++)
+ new_from_old[old_from_new[i]] = i;
+}
+
+template<typename Bound, typename Statistic>
+BinarySpaceTree<Bound, Statistic>::BinarySpaceTree() :
+ left_(NULL),
+ right_(NULL),
+ begin_(0),
+ count_(0),
+ bound_(),
+ stat_() {
+ // Nothing to do.
+}
+
+/***
+ * Deletes this node, deallocating the memory for the children and calling their
+ * destructors in turn. This will invalidate any pointers or references to any
+ * nodes which are children of this one.
+ */
+template<typename Bound, typename Statistic>
+BinarySpaceTree<Bound, Statistic>::~BinarySpaceTree() {
+ if (left_)
+ delete left_;
+ if (right_)
+ delete right_;
+}
+
+/**
+ * Find a node in this tree by its begin and count.
+ *
+ * Every node is uniquely identified by these two numbers.
+ * This is useful for communicating position over the network,
+ * when pointers would be invalid.
+ *
+ * @param begin_q the begin() of the node to find
+ * @param count_q the count() of the node to find
+ * @return the found node, or NULL
+ */
+template<typename Bound, typename Statistic>
+const BinarySpaceTree<Bound, Statistic>*
+BinarySpaceTree<Bound, Statistic>::FindByBeginCount(size_t begin_q,
+ size_t count_q) const {
+
+ mlpack::Log::Assert(begin_q >= begin_);
+ mlpack::Log::Assert(count_q <= count_);
+ if (begin_ == begin_q && count_ == count_q)
+ return this;
+ else if (is_leaf())
+ return NULL;
+ else if (begin_q < right_->begin_)
+ return left_->FindByBeginCount(begin_q, count_q);
+ else
+ return right_->FindByBeginCount(begin_q, count_q);
+}
+
+/**
+ * Find a node in this tree by its begin and count (const).
+ *
+ * Every node is uniquely identified by these two numbers.
+ * This is useful for communicating position over the network,
+ * when pointers would be invalid.
+ *
+ * @param begin_q the begin() of the node to find
+ * @param count_q the count() of the node to find
+ * @return the found node, or NULL
+ */
+template<typename Bound, typename Statistic>
+BinarySpaceTree<Bound, Statistic>*
+BinarySpaceTree<Bound, Statistic>::FindByBeginCount(size_t begin_q,
+ size_t count_q) {
+
+ mlpack::Log::Assert(begin_q >= begin_);
+ mlpack::Log::Assert(count_q <= count_);
+ if (begin_ == begin_q && count_ == count_q)
+ return this;
+ else if (is_leaf())
+ return NULL;
+ else if (begin_q < right_->begin_)
+ return left_->FindByBeginCount(begin_q, count_q);
+ else
+ return right_->FindByBeginCount(begin_q, count_q);
+}
+
+template<typename Bound, typename Statistic>
+const Bound& BinarySpaceTree<Bound, Statistic>::bound() const {
+ return bound_;
+}
+
+template<typename Bound, typename Statistic>
+Bound& BinarySpaceTree<Bound, Statistic>::bound() {
+ return bound_;
+}
+
+template<typename Bound, typename Statistic>
+const Statistic& BinarySpaceTree<Bound, Statistic>::stat() const {
+ return stat_;
+}
+
+template<typename Bound, typename Statistic>
+Statistic& BinarySpaceTree<Bound, Statistic>::stat() {
+ return stat_;
+}
+
+template<typename Bound, typename Statistic>
+bool BinarySpaceTree<Bound, Statistic>::is_leaf() const {
+ return !left_;
+}
+
+/**
+ * Gets the left branch of the tree.
+ */
+template<typename Bound, typename Statistic>
+BinarySpaceTree<Bound, Statistic>*
+BinarySpaceTree<Bound, Statistic>::left() const {
+ // TODO: Const correctness
+ return left_;
+}
+
+/**
+ * Gets the right branch.
+ */
+template<typename Bound, typename Statistic>
+BinarySpaceTree<Bound, Statistic>*
+BinarySpaceTree<Bound, Statistic>::right() const {
+ // TODO: Const correctness
+ return right_;
+}
+
+/**
+ * Gets the index of the begin point of this subset.
+ */
+template<typename Bound, typename Statistic>
+size_t BinarySpaceTree<Bound, Statistic>::begin() const {
+ return begin_;
+}
+
+/**
+ * Gets the index one beyond the last index in the series.
+ */
+template<typename Bound, typename Statistic>
+size_t BinarySpaceTree<Bound, Statistic>::end() const {
+ return begin_ + count_;
+}
+
+/**
+ * Gets the number of points in this subset.
+ */
+template<typename Bound, typename Statistic>
+size_t BinarySpaceTree<Bound, Statistic>::count() const {
+ return count_;
+}
+
+template<typename Bound, typename Statistic>
+void BinarySpaceTree<Bound, Statistic>::Print() const {
+ printf("node: %d to %d: %d points total\n",
+ begin_, begin_ + count_ - 1, count_);
+ if (!is_leaf()) {
+ left_->Print();
+ right_->Print();
+ }
+}
+
+template<typename Bound, typename Statistic>
+void BinarySpaceTree<Bound, Statistic>::SplitNode(arma::mat& data) {
+ // This should be a single function for Bound.
+ // We need to expand the bounds of this node properly.
+ for (size_t i = begin_; i < (begin_ + count_); i++)
+ bound_ |= data.unsafe_col(i);
+
+ // Now, check if we need to split at all.
+ if (count_ <= (size_t) CLI::GetParam<int>("tree/leaf_size"))
+ return; // We can't split this.
+
+ // Figure out which dimension to split on.
+ size_t split_dim = data.n_rows; // Indicate invalid by max_dim + 1.
+ double max_width = -1;
+
+ // Find the split dimension.
+ for (size_t d = 0; d < data.n_rows; d++) {
+ double width = bound_[d].width();
+
+ if (width > max_width) {
+ max_width = width;
+ split_dim = d;
+ }
+ }
+
+ // Split in the middle of that dimension.
+ double split_val = bound_[split_dim].mid();
+
+ if (max_width == 0) // All these points are the same. We can't split.
+ return;
+
+ // Perform the actual splitting. This will order the dataset such that points
+ // with value in dimension split_dim less than or equal to split_val are on
+ // the left of split_col, and points with value in dimension split_dim greater
+ // than split_val are on the right side of split_col.
+ size_t split_col = GetSplitIndex(data, split_dim, split_val);
+
+ // Now that we know the split column, we will recursively split the children
+ // by calling their constructors (which perform this splitting process).
+ left_ = new BinarySpaceTree<Bound, Statistic>(data, begin_,
+ split_col - begin_);
+ right_ = new BinarySpaceTree<Bound, Statistic>(data, split_col,
+ begin_ + count_ - split_col);
+}
+
+template<typename Bound, typename Statistic>
+void BinarySpaceTree<Bound, Statistic>::SplitNode(
+ arma::mat& data,
+ std::vector<size_t>& old_from_new) {
+ // This should be a single function for Bound.
+ // We need to expand the bounds of this node properly.
+ for (size_t i = begin_; i < (begin_ + count_); i++)
+ bound_ |= data.unsafe_col(i);
+
+ // First, check if we need to split at all.
+ if (count_ <= (size_t) CLI::GetParam<int>("tree/leaf_size"))
+ return; // We can't split this.
+
+ // Figure out which dimension to split on.
+ size_t split_dim = data.n_rows; // Indicate invalid by max_dim + 1.
+ double max_width = -1;
+
+ // Find the split dimension.
+ for (size_t d = 0; d < data.n_rows; d++) {
+ double width = bound_[d].width();
+
+ if (width > max_width) {
+ max_width = width;
+ split_dim = d;
+ }
+ }
+
+ // Split in the middle of that dimension.
+ double split_val = bound_[split_dim].mid();
+
+ if (max_width == 0) // All these points are the same. We can't split.
+ return;
+
+ // Perform the actual splitting. This will order the dataset such that points
+ // with value in dimension split_dim less than or equal to split_val are on
+ // the left of split_col, and points with value in dimension split_dim greater
+ // than split_val are on the right side of split_col.
+ size_t split_col = GetSplitIndex(data, split_dim, split_val, old_from_new);
+
+ // Now that we know the split column, we will recursively split the children
+ // by calling their constructors (which perform this splitting process).
+ left_ = new BinarySpaceTree<Bound, Statistic>(data, begin_,
+ split_col - begin_, old_from_new);
+ right_ = new BinarySpaceTree<Bound, Statistic>(data, split_col,
+ begin_ + count_ - split_col, old_from_new);
+}
+
+template<typename Bound, typename Statistic>
+size_t BinarySpaceTree<Bound, Statistic>::GetSplitIndex(
+ arma::mat& data,
+ int split_dim,
+ double split_val) {
+ // This method modifies the input dataset. We loop both from the left and
+ // right sides of the points contained in this node. The points less than
+ // split_val should be on the left side of the matrix, and the points greater
+ // than split_val should be on the right side of the matrix.
+ size_t left = begin_;
+ size_t right = begin_ + count_ - 1;
+
+ // First half-iteration of the loop is out here because the termination
+ // condition is in the middle.
+ while ((data(split_dim, left) < split_val) && (left <= right))
+ left++;
+ while ((data(split_dim, right) >= split_val) && (left <= right))
+ right--;
+
+ while(left <= right) {
+ // Swap columns.
+ data.swap_cols(left, right);
+
+ // See how many points on the left are correct. When they are correct,
+ // increase the left counter accordingly. When we encounter one that isn't
+ // correct, stop. We will switch it later.
+ while ((data(split_dim, left) < split_val) && (left <= right))
+ left++;
+
+ // Now see how many points on the right are correct. When they are correct,
+ // decrease the right counter accordingly. When we encounter one that isn't
+ // correct, stop. We will switch it with the wrong point we found in the
+ // previous loop.
+ while ((data(split_dim, right) >= split_val) && (left <= right))
+ right--;
+ }
+
+ assert(left == right + 1);
+
+ return left;
+}
+
+template<typename Bound, typename Statistic>
+size_t BinarySpaceTree<Bound, Statistic>::GetSplitIndex(
+ arma::mat& data,
+ int split_dim,
+ double split_val,
+ std::vector<size_t>& old_from_new) {
+ // This method modifies the input dataset. We loop both from the left and
+ // right sides of the points contained in this node. The points less than
+ // split_val should be on the left side of the matrix, and the points greater
+ // than split_val should be on the right side of the matrix.
+ size_t left = begin_;
+ size_t right = begin_ + count_ -1;
+
+ // First half-iteration of the loop is out here because the termination
+ // condition is in the middle.
+ while ((data(split_dim, left) < split_val) && (left <= right))
+ left++;
+ while ((data(split_dim, right) >= split_val) && (left <= right))
+ right--;
+
+ while(left <= right) {
+ // Swap columns.
+ data.swap_cols(left, right);
+
+ // Update the indices for what we changed.
+ size_t t = old_from_new[left];
+ old_from_new[left] = old_from_new[right];
+ old_from_new[right] = t;
+
+ // See how many points on the left are correct. When they are correct,
+ // increase the left counter accordingly. When we encounter one that isn't
+ // correct, stop. We will switch it later.
+ while ((data(split_dim, left) < split_val) && (left <= right))
+ left++;
+
+ // Now see how many points on the right are correct. When they are correct,
+ // decrease the right counter accordingly. When we encounter one that isn't
+ // correct, stop. We will switch it with the wrong point we found in the
+ // previous loop.
+ while ((data(split_dim, right) >= split_val) && (left <= right))
+ right--;
+ }
+
+ assert(left == right + 1);
+
+ return left;
+}
+
+}; // namespace tree
+}; // namespace mlpack
+
+#endif
Modified: mlpack/trunk/src/mlpack/core/tree/spacetree_impl.hpp
===================================================================
--- mlpack/trunk/src/mlpack/core/tree/spacetree_impl.hpp 2011-10-26 22:38:11 UTC (rev 10045)
+++ mlpack/trunk/src/mlpack/core/tree/spacetree_impl.hpp 2011-10-27 00:08:59 UTC (rev 10046)
@@ -9,11 +9,13 @@
#ifndef __MLPACK_CORE_TREE_SPACETREE_IMPL_HPP
#define __MLPACK_CORE_TREE_SPACETREE_IMPL_HPP
-// In case it wasn't included already for some reason.
-#include "spacetree.hpp"
-#include <mlpack/core/io/cli.hpp>
-#include <mlpack/core/io/log.hpp>
+// Try to prevent direct inclusion
+#ifndef __MLPACK_CORE_TREE_SPACETREE_HPP
+#error "Do not include this header directly."
+#endif
+#include "mlpack/core/io/log.hpp"
+
namespace mlpack {
namespace tree {
Copied: mlpack/trunk/src/mlpack/core/tree/statistic.h (from rev 9912, mlpack/trunk/src/mlpack/core/tree/statistic.h)
===================================================================
--- mlpack/trunk/src/mlpack/core/tree/statistic.h (rev 0)
+++ mlpack/trunk/src/mlpack/core/tree/statistic.h 2011-10-27 00:08:59 UTC (rev 10046)
@@ -0,0 +1,41 @@
+/**
+ * @file statistic.h
+ *
+ * Home for the concept of tree statistics.
+ *
+ * You should define your own statistic that looks like EmptyStatistic.
+ *
+ * @experimental
+ */
+
+#ifndef TREE_STATISTIC_H
+#define TREE_STATISTIC_H
+
+#include <armadillo>
+
+/**
+ * Empty statistic if you are not interested in storing statistics in your
+ * tree. Use this as a template for your own.
+ *
+ * @experimental
+ */
+class EmptyStatistic {
+ public:
+ EmptyStatistic() {}
+ ~EmptyStatistic() {}
+
+ /**
+ * Initializes by taking statistics on raw data.
+ */
+ void Init(const arma::mat& dataset, size_t start, size_t count) { }
+
+ /**
+ * Initializes by combining statistics of two partitions.
+ *
+ * This lets you build fast bottom-up statistics when building trees.
+ */
+ void Init(const arma::mat& dataset, size_t start, size_t count,
+ const EmptyStatistic& left_stat, const EmptyStatistic& right_stat) { }
+};
+
+#endif
Modified: mlpack/trunk/src/mlpack/core/tree/statistic.hpp
===================================================================
--- mlpack/trunk/src/mlpack/core/tree/statistic.hpp 2011-10-26 22:38:11 UTC (rev 10045)
+++ mlpack/trunk/src/mlpack/core/tree/statistic.hpp 2011-10-27 00:08:59 UTC (rev 10046)
@@ -17,7 +17,6 @@
*
* @experimental
*/
-// TODO: determine how to handle this template <class TDataset>
class EmptyStatistic {
public:
EmptyStatistic() {}
Modified: mlpack/trunk/src/mlpack/core/utilities/CMakeLists.txt
===================================================================
--- mlpack/trunk/src/mlpack/core/model/CMakeLists.txt 2011-10-17 21:50:13 UTC (rev 9909)
+++ mlpack/trunk/src/mlpack/core/utilities/CMakeLists.txt 2011-10-27 00:08:59 UTC (rev 10046)
@@ -3,11 +3,10 @@
# Define the files we need to compile.
# Anything not in this list will not be compiled into MLPACK.
set(SOURCES
- model.hpp
- save_restore_model.cpp
- save_restore_model.hpp
- save_restore_model_impl.hpp
- save_restore_model_test.cpp
+ save_restore_utility.cpp
+ save_restore_utility.hpp
+ save_restore_utility.hpp
+ save_restore_utility.cpp
)
# add directory name to sources
@@ -20,10 +19,10 @@
set(MLPACK_SRCS ${MLPACK_SRCS} ${DIR_SRCS} PARENT_SCOPE)
# add test executable
-add_executable(save_restore_model_test
- save_restore_model_test.cpp
+add_executable(save_restore_utility_test
+ save_restore_utility_test.cpp
)
-target_link_libraries(save_restore_model_test
+target_link_libraries(save_restore_utility_test
mlpack
boost_unit_test_framework
xml2
Deleted: mlpack/trunk/src/mlpack/core/utilities/model.hpp
===================================================================
--- mlpack/trunk/src/mlpack/core/model/model.hpp 2011-10-17 21:50:13 UTC (rev 9909)
+++ mlpack/trunk/src/mlpack/core/utilities/model.hpp 2011-10-27 00:08:59 UTC (rev 10046)
@@ -1,19 +0,0 @@
-#ifndef MODEL_HPP
-#define MODEL_HPP
-
-namespace mlpack
-{
- namespace model
- {
-
- class Model
- {
- public:
- Model () {}
- virtual ~Model () {}
- virtual bool solve () = 0;
- };
- };
-};
-
-#endif
Deleted: mlpack/trunk/src/mlpack/core/utilities/save_restore_model.cpp
===================================================================
--- mlpack/trunk/src/mlpack/core/model/save_restore_model.cpp 2011-10-17 21:50:13 UTC (rev 9909)
+++ mlpack/trunk/src/mlpack/core/utilities/save_restore_model.cpp 2011-10-27 00:08:59 UTC (rev 10046)
@@ -1,151 +0,0 @@
-#include "save_restore_model.hpp"
-
-using namespace mlpack;
-using namespace model;
-
-bool SaveRestoreModel::readFile (std::string filename)
-{
- xmlDocPtr xmlDocTree = NULL;
- if (NULL == (xmlDocTree = xmlReadFile (filename.c_str(), NULL, 0)))
- {
- errx (1, "Clearly, we couldn't load the XML file\n");
- }
- xmlNodePtr root = xmlDocGetRootElement (xmlDocTree);
- parameters.clear();
-
- recurseOnNodes (root->children);
- xmlFreeDoc (xmlDocTree);
- return true;
-}
-void SaveRestoreModel::recurseOnNodes (xmlNode* n)
-{
- xmlNodePtr current = NULL;
- for (current = n; current; current = current->next)
- {
- if (current->type == XML_ELEMENT_NODE)
- {
- xmlChar* content = xmlNodeGetContent (current);
- parameters[(const char*) current->name] = (const char*) content;
- xmlFree (content);
- }
- recurseOnNodes (current->children);
- }
-}
-bool SaveRestoreModel::writeFile (std::string filename)
-{
- bool success = false;
- xmlDocPtr xmlDocTree = xmlNewDoc (BAD_CAST "1.0");
- xmlNodePtr root = xmlNewNode(NULL, BAD_CAST "root");
- xmlNodePtr child = NULL;
-
- xmlDocSetRootElement (xmlDocTree, root);
-
- for (std::map<std::string, std::string>::iterator it = parameters.begin();
- it != parameters.end();
- ++it)
- {
- child = xmlNewChild (root, NULL,
- BAD_CAST (*it).first.c_str(),
- BAD_CAST (*it).second.c_str());
- /* TODO: perhaps we'll add more later?
- * xmlNewProp (child, BAD_CAST "attr", BAD_CAST "add more addibutes?"); */
- }
- /* save the file */
- xmlSaveFormatFileEnc (filename.c_str(), xmlDocTree, "UTF-8", 1);
- xmlFreeDoc (xmlDocTree);
- return success;
-}
-arma::mat& SaveRestoreModel::loadParameter (arma::mat& matrix, std::string name)
-{
- std::map<std::string, std::string>::iterator it = parameters.find (name);
- if (it != parameters.end ())
- {
- std::string value = (*it).second;
- boost::char_separator<char> sep ("\n");
- boost::tokenizer<boost::char_separator<char> > tok (value, sep);
- std::list<std::list<double> > rows;
- for (boost::tokenizer<boost::char_separator<char> >::iterator
- tokIt = tok.begin ();
- tokIt != tok.end ();
- ++tokIt)
- {
- std::string row = *tokIt;
- boost::char_separator<char> sepComma (",");
- boost::tokenizer<boost::char_separator<char> >
- tokInner (row, sepComma);
- std::list<double> rowList;
- for (boost::tokenizer<boost::char_separator<char> >::iterator
- tokInnerIt = tokInner.begin ();
- tokInnerIt != tokInner.end ();
- ++tokInnerIt)
- {
- double element;
- std::istringstream iss (*tokInnerIt);
- iss >> element;
- rowList.push_back (element);
- }
- rows.push_back (rowList);
- }
- matrix.zeros (rows.size (), (*(rows.begin ())).size ());
- size_t rowCounter = 0;
- size_t columnCounter = 0;
- for (std::list<std::list<double> >::iterator rowIt = rows.begin ();
- rowIt != rows.end ();
- ++rowIt)
- {
- std::list<double> row = *rowIt;
- columnCounter = 0;
- for (std::list<double>::iterator elementIt = row.begin ();
- elementIt != row.end ();
- ++elementIt)
- {
- matrix(rowCounter, columnCounter) = *elementIt;
- columnCounter++;
- }
- rowCounter++;
- }
- return matrix;
- }
- else
- {
- errx (1, "Missing the correct name\n");
- }
-}
-char SaveRestoreModel::loadParameter (char c, std::string name)
-{
- int temp;
- std::map<std::string, std::string>::iterator it = parameters.find (name);
- if (it != parameters.end ())
- {
- std::string value = (*it).second;
- std::istringstream input (value);
- input >> temp;
- return (char) temp;
- }
- else
- {
- errx (1, "Missing the correct name\n");
- }
-}
-void SaveRestoreModel::saveParameter (char c, std::string name)
-{
- int temp = (int) c;
- std::ostringstream output;
- output << temp;
- parameters[name] = output.str();
-}
-void SaveRestoreModel::saveParameter (arma::mat& mat, std::string name)
-{
- std::ostringstream output;
- size_t columns = mat.n_cols;
- size_t rows = mat.n_rows;
- for (size_t r = 0; r < rows; ++r)
- {
- for (size_t c = 0; c < columns - 1; ++c)
- {
- output << mat(r,c) << ",";
- }
- output << mat(r,columns - 1) << std::endl;
- }
- parameters[name] = output.str();
-}
Deleted: mlpack/trunk/src/mlpack/core/utilities/save_restore_model.hpp
===================================================================
--- mlpack/trunk/src/mlpack/core/model/save_restore_model.hpp 2011-10-17 21:50:13 UTC (rev 9909)
+++ mlpack/trunk/src/mlpack/core/utilities/save_restore_model.hpp 2011-10-27 00:08:59 UTC (rev 10046)
@@ -1,55 +0,0 @@
-#ifndef SAVE_RESTORE_MODEL_HPP
-#define SAVE_RESTORE_MODEL_HPP
-
-#include <err.h>
-#include <list>
-#include <map>
-#include <sstream>
-#include <string>
-
-#include <libxml/parser.h>
-#include <libxml/tree.h>
-
-#include <armadillo>
-#include <boost/tokenizer.hpp>
-
-#include "model.hpp"
-
-namespace mlpack
-{
- namespace model
- {
- class SaveRestoreModel : public Model
- {
- private:
- std::map<std::string, std::string> parameters;
-
- public:
- SaveRestoreModel() {}
- ~SaveRestoreModel() { parameters.clear(); }
- bool readFile (std::string filename);
- void recurseOnNodes (xmlNode* n);
- bool writeFile (std::string filename);
- template<typename T>
- T& loadParameter (T& t, std::string name);
- char loadParameter (char c, std::string name);
- arma::mat& loadParameter (arma::mat& matrix, std::string name);
- template<typename T>
- void saveParameter (T& t, std::string name);
- void saveParameter (char c, std::string name);
- void saveParameter (arma::mat& mat, std::string name);
- virtual bool loadModel (std::string filename)
- {
- return true;
- }
- virtual bool saveModel (std::string filename)
- {
- return true;
- }
- };
- };
-};
-
-#include "save_restore_model_impl.hpp"
-
-#endif
Deleted: mlpack/trunk/src/mlpack/core/utilities/save_restore_model_impl.hpp
===================================================================
--- mlpack/trunk/src/mlpack/core/model/save_restore_model_impl.hpp 2011-10-17 21:50:13 UTC (rev 9909)
+++ mlpack/trunk/src/mlpack/core/utilities/save_restore_model_impl.hpp 2011-10-27 00:08:59 UTC (rev 10046)
@@ -1,30 +0,0 @@
-#ifndef SAVE_RESTORE_MODEL_HPP
-#error "Do not include this header directly."
-#endif
-
-using namespace mlpack;
-using namespace mlpack::model;
-
-template<typename T>
-T& SaveRestoreModel::loadParameter (T& t, std::string name)
-{
- std::map<std::string, std::string>::iterator it = parameters.find (name);
- if (it != parameters.end ())
- {
- std::string value = (*it).second;
- std::istringstream input (value);
- input >> t;
- return t;
- }
- else
- {
- errx (1, "Missing the correct name\n");
- }
-}
-template<typename T>
-void SaveRestoreModel::saveParameter (T& t, std::string name)
-{
- std::ostringstream output;
- output << t;
- parameters[name] = output.str();
-}
Deleted: mlpack/trunk/src/mlpack/core/utilities/save_restore_model_test.cpp
===================================================================
--- mlpack/trunk/src/mlpack/core/model/save_restore_model_test.cpp 2011-10-17 21:50:13 UTC (rev 9909)
+++ mlpack/trunk/src/mlpack/core/utilities/save_restore_model_test.cpp 2011-10-27 00:08:59 UTC (rev 10046)
@@ -1,145 +0,0 @@
-/***
- * @file save_restore_model_test.cpp
- * @author Neil Slagle
- *
- * Here we have tests for the SaveRestoreModel class.
- */
-
-#include "save_restore_model.hpp"
-
-#define BOOST_TEST_MODULE SaveRestoreModel Test
-#include <boost/test/unit_test.hpp>
-
-#define ARGSTR(a) a,#a
-
-/***
- * We must override the purely virtual method solve; the
- * overridden saveModel and loadModel enable testing
- * of child class proper usage
- */
-class SaveRestoreModelTest : public SaveRestoreModel
-{
- private:
- size_t anInt;
- public:
- bool solve () { return true; }
- bool saveModel (std::string filename)
- {
- this->saveParameter (anInt, "anInt");
- return this->writeFile (filename);
- }
- bool loadModel (std::string filename)
- {
- bool success = this->readFile (filename);
- if (success)
- {
- anInt = this->loadParameter (anInt, "anInt");
- }
- return success;
- }
- size_t getAnInt () { return anInt; }
- void setAnInt (size_t s) { this->anInt = s; }
-};
-
-/***
- * Perform a save and restore on basic types.
- */
-BOOST_AUTO_TEST_CASE(save_basic_types)
-{
- bool b = false;
- char c = 67;
- unsigned u = 34;
- size_t s = 12;
- short sh = 100;
- int i = -23;
- float f = -2.34f;
- double d = 3.14159;
- std::string cc = "Hello world!";
-
- SaveRestoreModelTest* sRM = new SaveRestoreModelTest();
-
- sRM->saveParameter (ARGSTR(b));
- sRM->saveParameter (ARGSTR(c));
- sRM->saveParameter (ARGSTR(u));
- sRM->saveParameter (ARGSTR(s));
- sRM->saveParameter (ARGSTR(sh));
- sRM->saveParameter (ARGSTR(i));
- sRM->saveParameter (ARGSTR(f));
- sRM->saveParameter (ARGSTR(d));
- sRM->saveParameter (ARGSTR(cc));
- sRM->writeFile ("test_basic_types.xml");
-
- sRM->readFile ("test_basic_types.xml");
-
- bool b2 = sRM->loadParameter (ARGSTR(b));
- char c2 = sRM->loadParameter (ARGSTR(c));
- unsigned u2 = sRM->loadParameter (ARGSTR(u));
- size_t s2 = sRM->loadParameter (ARGSTR(s));
- short sh2 = sRM->loadParameter (ARGSTR(sh));
- int i2 = sRM->loadParameter (ARGSTR(i));
- float f2 = sRM->loadParameter (ARGSTR(f));
- double d2 = sRM->loadParameter (ARGSTR(d));
- std::string cc2 = sRM->loadParameter (ARGSTR(cc));
-
- BOOST_REQUIRE (b == b2);
- BOOST_REQUIRE (c == c2);
- BOOST_REQUIRE (u == u2);
- BOOST_REQUIRE (s == s2);
- BOOST_REQUIRE (sh == sh2);
- BOOST_REQUIRE (i == i2);
- BOOST_REQUIRE_CLOSE (f, f2, 1e-5);
- BOOST_REQUIRE_CLOSE (d, d2, 1e-5);
-
- delete sRM;
-}
-
-/***
- * Test the arma::mat functionality.
- */
-BOOST_AUTO_TEST_CASE(save_arma_mat)
-{
- arma::mat matrix;
- matrix << 1.2 << 2.3 << -0.1 << arma::endr
- << 3.5 << 2.4 << -1.2 << arma::endr
- << -0.1 << 3.4 << -7.8 << arma::endr;
- SaveRestoreModelTest* sRM = new SaveRestoreModelTest();
-
- sRM->saveParameter (ARGSTR (matrix));
-
- sRM->writeFile ("test_arma_mat_type.xml");
-
- sRM->readFile ("test_arma_mat_type.xml");
-
- arma::mat matrix2 = sRM->loadParameter (ARGSTR (matrix));
-
- for (size_t row = 0; row < matrix.n_rows; ++row)
- {
- for (size_t column = 0; column < matrix.n_cols; ++column)
- {
- BOOST_REQUIRE_CLOSE(matrix(row,column), matrix2(row,column), 1e-5);
- }
- }
-
- delete sRM;
-}
-/***
- * Test SaveRestoreModel proper usage in child classes and loading from
- * separately defined objects
- */
-BOOST_AUTO_TEST_CASE(save_restore_model_child_class_usage)
-{
- SaveRestoreModelTest* saver = new SaveRestoreModelTest();
- SaveRestoreModelTest* loader = new SaveRestoreModelTest();
- size_t s = 1200;
- const char* filename = "anInt.xml";
-
- saver->setAnInt (s);
- saver->saveModel (filename);
- delete saver;
-
- loader->loadModel (filename);
-
- BOOST_REQUIRE (loader->getAnInt () == s);
-
- delete loader;
-}
Copied: mlpack/trunk/src/mlpack/core/utilities/save_restore_utility.cpp (from rev 9909, mlpack/trunk/src/mlpack/core/model/save_restore_model.cpp)
===================================================================
--- mlpack/trunk/src/mlpack/core/utilities/save_restore_utility.cpp (rev 0)
+++ mlpack/trunk/src/mlpack/core/utilities/save_restore_utility.cpp 2011-10-27 00:08:59 UTC (rev 10046)
@@ -0,0 +1,172 @@
+/**
+ * @file utilities/save_restore_utility.cpp
+ * @author Neil Slagle
+ *
+ * The SaveRestoreUtility provides helper functions in saving and
+ * restoring models. The current output file type is XML.
+ *
+ * @experimental
+ */
+#include "save_restore_utility.hpp"
+
+using namespace mlpack;
+using namespace utilities;
+
+bool SaveRestoreUtility::ReadFile (std::string filename)
+{
+ xmlDocPtr xmlDocTree = NULL;
+ if (NULL == (xmlDocTree = xmlReadFile (filename.c_str(), NULL, 0)))
+ {
+ errx (1, "Clearly, we couldn't load the XML file\n");
+ }
+ xmlNodePtr root = xmlDocGetRootElement (xmlDocTree);
+ parameters.clear();
+
+ RecurseOnNodes (root->children);
+ xmlFreeDoc (xmlDocTree);
+ return true;
+}
+void SaveRestoreUtility::RecurseOnNodes (xmlNode* n)
+{
+ xmlNodePtr current = NULL;
+ for (current = n; current; current = current->next)
+ {
+ if (current->type == XML_ELEMENT_NODE)
+ {
+ xmlChar* content = xmlNodeGetContent (current);
+ parameters[(const char*) current->name] = (const char*) content;
+ xmlFree (content);
+ }
+ RecurseOnNodes (current->children);
+ }
+}
+bool SaveRestoreUtility::WriteFile (std::string filename)
+{
+ bool success = false;
+ xmlDocPtr xmlDocTree = xmlNewDoc (BAD_CAST "1.0");
+ xmlNodePtr root = xmlNewNode(NULL, BAD_CAST "root");
+ xmlNodePtr child = NULL;
+
+ xmlDocSetRootElement (xmlDocTree, root);
+
+ for (std::map<std::string, std::string>::iterator it = parameters.begin();
+ it != parameters.end();
+ ++it)
+ {
+ child = xmlNewChild (root, NULL,
+ BAD_CAST (*it).first.c_str(),
+ BAD_CAST (*it).second.c_str());
+ /* TODO: perhaps we'll add more later?
+ * xmlNewProp (child, BAD_CAST "attr", BAD_CAST "add more addibutes?"); */
+ }
+ /* save the file */
+ xmlSaveFormatFileEnc (filename.c_str(), xmlDocTree, "UTF-8", 1);
+ xmlFreeDoc (xmlDocTree);
+ return success;
+}
+arma::mat& SaveRestoreUtility::LoadParameter (arma::mat& matrix, std::string name)
+{
+ std::map<std::string, std::string>::iterator it = parameters.find (name);
+ if (it != parameters.end ())
+ {
+ std::string value = (*it).second;
+ boost::char_separator<char> sep ("\n");
+ boost::tokenizer<boost::char_separator<char> > tok (value, sep);
+ std::list<std::list<double> > rows;
+ for (boost::tokenizer<boost::char_separator<char> >::iterator
+ tokIt = tok.begin ();
+ tokIt != tok.end ();
+ ++tokIt)
+ {
+ std::string row = *tokIt;
+ boost::char_separator<char> sepComma (",");
+ boost::tokenizer<boost::char_separator<char> >
+ tokInner (row, sepComma);
+ std::list<double> rowList;
+ for (boost::tokenizer<boost::char_separator<char> >::iterator
+ tokInnerIt = tokInner.begin ();
+ tokInnerIt != tokInner.end ();
+ ++tokInnerIt)
+ {
+ double element;
+ std::istringstream iss (*tokInnerIt);
+ iss >> element;
+ rowList.push_back (element);
+ }
+ rows.push_back (rowList);
+ }
+ matrix.zeros (rows.size (), (*(rows.begin ())).size ());
+ size_t rowCounter = 0;
+ size_t columnCounter = 0;
+ for (std::list<std::list<double> >::iterator rowIt = rows.begin ();
+ rowIt != rows.end ();
+ ++rowIt)
+ {
+ std::list<double> row = *rowIt;
+ columnCounter = 0;
+ for (std::list<double>::iterator elementIt = row.begin ();
+ elementIt != row.end ();
+ ++elementIt)
+ {
+ matrix(rowCounter, columnCounter) = *elementIt;
+ columnCounter++;
+ }
+ rowCounter++;
+ }
+ return matrix;
+ }
+ else
+ {
+ errx (1, "Missing the correct name\n");
+ }
+}
+std::string SaveRestoreUtility::LoadParameter (std::string str, std::string name)
+{
+ std::map<std::string, std::string>::iterator it = parameters.find (name);
+ if (it != parameters.end ())
+ {
+ return (*it).second;
+ }
+ else
+ {
+ errx (1, "Missing the correct name\n");
+ }
+}
+char SaveRestoreUtility::LoadParameter (char c, std::string name)
+{
+ int temp;
+ std::map<std::string, std::string>::iterator it = parameters.find (name);
+ if (it != parameters.end ())
+ {
+ std::string value = (*it).second;
+ std::istringstream input (value);
+ input >> temp;
+ return (char) temp;
+ }
+ else
+ {
+ errx (1, "Missing the correct name\n");
+ }
+}
+void SaveRestoreUtility::SaveParameter (char c, std::string name)
+{
+ int temp = (int) c;
+ std::ostringstream output;
+ output << temp;
+ parameters[name] = output.str();
+}
+void SaveRestoreUtility::SaveParameter (arma::mat& mat, std::string name)
+{
+ std::ostringstream output;
+ size_t columns = mat.n_cols;
+ size_t rows = mat.n_rows;
+ for (size_t r = 0; r < rows; ++r)
+ {
+ for (size_t c = 0; c < columns - 1; ++c)
+ {
+ output << mat(r,c) << ",";
+ }
+ output << mat(r,columns - 1) << std::endl;
+ }
+ parameters[name] = output.str();
+}
Copied: mlpack/trunk/src/mlpack/core/utilities/save_restore_utility.hpp (from rev 9909, mlpack/trunk/src/mlpack/core/model/save_restore_model.hpp)
===================================================================
--- mlpack/trunk/src/mlpack/core/utilities/save_restore_utility.hpp (rev 0)
+++ mlpack/trunk/src/mlpack/core/utilities/save_restore_utility.hpp 2011-10-27 00:08:59 UTC (rev 10046)
@@ -0,0 +1,87 @@
+/**
+ * @file utilities/save_restore_utility.hpp
+ * @author Neil Slagle
+ *
+ * The SaveRestoreUtility provides helper functions in saving and
+ * restoring models. The current output file type is XML.
+ *
+ * @experimental
+ */
+
+#ifndef SAVE_RESTORE_MODEL_HPP
+#define SAVE_RESTORE_MODEL_HPP
+
+#include <err.h>
+#include <list>
+#include <map>
+#include <sstream>
+#include <string>
+
+#include <libxml/parser.h>
+#include <libxml/tree.h>
+
+#include <armadillo>
+#include <boost/tokenizer.hpp>
+
+namespace mlpack {
+namespace utilities {
+
+class SaveRestoreUtility
+{
+ private:
+ /**
+ * parameters contains a list of names and parameters in string form.
+ */
+ std::map<std::string, std::string> parameters;
+ /**
+ * RecurseOnNodes performs a depth first search of the XML tree.
+ */
+ void RecurseOnNodes (xmlNode* n);
+ public:
+ SaveRestoreUtility() {}
+ ~SaveRestoreUtility() { parameters.clear(); }
+ /**
+ * ReadFile reads an XML tree from a file.
+ */
+ bool ReadFile (std::string filename);
+ /**
+ * WriteFile writes the XML tree to a file.
+ */
+ bool WriteFile (std::string filename);
+ /**
+ * LoadParameter loads a parameter from the parameters map.
+ */
+ template<typename T>
+ T& LoadParameter (T& t, std::string name);
+ /**
+ * LoadParameter loads a character from the parameters map.
+ */
+ char LoadParameter (char c, std::string name);
+ /**
+ * LoadParameter loads a string from the parameters map.
+ */
+ std::string LoadParameter (std::string str, std::string name);
+ /**
+ * LoadParameter loads an arma::mat from the parameters map.
+ */
+ arma::mat& LoadParameter (arma::mat& matrix, std::string name);
+ /**
+ * SaveParameter saves a parameter to the parameters map.
+ */
+ template<typename T>
+ void SaveParameter (T& t, std::string name);
+ /**
+ * SaveParameter saves a character to the parameters map.
+ */
+ void SaveParameter (char c, std::string name);
+ /**
+ * SaveParameter saves an arma::mat to the parameters map.
+ */
+ void SaveParameter (arma::mat& mat, std::string name);
+};
+} /* namespace model */
+} /* namespace mlpack */
+
+#include "save_restore_utility_impl.hpp"
+
+#endif
Copied: mlpack/trunk/src/mlpack/core/utilities/save_restore_utility_impl.hpp (from rev 9909, mlpack/trunk/src/mlpack/core/model/save_restore_model_impl.hpp)
===================================================================
--- mlpack/trunk/src/mlpack/core/utilities/save_restore_utility_impl.hpp (rev 0)
+++ mlpack/trunk/src/mlpack/core/utilities/save_restore_utility_impl.hpp 2011-10-27 00:08:59 UTC (rev 10046)
@@ -0,0 +1,39 @@
+/**
+ * @file utilities/save_restore_utility_impl.hpp
+ * @author Neil Slagle
+ *
+ * The SaveRestoreUtility provides helper functions in saving and
+ * restoring models. The current output file type is XML.
+ *
+ * @experimental
+ */
+#ifndef SAVE_RESTORE_MODEL_HPP
+#error "Do not include this header directly."
+#endif
+
+using namespace mlpack;
+using namespace mlpack::utilities;
+
+template<typename T>
+T& SaveRestoreUtility::LoadParameter (T& t, std::string name)
+{
+ std::map<std::string, std::string>::iterator it = parameters.find (name);
+ if (it != parameters.end ())
+ {
+ std::string value = (*it).second;
+ std::istringstream input (value);
+ input >> t;
+ return t;
+ }
+ else
+ {
+ errx (1, "Missing the correct name\n");
+ }
+}
+template<typename T>
+void SaveRestoreUtility::SaveParameter (T& t, std::string name)
+{
+ std::ostringstream output;
+ output << t;
+ parameters[name] = output.str();
+}
Copied: mlpack/trunk/src/mlpack/core/utilities/save_restore_utility_test.cpp (from rev 9909, mlpack/trunk/src/mlpack/core/model/save_restore_model_test.cpp)
===================================================================
--- mlpack/trunk/src/mlpack/core/utilities/save_restore_utility_test.cpp (rev 0)
+++ mlpack/trunk/src/mlpack/core/utilities/save_restore_utility_test.cpp 2011-10-27 00:08:59 UTC (rev 10046)
@@ -0,0 +1,149 @@
+/***
+ * @file save_restore_model_test.cpp
+ * @author Neil Slagle
+ *
+ * Here we have tests for the SaveRestoreModel class.
+ */
+
+#include "save_restore_utility.hpp"
+
+#define BOOST_TEST_MODULE SaveRestoreModel Test
+#include <boost/test/unit_test.hpp>
+
+#define ARGSTR(a) a,#a
+
+/***
+ * Exhibit proper save restore utility usage
+ * of child class proper usage
+ */
+class SaveRestoreTest
+{
+ private:
+ size_t anInt;
+ SaveRestoreUtility saveRestore;
+ public:
+ SaveRestoreTest()
+ {
+ saveRestore = SaveRestoreUtility();
+ }
+ bool SaveModel (std::string filename)
+ {
+ saveRestore.SaveParameter (anInt, "anInt");
+ return saveRestore.WriteFile (filename);
+ }
+ bool LoadModel (std::string filename)
+ {
+ bool success = saveRestore.ReadFile (filename);
+ if (success)
+ {
+ anInt = saveRestore.LoadParameter (anInt, "anInt");
+ }
+ return success;
+ }
+ const size_t AnInt () { return anInt; }
+ void AnInt (size_t s) { this->anInt = s; }
+};
+
+/***
+ * Perform a save and restore on basic types.
+ */
+BOOST_AUTO_TEST_CASE(save_basic_types)
+{
+ bool b = false;
+ char c = 67;
+ unsigned u = 34;
+ size_t s = 12;
+ short sh = 100;
+ int i = -23;
+ float f = -2.34f;
+ double d = 3.14159;
+ std::string cc = "Hello world!";
+
+ SaveRestoreUtility* sRM = new SaveRestoreUtility();
+
+ sRM->SaveParameter (ARGSTR(b));
+ sRM->SaveParameter (ARGSTR(c));
+ sRM->SaveParameter (ARGSTR(u));
+ sRM->SaveParameter (ARGSTR(s));
+ sRM->SaveParameter (ARGSTR(sh));
+ sRM->SaveParameter (ARGSTR(i));
+ sRM->SaveParameter (ARGSTR(f));
+ sRM->SaveParameter (ARGSTR(d));
+ sRM->SaveParameter (ARGSTR(cc));
+ sRM->WriteFile ("test_basic_types.xml");
+
+ sRM->ReadFile ("test_basic_types.xml");
+
+ bool b2 = sRM->LoadParameter (ARGSTR(b));
+ char c2 = sRM->LoadParameter (ARGSTR(c));
+ unsigned u2 = sRM->LoadParameter (ARGSTR(u));
+ size_t s2 = sRM->LoadParameter (ARGSTR(s));
+ short sh2 = sRM->LoadParameter (ARGSTR(sh));
+ int i2 = sRM->LoadParameter (ARGSTR(i));
+ float f2 = sRM->LoadParameter (ARGSTR(f));
+ double d2 = sRM->LoadParameter (ARGSTR(d));
+ std::string cc2 = sRM->LoadParameter (ARGSTR(cc));
+
+ BOOST_REQUIRE (b == b2);
+ BOOST_REQUIRE (c == c2);
+ BOOST_REQUIRE (u == u2);
+ BOOST_REQUIRE (s == s2);
+ BOOST_REQUIRE (sh == sh2);
+ BOOST_REQUIRE (i == i2);
+ BOOST_REQUIRE (cc == cc2);
+ BOOST_REQUIRE_CLOSE (f, f2, 1e-5);
+ BOOST_REQUIRE_CLOSE (d, d2, 1e-5);
+
+ delete sRM;
+}
+
+/***
+ * Test the arma::mat functionality.
+ */
+BOOST_AUTO_TEST_CASE(save_arma_mat)
+{
+ arma::mat matrix;
+ matrix << 1.2 << 2.3 << -0.1 << arma::endr
+ << 3.5 << 2.4 << -1.2 << arma::endr
+ << -0.1 << 3.4 << -7.8 << arma::endr;
+ SaveRestoreUtility* sRM = new SaveRestoreUtility();
+
+ sRM->SaveParameter (ARGSTR (matrix));
+
+ sRM->WriteFile ("test_arma_mat_type.xml");
+
+ sRM->ReadFile ("test_arma_mat_type.xml");
+
+ arma::mat matrix2 = sRM->LoadParameter (ARGSTR (matrix));
+
+ for (size_t row = 0; row < matrix.n_rows; ++row)
+ {
+ for (size_t column = 0; column < matrix.n_cols; ++column)
+ {
+ BOOST_REQUIRE_CLOSE(matrix(row,column), matrix2(row,column), 1e-5);
+ }
+ }
+
+ delete sRM;
+}
+/***
+ * Test SaveRestoreModel proper usage in child classes and loading from
+ * separately defined objects
+ */
+BOOST_AUTO_TEST_CASE(save_restore_model_child_class_usage)
+{
+ SaveRestoreTest* saver = new SaveRestoreTest();
+ SaveRestoreTest* loader = new SaveRestoreTest();
+ size_t s = 1200;
+ const char* filename = "anInt.xml";
+
+ saver->AnInt (s);
+ saver->SaveModel (filename);
+ delete saver;
+
+ loader->LoadModel (filename);
+
+ BOOST_REQUIRE (loader->AnInt () == s);
+
+ delete loader;
+}
Modified: mlpack/trunk/src/mlpack/core.h
===================================================================
--- mlpack/trunk/src/mlpack/core.h 2011-10-26 22:38:11 UTC (rev 10045)
+++ mlpack/trunk/src/mlpack/core.h 2011-10-27 00:08:59 UTC (rev 10046)
@@ -92,8 +92,7 @@
// Now MLPACK-specific includes.
#include <mlpack/core/math/math_lib.hpp>
#include <mlpack/core/math/range.hpp>
-#include <mlpack/core/model/model.hpp>
-#include <mlpack/core/model/save_restore_model.hpp>
+#include <mlpack/core/utilities/save_restore_utility.hpp>
#include <mlpack/core/file/textfile.h>
#include <mlpack/core/io/cli.hpp>
#include <mlpack/core/io/log.hpp>
More information about the mlpack-svn
mailing list