[mlpack-svn] r15412 - mlpack/conf/jenkins-conf/benchmark/methods/weka/src/nbc

fastlab-svn at coffeetalk-1.cc.gatech.edu fastlab-svn at coffeetalk-1.cc.gatech.edu
Thu Jul 4 09:08:49 EDT 2013


Author: marcus
Date: Thu Jul  4 09:08:49 2013
New Revision: 15412

Log:
Add weka nbc method src.

Added:
   mlpack/conf/jenkins-conf/benchmark/methods/weka/src/nbc/
   mlpack/conf/jenkins-conf/benchmark/methods/weka/src/nbc/NBC.java

Added: mlpack/conf/jenkins-conf/benchmark/methods/weka/src/nbc/NBC.java
==============================================================================
--- (empty file)
+++ mlpack/conf/jenkins-conf/benchmark/methods/weka/src/nbc/NBC.java	Thu Jul  4 09:08:49 2013
@@ -0,0 +1,66 @@
+/**
+ * @file NBC.java
+ * @author Marcus Edel
+ *
+ * Naive Bayes Classifier with weka.
+ */
+
+import weka.classifiers.Classifier;
+import weka.classifiers.bayes.NaiveBayes;
+import weka.core.Instances;
+import weka.core.Utils;
+import weka.core.converters.ConverterUtils.DataSource;
+
+/**
+ * This class use the weka libary to implement Naive Bayes Classifier.
+ */
+public class NBC {
+  
+  private static final String USAGE = String
+      .format("This program trains the Naive Bayes classifier on the given\n"
+      + "labeled training set and then uses the trained classifier to classify\n"
+      + "the points in the given test set.\n\n"
+      + "Required options:\n"
+      + "-T [string]     A file containing the test set.\n"
+      + "-t [string]     A file containing the training set.");
+  
+  public static void main(String args[]) {
+	Timers timer = new Timers();	  
+    try {
+      // Get the data set path.
+      String trainFile = Utils.getOption('t', args);
+      String testFile = Utils.getOption('T', args);
+      if (trainFile.length() == 0 || testFile.length() == 0)
+    	  throw new IllegalArgumentException();
+        
+      // Load train and test dataset. 
+      DataSource source = new DataSource(trainFile);
+      Instances trainData = source.getDataSet();
+      // Use the last row of the training data as the labels.
+      trainData.setClassIndex((trainData.numAttributes() - 1));
+      
+      source = new DataSource(testFile);
+      Instances testData = source.getDataSet(); 
+      // Use the last row of the training data as the labels.
+      testData.setClassIndex((testData.numAttributes() - 1));
+      
+      timer.StartTimer("total_time");
+      // Create and train the classifier.   
+      Classifier cModel = (Classifier)new NaiveBayes();
+      cModel.buildClassifier(trainData);
+      
+      // Run Naive Bayes Classifier on the test dataset.
+      double prediction;
+      for (int i = 0; i < testData.numInstances(); i++)
+        prediction = cModel.classifyInstance(testData.instance(i));
+      
+      timer.StopTimer("total_time");
+      timer.PrintTimer("total_time");
+      
+    } catch (IllegalArgumentException e) {
+          System.err.println(USAGE);
+      } catch (Exception e) {
+        e.printStackTrace();
+      } 
+  }
+}



More information about the mlpack-svn mailing list