[mlpack-git] master: Backport another sparse matrix constructor for softmax regression. (d049835)

gitdub at big.cc.gt.atl.ga.us gitdub at big.cc.gt.atl.ga.us
Thu Mar 5 21:59:59 EST 2015


Repository : https://github.com/mlpack/mlpack

On branch  : master
Link       : https://github.com/mlpack/mlpack/compare/904762495c039e345beba14c1142fd719b3bd50e...f94823c800ad6f7266995c700b1b630d5ffdcf40

>---------------------------------------------------------------

commit d049835831aa7da3a82383e906d59ab9ecbba790
Author: Ryan Curtin <ryan at ratml.org>
Date:   Sat Aug 30 04:06:08 2014 +0000

    Backport another sparse matrix constructor for softmax regression.


>---------------------------------------------------------------

d049835831aa7da3a82383e906d59ab9ecbba790
 src/mlpack/core/arma_extend/SpMat_extra_bones.hpp | 16 +++++-
 src/mlpack/core/arma_extend/SpMat_extra_meat.hpp  | 60 +++++++++++++++++++++++
 2 files changed, 74 insertions(+), 2 deletions(-)

diff --git a/src/mlpack/core/arma_extend/SpMat_extra_bones.hpp b/src/mlpack/core/arma_extend/SpMat_extra_bones.hpp
index 1c84221..51b724c 100644
--- a/src/mlpack/core/arma_extend/SpMat_extra_bones.hpp
+++ b/src/mlpack/core/arma_extend/SpMat_extra_bones.hpp
@@ -5,12 +5,14 @@
  * Add a batch constructor for SpMat, if the version is older than 3.810.0.
  */
 #if ARMA_VERSION_MAJOR == 3 && ARMA_VERSION_MINOR < 810
-template<typename T1, typename T2> inline SpMat(
+template<typename T1, typename T2>
+inline SpMat(
     const Base<uword, T1>& locations,
     const Base<eT, T2>& values,
     const bool sort_locations = true);
 
-template<typename T1, typename T2> inline SpMat(
+template<typename T1, typename T2>
+inline SpMat(
     const Base<uword, T1>& locations,
     const Base<eT, T2>& values,
     const uword n_rows,
@@ -18,6 +20,16 @@ template<typename T1, typename T2> inline SpMat(
     const bool sort_locations = true);
 #endif
 
+#if ARMA_VERSION_MAJOR == 3 && ARMA_VERSION_MINOR < 920
+template<typename T1, typename T2, typename T3>
+inline SpMat(
+    const Base<uword, T1>& rowind,
+    const Base<uword, T2>& colptr,
+    const Base<eT, T3>& values,
+    const uword n_rows,
+    const uword n_cols);
+#endif
+
 /*
  * Extra functions for SpMat<eT>
  * Adding definition of row_col_iterator to generalize with Mat<eT>::row_col_iterator
diff --git a/src/mlpack/core/arma_extend/SpMat_extra_meat.hpp b/src/mlpack/core/arma_extend/SpMat_extra_meat.hpp
index 2cf980b..d2ad10e 100644
--- a/src/mlpack/core/arma_extend/SpMat_extra_meat.hpp
+++ b/src/mlpack/core/arma_extend/SpMat_extra_meat.hpp
@@ -250,6 +250,66 @@ SpMat<eT>::SpMat(const Base<uword,T1>& locations_expr, const Base<eT,T2>& vals_e
 
 #endif
 
+#if ARMA_VERSION_MAJOR == 3 && ARMA_VERSION_MINOR < 920
+//! Insert a large number of values at once.
+//! Per CSC format, rowind_expr should be row indices,~
+//! colptr_expr should column ptr indices locations,
+//! and values should be the corresponding values.
+//! In this constructor the size is explicitly given.
+//! Values are assumed to be sorted, and the size~
+//! information is trusted
+template<typename eT>
+template<typename T1, typename T2, typename T3>
+inline
+SpMat<eT>::SpMat
+  (
+  const Base<uword,T1>& rowind_expr,
+  const Base<uword,T2>& colptr_expr,
+  const Base<eT,   T3>& values_expr,
+  const uword           in_n_rows,
+  const uword           in_n_cols
+  )
+  : n_rows(0)
+  , n_cols(0)
+  , n_elem(0)
+  , n_nonzero(0)
+  , vec_state(0)
+  , values(NULL)
+  , row_indices(NULL)
+  , col_ptrs(NULL)
+  {
+  arma_extra_debug_sigprint_this(this);
+
+  init(in_n_rows, in_n_cols);
+
+  const unwrap<T1> rowind_tmp( rowind_expr.get_ref() );
+  const unwrap<T2> colptr_tmp( colptr_expr.get_ref() );
+  const unwrap<T3>   vals_tmp( values_expr.get_ref() );
+
+  const Mat<uword>& rowind = rowind_tmp.M;
+  const Mat<uword>& colptr = colptr_tmp.M;
+  const Mat<eT>&      vals = vals_tmp.M;
+
+  arma_debug_check( (rowind.is_vec() == false), "SpMat::SpMat(): given 'rowind' object is not a vector" );
+  arma_debug_check( (colptr.is_vec() == false), "SpMat::SpMat(): given 'colptr' object is not a vector" );
+  arma_debug_check( (vals.is_vec()   == false), "SpMat::SpMat(): given 'values' object is not a vector" );
+
+  arma_debug_check( (rowind.n_elem != vals.n_elem), "SpMat::SpMat(): number of row indices is not equal to number of values" );
+  arma_debug_check( (colptr.n_elem != (n_cols+1) ), "SpMat::SpMat(): number of column pointers is not equal to n_cols+1" );
+
+  // Resize to correct number of elements (this also sets n_nonzero)
+  mem_resize(vals.n_elem);
+
+  // copy supplied values into sparse matrix -- not checked for consistency
+  arrayops::copy(access::rwp(row_indices), rowind.memptr(), rowind.n_elem );
+  arrayops::copy(access::rwp(col_ptrs),    colptr.memptr(), colptr.n_elem );
+  arrayops::copy(access::rwp(values),      vals.memptr(),   vals.n_elem   );
+
+  // important: set the sentinel as well
+  access::rw(col_ptrs[n_cols + 1]) = std::numeric_limits<uword>::max();
+  }
+#endif
+
 #if ARMA_VERSION_MAJOR < 4 || \
     (ARMA_VERSION_MAJOR == 4 && ARMA_VERSION_MINOR < 349)
 template<typename eT>



More information about the mlpack-git mailing list