[mlpack-git] master: Move splitHistory to XTreeSplit. Update serialization of the RectangleTree class. (4208f38)

gitdub at mlpack.org gitdub at mlpack.org
Tue May 3 13:41:03 EDT 2016


Repository : https://github.com/mlpack/mlpack
On branch  : master
Link       : https://github.com/mlpack/mlpack/compare/4c8a8d1ccbc33916794fe0f6142fa5378ff8503d...93eb34cc4daf08cc331d1a22fccf44950b276daa

>---------------------------------------------------------------

commit 4208f38a73cb46f147602064eb13d32fb9aabe24
Author: Mikhail Lozhnikov <lozhnikovma at gmail.com>
Date:   Tue May 3 20:41:03 2016 +0300

    Move splitHistory to XTreeSplit. Update serialization of the RectangleTree class.


>---------------------------------------------------------------

4208f38a73cb46f147602064eb13d32fb9aabe24
 .../core/tree/rectangle_tree/r_star_tree_split.hpp |  7 +++
 .../core/tree/rectangle_tree/r_tree_split.hpp      |  8 ++++
 .../core/tree/rectangle_tree/rectangle_tree.hpp    | 30 -------------
 .../tree/rectangle_tree/rectangle_tree_impl.hpp    |  7 +--
 .../core/tree/rectangle_tree/x_tree_split.hpp      | 36 ++++++++++++++++
 .../core/tree/rectangle_tree/x_tree_split_impl.hpp | 50 +++++++++++++++-------
 6 files changed, 86 insertions(+), 52 deletions(-)

diff --git a/src/mlpack/core/tree/rectangle_tree/r_star_tree_split.hpp b/src/mlpack/core/tree/rectangle_tree/r_star_tree_split.hpp
index 908d6c2..e3550ff 100644
--- a/src/mlpack/core/tree/rectangle_tree/r_star_tree_split.hpp
+++ b/src/mlpack/core/tree/rectangle_tree/r_star_tree_split.hpp
@@ -72,6 +72,13 @@ class RStarTreeSplit
    * Insert a node into another node.
    */
   static void InsertNodeIntoTree(TreeType* destTree, TreeType* srcNode);
+
+ public:
+  /**
+   * Serialize the split.
+   */
+  template<typename Archive>
+  void Serialize(Archive &, const unsigned int /* version */) { };
 };
 
 } // namespace tree
diff --git a/src/mlpack/core/tree/rectangle_tree/r_tree_split.hpp b/src/mlpack/core/tree/rectangle_tree/r_tree_split.hpp
index ebb640c..5312125 100644
--- a/src/mlpack/core/tree/rectangle_tree/r_tree_split.hpp
+++ b/src/mlpack/core/tree/rectangle_tree/r_tree_split.hpp
@@ -79,6 +79,14 @@ class RTreeSplit
    * Insert a node into another node.
    */
   static void InsertNodeIntoTree(TreeType* destTree, TreeType* srcNode);
+
+ public:
+  /**
+   * Serialize the split.
+   */
+  template<typename Archive>
+  void Serialize(Archive &, const unsigned int /* version */) { };
+
 };
 
 } // namespace tree
diff --git a/src/mlpack/core/tree/rectangle_tree/rectangle_tree.hpp b/src/mlpack/core/tree/rectangle_tree/rectangle_tree.hpp
index 8e762f9..8432f44 100644
--- a/src/mlpack/core/tree/rectangle_tree/rectangle_tree.hpp
+++ b/src/mlpack/core/tree/rectangle_tree/rectangle_tree.hpp
@@ -53,29 +53,6 @@ class RectangleTree
   //! The element type held by the matrix type.
   typedef typename MatType::elem_type ElemType;
 
-  /**
-   * The X tree requires that the tree records it's "split history".  To make
-   * this easy, we use the following structure.
-   */
-  typedef struct SplitHistoryStruct
-  {
-    int lastDimension;
-    std::vector<bool> history;
-
-    SplitHistoryStruct(int dim) : lastDimension(0), history(dim)
-    {
-      for (int i = 0; i < dim; i++)
-        history[i] = false;
-    }
-
-    template<typename Archive>
-    void Serialize(Archive& ar, const unsigned int /* version */)
-    {
-      ar & data::CreateNVP(lastDimension, "lastDimension");
-      ar & data::CreateNVP(history, "history");
-    }
-  } SplitHistoryStruct;
-
  private:
   //! The max number of child nodes a non-leaf node can have.
   size_t maxNumChildren;
