[mlpack-svn] r17140 - mlpack/trunk/src/mlpack/core/arma_extend

fastlab-svn at coffeetalk-1.cc.gatech.edu fastlab-svn at coffeetalk-1.cc.gatech.edu
Sat Aug 30 00:06:08 EDT 2014


Author: rcurtin
Date: Sat Aug 30 00:06:08 2014
New Revision: 17140

Log:
Backport another sparse matrix constructor for softmax regression.


Modified:
   mlpack/trunk/src/mlpack/core/arma_extend/SpMat_extra_bones.hpp
   mlpack/trunk/src/mlpack/core/arma_extend/SpMat_extra_meat.hpp

Modified: mlpack/trunk/src/mlpack/core/arma_extend/SpMat_extra_bones.hpp
==============================================================================
--- mlpack/trunk/src/mlpack/core/arma_extend/SpMat_extra_bones.hpp	(original)
+++ mlpack/trunk/src/mlpack/core/arma_extend/SpMat_extra_bones.hpp	Sat Aug 30 00:06:08 2014
@@ -5,12 +5,14 @@
  * Add a batch constructor for SpMat, if the version is older than 3.810.0.
  */
 #if ARMA_VERSION_MAJOR == 3 && ARMA_VERSION_MINOR < 810
-template<typename T1, typename T2> inline SpMat(
+template<typename T1, typename T2>
+inline SpMat(
     const Base<uword, T1>& locations,
     const Base<eT, T2>& values,
     const bool sort_locations = true);
 
-template<typename T1, typename T2> inline SpMat(
+template<typename T1, typename T2>
+inline SpMat(
     const Base<uword, T1>& locations,
     const Base<eT, T2>& values,
     const uword n_rows,
@@ -18,6 +20,16 @@
     const bool sort_locations = true);
 #endif
 
+#if ARMA_VERSION_MAJOR == 3 && ARMA_VERSION_MINOR < 920
+template<typename T1, typename T2, typename T3>
+inline SpMat(
+    const Base<uword, T1>& rowind,
+    const Base<uword, T2>& colptr,
+    const Base<eT, T3>& values,
+    const uword n_rows,
+    const uword n_cols);
+#endif
+
 /*
  * Extra functions for SpMat<eT>
  * Adding definition of row_col_iterator to generalize with Mat<eT>::row_col_iterator

Modified: mlpack/trunk/src/mlpack/core/arma_extend/SpMat_extra_meat.hpp
==============================================================================
--- mlpack/trunk/src/mlpack/core/arma_extend/SpMat_extra_meat.hpp	(original)
+++ mlpack/trunk/src/mlpack/core/arma_extend/SpMat_extra_meat.hpp	Sat Aug 30 00:06:08 2014
@@ -250,6 +250,66 @@
 
 #endif
 
+#if ARMA_VERSION_MAJOR == 3 && ARMA_VERSION_MINOR < 920
+//! Insert a large number of values at once.
+//! Per CSC format, rowind_expr should be row indices,~
+//! colptr_expr should column ptr indices locations,
+//! and values should be the corresponding values.
+//! In this constructor the size is explicitly given.
+//! Values are assumed to be sorted, and the size~
+//! information is trusted
+template<typename eT>
+template<typename T1, typename T2, typename T3>
+inline
+SpMat<eT>::SpMat
+  (
+  const Base<uword,T1>& rowind_expr,
+  const Base<uword,T2>& colptr_expr,
+  const Base<eT,   T3>& values_expr,
+  const uword           in_n_rows,
+  const uword           in_n_cols
+  )
+  : n_rows(0)
+  , n_cols(0)
+  , n_elem(0)
+  , n_nonzero(0)
+  , vec_state(0)
+  , values(NULL)
+  , row_indices(NULL)
+  , col_ptrs(NULL)
+  {
+  arma_extra_debug_sigprint_this(this);
+
+  init(in_n_rows, in_n_cols);
+
+  const unwrap<T1> rowind_tmp( rowind_expr.get_ref() );
+  const unwrap<T2> colptr_tmp( colptr_expr.get_ref() );
+  const unwrap<T3>   vals_tmp( values_expr.get_ref() );
+
+  const Mat<uword>& rowind = rowind_tmp.M;
+  const Mat<uword>& colptr = colptr_tmp.M;
+  const Mat<eT>&      vals = vals_tmp.M;
+
+  arma_debug_check( (rowind.is_vec() == false), "SpMat::SpMat(): given 'rowind' object is not a vector" );
+  arma_debug_check( (colptr.is_vec() == false), "SpMat::SpMat(): given 'colptr' object is not a vector" );
+  arma_debug_check( (vals.is_vec()   == false), "SpMat::SpMat(): given 'values' object is not a vector" );
+
+  arma_debug_check( (rowind.n_elem != vals.n_elem), "SpMat::SpMat(): number of row indices is not equal to number of values" );
+  arma_debug_check( (colptr.n_elem != (n_cols+1) ), "SpMat::SpMat(): number of column pointers is not equal to n_cols+1" );
+
+  // Resize to correct number of elements (this also sets n_nonzero)
+  mem_resize(vals.n_elem);
+
+  // copy supplied values into sparse matrix -- not checked for consistency
+  arrayops::copy(access::rwp(row_indices), rowind.memptr(), rowind.n_elem );
+  arrayops::copy(access::rwp(col_ptrs),    colptr.memptr(), colptr.n_elem );
+  arrayops::copy(access::rwp(values),      vals.memptr(),   vals.n_elem   );
+
+  // important: set the sentinel as well
+  access::rw(col_ptrs[n_cols + 1]) = std::numeric_limits<uword>::max();
+  }
+#endif
+
 #if ARMA_VERSION_MAJOR < 4 || \
     (ARMA_VERSION_MAJOR == 4 && ARMA_VERSION_MINOR < 349)
 template<typename eT>



More information about the mlpack-svn mailing list