[mlpack-svn] r14161 - in mlpack/trunk/src/mlpack/bindings/matlab: . allkfn allknn hmm
fastlab-svn at coffeetalk-1.cc.gatech.edu
fastlab-svn at coffeetalk-1.cc.gatech.edu
Wed Jan 23 14:39:51 EST 2013
Author: pmason8
Date: 2013-01-23 14:39:51 -0500 (Wed, 23 Jan 2013)
New Revision: 14161
Added:
mlpack/trunk/src/mlpack/bindings/matlab/hmm/
mlpack/trunk/src/mlpack/bindings/matlab/hmm/hmm_generate.cpp
mlpack/trunk/src/mlpack/bindings/matlab/hmm/hmm_generate.m
Modified:
mlpack/trunk/src/mlpack/bindings/matlab/allkfn/allkfn.cpp
mlpack/trunk/src/mlpack/bindings/matlab/allknn/allknn.cpp
Log:
cleaned up commented code in allknn allkfn. adding the incomplete hmm binding
Modified: mlpack/trunk/src/mlpack/bindings/matlab/allkfn/allkfn.cpp
===================================================================
--- mlpack/trunk/src/mlpack/bindings/matlab/allkfn/allkfn.cpp 2013-01-23 18:54:56 UTC (rev 14160)
+++ mlpack/trunk/src/mlpack/bindings/matlab/allkfn/allkfn.cpp 2013-01-23 19:39:51 UTC (rev 14161)
@@ -57,10 +57,6 @@
arma::mat queryData;
bool hasQueryData = ((mxGetM(prhs[2]) != 0) && (mxGetN(prhs[2]) != 0));
- //arma::mat referenceData;
- //arma::mat queryData; // So it doesn't go out of scope.
- //data::Load(referenceFile.c_str(), referenceData, true);
-
// Sanity check on k value: must be greater than 0, must be less than the
// number of reference points.
if (k > referenceData.n_cols)
@@ -107,7 +103,6 @@
std::vector<size_t> oldFromNewQueries;
- //f (CLI::GetParam<string>("query_file") != "")
if (hasQueryData)
{
// setting the values.
@@ -145,7 +140,6 @@
arma::Mat<size_t> neighborsOut(neighbors.n_rows, neighbors.n_cols);
// Do the actual remapping.
- //if (CLI::GetParam<string>("query_file") != "")
if (hasQueryData)
{
for (size_t i = 0; i < distances.n_cols; ++i)
Modified: mlpack/trunk/src/mlpack/bindings/matlab/allknn/allknn.cpp
===================================================================
--- mlpack/trunk/src/mlpack/bindings/matlab/allknn/allknn.cpp 2013-01-23 18:54:56 UTC (rev 14160)
+++ mlpack/trunk/src/mlpack/bindings/matlab/allknn/allknn.cpp 2013-01-23 19:39:51 UTC (rev 14161)
@@ -63,29 +63,6 @@
// cover-tree?
bool usesCoverTree = (mxGetScalar(prhs[6]) == 1.0);
- /*
- // Give CLI the command line parameters the user passed in.
- // CLI::ParseCommandLine(argc, argv);
-
- // Get all the parameters.
- string referenceFile = CLI::GetParam<string>("reference_file");
- string distancesFile = CLI::GetParam<string>("distances_file");
- string neighborsFile = CLI::GetParam<string>("neighbors_file");
-
- int lsInt = CLI::GetParam<int>("leaf_size");
- size_t k = CLI::GetParam<int>("k");
-
- bool naive = CLI::HasParam("naive");
- bool singleMode = CLI::HasParam("single_mode");
-
- arma::mat referenceData;
- arma::mat queryData; // So it doesn't go out of scope.
- data::Load(referenceFile.c_str(), referenceData, true);
-
- Log::Info << "Loaded reference data from '" << referenceFile << "' ("
- << referenceData.n_rows << " x " << referenceData.n_cols << ")." << endl;
- */
-
// Sanity check on k value: must be greater than 0, must be less than the
// number of reference points.
if (k > referenceData.n_cols)
@@ -138,12 +115,8 @@
std::vector<size_t> oldFromNewQueries;
- //if (CLI::GetParam<string>("query_file") != "")
if (hasQueryData)
{
- //string queryFile = CLI::GetParam<string>("query_file");
- //data::Load(queryFile.c_str(), queryData, true);
-
// setting the values.
mexDataPoints = mxGetPr(prhs[2]);
numPoints = mxGetN(prhs[2]);
@@ -185,7 +158,6 @@
distances.set_size(distancesOut.n_rows, distancesOut.n_cols);
// Do the actual remapping.
- //if ((CLI::GetParam<string>("query_file") != "") && !singleMode)
if ((hasQueryData) && !singleMode)
{
for (size_t i = 0; i < distancesOut.n_cols; ++i)
@@ -201,7 +173,6 @@
}
}
}
- //else if ((CLI::GetParam<string>("query_file") != "") && singleMode)
else if ((hasQueryData) && singleMode)
{
// No remapping of queries is necessary. So distances are the same.
@@ -247,12 +218,8 @@
QueryStat<NearestNeighborSort> > >* allknn = NULL;
// See if we have query data.
- //if (CLI::HasParam("query_file"))
if (hasQueryData)
{
- //string queryFile = CLI::GetParam<string>("query_file");
- //data::Load(queryFile, queryData, true);
-
// setting the values.
mexDataPoints = mxGetPr(prhs[2]);
numPoints = mxGetN(prhs[2]);
@@ -293,8 +260,6 @@
}
// writing back to matlab
- //data::Save(distancesFile, distances);
- //data::Save(neighborsFile, neighbors);
// constructing matrix to return to matlab
plhs[0] = mxCreateDoubleMatrix(distances.n_rows, distances.n_cols, mxREAL);
plhs[1] = mxCreateDoubleMatrix(neighbors.n_rows, neighbors.n_cols, mxREAL);
Added: mlpack/trunk/src/mlpack/bindings/matlab/hmm/hmm_generate.cpp
===================================================================
--- mlpack/trunk/src/mlpack/bindings/matlab/hmm/hmm_generate.cpp (rev 0)
+++ mlpack/trunk/src/mlpack/bindings/matlab/hmm/hmm_generate.cpp 2013-01-23 19:39:51 UTC (rev 14161)
@@ -0,0 +1,373 @@
+#include "mex.h"
+
+#include <mlpack/core.hpp>
+
+#include "hmm.hpp"
+#include "hmm_util.hpp"
+#include <mlpack/methods/gmm/gmm.hpp>
+
+/*
+PROGRAM_INFO("Hidden Markov Model (HMM) Sequence Generator", "This "
+ "utility takes an already-trained HMM (--model_file) and generates a "
+ "random observation sequence and hidden state sequence based on its "
+ "parameters, saving them to the specified files (--output_file and "
+ "--state_file)");
+
+PARAM_STRING_REQ("model_file", "File containing HMM (XML).", "m");
+PARAM_INT_REQ("length", "Length of sequence to generate.", "l");
+
+PARAM_INT("start_state", "Starting state of sequence.", "t", 0);
+PARAM_STRING("output_file", "File to save observation sequence to.", "o",
+ "output.csv");
+PARAM_STRING("state_file", "File to save hidden state sequence to (may be left "
+ "unspecified.", "S", "");
+PARAM_INT("seed", "Random seed. If 0, 'std::time(NULL)' is used.", "s", 0);
+*/
+
+
+using namespace mlpack;
+using namespace mlpack::hmm;
+using namespace mlpack::distribution;
+using namespace mlpack::utilities;
+using namespace mlpack::gmm;
+using namespace mlpack::math;
+using namespace arma;
+using namespace std;
+
+namespace {
+ // gets the transition matrix from the struct
+ void getTransition(mat & transition, const mxArray * mxarray)
+ {
+ mxArray * mxTransitions = mxGetField(mxarray, 0, "transition");
+ if (NULL == mxTransitions)
+ {
+ mexErrMsgTxt("Model struct did not have transition matrix 'transition'.");
+ }
+ if (mxDOUBLE_CLASS != mxGetClassID(mxTransitions))
+ {
+ mexErrMsgTxt("Transition matrix 'transition' must have type mxDOUBLE_CLASS.");
+ }
+ const size_t m = mxGetM(mxTransitions);
+ const size_t n = mxGetN(mxTransitions);
+ transition.resize(m,n);
+
+ double * values = mxGetPr(mxTransitions);
+ for (int i = 0; i < m*n; ++i)
+ transition(i) = values[i];
+ }
+
+ // writes the matlab transition matrix to the model
+ template <class T>
+ void writeTransition(HMM<T> & hmm, const mxArray * mxarray)
+ {
+ mxArray * mxTransitions = mxGetField(mxarray, 0, "transition");
+ if (NULL == mxTransitions)
+ {
+ mexErrMsgTxt("Model struct did not have transition matrix 'transition'.");
+ }
+ if (mxDOUBLE_CLASS != mxGetClassID(mxTransitions))
+ {
+ mexErrMsgTxt("Transition matrix 'transition' must have type mxDOUBLE_CLASS.");
+ }
+
+ arma::mat transition(mxGetM(mxTransitions), mxGetN(mxTransitions));
+ double * values = mxGetPr(mxTransitions);
+ for (int i = 0; i < mxGetM(mxTransitions) * mxGetN(mxTransitions); ++i)
+ transition(i) = values[i];
+
+ hmm.Transition() = transition;
+ }
+
+ // argument check on the emission field
+ void checkEmission(const mat & transition, const mxArray * mxarray)
+ {
+ if (NULL == mxarray)
+ {
+ mexErrMsgTxt("Model struct did not have 'emission' struct.");
+ }
+ if ((int) mxGetN(mxarray) != (int) transition.n_rows)
+ {
+ stringstream ss;
+ ss << "'emissions' struct array must have dimensions 1 x "
+ << transition.n_rows << ".";
+ mexErrMsgTxt(ss.str().c_str());
+ }
+ }
+
+} // closing anonymous namespace
+
+void mexFunction(int nlhs, mxArray *plhs[],
+ int nrhs, const mxArray *prhs[])
+{
+ // argument checks
+ if (nrhs != 4)
+ {
+ mexErrMsgTxt("Expecting four arguments.");
+ }
+
+ if (nlhs != 1)
+ {
+ mexErrMsgTxt("Output required.");
+ }
+
+ // seed argument
+ size_t seed = (size_t) mxGetScalar(prhs[3]);
+
+ // Set random seed.
+ if (seed != 0)
+ mlpack::math::RandomSeed(seed);
+ else
+ mlpack::math::RandomSeed((size_t) std::time(NULL));
+
+ // length of observations
+ const int length = (int) mxGetScalar(prhs[1]);
+
+ // start state
+ const int startState = (int) mxGetScalar(prhs[2]);
+
+ if (length <= 0)
+ {
+ stringstream ss;
+ ss << "Invalid sequence length (" << length << "); must be greater "
+ << "than or equal to 0!";
+ mexErrMsgTxt(ss.str().c_str());
+ }
+
+ // getting the model type
+ if (mxIsStruct(prhs[0]) == 0)
+ {
+ mexErrMsgTxt("Model argument is not a struct.");
+ }
+
+ mxArray * mxHmmType = mxGetField(prhs[0], 0, "hmm_type");
+ if (mxHmmType == NULL)
+ {
+ mexErrMsgTxt("Model struct did not have 'hmm_type'.");
+ }
+ if (mxCHAR_CLASS != mxGetClassID(mxHmmType))
+ {
+ mexErrMsgTxt("'hmm_type' must have type mxCHAR_CLASS.");
+ }
+
+ // getting the model type string
+ int bufLength = mxGetNumberOfElements(mxHmmType) + 1;
+ char * buf;
+ buf = (char *) mxCalloc(bufLength, sizeof(char));
+ mxGetString(mxHmmType, buf, bufLength);
+ string type(buf);
+ mxFree(buf);
+
+ cout << type << endl;
+
+ // to be filled by the generator
+ mat observations;
+ Col<size_t> sequence;
+
+ // to be removed!
+ SaveRestoreUtility sr;
+
+ if (type == "discrete")
+ {
+ HMM<DiscreteDistribution> hmm(1, DiscreteDistribution(1));
+
+ // writing transition matrix to the hmm
+ writeTransition(hmm, prhs[0]);
+
+ // writing emission matrix to the hmm
+ mxArray * mxEmission = mxGetField(prhs[0], 0, "emission");
+ //checkEmission(hmm, mxEmission);
+
+ vector<DiscreteDistribution> emission(hmm.Transition().n_rows);
+ for (int i=0; i<hmm.Transition().n_rows; ++i)
+ {
+ mxArray * mxProbabilities = mxGetField(mxEmission, i, "probabilities");
+ if (NULL == mxProbabilities)
+ {
+ mexErrMsgTxt("'probabilities' field could not be found in 'emission' struct.");
+ }
+
+ arma::vec probabilities(mxGetN(mxProbabilities));
+ double * values = mxGetPr(mxProbabilities);
+ for (int j=0; j<mxGetN(mxProbabilities); ++j)
+ probabilities(j) = values[j];
+
+ emission[i] = DiscreteDistribution(probabilities);
+ }
+
+ hmm.Emission() = emission;
+
+ // At this point, the HMM model should be fully formed.
+ if (startState < 0 || startState >= (int) hmm.Transition().n_rows)
+ {
+ stringstream ss;
+ ss << "Invalid start state (" << startState << "); must be "
+ << "between 0 and number of states (" << hmm.Transition().n_rows
+ << ")!";
+ mexErrMsgTxt(ss.str().c_str());
+ }
+
+ hmm.Generate(size_t(length), observations, sequence, size_t(startState));
+ }
+ else if (type == "gaussian")
+ {
+ /*
+ //HMM<GaussianDistribution> hmm(1, GaussianDistribution(1));
+
+ // get transition matrix
+ //mat transition;
+ //getTransition(transition, prhs[0]);
+
+ //hmm.Transition() = transition;
+ //cout << transition << endl;
+ arma::mat transition("0.75 0.25; 0.25 0.75");
+
+ // get emission
+ //vector<GaussianDistribution> emission(transition.n_rows);
+ vector<GaussianDistribution> emission;
+ GaussianDistribution g1("5.0 5.0", "1.0 0.0; 0.0 1.0");
+ GaussianDistribution g2("-5.0 -5.0", "1.0 0.0; 0.0 1.0");
+ emission.push_back(g1);
+ emission.push_back(g2);
+
+
+ //HMM<GaussianDistribution> hmm(transition, emission);
+ //hmm.Emission() = emission;
+ HMM<GaussianDistribution> hmm(transition, emission);
+ */
+
+ // Our distribution will have three two-dimensional output Gaussians.
+ cout << "following the test" << endl;
+ HMM<GaussianDistribution> hmm(3, GaussianDistribution(2));
+ hmm.Transition() = arma::mat("0.4 0.6 0.8; 0.2 0.2 0.1; 0.4 0.2 0.1");
+ hmm.Emission()[0] = GaussianDistribution("0.0 0.0", "1.0 0.0; 0.0 1.0");
+ hmm.Emission()[1] = GaussianDistribution("2.0 2.0", "1.0 0.5; 0.5 1.2");
+ hmm.Emission()[2] = GaussianDistribution("-2.0 1.0", "2.0 0.1; 0.1 1.0");
+
+ // Now we will generate a long sequence.
+ std::vector<arma::mat> observations2(1);
+ std::vector<arma::Col<size_t> > states2(1);
+
+ // testing
+ SaveHMM(hmm, sr);
+ sr.WriteFile("testMexGaussian.xml");
+
+ // Start in state 1 (no reason).
+ cout << "test generation" << endl;
+ hmm.Generate(10000, observations2[0], states2[0], 1);
+ cout << "test complete" << endl;
+
+ if (startState < 0 || startState >= (int) hmm.Transition().n_rows)
+ {
+ stringstream ss;
+ ss << "Invalid start state (" << startState << "); must be "
+ << "between 0 and number of states (" << hmm.Transition().n_rows
+ << ")!";
+ mexErrMsgTxt(ss.str().c_str());
+ }
+ cout << "generating!" << endl;
+ hmm.Generate(size_t(length), observations, sequence, size_t(startState));
+ cout << "done!" << endl;
+ }
+ else if (type == "gmm")
+ {
+ HMM<GMM<> > hmm(1, GMM<>(1, 1));
+
+ LoadHMM(hmm, sr);
+
+ if (startState < 0 || startState >= (int) hmm.Transition().n_rows)
+ {
+ Log::Fatal << "Invalid start state (" << startState << "); must be "
+ << "between 0 and number of states (" << hmm.Transition().n_rows
+ << ")!" << endl;
+ }
+
+ hmm.Generate(size_t(length), observations, sequence, size_t(startState));
+ }
+ else
+ {
+ Log::Fatal << "Unknown HMM type '" << type << "'" << "'!" << endl;
+ }
+
+ cout << "returning to matlab" << endl;
+
+ // Setting values to be returned to matlab
+ mwSize ndim = 1;
+ mwSize dims[1] = {1};
+ const char * fieldNames[2] = {
+ "observations"
+ , "states"
+ };
+
+ plhs[0] = mxCreateStructArray(ndim, dims, 2, fieldNames);
+
+ mxArray * tmp;
+ double * values;
+
+ cout << observations.n_rows << "," << observations.n_cols << endl;
+ cout << sequence.n_rows << "," << sequence.n_cols << endl;
+ cout << observations << endl;
+ cout << sequence << endl;
+
+ // settings the observations
+ tmp = mxCreateDoubleMatrix(observations.n_rows, observations.n_cols, mxREAL);
+ values = mxGetPr(tmp);
+ for (int i=0; i<observations.n_rows * observations.n_cols; ++i)
+ values[i] = observations(i);
+
+ // note: SetField does not copy the data structure.
+ // mxDuplicateArray does the necessary copying.
+ mxSetFieldByNumber(plhs[0], 0, 0, mxDuplicateArray(tmp));
+ mxDestroyArray(tmp);
+
+ // settings the observations
+ tmp = mxCreateDoubleMatrix(sequence.n_rows, sequence.n_cols, mxREAL);
+ values = mxGetPr(tmp);
+ for (int i=0; i<length; ++i)
+ values[i] = sequence(i);
+
+ // note: SetField does not copy the data structure.
+ // mxDuplicateArray does the necessary copying.
+ mxSetFieldByNumber(plhs[0], 0, 1, mxDuplicateArray(tmp));
+ mxDestroyArray(tmp);
+}
+
+ /*
+ mxArray * mxEmission = mxGetField(prhs[0], 0, "emission");
+ checkEmission(transition, mxEmission);
+
+ vector<GaussianDistribution> emission(transition.n_rows);
+ for (int i=0; i<transition.n_rows; ++i)
+ {
+ // mean
+ mxArray * mxMean = mxGetField(mxEmission, i, "mean");
+ if (NULL == mxMean)
+ {
+ mexErrMsgTxt("'mean' field could not be found in 'emission' struct.");
+ }
+
+ arma::vec mean(mxGetN(mxMean));
+ double * values = mxGetPr(mxMean);
+ for (int j=0; j<mxGetN(mxMean); ++j)
+ mean(j) = values[j];
+
+ cout << mean << endl;
+
+ // covariance
+ mxArray * mxCovariance = mxGetField(mxEmission, i, "covariance");
+ if (NULL == mxCovariance)
+ {
+ mexErrMsgTxt("'covariance' field could not be found in 'emission' struct.");
+ }
+
+ const size_t m = (size_t) mxGetM(mxCovariance);
+ const size_t n = (size_t) mxGetN(mxCovariance);
+ mat covariance(m, n);
+ values = mxGetPr(mxCovariance);
+ for (int j=0; j < m * n; ++j)
+ covariance(j) = values[j];
+
+ cout << covariance << endl;
+
+ emission[i] = GaussianDistribution(mean, covariance);
+ }
+ */
Added: mlpack/trunk/src/mlpack/bindings/matlab/hmm/hmm_generate.m
===================================================================
--- mlpack/trunk/src/mlpack/bindings/matlab/hmm/hmm_generate.m (rev 0)
+++ mlpack/trunk/src/mlpack/bindings/matlab/hmm/hmm_generate.m 2013-01-23 19:39:51 UTC (rev 14161)
@@ -0,0 +1,28 @@
+function sequence = hmm_generate(model, sequence_length, varargin)
+%Hidden Markov Model (HMM) Sequence Generator
+%
+% This utility takes an already-trained HMM (model) and generates a
+% random observation sequence and hidden state sequence based on its parameters,
+% saving them to the specified files (--output_file and --state_file)
+%
+%Parameters:
+% model - (required) HMM model struct.
+% sequence_length - (required) Length of the sequence to produce.
+% start_state - (optional) Starting state of sequence. Default value 0.
+% seed - (optional) Random seed. If 0, 'std::time(NULL)' is used.
+% Default value 0.
+
+% a parser for the inputs
+p = inputParser;
+p.addParamValue('start_state', 0, @isscalar);
+p.addParamValue('seed', 0, @isscalar);
+
+% parsing the varargin options
+p.parse(varargin{:});
+parsed = p.Results;
+
+% interfacing with mlpack.
+sequence = mex_hmm_generate(model, sequence_length, ...
+ parsed.start_state, parsed.seed);
+
+
More information about the mlpack-svn
mailing list