[mlpack-svn] r15926 - mlpack/trunk/src/mlpack/core/arma_extend

fastlab-svn at coffeetalk-1.cc.gatech.edu fastlab-svn at coffeetalk-1.cc.gatech.edu
Fri Oct 4 00:47:13 EDT 2013


Author: rcurtin
Date: Fri Oct  4 00:47:12 2013
New Revision: 15926

Log:
Backport batch insertion constructor to Armadillo < 3.810.0.


Added:
   mlpack/trunk/src/mlpack/core/arma_extend/SpMat_extra_bones.hpp
   mlpack/trunk/src/mlpack/core/arma_extend/SpMat_extra_meat.hpp
Modified:
   mlpack/trunk/src/mlpack/core/arma_extend/CMakeLists.txt
   mlpack/trunk/src/mlpack/core/arma_extend/arma_extend.hpp

Modified: mlpack/trunk/src/mlpack/core/arma_extend/CMakeLists.txt
==============================================================================
--- mlpack/trunk/src/mlpack/core/arma_extend/CMakeLists.txt	(original)
+++ mlpack/trunk/src/mlpack/core/arma_extend/CMakeLists.txt	Fri Oct  4 00:47:12 2013
@@ -12,6 +12,8 @@
   restrictors.hpp
   traits.hpp
   typedef.hpp
+  SpMat_extra_bones.hpp
+  SpMat_extra_meat.hpp
 )
 
 # add directory name to sources

Added: mlpack/trunk/src/mlpack/core/arma_extend/SpMat_extra_bones.hpp
==============================================================================
--- (empty file)
+++ mlpack/trunk/src/mlpack/core/arma_extend/SpMat_extra_bones.hpp	Fri Oct  4 00:47:12 2013
@@ -0,0 +1,19 @@
+/**
+ * @file SpMat_extra_bones.hpp
+ * @author Ryan Curtin
+ *
+ * Add a batch constructor for SpMat, if the version is older than 3.810.0.
+ */
+#if ARMA_VERSION_MAJOR == 3 && ARMA_VERSION_MINOR < 810
+template<typename T1, typename T2> inline SpMat(
+    const Base<uword, T1>& locations,
+    const Base<eT, T2>& values,
+    const bool sort_locations = true);
+
+template<typename T1, typename T2> inline SpMat(
+    const Base<uword, T1>& locations,
+    const Base<eT, T2>& values,
+    const uword n_rows,
+    const uword n_cols,
+    const bool sort_locations = true);
+#endif

