[mlpack-svn] r13805 - in mlpack/trunk/src/mlpack/bindings/matlab: . gmm kmeans range_search
fastlab-svn at coffeetalk-1.cc.gatech.edu
fastlab-svn at coffeetalk-1.cc.gatech.edu
Wed Oct 31 16:54:04 EDT 2012
Author: pmason8
Date: 2012-10-31 16:54:03 -0400 (Wed, 31 Oct 2012)
New Revision: 13805
Added:
mlpack/trunk/src/mlpack/bindings/matlab/gmm/
mlpack/trunk/src/mlpack/bindings/matlab/gmm/CMakeLists.txt
mlpack/trunk/src/mlpack/bindings/matlab/gmm/Makefile
mlpack/trunk/src/mlpack/bindings/matlab/gmm/gmm.cpp
mlpack/trunk/src/mlpack/bindings/matlab/gmm/gmm.m
mlpack/trunk/src/mlpack/bindings/matlab/kmeans/CMakeLists.txt
mlpack/trunk/src/mlpack/bindings/matlab/range_search/
mlpack/trunk/src/mlpack/bindings/matlab/range_search/CMakeLists.txt
mlpack/trunk/src/mlpack/bindings/matlab/range_search/Makefile
mlpack/trunk/src/mlpack/bindings/matlab/range_search/range_search.cpp
mlpack/trunk/src/mlpack/bindings/matlab/range_search/range_search.m
Modified:
mlpack/trunk/src/mlpack/bindings/matlab/CMakeLists.txt
mlpack/trunk/src/mlpack/bindings/matlab/kmeans/kmeans.cpp
Log:
added range_search and gmm; CMake updates for kmeans
Modified: mlpack/trunk/src/mlpack/bindings/matlab/CMakeLists.txt
===================================================================
--- mlpack/trunk/src/mlpack/bindings/matlab/CMakeLists.txt 2012-10-31 20:50:28 UTC (rev 13804)
+++ mlpack/trunk/src/mlpack/bindings/matlab/CMakeLists.txt 2012-10-31 20:54:03 UTC (rev 13805)
@@ -70,7 +70,11 @@
# Set MATLAB toolbox install directory.
set(MATLAB_TOOLBOX_DIR "${MATLAB_ROOT}/toolbox")
+# CHANGE HERE FOR NEW BINDINGS!!!!
add_subdirectory(emst)
+add_subdirectory(kmeans)
+add_subdirectory(range_search)
+add_subdirectory(gmm)
# Create a target whose sole purpose is to modify the pathdef.m MATLAB file so
# that the MLPACK toolbox is added to the MATLAB default path.
Added: mlpack/trunk/src/mlpack/bindings/matlab/gmm/CMakeLists.txt
===================================================================
--- mlpack/trunk/src/mlpack/bindings/matlab/gmm/CMakeLists.txt (rev 0)
+++ mlpack/trunk/src/mlpack/bindings/matlab/gmm/CMakeLists.txt 2012-10-31 20:54:03 UTC (rev 13805)
@@ -0,0 +1,19 @@
+# Simple rules for building mex file. The _mex suffix is necessary to avoid
+# target name conflicts, and the mex file must have a different name than the .m
+# file.
+add_library(gmm_mex SHARED
+ gmm.cpp
+)
+target_link_libraries(gmm_mex
+ mlpack
+ ${LIBXML2_LIBRARIES}
+)
+
+# Installation rule. Install both the mex and the MATLAB file.
+install(TARGETS gmm_mex
+ LIBRARY DESTINATION "${MATLAB_TOOLBOX_DIR}/mlpack/"
+)
+install(FILES
+ gmm.m
+ DESTINATION "${MATLAB_TOOLBOX_DIR}/mlpack/"
+)
Added: mlpack/trunk/src/mlpack/bindings/matlab/gmm/Makefile
===================================================================
--- mlpack/trunk/src/mlpack/bindings/matlab/gmm/Makefile (rev 0)
+++ mlpack/trunk/src/mlpack/bindings/matlab/gmm/Makefile 2012-10-31 20:54:03 UTC (rev 13805)
@@ -0,0 +1,22 @@
+gmm: gmm.o
+ g++ -O -pthread -shared \
+-Wl,--version-script,/opt/matlab/2010b/extern/lib/glnxa64/mexFunction.map \
+-Wl,--no-undefined -o 'mex_gmm.mexa64' gmm.o \
+-L../../build/lib -lmlpack \
+-Wl,-rpath-link,/opt/matlab/2010b/bin/glnxa64 \
+-L/opt/matlab/2010b/bin/glnxa64 -lmx -lmex -lmat -lm \
+-Wl,-rpath=/net/hu19/pmason8/mlpack/trunk/build/lib \
+-L/usr/lib64 -larmadillo \
+
+gmm.o:
+ g++ -c \
+-I../../build/include \
+-I../../build/include/mlpack/methods/gmm \
+-I/usr/include/libxml2 \
+-I/opt/matlab/2010b/extern/include \
+-DMATLAB_MEX_FILE \
+-ansi -D_GNU_SOURCE -fPIC -fno-omit-frame-pointer -pthread \
+-DMX_COMPAT_32 -O -DNDEBUG 'gmm.cpp'
+
+clean:
+ rm -f *.o *.mexa64
Added: mlpack/trunk/src/mlpack/bindings/matlab/gmm/gmm.cpp
===================================================================
--- mlpack/trunk/src/mlpack/bindings/matlab/gmm/gmm.cpp (rev 0)
+++ mlpack/trunk/src/mlpack/bindings/matlab/gmm/gmm.cpp 2012-10-31 20:54:03 UTC (rev 13805)
@@ -0,0 +1,123 @@
+#include "mex.h"
+
+#include "gmm.hpp"
+#include <iostream>
+
+using namespace mlpack;
+using namespace mlpack::gmm;
+using namespace mlpack::utilities;
+
+void mexFunction(int nlhs, mxArray *plhs[],
+ int nrhs, const mxArray *prhs[])
+{
+ // argument checks
+ if (nrhs != 3)
+ {
+ mexErrMsgTxt("Expecting three inputs.");
+ }
+
+ if (nlhs != 1)
+ {
+ mexErrMsgTxt("Output required.");
+ }
+
+ size_t seed = (size_t) mxGetScalar(prhs[2]);
+ // Check parameters and load data.
+ if (seed != 0)
+ math::RandomSeed(seed);
+ else
+ math::RandomSeed((size_t) std::time(NULL));
+
+ // loading the data
+ double * mexDataPoints = mxGetPr(prhs[0]);
+ size_t numPoints = mxGetN(prhs[0]);
+ size_t numDimensions = mxGetM(prhs[0]);
+ arma::mat dataPoints(numDimensions, numPoints);
+ for (int i = 0, n = numPoints * numDimensions; i < n; ++i)
+ {
+ dataPoints(i) = mexDataPoints[i];
+ }
+
+ int gaussians = (int) mxGetScalar(prhs[1]);
+ if (gaussians <= 0)
+ {
+ std::stringstream ss;
+ ss << "Invalid number of Gaussians (" << gaussians << "); must "
+ "be greater than or equal to 1." << std::endl;
+ mexErrMsgTxt(ss.str().c_str());
+ }
+
+ // Calculate mixture of Gaussians.
+ GMM<> gmm(size_t(gaussians), dataPoints.n_rows);
+
+ ////// Computing the parameters of the model using the EM algorithm //////
+ gmm.Estimate(dataPoints);
+
+ // setting up the matlab structure to be returned
+ mwSize ndim = 1;
+ mwSize dims[1] = {
+ 1
+ };
+ const char * fieldNames[3] = {
+ "dimensionality"
+ , "weights"
+ , "gaussians"
+ };
+
+ plhs[0] = mxCreateStructArray(ndim, dims, 3, fieldNames);
+
+ // dimensionality
+ mxArray * field_value;
+ field_value = mxCreateDoubleMatrix(1, 1, mxREAL);
+ *mxGetPr(field_value) = numDimensions;
+ mxSetFieldByNumber(plhs[0], 0, 0, field_value);
+
+ // mixture weights
+ field_value = mxCreateDoubleMatrix(gmm.Weights().size(), 1, mxREAL);
+ double * values = mxGetPr(field_value);
+ for (int i=0; i<gmm.Weights().size(); ++i)
+ {
+ values[i] = gmm.Weights()[i];
+ }
+ mxSetFieldByNumber(plhs[0], 0, 1, field_value);
+
+ // gaussian mean/variances
+ const char * gaussianNames[2] = {
+ "mean"
+ , "covariance"
+ };
+ ndim = 1;
+ dims[0] = gmm.Gaussians();
+
+ field_value = mxCreateStructArray(ndim, dims, 2, gaussianNames);
+ for (int i=0; i<gmm.Gaussians(); ++i)
+ {
+ mxArray * tmp;
+ double * values;
+
+ // setting the mean
+ arma::mat mean = gmm.Means()[i];
+ tmp = mxCreateDoubleMatrix(numDimensions, 1, mxREAL);
+ values = mxGetPr(tmp);
+ for (int j = 0; j < numDimensions; ++j)
+ {
+ values[j] = mean(j);
+ }
+ // note: SetField does not copy the data structure.
+ // mxDuplicateArray does the necessary copying.
+ mxSetFieldByNumber(field_value, i, 0, mxDuplicateArray(tmp));
+ mxDestroyArray(tmp);
+
+ // setting the covariance matrix
+ arma::mat covariance = gmm.Covariances()[i];
+ tmp = mxCreateDoubleMatrix(numDimensions, numDimensions, mxREAL);
+ values = mxGetPr(tmp);
+ for (int j = 0; j < numDimensions * numDimensions; ++j)
+ {
+ values[j] = covariance(j);
+ }
+ mxSetFieldByNumber(field_value, i, 1, mxDuplicateArray(tmp));
+ mxDestroyArray(tmp);
+ }
+ mxSetFieldByNumber(plhs[0], 0, 2, field_value);
+}
Added: mlpack/trunk/src/mlpack/bindings/matlab/gmm/gmm.m
===================================================================
--- mlpack/trunk/src/mlpack/bindings/matlab/gmm/gmm.m (rev 0)
+++ mlpack/trunk/src/mlpack/bindings/matlab/gmm/gmm.m 2012-10-31 20:54:03 UTC (rev 13805)
@@ -0,0 +1,28 @@
+function result = gmm(dataPoints, varargin)
+%Gaussian Mixture Model (GMM) Training
+%
+% This program takes a parametric estimate of a Gaussian mixture model (GMM)
+% using the EM algorithm to find the maximum likelihood estimate. The model is
+% saved to an XML file, which contains information about each Gaussian.
+%
+%Parameters:
+% dataPoints- (required) Matrix containing the data on which the model will be fit
+% seed - (optional) Random seed. If 0, 'std::time(NULL)' is used.
+% Default value is 0.
+% gaussians - (optional) Number of gaussians in the GMM. Default value is 1.
+
+% a parser for the inputs
+p = inputParser;
+p.addParamValue('gaussians', 1, @isscalar);
+p.addParamValue('seed', 0, @isscalar);
+
+% parsing the varargin options
+p.parse(varargin{:});
+parsed = p.Results;
+
+% interfacing with mlpack
+result = mex_gmm(dataPoints', parsed.gaussians, parsed.seed);
+
+
+
+
Added: mlpack/trunk/src/mlpack/bindings/matlab/kmeans/CMakeLists.txt
===================================================================
--- mlpack/trunk/src/mlpack/bindings/matlab/kmeans/CMakeLists.txt (rev 0)
+++ mlpack/trunk/src/mlpack/bindings/matlab/kmeans/CMakeLists.txt 2012-10-31 20:54:03 UTC (rev 13805)
@@ -0,0 +1,19 @@
+# Simple rules for building mex file. The _mex suffix is necessary to avoid
+# target name conflicts, and the mex file must have a different name than the .m
+# file.
+add_library(kmeans_mex SHARED
+ kmeans.cpp
+)
+target_link_libraries(kmeans_mex
+ mlpack
+ ${LIBXML2_LIBRARIES}
+)
+
+# Installation rule. Install both the mex and the MATLAB file.
+install(TARGETS kmeans_mex
+ LIBRARY DESTINATION "${MATLAB_TOOLBOX_DIR}/mlpack/"
+)
+install(FILES
+ kmeans.m
+ DESTINATION "${MATLAB_TOOLBOX_DIR}/mlpack/"
+)
Modified: mlpack/trunk/src/mlpack/bindings/matlab/kmeans/kmeans.cpp
===================================================================
--- mlpack/trunk/src/mlpack/bindings/matlab/kmeans/kmeans.cpp 2012-10-31 20:50:28 UTC (rev 13804)
+++ mlpack/trunk/src/mlpack/bindings/matlab/kmeans/kmeans.cpp 2012-10-31 20:54:03 UTC (rev 13805)
@@ -51,12 +51,12 @@
mexErrMsgTxt("Output required.");
}
- size_t seed = (size_t) mxGetScalar(prhs[6]);
+ size_t seed = (size_t) mxGetScalar(prhs[6]);
// Initialize random seed.
//if (CLI::GetParam<int>("seed") != 0)
- //math::RandomSeed((size_t) CLI::GetParam<int>("seed"));
- if (seed != 0)
+ //math::RandomSeed((size_t) CLI::GetParam<int>("seed"));
+ if (seed != 0)
math::RandomSeed(seed);
else
math::RandomSeed((size_t) std::time(NULL));
@@ -64,53 +64,53 @@
// Now do validation of options.
//string inputFile = CLI::GetParam<string>("inputFile");
//int clusters = CLI::GetParam<int>("clusters");
- int clusters = (int) mxGetScalar(prhs[1]);
+ int clusters = (int) mxGetScalar(prhs[1]);
if (clusters < 1)
{
- stringstream ss;
- ss << "Invalid number of clusters requested (" << clusters << ")! "
+ stringstream ss;
+ ss << "Invalid number of clusters requested (" << clusters << ")! "
<< "Must be greater than or equal to 1.";
- mexErrMsgTxt(ss.str().c_str());
+ mexErrMsgTxt(ss.str().c_str());
}
//int maxIterations = CLI::GetParam<int>("max_iterations");
- int maxIterations = (int) mxGetScalar(prhs[2]);
+ int maxIterations = (int) mxGetScalar(prhs[2]);
if (maxIterations < 0)
{
- stringstream ss;
+ stringstream ss;
ss << "Invalid value for maximum iterations (" << maxIterations <<
")! Must be greater than or equal to 0.";
- mexErrMsgTxt(ss.str().c_str());
+ mexErrMsgTxt(ss.str().c_str());
}
//double overclustering = CLI::GetParam<double>("overclustering");
- double overclustering = mxGetScalar(prhs[3]);
+ double overclustering = mxGetScalar(prhs[3]);
if (overclustering < 1)
{
- stringstream ss;
+ stringstream ss;
ss << "Invalid value for overclustering (" << overclustering <<
")! Must be greater than or equal to 1.";
- mexErrMsgTxt(ss.str().c_str());
+ mexErrMsgTxt(ss.str().c_str());
}
- const bool allow_empty_clusters = (mxGetScalar(prhs[4]) == 1.0);
- const bool fast_kmeans = (mxGetScalar(prhs[5]) == 1.0);
+ const bool allow_empty_clusters = (mxGetScalar(prhs[4]) == 1.0);
+ const bool fast_kmeans = (mxGetScalar(prhs[5]) == 1.0);
- /*
+ /*
// Make sure we have an output file if we're not doing the work in-place.
if (!CLI::HasParam("in_place") && !CLI::HasParam("outputFile"))
{
Log::Fatal << "--outputFile not specified (and --in_place not set)."
<< std::endl;
}
- */
+ */
// Load our dataset.
- const size_t numPoints = mxGetN(prhs[0]);
+ const size_t numPoints = mxGetN(prhs[0]);
const size_t numDimensions = mxGetM(prhs[0]);
arma::mat dataset(numDimensions, numPoints);
- // setting the values.
+ // setting the values.
double * mexDataPoints = mxGetPr(prhs[0]);
for (int i = 0, n = numPoints * numDimensions; i < n; ++i)
{
@@ -122,13 +122,13 @@
arma::Col<size_t> assignments;
//if (CLI::HasParam("allow_empty_clusters"))
- if (allow_empty_clusters)
+ if (allow_empty_clusters)
{
KMeans<metric::SquaredEuclideanDistance, RandomPartition,
AllowEmptyClusters> k(maxIterations, overclustering);
//if (CLI::HasParam("fast_kmeans"))
- if (fast_kmeans)
+ if (fast_kmeans)
k.FastCluster(dataset, clusters, assignments);
else
k.Cluster(dataset, clusters, assignments);
@@ -138,13 +138,13 @@
KMeans<> k(maxIterations, overclustering);
//if (CLI::HasParam("fast_kmeans"))
- if (fast_kmeans)
+ if (fast_kmeans)
k.FastCluster(dataset, clusters, assignments);
else
k.Cluster(dataset, clusters, assignments);
}
- /*
+ /*
// Now figure out what to do with our results.
if (CLI::HasParam("in_place"))
{
@@ -182,7 +182,7 @@
data::Save(outputFile.c_str(), dataset);
}
}
- */
+ */
// constructing matrix to return to matlab
plhs[0] = mxCreateDoubleMatrix(assignments.n_elem, 1, mxREAL);
Added: mlpack/trunk/src/mlpack/bindings/matlab/range_search/CMakeLists.txt
===================================================================
--- mlpack/trunk/src/mlpack/bindings/matlab/range_search/CMakeLists.txt (rev 0)
+++ mlpack/trunk/src/mlpack/bindings/matlab/range_search/CMakeLists.txt 2012-10-31 20:54:03 UTC (rev 13805)
@@ -0,0 +1,19 @@
+# Simple rules for building mex file. The _mex suffix is necessary to avoid
+# target name conflicts, and the mex file must have a different name than the .m
+# file.
+add_library(range_search_mex SHARED
+ range_search.cpp
+)
+target_link_libraries(range_search_mex
+ mlpack
+ ${LIBXML2_LIBRARIES}
+)
+
+# Installation rule. Install both the mex and the MATLAB file.
+install(TARGETS range_search_mex
+ LIBRARY DESTINATION "${MATLAB_TOOLBOX_DIR}/mlpack/"
+)
+install(FILES
+ range_search.m
+ DESTINATION "${MATLAB_TOOLBOX_DIR}/mlpack/"
+)
Added: mlpack/trunk/src/mlpack/bindings/matlab/range_search/Makefile
===================================================================
--- mlpack/trunk/src/mlpack/bindings/matlab/range_search/Makefile (rev 0)
+++ mlpack/trunk/src/mlpack/bindings/matlab/range_search/Makefile 2012-10-31 20:54:03 UTC (rev 13805)
@@ -0,0 +1,21 @@
+range_search: range_search.o
+ g++ -O -pthread -shared \
+-Wl,--version-script,/opt/matlab/2010b/extern/lib/glnxa64/mexFunction.map \
+-Wl,--no-undefined -o 'mex_range_search.mexa64' range_search.o \
+-L../../build/lib -lmlpack \
+-Wl,-rpath-link,/opt/matlab/2010b/bin/glnxa64 \
+-L/opt/matlab/2010b/bin/glnxa64 -lmx -lmex -lmat -lm \
+-Wl,-rpath=/net/hu19/pmason8/mlpack/trunk/build/lib \
+
+range_search.o:
+ g++ -c \
+-I../../build/include \
+-I../../build/include/mlpack/methods/range_search \
+-I/usr/include/libxml2 \
+-I/opt/matlab/2010b/extern/include \
+-DMATLAB_MEX_FILE \
+-ansi -D_GNU_SOURCE -fPIC -fno-omit-frame-pointer -pthread \
+-DMX_COMPAT_32 -O -DNDEBUG 'range_search.cpp'
+
+clean:
+ rm -f *.o *.mexa64
Added: mlpack/trunk/src/mlpack/bindings/matlab/range_search/range_search.cpp
===================================================================
--- mlpack/trunk/src/mlpack/bindings/matlab/range_search/range_search.cpp (rev 0)
+++ mlpack/trunk/src/mlpack/bindings/matlab/range_search/range_search.cpp 2012-10-31 20:54:03 UTC (rev 13805)
@@ -0,0 +1,368 @@
+#include "mex.h"
+
+#include <mlpack/core.hpp>
+#include <mlpack/core/metrics/lmetric.hpp>
+
+#include "range_search.hpp"
+
+using namespace std;
+using namespace mlpack;
+using namespace mlpack::range;
+using namespace mlpack::tree;
+
+/*
+// Information about the program itself.
+PROGRAM_INFO("Range Search",
+ "This program implements range search with a Euclidean distance metric. "
+ "For a given query point, a given range, and a given set of reference "
+ "points, the program will return all of the reference points with distance "
+ "to the query point in the given range. This is performed for an entire "
+ "set of query points. You may specify a separate set of reference and query"
+ " points, or only a reference set -- which is then used as both the "
+ "reference and query set. The given range is taken to be inclusive (that "
+ "is, points with a distance exactly equal to the minimum and maximum of the"
+ " range are included in the results)."
+ "\n\n"
+ "For example, the following will calculate the points within the range [2, "
+ "5] of each point in 'input.csv' and store the distances in 'distances.csv'"
+ " and the neighbors in 'neighbors.csv':"
+ "\n\n"
+ "$ range_search --min=2 --max=5 --reference_file=input.csv\n"
+ " --distances_file=distances.csv --neighbors_file=neighbors.csv"
+ "\n\n"
+ "The output files are organized such that line i corresponds to the points "
+ "found for query point i. Because sometimes 0 points may be found in the "
+ "given range, lines of the output files may be empty. The points are not "
+ "ordered in any specific manner."
+ "\n\n"
+ "Because the number of points returned for each query point may differ, the"
+ " resultant CSV-like files may not be loadable by many programs. However, "
+ "at this time a better way to store this non-square result is not known. "
+ "As a result, any output files will be written as CSVs in this manner, "
+ "regardless of the given extension.");
+
+// Define our input parameters that this program will take.
+PARAM_STRING_REQ("reference_file", "File containing the reference dataset.",
+ "r");
+PARAM_STRING_REQ("distances_file", "File to output distances into.", "d");
+PARAM_STRING_REQ("neighbors_file", "File to output neighbors into.", "n");
+
+PARAM_DOUBLE_REQ("max", "Upper bound in range.", "M");
+PARAM_DOUBLE("min", "Lower bound in range.", "m", 0.0);
+
+PARAM_STRING("query_file", "File containing query points (optional).", "q", "");
+
+PARAM_INT("leaf_size", "Leaf size for tree building.", "l", 20);
+PARAM_FLAG("naive", "If true, O(n^2) naive mode is used for computation.", "N");
+PARAM_FLAG("single_mode", "If true, single-tree search is used (as opposed to "
+ "dual-tree search.", "s");
+*/
+
+typedef RangeSearch<metric::SquaredEuclideanDistance,
+ BinarySpaceTree<bound::HRectBound<2>, EmptyStatistic> > RSType;
+
+// the gateway, required by all mex functions
+void mexFunction(int nlhs, mxArray *plhs[],
+ int nrhs, const mxArray *prhs[])
+{
+ // Give CLI the command line parameters the user passed in.
+ //CLI::ParseCommandLine(argc, argv);
+
+ // Get all the parameters.
+ //string referenceFile = CLI::GetParam<string>("reference_file");
+ //string distancesFile = CLI::GetParam<string>("distances_file");
+ //string neighborsFile = CLI::GetParam<string>("neighbors_file");
+
+ //int lsInt = CLI::GetParam<int>("leaf_size");
+ //double max = CLI::GetParam<double>("max");
+ //double min = CLI::GetParam<double>("min");
+ //bool naive = CLI::HasParam("naive");
+ //bool singleMode = CLI::HasParam("single_mode");
+
+ // argument checks
+ if (nrhs != 7)
+ {
+ mexErrMsgTxt("Expecting an datapoints matrix, isBoruvka, and leafSize.");
+ }
+
+ if (nlhs != 1)
+ {
+ mexErrMsgTxt("Output required.");
+ }
+
+ double max = mxGetScalar(prhs[1]);
+ double min = mxGetScalar(prhs[2]);
+ int lsInt = (int) mxGetScalar(prhs[4]);
+ bool naive = (mxGetScalar(prhs[5]) == 1.0);
+ bool singleMode = (mxGetScalar(prhs[6]) == 1.0);
+
+ // checking for query data
+ bool hasQueryData = ((mxGetM(prhs[3]) != 0) && (mxGetN(prhs[3]) != 0));
+ arma::mat queryData;
+
+ // setting the dataset values.
+ double * mexDataPoints = mxGetPr(prhs[0]);
+ size_t numPoints = mxGetN(prhs[0]);
+ size_t numDimensions = mxGetM(prhs[0]);
+ arma::mat referenceData(numDimensions, numPoints);
+ for (int i = 0, n = numPoints * numDimensions; i < n; ++i)
+ {
+ referenceData(i) = mexDataPoints[i];
+ }
+
+ //if (!data::Load(referenceFile.c_str(), referenceData))
+ // Log::Fatal << "Reference file " << referenceFile << "not found." << endl;
+
+ //Log::Info << "Loaded reference data from '" << referenceFile << "'." << endl;
+
+ // Sanity check on range value: max must be greater than min.
+ if (max <= min)
+ {
+ stringstream ss;
+ ss << "Invalid range: maximum (" << max << ") must be greater than "
+ << "minimum (" << min << ").";
+ mexErrMsgTxt(ss.str().c_str());
+ }
+
+ // Sanity check on leaf size.
+ if (lsInt < 0)
+ {
+ stringstream ss;
+ ss << "Invalid leaf size: " << lsInt << ". Must be greater "
+ "than or equal to 0.";
+ mexErrMsgTxt(ss.str().c_str());
+ }
+
+ size_t leafSize = lsInt;
+
+ // Naive mode overrides single mode.
+ if (singleMode && naive)
+ {
+ mexWarnMsgTxt("single_mode ignored because naive is present.");
+ }
+
+ if (naive)
+ leafSize = referenceData.n_cols;
+
+ vector<vector<size_t> > neighbors;
+ vector<vector<double> > distances;
+
+ // Because we may construct it differently, we need a pointer.
+ RSType* rangeSearch = NULL;
+
+ // Mappings for when we build the tree.
+ vector<size_t> oldFromNewRefs;
+
+ // Build trees by hand, so we can save memory: if we pass a tree to
+ // NeighborSearch, it does not copy the matrix.
+ //Log::Info << "Building reference tree..." << endl;
+ //Timer::Start("tree_building");
+
+ BinarySpaceTree<bound::HRectBound<2>, tree::EmptyStatistic>
+ refTree(referenceData, oldFromNewRefs, leafSize);
+ BinarySpaceTree<bound::HRectBound<2>, tree::EmptyStatistic>*
+ queryTree = NULL; // Empty for now.
+
+ //Timer::Stop("tree_building");
+
+ std::vector<size_t> oldFromNewQueries;
+
+ //if (CLI::GetParam<string>("query_file") != "")
+ if (hasQueryData)
+ {
+ //string queryFile = CLI::GetParam<string>("query_file");
+ //if (!data::Load(queryFile.c_str(), queryData))
+ // Log::Fatal << "Query file " << queryFile << " not found" << endl;
+
+ // setting the values.
+ mexDataPoints = mxGetPr(prhs[3]);
+ numPoints = mxGetN(prhs[3]);
+ numDimensions = mxGetM(prhs[3]);
+ queryData = arma::mat(numDimensions, numPoints);
+ for (int i = 0, n = numPoints * numDimensions; i < n; ++i)
+ {
+ queryData(i) = mexDataPoints[i];
+ }
+
+ if (naive && leafSize < queryData.n_cols)
+ leafSize = queryData.n_cols;
+
+ //Log::Info << "Loaded query data from '" << queryFile << "'." << endl;
+
+ //Log::Info << "Building query tree..." << endl;
+
+ // Build trees by hand, so we can save memory: if we pass a tree to
+ // NeighborSearch, it does not copy the matrix.
+ //Timer::Start("tree_building");
+
+ queryTree = new BinarySpaceTree<bound::HRectBound<2>,
+ tree::EmptyStatistic >(queryData, oldFromNewQueries,
+ leafSize);
+
+ //Timer::Stop("tree_building");
+
+ rangeSearch = new RSType(&refTree, queryTree, referenceData, queryData,
+ singleMode);
+
+ //Log::Info << "Tree built." << endl;
+ }
+ else
+ {
+ rangeSearch = new RSType(&refTree, referenceData, singleMode);
+
+ //Log::Info << "Trees built." << endl;
+ }
+
+ //Log::Info << "Computing neighbors within range [" << min << ", " << max
+ // << "]." << endl;
+
+ math::Range r = math::Range(min, max);
+ rangeSearch->Search(r, neighbors, distances);
+
+ //Log::Info << "Neighbors computed." << endl;
+
+ // We have to map back to the original indices from before the tree
+ // construction.
+ //Log::Info << "Re-mapping indices..." << endl;
+
+ vector<vector<double> > distancesOut;
+ distancesOut.resize(distances.size());
+ vector<vector<size_t> > neighborsOut;
+ neighborsOut.resize(neighbors.size());
+
+ // Do the actual remapping.
+ //if (CLI::GetParam<string>("query_file") != "")
+ if (hasQueryData)
+ {
+ for (size_t i = 0; i < distances.size(); ++i)
+ {
+ // Map distances (copy a column).
+ distancesOut[oldFromNewQueries[i]] = distances[i];
+
+ // Map indices of neighbors.
+ neighborsOut[oldFromNewQueries[i]].resize(neighbors[i].size());
+ for (size_t j = 0; j < distances[i].size(); ++j)
+ {
+ neighborsOut[oldFromNewQueries[i]][j] = oldFromNewRefs[neighbors[i][j]];
+ }
+ }
+ }
+ else
+ {
+ for (size_t i = 0; i < distances.size(); ++i)
+ {
+ // Map distances (copy a column).
+ distancesOut[oldFromNewRefs[i]] = distances[i];
+
+ // Map indices of neighbors.
+ neighborsOut[oldFromNewRefs[i]].resize(neighbors[i].size());
+ for (size_t j = 0; j < distances[i].size(); ++j)
+ {
+ neighborsOut[oldFromNewRefs[i]][j] = oldFromNewRefs[neighbors[i][j]];
+ }
+ }
+ }
+
+ // Setting values to be returned to matlab
+ mwSize ndim = 1;
+ mwSize dims[1] = {distancesOut.size()};
+ const char * fieldNames[2] = {
+ "neighbors"
+ , "distances"
+ };
+
+ plhs[0] = mxCreateStructArray(ndim, dims, 2, fieldNames);
+
+ // setting the structure elements
+ for (int i=0; i<distancesOut.size(); ++i)
+ {
+ mxArray * tmp;
+ double * values;
+
+ // settings the neighbors
+ const size_t numElements = distancesOut[i].size();
+ tmp = mxCreateDoubleMatrix(1, numElements, mxREAL);
+ values = mxGetPr(tmp);
+ for (int j=0; j<numElements; ++j)
+ {
+ // converting to matlab's index offset
+ values[j] = neighborsOut[i][j] + 1;
+ }
+ // note: SetField does not copy the data structure.
+ // mxDuplicateArray does the necessary copying.
+ mxSetFieldByNumber(plhs[0], i, 0, mxDuplicateArray(tmp));
+ mxDestroyArray(tmp);
+
+ // setting the distances
+ tmp = mxCreateDoubleMatrix(1, numElements, mxREAL);
+ values = mxGetPr(tmp);
+ for (int j=0; j<numElements; ++j)
+ {
+ values[j] = distancesOut[i][j];
+ }
+ mxSetFieldByNumber(plhs[0], i, 1, mxDuplicateArray(tmp));
+ mxDestroyArray(tmp);
+ }
+
+ // Clean up.
+ if (queryTree)
+ delete queryTree;
+ delete rangeSearch;
+
+ /*
+ // Save output. We have to do this by hand.
+ fstream distancesStr(distancesFile.c_str(), fstream::out);
+ if (!distancesStr.is_open())
+ {
+ Log::Warn << "Cannot open file '" << distancesFile << "' to save output "
+ << "distances to!" << endl;
+ }
+ else
+ {
+ // Loop over each point.
+ for (size_t i = 0; i < distancesOut.size(); ++i)
+ {
+ // Store the distances of each point. We may have 0 points to store, so
+ // we must account for that possibility.
+ for (size_t j = 0; j + 1 < distancesOut[i].size(); ++j)
+ {
+ distancesStr << distancesOut[i][j] << ", ";
+ }
+
+ if (distancesOut[i].size() > 0)
+ distancesStr << distancesOut[i][distancesOut[i].size() - 1];
+
+ distancesStr << endl;
+ }
+
+ distancesStr.close();
+ }
+
+ fstream neighborsStr(neighborsFile.c_str(), fstream::out);
+ if (!neighborsStr.is_open())
+ {
+ Log::Warn << "Cannot open file '" << neighborsFile << "' to save output "
+ << "neighbor indices to!" << endl;
+ }
+ else
+ {
+ // Loop over each point.
+ for (size_t i = 0; i < neighborsOut.size(); ++i)
+ {
+ // Store the neighbors of each point. We may have 0 points to store, so
+ // we must account for that possibility.
+ for (size_t j = 0; j + 1 < neighborsOut[i].size(); ++j)
+ {
+ neighborsStr << neighborsOut[i][j] << ", ";
+ }
+
+ if (neighborsOut[i].size() > 0)
+ neighborsStr << neighborsOut[i][neighborsOut[i].size() - 1];
+
+ neighborsStr << endl;
+ }
+
+ neighborsStr.close();
+ }
+ */
+}
Added: mlpack/trunk/src/mlpack/bindings/matlab/range_search/range_search.m
===================================================================
--- mlpack/trunk/src/mlpack/bindings/matlab/range_search/range_search.m (rev 0)
+++ mlpack/trunk/src/mlpack/bindings/matlab/range_search/range_search.m 2012-10-31 20:54:03 UTC (rev 13805)
@@ -0,0 +1,47 @@
+function result = range_search(dataPoints, maxDistance, varargin)
+%Range Search
+%
+% This function implements range search with a Euclidean distance metric. For a
+% given query point, a given range, and a given set of reference points, the
+% program will return all of the reference points with distance to the query
+% point in the given range. This is performed for an entire set of query
+% points. You may specify a separate set of reference and query points, or only
+% a reference set -- which is then used as both the reference and query set.
+% The given range is taken to be inclusive (that is, points with a distance
+% exactly equal to the minimum and maximum of the range are included in the
+% results).
+%
+% For example, the following will calculate the points within the range [2, 5]
+% of each point in 'input.csv' and store the distances in 'distances.csv' and
+% the neighbors in 'neighbors.csv':
+%
+%Parameters:
+% dataPoints - (required) Matrix containing the reference dataset.
+% maxDistance - (required) The upper bound of the range.
+% minDistance - (optional) The lower bound. The default value is zero.
+% queryPoints - (optional) Range search query points.
+% leafSize - (optional) Leaf size for tree building. Default value 20.
+% naive - (optional) If true, O(n^2) naive mode is used for computation.
+% singleMode - (optional) If true, single-tree search is used (as opposed to
+% dual-tree search.
+
+% a parser for the inputs
+p = inputParser;
+p.addParamValue('minDistance', 0, @isscalar);
+p.addParamValue('queryPoints', zeros(0), @ismatrix);
+p.addParamValue('leafSize', 20, @isscalar);
+p.addParamValue('naive', false, @(x) (x == true) || (x == false));
+p.addParamValue('singleMode', false, @(x) (x == true) || (x == false));
+
+% parsing the varargin options
+p.parse(varargin{:});
+parsed = p.Results;
+
+% interfacing with mlpack
+result = mex_range_search(dataPoints', maxDistance, ...
+ parsed.minDistance, parsed.queryPoints', parsed.leafSize, ...
+ parsed.naive, parsed.singleMode);
+
+
+
+
More information about the mlpack-svn
mailing list