[mlpack-svn] r10370 - mlpack/trunk/src/mlpack/methods/naive_bayes

fastlab-svn at coffeetalk-1.cc.gatech.edu fastlab-svn at coffeetalk-1.cc.gatech.edu
Wed Nov 23 20:39:41 EST 2011


Author: rcurtin
Date: 2011-11-23 20:39:41 -0500 (Wed, 23 Nov 2011)
New Revision: 10370

Added:
   mlpack/trunk/src/mlpack/methods/naive_bayes/nbc_main.cpp
   mlpack/trunk/src/mlpack/methods/naive_bayes/simple_nbc.cpp
   mlpack/trunk/src/mlpack/methods/naive_bayes/simple_nbc.hpp
Removed:
   mlpack/trunk/src/mlpack/methods/naive_bayes/nbc_main.cc
   mlpack/trunk/src/mlpack/methods/naive_bayes/phi.cc
   mlpack/trunk/src/mlpack/methods/naive_bayes/phi.h
   mlpack/trunk/src/mlpack/methods/naive_bayes/simple_nbc.cc
   mlpack/trunk/src/mlpack/methods/naive_bayes/simple_nbc.h
Modified:
   mlpack/trunk/src/mlpack/methods/naive_bayes/CMakeLists.txt
Log:
Move filenames to .hpp and .cpp to help finish #152, and fix style in accordance
with #153.  Removed phi() since it is a copy of the version found in
methods/gmm/ (which will also have to be moved).


Modified: mlpack/trunk/src/mlpack/methods/naive_bayes/CMakeLists.txt
===================================================================
--- mlpack/trunk/src/mlpack/methods/naive_bayes/CMakeLists.txt	2011-11-23 23:43:22 UTC (rev 10369)
+++ mlpack/trunk/src/mlpack/methods/naive_bayes/CMakeLists.txt	2011-11-24 01:39:41 UTC (rev 10370)
@@ -3,8 +3,8 @@
 # Define the files we need to compile.
 # Anything not in this list will not be compiled into MLPACK.
 set(SOURCES
-  simple_nbc.cc
-  phi.cc
+  simple_nbc.hpp
+  simple_nbc.cpp
 )
 
 # Add directory name to sources.