Added: mlpack/trunk/src/mlpack/core/arma_extend/SpMat_extra_meat.hpp
==============================================================================
--- (empty file)
+++ mlpack/trunk/src/mlpack/core/arma_extend/SpMat_extra_meat.hpp	Fri Oct  4 00:47:12 2013
@@ -0,0 +1,251 @@
+/**
+ * @file SpMat_extra_meat.hpp
+ * @author Ryan Curtin
+ *
+ * Take the Armadillo batch sparse matrix constructor function from newer
+ * Armadillo versions and port it to versions earlier than 3.810.0.
+ */
+#if ARMA_VERSION_MAJOR == 3 && ARMA_VERSION_MINOR < 810
+
+//! Insert a large number of values at once.
+//! locations.row[0] should be row indices, locations.row[1] should be column indices,
+//! and values should be the corresponding values.
+//! If sort_locations is false, then it is assumed that the locations and values
+//! are already sorted in column-major ordering.
+template<typename eT>
+template<typename T1, typename T2>
+inline
+SpMat<eT>::SpMat(const Base<uword,T1>& locations_expr, const Base<eT,T2>& vals_expr, const bool sort_locations)
+  : n_rows(0)
+  , n_cols(0)
+  , n_elem(0)
+  , n_nonzero(0)
+  , vec_state(0)
+  , values(NULL)
+  , row_indices(NULL)
+  , col_ptrs(NULL)
+  {
+  arma_extra_debug_sigprint_this(this);
+
+  const unwrap<T1>         locs_tmp( locations_expr.get_ref() );
+  const Mat<uword>& locs = locs_tmp.M;
+
+  const unwrap<T2> vals_tmp( vals_expr.get_ref() );
+  const Mat<eT>& vals = vals_tmp.M;
+
+  arma_debug_check( (vals.is_vec() == false), "SpMat::SpMat(): given 'values' object is not a vector" );
+
+  arma_debug_check((locs.n_cols != vals.n_elem), "SpMat::SpMat(): number of locations is different than number of values");
+
+  // If there are no elements in the list, max() will fail.
+  if (locs.n_cols == 0)
+    {
+    init(0, 0);
+    return;
+    }
+
+  arma_debug_check((locs.n_rows != 2), "SpMat::SpMat(): locations matrix must have two rows");
+
+  // Automatically determine size (and check if it's sorted).
+  uvec bounds = arma::max(locs, 1);
+  init(bounds[0] + 1, bounds[1] + 1);
+
+  // Resize to correct number of elements.
+  mem_resize(vals.n_elem);
+
+  // Reset column pointers to zero.
+  arrayops::inplace_set(access::rwp(col_ptrs), uword(0), n_cols + 1);
+
+  bool actually_sorted = true;
+  if(sort_locations == true)
+    {
+    // sort_index() uses std::sort() which may use quicksort... so we better
+    // make sure it's not already sorted before taking an O(N^2) sort penalty.
+    for (uword i = 1; i < locs.n_cols; ++i)
+      {
+      if ((locs.at(1, i) < locs.at(1, i - 1)) || (locs.at(1, i) == locs.at(1, i - 1) && locs.at(0, i) <= locs.at(0, i - 1)))
+        {
+        actually_sorted = false;
+        break;
+        }
+      }
+
+    if(actually_sorted == false)
+      {
+      // This may not be the fastest possible implementation but it maximizes code reuse.
+      Col<uword> abslocs(locs.n_cols);
+
+      for (uword i = 0; i < locs.n_cols; ++i)
+        {
+        abslocs[i] = locs.at(1, i) * n_rows + locs.at(0, i);
+        }
+
+      // Now we will sort with sort_index().
+      uvec sorted_indices = sort_index(abslocs); // Ascending sort.
+
+      // Now we add the elements in this sorted order.
+      for (uword i = 0; i < sorted_indices.n_elem; ++i)
+        {
+        arma_debug_check((locs.at(0, sorted_indices[i]) >= n_rows), "SpMat::SpMat(): invalid row index");
+        arma_debug_check((locs.at(1, sorted_indices[i]) >= n_cols), "SpMat::SpMat(): invalid column index");
+
+        access::rw(values[i])      = vals[sorted_indices[i]];
+        access::rw(row_indices[i]) = locs.at(0, sorted_indices[i]);
+
+        access::rw(col_ptrs[locs.at(1, sorted_indices[i]) + 1])++;
+        }
+      }
+    }
+  if( (sort_locations == false) || (actually_sorted == true) )
+    {
+    // Now set the values and row indices correctly.
+    // Increment the column pointers in each column (so they are column "counts").
+    for (uword i = 0; i < vals.n_elem; ++i)
+      {
+      arma_debug_check((locs.at(0, i) >= n_rows), "SpMat::SpMat(): invalid row index");
+      arma_debug_check((locs.at(1, i) >= n_cols), "SpMat::SpMat(): invalid column index");
+
+      // Check ordering in debug mode.
+      if(i > 0)
+        {
+        arma_debug_check
+          (
+          ( (locs.at(1, i) < locs.at(1, i - 1)) || (locs.at(1, i) == locs.at(1, i - 1) && locs.at(0, i) < locs.at(0, i - 1)) ),
+          "SpMat::SpMat(): out of order points; either pass sort_locations = true, or sort points in column-major ordering"
+          );
+        arma_debug_check((locs.at(1, i) == locs.at(1, i - 1) && locs.at(0, i) == locs.at(0, i - 1)), "SpMat::SpMat(): two identical point locations in list");
+        }
+
+      access::rw(values[i])      = vals[i];
+      access::rw(row_indices[i]) = locs.at(0, i);
+
+      access::rw(col_ptrs[locs.at(1, i) + 1])++;
+      }
+    }
+
+  // Now fix the column pointers.
+  for (uword i = 0; i <= n_cols; ++i)
+    {
+    access::rw(col_ptrs[i + 1]) += col_ptrs[i];
+    }
+  }
+
+
+
+//! Insert a large number of values at once.
+//! locations.row[0] should be row indices, locations.row[1] should be column indices,
+//! and values should be the corresponding values.
+//! If sort_locations is false, then it is assumed that the locations and values
+//! are already sorted in column-major ordering.
+//! In this constructor the size is explicitly given.
+template<typename eT>
+template<typename T1, typename T2>
+inline
+SpMat<eT>::SpMat(const Base<uword,T1>& locations_expr, const Base<eT,T2>& vals_expr, const uword in_n_rows, const uword in_n_cols, const bool sort_locations)
+  : n_rows(0)
+  , n_cols(0)
+  , n_elem(0)
+  , n_nonzero(0)
+  , vec_state(0)
+  , values(NULL)
+  , row_indices(NULL)
+  , col_ptrs(NULL)
+  {
+  arma_extra_debug_sigprint_this(this);
+
+  init(in_n_rows, in_n_cols);
+
+  const unwrap<T1>         locs_tmp( locations_expr.get_ref() );
+  const Mat<uword>& locs = locs_tmp.M;
+
+  const unwrap<T2> vals_tmp( vals_expr.get_ref() );
+  const Mat<eT>& vals = vals_tmp.M;
+
+  arma_debug_check( (vals.is_vec() == false), "SpMat::SpMat(): given 'values' object is not a vector" );
+
+  arma_debug_check((locs.n_rows != 2), "SpMat::SpMat(): locations matrix must have two rows");
+
+  arma_debug_check((locs.n_cols != vals.n_elem), "SpMat::SpMat(): number of locations is different than number of values");
+
+  // Resize to correct number of elements.
+  mem_resize(vals.n_elem);
+
+  // Reset column pointers to zero.
+  arrayops::inplace_set(access::rwp(col_ptrs), uword(0), n_cols + 1);
+
+  bool actually_sorted = true;
+  if(sort_locations == true)
+    {
+    // sort_index() uses std::sort() which may use quicksort... so we better
+    // make sure it's not already sorted before taking an O(N^2) sort penalty.
+    for (uword i = 1; i < locs.n_cols; ++i)
+      {
+      if ((locs.at(1, i) < locs.at(1, i - 1)) || (locs.at(1, i) == locs.at(1, i - 1) && locs.at(0, i) <= locs.at(0, i - 1)))
+        {
+        actually_sorted = false;
+        break;
+        }
+      }
+
+    if(actually_sorted == false)
+      {
+      // This may not be the fastest possible implementation but it maximizes code reuse.
+      Col<uword> abslocs(locs.n_cols);
+
+      for (uword i = 0; i < locs.n_cols; ++i)
+        {
+        abslocs[i] = locs.at(1, i) * n_rows + locs.at(0, i);
+        }
+
+      // Now we will sort with sort_index().
+      uvec sorted_indices = sort_index(abslocs); // Ascending sort.
+
+      // Now we add the elements in this sorted order.
+      for (uword i = 0; i < sorted_indices.n_elem; ++i)
+        {
+        arma_debug_check((locs.at(0, sorted_indices[i]) >= n_rows), "SpMat::SpMat(): invalid row index");
+        arma_debug_check((locs.at(1, sorted_indices[i]) >= n_cols), "SpMat::SpMat(): invalid column index");
+
+        access::rw(values[i])      = vals[sorted_indices[i]];
+        access::rw(row_indices[i]) = locs.at(0, sorted_indices[i]);
+
+        access::rw(col_ptrs[locs.at(1, sorted_indices[i]) + 1])++;
+        }
+      }
+    }
+
+  if( (sort_locations == false) || (actually_sorted == true) )
+    {
+    // Now set the values and row indices correctly.
+    // Increment the column pointers in each column (so they are column "counts").
+    for (uword i = 0; i < vals.n_elem; ++i)
+      {
+      arma_debug_check((locs.at(0, i) >= n_rows), "SpMat::SpMat(): invalid row index");
+      arma_debug_check((locs.at(1, i) >= n_cols), "SpMat::SpMat(): invalid column index");
+
+      // Check ordering in debug mode.
+      if(i > 0)
+        {
+        arma_debug_check
+          (
+          ( (locs.at(1, i) < locs.at(1, i - 1)) || (locs.at(1, i) == locs.at(1, i - 1) && locs.at(0, i) < locs.at(0, i - 1)) ),
+          "SpMat::SpMat(): out of order points; either pass sort_locations = true or sort points in column-major ordering"
+          );
+        arma_debug_check((locs.at(1, i) == locs.at(1, i - 1) && locs.at(0, i) == locs.at(0, i - 1)), "SpMat::SpMat(): two identical point locations in list");
+        }
+
+      access::rw(values[i])      = vals[i];
+      access::rw(row_indices[i]) = locs.at(0, i);
+
+      access::rw(col_ptrs[locs.at(1, i) + 1])++;
+      }
+    }
+
+  // Now fix the column pointers.
+  for (uword i = 0; i <= n_cols; ++i)
+    {
+    access::rw(col_ptrs[i + 1]) += col_ptrs[i];
+    }
+  }
+
+#endif

Modified: mlpack/trunk/src/mlpack/core/arma_extend/arma_extend.hpp
==============================================================================
--- mlpack/trunk/src/mlpack/core/arma_extend/arma_extend.hpp	(original)
+++ mlpack/trunk/src/mlpack/core/arma_extend/arma_extend.hpp	Fri Oct  4 00:47:12 2013
@@ -19,6 +19,10 @@
 #define ARMA_EXTRA_ROW_PROTO mlpack/core/arma_extend/Row_extra_bones.hpp
 #define ARMA_EXTRA_ROW_MEAT  mlpack/core/arma_extend/Row_extra_meat.hpp
 
+// Add batch constructor for sparse matrix (if version <= 3.810.0).
+#define ARMA_EXTRA_SPMAT_PROTO mlpack/core/arma_extend/SpMat_extra_bones.hpp
+#define ARMA_EXTRA_SPMAT_MEAT  mlpack/core/arma_extend/SpMat_extra_meat.hpp
+
 #include <armadillo>
 
 // To get CSV support on versions of Armadillo prior to 2.0.0, we'll do this.  I



More information about the mlpack-svn mailing list