[mlpack-git] master: Refactor perceptron to not modify input dataset. Minimize internally-held variables. (424383c)
gitdub at big.cc.gt.atl.ga.us
gitdub at big.cc.gt.atl.ga.us
Fri Sep 4 11:50:41 EDT 2015
Repository : https://github.com/mlpack/mlpack
On branch : master
Link : https://github.com/mlpack/mlpack/compare/f5893d5d190d5f5b4b6dc94e2593f50c56d406e4...424383cb02dcf2d73728e1c3c4b582bdb7cba627
>---------------------------------------------------------------
commit 424383cb02dcf2d73728e1c3c4b582bdb7cba627
Author: Ryan Curtin <ryan at ratml.org>
Date: Fri Sep 4 15:50:04 2015 +0000
Refactor perceptron to not modify input dataset.
Minimize internally-held variables.
>---------------------------------------------------------------
424383cb02dcf2d73728e1c3c4b582bdb7cba627
.../learning_policies/simple_weight_update.hpp | 20 ++---
src/mlpack/methods/perceptron/perceptron.hpp | 18 ++---
src/mlpack/methods/perceptron/perceptron_impl.hpp | 85 +++++++++++-----------
3 files changed, 60 insertions(+), 63 deletions(-)
diff --git a/src/mlpack/methods/perceptron/learning_policies/simple_weight_update.hpp b/src/mlpack/methods/perceptron/learning_policies/simple_weight_update.hpp
index 12c67ea..b78adc9 100644
--- a/src/mlpack/methods/perceptron/learning_policies/simple_weight_update.hpp
+++ b/src/mlpack/methods/perceptron/learning_policies/simple_weight_update.hpp
@@ -26,9 +26,9 @@ class SimpleWeightUpdate
{
public:
/**
- * This function is called to update the weightVectors matrix.
- * It decreases the weights of the incorrectly classified class while
- * increasing the weight of the correct class it should have been classified to.
+ * This function is called to update the weightVectors matrix. It decreases
+ * the weights of the incorrectly classified class while increasing the weight
+ * of the correct class it should have been classified to.
*
* @param trainData The training dataset.
* @param weightVectors Matrix of weight vectors.
@@ -44,15 +44,17 @@ class SimpleWeightUpdate
const size_t rowIndex,
const arma::rowvec& D)
{
- weightVectors.row(rowIndex) = weightVectors.row(rowIndex) -
- D(labelIndex) * trainData.col(labelIndex).t();
+ weightVectors.row(rowIndex).subvec(1, weightVectors.n_cols - 1) -=
+ D(labelIndex) * trainData.col(labelIndex).t();
+ weightVectors(rowIndex, 0) -= D(labelIndex);
- weightVectors.row(vectorIndex) = weightVectors.row(vectorIndex) +
- D(labelIndex) * trainData.col(labelIndex).t();
+ weightVectors.row(vectorIndex).subvec(1, weightVectors.n_cols - 1) +=
+ D(labelIndex) * trainData.col(labelIndex).t();
+ weightVectors(vectorIndex, 0) += D(labelIndex);
}
};
-}; // namespace perceptron
-}; // namespace mlpack
+} // namespace perceptron
+} // namespace mlpack
#endif
diff --git a/src/mlpack/methods/perceptron/perceptron.hpp b/src/mlpack/methods/perceptron/perceptron.hpp
index b58d538..f2c9275 100644
--- a/src/mlpack/methods/perceptron/perceptron.hpp
+++ b/src/mlpack/methods/perceptron/perceptron.hpp
@@ -43,7 +43,7 @@ class Perceptron
*/
Perceptron(const MatType& data,
const arma::Row<size_t>& labels,
- int iterations);
+ const int iterations);
/**
* Classification function. After training, use the weightVectors matrix to
@@ -66,7 +66,7 @@ class Perceptron
* @param labels The labels of data.
*/
Perceptron(const Perceptron<>& other,
- MatType& data,
+ const MatType& data,
const arma::rowvec& D,
const arma::Row<size_t>& labels);
@@ -80,21 +80,17 @@ private:
//! To store the number of iterations
size_t iter;
- //! Stores the class labels for the input data.
- arma::Row<size_t> classLabels;
-
//! Stores the weight vectors for each of the input class labels.
arma::mat weightVectors;
- //! Stores the training data to be used later on in UpdateWeights.
- arma::mat trainData;
-
/**
- * Training Function. It trains on trainData using the cost matrix D
+ * Training Function. It trains on trainData using the cost matrix D
*
- * @param D Cost matrix. Stores the cost of mispredicting instances
+ * @param D Cost matrix. Stores the cost of mispredicting instances
*/
- void Train(const arma::rowvec& D);
+ void Train(const MatType& data,
+ const arma::Row<size_t>& labels,
+ const arma::rowvec& D);
};
} // namespace perceptron
diff --git a/src/mlpack/methods/perceptron/perceptron_impl.hpp b/src/mlpack/methods/perceptron/perceptron_impl.hpp
index 01a105f..c1aafc2 100644
--- a/src/mlpack/methods/perceptron/perceptron_impl.hpp
+++ b/src/mlpack/methods/perceptron/perceptron_impl.hpp
@@ -31,25 +31,17 @@ template<
Perceptron<LearnPolicy, WeightInitializationPolicy, MatType>::Perceptron(
const MatType& data,
const arma::Row<size_t>& labels,
- int iterations)
+ const int iterations)
{
WeightInitializationPolicy WIP;
WIP.Initialize(weightVectors, arma::max(labels) + 1, data.n_rows + 1);
// Start training.
- classLabels = labels;
-
- trainData = data;
- // Insert a row of ones at the top of the training data set.
- MatType zOnes(1, data.n_cols);
- zOnes.fill(1);
- trainData.insert_rows(0, zOnes);
-
iter = iterations;
arma::rowvec D(data.n_cols);
D.fill(1.0);// giving equal weight to all the points.
- Train(D);
+ Train(data, labels, D);
}
@@ -61,53 +53,57 @@ Perceptron<LearnPolicy, WeightInitializationPolicy, MatType>::Perceptron(
* @param predictedLabels vector to store the predicted classes after
* classifying test
*/
-template <typename LearnPolicy, typename WeightInitializationPolicy, typename MatType>
+template<
+ typename LearnPolicy,
+ typename WeightInitializationPolicy,
+ typename MatType
+>
void Perceptron<LearnPolicy, WeightInitializationPolicy, MatType>::Classify(
const MatType& test,
arma::Row<size_t>& predictedLabels)
{
- arma::mat tempLabelMat;
- arma::uword maxIndexRow, maxIndexCol;
+ arma::vec tempLabelMat;
+ arma::uword maxIndex;
+ // Could probably be faster if done in batch.
for (size_t i = 0; i < test.n_cols; i++)
{
tempLabelMat = weightVectors.submat(0, 1, weightVectors.n_rows - 1,
weightVectors.n_cols - 1) *
test.col(i) + weightVectors.col(0);
- tempLabelMat.max(maxIndexRow, maxIndexCol);
- predictedLabels(0, i) = maxIndexRow;
+ tempLabelMat.max(maxIndex);
+ predictedLabels(0, i) = maxIndex;
}
- // predictedLabels.print("These are the labels predicted by the perceptron");
}
/**
- * Alternate constructor which copies parameters from an already initiated
- * perceptron.
+ * Alternate constructor which copies parameters from an already initiated
+ * perceptron.
*
- * @param other The other initiated Perceptron object from which we copy the
- * values from.
- * @param data The data on which to train this Perceptron object on.
- * @param D Weight vector to use while training. For boosting purposes.
- * @param labels The labels of data.
+ * @param other The other initiated Perceptron object from which we copy the
+ * values from.
+ * @param data The data on which to train this Perceptron object on.
+ * @param D Weight vector to use while training. For boosting purposes.
+ * @param labels The labels of data.
*/
-template <typename LearnPolicy, typename WeightInitializationPolicy, typename MatType>
+template<
+ typename LearnPolicy,
+ typename WeightInitializationPolicy,
+ typename MatType
+>
Perceptron<LearnPolicy, WeightInitializationPolicy, MatType>::Perceptron(
- const Perceptron<>& other, MatType& data, const arma::rowvec& D, const arma::Row<size_t>& labels)
+ const Perceptron<>& other,
+ const MatType& data,
+ const arma::rowvec& D,
+ const arma::Row<size_t>& labels)
{
-
- classLabels = labels;
- trainData = data;
iter = other.iter;
// Insert a row of ones at the top of the training data set.
- MatType zOnes(1, data.n_cols);
- zOnes.fill(1);
- trainData.insert_rows(0, zOnes);
-
WeightInitializationPolicy WIP;
WIP.Initialize(weightVectors, arma::max(labels) + 1, data.n_rows + 1);
- Train(D);
+ Train(data, labels, D);
}
//! Serialize the perceptron.
@@ -123,9 +119,9 @@ void Perceptron<LearnPolicy, WeightInitializationPolicy, MatType>::Serialize(
}
/**
- * Training Function. It trains on trainData using the cost matrix D
+ * Training Function. It trains on trainData using the cost matrix D
*
- * @param D Cost matrix. Stores the cost of mispredicting instances
+ * @param D Cost matrix. Stores the cost of mispredicting instances
*/
template<
typename LearnPolicy,
@@ -133,7 +129,9 @@ template<
typename MatType
>
void Perceptron<LearnPolicy, WeightInitializationPolicy, MatType>::Train(
- const arma::rowvec& D)
+ const MatType& data,
+ const arma::Row<size_t>& labels,
+ const arma::rowvec& D)
{
size_t j, i = 0;
bool converged = false;
@@ -151,30 +149,31 @@ void Perceptron<LearnPolicy, WeightInitializationPolicy, MatType>::Train(
converged = true;
// Now this inner loop is for going through the dataset in each iteration.
- for (j = 0; j < trainData.n_cols; j++)
+ for (j = 0; j < data.n_cols; j++)
{
// Multiply for each variable and check whether the current weight vector
// correctly classifies this.
- tempLabelMat = weightVectors * trainData.col(j);
+ tempLabelMat = weightVectors.cols(1, weightVectors.n_cols - 1) *
+ data.col(j) + weightVectors.col(0);
tempLabelMat.max(maxIndexRow, maxIndexCol);
// Check whether prediction is correct.
- if (maxIndexRow != classLabels(0, j))
+ if (maxIndexRow != labels(0, j))
{
// Due to incorrect prediction, convergence set to false.
converged = false;
- tempLabel = classLabels(0, j);
+ tempLabel = labels(0, j);
// Send maxIndexRow for knowing which weight to update, send j to know
// the value of the vector to update it with. Send tempLabel to know
// the correct class.
- LP.UpdateWeights(trainData, weightVectors, j, tempLabel, maxIndexRow, D);
+ LP.UpdateWeights(data, weightVectors, j, tempLabel, maxIndexRow, D);
}
}
}
}
-}; // namespace perceptron
-}; // namespace mlpack
+} // namespace perceptron
+} // namespace mlpack
#endif
More information about the mlpack-git
mailing list