[mlpack-svn] r11568 - mlpack/trunk/src/mlpack/core/tree

fastlab-svn at coffeetalk-1.cc.gatech.edu fastlab-svn at coffeetalk-1.cc.gatech.edu
Tue Feb 21 05:17:10 EST 2012


Author: jcline3
Date: 2012-02-21 05:17:10 -0500 (Tue, 21 Feb 2012)
New Revision: 11568

Added:
   mlpack/trunk/src/mlpack/core/tree/mrkd_statistic.hpp
Modified:
   mlpack/trunk/src/mlpack/core/tree/CMakeLists.txt
Log:
MRKDStatistic for mrkd-trees. Probably still needs improvements.


Modified: mlpack/trunk/src/mlpack/core/tree/CMakeLists.txt
===================================================================
--- mlpack/trunk/src/mlpack/core/tree/CMakeLists.txt	2012-02-21 05:27:49 UTC (rev 11567)
+++ mlpack/trunk/src/mlpack/core/tree/CMakeLists.txt	2012-02-21 10:17:10 UTC (rev 11568)
@@ -13,6 +13,7 @@
   periodichrectbound.hpp
   periodichrectbound_impl.hpp
   statistic.hpp
+  mrkd_statistic.hpp
 )
 
 # add directory name to sources

Added: mlpack/trunk/src/mlpack/core/tree/mrkd_statistic.hpp
===================================================================
--- mlpack/trunk/src/mlpack/core/tree/mrkd_statistic.hpp	                        (rev 0)
+++ mlpack/trunk/src/mlpack/core/tree/mrkd_statistic.hpp	2012-02-21 10:17:10 UTC (rev 11568)
@@ -0,0 +1,107 @@
+/**
+ * @file mrkd_statistic.hpp
+ *
+ * Definition of the policy type for the statistic class.
+ *
+ * You should define your own statistic that looks like EmptyStatistic.
+ */
+
+#ifndef __MLPACK_CORE_TREE_MRKD_STATISTIC_HPP
+#define __MLPACK_CORE_TREE_MRKD_STATISTIC_HPP
+
+namespace mlpack {
+namespace tree {
+
+/**
+ * Statistic for multi-resolution kd-trees.
+ */
+class MRKDStatistic
+{
+  public:
+    MRKDStatistic()
+    :
+      dataset(NULL),
+      begin(0),
+      count(0)
+    { }
+
+    ~MRKDStatistic() {}
+
+    /**
+     * This constructor is called when a leaf is created.
+     *
+     * @param dataset Matrix that the tree is being built on.
+     * @param begin Starting index corresponding to this leaf.
+     * @param count Number of points held in this leaf.
+     */
+    template<typename MatType>
+    MRKDStatistic(const MatType& dataset,
+                   const size_t begin,
+                   const size_t count)
+    :
+      dataset(dataset),
+      begin(begin),
+      count(count)
+    {
+      centerOfMass = dataset[begin];
+      for(int i = begin+1; i < begin+count; ++i)
+        centerOfMass += dataset[i];
+      centerOfMass /= count;
+
+      sumOfSquaredNorms = 0.0;
+      for(int i = begin; i < begin+count; ++i)
+        sumOfSquaredNorms += arma::norm(dataset[i], 2);
+    }
+
+    /**
+     * This constructor is called when a non-leaf node is created.
+     * This lets you build fast bottom-up statistics when building trees.
+     *
+     * @param dataset Matrix that the tree is being built on.
+     * @param begin Starting index corresponding to this leaf.
+     * @param count Number of points held in this leaf.
+     * @param leftStat MRKDStatistic object of the left child node.
+     * @param rightStat MRKDStatistic object of the right child node.
+     */
+    template<typename MatType>
+    MRKDStatistic(const MatType& dataset,
+                   const size_t begin,
+                   const size_t count,
+                   const MRKDStatistic& leftStat,
+                   const MRKDStatistic& rightStat)
+    :
+      dataset(dataset),
+      begin(begin),
+      count(count),
+      leftStat(leftStat),
+      rightStat(rightStat)
+    {
+      sumOfSquaredNorms = leftStat.sumOfSquaredNorms + rightStat.sumOfSquaredNorms;
+
+      centerOfMass = ((leftStat.centerOfMass * leftStat.count) +
+                      (rightStat.centerOfMass * rightStat.count)) /
+                      (leftStat.count + rightStat.count);
+    }
+
+    //! The data points this object contains
+    const arma::mat* dataset;
+    //! The left child 
+    const MRKDStatistic* leftStat;
+    //! The right child 
+    const MRKDStatistic* rightStat;
+    //! The initial item in the dataset, so we don't have to make a copy
+    const size_t begin;
+    //! The number of items in the dataset
+    const size_t count;
+
+    // Computed statistics
+    //! The center of mass for this dataset
+    arma::vec centerOfMass;
+    //! The sum of the squared Euclidian norms for this dataset
+    double sumOfSquaredNorms;
+};
+
+}; // namespace tree
+}; // namespace mlpack
+
+#endif // __MLPACK_CORE_TREE_MRKD_STATISTIC_HPP




More information about the mlpack-svn mailing list