[mlpack-svn] r17318 - mlpack/trunk/src/mlpack/methods/decision_stump
fastlab-svn at coffeetalk-1.cc.gatech.edu
fastlab-svn at coffeetalk-1.cc.gatech.edu
Tue Nov 11 13:46:31 EST 2014
Author: rcurtin
Date: Tue Nov 11 13:46:31 2014
New Revision: 17318
Log:
Refactor CountMostFreq() so it is faster, simpler, and doesn't sometimes return
uninitialized values.
Modified:
mlpack/trunk/src/mlpack/methods/decision_stump/decision_stump.hpp
mlpack/trunk/src/mlpack/methods/decision_stump/decision_stump_impl.hpp
Modified: mlpack/trunk/src/mlpack/methods/decision_stump/decision_stump.hpp
==============================================================================
--- mlpack/trunk/src/mlpack/methods/decision_stump/decision_stump.hpp (original)
+++ mlpack/trunk/src/mlpack/methods/decision_stump/decision_stump.hpp Tue Nov 11 13:46:31 2014
@@ -165,7 +165,7 @@
* @param isWeight Whether we need to run a weighted Decision Stump.
*/
template <bool isWeight>
- void Train(const MatType& data, const arma::Row<size_t>& labels,
+ void Train(const MatType& data, const arma::Row<size_t>& labels,
const arma::rowvec& weightD);
};
Modified: mlpack/trunk/src/mlpack/methods/decision_stump/decision_stump_impl.hpp
==============================================================================
--- mlpack/trunk/src/mlpack/methods/decision_stump/decision_stump_impl.hpp (original)
+++ mlpack/trunk/src/mlpack/methods/decision_stump/decision_stump_impl.hpp Tue Nov 11 13:46:31 2014
@@ -356,44 +356,34 @@
template <typename rType>
rType DecisionStump<MatType>::CountMostFreq(const arma::Row<rType>& subCols)
{
- // Sort subCols for easier processing.
- arma::Row<rType> sortCounts = arma::sort(subCols);
- rType element;
- int count = 0, localCount = 0;
+ // We'll create a map of elements and the number of times that each element is
+ // seen.
+ std::map<rType, size_t> countMap;
- if (sortCounts.n_elem == 1)
- return sortCounts[0];
-
- // An O(n) loop which counts the most frequent element in sortCounts
- for (size_t i = 0; i < sortCounts.n_elem; ++i)
+ for (size_t i = 0; i < subCols.n_elem; ++i)
{
- if (i == sortCounts.n_elem - 1)
- {
- if (sortCounts(i - 1) == sortCounts(i))
- {
- // element = sortCounts(i - 1);
- localCount++;
- }
- else if (localCount > count)
- count = localCount;
- }
- else if (sortCounts(i) != sortCounts(i + 1))
- {
- localCount = 0;
- count++;
- }
+ if (countMap.count(subCols[i]) == 0)
+ countMap[subCols[i]] = 1;
else
+ ++countMap[subCols[i]];
+ }
+
+ // Now find the maximum value.
+ typename std::map<rType, size_t>::iterator it = countMap.begin();
+ rType mostFreq = it->first;
+ size_t mostFreqCount = it->second;
+ while (it != countMap.end())
+ {
+ if (it->second >= mostFreqCount)
{
- localCount++;
- if (localCount > count)
- {
- count = localCount;
- if (localCount == 1)
- element = sortCounts(i);
- }
+ mostFreq = it->first;
+ mostFreqCount = it->second;
}
+
+ ++it;
}
- return element;
+
+ return mostFreq;
}
/**
More information about the mlpack-svn
mailing list