[mlpack-svn] r13406 - in mlpack/tags: . mlpack-1.0.2/CMake mlpack-1.0.2/src/mlpack mlpack-1.0.2/src/mlpack/core/tree mlpack-1.0.2/src/mlpack/core/tree/cover_tree mlpack-1.0.2/src/mlpack/methods mlpack-1.0.2/src/mlpack/methods/neighbor_search mlpack-1.0.2/src/mlpack/methods/nmf mlpack-1.0.2/src/mlpack/tests
fastlab-svn at coffeetalk-1.cc.gatech.edu
fastlab-svn at coffeetalk-1.cc.gatech.edu
Wed Aug 15 13:40:10 EDT 2012
Author: rcurtin
Date: 2012-08-15 13:40:09 -0400 (Wed, 15 Aug 2012)
New Revision: 13406
Added:
mlpack/tags/mlpack-1.0.2/
mlpack/tags/mlpack-1.0.2/CMake/allexec2man.sh
mlpack/tags/mlpack-1.0.2/CMake/exec2man.sh
mlpack/tags/mlpack-1.0.2/src/mlpack/core.hpp
mlpack/tags/mlpack-1.0.2/src/mlpack/core/tree/CMakeLists.txt
mlpack/tags/mlpack-1.0.2/src/mlpack/core/tree/cover_tree/
mlpack/tags/mlpack-1.0.2/src/mlpack/core/tree/cover_tree/cover_tree.hpp
mlpack/tags/mlpack-1.0.2/src/mlpack/core/tree/cover_tree/cover_tree_impl.hpp
mlpack/tags/mlpack-1.0.2/src/mlpack/core/tree/cover_tree/dual_tree_traverser_impl.hpp
mlpack/tags/mlpack-1.0.2/src/mlpack/methods/det/
mlpack/tags/mlpack-1.0.2/src/mlpack/methods/neighbor_search/allknn_main.cpp
mlpack/tags/mlpack-1.0.2/src/mlpack/methods/neighbor_search/neighbor_search_rules.hpp
mlpack/tags/mlpack-1.0.2/src/mlpack/methods/neighbor_search/neighbor_search_rules_impl.hpp
mlpack/tags/mlpack-1.0.2/src/mlpack/methods/nmf/nmf_main.cpp
mlpack/tags/mlpack-1.0.2/src/mlpack/tests/allknn_test.cpp
mlpack/tags/mlpack-1.0.2/src/mlpack/tests/sparse_coding_test.cpp
mlpack/tags/mlpack-1.0.2/src/mlpack/tests/tree_test.cpp
Removed:
mlpack/tags/mlpack-1.0.2/CMake/allexec2man.sh
mlpack/tags/mlpack-1.0.2/CMake/exec2man.sh
mlpack/tags/mlpack-1.0.2/src/mlpack/core.hpp
mlpack/tags/mlpack-1.0.2/src/mlpack/core/tree/CMakeLists.txt
mlpack/tags/mlpack-1.0.2/src/mlpack/core/tree/cover_tree/
mlpack/tags/mlpack-1.0.2/src/mlpack/core/tree/cover_tree/cover_tree.hpp
mlpack/tags/mlpack-1.0.2/src/mlpack/core/tree/cover_tree/cover_tree_impl.hpp
mlpack/tags/mlpack-1.0.2/src/mlpack/core/tree/cover_tree/dual_tree_traverser_impl.hpp
mlpack/tags/mlpack-1.0.2/src/mlpack/methods/det/
mlpack/tags/mlpack-1.0.2/src/mlpack/methods/neighbor_search/allknn_main.cpp
mlpack/tags/mlpack-1.0.2/src/mlpack/methods/neighbor_search/neighbor_search_rules.hpp
mlpack/tags/mlpack-1.0.2/src/mlpack/methods/neighbor_search/neighbor_search_rules_impl.hpp
mlpack/tags/mlpack-1.0.2/src/mlpack/methods/nmf/nmf_main.cpp
mlpack/tags/mlpack-1.0.2/src/mlpack/tests/allknn_test.cpp
mlpack/tags/mlpack-1.0.2/src/mlpack/tests/sparse_coding_test.cpp
mlpack/tags/mlpack-1.0.2/src/mlpack/tests/tree_test.cpp
Log:
Tag mlpack-1.0.2 for release.
Deleted: mlpack/tags/mlpack-1.0.2/CMake/allexec2man.sh
===================================================================
--- mlpack/trunk/CMake/allexec2man.sh 2012-08-09 09:02:13 UTC (rev 13379)
+++ mlpack/tags/mlpack-1.0.2/CMake/allexec2man.sh 2012-08-15 17:40:09 UTC (rev 13406)
@@ -1,22 +0,0 @@
-#!/bin/bash
-#
-# Convert all of the executables in this directory that are not tests to man
-# pages in the given directory.
-#
-# Usage:
-# allexec2man.sh /full/path/of/exec2man.sh output_directory/
-#
-# For the executable 'cheese', the file 'cheese.1.gz' will be created in the
-# output directory.
-exec2man=$1
-outdir=$2
-
-mkdir -p $outdir
-for program in `find . -perm /u=x,g=x,o=x | \
- grep -v '[.]$' | \
- grep -v '_test$' | \
- sed 's|^./||'`; do
- echo "Generating man page for $program...";
- $1 $program $outdir/$program.1
- gzip -f $outdir/$program.1
-done
Copied: mlpack/tags/mlpack-1.0.2/CMake/allexec2man.sh (from rev 13389, mlpack/trunk/CMake/allexec2man.sh)
===================================================================
--- mlpack/tags/mlpack-1.0.2/CMake/allexec2man.sh (rev 0)
+++ mlpack/tags/mlpack-1.0.2/CMake/allexec2man.sh 2012-08-15 17:40:09 UTC (rev 13406)
@@ -0,0 +1,22 @@
+#!/bin/bash
+#
+# Convert all of the executables in this directory that are not tests to man
+# pages in the given directory.
+#
+# Usage:
+# allexec2man.sh /full/path/of/exec2man.sh output_directory/
+#
+# For the executable 'cheese', the file 'cheese.1.gz' will be created in the
+# output directory.
+exec2man="$1"
+outdir="$2"
+
+mkdir -p "$outdir"
+for program in `find . -perm /u=x,g=x,o=x | \
+ grep -v '[.]$' | \
+ grep -v '_test$' | \
+ sed 's|^./||'`; do
+ echo "Generating man page for $program...";
+ "$1" "$program" "$outdir/$program.1"
+ gzip -f "$outdir/$program.1"
+done
Deleted: mlpack/tags/mlpack-1.0.2/CMake/exec2man.sh
===================================================================
--- mlpack/trunk/CMake/exec2man.sh 2012-08-09 09:02:13 UTC (rev 13379)
+++ mlpack/tags/mlpack-1.0.2/CMake/exec2man.sh 2012-08-15 17:40:09 UTC (rev 13406)
@@ -1,80 +0,0 @@
-#!/bin/bash
-# Convert the output of an MLPACK executable into a man page. This assumes that
-# the CLI subsystem is used to output help, that the executable is properly
-# documented, and that the program is run in the directory that the executable
-# is in. Usually, this is used by CMake on Linux/UNIX systems to generate the
-# man pages.
-#
-# Usage:
-# exec2man.sh executable_name output_file_name
-#
-# No warranties...
-#
-# @author Ryan Curtin
-name=$1
-output=$2
-
-# Generate the synopsis.
-# First, required options.
-reqoptions=`./$name -h | \
- awk '/Required options:/,/Options:/' | \
- grep '^ --' | \
- sed 's/^ --/--/' | \
- sed 's/^--[A-Za-z0-9_-]* (\(-[A-Za-z0-9]\))/\1/' | \
- sed 's/\(^-[A-Za-z0-9]\) [^\[].*/\1/' | \
- sed 's/\(^-[A-Za-z0-9] \[[A-Za-z0-9]*\]\) .*/\1/' | \
- sed 's/\(^--[A-Za-z0-9_-]*\) [^[].*/\1/' | \
- sed 's/\(^--[A-Za-z0-9_-]* \[[A-Za-z0-9]*\]\) [^[].*/\1/' | \
- tr '\n' ' ' | \
- sed 's/\[//g' | \
- sed 's/\]//g'`
-
-# Then, regular options.
-options=`./$name -h | \
- awk '/Options:/,/For further information,/' | \
- grep '^ --' | \
- sed 's/^ --/--/' | \
- grep -v -- '--help' | \
- grep -v -- '--info' | \
- grep -v -- '--verbose' | \
- sed 's/^--[A-Za-z0-9_-]* (\(-[A-Za-z0-9]\))/\1/' | \
- sed 's/\(^-[A-Za-z0-9]\) [^\[].*/\1/' | \
- sed 's/\(^-[A-Za-z0-9] \[[A-Za-z0-9]*\]\) .*/\1/' | \
- sed 's/\(^--[A-Za-z0-9_-]*\) [^[].*/\1/' | \
- sed 's/\(^--[A-Za-z0-9_-]* \[[A-Za-z0-9]*\]\) [^[].*/\1/' | \
- tr '\n' ' ' | \
- sed 's/\[//g' | \
- sed 's/\]//g' | \
- sed 's/\(-[A-Za-z0-9]\)\( [^a-z]\)/\[\1\]\2/g' | \
- sed 's/\(--[A-Za-z0-9_-]*\)\( [^a-z]\)/\[\1\]\2/g' | \
- sed 's/\(-[A-Za-z0-9] [a-z]*\) /\[\1\] /g' | \
- sed 's/\(--[A-Za-z0-9_-]* [a-z]*\) /\[\1\] /g'`
-
-synopsis="$name [-h] [-v] $reqoptions $options";
-
-# Preview the whole thing first.
-#./$name -h | \
-# awk -v syn="$synopsis" \
-# '{ if (NR == 1) print "NAME\n '$name' - "tolower($0)"\nSYNOPSIS\n "syn" \nDESCRIPTION\n" ; else print } ' | \
-# sed '/^[^ ]/ y/qwertyuiopasdfghjklzxcvbnm:/QWERTYUIOPASDFGHJKLZXCVBNM /' | \
-# txt2man -T -P mlpack -t $name -d 1
-
-# Now do it.
-# The awk script is a little ugly, but it is meant to format parameters
-# correctly so that the entire description of the parameter is on one line (this
-# helps avoid 'man' warnings).
-# The sed line at the end removes accidental macros from the output, replacing
-# 'word' with "word".
-./$name -h | \
- sed 's/^For further information/Additional Information\n\n For further information/' | \
- sed 's/^consult the documentation/ consult the documentation/' | \
- sed 's/^distribution of MLPACK./ distribution of MLPACK./' | \
- awk -v syn="$synopsis" \
- '{ if (NR == 1) print "NAME\n '$name' - "tolower($0)"\nSYNOPSIS\n "syn" \nDESCRIPTION\n" ; else print } ' | \
- sed '/^[^ ]/ y/qwertyuiopasdfghjklzxcvbnm:/QWERTYUIOPASDFGHJKLZXCVBNM /' | \
- sed 's/ / /g' | \
- awk '/NAME/,/REQUIRED OPTIONS/ { print; } /ADDITIONAL INFORMATION/,0 { print; } /REQUIRED OPTIONS/,/ADDITIONAL INFORMATION/ { if (!/REQUIRED_OPTIONS/ && !/OPTIONS/ && !/ADDITIONAL INFORMATION/) { if (/ --/) { printf "\n" } sub(/^[ ]*/, ""); sub(/ [ ]*/, " "); printf "%s ", $0; } else { if (!/REQUIRED OPTIONS/ && !/ADDITIONAL INFORMATION/) { print "\n"$0; } } }' | \
- sed 's/ ADDITIONAL INFORMATION/\n\nADDITIONAL INFORMATION/' | \
- txt2man -P mlpack -t $name -d 1 | \
- sed "s/^'\([A-Za-z0-9 ]*\)'/\"\1\"/" > $output
-
Copied: mlpack/tags/mlpack-1.0.2/CMake/exec2man.sh (from rev 13388, mlpack/trunk/CMake/exec2man.sh)
===================================================================
--- mlpack/tags/mlpack-1.0.2/CMake/exec2man.sh (rev 0)
+++ mlpack/tags/mlpack-1.0.2/CMake/exec2man.sh 2012-08-15 17:40:09 UTC (rev 13406)
@@ -0,0 +1,80 @@
+#!/bin/bash
+# Convert the output of an MLPACK executable into a man page. This assumes that
+# the CLI subsystem is used to output help, that the executable is properly
+# documented, and that the program is run in the directory that the executable
+# is in. Usually, this is used by CMake on Linux/UNIX systems to generate the
+# man pages.
+#
+# Usage:
+# exec2man.sh executable_name output_file_name
+#
+# No warranties...
+#
+# @author Ryan Curtin
+name="$1"
+output="$2"
+
+# Generate the synopsis.
+# First, required options.
+reqoptions=`./"$name" -h | \
+ awk '/Required options:/,/Options:/' | \
+ grep '^ --' | \
+ sed 's/^ --/--/' | \
+ sed 's/^--[A-Za-z0-9_-]* (\(-[A-Za-z0-9]\))/\1/' | \
+ sed 's/\(^-[A-Za-z0-9]\) [^\[].*/\1/' | \
+ sed 's/\(^-[A-Za-z0-9] \[[A-Za-z0-9]*\]\) .*/\1/' | \
+ sed 's/\(^--[A-Za-z0-9_-]*\) [^[].*/\1/' | \
+ sed 's/\(^--[A-Za-z0-9_-]* \[[A-Za-z0-9]*\]\) [^[].*/\1/' | \
+ tr '\n' ' ' | \
+ sed 's/\[//g' | \
+ sed 's/\]//g'`
+
+# Then, regular options.
+options=`./"$name" -h | \
+ awk '/Options:/,/For further information,/' | \
+ grep '^ --' | \
+ sed 's/^ --/--/' | \
+ grep -v -- '--help' | \
+ grep -v -- '--info' | \
+ grep -v -- '--verbose' | \
+ sed 's/^--[A-Za-z0-9_-]* (\(-[A-Za-z0-9]\))/\1/' | \
+ sed 's/\(^-[A-Za-z0-9]\) [^\[].*/\1/' | \
+ sed 's/\(^-[A-Za-z0-9] \[[A-Za-z0-9]*\]\) .*/\1/' | \
+ sed 's/\(^--[A-Za-z0-9_-]*\) [^[].*/\1/' | \
+ sed 's/\(^--[A-Za-z0-9_-]* \[[A-Za-z0-9]*\]\) [^[].*/\1/' | \
+ tr '\n' ' ' | \
+ sed 's/\[//g' | \
+ sed 's/\]//g' | \
+ sed 's/\(-[A-Za-z0-9]\)\( [^a-z]\)/\[\1\]\2/g' | \
+ sed 's/\(--[A-Za-z0-9_-]*\)\( [^a-z]\)/\[\1\]\2/g' | \
+ sed 's/\(-[A-Za-z0-9] [a-z]*\) /\[\1\] /g' | \
+ sed 's/\(--[A-Za-z0-9_-]* [a-z]*\) /\[\1\] /g'`
+
+synopsis="$name [-h] [-v] $reqoptions $options";
+
+# Preview the whole thing first.
+#./$name -h | \
+# awk -v syn="$synopsis" \
+# '{ if (NR == 1) print "NAME\n '$name' - "tolower($0)"\nSYNOPSIS\n "syn" \nDESCRIPTION\n" ; else print } ' | \
+# sed '/^[^ ]/ y/qwertyuiopasdfghjklzxcvbnm:/QWERTYUIOPASDFGHJKLZXCVBNM /' | \
+# txt2man -T -P mlpack -t $name -d 1
+
+# Now do it.
+# The awk script is a little ugly, but it is meant to format parameters
+# correctly so that the entire description of the parameter is on one line (this
+# helps avoid 'man' warnings).
+# The sed line at the end removes accidental macros from the output, replacing
+# 'word' with "word".
+./"$name" -h | \
+ sed 's/^For further information/Additional Information\n\n For further information/' | \
+ sed 's/^consult the documentation/ consult the documentation/' | \
+ sed 's/^distribution of MLPACK./ distribution of MLPACK./' | \
+ awk -v syn="$synopsis" \
+ '{ if (NR == 1) print "NAME\n '"$name"' - "tolower($0)"\nSYNOPSIS\n "syn" \nDESCRIPTION\n" ; else print } ' | \
+ sed '/^[^ ]/ y/qwertyuiopasdfghjklzxcvbnm:/QWERTYUIOPASDFGHJKLZXCVBNM /' | \
+ sed 's/ / /g' | \
+ awk '/NAME/,/REQUIRED OPTIONS/ { print; } /ADDITIONAL INFORMATION/,0 { print; } /REQUIRED OPTIONS/,/ADDITIONAL INFORMATION/ { if (!/REQUIRED_OPTIONS/ && !/OPTIONS/ && !/ADDITIONAL INFORMATION/) { if (/ --/) { printf "\n" } sub(/^[ ]*/, ""); sub(/ [ ]*/, " "); printf "%s ", $0; } else { if (!/REQUIRED OPTIONS/ && !/ADDITIONAL INFORMATION/) { print "\n"$0; } } }' | \
+ sed 's/ ADDITIONAL INFORMATION/\n\nADDITIONAL INFORMATION/' | \
+ txt2man -P mlpack -t "$name" -d 1 | \
+ sed "s/^'\([A-Za-z0-9 ]*\)'/\"\1\"/" > "$output"
+
Deleted: mlpack/tags/mlpack-1.0.2/src/mlpack/core/tree/CMakeLists.txt
===================================================================
--- mlpack/trunk/src/mlpack/core/tree/CMakeLists.txt 2012-08-09 09:02:13 UTC (rev 13379)
+++ mlpack/tags/mlpack-1.0.2/src/mlpack/core/tree/CMakeLists.txt 2012-08-15 17:40:09 UTC (rev 13406)
@@ -1,37 +0,0 @@
-cmake_minimum_required(VERSION 2.8)
-
-# Define the files we need to compile.
-# Anything not in this list will not be compiled into MLPACK.
-set(SOURCES
- ballbound.hpp
- ballbound_impl.hpp
- binary_space_tree/binary_space_tree.hpp
- binary_space_tree/binary_space_tree_impl.hpp
- binary_space_tree/dual_tree_traverser.hpp
- binary_space_tree/dual_tree_traverser_impl.hpp
- binary_space_tree/single_tree_traverser.hpp
- binary_space_tree/single_tree_traverser_impl.hpp
- bounds.hpp
- cover_tree/cover_tree.hpp
- cover_tree/cover_tree_impl.hpp
- cover_tree/first_point_is_root.hpp
- cover_tree/single_tree_traverser.hpp
- cover_tree/single_tree_traverser_impl.hpp
- cover_tree/dual_tree_traverser.hpp
- cover_tree/dual_tree_traverser_impl.hpp
- hrectbound.hpp
- hrectbound_impl.hpp
- periodichrectbound.hpp
- periodichrectbound_impl.hpp
- statistic.hpp
- mrkd_statistic.hpp
-)
-
-# add directory name to sources
-set(DIR_SRCS)
-foreach(file ${SOURCES})
- set(DIR_SRCS ${DIR_SRCS} ${CMAKE_CURRENT_SOURCE_DIR}/${file})
-endforeach()
-# Append sources (with directory name) to list of all MLPACK sources (used at
-# the parent scope).
-set(MLPACK_SRCS ${MLPACK_SRCS} ${DIR_SRCS} PARENT_SCOPE)
Copied: mlpack/tags/mlpack-1.0.2/src/mlpack/core/tree/CMakeLists.txt (from rev 13384, mlpack/trunk/src/mlpack/core/tree/CMakeLists.txt)
===================================================================
--- mlpack/tags/mlpack-1.0.2/src/mlpack/core/tree/CMakeLists.txt (rev 0)
+++ mlpack/tags/mlpack-1.0.2/src/mlpack/core/tree/CMakeLists.txt 2012-08-15 17:40:09 UTC (rev 13406)
@@ -0,0 +1,37 @@
+cmake_minimum_required(VERSION 2.8)
+
+# Define the files we need to compile.
+# Anything not in this list will not be compiled into MLPACK.
+set(SOURCES
+ ballbound.hpp
+ ballbound_impl.hpp
+ binary_space_tree/binary_space_tree.hpp
+ binary_space_tree/binary_space_tree_impl.hpp
+ binary_space_tree/dual_tree_traverser.hpp
+ binary_space_tree/dual_tree_traverser_impl.hpp
+ binary_space_tree/single_tree_traverser.hpp
+ binary_space_tree/single_tree_traverser_impl.hpp
+ bounds.hpp
+ cover_tree/cover_tree.hpp
+ cover_tree/cover_tree_impl.hpp
+ cover_tree/first_point_is_root.hpp
+ cover_tree/single_tree_traverser.hpp
+ cover_tree/single_tree_traverser_impl.hpp
+ cover_tree/dual_tree_traverser.hpp
+ cover_tree/dual_tree_traverser_impl.hpp
+ hrectbound.hpp
+ hrectbound_impl.hpp
+ periodichrectbound.hpp
+ periodichrectbound_impl.hpp
+ statistic.hpp
+ mrkd_statistic.hpp
+)
+
+# add directory name to sources
+set(DIR_SRCS)
+foreach(file ${SOURCES})
+ set(DIR_SRCS ${DIR_SRCS} ${CMAKE_CURRENT_SOURCE_DIR}/${file})
+endforeach()
+# Append sources (with directory name) to list of all MLPACK sources (used at
+# the parent scope).
+set(MLPACK_SRCS ${MLPACK_SRCS} ${DIR_SRCS} PARENT_SCOPE)
Deleted: mlpack/tags/mlpack-1.0.2/src/mlpack/core/tree/cover_tree/cover_tree.hpp
===================================================================
--- mlpack/trunk/src/mlpack/core/tree/cover_tree/cover_tree.hpp 2012-08-15 16:44:38 UTC (rev 13396)
+++ mlpack/tags/mlpack-1.0.2/src/mlpack/core/tree/cover_tree/cover_tree.hpp 2012-08-15 17:40:09 UTC (rev 13406)
@@ -1,351 +0,0 @@
-/**
- * @file cover_tree.hpp
- * @author Ryan Curtin
- *
- * Definition of CoverTree, which can be used in place of the BinarySpaceTree.
- */
-#ifndef __MLPACK_CORE_TREE_COVER_TREE_COVER_TREE_HPP
-#define __MLPACK_CORE_TREE_COVER_TREE_COVER_TREE_HPP
-
-#include <mlpack/core.hpp>
-#include <mlpack/core/metrics/lmetric.hpp>
-#include "first_point_is_root.hpp"
-#include "../statistic.hpp"
-
-namespace mlpack {
-namespace tree {
-
-/**
- * A cover tree is a tree specifically designed to speed up nearest-neighbor
- * computation in high-dimensional spaces. Each non-leaf node references a
- * point and has a nonzero number of children, including a "self-child" which
- * references the same point. A leaf node represents only one point.
- *
- * The tree can be thought of as a hierarchy with the root node at the top level
- * and the leaf nodes at the bottom level. Each level in the tree has an
- * assigned 'scale' i. The tree follows these three conditions:
- *
- * - nesting: the level C_i is a subset of the level C_{i - 1}.
- * - covering: all node in level C_{i - 1} have at least one node in the
- * level C_i with distance less than or equal to EC^i (exactly one of these
- * is a parent of the point in level C_{i - 1}.
- * - separation: all nodes in level C_i have distance greater than EC^i to all
- * other nodes in level C_i.
- *
- * The value 'EC' refers to the expansion constant, which is a parameter of the
- * tree. These three properties make the cover tree very good for fast,
- * high-dimensional nearest-neighbor search.
- *
- * The theoretical structure of the tree contains many 'implicit' nodes which
- * only have a "self-child" (a child referencing the same point, but at a lower
- * scale level). This practical implementation only constructs explicit nodes
- * -- non-leaf nodes with more than one child. A leaf node has no children, and
- * its scale level is INT_MIN.
- *
- * For more information on cover trees, see
- *
- * @code
- * @inproceedings{
- * author = {Beygelzimer, Alina and Kakade, Sham and Langford, John},
- * title = {Cover trees for nearest neighbor},
- * booktitle = {Proceedings of the 23rd International Conference on Machine
- * Learning},
- * series = {ICML '06},
- * year = {2006},
- * pages = {97--104]
- * }
- * @endcode
- *
- * For information on runtime bounds of the nearest-neighbor computation using
- * cover trees, see the following paper, presented at NIPS 2009:
- *
- * @code
- * @inproceedings{
- * author = {Ram, P., and Lee, D., and March, W.B., and Gray, A.G.},
- * title = {Linear-time Algorithms for Pairwise Statistical Problems},
- * booktitle = {Advances in Neural Information Processing Systems 22},
- * editor = {Y. Bengio and D. Schuurmans and J. Lafferty and C.K.I. Williams
- * and A. Culotta},
- * pages = {1527--1535},
- * year = {2009}
- * }
- * @endcode
- *
- * The CoverTree class offers three template parameters; a custom metric type
- * can be used with MetricType (this class defaults to the L2-squared metric).
- * The root node's point can be chosen with the RootPointPolicy; by default, the
- * FirstPointIsRoot policy is used, meaning the first point in the dataset is
- * used. The StatisticType policy allows you to define statistics which can be
- * gathered during the creation of the tree.
- *
- * @tparam MetricType Metric type to use during tree construction.
- * @tparam RootPointPolicy Determines which point to use as the root node.
- * @tparam StatisticType Statistic to be used during tree creation.
- */
-template<typename MetricType = metric::LMetric<2>,
- typename RootPointPolicy = FirstPointIsRoot,
- typename StatisticType = EmptyStatistic>
-class CoverTree
-{
- public:
- typedef arma::mat Mat;
-
- /**
- * Create the cover tree with the given dataset and given expansion constant.
- * The dataset will not be modified during the building procedure (unlike
- * BinarySpaceTree).
- *
- * @param dataset Reference to the dataset to build a tree on.
- * @param expansionConstant Expansion constant (EC) to use during tree
- * building (default 2.0).
- */
- CoverTree(const arma::mat& dataset,
- const double expansionConstant = 2.0,
- MetricType* metric = NULL);
-
- /**
- * Construct a child cover tree node. This constructor is not meant to be
- * used externally, but it could be used to insert another node into a tree.
- * This procedure uses only one vector for the near set, the far set, and the
- * used set (this is to prevent unnecessary memory allocation in recursive
- * calls to this constructor). Therefore, the size of the near set, far set,
- * and used set must be passed in. The near set will be entirely used up, and
- * some of the far set may be used. The value of usedSetSize will be set to
- * the number of points used in the construction of this node, and the value
- * of farSetSize will be modified to reflect the number of points in the far
- * set _after_ the construction of this node.
- *
- * If you are calling this manually, be careful that the given scale is
- * as small as possible, or you may be creating an implicit node in your tree.
- *
- * @param dataset Reference to the dataset to build a tree on.
- * @param expansionConstant Expansion constant (EC) to use during tree
- * building.
- * @param pointIndex Index of the point this node references.
- * @param scale Scale of this level in the tree.
- * @param indices Array of indices, ordered [ nearSet | farSet | usedSet ];
- * will be modified to [ farSet | usedSet ].
- * @param distances Array of distances, ordered the same way as the indices.
- * These represent the distances between the point specified by pointIndex
- * and each point in the indices array.
- * @param nearSetSize Size of the near set; if 0, this will be a leaf.
- * @param farSetSize Size of the far set; may be modified (if this node uses
- * any points in the far set).
- * @param usedSetSize The number of points used will be added to this number.
- */
- CoverTree(const arma::mat& dataset,
- const double expansionConstant,
- const size_t pointIndex,
- const int scale,
- const double parentDistance,
- arma::Col<size_t>& indices,
- arma::vec& distances,
- size_t nearSetSize,
- size_t& farSetSize,
- size_t& usedSetSize,
- MetricType& metric = NULL);
-
- /**
- * Delete this cover tree node and its children.
- */
- ~CoverTree();
-
- //! A single-tree cover tree traverser; see single_tree_traverser.hpp for
- //! implementation.
- template<typename RuleType>
- class SingleTreeTraverser;
-
- //! A dual-tree cover tree traverser; see dual_tree_traverser.hpp.
- template<typename RuleType>
- class DualTreeTraverser;
-
- //! Get a reference to the dataset.
- const arma::mat& Dataset() const { return dataset; }
-
- //! Get the index of the point which this node represents.
- size_t Point() const { return point; }
- //! For compatibility with other trees; the argument is ignored.
- size_t Point(const size_t) const { return point; }
-
- // Fake
- CoverTree* Left() const { return NULL; }
- CoverTree* Right() const { return NULL; }
- size_t Begin() const { return 0; }
- size_t Count() const { return 0; }
- size_t End() const { return 0; }
- bool IsLeaf() const { return (children.size() == 0); }
- size_t NumPoints() const { return 1; }
-
- //! Get a particular child node.
- const CoverTree& Child(const size_t index) const { return *children[index]; }
- //! Modify a particular child node.
- CoverTree& Child(const size_t index) { return *children[index]; }
-
- //! Get the number of children.
- size_t NumChildren() const { return children.size(); }
-
- //! Get the children.
- const std::vector<CoverTree*>& Children() const { return children; }
- //! Modify the children manually (maybe not a great idea).
- std::vector<CoverTree*>& Children() { return children; }
-
- //! Get the scale of this node.
- int Scale() const { return scale; }
- //! Modify the scale of this node. Be careful...
- int& Scale() { return scale; }
-
- //! Get the expansion constant.
- double ExpansionConstant() const { return expansionConstant; }
- //! Modify the expansion constant; don't do this, you'll break everything.
- double& ExpansionConstant() { return expansionConstant; }
-
- //! Get the statistic for this node.
- const StatisticType& Stat() const { return stat; }
- //! Modify the statistic for this node.
- StatisticType& Stat() { return stat; }
-
- //! Return the minimum distance to another node.
- double MinDistance(const CoverTree* other) const;
-
- //! Return the minimum distance to another node given that the point-to-point
- //! distance has already been calculated.
- double MinDistance(const CoverTree* other, const double distance) const;
-
- //! Return the minimum distance to another point.
- double MinDistance(const arma::vec& other) const;
-
- //! Return the minimum distance to another point given that the distance from
- //! the center to the point has already been calculated.
- double MinDistance(const arma::vec& other, const double distance) const;
-
- //! Return the maximum distance to another node.
- double MaxDistance(const CoverTree* other) const;
-
- //! Return the maximum distance to another node given that the point-to-point
- //! distance has already been calculated.
- double MaxDistance(const CoverTree* other, const double distance) const;
-
- //! Return the maximum distance to another point.
- double MaxDistance(const arma::vec& other) const;
-
- //! Return the maximum distance to another point given that the distance from
- //! the center to the point has already been calculated.
- double MaxDistance(const arma::vec& other, const double distance) const;
-
- //! Returns true: this tree does have self-children.
- static bool HasSelfChildren() { return true; }
-
- //! Get the distance to the parent.
- double ParentDistance() const { return parentDistance; }
-
- //! Get the distance to teh furthest descendant.
- double FurthestDescendantDistance() const
- { return furthestDescendantDistance; }
-
- private:
- //! Reference to the matrix which this tree is built on.
- const arma::mat& dataset;
-
- //! Index of the point in the matrix which this node represents.
- size_t point;
-
- //! The list of children; the first is the self-child.
- std::vector<CoverTree*> children;
-
- //! Scale level of the node.
- int scale;
-
- //! The expansion constant used to construct the tree.
- double expansionConstant;
-
- //! The instantiated statistic.
- StatisticType stat;
-
- //! Distance to the parent.
- double parentDistance;
-
- //! Distance to the furthest descendant.
- double furthestDescendantDistance;
-
- /**
- * Fill the vector of distances with the distances between the point specified
- * by pointIndex and each point in the indices array. The distances of the
- * first pointSetSize points in indices are calculated (so, this does not
- * necessarily need to use all of the points in the arrays).
- *
- * @param pointIndex Point to build the distances for.
- * @param indices List of indices to compute distances for.
- * @param distances Vector to store calculated distances in.
- * @param pointSetSize Number of points in arrays to calculate distances for.
- */
- void ComputeDistances(const size_t pointIndex,
- const arma::Col<size_t>& indices,
- arma::vec& distances,
- const size_t pointSetSize,
- MetricType& metric);
- /**
- * Split the given indices and distances into a near and a far set, returning
- * the number of points in the near set. The distances must already be
- * initialized. This will order the indices and distances such that the
- * points in the near set make up the first part of the array and the far set
- * makes up the rest: [ nearSet | farSet ].
- *
- * @param indices List of indices; will be reordered.
- * @param distances List of distances; will be reordered.
- * @param bound If the distance is less than or equal to this bound, the point
- * is placed into the near set.
- * @param pointSetSize Size of point set (because we may be sorting a smaller
- * list than the indices vector will hold).
- */
- size_t SplitNearFar(arma::Col<size_t>& indices,
- arma::vec& distances,
- const double bound,
- const size_t pointSetSize);
-
- /**
- * Assuming that the list of indices and distances is sorted as
- * [ childFarSet | childUsedSet | farSet | usedSet ],
- * resort the sets so the organization is
- * [ childFarSet | farSet | childUsedSet | usedSet ].
- *
- * The size_t parameters specify the sizes of each set in the array. Only the
- * ordering of the indices and distances arrays will be modified (not their
- * actual contents).
- *
- * The size of any of the four sets can be zero and this method will handle
- * that case accordingly.
- *
- * @param indices List of indices to sort.
- * @param distances List of distances to sort.
- * @param childFarSetSize Number of points in child far set (childFarSet).
- * @param childUsedSetSize Number of points in child used set (childUsedSet).
- * @param farSetSize Number of points in far set (farSet).
- */
- size_t SortPointSet(arma::Col<size_t>& indices,
- arma::vec& distances,
- const size_t childFarSetSize,
- const size_t childUsedSetSize,
- const size_t farSetSize);
-
- void MoveToUsedSet(arma::Col<size_t>& indices,
- arma::vec& distances,
- size_t& nearSetSize,
- size_t& farSetSize,
- size_t& usedSetSize,
- arma::Col<size_t>& childIndices,
- const size_t childFarSetSize,
- const size_t childUsedSetSize);
- size_t PruneFarSet(arma::Col<size_t>& indices,
- arma::vec& distances,
- const double bound,
- const size_t nearSetSize,
- const size_t pointSetSize);
-};
-
-}; // namespace tree
-}; // namespace mlpack
-
-// Include implementation.
-#include "cover_tree_impl.hpp"
-
-#endif
Copied: mlpack/tags/mlpack-1.0.2/src/mlpack/core/tree/cover_tree/cover_tree.hpp (from rev 13403, mlpack/trunk/src/mlpack/core/tree/cover_tree/cover_tree.hpp)
===================================================================
--- mlpack/tags/mlpack-1.0.2/src/mlpack/core/tree/cover_tree/cover_tree.hpp (rev 0)
+++ mlpack/tags/mlpack-1.0.2/src/mlpack/core/tree/cover_tree/cover_tree.hpp 2012-08-15 17:40:09 UTC (rev 13406)
@@ -0,0 +1,349 @@
+/**
+ * @file cover_tree.hpp
+ * @author Ryan Curtin
+ *
+ * Definition of CoverTree, which can be used in place of the BinarySpaceTree.
+ */
+#ifndef __MLPACK_CORE_TREE_COVER_TREE_COVER_TREE_HPP
+#define __MLPACK_CORE_TREE_COVER_TREE_COVER_TREE_HPP
+
+#include <mlpack/core.hpp>
+#include <mlpack/core/metrics/lmetric.hpp>
+#include "first_point_is_root.hpp"
+#include "../statistic.hpp"
+
+namespace mlpack {
+namespace tree {
+
+/**
+ * A cover tree is a tree specifically designed to speed up nearest-neighbor
+ * computation in high-dimensional spaces. Each non-leaf node references a
+ * point and has a nonzero number of children, including a "self-child" which
+ * references the same point. A leaf node represents only one point.
+ *
+ * The tree can be thought of as a hierarchy with the root node at the top level
+ * and the leaf nodes at the bottom level. Each level in the tree has an
+ * assigned 'scale' i. The tree follows these three conditions:
+ *
+ * - nesting: the level C_i is a subset of the level C_{i - 1}.
+ * - covering: all node in level C_{i - 1} have at least one node in the
+ * level C_i with distance less than or equal to EC^i (exactly one of these
+ * is a parent of the point in level C_{i - 1}.
+ * - separation: all nodes in level C_i have distance greater than EC^i to all
+ * other nodes in level C_i.
+ *
+ * The value 'EC' refers to the base, which is a parameter of the
+ * tree. These three properties make the cover tree very good for fast,
+ * high-dimensional nearest-neighbor search.
+ *
+ * The theoretical structure of the tree contains many 'implicit' nodes which
+ * only have a "self-child" (a child referencing the same point, but at a lower
+ * scale level). This practical implementation only constructs explicit nodes
+ * -- non-leaf nodes with more than one child. A leaf node has no children, and
+ * its scale level is INT_MIN.
+ *
+ * For more information on cover trees, see
+ *
+ * @code
+ * @inproceedings{
+ * author = {Beygelzimer, Alina and Kakade, Sham and Langford, John},
+ * title = {Cover trees for nearest neighbor},
+ * booktitle = {Proceedings of the 23rd International Conference on Machine
+ * Learning},
+ * series = {ICML '06},
+ * year = {2006},
+ * pages = {97--104]
+ * }
+ * @endcode
+ *
+ * For information on runtime bounds of the nearest-neighbor computation using
+ * cover trees, see the following paper, presented at NIPS 2009:
+ *
+ * @code
+ * @inproceedings{
+ * author = {Ram, P., and Lee, D., and March, W.B., and Gray, A.G.},
+ * title = {Linear-time Algorithms for Pairwise Statistical Problems},
+ * booktitle = {Advances in Neural Information Processing Systems 22},
+ * editor = {Y. Bengio and D. Schuurmans and J. Lafferty and C.K.I. Williams
+ * and A. Culotta},
+ * pages = {1527--1535},
+ * year = {2009}
+ * }
+ * @endcode
+ *
+ * The CoverTree class offers three template parameters; a custom metric type
+ * can be used with MetricType (this class defaults to the L2-squared metric).
+ * The root node's point can be chosen with the RootPointPolicy; by default, the
+ * FirstPointIsRoot policy is used, meaning the first point in the dataset is
+ * used. The StatisticType policy allows you to define statistics which can be
+ * gathered during the creation of the tree.
+ *
+ * @tparam MetricType Metric type to use during tree construction.
+ * @tparam RootPointPolicy Determines which point to use as the root node.
+ * @tparam StatisticType Statistic to be used during tree creation.
+ */
+template<typename MetricType = metric::LMetric<2, true>,
+ typename RootPointPolicy = FirstPointIsRoot,
+ typename StatisticType = EmptyStatistic>
+class CoverTree
+{
+ public:
+ typedef arma::mat Mat;
+
+ /**
+ * Create the cover tree with the given dataset and given base.
+ * The dataset will not be modified during the building procedure (unlike
+ * BinarySpaceTree).
+ *
+ * @param dataset Reference to the dataset to build a tree on.
+ * @param base Base to use during tree building (default 2.0).
+ */
+ CoverTree(const arma::mat& dataset,
+ const double base = 2.0,
+ MetricType* metric = NULL);
+
+ /**
+ * Construct a child cover tree node. This constructor is not meant to be
+ * used externally, but it could be used to insert another node into a tree.
+ * This procedure uses only one vector for the near set, the far set, and the
+ * used set (this is to prevent unnecessary memory allocation in recursive
+ * calls to this constructor). Therefore, the size of the near set, far set,
+ * and used set must be passed in. The near set will be entirely used up, and
+ * some of the far set may be used. The value of usedSetSize will be set to
+ * the number of points used in the construction of this node, and the value
+ * of farSetSize will be modified to reflect the number of points in the far
+ * set _after_ the construction of this node.
+ *
+ * If you are calling this manually, be careful that the given scale is
+ * as small as possible, or you may be creating an implicit node in your tree.
+ *
+ * @param dataset Reference to the dataset to build a tree on.
+ * @param base Base to use during tree building.
+ * @param pointIndex Index of the point this node references.
+ * @param scale Scale of this level in the tree.
+ * @param indices Array of indices, ordered [ nearSet | farSet | usedSet ];
+ * will be modified to [ farSet | usedSet ].
+ * @param distances Array of distances, ordered the same way as the indices.
+ * These represent the distances between the point specified by pointIndex
+ * and each point in the indices array.
+ * @param nearSetSize Size of the near set; if 0, this will be a leaf.
+ * @param farSetSize Size of the far set; may be modified (if this node uses
+ * any points in the far set).
+ * @param usedSetSize The number of points used will be added to this number.
+ */
+ CoverTree(const arma::mat& dataset,
+ const double base,
+ const size_t pointIndex,
+ const int scale,
+ const double parentDistance,
+ arma::Col<size_t>& indices,
+ arma::vec& distances,
+ size_t nearSetSize,
+ size_t& farSetSize,
+ size_t& usedSetSize,
+ MetricType& metric = NULL);
+
+ /**
+ * Delete this cover tree node and its children.
+ */
+ ~CoverTree();
+
+ //! A single-tree cover tree traverser; see single_tree_traverser.hpp for
+ //! implementation.
+ template<typename RuleType>
+ class SingleTreeTraverser;
+
+ //! A dual-tree cover tree traverser; see dual_tree_traverser.hpp.
+ template<typename RuleType>
+ class DualTreeTraverser;
+
+ //! Get a reference to the dataset.
+ const arma::mat& Dataset() const { return dataset; }
+
+ //! Get the index of the point which this node represents.
+ size_t Point() const { return point; }
+ //! For compatibility with other trees; the argument is ignored.
+ size_t Point(const size_t) const { return point; }
+
+ // Fake
+ CoverTree* Left() const { return NULL; }
+ CoverTree* Right() const { return NULL; }
+ size_t Begin() const { return 0; }
+ size_t Count() const { return 0; }
+ size_t End() const { return 0; }
+ bool IsLeaf() const { return (children.size() == 0); }
+ size_t NumPoints() const { return 1; }
+
+ //! Get a particular child node.
+ const CoverTree& Child(const size_t index) const { return *children[index]; }
+ //! Modify a particular child node.
+ CoverTree& Child(const size_t index) { return *children[index]; }
+
+ //! Get the number of children.
+ size_t NumChildren() const { return children.size(); }
+
+ //! Get the children.
+ const std::vector<CoverTree*>& Children() const { return children; }
+ //! Modify the children manually (maybe not a great idea).
+ std::vector<CoverTree*>& Children() { return children; }
+
+ //! Get the scale of this node.
+ int Scale() const { return scale; }
+ //! Modify the scale of this node. Be careful...
+ int& Scale() { return scale; }
+
+ //! Get the base.
+ double Base() const { return base; }
+ //! Modify the base; don't do this, you'll break everything.
+ double& Base() { return base; }
+
+ //! Get the statistic for this node.
+ const StatisticType& Stat() const { return stat; }
+ //! Modify the statistic for this node.
+ StatisticType& Stat() { return stat; }
+
+ //! Return the minimum distance to another node.
+ double MinDistance(const CoverTree* other) const;
+
+ //! Return the minimum distance to another node given that the point-to-point
+ //! distance has already been calculated.
+ double MinDistance(const CoverTree* other, const double distance) const;
+
+ //! Return the minimum distance to another point.
+ double MinDistance(const arma::vec& other) const;
+
+ //! Return the minimum distance to another point given that the distance from
+ //! the center to the point has already been calculated.
+ double MinDistance(const arma::vec& other, const double distance) const;
+
+ //! Return the maximum distance to another node.
+ double MaxDistance(const CoverTree* other) const;
+
+ //! Return the maximum distance to another node given that the point-to-point
+ //! distance has already been calculated.
+ double MaxDistance(const CoverTree* other, const double distance) const;
+
+ //! Return the maximum distance to another point.
+ double MaxDistance(const arma::vec& other) const;
+
+ //! Return the maximum distance to another point given that the distance from
+ //! the center to the point has already been calculated.
+ double MaxDistance(const arma::vec& other, const double distance) const;
+
+ //! Returns true: this tree does have self-children.
+ static bool HasSelfChildren() { return true; }
+
+ //! Get the distance to the parent.
+ double ParentDistance() const { return parentDistance; }
+
+ //! Get the distance to teh furthest descendant.
+ double FurthestDescendantDistance() const
+ { return furthestDescendantDistance; }
+
+ private:
+ //! Reference to the matrix which this tree is built on.
+ const arma::mat& dataset;
+
+ //! Index of the point in the matrix which this node represents.
+ size_t point;
+
+ //! The list of children; the first is the self-child.
+ std::vector<CoverTree*> children;
+
+ //! Scale level of the node.
+ int scale;
+
+ //! The base used to construct the tree.
+ double base;
+
+ //! The instantiated statistic.
+ StatisticType stat;
+
+ //! Distance to the parent.
+ double parentDistance;
+
+ //! Distance to the furthest descendant.
+ double furthestDescendantDistance;
+
+ /**
+ * Fill the vector of distances with the distances between the point specified
+ * by pointIndex and each point in the indices array. The distances of the
+ * first pointSetSize points in indices are calculated (so, this does not
+ * necessarily need to use all of the points in the arrays).
+ *
+ * @param pointIndex Point to build the distances for.
+ * @param indices List of indices to compute distances for.
+ * @param distances Vector to store calculated distances in.
+ * @param pointSetSize Number of points in arrays to calculate distances for.
+ */
+ void ComputeDistances(const size_t pointIndex,
+ const arma::Col<size_t>& indices,
+ arma::vec& distances,
+ const size_t pointSetSize,
+ MetricType& metric);
+ /**
+ * Split the given indices and distances into a near and a far set, returning
+ * the number of points in the near set. The distances must already be
+ * initialized. This will order the indices and distances such that the
+ * points in the near set make up the first part of the array and the far set
+ * makes up the rest: [ nearSet | farSet ].
+ *
+ * @param indices List of indices; will be reordered.
+ * @param distances List of distances; will be reordered.
+ * @param bound If the distance is less than or equal to this bound, the point
+ * is placed into the near set.
+ * @param pointSetSize Size of point set (because we may be sorting a smaller
+ * list than the indices vector will hold).
+ */
+ size_t SplitNearFar(arma::Col<size_t>& indices,
+ arma::vec& distances,
+ const double bound,
+ const size_t pointSetSize);
+
+ /**
+ * Assuming that the list of indices and distances is sorted as
+ * [ childFarSet | childUsedSet | farSet | usedSet ],
+ * resort the sets so the organization is
+ * [ childFarSet | farSet | childUsedSet | usedSet ].
+ *
+ * The size_t parameters specify the sizes of each set in the array. Only the
+ * ordering of the indices and distances arrays will be modified (not their
+ * actual contents).
+ *
+ * The size of any of the four sets can be zero and this method will handle
+ * that case accordingly.
+ *
+ * @param indices List of indices to sort.
+ * @param distances List of distances to sort.
+ * @param childFarSetSize Number of points in child far set (childFarSet).
+ * @param childUsedSetSize Number of points in child used set (childUsedSet).
+ * @param farSetSize Number of points in far set (farSet).
+ */
+ size_t SortPointSet(arma::Col<size_t>& indices,
+ arma::vec& distances,
+ const size_t childFarSetSize,
+ const size_t childUsedSetSize,
+ const size_t farSetSize);
+
+ void MoveToUsedSet(arma::Col<size_t>& indices,
+ arma::vec& distances,
+ size_t& nearSetSize,
+ size_t& farSetSize,
+ size_t& usedSetSize,
+ arma::Col<size_t>& childIndices,
+ const size_t childFarSetSize,
+ const size_t childUsedSetSize);
+ size_t PruneFarSet(arma::Col<size_t>& indices,
+ arma::vec& distances,
+ const double bound,
+ const size_t nearSetSize,
+ const size_t pointSetSize);
+};
+
+}; // namespace tree
+}; // namespace mlpack
+
+// Include implementation.
+#include "cover_tree_impl.hpp"
+
+#endif
Deleted: mlpack/tags/mlpack-1.0.2/src/mlpack/core/tree/cover_tree/cover_tree_impl.hpp
===================================================================
--- mlpack/trunk/src/mlpack/core/tree/cover_tree/cover_tree_impl.hpp 2012-08-15 16:44:38 UTC (rev 13396)
+++ mlpack/tags/mlpack-1.0.2/src/mlpack/core/tree/cover_tree/cover_tree_impl.hpp 2012-08-15 17:40:09 UTC (rev 13406)
@@ -1,804 +0,0 @@
-/**
- * @file cover_tree_impl.hpp
- * @author Ryan Curtin
- *
- * Implementation of CoverTree class.
- */
-#ifndef __MLPACK_CORE_TREE_COVER_TREE_COVER_TREE_IMPL_HPP
-#define __MLPACK_CORE_TREE_COVER_TREE_COVER_TREE_IMPL_HPP
-
-// In case it hasn't already been included.
-#include "cover_tree.hpp"
-
-namespace mlpack {
-namespace tree {
-
-// Create the cover tree.
-template<typename MetricType, typename RootPointPolicy, typename StatisticType>
-CoverTree<MetricType, RootPointPolicy, StatisticType>::CoverTree(
- const arma::mat& dataset,
- const double expansionConstant,
- MetricType* metric) :
- dataset(dataset),
- point(RootPointPolicy::ChooseRoot(dataset)),
- expansionConstant(expansionConstant),
- parentDistance(0),
- furthestDescendantDistance(0)
-{
- // If we need to create a metric, do that. We'll just do it on the heap.
- bool localMetric = false;
- if (metric == NULL)
- {
- localMetric = true; // So we know we need to free it.
- metric = new MetricType();
- }
-
- // Kick off the building. Create the indices array and the distances array.
- arma::Col<size_t> indices = arma::linspace<arma::Col<size_t> >(1,
- dataset.n_cols - 1, dataset.n_cols - 1);
- // This is now [1 2 3 4 ... n]. We must be sure that our point does not
- // occur.
- if (point != 0)
- indices[point - 1] = 0; // Put 0 back into the set; remove what was there.
-
- arma::vec distances(dataset.n_cols - 1);
-
- // Build the initial distances.
- ComputeDistances(point, indices, distances, dataset.n_cols - 1, *metric);
-
- // Now determine the scale factor of the root node.
- const double maxDistance = max(distances);
- scale = (int) ceil(log(maxDistance) / log(expansionConstant));
- const double bound = pow(expansionConstant, scale - 1);
-
- // Unfortunately, we can't call out to other constructors, so we have to copy
- // a little bit of code from the other constructor. First we build the self
- // child.
- size_t childNearSetSize = SplitNearFar(indices, distances, bound,
- dataset.n_cols - 1);
- size_t childFarSetSize = (dataset.n_cols - 1) - childNearSetSize;
- size_t usedSetSize = 0;
- children.push_back(new CoverTree(dataset, expansionConstant, point, scale - 1,
- 0, indices, distances, childNearSetSize, childFarSetSize, usedSetSize,
- *metric));
-
- furthestDescendantDistance = children[0]->FurthestDescendantDistance();
-
- // If we created an implicit node, take its self-child instead (this could
- // happen multiple times).
- while (children[children.size() - 1]->NumChildren() == 1)
- {
- CoverTree* old = children[children.size() - 1];
- children.erase(children.begin() + children.size() - 1);
-
- // Now take its child.
- children.push_back(&(old->Child(0)));
-
- // Remove its child (so it doesn't delete it).
- old->Children().erase(old->Children().begin() + old->Children().size() - 1);
-
- // Now delete it.
- delete old;
- }
-
- size_t nearSetSize = (dataset.n_cols - 1) - usedSetSize;
-
- // We have no far set, so the array is organized thusly:
- // [ near | used ]. No resorting is necessary.
- // Therefore, go ahead and build the children.
- while (nearSetSize > 0)
- {
- // We want to select the furthest point in the near set as the next child.
- size_t newPointIndex = nearSetSize - 1;
-
- // Swap to front if necessary.
- if (newPointIndex != 0)
- {
- const size_t tempIndex = indices[newPointIndex];
- const double tempDist = distances[newPointIndex];
-
- indices[newPointIndex] = indices[0];
- distances[newPointIndex] = distances[0];
-
- indices[0] = tempIndex;
- distances[0] = tempDist;
- }
-
- if (distances[0] > furthestDescendantDistance)
- furthestDescendantDistance = distances[0];
-
- size_t childUsedSetSize = 0;
-
- // If there's only one point left, we don't need this crap.
- if (nearSetSize == 1)
- {
- size_t childNearSetSize = 0;
- size_t childFarSetSize = 0;
- children.push_back(new CoverTree(dataset, expansionConstant,
- indices[0], scale - 1, distances[0], indices, distances,
- childNearSetSize, childFarSetSize, usedSetSize, *metric));
-
- // And we're done.
- break;
- }
-
- // Create the near and far set indices and distance vectors.
- arma::Col<size_t> childIndices(nearSetSize);
- childIndices.rows(0, (nearSetSize - 2)) = indices.rows(1, nearSetSize - 1);
- // Put the current point into the used set, so when we move our indices to
- // the used set, this will be done for us.
- childIndices(nearSetSize - 1) = indices[0];
- arma::vec childDistances(nearSetSize);
-
- // Build distances for the child.
- ComputeDistances(indices[0], childIndices, childDistances,
- nearSetSize - 1, *metric);
- childDistances(nearSetSize - 1) = 0;
-
- // Split into near and far sets for this point.
- childNearSetSize = SplitNearFar(childIndices, childDistances, bound,
- nearSetSize - 1);
-
- // Build this child (recursively).
- childUsedSetSize = 1; // Mark self point as used.
- childFarSetSize = ((nearSetSize - 1) - childNearSetSize);
- children.push_back(new CoverTree(dataset, expansionConstant, indices[0],
- scale - 1, distances[0], childIndices, childDistances, childNearSetSize,
- childFarSetSize, childUsedSetSize, *metric));
-
- // If we created an implicit node, take its self-child instead (this could
- // happen multiple times).
- while (children[children.size() - 1]->NumChildren() == 1)
- {
- CoverTree* old = children[children.size() - 1];
- children.erase(children.begin() + children.size() - 1);
-
- // Now take its child.
- children.push_back(&(old->Child(0)));
-
- // Remove its child (so it doesn't delete it).
- old->Children().erase(old->Children().begin() + old->Children().size()
- - 1);
-
- // Now delete it.
- delete old;
- }
-
- // Now with the child created, it returns the childIndices and
- // childDistances vectors in this form:
- // [ childFar | childUsed ]
- // For each point in the childUsed set, we must move that point to the used
- // set in our own vector.
- size_t farSetSize = 0;
- MoveToUsedSet(indices, distances, nearSetSize, farSetSize, usedSetSize,
- childIndices, childFarSetSize, childUsedSetSize);
- }
-
- // Calculate furthest descendant.
- for (size_t i = 0; i < usedSetSize; ++i)
- if (distances[i] > furthestDescendantDistance)
- furthestDescendantDistance = distances[i];
-
- Log::Assert(furthestDescendantDistance <= pow(expansionConstant, scale + 1));
-
- if (localMetric)
- delete metric;
-}
-
-template<typename MetricType, typename RootPointPolicy, typename StatisticType>
-CoverTree<MetricType, RootPointPolicy, StatisticType>::CoverTree(
- const arma::mat& dataset,
- const double expansionConstant,
- const size_t pointIndex,
- const int scale,
- const double parentDistance,
- arma::Col<size_t>& indices,
- arma::vec& distances,
- size_t nearSetSize,
- size_t& farSetSize,
- size_t& usedSetSize,
- MetricType& metric) :
- dataset(dataset),
- point(pointIndex),
- scale(scale),
- expansionConstant(expansionConstant),
- parentDistance(parentDistance),
- furthestDescendantDistance(0)
-{
- // If the size of the near set is 0, this is a leaf.
- if (nearSetSize == 0)
- {
- this->scale = INT_MIN;
- return;
- }
-
- // Determine the next scale level. This should be the first level where there
- // are any points in the far set. So, if we know the maximum distance in the
- // distances array, this will be the largest i such that
- // maxDistance > pow(ec, i)
- // and using this for the scale factor should guarantee we are not creating an
- // implicit node. If the maximum distance is 0, every point in the near set
- // will be created as a leaf, and a child to this node. We also do not need
- // to change the furthestChildDistance or furthestDescendantDistance.
- const double maxDistance = max(distances.rows(0,
- nearSetSize + farSetSize - 1));
- if (maxDistance == 0)
- {
- // Make the self child at the lowest possible level.
- // This should not modify farSetSize or usedSetSize.
- size_t tempSize = 0;
- children.push_back(new CoverTree(dataset, expansionConstant, pointIndex,
- INT_MIN, 0, indices, distances, 0, tempSize, usedSetSize, metric));
-
- // Every point in the near set should be a leaf.
- for (size_t i = 0; i < nearSetSize; ++i)
- {
- // farSetSize and usedSetSize will not be modified.
- children.push_back(new CoverTree(dataset, expansionConstant, indices[i],
- INT_MIN, 0, indices, distances, 0, tempSize, usedSetSize, metric));
- usedSetSize++;
- }
-
- // Re-sort the dataset. We have
- // [ used | far | other used ]
- // and we want
- // [ far | all used ].
- SortPointSet(indices, distances, 0, usedSetSize, farSetSize);
-
- return;
- }
-
- const int nextScale = std::min(scale,
- (int) ceil(log(maxDistance) / log(expansionConstant))) - 1;
- const double bound = pow(expansionConstant, nextScale);
-
- // This needs to be taken out. It's a sanity check for now.
- Log::Assert(nextScale < scale);
-
- // First, make the self child. We must split the given near set into the near
- // set and far set for the self child.
- size_t childNearSetSize =
- SplitNearFar(indices, distances, bound, nearSetSize);
-
- // Build the self child (recursively).
- size_t childFarSetSize = nearSetSize - childNearSetSize;
- size_t childUsedSetSize = 0;
- children.push_back(new CoverTree(dataset, expansionConstant, pointIndex,
- nextScale, 0, indices, distances, childNearSetSize, childFarSetSize,
- childUsedSetSize, metric));
-
- // The self-child can't modify the furthestChildDistance away from 0, but it
- // can modify the furthestDescendantDistance.
- furthestDescendantDistance = children[0]->FurthestDescendantDistance();
-
- // If we created an implicit node, take its self-child instead (this could
- // happen multiple times).
- while (children[children.size() - 1]->NumChildren() == 1)
- {
- CoverTree* old = children[children.size() - 1];
- children.erase(children.begin() + children.size() - 1);
-
- // Now take its child.
- children.push_back(&(old->Child(0)));
-
- // Remove its child (so it doesn't delete it).
- old->Children().erase(old->Children().begin() + old->Children().size() - 1);
-
- // Now delete it.
- delete old;
- }
-
- // Now the arrays, in memory, look like this:
- // [ childFar | childUsed | far | used ]
- // but we need to move the used points past our far set:
- // [ childFar | far | childUsed + used ]
- // and keeping in mind that childFar = our near set,
- // [ near | far | childUsed + used ]
- // is what we are trying to make.
- SortPointSet(indices, distances, childFarSetSize, childUsedSetSize,
- farSetSize);
-
- // Update size of near set and used set.
- nearSetSize -= childUsedSetSize;
- usedSetSize += childUsedSetSize;
-
- // Now for each point in the near set, we need to make children. To save
- // computation later, we'll create an array holding the points in the near
- // set, and then after each run we'll check which of those (if any) were used
- // and we will remove them. ...if that's faster. I think it is.
- while (nearSetSize > 0)
- {
- size_t newPointIndex = nearSetSize - 1;
-
- // Swap to front if necessary.
- if (newPointIndex != 0)
- {
- const size_t tempIndex = indices[newPointIndex];
- const double tempDist = distances[newPointIndex];
-
- indices[newPointIndex] = indices[0];
- distances[newPointIndex] = distances[0];
-
- indices[0] = tempIndex;
- distances[0] = tempDist;
- }
-
- // Will this be a new furthest child?
- if (distances[0] > furthestDescendantDistance)
- furthestDescendantDistance = distances[0];
-
- // If there's only one point left, we don't need this crap.
- if ((nearSetSize == 1) && (farSetSize == 0))
- {
- size_t childNearSetSize = 0;
- children.push_back(new CoverTree(dataset, expansionConstant,
- indices[0], nextScale, distances[0], indices, distances,
- childNearSetSize, farSetSize, usedSetSize, metric));
-
- // Because the far set size is 0, we don't have to do any swapping to
- // move the point into the used set.
- ++usedSetSize;
- --nearSetSize;
-
- // And we're done.
- break;
- }
-
- // Create the near and far set indices and distance vectors. We don't fill
- // in the self-point, yet.
- arma::Col<size_t> childIndices(nearSetSize + farSetSize);
- childIndices.rows(0, (nearSetSize + farSetSize - 2)) = indices.rows(1,
- nearSetSize + farSetSize - 1);
- arma::vec childDistances(nearSetSize + farSetSize);
-
- // Build distances for the child.
- ComputeDistances(indices[0], childIndices, childDistances, nearSetSize
- + farSetSize - 1, metric);
-
- // Split into near and far sets for this point.
- childNearSetSize = SplitNearFar(childIndices, childDistances, bound,
- nearSetSize + farSetSize - 1);
- childFarSetSize = PruneFarSet(childIndices, childDistances,
- expansionConstant * bound, childNearSetSize,
- (nearSetSize + farSetSize - 1));
-
- // Now that we know the near and far set sizes, we can put the used point
- // (the self point) in the correct place; now, when we call
- // MoveToUsedSet(), it will move the self-point correctly. The distance
- // does not matter.
- childIndices(childNearSetSize + childFarSetSize) = indices[0];
- childDistances(childNearSetSize + childFarSetSize) = 0;
-
- // Build this child (recursively).
- childUsedSetSize = 1; // Mark self point as used.
- children.push_back(new CoverTree(dataset, expansionConstant, indices[0],
- nextScale, distances[0], childIndices, childDistances, childNearSetSize,
- childFarSetSize, childUsedSetSize, metric));
-
- // If we created an implicit node, take its self-child instead (this could
- // happen multiple times).
- while (children[children.size() - 1]->NumChildren() == 1)
- {
- CoverTree* old = children[children.size() - 1];
- children.erase(children.begin() + children.size() - 1);
-
- // Now take its child.
- children.push_back(&(old->Child(0)));
-
- // Remove its child (so it doesn't delete it).
- old->Children().erase(old->Children().begin() + old->Children().size()
- - 1);
-
- // Now delete it.
- delete old;
- }
-
- // Now with the child created, it returns the childIndices and
- // childDistances vectors in this form:
- // [ childFar | childUsed ]
- // For each point in the childUsed set, we must move that point to the used
- // set in our own vector.
- MoveToUsedSet(indices, distances, nearSetSize, farSetSize, usedSetSize,
- childIndices, childFarSetSize, childUsedSetSize);
- }
-
- // Calculate furthest descendant.
- for (size_t i = (nearSetSize + farSetSize); i < (nearSetSize + farSetSize +
- usedSetSize); ++i)
- if (distances[i] > furthestDescendantDistance)
- furthestDescendantDistance = distances[i];
-
- Log::Assert(furthestDescendantDistance <= pow(expansionConstant, scale + 1));
-}
-
-template<typename MetricType, typename RootPointPolicy, typename StatisticType>
-CoverTree<MetricType, RootPointPolicy, StatisticType>::~CoverTree()
-{
- // Delete each child.
- for (size_t i = 0; i < children.size(); ++i)
- delete children[i];
-}
-
-template<typename MetricType, typename RootPointPolicy, typename StatisticType>
-double CoverTree<MetricType, RootPointPolicy, StatisticType>::MinDistance(
- const CoverTree<MetricType, RootPointPolicy, StatisticType>* other) const
-{
- // Every cover tree node will contain points up to EC^(scale + 1) away.
- return std::max(MetricType::Evaluate(dataset.unsafe_col(point),
- other->Dataset().unsafe_col(other->Point())) -
- furthestDescendantDistance - other->FurthestDescendantDistance(), 0.0);
-}
-
-template<typename MetricType, typename RootPointPolicy, typename StatisticType>
-double CoverTree<MetricType, RootPointPolicy, StatisticType>::MinDistance(
- const CoverTree<MetricType, RootPointPolicy, StatisticType>* other,
- const double distance) const
-{
- // We already have the distance as evaluated by the metric.
- return std::max(distance - furthestDescendantDistance -
- other->FurthestDescendantDistance(), 0.0);
-}
-
-template<typename MetricType, typename RootPointPolicy, typename StatisticType>
-double CoverTree<MetricType, RootPointPolicy, StatisticType>::MinDistance(
- const arma::vec& other) const
-{
- return std::max(MetricType::Evaluate(dataset.unsafe_col(point), other) -
- furthestDescendantDistance, 0.0);
-}
-
-template<typename MetricType, typename RootPointPolicy, typename StatisticType>
-double CoverTree<MetricType, RootPointPolicy, StatisticType>::MinDistance(
- const arma::vec& /* other */,
- const double distance) const
-{
- return std::max(distance - furthestDescendantDistance, 0.0);
-}
-
-template<typename MetricType, typename RootPointPolicy, typename StatisticType>
-double CoverTree<MetricType, RootPointPolicy, StatisticType>::MaxDistance(
- const CoverTree<MetricType, RootPointPolicy, StatisticType>* other) const
-{
- return MetricType::Evaluate(dataset.unsafe_col(point),
- other->Dataset().unsafe_col(other->Point())) +
- furthestDescendantDistance + other->FurthestDescendantDistance();
-}
-
-template<typename MetricType, typename RootPointPolicy, typename StatisticType>
-double CoverTree<MetricType, RootPointPolicy, StatisticType>::MaxDistance(
- const CoverTree<MetricType, RootPointPolicy, StatisticType>* other,
- const double distance) const
-{
- // We already have the distance as evaluated by the metric.
- return distance + furthestDescendantDistance +
- other->FurthestDescendantDistance();
-}
-
-template<typename MetricType, typename RootPointPolicy, typename StatisticType>
-double CoverTree<MetricType, RootPointPolicy, StatisticType>::MaxDistance(
- const arma::vec& other) const
-{
- return MetricType::Evaluate(dataset.unsafe_col(point), other) +
- furthestDescendantDistance;
-}
-
-template<typename MetricType, typename RootPointPolicy, typename StatisticType>
-double CoverTree<MetricType, RootPointPolicy, StatisticType>::MaxDistance(
- const arma::vec& /* other */,
- const double distance) const
-{
- return distance + furthestDescendantDistance;
-}
-
-template<typename MetricType, typename RootPointPolicy, typename StatisticType>
-size_t CoverTree<MetricType, RootPointPolicy, StatisticType>::SplitNearFar(
- arma::Col<size_t>& indices,
- arma::vec& distances,
- const double bound,
- const size_t pointSetSize)
-{
- // Sanity check; there is no guarantee that this condition will not be true.
- // ...or is there?
- if (pointSetSize <= 1)
- return 0;
-
- // We'll traverse from both left and right.
- size_t left = 0;
- size_t right = pointSetSize - 1;
-
- // A modification of quicksort, with the pivot value set to the bound.
- // Everything on the left of the pivot will be less than or equal to the
- // bound; everything on the right will be greater than the bound.
- while ((distances[left] <= bound) && (left != right))
- ++left;
- while ((distances[right] > bound) && (left != right))
- --right;
-
- while (left != right)
- {
- // Now swap the values and indices.
- const size_t tempPoint = indices[left];
- const double tempDist = distances[left];
-
- indices[left] = indices[right];
- distances[left] = distances[right];
-
- indices[right] = tempPoint;
- distances[right] = tempDist;
-
- // Traverse the left, seeing how many points are correctly on that side.
- // When we encounter an incorrect point, stop. We will switch it later.
- while ((distances[left] <= bound) && (left != right))
- ++left;
-
- // Traverse the right, seeing how many points are correctly on that side.
- // When we encounter an incorrect point, stop. We will switch it with the
- // wrong point from the left side.
- while ((distances[right] > bound) && (left != right))
- --right;
- }
-
- // The final left value is the index of the first far value.
- return left;
-}
-
-// Returns the maximum distance between points.
-template<typename MetricType, typename RootPointPolicy, typename StatisticType>
-void CoverTree<MetricType, RootPointPolicy, StatisticType>::ComputeDistances(
- const size_t pointIndex,
- const arma::Col<size_t>& indices,
- arma::vec& distances,
- const size_t pointSetSize,
- MetricType& metric)
-{
- // For each point, rebuild the distances. The indices do not need to be
- // modified.
- for (size_t i = 0; i < pointSetSize; ++i)
- {
- distances[i] = metric.Evaluate(dataset.unsafe_col(pointIndex),
- dataset.unsafe_col(indices[i]));
- }
-}
-
-template<typename MetricType, typename RootPointPolicy, typename StatisticType>
-size_t CoverTree<MetricType, RootPointPolicy, StatisticType>::SortPointSet(
- arma::Col<size_t>& indices,
- arma::vec& distances,
- const size_t childFarSetSize,
- const size_t childUsedSetSize,
- const size_t farSetSize)
-{
- // We'll use low-level memcpy calls ourselves, just to ensure it's done
- // quickly and the way we want it to be. Unfortunately this takes up more
- // memory than one-element swaps, but there's not a great way around that.
- const size_t bufferSize = std::min(farSetSize, childUsedSetSize);
- const size_t bigCopySize = std::max(farSetSize, childUsedSetSize);
-
- // Sanity check: there is no need to sort if the buffer size is going to be
- // zero.
- if (bufferSize == 0)
- return (childFarSetSize + farSetSize);
-
- size_t* indicesBuffer = new size_t[bufferSize];
- double* distancesBuffer = new double[bufferSize];
-
- // The start of the memory region to copy to the buffer.
- const size_t bufferFromLocation = ((bufferSize == farSetSize) ?
- (childFarSetSize + childUsedSetSize) : childFarSetSize);
- // The start of the memory region to move directly to the new place.
- const size_t directFromLocation = ((bufferSize == farSetSize) ?
- childFarSetSize : (childFarSetSize + childUsedSetSize));
- // The destination to copy the buffer back to.
- const size_t bufferToLocation = ((bufferSize == farSetSize) ?
- childFarSetSize : (childFarSetSize + farSetSize));
- // The destination of the directly moved memory region.
- const size_t directToLocation = ((bufferSize == farSetSize) ?
- (childFarSetSize + farSetSize) : childFarSetSize);
-
- // Copy the smaller piece to the buffer.
- memcpy(indicesBuffer, indices.memptr() + bufferFromLocation,
- sizeof(size_t) * bufferSize);
- memcpy(distancesBuffer, distances.memptr() + bufferFromLocation,
- sizeof(double) * bufferSize);
-
- // Now move the other memory.
- memmove(indices.memptr() + directToLocation,
- indices.memptr() + directFromLocation, sizeof(size_t) * bigCopySize);
- memmove(distances.memptr() + directToLocation,
- distances.memptr() + directFromLocation, sizeof(double) * bigCopySize);
-
- // Now copy the temporary memory to the right place.
- memcpy(indices.memptr() + bufferToLocation, indicesBuffer,
- sizeof(size_t) * bufferSize);
- memcpy(distances.memptr() + bufferToLocation, distancesBuffer,
- sizeof(double) * bufferSize);
-
- delete[] indicesBuffer;
- delete[] distancesBuffer;
-
- // This returns the complete size of the far set.
- return (childFarSetSize + farSetSize);
-}
-
-template<typename MetricType, typename RootPointPolicy, typename StatisticType>
-void CoverTree<MetricType, RootPointPolicy, StatisticType>::MoveToUsedSet(
- arma::Col<size_t>& indices,
- arma::vec& distances,
- size_t& nearSetSize,
- size_t& farSetSize,
- size_t& usedSetSize,
- arma::Col<size_t>& childIndices,
- const size_t childFarSetSize, // childNearSetSize is 0 in this case.
- const size_t childUsedSetSize)
-{
- const size_t originalSum = nearSetSize + farSetSize + usedSetSize;
-
- // Loop across the set. We will swap points as we need. It should be noted
- // that farSetSize and nearSetSize may change with each iteration of this loop
- // (depending on if we make a swap or not).
- size_t startChildUsedSet = 0; // Where to start in the child set.
- for (size_t i = 0; i < nearSetSize; ++i)
- {
- // Discover if this point was in the child's used set.
- for (size_t j = startChildUsedSet; j < childUsedSetSize; ++j)
- {
- if (childIndices[childFarSetSize + j] == indices[i])
- {
- // We have found a point; a swap is necessary.
-
- // Since this point is from the near set, to preserve the near set, we
- // must do a swap.
- if (farSetSize > 0)
- {
- if ((nearSetSize - 1) != i)
- {
- // In this case it must be a three-way swap.
- size_t tempIndex = indices[nearSetSize + farSetSize - 1];
- double tempDist = distances[nearSetSize + farSetSize - 1];
-
- size_t tempNearIndex = indices[nearSetSize - 1];
- double tempNearDist = distances[nearSetSize - 1];
-
- indices[nearSetSize + farSetSize - 1] = indices[i];
- distances[nearSetSize + farSetSize - 1] = distances[i];
-
- indices[nearSetSize - 1] = tempIndex;
- distances[nearSetSize - 1] = tempDist;
-
- indices[i] = tempNearIndex;
- distances[i] = tempNearDist;
- }
- else
- {
- // We can do a two-way swap.
- size_t tempIndex = indices[nearSetSize + farSetSize - 1];
- double tempDist = distances[nearSetSize + farSetSize - 1];
-
- indices[nearSetSize + farSetSize - 1] = indices[i];
- distances[nearSetSize + farSetSize - 1] = distances[i];
-
- indices[i] = tempIndex;
- distances[i] = tempDist;
- }
- }
- else if ((nearSetSize - 1) != i)
- {
- // A two-way swap is possible.
- size_t tempIndex = indices[nearSetSize + farSetSize - 1];
- double tempDist = distances[nearSetSize + farSetSize - 1];
-
- indices[nearSetSize + farSetSize - 1] = indices[i];
- distances[nearSetSize + farSetSize - 1] = distances[i];
-
- indices[i] = tempIndex;
- distances[i] = tempDist;
- }
- else
- {
- // No swap is necessary.
- }
-
- // We don't need to do a complete preservation of the child index set,
- // but we want to make sure we only loop over points we haven't seen.
- // So increment the child counter by 1 and move a point if we need.
- if (j != startChildUsedSet)
- {
- childIndices[childFarSetSize + j] = childIndices[childFarSetSize +
- startChildUsedSet];
- }
-
- // Update all counters from the swaps we have done.
- ++startChildUsedSet;
- --nearSetSize;
- --i; // Since we moved a point out of the near set we must step back.
-
- break; // Break out of this for loop; back to the first one.
- }
- }
- }
-
- // Now loop over the far set. This loop is different because we only require
- // a normal two-way swap instead of the three-way swap to preserve the near
- // set / far set ordering.
- for (size_t i = 0; i < farSetSize; ++i)
- {
- // Discover if this point was in the child's used set.
- for (size_t j = startChildUsedSet; j < childUsedSetSize; ++j)
- {
- if (childIndices[childFarSetSize + j] == indices[i + nearSetSize])
- {
- // We have found a point to swap.
-
- // Perform the swap.
- size_t tempIndex = indices[nearSetSize + farSetSize - 1];
- double tempDist = distances[nearSetSize + farSetSize - 1];
-
- indices[nearSetSize + farSetSize - 1] = indices[nearSetSize + i];
- distances[nearSetSize + farSetSize - 1] = distances[nearSetSize + i];
-
- indices[nearSetSize + i] = tempIndex;
- distances[nearSetSize + i] = tempDist;
-
- if (j != startChildUsedSet)
- {
- childIndices[childFarSetSize + j] = childIndices[childFarSetSize +
- startChildUsedSet];
- }
-
- // Update all counters from the swaps we have done.
- ++startChildUsedSet;
- --farSetSize;
- --i;
-
- break; // Break out of this for loop; back to the first one.
- }
- }
- }
-
- // Update used set size.
- usedSetSize += childUsedSetSize;
-
- Log::Assert(originalSum == (nearSetSize + farSetSize + usedSetSize));
-}
-
-template<typename MetricType, typename RootPointPolicy, typename StatisticType>
-size_t CoverTree<MetricType, RootPointPolicy, StatisticType>::PruneFarSet(
- arma::Col<size_t>& indices,
- arma::vec& distances,
- const double bound,
- const size_t nearSetSize,
- const size_t pointSetSize)
-{
- // What we are trying to do is remove any points greater than the bound from
- // the far set. We don't care what happens to those indices and distances...
- // so, we don't need to properly swap points -- just drop new ones in place.
- size_t left = nearSetSize;
- size_t right = pointSetSize - 1;
- while ((distances[left] <= bound) && (left != right))
- ++left;
- while ((distances[right] > bound) && (left != right))
- --right;
-
- while (left != right)
- {
- // We don't care what happens to the point which should be on the right.
- indices[left] = indices[right];
- distances[left] = distances[right];
- --right; // Since we aren't changing the right.
-
- // Advance to next location which needs to switch.
- while ((distances[left] <= bound) && (left != right))
- ++left;
- while ((distances[right] > bound) && (left != right))
- --right;
- }
-
- // The far set size is the left pointer, with the near set size subtracted
- // from it.
- return (left - nearSetSize);
-}
-
-}; // namespace tree
-}; // namespace mlpack
-
-#endif
Copied: mlpack/tags/mlpack-1.0.2/src/mlpack/core/tree/cover_tree/cover_tree_impl.hpp (from rev 13401, mlpack/trunk/src/mlpack/core/tree/cover_tree/cover_tree_impl.hpp)
===================================================================
--- mlpack/tags/mlpack-1.0.2/src/mlpack/core/tree/cover_tree/cover_tree_impl.hpp (rev 0)
+++ mlpack/tags/mlpack-1.0.2/src/mlpack/core/tree/cover_tree/cover_tree_impl.hpp 2012-08-15 17:40:09 UTC (rev 13406)
@@ -0,0 +1,804 @@
+/**
+ * @file cover_tree_impl.hpp
+ * @author Ryan Curtin
+ *
+ * Implementation of CoverTree class.
+ */
+#ifndef __MLPACK_CORE_TREE_COVER_TREE_COVER_TREE_IMPL_HPP
+#define __MLPACK_CORE_TREE_COVER_TREE_COVER_TREE_IMPL_HPP
+
+// In case it hasn't already been included.
+#include "cover_tree.hpp"
+
+namespace mlpack {
+namespace tree {
+
+// Create the cover tree.
+template<typename MetricType, typename RootPointPolicy, typename StatisticType>
+CoverTree<MetricType, RootPointPolicy, StatisticType>::CoverTree(
+ const arma::mat& dataset,
+ const double base,
+ MetricType* metric) :
+ dataset(dataset),
+ point(RootPointPolicy::ChooseRoot(dataset)),
+ base(base),
+ parentDistance(0),
+ furthestDescendantDistance(0)
+{
+ // If we need to create a metric, do that. We'll just do it on the heap.
+ bool localMetric = false;
+ if (metric == NULL)
+ {
+ localMetric = true; // So we know we need to free it.
+ metric = new MetricType();
+ }
+
+ // Kick off the building. Create the indices array and the distances array.
+ arma::Col<size_t> indices = arma::linspace<arma::Col<size_t> >(1,
+ dataset.n_cols - 1, dataset.n_cols - 1);
+ // This is now [1 2 3 4 ... n]. We must be sure that our point does not
+ // occur.
+ if (point != 0)
+ indices[point - 1] = 0; // Put 0 back into the set; remove what was there.
+
+ arma::vec distances(dataset.n_cols - 1);
+
+ // Build the initial distances.
+ ComputeDistances(point, indices, distances, dataset.n_cols - 1, *metric);
+
+ // Now determine the scale factor of the root node.
+ const double maxDistance = max(distances);
+ scale = (int) ceil(log(maxDistance) / log(base));
+ const double bound = pow(base, scale - 1);
+
+ // Unfortunately, we can't call out to other constructors, so we have to copy
+ // a little bit of code from the other constructor. First we build the self
+ // child.
+ size_t childNearSetSize = SplitNearFar(indices, distances, bound,
+ dataset.n_cols - 1);
+ size_t childFarSetSize = (dataset.n_cols - 1) - childNearSetSize;
+ size_t usedSetSize = 0;
+ children.push_back(new CoverTree(dataset, base, point, scale - 1,
+ 0, indices, distances, childNearSetSize, childFarSetSize, usedSetSize,
+ *metric));
+
+ furthestDescendantDistance = children[0]->FurthestDescendantDistance();
+
+ // If we created an implicit node, take its self-child instead (this could
+ // happen multiple times).
+ while (children[children.size() - 1]->NumChildren() == 1)
+ {
+ CoverTree* old = children[children.size() - 1];
+ children.erase(children.begin() + children.size() - 1);
+
+ // Now take its child.
+ children.push_back(&(old->Child(0)));
+
+ // Remove its child (so it doesn't delete it).
+ old->Children().erase(old->Children().begin() + old->Children().size() - 1);
+
+ // Now delete it.
+ delete old;
+ }
+
+ size_t nearSetSize = (dataset.n_cols - 1) - usedSetSize;
+
+ // We have no far set, so the array is organized thusly:
+ // [ near | used ]. No resorting is necessary.
+ // Therefore, go ahead and build the children.
+ while (nearSetSize > 0)
+ {
+ // We want to select the furthest point in the near set as the next child.
+ size_t newPointIndex = nearSetSize - 1;
+
+ // Swap to front if necessary.
+ if (newPointIndex != 0)
+ {
+ const size_t tempIndex = indices[newPointIndex];
+ const double tempDist = distances[newPointIndex];
+
+ indices[newPointIndex] = indices[0];
+ distances[newPointIndex] = distances[0];
+
+ indices[0] = tempIndex;
+ distances[0] = tempDist;
+ }
+
+ if (distances[0] > furthestDescendantDistance)
+ furthestDescendantDistance = distances[0];
+
+ size_t childUsedSetSize = 0;
+
+ // If there's only one point left, we don't need this crap.
+ if (nearSetSize == 1)
+ {
+ size_t childNearSetSize = 0;
+ size_t childFarSetSize = 0;
+ children.push_back(new CoverTree(dataset, base,
+ indices[0], scale - 1, distances[0], indices, distances,
+ childNearSetSize, childFarSetSize, usedSetSize, *metric));
+
+ // And we're done.
+ break;
+ }
+
+ // Create the near and far set indices and distance vectors.
+ arma::Col<size_t> childIndices(nearSetSize);
+ childIndices.rows(0, (nearSetSize - 2)) = indices.rows(1, nearSetSize - 1);
+ // Put the current point into the used set, so when we move our indices to
+ // the used set, this will be done for us.
+ childIndices(nearSetSize - 1) = indices[0];
+ arma::vec childDistances(nearSetSize);
+
+ // Build distances for the child.
+ ComputeDistances(indices[0], childIndices, childDistances,
+ nearSetSize - 1, *metric);
+ childDistances(nearSetSize - 1) = 0;
+
+ // Split into near and far sets for this point.
+ childNearSetSize = SplitNearFar(childIndices, childDistances, bound,
+ nearSetSize - 1);
+
+ // Build this child (recursively).
+ childUsedSetSize = 1; // Mark self point as used.
+ childFarSetSize = ((nearSetSize - 1) - childNearSetSize);
+ children.push_back(new CoverTree(dataset, base, indices[0],
+ scale - 1, distances[0], childIndices, childDistances, childNearSetSize,
+ childFarSetSize, childUsedSetSize, *metric));
+
+ // If we created an implicit node, take its self-child instead (this could
+ // happen multiple times).
+ while (children[children.size() - 1]->NumChildren() == 1)
+ {
+ CoverTree* old = children[children.size() - 1];
+ children.erase(children.begin() + children.size() - 1);
+
+ // Now take its child.
+ children.push_back(&(old->Child(0)));
+
+ // Remove its child (so it doesn't delete it).
+ old->Children().erase(old->Children().begin() + old->Children().size()
+ - 1);
+
+ // Now delete it.
+ delete old;
+ }
+
+ // Now with the child created, it returns the childIndices and
+ // childDistances vectors in this form:
+ // [ childFar | childUsed ]
+ // For each point in the childUsed set, we must move that point to the used
+ // set in our own vector.
+ size_t farSetSize = 0;
+ MoveToUsedSet(indices, distances, nearSetSize, farSetSize, usedSetSize,
+ childIndices, childFarSetSize, childUsedSetSize);
+ }
+
+ // Calculate furthest descendant.
+ for (size_t i = 0; i < usedSetSize; ++i)
+ if (distances[i] > furthestDescendantDistance)
+ furthestDescendantDistance = distances[i];
+
+ Log::Assert(furthestDescendantDistance <= pow(base, scale + 1));
+
+ if (localMetric)
+ delete metric;
+}
+
+template<typename MetricType, typename RootPointPolicy, typename StatisticType>
+CoverTree<MetricType, RootPointPolicy, StatisticType>::CoverTree(
+ const arma::mat& dataset,
+ const double base,
+ const size_t pointIndex,
+ const int scale,
+ const double parentDistance,
+ arma::Col<size_t>& indices,
+ arma::vec& distances,
+ size_t nearSetSize,
+ size_t& farSetSize,
+ size_t& usedSetSize,
+ MetricType& metric) :
+ dataset(dataset),
+ point(pointIndex),
+ scale(scale),
+ base(base),
+ parentDistance(parentDistance),
+ furthestDescendantDistance(0)
+{
+ // If the size of the near set is 0, this is a leaf.
+ if (nearSetSize == 0)
+ {
+ this->scale = INT_MIN;
+ return;
+ }
+
+ // Determine the next scale level. This should be the first level where there
+ // are any points in the far set. So, if we know the maximum distance in the
+ // distances array, this will be the largest i such that
+ // maxDistance > pow(ec, i)
+ // and using this for the scale factor should guarantee we are not creating an
+ // implicit node. If the maximum distance is 0, every point in the near set
+ // will be created as a leaf, and a child to this node. We also do not need
+ // to change the furthestChildDistance or furthestDescendantDistance.
+ const double maxDistance = max(distances.rows(0,
+ nearSetSize + farSetSize - 1));
+ if (maxDistance == 0)
+ {
+ // Make the self child at the lowest possible level.
+ // This should not modify farSetSize or usedSetSize.
+ size_t tempSize = 0;
+ children.push_back(new CoverTree(dataset, base, pointIndex,
+ INT_MIN, 0, indices, distances, 0, tempSize, usedSetSize, metric));
+
+ // Every point in the near set should be a leaf.
+ for (size_t i = 0; i < nearSetSize; ++i)
+ {
+ // farSetSize and usedSetSize will not be modified.
+ children.push_back(new CoverTree(dataset, base, indices[i],
+ INT_MIN, 0, indices, distances, 0, tempSize, usedSetSize, metric));
+ usedSetSize++;
+ }
+
+ // Re-sort the dataset. We have
+ // [ used | far | other used ]
+ // and we want
+ // [ far | all used ].
+ SortPointSet(indices, distances, 0, usedSetSize, farSetSize);
+
+ return;
+ }
+
+ const int nextScale = std::min(scale,
+ (int) ceil(log(maxDistance) / log(base))) - 1;
+ const double bound = pow(base, nextScale);
+
+ // This needs to be taken out. It's a sanity check for now.
+ Log::Assert(nextScale < scale);
+
+ // First, make the self child. We must split the given near set into the near
+ // set and far set for the self child.
+ size_t childNearSetSize =
+ SplitNearFar(indices, distances, bound, nearSetSize);
+
+ // Build the self child (recursively).
+ size_t childFarSetSize = nearSetSize - childNearSetSize;
+ size_t childUsedSetSize = 0;
+ children.push_back(new CoverTree(dataset, base, pointIndex,
+ nextScale, 0, indices, distances, childNearSetSize, childFarSetSize,
+ childUsedSetSize, metric));
+
+ // The self-child can't modify the furthestChildDistance away from 0, but it
+ // can modify the furthestDescendantDistance.
+ furthestDescendantDistance = children[0]->FurthestDescendantDistance();
+
+ // If we created an implicit node, take its self-child instead (this could
+ // happen multiple times).
+ while (children[children.size() - 1]->NumChildren() == 1)
+ {
+ CoverTree* old = children[children.size() - 1];
+ children.erase(children.begin() + children.size() - 1);
+
+ // Now take its child.
+ children.push_back(&(old->Child(0)));
+
+ // Remove its child (so it doesn't delete it).
+ old->Children().erase(old->Children().begin() + old->Children().size() - 1);
+
+ // Now delete it.
+ delete old;
+ }
+
+ // Now the arrays, in memory, look like this:
+ // [ childFar | childUsed | far | used ]
+ // but we need to move the used points past our far set:
+ // [ childFar | far | childUsed + used ]
+ // and keeping in mind that childFar = our near set,
+ // [ near | far | childUsed + used ]
+ // is what we are trying to make.
+ SortPointSet(indices, distances, childFarSetSize, childUsedSetSize,
+ farSetSize);
+
+ // Update size of near set and used set.
+ nearSetSize -= childUsedSetSize;
+ usedSetSize += childUsedSetSize;
+
+ // Now for each point in the near set, we need to make children. To save
+ // computation later, we'll create an array holding the points in the near
+ // set, and then after each run we'll check which of those (if any) were used
+ // and we will remove them. ...if that's faster. I think it is.
+ while (nearSetSize > 0)
+ {
+ size_t newPointIndex = nearSetSize - 1;
+
+ // Swap to front if necessary.
+ if (newPointIndex != 0)
+ {
+ const size_t tempIndex = indices[newPointIndex];
+ const double tempDist = distances[newPointIndex];
+
+ indices[newPointIndex] = indices[0];
+ distances[newPointIndex] = distances[0];
+
+ indices[0] = tempIndex;
+ distances[0] = tempDist;
+ }
+
+ // Will this be a new furthest child?
+ if (distances[0] > furthestDescendantDistance)
+ furthestDescendantDistance = distances[0];
+
+ // If there's only one point left, we don't need this crap.
+ if ((nearSetSize == 1) && (farSetSize == 0))
+ {
+ size_t childNearSetSize = 0;
+ children.push_back(new CoverTree(dataset, base,
+ indices[0], nextScale, distances[0], indices, distances,
+ childNearSetSize, farSetSize, usedSetSize, metric));
+
+ // Because the far set size is 0, we don't have to do any swapping to
+ // move the point into the used set.
+ ++usedSetSize;
+ --nearSetSize;
+
+ // And we're done.
+ break;
+ }
+
+ // Create the near and far set indices and distance vectors. We don't fill
+ // in the self-point, yet.
+ arma::Col<size_t> childIndices(nearSetSize + farSetSize);
+ childIndices.rows(0, (nearSetSize + farSetSize - 2)) = indices.rows(1,
+ nearSetSize + farSetSize - 1);
+ arma::vec childDistances(nearSetSize + farSetSize);
+
+ // Build distances for the child.
+ ComputeDistances(indices[0], childIndices, childDistances, nearSetSize
+ + farSetSize - 1, metric);
+
+ // Split into near and far sets for this point.
+ childNearSetSize = SplitNearFar(childIndices, childDistances, bound,
+ nearSetSize + farSetSize - 1);
+ childFarSetSize = PruneFarSet(childIndices, childDistances,
+ base * bound, childNearSetSize,
+ (nearSetSize + farSetSize - 1));
+
+ // Now that we know the near and far set sizes, we can put the used point
+ // (the self point) in the correct place; now, when we call
+ // MoveToUsedSet(), it will move the self-point correctly. The distance
+ // does not matter.
+ childIndices(childNearSetSize + childFarSetSize) = indices[0];
+ childDistances(childNearSetSize + childFarSetSize) = 0;
+
+ // Build this child (recursively).
+ childUsedSetSize = 1; // Mark self point as used.
+ children.push_back(new CoverTree(dataset, base, indices[0],
+ nextScale, distances[0], childIndices, childDistances, childNearSetSize,
+ childFarSetSize, childUsedSetSize, metric));
+
+ // If we created an implicit node, take its self-child instead (this could
+ // happen multiple times).
+ while (children[children.size() - 1]->NumChildren() == 1)
+ {
+ CoverTree* old = children[children.size() - 1];
+ children.erase(children.begin() + children.size() - 1);
+
+ // Now take its child.
+ children.push_back(&(old->Child(0)));
+
+ // Remove its child (so it doesn't delete it).
+ old->Children().erase(old->Children().begin() + old->Children().size()
+ - 1);
+
+ // Now delete it.
+ delete old;
+ }
+
+ // Now with the child created, it returns the childIndices and
+ // childDistances vectors in this form:
+ // [ childFar | childUsed ]
+ // For each point in the childUsed set, we must move that point to the used
+ // set in our own vector.
+ MoveToUsedSet(indices, distances, nearSetSize, farSetSize, usedSetSize,
+ childIndices, childFarSetSize, childUsedSetSize);
+ }
+
+ // Calculate furthest descendant.
+ for (size_t i = (nearSetSize + farSetSize); i < (nearSetSize + farSetSize +
+ usedSetSize); ++i)
+ if (distances[i] > furthestDescendantDistance)
+ furthestDescendantDistance = distances[i];
+
+ Log::Assert(furthestDescendantDistance <= pow(base, scale + 1));
+}
+
+template<typename MetricType, typename RootPointPolicy, typename StatisticType>
+CoverTree<MetricType, RootPointPolicy, StatisticType>::~CoverTree()
+{
+ // Delete each child.
+ for (size_t i = 0; i < children.size(); ++i)
+ delete children[i];
+}
+
+template<typename MetricType, typename RootPointPolicy, typename StatisticType>
+double CoverTree<MetricType, RootPointPolicy, StatisticType>::MinDistance(
+ const CoverTree<MetricType, RootPointPolicy, StatisticType>* other) const
+{
+ // Every cover tree node will contain points up to EC^(scale + 1) away.
+ return std::max(MetricType::Evaluate(dataset.unsafe_col(point),
+ other->Dataset().unsafe_col(other->Point())) -
+ furthestDescendantDistance - other->FurthestDescendantDistance(), 0.0);
+}
+
+template<typename MetricType, typename RootPointPolicy, typename StatisticType>
+double CoverTree<MetricType, RootPointPolicy, StatisticType>::MinDistance(
+ const CoverTree<MetricType, RootPointPolicy, StatisticType>* other,
+ const double distance) const
+{
+ // We already have the distance as evaluated by the metric.
+ return std::max(distance - furthestDescendantDistance -
+ other->FurthestDescendantDistance(), 0.0);
+}
+
+template<typename MetricType, typename RootPointPolicy, typename StatisticType>
+double CoverTree<MetricType, RootPointPolicy, StatisticType>::MinDistance(
+ const arma::vec& other) const
+{
+ return std::max(MetricType::Evaluate(dataset.unsafe_col(point), other) -
+ furthestDescendantDistance, 0.0);
+}
+
+template<typename MetricType, typename RootPointPolicy, typename StatisticType>
+double CoverTree<MetricType, RootPointPolicy, StatisticType>::MinDistance(
+ const arma::vec& /* other */,
+ const double distance) const
+{
+ return std::max(distance - furthestDescendantDistance, 0.0);
+}
+
+template<typename MetricType, typename RootPointPolicy, typename StatisticType>
+double CoverTree<MetricType, RootPointPolicy, StatisticType>::MaxDistance(
+ const CoverTree<MetricType, RootPointPolicy, StatisticType>* other) const
+{
+ return MetricType::Evaluate(dataset.unsafe_col(point),
+ other->Dataset().unsafe_col(other->Point())) +
+ furthestDescendantDistance + other->FurthestDescendantDistance();
+}
+
+template<typename MetricType, typename RootPointPolicy, typename StatisticType>
+double CoverTree<MetricType, RootPointPolicy, StatisticType>::MaxDistance(
+ const CoverTree<MetricType, RootPointPolicy, StatisticType>* other,
+ const double distance) const
+{
+ // We already have the distance as evaluated by the metric.
+ return distance + furthestDescendantDistance +
+ other->FurthestDescendantDistance();
+}
+
+template<typename MetricType, typename RootPointPolicy, typename StatisticType>
+double CoverTree<MetricType, RootPointPolicy, StatisticType>::MaxDistance(
+ const arma::vec& other) const
+{
+ return MetricType::Evaluate(dataset.unsafe_col(point), other) +
+ furthestDescendantDistance;
+}
+
+template<typename MetricType, typename RootPointPolicy, typename StatisticType>
+double CoverTree<MetricType, RootPointPolicy, StatisticType>::MaxDistance(
+ const arma::vec& /* other */,
+ const double distance) const
+{
+ return distance + furthestDescendantDistance;
+}
+
+template<typename MetricType, typename RootPointPolicy, typename StatisticType>
+size_t CoverTree<MetricType, RootPointPolicy, StatisticType>::SplitNearFar(
+ arma::Col<size_t>& indices,
+ arma::vec& distances,
+ const double bound,
+ const size_t pointSetSize)
+{
+ // Sanity check; there is no guarantee that this condition will not be true.
+ // ...or is there?
+ if (pointSetSize <= 1)
+ return 0;
+
+ // We'll traverse from both left and right.
+ size_t left = 0;
+ size_t right = pointSetSize - 1;
+
+ // A modification of quicksort, with the pivot value set to the bound.
+ // Everything on the left of the pivot will be less than or equal to the
+ // bound; everything on the right will be greater than the bound.
+ while ((distances[left] <= bound) && (left != right))
+ ++left;
+ while ((distances[right] > bound) && (left != right))
+ --right;
+
+ while (left != right)
+ {
+ // Now swap the values and indices.
+ const size_t tempPoint = indices[left];
+ const double tempDist = distances[left];
+
+ indices[left] = indices[right];
+ distances[left] = distances[right];
+
+ indices[right] = tempPoint;
+ distances[right] = tempDist;
+
+ // Traverse the left, seeing how many points are correctly on that side.
+ // When we encounter an incorrect point, stop. We will switch it later.
+ while ((distances[left] <= bound) && (left != right))
+ ++left;
+
+ // Traverse the right, seeing how many points are correctly on that side.
+ // When we encounter an incorrect point, stop. We will switch it with the
+ // wrong point from the left side.
+ while ((distances[right] > bound) && (left != right))
+ --right;
+ }
+
+ // The final left value is the index of the first far value.
+ return left;
+}
+
+// Returns the maximum distance between points.
+template<typename MetricType, typename RootPointPolicy, typename StatisticType>
+void CoverTree<MetricType, RootPointPolicy, StatisticType>::ComputeDistances(
+ const size_t pointIndex,
+ const arma::Col<size_t>& indices,
+ arma::vec& distances,
+ const size_t pointSetSize,
+ MetricType& metric)
+{
+ // For each point, rebuild the distances. The indices do not need to be
+ // modified.
+ for (size_t i = 0; i < pointSetSize; ++i)
+ {
+ distances[i] = metric.Evaluate(dataset.unsafe_col(pointIndex),
+ dataset.unsafe_col(indices[i]));
+ }
+}
+
+template<typename MetricType, typename RootPointPolicy, typename StatisticType>
+size_t CoverTree<MetricType, RootPointPolicy, StatisticType>::SortPointSet(
+ arma::Col<size_t>& indices,
+ arma::vec& distances,
+ const size_t childFarSetSize,
+ const size_t childUsedSetSize,
+ const size_t farSetSize)
+{
+ // We'll use low-level memcpy calls ourselves, just to ensure it's done
+ // quickly and the way we want it to be. Unfortunately this takes up more
+ // memory than one-element swaps, but there's not a great way around that.
+ const size_t bufferSize = std::min(farSetSize, childUsedSetSize);
+ const size_t bigCopySize = std::max(farSetSize, childUsedSetSize);
+
+ // Sanity check: there is no need to sort if the buffer size is going to be
+ // zero.
+ if (bufferSize == 0)
+ return (childFarSetSize + farSetSize);
+
+ size_t* indicesBuffer = new size_t[bufferSize];
+ double* distancesBuffer = new double[bufferSize];
+
+ // The start of the memory region to copy to the buffer.
+ const size_t bufferFromLocation = ((bufferSize == farSetSize) ?
+ (childFarSetSize + childUsedSetSize) : childFarSetSize);
+ // The start of the memory region to move directly to the new place.
+ const size_t directFromLocation = ((bufferSize == farSetSize) ?
+ childFarSetSize : (childFarSetSize + childUsedSetSize));
+ // The destination to copy the buffer back to.
+ const size_t bufferToLocation = ((bufferSize == farSetSize) ?
+ childFarSetSize : (childFarSetSize + farSetSize));
+ // The destination of the directly moved memory region.
+ const size_t directToLocation = ((bufferSize == farSetSize) ?
+ (childFarSetSize + farSetSize) : childFarSetSize);
+
+ // Copy the smaller piece to the buffer.
+ memcpy(indicesBuffer, indices.memptr() + bufferFromLocation,
+ sizeof(size_t) * bufferSize);
+ memcpy(distancesBuffer, distances.memptr() + bufferFromLocation,
+ sizeof(double) * bufferSize);
+
+ // Now move the other memory.
+ memmove(indices.memptr() + directToLocation,
+ indices.memptr() + directFromLocation, sizeof(size_t) * bigCopySize);
+ memmove(distances.memptr() + directToLocation,
+ distances.memptr() + directFromLocation, sizeof(double) * bigCopySize);
+
+ // Now copy the temporary memory to the right place.
+ memcpy(indices.memptr() + bufferToLocation, indicesBuffer,
+ sizeof(size_t) * bufferSize);
+ memcpy(distances.memptr() + bufferToLocation, distancesBuffer,
+ sizeof(double) * bufferSize);
+
+ delete[] indicesBuffer;
+ delete[] distancesBuffer;
+
+ // This returns the complete size of the far set.
+ return (childFarSetSize + farSetSize);
+}
+
+template<typename MetricType, typename RootPointPolicy, typename StatisticType>
+void CoverTree<MetricType, RootPointPolicy, StatisticType>::MoveToUsedSet(
+ arma::Col<size_t>& indices,
+ arma::vec& distances,
+ size_t& nearSetSize,
+ size_t& farSetSize,
+ size_t& usedSetSize,
+ arma::Col<size_t>& childIndices,
+ const size_t childFarSetSize, // childNearSetSize is 0 in this case.
+ const size_t childUsedSetSize)
+{
+ const size_t originalSum = nearSetSize + farSetSize + usedSetSize;
+
+ // Loop across the set. We will swap points as we need. It should be noted
+ // that farSetSize and nearSetSize may change with each iteration of this loop
+ // (depending on if we make a swap or not).
+ size_t startChildUsedSet = 0; // Where to start in the child set.
+ for (size_t i = 0; i < nearSetSize; ++i)
+ {
+ // Discover if this point was in the child's used set.
+ for (size_t j = startChildUsedSet; j < childUsedSetSize; ++j)
+ {
+ if (childIndices[childFarSetSize + j] == indices[i])
+ {
+ // We have found a point; a swap is necessary.
+
+ // Since this point is from the near set, to preserve the near set, we
+ // must do a swap.
+ if (farSetSize > 0)
+ {
+ if ((nearSetSize - 1) != i)
+ {
+ // In this case it must be a three-way swap.
+ size_t tempIndex = indices[nearSetSize + farSetSize - 1];
+ double tempDist = distances[nearSetSize + farSetSize - 1];
+
+ size_t tempNearIndex = indices[nearSetSize - 1];
+ double tempNearDist = distances[nearSetSize - 1];
+
+ indices[nearSetSize + farSetSize - 1] = indices[i];
+ distances[nearSetSize + farSetSize - 1] = distances[i];
+
+ indices[nearSetSize - 1] = tempIndex;
+ distances[nearSetSize - 1] = tempDist;
+
+ indices[i] = tempNearIndex;
+ distances[i] = tempNearDist;
+ }
+ else
+ {
+ // We can do a two-way swap.
+ size_t tempIndex = indices[nearSetSize + farSetSize - 1];
+ double tempDist = distances[nearSetSize + farSetSize - 1];
+
+ indices[nearSetSize + farSetSize - 1] = indices[i];
+ distances[nearSetSize + farSetSize - 1] = distances[i];
+
+ indices[i] = tempIndex;
+ distances[i] = tempDist;
+ }
+ }
+ else if ((nearSetSize - 1) != i)
+ {
+ // A two-way swap is possible.
+ size_t tempIndex = indices[nearSetSize + farSetSize - 1];
+ double tempDist = distances[nearSetSize + farSetSize - 1];
+
+ indices[nearSetSize + farSetSize - 1] = indices[i];
+ distances[nearSetSize + farSetSize - 1] = distances[i];
+
+ indices[i] = tempIndex;
+ distances[i] = tempDist;
+ }
+ else
+ {
+ // No swap is necessary.
+ }
+
+ // We don't need to do a complete preservation of the child index set,
+ // but we want to make sure we only loop over points we haven't seen.
+ // So increment the child counter by 1 and move a point if we need.
+ if (j != startChildUsedSet)
+ {
+ childIndices[childFarSetSize + j] = childIndices[childFarSetSize +
+ startChildUsedSet];
+ }
+
+ // Update all counters from the swaps we have done.
+ ++startChildUsedSet;
+ --nearSetSize;
+ --i; // Since we moved a point out of the near set we must step back.
+
+ break; // Break out of this for loop; back to the first one.
+ }
+ }
+ }
+
+ // Now loop over the far set. This loop is different because we only require
+ // a normal two-way swap instead of the three-way swap to preserve the near
+ // set / far set ordering.
+ for (size_t i = 0; i < farSetSize; ++i)
+ {
+ // Discover if this point was in the child's used set.
+ for (size_t j = startChildUsedSet; j < childUsedSetSize; ++j)
+ {
+ if (childIndices[childFarSetSize + j] == indices[i + nearSetSize])
+ {
+ // We have found a point to swap.
+
+ // Perform the swap.
+ size_t tempIndex = indices[nearSetSize + farSetSize - 1];
+ double tempDist = distances[nearSetSize + farSetSize - 1];
+
+ indices[nearSetSize + farSetSize - 1] = indices[nearSetSize + i];
+ distances[nearSetSize + farSetSize - 1] = distances[nearSetSize + i];
+
+ indices[nearSetSize + i] = tempIndex;
+ distances[nearSetSize + i] = tempDist;
+
+ if (j != startChildUsedSet)
+ {
+ childIndices[childFarSetSize + j] = childIndices[childFarSetSize +
+ startChildUsedSet];
+ }
+
+ // Update all counters from the swaps we have done.
+ ++startChildUsedSet;
+ --farSetSize;
+ --i;
+
+ break; // Break out of this for loop; back to the first one.
+ }
+ }
+ }
+
+ // Update used set size.
+ usedSetSize += childUsedSetSize;
+
+ Log::Assert(originalSum == (nearSetSize + farSetSize + usedSetSize));
+}
+
+template<typename MetricType, typename RootPointPolicy, typename StatisticType>
+size_t CoverTree<MetricType, RootPointPolicy, StatisticType>::PruneFarSet(
+ arma::Col<size_t>& indices,
+ arma::vec& distances,
+ const double bound,
+ const size_t nearSetSize,
+ const size_t pointSetSize)
+{
+ // What we are trying to do is remove any points greater than the bound from
+ // the far set. We don't care what happens to those indices and distances...
+ // so, we don't need to properly swap points -- just drop new ones in place.
+ size_t left = nearSetSize;
+ size_t right = pointSetSize - 1;
+ while ((distances[left] <= bound) && (left != right))
+ ++left;
+ while ((distances[right] > bound) && (left != right))
+ --right;
+
+ while (left != right)
+ {
+ // We don't care what happens to the point which should be on the right.
+ indices[left] = indices[right];
+ distances[left] = distances[right];
+ --right; // Since we aren't changing the right.
+
+ // Advance to next location which needs to switch.
+ while ((distances[left] <= bound) && (left != right))
+ ++left;
+ while ((distances[right] > bound) && (left != right))
+ --right;
+ }
+
+ // The far set size is the left pointer, with the near set size subtracted
+ // from it.
+ return (left - nearSetSize);
+}
+
+}; // namespace tree
+}; // namespace mlpack
+
+#endif
Deleted: mlpack/tags/mlpack-1.0.2/src/mlpack/core/tree/cover_tree/dual_tree_traverser_impl.hpp
===================================================================
--- mlpack/trunk/src/mlpack/core/tree/cover_tree/dual_tree_traverser_impl.hpp 2012-08-15 16:44:38 UTC (rev 13396)
+++ mlpack/tags/mlpack-1.0.2/src/mlpack/core/tree/cover_tree/dual_tree_traverser_impl.hpp 2012-08-15 17:40:09 UTC (rev 13406)
@@ -1,499 +0,0 @@
-/**
- * @file dual_tree_traverser_impl.hpp
- * @author Ryan Curtin
- *
- * A dual-tree traverser for the cover tree.
- */
-#ifndef __MLPACK_CORE_TREE_COVER_TREE_DUAL_TREE_TRAVERSER_IMPL_HPP
-#define __MLPACK_CORE_TREE_COVER_TREE_DUAL_TREE_TRAVERSER_IMPL_HPP
-
-#include <mlpack/core.hpp>
-#include <queue>
-
-namespace mlpack {
-namespace tree {
-
-//! The object placed in the map for tree traversal.
-template<typename MetricType, typename RootPointPolicy, typename StatisticType>
-struct DualCoverTreeMapEntry
-{
- //! The node this entry refers to.
- CoverTree<MetricType, RootPointPolicy, StatisticType>* referenceNode;
- //! The score of the node.
- double score;
- //! The index of the reference node used for the base case evaluation.
- size_t referenceIndex;
- //! The index of the query node used for the base case evaluation.
- size_t queryIndex;
- //! The base case evaluation.
- double baseCase;
-
- //! Comparison operator, for sorting within the map.
- bool operator<(const DualCoverTreeMapEntry& other) const
- {
- return (score < other.score);
- }
-};
-
-template<typename MetricType, typename RootPointPolicy, typename StatisticType>
-template<typename RuleType>
-CoverTree<MetricType, RootPointPolicy, StatisticType>::
-DualTreeTraverser<RuleType>::DualTreeTraverser(RuleType& rule) :
- rule(rule),
- numPrunes(0)
-{ /* Nothing to do. */ }
-
-template<typename MetricType, typename RootPointPolicy, typename StatisticType>
-template<typename RuleType>
-void CoverTree<MetricType, RootPointPolicy, StatisticType>::
-DualTreeTraverser<RuleType>::Traverse(
- CoverTree<MetricType, RootPointPolicy, StatisticType>& queryNode,
- CoverTree<MetricType, RootPointPolicy, StatisticType>& referenceNode)
-{
- typedef DualCoverTreeMapEntry<MetricType, RootPointPolicy, StatisticType>
- MapEntryType;
-
- // Start by creating a map and adding the reference node to it.
- std::map<int, std::vector<MapEntryType> > refMap;
-
- MapEntryType rootRefEntry;
-
- rootRefEntry.referenceNode = &referenceNode;
- rootRefEntry.score = 0.0; // Must recurse into.
- rootRefEntry.referenceIndex = referenceNode.Point();
- rootRefEntry.queryIndex = queryNode.Point();
- rootRefEntry.baseCase = rule.BaseCase(queryNode.Point(),
- referenceNode.Point());
- rule.UpdateAfterRecursion(queryNode, referenceNode);
-
- refMap[referenceNode.Scale()].push_back(rootRefEntry);
-
- Traverse(queryNode, refMap);
-}
-
-template<typename MetricType, typename RootPointPolicy, typename StatisticType>
-template<typename RuleType>
-void CoverTree<MetricType, RootPointPolicy, StatisticType>::
-DualTreeTraverser<RuleType>::Traverse(
- CoverTree<MetricType, RootPointPolicy, StatisticType>& queryNode,
- std::map<int, std::vector<DualCoverTreeMapEntry<MetricType, RootPointPolicy,
- StatisticType> > >& referenceMap)
-{
-// Log::Debug << "Recursed into query node " << queryNode.Point() << ", scale "
-// << queryNode.Scale() << "\n";
-
- // Convenience typedef.
- typedef DualCoverTreeMapEntry<MetricType, RootPointPolicy, StatisticType>
- MapEntryType;
-
- if (referenceMap.size() == 0)
- return; // Nothing to do!
-
- // First recurse down the reference nodes as necessary.
- ReferenceRecursion(queryNode, referenceMap);
-
- // Now, reduce the scale of the query node by recursing. But we can't recurse
- // if the query node is a leaf node.
- if ((queryNode.Scale() != INT_MIN) &&
- (queryNode.Scale() >= (*referenceMap.rbegin()).first))
- {
- // Recurse into the non-self-children first.
- for (size_t i = 1; i < queryNode.NumChildren(); ++i)
- {
- std::map<int, std::vector<MapEntryType> > childMap;
- PruneMap(queryNode.Child(i), referenceMap, childMap);
-
- Log::Debug << "Recurse into query child " << i << ": " <<
- queryNode.Child(i).Point() << " scale " << queryNode.Child(i).Scale()
- << "; this parent is " << queryNode.Point() << " scale " <<
- queryNode.Scale() << std::endl;
- Traverse(queryNode.Child(i), childMap);
- }
-
- PruneMapForSelfChild(queryNode.Child(0), referenceMap);
-
- // Now we can use the existing map (without a copy) for the self-child.
- Log::Warn << "Recurse into query self-child " << queryNode.Child(0).Point()
- << " scale " << queryNode.Child(0).Scale() << "; this parent is "
- << queryNode.Point() << " scale " << queryNode.Scale() << std::endl;
- Traverse(queryNode.Child(0), referenceMap);
- }
-
- if (queryNode.Scale() != INT_MIN)
- return; // No need to evaluate base cases at this level. It's all done.
-
- // If we have made it this far, all we have is a bunch of base case
- // evaluations to do.
- Log::Assert((*referenceMap.begin()).first == INT_MIN);
- Log::Assert(queryNode.Scale() == INT_MIN);
- std::vector<MapEntryType>& pointVector = (*referenceMap.begin()).second;
-// Log::Debug << "Onto base case evaluations\n";
-
- for (size_t i = 0; i < pointVector.size(); ++i)
- {
- // Get a reference to the frame.
- const MapEntryType& frame = pointVector[i];
-
- CoverTree<MetricType, RootPointPolicy, StatisticType>* refNode =
- frame.referenceNode;
- const double oldScore = frame.score;
- const size_t refIndex = frame.referenceIndex;
- const size_t queryIndex = frame.queryIndex;
-// Log::Debug << "Consider query " << queryNode.Point() << ", reference "
-// << refNode->Point() << "\n";
-// Log::Debug << "Old score " << oldScore << " with refParent " << refIndex
-// << " and parent query node " << queryIndex << "\n";
-
- // First, ensure that we have not already calculated the base case.
- if ((refIndex == refNode->Point()) && (queryIndex == queryNode.Point()))
- {
-// Log::Debug << "Pruned because we already did the base case and its value "
-// << " was " << frame.baseCase << std::endl;
- ++numPrunes;
- continue;
- }
-
- // Now, check if we can prune it.
- const double rescore = rule.Rescore(queryNode, *refNode, oldScore);
-
- if (rescore == DBL_MAX)
- {
-// Log::Debug << "Pruned after rescoring\n";
- ++numPrunes;
- continue;
- }
-
- // If not, compute the base case.
-// Log::Debug << "Not pruned, performing base case " << queryNode.Point() <<
-// " " << pointVector[i].referenceNode->Point() << "\n";
- rule.BaseCase(queryNode.Point(), pointVector[i].referenceNode->Point());
- rule.UpdateAfterRecursion(queryNode, *pointVector[i].referenceNode);
-// Log::Debug << "Bound for point " << queryNode.Point() << " scale " <<
-// queryNode.Scale() << " is now " << queryNode.Stat().Bound() <<
-// std::endl;
- }
-}
-
-template<typename MetricType, typename RootPointPolicy, typename StatisticType>
-template<typename RuleType>
-void CoverTree<MetricType, RootPointPolicy, StatisticType>::
-DualTreeTraverser<RuleType>::PruneMap(
- CoverTree& candidateQueryNode,
- std::map<int, std::vector<DualCoverTreeMapEntry<MetricType,
- RootPointPolicy, StatisticType> > >& referenceMap,
- std::map<int, std::vector<DualCoverTreeMapEntry<MetricType,
- RootPointPolicy, StatisticType> > >& childMap)
-{
-// Log::Debug << "Prep for recurse into query child point " <<
-// queryNode.Child(i).Point() << " scale " << queryNode.Child(i).Scale()
-// << std::endl;
- typedef DualCoverTreeMapEntry<MetricType, RootPointPolicy, StatisticType>
- MapEntryType;
-
- typename std::map<int, std::vector<MapEntryType> >::reverse_iterator it =
- referenceMap.rbegin();
-
- while ((it != referenceMap.rend()) && ((*it).first != INT_MIN))
- {
- // Get a reference to the vector representing the entries at this scale.
- const std::vector<MapEntryType>& scaleVector = (*it).second;
-
- std::vector<MapEntryType>& newScaleVector = childMap[(*it).first];
- newScaleVector.reserve(scaleVector.size()); // Maximum possible size.
-
- // Loop over each entry in the vector.
- for (size_t j = 0; j < scaleVector.size(); ++j)
- {
- const MapEntryType& frame = scaleVector[j];
-
- // First evaluate if we can prune without performing the base case.
- CoverTree<MetricType, RootPointPolicy, StatisticType>* refNode =
- frame.referenceNode;
- const double oldScore = frame.score;
-
-// Log::Debug << "Recheck reference node " << refNode->Point() <<
-// " scale " << refNode->Scale() << " which has old score " <<
-// oldScore << " with old reference index " << frame.referenceIndex
-// << " and old query index " << frame.queryIndex << std::endl;
-
- double score = rule.Rescore(candidateQueryNode, *refNode, oldScore);
-
-// Log::Debug << "Rescored as " << score << std::endl;
-
- if (score == DBL_MAX)
- {
- // Pruned. Move on.
- ++numPrunes;
- continue;
- }
-
- // Evaluate base case.
-// Log::Debug << "Must evaluate base case " << queryNode.Child(i).Point()
-// << " " << refNode->Point() << "\n";
- double baseCase = rule.BaseCase(candidateQueryNode.Point(),
- refNode->Point());
- rule.UpdateAfterRecursion(candidateQueryNode, *refNode);
-// Log::Debug << "Base case was " << baseCase << std::endl;
-
- score = rule.Score(candidateQueryNode, *refNode, baseCase);
-
- if (score == DBL_MAX)
- {
- // Pruned. Move on.
- ++numPrunes;
- continue;
- }
-
- // Add to child map.
- newScaleVector.push_back(frame);
- newScaleVector.back().score = score;
- newScaleVector.back().baseCase = baseCase;
- newScaleVector.back().referenceIndex = refNode->Point();
- newScaleVector.back().queryIndex = candidateQueryNode.Point();
- }
-
- // If we didn't add anything, then strike this vector from the map.
- if (newScaleVector.size() == 0)
- childMap.erase((*it).first);
-
- ++it; // Advance to next scale.
- }
-
- childMap[INT_MIN] = referenceMap[INT_MIN];
-}
-
-template<typename MetricType, typename RootPointPolicy, typename StatisticType>
-template<typename RuleType>
-void CoverTree<MetricType, RootPointPolicy, StatisticType>::
-DualTreeTraverser<RuleType>::PruneMapForSelfChild(
- CoverTree& candidateQueryNode,
- std::map<int, std::vector<DualCoverTreeMapEntry<MetricType, RootPointPolicy,
- StatisticType> > >& referenceMap)
-{
-// Log::Debug << "Prep for recurse into query self-child point " <<
-// queryNode.Child(0).Point() << " scale " << queryNode.Child(0).Scale()
-// << std::endl;
- typedef DualCoverTreeMapEntry<MetricType, RootPointPolicy, StatisticType>
- MapEntryType;
-
- // Create the child reference map. We will do this by recursing through
- // every entry in the reference map and evaluating (or pruning) it. But
- // in this setting we do not recurse into any children of the reference
- // entries.
- typename std::map<int, std::vector<MapEntryType> >::reverse_iterator it =
- referenceMap.rbegin();
-
- while (it != referenceMap.rend() && (*it).first != INT_MIN)
- {
- // Get a reference to the vector representing the entries at this scale.
- std::vector<MapEntryType>& newScaleVector = (*it).second;
- const std::vector<MapEntryType> scaleVector = newScaleVector;
-
- newScaleVector.clear();
- newScaleVector.reserve(scaleVector.size());
-
- // Loop over each entry in the vector.
- for (size_t j = 0; j < scaleVector.size(); ++j)
- {
- const MapEntryType& frame = scaleVector[j];
-
- // First evaluate if we can prune without performing the base case.
- CoverTree<MetricType, RootPointPolicy, StatisticType>* refNode =
- frame.referenceNode;
- const double oldScore = frame.score;
- double baseCase = frame.baseCase;
- const size_t queryIndex = frame.queryIndex;
- const size_t refIndex = frame.referenceIndex;
-
- // Log::Debug << "Recheck reference node " << refNode->Point() <<
- // " scale " << refNode->Scale() << " which has old score " <<
- // oldScore << std::endl;
-
- // Have we performed the base case yet?
- double score;
- if ((refIndex != refNode->Point()) ||
- (queryIndex != candidateQueryNode.Point()))
- {
- // Attempt to rescore before performing the base case.
- score = rule.Rescore(candidateQueryNode, *refNode, oldScore);
-
- if (score == DBL_MAX)
- {
- ++numPrunes;
- continue;
- }
-
- // If not pruned, we have to perform the base case.
- baseCase = rule.BaseCase(candidateQueryNode.Point(), refNode->Point());
- rule.UpdateAfterRecursion(candidateQueryNode, *refNode);
- }
-
- score = rule.Score(candidateQueryNode, *refNode, score);
-
- // Log::Debug << "Rescored as " << score << std::endl;
-
- if (score == DBL_MAX)
- {
- // Pruned. Move on.
- ++numPrunes;
- continue;
- }
-
- // Log::Debug << "Kept in map\n";
-
- // Add to child map.
- newScaleVector.push_back(frame);
- newScaleVector.back().score = score;
- newScaleVector.back().baseCase = baseCase;
- newScaleVector.back().queryIndex = candidateQueryNode.Point();
- newScaleVector.back().referenceIndex = refNode->Point();
- }
-
- // If we didn't add anything, then strike this vector from the map.
- if (newScaleVector.size() == 0)
- referenceMap.erase((*it).first);
-
- ++it; // Advance to next scale.
- }
-}
-
-template<typename MetricType, typename RootPointPolicy, typename StatisticType>
-template<typename RuleType>
-void CoverTree<MetricType, RootPointPolicy, StatisticType>::
-DualTreeTraverser<RuleType>::ReferenceRecursion(
- CoverTree& queryNode,
- std::map<int, std::vector<DualCoverTreeMapEntry<MetricType, RootPointPolicy,
- StatisticType> > >& referenceMap)
-{
- typedef DualCoverTreeMapEntry<MetricType, RootPointPolicy, StatisticType>
- MapEntryType;
-
- // First, reduce the maximum scale in the reference map down to the scale of
- // the query node.
- while ((*referenceMap.rbegin()).first > queryNode.Scale())
- {
- // Get a reference to the current largest scale.
- std::vector<MapEntryType>& scaleVector = (*referenceMap.rbegin()).second;
-
- // Before traversing all the points in this scale, sort by score.
- std::sort(scaleVector.begin(), scaleVector.end());
-
- // Now loop over each element.
- for (size_t i = 0; i < scaleVector.size(); ++i)
- {
- // Get a reference to the current element.
- const MapEntryType& frame = scaleVector.at(i);
-
- CoverTree<MetricType, RootPointPolicy, StatisticType>* refNode =
- frame.referenceNode;
- const double score = frame.score;
- const size_t refIndex = frame.referenceIndex;
- const size_t refPoint = refNode->Point();
- const size_t queryIndex = frame.queryIndex;
- double baseCase = frame.baseCase;
-
-// Log::Debug << "Currently inspecting reference node " << refNode->Point()
-// << " scale " << refNode->Scale() << " earlier query index " <<
-// queryIndex << std::endl;
-
-// Log::Debug << "Old score " << score << " with refParent " << refIndex
-// << " and queryIndex " << queryIndex << "\n";
-
- // First we recalculate the score of this node to find if we can prune it.
- if (rule.Rescore(queryNode, *refNode, score) == DBL_MAX)
- {
- // Log::Warn << "Pruned after rescore\n";
- ++numPrunes;
- continue;
- }
-
- // If this is a self-child, the base case has already been evaluated.
- // We also must ensure that the base case was evaluated with this query
- // point.
- if ((refPoint != refIndex) || (queryNode.Point() != queryIndex))
- {
-// Log::Warn << "Must evaluate base case " << queryNode.Point() << " "
-// << refPoint << "\n";
- baseCase = rule.BaseCase(queryNode.Point(), refPoint);
-// Log::Debug << "Base case " << baseCase << std::endl;
- rule.UpdateAfterRecursion(queryNode, *refNode); // Kludgey.
-// Log::Debug << "Bound for point " << queryNode.Point() << " scale " <<
-// queryNode.Scale() << " is now " << queryNode.Stat().Bound() <<
-// std::endl;
- }
-
- // Create the score for the children.
- double childScore = rule.Score(queryNode, *refNode, baseCase);
-
- // Now if this childScore is DBL_MAX we can prune all children. In this
- // recursion setup pruning is all or nothing for children.
- if (childScore == DBL_MAX)
- {
-// Log::Warn << "Pruned all children.\n";
- numPrunes += refNode->NumChildren();
- continue;
- }
-
- // We must treat the self-leaf differently. The base case has already
- // been performed.
- childScore = rule.Score(queryNode, refNode->Child(0), baseCase);
-
- if (childScore != DBL_MAX)
- {
- MapEntryType newFrame;
- newFrame.referenceNode = &refNode->Child(0);
- newFrame.score = childScore;
- newFrame.baseCase = baseCase;
- newFrame.referenceIndex = refPoint;
- newFrame.queryIndex = queryNode.Point();
-
- referenceMap[newFrame.referenceNode->Scale()].push_back(newFrame);
- }
- else
- {
- ++numPrunes;
- }
-
- // Add the non-self-leaf children.
- for (size_t j = 1; j < refNode->NumChildren(); ++j)
- {
- const size_t queryIndex = queryNode.Point();
- const size_t refIndex = refNode->Child(j).Point();
-
- // Calculate the base case of each child.
- baseCase = rule.BaseCase(queryIndex, refIndex);
- rule.UpdateAfterRecursion(queryNode, refNode->Child(j));
-
- // See if we can prune it.
- childScore = rule.Score(queryNode, refNode->Child(j), baseCase);
-
- if (childScore == DBL_MAX)
- {
- ++numPrunes;
- continue;
- }
-
- MapEntryType newFrame;
- newFrame.referenceNode = &refNode->Child(j);
- newFrame.score = childScore;
- newFrame.baseCase = baseCase;
- newFrame.referenceIndex = refIndex;
- newFrame.queryIndex = queryIndex;
-
-// Log::Debug << "Push onto map child " << refNode->Child(j).Point() <<
-// " scale " << refNode->Child(j).Scale() << std::endl;
-
- referenceMap[newFrame.referenceNode->Scale()].push_back(newFrame);
- }
- }
-
- // Now clear the memory for this scale; it isn't needed anymore.
- referenceMap.erase((*referenceMap.rbegin()).first);
- }
-
-}
-
-}; // namespace tree
-}; // namespace mlpack
-
-#endif
Copied: mlpack/tags/mlpack-1.0.2/src/mlpack/core/tree/cover_tree/dual_tree_traverser_impl.hpp (from rev 13405, mlpack/trunk/src/mlpack/core/tree/cover_tree/dual_tree_traverser_impl.hpp)
===================================================================
--- mlpack/tags/mlpack-1.0.2/src/mlpack/core/tree/cover_tree/dual_tree_traverser_impl.hpp (rev 0)
+++ mlpack/tags/mlpack-1.0.2/src/mlpack/core/tree/cover_tree/dual_tree_traverser_impl.hpp 2012-08-15 17:40:09 UTC (rev 13406)
@@ -0,0 +1,499 @@
+/**
+ * @file dual_tree_traverser_impl.hpp
+ * @author Ryan Curtin
+ *
+ * A dual-tree traverser for the cover tree.
+ */
+#ifndef __MLPACK_CORE_TREE_COVER_TREE_DUAL_TREE_TRAVERSER_IMPL_HPP
+#define __MLPACK_CORE_TREE_COVER_TREE_DUAL_TREE_TRAVERSER_IMPL_HPP
+
+#include <mlpack/core.hpp>
+#include <queue>
+
+namespace mlpack {
+namespace tree {
+
+//! The object placed in the map for tree traversal.
+template<typename MetricType, typename RootPointPolicy, typename StatisticType>
+struct DualCoverTreeMapEntry
+{
+ //! The node this entry refers to.
+ CoverTree<MetricType, RootPointPolicy, StatisticType>* referenceNode;
+ //! The score of the node.
+ double score;
+ //! The index of the reference node used for the base case evaluation.
+ size_t referenceIndex;
+ //! The index of the query node used for the base case evaluation.
+ size_t queryIndex;
+ //! The base case evaluation.
+ double baseCase;
+
+ //! Comparison operator, for sorting within the map.
+ bool operator<(const DualCoverTreeMapEntry& other) const
+ {
+ return (score < other.score);
+ }
+};
+
+template<typename MetricType, typename RootPointPolicy, typename StatisticType>
+template<typename RuleType>
+CoverTree<MetricType, RootPointPolicy, StatisticType>::
+DualTreeTraverser<RuleType>::DualTreeTraverser(RuleType& rule) :
+ rule(rule),
+ numPrunes(0)
+{ /* Nothing to do. */ }
+
+template<typename MetricType, typename RootPointPolicy, typename StatisticType>
+template<typename RuleType>
+void CoverTree<MetricType, RootPointPolicy, StatisticType>::
+DualTreeTraverser<RuleType>::Traverse(
+ CoverTree<MetricType, RootPointPolicy, StatisticType>& queryNode,
+ CoverTree<MetricType, RootPointPolicy, StatisticType>& referenceNode)
+{
+ typedef DualCoverTreeMapEntry<MetricType, RootPointPolicy, StatisticType>
+ MapEntryType;
+
+ // Start by creating a map and adding the reference node to it.
+ std::map<int, std::vector<MapEntryType> > refMap;
+
+ MapEntryType rootRefEntry;
+
+ rootRefEntry.referenceNode = &referenceNode;
+ rootRefEntry.score = 0.0; // Must recurse into.
+ rootRefEntry.referenceIndex = referenceNode.Point();
+ rootRefEntry.queryIndex = queryNode.Point();
+ rootRefEntry.baseCase = rule.BaseCase(queryNode.Point(),
+ referenceNode.Point());
+ rule.UpdateAfterRecursion(queryNode, referenceNode);
+
+ refMap[referenceNode.Scale()].push_back(rootRefEntry);
+
+ Traverse(queryNode, refMap);
+}
+
+template<typename MetricType, typename RootPointPolicy, typename StatisticType>
+template<typename RuleType>
+void CoverTree<MetricType, RootPointPolicy, StatisticType>::
+DualTreeTraverser<RuleType>::Traverse(
+ CoverTree<MetricType, RootPointPolicy, StatisticType>& queryNode,
+ std::map<int, std::vector<DualCoverTreeMapEntry<MetricType, RootPointPolicy,
+ StatisticType> > >& referenceMap)
+{
+// Log::Debug << "Recursed into query node " << queryNode.Point() << ", scale "
+// << queryNode.Scale() << "\n";
+
+ // Convenience typedef.
+ typedef DualCoverTreeMapEntry<MetricType, RootPointPolicy, StatisticType>
+ MapEntryType;
+
+ if (referenceMap.size() == 0)
+ return; // Nothing to do!
+
+ // First recurse down the reference nodes as necessary.
+ ReferenceRecursion(queryNode, referenceMap);
+
+ // Now, reduce the scale of the query node by recursing. But we can't recurse
+ // if the query node is a leaf node.
+ if ((queryNode.Scale() != INT_MIN) &&
+ (queryNode.Scale() >= (*referenceMap.rbegin()).first))
+ {
+ // Recurse into the non-self-children first.
+ for (size_t i = 1; i < queryNode.NumChildren(); ++i)
+ {
+ std::map<int, std::vector<MapEntryType> > childMap;
+ PruneMap(queryNode.Child(i), referenceMap, childMap);
+
+// Log::Debug << "Recurse into query child " << i << ": " <<
+// queryNode.Child(i).Point() << " scale " << queryNode.Child(i).Scale()
+// << "; this parent is " << queryNode.Point() << " scale " <<
+// queryNode.Scale() << std::endl;
+ Traverse(queryNode.Child(i), childMap);
+ }
+
+ PruneMapForSelfChild(queryNode.Child(0), referenceMap);
+
+ // Now we can use the existing map (without a copy) for the self-child.
+// Log::Warn << "Recurse into query self-child " << queryNode.Child(0).Point()
+// << " scale " << queryNode.Child(0).Scale() << "; this parent is "
+// << queryNode.Point() << " scale " << queryNode.Scale() << std::endl;
+ Traverse(queryNode.Child(0), referenceMap);
+ }
+
+ if (queryNode.Scale() != INT_MIN)
+ return; // No need to evaluate base cases at this level. It's all done.
+
+ // If we have made it this far, all we have is a bunch of base case
+ // evaluations to do.
+ Log::Assert((*referenceMap.begin()).first == INT_MIN);
+ Log::Assert(queryNode.Scale() == INT_MIN);
+ std::vector<MapEntryType>& pointVector = (*referenceMap.begin()).second;
+// Log::Debug << "Onto base case evaluations\n";
+
+ for (size_t i = 0; i < pointVector.size(); ++i)
+ {
+ // Get a reference to the frame.
+ const MapEntryType& frame = pointVector[i];
+
+ CoverTree<MetricType, RootPointPolicy, StatisticType>* refNode =
+ frame.referenceNode;
+ const double oldScore = frame.score;
+ const size_t refIndex = frame.referenceIndex;
+ const size_t queryIndex = frame.queryIndex;
+// Log::Debug << "Consider query " << queryNode.Point() << ", reference "
+// << refNode->Point() << "\n";
+// Log::Debug << "Old score " << oldScore << " with refParent " << refIndex
+// << " and parent query node " << queryIndex << "\n";
+
+ // First, ensure that we have not already calculated the base case.
+ if ((refIndex == refNode->Point()) && (queryIndex == queryNode.Point()))
+ {
+// Log::Debug << "Pruned because we already did the base case and its value "
+// << " was " << frame.baseCase << std::endl;
+ ++numPrunes;
+ continue;
+ }
+
+ // Now, check if we can prune it.
+ const double rescore = rule.Rescore(queryNode, *refNode, oldScore);
+
+ if (rescore == DBL_MAX)
+ {
+// Log::Debug << "Pruned after rescoring\n";
+ ++numPrunes;
+ continue;
+ }
+
+ // If not, compute the base case.
+// Log::Debug << "Not pruned, performing base case " << queryNode.Point() <<
+// " " << pointVector[i].referenceNode->Point() << "\n";
+ rule.BaseCase(queryNode.Point(), pointVector[i].referenceNode->Point());
+ rule.UpdateAfterRecursion(queryNode, *pointVector[i].referenceNode);
+// Log::Debug << "Bound for point " << queryNode.Point() << " scale " <<
+// queryNode.Scale() << " is now " << queryNode.Stat().Bound() <<
+// std::endl;
+ }
+}
+
+template<typename MetricType, typename RootPointPolicy, typename StatisticType>
+template<typename RuleType>
+void CoverTree<MetricType, RootPointPolicy, StatisticType>::
+DualTreeTraverser<RuleType>::PruneMap(
+ CoverTree& candidateQueryNode,
+ std::map<int, std::vector<DualCoverTreeMapEntry<MetricType,
+ RootPointPolicy, StatisticType> > >& referenceMap,
+ std::map<int, std::vector<DualCoverTreeMapEntry<MetricType,
+ RootPointPolicy, StatisticType> > >& childMap)
+{
+// Log::Debug << "Prep for recurse into query child point " <<
+// queryNode.Child(i).Point() << " scale " << queryNode.Child(i).Scale()
+// << std::endl;
+ typedef DualCoverTreeMapEntry<MetricType, RootPointPolicy, StatisticType>
+ MapEntryType;
+
+ typename std::map<int, std::vector<MapEntryType> >::reverse_iterator it =
+ referenceMap.rbegin();
+
+ while ((it != referenceMap.rend()) && ((*it).first != INT_MIN))
+ {
+ // Get a reference to the vector representing the entries at this scale.
+ const std::vector<MapEntryType>& scaleVector = (*it).second;
+
+ std::vector<MapEntryType>& newScaleVector = childMap[(*it).first];
+ newScaleVector.reserve(scaleVector.size()); // Maximum possible size.
+
+ // Loop over each entry in the vector.
+ for (size_t j = 0; j < scaleVector.size(); ++j)
+ {
+ const MapEntryType& frame = scaleVector[j];
+
+ // First evaluate if we can prune without performing the base case.
+ CoverTree<MetricType, RootPointPolicy, StatisticType>* refNode =
+ frame.referenceNode;
+ const double oldScore = frame.score;
+
+// Log::Debug << "Recheck reference node " << refNode->Point() <<
+// " scale " << refNode->Scale() << " which has old score " <<
+// oldScore << " with old reference index " << frame.referenceIndex
+// << " and old query index " << frame.queryIndex << std::endl;
+
+ double score = rule.Rescore(candidateQueryNode, *refNode, oldScore);
+
+// Log::Debug << "Rescored as " << score << std::endl;
+
+ if (score == DBL_MAX)
+ {
+ // Pruned. Move on.
+ ++numPrunes;
+ continue;
+ }
+
+ // Evaluate base case.
+// Log::Debug << "Must evaluate base case " << queryNode.Child(i).Point()
+// << " " << refNode->Point() << "\n";
+ double baseCase = rule.BaseCase(candidateQueryNode.Point(),
+ refNode->Point());
+ rule.UpdateAfterRecursion(candidateQueryNode, *refNode);
+// Log::Debug << "Base case was " << baseCase << std::endl;
+
+ score = rule.Score(candidateQueryNode, *refNode, baseCase);
+
+ if (score == DBL_MAX)
+ {
+ // Pruned. Move on.
+ ++numPrunes;
+ continue;
+ }
+
+ // Add to child map.
+ newScaleVector.push_back(frame);
+ newScaleVector.back().score = score;
+ newScaleVector.back().baseCase = baseCase;
+ newScaleVector.back().referenceIndex = refNode->Point();
+ newScaleVector.back().queryIndex = candidateQueryNode.Point();
+ }
+
+ // If we didn't add anything, then strike this vector from the map.
+ if (newScaleVector.size() == 0)
+ childMap.erase((*it).first);
+
+ ++it; // Advance to next scale.
+ }
+
+ childMap[INT_MIN] = referenceMap[INT_MIN];
+}
+
+template<typename MetricType, typename RootPointPolicy, typename StatisticType>
+template<typename RuleType>
+void CoverTree<MetricType, RootPointPolicy, StatisticType>::
+DualTreeTraverser<RuleType>::PruneMapForSelfChild(
+ CoverTree& candidateQueryNode,
+ std::map<int, std::vector<DualCoverTreeMapEntry<MetricType, RootPointPolicy,
+ StatisticType> > >& referenceMap)
+{
+// Log::Debug << "Prep for recurse into query self-child point " <<
+// queryNode.Child(0).Point() << " scale " << queryNode.Child(0).Scale()
+// << std::endl;
+ typedef DualCoverTreeMapEntry<MetricType, RootPointPolicy, StatisticType>
+ MapEntryType;
+
+ // Create the child reference map. We will do this by recursing through
+ // every entry in the reference map and evaluating (or pruning) it. But
+ // in this setting we do not recurse into any children of the reference
+ // entries.
+ typename std::map<int, std::vector<MapEntryType> >::reverse_iterator it =
+ referenceMap.rbegin();
+
+ while (it != referenceMap.rend() && (*it).first != INT_MIN)
+ {
+ // Get a reference to the vector representing the entries at this scale.
+ std::vector<MapEntryType>& newScaleVector = (*it).second;
+ const std::vector<MapEntryType> scaleVector = newScaleVector;
+
+ newScaleVector.clear();
+ newScaleVector.reserve(scaleVector.size());
+
+ // Loop over each entry in the vector.
+ for (size_t j = 0; j < scaleVector.size(); ++j)
+ {
+ const MapEntryType& frame = scaleVector[j];
+
+ // First evaluate if we can prune without performing the base case.
+ CoverTree<MetricType, RootPointPolicy, StatisticType>* refNode =
+ frame.referenceNode;
+ const double oldScore = frame.score;
+ double baseCase = frame.baseCase;
+ const size_t queryIndex = frame.queryIndex;
+ const size_t refIndex = frame.referenceIndex;
+
+ // Log::Debug << "Recheck reference node " << refNode->Point() <<
+ // " scale " << refNode->Scale() << " which has old score " <<
+ // oldScore << std::endl;
+
+ // Have we performed the base case yet?
+ double score;
+ if ((refIndex != refNode->Point()) ||
+ (queryIndex != candidateQueryNode.Point()))
+ {
+ // Attempt to rescore before performing the base case.
+ score = rule.Rescore(candidateQueryNode, *refNode, oldScore);
+
+ if (score == DBL_MAX)
+ {
+ ++numPrunes;
+ continue;
+ }
+
+ // If not pruned, we have to perform the base case.
+ baseCase = rule.BaseCase(candidateQueryNode.Point(), refNode->Point());
+ rule.UpdateAfterRecursion(candidateQueryNode, *refNode);
+ }
+
+ score = rule.Score(candidateQueryNode, *refNode, baseCase);
+
+ // Log::Debug << "Rescored as " << score << std::endl;
+
+ if (score == DBL_MAX)
+ {
+ // Pruned. Move on.
+ ++numPrunes;
+ continue;
+ }
+
+ // Log::Debug << "Kept in map\n";
+
+ // Add to child map.
+ newScaleVector.push_back(frame);
+ newScaleVector.back().score = score;
+ newScaleVector.back().baseCase = baseCase;
+ newScaleVector.back().queryIndex = candidateQueryNode.Point();
+ newScaleVector.back().referenceIndex = refNode->Point();
+ }
+
+ // If we didn't add anything, then strike this vector from the map.
+ if (newScaleVector.size() == 0)
+ referenceMap.erase((*it).first);
+
+ ++it; // Advance to next scale.
+ }
+}
+
+template<typename MetricType, typename RootPointPolicy, typename StatisticType>
+template<typename RuleType>
+void CoverTree<MetricType, RootPointPolicy, StatisticType>::
+DualTreeTraverser<RuleType>::ReferenceRecursion(
+ CoverTree& queryNode,
+ std::map<int, std::vector<DualCoverTreeMapEntry<MetricType, RootPointPolicy,
+ StatisticType> > >& referenceMap)
+{
+ typedef DualCoverTreeMapEntry<MetricType, RootPointPolicy, StatisticType>
+ MapEntryType;
+
+ // First, reduce the maximum scale in the reference map down to the scale of
+ // the query node.
+ while ((*referenceMap.rbegin()).first > queryNode.Scale())
+ {
+ // Get a reference to the current largest scale.
+ std::vector<MapEntryType>& scaleVector = (*referenceMap.rbegin()).second;
+
+ // Before traversing all the points in this scale, sort by score.
+ std::sort(scaleVector.begin(), scaleVector.end());
+
+ // Now loop over each element.
+ for (size_t i = 0; i < scaleVector.size(); ++i)
+ {
+ // Get a reference to the current element.
+ const MapEntryType& frame = scaleVector.at(i);
+
+ CoverTree<MetricType, RootPointPolicy, StatisticType>* refNode =
+ frame.referenceNode;
+ const double score = frame.score;
+ const size_t refIndex = frame.referenceIndex;
+ const size_t refPoint = refNode->Point();
+ const size_t queryIndex = frame.queryIndex;
+ double baseCase = frame.baseCase;
+
+// Log::Debug << "Currently inspecting reference node " << refNode->Point()
+// << " scale " << refNode->Scale() << " earlier query index " <<
+// queryIndex << std::endl;
+
+// Log::Debug << "Old score " << score << " with refParent " << refIndex
+// << " and queryIndex " << queryIndex << "\n";
+
+ // First we recalculate the score of this node to find if we can prune it.
+ if (rule.Rescore(queryNode, *refNode, score) == DBL_MAX)
+ {
+ // Log::Warn << "Pruned after rescore\n";
+ ++numPrunes;
+ continue;
+ }
+
+ // If this is a self-child, the base case has already been evaluated.
+ // We also must ensure that the base case was evaluated with this query
+ // point.
+ if ((refPoint != refIndex) || (queryNode.Point() != queryIndex))
+ {
+// Log::Warn << "Must evaluate base case " << queryNode.Point() << " "
+// << refPoint << "\n";
+ baseCase = rule.BaseCase(queryNode.Point(), refPoint);
+// Log::Debug << "Base case " << baseCase << std::endl;
+ rule.UpdateAfterRecursion(queryNode, *refNode); // Kludgey.
+// Log::Debug << "Bound for point " << queryNode.Point() << " scale " <<
+// queryNode.Scale() << " is now " << queryNode.Stat().Bound() <<
+// std::endl;
+ }
+
+ // Create the score for the children.
+ double childScore = rule.Score(queryNode, *refNode, baseCase);
+
+ // Now if this childScore is DBL_MAX we can prune all children. In this
+ // recursion setup pruning is all or nothing for children.
+ if (childScore == DBL_MAX)
+ {
+// Log::Warn << "Pruned all children.\n";
+ numPrunes += refNode->NumChildren();
+ continue;
+ }
+
+ // We must treat the self-leaf differently. The base case has already
+ // been performed.
+ childScore = rule.Score(queryNode, refNode->Child(0), baseCase);
+
+ if (childScore != DBL_MAX)
+ {
+ MapEntryType newFrame;
+ newFrame.referenceNode = &refNode->Child(0);
+ newFrame.score = childScore;
+ newFrame.baseCase = baseCase;
+ newFrame.referenceIndex = refPoint;
+ newFrame.queryIndex = queryNode.Point();
+
+ referenceMap[newFrame.referenceNode->Scale()].push_back(newFrame);
+ }
+ else
+ {
+ ++numPrunes;
+ }
+
+ // Add the non-self-leaf children.
+ for (size_t j = 1; j < refNode->NumChildren(); ++j)
+ {
+ const size_t queryIndex = queryNode.Point();
+ const size_t refIndex = refNode->Child(j).Point();
+
+ // Calculate the base case of each child.
+ baseCase = rule.BaseCase(queryIndex, refIndex);
+ rule.UpdateAfterRecursion(queryNode, refNode->Child(j));
+
+ // See if we can prune it.
+ childScore = rule.Score(queryNode, refNode->Child(j), baseCase);
+
+ if (childScore == DBL_MAX)
+ {
+ ++numPrunes;
+ continue;
+ }
+
+ MapEntryType newFrame;
+ newFrame.referenceNode = &refNode->Child(j);
+ newFrame.score = childScore;
+ newFrame.baseCase = baseCase;
+ newFrame.referenceIndex = refIndex;
+ newFrame.queryIndex = queryIndex;
+
+// Log::Debug << "Push onto map child " << refNode->Child(j).Point() <<
+// " scale " << refNode->Child(j).Scale() << std::endl;
+
+ referenceMap[newFrame.referenceNode->Scale()].push_back(newFrame);
+ }
+ }
+
+ // Now clear the memory for this scale; it isn't needed anymore.
+ referenceMap.erase((*referenceMap.rbegin()).first);
+ }
+
+}
+
+}; // namespace tree
+}; // namespace mlpack
+
+#endif
Deleted: mlpack/tags/mlpack-1.0.2/src/mlpack/core.hpp
===================================================================
--- mlpack/trunk/src/mlpack/core.hpp 2012-08-09 09:02:13 UTC (rev 13379)
+++ mlpack/tags/mlpack-1.0.2/src/mlpack/core.hpp 2012-08-15 17:40:09 UTC (rev 13406)
@@ -1,185 +0,0 @@
-/***
- * @file core.hpp
- *
- * Include all of the base components required to write MLPACK methods, and the
- * main MLPACK Doxygen documentation.
- */
-#ifndef __MLPACK_CORE_HPP
-#define __MLPACK_CORE_HPP
-
-/**
- * @mainpage MLPACK Documentation
- *
- * @section intro_sec Introduction
- *
- * MLPACK is an intuitive, fast, scalable C++ machine learning library, meant to
- * be a machine learning analog to LAPACK. It aims to implement a wide array of
- * machine learning methods and function as a "swiss army knife" for machine
- * learning researchers. The MLPACK development website can be found at
- * http://mlpack.org.
- *
- * MLPACK uses the Armadillo C++ matrix library (http://arma.sourceforge.net)
- * for general matrix, vector, and linear algebra support. MLPACK also uses the
- * program_options, math_c99, and unit_test_framework components of the Boost
- * library; in addition, LibXml2 is used.
- *
- * @section howto How To Use This Documentation
- *
- * This documentation is API documentation similar to Javadoc. It isn't
- * necessarily a tutorial, but it does provide detailed documentation on every
- * namespace, method, and class.
- *
- * Each MLPACK namespace generally refers to one machine learning method, so
- * browsing the list of namespaces provides some insight as to the breadth of
- * the methods contained in the library.
- *
- * To generate this documentation in your own local copy of MLPACK, you can
- * simply use Doxygen, from the root directory (@c /mlpack/trunk/ ):
- *
- * @code
- * $ doxygen
- * @endcode
- *
- * @section executables Executables
- *
- * MLPACK provides several executables so that MLPACK methods can be used
- * without any need for knowledge of C++. These executables are all
- * self-documented, and that documentation can be accessed by running the
- * executables with the '-h' or '--help' flag.
- *
- * A full list of executables is given below:
- *
- * allkfn, allknn, emst, gmm, hmm_train, hmm_loglik, hmm_viterbi, hmm_generate,
- * kernel_pca, kmeans, lars, linear_regression, local_coordinate_coding, mvu,
- * nbc, nca, pca, radical, sparse_coding
- *
- * @section tutorial Tutorials
- *
- * A few short tutorials on how to use MLPACK are given below.
- *
- * - @ref build
- * - @ref matrices
- * - @ref iodoc
- * - @ref timer
- * - @ref sample
- *
- * Tutorials on specific methods are also available.
- *
- * - @ref nstutorial
- * - @ref lrtutorial
- * - @ref rstutorial
- * - @ref dettutorial
- * - @ref emst_tutorial
- *
- * @section methods Methods in MLPACK
- *
- * The following methods are included in MLPACK:
- *
- * - Euclidean Minimum Spanning Trees - mlpack::emst::DualTreeBoruvka
- * - Gaussian Mixture Models (GMMs) - mlpack::gmm::GMM
- * - Hidden Markov Models (HMMs) - mlpack::hmm::HMM
- * - Kernel PCA - mlpack::kpca::KernelPCA
- * - K-Means Clustering - mlpack::kmeans::KMeans
- * - Least-Angle Regression (LARS/LASSO) - mlpack::regression::LARS
- * - Local Coordinate Coding - mlpack::lcc::LocalCoordinateCoding
- * - Naive Bayes Classifier - mlpack::naive_bayes::NaiveBayesClassifier
- * - Neighborhood Components Analysis (NCA) - mlpack::nca::NCA
- * - Principal Components Analysis (PCA) - mlpack::pca::PCA
- * - RADICAL (ICA) - mlpack::radical::Radical
- * - Simple Least-Squares Linear Regression -
- * mlpack::regression::LinearRegression
- * - Sparse Coding - mlpack::sparse_coding::SparseCoding
- * - Tree-based neighbor search (AllkNN, AllkFN) -
- * mlpack::neighbor::NeighborSearch
- * - Tree-based range search - mlpack::range::RangeSearch
- *
- * @section remarks Final Remarks
- *
- * This software was written in the FASTLab (http://www.fast-lab.org), which is
- * in the School of Computational Science and Engineering at the Georgia
- * Institute of Technology.
- *
- * MLPACK contributors include:
- *
- * - Ryan Curtin <gth671b at mail.gatech.edu>
- * - James Cline <james.cline at gatech.edu>
- * - Neil Slagle <nslagle3 at gatech.edu>
- * - Matthew Amidon <mamidon at gatech.edu>
- * - Vlad Grantcharov <vlad321 at gatech.edu>
- * - Ajinkya Kale <kaleajinkya at gmail.com>
- * - Bill March <march at gatech.edu>
- * - Dongryeol Lee <dongryel at cc.gatech.edu>
- * - Nishant Mehta <niche at cc.gatech.edu>
- * - Parikshit Ram <p.ram at gatech.edu>
- * - Chip Mappus <cmappus at gatech.edu>
- * - Hua Ouyang <houyang at gatech.edu>
- * - Long Quoc Tran <tqlong at gmail.com>
- * - Noah Kauffman <notoriousnoah at gmail.com>
- * - Guillermo Colon <gcolon7 at mail.gatech.edu>
- * - Wei Guan <wguan at cc.gatech.edu>
- * - Ryan Riegel <rriegel at cc.gatech.edu>
- * - Nikolaos Vasiloglou <nvasil at ieee.org>
- * - Garry Boyer <garryb at gmail.com>
- * - Andreas Löf <andreas.lof at cs.waikato.ac.nz>
- */
-
-// First, standard includes.
-#include <stdlib.h>
-#include <stdio.h>
-#include <string.h>
-#include <ctype.h>
-#include <limits.h>
-#include <float.h>
-#include <stdint.h>
-#include <iostream>
-
-// Defining __USE_MATH_DEFINES should set M_PI.
-#define _USE_MATH_DEFINES
-#include <math.h>
-
-// For tgamma().
-#include <boost/math/special_functions/gamma.hpp>
-
-// But if it's not defined, we'll do it.
-#ifndef M_PI
- #define M_PI 3.141592653589793238462643383279
-#endif
-
-// Now MLPACK-specific includes.
-#include <mlpack/core/arma_extend/arma_extend.hpp> // Includes Armadillo.
-#include <mlpack/core/util/log.hpp>
-#include <mlpack/core/util/cli.hpp>
-#include <mlpack/core/data/load.hpp>
-#include <mlpack/core/data/save.hpp>
-#include <mlpack/core/math/clamp.hpp>
-#include <mlpack/core/math/random.hpp>
-#include <mlpack/core/math/lin_alg.hpp>
-#include <mlpack/core/math/range.hpp>
-#include <mlpack/core/math/round.hpp>
-#include <mlpack/core/util/save_restore_utility.hpp>
-#include <mlpack/core/dists/discrete_distribution.hpp>
-#include <mlpack/core/dists/gaussian_distribution.hpp>
-
-// Clean up unfortunate Windows preprocessor definitions.
-// Use std::min and std::max!
-#ifdef _WIN32
- #ifdef min
- #undef min
- #endif
-
- #ifdef max
- #undef max
- #endif
-#endif
-
-// Give ourselves a nice way to force functions to be inline if we need.
-#define force_inline
-#if defined(__GNUG__)
- #undef force_inline
- #define force_inline __attribute__((always_inline))
-#elif defined(_MSC_VER)
- #undef force_inline
- #define force_inline __forceinline
-#endif
-
-#endif
Copied: mlpack/tags/mlpack-1.0.2/src/mlpack/core.hpp (from rev 13398, mlpack/trunk/src/mlpack/core.hpp)
===================================================================
--- mlpack/tags/mlpack-1.0.2/src/mlpack/core.hpp (rev 0)
+++ mlpack/tags/mlpack-1.0.2/src/mlpack/core.hpp 2012-08-15 17:40:09 UTC (rev 13406)
@@ -0,0 +1,185 @@
+/***
+ * @file core.hpp
+ *
+ * Include all of the base components required to write MLPACK methods, and the
+ * main MLPACK Doxygen documentation.
+ */
+#ifndef __MLPACK_CORE_HPP
+#define __MLPACK_CORE_HPP
+
+/**
+ * @mainpage MLPACK Documentation
+ *
+ * @section intro_sec Introduction
+ *
+ * MLPACK is an intuitive, fast, scalable C++ machine learning library, meant to
+ * be a machine learning analog to LAPACK. It aims to implement a wide array of
+ * machine learning methods and function as a "swiss army knife" for machine
+ * learning researchers. The MLPACK development website can be found at
+ * http://mlpack.org.
+ *
+ * MLPACK uses the Armadillo C++ matrix library (http://arma.sourceforge.net)
+ * for general matrix, vector, and linear algebra support. MLPACK also uses the
+ * program_options, math_c99, and unit_test_framework components of the Boost
+ * library; in addition, LibXml2 is used.
+ *
+ * @section howto How To Use This Documentation
+ *
+ * This documentation is API documentation similar to Javadoc. It isn't
+ * necessarily a tutorial, but it does provide detailed documentation on every
+ * namespace, method, and class.
+ *
+ * Each MLPACK namespace generally refers to one machine learning method, so
+ * browsing the list of namespaces provides some insight as to the breadth of
+ * the methods contained in the library.
+ *
+ * To generate this documentation in your own local copy of MLPACK, you can
+ * simply use Doxygen, from the root directory (@c /mlpack/trunk/ ):
+ *
+ * @code
+ * $ doxygen
+ * @endcode
+ *
+ * @section executables Executables
+ *
+ * MLPACK provides several executables so that MLPACK methods can be used
+ * without any need for knowledge of C++. These executables are all
+ * self-documented, and that documentation can be accessed by running the
+ * executables with the '-h' or '--help' flag.
+ *
+ * A full list of executables is given below:
+ *
+ * allkfn, allknn, emst, gmm, hmm_train, hmm_loglik, hmm_viterbi, hmm_generate,
+ * kernel_pca, kmeans, lars, linear_regression, local_coordinate_coding, mvu,
+ * nbc, nca, pca, radical, sparse_coding
+ *
+ * @section tutorial Tutorials
+ *
+ * A few short tutorials on how to use MLPACK are given below.
+ *
+ * - @ref build
+ * - @ref matrices
+ * - @ref iodoc
+ * - @ref timer
+ * - @ref sample
+ *
+ * Tutorials on specific methods are also available.
+ *
+ * - @ref nstutorial
+ * - @ref lrtutorial
+ * - @ref rstutorial
+ * - @ref dettutorial
+ * - @ref emst_tutorial
+ *
+ * @section methods Methods in MLPACK
+ *
+ * The following methods are included in MLPACK:
+ *
+ * - Euclidean Minimum Spanning Trees - mlpack::emst::DualTreeBoruvka
+ * - Gaussian Mixture Models (GMMs) - mlpack::gmm::GMM
+ * - Hidden Markov Models (HMMs) - mlpack::hmm::HMM
+ * - Kernel PCA - mlpack::kpca::KernelPCA
+ * - K-Means Clustering - mlpack::kmeans::KMeans
+ * - Least-Angle Regression (LARS/LASSO) - mlpack::regression::LARS
+ * - Local Coordinate Coding - mlpack::lcc::LocalCoordinateCoding
+ * - Naive Bayes Classifier - mlpack::naive_bayes::NaiveBayesClassifier
+ * - Neighborhood Components Analysis (NCA) - mlpack::nca::NCA
+ * - Principal Components Analysis (PCA) - mlpack::pca::PCA
+ * - RADICAL (ICA) - mlpack::radical::Radical
+ * - Simple Least-Squares Linear Regression -
+ * mlpack::regression::LinearRegression
+ * - Sparse Coding - mlpack::sparse_coding::SparseCoding
+ * - Tree-based neighbor search (AllkNN, AllkFN) -
+ * mlpack::neighbor::NeighborSearch
+ * - Tree-based range search - mlpack::range::RangeSearch
+ *
+ * @section remarks Final Remarks
+ *
+ * This software was written in the FASTLab (http://www.fast-lab.org), which is
+ * in the School of Computational Science and Engineering at the Georgia
+ * Institute of Technology.
+ *
+ * MLPACK contributors include:
+ *
+ * - Ryan Curtin <gth671b at mail.gatech.edu>
+ * - James Cline <james.cline at gatech.edu>
+ * - Neil Slagle <nslagle3 at gatech.edu>
+ * - Matthew Amidon <mamidon at gatech.edu>
+ * - Vlad Grantcharov <vlad321 at gatech.edu>
+ * - Ajinkya Kale <kaleajinkya at gmail.com>
+ * - Bill March <march at gatech.edu>
+ * - Dongryeol Lee <dongryel at cc.gatech.edu>
+ * - Nishant Mehta <niche at cc.gatech.edu>
+ * - Parikshit Ram <p.ram at gatech.edu>
+ * - Chip Mappus <cmappus at gatech.edu>
+ * - Hua Ouyang <houyang at gatech.edu>
+ * - Long Quoc Tran <tqlong at gmail.com>
+ * - Noah Kauffman <notoriousnoah at gmail.com>
+ * - Guillermo Colon <gcolon7 at mail.gatech.edu>
+ * - Wei Guan <wguan at cc.gatech.edu>
+ * - Ryan Riegel <rriegel at cc.gatech.edu>
+ * - Nikolaos Vasiloglou <nvasil at ieee.org>
+ * - Garry Boyer <garryb at gmail.com>
+ * - Andreas Löf <andreas.lof at cs.waikato.ac.nz>
+ */
+
+// First, standard includes.
+#include <stdlib.h>
+#include <stdio.h>
+#include <string.h>
+#include <ctype.h>
+#include <limits.h>
+#include <float.h>
+#include <stdint.h>
+#include <iostream>
+
+// Defining __USE_MATH_DEFINES should set M_PI.
+#define _USE_MATH_DEFINES
+#include <math.h>
+
+// For tgamma().
+#include <boost/math/special_functions/gamma.hpp>
+
+// But if it's not defined, we'll do it.
+#ifndef M_PI
+ #define M_PI 3.141592653589793238462643383279
+#endif
+
+// Now MLPACK-specific includes.
+#include <mlpack/core/arma_extend/arma_extend.hpp> // Includes Armadillo.
+#include <mlpack/core/util/log.hpp>
+#include <mlpack/core/util/cli.hpp>
+#include <mlpack/core/data/load.hpp>
+#include <mlpack/core/data/save.hpp>
+#include <mlpack/core/math/clamp.hpp>
+#include <mlpack/core/math/random.hpp>
+#include <mlpack/core/math/lin_alg.hpp>
+#include <mlpack/core/math/range.hpp>
+#include <mlpack/core/math/round.hpp>
+#include <mlpack/core/util/save_restore_utility.hpp>
+#include <mlpack/core/dists/discrete_distribution.hpp>
+#include <mlpack/core/dists/gaussian_distribution.hpp>
+
+// Clean up unfortunate Windows preprocessor definitions.
+// Use std::min and std::max!
+#ifdef _WIN32
+ #ifdef min
+ #undef min
+ #endif
+
+ #ifdef max
+ #undef max
+ #endif
+#endif
+
+// Give ourselves a nice way to force functions to be inline if we need.
+#define force_inline
+#if defined(__GNUG__) && !defined(DEBUG)
+ #undef force_inline
+ #define force_inline __attribute__((noinline))
+#elif defined(_MSC_VER)
+ #undef force_inline && !defined(DEBUG)
+ #define force_inline __forceinline
+#endif
+
+#endif
Deleted: mlpack/tags/mlpack-1.0.2/src/mlpack/methods/neighbor_search/allknn_main.cpp
===================================================================
--- mlpack/trunk/src/mlpack/methods/neighbor_search/allknn_main.cpp 2012-08-09 09:02:13 UTC (rev 13379)
+++ mlpack/tags/mlpack-1.0.2/src/mlpack/methods/neighbor_search/allknn_main.cpp 2012-08-15 17:40:09 UTC (rev 13406)
@@ -1,296 +0,0 @@
-/**
- * @file allknn_main.cpp
- * @author Ryan Curtin
- *
- * Implementation of the AllkNN executable. Allows some number of standard
- * options.
- */
-#include <mlpack/core.hpp>
-#include <mlpack/core/tree/cover_tree.hpp>
-
-#include <string>
-#include <fstream>
-#include <iostream>
-
-#include "neighbor_search.hpp"
-
-using namespace std;
-using namespace mlpack;
-using namespace mlpack::neighbor;
-using namespace mlpack::tree;
-
-// Information about the program itself.
-PROGRAM_INFO("All K-Nearest-Neighbors",
- "This program will calculate the all k-nearest-neighbors of a set of "
- "points. You may specify a separate set of reference points and query "
- "points, or just a reference set which will be used as both the reference "
- "and query set."
- "\n\n"
- "For example, the following will calculate the 5 nearest neighbors of each"
- "point in 'input.csv' and store the distances in 'distances.csv' and the "
- "neighbors in the file 'neighbors.csv':"
- "\n\n"
- "$ allknn --k=5 --reference_file=input.csv --distances_file=distances.csv\n"
- " --neighbors_file=neighbors.csv"
- "\n\n"
- "The output files are organized such that row i and column j in the "
- "neighbors output file corresponds to the index of the point in the "
- "reference set which is the i'th nearest neighbor from the point in the "
- "query set with index j. Row i and column j in the distances output file "
- "corresponds to the distance between those two points.");
-
-// Define our input parameters that this program will take.
-PARAM_STRING_REQ("reference_file", "File containing the reference dataset.",
- "r");
-PARAM_STRING_REQ("distances_file", "File to output distances into.", "d");
-PARAM_STRING_REQ("neighbors_file", "File to output neighbors into.", "n");
-
-PARAM_INT_REQ("k", "Number of furthest neighbors to find.", "k");
-
-PARAM_STRING("query_file", "File containing query points (optional).", "q", "");
-
-PARAM_INT("leaf_size", "Leaf size for tree building.", "l", 20);
-PARAM_FLAG("naive", "If true, O(n^2) naive mode is used for computation.", "N");
-PARAM_FLAG("single_mode", "If true, single-tree search is used (as opposed to "
- "dual-tree search.", "s");
-PARAM_FLAG("cover_tree", "If true, use cover trees to perform the search.",
- "c");
-
-int main(int argc, char *argv[])
-{
- // Give CLI the command line parameters the user passed in.
- CLI::ParseCommandLine(argc, argv);
-
- // Get all the parameters.
- string referenceFile = CLI::GetParam<string>("reference_file");
-
- string distancesFile = CLI::GetParam<string>("distances_file");
- string neighborsFile = CLI::GetParam<string>("neighbors_file");
-
- int lsInt = CLI::GetParam<int>("leaf_size");
-
- size_t k = CLI::GetParam<int>("k");
-
- bool naive = CLI::HasParam("naive");
- bool singleMode = CLI::HasParam("single_mode");
-
- arma::mat referenceData;
- arma::mat queryData; // So it doesn't go out of scope.
- data::Load(referenceFile.c_str(), referenceData, true);
-
- Log::Info << "Loaded reference data from '" << referenceFile << "' ("
- << referenceData.n_rows << " x " << referenceData.n_cols << ")." << endl;
-
- // Sanity check on k value: must be greater than 0, must be less than the
- // number of reference points.
- if (k > referenceData.n_cols)
- {
- Log::Fatal << "Invalid k: " << k << "; must be greater than 0 and less ";
- Log::Fatal << "than or equal to the number of reference points (";
- Log::Fatal << referenceData.n_cols << ")." << endl;
- }
-
- // Sanity check on leaf size.
- if (lsInt < 0)
- {
- Log::Fatal << "Invalid leaf size: " << lsInt << ". Must be greater "
- "than or equal to 0." << endl;
- }
- size_t leafSize = lsInt;
-
- // Naive mode overrides single mode.
- if (singleMode && naive)
- {
- Log::Warn << "--single_mode ignored because --naive is present." << endl;
- }
-
- if (naive)
- leafSize = referenceData.n_cols;
-
- arma::Mat<size_t> neighbors;
- arma::mat distances;
-
- if (!CLI::HasParam("cover_tree"))
- {
- // Because we may construct it differently, we need a pointer.
- AllkNN* allknn = NULL;
-
- // Mappings for when we build the tree.
- std::vector<size_t> oldFromNewRefs;
-
- // Build trees by hand, so we can save memory: if we pass a tree to
- // NeighborSearch, it does not copy the matrix.
- Log::Info << "Building reference tree..." << endl;
- Timer::Start("tree_building");
-
- BinarySpaceTree<bound::HRectBound<2>, QueryStat<NearestNeighborSort> >
- refTree(referenceData, oldFromNewRefs, leafSize);
- BinarySpaceTree<bound::HRectBound<2>, QueryStat<NearestNeighborSort> >*
- queryTree = NULL; // Empty for now.
-
- Timer::Stop("tree_building");
-
- std::vector<size_t> oldFromNewQueries;
-
- if (CLI::GetParam<string>("query_file") != "")
- {
- string queryFile = CLI::GetParam<string>("query_file");
-
- data::Load(queryFile.c_str(), queryData, true);
-
- if (naive && leafSize < queryData.n_cols)
- leafSize = queryData.n_cols;
-
- Log::Info << "Loaded query data from '" << queryFile << "' ("
- << queryData.n_rows << " x " << queryData.n_cols << ")." << endl;
-
- Log::Info << "Building query tree..." << endl;
-
- // Build trees by hand, so we can save memory: if we pass a tree to
- // NeighborSearch, it does not copy the matrix.
- if (!singleMode)
- {
- Timer::Start("tree_building");
-
- queryTree = new BinarySpaceTree<bound::HRectBound<2>,
- QueryStat<NearestNeighborSort> >(queryData, oldFromNewQueries,
- leafSize);
-
- Timer::Stop("tree_building");
- }
-
- allknn = new AllkNN(&refTree, queryTree, referenceData, queryData,
- singleMode);
-
- Log::Info << "Tree built." << endl;
- }
- else
- {
- allknn = new AllkNN(&refTree, referenceData, singleMode);
-
- Log::Info << "Trees built." << endl;
- }
-
- arma::mat distancesOut;
- arma::Mat<size_t> neighborsOut;
-
- Log::Info << "Computing " << k << " nearest neighbors..." << endl;
- allknn->Search(k, neighborsOut, distancesOut);
-
- Log::Info << "Neighbors computed." << endl;
-
- // We have to map back to the original indices from before the tree
- // construction.
- Log::Info << "Re-mapping indices..." << endl;
-
- neighbors.set_size(neighborsOut.n_rows, neighborsOut.n_cols);
- distances.set_size(distancesOut.n_rows, distancesOut.n_cols);
-
- // Do the actual remapping.
- if ((CLI::GetParam<string>("query_file") != "") && !singleMode)
- {
- for (size_t i = 0; i < distancesOut.n_cols; ++i)
- {
- // Map distances (copy a column) and square root.
- distances.col(oldFromNewQueries[i]) = sqrt(distancesOut.col(i));
-
- // Map indices of neighbors.
- for (size_t j = 0; j < distancesOut.n_rows; ++j)
- {
- neighbors(j, oldFromNewQueries[i]) =
- oldFromNewRefs[neighborsOut(j, i)];
- }
- }
- }
- else if ((CLI::GetParam<string>("query_file") != "") && singleMode)
- {
- // No remapping of queries is necessary. So distances are the same.
- distances = sqrt(distancesOut);
-
- // The neighbor indices must be mapped.
- for (size_t j = 0; j < neighborsOut.n_elem; ++j)
- {
- neighbors[j] = oldFromNewRefs[neighborsOut[j]];
- }
- }
- else
- {
- for (size_t i = 0; i < distancesOut.n_cols; ++i)
- {
- // Map distances (copy a column).
- distances.col(oldFromNewRefs[i]) = sqrt(distancesOut.col(i));
-
- // Map indices of neighbors.
- for (size_t j = 0; j < distancesOut.n_rows; ++j)
- {
- neighbors(j, oldFromNewRefs[i]) = oldFromNewRefs[neighborsOut(j, i)];
- }
- }
- }
-
- // Clean up.
- if (queryTree)
- delete queryTree;
-
- delete allknn;
- }
- else // Cover trees.
- {
- // Build our reference tree.
- Log::Info << "Building reference tree..." << endl;
- Timer::Start("tree_building");
- CoverTree<metric::LMetric<2, true>, tree::FirstPointIsRoot,
- QueryStat<NearestNeighborSort> > referenceTree(referenceData);
- CoverTree<metric::LMetric<2, true>, tree::FirstPointIsRoot,
- QueryStat<NearestNeighborSort> >* queryTree = NULL;
- Timer::Stop("tree_building");
-
- NeighborSearch<NearestNeighborSort, metric::LMetric<2, true>,
- CoverTree<metric::LMetric<2, true>, tree::FirstPointIsRoot,
- QueryStat<NearestNeighborSort> > >* allknn = NULL;
-
- // See if we have query data.
- if (CLI::HasParam("query_file"))
- {
- string queryFile = CLI::GetParam<string>("query_file");
-
- data::Load(queryFile, queryData, true);
-
- // Build query tree.
- if (!singleMode)
- {
- Log::Info << "Building query tree..." << endl;
- Timer::Start("tree_building");
- queryTree = new CoverTree<metric::LMetric<2, true>,
- tree::FirstPointIsRoot, QueryStat<NearestNeighborSort> >(queryData);
- Timer::Stop("tree_building");
- }
-
- allknn = new NeighborSearch<NearestNeighborSort, metric::LMetric<2, true>,
- CoverTree<metric::LMetric<2, true>, tree::FirstPointIsRoot,
- QueryStat<NearestNeighborSort> > >(&referenceTree, queryTree,
- referenceData, queryData, singleMode);
- }
- else
- {
- allknn = new NeighborSearch<NearestNeighborSort, metric::LMetric<2, true>,
- CoverTree<metric::LMetric<2, true>, tree::FirstPointIsRoot,
- QueryStat<NearestNeighborSort> > >(&referenceTree, referenceData,
- singleMode);
- }
-
- Log::Info << "Computing " << k << " nearest neighbors..." << endl;
- allknn->Search(k, neighbors, distances);
-
- Log::Info << "Neighbors computed." << endl;
-
- delete allknn;
-
- if (queryTree)
- delete queryTree;
- }
-
- // Save output.
- data::Save(distancesFile, distances);
- data::Save(neighborsFile, neighbors);
-}
Copied: mlpack/tags/mlpack-1.0.2/src/mlpack/methods/neighbor_search/allknn_main.cpp (from rev 13400, mlpack/trunk/src/mlpack/methods/neighbor_search/allknn_main.cpp)
===================================================================
--- mlpack/tags/mlpack-1.0.2/src/mlpack/methods/neighbor_search/allknn_main.cpp (rev 0)
+++ mlpack/tags/mlpack-1.0.2/src/mlpack/methods/neighbor_search/allknn_main.cpp 2012-08-15 17:40:09 UTC (rev 13406)
@@ -0,0 +1,301 @@
+/**
+ * @file allknn_main.cpp
+ * @author Ryan Curtin
+ *
+ * Implementation of the AllkNN executable. Allows some number of standard
+ * options.
+ */
+#include <mlpack/core.hpp>
+#include <mlpack/core/tree/cover_tree.hpp>
+
+#include <string>
+#include <fstream>
+#include <iostream>
+
+#include "neighbor_search.hpp"
+
+using namespace std;
+using namespace mlpack;
+using namespace mlpack::neighbor;
+using namespace mlpack::tree;
+
+// Information about the program itself.
+PROGRAM_INFO("All K-Nearest-Neighbors",
+ "This program will calculate the all k-nearest-neighbors of a set of "
+ "points using kd-trees or cover trees (cover tree support is experimental "
+ "and may not be optimally fast). You may specify a separate set of "
+ "reference points and query points, or just a reference set which will be "
+ "used as both the reference and query set."
+ "\n\n"
+ "For example, the following will calculate the 5 nearest neighbors of each"
+ "point in 'input.csv' and store the distances in 'distances.csv' and the "
+ "neighbors in the file 'neighbors.csv':"
+ "\n\n"
+ "$ allknn --k=5 --reference_file=input.csv --distances_file=distances.csv\n"
+ " --neighbors_file=neighbors.csv"
+ "\n\n"
+ "The output files are organized such that row i and column j in the "
+ "neighbors output file corresponds to the index of the point in the "
+ "reference set which is the i'th nearest neighbor from the point in the "
+ "query set with index j. Row i and column j in the distances output file "
+ "corresponds to the distance between those two points.");
+
+// Define our input parameters that this program will take.
+PARAM_STRING_REQ("reference_file", "File containing the reference dataset.",
+ "r");
+PARAM_STRING_REQ("distances_file", "File to output distances into.", "d");
+PARAM_STRING_REQ("neighbors_file", "File to output neighbors into.", "n");
+
+PARAM_INT_REQ("k", "Number of furthest neighbors to find.", "k");
+
+PARAM_STRING("query_file", "File containing query points (optional).", "q", "");
+
+PARAM_INT("leaf_size", "Leaf size for tree building.", "l", 20);
+PARAM_FLAG("naive", "If true, O(n^2) naive mode is used for computation.", "N");
+PARAM_FLAG("single_mode", "If true, single-tree search is used (as opposed to "
+ "dual-tree search.", "s");
+PARAM_FLAG("cover_tree", "If true, use cover trees to perform the search "
+ "(experimental, may be slow).", "c");
+
+int main(int argc, char *argv[])
+{
+ // Give CLI the command line parameters the user passed in.
+ CLI::ParseCommandLine(argc, argv);
+
+ // Get all the parameters.
+ string referenceFile = CLI::GetParam<string>("reference_file");
+
+ string distancesFile = CLI::GetParam<string>("distances_file");
+ string neighborsFile = CLI::GetParam<string>("neighbors_file");
+
+ int lsInt = CLI::GetParam<int>("leaf_size");
+
+ size_t k = CLI::GetParam<int>("k");
+
+ bool naive = CLI::HasParam("naive");
+ bool singleMode = CLI::HasParam("single_mode");
+
+ arma::mat referenceData;
+ arma::mat queryData; // So it doesn't go out of scope.
+ data::Load(referenceFile.c_str(), referenceData, true);
+
+ Log::Info << "Loaded reference data from '" << referenceFile << "' ("
+ << referenceData.n_rows << " x " << referenceData.n_cols << ")." << endl;
+
+ // Sanity check on k value: must be greater than 0, must be less than the
+ // number of reference points.
+ if (k > referenceData.n_cols)
+ {
+ Log::Fatal << "Invalid k: " << k << "; must be greater than 0 and less ";
+ Log::Fatal << "than or equal to the number of reference points (";
+ Log::Fatal << referenceData.n_cols << ")." << endl;
+ }
+
+ // Sanity check on leaf size.
+ if (lsInt < 0)
+ {
+ Log::Fatal << "Invalid leaf size: " << lsInt << ". Must be greater "
+ "than or equal to 0." << endl;
+ }
+ size_t leafSize = lsInt;
+
+ // Naive mode overrides single mode.
+ if (singleMode && naive)
+ {
+ Log::Warn << "--single_mode ignored because --naive is present." << endl;
+ }
+
+ if (naive)
+ leafSize = referenceData.n_cols;
+
+ arma::Mat<size_t> neighbors;
+ arma::mat distances;
+
+ if (!CLI::HasParam("cover_tree"))
+ {
+ // Because we may construct it differently, we need a pointer.
+ AllkNN* allknn = NULL;
+
+ // Mappings for when we build the tree.
+ std::vector<size_t> oldFromNewRefs;
+
+ // Build trees by hand, so we can save memory: if we pass a tree to
+ // NeighborSearch, it does not copy the matrix.
+ Log::Info << "Building reference tree..." << endl;
+ Timer::Start("tree_building");
+
+ BinarySpaceTree<bound::HRectBound<2>, QueryStat<NearestNeighborSort> >
+ refTree(referenceData, oldFromNewRefs, leafSize);
+ BinarySpaceTree<bound::HRectBound<2>, QueryStat<NearestNeighborSort> >*
+ queryTree = NULL; // Empty for now.
+
+ Timer::Stop("tree_building");
+
+ std::vector<size_t> oldFromNewQueries;
+
+ if (CLI::GetParam<string>("query_file") != "")
+ {
+ string queryFile = CLI::GetParam<string>("query_file");
+
+ data::Load(queryFile.c_str(), queryData, true);
+
+ if (naive && leafSize < queryData.n_cols)
+ leafSize = queryData.n_cols;
+
+ Log::Info << "Loaded query data from '" << queryFile << "' ("
+ << queryData.n_rows << " x " << queryData.n_cols << ")." << endl;
+
+ Log::Info << "Building query tree..." << endl;
+
+ // Build trees by hand, so we can save memory: if we pass a tree to
+ // NeighborSearch, it does not copy the matrix.
+ if (!singleMode)
+ {
+ Timer::Start("tree_building");
+
+ queryTree = new BinarySpaceTree<bound::HRectBound<2>,
+ QueryStat<NearestNeighborSort> >(queryData, oldFromNewQueries,
+ leafSize);
+
+ Timer::Stop("tree_building");
+ }
+
+ allknn = new AllkNN(&refTree, queryTree, referenceData, queryData,
+ singleMode);
+
+ Log::Info << "Tree built." << endl;
+ }
+ else
+ {
+ allknn = new AllkNN(&refTree, referenceData, singleMode);
+
+ Log::Info << "Trees built." << endl;
+ }
+
+ arma::mat distancesOut;
+ arma::Mat<size_t> neighborsOut;
+
+ Log::Info << "Computing " << k << " nearest neighbors..." << endl;
+ allknn->Search(k, neighborsOut, distancesOut);
+
+ Log::Info << "Neighbors computed." << endl;
+
+ // We have to map back to the original indices from before the tree
+ // construction.
+ Log::Info << "Re-mapping indices..." << endl;
+
+ neighbors.set_size(neighborsOut.n_rows, neighborsOut.n_cols);
+ distances.set_size(distancesOut.n_rows, distancesOut.n_cols);
+
+ // Do the actual remapping.
+ if ((CLI::GetParam<string>("query_file") != "") && !singleMode)
+ {
+ for (size_t i = 0; i < distancesOut.n_cols; ++i)
+ {
+ // Map distances (copy a column) and square root.
+ distances.col(oldFromNewQueries[i]) = sqrt(distancesOut.col(i));
+
+ // Map indices of neighbors.
+ for (size_t j = 0; j < distancesOut.n_rows; ++j)
+ {
+ neighbors(j, oldFromNewQueries[i]) =
+ oldFromNewRefs[neighborsOut(j, i)];
+ }
+ }
+ }
+ else if ((CLI::GetParam<string>("query_file") != "") && singleMode)
+ {
+ // No remapping of queries is necessary. So distances are the same.
+ distances = sqrt(distancesOut);
+
+ // The neighbor indices must be mapped.
+ for (size_t j = 0; j < neighborsOut.n_elem; ++j)
+ {
+ neighbors[j] = oldFromNewRefs[neighborsOut[j]];
+ }
+ }
+ else
+ {
+ for (size_t i = 0; i < distancesOut.n_cols; ++i)
+ {
+ // Map distances (copy a column).
+ distances.col(oldFromNewRefs[i]) = sqrt(distancesOut.col(i));
+
+ // Map indices of neighbors.
+ for (size_t j = 0; j < distancesOut.n_rows; ++j)
+ {
+ neighbors(j, oldFromNewRefs[i]) = oldFromNewRefs[neighborsOut(j, i)];
+ }
+ }
+ }
+
+ // Clean up.
+ if (queryTree)
+ delete queryTree;
+
+ delete allknn;
+ }
+ else // Cover trees.
+ {
+ // Make sure to notify the user that they are using cover trees.
+ Log::Info << "Using cover trees for nearest-neighbor calculation." << endl;
+
+ // Build our reference tree.
+ Log::Info << "Building reference tree..." << endl;
+ Timer::Start("tree_building");
+ CoverTree<metric::LMetric<2, true>, tree::FirstPointIsRoot,
+ QueryStat<NearestNeighborSort> > referenceTree(referenceData, 1.3);
+ CoverTree<metric::LMetric<2, true>, tree::FirstPointIsRoot,
+ QueryStat<NearestNeighborSort> >* queryTree = NULL;
+ Timer::Stop("tree_building");
+
+ NeighborSearch<NearestNeighborSort, metric::LMetric<2, true>,
+ CoverTree<metric::LMetric<2, true>, tree::FirstPointIsRoot,
+ QueryStat<NearestNeighborSort> > >* allknn = NULL;
+
+ // See if we have query data.
+ if (CLI::HasParam("query_file"))
+ {
+ string queryFile = CLI::GetParam<string>("query_file");
+
+ data::Load(queryFile, queryData, true);
+
+ // Build query tree.
+ if (!singleMode)
+ {
+ Log::Info << "Building query tree..." << endl;
+ Timer::Start("tree_building");
+ queryTree = new CoverTree<metric::LMetric<2, true>,
+ tree::FirstPointIsRoot, QueryStat<NearestNeighborSort> >(queryData,
+ 1.3);
+ Timer::Stop("tree_building");
+ }
+
+ allknn = new NeighborSearch<NearestNeighborSort, metric::LMetric<2, true>,
+ CoverTree<metric::LMetric<2, true>, tree::FirstPointIsRoot,
+ QueryStat<NearestNeighborSort> > >(&referenceTree, queryTree,
+ referenceData, queryData, singleMode);
+ }
+ else
+ {
+ allknn = new NeighborSearch<NearestNeighborSort, metric::LMetric<2, true>,
+ CoverTree<metric::LMetric<2, true>, tree::FirstPointIsRoot,
+ QueryStat<NearestNeighborSort> > >(&referenceTree, referenceData,
+ singleMode);
+ }
+
+ Log::Info << "Computing " << k << " nearest neighbors..." << endl;
+ allknn->Search(k, neighbors, distances);
+
+ Log::Info << "Neighbors computed." << endl;
+
+ delete allknn;
+
+ if (queryTree)
+ delete queryTree;
+ }
+
+ // Save output.
+ data::Save(distancesFile, distances);
+ data::Save(neighborsFile, neighbors);
+}
Deleted: mlpack/tags/mlpack-1.0.2/src/mlpack/methods/neighbor_search/neighbor_search_rules.hpp
===================================================================
--- mlpack/trunk/src/mlpack/methods/neighbor_search/neighbor_search_rules.hpp 2012-08-09 09:02:13 UTC (rev 13379)
+++ mlpack/tags/mlpack-1.0.2/src/mlpack/methods/neighbor_search/neighbor_search_rules.hpp 2012-08-15 17:40:09 UTC (rev 13406)
@@ -1,136 +0,0 @@
-/**
- * @file neighbor_search_rules.hpp
- * @author Ryan Curtin
- *
- * Defines the pruning rules and base case rules necessary to perform a
- * tree-based search (with an arbitrary tree) for the NeighborSearch class.
- */
-#ifndef __MLPACK_METHODS_NEIGHBOR_SEARCH_NEIGHBOR_SEARCH_RULES_HPP
-#define __MLPACK_METHODS_NEIGHBOR_SEARCH_NEIGHBOR_SEARCH_RULES_HPP
-
-namespace mlpack {
-namespace neighbor {
-
-template<typename SortPolicy, typename MetricType, typename TreeType>
-class NeighborSearchRules
-{
- public:
- NeighborSearchRules(const arma::mat& referenceSet,
- const arma::mat& querySet,
- arma::Mat<size_t>& neighbors,
- arma::mat& distances,
- MetricType& metric);
-
- double BaseCase(const size_t queryIndex, const size_t referenceIndex);
-
- // For single-tree traversal.
- bool CanPrune(const size_t queryIndex, TreeType& referenceNode);
-
- // For dual-tree traversal.
- bool CanPrune(TreeType& queryNode, TreeType& referenceNode);
-
- // Update bounds. Needs a better name.
- void UpdateAfterRecursion(TreeType& queryNode, TreeType& referenceNode);
-
- /**
- * Get the score for recursion order. A low score indicates priority for
- * recursion, while DBL_MAX indicates that the node should not be recursed
- * into at all (it should be pruned).
- *
- * @param queryIndex Index of query point.
- * @param referenceNode Candidate node to be recursed into.
- */
- double Score(const size_t queryIndex, TreeType& referenceNode) const;
-
- /**
- * Get the score for recursion order, passing the base case result (in the
- * situation where it may be needed to calculate the recursion order). A low
- * score indicates priority for recursion, while DBL_MAX indicates that the
- * node should not be recursed into at all (it should be pruned).
- *
- * @param queryIndex Index of query point.
- * @param referenceNode Candidate node to be recursed into.
- * @param baseCaseResult Result of BaseCase(queryIndex, referenceNode).
- */
- double Score(const size_t queryIndex,
- TreeType& referenceNode,
- const double baseCaseResult) const;
-
- /**
- * Re-evaluate the score for recursion order. A low score indicates priority
- * for recursion, while DBL_MAX indicates that the node should not be recursed
- * into at all (it should be pruned). This is used when the score has already
- * been calculated, but another recursion may have modified the bounds for
- * pruning. So the old score is checked against the new pruning bound.
- *
- * @param queryIndex Index of query point.
- * @param referenceNode Candidate node to be recursed into.
- * @param oldScore Old score produced by Score() (or Rescore()).
- */
- double Rescore(const size_t queryIndex,
- TreeType& referenceNode,
- const double oldScore) const;
-
- /**
- * Get the score for recursion order. A low score indicates priority for
- * recursionm while DBL_MAX indicates that the node should not be recursed
- * into at all (it should be pruned).
- *
- * @param queryNode Candidate query node to recurse into.
- * @param referenceNode Candidate reference node to recurse into.
- */
- double Score(TreeType& queryNode, TreeType& referenceNode) const;
-
- /**
- * Re-evaluate the score for recursion order. A low score indicates priority
- * for recursion, while DBL_MAX indicates that the node should not be recursed
- * into at all (it should be pruned). This is used when the score has already
- * been calculated, but another recursion may have modified the bounds for
- * pruning. So the old score is checked against the new pruning bound.
- *
- * @param queryNode Candidate query node to recurse into.
- * @param referenceNode Candidate reference node to recurse into.
- * @param oldScore Old score produced by Socre() (or Rescore()).
- */
- double Rescore(TreeType& queryNode,
- TreeType& referenceNode,
- const double oldScore) const;
-
- private:
- //! The reference set.
- const arma::mat& referenceSet;
-
- //! The query set.
- const arma::mat& querySet;
-
- //! The matrix the resultant neighbor indices should be stored in.
- arma::Mat<size_t>& neighbors;
-
- //! The matrix the resultant neighbor distances should be stored in.
- arma::mat& distances;
-
- //! The instantiated metric.
- MetricType& metric;
-
- /**
- * Insert a point into the neighbors and distances matrices; this is a helper
- * function.
- *
- * @param queryIndex Index of point whose neighbors we are inserting into.
- * @param pos Position in list to insert into.
- * @param neighbor Index of reference point which is being inserted.
- * @param distance Distance from query point to reference point.
- */
- void InsertNeighbor(const size_t queryIndex,
- const size_t pos,
- const size_t neighbor,
- const double distance);
-};
-
-}; // namespace neighbor
-}; // namespace mlpack
-
-// Include implementation.
-#include "neighbor_search_rules_impl.hpp"
-
-#endif // __MLPACK_METHODS_NEIGHBOR_SEARCH_NEIGHBOR_SEARCH_RULES_HPP
Copied: mlpack/tags/mlpack-1.0.2/src/mlpack/methods/neighbor_search/neighbor_search_rules.hpp (from rev 13385, mlpack/trunk/src/mlpack/methods/neighbor_search/neighbor_search_rules.hpp)
===================================================================
--- mlpack/tags/mlpack-1.0.2/src/mlpack/methods/neighbor_search/neighbor_search_rules.hpp (rev 0)
+++ mlpack/tags/mlpack-1.0.2/src/mlpack/methods/neighbor_search/neighbor_search_rules.hpp 2012-08-15 17:40:09 UTC (rev 13406)
@@ -0,0 +1,144 @@
+/**
+ * @file neighbor_search_rules.hpp
+ * @author Ryan Curtin
+ *
+ * Defines the pruning rules and base case rules necessary to perform a
+ * tree-based search (with an arbitrary tree) for the NeighborSearch class.
+ */
+#ifndef __MLPACK_METHODS_NEIGHBOR_SEARCH_NEIGHBOR_SEARCH_RULES_HPP
+#define __MLPACK_METHODS_NEIGHBOR_SEARCH_NEIGHBOR_SEARCH_RULES_HPP
+
+namespace mlpack {
+namespace neighbor {
+
+template<typename SortPolicy, typename MetricType, typename TreeType>
+class NeighborSearchRules
+{
+ public:
+ NeighborSearchRules(const arma::mat& referenceSet,
+ const arma::mat& querySet,
+ arma::Mat<size_t>& neighbors,
+ arma::mat& distances,
+ MetricType& metric);
+
+ double BaseCase(const size_t queryIndex, const size_t referenceIndex);
+
+ // Update bounds. Needs a better name.
+ void UpdateAfterRecursion(TreeType& queryNode, TreeType& referenceNode);
+
+ /**
+ * Get the score for recursion order. A low score indicates priority for
+ * recursion, while DBL_MAX indicates that the node should not be recursed
+ * into at all (it should be pruned).
+ *
+ * @param queryIndex Index of query point.
+ * @param referenceNode Candidate node to be recursed into.
+ */
+ double Score(const size_t queryIndex, TreeType& referenceNode) const;
+
+ /**
+ * Get the score for recursion order, passing the base case result (in the
+ * situation where it may be needed to calculate the recursion order). A low
+ * score indicates priority for recursion, while DBL_MAX indicates that the
+ * node should not be recursed into at all (it should be pruned).
+ *
+ * @param queryIndex Index of query point.
+ * @param referenceNode Candidate node to be recursed into.
+ * @param baseCaseResult Result of BaseCase(queryIndex, referenceNode).
+ */
+ double Score(const size_t queryIndex,
+ TreeType& referenceNode,
+ const double baseCaseResult) const;
+
+ /**
+ * Re-evaluate the score for recursion order. A low score indicates priority
+ * for recursion, while DBL_MAX indicates that the node should not be recursed
+ * into at all (it should be pruned). This is used when the score has already
+ * been calculated, but another recursion may have modified the bounds for
+ * pruning. So the old score is checked against the new pruning bound.
+ *
+ * @param queryIndex Index of query point.
+ * @param referenceNode Candidate node to be recursed into.
+ * @param oldScore Old score produced by Score() (or Rescore()).
+ */
+ double Rescore(const size_t queryIndex,
+ TreeType& referenceNode,
+ const double oldScore) const;
+
+ /**
+ * Get the score for recursion order. A low score indicates priority for
+ * recursionm while DBL_MAX indicates that the node should not be recursed
+ * into at all (it should be pruned).
+ *
+ * @param queryNode Candidate query node to recurse into.
+ * @param referenceNode Candidate reference node to recurse into.
+ */
+ double Score(TreeType& queryNode, TreeType& referenceNode) const;
+
+ /**
+ * Get the score for recursion order, passing the base case result (in the
+ * situation where it may be needed to calculate the recursion order). A low
+ * score indicates priority for recursion, while DBL_MAX indicates that the
+ * node should not be recursed into at all (it should be pruned).
+ *
+ * @param queryNode Candidate query node to recurse into.
+ * @param referenceNode Candidate reference node to recurse into.
+ * @param baseCaseResult Result of BaseCase(queryIndex, referenceNode).
+ */
+ double Score(TreeType& queryNode,
+ TreeType& referenceNode,
+ const double baseCaseResult) const;
+
+ /**
+ * Re-evaluate the score for recursion order. A low score indicates priority
+ * for recursion, while DBL_MAX indicates that the node should not be recursed
+ * into at all (it should be pruned). This is used when the score has already
+ * been calculated, but another recursion may have modified the bounds for
+ * pruning. So the old score is checked against the new pruning bound.
+ *
+ * @param queryNode Candidate query node to recurse into.
+ * @param referenceNode Candidate reference node to recurse into.
+ * @param oldScore Old score produced by Socre() (or Rescore()).
+ */
+ double Rescore(TreeType& queryNode,
+ TreeType& referenceNode,
+ const double oldScore) const;
+
+ private:
+ //! The reference set.
+ const arma::mat& referenceSet;
+
+ //! The query set.
+ const arma::mat& querySet;
+
+ //! The matrix the resultant neighbor indices should be stored in.
+ arma::Mat<size_t>& neighbors;
+
+ //! The matrix the resultant neighbor distances should be stored in.
+ arma::mat& distances;
+
+ //! The instantiated metric.
+ MetricType& metric;
+
+ /**
+ * Insert a point into the neighbors and distances matrices; this is a helper
+ * function.
+ *
+ * @param queryIndex Index of point whose neighbors we are inserting into.
+ * @param pos Position in list to insert into.
+ * @param neighbor Index of reference point which is being inserted.
+ * @param distance Distance from query point to reference point.
+ */
+ void InsertNeighbor(const size_t queryIndex,
+ const size_t pos,
+ const size_t neighbor,
+ const double distance);
+};
+
+}; // namespace neighbor
+}; // namespace mlpack
+
+// Include implementation.
+#include "neighbor_search_rules_impl.hpp"
+
+#endif // __MLPACK_METHODS_NEIGHBOR_SEARCH_NEIGHBOR_SEARCH_RULES_HPP
Deleted: mlpack/tags/mlpack-1.0.2/src/mlpack/methods/neighbor_search/neighbor_search_rules_impl.hpp
===================================================================
--- mlpack/trunk/src/mlpack/methods/neighbor_search/neighbor_search_rules_impl.hpp 2012-08-09 09:02:13 UTC (rev 13379)
+++ mlpack/tags/mlpack-1.0.2/src/mlpack/methods/neighbor_search/neighbor_search_rules_impl.hpp 2012-08-15 17:40:09 UTC (rev 13406)
@@ -1,218 +0,0 @@
-/**
- * @file nearest_neighbor_rules_impl.hpp
- * @author Ryan Curtin
- *
- * Implementation of NearestNeighborRules.
- */
-#ifndef __MLPACK_METHODS_NEIGHBOR_SEARCH_NEAREST_NEIGHBOR_RULES_IMPL_HPP
-#define __MLPACK_METHODS_NEIGHBOR_SEARCH_NEAREST_NEIGHBOR_RULES_IMPL_HPP
-
-// In case it hasn't been included yet.
-#include "neighbor_search_rules.hpp"
-
-namespace mlpack {
-namespace neighbor {
-
-template<typename SortPolicy, typename MetricType, typename TreeType>
-NeighborSearchRules<SortPolicy, MetricType, TreeType>::NeighborSearchRules(
- const arma::mat& referenceSet,
- const arma::mat& querySet,
- arma::Mat<size_t>& neighbors,
- arma::mat& distances,
- MetricType& metric) :
- referenceSet(referenceSet),
- querySet(querySet),
- neighbors(neighbors),
- distances(distances),
- metric(metric)
-{ /* Nothing left to do. */ }
-
-template<typename SortPolicy, typename MetricType, typename TreeType>
-inline force_inline // Absolutely MUST be inline so optimizations can happen.
-double NeighborSearchRules<SortPolicy, MetricType, TreeType>::
-BaseCase(const size_t queryIndex, const size_t referenceIndex)
-{
- // If the datasets are the same, then this search is only using one dataset
- // and we should not return identical points.
- if ((&querySet == &referenceSet) && (queryIndex == referenceIndex))
- return 0.0;
-
- double distance = metric.Evaluate(querySet.unsafe_col(queryIndex),
- referenceSet.unsafe_col(referenceIndex));
-
- // If this distance is better than any of the current candidates, the
- // SortDistance() function will give us the position to insert it into.
- arma::vec queryDist = distances.unsafe_col(queryIndex);
- size_t insertPosition = SortPolicy::SortDistance(queryDist, distance);
-
- // SortDistance() returns (size_t() - 1) if we shouldn't add it.
- if (insertPosition != (size_t() - 1))
- InsertNeighbor(queryIndex, insertPosition, referenceIndex, distance);
-
- return distance;
-}
-
-template<typename SortPolicy, typename MetricType, typename TreeType>
-inline bool NeighborSearchRules<SortPolicy, MetricType, TreeType>::CanPrune(
- const size_t queryIndex,
- TreeType& referenceNode)
-{
- // Find the best distance between the query point and the node.
- const arma::vec queryPoint = querySet.unsafe_col(queryIndex);
- const double distance =
- SortPolicy::BestPointToNodeDistance(queryPoint, &referenceNode);
- const double bestDistance = distances(distances.n_rows - 1, queryIndex);
-
- // If this is better than the best distance we've seen so far, maybe there
- // will be something down this node.
- return !(SortPolicy::IsBetter(distance, bestDistance));
-}
-
-template<typename SortPolicy, typename MetricType, typename TreeType>
-inline bool NeighborSearchRules<SortPolicy, MetricType, TreeType>::CanPrune(
- TreeType& queryNode,
- TreeType& referenceNode)
-{
- const double distance = SortPolicy::BestNodeToNodeDistance(
- &queryNode, &referenceNode);
- const double bestDistance = queryNode.Stat().Bound();
-
- return !(SortPolicy::IsBetter(distance, bestDistance));
-}
-
-template<typename SortPolicy, typename MetricType, typename TreeType>
-void NeighborSearchRules<
- SortPolicy,
- MetricType,
- TreeType>::
-UpdateAfterRecursion(TreeType& queryNode, TreeType& /* referenceNode */)
-{
- // Find the worst distance that the children found (including any points), and
- // update the bound accordingly.
- double worstDistance = SortPolicy::BestDistance();
-
- // First look through children nodes.
- for (size_t i = 0; i < queryNode.NumChildren(); ++i)
- {
- if (SortPolicy::IsBetter(worstDistance, queryNode.Child(i).Stat().Bound()))
- worstDistance = queryNode.Child(i).Stat().Bound();
- }
-
- // Now look through children points.
- for (size_t i = 0; i < queryNode.NumPoints(); ++i)
- {
- if (SortPolicy::IsBetter(worstDistance,
- distances(distances.n_rows - 1, queryNode.Point(i))))
- worstDistance = distances(distances.n_rows - 1, queryNode.Point(i));
- }
-
- // Take the worst distance from all of these, and update our bound to reflect
- // that.
- queryNode.Stat().Bound() = worstDistance;
-}
-
-template<typename SortPolicy, typename MetricType, typename TreeType>
-inline double NeighborSearchRules<SortPolicy, MetricType, TreeType>::Score(
- const size_t queryIndex,
- TreeType& referenceNode) const
-{
- const arma::vec queryPoint = querySet.unsafe_col(queryIndex);
- const double distance = SortPolicy::BestPointToNodeDistance(queryPoint,
- &referenceNode);
- const double bestDistance = distances(distances.n_rows - 1, queryIndex);
-
- return (SortPolicy::IsBetter(distance, bestDistance)) ? distance : DBL_MAX;
-}
-
-template<typename SortPolicy, typename MetricType, typename TreeType>
-inline double NeighborSearchRules<SortPolicy, MetricType, TreeType>::Score(
- const size_t queryIndex,
- TreeType& referenceNode,
- const double baseCaseResult) const
-{
- const arma::vec queryPoint = querySet.unsafe_col(queryIndex);
- const double distance = SortPolicy::BestPointToNodeDistance(queryPoint,
- &referenceNode, baseCaseResult);
- const double bestDistance = distances(distances.n_rows - 1, queryIndex);
-
- return (SortPolicy::IsBetter(distance, bestDistance)) ? distance : DBL_MAX;
-}
-
-template<typename SortPolicy, typename MetricType, typename TreeType>
-inline double NeighborSearchRules<SortPolicy, MetricType, TreeType>::Rescore(
- const size_t queryIndex,
- TreeType& /* referenceNode */,
- const double oldScore) const
-{
- // If we are already pruning, still prune.
- if (oldScore == DBL_MAX)
- return oldScore;
-
- // Just check the score again against the distances.
- const double bestDistance = distances(distances.n_rows - 1, queryIndex);
-
- return (SortPolicy::IsBetter(oldScore, bestDistance)) ? oldScore : DBL_MAX;
-}
-
-template<typename SortPolicy, typename MetricType, typename TreeType>
-inline double NeighborSearchRules<SortPolicy, MetricType, TreeType>::Score(
- TreeType& queryNode,
- TreeType& referenceNode) const
-{
- const double distance = SortPolicy::BestNodeToNodeDistance(&queryNode,
- &referenceNode);
- const double bestDistance = queryNode.Stat().Bound();
-
- return (SortPolicy::IsBetter(distance, bestDistance)) ? distance : DBL_MAX;
-}
-
-template<typename SortPolicy, typename MetricType, typename TreeType>
-inline double NeighborSearchRules<SortPolicy, MetricType, TreeType>::Rescore(
- TreeType& queryNode,
- TreeType& /* referenceNode */,
- const double oldScore) const
-{
- if (oldScore == DBL_MAX)
- return oldScore;
-
- const double bestDistance = queryNode.Stat().Bound();
-
- return (SortPolicy::IsBetter(oldScore, bestDistance)) ? oldScore : DBL_MAX;
-}
-
-/**
- * Helper function to insert a point into the neighbors and distances matrices.
- *
- * @param queryIndex Index of point whose neighbors we are inserting into.
- * @param pos Position in list to insert into.
- * @param neighbor Index of reference point which is being inserted.
- * @param distance Distance from query point to reference point.
- */
-template<typename SortPolicy, typename MetricType, typename TreeType>
-void NeighborSearchRules<SortPolicy, MetricType, TreeType>::InsertNeighbor(
- const size_t queryIndex,
- const size_t pos,
- const size_t neighbor,
- const double distance)
-{
- // We only memmove() if there is actually a need to shift something.
- if (pos < (distances.n_rows - 1))
- {
- int len = (distances.n_rows - 1) - pos;
- memmove(distances.colptr(queryIndex) + (pos + 1),
- distances.colptr(queryIndex) + pos,
- sizeof(double) * len);
- memmove(neighbors.colptr(queryIndex) + (pos + 1),
- neighbors.colptr(queryIndex) + pos,
- sizeof(size_t) * len);
- }
-
- // Now put the new information in the right index.
- distances(pos, queryIndex) = distance;
- neighbors(pos, queryIndex) = neighbor;
-}
-
-}; // namespace neighbor
-}; // namespace mlpack
-
-#endif // __MLPACK_METHODS_NEIGHBOR_SEARCH_NEAREST_NEIGHBOR_RULES_IMPL_HPP
Copied: mlpack/tags/mlpack-1.0.2/src/mlpack/methods/neighbor_search/neighbor_search_rules_impl.hpp (from rev 13385, mlpack/trunk/src/mlpack/methods/neighbor_search/neighbor_search_rules_impl.hpp)
===================================================================
--- mlpack/tags/mlpack-1.0.2/src/mlpack/methods/neighbor_search/neighbor_search_rules_impl.hpp (rev 0)
+++ mlpack/tags/mlpack-1.0.2/src/mlpack/methods/neighbor_search/neighbor_search_rules_impl.hpp 2012-08-15 17:40:09 UTC (rev 13406)
@@ -0,0 +1,203 @@
+/**
+ * @file nearest_neighbor_rules_impl.hpp
+ * @author Ryan Curtin
+ *
+ * Implementation of NearestNeighborRules.
+ */
+#ifndef __MLPACK_METHODS_NEIGHBOR_SEARCH_NEAREST_NEIGHBOR_RULES_IMPL_HPP
+#define __MLPACK_METHODS_NEIGHBOR_SEARCH_NEAREST_NEIGHBOR_RULES_IMPL_HPP
+
+// In case it hasn't been included yet.
+#include "neighbor_search_rules.hpp"
+
+namespace mlpack {
+namespace neighbor {
+
+template<typename SortPolicy, typename MetricType, typename TreeType>
+NeighborSearchRules<SortPolicy, MetricType, TreeType>::NeighborSearchRules(
+ const arma::mat& referenceSet,
+ const arma::mat& querySet,
+ arma::Mat<size_t>& neighbors,
+ arma::mat& distances,
+ MetricType& metric) :
+ referenceSet(referenceSet),
+ querySet(querySet),
+ neighbors(neighbors),
+ distances(distances),
+ metric(metric)
+{ /* Nothing left to do. */ }
+
+template<typename SortPolicy, typename MetricType, typename TreeType>
+inline force_inline // Absolutely MUST be inline so optimizations can happen.
+double NeighborSearchRules<SortPolicy, MetricType, TreeType>::
+BaseCase(const size_t queryIndex, const size_t referenceIndex)
+{
+ // If the datasets are the same, then this search is only using one dataset
+ // and we should not return identical points.
+ if ((&querySet == &referenceSet) && (queryIndex == referenceIndex))
+ return 0.0;
+
+ double distance = metric.Evaluate(querySet.unsafe_col(queryIndex),
+ referenceSet.unsafe_col(referenceIndex));
+
+ // If this distance is better than any of the current candidates, the
+ // SortDistance() function will give us the position to insert it into.
+ arma::vec queryDist = distances.unsafe_col(queryIndex);
+ size_t insertPosition = SortPolicy::SortDistance(queryDist, distance);
+
+ // SortDistance() returns (size_t() - 1) if we shouldn't add it.
+ if (insertPosition != (size_t() - 1))
+ InsertNeighbor(queryIndex, insertPosition, referenceIndex, distance);
+
+ return distance;
+}
+
+template<typename SortPolicy, typename MetricType, typename TreeType>
+void NeighborSearchRules<
+ SortPolicy,
+ MetricType,
+ TreeType>::
+UpdateAfterRecursion(TreeType& queryNode, TreeType& /* referenceNode */)
+{
+ // Find the worst distance that the children found (including any points), and
+ // update the bound accordingly.
+ double worstDistance = SortPolicy::BestDistance();
+
+ // First look through children nodes.
+ for (size_t i = 0; i < queryNode.NumChildren(); ++i)
+ {
+ if (SortPolicy::IsBetter(worstDistance, queryNode.Child(i).Stat().Bound()))
+ worstDistance = queryNode.Child(i).Stat().Bound();
+ }
+
+ // Now look through children points.
+ for (size_t i = 0; i < queryNode.NumPoints(); ++i)
+ {
+ if (SortPolicy::IsBetter(worstDistance,
+ distances(distances.n_rows - 1, queryNode.Point(i))))
+ worstDistance = distances(distances.n_rows - 1, queryNode.Point(i));
+ }
+
+ // Take the worst distance from all of these, and update our bound to reflect
+ // that.
+ queryNode.Stat().Bound() = worstDistance;
+}
+
+template<typename SortPolicy, typename MetricType, typename TreeType>
+inline double NeighborSearchRules<SortPolicy, MetricType, TreeType>::Score(
+ const size_t queryIndex,
+ TreeType& referenceNode) const
+{
+ const arma::vec queryPoint = querySet.unsafe_col(queryIndex);
+ const double distance = SortPolicy::BestPointToNodeDistance(queryPoint,
+ &referenceNode);
+ const double bestDistance = distances(distances.n_rows - 1, queryIndex);
+
+ return (SortPolicy::IsBetter(distance, bestDistance)) ? distance : DBL_MAX;
+}
+
+template<typename SortPolicy, typename MetricType, typename TreeType>
+inline double NeighborSearchRules<SortPolicy, MetricType, TreeType>::Score(
+ const size_t queryIndex,
+ TreeType& referenceNode,
+ const double baseCaseResult) const
+{
+ const arma::vec queryPoint = querySet.unsafe_col(queryIndex);
+ const double distance = SortPolicy::BestPointToNodeDistance(queryPoint,
+ &referenceNode, baseCaseResult);
+ const double bestDistance = distances(distances.n_rows - 1, queryIndex);
+
+ return (SortPolicy::IsBetter(distance, bestDistance)) ? distance : DBL_MAX;
+}
+
+template<typename SortPolicy, typename MetricType, typename TreeType>
+inline double NeighborSearchRules<SortPolicy, MetricType, TreeType>::Rescore(
+ const size_t queryIndex,
+ TreeType& /* referenceNode */,
+ const double oldScore) const
+{
+ // If we are already pruning, still prune.
+ if (oldScore == DBL_MAX)
+ return oldScore;
+
+ // Just check the score again against the distances.
+ const double bestDistance = distances(distances.n_rows - 1, queryIndex);
+
+ return (SortPolicy::IsBetter(oldScore, bestDistance)) ? oldScore : DBL_MAX;
+}
+
+template<typename SortPolicy, typename MetricType, typename TreeType>
+inline double NeighborSearchRules<SortPolicy, MetricType, TreeType>::Score(
+ TreeType& queryNode,
+ TreeType& referenceNode) const
+{
+ const double distance = SortPolicy::BestNodeToNodeDistance(&queryNode,
+ &referenceNode);
+ const double bestDistance = queryNode.Stat().Bound();
+
+ return (SortPolicy::IsBetter(distance, bestDistance)) ? distance : DBL_MAX;
+}
+
+template<typename SortPolicy, typename MetricType, typename TreeType>
+inline double NeighborSearchRules<SortPolicy, MetricType, TreeType>::Score(
+ TreeType& queryNode,
+ TreeType& referenceNode,
+ const double baseCaseResult) const
+{
+ const double distance = SortPolicy::BestNodeToNodeDistance(&queryNode,
+ &referenceNode, baseCaseResult);
+ const double bestDistance = queryNode.Stat().Bound();
+
+ return (SortPolicy::IsBetter(distance, bestDistance)) ? distance : DBL_MAX;
+}
+
+template<typename SortPolicy, typename MetricType, typename TreeType>
+inline double NeighborSearchRules<SortPolicy, MetricType, TreeType>::Rescore(
+ TreeType& queryNode,
+ TreeType& /* referenceNode */,
+ const double oldScore) const
+{
+ if (oldScore == DBL_MAX)
+ return oldScore;
+
+ const double bestDistance = queryNode.Stat().Bound();
+
+ return (SortPolicy::IsBetter(oldScore, bestDistance)) ? oldScore : DBL_MAX;
+}
+
+/**
+ * Helper function to insert a point into the neighbors and distances matrices.
+ *
+ * @param queryIndex Index of point whose neighbors we are inserting into.
+ * @param pos Position in list to insert into.
+ * @param neighbor Index of reference point which is being inserted.
+ * @param distance Distance from query point to reference point.
+ */
+template<typename SortPolicy, typename MetricType, typename TreeType>
+void NeighborSearchRules<SortPolicy, MetricType, TreeType>::InsertNeighbor(
+ const size_t queryIndex,
+ const size_t pos,
+ const size_t neighbor,
+ const double distance)
+{
+ // We only memmove() if there is actually a need to shift something.
+ if (pos < (distances.n_rows - 1))
+ {
+ int len = (distances.n_rows - 1) - pos;
+ memmove(distances.colptr(queryIndex) + (pos + 1),
+ distances.colptr(queryIndex) + pos,
+ sizeof(double) * len);
+ memmove(neighbors.colptr(queryIndex) + (pos + 1),
+ neighbors.colptr(queryIndex) + pos,
+ sizeof(size_t) * len);
+ }
+
+ // Now put the new information in the right index.
+ distances(pos, queryIndex) = distance;
+ neighbors(pos, queryIndex) = neighbor;
+}
+
+}; // namespace neighbor
+}; // namespace mlpack
+
+#endif // __MLPACK_METHODS_NEIGHBOR_SEARCH_NEAREST_NEIGHBOR_RULES_IMPL_HPP
Deleted: mlpack/tags/mlpack-1.0.2/src/mlpack/methods/nmf/nmf_main.cpp
===================================================================
--- mlpack/trunk/src/mlpack/methods/nmf/nmf_main.cpp 2012-08-09 09:02:13 UTC (rev 13379)
+++ mlpack/tags/mlpack-1.0.2/src/mlpack/methods/nmf/nmf_main.cpp 2012-08-15 17:40:09 UTC (rev 13406)
@@ -1,132 +0,0 @@
-/**
- * @file nmf_main.cpp
- * @author Mohan Rajendran
- *
- * Main executable to run NMF.
- */
-#include <mlpack/core.hpp>
-
-#include "nmf.hpp"
-
-#include "random_init.hpp"
-#include "mult_dist_update_rules.hpp"
-#include "mult_div_update_rules.hpp"
-#include "als_update_rules.hpp"
-
-using namespace mlpack;
-using namespace mlpack::nmf;
-using namespace std;
-
-// Document program.
-PROGRAM_INFO("Non-negative Matrix Factorization", "This program performs "
- "non-negative matrix factorization on the given dataset, storing the "
- "resulting decomposed matrices in the specified files. For an input "
- "dataset V, NMF decomposes V into two matrices W and H such that "
- "\n\n"
- "V = W * H"
- "\n\n"
- "where all elements in W and H are non-negative. If V is of size (n x m),"
- " then W will be of size (n x r) and H will be of size (r x m), where r is "
- "the rank of the factorization (specified by --rank)."
- "\n\n"
- "Optionally, the desired update rules for each NMF iteration can be chosen "
- "from the following list:"
- "\n\n"
- " - multdist: multiplicative distance-based update rules (Lee and Seung "
- "1999)\n"
- " - multdiv: multiplicative divergence-based update rules (Lee and Seung "
- "1999)\n"
- " - als: alternating least squares update rules (Paatero and Tapper 1994)"
- "\n\n"
- "The maximum number of iterations is specified with --max_iterations, and "
- "the minimum residue required for algorithm termination is specified with "
- "--min_residue.");
-
-// Parameters for program.
-PARAM_STRING_REQ("input_file", "Input dataset to perform NMF on.", "i");
-PARAM_STRING_REQ("w_file", "File to save the calculated W matrix to.", "w");
-PARAM_STRING_REQ("h_file", "File to save the calculated H matrix to.", "h");
-PARAM_INT_REQ("rank", "Rank of the factorization.", "r");
-
-PARAM_INT("max_iterations", "Number of iterations before NMF terminates (0 runs"
- " until convergence.", "m", 10000);
-PARAM_INT("seed", "Random seed. If 0, 'std::time(NULL)' is used.", "s", 0);
-PARAM_DOUBLE("min_residue", "The minimum root mean square residue allowed for "
- "each iteration, below which the program terminates.", "e", 1e-5);
-
-PARAM_STRING("update_rules", "Update rules for each iteration; ( multdist | "
- "multdiv | als ).", "u", "multdist");
-
-int main(int argc, char** argv)
-{
- // Parse command line.
- CLI::ParseCommandLine(argc, argv);
-
- // Initialize random seed.
- if (CLI::GetParam<int>("seed") != 0)
- math::RandomSeed((size_t) CLI::GetParam<int>("seed"));
- else
- math::RandomSeed((size_t) std::time(NULL));
-
- // Gather parameters.
- const string inputFile = CLI::GetParam<string>("input_file");
- const string hOutputFile = CLI::GetParam<string>("h_file");
- const string wOutputFile = CLI::GetParam<string>("w_file");
- const size_t r = CLI::GetParam<int>("rank");
- const size_t maxIterations = CLI::GetParam<int>("max_iterations");
- const double minResidue = CLI::GetParam<double>("min_residue");
- const string updateRules = CLI::GetParam<string>("update_rules");
-
- // Validate rank.
- if (r < 1)
- {
- Log::Fatal << "The rank of the factorization cannot be less than 1."
- << std::endl;
- }
-
- if ((updateRules != "multdist") &&
- (updateRules != "multdiv") &&
- (updateRules != "als"))
- {
- Log::Fatal << "Invalid update rules ('" << updateRules << "'); must be '"
- << "multdist', 'multdiv', or 'als'." << std::endl;
- }
-
- // Load input dataset.
- arma::mat V;
- data::Load(inputFile, V, true);
-
- arma::mat W;
- arma::mat H;
-
- // Perform NMF with the specified update rules.
- if (updateRules == "multdist")
- {
- Log::Info << "Performing NMF with multiplicative distance-based update "
- << "rules." << std::endl;
- NMF<> nmf(maxIterations, minResidue);
- nmf.Apply(V, r, W, H);
- }
- else if (updateRules == "multdiv")
- {
- Log::Info << "Performing NMF with multiplicative divergence-based update "
- << "rules." << std::endl;
- NMF<RandomInitialization,
- WMultiplicativeDivergenceRule,
- HMultiplicativeDivergenceRule> nmf(maxIterations, minResidue);
- nmf.Apply(V, r, W, H);
- }
- else if (updateRules == "als")
- {
- Log::Info << "Performing NMF with alternating least squared update rules."
- << std::endl;
- NMF<RandomInitialization,
- WAlternatingLeastSquaresRule,
- HAlternatingLeastSquaresRule> nmf(maxIterations, minResidue);
- nmf.Apply(V, r, W, H);
- }
-
- // Save results.
- data::Save(wOutputFile, W, false);
- data::Save(hOutputFile, H, false);
-}
Copied: mlpack/tags/mlpack-1.0.2/src/mlpack/methods/nmf/nmf_main.cpp (from rev 13387, mlpack/trunk/src/mlpack/methods/nmf/nmf_main.cpp)
===================================================================
--- mlpack/tags/mlpack-1.0.2/src/mlpack/methods/nmf/nmf_main.cpp (rev 0)
+++ mlpack/tags/mlpack-1.0.2/src/mlpack/methods/nmf/nmf_main.cpp 2012-08-15 17:40:09 UTC (rev 13406)
@@ -0,0 +1,132 @@
+/**
+ * @file nmf_main.cpp
+ * @author Mohan Rajendran
+ *
+ * Main executable to run NMF.
+ */
+#include <mlpack/core.hpp>
+
+#include "nmf.hpp"
+
+#include "random_init.hpp"
+#include "mult_dist_update_rules.hpp"
+#include "mult_div_update_rules.hpp"
+#include "als_update_rules.hpp"
+
+using namespace mlpack;
+using namespace mlpack::nmf;
+using namespace std;
+
+// Document program.
+PROGRAM_INFO("Non-negative Matrix Factorization", "This program performs "
+ "non-negative matrix factorization on the given dataset, storing the "
+ "resulting decomposed matrices in the specified files. For an input "
+ "dataset V, NMF decomposes V into two matrices W and H such that "
+ "\n\n"
+ "V = W * H"
+ "\n\n"
+ "where all elements in W and H are non-negative. If V is of size (n x m),"
+ " then W will be of size (n x r) and H will be of size (r x m), where r is "
+ "the rank of the factorization (specified by --rank)."
+ "\n\n"
+ "Optionally, the desired update rules for each NMF iteration can be chosen "
+ "from the following list:"
+ "\n\n"
+ " - multdist: multiplicative distance-based update rules (Lee and Seung "
+ "1999)\n"
+ " - multdiv: multiplicative divergence-based update rules (Lee and Seung "
+ "1999)\n"
+ " - als: alternating least squares update rules (Paatero and Tapper 1994)"
+ "\n\n"
+ "The maximum number of iterations is specified with --max_iterations, and "
+ "the minimum residue required for algorithm termination is specified with "
+ "--min_residue.");
+
+// Parameters for program.
+PARAM_STRING_REQ("input_file", "Input dataset to perform NMF on.", "i");
+PARAM_STRING_REQ("w_file", "File to save the calculated W matrix to.", "W");
+PARAM_STRING_REQ("h_file", "File to save the calculated H matrix to.", "H");
+PARAM_INT_REQ("rank", "Rank of the factorization.", "r");
+
+PARAM_INT("max_iterations", "Number of iterations before NMF terminates (0 runs"
+ " until convergence.", "m", 10000);
+PARAM_INT("seed", "Random seed. If 0, 'std::time(NULL)' is used.", "s", 0);
+PARAM_DOUBLE("min_residue", "The minimum root mean square residue allowed for "
+ "each iteration, below which the program terminates.", "e", 1e-5);
+
+PARAM_STRING("update_rules", "Update rules for each iteration; ( multdist | "
+ "multdiv | als ).", "u", "multdist");
+
+int main(int argc, char** argv)
+{
+ // Parse command line.
+ CLI::ParseCommandLine(argc, argv);
+
+ // Initialize random seed.
+ if (CLI::GetParam<int>("seed") != 0)
+ math::RandomSeed((size_t) CLI::GetParam<int>("seed"));
+ else
+ math::RandomSeed((size_t) std::time(NULL));
+
+ // Gather parameters.
+ const string inputFile = CLI::GetParam<string>("input_file");
+ const string hOutputFile = CLI::GetParam<string>("h_file");
+ const string wOutputFile = CLI::GetParam<string>("w_file");
+ const size_t r = CLI::GetParam<int>("rank");
+ const size_t maxIterations = CLI::GetParam<int>("max_iterations");
+ const double minResidue = CLI::GetParam<double>("min_residue");
+ const string updateRules = CLI::GetParam<string>("update_rules");
+
+ // Validate rank.
+ if (r < 1)
+ {
+ Log::Fatal << "The rank of the factorization cannot be less than 1."
+ << std::endl;
+ }
+
+ if ((updateRules != "multdist") &&
+ (updateRules != "multdiv") &&
+ (updateRules != "als"))
+ {
+ Log::Fatal << "Invalid update rules ('" << updateRules << "'); must be '"
+ << "multdist', 'multdiv', or 'als'." << std::endl;
+ }
+
+ // Load input dataset.
+ arma::mat V;
+ data::Load(inputFile, V, true);
+
+ arma::mat W;
+ arma::mat H;
+
+ // Perform NMF with the specified update rules.
+ if (updateRules == "multdist")
+ {
+ Log::Info << "Performing NMF with multiplicative distance-based update "
+ << "rules." << std::endl;
+ NMF<> nmf(maxIterations, minResidue);
+ nmf.Apply(V, r, W, H);
+ }
+ else if (updateRules == "multdiv")
+ {
+ Log::Info << "Performing NMF with multiplicative divergence-based update "
+ << "rules." << std::endl;
+ NMF<RandomInitialization,
+ WMultiplicativeDivergenceRule,
+ HMultiplicativeDivergenceRule> nmf(maxIterations, minResidue);
+ nmf.Apply(V, r, W, H);
+ }
+ else if (updateRules == "als")
+ {
+ Log::Info << "Performing NMF with alternating least squared update rules."
+ << std::endl;
+ NMF<RandomInitialization,
+ WAlternatingLeastSquaresRule,
+ HAlternatingLeastSquaresRule> nmf(maxIterations, minResidue);
+ nmf.Apply(V, r, W, H);
+ }
+
+ // Save results.
+ data::Save(wOutputFile, W, false);
+ data::Save(hOutputFile, H, false);
+}
Deleted: mlpack/tags/mlpack-1.0.2/src/mlpack/tests/allknn_test.cpp
===================================================================
--- mlpack/trunk/src/mlpack/tests/allknn_test.cpp 2012-08-09 09:02:13 UTC (rev 13379)
+++ mlpack/tags/mlpack-1.0.2/src/mlpack/tests/allknn_test.cpp 2012-08-15 17:40:09 UTC (rev 13406)
@@ -1,516 +0,0 @@
-/**
- * @file allknn_test.cpp
- *
- * Test file for AllkNN class.
- */
-#include <mlpack/core.hpp>
-#include <mlpack/methods/neighbor_search/neighbor_search.hpp>
-#include <mlpack/core/tree/cover_tree.hpp>
-#include <boost/test/unit_test.hpp>
-#include "old_boost_test_definitions.hpp"
-
-using namespace mlpack;
-using namespace mlpack::neighbor;
-
-BOOST_AUTO_TEST_SUITE(AllkNNTest);
-
-/**
- * Simple nearest-neighbors test with small, synthetic dataset. This is an
- * exhaustive test, which checks that each method for performing the calculation
- * (dual-tree, single-tree, naive) produces the correct results. An
- * eleven-point dataset and the ten nearest neighbors are taken. The dataset is
- * in one dimension for simplicity -- the correct functionality of distance
- * functions is not tested here.
- */
-BOOST_AUTO_TEST_CASE(ExhaustiveSyntheticTest)
-{
- // Set up our data.
- arma::mat data(1, 11);
- data[0] = 0.05; // Row addressing is unnecessary (they are all 0).
- data[1] = 0.35;
- data[2] = 0.15;
- data[3] = 1.25;
- data[4] = 5.05;
- data[5] = -0.22;
- data[6] = -2.00;
- data[7] = -1.30;
- data[8] = 0.45;
- data[9] = 0.90;
- data[10] = 1.00;
-
- // We will loop through three times, one for each method of performing the
- // calculation.
- for (int i = 0; i < 3; i++)
- {
- AllkNN* allknn;
- arma::mat dataMutable = data;
- switch (i)
- {
- case 0: // Use the dual-tree method.
- allknn = new AllkNN(dataMutable, false, false, 1);
- break;
- case 1: // Use the single-tree method.
- allknn = new AllkNN(dataMutable, false, true, 1);
- break;
- case 2: // Use the naive method.
- allknn = new AllkNN(dataMutable, true);
- break;
- }
-
- // Now perform the actual calculation.
- arma::Mat<size_t> neighbors;
- arma::mat distances;
- allknn->Search(10, neighbors, distances);
-
- // Now the exhaustive check for correctness. This will be long. We must
- // also remember that the distances returned are squared distances. As a
- // result, distance comparisons are written out as (distance * distance) for
- // readability.
-
- // Neighbors of point 0.
- BOOST_REQUIRE(neighbors(0, 0) == 2);
- BOOST_REQUIRE_CLOSE(distances(0, 0), (0.10 * 0.10), 1e-5);
- BOOST_REQUIRE(neighbors(1, 0) == 5);
- BOOST_REQUIRE_CLOSE(distances(1, 0), (0.27 * 0.27), 1e-5);
- BOOST_REQUIRE(neighbors(2, 0) == 1);
- BOOST_REQUIRE_CLOSE(distances(2, 0), (0.30 * 0.30), 1e-5);
- BOOST_REQUIRE(neighbors(3, 0) == 8);
- BOOST_REQUIRE_CLOSE(distances(3, 0), (0.40 * 0.40), 1e-5);
- BOOST_REQUIRE(neighbors(4, 0) == 9);
- BOOST_REQUIRE_CLOSE(distances(4, 0), (0.85 * 0.85), 1e-5);
- BOOST_REQUIRE(neighbors(5, 0) == 10);
- BOOST_REQUIRE_CLOSE(distances(5, 0), (0.95 * 0.95), 1e-5);
- BOOST_REQUIRE(neighbors(6, 0) == 3);
- BOOST_REQUIRE_CLOSE(distances(6, 0), (1.20 * 1.20), 1e-5);
- BOOST_REQUIRE(neighbors(7, 0) == 7);
- BOOST_REQUIRE_CLOSE(distances(7, 0), (1.35 * 1.35), 1e-5);
- BOOST_REQUIRE(neighbors(8, 0) == 6);
- BOOST_REQUIRE_CLOSE(distances(8, 0), (2.05 * 2.05), 1e-5);
- BOOST_REQUIRE(neighbors(9, 0) == 4);
- BOOST_REQUIRE_CLOSE(distances(9, 0), (5.00 * 5.00), 1e-5);
-
- // Neighbors of point 1.
- BOOST_REQUIRE(neighbors(0, 1) == 8);
- BOOST_REQUIRE_CLOSE(distances(0, 1), (0.10 * 0.10), 1e-5);
- BOOST_REQUIRE(neighbors(1, 1) == 2);
- BOOST_REQUIRE_CLOSE(distances(1, 1), (0.20 * 0.20), 1e-5);
- BOOST_REQUIRE(neighbors(2, 1) == 0);
- BOOST_REQUIRE_CLOSE(distances(2, 1), (0.30 * 0.30), 1e-5);
- BOOST_REQUIRE(neighbors(3, 1) == 9);
- BOOST_REQUIRE_CLOSE(distances(3, 1), (0.55 * 0.55), 1e-5);
- BOOST_REQUIRE(neighbors(4, 1) == 5);
- BOOST_REQUIRE_CLOSE(distances(4, 1), (0.57 * 0.57), 1e-5);
- BOOST_REQUIRE(neighbors(5, 1) == 10);
- BOOST_REQUIRE_CLOSE(distances(5, 1), (0.65 * 0.65), 1e-5);
- BOOST_REQUIRE(neighbors(6, 1) == 3);
- BOOST_REQUIRE_CLOSE(distances(6, 1), (0.90 * 0.90), 1e-5);
- BOOST_REQUIRE(neighbors(7, 1) == 7);
- BOOST_REQUIRE_CLOSE(distances(7, 1), (1.65 * 1.65), 1e-5);
- BOOST_REQUIRE(neighbors(8, 1) == 6);
- BOOST_REQUIRE_CLOSE(distances(8, 1), (2.35 * 2.35), 1e-5);
- BOOST_REQUIRE(neighbors(9, 1) == 4);
- BOOST_REQUIRE_CLOSE(distances(9, 1), (4.70 * 4.70), 1e-5);
-
- // Neighbors of point 2.
- BOOST_REQUIRE(neighbors(0, 2) == 0);
- BOOST_REQUIRE_CLOSE(distances(0, 2), (0.10 * 0.10), 1e-5);
- BOOST_REQUIRE(neighbors(1, 2) == 1);
- BOOST_REQUIRE_CLOSE(distances(1, 2), (0.20 * 0.20), 1e-5);
- BOOST_REQUIRE(neighbors(2, 2) == 8);
- BOOST_REQUIRE_CLOSE(distances(2, 2), (0.30 * 0.30), 1e-5);
- BOOST_REQUIRE(neighbors(3, 2) == 5);
- BOOST_REQUIRE_CLOSE(distances(3, 2), (0.37 * 0.37), 1e-5);
- BOOST_REQUIRE(neighbors(4, 2) == 9);
- BOOST_REQUIRE_CLOSE(distances(4, 2), (0.75 * 0.75), 1e-5);
- BOOST_REQUIRE(neighbors(5, 2) == 10);
- BOOST_REQUIRE_CLOSE(distances(5, 2), (0.85 * 0.85), 1e-5);
- BOOST_REQUIRE(neighbors(6, 2) == 3);
- BOOST_REQUIRE_CLOSE(distances(6, 2), (1.10 * 1.10), 1e-5);
- BOOST_REQUIRE(neighbors(7, 2) == 7);
- BOOST_REQUIRE_CLOSE(distances(7, 2), (1.45 * 1.45), 1e-5);
- BOOST_REQUIRE(neighbors(8, 2) == 6);
- BOOST_REQUIRE_CLOSE(distances(8, 2), (2.15 * 2.15), 1e-5);
- BOOST_REQUIRE(neighbors(9, 2) == 4);
- BOOST_REQUIRE_CLOSE(distances(9, 2), (4.90 * 4.90), 1e-5);
-
- // Neighbors of point 3.
- BOOST_REQUIRE(neighbors(0, 3) == 10);
- BOOST_REQUIRE_CLOSE(distances(0, 3), (0.25 * 0.25), 1e-5);
- BOOST_REQUIRE(neighbors(1, 3) == 9);
- BOOST_REQUIRE_CLOSE(distances(1, 3), (0.35 * 0.35), 1e-5);
- BOOST_REQUIRE(neighbors(2, 3) == 8);
- BOOST_REQUIRE_CLOSE(distances(2, 3), (0.80 * 0.80), 1e-5);
- BOOST_REQUIRE(neighbors(3, 3) == 1);
- BOOST_REQUIRE_CLOSE(distances(3, 3), (0.90 * 0.90), 1e-5);
- BOOST_REQUIRE(neighbors(4, 3) == 2);
- BOOST_REQUIRE_CLOSE(distances(4, 3), (1.10 * 1.10), 1e-5);
- BOOST_REQUIRE(neighbors(5, 3) == 0);
- BOOST_REQUIRE_CLOSE(distances(5, 3), (1.20 * 1.20), 1e-5);
- BOOST_REQUIRE(neighbors(6, 3) == 5);
- BOOST_REQUIRE_CLOSE(distances(6, 3), (1.47 * 1.47), 1e-5);
- BOOST_REQUIRE(neighbors(7, 3) == 7);
- BOOST_REQUIRE_CLOSE(distances(7, 3), (2.55 * 2.55), 1e-5);
- BOOST_REQUIRE(neighbors(8, 3) == 6);
- BOOST_REQUIRE_CLOSE(distances(8, 3), (3.25 * 3.25), 1e-5);
- BOOST_REQUIRE(neighbors(9, 3) == 4);
- BOOST_REQUIRE_CLOSE(distances(9, 3), (3.80 * 3.80), 1e-5);
-
- // Neighbors of point 4.
- BOOST_REQUIRE(neighbors(0, 4) == 3);
- BOOST_REQUIRE_CLOSE(distances(0, 4), (3.80 * 3.80), 1e-5);
- BOOST_REQUIRE(neighbors(1, 4) == 10);
- BOOST_REQUIRE_CLOSE(distances(1, 4), (4.05 * 4.05), 1e-5);
- BOOST_REQUIRE(neighbors(2, 4) == 9);
- BOOST_REQUIRE_CLOSE(distances(2, 4), (4.15 * 4.15), 1e-5);
- BOOST_REQUIRE(neighbors(3, 4) == 8);
- BOOST_REQUIRE_CLOSE(distances(3, 4), (4.60 * 4.60), 1e-5);
- BOOST_REQUIRE(neighbors(4, 4) == 1);
- BOOST_REQUIRE_CLOSE(distances(4, 4), (4.70 * 4.70), 1e-5);
- BOOST_REQUIRE(neighbors(5, 4) == 2);
- BOOST_REQUIRE_CLOSE(distances(5, 4), (4.90 * 4.90), 1e-5);
- BOOST_REQUIRE(neighbors(6, 4) == 0);
- BOOST_REQUIRE_CLOSE(distances(6, 4), (5.00 * 5.00), 1e-5);
- BOOST_REQUIRE(neighbors(7, 4) == 5);
- BOOST_REQUIRE_CLOSE(distances(7, 4), (5.27 * 5.27), 1e-5);
- BOOST_REQUIRE(neighbors(8, 4) == 7);
- BOOST_REQUIRE_CLOSE(distances(8, 4), (6.35 * 6.35), 1e-5);
- BOOST_REQUIRE(neighbors(9, 4) == 6);
- BOOST_REQUIRE_CLOSE(distances(9, 4), (7.05 * 7.05), 1e-5);
-
- // Neighbors of point 5.
- BOOST_REQUIRE(neighbors(0, 5) == 0);
- BOOST_REQUIRE_CLOSE(distances(0, 5), (0.27 * 0.27), 1e-5);
- BOOST_REQUIRE(neighbors(1, 5) == 2);
- BOOST_REQUIRE_CLOSE(distances(1, 5), (0.37 * 0.37), 1e-5);
- BOOST_REQUIRE(neighbors(2, 5) == 1);
- BOOST_REQUIRE_CLOSE(distances(2, 5), (0.57 * 0.57), 1e-5);
- BOOST_REQUIRE(neighbors(3, 5) == 8);
- BOOST_REQUIRE_CLOSE(distances(3, 5), (0.67 * 0.67), 1e-5);
- BOOST_REQUIRE(neighbors(4, 5) == 7);
- BOOST_REQUIRE_CLOSE(distances(4, 5), (1.08 * 1.08), 1e-5);
- BOOST_REQUIRE(neighbors(5, 5) == 9);
- BOOST_REQUIRE_CLOSE(distances(5, 5), (1.12 * 1.12), 1e-5);
- BOOST_REQUIRE(neighbors(6, 5) == 10);
- BOOST_REQUIRE_CLOSE(distances(6, 5), (1.22 * 1.22), 1e-5);
- BOOST_REQUIRE(neighbors(7, 5) == 3);
- BOOST_REQUIRE_CLOSE(distances(7, 5), (1.47 * 1.47), 1e-5);
- BOOST_REQUIRE(neighbors(8, 5) == 6);
- BOOST_REQUIRE_CLOSE(distances(8, 5), (1.78 * 1.78), 1e-5);
- BOOST_REQUIRE(neighbors(9, 5) == 4);
- BOOST_REQUIRE_CLOSE(distances(9, 5), (5.27 * 5.27), 1e-5);
-
- // Neighbors of point 6.
- BOOST_REQUIRE(neighbors(0, 6) == 7);
- BOOST_REQUIRE_CLOSE(distances(0, 6), (0.70 * 0.70), 1e-5);
- BOOST_REQUIRE(neighbors(1, 6) == 5);
- BOOST_REQUIRE_CLOSE(distances(1, 6), (1.78 * 1.78), 1e-5);
- BOOST_REQUIRE(neighbors(2, 6) == 0);
- BOOST_REQUIRE_CLOSE(distances(2, 6), (2.05 * 2.05), 1e-5);
- BOOST_REQUIRE(neighbors(3, 6) == 2);
- BOOST_REQUIRE_CLOSE(distances(3, 6), (2.15 * 2.15), 1e-5);
- BOOST_REQUIRE(neighbors(4, 6) == 1);
- BOOST_REQUIRE_CLOSE(distances(4, 6), (2.35 * 2.35), 1e-5);
- BOOST_REQUIRE(neighbors(5, 6) == 8);
- BOOST_REQUIRE_CLOSE(distances(5, 6), (2.45 * 2.45), 1e-5);
- BOOST_REQUIRE(neighbors(6, 6) == 9);
- BOOST_REQUIRE_CLOSE(distances(6, 6), (2.90 * 2.90), 1e-5);
- BOOST_REQUIRE(neighbors(7, 6) == 10);
- BOOST_REQUIRE_CLOSE(distances(7, 6), (3.00 * 3.00), 1e-5);
- BOOST_REQUIRE(neighbors(8, 6) == 3);
- BOOST_REQUIRE_CLOSE(distances(8, 6), (3.25 * 3.25), 1e-5);
- BOOST_REQUIRE(neighbors(9, 6) == 4);
- BOOST_REQUIRE_CLOSE(distances(9, 6), (7.05 * 7.05), 1e-5);
-
- // Neighbors of point 7.
- BOOST_REQUIRE(neighbors(0, 7) == 6);
- BOOST_REQUIRE_CLOSE(distances(0, 7), (0.70 * 0.70), 1e-5);
- BOOST_REQUIRE(neighbors(1, 7) == 5);
- BOOST_REQUIRE_CLOSE(distances(1, 7), (1.08 * 1.08), 1e-5);
- BOOST_REQUIRE(neighbors(2, 7) == 0);
- BOOST_REQUIRE_CLOSE(distances(2, 7), (1.35 * 1.35), 1e-5);
- BOOST_REQUIRE(neighbors(3, 7) == 2);
- BOOST_REQUIRE_CLOSE(distances(3, 7), (1.45 * 1.45), 1e-5);
- BOOST_REQUIRE(neighbors(4, 7) == 1);
- BOOST_REQUIRE_CLOSE(distances(4, 7), (1.65 * 1.65), 1e-5);
- BOOST_REQUIRE(neighbors(5, 7) == 8);
- BOOST_REQUIRE_CLOSE(distances(5, 7), (1.75 * 1.75), 1e-5);
- BOOST_REQUIRE(neighbors(6, 7) == 9);
- BOOST_REQUIRE_CLOSE(distances(6, 7), (2.20 * 2.20), 1e-5);
- BOOST_REQUIRE(neighbors(7, 7) == 10);
- BOOST_REQUIRE_CLOSE(distances(7, 7), (2.30 * 2.30), 1e-5);
- BOOST_REQUIRE(neighbors(8, 7) == 3);
- BOOST_REQUIRE_CLOSE(distances(8, 7), (2.55 * 2.55), 1e-5);
- BOOST_REQUIRE(neighbors(9, 7) == 4);
- BOOST_REQUIRE_CLOSE(distances(9, 7), (6.35 * 6.35), 1e-5);
-
- // Neighbors of point 8.
- BOOST_REQUIRE(neighbors(0, 8) == 1);
- BOOST_REQUIRE_CLOSE(distances(0, 8), (0.10 * 0.10), 1e-5);
- BOOST_REQUIRE(neighbors(1, 8) == 2);
- BOOST_REQUIRE_CLOSE(distances(1, 8), (0.30 * 0.30), 1e-5);
- BOOST_REQUIRE(neighbors(2, 8) == 0);
- BOOST_REQUIRE_CLOSE(distances(2, 8), (0.40 * 0.40), 1e-5);
- BOOST_REQUIRE(neighbors(3, 8) == 9);
- BOOST_REQUIRE_CLOSE(distances(3, 8), (0.45 * 0.45), 1e-5);
- BOOST_REQUIRE(neighbors(4, 8) == 10);
- BOOST_REQUIRE_CLOSE(distances(4, 8), (0.55 * 0.55), 1e-5);
- BOOST_REQUIRE(neighbors(5, 8) == 5);
- BOOST_REQUIRE_CLOSE(distances(5, 8), (0.67 * 0.67), 1e-5);
- BOOST_REQUIRE(neighbors(6, 8) == 3);
- BOOST_REQUIRE_CLOSE(distances(6, 8), (0.80 * 0.80), 1e-5);
- BOOST_REQUIRE(neighbors(7, 8) == 7);
- BOOST_REQUIRE_CLOSE(distances(7, 8), (1.75 * 1.75), 1e-5);
- BOOST_REQUIRE(neighbors(8, 8) == 6);
- BOOST_REQUIRE_CLOSE(distances(8, 8), (2.45 * 2.45), 1e-5);
- BOOST_REQUIRE(neighbors(9, 8) == 4);
- BOOST_REQUIRE_CLOSE(distances(9, 8), (4.60 * 4.60), 1e-5);
-
- // Neighbors of point 9.
- BOOST_REQUIRE(neighbors(0, 9) == 10);
- BOOST_REQUIRE_CLOSE(distances(0, 9), (0.10 * 0.10), 1e-5);
- BOOST_REQUIRE(neighbors(1, 9) == 3);
- BOOST_REQUIRE_CLOSE(distances(1, 9), (0.35 * 0.35), 1e-5);
- BOOST_REQUIRE(neighbors(2, 9) == 8);
- BOOST_REQUIRE_CLOSE(distances(2, 9), (0.45 * 0.45), 1e-5);
- BOOST_REQUIRE(neighbors(3, 9) == 1);
- BOOST_REQUIRE_CLOSE(distances(3, 9), (0.55 * 0.55), 1e-5);
- BOOST_REQUIRE(neighbors(4, 9) == 2);
- BOOST_REQUIRE_CLOSE(distances(4, 9), (0.75 * 0.75), 1e-5);
- BOOST_REQUIRE(neighbors(5, 9) == 0);
- BOOST_REQUIRE_CLOSE(distances(5, 9), (0.85 * 0.85), 1e-5);
- BOOST_REQUIRE(neighbors(6, 9) == 5);
- BOOST_REQUIRE_CLOSE(distances(6, 9), (1.12 * 1.12), 1e-5);
- BOOST_REQUIRE(neighbors(7, 9) == 7);
- BOOST_REQUIRE_CLOSE(distances(7, 9), (2.20 * 2.20), 1e-5);
- BOOST_REQUIRE(neighbors(8, 9) == 6);
- BOOST_REQUIRE_CLOSE(distances(8, 9), (2.90 * 2.90), 1e-5);
- BOOST_REQUIRE(neighbors(9, 9) == 4);
- BOOST_REQUIRE_CLOSE(distances(9, 9), (4.15 * 4.15), 1e-5);
-
- // Neighbors of point 10.
- BOOST_REQUIRE(neighbors(0, 10) == 9);
- BOOST_REQUIRE_CLOSE(distances(0, 10), (0.10 * 0.10), 1e-5);
- BOOST_REQUIRE(neighbors(1, 10) == 3);
- BOOST_REQUIRE_CLOSE(distances(1, 10), (0.25 * 0.25), 1e-5);
- BOOST_REQUIRE(neighbors(2, 10) == 8);
- BOOST_REQUIRE_CLOSE(distances(2, 10), (0.55 * 0.55), 1e-5);
- BOOST_REQUIRE(neighbors(3, 10) == 1);
- BOOST_REQUIRE_CLOSE(distances(3, 10), (0.65 * 0.65), 1e-5);
- BOOST_REQUIRE(neighbors(4, 10) == 2);
- BOOST_REQUIRE_CLOSE(distances(4, 10), (0.85 * 0.85), 1e-5);
- BOOST_REQUIRE(neighbors(5, 10) == 0);
- BOOST_REQUIRE_CLOSE(distances(5, 10), (0.95 * 0.95), 1e-5);
- BOOST_REQUIRE(neighbors(6, 10) == 5);
- BOOST_REQUIRE_CLOSE(distances(6, 10), (1.22 * 1.22), 1e-5);
- BOOST_REQUIRE(neighbors(7, 10) == 7);
- BOOST_REQUIRE_CLOSE(distances(7, 10), (2.30 * 2.30), 1e-5);
- BOOST_REQUIRE(neighbors(8, 10) == 6);
- BOOST_REQUIRE_CLOSE(distances(8, 10), (3.00 * 3.00), 1e-5);
- BOOST_REQUIRE(neighbors(9, 10) == 4);
- BOOST_REQUIRE_CLOSE(distances(9, 10), (4.05 * 4.05), 1e-5);
-
- // Clean the memory.
- delete allknn;
- }
-}
-
-/**
- * Test the dual-tree nearest-neighbors method with the naive method. This
- * uses both a query and reference dataset.
- *
- * Errors are produced if the results are not identical.
- */
-BOOST_AUTO_TEST_CASE(DualTreeVsNaive1)
-{
- arma::mat dataForTree;
-
- // Hard-coded filename: bad!
- if (!data::Load("test_data_3_1000.csv", dataForTree))
- BOOST_FAIL("Cannot load test dataset test_data_3_1000.csv!");
-
- // Set up matrices to work with.
- arma::mat dualQuery(dataForTree);
- arma::mat dualReferences(dataForTree);
- arma::mat naiveQuery(dataForTree);
- arma::mat naiveReferences(dataForTree);
-
- AllkNN allknn(dualQuery, dualReferences);
-
- AllkNN naive(naiveQuery, naiveReferences, true);
-
- arma::Mat<size_t> resultingNeighborsTree;
- arma::mat distancesTree;
- allknn.Search(15, resultingNeighborsTree, distancesTree);
-
- arma::Mat<size_t> resultingNeighborsNaive;
- arma::mat distancesNaive;
- naive.Search(15, resultingNeighborsNaive, distancesNaive);
-
- for (size_t i = 0; i < resultingNeighborsTree.n_elem; i++)
- {
- BOOST_REQUIRE(resultingNeighborsTree[i] == resultingNeighborsNaive[i]);
- BOOST_REQUIRE_CLOSE(distancesTree[i], distancesNaive[i], 1e-5);
- }
-}
-
-/**
- * Test the dual-tree nearest-neighbors method with the naive method. This uses
- * only a reference dataset.
- *
- * Errors are produced if the results are not identical.
- */
-BOOST_AUTO_TEST_CASE(DualTreeVsNaive2)
-{
- arma::mat dataForTree;
-
- // Hard-coded filename: bad!
- // Code duplication: also bad!
- if (!data::Load("test_data_3_1000.csv", dataForTree))
- BOOST_FAIL("Cannot load test dataset test_data_3_1000.csv!");
-
- // Set up matrices to work with (may not be necessary with no ALIAS_MATRIX?).
- arma::mat dualQuery(dataForTree);
- arma::mat naiveQuery(dataForTree);
-
- AllkNN allknn(dualQuery);
-
- // Set naive mode.
- AllkNN naive(naiveQuery, true);
-
- arma::Mat<size_t> resultingNeighborsTree;
- arma::mat distancesTree;
- allknn.Search(15, resultingNeighborsTree, distancesTree);
-
- arma::Mat<size_t> resultingNeighborsNaive;
- arma::mat distancesNaive;
- naive.Search(15, resultingNeighborsNaive, distancesNaive);
-
- for (size_t i = 0; i < resultingNeighborsTree.n_elem; i++)
- {
- BOOST_REQUIRE(resultingNeighborsTree[i] == resultingNeighborsNaive[i]);
- BOOST_REQUIRE_CLOSE(distancesTree[i], distancesNaive[i], 1e-5);
- }
-}
-
-/**
- * Test the single-tree nearest-neighbors method with the naive method. This
- * uses only a reference dataset.
- *
- * Errors are produced if the results are not identical.
- */
-BOOST_AUTO_TEST_CASE(SingleTreeVsNaive)
-{
- arma::mat dataForTree;
-
- // Hard-coded filename: bad!
- // Code duplication: also bad!
- if (!data::Load("test_data_3_1000.csv", dataForTree))
- BOOST_FAIL("Cannot load test dataset test_data_3_1000.csv!");
-
- // Set up matrices to work with (may not be necessary with no ALIAS_MATRIX?).
- arma::mat singleQuery(dataForTree);
- arma::mat naiveQuery(dataForTree);
-
- AllkNN allknn(singleQuery, false, true);
-
- // Set up computation for naive mode.
- AllkNN naive(naiveQuery, true);
-
- arma::Mat<size_t> resultingNeighborsTree;
- arma::mat distancesTree;
- allknn.Search(15, resultingNeighborsTree, distancesTree);
-
- arma::Mat<size_t> resultingNeighborsNaive;
- arma::mat distancesNaive;
- naive.Search(15, resultingNeighborsNaive, distancesNaive);
-
- for (size_t i = 0; i < resultingNeighborsTree.n_elem; i++)
- {
- BOOST_REQUIRE(resultingNeighborsTree[i] == resultingNeighborsNaive[i]);
- BOOST_REQUIRE_CLOSE(distancesTree[i], distancesNaive[i], 1e-5);
- }
-}
-
-/**
- * Test the cover tree single-tree nearest-neighbors method against the naive
- * method. This uses only a random reference dataset.
- *
- * Errors are produced if the results are not identical.
- */
-BOOST_AUTO_TEST_CASE(SingleCoverTreeTest)
-{
- arma::mat data;
- data.randu(75, 1000); // 75 dimensional, 1000 points.
-
- arma::mat naiveQuery(data); // For naive AllkNN.
-
- tree::CoverTree<metric::LMetric<2>, tree::FirstPointIsRoot,
- QueryStat<NearestNeighborSort> > tree = tree::CoverTree<
- metric::LMetric<2>, tree::FirstPointIsRoot,
- QueryStat<NearestNeighborSort> >(data);
-
- NeighborSearch<NearestNeighborSort, metric::LMetric<2>,
- tree::CoverTree<metric::LMetric<2>, tree::FirstPointIsRoot,
- QueryStat<NearestNeighborSort> > >
- coverTreeSearch(&tree, data, true);
-
- AllkNN naive(naiveQuery, true);
-
- arma::Mat<size_t> coverTreeNeighbors;
- arma::mat coverTreeDistances;
- coverTreeSearch.Search(15, coverTreeNeighbors, coverTreeDistances);
-
- arma::Mat<size_t> naiveNeighbors;
- arma::mat naiveDistances;
- naive.Search(15, naiveNeighbors, naiveDistances);
-
- for (size_t i = 0; i < coverTreeNeighbors.n_elem; ++i)
- {
- BOOST_REQUIRE_EQUAL(coverTreeNeighbors[i], naiveNeighbors[i]);
- BOOST_REQUIRE_CLOSE(coverTreeDistances[i], naiveDistances[i], 1e-5);
- }
-}
-
-/**
- * Test the cover tree dual-tree nearest neighbors method against the naive
- * method.
- */
-BOOST_AUTO_TEST_CASE(DualCoverTreeTest)
-{
- arma::mat data;
- srand(time(NULL));
- data.randn(3, 1000);
-
- arma::mat kdtreeData(data);
-
- AllkNN tree(kdtreeData);
-
- arma::Mat<size_t> kdNeighbors;
- arma::mat kdDistances;
- tree.Search(5, kdNeighbors, kdDistances);
-
- tree::CoverTree<metric::LMetric<2>, tree::FirstPointIsRoot,
- QueryStat<NearestNeighborSort> > referenceTree = tree::CoverTree<
- metric::LMetric<2>, tree::FirstPointIsRoot,
- QueryStat<NearestNeighborSort> >(data);
-
- NeighborSearch<NearestNeighborSort, metric::LMetric<2>,
- tree::CoverTree<metric::LMetric<2>, tree::FirstPointIsRoot,
- QueryStat<NearestNeighborSort> > >
- coverTreeSearch(&referenceTree, data);
-
- arma::Mat<size_t> coverNeighbors;
- arma::mat coverDistances;
- coverTreeSearch.Search(5, coverNeighbors, coverDistances);
-
- for (size_t i = 0; i < coverNeighbors.n_cols; ++i)
- {
- for (size_t j = 0; j < coverNeighbors.n_rows; ++j)
- {
- BOOST_REQUIRE_EQUAL(coverNeighbors(j, i), kdNeighbors(j, i));
- BOOST_REQUIRE_CLOSE(coverDistances(j, i), kdDistances(j, i), 1e-5);
- }
- }
-}
-
-BOOST_AUTO_TEST_SUITE_END();
Copied: mlpack/tags/mlpack-1.0.2/src/mlpack/tests/allknn_test.cpp (from rev 13398, mlpack/trunk/src/mlpack/tests/allknn_test.cpp)
===================================================================
--- mlpack/tags/mlpack-1.0.2/src/mlpack/tests/allknn_test.cpp (rev 0)
+++ mlpack/tags/mlpack-1.0.2/src/mlpack/tests/allknn_test.cpp 2012-08-15 17:40:09 UTC (rev 13406)
@@ -0,0 +1,525 @@
+/**
+ * @file allknn_test.cpp
+ *
+ * Test file for AllkNN class.
+ */
+#include <mlpack/core.hpp>
+#include <mlpack/methods/neighbor_search/neighbor_search.hpp>
+#include <mlpack/core/tree/cover_tree.hpp>
+#include <boost/test/unit_test.hpp>
+#include "old_boost_test_definitions.hpp"
+
+using namespace mlpack;
+using namespace mlpack::neighbor;
+
+BOOST_AUTO_TEST_SUITE(AllkNNTest);
+
+/**
+ * Simple nearest-neighbors test with small, synthetic dataset. This is an
+ * exhaustive test, which checks that each method for performing the calculation
+ * (dual-tree, single-tree, naive) produces the correct results. An
+ * eleven-point dataset and the ten nearest neighbors are taken. The dataset is
+ * in one dimension for simplicity -- the correct functionality of distance
+ * functions is not tested here.
+ */
+BOOST_AUTO_TEST_CASE(ExhaustiveSyntheticTest)
+{
+ // Set up our data.
+ arma::mat data(1, 11);
+ data[0] = 0.05; // Row addressing is unnecessary (they are all 0).
+ data[1] = 0.35;
+ data[2] = 0.15;
+ data[3] = 1.25;
+ data[4] = 5.05;
+ data[5] = -0.22;
+ data[6] = -2.00;
+ data[7] = -1.30;
+ data[8] = 0.45;
+ data[9] = 0.90;
+ data[10] = 1.00;
+
+ // We will loop through three times, one for each method of performing the
+ // calculation.
+ for (int i = 0; i < 3; i++)
+ {
+ AllkNN* allknn;
+ arma::mat dataMutable = data;
+ switch (i)
+ {
+ case 0: // Use the dual-tree method.
+ allknn = new AllkNN(dataMutable, false, false, 1);
+ break;
+ case 1: // Use the single-tree method.
+ allknn = new AllkNN(dataMutable, false, true, 1);
+ break;
+ case 2: // Use the naive method.
+ allknn = new AllkNN(dataMutable, true);
+ break;
+ }
+
+ // Now perform the actual calculation.
+ arma::Mat<size_t> neighbors;
+ arma::mat distances;
+ allknn->Search(10, neighbors, distances);
+
+ // Now the exhaustive check for correctness. This will be long. We must
+ // also remember that the distances returned are squared distances. As a
+ // result, distance comparisons are written out as (distance * distance) for
+ // readability.
+
+ // Neighbors of point 0.
+ BOOST_REQUIRE(neighbors(0, 0) == 2);
+ BOOST_REQUIRE_CLOSE(distances(0, 0), (0.10 * 0.10), 1e-5);
+ BOOST_REQUIRE(neighbors(1, 0) == 5);
+ BOOST_REQUIRE_CLOSE(distances(1, 0), (0.27 * 0.27), 1e-5);
+ BOOST_REQUIRE(neighbors(2, 0) == 1);
+ BOOST_REQUIRE_CLOSE(distances(2, 0), (0.30 * 0.30), 1e-5);
+ BOOST_REQUIRE(neighbors(3, 0) == 8);
+ BOOST_REQUIRE_CLOSE(distances(3, 0), (0.40 * 0.40), 1e-5);
+ BOOST_REQUIRE(neighbors(4, 0) == 9);
+ BOOST_REQUIRE_CLOSE(distances(4, 0), (0.85 * 0.85), 1e-5);
+ BOOST_REQUIRE(neighbors(5, 0) == 10);
+ BOOST_REQUIRE_CLOSE(distances(5, 0), (0.95 * 0.95), 1e-5);
+ BOOST_REQUIRE(neighbors(6, 0) == 3);
+ BOOST_REQUIRE_CLOSE(distances(6, 0), (1.20 * 1.20), 1e-5);
+ BOOST_REQUIRE(neighbors(7, 0) == 7);
+ BOOST_REQUIRE_CLOSE(distances(7, 0), (1.35 * 1.35), 1e-5);
+ BOOST_REQUIRE(neighbors(8, 0) == 6);
+ BOOST_REQUIRE_CLOSE(distances(8, 0), (2.05 * 2.05), 1e-5);
+ BOOST_REQUIRE(neighbors(9, 0) == 4);
+ BOOST_REQUIRE_CLOSE(distances(9, 0), (5.00 * 5.00), 1e-5);
+
+ // Neighbors of point 1.
+ BOOST_REQUIRE(neighbors(0, 1) == 8);
+ BOOST_REQUIRE_CLOSE(distances(0, 1), (0.10 * 0.10), 1e-5);
+ BOOST_REQUIRE(neighbors(1, 1) == 2);
+ BOOST_REQUIRE_CLOSE(distances(1, 1), (0.20 * 0.20), 1e-5);
+ BOOST_REQUIRE(neighbors(2, 1) == 0);
+ BOOST_REQUIRE_CLOSE(distances(2, 1), (0.30 * 0.30), 1e-5);
+ BOOST_REQUIRE(neighbors(3, 1) == 9);
+ BOOST_REQUIRE_CLOSE(distances(3, 1), (0.55 * 0.55), 1e-5);
+ BOOST_REQUIRE(neighbors(4, 1) == 5);
+ BOOST_REQUIRE_CLOSE(distances(4, 1), (0.57 * 0.57), 1e-5);
+ BOOST_REQUIRE(neighbors(5, 1) == 10);
+ BOOST_REQUIRE_CLOSE(distances(5, 1), (0.65 * 0.65), 1e-5);
+ BOOST_REQUIRE(neighbors(6, 1) == 3);
+ BOOST_REQUIRE_CLOSE(distances(6, 1), (0.90 * 0.90), 1e-5);
+ BOOST_REQUIRE(neighbors(7, 1) == 7);
+ BOOST_REQUIRE_CLOSE(distances(7, 1), (1.65 * 1.65), 1e-5);
+ BOOST_REQUIRE(neighbors(8, 1) == 6);
+ BOOST_REQUIRE_CLOSE(distances(8, 1), (2.35 * 2.35), 1e-5);
+ BOOST_REQUIRE(neighbors(9, 1) == 4);
+ BOOST_REQUIRE_CLOSE(distances(9, 1), (4.70 * 4.70), 1e-5);
+
+ // Neighbors of point 2.
+ BOOST_REQUIRE(neighbors(0, 2) == 0);
+ BOOST_REQUIRE_CLOSE(distances(0, 2), (0.10 * 0.10), 1e-5);
+ BOOST_REQUIRE(neighbors(1, 2) == 1);
+ BOOST_REQUIRE_CLOSE(distances(1, 2), (0.20 * 0.20), 1e-5);
+ BOOST_REQUIRE(neighbors(2, 2) == 8);
+ BOOST_REQUIRE_CLOSE(distances(2, 2), (0.30 * 0.30), 1e-5);
+ BOOST_REQUIRE(neighbors(3, 2) == 5);
+ BOOST_REQUIRE_CLOSE(distances(3, 2), (0.37 * 0.37), 1e-5);
+ BOOST_REQUIRE(neighbors(4, 2) == 9);
+ BOOST_REQUIRE_CLOSE(distances(4, 2), (0.75 * 0.75), 1e-5);
+ BOOST_REQUIRE(neighbors(5, 2) == 10);
+ BOOST_REQUIRE_CLOSE(distances(5, 2), (0.85 * 0.85), 1e-5);
+ BOOST_REQUIRE(neighbors(6, 2) == 3);
+ BOOST_REQUIRE_CLOSE(distances(6, 2), (1.10 * 1.10), 1e-5);
+ BOOST_REQUIRE(neighbors(7, 2) == 7);
+ BOOST_REQUIRE_CLOSE(distances(7, 2), (1.45 * 1.45), 1e-5);
+ BOOST_REQUIRE(neighbors(8, 2) == 6);
+ BOOST_REQUIRE_CLOSE(distances(8, 2), (2.15 * 2.15), 1e-5);
+ BOOST_REQUIRE(neighbors(9, 2) == 4);
+ BOOST_REQUIRE_CLOSE(distances(9, 2), (4.90 * 4.90), 1e-5);
+
+ // Neighbors of point 3.
+ BOOST_REQUIRE(neighbors(0, 3) == 10);
+ BOOST_REQUIRE_CLOSE(distances(0, 3), (0.25 * 0.25), 1e-5);
+ BOOST_REQUIRE(neighbors(1, 3) == 9);
+ BOOST_REQUIRE_CLOSE(distances(1, 3), (0.35 * 0.35), 1e-5);
+ BOOST_REQUIRE(neighbors(2, 3) == 8);
+ BOOST_REQUIRE_CLOSE(distances(2, 3), (0.80 * 0.80), 1e-5);
+ BOOST_REQUIRE(neighbors(3, 3) == 1);
+ BOOST_REQUIRE_CLOSE(distances(3, 3), (0.90 * 0.90), 1e-5);
+ BOOST_REQUIRE(neighbors(4, 3) == 2);
+ BOOST_REQUIRE_CLOSE(distances(4, 3), (1.10 * 1.10), 1e-5);
+ BOOST_REQUIRE(neighbors(5, 3) == 0);
+ BOOST_REQUIRE_CLOSE(distances(5, 3), (1.20 * 1.20), 1e-5);
+ BOOST_REQUIRE(neighbors(6, 3) == 5);
+ BOOST_REQUIRE_CLOSE(distances(6, 3), (1.47 * 1.47), 1e-5);
+ BOOST_REQUIRE(neighbors(7, 3) == 7);
+ BOOST_REQUIRE_CLOSE(distances(7, 3), (2.55 * 2.55), 1e-5);
+ BOOST_REQUIRE(neighbors(8, 3) == 6);
+ BOOST_REQUIRE_CLOSE(distances(8, 3), (3.25 * 3.25), 1e-5);
+ BOOST_REQUIRE(neighbors(9, 3) == 4);
+ BOOST_REQUIRE_CLOSE(distances(9, 3), (3.80 * 3.80), 1e-5);
+
+ // Neighbors of point 4.
+ BOOST_REQUIRE(neighbors(0, 4) == 3);
+ BOOST_REQUIRE_CLOSE(distances(0, 4), (3.80 * 3.80), 1e-5);
+ BOOST_REQUIRE(neighbors(1, 4) == 10);
+ BOOST_REQUIRE_CLOSE(distances(1, 4), (4.05 * 4.05), 1e-5);
+ BOOST_REQUIRE(neighbors(2, 4) == 9);
+ BOOST_REQUIRE_CLOSE(distances(2, 4), (4.15 * 4.15), 1e-5);
+ BOOST_REQUIRE(neighbors(3, 4) == 8);
+ BOOST_REQUIRE_CLOSE(distances(3, 4), (4.60 * 4.60), 1e-5);
+ BOOST_REQUIRE(neighbors(4, 4) == 1);
+ BOOST_REQUIRE_CLOSE(distances(4, 4), (4.70 * 4.70), 1e-5);
+ BOOST_REQUIRE(neighbors(5, 4) == 2);
+ BOOST_REQUIRE_CLOSE(distances(5, 4), (4.90 * 4.90), 1e-5);
+ BOOST_REQUIRE(neighbors(6, 4) == 0);
+ BOOST_REQUIRE_CLOSE(distances(6, 4), (5.00 * 5.00), 1e-5);
+ BOOST_REQUIRE(neighbors(7, 4) == 5);
+ BOOST_REQUIRE_CLOSE(distances(7, 4), (5.27 * 5.27), 1e-5);
+ BOOST_REQUIRE(neighbors(8, 4) == 7);
+ BOOST_REQUIRE_CLOSE(distances(8, 4), (6.35 * 6.35), 1e-5);
+ BOOST_REQUIRE(neighbors(9, 4) == 6);
+ BOOST_REQUIRE_CLOSE(distances(9, 4), (7.05 * 7.05), 1e-5);
+
+ // Neighbors of point 5.
+ BOOST_REQUIRE(neighbors(0, 5) == 0);
+ BOOST_REQUIRE_CLOSE(distances(0, 5), (0.27 * 0.27), 1e-5);
+ BOOST_REQUIRE(neighbors(1, 5) == 2);
+ BOOST_REQUIRE_CLOSE(distances(1, 5), (0.37 * 0.37), 1e-5);
+ BOOST_REQUIRE(neighbors(2, 5) == 1);
+ BOOST_REQUIRE_CLOSE(distances(2, 5), (0.57 * 0.57), 1e-5);
+ BOOST_REQUIRE(neighbors(3, 5) == 8);
+ BOOST_REQUIRE_CLOSE(distances(3, 5), (0.67 * 0.67), 1e-5);
+ BOOST_REQUIRE(neighbors(4, 5) == 7);
+ BOOST_REQUIRE_CLOSE(distances(4, 5), (1.08 * 1.08), 1e-5);
+ BOOST_REQUIRE(neighbors(5, 5) == 9);
+ BOOST_REQUIRE_CLOSE(distances(5, 5), (1.12 * 1.12), 1e-5);
+ BOOST_REQUIRE(neighbors(6, 5) == 10);
+ BOOST_REQUIRE_CLOSE(distances(6, 5), (1.22 * 1.22), 1e-5);
+ BOOST_REQUIRE(neighbors(7, 5) == 3);
+ BOOST_REQUIRE_CLOSE(distances(7, 5), (1.47 * 1.47), 1e-5);
+ BOOST_REQUIRE(neighbors(8, 5) == 6);
+ BOOST_REQUIRE_CLOSE(distances(8, 5), (1.78 * 1.78), 1e-5);
+ BOOST_REQUIRE(neighbors(9, 5) == 4);
+ BOOST_REQUIRE_CLOSE(distances(9, 5), (5.27 * 5.27), 1e-5);
+
+ // Neighbors of point 6.
+ BOOST_REQUIRE(neighbors(0, 6) == 7);
+ BOOST_REQUIRE_CLOSE(distances(0, 6), (0.70 * 0.70), 1e-5);
+ BOOST_REQUIRE(neighbors(1, 6) == 5);
+ BOOST_REQUIRE_CLOSE(distances(1, 6), (1.78 * 1.78), 1e-5);
+ BOOST_REQUIRE(neighbors(2, 6) == 0);
+ BOOST_REQUIRE_CLOSE(distances(2, 6), (2.05 * 2.05), 1e-5);
+ BOOST_REQUIRE(neighbors(3, 6) == 2);
+ BOOST_REQUIRE_CLOSE(distances(3, 6), (2.15 * 2.15), 1e-5);
+ BOOST_REQUIRE(neighbors(4, 6) == 1);
+ BOOST_REQUIRE_CLOSE(distances(4, 6), (2.35 * 2.35), 1e-5);
+ BOOST_REQUIRE(neighbors(5, 6) == 8);
+ BOOST_REQUIRE_CLOSE(distances(5, 6), (2.45 * 2.45), 1e-5);
+ BOOST_REQUIRE(neighbors(6, 6) == 9);
+ BOOST_REQUIRE_CLOSE(distances(6, 6), (2.90 * 2.90), 1e-5);
+ BOOST_REQUIRE(neighbors(7, 6) == 10);
+ BOOST_REQUIRE_CLOSE(distances(7, 6), (3.00 * 3.00), 1e-5);
+ BOOST_REQUIRE(neighbors(8, 6) == 3);
+ BOOST_REQUIRE_CLOSE(distances(8, 6), (3.25 * 3.25), 1e-5);
+ BOOST_REQUIRE(neighbors(9, 6) == 4);
+ BOOST_REQUIRE_CLOSE(distances(9, 6), (7.05 * 7.05), 1e-5);
+
+ // Neighbors of point 7.
+ BOOST_REQUIRE(neighbors(0, 7) == 6);
+ BOOST_REQUIRE_CLOSE(distances(0, 7), (0.70 * 0.70), 1e-5);
+ BOOST_REQUIRE(neighbors(1, 7) == 5);
+ BOOST_REQUIRE_CLOSE(distances(1, 7), (1.08 * 1.08), 1e-5);
+ BOOST_REQUIRE(neighbors(2, 7) == 0);
+ BOOST_REQUIRE_CLOSE(distances(2, 7), (1.35 * 1.35), 1e-5);
+ BOOST_REQUIRE(neighbors(3, 7) == 2);
+ BOOST_REQUIRE_CLOSE(distances(3, 7), (1.45 * 1.45), 1e-5);
+ BOOST_REQUIRE(neighbors(4, 7) == 1);
+ BOOST_REQUIRE_CLOSE(distances(4, 7), (1.65 * 1.65), 1e-5);
+ BOOST_REQUIRE(neighbors(5, 7) == 8);
+ BOOST_REQUIRE_CLOSE(distances(5, 7), (1.75 * 1.75), 1e-5);
+ BOOST_REQUIRE(neighbors(6, 7) == 9);
+ BOOST_REQUIRE_CLOSE(distances(6, 7), (2.20 * 2.20), 1e-5);
+ BOOST_REQUIRE(neighbors(7, 7) == 10);
+ BOOST_REQUIRE_CLOSE(distances(7, 7), (2.30 * 2.30), 1e-5);
+ BOOST_REQUIRE(neighbors(8, 7) == 3);
+ BOOST_REQUIRE_CLOSE(distances(8, 7), (2.55 * 2.55), 1e-5);
+ BOOST_REQUIRE(neighbors(9, 7) == 4);
+ BOOST_REQUIRE_CLOSE(distances(9, 7), (6.35 * 6.35), 1e-5);
+
+ // Neighbors of point 8.
+ BOOST_REQUIRE(neighbors(0, 8) == 1);
+ BOOST_REQUIRE_CLOSE(distances(0, 8), (0.10 * 0.10), 1e-5);
+ BOOST_REQUIRE(neighbors(1, 8) == 2);
+ BOOST_REQUIRE_CLOSE(distances(1, 8), (0.30 * 0.30), 1e-5);
+ BOOST_REQUIRE(neighbors(2, 8) == 0);
+ BOOST_REQUIRE_CLOSE(distances(2, 8), (0.40 * 0.40), 1e-5);
+ BOOST_REQUIRE(neighbors(3, 8) == 9);
+ BOOST_REQUIRE_CLOSE(distances(3, 8), (0.45 * 0.45), 1e-5);
+ BOOST_REQUIRE(neighbors(4, 8) == 10);
+ BOOST_REQUIRE_CLOSE(distances(4, 8), (0.55 * 0.55), 1e-5);
+ BOOST_REQUIRE(neighbors(5, 8) == 5);
+ BOOST_REQUIRE_CLOSE(distances(5, 8), (0.67 * 0.67), 1e-5);
+ BOOST_REQUIRE(neighbors(6, 8) == 3);
+ BOOST_REQUIRE_CLOSE(distances(6, 8), (0.80 * 0.80), 1e-5);
+ BOOST_REQUIRE(neighbors(7, 8) == 7);
+ BOOST_REQUIRE_CLOSE(distances(7, 8), (1.75 * 1.75), 1e-5);
+ BOOST_REQUIRE(neighbors(8, 8) == 6);
+ BOOST_REQUIRE_CLOSE(distances(8, 8), (2.45 * 2.45), 1e-5);
+ BOOST_REQUIRE(neighbors(9, 8) == 4);
+ BOOST_REQUIRE_CLOSE(distances(9, 8), (4.60 * 4.60), 1e-5);
+
+ // Neighbors of point 9.
+ BOOST_REQUIRE(neighbors(0, 9) == 10);
+ BOOST_REQUIRE_CLOSE(distances(0, 9), (0.10 * 0.10), 1e-5);
+ BOOST_REQUIRE(neighbors(1, 9) == 3);
+ BOOST_REQUIRE_CLOSE(distances(1, 9), (0.35 * 0.35), 1e-5);
+ BOOST_REQUIRE(neighbors(2, 9) == 8);
+ BOOST_REQUIRE_CLOSE(distances(2, 9), (0.45 * 0.45), 1e-5);
+ BOOST_REQUIRE(neighbors(3, 9) == 1);
+ BOOST_REQUIRE_CLOSE(distances(3, 9), (0.55 * 0.55), 1e-5);
+ BOOST_REQUIRE(neighbors(4, 9) == 2);
+ BOOST_REQUIRE_CLOSE(distances(4, 9), (0.75 * 0.75), 1e-5);
+ BOOST_REQUIRE(neighbors(5, 9) == 0);
+ BOOST_REQUIRE_CLOSE(distances(5, 9), (0.85 * 0.85), 1e-5);
+ BOOST_REQUIRE(neighbors(6, 9) == 5);
+ BOOST_REQUIRE_CLOSE(distances(6, 9), (1.12 * 1.12), 1e-5);
+ BOOST_REQUIRE(neighbors(7, 9) == 7);
+ BOOST_REQUIRE_CLOSE(distances(7, 9), (2.20 * 2.20), 1e-5);
+ BOOST_REQUIRE(neighbors(8, 9) == 6);
+ BOOST_REQUIRE_CLOSE(distances(8, 9), (2.90 * 2.90), 1e-5);
+ BOOST_REQUIRE(neighbors(9, 9) == 4);
+ BOOST_REQUIRE_CLOSE(distances(9, 9), (4.15 * 4.15), 1e-5);
+
+ // Neighbors of point 10.
+ BOOST_REQUIRE(neighbors(0, 10) == 9);
+ BOOST_REQUIRE_CLOSE(distances(0, 10), (0.10 * 0.10), 1e-5);
+ BOOST_REQUIRE(neighbors(1, 10) == 3);
+ BOOST_REQUIRE_CLOSE(distances(1, 10), (0.25 * 0.25), 1e-5);
+ BOOST_REQUIRE(neighbors(2, 10) == 8);
+ BOOST_REQUIRE_CLOSE(distances(2, 10), (0.55 * 0.55), 1e-5);
+ BOOST_REQUIRE(neighbors(3, 10) == 1);
+ BOOST_REQUIRE_CLOSE(distances(3, 10), (0.65 * 0.65), 1e-5);
+ BOOST_REQUIRE(neighbors(4, 10) == 2);
+ BOOST_REQUIRE_CLOSE(distances(4, 10), (0.85 * 0.85), 1e-5);
+ BOOST_REQUIRE(neighbors(5, 10) == 0);
+ BOOST_REQUIRE_CLOSE(distances(5, 10), (0.95 * 0.95), 1e-5);
+ BOOST_REQUIRE(neighbors(6, 10) == 5);
+ BOOST_REQUIRE_CLOSE(distances(6, 10), (1.22 * 1.22), 1e-5);
+ BOOST_REQUIRE(neighbors(7, 10) == 7);
+ BOOST_REQUIRE_CLOSE(distances(7, 10), (2.30 * 2.30), 1e-5);
+ BOOST_REQUIRE(neighbors(8, 10) == 6);
+ BOOST_REQUIRE_CLOSE(distances(8, 10), (3.00 * 3.00), 1e-5);
+ BOOST_REQUIRE(neighbors(9, 10) == 4);
+ BOOST_REQUIRE_CLOSE(distances(9, 10), (4.05 * 4.05), 1e-5);
+
+ // Clean the memory.
+ delete allknn;
+ }
+}
+
+/**
+ * Test the dual-tree nearest-neighbors method with the naive method. This
+ * uses both a query and reference dataset.
+ *
+ * Errors are produced if the results are not identical.
+ */
+BOOST_AUTO_TEST_CASE(DualTreeVsNaive1)
+{
+ arma::mat dataForTree;
+
+ // Hard-coded filename: bad!
+ if (!data::Load("test_data_3_1000.csv", dataForTree))
+ BOOST_FAIL("Cannot load test dataset test_data_3_1000.csv!");
+
+ // Set up matrices to work with.
+ arma::mat dualQuery(dataForTree);
+ arma::mat dualReferences(dataForTree);
+ arma::mat naiveQuery(dataForTree);
+ arma::mat naiveReferences(dataForTree);
+
+ AllkNN allknn(dualQuery, dualReferences);
+
+ AllkNN naive(naiveQuery, naiveReferences, true);
+
+ arma::Mat<size_t> resultingNeighborsTree;
+ arma::mat distancesTree;
+ allknn.Search(15, resultingNeighborsTree, distancesTree);
+
+ arma::Mat<size_t> resultingNeighborsNaive;
+ arma::mat distancesNaive;
+ naive.Search(15, resultingNeighborsNaive, distancesNaive);
+
+ for (size_t i = 0; i < resultingNeighborsTree.n_elem; i++)
+ {
+ BOOST_REQUIRE(resultingNeighborsTree[i] == resultingNeighborsNaive[i]);
+ BOOST_REQUIRE_CLOSE(distancesTree[i], distancesNaive[i], 1e-5);
+ }
+}
+
+/**
+ * Test the dual-tree nearest-neighbors method with the naive method. This uses
+ * only a reference dataset.
+ *
+ * Errors are produced if the results are not identical.
+ */
+BOOST_AUTO_TEST_CASE(DualTreeVsNaive2)
+{
+ arma::mat dataForTree;
+
+ // Hard-coded filename: bad!
+ // Code duplication: also bad!
+ if (!data::Load("test_data_3_1000.csv", dataForTree))
+ BOOST_FAIL("Cannot load test dataset test_data_3_1000.csv!");
+
+ // Set up matrices to work with (may not be necessary with no ALIAS_MATRIX?).
+ arma::mat dualQuery(dataForTree);
+ arma::mat naiveQuery(dataForTree);
+
+ AllkNN allknn(dualQuery);
+
+ // Set naive mode.
+ AllkNN naive(naiveQuery, true);
+
+ arma::Mat<size_t> resultingNeighborsTree;
+ arma::mat distancesTree;
+ allknn.Search(15, resultingNeighborsTree, distancesTree);
+
+ arma::Mat<size_t> resultingNeighborsNaive;
+ arma::mat distancesNaive;
+ naive.Search(15, resultingNeighborsNaive, distancesNaive);
+
+ for (size_t i = 0; i < resultingNeighborsTree.n_elem; i++)
+ {
+ BOOST_REQUIRE(resultingNeighborsTree[i] == resultingNeighborsNaive[i]);
+ BOOST_REQUIRE_CLOSE(distancesTree[i], distancesNaive[i], 1e-5);
+ }
+}
+
+/**
+ * Test the single-tree nearest-neighbors method with the naive method. This
+ * uses only a reference dataset.
+ *
+ * Errors are produced if the results are not identical.
+ */
+BOOST_AUTO_TEST_CASE(SingleTreeVsNaive)
+{
+ arma::mat dataForTree;
+
+ // Hard-coded filename: bad!
+ // Code duplication: also bad!
+ if (!data::Load("test_data_3_1000.csv", dataForTree))
+ BOOST_FAIL("Cannot load test dataset test_data_3_1000.csv!");
+
+ // Set up matrices to work with (may not be necessary with no ALIAS_MATRIX?).
+ arma::mat singleQuery(dataForTree);
+ arma::mat naiveQuery(dataForTree);
+
+ AllkNN allknn(singleQuery, false, true);
+
+ // Set up computation for naive mode.
+ AllkNN naive(naiveQuery, true);
+
+ arma::Mat<size_t> resultingNeighborsTree;
+ arma::mat distancesTree;
+ allknn.Search(15, resultingNeighborsTree, distancesTree);
+
+ arma::Mat<size_t> resultingNeighborsNaive;
+ arma::mat distancesNaive;
+ naive.Search(15, resultingNeighborsNaive, distancesNaive);
+
+ for (size_t i = 0; i < resultingNeighborsTree.n_elem; i++)
+ {
+ BOOST_REQUIRE(resultingNeighborsTree[i] == resultingNeighborsNaive[i]);
+ BOOST_REQUIRE_CLOSE(distancesTree[i], distancesNaive[i], 1e-5);
+ }
+}
+
+/**
+ * Test the cover tree single-tree nearest-neighbors method against the naive
+ * method. This uses only a random reference dataset.
+ *
+ * Errors are produced if the results are not identical.
+ */
+BOOST_AUTO_TEST_CASE(SingleCoverTreeTest)
+{
+ arma::mat data;
+ data.randu(75, 1000); // 75 dimensional, 1000 points.
+
+ arma::mat naiveQuery(data); // For naive AllkNN.
+
+ tree::CoverTree<metric::LMetric<2>, tree::FirstPointIsRoot,
+ QueryStat<NearestNeighborSort> > tree = tree::CoverTree<
+ metric::LMetric<2>, tree::FirstPointIsRoot,
+ QueryStat<NearestNeighborSort> >(data);
+
+ NeighborSearch<NearestNeighborSort, metric::LMetric<2>,
+ tree::CoverTree<metric::LMetric<2>, tree::FirstPointIsRoot,
+ QueryStat<NearestNeighborSort> > >
+ coverTreeSearch(&tree, data, true);
+
+ AllkNN naive(naiveQuery, true);
+
+ arma::Mat<size_t> coverTreeNeighbors;
+ arma::mat coverTreeDistances;
+ coverTreeSearch.Search(15, coverTreeNeighbors, coverTreeDistances);
+
+ arma::Mat<size_t> naiveNeighbors;
+ arma::mat naiveDistances;
+ naive.Search(15, naiveNeighbors, naiveDistances);
+
+ for (size_t i = 0; i < coverTreeNeighbors.n_elem; ++i)
+ {
+ BOOST_REQUIRE_EQUAL(coverTreeNeighbors[i], naiveNeighbors[i]);
+ BOOST_REQUIRE_CLOSE(coverTreeDistances[i], naiveDistances[i], 1e-5);
+ }
+}
+
+/**
+ * Test the cover tree dual-tree nearest neighbors method against the naive
+ * method.
+ */
+BOOST_AUTO_TEST_CASE(DualCoverTreeTest)
+{
+ arma::mat dataset;
+// srand(time(NULL));
+// dataset.randn(5, 5000);
+ data::Load("test_data_3_1000.csv", dataset);
+
+ arma::mat kdtreeData(dataset);
+
+ AllkNN tree(kdtreeData);
+
+ arma::Mat<size_t> kdNeighbors;
+ arma::mat kdDistances;
+ tree.Search(5, kdNeighbors, kdDistances);
+
+ tree::CoverTree<metric::LMetric<2, true>, tree::FirstPointIsRoot,
+ QueryStat<NearestNeighborSort> > referenceTree = tree::CoverTree<
+ metric::LMetric<2, true>, tree::FirstPointIsRoot,
+ QueryStat<NearestNeighborSort> >(dataset);
+
+ NeighborSearch<NearestNeighborSort, metric::LMetric<2, true>,
+ tree::CoverTree<metric::LMetric<2, true>, tree::FirstPointIsRoot,
+ QueryStat<NearestNeighborSort> > >
+ coverTreeSearch(&referenceTree, dataset);
+
+ arma::Mat<size_t> coverNeighbors;
+ arma::mat coverDistances;
+ coverTreeSearch.Search(5, coverNeighbors, coverDistances);
+
+ for (size_t i = 0; i < coverNeighbors.n_cols; ++i)
+ {
+// Log::Debug << "cover neighbors col " << i << "\n" <<
+// trans(coverNeighbors.col(i));
+// Log::Debug << "cover distances col " << i << "\n" <<
+// trans(coverDistances.col(i));
+// Log::Debug << "kd neighbors col " << i << "\n" <<
+// trans(kdNeighbors.col(i));
+// Log::Debug << "kd distances col " << i << "\n" <<
+// trans(kdDistances.col(i));
+ for (size_t j = 0; j < coverNeighbors.n_rows; ++j)
+ {
+ BOOST_REQUIRE_EQUAL(coverNeighbors(j, i), kdNeighbors(j, i));
+ BOOST_REQUIRE_CLOSE(coverDistances(j, i), sqrt(kdDistances(j, i)), 1e-5);
+ }
+ }
+}
+
+BOOST_AUTO_TEST_SUITE_END();
Deleted: mlpack/tags/mlpack-1.0.2/src/mlpack/tests/sparse_coding_test.cpp
===================================================================
--- mlpack/trunk/src/mlpack/tests/sparse_coding_test.cpp 2012-08-09 09:02:13 UTC (rev 13379)
+++ mlpack/tags/mlpack-1.0.2/src/mlpack/tests/sparse_coding_test.cpp 2012-08-15 17:40:09 UTC (rev 13406)
@@ -1,139 +0,0 @@
-/**
- * @file sparse_coding_test.cpp
- *
- * Test for Sparse Coding
- */
-
-// Note: We don't use BOOST_REQUIRE_CLOSE in the code below because we need
-// to use FPC_WEAK, and it's not at all intuitive how to do that.
-
-#include <mlpack/core.hpp>
-#include <mlpack/methods/sparse_coding/sparse_coding.hpp>
-
-#include <boost/test/unit_test.hpp>
-#include "old_boost_test_definitions.hpp"
-
-using namespace arma;
-using namespace mlpack;
-using namespace mlpack::regression;
-using namespace mlpack::sparse_coding;
-
-BOOST_AUTO_TEST_SUITE(SparseCodingTest);
-
-void SCVerifyCorrectness(vec beta, vec errCorr, double lambda)
-{
- const double tol = 1e-12;
- size_t nDims = beta.n_elem;
- for(size_t j = 0; j < nDims; j++)
- {
- if (beta(j) == 0)
- {
- // Make sure that errCorr(j) <= lambda.
- BOOST_REQUIRE_SMALL(std::max(fabs(errCorr(j)) - lambda, 0.0), tol);
- }
- else if (beta(j) < 0)
- {
- // Make sure that errCorr(j) == lambda.
- BOOST_REQUIRE_SMALL(errCorr(j) - lambda, tol);
- }
- else // beta(j) > 0.
- {
- // Make sure that errCorr(j) == -lambda.
- BOOST_REQUIRE_SMALL(errCorr(j) + lambda, tol);
- }
- }
-}
-
-BOOST_AUTO_TEST_CASE(SparseCodingTestCodingStepLasso)
-{
- double lambda1 = 0.1;
- uword nAtoms = 25;
-
- mat X;
- X.load("mnist_first250_training_4s_and_9s.arm");
- uword nPoints = X.n_cols;
-
- // Normalize each point since these are images.
- for (uword i = 0; i < nPoints; ++i) {
- X.col(i) /= norm(X.col(i), 2);
- }
-
- SparseCoding<> sc(X, nAtoms, lambda1);
- sc.OptimizeCode();
-
- mat D = sc.Dictionary();
- mat Z = sc.Codes();
-
- for (uword i = 0; i < nPoints; ++i)
- {
- vec errCorr = trans(D) * (D * Z.unsafe_col(i) - X.unsafe_col(i));
- SCVerifyCorrectness(Z.unsafe_col(i), errCorr, lambda1);
- }
-}
-
-BOOST_AUTO_TEST_CASE(SparseCodingTestCodingStepElasticNet)
-{
- double lambda1 = 0.1;
- double lambda2 = 0.2;
- uword nAtoms = 25;
-
- mat X;
- X.load("mnist_first250_training_4s_and_9s.arm");
- uword nPoints = X.n_cols;
-
- // Normalize each point since these are images.
- for (uword i = 0; i < nPoints; ++i)
- X.col(i) /= norm(X.col(i), 2);
-
- SparseCoding<> sc(X, nAtoms, lambda1, lambda2);
- sc.OptimizeCode();
-
- mat D = sc.Dictionary();
- mat Z = sc.Codes();
-
- for(uword i = 0; i < nPoints; ++i)
- {
- vec errCorr =
- (trans(D) * D + lambda2 * eye(nAtoms, nAtoms)) * Z.unsafe_col(i)
- - trans(D) * X.unsafe_col(i);
-
- SCVerifyCorrectness(Z.unsafe_col(i), errCorr, lambda1);
- }
-}
-
-BOOST_AUTO_TEST_CASE(SparseCodingTestDictionaryStep)
-{
- const double tol = 1e-7;
-
- double lambda1 = 0.1;
- uword nAtoms = 25;
-
- mat X;
- X.load("mnist_first250_training_4s_and_9s.arm");
- uword nPoints = X.n_cols;
-
- // Normalize each point since these are images.
- for (uword i = 0; i < nPoints; ++i)
- X.col(i) /= norm(X.col(i), 2);
-
- SparseCoding<> sc(X, nAtoms, lambda1);
- sc.OptimizeCode();
-
- mat D = sc.Dictionary();
- mat Z = sc.Codes();
-
- uvec adjacencies = find(Z);
- double normGradient = sc.OptimizeDictionary(adjacencies, 1e-12);
-
- BOOST_REQUIRE_SMALL(normGradient, tol);
-}
-
-/*
-BOOST_AUTO_TEST_CASE(SparseCodingTestWhole)
-{
-
-}
-*/
-
-
-BOOST_AUTO_TEST_SUITE_END();
Copied: mlpack/tags/mlpack-1.0.2/src/mlpack/tests/sparse_coding_test.cpp (from rev 13399, mlpack/trunk/src/mlpack/tests/sparse_coding_test.cpp)
===================================================================
--- mlpack/tags/mlpack-1.0.2/src/mlpack/tests/sparse_coding_test.cpp (rev 0)
+++ mlpack/tags/mlpack-1.0.2/src/mlpack/tests/sparse_coding_test.cpp 2012-08-15 17:40:09 UTC (rev 13406)
@@ -0,0 +1,139 @@
+/**
+ * @file sparse_coding_test.cpp
+ *
+ * Test for Sparse Coding
+ */
+
+// Note: We don't use BOOST_REQUIRE_CLOSE in the code below because we need
+// to use FPC_WEAK, and it's not at all intuitive how to do that.
+
+#include <mlpack/core.hpp>
+#include <mlpack/methods/sparse_coding/sparse_coding.hpp>
+
+#include <boost/test/unit_test.hpp>
+#include "old_boost_test_definitions.hpp"
+
+using namespace arma;
+using namespace mlpack;
+using namespace mlpack::regression;
+using namespace mlpack::sparse_coding;
+
+BOOST_AUTO_TEST_SUITE(SparseCodingTest);
+
+void SCVerifyCorrectness(vec beta, vec errCorr, double lambda)
+{
+ const double tol = 1e-12;
+ size_t nDims = beta.n_elem;
+ for(size_t j = 0; j < nDims; j++)
+ {
+ if (beta(j) == 0)
+ {
+ // Make sure that errCorr(j) <= lambda.
+ BOOST_REQUIRE_SMALL(std::max(fabs(errCorr(j)) - lambda, 0.0), tol);
+ }
+ else if (beta(j) < 0)
+ {
+ // Make sure that errCorr(j) == lambda.
+ BOOST_REQUIRE_SMALL(errCorr(j) - lambda, tol);
+ }
+ else // beta(j) > 0.
+ {
+ // Make sure that errCorr(j) == -lambda.
+ BOOST_REQUIRE_SMALL(errCorr(j) + lambda, tol);
+ }
+ }
+}
+
+BOOST_AUTO_TEST_CASE(SparseCodingTestCodingStepLasso)
+{
+ double lambda1 = 0.1;
+ uword nAtoms = 25;
+
+ mat X;
+ X.load("mnist_first250_training_4s_and_9s.arm");
+ uword nPoints = X.n_cols;
+
+ // Normalize each point since these are images.
+ for (uword i = 0; i < nPoints; ++i) {
+ X.col(i) /= norm(X.col(i), 2);
+ }
+
+ SparseCoding<> sc(X, nAtoms, lambda1);
+ sc.OptimizeCode();
+
+ mat D = sc.Dictionary();
+ mat Z = sc.Codes();
+
+ for (uword i = 0; i < nPoints; ++i)
+ {
+ vec errCorr = trans(D) * (D * Z.unsafe_col(i) - X.unsafe_col(i));
+ SCVerifyCorrectness(Z.unsafe_col(i), errCorr, lambda1);
+ }
+}
+
+BOOST_AUTO_TEST_CASE(SparseCodingTestCodingStepElasticNet)
+{
+ double lambda1 = 0.1;
+ double lambda2 = 0.2;
+ uword nAtoms = 25;
+
+ mat X;
+ X.load("mnist_first250_training_4s_and_9s.arm");
+ uword nPoints = X.n_cols;
+
+ // Normalize each point since these are images.
+ for (uword i = 0; i < nPoints; ++i)
+ X.col(i) /= norm(X.col(i), 2);
+
+ SparseCoding<> sc(X, nAtoms, lambda1, lambda2);
+ sc.OptimizeCode();
+
+ mat D = sc.Dictionary();
+ mat Z = sc.Codes();
+
+ for(uword i = 0; i < nPoints; ++i)
+ {
+ vec errCorr =
+ (trans(D) * D + lambda2 * eye(nAtoms, nAtoms)) * Z.unsafe_col(i)
+ - trans(D) * X.unsafe_col(i);
+
+ SCVerifyCorrectness(Z.unsafe_col(i), errCorr, lambda1);
+ }
+}
+
+BOOST_AUTO_TEST_CASE(SparseCodingTestDictionaryStep)
+{
+ const double tol = 2e-7;
+
+ double lambda1 = 0.1;
+ uword nAtoms = 25;
+
+ mat X;
+ X.load("mnist_first250_training_4s_and_9s.arm");
+ uword nPoints = X.n_cols;
+
+ // Normalize each point since these are images.
+ for (uword i = 0; i < nPoints; ++i)
+ X.col(i) /= norm(X.col(i), 2);
+
+ SparseCoding<> sc(X, nAtoms, lambda1);
+ sc.OptimizeCode();
+
+ mat D = sc.Dictionary();
+ mat Z = sc.Codes();
+
+ uvec adjacencies = find(Z);
+ double normGradient = sc.OptimizeDictionary(adjacencies, 1e-12);
+
+ BOOST_REQUIRE_SMALL(normGradient, tol);
+}
+
+/*
+BOOST_AUTO_TEST_CASE(SparseCodingTestWhole)
+{
+
+}
+*/
+
+
+BOOST_AUTO_TEST_SUITE_END();
Deleted: mlpack/tags/mlpack-1.0.2/src/mlpack/tests/tree_test.cpp
===================================================================
--- mlpack/trunk/src/mlpack/tests/tree_test.cpp 2012-08-09 09:02:13 UTC (rev 13379)
+++ mlpack/tags/mlpack-1.0.2/src/mlpack/tests/tree_test.cpp 2012-08-15 17:40:09 UTC (rev 13406)
@@ -1,1671 +0,0 @@
-/**
- * @file tree_test.cpp
- *
- * Tests for tree-building methods.
- */
-#include <mlpack/core.hpp>
-#include <mlpack/core/tree/bounds.hpp>
-#include <mlpack/core/tree/binary_space_tree/binary_space_tree.hpp>
-#include <mlpack/core/metrics/lmetric.hpp>
-#include <mlpack/core/tree/cover_tree/cover_tree.hpp>
-
-#include <boost/test/unit_test.hpp>
-#include "old_boost_test_definitions.hpp"
-
-using namespace mlpack;
-using namespace mlpack::math;
-using namespace mlpack::tree;
-using namespace mlpack::metric;
-using namespace mlpack::bound;
-
-BOOST_AUTO_TEST_SUITE(TreeTest);
-
-/**
- * Ensure that a bound, by default, is empty and has no dimensionality.
- */
-BOOST_AUTO_TEST_CASE(HRectBoundEmptyConstructor)
-{
- HRectBound<2> b;
-
- BOOST_REQUIRE_EQUAL((int) b.Dim(), 0);
-}
-
-/**
- * Ensure that when we specify the dimensionality in the constructor, it is
- * correct, and the bounds are all the empty set.
- */
-BOOST_AUTO_TEST_CASE(HRectBoundDimConstructor)
-{
- HRectBound<2> b(2); // We'll do this with 2 and 5 dimensions.
-
- BOOST_REQUIRE_EQUAL(b.Dim(), 2);
- BOOST_REQUIRE_SMALL(b[0].Width(), 1e-5);
- BOOST_REQUIRE_SMALL(b[1].Width(), 1e-5);
-
- b = HRectBound<2>(5);
-
- BOOST_REQUIRE_EQUAL(b.Dim(), 5);
- BOOST_REQUIRE_SMALL(b[0].Width(), 1e-5);
- BOOST_REQUIRE_SMALL(b[1].Width(), 1e-5);
- BOOST_REQUIRE_SMALL(b[2].Width(), 1e-5);
- BOOST_REQUIRE_SMALL(b[3].Width(), 1e-5);
- BOOST_REQUIRE_SMALL(b[4].Width(), 1e-5);
-}
-
-/**
- * Test the copy constructor.
- */
-BOOST_AUTO_TEST_CASE(HRectBoundCopyConstructor)
-{
- HRectBound<2> b(2);
- b[0] = Range(0.0, 2.0);
- b[1] = Range(2.0, 3.0);
-
- HRectBound<2> c(b);
-
- BOOST_REQUIRE_EQUAL(c.Dim(), 2);
- BOOST_REQUIRE_SMALL(c[0].Lo(), 1e-5);
- BOOST_REQUIRE_CLOSE(c[0].Hi(), 2.0, 1e-5);
- BOOST_REQUIRE_CLOSE(c[1].Lo(), 2.0, 1e-5);
- BOOST_REQUIRE_CLOSE(c[1].Hi(), 3.0, 1e-5);
-}
-
-/**
- * Test the assignment operator.
- */
-BOOST_AUTO_TEST_CASE(HRectBoundAssignmentOperator)
-{
- HRectBound<2> b(2);
- b[0] = Range(0.0, 2.0);
- b[1] = Range(2.0, 3.0);
-
- HRectBound<2> c(4);
-
- c = b;
-
- BOOST_REQUIRE_EQUAL(c.Dim(), 2);
- BOOST_REQUIRE_SMALL(c[0].Lo(), 1e-5);
- BOOST_REQUIRE_CLOSE(c[0].Hi(), 2.0, 1e-5);
- BOOST_REQUIRE_CLOSE(c[1].Lo(), 2.0, 1e-5);
- BOOST_REQUIRE_CLOSE(c[1].Hi(), 3.0, 1e-5);
-}
-
-/**
- * Test that clearing the dimensions resets the bound to empty.
- */
-BOOST_AUTO_TEST_CASE(HRectBoundClear)
-{
- HRectBound<2> b(2); // We'll do this with two dimensions only.
-
- b[0] = Range(0.0, 2.0);
- b[1] = Range(2.0, 4.0);
-
- // Now we just need to make sure that we clear the range.
- b.Clear();
-
- BOOST_REQUIRE_SMALL(b[0].Width(), 1e-5);
- BOOST_REQUIRE_SMALL(b[1].Width(), 1e-5);
-}
-
-/**
- * Ensure that we get the correct centroid for our bound.
- */
-BOOST_AUTO_TEST_CASE(HRectBoundCentroid)
-{
- // Create a simple 3-dimensional bound.
- HRectBound<2> b(3);
-
- b[0] = Range(0.0, 5.0);
- b[1] = Range(-2.0, -1.0);
- b[2] = Range(-10.0, 50.0);
-
- arma::vec centroid;
-
- b.Centroid(centroid);
-
- BOOST_REQUIRE_EQUAL(centroid.n_elem, 3);
- BOOST_REQUIRE_CLOSE(centroid[0], 2.5, 1e-5);
- BOOST_REQUIRE_CLOSE(centroid[1], -1.5, 1e-5);
- BOOST_REQUIRE_CLOSE(centroid[2], 20.0, 1e-5);
-}
-
-/**
- * Ensure that we calculate the correct minimum distance between a point and a
- * bound.
- */
-BOOST_AUTO_TEST_CASE(HRectBoundMinDistancePoint)
-{
- // We'll do the calculation in five dimensions, and we'll use three cases for
- // the point: point is outside the bound; point is on the edge of the bound;
- // point is inside the bound. In the latter two cases, the distance should be
- // zero.
- HRectBound<2> b(5);
-
- b[0] = Range(0.0, 2.0);
- b[1] = Range(1.0, 5.0);
- b[2] = Range(-2.0, 2.0);
- b[3] = Range(-5.0, -2.0);
- b[4] = Range(1.0, 2.0);
-
- arma::vec point = "-2.0 0.0 10.0 3.0 3.0";
-
- // This will be the Euclidean squared distance.
- BOOST_REQUIRE_CLOSE(b.MinDistance(point), 95.0, 1e-5);
-
- point = "2.0 5.0 2.0 -5.0 1.0";
-
- BOOST_REQUIRE_SMALL(b.MinDistance(point), 1e-5);
-
- point = "1.0 2.0 0.0 -2.0 1.5";
-
- BOOST_REQUIRE_SMALL(b.MinDistance(point), 1e-5);
-}
-
-/**
- * Ensure that we calculate the correct minimum distance between a bound and
- * another bound.
- */
-BOOST_AUTO_TEST_CASE(HRectBoundMinDistanceBound)
-{
- // We'll do the calculation in five dimensions, and we can use six cases.
- // The other bound is completely outside the bound; the other bound is on the
- // edge of the bound; the other bound partially overlaps the bound; the other
- // bound fully overlaps the bound; the other bound is entirely inside the
- // bound; the other bound entirely envelops the bound.
- HRectBound<2> b(5);
-
- b[0] = Range(0.0, 2.0);
- b[1] = Range(1.0, 5.0);
- b[2] = Range(-2.0, 2.0);
- b[3] = Range(-5.0, -2.0);
- b[4] = Range(1.0, 2.0);
-
- HRectBound<2> c(5);
-
- // The other bound is completely outside the bound.
- c[0] = Range(-5.0, -2.0);
- c[1] = Range(6.0, 7.0);
- c[2] = Range(-2.0, 2.0);
- c[3] = Range(2.0, 5.0);
- c[4] = Range(3.0, 4.0);
-
- BOOST_REQUIRE_CLOSE(b.MinDistance(c), 22.0, 1e-5);
- BOOST_REQUIRE_CLOSE(c.MinDistance(b), 22.0, 1e-5);
-
- // The other bound is on the edge of the bound.
- c[0] = Range(-2.0, 0.0);
- c[1] = Range(0.0, 1.0);
- c[2] = Range(-3.0, -2.0);
- c[3] = Range(-10.0, -5.0);
- c[4] = Range(2.0, 3.0);
-
- BOOST_REQUIRE_SMALL(b.MinDistance(c), 1e-5);
- BOOST_REQUIRE_SMALL(c.MinDistance(b), 1e-5);
-
- // The other bound partially overlaps the bound.
- c[0] = Range(-2.0, 1.0);
- c[1] = Range(0.0, 2.0);
- c[2] = Range(-2.0, 2.0);
- c[3] = Range(-8.0, -4.0);
- c[4] = Range(0.0, 4.0);
-
- BOOST_REQUIRE_SMALL(b.MinDistance(c), 1e-5);
- BOOST_REQUIRE_SMALL(c.MinDistance(b), 1e-5);
-
- // The other bound fully overlaps the bound.
- BOOST_REQUIRE_SMALL(b.MinDistance(b), 1e-5);
- BOOST_REQUIRE_SMALL(c.MinDistance(c), 1e-5);
-
- // The other bound is entirely inside the bound / the other bound entirely
- // envelops the bound.
- c[0] = Range(-1.0, 3.0);
- c[1] = Range(0.0, 6.0);
- c[2] = Range(-3.0, 3.0);
- c[3] = Range(-7.0, 0.0);
- c[4] = Range(0.0, 5.0);
-
- BOOST_REQUIRE_SMALL(b.MinDistance(c), 1e-5);
- BOOST_REQUIRE_SMALL(c.MinDistance(b), 1e-5);
-
- // Now we must be sure that the minimum distance to itself is 0.
- BOOST_REQUIRE_SMALL(b.MinDistance(b), 1e-5);
- BOOST_REQUIRE_SMALL(c.MinDistance(c), 1e-5);
-}
-
-/**
- * Ensure that we calculate the correct maximum distance between a bound and a
- * point. This uses the same test cases as the MinDistance test.
- */
-BOOST_AUTO_TEST_CASE(HRectBoundMaxDistancePoint)
-{
- // We'll do the calculation in five dimensions, and we'll use three cases for
- // the point: point is outside the bound; point is on the edge of the bound;
- // point is inside the bound. In the latter two cases, the distance should be
- // zero.
- HRectBound<2> b(5);
-
- b[0] = Range(0.0, 2.0);
- b[1] = Range(1.0, 5.0);
- b[2] = Range(-2.0, 2.0);
- b[3] = Range(-5.0, -2.0);
- b[4] = Range(1.0, 2.0);
-
- arma::vec point = "-2.0 0.0 10.0 3.0 3.0";
-
- // This will be the Euclidean squared distance.
- BOOST_REQUIRE_CLOSE(b.MaxDistance(point), 253.0, 1e-5);
-
- point = "2.0 5.0 2.0 -5.0 1.0";
-
- BOOST_REQUIRE_CLOSE(b.MaxDistance(point), 46.0, 1e-5);
-
- point = "1.0 2.0 0.0 -2.0 1.5";
-
- BOOST_REQUIRE_CLOSE(b.MaxDistance(point), 23.25, 1e-5);
-}
-
-/**
- * Ensure that we calculate the correct maximum distance between a bound and
- * another bound. This uses the same test cases as the MinDistance test.
- */
-BOOST_AUTO_TEST_CASE(HRectBoundMaxDistanceBound)
-{
- // We'll do the calculation in five dimensions, and we can use six cases.
- // The other bound is completely outside the bound; the other bound is on the
- // edge of the bound; the other bound partially overlaps the bound; the other
- // bound fully overlaps the bound; the other bound is entirely inside the
- // bound; the other bound entirely envelops the bound.
- HRectBound<2> b(5);
-
- b[0] = Range(0.0, 2.0);
- b[1] = Range(1.0, 5.0);
- b[2] = Range(-2.0, 2.0);
- b[3] = Range(-5.0, -2.0);
- b[4] = Range(1.0, 2.0);
-
- HRectBound<2> c(5);
-
- // The other bound is completely outside the bound.
- c[0] = Range(-5.0, -2.0);
- c[1] = Range(6.0, 7.0);
- c[2] = Range(-2.0, 2.0);
- c[3] = Range(2.0, 5.0);
- c[4] = Range(3.0, 4.0);
-
- BOOST_REQUIRE_CLOSE(b.MaxDistance(c), 210.0, 1e-5);
- BOOST_REQUIRE_CLOSE(c.MaxDistance(b), 210.0, 1e-5);
-
- // The other bound is on the edge of the bound.
- c[0] = Range(-2.0, 0.0);
- c[1] = Range(0.0, 1.0);
- c[2] = Range(-3.0, -2.0);
- c[3] = Range(-10.0, -5.0);
- c[4] = Range(2.0, 3.0);
-
- BOOST_REQUIRE_CLOSE(b.MaxDistance(c), 134.0, 1e-5);
- BOOST_REQUIRE_CLOSE(c.MaxDistance(b), 134.0, 1e-5);
-
- // The other bound partially overlaps the bound.
- c[0] = Range(-2.0, 1.0);
- c[1] = Range(0.0, 2.0);
- c[2] = Range(-2.0, 2.0);
- c[3] = Range(-8.0, -4.0);
- c[4] = Range(0.0, 4.0);
-
- BOOST_REQUIRE_CLOSE(b.MaxDistance(c), 102.0, 1e-5);
- BOOST_REQUIRE_CLOSE(c.MaxDistance(b), 102.0, 1e-5);
-
- // The other bound fully overlaps the bound.
- BOOST_REQUIRE_CLOSE(b.MaxDistance(b), 46.0, 1e-5);
- BOOST_REQUIRE_CLOSE(c.MaxDistance(c), 61.0, 1e-5);
-
- // The other bound is entirely inside the bound / the other bound entirely
- // envelops the bound.
- c[0] = Range(-1.0, 3.0);
- c[1] = Range(0.0, 6.0);
- c[2] = Range(-3.0, 3.0);
- c[3] = Range(-7.0, 0.0);
- c[4] = Range(0.0, 5.0);
-
- BOOST_REQUIRE_CLOSE(b.MaxDistance(c), 100.0, 1e-5);
- BOOST_REQUIRE_CLOSE(c.MaxDistance(b), 100.0, 1e-5);
-
- // Identical bounds. This will be the sum of the squared widths in each
- // dimension.
- BOOST_REQUIRE_CLOSE(b.MaxDistance(b), 46.0, 1e-5);
- BOOST_REQUIRE_CLOSE(c.MaxDistance(c), 162.0, 1e-5);
-
- // One last additional case. If the bound encloses only one point, the
- // maximum distance between it and itself is 0.
- HRectBound<2> d(2);
-
- d[0] = Range(2.0, 2.0);
- d[1] = Range(3.0, 3.0);
-
- BOOST_REQUIRE_SMALL(d.MaxDistance(d), 1e-5);
-}
-
-/**
- * Ensure that the ranges returned by RangeDistance() are equal to the minimum
- * and maximum distance. We will perform this test by creating random bounds
- * and comparing the behavior to MinDistance() and MaxDistance() -- so this test
- * is assuming that those passed and operate correctly.
- */
-BOOST_AUTO_TEST_CASE(HRectBoundRangeDistanceBound)
-{
- for (int i = 0; i < 50; i++)
- {
- size_t dim = math::RandInt(20);
-
- HRectBound<2> a(dim);
- HRectBound<2> b(dim);
-
- // We will set the low randomly and the width randomly for each dimension of
- // each bound.
- arma::vec loA(dim);
- arma::vec widthA(dim);
-
- loA.randu();
- widthA.randu();
-
- arma::vec lo_b(dim);
- arma::vec width_b(dim);
-
- lo_b.randu();
- width_b.randu();
-
- for (size_t j = 0; j < dim; j++)
- {
- a[j] = Range(loA[j], loA[j] + widthA[j]);
- b[j] = Range(lo_b[j], lo_b[j] + width_b[j]);
- }
-
- // Now ensure that MinDistance and MaxDistance report the same.
- Range r = a.RangeDistance(b);
- Range s = b.RangeDistance(a);
-
- BOOST_REQUIRE_CLOSE(r.Lo(), s.Lo(), 1e-5);
- BOOST_REQUIRE_CLOSE(r.Hi(), s.Hi(), 1e-5);
-
- BOOST_REQUIRE_CLOSE(r.Lo(), a.MinDistance(b), 1e-5);
- BOOST_REQUIRE_CLOSE(r.Hi(), a.MaxDistance(b), 1e-5);
-
- BOOST_REQUIRE_CLOSE(s.Lo(), b.MinDistance(a), 1e-5);
- BOOST_REQUIRE_CLOSE(s.Hi(), b.MaxDistance(a), 1e-5);
- }
-}
-
-/**
- * Ensure that the ranges returned by RangeDistance() are equal to the minimum
- * and maximum distance. We will perform this test by creating random bounds
- * and comparing the bheavior to MinDistance() and MaxDistance() -- so this test
- * is assuming that those passed and operate correctly. This is for the
- * bound-to-point case.
- */
-BOOST_AUTO_TEST_CASE(HRectBoundRangeDistancePoint)
-{
- for (int i = 0; i < 20; i++)
- {
- size_t dim = math::RandInt(20);
-
- HRectBound<2> a(dim);
-
- // We will set the low randomly and the width randomly for each dimension of
- // each bound.
- arma::vec loA(dim);
- arma::vec widthA(dim);
-
- loA.randu();
- widthA.randu();
-
- for (size_t j = 0; j < dim; j++)
- a[j] = Range(loA[j], loA[j] + widthA[j]);
-
- // Now run the test on a few points.
- for (int j = 0; j < 10; j++)
- {
- arma::vec point(dim);
-
- point.randu();
-
- Range r = a.RangeDistance(point);
-
- BOOST_REQUIRE_CLOSE(r.Lo(), a.MinDistance(point), 1e-5);
- BOOST_REQUIRE_CLOSE(r.Hi(), a.MaxDistance(point), 1e-5);
- }
- }
-}
-
-/**
- * Test that we can expand the bound to include a new point.
- */
-BOOST_AUTO_TEST_CASE(HRectBoundOrOperatorPoint)
-{
- // Because this should be independent in each dimension, we can essentially
- // run five test cases at once.
- HRectBound<2> b(5);
-
- b[0] = Range(1.0, 3.0);
- b[1] = Range(2.0, 4.0);
- b[2] = Range(-2.0, -1.0);
- b[3] = Range(0.0, 0.0);
- b[4] = Range(); // Empty range.
-
- arma::vec point = "2.0 4.0 2.0 -1.0 6.0";
-
- b |= point;
-
- BOOST_REQUIRE_CLOSE(b[0].Lo(), 1.0, 1e-5);
- BOOST_REQUIRE_CLOSE(b[0].Hi(), 3.0, 1e-5);
- BOOST_REQUIRE_CLOSE(b[1].Lo(), 2.0, 1e-5);
- BOOST_REQUIRE_CLOSE(b[1].Hi(), 4.0, 1e-5);
- BOOST_REQUIRE_CLOSE(b[2].Lo(), -2.0, 1e-5);
- BOOST_REQUIRE_CLOSE(b[2].Hi(), 2.0, 1e-5);
- BOOST_REQUIRE_CLOSE(b[3].Lo(), -1.0, 1e-5);
- BOOST_REQUIRE_SMALL(b[3].Hi(), 1e-5);
- BOOST_REQUIRE_CLOSE(b[4].Lo(), 6.0, 1e-5);
- BOOST_REQUIRE_CLOSE(b[4].Hi(), 6.0, 1e-5);
-}
-
-/**
- * Test that we can expand the bound to include another bound.
- */
-BOOST_AUTO_TEST_CASE(HRectBoundOrOperatorBound)
-{
- // Because this should be independent in each dimension, we can run many tests
- // at once.
- HRectBound<2> b(8);
-
- b[0] = Range(1.0, 3.0);
- b[1] = Range(2.0, 4.0);
- b[2] = Range(-2.0, -1.0);
- b[3] = Range(4.0, 5.0);
- b[4] = Range(2.0, 4.0);
- b[5] = Range(0.0, 0.0);
- b[6] = Range();
- b[7] = Range(1.0, 3.0);
-
- HRectBound<2> c(8);
-
- c[0] = Range(-3.0, -1.0); // Entirely less than the other bound.
- c[1] = Range(0.0, 2.0); // Touching edges.
- c[2] = Range(-3.0, -1.5); // Partially overlapping.
- c[3] = Range(4.0, 5.0); // Identical.
- c[4] = Range(1.0, 5.0); // Entirely enclosing.
- c[5] = Range(2.0, 2.0); // A single point.
- c[6] = Range(1.0, 3.0);
- c[7] = Range(); // Empty set.
-
- HRectBound<2> d = c;
-
- b |= c;
- d |= b;
-
- BOOST_REQUIRE_CLOSE(b[0].Lo(), -3.0, 1e-5);
- BOOST_REQUIRE_CLOSE(b[0].Hi(), 3.0, 1e-5);
- BOOST_REQUIRE_CLOSE(d[0].Lo(), -3.0, 1e-5);
- BOOST_REQUIRE_CLOSE(d[0].Hi(), 3.0, 1e-5);
-
- BOOST_REQUIRE_CLOSE(b[1].Lo(), 0.0, 1e-5);
- BOOST_REQUIRE_CLOSE(b[1].Hi(), 4.0, 1e-5);
- BOOST_REQUIRE_CLOSE(d[1].Lo(), 0.0, 1e-5);
- BOOST_REQUIRE_CLOSE(d[1].Hi(), 4.0, 1e-5);
-
- BOOST_REQUIRE_CLOSE(b[2].Lo(), -3.0, 1e-5);
- BOOST_REQUIRE_CLOSE(b[2].Hi(), -1.0, 1e-5);
- BOOST_REQUIRE_CLOSE(d[2].Lo(), -3.0, 1e-5);
- BOOST_REQUIRE_CLOSE(d[2].Hi(), -1.0, 1e-5);
-
- BOOST_REQUIRE_CLOSE(b[3].Lo(), 4.0, 1e-5);
- BOOST_REQUIRE_CLOSE(b[3].Hi(), 5.0, 1e-5);
- BOOST_REQUIRE_CLOSE(d[3].Lo(), 4.0, 1e-5);
- BOOST_REQUIRE_CLOSE(d[3].Hi(), 5.0, 1e-5);
-
- BOOST_REQUIRE_CLOSE(b[4].Lo(), 1.0, 1e-5);
- BOOST_REQUIRE_CLOSE(b[4].Hi(), 5.0, 1e-5);
- BOOST_REQUIRE_CLOSE(d[4].Lo(), 1.0, 1e-5);
- BOOST_REQUIRE_CLOSE(d[4].Hi(), 5.0, 1e-5);
-
- BOOST_REQUIRE_SMALL(b[5].Lo(), 1e-5);
- BOOST_REQUIRE_CLOSE(b[5].Hi(), 2.0, 1e-5);
- BOOST_REQUIRE_SMALL(d[5].Lo(), 1e-5);
- BOOST_REQUIRE_CLOSE(d[5].Hi(), 2.0, 1e-5);
-
- BOOST_REQUIRE_CLOSE(b[6].Lo(), 1.0, 1e-5);
- BOOST_REQUIRE_CLOSE(b[6].Hi(), 3.0, 1e-5);
- BOOST_REQUIRE_CLOSE(d[6].Lo(), 1.0, 1e-5);
- BOOST_REQUIRE_CLOSE(d[6].Hi(), 3.0, 1e-5);
-
- BOOST_REQUIRE_CLOSE(b[7].Lo(), 1.0, 1e-5);
- BOOST_REQUIRE_CLOSE(b[7].Hi(), 3.0, 1e-5);
- BOOST_REQUIRE_CLOSE(d[7].Lo(), 1.0, 1e-5);
- BOOST_REQUIRE_CLOSE(d[7].Hi(), 3.0, 1e-5);
-}
-
-/**
- * Test that the Contains() function correctly figures out whether or not a
- * point is in a bound.
- */
-BOOST_AUTO_TEST_CASE(HRectBoundContains)
-{
- // We can test a couple different points: completely outside the bound,
- // adjacent in one dimension to the bound, adjacent in all dimensions to the
- // bound, and inside the bound.
- HRectBound<2> b(3);
-
- b[0] = Range(0.0, 2.0);
- b[1] = Range(0.0, 2.0);
- b[2] = Range(0.0, 2.0);
-
- // Completely outside the range.
- arma::vec point = "-1.0 4.0 4.0";
- BOOST_REQUIRE(!b.Contains(point));
-
- // Completely outside, but one dimension is in the range.
- point = "-1.0 4.0 1.0";
- BOOST_REQUIRE(!b.Contains(point));
-
- // Outside, but one dimension is on the edge.
- point = "-1.0 0.0 3.0";
- BOOST_REQUIRE(!b.Contains(point));
-
- // Two dimensions are on the edge, but one is outside.
- point = "0.0 0.0 3.0";
- BOOST_REQUIRE(!b.Contains(point));
-
- // Completely on the edge (should be contained).
- point = "0.0 0.0 0.0";
- BOOST_REQUIRE(b.Contains(point));
-
- // Inside the range.
- point = "0.3 1.0 0.4";
- BOOST_REQUIRE(b.Contains(point));
-}
-
-BOOST_AUTO_TEST_CASE(TestBallBound)
-{
- BallBound<> b1;
- BallBound<> b2;
-
- // Create two balls with a center distance of 1 from each other.
- // Give the first one a radius of 0.3 and the second a radius of 0.4.
- b1.Center().set_size(3);
- b1.Center()[0] = 1;
- b1.Center()[1] = 2;
- b1.Center()[2] = 3;
- b1.Radius() = 0.3;
-
- b2.Center().set_size(3);
- b2.Center()[0] = 1;
- b2.Center()[1] = 2;
- b2.Center()[2] = 4;
- b2.Radius() = 0.4;
-
- BOOST_REQUIRE_CLOSE(b1.MinDistance(b2), 1-0.3-0.4, 1e-5);
- BOOST_REQUIRE_CLOSE(b1.RangeDistance(b2).Hi(), 1+0.3+0.4, 1e-5);
- BOOST_REQUIRE_CLOSE(b1.RangeDistance(b2).Lo(), 1-0.3-0.4, 1e-5);
- BOOST_REQUIRE_CLOSE(b1.RangeDistance(b2).Hi(), 1+0.3+0.4, 1e-5);
- BOOST_REQUIRE_CLOSE(b1.RangeDistance(b2).Lo(), 1-0.3-0.4, 1e-5);
-
- BOOST_REQUIRE_CLOSE(b2.MinDistance(b1), 1-0.3-0.4, 1e-5);
- BOOST_REQUIRE_CLOSE(b2.MaxDistance(b1), 1+0.3+0.4, 1e-5);
- BOOST_REQUIRE_CLOSE(b2.RangeDistance(b1).Hi(), 1+0.3+0.4, 1e-5);
- BOOST_REQUIRE_CLOSE(b2.RangeDistance(b1).Lo(), 1-0.3-0.4, 1e-5);
-
- BOOST_REQUIRE(b1.Contains(b1.Center()));
- BOOST_REQUIRE(!b1.Contains(b2.Center()));
-
- BOOST_REQUIRE(!b2.Contains(b1.Center()));
- BOOST_REQUIRE(b2.Contains(b2.Center()));
- arma::vec b2point(3); // A point that's within the radius but not the center.
- b2point[0] = 1.1;
- b2point[1] = 2.1;
- b2point[2] = 4.1;
-
- BOOST_REQUIRE(b2.Contains(b2point));
-
- BOOST_REQUIRE_SMALL(b1.MinDistance(b1.Center()), 1e-5);
- BOOST_REQUIRE_CLOSE(b1.MinDistance(b2.Center()), 1 - 0.3, 1e-5);
- BOOST_REQUIRE_CLOSE(b2.MinDistance(b1.Center()), 1 - 0.4, 1e-5);
- BOOST_REQUIRE_CLOSE(b2.MaxDistance(b1.Center()), 1 + 0.4, 1e-5);
- BOOST_REQUIRE_CLOSE(b1.MaxDistance(b2.Center()), 1 + 0.3, 1e-5);
-}
-
-/**
- * Ensure that a bound, by default, is empty and has no dimensionality, and the
- * box size vector is empty.
- */
-BOOST_AUTO_TEST_CASE(PeriodicHRectBoundEmptyConstructor)
-{
- PeriodicHRectBound<2> b;
-
- BOOST_REQUIRE_EQUAL(b.Dim(), 0);
- BOOST_REQUIRE_EQUAL(b.Box().n_elem, 0);
-}
-
-/**
- * Ensure that when we specify the dimensionality in the constructor, it is
- * correct, and the bounds are all the empty set.
- */
-BOOST_AUTO_TEST_CASE(PeriodicHRectBoundBoxConstructor)
-{
- PeriodicHRectBound<2> b(arma::vec("5 6"));
-
- BOOST_REQUIRE_EQUAL(b.Dim(), 2);
- BOOST_REQUIRE_SMALL(b[0].Width(), 1e-5);
- BOOST_REQUIRE_SMALL(b[1].Width(), 1e-5);
- BOOST_REQUIRE_EQUAL(b.Box().n_elem, 2);
- BOOST_REQUIRE_CLOSE(b.Box()[0], 5.0, 1e-5);
- BOOST_REQUIRE_CLOSE(b.Box()[1], 6.0, 1e-5);
-
- PeriodicHRectBound<2> d(arma::vec("2 3 4 5 6"));
-
- BOOST_REQUIRE_EQUAL(d.Dim(), 5);
- BOOST_REQUIRE_SMALL(d[0].Width(), 1e-5);
- BOOST_REQUIRE_SMALL(d[1].Width(), 1e-5);
- BOOST_REQUIRE_SMALL(d[2].Width(), 1e-5);
- BOOST_REQUIRE_SMALL(d[3].Width(), 1e-5);
- BOOST_REQUIRE_SMALL(d[4].Width(), 1e-5);
- BOOST_REQUIRE_EQUAL(d.Box().n_elem, 5);
- BOOST_REQUIRE_CLOSE(d.Box()[0], 2.0, 1e-5);
- BOOST_REQUIRE_CLOSE(d.Box()[1], 3.0, 1e-5);
- BOOST_REQUIRE_CLOSE(d.Box()[2], 4.0, 1e-5);
- BOOST_REQUIRE_CLOSE(d.Box()[3], 5.0, 1e-5);
- BOOST_REQUIRE_CLOSE(d.Box()[4], 6.0, 1e-5);
-}
-
-/**
- * Test the copy constructor.
- */
-BOOST_AUTO_TEST_CASE(PeriodicHRectBoundCopyConstructor)
-{
- PeriodicHRectBound<2> b(arma::vec("3 4"));
- b[0] = Range(0.0, 2.0);
- b[1] = Range(2.0, 3.0);
-
- PeriodicHRectBound<2> c(b);
-
- BOOST_REQUIRE_EQUAL(c.Dim(), 2);
- BOOST_REQUIRE_SMALL(c[0].Lo(), 1e-5);
- BOOST_REQUIRE_CLOSE(c[0].Hi(), 2.0, 1e-5);
- BOOST_REQUIRE_CLOSE(c[1].Lo(), 2.0, 1e-5);
- BOOST_REQUIRE_CLOSE(c[1].Hi(), 3.0, 1e-5);
- BOOST_REQUIRE_EQUAL(c.Box().n_elem, 2);
- BOOST_REQUIRE_CLOSE(c.Box()[0], 3.0, 1e-5);
- BOOST_REQUIRE_CLOSE(c.Box()[1], 4.0, 1e-5);
-}
-
-/**
- * Test the assignment operator.
- *
-BOOST_AUTO_TEST_CASE(PeriodicHRectBoundAssignmentOperator)
-{
- PeriodicHRectBound<2> b(arma::vec("3 4"));
- b[0] = Range(0.0, 2.0);
- b[1] = Range(2.0, 3.0);
-
- PeriodicHRectBound<2> c(arma::vec("3 4 5 6"));
-
- c = b;
-
- BOOST_REQUIRE_EQUAL(c.Dim(), 2);
- BOOST_REQUIRE_SMALL(c[0].Lo(), 1e-5);
- BOOST_REQUIRE_CLOSE(c[0].Hi(), 2.0, 1e-5);
- BOOST_REQUIRE_CLOSE(c[1].Lo(), 2.0, 1e-5);
- BOOST_REQUIRE_CLOSE(c[1].Hi(), 3.0, 1e-5);
- BOOST_REQUIRE_EQUAL(c.Box().n_elem, 2);
- BOOST_REQUIRE_CLOSE(c.Box()[0], 3.0, 1e-5);
- BOOST_REQUIRE_CLOSE(c.Box()[1], 4.0, 1e-5);
-}*/
-
-/**
- * Ensure that we can set the box size correctly.
- *
-BOOST_AUTO_TEST_CASE(PeriodicHRectBoundSetBoxSize)
-{
- PeriodicHRectBound<2> b(arma::vec("1 2"));
-
- b.SetBoxSize(arma::vec("10 12"));
-
- BOOST_REQUIRE_EQUAL(b.Box().n_elem, 2);
- BOOST_REQUIRE_CLOSE(b.Box()[0], 10.0, 1e-5);
- BOOST_REQUIRE_CLOSE(b.Box()[1], 12.0, 1e-5);
-}*/
-
-/**
- * Ensure that we can clear the dimensions correctly. This does not involve the
- * box size at all, so the test can be identical to the HRectBound test.
- *
-BOOST_AUTO_TEST_CASE(PeriodicHRectBoundClear)
-{
- // We'll do this with two dimensions only.
- PeriodicHRectBound<2> b(arma::vec("5 5"));
-
- b[0] = Range(0.0, 2.0);
- b[1] = Range(2.0, 4.0);
-
- // Now we just need to make sure that we clear the range.
- b.Clear();
-
- BOOST_REQUIRE_SMALL(b[0].Width(), 1e-5);
- BOOST_REQUIRE_SMALL(b[1].Width(), 1e-5);
-}*/
-
-/**
- * Ensure that we get the correct centroid for our bound.
- *
-BOOST_AUTO_TEST_CASE(PeriodicHRectBoundCentroid) {
- // Create a simple 3-dimensional bound. The centroid is not affected by the
- // periodic coordinates.
- PeriodicHRectBound<2> b(arma::vec("100 100 100"));
-
- b[0] = Range(0.0, 5.0);
- b[1] = Range(-2.0, -1.0);
- b[2] = Range(-10.0, 50.0);
-
- arma::vec centroid;
-
- b.Centroid(centroid);
-
- BOOST_REQUIRE_EQUAL(centroid.n_elem, 3);
- BOOST_REQUIRE_CLOSE(centroid[0], 2.5, 1e-5);
- BOOST_REQUIRE_CLOSE(centroid[1], -1.5, 1e-5);
- BOOST_REQUIRE_CLOSE(centroid[2], 20.0, 1e-5);
-}*/
-
-/**
- * Correctly calculate the minimum distance between the bound and a point in
- * periodic coordinates. We have to account for the shifts necessary in
- * periodic coordinates too, so that makes testing this a little more difficult.
- *
-BOOST_AUTO_TEST_CASE(PeriodicHRectBoundMinDistancePoint)
-{
- // First, we'll start with a simple 2-dimensional case where the point is
- // inside the bound, then on the edge of the bound, then barely outside the
- // bound. The box size will be large enough that this is basically the
- // HRectBound case.
- PeriodicHRectBound<2> b(arma::vec("100 100"));
-
- b[0] = Range(0.0, 5.0);
- b[1] = Range(2.0, 4.0);
-
- // Inside the bound.
- arma::vec point = "2.5 3.0";
-
- BOOST_REQUIRE_SMALL(b.MinDistance(point), 1e-5);
-
- // On the edge.
- point = "5.0 4.0";
-
- BOOST_REQUIRE_SMALL(b.MinDistance(point), 1e-5);
-
- // And just a little outside the bound.
- point = "6.0 5.0";
-
- BOOST_REQUIRE_CLOSE(b.MinDistance(point), 2.0, 1e-5);
-
- // Now we start to invoke the periodicity. This point will "alias" to (-1,
- // -1).
- point = "99.0 99.0";
-
- BOOST_REQUIRE_CLOSE(b.MinDistance(point), 10.0, 1e-5);
-
- // We will perform several tests on a one-dimensional bound.
- PeriodicHRectBound<2> a(arma::vec("5.0"));
- point.set_size(1);
-
- a[0] = Range(2.0, 4.0); // Entirely inside box.
- point[0] = 7.5; // Inside first right image of the box.
-
- BOOST_REQUIRE_SMALL(a.MinDistance(point), 1e-5);
-
- a[0] = Range(0.0, 5.0); // Fills box fully.
- point[1] = 19.3; // Inside the box, which covers everything.
-
- BOOST_REQUIRE_SMALL(a.MinDistance(point), 1e-5);
-
- a[0] = Range(-10.0, 10.0); // Larger than the box.
- point[0] = -500.0; // Inside the box, which covers everything.
-
- BOOST_REQUIRE_SMALL(a.MinDistance(point), 1e-5);
-
- a[0] = Range(-2.0, 1.0); // Crosses over an edge.
- point[0] = 2.9; // The first right image of the bound starts at 3.0.
-
- BOOST_REQUIRE_CLOSE(a.MinDistance(point), 0.01, 1e-5);
-
- a[0] = Range(2.0, 4.0); // Inside box.
- point[0] = 0.0; // Closest to the first left image of the bound.
-
- BOOST_REQUIRE_CLOSE(a.MinDistance(point), 1.0, 1e-5);
-
- a[0] = Range(0.0, 2.0); // On edge of box.
- point[0] = 7.1; // 0.1 away from the first right image of the bound.
-
- BOOST_REQUIRE_CLOSE(a.MinDistance(point), 0.01, 1e-5);
-
- PeriodicHRectBound<2> d(arma::vec("0.0"));
- d[0] = Range(-10.0, 10.0); // Box is of infinite size.
- point[0] = 810.0; // 800 away from the only image of the box.
-
- BOOST_REQUIRE_CLOSE(d.MinDistance(point), 640000.0, 1e-5);
-
- PeriodicHRectBound<2> e(arma::vec("-5.0"));
- e[0] = Range(2.0, 4.0); // Box size of -5 should function the same as 5.
- point[0] = -10.8; // Should alias to 4.2.
-
- BOOST_REQUIRE_CLOSE(e.MinDistance(point), 0.04, 1e-5);
-
- // Switch our bound to a higher dimensionality. This should ensure that the
- // dimensions are independent like they should be.
- PeriodicHRectBound<2> c(arma::vec("5.0 5.0 5.0 5.0 5.0 5.0 0.0 -5.0"));
-
- c[0] = Range(2.0, 4.0); // Entirely inside box.
- c[1] = Range(0.0, 5.0); // Fills box fully.
- c[2] = Range(-10.0, 10.0); // Larger than the box.
- c[3] = Range(-2.0, 1.0); // Crosses over an edge.
- c[4] = Range(2.0, 4.0); // Inside box.
- c[5] = Range(0.0, 2.0); // On edge of box.
- c[6] = Range(-10.0, 10.0); // Box is of infinite size.
- c[7] = Range(2.0, 4.0); // Box size of -5 should function the same as 5.
-
- point.set_size(8);
- point[0] = 7.5; // Inside first right image of the box.
- point[1] = 19.3; // Inside the box, which covers everything.
- point[2] = -500.0; // Inside the box, which covers everything.
- point[3] = 2.9; // The first right image of the bound starts at 3.0.
- point[4] = 0.0; // Closest to the first left image of the bound.
- point[5] = 7.1; // 0.1 away from the first right image of the bound.
- point[6] = 810.0; // 800 away from the only image of the box.
- point[7] = -10.8; // Should alias to 4.2.
-
- BOOST_REQUIRE_CLOSE(c.MinDistance(point), 640001.06, 1e-10);
-}*/
-
-/**
- * Correctly calculate the minimum distance between the bound and another bound in
- * periodic coordinates. We have to account for the shifts necessary in
- * periodic coordinates too, so that makes testing this a little more difficult.
- *
-BOOST_AUTO_TEST_CASE(PeriodicHRectBoundMinDistanceBound)
-{
- // First, we'll start with a simple 2-dimensional case where the bounds are nonoverlapping,
- // then one bound is on the edge of the other bound,
- // then overlapping, then one range entirely covering the other. The box size will be large enough that this is basically the
- // HRectBound case.
- PeriodicHRectBound<2> b(arma::vec("100 100"));
- PeriodicHRectBound<2> c(arma::vec("100 100"));
-
- b[0] = Range(0.0, 5.0);
- b[1] = Range(2.0, 4.0);
-
- // Inside the bound.
- c[0] = Range(7.0, 9.0);
- c[1] = Range(10.0,12.0);
-
-
- BOOST_REQUIRE_CLOSE(b.MinDistance(c), 40.0, 1e-5);
-
- // On the edge.
- c[0] = Range(5.0, 8.0);
- c[1] = Range(4.0, 6.0);
-
- BOOST_REQUIRE_SMALL(b.MinDistance(c), 1e-5);
-
- // Overlapping the bound.
- c[0] = Range(3.0, 6.0);
- c[1] = Range(1.0, 3.0);
-
- BOOST_REQUIRE_SMALL(b.MinDistance(c), 1e-5);
-
- // One range entirely covering the other
-
- c[0] = Range(0.0, 6.0);
- c[1] = Range(1.0, 7.0);
-
- BOOST_REQUIRE_SMALL(b.MinDistance(c), 1e-5);
-
- // Now we start to invoke the periodicity. These bounds "alias" to (-3.0,
- // -1.0) and (5,0,6.0).
-
- c[0] = Range(97.0, 99.0);
- c[1] = Range(105.0, 106.0);
-
- BOOST_REQUIRE_CLOSE(b.MinDistance(c), 2.0, 1e-5);
-
- // We will perform several tests on a one-dimensional bound and smaller box size and mostly overlapping.
- PeriodicHRectBound<2> a(arma::vec("5.0"));
- PeriodicHRectBound<2> d(a);
-
- a[0] = Range(2.0, 4.0); // Entirely inside box.
- d[0] = Range(7.5, 10.0); // In the right image of the box, overlapping ranges.
-
- BOOST_REQUIRE_SMALL(a.MinDistance(d), 1e-5);
-
- a[0] = Range(0.0, 5.0); // Fills box fully.
- d[0] = Range(19.3, 21.0); // Two intervals inside the box, same as range of b[0].
-
- BOOST_REQUIRE_SMALL(a.MinDistance(d), 1e-5);
-
- a[0] = Range(-10.0, 10.0); // Larger than the box.
- d[0] = Range(-500.0, -498.0); // Inside the box, which covers everything.
-
- BOOST_REQUIRE_SMALL(a.MinDistance(d), 1e-5);
-
- a[0] = Range(-2.0, 1.0); // Crosses over an edge.
- d[0] = Range(2.9, 5.1); // The first right image of the bound starts at 3.0. Overlapping
-
- BOOST_REQUIRE_SMALL(a.MinDistance(d), 1e-5);
-
- a[0] = Range(-1.0, 1.0); // Crosses over an edge.
- d[0] = Range(11.9, 12.5); // The first right image of the bound starts at 4.0.
- BOOST_REQUIRE_CLOSE(a.MinDistance(d), 0.81, 1e-5);
-
- a[0] = Range(2.0, 3.0);
- d[0] = Range(9.5, 11);
- BOOST_REQUIRE_CLOSE(a.MinDistance(d), 1.0, 1e-5);
-
-}*/
-
-/**
- * Correctly calculate the maximum distance between the bound and a point in
- * periodic coordinates. We have to account for the shifts necessary in
- * periodic coordinates too, so that makes testing this a little more difficult.
- *
-BOOST_AUTO_TEST_CASE(PeriodicHRectBoundMaxDistancePoint)
-{
- // First, we'll start with a simple 2-dimensional case where the point is
- // inside the bound, then on the edge of the bound, then barely outside the
- // bound. The box size will be large enough that this is basically the
- // HRectBound case.
- PeriodicHRectBound<2> b(arma::vec("100 100"));
-
- b[0] = Range(0.0, 5.0);
- b[1] = Range(2.0, 4.0);
-
- // Inside the bound.
- arma::vec point = "2.5 3.0";
-
- //BOOST_REQUIRE_CLOSE(b.MaxDistance(point), 7.25, 1e-5);
-
- // On the edge.
- point = "5.0 4.0";
-
- BOOST_REQUIRE_CLOSE(b.MaxDistance(point), 29.0, 1e-5);
-
- // And just a little outside the bound.
- point = "6.0 5.0";
-
- BOOST_REQUIRE_CLOSE(b.MaxDistance(point), 45.0, 1e-5);
-
- // Now we start to invoke the periodicity. This point will "alias" to (-1,
- // -1).
- point = "99.0 99.0";
-
- BOOST_REQUIRE_CLOSE(b.MaxDistance(point), 61.0, 1e-5);
-
- // We will perform several tests on a one-dimensional bound and smaller box size.
- PeriodicHRectBound<2> a(arma::vec("5.0"));
- point.set_size(1);
-
- a[0] = Range(2.0, 4.0); // Entirely inside box.
- point[0] = 7.5; // Inside first right image of the box.
-
- BOOST_REQUIRE_CLOSE(a.MaxDistance(point), 2.25, 1e-5);
-
- a[0] = Range(0.0, 5.0); // Fills box fully.
- point[1] = 19.3; // Inside the box, which covers everything.
-
- BOOST_REQUIRE_CLOSE(a.MaxDistance(point), 18.49, 1e-5);
-
- a[0] = Range(-10.0, 10.0); // Larger than the box.
- point[0] = -500.0; // Inside the box, which covers everything.
-
- BOOST_REQUIRE_CLOSE(a.MaxDistance(point), 25.0, 1e-5);
-
- a[0] = Range(-2.0, 1.0); // Crosses over an edge.
- point[0] = 2.9; // The first right image of the bound starts at 3.0.
-
- BOOST_REQUIRE_CLOSE(a.MaxDistance(point), 8.41, 1e-5);
-
- a[0] = Range(2.0, 4.0); // Inside box.
- point[0] = 0.0; // Farthest from the first right image of the bound.
-
- BOOST_REQUIRE_CLOSE(a.MaxDistance(point), 25.0, 1e-5);
-
- a[0] = Range(0.0, 2.0); // On edge of box.
- point[0] = 7.1; // 2.1 away from the first left image of the bound.
-
- BOOST_REQUIRE_CLOSE(a.MaxDistance(point), 4.41, 1e-5);
-
- PeriodicHRectBound<2> d(arma::vec("0.0"));
- d[0] = Range(-10.0, 10.0); // Box is of infinite size.
- point[0] = 810.0; // 820 away from the only left image of the box.
-
- BOOST_REQUIRE_CLOSE(d.MinDistance(point), 672400.0, 1e-5);
-
- PeriodicHRectBound<2> e(arma::vec("-5.0"));
- e[0] = Range(2.0, 4.0); // Box size of -5 should function the same as 5.
- point[0] = -10.8; // Should alias to 4.2.
-
- BOOST_REQUIRE_CLOSE(e.MaxDistance(point), 4.84, 1e-5);
-
- // Switch our bound to a higher dimensionality. This should ensure that the
- // dimensions are independent like they should be.
- PeriodicHRectBound<2> c(arma::vec("5.0 5.0 5.0 5.0 5.0 5.0 0.0 -5.0"));
-
- c[0] = Range(2.0, 4.0); // Entirely inside box.
- c[1] = Range(0.0, 5.0); // Fills box fully.
- c[2] = Range(-10.0, 10.0); // Larger than the box.
- c[3] = Range(-2.0, 1.0); // Crosses over an edge.
- c[4] = Range(2.0, 4.0); // Inside box.
- c[5] = Range(0.0, 2.0); // On edge of box.
- c[6] = Range(-10.0, 10.0); // Box is of infinite size.
- c[7] = Range(2.0, 4.0); // Box size of -5 should function the same as 5.
-
- point.set_size(8);
- point[0] = 7.5; // Inside first right image of the box.
- point[1] = 19.3; // Inside the box, which covers everything.
- point[2] = -500.0; // Inside the box, which covers everything.
- point[3] = 2.9; // The first right image of the bound starts at 3.0.
- point[4] = 0.0; // Closest to the first left image of the bound.
- point[5] = 7.1; // 0.1 away from the first right image of the bound.
- point[6] = 810.0; // 800 away from the only image of the box.
- point[7] = -10.8; // Should alias to 4.2.
-
- BOOST_REQUIRE_CLOSE(c.MaxDistance(point), 672630.65, 1e-10);
-}*/
-
-/**
- * Correctly calculate the maximum distance between the bound and another bound in
- * periodic coordinates. We have to account for the shifts necessary in
- * periodic coordinates too, so that makes testing this a little more difficult.
- *
-BOOST_AUTO_TEST_CASE(PeriodicHRectBoundMaxDistanceBound)
-{
- // First, we'll start with a simple 2-dimensional case where the bounds are nonoverlapping,
- // then one bound is on the edge of the other bound,
- // then overlapping, then one range entirely covering the other. The box size will be large enough that this is basically the
- // HRectBound case.
- PeriodicHRectBound<2> b(arma::vec("100 100"));
- PeriodicHRectBound<2> c(b);
-
- b[0] = Range(0.0, 5.0);
- b[1] = Range(2.0, 4.0);
-
- // Inside the bound.
-
- c[0] = Range(7.0, 9.0);
- c[1] = Range(10.0,12.0);
-
-
- BOOST_REQUIRE_CLOSE(b.MaxDistance(c), 181.0, 1e-5);
-
- // On the edge.
-
- c[0] = Range(5.0, 8.0);
- c[1] = Range(4.0, 6.0);
-
- BOOST_REQUIRE_CLOSE(b.MaxDistance(c), 80.0, 1e-5);
-
- // Overlapping the bound.
-
- c[0] = Range(3.0, 6.0);
- c[1] = Range(1.0, 3.0);
-
- BOOST_REQUIRE_CLOSE(b.MaxDistance(c), 45.0, 1e-5);
-
- // One range entirely covering the other
-
- c[0] = Range(0.0, 6.0);
- c[1] = Range(1.0, 7.0);
-
- BOOST_REQUIRE_CLOSE(b.MaxDistance(c), 61.0, 1e-5);
-
- // Now we start to invoke the periodicity. Thess bounds "alias" to (-3.0,
- // -1.0) and (5,0,6.0).
-
- c[0] = Range(97.0, 99.0);
- c[1] = Range(105.0, 106.0);
-
- BOOST_REQUIRE_CLOSE(b.MaxDistance(c), 80.0, 1e-5);
-
- // We will perform several tests on a one-dimensional bound and smaller box size.
- PeriodicHRectBound<2> a(arma::vec("5.0"));
- PeriodicHRectBound<2> d(a);
-
- a[0] = Range(2.0, 4.0); // Entirely inside box.
- d[0] = Range(7.5, 10); // In the right image of the box, overlapping ranges.
-
- BOOST_REQUIRE_CLOSE(a.MaxDistance(d), 9.0, 1e-5);
-
- a[0] = Range(0.0, 5.0); // Fills box fully.
- d[0] = Range(19.3, 21.0); // Two intervals inside the box, same as range of b[0].
-
- BOOST_REQUIRE_CLOSE(a.MaxDistance(d), 18.49, 1e-5);
-
- a[0] = Range(-10.0, 10.0); // Larger than the box.
- d[0] = Range(-500.0, -498.0); // Inside the box, which covers everything.
-
- BOOST_REQUIRE_CLOSE(a.MaxDistance(d), 9.0, 1e-5);
-
- a[0] = Range(-2.0, 1.0); // Crosses over an edge.
- d[0] = Range(2.9, 5.1); // The first right image of the bound starts at 3.0.
-
- BOOST_REQUIRE_CLOSE(a.MaxDistance(d), 24.01, 1e-5);
-}*/
-
-
-/**
- * It seems as though Bill has stumbled across a bug where
- * BinarySpaceTree<>::count() returns something different than
- * BinarySpaceTree<>::count_. So, let's build a simple tree and make sure they
- * are the same.
- */
-BOOST_AUTO_TEST_CASE(TreeCountMismatch)
-{
- arma::mat dataset = "2.0 5.0 9.0 4.0 8.0 7.0;"
- "3.0 4.0 6.0 7.0 1.0 2.0 ";
-
- // Leaf size of 1.
- BinarySpaceTree<HRectBound<2> > rootNode(dataset, 1);
-
- BOOST_REQUIRE(rootNode.Count() == 6);
- BOOST_REQUIRE(rootNode.Left()->Count() == 3);
- BOOST_REQUIRE(rootNode.Left()->Left()->Count() == 2);
- BOOST_REQUIRE(rootNode.Left()->Left()->Left()->Count() == 1);
- BOOST_REQUIRE(rootNode.Left()->Left()->Right()->Count() == 1);
- BOOST_REQUIRE(rootNode.Left()->Right()->Count() == 1);
- BOOST_REQUIRE(rootNode.Right()->Count() == 3);
- BOOST_REQUIRE(rootNode.Right()->Left()->Count() == 2);
- BOOST_REQUIRE(rootNode.Right()->Left()->Left()->Count() == 1);
- BOOST_REQUIRE(rootNode.Right()->Left()->Right()->Count() == 1);
- BOOST_REQUIRE(rootNode.Right()->Right()->Count() == 1);
-}
-
-// Forward declaration of methods we need for the next test.
-template<typename TreeType>
-bool CheckPointBounds(TreeType* node, const arma::mat& data);
-
-template<typename TreeType>
-void GenerateVectorOfTree(TreeType* node,
- size_t depth,
- std::vector<TreeType*>& v);
-
-template<int t_pow>
-bool DoBoundsIntersect(HRectBound<t_pow>& a,
- HRectBound<t_pow>& b,
- size_t ia,
- size_t ib);
-
-/**
- * Exhaustive kd-tree test based on #125.
- *
- * - Generate a random dataset of a random size.
- * - Build a tree on that dataset.
- * - Ensure all the permutation indices map back to the correct points.
- * - Verify that each point is contained inside all of the bounds of its parent
- * nodes.
- * - Verify that each bound at a particular level of the tree does not overlap
- * with any other bounds at that level.
- *
- * Then, we do that whole process a handful of times.
- */
-BOOST_AUTO_TEST_CASE(KdTreeTest)
-{
- typedef BinarySpaceTree<HRectBound<2> > TreeType;
-
- size_t maxRuns = 10; // Ten total tests.
- size_t pointIncrements = 1000; // Range is from 2000 points to 11000.
-
- // We use the default leaf size of 20.
- for(size_t run = 0; run < maxRuns; run++)
- {
- size_t dimensions = run + 2;
- size_t maxPoints = (run + 1) * pointIncrements;
-
- size_t size = maxPoints;
- arma::mat dataset = arma::mat(dimensions, size);
- arma::mat datacopy; // Used to test mappings.
-
- // Mappings for post-sort verification of data.
- std::vector<size_t> newToOld;
- std::vector<size_t> oldToNew;
-
- // Generate data.
- dataset.randu();
- datacopy = dataset; // Save a copy.
-
- // Build the tree itself.
- TreeType root(dataset, newToOld, oldToNew);
-
- // Ensure the size of the tree is correct.
- BOOST_REQUIRE_EQUAL(root.Count(), size);
-
- // Check the forward and backward mappings for correctness.
- for(size_t i = 0; i < size; i++)
- {
- for(size_t j = 0; j < dimensions; j++)
- {
- BOOST_REQUIRE_EQUAL(dataset(j, i), datacopy(j, newToOld[i]));
- BOOST_REQUIRE_EQUAL(dataset(j, oldToNew[i]), datacopy(j, i));
- }
- }
-
- // Now check that each point is contained inside of all bounds above it.
- CheckPointBounds(&root, dataset);
-
- // Now check that no peers overlap.
- std::vector<TreeType*> v;
- GenerateVectorOfTree(&root, 1, v);
-
- // Start with the first pair.
- size_t depth = 2;
- // Compare each peer against every other peer.
- while (depth < v.size())
- {
- for (size_t i = depth; i < 2 * depth && i < v.size(); i++)
- for (size_t j = i + 1; j < 2 * depth && j < v.size(); j++)
- if (v[i] != NULL && v[j] != NULL)
- BOOST_REQUIRE(!DoBoundsIntersect(v[i]->Bound(), v[j]->Bound(),
- i, j));
-
- depth *= 2;
- }
- }
-
- arma::mat dataset = arma::mat(25, 1000);
- for (size_t col = 0; col < dataset.n_cols; ++col)
- for (size_t row = 0; row < dataset.n_rows; ++row)
- dataset(row, col) = row + col;
-
- TreeType root(dataset);
- // Check the tree size.
- BOOST_REQUIRE_EQUAL(root.TreeSize(), 127);
- // Check the tree depth.
- BOOST_REQUIRE_EQUAL(root.TreeDepth(), 7);
-}
-
-// Recursively checks that each node contains all points that it claims to have.
-template<typename TreeType>
-bool CheckPointBounds(TreeType* node, const arma::mat& data)
-{
- if (node == NULL) // We have passed a leaf node.
- return true;
-
- TreeType* left = node->Left();
- TreeType* right = node->Right();
-
- size_t begin = node->Begin();
- size_t count = node->Count();
-
- // Check that each point which this tree claims is actually inside the tree.
- for (size_t index = begin; index < begin + count; index++)
- if (!node->Bound().Contains(data.col(index)))
- return false;
-
- return CheckPointBounds(left, data) && CheckPointBounds(right, data);
-}
-
-template<int t_pow>
-bool DoBoundsIntersect(HRectBound<t_pow>& a,
- HRectBound<t_pow>& b,
- size_t /* ia */,
- size_t /* ib */)
-{
- size_t dimensionality = a.Dim();
-
- Range r_a;
- Range r_b;
-
- for (size_t i = 0; i < dimensionality; i++)
- {
- r_a = a[i];
- r_b = b[i];
- if (r_a < r_b || r_a > r_b) // If a does not overlap b at all.
- return false;
- }
-
- return true;
-}
-
-template<typename TreeType>
-void GenerateVectorOfTree(TreeType* node,
- size_t depth,
- std::vector<TreeType*>& v)
-{
- if (node == NULL)
- return;
-
- if (depth >= v.size())
- v.resize(2 * depth + 1, NULL); // Resize to right size; fill with NULL.
-
- v[depth] = node;
-
- // Recurse to the left and right children.
- GenerateVectorOfTree(node->Left(), depth * 2, v);
- GenerateVectorOfTree(node->Right(), depth * 2 + 1, v);
-
- return;
-}
-
-/**
- * Exhaustive sparse kd-tree test based on #125.
- *
- * - Generate a random dataset of a random size.
- * - Build a tree on that dataset.
- * - Ensure all the permutation indices map back to the correct points.
- * - Verify that each point is contained inside all of the bounds of its parent
- * nodes.
- * - Verify that each bound at a particular level of the tree does not overlap
- * with any other bounds at that level.
- *
- * Then, we do that whole process a handful of times.
- */
-BOOST_AUTO_TEST_CASE(ExhaustiveSparseKDTreeTest)
-{
- typedef BinarySpaceTree<HRectBound<2>, EmptyStatistic, arma::SpMat<double> >
- TreeType;
-
- size_t maxRuns = 2; // Two total tests.
- size_t pointIncrements = 200; // Range is from 200 points to 400.
-
- // We use the default leaf size of 20.
- for(size_t run = 0; run < maxRuns; run++)
- {
- size_t dimensions = run + 2;
- size_t maxPoints = (run + 1) * pointIncrements;
-
- size_t size = maxPoints;
- arma::SpMat<double> dataset = arma::SpMat<double>(dimensions, size);
- arma::SpMat<double> datacopy; // Used to test mappings.
-
- // Mappings for post-sort verification of data.
- std::vector<size_t> newToOld;
- std::vector<size_t> oldToNew;
-
- // Generate data.
- dataset.randu();
- datacopy = dataset; // Save a copy.
-
- // Build the tree itself.
- TreeType root(dataset, newToOld, oldToNew);
-
- // Ensure the size of the tree is correct.
- BOOST_REQUIRE_EQUAL(root.Count(), size);
-
- // Check the forward and backward mappings for correctness.
- for(size_t i = 0; i < size; i++)
- {
- for(size_t j = 0; j < dimensions; j++)
- {
- BOOST_REQUIRE_EQUAL(dataset(j, i), datacopy(j, newToOld[i]));
- BOOST_REQUIRE_EQUAL(dataset(j, oldToNew[i]), datacopy(j, i));
- }
- }
-
- // Now check that each point is contained inside of all bounds above it.
- CheckPointBounds(&root, dataset);
-
- // Now check that no peers overlap.
- std::vector<TreeType*> v;
- GenerateVectorOfTree(&root, 1, v);
-
- // Start with the first pair.
- size_t depth = 2;
- // Compare each peer against every other peer.
- while (depth < v.size())
- {
- for (size_t i = depth; i < 2 * depth && i < v.size(); i++)
- for (size_t j = i + 1; j < 2 * depth && j < v.size(); j++)
- if (v[i] != NULL && v[j] != NULL)
- BOOST_REQUIRE(!DoBoundsIntersect(v[i]->Bound(), v[j]->Bound(),
- i, j));
-
- depth *= 2;
- }
- }
-
- arma::SpMat<double> dataset(25, 1000);
- for (size_t col = 0; col < dataset.n_cols; ++col)
- for (size_t row = 0; row < dataset.n_rows; ++row)
- dataset(row, col) = row + col;
-
- TreeType root(dataset);
- // Check the tree size.
- BOOST_REQUIRE_EQUAL(root.TreeSize(), 127);
- // Check the tree depth.
- BOOST_REQUIRE_EQUAL(root.TreeDepth(), 7);
-}
-
-template<typename TreeType>
-void RecurseTreeCountLeaves(const TreeType& node, arma::vec& counts)
-{
- for (size_t i = 0; i < node.NumChildren(); ++i)
- {
- if (node.Child(i).NumChildren() == 0)
- counts[node.Child(i).Point()]++;
- else
- RecurseTreeCountLeaves<TreeType>(node.Child(i), counts);
- }
-}
-
-template<typename TreeType>
-void CheckSelfChild(const TreeType& node)
-{
- if (node.NumChildren() == 0)
- return; // No self-child applicable here.
-
- bool found = false;
- for (size_t i = 0; i < node.NumChildren(); ++i)
- {
- if (node.Child(i).Point() == node.Point())
- found = true;
-
- // Recursively check the children.
- CheckSelfChild(node.Child(i));
- }
-
- // Ensure this has its own self-child.
- BOOST_REQUIRE_EQUAL(found, true);
-}
-
-template<typename TreeType, typename MetricType>
-void CheckCovering(const TreeType& node)
-{
- // Return if a leaf. No checking necessary.
- if (node.NumChildren() == 0)
- return;
-
- const arma::mat& dataset = node.Dataset();
- const size_t nodePoint = node.Point();
-
- // To ensure that this node satisfies the covering principle, we must ensure
- // that the distance to each child is less than pow(expansionConstant, scale).
- double maxDistance = pow(node.ExpansionConstant(), node.Scale());
- for (size_t i = 0; i < node.NumChildren(); ++i)
- {
- const size_t childPoint = node.Child(i).Point();
-
- double distance = MetricType::Evaluate(dataset.col(nodePoint),
- dataset.col(childPoint));
-
- BOOST_REQUIRE_LE(distance, maxDistance);
-
- // Check the child.
- CheckCovering<TreeType, MetricType>(node.Child(i));
- }
-}
-
-template<typename TreeType, typename MetricType>
-void CheckIndividualSeparation(const TreeType& constantNode,
- const TreeType& node)
-{
- // Don't check points at a lower scale.
- if (node.Scale() < constantNode.Scale())
- return;
-
- // If at a higher scale, recurse.
- if (node.Scale() > constantNode.Scale())
- {
- for (size_t i = 0; i < node.NumChildren(); ++i)
- {
- // Don't recurse into leaves.
- if (node.Child(i).NumChildren() > 0)
- CheckIndividualSeparation<TreeType, MetricType>(constantNode,
- node.Child(i));
- }
-
- return;
- }
-
- // Don't compare the same point against itself.
- if (node.Point() == constantNode.Point())
- return;
-
- // Now we know we are at the same scale, so make the comparison.
- const arma::mat& dataset = constantNode.Dataset();
- const size_t constantPoint = constantNode.Point();
- const size_t nodePoint = node.Point();
-
- // Make sure the distance is at least the following value (in accordance with
- // the separation principle of cover trees).
- double minDistance = pow(constantNode.ExpansionConstant(),
- constantNode.Scale());
-
- double distance = MetricType::Evaluate(dataset.col(constantPoint),
- dataset.col(nodePoint));
-
- BOOST_REQUIRE_GE(distance, minDistance);
-}
-
-template<typename TreeType, typename MetricType>
-void CheckSeparation(const TreeType& node, const TreeType& root)
-{
- // Check the separation between this point and all other points on this scale.
- CheckIndividualSeparation<TreeType, MetricType>(node, root);
-
- // Check the children, but only if they are not leaves. Leaves don't need to
- // be checked.
- for (size_t i = 0; i < node.NumChildren(); ++i)
- if (node.Child(i).NumChildren() > 0)
- CheckSeparation<TreeType, MetricType>(node.Child(i), root);
-}
-
-
-/**
- * Create a simple cover tree and then make sure it is valid.
- */
-BOOST_AUTO_TEST_CASE(SimpleCoverTreeConstructionTest)
-{
- // 20-point dataset.
- arma::mat data = arma::trans(arma::mat("0.0 0.0;"
- "1.0 0.0;"
- "0.5 0.5;"
- "2.0 2.0;"
- "-1.0 2.0;"
- "3.0 0.0;"
- "1.5 5.5;"
- "-2.0 -2.0;"
- "-1.5 1.5;"
- "0.0 4.0;"
- "2.0 1.0;"
- "2.0 1.2;"
- "-3.0 -2.5;"
- "-5.0 -5.0;"
- "3.5 1.5;"
- "2.0 2.5;"
- "-1.0 -1.0;"
- "-3.5 1.5;"
- "3.5 -1.5;"
- "2.0 1.0;"));
-
- // The root point will be the first point, (0, 0).
- CoverTree<> tree(data); // Expansion constant of 2.0.
-
- // The furthest point from the root will be (-5, -5), with a squared distance
- // of 50. This means the scale of the root node should be 6 (because 2^6 =
- // 64).
- BOOST_REQUIRE_EQUAL(tree.Scale(), 6);
-
- // Now loop through the tree and ensure that each leaf is only created once.
- arma::vec counts;
- counts.zeros(20);
- RecurseTreeCountLeaves(tree, counts);
-
- // Each point should only have one leaf node representing it.
- for (size_t i = 0; i < 20; ++i)
- BOOST_REQUIRE_EQUAL(counts[i], 1);
-
- // Each non-leaf should have a self-child.
- CheckSelfChild<CoverTree<> >(tree);
-
- // Each node must satisfy the covering principle (its children must be less
- // than or equal to a certain distance apart).
- CheckCovering<CoverTree<>, LMetric<2> >(tree);
-
- // Each node's children must be separated by at least a certain value.
- CheckSeparation<CoverTree<>, LMetric<2> >(tree, tree);
-}
-
-/**
- * Create a large cover tree and make sure it's accurate.
- */
-BOOST_AUTO_TEST_CASE(CoverTreeConstructionTest)
-{
- arma::mat dataset;
- // 50-dimensional, 1000 point.
- dataset.randu(50, 1000);
-
- CoverTree<> tree(dataset);
-
- // Ensure each leaf is only created once.
- arma::vec counts;
- counts.zeros(1000);
- RecurseTreeCountLeaves(tree, counts);
-
- for (size_t i = 0; i < 1000; ++i)
- BOOST_REQUIRE_EQUAL(counts[i], 1);
-
- // Each non-leaf should have a self-child.
- CheckSelfChild<CoverTree<> >(tree);
-
- // Each node must satisfy the covering principle (its children must be less
- // than or equal to a certain distance apart).
- CheckCovering<CoverTree<>, LMetric<2> >(tree);
-
- // Each node's children must be separated by at least a certain value.
- CheckSeparation<CoverTree<>, LMetric<2> >(tree, tree);
-}
-
-/**
- * Make sure cover trees work in different metric spaces.
- */
-BOOST_AUTO_TEST_CASE(CoverTreeAlternateMetricTest)
-{
- arma::mat dataset;
- // 5-dimensional, 300-point dataset.
- dataset.randu(5, 300);
-
- CoverTree<LMetric<1> > tree(dataset);
-
- // Ensure each leaf is only created once.
- arma::vec counts;
- counts.zeros(300);
- RecurseTreeCountLeaves<CoverTree<LMetric<1> > >(tree, counts);
-
- for (size_t i = 0; i < 300; ++i)
- BOOST_REQUIRE_EQUAL(counts[i], 1);
-
- // Each non-leaf should have a self-child.
- CheckSelfChild<CoverTree<LMetric<1> > >(tree);
-
- // Each node must satisfy the covering principle (its children must be less
- // than or equal to a certain distance apart).
- CheckCovering<CoverTree<LMetric<1> >, LMetric<1> >(tree);
-
- // Each node's children must be separated by at least a certain value.
- CheckSeparation<CoverTree<LMetric<1> >, LMetric<1> >(tree, tree);
-}
-
-BOOST_AUTO_TEST_SUITE_END();
Copied: mlpack/tags/mlpack-1.0.2/src/mlpack/tests/tree_test.cpp (from rev 13404, mlpack/trunk/src/mlpack/tests/tree_test.cpp)
===================================================================
--- mlpack/tags/mlpack-1.0.2/src/mlpack/tests/tree_test.cpp (rev 0)
+++ mlpack/tags/mlpack-1.0.2/src/mlpack/tests/tree_test.cpp 2012-08-15 17:40:09 UTC (rev 13406)
@@ -0,0 +1,1671 @@
+/**
+ * @file tree_test.cpp
+ *
+ * Tests for tree-building methods.
+ */
+#include <mlpack/core.hpp>
+#include <mlpack/core/tree/bounds.hpp>
+#include <mlpack/core/tree/binary_space_tree/binary_space_tree.hpp>
+#include <mlpack/core/metrics/lmetric.hpp>
+#include <mlpack/core/tree/cover_tree/cover_tree.hpp>
+
+#include <boost/test/unit_test.hpp>
+#include "old_boost_test_definitions.hpp"
+
+using namespace mlpack;
+using namespace mlpack::math;
+using namespace mlpack::tree;
+using namespace mlpack::metric;
+using namespace mlpack::bound;
+
+BOOST_AUTO_TEST_SUITE(TreeTest);
+
+/**
+ * Ensure that a bound, by default, is empty and has no dimensionality.
+ */
+BOOST_AUTO_TEST_CASE(HRectBoundEmptyConstructor)
+{
+ HRectBound<2> b;
+
+ BOOST_REQUIRE_EQUAL((int) b.Dim(), 0);
+}
+
+/**
+ * Ensure that when we specify the dimensionality in the constructor, it is
+ * correct, and the bounds are all the empty set.
+ */
+BOOST_AUTO_TEST_CASE(HRectBoundDimConstructor)
+{
+ HRectBound<2> b(2); // We'll do this with 2 and 5 dimensions.
+
+ BOOST_REQUIRE_EQUAL(b.Dim(), 2);
+ BOOST_REQUIRE_SMALL(b[0].Width(), 1e-5);
+ BOOST_REQUIRE_SMALL(b[1].Width(), 1e-5);
+
+ b = HRectBound<2>(5);
+
+ BOOST_REQUIRE_EQUAL(b.Dim(), 5);
+ BOOST_REQUIRE_SMALL(b[0].Width(), 1e-5);
+ BOOST_REQUIRE_SMALL(b[1].Width(), 1e-5);
+ BOOST_REQUIRE_SMALL(b[2].Width(), 1e-5);
+ BOOST_REQUIRE_SMALL(b[3].Width(), 1e-5);
+ BOOST_REQUIRE_SMALL(b[4].Width(), 1e-5);
+}
+
+/**
+ * Test the copy constructor.
+ */
+BOOST_AUTO_TEST_CASE(HRectBoundCopyConstructor)
+{
+ HRectBound<2> b(2);
+ b[0] = Range(0.0, 2.0);
+ b[1] = Range(2.0, 3.0);
+
+ HRectBound<2> c(b);
+
+ BOOST_REQUIRE_EQUAL(c.Dim(), 2);
+ BOOST_REQUIRE_SMALL(c[0].Lo(), 1e-5);
+ BOOST_REQUIRE_CLOSE(c[0].Hi(), 2.0, 1e-5);
+ BOOST_REQUIRE_CLOSE(c[1].Lo(), 2.0, 1e-5);
+ BOOST_REQUIRE_CLOSE(c[1].Hi(), 3.0, 1e-5);
+}
+
+/**
+ * Test the assignment operator.
+ */
+BOOST_AUTO_TEST_CASE(HRectBoundAssignmentOperator)
+{
+ HRectBound<2> b(2);
+ b[0] = Range(0.0, 2.0);
+ b[1] = Range(2.0, 3.0);
+
+ HRectBound<2> c(4);
+
+ c = b;
+
+ BOOST_REQUIRE_EQUAL(c.Dim(), 2);
+ BOOST_REQUIRE_SMALL(c[0].Lo(), 1e-5);
+ BOOST_REQUIRE_CLOSE(c[0].Hi(), 2.0, 1e-5);
+ BOOST_REQUIRE_CLOSE(c[1].Lo(), 2.0, 1e-5);
+ BOOST_REQUIRE_CLOSE(c[1].Hi(), 3.0, 1e-5);
+}
+
+/**
+ * Test that clearing the dimensions resets the bound to empty.
+ */
+BOOST_AUTO_TEST_CASE(HRectBoundClear)
+{
+ HRectBound<2> b(2); // We'll do this with two dimensions only.
+
+ b[0] = Range(0.0, 2.0);
+ b[1] = Range(2.0, 4.0);
+
+ // Now we just need to make sure that we clear the range.
+ b.Clear();
+
+ BOOST_REQUIRE_SMALL(b[0].Width(), 1e-5);
+ BOOST_REQUIRE_SMALL(b[1].Width(), 1e-5);
+}
+
+/**
+ * Ensure that we get the correct centroid for our bound.
+ */
+BOOST_AUTO_TEST_CASE(HRectBoundCentroid)
+{
+ // Create a simple 3-dimensional bound.
+ HRectBound<2> b(3);
+
+ b[0] = Range(0.0, 5.0);
+ b[1] = Range(-2.0, -1.0);
+ b[2] = Range(-10.0, 50.0);
+
+ arma::vec centroid;
+
+ b.Centroid(centroid);
+
+ BOOST_REQUIRE_EQUAL(centroid.n_elem, 3);
+ BOOST_REQUIRE_CLOSE(centroid[0], 2.5, 1e-5);
+ BOOST_REQUIRE_CLOSE(centroid[1], -1.5, 1e-5);
+ BOOST_REQUIRE_CLOSE(centroid[2], 20.0, 1e-5);
+}
+
+/**
+ * Ensure that we calculate the correct minimum distance between a point and a
+ * bound.
+ */
+BOOST_AUTO_TEST_CASE(HRectBoundMinDistancePoint)
+{
+ // We'll do the calculation in five dimensions, and we'll use three cases for
+ // the point: point is outside the bound; point is on the edge of the bound;
+ // point is inside the bound. In the latter two cases, the distance should be
+ // zero.
+ HRectBound<2> b(5);
+
+ b[0] = Range(0.0, 2.0);
+ b[1] = Range(1.0, 5.0);
+ b[2] = Range(-2.0, 2.0);
+ b[3] = Range(-5.0, -2.0);
+ b[4] = Range(1.0, 2.0);
+
+ arma::vec point = "-2.0 0.0 10.0 3.0 3.0";
+
+ // This will be the Euclidean squared distance.
+ BOOST_REQUIRE_CLOSE(b.MinDistance(point), 95.0, 1e-5);
+
+ point = "2.0 5.0 2.0 -5.0 1.0";
+
+ BOOST_REQUIRE_SMALL(b.MinDistance(point), 1e-5);
+
+ point = "1.0 2.0 0.0 -2.0 1.5";
+
+ BOOST_REQUIRE_SMALL(b.MinDistance(point), 1e-5);
+}
+
+/**
+ * Ensure that we calculate the correct minimum distance between a bound and
+ * another bound.
+ */
+BOOST_AUTO_TEST_CASE(HRectBoundMinDistanceBound)
+{
+ // We'll do the calculation in five dimensions, and we can use six cases.
+ // The other bound is completely outside the bound; the other bound is on the
+ // edge of the bound; the other bound partially overlaps the bound; the other
+ // bound fully overlaps the bound; the other bound is entirely inside the
+ // bound; the other bound entirely envelops the bound.
+ HRectBound<2> b(5);
+
+ b[0] = Range(0.0, 2.0);
+ b[1] = Range(1.0, 5.0);
+ b[2] = Range(-2.0, 2.0);
+ b[3] = Range(-5.0, -2.0);
+ b[4] = Range(1.0, 2.0);
+
+ HRectBound<2> c(5);
+
+ // The other bound is completely outside the bound.
+ c[0] = Range(-5.0, -2.0);
+ c[1] = Range(6.0, 7.0);
+ c[2] = Range(-2.0, 2.0);
+ c[3] = Range(2.0, 5.0);
+ c[4] = Range(3.0, 4.0);
+
+ BOOST_REQUIRE_CLOSE(b.MinDistance(c), 22.0, 1e-5);
+ BOOST_REQUIRE_CLOSE(c.MinDistance(b), 22.0, 1e-5);
+
+ // The other bound is on the edge of the bound.
+ c[0] = Range(-2.0, 0.0);
+ c[1] = Range(0.0, 1.0);
+ c[2] = Range(-3.0, -2.0);
+ c[3] = Range(-10.0, -5.0);
+ c[4] = Range(2.0, 3.0);
+
+ BOOST_REQUIRE_SMALL(b.MinDistance(c), 1e-5);
+ BOOST_REQUIRE_SMALL(c.MinDistance(b), 1e-5);
+
+ // The other bound partially overlaps the bound.
+ c[0] = Range(-2.0, 1.0);
+ c[1] = Range(0.0, 2.0);
+ c[2] = Range(-2.0, 2.0);
+ c[3] = Range(-8.0, -4.0);
+ c[4] = Range(0.0, 4.0);
+
+ BOOST_REQUIRE_SMALL(b.MinDistance(c), 1e-5);
+ BOOST_REQUIRE_SMALL(c.MinDistance(b), 1e-5);
+
+ // The other bound fully overlaps the bound.
+ BOOST_REQUIRE_SMALL(b.MinDistance(b), 1e-5);
+ BOOST_REQUIRE_SMALL(c.MinDistance(c), 1e-5);
+
+ // The other bound is entirely inside the bound / the other bound entirely
+ // envelops the bound.
+ c[0] = Range(-1.0, 3.0);
+ c[1] = Range(0.0, 6.0);
+ c[2] = Range(-3.0, 3.0);
+ c[3] = Range(-7.0, 0.0);
+ c[4] = Range(0.0, 5.0);
+
+ BOOST_REQUIRE_SMALL(b.MinDistance(c), 1e-5);
+ BOOST_REQUIRE_SMALL(c.MinDistance(b), 1e-5);
+
+ // Now we must be sure that the minimum distance to itself is 0.
+ BOOST_REQUIRE_SMALL(b.MinDistance(b), 1e-5);
+ BOOST_REQUIRE_SMALL(c.MinDistance(c), 1e-5);
+}
+
+/**
+ * Ensure that we calculate the correct maximum distance between a bound and a
+ * point. This uses the same test cases as the MinDistance test.
+ */
+BOOST_AUTO_TEST_CASE(HRectBoundMaxDistancePoint)
+{
+ // We'll do the calculation in five dimensions, and we'll use three cases for
+ // the point: point is outside the bound; point is on the edge of the bound;
+ // point is inside the bound. In the latter two cases, the distance should be
+ // zero.
+ HRectBound<2> b(5);
+
+ b[0] = Range(0.0, 2.0);
+ b[1] = Range(1.0, 5.0);
+ b[2] = Range(-2.0, 2.0);
+ b[3] = Range(-5.0, -2.0);
+ b[4] = Range(1.0, 2.0);
+
+ arma::vec point = "-2.0 0.0 10.0 3.0 3.0";
+
+ // This will be the Euclidean squared distance.
+ BOOST_REQUIRE_CLOSE(b.MaxDistance(point), 253.0, 1e-5);
+
+ point = "2.0 5.0 2.0 -5.0 1.0";
+
+ BOOST_REQUIRE_CLOSE(b.MaxDistance(point), 46.0, 1e-5);
+
+ point = "1.0 2.0 0.0 -2.0 1.5";
+
+ BOOST_REQUIRE_CLOSE(b.MaxDistance(point), 23.25, 1e-5);
+}
+
+/**
+ * Ensure that we calculate the correct maximum distance between a bound and
+ * another bound. This uses the same test cases as the MinDistance test.
+ */
+BOOST_AUTO_TEST_CASE(HRectBoundMaxDistanceBound)
+{
+ // We'll do the calculation in five dimensions, and we can use six cases.
+ // The other bound is completely outside the bound; the other bound is on the
+ // edge of the bound; the other bound partially overlaps the bound; the other
+ // bound fully overlaps the bound; the other bound is entirely inside the
+ // bound; the other bound entirely envelops the bound.
+ HRectBound<2> b(5);
+
+ b[0] = Range(0.0, 2.0);
+ b[1] = Range(1.0, 5.0);
+ b[2] = Range(-2.0, 2.0);
+ b[3] = Range(-5.0, -2.0);
+ b[4] = Range(1.0, 2.0);
+
+ HRectBound<2> c(5);
+
+ // The other bound is completely outside the bound.
+ c[0] = Range(-5.0, -2.0);
+ c[1] = Range(6.0, 7.0);
+ c[2] = Range(-2.0, 2.0);
+ c[3] = Range(2.0, 5.0);
+ c[4] = Range(3.0, 4.0);
+
+ BOOST_REQUIRE_CLOSE(b.MaxDistance(c), 210.0, 1e-5);
+ BOOST_REQUIRE_CLOSE(c.MaxDistance(b), 210.0, 1e-5);
+
+ // The other bound is on the edge of the bound.
+ c[0] = Range(-2.0, 0.0);
+ c[1] = Range(0.0, 1.0);
+ c[2] = Range(-3.0, -2.0);
+ c[3] = Range(-10.0, -5.0);
+ c[4] = Range(2.0, 3.0);
+
+ BOOST_REQUIRE_CLOSE(b.MaxDistance(c), 134.0, 1e-5);
+ BOOST_REQUIRE_CLOSE(c.MaxDistance(b), 134.0, 1e-5);
+
+ // The other bound partially overlaps the bound.
+ c[0] = Range(-2.0, 1.0);
+ c[1] = Range(0.0, 2.0);
+ c[2] = Range(-2.0, 2.0);
+ c[3] = Range(-8.0, -4.0);
+ c[4] = Range(0.0, 4.0);
+
+ BOOST_REQUIRE_CLOSE(b.MaxDistance(c), 102.0, 1e-5);
+ BOOST_REQUIRE_CLOSE(c.MaxDistance(b), 102.0, 1e-5);
+
+ // The other bound fully overlaps the bound.
+ BOOST_REQUIRE_CLOSE(b.MaxDistance(b), 46.0, 1e-5);
+ BOOST_REQUIRE_CLOSE(c.MaxDistance(c), 61.0, 1e-5);
+
+ // The other bound is entirely inside the bound / the other bound entirely
+ // envelops the bound.
+ c[0] = Range(-1.0, 3.0);
+ c[1] = Range(0.0, 6.0);
+ c[2] = Range(-3.0, 3.0);
+ c[3] = Range(-7.0, 0.0);
+ c[4] = Range(0.0, 5.0);
+
+ BOOST_REQUIRE_CLOSE(b.MaxDistance(c), 100.0, 1e-5);
+ BOOST_REQUIRE_CLOSE(c.MaxDistance(b), 100.0, 1e-5);
+
+ // Identical bounds. This will be the sum of the squared widths in each
+ // dimension.
+ BOOST_REQUIRE_CLOSE(b.MaxDistance(b), 46.0, 1e-5);
+ BOOST_REQUIRE_CLOSE(c.MaxDistance(c), 162.0, 1e-5);
+
+ // One last additional case. If the bound encloses only one point, the
+ // maximum distance between it and itself is 0.
+ HRectBound<2> d(2);
+
+ d[0] = Range(2.0, 2.0);
+ d[1] = Range(3.0, 3.0);
+
+ BOOST_REQUIRE_SMALL(d.MaxDistance(d), 1e-5);
+}
+
+/**
+ * Ensure that the ranges returned by RangeDistance() are equal to the minimum
+ * and maximum distance. We will perform this test by creating random bounds
+ * and comparing the behavior to MinDistance() and MaxDistance() -- so this test
+ * is assuming that those passed and operate correctly.
+ */
+BOOST_AUTO_TEST_CASE(HRectBoundRangeDistanceBound)
+{
+ for (int i = 0; i < 50; i++)
+ {
+ size_t dim = math::RandInt(20);
+
+ HRectBound<2> a(dim);
+ HRectBound<2> b(dim);
+
+ // We will set the low randomly and the width randomly for each dimension of
+ // each bound.
+ arma::vec loA(dim);
+ arma::vec widthA(dim);
+
+ loA.randu();
+ widthA.randu();
+
+ arma::vec lo_b(dim);
+ arma::vec width_b(dim);
+
+ lo_b.randu();
+ width_b.randu();
+
+ for (size_t j = 0; j < dim; j++)
+ {
+ a[j] = Range(loA[j], loA[j] + widthA[j]);
+ b[j] = Range(lo_b[j], lo_b[j] + width_b[j]);
+ }
+
+ // Now ensure that MinDistance and MaxDistance report the same.
+ Range r = a.RangeDistance(b);
+ Range s = b.RangeDistance(a);
+
+ BOOST_REQUIRE_CLOSE(r.Lo(), s.Lo(), 1e-5);
+ BOOST_REQUIRE_CLOSE(r.Hi(), s.Hi(), 1e-5);
+
+ BOOST_REQUIRE_CLOSE(r.Lo(), a.MinDistance(b), 1e-5);
+ BOOST_REQUIRE_CLOSE(r.Hi(), a.MaxDistance(b), 1e-5);
+
+ BOOST_REQUIRE_CLOSE(s.Lo(), b.MinDistance(a), 1e-5);
+ BOOST_REQUIRE_CLOSE(s.Hi(), b.MaxDistance(a), 1e-5);
+ }
+}
+
+/**
+ * Ensure that the ranges returned by RangeDistance() are equal to the minimum
+ * and maximum distance. We will perform this test by creating random bounds
+ * and comparing the bheavior to MinDistance() and MaxDistance() -- so this test
+ * is assuming that those passed and operate correctly. This is for the
+ * bound-to-point case.
+ */
+BOOST_AUTO_TEST_CASE(HRectBoundRangeDistancePoint)
+{
+ for (int i = 0; i < 20; i++)
+ {
+ size_t dim = math::RandInt(20);
+
+ HRectBound<2> a(dim);
+
+ // We will set the low randomly and the width randomly for each dimension of
+ // each bound.
+ arma::vec loA(dim);
+ arma::vec widthA(dim);
+
+ loA.randu();
+ widthA.randu();
+
+ for (size_t j = 0; j < dim; j++)
+ a[j] = Range(loA[j], loA[j] + widthA[j]);
+
+ // Now run the test on a few points.
+ for (int j = 0; j < 10; j++)
+ {
+ arma::vec point(dim);
+
+ point.randu();
+
+ Range r = a.RangeDistance(point);
+
+ BOOST_REQUIRE_CLOSE(r.Lo(), a.MinDistance(point), 1e-5);
+ BOOST_REQUIRE_CLOSE(r.Hi(), a.MaxDistance(point), 1e-5);
+ }
+ }
+}
+
+/**
+ * Test that we can expand the bound to include a new point.
+ */
+BOOST_AUTO_TEST_CASE(HRectBoundOrOperatorPoint)
+{
+ // Because this should be independent in each dimension, we can essentially
+ // run five test cases at once.
+ HRectBound<2> b(5);
+
+ b[0] = Range(1.0, 3.0);
+ b[1] = Range(2.0, 4.0);
+ b[2] = Range(-2.0, -1.0);
+ b[3] = Range(0.0, 0.0);
+ b[4] = Range(); // Empty range.
+
+ arma::vec point = "2.0 4.0 2.0 -1.0 6.0";
+
+ b |= point;
+
+ BOOST_REQUIRE_CLOSE(b[0].Lo(), 1.0, 1e-5);
+ BOOST_REQUIRE_CLOSE(b[0].Hi(), 3.0, 1e-5);
+ BOOST_REQUIRE_CLOSE(b[1].Lo(), 2.0, 1e-5);
+ BOOST_REQUIRE_CLOSE(b[1].Hi(), 4.0, 1e-5);
+ BOOST_REQUIRE_CLOSE(b[2].Lo(), -2.0, 1e-5);
+ BOOST_REQUIRE_CLOSE(b[2].Hi(), 2.0, 1e-5);
+ BOOST_REQUIRE_CLOSE(b[3].Lo(), -1.0, 1e-5);
+ BOOST_REQUIRE_SMALL(b[3].Hi(), 1e-5);
+ BOOST_REQUIRE_CLOSE(b[4].Lo(), 6.0, 1e-5);
+ BOOST_REQUIRE_CLOSE(b[4].Hi(), 6.0, 1e-5);
+}
+
+/**
+ * Test that we can expand the bound to include another bound.
+ */
+BOOST_AUTO_TEST_CASE(HRectBoundOrOperatorBound)
+{
+ // Because this should be independent in each dimension, we can run many tests
+ // at once.
+ HRectBound<2> b(8);
+
+ b[0] = Range(1.0, 3.0);
+ b[1] = Range(2.0, 4.0);
+ b[2] = Range(-2.0, -1.0);
+ b[3] = Range(4.0, 5.0);
+ b[4] = Range(2.0, 4.0);
+ b[5] = Range(0.0, 0.0);
+ b[6] = Range();
+ b[7] = Range(1.0, 3.0);
+
+ HRectBound<2> c(8);
+
+ c[0] = Range(-3.0, -1.0); // Entirely less than the other bound.
+ c[1] = Range(0.0, 2.0); // Touching edges.
+ c[2] = Range(-3.0, -1.5); // Partially overlapping.
+ c[3] = Range(4.0, 5.0); // Identical.
+ c[4] = Range(1.0, 5.0); // Entirely enclosing.
+ c[5] = Range(2.0, 2.0); // A single point.
+ c[6] = Range(1.0, 3.0);
+ c[7] = Range(); // Empty set.
+
+ HRectBound<2> d = c;
+
+ b |= c;
+ d |= b;
+
+ BOOST_REQUIRE_CLOSE(b[0].Lo(), -3.0, 1e-5);
+ BOOST_REQUIRE_CLOSE(b[0].Hi(), 3.0, 1e-5);
+ BOOST_REQUIRE_CLOSE(d[0].Lo(), -3.0, 1e-5);
+ BOOST_REQUIRE_CLOSE(d[0].Hi(), 3.0, 1e-5);
+
+ BOOST_REQUIRE_CLOSE(b[1].Lo(), 0.0, 1e-5);
+ BOOST_REQUIRE_CLOSE(b[1].Hi(), 4.0, 1e-5);
+ BOOST_REQUIRE_CLOSE(d[1].Lo(), 0.0, 1e-5);
+ BOOST_REQUIRE_CLOSE(d[1].Hi(), 4.0, 1e-5);
+
+ BOOST_REQUIRE_CLOSE(b[2].Lo(), -3.0, 1e-5);
+ BOOST_REQUIRE_CLOSE(b[2].Hi(), -1.0, 1e-5);
+ BOOST_REQUIRE_CLOSE(d[2].Lo(), -3.0, 1e-5);
+ BOOST_REQUIRE_CLOSE(d[2].Hi(), -1.0, 1e-5);
+
+ BOOST_REQUIRE_CLOSE(b[3].Lo(), 4.0, 1e-5);
+ BOOST_REQUIRE_CLOSE(b[3].Hi(), 5.0, 1e-5);
+ BOOST_REQUIRE_CLOSE(d[3].Lo(), 4.0, 1e-5);
+ BOOST_REQUIRE_CLOSE(d[3].Hi(), 5.0, 1e-5);
+
+ BOOST_REQUIRE_CLOSE(b[4].Lo(), 1.0, 1e-5);
+ BOOST_REQUIRE_CLOSE(b[4].Hi(), 5.0, 1e-5);
+ BOOST_REQUIRE_CLOSE(d[4].Lo(), 1.0, 1e-5);
+ BOOST_REQUIRE_CLOSE(d[4].Hi(), 5.0, 1e-5);
+
+ BOOST_REQUIRE_SMALL(b[5].Lo(), 1e-5);
+ BOOST_REQUIRE_CLOSE(b[5].Hi(), 2.0, 1e-5);
+ BOOST_REQUIRE_SMALL(d[5].Lo(), 1e-5);
+ BOOST_REQUIRE_CLOSE(d[5].Hi(), 2.0, 1e-5);
+
+ BOOST_REQUIRE_CLOSE(b[6].Lo(), 1.0, 1e-5);
+ BOOST_REQUIRE_CLOSE(b[6].Hi(), 3.0, 1e-5);
+ BOOST_REQUIRE_CLOSE(d[6].Lo(), 1.0, 1e-5);
+ BOOST_REQUIRE_CLOSE(d[6].Hi(), 3.0, 1e-5);
+
+ BOOST_REQUIRE_CLOSE(b[7].Lo(), 1.0, 1e-5);
+ BOOST_REQUIRE_CLOSE(b[7].Hi(), 3.0, 1e-5);
+ BOOST_REQUIRE_CLOSE(d[7].Lo(), 1.0, 1e-5);
+ BOOST_REQUIRE_CLOSE(d[7].Hi(), 3.0, 1e-5);
+}
+
+/**
+ * Test that the Contains() function correctly figures out whether or not a
+ * point is in a bound.
+ */
+BOOST_AUTO_TEST_CASE(HRectBoundContains)
+{
+ // We can test a couple different points: completely outside the bound,
+ // adjacent in one dimension to the bound, adjacent in all dimensions to the
+ // bound, and inside the bound.
+ HRectBound<2> b(3);
+
+ b[0] = Range(0.0, 2.0);
+ b[1] = Range(0.0, 2.0);
+ b[2] = Range(0.0, 2.0);
+
+ // Completely outside the range.
+ arma::vec point = "-1.0 4.0 4.0";
+ BOOST_REQUIRE(!b.Contains(point));
+
+ // Completely outside, but one dimension is in the range.
+ point = "-1.0 4.0 1.0";
+ BOOST_REQUIRE(!b.Contains(point));
+
+ // Outside, but one dimension is on the edge.
+ point = "-1.0 0.0 3.0";
+ BOOST_REQUIRE(!b.Contains(point));
+
+ // Two dimensions are on the edge, but one is outside.
+ point = "0.0 0.0 3.0";
+ BOOST_REQUIRE(!b.Contains(point));
+
+ // Completely on the edge (should be contained).
+ point = "0.0 0.0 0.0";
+ BOOST_REQUIRE(b.Contains(point));
+
+ // Inside the range.
+ point = "0.3 1.0 0.4";
+ BOOST_REQUIRE(b.Contains(point));
+}
+
+BOOST_AUTO_TEST_CASE(TestBallBound)
+{
+ BallBound<> b1;
+ BallBound<> b2;
+
+ // Create two balls with a center distance of 1 from each other.
+ // Give the first one a radius of 0.3 and the second a radius of 0.4.
+ b1.Center().set_size(3);
+ b1.Center()[0] = 1;
+ b1.Center()[1] = 2;
+ b1.Center()[2] = 3;
+ b1.Radius() = 0.3;
+
+ b2.Center().set_size(3);
+ b2.Center()[0] = 1;
+ b2.Center()[1] = 2;
+ b2.Center()[2] = 4;
+ b2.Radius() = 0.4;
+
+ BOOST_REQUIRE_CLOSE(b1.MinDistance(b2), 1-0.3-0.4, 1e-5);
+ BOOST_REQUIRE_CLOSE(b1.RangeDistance(b2).Hi(), 1+0.3+0.4, 1e-5);
+ BOOST_REQUIRE_CLOSE(b1.RangeDistance(b2).Lo(), 1-0.3-0.4, 1e-5);
+ BOOST_REQUIRE_CLOSE(b1.RangeDistance(b2).Hi(), 1+0.3+0.4, 1e-5);
+ BOOST_REQUIRE_CLOSE(b1.RangeDistance(b2).Lo(), 1-0.3-0.4, 1e-5);
+
+ BOOST_REQUIRE_CLOSE(b2.MinDistance(b1), 1-0.3-0.4, 1e-5);
+ BOOST_REQUIRE_CLOSE(b2.MaxDistance(b1), 1+0.3+0.4, 1e-5);
+ BOOST_REQUIRE_CLOSE(b2.RangeDistance(b1).Hi(), 1+0.3+0.4, 1e-5);
+ BOOST_REQUIRE_CLOSE(b2.RangeDistance(b1).Lo(), 1-0.3-0.4, 1e-5);
+
+ BOOST_REQUIRE(b1.Contains(b1.Center()));
+ BOOST_REQUIRE(!b1.Contains(b2.Center()));
+
+ BOOST_REQUIRE(!b2.Contains(b1.Center()));
+ BOOST_REQUIRE(b2.Contains(b2.Center()));
+ arma::vec b2point(3); // A point that's within the radius but not the center.
+ b2point[0] = 1.1;
+ b2point[1] = 2.1;
+ b2point[2] = 4.1;
+
+ BOOST_REQUIRE(b2.Contains(b2point));
+
+ BOOST_REQUIRE_SMALL(b1.MinDistance(b1.Center()), 1e-5);
+ BOOST_REQUIRE_CLOSE(b1.MinDistance(b2.Center()), 1 - 0.3, 1e-5);
+ BOOST_REQUIRE_CLOSE(b2.MinDistance(b1.Center()), 1 - 0.4, 1e-5);
+ BOOST_REQUIRE_CLOSE(b2.MaxDistance(b1.Center()), 1 + 0.4, 1e-5);
+ BOOST_REQUIRE_CLOSE(b1.MaxDistance(b2.Center()), 1 + 0.3, 1e-5);
+}
+
+/**
+ * Ensure that a bound, by default, is empty and has no dimensionality, and the
+ * box size vector is empty.
+ */
+BOOST_AUTO_TEST_CASE(PeriodicHRectBoundEmptyConstructor)
+{
+ PeriodicHRectBound<2> b;
+
+ BOOST_REQUIRE_EQUAL(b.Dim(), 0);
+ BOOST_REQUIRE_EQUAL(b.Box().n_elem, 0);
+}
+
+/**
+ * Ensure that when we specify the dimensionality in the constructor, it is
+ * correct, and the bounds are all the empty set.
+ */
+BOOST_AUTO_TEST_CASE(PeriodicHRectBoundBoxConstructor)
+{
+ PeriodicHRectBound<2> b(arma::vec("5 6"));
+
+ BOOST_REQUIRE_EQUAL(b.Dim(), 2);
+ BOOST_REQUIRE_SMALL(b[0].Width(), 1e-5);
+ BOOST_REQUIRE_SMALL(b[1].Width(), 1e-5);
+ BOOST_REQUIRE_EQUAL(b.Box().n_elem, 2);
+ BOOST_REQUIRE_CLOSE(b.Box()[0], 5.0, 1e-5);
+ BOOST_REQUIRE_CLOSE(b.Box()[1], 6.0, 1e-5);
+
+ PeriodicHRectBound<2> d(arma::vec("2 3 4 5 6"));
+
+ BOOST_REQUIRE_EQUAL(d.Dim(), 5);
+ BOOST_REQUIRE_SMALL(d[0].Width(), 1e-5);
+ BOOST_REQUIRE_SMALL(d[1].Width(), 1e-5);
+ BOOST_REQUIRE_SMALL(d[2].Width(), 1e-5);
+ BOOST_REQUIRE_SMALL(d[3].Width(), 1e-5);
+ BOOST_REQUIRE_SMALL(d[4].Width(), 1e-5);
+ BOOST_REQUIRE_EQUAL(d.Box().n_elem, 5);
+ BOOST_REQUIRE_CLOSE(d.Box()[0], 2.0, 1e-5);
+ BOOST_REQUIRE_CLOSE(d.Box()[1], 3.0, 1e-5);
+ BOOST_REQUIRE_CLOSE(d.Box()[2], 4.0, 1e-5);
+ BOOST_REQUIRE_CLOSE(d.Box()[3], 5.0, 1e-5);
+ BOOST_REQUIRE_CLOSE(d.Box()[4], 6.0, 1e-5);
+}
+
+/**
+ * Test the copy constructor.
+ */
+BOOST_AUTO_TEST_CASE(PeriodicHRectBoundCopyConstructor)
+{
+ PeriodicHRectBound<2> b(arma::vec("3 4"));
+ b[0] = Range(0.0, 2.0);
+ b[1] = Range(2.0, 3.0);
+
+ PeriodicHRectBound<2> c(b);
+
+ BOOST_REQUIRE_EQUAL(c.Dim(), 2);
+ BOOST_REQUIRE_SMALL(c[0].Lo(), 1e-5);
+ BOOST_REQUIRE_CLOSE(c[0].Hi(), 2.0, 1e-5);
+ BOOST_REQUIRE_CLOSE(c[1].Lo(), 2.0, 1e-5);
+ BOOST_REQUIRE_CLOSE(c[1].Hi(), 3.0, 1e-5);
+ BOOST_REQUIRE_EQUAL(c.Box().n_elem, 2);
+ BOOST_REQUIRE_CLOSE(c.Box()[0], 3.0, 1e-5);
+ BOOST_REQUIRE_CLOSE(c.Box()[1], 4.0, 1e-5);
+}
+
+/**
+ * Test the assignment operator.
+ *
+BOOST_AUTO_TEST_CASE(PeriodicHRectBoundAssignmentOperator)
+{
+ PeriodicHRectBound<2> b(arma::vec("3 4"));
+ b[0] = Range(0.0, 2.0);
+ b[1] = Range(2.0, 3.0);
+
+ PeriodicHRectBound<2> c(arma::vec("3 4 5 6"));
+
+ c = b;
+
+ BOOST_REQUIRE_EQUAL(c.Dim(), 2);
+ BOOST_REQUIRE_SMALL(c[0].Lo(), 1e-5);
+ BOOST_REQUIRE_CLOSE(c[0].Hi(), 2.0, 1e-5);
+ BOOST_REQUIRE_CLOSE(c[1].Lo(), 2.0, 1e-5);
+ BOOST_REQUIRE_CLOSE(c[1].Hi(), 3.0, 1e-5);
+ BOOST_REQUIRE_EQUAL(c.Box().n_elem, 2);
+ BOOST_REQUIRE_CLOSE(c.Box()[0], 3.0, 1e-5);
+ BOOST_REQUIRE_CLOSE(c.Box()[1], 4.0, 1e-5);
+}*/
+
+/**
+ * Ensure that we can set the box size correctly.
+ *
+BOOST_AUTO_TEST_CASE(PeriodicHRectBoundSetBoxSize)
+{
+ PeriodicHRectBound<2> b(arma::vec("1 2"));
+
+ b.SetBoxSize(arma::vec("10 12"));
+
+ BOOST_REQUIRE_EQUAL(b.Box().n_elem, 2);
+ BOOST_REQUIRE_CLOSE(b.Box()[0], 10.0, 1e-5);
+ BOOST_REQUIRE_CLOSE(b.Box()[1], 12.0, 1e-5);
+}*/
+
+/**
+ * Ensure that we can clear the dimensions correctly. This does not involve the
+ * box size at all, so the test can be identical to the HRectBound test.
+ *
+BOOST_AUTO_TEST_CASE(PeriodicHRectBoundClear)
+{
+ // We'll do this with two dimensions only.
+ PeriodicHRectBound<2> b(arma::vec("5 5"));
+
+ b[0] = Range(0.0, 2.0);
+ b[1] = Range(2.0, 4.0);
+
+ // Now we just need to make sure that we clear the range.
+ b.Clear();
+
+ BOOST_REQUIRE_SMALL(b[0].Width(), 1e-5);
+ BOOST_REQUIRE_SMALL(b[1].Width(), 1e-5);
+}*/
+
+/**
+ * Ensure that we get the correct centroid for our bound.
+ *
+BOOST_AUTO_TEST_CASE(PeriodicHRectBoundCentroid) {
+ // Create a simple 3-dimensional bound. The centroid is not affected by the
+ // periodic coordinates.
+ PeriodicHRectBound<2> b(arma::vec("100 100 100"));
+
+ b[0] = Range(0.0, 5.0);
+ b[1] = Range(-2.0, -1.0);
+ b[2] = Range(-10.0, 50.0);
+
+ arma::vec centroid;
+
+ b.Centroid(centroid);
+
+ BOOST_REQUIRE_EQUAL(centroid.n_elem, 3);
+ BOOST_REQUIRE_CLOSE(centroid[0], 2.5, 1e-5);
+ BOOST_REQUIRE_CLOSE(centroid[1], -1.5, 1e-5);
+ BOOST_REQUIRE_CLOSE(centroid[2], 20.0, 1e-5);
+}*/
+
+/**
+ * Correctly calculate the minimum distance between the bound and a point in
+ * periodic coordinates. We have to account for the shifts necessary in
+ * periodic coordinates too, so that makes testing this a little more difficult.
+ *
+BOOST_AUTO_TEST_CASE(PeriodicHRectBoundMinDistancePoint)
+{
+ // First, we'll start with a simple 2-dimensional case where the point is
+ // inside the bound, then on the edge of the bound, then barely outside the
+ // bound. The box size will be large enough that this is basically the
+ // HRectBound case.
+ PeriodicHRectBound<2> b(arma::vec("100 100"));
+
+ b[0] = Range(0.0, 5.0);
+ b[1] = Range(2.0, 4.0);
+
+ // Inside the bound.
+ arma::vec point = "2.5 3.0";
+
+ BOOST_REQUIRE_SMALL(b.MinDistance(point), 1e-5);
+
+ // On the edge.
+ point = "5.0 4.0";
+
+ BOOST_REQUIRE_SMALL(b.MinDistance(point), 1e-5);
+
+ // And just a little outside the bound.
+ point = "6.0 5.0";
+
+ BOOST_REQUIRE_CLOSE(b.MinDistance(point), 2.0, 1e-5);
+
+ // Now we start to invoke the periodicity. This point will "alias" to (-1,
+ // -1).
+ point = "99.0 99.0";
+
+ BOOST_REQUIRE_CLOSE(b.MinDistance(point), 10.0, 1e-5);
+
+ // We will perform several tests on a one-dimensional bound.
+ PeriodicHRectBound<2> a(arma::vec("5.0"));
+ point.set_size(1);
+
+ a[0] = Range(2.0, 4.0); // Entirely inside box.
+ point[0] = 7.5; // Inside first right image of the box.
+
+ BOOST_REQUIRE_SMALL(a.MinDistance(point), 1e-5);
+
+ a[0] = Range(0.0, 5.0); // Fills box fully.
+ point[1] = 19.3; // Inside the box, which covers everything.
+
+ BOOST_REQUIRE_SMALL(a.MinDistance(point), 1e-5);
+
+ a[0] = Range(-10.0, 10.0); // Larger than the box.
+ point[0] = -500.0; // Inside the box, which covers everything.
+
+ BOOST_REQUIRE_SMALL(a.MinDistance(point), 1e-5);
+
+ a[0] = Range(-2.0, 1.0); // Crosses over an edge.
+ point[0] = 2.9; // The first right image of the bound starts at 3.0.
+
+ BOOST_REQUIRE_CLOSE(a.MinDistance(point), 0.01, 1e-5);
+
+ a[0] = Range(2.0, 4.0); // Inside box.
+ point[0] = 0.0; // Closest to the first left image of the bound.
+
+ BOOST_REQUIRE_CLOSE(a.MinDistance(point), 1.0, 1e-5);
+
+ a[0] = Range(0.0, 2.0); // On edge of box.
+ point[0] = 7.1; // 0.1 away from the first right image of the bound.
+
+ BOOST_REQUIRE_CLOSE(a.MinDistance(point), 0.01, 1e-5);
+
+ PeriodicHRectBound<2> d(arma::vec("0.0"));
+ d[0] = Range(-10.0, 10.0); // Box is of infinite size.
+ point[0] = 810.0; // 800 away from the only image of the box.
+
+ BOOST_REQUIRE_CLOSE(d.MinDistance(point), 640000.0, 1e-5);
+
+ PeriodicHRectBound<2> e(arma::vec("-5.0"));
+ e[0] = Range(2.0, 4.0); // Box size of -5 should function the same as 5.
+ point[0] = -10.8; // Should alias to 4.2.
+
+ BOOST_REQUIRE_CLOSE(e.MinDistance(point), 0.04, 1e-5);
+
+ // Switch our bound to a higher dimensionality. This should ensure that the
+ // dimensions are independent like they should be.
+ PeriodicHRectBound<2> c(arma::vec("5.0 5.0 5.0 5.0 5.0 5.0 0.0 -5.0"));
+
+ c[0] = Range(2.0, 4.0); // Entirely inside box.
+ c[1] = Range(0.0, 5.0); // Fills box fully.
+ c[2] = Range(-10.0, 10.0); // Larger than the box.
+ c[3] = Range(-2.0, 1.0); // Crosses over an edge.
+ c[4] = Range(2.0, 4.0); // Inside box.
+ c[5] = Range(0.0, 2.0); // On edge of box.
+ c[6] = Range(-10.0, 10.0); // Box is of infinite size.
+ c[7] = Range(2.0, 4.0); // Box size of -5 should function the same as 5.
+
+ point.set_size(8);
+ point[0] = 7.5; // Inside first right image of the box.
+ point[1] = 19.3; // Inside the box, which covers everything.
+ point[2] = -500.0; // Inside the box, which covers everything.
+ point[3] = 2.9; // The first right image of the bound starts at 3.0.
+ point[4] = 0.0; // Closest to the first left image of the bound.
+ point[5] = 7.1; // 0.1 away from the first right image of the bound.
+ point[6] = 810.0; // 800 away from the only image of the box.
+ point[7] = -10.8; // Should alias to 4.2.
+
+ BOOST_REQUIRE_CLOSE(c.MinDistance(point), 640001.06, 1e-10);
+}*/
+
+/**
+ * Correctly calculate the minimum distance between the bound and another bound in
+ * periodic coordinates. We have to account for the shifts necessary in
+ * periodic coordinates too, so that makes testing this a little more difficult.
+ *
+BOOST_AUTO_TEST_CASE(PeriodicHRectBoundMinDistanceBound)
+{
+ // First, we'll start with a simple 2-dimensional case where the bounds are nonoverlapping,
+ // then one bound is on the edge of the other bound,
+ // then overlapping, then one range entirely covering the other. The box size will be large enough that this is basically the
+ // HRectBound case.
+ PeriodicHRectBound<2> b(arma::vec("100 100"));
+ PeriodicHRectBound<2> c(arma::vec("100 100"));
+
+ b[0] = Range(0.0, 5.0);
+ b[1] = Range(2.0, 4.0);
+
+ // Inside the bound.
+ c[0] = Range(7.0, 9.0);
+ c[1] = Range(10.0,12.0);
+
+
+ BOOST_REQUIRE_CLOSE(b.MinDistance(c), 40.0, 1e-5);
+
+ // On the edge.
+ c[0] = Range(5.0, 8.0);
+ c[1] = Range(4.0, 6.0);
+
+ BOOST_REQUIRE_SMALL(b.MinDistance(c), 1e-5);
+
+ // Overlapping the bound.
+ c[0] = Range(3.0, 6.0);
+ c[1] = Range(1.0, 3.0);
+
+ BOOST_REQUIRE_SMALL(b.MinDistance(c), 1e-5);
+
+ // One range entirely covering the other
+
+ c[0] = Range(0.0, 6.0);
+ c[1] = Range(1.0, 7.0);
+
+ BOOST_REQUIRE_SMALL(b.MinDistance(c), 1e-5);
+
+ // Now we start to invoke the periodicity. These bounds "alias" to (-3.0,
+ // -1.0) and (5,0,6.0).
+
+ c[0] = Range(97.0, 99.0);
+ c[1] = Range(105.0, 106.0);
+
+ BOOST_REQUIRE_CLOSE(b.MinDistance(c), 2.0, 1e-5);
+
+ // We will perform several tests on a one-dimensional bound and smaller box size and mostly overlapping.
+ PeriodicHRectBound<2> a(arma::vec("5.0"));
+ PeriodicHRectBound<2> d(a);
+
+ a[0] = Range(2.0, 4.0); // Entirely inside box.
+ d[0] = Range(7.5, 10.0); // In the right image of the box, overlapping ranges.
+
+ BOOST_REQUIRE_SMALL(a.MinDistance(d), 1e-5);
+
+ a[0] = Range(0.0, 5.0); // Fills box fully.
+ d[0] = Range(19.3, 21.0); // Two intervals inside the box, same as range of b[0].
+
+ BOOST_REQUIRE_SMALL(a.MinDistance(d), 1e-5);
+
+ a[0] = Range(-10.0, 10.0); // Larger than the box.
+ d[0] = Range(-500.0, -498.0); // Inside the box, which covers everything.
+
+ BOOST_REQUIRE_SMALL(a.MinDistance(d), 1e-5);
+
+ a[0] = Range(-2.0, 1.0); // Crosses over an edge.
+ d[0] = Range(2.9, 5.1); // The first right image of the bound starts at 3.0. Overlapping
+
+ BOOST_REQUIRE_SMALL(a.MinDistance(d), 1e-5);
+
+ a[0] = Range(-1.0, 1.0); // Crosses over an edge.
+ d[0] = Range(11.9, 12.5); // The first right image of the bound starts at 4.0.
+ BOOST_REQUIRE_CLOSE(a.MinDistance(d), 0.81, 1e-5);
+
+ a[0] = Range(2.0, 3.0);
+ d[0] = Range(9.5, 11);
+ BOOST_REQUIRE_CLOSE(a.MinDistance(d), 1.0, 1e-5);
+
+}*/
+
+/**
+ * Correctly calculate the maximum distance between the bound and a point in
+ * periodic coordinates. We have to account for the shifts necessary in
+ * periodic coordinates too, so that makes testing this a little more difficult.
+ *
+BOOST_AUTO_TEST_CASE(PeriodicHRectBoundMaxDistancePoint)
+{
+ // First, we'll start with a simple 2-dimensional case where the point is
+ // inside the bound, then on the edge of the bound, then barely outside the
+ // bound. The box size will be large enough that this is basically the
+ // HRectBound case.
+ PeriodicHRectBound<2> b(arma::vec("100 100"));
+
+ b[0] = Range(0.0, 5.0);
+ b[1] = Range(2.0, 4.0);
+
+ // Inside the bound.
+ arma::vec point = "2.5 3.0";
+
+ //BOOST_REQUIRE_CLOSE(b.MaxDistance(point), 7.25, 1e-5);
+
+ // On the edge.
+ point = "5.0 4.0";
+
+ BOOST_REQUIRE_CLOSE(b.MaxDistance(point), 29.0, 1e-5);
+
+ // And just a little outside the bound.
+ point = "6.0 5.0";
+
+ BOOST_REQUIRE_CLOSE(b.MaxDistance(point), 45.0, 1e-5);
+
+ // Now we start to invoke the periodicity. This point will "alias" to (-1,
+ // -1).
+ point = "99.0 99.0";
+
+ BOOST_REQUIRE_CLOSE(b.MaxDistance(point), 61.0, 1e-5);
+
+ // We will perform several tests on a one-dimensional bound and smaller box size.
+ PeriodicHRectBound<2> a(arma::vec("5.0"));
+ point.set_size(1);
+
+ a[0] = Range(2.0, 4.0); // Entirely inside box.
+ point[0] = 7.5; // Inside first right image of the box.
+
+ BOOST_REQUIRE_CLOSE(a.MaxDistance(point), 2.25, 1e-5);
+
+ a[0] = Range(0.0, 5.0); // Fills box fully.
+ point[1] = 19.3; // Inside the box, which covers everything.
+
+ BOOST_REQUIRE_CLOSE(a.MaxDistance(point), 18.49, 1e-5);
+
+ a[0] = Range(-10.0, 10.0); // Larger than the box.
+ point[0] = -500.0; // Inside the box, which covers everything.
+
+ BOOST_REQUIRE_CLOSE(a.MaxDistance(point), 25.0, 1e-5);
+
+ a[0] = Range(-2.0, 1.0); // Crosses over an edge.
+ point[0] = 2.9; // The first right image of the bound starts at 3.0.
+
+ BOOST_REQUIRE_CLOSE(a.MaxDistance(point), 8.41, 1e-5);
+
+ a[0] = Range(2.0, 4.0); // Inside box.
+ point[0] = 0.0; // Farthest from the first right image of the bound.
+
+ BOOST_REQUIRE_CLOSE(a.MaxDistance(point), 25.0, 1e-5);
+
+ a[0] = Range(0.0, 2.0); // On edge of box.
+ point[0] = 7.1; // 2.1 away from the first left image of the bound.
+
+ BOOST_REQUIRE_CLOSE(a.MaxDistance(point), 4.41, 1e-5);
+
+ PeriodicHRectBound<2> d(arma::vec("0.0"));
+ d[0] = Range(-10.0, 10.0); // Box is of infinite size.
+ point[0] = 810.0; // 820 away from the only left image of the box.
+
+ BOOST_REQUIRE_CLOSE(d.MinDistance(point), 672400.0, 1e-5);
+
+ PeriodicHRectBound<2> e(arma::vec("-5.0"));
+ e[0] = Range(2.0, 4.0); // Box size of -5 should function the same as 5.
+ point[0] = -10.8; // Should alias to 4.2.
+
+ BOOST_REQUIRE_CLOSE(e.MaxDistance(point), 4.84, 1e-5);
+
+ // Switch our bound to a higher dimensionality. This should ensure that the
+ // dimensions are independent like they should be.
+ PeriodicHRectBound<2> c(arma::vec("5.0 5.0 5.0 5.0 5.0 5.0 0.0 -5.0"));
+
+ c[0] = Range(2.0, 4.0); // Entirely inside box.
+ c[1] = Range(0.0, 5.0); // Fills box fully.
+ c[2] = Range(-10.0, 10.0); // Larger than the box.
+ c[3] = Range(-2.0, 1.0); // Crosses over an edge.
+ c[4] = Range(2.0, 4.0); // Inside box.
+ c[5] = Range(0.0, 2.0); // On edge of box.
+ c[6] = Range(-10.0, 10.0); // Box is of infinite size.
+ c[7] = Range(2.0, 4.0); // Box size of -5 should function the same as 5.
+
+ point.set_size(8);
+ point[0] = 7.5; // Inside first right image of the box.
+ point[1] = 19.3; // Inside the box, which covers everything.
+ point[2] = -500.0; // Inside the box, which covers everything.
+ point[3] = 2.9; // The first right image of the bound starts at 3.0.
+ point[4] = 0.0; // Closest to the first left image of the bound.
+ point[5] = 7.1; // 0.1 away from the first right image of the bound.
+ point[6] = 810.0; // 800 away from the only image of the box.
+ point[7] = -10.8; // Should alias to 4.2.
+
+ BOOST_REQUIRE_CLOSE(c.MaxDistance(point), 672630.65, 1e-10);
+}*/
+
+/**
+ * Correctly calculate the maximum distance between the bound and another bound in
+ * periodic coordinates. We have to account for the shifts necessary in
+ * periodic coordinates too, so that makes testing this a little more difficult.
+ *
+BOOST_AUTO_TEST_CASE(PeriodicHRectBoundMaxDistanceBound)
+{
+ // First, we'll start with a simple 2-dimensional case where the bounds are nonoverlapping,
+ // then one bound is on the edge of the other bound,
+ // then overlapping, then one range entirely covering the other. The box size will be large enough that this is basically the
+ // HRectBound case.
+ PeriodicHRectBound<2> b(arma::vec("100 100"));
+ PeriodicHRectBound<2> c(b);
+
+ b[0] = Range(0.0, 5.0);
+ b[1] = Range(2.0, 4.0);
+
+ // Inside the bound.
+
+ c[0] = Range(7.0, 9.0);
+ c[1] = Range(10.0,12.0);
+
+
+ BOOST_REQUIRE_CLOSE(b.MaxDistance(c), 181.0, 1e-5);
+
+ // On the edge.
+
+ c[0] = Range(5.0, 8.0);
+ c[1] = Range(4.0, 6.0);
+
+ BOOST_REQUIRE_CLOSE(b.MaxDistance(c), 80.0, 1e-5);
+
+ // Overlapping the bound.
+
+ c[0] = Range(3.0, 6.0);
+ c[1] = Range(1.0, 3.0);
+
+ BOOST_REQUIRE_CLOSE(b.MaxDistance(c), 45.0, 1e-5);
+
+ // One range entirely covering the other
+
+ c[0] = Range(0.0, 6.0);
+ c[1] = Range(1.0, 7.0);
+
+ BOOST_REQUIRE_CLOSE(b.MaxDistance(c), 61.0, 1e-5);
+
+ // Now we start to invoke the periodicity. Thess bounds "alias" to (-3.0,
+ // -1.0) and (5,0,6.0).
+
+ c[0] = Range(97.0, 99.0);
+ c[1] = Range(105.0, 106.0);
+
+ BOOST_REQUIRE_CLOSE(b.MaxDistance(c), 80.0, 1e-5);
+
+ // We will perform several tests on a one-dimensional bound and smaller box size.
+ PeriodicHRectBound<2> a(arma::vec("5.0"));
+ PeriodicHRectBound<2> d(a);
+
+ a[0] = Range(2.0, 4.0); // Entirely inside box.
+ d[0] = Range(7.5, 10); // In the right image of the box, overlapping ranges.
+
+ BOOST_REQUIRE_CLOSE(a.MaxDistance(d), 9.0, 1e-5);
+
+ a[0] = Range(0.0, 5.0); // Fills box fully.
+ d[0] = Range(19.3, 21.0); // Two intervals inside the box, same as range of b[0].
+
+ BOOST_REQUIRE_CLOSE(a.MaxDistance(d), 18.49, 1e-5);
+
+ a[0] = Range(-10.0, 10.0); // Larger than the box.
+ d[0] = Range(-500.0, -498.0); // Inside the box, which covers everything.
+
+ BOOST_REQUIRE_CLOSE(a.MaxDistance(d), 9.0, 1e-5);
+
+ a[0] = Range(-2.0, 1.0); // Crosses over an edge.
+ d[0] = Range(2.9, 5.1); // The first right image of the bound starts at 3.0.
+
+ BOOST_REQUIRE_CLOSE(a.MaxDistance(d), 24.01, 1e-5);
+}*/
+
+
+/**
+ * It seems as though Bill has stumbled across a bug where
+ * BinarySpaceTree<>::count() returns something different than
+ * BinarySpaceTree<>::count_. So, let's build a simple tree and make sure they
+ * are the same.
+ */
+BOOST_AUTO_TEST_CASE(TreeCountMismatch)
+{
+ arma::mat dataset = "2.0 5.0 9.0 4.0 8.0 7.0;"
+ "3.0 4.0 6.0 7.0 1.0 2.0 ";
+
+ // Leaf size of 1.
+ BinarySpaceTree<HRectBound<2> > rootNode(dataset, 1);
+
+ BOOST_REQUIRE(rootNode.Count() == 6);
+ BOOST_REQUIRE(rootNode.Left()->Count() == 3);
+ BOOST_REQUIRE(rootNode.Left()->Left()->Count() == 2);
+ BOOST_REQUIRE(rootNode.Left()->Left()->Left()->Count() == 1);
+ BOOST_REQUIRE(rootNode.Left()->Left()->Right()->Count() == 1);
+ BOOST_REQUIRE(rootNode.Left()->Right()->Count() == 1);
+ BOOST_REQUIRE(rootNode.Right()->Count() == 3);
+ BOOST_REQUIRE(rootNode.Right()->Left()->Count() == 2);
+ BOOST_REQUIRE(rootNode.Right()->Left()->Left()->Count() == 1);
+ BOOST_REQUIRE(rootNode.Right()->Left()->Right()->Count() == 1);
+ BOOST_REQUIRE(rootNode.Right()->Right()->Count() == 1);
+}
+
+// Forward declaration of methods we need for the next test.
+template<typename TreeType>
+bool CheckPointBounds(TreeType* node, const arma::mat& data);
+
+template<typename TreeType>
+void GenerateVectorOfTree(TreeType* node,
+ size_t depth,
+ std::vector<TreeType*>& v);
+
+template<int t_pow>
+bool DoBoundsIntersect(HRectBound<t_pow>& a,
+ HRectBound<t_pow>& b,
+ size_t ia,
+ size_t ib);
+
+/**
+ * Exhaustive kd-tree test based on #125.
+ *
+ * - Generate a random dataset of a random size.
+ * - Build a tree on that dataset.
+ * - Ensure all the permutation indices map back to the correct points.
+ * - Verify that each point is contained inside all of the bounds of its parent
+ * nodes.
+ * - Verify that each bound at a particular level of the tree does not overlap
+ * with any other bounds at that level.
+ *
+ * Then, we do that whole process a handful of times.
+ */
+BOOST_AUTO_TEST_CASE(KdTreeTest)
+{
+ typedef BinarySpaceTree<HRectBound<2> > TreeType;
+
+ size_t maxRuns = 10; // Ten total tests.
+ size_t pointIncrements = 1000; // Range is from 2000 points to 11000.
+
+ // We use the default leaf size of 20.
+ for(size_t run = 0; run < maxRuns; run++)
+ {
+ size_t dimensions = run + 2;
+ size_t maxPoints = (run + 1) * pointIncrements;
+
+ size_t size = maxPoints;
+ arma::mat dataset = arma::mat(dimensions, size);
+ arma::mat datacopy; // Used to test mappings.
+
+ // Mappings for post-sort verification of data.
+ std::vector<size_t> newToOld;
+ std::vector<size_t> oldToNew;
+
+ // Generate data.
+ dataset.randu();
+ datacopy = dataset; // Save a copy.
+
+ // Build the tree itself.
+ TreeType root(dataset, newToOld, oldToNew);
+
+ // Ensure the size of the tree is correct.
+ BOOST_REQUIRE_EQUAL(root.Count(), size);
+
+ // Check the forward and backward mappings for correctness.
+ for(size_t i = 0; i < size; i++)
+ {
+ for(size_t j = 0; j < dimensions; j++)
+ {
+ BOOST_REQUIRE_EQUAL(dataset(j, i), datacopy(j, newToOld[i]));
+ BOOST_REQUIRE_EQUAL(dataset(j, oldToNew[i]), datacopy(j, i));
+ }
+ }
+
+ // Now check that each point is contained inside of all bounds above it.
+ CheckPointBounds(&root, dataset);
+
+ // Now check that no peers overlap.
+ std::vector<TreeType*> v;
+ GenerateVectorOfTree(&root, 1, v);
+
+ // Start with the first pair.
+ size_t depth = 2;
+ // Compare each peer against every other peer.
+ while (depth < v.size())
+ {
+ for (size_t i = depth; i < 2 * depth && i < v.size(); i++)
+ for (size_t j = i + 1; j < 2 * depth && j < v.size(); j++)
+ if (v[i] != NULL && v[j] != NULL)
+ BOOST_REQUIRE(!DoBoundsIntersect(v[i]->Bound(), v[j]->Bound(),
+ i, j));
+
+ depth *= 2;
+ }
+ }
+
+ arma::mat dataset = arma::mat(25, 1000);
+ for (size_t col = 0; col < dataset.n_cols; ++col)
+ for (size_t row = 0; row < dataset.n_rows; ++row)
+ dataset(row, col) = row + col;
+
+ TreeType root(dataset);
+ // Check the tree size.
+ BOOST_REQUIRE_EQUAL(root.TreeSize(), 127);
+ // Check the tree depth.
+ BOOST_REQUIRE_EQUAL(root.TreeDepth(), 7);
+}
+
+// Recursively checks that each node contains all points that it claims to have.
+template<typename TreeType>
+bool CheckPointBounds(TreeType* node, const arma::mat& data)
+{
+ if (node == NULL) // We have passed a leaf node.
+ return true;
+
+ TreeType* left = node->Left();
+ TreeType* right = node->Right();
+
+ size_t begin = node->Begin();
+ size_t count = node->Count();
+
+ // Check that each point which this tree claims is actually inside the tree.
+ for (size_t index = begin; index < begin + count; index++)
+ if (!node->Bound().Contains(data.col(index)))
+ return false;
+
+ return CheckPointBounds(left, data) && CheckPointBounds(right, data);
+}
+
+template<int t_pow>
+bool DoBoundsIntersect(HRectBound<t_pow>& a,
+ HRectBound<t_pow>& b,
+ size_t /* ia */,
+ size_t /* ib */)
+{
+ size_t dimensionality = a.Dim();
+
+ Range r_a;
+ Range r_b;
+
+ for (size_t i = 0; i < dimensionality; i++)
+ {
+ r_a = a[i];
+ r_b = b[i];
+ if (r_a < r_b || r_a > r_b) // If a does not overlap b at all.
+ return false;
+ }
+
+ return true;
+}
+
+template<typename TreeType>
+void GenerateVectorOfTree(TreeType* node,
+ size_t depth,
+ std::vector<TreeType*>& v)
+{
+ if (node == NULL)
+ return;
+
+ if (depth >= v.size())
+ v.resize(2 * depth + 1, NULL); // Resize to right size; fill with NULL.
+
+ v[depth] = node;
+
+ // Recurse to the left and right children.
+ GenerateVectorOfTree(node->Left(), depth * 2, v);
+ GenerateVectorOfTree(node->Right(), depth * 2 + 1, v);
+
+ return;
+}
+
+/**
+ * Exhaustive sparse kd-tree test based on #125.
+ *
+ * - Generate a random dataset of a random size.
+ * - Build a tree on that dataset.
+ * - Ensure all the permutation indices map back to the correct points.
+ * - Verify that each point is contained inside all of the bounds of its parent
+ * nodes.
+ * - Verify that each bound at a particular level of the tree does not overlap
+ * with any other bounds at that level.
+ *
+ * Then, we do that whole process a handful of times.
+ */
+BOOST_AUTO_TEST_CASE(ExhaustiveSparseKDTreeTest)
+{
+ typedef BinarySpaceTree<HRectBound<2>, EmptyStatistic, arma::SpMat<double> >
+ TreeType;
+
+ size_t maxRuns = 2; // Two total tests.
+ size_t pointIncrements = 200; // Range is from 200 points to 400.
+
+ // We use the default leaf size of 20.
+ for(size_t run = 0; run < maxRuns; run++)
+ {
+ size_t dimensions = run + 2;
+ size_t maxPoints = (run + 1) * pointIncrements;
+
+ size_t size = maxPoints;
+ arma::SpMat<double> dataset = arma::SpMat<double>(dimensions, size);
+ arma::SpMat<double> datacopy; // Used to test mappings.
+
+ // Mappings for post-sort verification of data.
+ std::vector<size_t> newToOld;
+ std::vector<size_t> oldToNew;
+
+ // Generate data.
+ dataset.randu();
+ datacopy = dataset; // Save a copy.
+
+ // Build the tree itself.
+ TreeType root(dataset, newToOld, oldToNew);
+
+ // Ensure the size of the tree is correct.
+ BOOST_REQUIRE_EQUAL(root.Count(), size);
+
+ // Check the forward and backward mappings for correctness.
+ for(size_t i = 0; i < size; i++)
+ {
+ for(size_t j = 0; j < dimensions; j++)
+ {
+ BOOST_REQUIRE_EQUAL(dataset(j, i), datacopy(j, newToOld[i]));
+ BOOST_REQUIRE_EQUAL(dataset(j, oldToNew[i]), datacopy(j, i));
+ }
+ }
+
+ // Now check that each point is contained inside of all bounds above it.
+ CheckPointBounds(&root, dataset);
+
+ // Now check that no peers overlap.
+ std::vector<TreeType*> v;
+ GenerateVectorOfTree(&root, 1, v);
+
+ // Start with the first pair.
+ size_t depth = 2;
+ // Compare each peer against every other peer.
+ while (depth < v.size())
+ {
+ for (size_t i = depth; i < 2 * depth && i < v.size(); i++)
+ for (size_t j = i + 1; j < 2 * depth && j < v.size(); j++)
+ if (v[i] != NULL && v[j] != NULL)
+ BOOST_REQUIRE(!DoBoundsIntersect(v[i]->Bound(), v[j]->Bound(),
+ i, j));
+
+ depth *= 2;
+ }
+ }
+
+ arma::SpMat<double> dataset(25, 1000);
+ for (size_t col = 0; col < dataset.n_cols; ++col)
+ for (size_t row = 0; row < dataset.n_rows; ++row)
+ dataset(row, col) = row + col;
+
+ TreeType root(dataset);
+ // Check the tree size.
+ BOOST_REQUIRE_EQUAL(root.TreeSize(), 127);
+ // Check the tree depth.
+ BOOST_REQUIRE_EQUAL(root.TreeDepth(), 7);
+}
+
+template<typename TreeType>
+void RecurseTreeCountLeaves(const TreeType& node, arma::vec& counts)
+{
+ for (size_t i = 0; i < node.NumChildren(); ++i)
+ {
+ if (node.Child(i).NumChildren() == 0)
+ counts[node.Child(i).Point()]++;
+ else
+ RecurseTreeCountLeaves<TreeType>(node.Child(i), counts);
+ }
+}
+
+template<typename TreeType>
+void CheckSelfChild(const TreeType& node)
+{
+ if (node.NumChildren() == 0)
+ return; // No self-child applicable here.
+
+ bool found = false;
+ for (size_t i = 0; i < node.NumChildren(); ++i)
+ {
+ if (node.Child(i).Point() == node.Point())
+ found = true;
+
+ // Recursively check the children.
+ CheckSelfChild(node.Child(i));
+ }
+
+ // Ensure this has its own self-child.
+ BOOST_REQUIRE_EQUAL(found, true);
+}
+
+template<typename TreeType, typename MetricType>
+void CheckCovering(const TreeType& node)
+{
+ // Return if a leaf. No checking necessary.
+ if (node.NumChildren() == 0)
+ return;
+
+ const arma::mat& dataset = node.Dataset();
+ const size_t nodePoint = node.Point();
+
+ // To ensure that this node satisfies the covering principle, we must ensure
+ // that the distance to each child is less than pow(base, scale).
+ double maxDistance = pow(node.Base(), node.Scale());
+ for (size_t i = 0; i < node.NumChildren(); ++i)
+ {
+ const size_t childPoint = node.Child(i).Point();
+
+ double distance = MetricType::Evaluate(dataset.col(nodePoint),
+ dataset.col(childPoint));
+
+ BOOST_REQUIRE_LE(distance, maxDistance);
+
+ // Check the child.
+ CheckCovering<TreeType, MetricType>(node.Child(i));
+ }
+}
+
+template<typename TreeType, typename MetricType>
+void CheckIndividualSeparation(const TreeType& constantNode,
+ const TreeType& node)
+{
+ // Don't check points at a lower scale.
+ if (node.Scale() < constantNode.Scale())
+ return;
+
+ // If at a higher scale, recurse.
+ if (node.Scale() > constantNode.Scale())
+ {
+ for (size_t i = 0; i < node.NumChildren(); ++i)
+ {
+ // Don't recurse into leaves.
+ if (node.Child(i).NumChildren() > 0)
+ CheckIndividualSeparation<TreeType, MetricType>(constantNode,
+ node.Child(i));
+ }
+
+ return;
+ }
+
+ // Don't compare the same point against itself.
+ if (node.Point() == constantNode.Point())
+ return;
+
+ // Now we know we are at the same scale, so make the comparison.
+ const arma::mat& dataset = constantNode.Dataset();
+ const size_t constantPoint = constantNode.Point();
+ const size_t nodePoint = node.Point();
+
+ // Make sure the distance is at least the following value (in accordance with
+ // the separation principle of cover trees).
+ double minDistance = pow(constantNode.Base(),
+ constantNode.Scale());
+
+ double distance = MetricType::Evaluate(dataset.col(constantPoint),
+ dataset.col(nodePoint));
+
+ BOOST_REQUIRE_GE(distance, minDistance);
+}
+
+template<typename TreeType, typename MetricType>
+void CheckSeparation(const TreeType& node, const TreeType& root)
+{
+ // Check the separation between this point and all other points on this scale.
+ CheckIndividualSeparation<TreeType, MetricType>(node, root);
+
+ // Check the children, but only if they are not leaves. Leaves don't need to
+ // be checked.
+ for (size_t i = 0; i < node.NumChildren(); ++i)
+ if (node.Child(i).NumChildren() > 0)
+ CheckSeparation<TreeType, MetricType>(node.Child(i), root);
+}
+
+
+/**
+ * Create a simple cover tree and then make sure it is valid.
+ */
+BOOST_AUTO_TEST_CASE(SimpleCoverTreeConstructionTest)
+{
+ // 20-point dataset.
+ arma::mat data = arma::trans(arma::mat("0.0 0.0;"
+ "1.0 0.0;"
+ "0.5 0.5;"
+ "2.0 2.0;"
+ "-1.0 2.0;"
+ "3.0 0.0;"
+ "1.5 5.5;"
+ "-2.0 -2.0;"
+ "-1.5 1.5;"
+ "0.0 4.0;"
+ "2.0 1.0;"
+ "2.0 1.2;"
+ "-3.0 -2.5;"
+ "-5.0 -5.0;"
+ "3.5 1.5;"
+ "2.0 2.5;"
+ "-1.0 -1.0;"
+ "-3.5 1.5;"
+ "3.5 -1.5;"
+ "2.0 1.0;"));
+
+ // The root point will be the first point, (0, 0).
+ CoverTree<> tree(data); // Expansion constant of 2.0.
+
+ // The furthest point from the root will be (-5, -5), with a squared distance
+ // of 50. This means the scale of the root node should be 6 (because 2^6 =
+ // 64).
+ BOOST_REQUIRE_EQUAL(tree.Scale(), 6);
+
+ // Now loop through the tree and ensure that each leaf is only created once.
+ arma::vec counts;
+ counts.zeros(20);
+ RecurseTreeCountLeaves(tree, counts);
+
+ // Each point should only have one leaf node representing it.
+ for (size_t i = 0; i < 20; ++i)
+ BOOST_REQUIRE_EQUAL(counts[i], 1);
+
+ // Each non-leaf should have a self-child.
+ CheckSelfChild<CoverTree<> >(tree);
+
+ // Each node must satisfy the covering principle (its children must be less
+ // than or equal to a certain distance apart).
+ CheckCovering<CoverTree<>, LMetric<2> >(tree);
+
+ // Each node's children must be separated by at least a certain value.
+ CheckSeparation<CoverTree<>, LMetric<2> >(tree, tree);
+}
+
+/**
+ * Create a large cover tree and make sure it's accurate.
+ */
+BOOST_AUTO_TEST_CASE(CoverTreeConstructionTest)
+{
+ arma::mat dataset;
+ // 50-dimensional, 1000 point.
+ dataset.randu(50, 1000);
+
+ CoverTree<> tree(dataset);
+
+ // Ensure each leaf is only created once.
+ arma::vec counts;
+ counts.zeros(1000);
+ RecurseTreeCountLeaves(tree, counts);
+
+ for (size_t i = 0; i < 1000; ++i)
+ BOOST_REQUIRE_EQUAL(counts[i], 1);
+
+ // Each non-leaf should have a self-child.
+ CheckSelfChild<CoverTree<> >(tree);
+
+ // Each node must satisfy the covering principle (its children must be less
+ // than or equal to a certain distance apart).
+ CheckCovering<CoverTree<>, LMetric<2> >(tree);
+
+ // Each node's children must be separated by at least a certain value.
+ CheckSeparation<CoverTree<>, LMetric<2> >(tree, tree);
+}
+
+/**
+ * Make sure cover trees work in different metric spaces.
+ */
+BOOST_AUTO_TEST_CASE(CoverTreeAlternateMetricTest)
+{
+ arma::mat dataset;
+ // 5-dimensional, 300-point dataset.
+ dataset.randu(5, 300);
+
+ CoverTree<LMetric<1, true> > tree(dataset);
+
+ // Ensure each leaf is only created once.
+ arma::vec counts;
+ counts.zeros(300);
+ RecurseTreeCountLeaves<CoverTree<LMetric<1, true> > >(tree, counts);
+
+ for (size_t i = 0; i < 300; ++i)
+ BOOST_REQUIRE_EQUAL(counts[i], 1);
+
+ // Each non-leaf should have a self-child.
+ CheckSelfChild<CoverTree<LMetric<1, true> > >(tree);
+
+ // Each node must satisfy the covering principle (its children must be less
+ // than or equal to a certain distance apart).
+ CheckCovering<CoverTree<LMetric<1, true> >, LMetric<1, true> >(tree);
+
+ // Each node's children must be separated by at least a certain value.
+ CheckSeparation<CoverTree<LMetric<1, true> >, LMetric<1, true> >(tree, tree);
+}
+
+BOOST_AUTO_TEST_SUITE_END();
More information about the mlpack-svn
mailing list