[mlpack-svn] r13878 - mlpack/trunk/src/mlpack/bindings/matlab/emst
fastlab-svn at coffeetalk-1.cc.gatech.edu
fastlab-svn at coffeetalk-1.cc.gatech.edu
Fri Nov 16 13:17:04 EST 2012
Author: rcurtin
Date: 2012-11-16 13:17:03 -0500 (Fri, 16 Nov 2012)
New Revision: 13878
Modified:
mlpack/trunk/src/mlpack/bindings/matlab/emst/emst.cpp
mlpack/trunk/src/mlpack/bindings/matlab/emst/emst.m
Log:
Clean up MATLAB bindings for EMST.
Modified: mlpack/trunk/src/mlpack/bindings/matlab/emst/emst.cpp
===================================================================
--- mlpack/trunk/src/mlpack/bindings/matlab/emst/emst.cpp 2012-11-16 15:48:34 UTC (rev 13877)
+++ mlpack/trunk/src/mlpack/bindings/matlab/emst/emst.cpp 2012-11-16 18:17:03 UTC (rev 13878)
@@ -1,3 +1,9 @@
+/**
+ * @file emst.cpp
+ * @author Patrick Mason
+ *
+ * MEX function for MATLAB EMST binding.
+ */
#include "mex.h"
#include <mlpack/core.hpp>
@@ -9,11 +15,11 @@
using namespace mlpack::emst;
using namespace mlpack::tree;
-// the gateway, required by all mex functions
+// The gateway, required by all mex functions.
void mexFunction(int nlhs, mxArray *plhs[],
int nrhs, const mxArray *prhs[])
{
- // argument checks
+ // Argument checks.
if (nrhs != 3)
{
mexErrMsgTxt("Expecting an datapoints matrix, isBoruvka, and leafSize.");
@@ -24,28 +30,26 @@
mexErrMsgTxt("Output required.");
}
- // getting the dimensions of the input matrix
const size_t numPoints = mxGetN(prhs[0]);
const size_t numDimensions = mxGetM(prhs[0]);
- // converting from mxArray to armadillo matrix
+ // Converting from mxArray to armadillo matrix.
arma::mat dataPoints(numDimensions, numPoints);
- // setting the values.
- double * mexDataPoints = mxGetPr(prhs[0]);
+ // Set the values.
+ double* mexDataPoints = mxGetPr(prhs[0]);
for (int i = 0, n = numPoints * numDimensions; i < n; ++i)
{
dataPoints(i) = mexDataPoints[i];
}
- // getting the isBoruvka bit
const bool isBoruvka = (mxGetScalar(prhs[1]) == 1.0);
- // running the computation
+ // Run the computation.
arma::mat result;
if (isBoruvka)
{
- // getting the number of leaves
+ // Get the number of leaves.
const size_t leafSize = (size_t) mxGetScalar(prhs[2]);
DualTreeBoruvka<> dtb(dataPoints, false, leafSize);
@@ -57,15 +61,12 @@
naive.ComputeMST(result);
}
- // constructing matrix to return to matlab
- plhs[0] = mxCreateDoubleMatrix(3, numPoints-1, mxREAL);
+ // Construct matrix to return to MATLAB.
+ plhs[0] = mxCreateDoubleMatrix(3, numPoints - 1, mxREAL);
- // setting the values
- double * out = mxGetPr(plhs[0]);
+ double* out = mxGetPr(plhs[0]);
for (int i = 0, n = (numPoints - 1) * 3; i < n; ++i)
{
out[i] = result(i);
}
-
- return;
}
Modified: mlpack/trunk/src/mlpack/bindings/matlab/emst/emst.m
===================================================================
--- mlpack/trunk/src/mlpack/bindings/matlab/emst/emst.m 2012-11-16 15:48:34 UTC (rev 13877)
+++ mlpack/trunk/src/mlpack/bindings/matlab/emst/emst.m 2012-11-16 18:17:03 UTC (rev 13878)
@@ -1,6 +1,7 @@
function result = emst(dataPoints, varargin)
-% Fast Euclidean Minimum Spanning Tree. This script can compute
-% the Euclidean minimum spanning tree of a set of input points using the
+% result = emst(dataPoints, varargin)
+%
+% Compute the Euclidean minimum spanning tree of a set of input points using the
% dual-tree Boruvka algorithm.
%
% The output is saved in a three-column matrix, where each row indicates an
@@ -9,36 +10,40 @@
% column corresponds to the distance between the two points.
%
% Parameters:
-% dataPoints - the matrix of data points. Columns are assumed to represent dimensions,
-% with rows representing seperate points.
-% method - the algorithm for computing the tree. 'naive' or 'boruvka', with
-% 'boruvka' being the default algorithm.
+%
+% dataPoints - The matrix of data points. Columns are assumed to represent
+% dimensions, with rows representing separate points.
+% method - The algorithm for computing the tree. 'naive' or 'boruvka', with
+% 'boruvka' being the default dual-tree Boruvka algorithm.
% leafSize - Leaf size in the kd-tree. One-element leaves give the
% empirically best performance, but at the cost of greater memory
-% requirements. One is default.
+% requirements. Defaults to 1.
%
% Examples:
+%
% result = emst(dataPoints);
-% or
-% esult = emst(dataPoints,'method','naive');
+% result = emst(dataPoints, 'method', 'naive');
+% result = emst(dataPoints, 'method', 'naive', 'leafSize', 5);
-% a parser for the inputs
+% A parser for the inputs.
p = inputParser;
-p.addParamValue('method', 'boruvka', @(x) strcmpi(x, 'naive') || strcmpi(x, 'boruvka'));
+p.addParamValue('method', 'boruvka', ...
+ @(x) strcmpi(x, 'naive') || strcmpi(x, 'boruvka'));
p.addParamValue('leafSize', 1, @isscalar);
-% parsing the varargin options
+% Parse the varargin options.
p.parse(varargin{:});
parsed = p.Results;
-% interfacing with mlpack. transposing to machine learning standards.
+% Interface with mlpack. Transpose to machine learning standards. MLPACK
+% expects column-major matrices; the user has passed in a row-major matrix.
if strcmpi(parsed.method, 'boruvka')
result = emst_mex(dataPoints', 1, parsed.leafSize);
- result = result';
+ result = result';
return;
else
result = emst_mex(dataPoints', 0, 1);
- result = result';
+ result = result';
return;
end
More information about the mlpack-svn
mailing list