[mlpack-git] master: Refactor weight vectors to be column-major. (19a4620)
gitdub at big.cc.gt.atl.ga.us
gitdub at big.cc.gt.atl.ga.us
Fri Sep 4 13:32:58 EDT 2015
Repository : https://github.com/mlpack/mlpack
On branch : master
Link : https://github.com/mlpack/mlpack/compare/424383cb02dcf2d73728e1c3c4b582bdb7cba627...f3bd5e8853a795f4ff41849dd8ef844d53199412
>---------------------------------------------------------------
commit 19a4620fc4081a63772b95987ad157b9d6dfd50b
Author: Ryan Curtin <ryan at ratml.org>
Date: Fri Sep 4 16:28:00 2015 +0000
Refactor weight vectors to be column-major.
>---------------------------------------------------------------
19a4620fc4081a63772b95987ad157b9d6dfd50b
.../perceptron/initialization_methods/random_init.hpp | 10 +++++-----
.../perceptron/initialization_methods/zero_init.hpp | 10 +++++-----
.../learning_policies/simple_weight_update.hpp | 14 +++++++-------
src/mlpack/methods/perceptron/perceptron.hpp | 6 +++++-
src/mlpack/methods/perceptron/perceptron_impl.hpp | 16 ++++++++--------
5 files changed, 30 insertions(+), 26 deletions(-)
diff --git a/src/mlpack/methods/perceptron/initialization_methods/random_init.hpp b/src/mlpack/methods/perceptron/initialization_methods/random_init.hpp
index f88e952..7ec04df 100644
--- a/src/mlpack/methods/perceptron/initialization_methods/random_init.hpp
+++ b/src/mlpack/methods/perceptron/initialization_methods/random_init.hpp
@@ -22,14 +22,14 @@ class RandomInitialization
RandomInitialization() { }
inline static void Initialize(arma::mat& W,
- const size_t row,
- const size_t col)
+ const size_t numFeatures,
+ const size_t numClasses)
{
- W = arma::randu<arma::mat>(row, col);
+ W = arma::randu<arma::mat>(numFeatures, numClasses);
}
}; // class RandomInitialization
-}; // namespace perceptron
-}; // namespace mlpack
+} // namespace perceptron
+} // namespace mlpack
#endif
diff --git a/src/mlpack/methods/perceptron/initialization_methods/zero_init.hpp b/src/mlpack/methods/perceptron/initialization_methods/zero_init.hpp
index 0a02459..7fdf009 100644
--- a/src/mlpack/methods/perceptron/initialization_methods/zero_init.hpp
+++ b/src/mlpack/methods/perceptron/initialization_methods/zero_init.hpp
@@ -21,17 +21,17 @@ class ZeroInitialization
ZeroInitialization() { }
inline static void Initialize(arma::mat& W,
- const size_t row,
- const size_t col)
+ const size_t numFeatures,
+ const size_t numClasses)
{
- arma::mat tempWeights(row, col);
+ arma::mat tempWeights(numFeatures, numClasses);
tempWeights.fill(0.0);
W = tempWeights;
}
}; // class ZeroInitialization
-}; // namespace perceptron
-}; // namespace mlpack
+} // namespace perceptron
+} // namespace mlpack
#endif
diff --git a/src/mlpack/methods/perceptron/learning_policies/simple_weight_update.hpp b/src/mlpack/methods/perceptron/learning_policies/simple_weight_update.hpp
index b78adc9..262c2cd 100644
--- a/src/mlpack/methods/perceptron/learning_policies/simple_weight_update.hpp
+++ b/src/mlpack/methods/perceptron/learning_policies/simple_weight_update.hpp
@@ -32,7 +32,7 @@ class SimpleWeightUpdate
*
* @param trainData The training dataset.
* @param weightVectors Matrix of weight vectors.
- * @param rowIndex Index of the row which has been incorrectly predicted.
+ * @param colIndex Index of the column which has been incorrectly predicted.
* @param labelIndex Index of the vector in trainData.
* @param vectorIndex Index of the class which should have been predicted.
* @param D Cost of mispredicting the labelIndex instance.
@@ -41,15 +41,15 @@ class SimpleWeightUpdate
arma::mat& weightVectors,
const size_t labelIndex,
const size_t vectorIndex,
- const size_t rowIndex,
+ const size_t colIndex,
const arma::rowvec& D)
{
- weightVectors.row(rowIndex).subvec(1, weightVectors.n_cols - 1) -=
- D(labelIndex) * trainData.col(labelIndex).t();
- weightVectors(rowIndex, 0) -= D(labelIndex);
+ weightVectors.col(colIndex).subvec(1, weightVectors.n_rows - 1) -=
+ D(labelIndex) * trainData.col(labelIndex);
+ weightVectors(colIndex, 0) -= D(labelIndex);
- weightVectors.row(vectorIndex).subvec(1, weightVectors.n_cols - 1) +=
- D(labelIndex) * trainData.col(labelIndex).t();
+ weightVectors.col(vectorIndex).subvec(1, weightVectors.n_rows - 1) +=
+ D(labelIndex) * trainData.col(labelIndex);
weightVectors(vectorIndex, 0) += D(labelIndex);
}
};
diff --git a/src/mlpack/methods/perceptron/perceptron.hpp b/src/mlpack/methods/perceptron/perceptron.hpp
index f2c9275..b556684 100644
--- a/src/mlpack/methods/perceptron/perceptron.hpp
+++ b/src/mlpack/methods/perceptron/perceptron.hpp
@@ -80,7 +80,11 @@ private:
//! To store the number of iterations
size_t iter;
- //! Stores the weight vectors for each of the input class labels.
+ /**
+ * Stores the weight vectors for each of the input class labels. Each column
+ * corresponds to the weights for one class label, and each row corresponds to
+ * the weights for one dimension of the input data.
+ */
arma::mat weightVectors;
/**
diff --git a/src/mlpack/methods/perceptron/perceptron_impl.hpp b/src/mlpack/methods/perceptron/perceptron_impl.hpp
index c1aafc2..1f4bf6c 100644
--- a/src/mlpack/methods/perceptron/perceptron_impl.hpp
+++ b/src/mlpack/methods/perceptron/perceptron_impl.hpp
@@ -34,7 +34,7 @@ Perceptron<LearnPolicy, WeightInitializationPolicy, MatType>::Perceptron(
const int iterations)
{
WeightInitializationPolicy WIP;
- WIP.Initialize(weightVectors, arma::max(labels) + 1, data.n_rows + 1);
+ WIP.Initialize(weightVectors, data.n_rows + 1, arma::max(labels) + 1);
// Start training.
iter = iterations;
@@ -68,9 +68,9 @@ void Perceptron<LearnPolicy, WeightInitializationPolicy, MatType>::Classify(
// Could probably be faster if done in batch.
for (size_t i = 0; i < test.n_cols; i++)
{
- tempLabelMat = weightVectors.submat(0, 1, weightVectors.n_rows - 1,
- weightVectors.n_cols - 1) *
- test.col(i) + weightVectors.col(0);
+ tempLabelMat = weightVectors.submat(1, 0, weightVectors.n_rows - 1,
+ weightVectors.n_cols - 1).t() *
+ test.col(i) + weightVectors.row(0).t();
tempLabelMat.max(maxIndex);
predictedLabels(0, i) = maxIndex;
}
@@ -101,7 +101,7 @@ Perceptron<LearnPolicy, WeightInitializationPolicy, MatType>::Perceptron(
// Insert a row of ones at the top of the training data set.
WeightInitializationPolicy WIP;
- WIP.Initialize(weightVectors, arma::max(labels) + 1, data.n_rows + 1);
+ WIP.Initialize(weightVectors, data.n_rows + 1, arma::max(labels) + 1);
Train(data, labels, D);
}
@@ -153,8 +153,8 @@ void Perceptron<LearnPolicy, WeightInitializationPolicy, MatType>::Train(
{
// Multiply for each variable and check whether the current weight vector
// correctly classifies this.
- tempLabelMat = weightVectors.cols(1, weightVectors.n_cols - 1) *
- data.col(j) + weightVectors.col(0);
+ tempLabelMat = weightVectors.rows(1, weightVectors.n_rows - 1).t() *
+ data.col(j) + weightVectors.row(0).t();
tempLabelMat.max(maxIndexRow, maxIndexCol);
@@ -167,7 +167,7 @@ void Perceptron<LearnPolicy, WeightInitializationPolicy, MatType>::Train(
// Send maxIndexRow for knowing which weight to update, send j to know
// the value of the vector to update it with. Send tempLabel to know
// the correct class.
- LP.UpdateWeights(data, weightVectors, j, tempLabel, maxIndexRow, D);
+ LP.UpdateWeights(data, weightVectors, j, tempLabel, maxIndexCol, D);
}
}
}
More information about the mlpack-git
mailing list