@@ -17,7 +17,7 @@
 set(MLPACK_SRCS ${MLPACK_SRCS} ${DIR_SRCS} PARENT_SCOPE)
 
 add_executable(nbc
-  nbc_main.cc
+  nbc_main.cpp
 )
 target_link_libraries(nbc
   mlpack

Deleted: mlpack/trunk/src/mlpack/methods/naive_bayes/nbc_main.cc
===================================================================
--- mlpack/trunk/src/mlpack/methods/naive_bayes/nbc_main.cc	2011-11-23 23:43:22 UTC (rev 10369)
+++ mlpack/trunk/src/mlpack/methods/naive_bayes/nbc_main.cc	2011-11-24 01:39:41 UTC (rev 10370)
@@ -1,116 +0,0 @@
-/**
- * @author Parikshit Ram (pram at cc.gatech.edu)
- * @file nbc_main.cc
- *
- * This program test drives the Simple Naive Bayes Classifier
- *
- * This classifier does parametric naive bayes classification
- * assuming that the features are sampled from a Gaussian
- * distribution.
- *
- * PARAMETERS TO BE INPUT:
- *
- * --train
- * This is the file that contains the training data
- *
- * --nbc/classes
- * This is the number of classes present in the training data
- *
- * --test
- * This file contains the data points which the trained
- * classifier would classify
- *
- * --output
- * This file will contain the classes to which the corresponding
- * data points in the testing data
- *
- */
-#include "simple_nbc.h"
-
-#include <mlpack/core.h>
-
-
-/*const fx_entry_doc parm_nbc_main_entries[] = {
-  {"train", FX_REQUIRED, FX_STR, NULL,
-   " A file containing the training set\n"},
-  {"test", FX_REQUIRED, FX_STR, NULL,
-   " A file containing the test set\n"},
-  {"output", FX_PARAM, FX_STR, NULL,
-   " The file in which the output of the test would be "
-   "written (defaults to 'output.csv')\n"},
-  FX_ENTRY_DOC_DONE
-};*/
-
-PARAM_STRING_REQ("train", "A file containing the training set", "nbc");
-PARAM_STRING_REQ("test", "A file containing the test set", "nbc");
-PARAM_STRING("output", "The file in which the output of the test would\
- be written, defaults to 'output.csv')", "nbc", "output.csv");
-
-/*const fx_submodule_doc parm_nbc_main_submodules[] = {
-  {"nbc", &parm_nbc_doc,
-   " Trains on a given set and number of classes and "
-   "tests them on a given set\n"},
-  FX_SUBMODULE_DOC_DONE
-};*/
-
-PARAM_MODULE("nbc", "Trains on a given set and number\
- of classes and tests them on a given set");
-
-/*const fx_module_doc parm_nbc_main_doc = {
-  parm_nbc_main_entries, parm_nbc_main_submodules,
-  "This program test drives the Parametric Naive Bayes \n"
-  "Classifier assuming that the features are sampled \n"
-  "from a Gaussian distribution.\n"
-};*/
-
-PROGRAM_INFO("Parametric Naive Bayes", "This program test drives the\
- Parametric Naive Bayes Classifier assuming that the features are\
- sampled from a Gaussian distribution.", "nbc");
-
-using namespace mlpack;
-using namespace naive_bayes;
-
-int main(int argc, char* argv[]) {
-
-  CLI::ParseCommandLine(argc, argv);
-
-  ////// READING PARAMETERS AND LOADING DATA //////
-
-  const char *training_data_filename = CLI::GetParam<std::string>("nbc/train").c_str();
-  arma::mat training_data;
-  data::Load(training_data_filename, training_data, true);
-
-  const char *testing_data_filename = CLI::GetParam<std::string>("nbc/test").c_str();
-  arma::mat testing_data;
-  data::Load(testing_data_filename, testing_data, true);
-
-  ////// SIMPLE NAIVE BAYES CLASSIFICATCLIN ASSUMING THE DATA TO BE UNIFORMLY DISTRIBUTED //////
-
-  ////// Timing the training of the Naive Bayes Classifier //////
-  Timers::StartTimer("nbc/training");
-
-  ////// Create and train the classifier
-  SimpleNaiveBayesClassifier nbc = SimpleNaiveBayesClassifier(training_data);
-
-  ////// Stop training timer //////
-  Timers::StopTimer("nbc/training");
-
-  ////// Timing the testing of the Naive Bayes Classifier //////
-  ////// The variable that contains the result of the classification
-  arma::vec results;
-
-  Timers::StartTimer("nbc/testing");
-
-  ////// Calling the function that classifies the test data
-  nbc.Classify(testing_data, results);
-
-  ////// Stop testing timer //////
-  Timers::StopTimer("nbc/testing");
-
-  ////// OUTPUT RESULTS //////
-  std::string output_filename = CLI::GetParam<std::string>("nbc/output");
-
-  data::Save(output_filename.c_str(), results, true);
-
-  return 1;
-}

Copied: mlpack/trunk/src/mlpack/methods/naive_bayes/nbc_main.cpp (from rev 10352, mlpack/trunk/src/mlpack/methods/naive_bayes/nbc_main.cc)
===================================================================
--- mlpack/trunk/src/mlpack/methods/naive_bayes/nbc_main.cpp	                        (rev 0)
+++ mlpack/trunk/src/mlpack/methods/naive_bayes/nbc_main.cpp	2011-11-24 01:39:41 UTC (rev 10370)
@@ -0,0 +1,75 @@
+/**
+ * @author Parikshit Ram (pram at cc.gatech.edu)
+ * @file nbc_main.cpp
+ *
+ * This program runs the Simple Naive Bayes Classifier.
+ *
+ * This classifier does parametric naive bayes classification assuming that the
+ * features are sampled from a Gaussian distribution.
+ *
+ * PARAMETERS TO BE INPUT:
+ *
+ * --train
+ * This is the file that contains the training data.
+ *
+ * --nbc/classes
+ * This is the number of classes present in the training data.
+ *
+ * --test
+ * This file contains the data points which the trained classifier would
+ * classify.
+ *
+ * --output
+ * This file will contain the classes to which the corresponding data points in
+ * the testing data.
+ */
+#include <mlpack/core.h>
+
+#include "simple_nbc.hpp"
+
+PARAM_STRING_REQ("train", "A file containing the training set", "nbc");
+PARAM_STRING_REQ("test", "A file containing the test set", "nbc");
+PARAM_STRING("output", "The file in which the output of the test would "
+    "be written, defaults to 'output.csv')", "nbc", "output.csv");
+
+PARAM_MODULE("nbc", "Trains on a given set and number of classes and tests "
+    "them on a given set");
+
+PROGRAM_INFO("Parametric Naive Bayes", "This program test drives the Parametric"
+    " Naive Bayes Classifier assuming that the features are sampled from a "
+    "Gaussian distribution.", "nbc");
+
+using namespace mlpack;
+using namespace naive_bayes;
+
+int main(int argc, char* argv[])
+{
+  CLI::ParseCommandLine(argc, argv);
+
+  const char *training_data_filename =
+      CLI::GetParam<std::string>("nbc/train").c_str();
+  arma::mat training_data;
+  data::Load(training_data_filename, training_data, true);
+
+  const char *testing_data_filename =
+      CLI::GetParam<std::string>("nbc/test").c_str();
+  arma::mat testing_data;
+  data::Load(testing_data_filename, testing_data, true);
+
+  // Create and train the classifier.
+  Timers::StartTimer("nbc/training");
+  SimpleNaiveBayesClassifier nbc = SimpleNaiveBayesClassifier(training_data);
+  Timers::StopTimer("nbc/training");
+
+  // Timing the running of the Naive Bayes Classifier.
+  arma::vec results;
+  Timers::StartTimer("nbc/testing");
+  nbc.Classify(testing_data, results);
+  Timers::StopTimer("nbc/testing");
+
+  // Output results.
+  std::string output_filename = CLI::GetParam<std::string>("nbc/output");
+  data::Save(output_filename.c_str(), results, true);
+
+  return 0;
+}

Deleted: mlpack/trunk/src/mlpack/methods/naive_bayes/phi.cc
===================================================================
--- mlpack/trunk/src/mlpack/methods/naive_bayes/phi.cc	2011-11-23 23:43:22 UTC (rev 10369)
+++ mlpack/trunk/src/mlpack/methods/naive_bayes/phi.cc	2011-11-24 01:39:41 UTC (rev 10370)
@@ -1,113 +0,0 @@
-/**
- * @author Parikshit Ram (pram at cc.gatech.edu)
- * @file phi.h
- *
- * This file computes the Gaussian probability
- * density function
- */
-#include "phi.h"
-
-long double phi(const arma::vec& x, const arma::vec& mean, const arma::mat& cov) {
-
-  long double det, f;
-  double exponent;
-  size_t dim;
-  arma::mat inv;
-  arma::vec diff, tmp;
-
-  dim = x.n_rows;
-  inv = arma::inv(cov);
-  det = arma::det(cov);
-
-  if( det < 0){
-    det = -det;
-  }
-
-  diff = mean - x;
-  tmp = inv*diff;
-  exponent = arma::dot(diff,tmp);
-
-  long double tmp1, tmp2, tmp3;
-  tmp1 = 1;
-  tmp2 = dim;
-  tmp2 = tmp2/2;
-  tmp2 = pow(2 * M_PI,tmp2);
-  tmp1 = tmp1/tmp2;
-  tmp3 = 1;
-  tmp2 = sqrt(det);
-  tmp3 = tmp3/tmp2;
-  tmp2 = -exponent;
-  tmp2 = tmp2 / 2;
-
-  f = (tmp1*tmp3*exp(tmp2));
-
-  return f;
-}
-
-long double phi(const double x, const double mean, const double var) {
-
-  long double f;
-
-  f = exp( -1.0*( (x-mean)*(x-mean)/(2*var) ) )/sqrt(2* M_PI*var);
-  return f;
-}
-
-long double phi(const arma::vec& x, const arma::vec& mean, const arma::mat& cov, const std::vector<arma::mat>& d_cov, arma::vec& g_mean, arma::vec& g_cov){
-
-  long double det, f;
-  double exponent;
-  size_t dim;
-  arma::mat inv;
-  arma::vec diff, tmp;
-
-  // First calculate the multivariate Gaussian probability density function
-  // We don't just call phi() to do this because we need some of the values later
-  dim = x.n_rows;
-  inv = arma::inv(cov);
-  det = arma::det(cov);
-
-  if( det < 0){
-    det = -det;
-  }
-
-  diff = mean - x;
-  tmp = inv*diff;
-  exponent = arma::dot(diff,tmp);
-
-  long double tmp1, tmp2, tmp3;
-  tmp1 = 1;
-  tmp2 = dim;
-  tmp2 = tmp2/2;
-  tmp2 = pow(2 * M_PI,tmp2);
-  tmp1 = tmp1/tmp2;
-  tmp3 = 1;
-  tmp2 = sqrt(det);
-  tmp3 = tmp3/tmp2;
-  tmp2 = -exponent;
-  tmp2 = tmp2 / 2;
-
-  f = (tmp1*tmp3*exp(tmp2));
-
-  // Calculating the g_mean values  which would be a (1 X dim) vector
-  g_mean = f*tmp;
-
-  // Calculating the g_cov values which would be a (1 X (dim*(dim+1)/2)) vector
-  arma::vec g_cov_tmp(d_cov.size());
-  for(size_t i = 0; i < d_cov.size(); i++){
-    arma::vec tmp_d;
-    arma::mat inv_d;
-    long double tmp_d_cov_d_r;
-
-    tmp_d = d_cov[i]*tmp;
-    tmp_d_cov_d_r = arma::dot(tmp_d,tmp);
-    inv_d = inv*d_cov[i];
-
-    for(size_t j = 0; j < dim; j++)
-      tmp_d_cov_d_r += inv_d(j,j);
-
-    g_cov_tmp[i] = f*tmp_d_cov_d_r/2;
-  }
-  g_cov = g_cov_tmp;
-
-  return f;
-}

Deleted: mlpack/trunk/src/mlpack/methods/naive_bayes/phi.h
===================================================================
--- mlpack/trunk/src/mlpack/methods/naive_bayes/phi.h	2011-11-23 23:43:22 UTC (rev 10369)
+++ mlpack/trunk/src/mlpack/methods/naive_bayes/phi.h	2011-11-24 01:39:41 UTC (rev 10370)
@@ -1,56 +0,0 @@
-/**
- * @author Parikshit Ram (pram at cc.gatech.edu)
- * @file phi.h
- *
- * This file computes the Gaussian probability
- * density function
- */
-
-#ifndef MLPACK_PHI_H
-#define MLPACK_PHI_H
-
-#include <mlpack/core.h>
-
-/**
- * Calculates the multivariate Gaussian probability density function
- *
- * Example use:
- * @code
- * arma::vec x, mean;
- * arma::mat cov;
- * ....
- * long double f = phi(x, mean, cov);
- * @endcode
- */
-
-long double phi(const arma::vec& x, const arma::vec& mean, const arma::mat& cov);
-
-/**
- * Calculates the univariate Gaussian probability density function
- *
- * Example use:
- * @code
- * double x, mean, var;
- * ....
- * long double f = phi(x, mean, var);
- * @endcode
- */
-
-long double phi(const double x, const double mean, const double var);
-
-/**
- * Calculates the multivariate Gaussian probability density function
- * and also the gradients with respect to the mean and the variance
- *
- * Example use:
- * @code
- * arma::vec x, mean, g_mean, g_cov;
- * std::vector<arma::mat> d_cov; // the dSigma
- * ....
- * long double f = phi(x, mean, cov, d_cov, &g_mean, &g_cov);
- * @endcode
- */
-
-long double phi(const arma::vec& x, const arma::vec& mean, const arma::mat& cov, const std::vector<arma::mat>& d_cov, arma::vec& g_mean, arma::vec& g_cov);
-
-#endif // MLPACK_PHI_H

Deleted: mlpack/trunk/src/mlpack/methods/naive_bayes/simple_nbc.cc
===================================================================
--- mlpack/trunk/src/mlpack/methods/naive_bayes/simple_nbc.cc	2011-11-23 23:43:22 UTC (rev 10369)
+++ mlpack/trunk/src/mlpack/methods/naive_bayes/simple_nbc.cc	2011-11-24 01:39:41 UTC (rev 10370)
@@ -1,124 +0,0 @@
-/**
- * @author Parikshit Ram (pram at cc.gatech.edu)
- * @file simple_nbc.h
- *
- * A Naive Bayes Classifier which parametrically
- * estimates the distribution of the features.
- * It is assumed that the features have been
- * sampled from a Gaussian PDF
- *
- */
-#include <mlpack/core.h>
-
-#include "simple_nbc.h"
-#include "phi.h"
-
-namespace mlpack {
-namespace naive_bayes {
-
-SimpleNaiveBayesClassifier::SimpleNaiveBayesClassifier(const arma::mat& data)
-{
-
-  size_t number_examples = data.n_cols;
-  size_t number_features = data.n_rows - 1;
-
-  arma::vec feature_sum, feature_sum_squared;
-  feature_sum.zeros(number_features);
-  feature_sum_squared.zeros(number_features);
-
-  // updating the variables, private and local, according to
-  // the number of features and classes present in the data
-  number_of_classes_ = mlpack::CLI::GetParam<int>("nbc/classes");
-  class_probabilities_.set_size(number_of_classes_);
-  means_.set_size(number_features,number_of_classes_);
-  variances_.set_size(number_features,number_of_classes_);
-
-  Log::Info << number_examples << " examples with " << number_features
-    << " features each" << std::endl;
-
-  CLI::GetParam<int>("nbc/features") = number_features;
-  CLI::GetParam<int>("nbc/examples") = number_examples;
-
-  // calculating the class probabilities as well as the
-  // sample mean and variance for each of the features
-  // with respect to each of the labels
-  for(size_t i = 0; i < number_of_classes_; i++ ) {
-    size_t number_of_occurrences = 0;
-    for (size_t j = 0; j < number_examples; j++) {
-      size_t flag = (size_t)  data(number_features, j);
-      if(i == flag) {
-	++number_of_occurrences;
-	for(size_t k = 0; k < number_features; k++) {
-	  double tmp = data(k, j);
-	  feature_sum(k) += tmp;
-	  feature_sum_squared(k) += tmp*tmp;
-	}
-      }
-    }
-    class_probabilities_[i] = (double)number_of_occurrences
-      / (double)number_examples ;
-    for(size_t k = 0; k < number_features; k++) {
-      double sum = feature_sum(k),
-	     sum_squared = feature_sum_squared(k);
-
-      means_(k, i) = (sum / number_of_occurrences);
-      variances_(k, i) = (sum_squared
-			    - (sum * sum / number_of_occurrences))
-			   /(number_of_occurrences - 1);
-    }
-    // Reset the summations to zero for the next iteration
-    feature_sum.zeros(number_features);
-    feature_sum_squared.zeros(number_features);
-  }
-}
-
-void SimpleNaiveBayesClassifier::Classify(const arma::mat& test_data, arma::vec& results){
-
-  // Checking that the number of features in the test data is same
-  // as in the training data
-  Log::Assert(test_data.n_rows - 1 == means_.n_rows);
-
-  arma::vec tmp_vals(number_of_classes_);
-  size_t number_features = test_data.n_rows - 1;
-
-  results.zeros(test_data.n_cols);
-
-  Log::Info << test_data.n_cols << " test cases with " << number_features
-    << " features each" << std::endl;
-
-  CLI::GetParam<int>("nbc/tests") = test_data.n_cols;
-  // Calculating the joint probability for each of the data points
-  // for each of the classes
-
-  // looping over every test case
-  for (size_t n = 0; n < test_data.n_cols; n++) {
-
-    //looping over every class
-    for (size_t i = 0; i < number_of_classes_; i++) {
-      // Using the log values to prevent floating point underflow
-      tmp_vals(i) = log(class_probabilities_(i));
-
-      //looping over every feature
-      for (size_t j = 0; j < number_features; j++) {
-	tmp_vals(i) += log(phi(test_data(j, n),
-			       means_(j, i),
-			       variances_(j, i))
-			   );
-      }
-    }
-
-    // Find the index of the maximum value in tmp_vals.
-    size_t max = 0;
-    for (size_t k = 0; k < number_of_classes_; k++) {
-      if(tmp_vals(max) < tmp_vals(k))
-	max = k;
-    }
-    results(n) = max;
-  }
-
-
-  return;
-}
-
-}; // namespace naive_bayes
-}; // namespace mlpack

Copied: mlpack/trunk/src/mlpack/methods/naive_bayes/simple_nbc.cpp (from rev 10352, mlpack/trunk/src/mlpack/methods/naive_bayes/simple_nbc.cc)
===================================================================
--- mlpack/trunk/src/mlpack/methods/naive_bayes/simple_nbc.cpp	                        (rev 0)
+++ mlpack/trunk/src/mlpack/methods/naive_bayes/simple_nbc.cpp	2011-11-24 01:39:41 UTC (rev 10370)
@@ -0,0 +1,127 @@
+/**
+ * @file simple_nbc.cpp
+ * @author Parikshit Ram (pram at cc.gatech.edu)
+ *
+ * A Naive Bayes Classifier which parametrically estimates the distribution of
+ * the features.  It is assumed that the features have been sampled from a
+ * Gaussian PDF.
+ */
+#include <mlpack/core.h>
+
+#include "simple_nbc.hpp"
+
+namespace mlpack {
+namespace naive_bayes {
+
+SimpleNaiveBayesClassifier::SimpleNaiveBayesClassifier(const arma::mat& data)
+{
+  size_t number_examples = data.n_cols;
+  size_t number_features = data.n_rows - 1;
+
+  arma::vec feature_sum, feature_sum_squared;
+  feature_sum.zeros(number_features);
+  feature_sum_squared.zeros(number_features);
+
+  // Update the variables, private and local, according to the number of
+  // features and classes present in the data.
+  number_of_classes_ = mlpack::CLI::GetParam<int>("nbc/classes");
+  class_probabilities_.set_size(number_of_classes_);
+  means_.set_size(number_features,number_of_classes_);
+  variances_.set_size(number_features,number_of_classes_);
+
+  Log::Info << number_examples << " examples with " << number_features
+      << " features each" << std::endl;
+
+  CLI::GetParam<int>("nbc/features") = number_features;
+  CLI::GetParam<int>("nbc/examples") = number_examples;
+
+  // Calculate the class probabilities as well as the sample mean and variance
+  // for each of the features with respect to each of the labels.
+  for (size_t i = 0; i < number_of_classes_; i++ )
+  {
+    size_t number_of_occurrences = 0;
+    for (size_t j = 0; j < number_examples; j++)
+    {
+      size_t flag = (size_t)  data(number_features, j);
+      if (i == flag)
+      {
+        ++number_of_occurrences;
+        for (size_t k = 0; k < number_features; k++)
+        {
+          double tmp = data(k, j);
+          feature_sum(k) += tmp;
+          feature_sum_squared(k) += tmp*tmp;
+        }
+      }
+    }
+
+    class_probabilities_[i] = (double) number_of_occurrences
+        / (double) number_examples;
+
+    for (size_t k = 0; k < number_features; k++)
+    {
+      double sum = feature_sum(k);
+      double sum_squared = feature_sum_squared(k);
+
+      means_(k, i) = (sum / number_of_occurrences);
+      variances_(k, i) = (sum_squared - (sum * sum / number_of_occurrences))
+          / (number_of_occurrences - 1);
+    }
+
+    // Reset the summations to zero for the next iteration
+    feature_sum.zeros(number_features);
+    feature_sum_squared.zeros(number_features);
+  }
+}
+
+void SimpleNaiveBayesClassifier::Classify(const arma::mat& test_data,
+                                          arma::vec& results)
+{
+  // Check that the number of features in the test data is same as in the
+  // training data.
+  Log::Assert(test_data.n_rows - 1 == means_.n_rows);
+
+  arma::vec tmp_vals(number_of_classes_);
+  size_t number_features = test_data.n_rows - 1;
+
+  results.zeros(test_data.n_cols);
+
+  Log::Info << test_data.n_cols << " test cases with " << number_features
+      << " features each" << std::endl;
+
+  CLI::GetParam<int>("nbc/tests") = test_data.n_cols;
+  // Calculate the joint probability for each of the data points for each of the
+  // classes.
+
+  // Loop over every test case.
+  for (size_t n = 0; n < test_data.n_cols; n++)
+  {
+    // Loop over every class.
+    for (size_t i = 0; i < number_of_classes_; i++)
+    {
+      // Use the log values to prevent floating point underflow.
+      tmp_vals(i) = log(class_probabilities_(i));
+
+      // Loop over every feature.
+      for (size_t j = 0; j < number_features; j++)
+      {
+        tmp_vals(i) += log(gmm::phi(test_data(j, n), means_(j, i),
+            variances_(j, i)));
+      }
+    }
+
+    // Find the index of the maximum value in tmp_vals.
+    size_t max = 0;
+    for (size_t k = 0; k < number_of_classes_; k++)
+    {
+      if (tmp_vals(max) < tmp_vals(k))
+        max = k;
+    }
+    results(n) = max;
+  }
+
+  return;
+}
+
+}; // namespace naive_bayes
+}; // namespace mlpack

Deleted: mlpack/trunk/src/mlpack/methods/naive_bayes/simple_nbc.h
===================================================================
--- mlpack/trunk/src/mlpack/methods/naive_bayes/simple_nbc.h	2011-11-23 23:43:22 UTC (rev 10369)
+++ mlpack/trunk/src/mlpack/methods/naive_bayes/simple_nbc.h	2011-11-24 01:39:41 UTC (rev 10370)
@@ -1,146 +0,0 @@
-/**
- * @author Parikshit Ram (pram at cc.gatech.edu)
- * @file simple_nbc.h
- *
- * A Naive Bayes Classifier which parametrically
- * estimates the distribution of the features.
- * It is assumed that the features have been
- * sampled from a Gaussian PDF
- *
- */
-#ifndef NBC_H
-#define NBC_H
-
-#include <mlpack/core.h>
-#include "phi.h"
-
-namespace mlpack {
-namespace naive_bayes {
-
-/*const fx_entry_doc parm_nbc_entries[] ={
-  {"training", FX_TIMER, FX_CUSTOM, NULL,
-   " The timer to record the training time\n"},
-  {"testing", FX_TIMER, FX_CUSTOM, NULL,
-   " The timer to record the testing time\n"},
-  {"classes", FX_REQUIRED, FX_INT, NULL,
-   " The number of classes present in the data\n"},
-  {"features", FX_RESULT, FX_INT, NULL,
-   " The number of features in the data\n"},
-  {"examples", FX_RESULT, FX_INT, NULL,
-   " The number of examples in the training set\n"},
-  {"tests", FX_RESULT, FX_INT, NULL,
-   " The number of data points in the test set\n"},
-  FX_ENTRY_DOC_DONE
-};*/
-
-PARAM_INT_REQ("classes", "The number of classes present in the data", "nbc");
-PARAM_INT("features", "The number of features in the data", "nbc", 0);
-PARAM_INT("examples", "The number of examples in the training set", "nbc", 0);
-PARAM_INT("tests", "The number of data points in the test set", "nbc", 0);
-
-/*const fx_submodule_doc parm_nbc_submodules[] = {
-  FX_SUBMODULE_DOC_DONE
-};*/
-
-/*const fx_module_doc parm_nbc_doc = {
-  parm_nbc_entries, parm_nbc_submodules,
-  " Trains the classifier using the training set and "
-  "outputs the results for the test set\n"
-};*/
-
-PARAM_MODULE("nbc", "Trains the classifier using the training set \
-and outputs the results for the test set");
-
-/**
- * A classification class. The class labels are assumed
- * to be positive integers - 0,1,2,....
- *
- * This class trains on the data by calculating the
- * sample mean and variance of the features with
- * respect to each of the labels, and also the class
- * probabilities.
- *
- * Mathematically, it computes P(X_i = x_i | Y = y_j)
- * for each feature X_i for each of the labels y_j.
- * Alongwith this, it also computes the classs probabilities
- * P( Y = y_j)
- *
- * For classifying a data point (x_1, x_2, ..., x_n),
- * it computes the following:
- * arg max_y(P(Y = y)*P(X_1 = x_1 | Y = y) * ... * P(X_n = x_n | Y = y))
- *
- * Example use:
- *
- * @code
- * SimpleNaiveBayesClassifier nbc;
- * arma::mat training_data, testing_data;
- * datanode *nbc_module = fx_submodule(NULL,"nbc","nbc");
- * arma::vec results;
- *
- * nbc.InitTrain(training_data, nbc_module);
- * nbc.Classify(testing_data, &results);
- * @endcode
- */
-class SimpleNaiveBayesClassifier {
-
-  // The class for testing this class is made a friend class
-  //friend class TestClassSimpleNBC;
-
- private:
-
-
- public:
-
-
-  // The variables containing the sample mean and variance
-  // for each of the features with respect to each class
-
-  // The variables containing the sample mean and variance
-  // for each of the features with respect to each class
-  arma::mat means_, variances_;
-
-  // The variable containing the class probabilities
-  arma::vec class_probabilities_;
-
-  // The variable keeping the information about the
-  // number of classes present
-  size_t number_of_classes_;
-
- /**
-  * Initializes the classifier as per the input and then trains it
-  * by calculating the sample mean and variances
-  *
-  * Example use:
-  * @code
-  * arma::mat training_data, testing_data;
-  * datanode nbc_module = fx_submodule(NULL,"nbc","nbc");
-  * ....
-  * SimpleNaiveBayesClassifier nbc(training_data, nbc_module);
-  * @endcode
-  */
-  SimpleNaiveBayesClassifier(const arma::mat& data);
-  /**
-   * Default constructor, you need to use the other one.
-  */
-  SimpleNaiveBayesClassifier();
-  ~SimpleNaiveBayesClassifier(){
-  }
-
-  /**
-   * Given a bunch of data points, this function evaluates the class
-   * of each of those data points, and puts it in the vector 'results'
-   *
-   * @code
-   * arma::mat test_data; // each column is a test point
-   * arma::vec results;
-   * ...
-   * nbc.Classify(test_data, &results);
-   * @endcode
-   */
-  void Classify(const arma::mat& test_data, arma::vec& results);
-};
-
-}; // namespace naive_bayes
-}; // namespace mlpack
-
-#endif

Copied: mlpack/trunk/src/mlpack/methods/naive_bayes/simple_nbc.hpp (from rev 10352, mlpack/trunk/src/mlpack/methods/naive_bayes/simple_nbc.h)
===================================================================
--- mlpack/trunk/src/mlpack/methods/naive_bayes/simple_nbc.hpp	                        (rev 0)
+++ mlpack/trunk/src/mlpack/methods/naive_bayes/simple_nbc.hpp	2011-11-24 01:39:41 UTC (rev 10370)
@@ -0,0 +1,109 @@
+/**
+ * @file simple_nbc.hpp
+ * @author Parikshit Ram (pram at cc.gatech.edu)
+ *
+ * A Naive Bayes Classifier which parametrically estimates the distribution of
+ * the features.  It is assumed that the features have been sampled from a
+ * Gaussian PDF.
+ */
+#ifndef __MLPACK_METHODS_NBC_SIMPLE_NBC_HPP
+#define __MLPACK_METHODS_NBC_SIMPLE_NBC_HPP
+
+#include <mlpack/core.h>
+#include <mlpack/methods/gmm/phi.hpp>
+
+namespace mlpack {
+namespace naive_bayes {
+
+PARAM_INT_REQ("classes", "The number of classes present in the data.", "nbc");
+PARAM_INT("features", "The number of features in the data.", "nbc", 0);
+PARAM_INT("examples", "The number of examples in the training set.", "nbc", 0);
+PARAM_INT("tests", "The number of data points in the test set.", "nbc", 0);
+
+PARAM_MODULE("nbc", "Trains the classifier using the training set "
+    "and outputs the results for the test set.");
+
+/**
+ * A classification class. The class labels are assumed
+ * to be positive integers - 0,1,2,....
+ *
+ * This class trains on the data by calculating the
+ * sample mean and variance of the features with
+ * respect to each of the labels, and also the class
+ * probabilities.
+ *
+ * Mathematically, it computes P(X_i = x_i | Y = y_j)
+ * for each feature X_i for each of the labels y_j.
+ * Alongwith this, it also computes the classs probabilities
+ * P( Y = y_j)
+ *
+ * For classifying a data point (x_1, x_2, ..., x_n),
+ * it computes the following:
+ * arg max_y(P(Y = y)*P(X_1 = x_1 | Y = y) * ... * P(X_n = x_n | Y = y))
+ *
+ * Example use:
+ *
+ * @code
+ * SimpleNaiveBayesClassifier nbc;
+ * arma::mat training_data, testing_data;
+ * datanode *nbc_module = fx_submodule(NULL,"nbc","nbc");
+ * arma::vec results;
+ *
+ * nbc.InitTrain(training_data, nbc_module);
+ * nbc.Classify(testing_data, &results);
+ * @endcode
+ */
+class SimpleNaiveBayesClassifier
+{
+ public:
+  //! Sample mean for each class.
+  arma::mat means_;
+
+  //! Sample variances for each class.
+  arma::mat variances_;
+
+  //! Class probabilities.
+  arma::vec class_probabilities_;
+
+  //! The number of classes present.
+  size_t number_of_classes_;
+
+  /**
+   * Initializes the classifier as per the input and then trains it
+   * by calculating the sample mean and variances
+   *
+   * Example use:
+   * @code
+   * arma::mat training_data, testing_data;
+   * datanode nbc_module = fx_submodule(NULL,"nbc","nbc");
+   * ....
+   * SimpleNaiveBayesClassifier nbc(training_data, nbc_module);
+   * @endcode
+   */
+  SimpleNaiveBayesClassifier(const arma::mat& data);
+
+  /**
+   * Default constructor, you need to use the other one.
+   */
+  SimpleNaiveBayesClassifier();
+
+  ~SimpleNaiveBayesClassifier() { }
+
+  /**
+   * Given a bunch of data points, this function evaluates the class
+   * of each of those data points, and puts it in the vector 'results'
+   *
+   * @code
+   * arma::mat test_data; // each column is a test point
+   * arma::vec results;
+   * ...
+   * nbc.Classify(test_data, &results);
+   * @endcode
+   */
+  void Classify(const arma::mat& test_data, arma::vec& results);
+};
+
+}; // namespace naive_bayes
+}; // namespace mlpack
+
+#endif




More information about the mlpack-svn mailing list