[mlpack-git] master: change binarize to col_major and add overload (57a1b19)
gitdub at mlpack.org
gitdub at mlpack.org
Mon Jun 6 14:14:45 EDT 2016
Repository : https://github.com/mlpack/mlpack
On branch : master
Link : https://github.com/mlpack/mlpack/compare/a0b31abe5ff69117645c664dbeac1476dd5e48f7...2da9c5bac14a00145c757b8139c245913b86e034
>---------------------------------------------------------------
commit 57a1b195b2770b0555aee4951726b29a4a01d23f
Author: Keon Kim <kwk236 at gmail.com>
Date: Tue Jun 7 03:02:37 2016 +0900
change binarize to col_major and add overload
>---------------------------------------------------------------
57a1b195b2770b0555aee4951726b29a4a01d23f
src/mlpack/core/data/binarize.hpp | 34 ++++++++++++++++++++++++++++++++--
src/mlpack/tests/binarize_test.cpp | 20 --------------------
2 files changed, 32 insertions(+), 22 deletions(-)
diff --git a/src/mlpack/core/data/binarize.hpp b/src/mlpack/core/data/binarize.hpp
index c92fb1d..59d1356 100644
--- a/src/mlpack/core/data/binarize.hpp
+++ b/src/mlpack/core/data/binarize.hpp
@@ -66,8 +66,38 @@ template<typename T>
void Binarize(arma::Mat<T>& input,
const double threshold)
{
- for (size_t i = 0; i < input.n_rows; ++i)
- Binarize(input, threshold, i);
+ for (size_t i = 0; i < input.n_cols; ++i)
+ {
+ for (size_t j = 0; j < input.n_rows; ++j)
+ {
+ if (input(i, j) > threshold)
+ input(i, j) = 1;
+ else
+ input(i, j) = 0;
+ }
+ }
+ }
+
+template<typename T>
+void Binarize(const arma::Mat<T>& input,
+ arma::Mat<T>& output,
+ const double threshold)
+{
+ for (size_t i = 0; i < input.n_cols; ++i)
+ {
+ output.row(i) =
+ arma::conv_to<arma::Mat<T>>::from(input.row(i) > threshold);
+ }
+}
+
+template<typename T>
+void Binarize(const arma::Mat<T>& input,
+ arma::Mat<T>& output,
+ const double threshold,
+ const size_t dimension)
+{
+ output.row(dimension) =
+ arma::conv_to<arma::Mat<T>>::from(input.row(dimension) > threshold);
}
} // namespace data
diff --git a/src/mlpack/tests/binarize_test.cpp b/src/mlpack/tests/binarize_test.cpp
index d456f14..5ccaf38 100644
--- a/src/mlpack/tests/binarize_test.cpp
+++ b/src/mlpack/tests/binarize_test.cpp
@@ -44,26 +44,6 @@ BOOST_AUTO_TEST_CASE(BinarizeThreshold)
{
mat input(10, 10, fill::randu); // fill input with randome Number
mat constMat(10, 10);
- math::RandomSeed((size_t) std::time(NULL));
- double threshold = math::Random(); // random number threshold
- constMat.fill(threshold);
-
- umat answer = input > constMat;
-
- // Binarize every values inside the matrix with threshold of 0;
- Binarize(input, threshold);
-
- CheckAnswer(input, answer);
-}
-
-/**
- * The same test as above, but on a larger dataset.
- */
-BOOST_AUTO_TEST_CASE(BinarizeThresholdLargerTest)
-{
- mat input(10, 500, fill::randu); // fill input with randome Number
- mat constMat(10, 500);
- math::RandomSeed((size_t) std::time(NULL));
double threshold = math::Random(); // random number threshold
constMat.fill(threshold);
More information about the mlpack-git
mailing list