[mlpack-svn] r17140 - mlpack/trunk/src/mlpack/core/arma_extend
fastlab-svn at coffeetalk-1.cc.gatech.edu
fastlab-svn at coffeetalk-1.cc.gatech.edu
Sat Aug 30 00:06:08 EDT 2014
Author: rcurtin
Date: Sat Aug 30 00:06:08 2014
New Revision: 17140
Log:
Backport another sparse matrix constructor for softmax regression.
Modified:
mlpack/trunk/src/mlpack/core/arma_extend/SpMat_extra_bones.hpp
mlpack/trunk/src/mlpack/core/arma_extend/SpMat_extra_meat.hpp
Modified: mlpack/trunk/src/mlpack/core/arma_extend/SpMat_extra_bones.hpp
==============================================================================
--- mlpack/trunk/src/mlpack/core/arma_extend/SpMat_extra_bones.hpp (original)
+++ mlpack/trunk/src/mlpack/core/arma_extend/SpMat_extra_bones.hpp Sat Aug 30 00:06:08 2014
@@ -5,12 +5,14 @@
* Add a batch constructor for SpMat, if the version is older than 3.810.0.
*/
#if ARMA_VERSION_MAJOR == 3 && ARMA_VERSION_MINOR < 810
-template<typename T1, typename T2> inline SpMat(
+template<typename T1, typename T2>
+inline SpMat(
const Base<uword, T1>& locations,
const Base<eT, T2>& values,
const bool sort_locations = true);
-template<typename T1, typename T2> inline SpMat(
+template<typename T1, typename T2>
+inline SpMat(
const Base<uword, T1>& locations,
const Base<eT, T2>& values,
const uword n_rows,
@@ -18,6 +20,16 @@
const bool sort_locations = true);
#endif
+#if ARMA_VERSION_MAJOR == 3 && ARMA_VERSION_MINOR < 920
+template<typename T1, typename T2, typename T3>
+inline SpMat(
+ const Base<uword, T1>& rowind,
+ const Base<uword, T2>& colptr,
+ const Base<eT, T3>& values,
+ const uword n_rows,
+ const uword n_cols);
+#endif
+
/*
* Extra functions for SpMat<eT>
* Adding definition of row_col_iterator to generalize with Mat<eT>::row_col_iterator
Modified: mlpack/trunk/src/mlpack/core/arma_extend/SpMat_extra_meat.hpp
==============================================================================
--- mlpack/trunk/src/mlpack/core/arma_extend/SpMat_extra_meat.hpp (original)
+++ mlpack/trunk/src/mlpack/core/arma_extend/SpMat_extra_meat.hpp Sat Aug 30 00:06:08 2014
@@ -250,6 +250,66 @@
#endif
+#if ARMA_VERSION_MAJOR == 3 && ARMA_VERSION_MINOR < 920
+//! Insert a large number of values at once.
+//! Per CSC format, rowind_expr should be row indices,~
+//! colptr_expr should column ptr indices locations,
+//! and values should be the corresponding values.
+//! In this constructor the size is explicitly given.
+//! Values are assumed to be sorted, and the size~
+//! information is trusted
+template<typename eT>
+template<typename T1, typename T2, typename T3>
+inline
+SpMat<eT>::SpMat
+ (
+ const Base<uword,T1>& rowind_expr,
+ const Base<uword,T2>& colptr_expr,
+ const Base<eT, T3>& values_expr,
+ const uword in_n_rows,
+ const uword in_n_cols
+ )
+ : n_rows(0)
+ , n_cols(0)
+ , n_elem(0)
+ , n_nonzero(0)
+ , vec_state(0)
+ , values(NULL)
+ , row_indices(NULL)
+ , col_ptrs(NULL)
+ {
+ arma_extra_debug_sigprint_this(this);
+
+ init(in_n_rows, in_n_cols);
+
+ const unwrap<T1> rowind_tmp( rowind_expr.get_ref() );
+ const unwrap<T2> colptr_tmp( colptr_expr.get_ref() );
+ const unwrap<T3> vals_tmp( values_expr.get_ref() );
+
+ const Mat<uword>& rowind = rowind_tmp.M;
+ const Mat<uword>& colptr = colptr_tmp.M;
+ const Mat<eT>& vals = vals_tmp.M;
+
+ arma_debug_check( (rowind.is_vec() == false), "SpMat::SpMat(): given 'rowind' object is not a vector" );
+ arma_debug_check( (colptr.is_vec() == false), "SpMat::SpMat(): given 'colptr' object is not a vector" );
+ arma_debug_check( (vals.is_vec() == false), "SpMat::SpMat(): given 'values' object is not a vector" );
+
+ arma_debug_check( (rowind.n_elem != vals.n_elem), "SpMat::SpMat(): number of row indices is not equal to number of values" );
+ arma_debug_check( (colptr.n_elem != (n_cols+1) ), "SpMat::SpMat(): number of column pointers is not equal to n_cols+1" );
+
+ // Resize to correct number of elements (this also sets n_nonzero)
+ mem_resize(vals.n_elem);
+
+ // copy supplied values into sparse matrix -- not checked for consistency
+ arrayops::copy(access::rwp(row_indices), rowind.memptr(), rowind.n_elem );
+ arrayops::copy(access::rwp(col_ptrs), colptr.memptr(), colptr.n_elem );
+ arrayops::copy(access::rwp(values), vals.memptr(), vals.n_elem );
+
+ // important: set the sentinel as well
+ access::rw(col_ptrs[n_cols + 1]) = std::numeric_limits<uword>::max();
+ }
+#endif
+
#if ARMA_VERSION_MAJOR < 4 || \
(ARMA_VERSION_MAJOR == 4 && ARMA_VERSION_MINOR < 349)
template<typename eT>
More information about the mlpack-svn
mailing list