[mlpack-git] master: change binarize to col_major and add overload (57a1b19)

gitdub at mlpack.org gitdub at mlpack.org
Mon Jun 6 14:14:45 EDT 2016


Repository : https://github.com/mlpack/mlpack
On branch  : master
Link       : https://github.com/mlpack/mlpack/compare/a0b31abe5ff69117645c664dbeac1476dd5e48f7...2da9c5bac14a00145c757b8139c245913b86e034

>---------------------------------------------------------------

commit 57a1b195b2770b0555aee4951726b29a4a01d23f
Author: Keon Kim <kwk236 at gmail.com>
Date:   Tue Jun 7 03:02:37 2016 +0900

    change binarize to col_major and add overload


>---------------------------------------------------------------

57a1b195b2770b0555aee4951726b29a4a01d23f
 src/mlpack/core/data/binarize.hpp  | 34 ++++++++++++++++++++++++++++++++--
 src/mlpack/tests/binarize_test.cpp | 20 --------------------
 2 files changed, 32 insertions(+), 22 deletions(-)

diff --git a/src/mlpack/core/data/binarize.hpp b/src/mlpack/core/data/binarize.hpp
index c92fb1d..59d1356 100644
--- a/src/mlpack/core/data/binarize.hpp
+++ b/src/mlpack/core/data/binarize.hpp
@@ -66,8 +66,38 @@ template<typename T>
 void Binarize(arma::Mat<T>& input,
               const double threshold)
 {
-  for (size_t i = 0; i < input.n_rows; ++i)
-    Binarize(input, threshold, i);
+  for (size_t i = 0; i < input.n_cols; ++i)
+  {
+    for (size_t j = 0; j < input.n_rows; ++j)
+    {
+      if (input(i, j) > threshold)
+        input(i, j) = 1;
+      else
+        input(i, j) = 0;
+    }
+  }
+ }
+
+template<typename T>
+void Binarize(const arma::Mat<T>& input,
+              arma::Mat<T>& output,
+              const double threshold)
+{
+  for (size_t i = 0; i < input.n_cols; ++i)
+  {
+    output.row(i) =
+        arma::conv_to<arma::Mat<T>>::from(input.row(i) > threshold);
+  }
+}
+
+template<typename T>
+void Binarize(const arma::Mat<T>& input,
+              arma::Mat<T>& output,
+              const double threshold,
+              const size_t dimension)
+{
+   output.row(dimension) =
+      arma::conv_to<arma::Mat<T>>::from(input.row(dimension) > threshold);
 }
 
 } // namespace data
diff --git a/src/mlpack/tests/binarize_test.cpp b/src/mlpack/tests/binarize_test.cpp
index d456f14..5ccaf38 100644
--- a/src/mlpack/tests/binarize_test.cpp
+++ b/src/mlpack/tests/binarize_test.cpp
@@ -44,26 +44,6 @@ BOOST_AUTO_TEST_CASE(BinarizeThreshold)
 {
   mat input(10, 10, fill::randu); // fill input with randome Number
   mat constMat(10, 10);
-  math::RandomSeed((size_t) std::time(NULL));
-  double threshold = math::Random(); // random number threshold
-  constMat.fill(threshold);
-
-  umat answer = input > constMat;
-
-  // Binarize every values inside the matrix with threshold of 0;
-  Binarize(input, threshold);
-
-  CheckAnswer(input, answer);
-}
-
-/**
- * The same test as above, but on a larger dataset.
- */
-BOOST_AUTO_TEST_CASE(BinarizeThresholdLargerTest)
-{
-  mat input(10, 500, fill::randu); // fill input with randome Number
-  mat constMat(10, 500);
-  math::RandomSeed((size_t) std::time(NULL));
   double threshold = math::Random(); // random number threshold
   constMat.fill(threshold);
 




More information about the mlpack-git mailing list