[mlpack-git] master: Refactor to make instance weights optional. (2bd5e78)
gitdub at big.cc.gt.atl.ga.us
gitdub at big.cc.gt.atl.ga.us
Fri Sep 4 14:07:05 EDT 2015
Repository : https://github.com/mlpack/mlpack
On branch : master
Link : https://github.com/mlpack/mlpack/compare/f3bd5e8853a795f4ff41849dd8ef844d53199412...2bd5e7889d810e86f7cdd586485f4b4e7a3b9cf0
>---------------------------------------------------------------
commit 2bd5e7889d810e86f7cdd586485f4b4e7a3b9cf0
Author: Ryan Curtin <ryan at ratml.org>
Date: Fri Sep 4 18:06:49 2015 +0000
Refactor to make instance weights optional.
>---------------------------------------------------------------
2bd5e7889d810e86f7cdd586485f4b4e7a3b9cf0
src/mlpack/methods/perceptron/perceptron.hpp | 2 +-
src/mlpack/methods/perceptron/perceptron_impl.hpp | 16 ++++++++++------
2 files changed, 11 insertions(+), 7 deletions(-)
diff --git a/src/mlpack/methods/perceptron/perceptron.hpp b/src/mlpack/methods/perceptron/perceptron.hpp
index 50eaa4d..fd674c6 100644
--- a/src/mlpack/methods/perceptron/perceptron.hpp
+++ b/src/mlpack/methods/perceptron/perceptron.hpp
@@ -98,7 +98,7 @@ private:
*/
void Train(const MatType& data,
const arma::Row<size_t>& labels,
- const arma::rowvec& D);
+ const arma::rowvec& D = arma::rowvec());
};
} // namespace perceptron
diff --git a/src/mlpack/methods/perceptron/perceptron_impl.hpp b/src/mlpack/methods/perceptron/perceptron_impl.hpp
index 5b5cbc2..72b720a 100644
--- a/src/mlpack/methods/perceptron/perceptron_impl.hpp
+++ b/src/mlpack/methods/perceptron/perceptron_impl.hpp
@@ -37,10 +37,7 @@ Perceptron<LearnPolicy, WeightInitializationPolicy, MatType>::Perceptron(
// Start training.
iter = iterations;
- arma::rowvec D(data.n_cols);
- D.fill(1.0);// giving equal weight to all the points.
-
- Train(data, labels, D);
+ Train(data, labels);
}
@@ -138,6 +135,8 @@ void Perceptron<LearnPolicy, WeightInitializationPolicy, MatType>::Train(
LearnPolicy LP;
+ const bool hasWeights = (D.n_elem > 0);
+
while ((i < iter) && (!converged))
{
// This outer loop is for each iteration, and we use the 'converged'
@@ -160,11 +159,16 @@ void Perceptron<LearnPolicy, WeightInitializationPolicy, MatType>::Train(
// Due to incorrect prediction, convergence set to false.
converged = false;
tempLabel = labels(0, j);
+
// Send maxIndexRow for knowing which weight to update, send j to know
// the value of the vector to update it with. Send tempLabel to know
// the correct class.
- LP.UpdateWeights(data.col(j), weights, biases, maxIndexRow, tempLabel,
- D(j));
+ if (hasWeights)
+ LP.UpdateWeights(data.col(j), weights, biases, maxIndexRow, tempLabel,
+ D(j));
+ else
+ LP.UpdateWeights(data.col(j), weights, biases, maxIndexRow,
+ tempLabel);
}
}
}
More information about the mlpack-git
mailing list