[mlpack-git] master: Add support for Greedy Single Tree Search, inside NeighborSearch class. (1c855b0)
gitdub at mlpack.org
gitdub at mlpack.org
Sat Aug 20 14:56:07 EDT 2016
Repository : https://github.com/mlpack/mlpack
On branch : master
Link : https://github.com/mlpack/mlpack/compare/3274b05fcc545c3b36f783316fea2e22f79c3d03...1c77230c7d3b9c45fb102cd3c632d9c7248e085e
>---------------------------------------------------------------
commit 1c855b0323c7f736daa32f766871b83809f78678
Author: MarcosPividori <marcos.pividori at gmail.com>
Date: Tue Aug 16 03:58:55 2016 -0300
Add support for Greedy Single Tree Search, inside NeighborSearch class.
>---------------------------------------------------------------
1c855b0323c7f736daa32f766871b83809f78678
.../methods/neighbor_search/neighbor_search.hpp | 3 +-
.../neighbor_search/neighbor_search_impl.hpp | 71 ++++++++++++++++------
.../neighbor_search/neighbor_search_rules.hpp | 8 +++
.../neighbor_search/neighbor_search_rules_impl.hpp | 8 +++
4 files changed, 72 insertions(+), 18 deletions(-)
diff --git a/src/mlpack/methods/neighbor_search/neighbor_search.hpp b/src/mlpack/methods/neighbor_search/neighbor_search.hpp
index 88daf5f..beac1bf 100644
--- a/src/mlpack/methods/neighbor_search/neighbor_search.hpp
+++ b/src/mlpack/methods/neighbor_search/neighbor_search.hpp
@@ -78,7 +78,8 @@ class NeighborSearch
{
NAIVE_MODE,
SINGLE_TREE_MODE,
- DUAL_TREE_MODE
+ DUAL_TREE_MODE,
+ GREEDY_SINGLE_TREE_MODE
};
/**
diff --git a/src/mlpack/methods/neighbor_search/neighbor_search_impl.hpp b/src/mlpack/methods/neighbor_search/neighbor_search_impl.hpp
index f769de1..7bb1589 100644
--- a/src/mlpack/methods/neighbor_search/neighbor_search_impl.hpp
+++ b/src/mlpack/methods/neighbor_search/neighbor_search_impl.hpp
@@ -9,7 +9,7 @@
#define MLPACK_METHODS_NEIGHBOR_SEARCH_NEIGHBOR_SEARCH_IMPL_HPP
#include <mlpack/core.hpp>
-
+#include <mlpack/core/tree/greedy_single_tree_traverser.hpp>
#include "neighbor_search_rules.hpp"
#include <mlpack/core/tree/spill_tree/is_spill_tree.hpp>
@@ -84,15 +84,13 @@ SingleTreeTraversalType>::NeighborSearch(const MatType& referenceSetIn,
&referenceTree->Dataset()),
treeOwner(mode != NAIVE_MODE),
setOwner(false),
- searchMode(mode),
- naive(mode == NAIVE_MODE),
- singleMode(mode == SINGLE_TREE_MODE),
epsilon(epsilon),
metric(metric),
baseCases(0),
scores(0),
treeNeedsReset(false)
{
+ SetSearchMode(mode);
if (epsilon < 0)
throw std::invalid_argument("epsilon must be non-negative");
}
@@ -118,15 +116,13 @@ SingleTreeTraversalType>::NeighborSearch(MatType&& referenceSetIn,
&referenceTree->Dataset()),
treeOwner(mode != NAIVE_MODE),
setOwner(mode == NAIVE_MODE),
- searchMode(mode),
- naive(mode == NAIVE_MODE),
- singleMode(mode == SINGLE_TREE_MODE),
epsilon(epsilon),
metric(metric),
baseCases(0),
scores(0),
treeNeedsReset(false)
{
+ SetSearchMode(mode);
if (epsilon < 0)
throw std::invalid_argument("epsilon must be non-negative");
}
@@ -149,15 +145,13 @@ SingleTreeTraversalType>::NeighborSearch(Tree* referenceTree,
referenceSet(&referenceTree->Dataset()),
treeOwner(false),
setOwner(false),
- searchMode(mode),
- naive(mode == NAIVE_MODE),
- singleMode(mode == SINGLE_TREE_MODE),
epsilon(epsilon),
metric(metric),
baseCases(0),
scores(0),
treeNeedsReset(false)
{
+ SetSearchMode(mode);
if (mode == NAIVE_MODE)
throw std::invalid_argument("invalid constructor for naive mode");
if (epsilon < 0)
@@ -181,15 +175,13 @@ SingleTreeTraversalType>::NeighborSearch(const SearchMode mode,
referenceSet(new MatType()), // Empty matrix.
treeOwner(false),
setOwner(true),
- searchMode(mode),
- naive(mode == NAIVE_MODE),
- singleMode(mode == SINGLE_TREE_MODE),
epsilon(epsilon),
metric(metric),
baseCases(0),
scores(0),
treeNeedsReset(false)
{
+ SetSearchMode(mode);
if (epsilon < 0)
throw std::invalid_argument("epsilon must be non-negative");
// Build the tree on the empty dataset, if necessary.
@@ -618,6 +610,29 @@ DualTreeTraversalType, SingleTreeTraversalType>::Search(
delete queryTree;
break;
}
+ case GREEDY_SINGLE_TREE_MODE:
+ {
+ // Create the helper object for the tree traversal.
+ RuleType rules(*referenceSet, querySet, k, metric);
+
+ // Create the traverser.
+ tree::GreedySingleTreeTraverser<Tree, RuleType> traverser(rules);
+
+ // Now have it traverse for each point.
+ for (size_t i = 0; i < querySet.n_cols; ++i)
+ traverser.Traverse(i, *referenceTree);
+
+ scores += rules.Scores();
+ baseCases += rules.BaseCases();
+
+ Log::Info << rules.Scores() << " node combinations were scored."
+ << std::endl;
+ Log::Info << rules.BaseCases() << " base cases were calculated."
+ << std::endl;
+
+ rules.GetResults(*neighborPtr, *distancePtr);
+ break;
+ }
}
Timer::Stop("computing_neighbors");
@@ -814,8 +829,8 @@ DualTreeTraversalType, SingleTreeTraversalType>::Search(
// Create the helper object for the traversal.
typedef NeighborSearchRules<SortPolicy, MetricType, Tree> RuleType;
- RuleType rules(*referenceSet, *referenceSet, k, metric, epsilon,
- true /* don't return the same point as nearest neighbor */);
+ RuleType rules(*referenceSet, *referenceSet, k, metric, true
+ /* don't return the same point as nearest neighbor */);
switch (searchMode)
{
@@ -898,6 +913,24 @@ DualTreeTraversalType, SingleTreeTraversalType>::Search(
treeNeedsReset = true;
break;
}
+ case GREEDY_SINGLE_TREE_MODE:
+ {
+ // Create the traverser.
+ tree::GreedySingleTreeTraverser<Tree, RuleType> traverser(rules);
+
+ // Now have it traverse for each point.
+ for (size_t i = 0; i < referenceSet->n_cols; ++i)
+ traverser.Traverse(i, *referenceTree);
+
+ scores += rules.Scores();
+ baseCases += rules.BaseCases();
+
+ Log::Info << rules.Scores() << " node combinations were scored."
+ << std::endl;
+ Log::Info << rules.BaseCases() << " base cases were calculated."
+ << std::endl;
+ break;
+ }
}
rules.GetResults(*neighborPtr, *distancePtr);
@@ -1112,6 +1145,10 @@ DualTreeTraversalType, SingleTreeTraversalType>::SetSearchMode(
naive = false;
singleMode = false;
break;
+ case GREEDY_SINGLE_TREE_MODE:
+ naive = false;
+ singleMode = true;
+ break;
}
}
@@ -1131,9 +1168,9 @@ DualTreeTraversalType, SingleTreeTraversalType>::UpdateSearchMode()
{
if (naive)
searchMode = NAIVE_MODE;
- else if (singleMode)
+ else if (singleMode && (searchMode != GREEDY_SINGLE_TREE_MODE))
searchMode = SINGLE_TREE_MODE;
- else
+ else if (!singleMode)
searchMode = DUAL_TREE_MODE;
}
diff --git a/src/mlpack/methods/neighbor_search/neighbor_search_rules.hpp b/src/mlpack/methods/neighbor_search/neighbor_search_rules.hpp
index e7a7ce1..55d2163 100644
--- a/src/mlpack/methods/neighbor_search/neighbor_search_rules.hpp
+++ b/src/mlpack/methods/neighbor_search/neighbor_search_rules.hpp
@@ -77,6 +77,14 @@ class NeighborSearchRules
double Score(const size_t queryIndex, TreeType& referenceNode);
/**
+ * Get the child node with the best score.
+ *
+ * @param queryIndex Index of query point.
+ * @param referenceNode Candidate node to be recursed into.
+ */
+ TreeType& GetBestChild(const size_t queryIndex, TreeType& referenceNode);
+
+ /**
* Re-evaluate the score for recursion order. A low score indicates priority
* for recursion, while DBL_MAX indicates that the node should not be recursed
* into at all (it should be pruned). This is used when the score has already
diff --git a/src/mlpack/methods/neighbor_search/neighbor_search_rules_impl.hpp b/src/mlpack/methods/neighbor_search/neighbor_search_rules_impl.hpp
index 70f0bc0..e06683a 100644
--- a/src/mlpack/methods/neighbor_search/neighbor_search_rules_impl.hpp
+++ b/src/mlpack/methods/neighbor_search/neighbor_search_rules_impl.hpp
@@ -146,6 +146,14 @@ inline double NeighborSearchRules<SortPolicy, MetricType, TreeType>::Score(
}
template<typename SortPolicy, typename MetricType, typename TreeType>
+inline TreeType& NeighborSearchRules<SortPolicy, MetricType, TreeType>::
+GetBestChild(const size_t queryIndex, TreeType& referenceNode)
+{
+ ++scores;
+ return SortPolicy::GetBestChild(querySet.col(queryIndex), referenceNode);
+}
+
+template<typename SortPolicy, typename MetricType, typename TreeType>
inline double NeighborSearchRules<SortPolicy, MetricType, TreeType>::Rescore(
const size_t queryIndex,
TreeType& /* referenceNode */,
More information about the mlpack-git
mailing list