[mlpack-git] master: Add Serialize() and a test. (e3c3f78)

gitdub at big.cc.gt.atl.ga.us gitdub at big.cc.gt.atl.ga.us
Thu Dec 3 14:37:17 EST 2015


Repository : https://github.com/mlpack/mlpack

On branch  : master
Link       : https://github.com/mlpack/mlpack/compare/c829fc1a2415f3dddb672431bb51ff05cbc40a76...9e76f0b82b8bc4fe038179abe77a78146b40c195

>---------------------------------------------------------------

commit e3c3f782746cb2c24473403be2610efaa707a4ae
Author: ryan <ryan at ratml.org>
Date:   Thu Dec 3 11:35:33 2015 -0500

    Add Serialize() and a test.


>---------------------------------------------------------------

e3c3f782746cb2c24473403be2610efaa707a4ae
 src/mlpack/methods/lars/CMakeLists.txt  |  5 ++--
 src/mlpack/methods/lars/lars.cpp        | 14 ++++-----
 src/mlpack/methods/lars/lars.hpp        | 13 ++++++--
 src/mlpack/methods/lars/lars_impl.hpp   | 53 +++++++++++++++++++++++++++++++++
 src/mlpack/tests/serialization_test.cpp | 38 +++++++++++++++++++++++
 5 files changed, 112 insertions(+), 11 deletions(-)

diff --git a/src/mlpack/methods/lars/CMakeLists.txt b/src/mlpack/methods/lars/CMakeLists.txt
index a502937..b876c07 100644
--- a/src/mlpack/methods/lars/CMakeLists.txt
+++ b/src/mlpack/methods/lars/CMakeLists.txt
@@ -1,8 +1,9 @@
 # Define the files we need to compile
 # Anything not in this list will not be compiled into the output library
 set(SOURCES
-   lars.hpp
-   lars.cpp
+  lars.hpp
+  lars_impl.hpp
+  lars.cpp
 )
 
 # add directory name to sources
diff --git a/src/mlpack/methods/lars/lars.cpp b/src/mlpack/methods/lars/lars.cpp
index f3b9df0..856d6f0 100644
--- a/src/mlpack/methods/lars/lars.cpp
+++ b/src/mlpack/methods/lars/lars.cpp
@@ -13,7 +13,7 @@ LARS::LARS(const bool useCholesky,
            const double lambda1,
            const double lambda2,
            const double tolerance) :
-    matGram(matGramInternal),
+    matGram(&matGramInternal),
     useCholesky(useCholesky),
     lasso((lambda1 != 0)),
     lambda1(lambda1),
@@ -27,7 +27,7 @@ LARS::LARS(const bool useCholesky,
            const double lambda1,
            const double lambda2,
            const double tolerance) :
-    matGram(gramMatrix),
+    matGram(&gramMatrix),
     useCholesky(useCholesky),
     lasso((lambda1 != 0)),
     lambda1(lambda1),
