[mlpack-svn] r13065 - in mlpack/trunk/src/mlpack: methods/nmf tests
fastlab-svn at coffeetalk-1.cc.gatech.edu
fastlab-svn at coffeetalk-1.cc.gatech.edu
Thu Jun 14 18:59:04 EDT 2012
Author: rmohan
Date: 2012-06-14 18:59:04 -0400 (Thu, 14 Jun 2012)
New Revision: 13065
Added:
mlpack/trunk/src/mlpack/methods/nmf/randominit.hpp
mlpack/trunk/src/mlpack/tests/nmf_test.cpp
Modified:
mlpack/trunk/src/mlpack/methods/nmf/CMakeLists.txt
mlpack/trunk/src/mlpack/methods/nmf/mdivupdate.hpp
mlpack/trunk/src/mlpack/methods/nmf/nmf.hpp
mlpack/trunk/src/mlpack/methods/nmf/nmf_impl.hpp
mlpack/trunk/src/mlpack/methods/nmf/nmf_main.cpp
mlpack/trunk/src/mlpack/tests/CMakeLists.txt
Log:
Unit test for NMF added successfully. Validation conditions need to be tuned
Modified: mlpack/trunk/src/mlpack/methods/nmf/CMakeLists.txt
===================================================================
--- mlpack/trunk/src/mlpack/methods/nmf/CMakeLists.txt 2012-06-14 21:49:36 UTC (rev 13064)
+++ mlpack/trunk/src/mlpack/methods/nmf/CMakeLists.txt 2012-06-14 22:59:04 UTC (rev 13065)
@@ -5,6 +5,7 @@
set(SOURCES
mdistupdate.hpp
mdivupdate.hpp
+ randominit.hpp
nmf.hpp
nmf_impl.hpp
)
Modified: mlpack/trunk/src/mlpack/methods/nmf/mdivupdate.hpp
===================================================================
--- mlpack/trunk/src/mlpack/methods/nmf/mdivupdate.hpp 2012-06-14 21:49:36 UTC (rev 13064)
+++ mlpack/trunk/src/mlpack/methods/nmf/mdivupdate.hpp 2012-06-14 22:59:04 UTC (rev 13065)
@@ -42,7 +42,7 @@
* @param H Encoding matrix to output
*/
- inline static void Update(const arma::mat& V,
+ inline static void Init(const arma::mat& V,
arma::mat& W,
const arma::mat& H)
{
Modified: mlpack/trunk/src/mlpack/methods/nmf/nmf.hpp
===================================================================
--- mlpack/trunk/src/mlpack/methods/nmf/nmf.hpp 2012-06-14 21:49:36 UTC (rev 13064)
+++ mlpack/trunk/src/mlpack/methods/nmf/nmf.hpp 2012-06-14 22:59:04 UTC (rev 13065)
@@ -10,6 +10,7 @@
#include <mlpack/core.hpp>
#include "mdistupdate.hpp"
+#include "randominit.hpp"
namespace mlpack {
namespace nmf {
@@ -43,7 +44,8 @@
* @tparam HUpdateRule The update rule for calculating H matrix at each
* iteration; @see MultiplicativeDistanceH for an example.
*/
-template<typename WUpdateRule = MultiplicativeDistanceW,
+template<typename InitializeRule = RandomInitialization,
+ typename WUpdateRule = MultiplicativeDistanceW,
typename HUpdateRule = MultiplicativeDistanceH>
class NMF
{
@@ -59,13 +61,16 @@
* A low residual value denotes that subsequent iterationas are not
* producing much different values of W and H. Once the difference goes
* below the supplied value, the iteration terminates.
+ * @param Initialize Optional Initialization object for initializing the
+ * W and H matrices
* @param WUpdate Optional WUpdateRule object; for when the update rule for
* the W vector has states that it needs to store.
* @param HUpdate Optional HUpdateRule object; for when the update rule for
* the H vector has states that it needs to store.
*/
- NMF(const size_t maxIterations = 1000,
+ NMF(const size_t maxIterations = 10000,
const double maxResidue = 1e-10,
+ const InitializeRule Initialize = InitializeRule(),
const WUpdateRule WUpdate = WUpdateRule(),
const HUpdateRule HUpdate = HUpdateRule());
@@ -85,6 +90,8 @@
size_t maxIterations;
//! The maximum residue below which iteration is considered converged
double maxResidue;
+ //! Instantiated W&H Initialization Rule
+ InitializeRule Initialize;
//! Instantiated W Update Rule
WUpdateRule WUpdate;
//! Instantiated H Update Rule
Modified: mlpack/trunk/src/mlpack/methods/nmf/nmf_impl.hpp
===================================================================
--- mlpack/trunk/src/mlpack/methods/nmf/nmf_impl.hpp 2012-06-14 21:49:36 UTC (rev 13064)
+++ mlpack/trunk/src/mlpack/methods/nmf/nmf_impl.hpp 2012-06-14 22:59:04 UTC (rev 13065)
@@ -6,6 +6,7 @@
* on the given matrix.
*/
#include "nmf.hpp"
+#include <iostream>
namespace mlpack {
namespace nmf {
@@ -13,16 +14,20 @@
/**
* Construct the NMF object.
*/
-template<typename WUpdateRule,
+template<typename InitializeRule,
+ typename WUpdateRule,
typename HUpdateRule>
-NMF<WUpdateRule,
+NMF<InitializeRule,
+ WUpdateRule,
HUpdateRule>::
NMF(const size_t maxIterations,
const double maxResidue,
+ const InitializeRule Initialize,
const WUpdateRule WUpdate,
const HUpdateRule HUpdate) :
maxIterations(maxIterations),
maxResidue(maxResidue),
+ Initialize(Initialize),
WUpdate(WUpdate),
HUpdate(HUpdate)
{
@@ -43,20 +48,23 @@
* @param H Encoding matrix to output
* @param r Rank r of the factorization
*/
-template<typename WUpdateRule,
+template<typename InitializeRule,
+ typename WUpdateRule,
typename HUpdateRule>
-void NMF<WUpdateRule,
+void NMF<InitializeRule,
+ WUpdateRule,
HUpdateRule>::
Apply(const arma::mat& V, arma::mat& W, arma::mat& H, size_t& r) const
{
size_t n = V.n_rows;
size_t m = V.n_cols;
+
// old and new product WH for residue checking
arma::mat WHold,WH,diff;
- // Allocate random values to the starting iteration
- W.randu(n,r);
- H.randu(r,m);
+ // Intialize W and H
+ Initialize.Init(V,W,H,r);
+
// Store the original calculated value for residue checking
WHold = W*H;
@@ -77,6 +85,8 @@
diff = diff%diff;
residue = accu(diff)/(double)(n*m);
WHold = WH;
+ Log::Debug << "Iteration: " << iteration << " Residue: "
+ << residue << std::endl;
iteration++;
Modified: mlpack/trunk/src/mlpack/methods/nmf/nmf_main.cpp
===================================================================
--- mlpack/trunk/src/mlpack/methods/nmf/nmf_main.cpp 2012-06-14 21:49:36 UTC (rev 13064)
+++ mlpack/trunk/src/mlpack/methods/nmf/nmf_main.cpp 2012-06-14 22:59:04 UTC (rev 13065)
@@ -25,7 +25,7 @@
"h");
PARAM_INT_REQ("rank", "Rank of the factorization.", "r");
PARAM_INT("max_iterations", "Number of iterations before NMF terminates",
- "m", 1000);
+ "m", 10000);
PARAM_DOUBLE("max_residue", "The maximum root mean square allowed below which "
"the program termiates", "e", 1e-10);
Added: mlpack/trunk/src/mlpack/methods/nmf/randominit.hpp
===================================================================
--- mlpack/trunk/src/mlpack/methods/nmf/randominit.hpp (rev 0)
+++ mlpack/trunk/src/mlpack/methods/nmf/randominit.hpp 2012-06-14 22:59:04 UTC (rev 13065)
@@ -0,0 +1,43 @@
+/**
+ * @file randominit.hpp
+ * @author Mohan Rajendran
+ *
+ * Intialization rule for the Non-negative Matrix Factorization. This simple
+ * initialization is performed by assigning a random matrix to W and H
+ *
+ */
+
+#ifndef __MLPACK_METHODS_NMF_RANDOMINIT_HPP
+#define __MLPACK_METHODS_NMF_RANDOMINIT_HPP
+
+#include <mlpack/core.hpp>
+
+namespace mlpack {
+namespace nmf {
+
+class RandomInitialization
+{
+ public:
+ // Empty constructor required for the InitializeRule template
+ RandomInitialization() { }
+
+ inline static void Init(const arma::mat& V,
+ arma::mat& W,
+ arma::mat& H,
+ const size_t& r)
+ {
+ // Simple inplementation. This can be left here.
+ size_t n = V.n_rows;
+ size_t m = V.n_cols;
+
+ // Intialize to random values
+ W.randu(n,r);
+ H.randu(r,m);
+ }
+
+}; // Class RandomInitialization
+
+}; // namespace nmf
+}; // namespace mlpack
+
+#endif
Modified: mlpack/trunk/src/mlpack/tests/CMakeLists.txt
===================================================================
--- mlpack/trunk/src/mlpack/tests/CMakeLists.txt 2012-06-14 21:49:36 UTC (rev 13064)
+++ mlpack/trunk/src/mlpack/tests/CMakeLists.txt 2012-06-14 22:59:04 UTC (rev 13065)
@@ -24,6 +24,7 @@
max_ip_test.cpp
nbc_test.cpp
nca_test.cpp
+ nmf_test.cpp
pca_test.cpp
radical_test.cpp
range_search_test.cpp
Added: mlpack/trunk/src/mlpack/tests/nmf_test.cpp
===================================================================
--- mlpack/trunk/src/mlpack/tests/nmf_test.cpp (rev 0)
+++ mlpack/trunk/src/mlpack/tests/nmf_test.cpp 2012-06-14 22:59:04 UTC (rev 13065)
@@ -0,0 +1,44 @@
+/**
+ * @file nmf_test.cpp
+ * @author Mohan Rajendran
+ *
+ * Test file for NMF class.
+ */
+#include <mlpack/core.hpp>
+#include <mlpack/methods/nmf/nmf.hpp>
+
+#include <boost/test/unit_test.hpp>
+#include "old_boost_test_definitions.hpp"
+
+BOOST_AUTO_TEST_SUITE(NMFTest);
+
+using namespace std;
+using namespace arma;
+using namespace mlpack;
+using namespace mlpack::nmf;
+
+/**
+ * Check the if the product of the calculated factorization is close to the
+ * input matrix.
+ */
+BOOST_AUTO_TEST_CASE(NMFTest)
+{
+ mat V = randu<mat>(5,5);
+ size_t r = 4;
+ mat W,H;
+
+ NMF<> nmf;
+ nmf.Apply(V,W,H,r);
+
+ mat WH = W*H;
+
+ V.print("V=");
+ WH.print("WH=");
+
+ for (size_t row = 0; row < 5; row++)
+ for (size_t col = 0; col < 5; col++)
+ BOOST_REQUIRE_CLOSE(V(row, col), WH(row, col), 5);
+}
+
+
+BOOST_AUTO_TEST_SUITE_END();
More information about the mlpack-svn
mailing list