@@ -103,8 +80,6 @@ class RectangleTree
   bound::HRectBound<metric::EuclideanDistance, ElemType> bound;
   //! Any extra data contained in the node.
   StatisticType stat;
-  //! A struct to store the "split history" for X trees.
-  SplitHistoryStruct splitHistory;
   //! The distance from the centroid of this node to the centroid of the parent.
   ElemType parentDistance;
   //! The dataset.
@@ -316,11 +291,6 @@ class RectangleTree
   //! Modify the statistic object for this node.
   StatisticType& Stat() { return stat; }
 
-  //! Return the split history object of this node.
-  const SplitHistoryStruct& SplitHistory() const { return splitHistory; }
-  //! Modify the split history object of this node.
-  SplitHistoryStruct& SplitHistory() { return splitHistory; }
-
   //! Return the split object of this node.
   const SplitType<RectangleTree>& Split() const { return split; }
   //! Modify the split object of this node.
diff --git a/src/mlpack/core/tree/rectangle_tree/rectangle_tree_impl.hpp b/src/mlpack/core/tree/rectangle_tree/rectangle_tree_impl.hpp
index 4e1526c..6793649 100644
--- a/src/mlpack/core/tree/rectangle_tree/rectangle_tree_impl.hpp
+++ b/src/mlpack/core/tree/rectangle_tree/rectangle_tree_impl.hpp
@@ -39,7 +39,6 @@ RectangleTree(const MatType& data,
     maxLeafSize(maxLeafSize),
     minLeafSize(minLeafSize),
     bound(data.n_rows),
-    splitHistory(bound.Dim()),
     parentDistance(0),
     dataset(new MatType(data)),
     ownsDataset(true),
@@ -80,7 +79,6 @@ RectangleTree(MatType&& data,
     maxLeafSize(maxLeafSize),
     minLeafSize(minLeafSize),
     bound(data.n_rows),
-    splitHistory(bound.Dim()),
     parentDistance(0),
     dataset(new MatType(std::move(data))),
     ownsDataset(true),
@@ -118,7 +116,6 @@ RectangleTree(
     maxLeafSize(parentNode->MaxLeafSize()),
     minLeafSize(parentNode->MinLeafSize()),
     bound(parentNode->Bound().Dim()),
-    splitHistory(bound.Dim()),
     parentDistance(0),
     dataset(&parentNode->Dataset()),
     ownsDataset(false),
@@ -153,7 +150,6 @@ RectangleTree(
     maxLeafSize(other.MaxLeafSize()),
     minLeafSize(other.MinLeafSize()),
     bound(other.bound),
-    splitHistory(other.SplitHistory()),
     parentDistance(other.ParentDistance()),
     dataset(deepCopy ? new MatType(*other.dataset) : &other.Dataset()),
     ownsDataset(deepCopy),
@@ -701,7 +697,6 @@ RectangleTree() :
     count(0),
     maxLeafSize(0),
     minLeafSize(0),
-    splitHistory(0),
     parentDistance(0.0),
     dataset(NULL),
     ownsDataset(false),
@@ -1023,7 +1018,6 @@ void RectangleTree<MetricType, StatisticType, MatType, SplitType, DescentType>::
   ar & CreateNVP(minLeafSize, "minLeafSize");
   ar & CreateNVP(bound, "bound");
   ar & CreateNVP(stat, "stat");
-  ar & CreateNVP(splitHistory, "splitHistory");
   ar & CreateNVP(parentDistance, "parentDistance");
   ar & CreateNVP(dataset, "dataset");
 
@@ -1033,6 +1027,7 @@ void RectangleTree<MetricType, StatisticType, MatType, SplitType, DescentType>::
 
   ar & CreateNVP(points, "points");
   ar & CreateNVP(localDataset, "localDataset");
+  ar & CreateNVP(split, "split");
 
   // Because 'children' holds mlpack types (that have Serialize()), we can't use
   // the std::vector serialization.
diff --git a/src/mlpack/core/tree/rectangle_tree/x_tree_split.hpp b/src/mlpack/core/tree/rectangle_tree/x_tree_split.hpp
index 3d112c6..8d3e5d8 100644
--- a/src/mlpack/core/tree/rectangle_tree/x_tree_split.hpp
+++ b/src/mlpack/core/tree/rectangle_tree/x_tree_split.hpp
@@ -57,9 +57,34 @@ class XTreeSplit
    */
   bool SplitNonLeafNode(TreeType *tree,std::vector<bool>& relevels);
 
+  /**
+   * The X tree requires that the tree records it's "split history".  To make
+   * this easy, we use the following structure.
+   */
+  typedef struct SplitHistoryStruct
+  {
+    int lastDimension;
+    std::vector<bool> history;
+
+    SplitHistoryStruct(int dim) : lastDimension(0), history(dim)
+    {
+      for (int i = 0; i < dim; i++)
+        history[i] = false;
+    }
+
+    template<typename Archive>
+    void Serialize(Archive& ar, const unsigned int /* version */)
+    {
+      ar & data::CreateNVP(lastDimension, "lastDimension");
+      ar & data::CreateNVP(history, "history");
+    }
+  } SplitHistoryStruct;
+
  private:
   //! The max number of child nodes a non-leaf normal node can have.
   size_t normalNodeMaxNumChildren;
+  //! A struct to store the "split history" for X trees.
+  SplitHistoryStruct splitHistory;
 
   /**
    * Class to allow for faster sorting.
@@ -92,6 +117,17 @@ class XTreeSplit
   size_t NormalNodeMaxNumChildren() const { return normalNodeMaxNumChildren; }
   //! Modify the maximum number of a normal node's children.
   size_t& NormalNodeMaxNumChildren() { return normalNodeMaxNumChildren; }
+  //! Return the split history of the node assosiated with this object.
+  const SplitHistoryStruct& SplitHistory() const { return splitHistory; }
+  //! Modify the split history of the node assosiated with this object.
+  SplitHistoryStruct& SplitHistory() { return splitHistory; }
+
+
+  /**
+   * Serialize the split.
+   */
+  template<typename Archive>
+  void Serialize(Archive& ar, const unsigned int /* version */);
 };
 
 } // namespace tree
diff --git a/src/mlpack/core/tree/rectangle_tree/x_tree_split_impl.hpp b/src/mlpack/core/tree/rectangle_tree/x_tree_split_impl.hpp
index b1929d0..1bde106 100644
--- a/src/mlpack/core/tree/rectangle_tree/x_tree_split_impl.hpp
+++ b/src/mlpack/core/tree/rectangle_tree/x_tree_split_impl.hpp
@@ -16,28 +16,32 @@ namespace tree {
 
 template<typename TreeType>
 XTreeSplit<TreeType>::XTreeSplit() :
-    normalNodeMaxNumChildren(0)
+    normalNodeMaxNumChildren(0),
+    splitHistory(0)
 {
 
 }
 
 template<typename TreeType>
 XTreeSplit<TreeType>::XTreeSplit(const TreeType *node) :
-    normalNodeMaxNumChildren(node->MaxNumChildren())
+    normalNodeMaxNumChildren(node->MaxNumChildren()),
+    splitHistory(node->Bound().Dim())
 {
 
 }
 
 template<typename TreeType>
-XTreeSplit<TreeType>::XTreeSplit(const TreeType *,const TreeType *parent) :
-    normalNodeMaxNumChildren(parent->Split().NormalNodeMaxNumChildren())
+XTreeSplit<TreeType>::XTreeSplit(const TreeType *node,const TreeType *parent) :
+    normalNodeMaxNumChildren(parent->Split().NormalNodeMaxNumChildren()),
+    splitHistory(node->Bound().Dim())
 {
 
 }
 
 template<typename TreeType>
 XTreeSplit<TreeType>::XTreeSplit(const TreeType &other) :
-    normalNodeMaxNumChildren(other.Split().NormalNodeMaxNumChildren())
+    normalNodeMaxNumChildren(other.Split().NormalNodeMaxNumChildren()),
+    splitHistory(other.Split().SplitHistory())
 {
 
 }
@@ -290,10 +294,10 @@ void XTreeSplit<TreeType>::SplitLeafNode(TreeType *tree,std::vector<bool>& relev
   par->Children()[par->NumChildren()++] = treeTwo;
 
   // We now update the split history of each new node.
-  treeOne->SplitHistory().history[bestAxis] = true;
-  treeOne->SplitHistory().lastDimension = bestAxis;
-  treeTwo->SplitHistory().history[bestAxis] = true;
-  treeTwo->SplitHistory().lastDimension = bestAxis;
+  treeOne->Split().SplitHistory().history[bestAxis] = true;
+  treeOne->Split().SplitHistory().lastDimension = bestAxis;
+  treeTwo->Split().SplitHistory().history[bestAxis] = true;
+  treeTwo->Split().SplitHistory().lastDimension = bestAxis;
 
   // We only add one at a time, so we should only need to test for equality just
   // in case, we use an assert.
@@ -353,7 +357,7 @@ bool XTreeSplit<TreeType>::SplitNonLeafNode(TreeType *tree,std::vector<bool>& re
   std::vector<bool> axes(tree->Bound().Dim());
   std::vector<int> dimensionsLastUsed(tree->NumChildren());
   for (size_t i = 0; i < tree->NumChildren(); i++)
-    dimensionsLastUsed[i] = tree->Child(i).SplitHistory().lastDimension;
+    dimensionsLastUsed[i] = tree->Child(i).Split().SplitHistory().lastDimension;
   std::sort(dimensionsLastUsed.begin(), dimensionsLastUsed.end());
 
   size_t lastDim = dimensionsLastUsed[dimensionsLastUsed.size()/2];
@@ -364,7 +368,7 @@ bool XTreeSplit<TreeType>::SplitNonLeafNode(TreeType *tree,std::vector<bool>& re
   {
     axes[i] = true;
     for (size_t j = 0; j < tree->NumChildren(); j++)
-      axes[i] = axes[i] & tree->Child(j).SplitHistory().history[i];
+      axes[i] = axes[i] & tree->Child(j).Split().SplitHistory().history[i];
     if (axes[i] == true)
     {
       minOverlapSplitDimension = i;
@@ -377,7 +381,7 @@ bool XTreeSplit<TreeType>::SplitNonLeafNode(TreeType *tree,std::vector<bool>& re
     {
       axes[i] = true;
       for (size_t j = 0; j < tree->NumChildren(); j++)
-        axes[i] = axes[i] & tree->Child(j).SplitHistory().history[i];
+        axes[i] = axes[i] & tree->Child(j).Split().SplitHistory().history[i];
       if (axes[i] == true)
       {
         minOverlapSplitDimension = i;
@@ -802,10 +806,10 @@ bool XTreeSplit<TreeType>::SplitNonLeafNode(TreeType *tree,std::vector<bool>& re
   }
 
   // Update the split history of each child.
-  treeOne->SplitHistory().history[bestAxis] = true;
-  treeOne->SplitHistory().lastDimension = bestAxis;
-  treeTwo->SplitHistory().history[bestAxis] = true;
-  treeTwo->SplitHistory().lastDimension = bestAxis;
+  treeOne->Split().SplitHistory().history[bestAxis] = true;
+  treeOne->Split().SplitHistory().lastDimension = bestAxis;
+  treeTwo->Split().SplitHistory().history[bestAxis] = true;
+  treeTwo->Split().SplitHistory().lastDimension = bestAxis;
 
   // Remove this node and insert treeOne and treeTwo
   TreeType* par = tree->Parent();
@@ -868,6 +872,20 @@ void XTreeSplit<TreeType>::InsertNodeIntoTree(TreeType* destTree, TreeType* srcN
   destTree->NumChildren()++;
 }
 
+/**
+ * Serialize the split.
+ */
+template<typename TreeType>
+template<typename Archive>
+void XTreeSplit<TreeType>::Serialize(Archive& ar,const unsigned int /* version */)
+{
+  using data::CreateNVP;
+
+  ar & CreateNVP(normalNodeMaxNumChildren, "normalNodeMaxNumChildren");
+  ar & CreateNVP(splitHistory, "splitHistory");
+
+}
+
 } // namespace tree
 } // namespace mlpack
 




More information about the mlpack-git mailing list