@@ -102,7 +102,7 @@ void LARS::Train(const arma::mat& matX,
 
   // Compute the Gram matrix.  If this is the elastic net problem, we will add
   // lambda2 * I_n to the matrix.
-  if (matGram.n_elem != dataRef.n_cols * dataRef.n_cols)
+  if (matGram->n_elem != dataRef.n_cols * dataRef.n_cols)
   {
     // In this case, matGram should reference matGramInternal.
     matGramInternal = trans(dataRef) * dataRef;
@@ -136,10 +136,10 @@ void LARS::Train(const arma::mat& matX,
         //   newGramCol[i] = dot(matX.col(activeSet[i]), matX.col(changeInd));
         // }
         // This is equivalent to the above 5 lines.
-        arma::vec newGramCol = matGram.elem(changeInd * dataRef.n_cols +
+        arma::vec newGramCol = matGram->elem(changeInd * dataRef.n_cols +
             arma::conv_to<arma::uvec>::from(activeSet));
 
-        CholeskyInsert(matGram(changeInd, changeInd), newGramCol);
+        CholeskyInsert((*matGram)(changeInd, changeInd), newGramCol);
       }
 
       // Add variable to active set.
@@ -200,7 +200,7 @@ void LARS::Train(const arma::mat& matX,
       arma::mat matGramActive = arma::mat(activeSet.size(), activeSet.size());
       for (size_t i = 0; i < activeSet.size(); i++)
         for (size_t j = 0; j < activeSet.size(); j++)
-          matGramActive(i, j) = matGram(activeSet[i], activeSet[j]);
+          matGramActive(i, j) = (*matGram)(activeSet[i], activeSet[j]);
 
       // Check for singularity.
       arma::mat matS = s * arma::ones<arma::mat>(1, activeSet.size());
@@ -502,7 +502,7 @@ std::string LARS::ToString() const
 {
   std::ostringstream convert;
   convert << "LARS [" << this << "]" << std::endl;
-  convert << "  Gram Matrix: " << matGram.n_rows << "x" << matGram.n_cols;
+  convert << "  Gram Matrix: " << matGram->n_rows << "x" << matGram->n_cols;
   convert << std::endl;
   convert << "  Tolerance: " << tolerance << std::endl;
   return convert.str();
diff --git a/src/mlpack/methods/lars/lars.hpp b/src/mlpack/methods/lars/lars.hpp
index dca181d..67f5aed 100644
--- a/src/mlpack/methods/lars/lars.hpp
+++ b/src/mlpack/methods/lars/lars.hpp
@@ -163,6 +163,12 @@ class LARS
   //! Access the upper triangular cholesky factor.
   const arma::mat& MatUtriCholFactor() const { return matUtriCholFactor; }
 
+  /**
+   * Serialize the LARS model.
+   */
+  template<typename Archive>
+  void Serialize(Archive& ar, const unsigned int /* version */);
+
   // Returns a string representation of this object.
   std::string ToString() const;
 
@@ -170,8 +176,8 @@ class LARS
   //! Gram matrix.
   arma::mat matGramInternal;
 
-  //! Reference to the Gram matrix we will use.
-  const arma::mat& matGram;
+  //! Pointer to the Gram matrix we will use.
+  const arma::mat* matGram;
 
   //! Upper triangular cholesky factor; initially 0x0 matrix.
   arma::mat matUtriCholFactor;
@@ -255,4 +261,7 @@ class LARS
 } // namespace regression
 } // namespace mlpack
 
+// Include implementation of Serialize().
+#include "lars_impl.hpp"
+
 #endif
diff --git a/src/mlpack/methods/lars/lars_impl.hpp b/src/mlpack/methods/lars/lars_impl.hpp
new file mode 100644
index 0000000..735f9da
--- /dev/null
+++ b/src/mlpack/methods/lars/lars_impl.hpp
@@ -0,0 +1,53 @@
+/**
+ * @file lars_impl.hpp
+ * @author Ryan Curtin
+ *
+ * Implementation of templated LARS functions.
+ */
+#ifndef __MLPACK_METHODS_LARS_LARS_IMPL_HPP
+#define __MLPACK_METHODS_LARS_LARS_IMPL_HPP
+
+//! In case it hasn't been included yet.
+#include "lars.hpp"
+
+namespace mlpack {
+namespace regression {
+
+/**
+ * Serialize the LARS model.
+ */
+template<typename Archive>
+void LARS::Serialize(Archive& ar, const unsigned int /* version */)
+{
+  using data::CreateNVP;
+
+  // If we're loading, we have to use the internal storage.
+  if (Archive::is_loading::value)
+  {
+    matGram = &matGramInternal;
+    ar & CreateNVP(matGramInternal, "matGramInternal");
+  }
+  else
+  {
+    ar & CreateNVP(const_cast<arma::mat&>(*matGram), "matGramInternal");
+  }
+
+  ar & CreateNVP(matUtriCholFactor, "matUtriCholFactor");
+  ar & CreateNVP(useCholesky, "useCholesky");
+  ar & CreateNVP(lasso, "lasso");
+  ar & CreateNVP(lambda1, "lambda1");
+  ar & CreateNVP(elasticNet, "elasticNet");
+  ar & CreateNVP(lambda2, "lambda2");
+  ar & CreateNVP(tolerance, "tolerance");
+  ar & CreateNVP(betaPath, "betaPath");
+  ar & CreateNVP(lambdaPath, "lambdaPath");
+  ar & CreateNVP(activeSet, "activeSet");
+  ar & CreateNVP(isActive, "isActive");
+  ar & CreateNVP(ignoreSet, "ignoreSet");
+  ar & CreateNVP(isIgnored, "isIgnored");
+}
+
+} // namespace regression
+} // namespace mlpack
+
+#endif
diff --git a/src/mlpack/tests/serialization_test.cpp b/src/mlpack/tests/serialization_test.cpp
index d63a516..4a9a182 100644
--- a/src/mlpack/tests/serialization_test.cpp
+++ b/src/mlpack/tests/serialization_test.cpp
@@ -33,6 +33,7 @@
 #include <mlpack/methods/rann/ra_search.hpp>
 #include <mlpack/methods/lsh/lsh_search.hpp>
 #include <mlpack/methods/decision_stump/decision_stump.hpp>
+#include <mlpack/methods/lars/lars.hpp>
 
 using namespace mlpack;
 using namespace mlpack::distribution;
@@ -1467,4 +1468,41 @@ BOOST_AUTO_TEST_CASE(DecisionStumpTest)
       binaryDs.BinLabels());
 }
 
+// Make sure serialization works for LARS.
+BOOST_AUTO_TEST_CASE(LARSTest)
+{
+  using namespace mlpack::regression;
+
+  // Create a dataset.
+  arma::mat X = arma::randn(75, 250);
+  arma::vec beta = arma::randn(75, 1);
+  arma::vec y = trans(X) * beta;
+
+  LARS lars(true, 0.1, 0.1);
+  arma::vec betaOpt;
+  lars.Train(X, y, betaOpt);
+
+  // Now, serialize.
+  LARS xmlLars(false, 0.5, 0.0), binaryLars(true, 1.0, 0.0),
+      textLars(false, 0.1, 0.1);
+
+  // Train textLars.
+  arma::mat textX = arma::randn(25, 150);
+  arma::vec textBeta = arma::randn(25, 1);
+  arma::vec textY = trans(textX) * textBeta;
+  arma::vec textBetaOpt;
+  textLars.Train(textX, textY, textBetaOpt);
+
+  SerializeObjectAll(lars, xmlLars, binaryLars, textLars);
+
+  // Now, check that predictions are the same.
+  arma::vec pred, xmlPred, textPred, binaryPred;
+  lars.Predict(X, pred);
+  xmlLars.Predict(X, xmlPred);
+  textLars.Predict(X, textPred);
+  binaryLars.Predict(X, binaryPred);
+
+  CheckMatrices(pred, xmlPred, textPred, binaryPred);
+}
+
 BOOST_AUTO_TEST_SUITE_END();



More information about the mlpack-git mailing list